mindspore 2.1.0__cp39-cp39-macosx_11_0_arm64.whl → 2.2.11__cp39-cp39-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (497) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_c_dataengine.cpython-39-darwin.so +0 -0
  4. mindspore/_c_expression.cpython-39-darwin.so +0 -0
  5. mindspore/_c_mindrecord.cpython-39-darwin.so +0 -0
  6. mindspore/_check_jit_forbidden_api.py +3 -1
  7. mindspore/_checkparam.py +23 -29
  8. mindspore/_extends/graph_kernel/__init__.py +0 -1
  9. mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
  10. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  11. mindspore/_extends/graph_kernel/splitter.py +4 -11
  12. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
  13. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
  14. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  15. mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
  16. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
  17. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
  18. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  19. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  20. mindspore/_extends/parse/__init__.py +13 -15
  21. mindspore/_extends/parse/namespace.py +7 -33
  22. mindspore/_extends/parse/parser.py +67 -72
  23. mindspore/_extends/parse/resources.py +1 -1
  24. mindspore/_extends/parse/standard_method.py +86 -106
  25. mindspore/_extends/parse/trope.py +1 -1
  26. mindspore/_extends/remote/kernel_build_server.py +25 -7
  27. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  28. mindspore/_install_custom.py +43 -0
  29. mindspore/_mindspore_offline_debug.cpython-39-darwin.so +0 -0
  30. mindspore/amp.py +47 -11
  31. mindspore/boost/boost.py +1 -8
  32. mindspore/boost/boost_cell_wrapper.py +3 -2
  33. mindspore/boost/grad_accumulation.py +1 -1
  34. mindspore/boost/group_loss_scale_manager.py +8 -7
  35. mindspore/common/__init__.py +5 -3
  36. mindspore/common/_jit_fallback_utils.py +6 -0
  37. mindspore/common/_register_for_adapter.py +2 -0
  38. mindspore/common/_register_for_tensor.py +2 -2
  39. mindspore/common/_stub_tensor.py +13 -0
  40. mindspore/common/_utils.py +29 -0
  41. mindspore/common/api.py +174 -259
  42. mindspore/common/auto_dynamic_shape.py +494 -0
  43. mindspore/common/dtype.py +18 -11
  44. mindspore/common/dump.py +6 -4
  45. mindspore/common/initializer.py +14 -14
  46. mindspore/common/jit_config.py +33 -15
  47. mindspore/common/lazy_inline.py +126 -7
  48. mindspore/common/mindir_util.py +101 -0
  49. mindspore/common/parameter.py +51 -41
  50. mindspore/common/seed.py +4 -4
  51. mindspore/common/sparse_tensor.py +13 -14
  52. mindspore/common/tensor.py +243 -165
  53. mindspore/communication/__init__.py +7 -4
  54. mindspore/communication/_comm_helper.py +83 -4
  55. mindspore/communication/management.py +152 -84
  56. mindspore/config/op_info.config +14 -3
  57. mindspore/context.py +152 -61
  58. mindspore/dataset/__init__.py +5 -5
  59. mindspore/dataset/audio/__init__.py +2 -2
  60. mindspore/dataset/audio/transforms.py +52 -52
  61. mindspore/dataset/callback/ds_callback.py +16 -2
  62. mindspore/dataset/core/config.py +68 -51
  63. mindspore/dataset/engine/cache_client.py +33 -7
  64. mindspore/dataset/engine/datasets.py +250 -112
  65. mindspore/dataset/engine/datasets_audio.py +43 -211
  66. mindspore/dataset/engine/datasets_standard_format.py +16 -35
  67. mindspore/dataset/engine/datasets_text.py +43 -67
  68. mindspore/dataset/engine/datasets_user_defined.py +86 -100
  69. mindspore/dataset/engine/datasets_vision.py +219 -1029
  70. mindspore/dataset/engine/iterators.py +11 -4
  71. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
  72. mindspore/dataset/engine/obs/util.py +3 -0
  73. mindspore/dataset/engine/samplers.py +1 -1
  74. mindspore/dataset/engine/validators.py +19 -5
  75. mindspore/dataset/text/__init__.py +3 -3
  76. mindspore/dataset/text/transforms.py +101 -127
  77. mindspore/dataset/text/utils.py +205 -138
  78. mindspore/dataset/transforms/__init__.py +1 -1
  79. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  80. mindspore/dataset/transforms/transforms.py +95 -40
  81. mindspore/dataset/utils/browse_dataset.py +8 -2
  82. mindspore/dataset/utils/line_reader.py +17 -19
  83. mindspore/dataset/vision/__init__.py +3 -3
  84. mindspore/dataset/vision/c_transforms.py +6 -3
  85. mindspore/dataset/vision/transforms.py +409 -287
  86. mindspore/dataset/vision/utils.py +13 -14
  87. mindspore/dataset/vision/validators.py +11 -1
  88. mindspore/experimental/map_parameter.py +14 -0
  89. mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
  90. mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
  91. mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
  92. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  93. mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
  94. mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
  95. mindspore/gen_ops.py +273 -0
  96. mindspore/include/OWNERS +0 -1
  97. mindspore/include/api/data_type.h +2 -1
  98. mindspore/include/api/graph.h +0 -15
  99. mindspore/include/api/kernel.h +2 -0
  100. mindspore/include/api/kernel_api.h +37 -12
  101. mindspore/include/api/model.h +17 -14
  102. mindspore/include/api/status.h +8 -3
  103. mindspore/include/api/types.h +37 -4
  104. mindspore/include/c_api/ms/abstract.h +67 -0
  105. mindspore/include/c_api/ms/attribute.h +197 -0
  106. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  107. mindspore/include/c_api/ms/base/macros.h +32 -0
  108. mindspore/include/c_api/ms/base/status.h +33 -0
  109. mindspore/include/c_api/ms/base/types.h +282 -0
  110. mindspore/include/c_api/ms/context.h +102 -0
  111. mindspore/include/c_api/ms/graph.h +160 -0
  112. mindspore/include/c_api/ms/node.h +606 -0
  113. mindspore/include/c_api/ms/tensor.h +161 -0
  114. mindspore/include/c_api/ms/value.h +84 -0
  115. mindspore/include/dataset/constants.h +6 -5
  116. mindspore/include/dataset/execute.h +23 -13
  117. mindspore/include/dataset/text.h +26 -26
  118. mindspore/include/dataset/transforms.h +13 -13
  119. mindspore/include/dataset/vision.h +60 -60
  120. mindspore/include/dataset/vision_ascend.h +5 -6
  121. mindspore/include/dataset/vision_lite.h +17 -17
  122. mindspore/lib/libdnnl.2.dylib +0 -0
  123. mindspore/lib/libmindspore_backend.dylib +0 -0
  124. mindspore/lib/libmindspore_common.dylib +0 -0
  125. mindspore/lib/libmindspore_core.dylib +0 -0
  126. mindspore/lib/libmindspore_glog.0.dylib +0 -0
  127. mindspore/lib/libmindspore_gpr.15.dylib +0 -0
  128. mindspore/lib/libmindspore_grpc.15.dylib +0 -0
  129. mindspore/lib/libmindspore_shared_lib.dylib +0 -0
  130. mindspore/lib/libnnacl.dylib +0 -0
  131. mindspore/lib/libopencv_core.4.5.dylib +0 -0
  132. mindspore/lib/libopencv_imgcodecs.4.5.dylib +0 -0
  133. mindspore/lib/libopencv_imgproc.4.5.dylib +0 -0
  134. mindspore/lib/libps_cache.dylib +0 -0
  135. mindspore/lib/libtinyxml2.8.dylib +0 -0
  136. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  137. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  138. mindspore/nn/__init__.py +0 -2
  139. mindspore/nn/cell.py +313 -74
  140. mindspore/nn/dynamic_lr.py +21 -21
  141. mindspore/nn/layer/activation.py +22 -30
  142. mindspore/nn/layer/basic.py +15 -13
  143. mindspore/nn/layer/channel_shuffle.py +1 -1
  144. mindspore/nn/layer/container.py +271 -9
  145. mindspore/nn/layer/conv.py +323 -204
  146. mindspore/nn/layer/dense.py +8 -5
  147. mindspore/nn/layer/embedding.py +33 -27
  148. mindspore/nn/layer/flash_attention.py +61 -95
  149. mindspore/nn/layer/image.py +8 -6
  150. mindspore/nn/layer/math.py +16 -25
  151. mindspore/nn/layer/normalization.py +107 -66
  152. mindspore/nn/layer/padding.py +1 -1
  153. mindspore/nn/layer/pooling.py +131 -109
  154. mindspore/nn/layer/rnn_cells.py +27 -22
  155. mindspore/nn/layer/rnns.py +13 -16
  156. mindspore/nn/layer/thor_layer.py +1 -1
  157. mindspore/nn/layer/transformer.py +221 -154
  158. mindspore/nn/learning_rate_schedule.py +9 -1
  159. mindspore/nn/loss/loss.py +235 -174
  160. mindspore/nn/optim/ada_grad.py +2 -1
  161. mindspore/nn/optim/adadelta.py +1 -0
  162. mindspore/nn/optim/adafactor.py +2 -1
  163. mindspore/nn/optim/adam.py +7 -4
  164. mindspore/nn/optim/adamax.py +3 -2
  165. mindspore/nn/optim/adasum.py +2 -2
  166. mindspore/nn/optim/asgd.py +2 -3
  167. mindspore/nn/optim/ftrl.py +6 -5
  168. mindspore/nn/optim/lamb.py +7 -4
  169. mindspore/nn/optim/lars.py +1 -1
  170. mindspore/nn/optim/lazyadam.py +5 -3
  171. mindspore/nn/optim/momentum.py +2 -1
  172. mindspore/nn/optim/optimizer.py +53 -4
  173. mindspore/nn/optim/proximal_ada_grad.py +3 -4
  174. mindspore/nn/optim/rmsprop.py +4 -3
  175. mindspore/nn/optim/rprop.py +23 -12
  176. mindspore/nn/optim/sgd.py +26 -11
  177. mindspore/nn/optim/thor.py +9 -7
  178. mindspore/nn/probability/bijector/bijector.py +5 -5
  179. mindspore/nn/probability/bijector/power_transform.py +27 -27
  180. mindspore/nn/probability/bijector/softplus.py +3 -3
  181. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
  182. mindspore/nn/probability/distribution/bernoulli.py +5 -5
  183. mindspore/nn/probability/distribution/beta.py +3 -3
  184. mindspore/nn/probability/distribution/categorical.py +7 -7
  185. mindspore/nn/probability/distribution/cauchy.py +0 -1
  186. mindspore/nn/probability/distribution/distribution.py +3 -3
  187. mindspore/nn/probability/distribution/gamma.py +3 -3
  188. mindspore/nn/probability/distribution/geometric.py +4 -4
  189. mindspore/nn/probability/distribution/gumbel.py +4 -4
  190. mindspore/nn/probability/distribution/log_normal.py +2 -2
  191. mindspore/nn/probability/distribution/logistic.py +2 -2
  192. mindspore/nn/probability/distribution/poisson.py +4 -4
  193. mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
  194. mindspore/nn/probability/distribution/uniform.py +6 -6
  195. mindspore/nn/wrap/__init__.py +4 -2
  196. mindspore/nn/wrap/cell_wrapper.py +87 -34
  197. mindspore/nn/wrap/grad_reducer.py +8 -5
  198. mindspore/nn/wrap/loss_scale.py +105 -42
  199. mindspore/numpy/array_creations.py +1 -2
  200. mindspore/numpy/array_ops.py +3 -2
  201. mindspore/numpy/utils_const.py +5 -5
  202. mindspore/offline_debug/convert_async.py +2 -2
  203. mindspore/ops/_grad_experimental/__init__.py +0 -5
  204. mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
  205. mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
  206. mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
  207. mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
  208. mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
  209. mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
  210. mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
  211. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  212. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  213. mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
  214. mindspore/ops/_op_impl/aicpu/add.py +3 -3
  215. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  216. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  217. mindspore/ops/_op_impl/{_custom_op/flash_attention/constants.py → aicpu/eps.py} +18 -27
  218. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  219. mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
  220. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  221. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  222. mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
  223. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  224. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  225. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  226. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  227. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  228. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  229. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  230. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
  231. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
  232. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  233. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  234. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  235. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  236. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  237. mindspore/ops/_op_impl/tbe/__init__.py +4 -4
  238. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  239. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  240. mindspore/ops/_primitive_cache.py +1 -1
  241. mindspore/ops/_tracefunc.py +45 -13
  242. mindspore/ops/_utils/utils.py +6 -1
  243. mindspore/ops/_vmap/vmap_array_ops.py +3 -3
  244. mindspore/ops/_vmap/vmap_base.py +3 -3
  245. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  246. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  247. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  248. mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
  249. mindspore/ops/arg_dtype_cast.py +54 -0
  250. mindspore/ops/composite/base.py +37 -10
  251. mindspore/ops/composite/math_ops.py +5 -4
  252. mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
  253. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
  254. mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
  255. mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
  256. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  257. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  258. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  259. mindspore/ops/deprecated.py +304 -0
  260. mindspore/ops/function/__init__.py +4 -1
  261. mindspore/ops/function/array_func.py +174 -193
  262. mindspore/ops/function/clip_func.py +81 -13
  263. mindspore/ops/function/debug_func.py +1 -1
  264. mindspore/ops/function/grad/grad_func.py +18 -9
  265. mindspore/ops/function/image_func.py +10 -4
  266. mindspore/ops/function/linalg_func.py +5 -5
  267. mindspore/ops/function/math_func.py +575 -386
  268. mindspore/ops/function/nn_func.py +568 -260
  269. mindspore/ops/function/random_func.py +88 -57
  270. mindspore/ops/function/sparse_func.py +1 -1
  271. mindspore/ops/function/sparse_unary_func.py +14 -12
  272. mindspore/ops/function/vmap_func.py +6 -5
  273. mindspore/ops/functional.py +15 -10
  274. mindspore/ops/op_info_register.py +244 -25
  275. mindspore/ops/operations/__init__.py +31 -19
  276. mindspore/ops/operations/_grad_ops.py +71 -7
  277. mindspore/ops/operations/_inner_ops.py +350 -17
  278. mindspore/ops/operations/_quant_ops.py +4 -8
  279. mindspore/ops/operations/_sequence_ops.py +42 -0
  280. mindspore/ops/operations/array_ops.py +68 -282
  281. mindspore/ops/operations/comm_ops.py +107 -59
  282. mindspore/ops/operations/custom_ops.py +94 -70
  283. mindspore/ops/operations/debug_ops.py +8 -4
  284. mindspore/ops/operations/image_ops.py +18 -12
  285. mindspore/ops/operations/inner_ops.py +26 -3
  286. mindspore/ops/operations/math_ops.py +192 -144
  287. mindspore/ops/operations/nn_ops.py +857 -489
  288. mindspore/ops/operations/other_ops.py +0 -22
  289. mindspore/ops/operations/random_ops.py +53 -111
  290. mindspore/ops/operations/sparse_ops.py +3 -1
  291. mindspore/ops/primitive.py +24 -18
  292. mindspore/parallel/_auto_parallel_context.py +68 -8
  293. mindspore/parallel/_cost_model_context.py +2 -2
  294. mindspore/parallel/_offload_context.py +17 -3
  295. mindspore/parallel/_parallel_serialization.py +12 -5
  296. mindspore/parallel/_ps_context.py +12 -0
  297. mindspore/parallel/_tensor.py +18 -13
  298. mindspore/parallel/_transformer/layers.py +5 -3
  299. mindspore/parallel/_transformer/loss.py +1 -0
  300. mindspore/parallel/_transformer/moe.py +2 -2
  301. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  302. mindspore/parallel/_transformer/transformer.py +23 -3
  303. mindspore/parallel/_utils.py +11 -7
  304. mindspore/parallel/algo_parameter_config.py +85 -5
  305. mindspore/parallel/checkpoint_transform.py +19 -12
  306. mindspore/parallel/shard.py +21 -14
  307. mindspore/profiler/common/struct_type.py +3 -3
  308. mindspore/profiler/common/util.py +4 -2
  309. mindspore/profiler/envprofiling.py +1 -1
  310. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  311. mindspore/profiler/parser/ascend_flops_generator.py +2 -2
  312. mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
  313. mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
  314. mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
  315. mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
  316. mindspore/profiler/parser/ascend_op_generator.py +6 -6
  317. mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
  318. mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
  319. mindspore/profiler/parser/base_timeline_generator.py +10 -8
  320. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
  321. mindspore/profiler/parser/flops_parser.py +15 -11
  322. mindspore/profiler/parser/framework_parser.py +38 -22
  323. mindspore/profiler/parser/hccl_parser.py +16 -12
  324. mindspore/profiler/parser/integrator.py +22 -11
  325. mindspore/profiler/parser/memory_usage_parser.py +2 -2
  326. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  327. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  328. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  329. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  330. mindspore/profiler/parser/optime_parser.py +1 -1
  331. mindspore/profiler/parser/profiler_info.py +21 -2
  332. mindspore/profiler/parser/step_trace_parser.py +11 -14
  333. mindspore/profiler/profiling.py +179 -89
  334. mindspore/rewrite/api/node.py +102 -19
  335. mindspore/rewrite/api/node_type.py +5 -1
  336. mindspore/rewrite/api/pattern_engine.py +1 -1
  337. mindspore/rewrite/api/scoped_value.py +9 -17
  338. mindspore/rewrite/api/symbol_tree.py +131 -47
  339. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  340. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  341. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  342. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
  343. mindspore/rewrite/common/rewrite_elog.py +5 -1
  344. mindspore/rewrite/namer.py +33 -24
  345. mindspore/rewrite/namespace.py +14 -5
  346. mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
  347. mindspore/rewrite/node/call_function.py +79 -0
  348. mindspore/rewrite/node/cell_container.py +135 -0
  349. mindspore/rewrite/node/control_flow.py +88 -0
  350. mindspore/rewrite/{node.py → node/node.py} +273 -234
  351. mindspore/rewrite/node/node_manager.py +254 -0
  352. mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
  353. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  354. mindspore/rewrite/parsers/assign_parser.py +216 -221
  355. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  356. mindspore/rewrite/parsers/class_def_parser.py +174 -113
  357. mindspore/rewrite/parsers/constant_parser.py +9 -6
  358. mindspore/rewrite/parsers/container_parser.py +9 -7
  359. mindspore/rewrite/parsers/for_parser.py +42 -21
  360. mindspore/rewrite/parsers/function_def_parser.py +24 -16
  361. mindspore/rewrite/parsers/if_parser.py +28 -24
  362. mindspore/rewrite/parsers/module_parser.py +196 -25
  363. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  364. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  365. mindspore/rewrite/parsers/return_parser.py +6 -6
  366. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  367. mindspore/rewrite/sparsify/utils.py +1 -1
  368. mindspore/rewrite/symbol_tree.py +523 -578
  369. mindspore/rewrite/symbol_tree_builder.py +9 -193
  370. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  371. mindspore/run_check/_check_version.py +6 -4
  372. mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
  373. mindspore/safeguard/rewrite_obfuscation.py +541 -0
  374. mindspore/scipy/linalg.py +1 -1
  375. mindspore/scipy/ops.py +55 -5
  376. mindspore/scipy/optimize/__init__.py +3 -2
  377. mindspore/scipy/optimize/linear_sum_assignment.py +38 -33
  378. mindspore/scipy/optimize/minimize.py +7 -3
  379. mindspore/train/_utils.py +7 -3
  380. mindspore/train/amp.py +323 -123
  381. mindspore/train/anf_ir_pb2.py +14 -2
  382. mindspore/train/callback/_backup_and_restore.py +2 -12
  383. mindspore/train/callback/_callback.py +29 -4
  384. mindspore/train/callback/_checkpoint.py +23 -8
  385. mindspore/train/callback/_early_stop.py +2 -2
  386. mindspore/train/callback/_landscape.py +4 -4
  387. mindspore/train/callback/_loss_monitor.py +2 -2
  388. mindspore/train/callback/_on_request_exit.py +2 -2
  389. mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
  390. mindspore/train/callback/_summary_collector.py +15 -8
  391. mindspore/train/callback/_time_monitor.py +58 -5
  392. mindspore/train/data_sink.py +5 -11
  393. mindspore/train/dataset_helper.py +84 -57
  394. mindspore/train/loss_scale_manager.py +2 -2
  395. mindspore/train/metrics/__init__.py +3 -3
  396. mindspore/train/metrics/cosine_similarity.py +1 -1
  397. mindspore/train/metrics/hausdorff_distance.py +3 -2
  398. mindspore/train/metrics/mean_surface_distance.py +3 -2
  399. mindspore/train/metrics/metric.py +39 -19
  400. mindspore/train/metrics/roc.py +2 -2
  401. mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
  402. mindspore/train/mind_ir_pb2.py +85 -36
  403. mindspore/train/model.py +187 -47
  404. mindspore/train/serialization.py +487 -161
  405. mindspore/train/summary/_summary_adapter.py +1 -1
  406. mindspore/train/summary/_writer_pool.py +3 -2
  407. mindspore/train/summary/summary_record.py +37 -17
  408. mindspore/train/train_thor/convert_utils.py +3 -3
  409. mindspore/train/train_thor/dataset_helper.py +1 -1
  410. mindspore/version.py +1 -1
  411. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/METADATA +7 -4
  412. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/RECORD +415 -472
  413. mindspore/_extends/graph_kernel/expander.py +0 -80
  414. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
  415. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  416. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  417. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  418. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  419. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  420. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  421. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  422. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  423. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  424. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  425. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  426. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  427. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  428. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  429. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  430. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  431. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  432. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  433. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  434. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  435. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  436. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  437. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  438. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  439. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  440. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  441. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  442. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  443. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  444. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  445. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  446. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  447. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  448. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  449. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  450. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  451. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  452. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  453. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  454. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  455. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  456. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  457. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  458. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  459. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  460. mindspore/dataset/datapreprocess/__init__.py +0 -20
  461. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  462. mindspore/include/api/net.h +0 -142
  463. mindspore/nn/lr_scheduler.py +0 -262
  464. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
  465. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
  466. mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
  467. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  468. mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
  469. mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
  470. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -350
  471. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -409
  472. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -578
  473. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -199
  474. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -446
  475. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
  476. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
  477. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
  478. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
  479. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
  480. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
  481. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
  482. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  483. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  484. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
  485. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
  486. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  487. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  488. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
  489. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
  490. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
  491. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
  492. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
  493. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  494. mindspore/rewrite/node_visitor.py +0 -44
  495. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/WHEEL +0 -0
  496. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/entry_points.txt +0 -0
  497. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/top_level.txt +0 -0
mindspore/.commit_id CHANGED
@@ -1 +1 @@
1
- __commit_id__ = ''[sha1]:5822529e,[branch]:(HEAD,origin/r2.1,r2.1)''
1
+ __commit_id__ = ''[sha1]:8c390933,[branch]:(HEAD,origin/r2.2,r2.2)''
mindspore/__init__.py CHANGED
@@ -20,7 +20,7 @@ from mindspore import common, dataset, mindrecord, train, log, amp
20
20
  from mindspore import profiler, communication, numpy, parallel
21
21
  from mindspore.common import *
22
22
  from mindspore.mindrecord import *
23
- from mindspore.ops import _op_impl, grad, value_and_grad, vjp, jvp, jacfwd, jacrev, vmap, get_grad
23
+ from mindspore.ops import _op_impl, grad, value_and_grad, vjp, jvp, jacfwd, jacrev, vmap, get_grad, constexpr
24
24
  from mindspore.train import *
25
25
  from mindspore.log import *
26
26
  from mindspore.context import GRAPH_MODE, PYNATIVE_MODE, set_context, get_context, set_auto_parallel_context, \
@@ -31,8 +31,10 @@ from mindspore.profiler import Profiler, EnvProfiler
31
31
  from mindspore.parallel import set_algo_parameters, get_algo_parameters, reset_algo_parameters, \
32
32
  rank_list_for_transform, transform_checkpoint_by_rank, transform_checkpoints, merge_pipeline_strategys, shard
33
33
  from mindspore.rewrite import SymbolTree, ScopedValue, Node, NodeType, TreeNodeHelper
34
+ from mindspore.safeguard import obfuscate_ckpt, load_obf_params_into_net
34
35
  from mindspore._check_jit_forbidden_api import get_obj_module_and_name_info, is_jit_forbidden_module, \
35
36
  is_invalid_or_jit_forbidden_method
37
+ from mindspore import _install_custom
36
38
 
37
39
 
38
40
  __all__ = ["run_check"]
@@ -43,4 +45,5 @@ __all__.extend(log.__all__)
43
45
  __all__.extend(context.__all__)
44
46
  __all__.extend(parallel.__all__)
45
47
  __all__.extend(rewrite.__all__)
48
+ __all__.extend(safeguard.__all__)
46
49
  __all__.append("Profiler")
@@ -96,9 +96,11 @@ def is_invalid_or_jit_forbidden_method(obj, obj_type, attr):
96
96
  if not hasattr(obj, attr):
97
97
  raise AttributeError(f"'{obj_type}' object has no attribute '{attr}'")
98
98
  method = getattr(obj, attr)
99
- if not hasattr(method, "__module__"):
99
+ if not hasattr(method, "__module__") or method.__module__ is None:
100
100
  return False
101
101
  method_info = method.__module__ + '.' + method.__qualname__
102
102
  return method_info in _jit_forbidden_method
103
103
 
104
104
  add_jit_forbidden_module("mindspore.common.initializer")
105
+ add_jit_forbidden_module("mindspore.context")
106
+ add_jit_forbidden_module("mindspore.log")
mindspore/_checkparam.py CHANGED
@@ -84,21 +84,21 @@ def _check_inc_rel(val, lower, upper, rel):
84
84
  def _format_str_one_value(value, rel):
85
85
  """format string"""
86
86
  if rel == EQ:
87
- return "= {}".format(value)
87
+ return f"= {value}"
88
88
  if rel == NE:
89
- return "!= {}".format(value)
89
+ return f"!= {value}"
90
90
  if rel == LT:
91
- return "< {}".format(value)
91
+ return f"< {value}"
92
92
  if rel == LE:
93
- return "<= {}".format(value)
93
+ return f"<= {value}"
94
94
  if rel == GT:
95
- return "> {}".format(value)
95
+ return f"> {value}"
96
96
  if rel == GE:
97
- return ">= {}".format(value)
97
+ return f">= {value}"
98
98
  if rel == IN:
99
- return "in {}".format(value)
99
+ return f"in {value}"
100
100
  if rel == NOT_IN:
101
- return "not in {}".format(value)
101
+ return f"not in {value}"
102
102
 
103
103
  return ""
104
104
 
@@ -106,13 +106,13 @@ def _format_str_one_value(value, rel):
106
106
  def _format_str_two_value(val1, val2, rel):
107
107
  """format string"""
108
108
  if rel == INC_NEITHER:
109
- return "({}, {})".format(val1, val2)
109
+ return f"({val1}, {val2})"
110
110
  if rel == INC_LEFT:
111
- return "[{}, {})".format(val1, val2)
111
+ return f"[{val1}, {val2})"
112
112
  if rel == INC_RIGHT:
113
- return "({}, {}]".format(val1, val2)
113
+ return f"({val1}, {val2}]"
114
114
  if rel == INC_BOTH:
115
- return "[{}, {}]".format(val1, val2)
115
+ return f"[{val1}, {val2}]"
116
116
 
117
117
  return ""
118
118
 
@@ -556,8 +556,7 @@ def check_str_by_regular(target, reg=None, flag=re.ASCII, prim_name=None):
556
556
  reg = r"^\w+[0-9a-zA-Z\_\.]*$"
557
557
  if re.match(reg, target, flag) is None:
558
558
  prim_name = f"For '{prim_name}', the" if prim_name else "The"
559
- raise ValueError("{} '{}' is illegal, it must be match regular'{}' by flags'{}.'".format(
560
- prim_name, target, reg, flag))
559
+ raise ValueError(f"{prim_name} '{target}' is illegal, it must be match regular'{reg}' by flags'{flag}.'")
561
560
  return True
562
561
 
563
562
 
@@ -565,11 +564,10 @@ def check_str_by_regular(target, reg=None, flag=re.ASCII, prim_name=None):
565
564
  def check_str_and_none_by_regular(target, reg=None, flag=re.ASCII, prim_name=None):
566
565
  if reg is None:
567
566
  # Named string regular expression
568
- reg = r"^\w*[0-9a-zA-Z\_\.]*$"
567
+ reg = r"^\w*[0-9a-zA-Z\_\.\-]*$"
569
568
  if re.match(reg, target, flag) is None:
570
569
  prim_name = f"For '{prim_name}', the" if prim_name else "The"
571
- raise ValueError("{} '{}' is illegal, it must be match regular'{}' by flags'{}.'".format(
572
- prim_name, target, reg, flag))
570
+ raise ValueError(f"{prim_name} '{target}' is illegal, it must be match regular'{reg}' by flags'{flag}.'")
573
571
  return True
574
572
 
575
573
 
@@ -585,8 +583,7 @@ def check_file_name_by_regular(target, reg=None, prim_name=None):
585
583
  reg = r"^[0-9a-zA-Z@\_\-\.\:\/\\]+$"
586
584
  if re.match(reg, target) is None:
587
585
  prim_name = f"For '{prim_name}', the" if prim_name else "The"
588
- raise ValueError("{} '{}' is illegal, it must be match regular '{}'.".format(
589
- prim_name, target, reg))
586
+ raise ValueError(f"{prim_name} '{target}' is illegal, it must be match regular '{reg}'.")
590
587
 
591
588
  return True
592
589
 
@@ -802,6 +799,9 @@ def check_transpose_axis(axes, ndim):
802
799
  # if only one argument provided, it must be tuple or list
803
800
  if isinstance(perm, list):
804
801
  perm = tuple(perm)
802
+ elif isinstance(perm, int):
803
+ perm = (perm,)
804
+ _check_dim()
805
805
  else:
806
806
  if not isinstance(perm, tuple):
807
807
  raise TypeError(f"For Tensor.transpose, the parameter 'axes' must be a tuple/list, " \
@@ -959,11 +959,6 @@ def is_stub_tensor(tensor):
959
959
  return hasattr(tensor, "stub")
960
960
 
961
961
 
962
- def is_pack_tensor(tensor):
963
- """Whether it is a PackTensor."""
964
- return hasattr(tensor, "__pack__")
965
-
966
-
967
962
  def expanded_shape(ndim, axis_size, axis):
968
963
  """
969
964
  Returns a shape with size = 1 for all dimensions
@@ -984,8 +979,8 @@ def infer_out_shape(*shapes):
984
979
  def _check(items, max_size, shapes):
985
980
  for item in items:
986
981
  if item not in (1, max_size):
987
- raise ValueError(f'For Tensor, the dimension on each axis must be 1 or the max on the axis' \
988
- f'to support broadcast, but got shapes {shapes,}')
982
+ raise ValueError(f'For Tensor, the dimension on each axis must be 1 or the max value on the axis' \
983
+ f'to support broadcasting, but got shapes {shapes,}')
989
984
  shape_out = ()
990
985
  max_len = max([len(it) for it in shapes])
991
986
  for i in range(max_len):
@@ -1269,7 +1264,7 @@ def check_input_data(*data, data_class):
1269
1264
  if not ret:
1270
1265
  data_class_str = tuple(i.__name__ if hasattr(i, '__name__') else i for i in data_class) if isinstance(
1271
1266
  data_class, (tuple, list)) else (data_class if data_class is None else data_class.__name__)
1272
- raise TypeError(f'The type of input data must be in the Union({data_class_str}, ' \
1267
+ raise TypeError(f'The types of input data must be in the Union({data_class_str}, ' \
1273
1268
  f'tuple[{data_class_str}], list[{data_class_str}], dict[{data_class_str}]), ' \
1274
1269
  f'but got type {item if item is None else type(item).__name__}.')
1275
1270
 
@@ -1314,8 +1309,7 @@ def args_type_check(*type_args, **type_kwargs):
1314
1309
  for name, value in argument_dict.items():
1315
1310
  if name in bound_types:
1316
1311
  if value is not None and not isinstance(value, bound_types[name]):
1317
- raise TypeError("The parameter '{}' must be {}, but got {}"
1318
- .format(name, bound_types[name], type(value)))
1312
+ raise TypeError(f"The parameter '{name}' must be {bound_types[name]}, but got {type(value)}")
1319
1313
  return func(*args, **kwargs)
1320
1314
 
1321
1315
  return wrapper
@@ -14,5 +14,4 @@
14
14
  # ============================================================================
15
15
  """init"""
16
16
  from .splitter import split_with_json
17
- from .expander import get_op_expander, get_expander_op_list
18
17
  from .parallel_estimate import estimate_calculation_amount, estimate_ops
@@ -83,23 +83,23 @@ class CommonPattern:
83
83
  def reshape(dom):
84
84
  """fuse strategy for reshape dom"""
85
85
  if dom.pattern != PrimLib.RESHAPE:
86
- return []
86
+ return [], False
87
87
  min_area, forward_fuse = None, False
88
88
  for a, _ in dom.out_relations.items():
89
- if a.pattern <= PrimLib.BROADCAST and dom.check_acyclic(a) and \
90
- (min_area is None or a.pattern < min_area.pattern):
91
- min_area = a
89
+ if a.pattern <= PrimLib.BROADCAST and dom.check_acyclic(a):
90
+ if min_area is None or a.pattern < min_area.pattern:
91
+ min_area = a
92
92
  for a, _ in dom.in_relations.items():
93
- if a.pattern <= PrimLib.BROADCAST and a.check_acyclic(dom) and \
94
- (min_area is None or a.pattern < min_area.pattern):
95
- min_area, forward_fuse = a, True
96
- return ([min_area], forward_fuse) if min_area else []
93
+ if a.pattern <= PrimLib.BROADCAST and a.check_acyclic(dom):
94
+ if min_area is None or a.pattern < min_area.pattern:
95
+ min_area, forward_fuse = a, True
96
+ return ([min_area], forward_fuse) if min_area else ([], False)
97
97
 
98
98
  @staticmethod
99
99
  def isolate_reshape(dom):
100
100
  """fuse strategy for isolate reshape dom"""
101
101
  if dom.pattern != PrimLib.RESHAPE or len(dom.ops) != 1:
102
- return []
102
+ return [], False
103
103
  for a, _ in dom.out_relations.items():
104
104
  if a.mode == GraphSplitByPattern.Area.MODE_COMPOSITE and dom.check_acyclic(a):
105
105
  return [a], False
@@ -107,59 +107,61 @@ class CommonPattern:
107
107
  if a.mode == GraphSplitByPattern.Area.MODE_COMPOSITE and a.pattern <= PrimLib.BROADCAST and \
108
108
  a.check_acyclic(dom):
109
109
  return [a], True
110
- return []
110
+ return [], False
111
111
 
112
112
  @staticmethod
113
113
  def elemwise_depth(dom):
114
114
  """fuse strategy in depth for elemwise dom"""
115
115
  if dom.pattern != PrimLib.ELEMWISE or len(dom.in_relations) != 1:
116
- return []
116
+ return [], False
117
117
  a, r = list(dom.in_relations.items())[0]
118
- if a.pattern > PrimLib.ELEMWISE or len(a.out_relations) != 1 or r > PrimLib.ELEMWISE or \
119
- tensor_size(a.dom_op().output) != tensor_size(dom.dom_op().output):
120
- return []
118
+ if a.pattern > PrimLib.ELEMWISE or len(a.out_relations) != 1 or r > PrimLib.ELEMWISE:
119
+ return [], False
120
+ if tensor_size(a.dom_op().output) != tensor_size(dom.dom_op().output):
121
+ return [], False
121
122
  return [a], True
122
123
 
123
124
  @staticmethod
124
125
  def elemwise_width(dom):
125
126
  """fuse strategy in width for elemwise dom"""
126
127
  if dom.pattern != PrimLib.ELEMWISE:
127
- return []
128
+ return [], False
128
129
  fused = []
129
130
  for a, r in dom.in_relations.items():
130
- if a.pattern <= PrimLib.ELEMWISE and r <= PrimLib.ELEMWISE and a.check_acyclic(dom) and \
131
- tensor_size(a.dom_op().output) == tensor_size(dom.dom_op().output):
132
- fused.append(a)
131
+ if a.pattern <= PrimLib.ELEMWISE and r <= PrimLib.ELEMWISE and a.check_acyclic(dom):
132
+ if tensor_size(a.dom_op().output) == tensor_size(dom.dom_op().output):
133
+ fused.append(a)
133
134
  return fused, True
134
135
 
135
136
  @staticmethod
136
137
  def broadcast_depth(dom):
137
138
  """fuse strategy in depth for broadcast dom"""
138
139
  if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or len(dom.in_relations) != 1:
139
- return []
140
+ return [], False
140
141
  a, r = list(dom.in_relations.items())[0]
141
- if a.pattern > PrimLib.BROADCAST or len(a.out_relations) != 1 or r > PrimLib.ELEMWISE or \
142
- tensor_size(a.dom_op().output) != tensor_size(dom.dom_op().output):
143
- return []
142
+ if a.pattern > PrimLib.BROADCAST or len(a.out_relations) != 1 or r > PrimLib.ELEMWISE:
143
+ return [], False
144
+ if tensor_size(a.dom_op().output) != tensor_size(dom.dom_op().output):
145
+ return [], False
144
146
  return [a], True
145
147
 
146
148
  @staticmethod
147
149
  def broadcast_width(dom):
148
150
  """fuse strategy in width for broadcast dom"""
149
151
  if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST):
150
- return []
152
+ return [], False
151
153
  fused = []
152
154
  for a, r in dom.in_relations.items():
153
- if a.pattern <= PrimLib.BROADCAST and r <= PrimLib.ELEMWISE and a.check_acyclic(dom) and \
154
- tensor_size(a.dom_op().output) == tensor_size(dom.dom_op().output):
155
- fused.append(a)
155
+ if a.pattern <= PrimLib.BROADCAST and r <= PrimLib.ELEMWISE and a.check_acyclic(dom):
156
+ if tensor_size(a.dom_op().output) == tensor_size(dom.dom_op().output):
157
+ fused.append(a)
156
158
  return fused, True
157
159
 
158
160
  @staticmethod
159
161
  def assign(dom):
160
162
  """fuse strategy for assign dom"""
161
163
  if len(dom.ops) != 1 or dom.dom_op().prim != "Assign":
162
- return []
164
+ return [], False
163
165
  fused = []
164
166
  for a, _ in dom.in_relations.items():
165
167
  fused.append(a)
@@ -711,8 +713,9 @@ class GraphSplitByPattern:
711
713
  for i in range(len(areas) - 1):
712
714
  dom = areas[i]
713
715
  for a in areas[i + 1:]:
714
- if dom.check_acyclic(a) and a.check_acyclic(dom) and \
715
- selector(dom, a) and self.limit_area_size(dom, [a], 64) and dom.fuse_confirm(a):
716
+ can_fuse = dom.check_acyclic(a) and a.check_acyclic(dom) and selector(dom, a) \
717
+ and self.limit_area_size(dom, [a], 64) and dom.fuse_confirm(a)
718
+ if can_fuse:
716
719
  dom.fuse(a)
717
720
  self.set_area_map(a.ops, dom)
718
721
  self.areas.remove(a)
@@ -844,7 +847,7 @@ class GraphSplitByPattern:
844
847
  while stack:
845
848
  op = stack.pop()
846
849
  if len(op.inputs) > 1 or PrimLib.iter_type(op) > PrimLib.BROADCAST or len(ops) > max_weight:
847
- return []
850
+ return [], []
848
851
  ops.append(op)
849
852
  for t in op.inputs:
850
853
  if t.op in area.ops:
@@ -878,8 +881,8 @@ class GraphSplitByPattern:
878
881
  return []
879
882
  result = []
880
883
  for op in borders:
881
- if prods[op]:
882
- prod_ops, inputs = prods[op]
884
+ prod_ops, inputs = prods[op]
885
+ if prod_ops:
883
886
  if sum([t.get_size() for t in inputs]) <= op.output.get_size():
884
887
  pred = self.area_map.get(inputs[0].op) if inputs and inputs[0].op else None
885
888
  result.append([pred, prod_ops[::-1]])
@@ -938,23 +941,25 @@ class GraphSplitGpu(GraphSplitByPattern):
938
941
  return a.pattern > PrimLib.REDUCE or r > PrimLib.BROADCAST
939
942
 
940
943
  def _broadcast_bwd_depth(dom):
941
- if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or len(dom.out_relations) != 1 or \
942
- dom.is_output or len(dom.ops) > self.BROADCAST_FUSE_DEPTH:
943
- return []
944
+ if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or len(dom.out_relations) != 1:
945
+ return [], False
946
+ if dom.is_output or len(dom.ops) > self.BROADCAST_FUSE_DEPTH:
947
+ return [], False
944
948
  a, r = list(dom.out_relations.items())[0]
945
949
  if _broadcast_pat_exclude(dom, a, r) or len(a.in_relations) != 1:
946
- return []
950
+ return [], False
947
951
  return [a], False
948
952
 
949
953
  def _broadcast_bwd_width(dom):
950
954
  if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or \
951
955
  dom.is_output or len(dom.ops) > self.BROADCAST_FUSE_DEPTH:
952
- return []
956
+ return [], False
953
957
  fused = []
954
958
  for a, r in dom.out_relations.items():
955
- if _broadcast_pat_exclude(dom, a, r) or not dom.check_acyclic(a) or \
956
- (fused and tensor_size(fused[0].dom_op().output) != tensor_size(a.dom_op().output)):
957
- return []
959
+ if _broadcast_pat_exclude(dom, a, r) or not dom.check_acyclic(a):
960
+ return [], False
961
+ if fused and tensor_size(fused[0].dom_op().output) != tensor_size(a.dom_op().output):
962
+ return [], False
958
963
  fused.append(a)
959
964
  return fused, False
960
965
 
@@ -965,25 +970,25 @@ class GraphSplitGpu(GraphSplitByPattern):
965
970
 
966
971
  def _reduce_depth(dom):
967
972
  if dom.pattern != PrimLib.REDUCE or len(dom.in_relations) != 1:
968
- return []
973
+ return [], False
969
974
  a, r = list(dom.in_relations.items())[0]
970
- if dom.ops[0].inputs[0].dtype == "float16" and a.is_output and len(a.ops) >= 10 and \
971
- _is_atomic_add_available(dom):
972
- # to evade the precision problem.
973
- return []
975
+ if dom.ops[0].inputs[0].dtype == "float16" and a.is_output:
976
+ if len(a.ops) >= 10 and _is_atomic_add_available(dom):
977
+ # to evade the precision problem.
978
+ return [], False
974
979
  if _reduce_pat_exclude(dom, a, r) or len(a.out_relations) != 1:
975
980
  return []
976
981
  return [a], True
977
982
 
978
983
  def _reduce_width(dom):
979
984
  if dom.pattern != PrimLib.REDUCE:
980
- return []
985
+ return [], False
981
986
  fused = []
982
987
  for a, r in dom.in_relations.items():
983
- if dom.ops[0].inputs[0].dtype == "float16" and a.is_output and len(a.ops) >= 10 and \
984
- _is_atomic_add_available(dom):
985
- # to evade the precision problem.
986
- continue
988
+ if dom.ops[0].inputs[0].dtype == "float16" and a.is_output:
989
+ if len(a.ops) >= 10 and _is_atomic_add_available(dom):
990
+ # to evade the precision problem.
991
+ continue
987
992
  if not _reduce_pat_exclude(dom, a, r) and a.check_acyclic(dom):
988
993
  fused.append(a)
989
994
  return fused, True
@@ -1016,15 +1021,15 @@ class GraphSplitGpu(GraphSplitByPattern):
1016
1021
 
1017
1022
  def _reduce_output(dom):
1018
1023
  if dom.pattern != PrimLib.REDUCE:
1019
- return []
1024
+ return [], False
1020
1025
  if _may_multi_filter(dom.ops):
1021
- return []
1026
+ return [], False
1022
1027
  if _is_atomic_add_available(dom):
1023
- return []
1028
+ return [], False
1024
1029
  is_all_reduce = tensor_size(dom.ops[0].output) == 1
1025
1030
  # excluded large size all reduce
1026
1031
  if is_all_reduce and dom.ops[0].inputs and tensor_size(dom.ops[0].inputs[0]) > 1024 * 12:
1027
- return []
1032
+ return [], False
1028
1033
 
1029
1034
  fused = []
1030
1035
  for a, r in dom.out_relations.items():
@@ -1034,11 +1039,11 @@ class GraphSplitGpu(GraphSplitByPattern):
1034
1039
 
1035
1040
  def _reduce_stitch(dom):
1036
1041
  if dom.pattern != PrimLib.REDUCE:
1037
- return []
1042
+ return [], False
1038
1043
  if tensor_size(dom.ops[0].output) == 1:
1039
- return []
1044
+ return [], False
1040
1045
  if tensor_size(dom.ops[0].inputs[0]) < 1024 * 12:
1041
- return []
1046
+ return [], False
1042
1047
 
1043
1048
  fused = []
1044
1049
  for a, r in dom.out_relations.items():
@@ -1055,7 +1060,7 @@ class GraphSplitGpu(GraphSplitByPattern):
1055
1060
 
1056
1061
  def _transpose(dom):
1057
1062
  if len(dom.ops) != 1 or dom.ops[0].prim != "Transpose":
1058
- return []
1063
+ return [], False
1059
1064
  fused = []
1060
1065
  for a, _ in dom.in_relations.items():
1061
1066
  if a.pattern <= PrimLib.BROADCAST and a.check_acyclic(dom) and len(a.ops) <= self.TRANSPOSE_FUSE_DEPTH:
@@ -1064,7 +1069,7 @@ class GraphSplitGpu(GraphSplitByPattern):
1064
1069
 
1065
1070
  def _strided_slice(dom):
1066
1071
  if dom.dom_op().prim != "StridedSlice":
1067
- return []
1072
+ return [], False
1068
1073
  fused = []
1069
1074
  for a, _ in dom.in_relations.items():
1070
1075
  if a.pattern <= PrimLib.BROADCAST and a.check_acyclic(dom) and \
@@ -1075,7 +1080,7 @@ class GraphSplitGpu(GraphSplitByPattern):
1075
1080
  def _gather_output(dom, reduce_fusion=False):
1076
1081
  gather_prims = ("Gather", "GatherNd", "CSRGather")
1077
1082
  if not dom.dom_op().prim in gather_prims:
1078
- return []
1083
+ return [], False
1079
1084
 
1080
1085
  def _reduce_exclude(op, axis_list):
1081
1086
  """ Whether this operator should be excluded.
@@ -1173,7 +1178,7 @@ class GraphSplitGpu(GraphSplitByPattern):
1173
1178
  for a, _ in dom.out_relations.items():
1174
1179
  if _shape_consistent(gather_prims, appected_areas, dom, a) and dom.check_acyclic(a):
1175
1180
  return [a], False
1176
- return []
1181
+ return [], False
1177
1182
 
1178
1183
  def _broadcast_tot(dom):
1179
1184
  """Fuse rule for TensorScatterAdd and UnsortedSegmentSum."""
@@ -1182,13 +1187,13 @@ class GraphSplitGpu(GraphSplitByPattern):
1182
1187
  return bool(set(op1.inputs) & set(op2.inputs))
1183
1188
 
1184
1189
  if len(dom.ops) != 1:
1185
- return []
1190
+ return [], False
1186
1191
 
1187
1192
  # Only fuse the first input for `TensorScatterAdd`` and the first and second input for `UnsortedSegmentSum`.
1188
1193
  fuse_arg = {"TensorScatterAdd": slice(1, None), "UnsortedSegmentSum": slice(0, 2)}
1189
1194
  arg_idx = fuse_arg.get(dom.dom_op().prim, -1)
1190
1195
  if arg_idx == -1:
1191
- return []
1196
+ return [], False
1192
1197
  fuse_tensor = dom.dom_op().inputs[arg_idx]
1193
1198
 
1194
1199
  for a, _ in dom.in_relations.items():
@@ -1200,27 +1205,30 @@ class GraphSplitGpu(GraphSplitByPattern):
1200
1205
  # Rule 2: Fuse op(reshape/elementwise/broadcast) in specified position inputs.
1201
1206
  if a.pattern <= PrimLib.BROADCAST and any((op.output in fuse_tensor for op in a.ops)):
1202
1207
  return [a], True
1203
- return []
1208
+ return [], False
1204
1209
 
1205
1210
  def _broadcast_onehot(dom, fwd=True):
1206
1211
  """Fuse rule for OneHot."""
1207
1212
  if dom.dom_op().prim != "OneHot":
1208
- return []
1213
+ return [], False
1209
1214
 
1210
1215
  fused = []
1211
1216
  neighbours = dom.in_relations.items() if fwd else dom.out_relations.items()
1212
1217
  for a, _ in neighbours:
1213
1218
  if a.pattern <= PrimLib.BROADCAST:
1214
- if (fwd and a.check_acyclic(dom) and len(a.out_relations) == 1 and not a.is_output) or \
1215
- (not fwd and dom.check_acyclic(a)):
1216
- fused.append(a)
1219
+ if fwd:
1220
+ if a.check_acyclic(dom) and len(a.out_relations) == 1 and not a.is_output:
1221
+ fused.append(a)
1222
+ else:
1223
+ if dom.check_acyclic(a):
1224
+ fused.append(a)
1217
1225
 
1218
1226
  return fused, fwd
1219
1227
 
1220
1228
  def _elemwise_elemany(dom):
1221
1229
  """Fuse rule for elemany."""
1222
1230
  if dom.dom_op().prim != "ElemAny":
1223
- return []
1231
+ return [], False
1224
1232
 
1225
1233
  fused = []
1226
1234
  for a, r in dom.in_relations.items():
@@ -1233,21 +1241,21 @@ class GraphSplitGpu(GraphSplitByPattern):
1233
1241
  """Fuse rule for injective """
1234
1242
  injective_ops = {"Transpose", "StridedSlice"}
1235
1243
  if dom.dom_op().prim not in injective_ops:
1236
- return []
1244
+ return [], False
1237
1245
  to_ops = dom.dom_op().output.to_ops
1238
1246
  if dom.is_output or len(to_ops) != 1 or len(dom.out_relations) != 1:
1239
- return []
1247
+ return [], False
1240
1248
  to_area = list(dom.out_relations.keys())[0]
1241
1249
  if (to_area.pattern >= PrimLib.REDUCE and to_area.dom_op().prim not in injective_ops) or \
1242
1250
  to_ops[0] not in to_area.ops:
1243
- return []
1251
+ return [], False
1244
1252
  if len(to_area.ops) > self.TRANSPOSE_FUSE_DEPTH:
1245
- return []
1253
+ return [], False
1246
1254
  return [to_area], False
1247
1255
 
1248
1256
  def _h_broadcast(dom, a):
1249
1257
  if dom.pattern > PrimLib.BROADCAST:
1250
- return []
1258
+ return [], False
1251
1259
  return a.pattern <= PrimLib.BROADCAST and dom.ops[0].output.shape == a.ops[0].output.shape
1252
1260
 
1253
1261
  def _h_reduce(dom, a):
@@ -1274,7 +1282,7 @@ class GraphSplitGpu(GraphSplitByPattern):
1274
1282
  fuse_arg = {"CSRReduceSum": slice(1, 3), "CSRGather": slice(2, 3)}
1275
1283
  arg_idx = fuse_arg.get(dom.dom_op().prim, -1)
1276
1284
  if arg_idx == -1:
1277
- return []
1285
+ return [], False
1278
1286
  fuse_tensor = dom.dom_op().inputs[arg_idx]
1279
1287
  for a, _ in dom.in_relations.items():
1280
1288
  if (a.dom_op().prim == "CSRGather" and a.dom_op().prim == dom.dom_op().prim and
@@ -1283,7 +1291,7 @@ class GraphSplitGpu(GraphSplitByPattern):
1283
1291
  if a.pattern <= PrimLib.BROADCAST and dom.check_acyclic(a) and \
1284
1292
  any([op.output in fuse_tensor for op in a.ops]):
1285
1293
  return [a], True
1286
- return []
1294
+ return [], False
1287
1295
 
1288
1296
  def _fuse_loop():
1289
1297
  self.fuse(CommonPattern.reshape)