mindspore 2.7.0__cp310-cp310-win_amd64.whl → 2.7.1__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (290) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
  6. mindspore/_extends/parse/compile_config.py +24 -1
  7. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +6 -2
  8. mindspore/_extends/parse/resources.py +1 -1
  9. mindspore/_extends/parse/standard_method.py +8 -1
  10. mindspore/_extends/parse/trope.py +2 -1
  11. mindspore/_extends/pijit/pijit_func_white_list.py +7 -22
  12. mindspore/avcodec-59.dll +0 -0
  13. mindspore/avdevice-59.dll +0 -0
  14. mindspore/avfilter-8.dll +0 -0
  15. mindspore/avformat-59.dll +0 -0
  16. mindspore/avutil-57.dll +0 -0
  17. mindspore/boost/base.py +29 -2
  18. mindspore/common/_decorator.py +3 -2
  19. mindspore/common/_grad_function.py +3 -1
  20. mindspore/common/_tensor_cpp_method.py +1 -1
  21. mindspore/common/_tensor_docs.py +275 -64
  22. mindspore/common/_utils.py +0 -44
  23. mindspore/common/api.py +285 -35
  24. mindspore/common/dump.py +7 -108
  25. mindspore/common/dynamic_shape/auto_dynamic_shape.py +1 -3
  26. mindspore/common/hook_handle.py +60 -0
  27. mindspore/common/jit_config.py +5 -1
  28. mindspore/common/jit_trace.py +27 -12
  29. mindspore/common/lazy_inline.py +5 -3
  30. mindspore/common/parameter.py +13 -107
  31. mindspore/common/recompute.py +4 -11
  32. mindspore/common/tensor.py +16 -169
  33. mindspore/communication/_comm_helper.py +11 -1
  34. mindspore/communication/comm_func.py +138 -4
  35. mindspore/communication/management.py +85 -1
  36. mindspore/config/op_info.config +0 -15
  37. mindspore/context.py +5 -85
  38. mindspore/dataset/engine/datasets.py +8 -4
  39. mindspore/dataset/engine/datasets_vision.py +1 -1
  40. mindspore/dataset/engine/validators.py +1 -15
  41. mindspore/dnnl.dll +0 -0
  42. mindspore/{experimental/llm_boost/ascend_native → graph}/__init__.py +7 -7
  43. mindspore/graph/custom_pass.py +55 -0
  44. mindspore/include/dataset/execute.h +2 -2
  45. mindspore/jpeg62.dll +0 -0
  46. mindspore/mindrecord/__init__.py +3 -3
  47. mindspore/mindrecord/common/exceptions.py +1 -0
  48. mindspore/mindrecord/config.py +1 -1
  49. mindspore/{parallel/mpi → mindrecord/core}/__init__.py +4 -1
  50. mindspore/mindrecord/{shardheader.py → core/shardheader.py} +2 -1
  51. mindspore/mindrecord/{shardindexgenerator.py → core/shardindexgenerator.py} +1 -1
  52. mindspore/mindrecord/{shardreader.py → core/shardreader.py} +2 -1
  53. mindspore/mindrecord/{shardsegment.py → core/shardsegment.py} +2 -2
  54. mindspore/mindrecord/{shardutils.py → core/shardutils.py} +1 -1
  55. mindspore/mindrecord/{shardwriter.py → core/shardwriter.py} +1 -1
  56. mindspore/mindrecord/filereader.py +4 -4
  57. mindspore/mindrecord/filewriter.py +5 -5
  58. mindspore/mindrecord/mindpage.py +2 -2
  59. mindspore/mindrecord/tools/cifar10.py +1 -1
  60. mindspore/mindrecord/tools/cifar100.py +1 -1
  61. mindspore/mindrecord/tools/cifar100_to_mr.py +1 -1
  62. mindspore/mindrecord/tools/cifar10_to_mr.py +1 -1
  63. mindspore/mindrecord/tools/csv_to_mr.py +1 -1
  64. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  65. mindspore/mindrecord/tools/mnist_to_mr.py +1 -1
  66. mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -1
  67. mindspore/mindspore_backend_common.dll +0 -0
  68. mindspore/mindspore_backend_manager.dll +0 -0
  69. mindspore/mindspore_cluster.dll +0 -0
  70. mindspore/mindspore_common.dll +0 -0
  71. mindspore/mindspore_core.dll +0 -0
  72. mindspore/mindspore_cpu.dll +0 -0
  73. mindspore/mindspore_dump.dll +0 -0
  74. mindspore/mindspore_frontend.dll +0 -0
  75. mindspore/mindspore_glog.dll +0 -0
  76. mindspore/mindspore_hardware_abstract.dll +0 -0
  77. mindspore/mindspore_memory_pool.dll +0 -0
  78. mindspore/mindspore_ms_backend.dll +0 -0
  79. mindspore/mindspore_ops.dll +0 -0
  80. mindspore/{mindspore_ops_host.dll → mindspore_ops_cpu.dll} +0 -0
  81. mindspore/mindspore_profiler.dll +0 -0
  82. mindspore/mindspore_pyboost.dll +0 -0
  83. mindspore/mindspore_pynative.dll +0 -0
  84. mindspore/mindspore_runtime_pipeline.dll +0 -0
  85. mindspore/mindspore_runtime_utils.dll +0 -0
  86. mindspore/mindspore_tools.dll +0 -0
  87. mindspore/mint/__init__.py +15 -10
  88. mindspore/mint/distributed/distributed.py +182 -62
  89. mindspore/mint/nn/__init__.py +2 -16
  90. mindspore/mint/nn/functional.py +4 -110
  91. mindspore/mint/nn/layer/__init__.py +0 -2
  92. mindspore/mint/nn/layer/activation.py +0 -6
  93. mindspore/mint/nn/layer/basic.py +0 -47
  94. mindspore/mint/nn/layer/conv.py +4 -4
  95. mindspore/mint/nn/layer/normalization.py +8 -13
  96. mindspore/mint/nn/layer/pooling.py +0 -4
  97. mindspore/nn/__init__.py +1 -3
  98. mindspore/nn/cell.py +16 -66
  99. mindspore/nn/layer/basic.py +49 -1
  100. mindspore/nn/layer/container.py +16 -0
  101. mindspore/nn/layer/embedding.py +4 -169
  102. mindspore/nn/layer/normalization.py +2 -1
  103. mindspore/nn/layer/thor_layer.py +4 -85
  104. mindspore/nn/optim/ada_grad.py +0 -1
  105. mindspore/nn/optim/adafactor.py +0 -1
  106. mindspore/nn/optim/adam.py +31 -124
  107. mindspore/nn/optim/adamax.py +0 -1
  108. mindspore/nn/optim/asgd.py +0 -1
  109. mindspore/nn/optim/ftrl.py +8 -102
  110. mindspore/nn/optim/lamb.py +0 -1
  111. mindspore/nn/optim/lars.py +0 -3
  112. mindspore/nn/optim/lazyadam.py +25 -218
  113. mindspore/nn/optim/momentum.py +5 -43
  114. mindspore/nn/optim/optimizer.py +6 -55
  115. mindspore/nn/optim/proximal_ada_grad.py +0 -1
  116. mindspore/nn/optim/rmsprop.py +0 -1
  117. mindspore/nn/optim/rprop.py +0 -1
  118. mindspore/nn/optim/sgd.py +0 -1
  119. mindspore/nn/optim/tft_wrapper.py +0 -1
  120. mindspore/nn/optim/thor.py +0 -2
  121. mindspore/nn/probability/bijector/bijector.py +7 -8
  122. mindspore/nn/probability/bijector/gumbel_cdf.py +2 -2
  123. mindspore/nn/probability/bijector/power_transform.py +20 -21
  124. mindspore/nn/probability/bijector/scalar_affine.py +5 -5
  125. mindspore/nn/probability/bijector/softplus.py +13 -14
  126. mindspore/nn/wrap/grad_reducer.py +4 -74
  127. mindspore/numpy/array_creations.py +2 -2
  128. mindspore/numpy/fft.py +9 -9
  129. mindspore/{nn/reinforcement → onnx}/__init__.py +5 -8
  130. mindspore/onnx/onnx_export.py +137 -0
  131. mindspore/opencv_core4110.dll +0 -0
  132. mindspore/opencv_imgcodecs4110.dll +0 -0
  133. mindspore/{opencv_imgproc452.dll → opencv_imgproc4110.dll} +0 -0
  134. mindspore/ops/__init__.py +2 -0
  135. mindspore/ops/_grad_experimental/grad_comm_ops.py +38 -2
  136. mindspore/ops/_op_impl/aicpu/__init__.py +0 -10
  137. mindspore/ops/_op_impl/cpu/__init__.py +0 -5
  138. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +16 -22
  139. mindspore/ops/auto_generate/gen_extend_func.py +2 -7
  140. mindspore/ops/auto_generate/gen_ops_def.py +98 -141
  141. mindspore/ops/auto_generate/gen_ops_prim.py +12708 -12686
  142. mindspore/ops/communication.py +97 -0
  143. mindspore/ops/composite/__init__.py +5 -2
  144. mindspore/ops/composite/base.py +15 -1
  145. mindspore/ops/composite/multitype_ops/__init__.py +3 -1
  146. mindspore/ops/composite/multitype_ops/_compile_utils.py +150 -8
  147. mindspore/ops/composite/multitype_ops/add_impl.py +7 -0
  148. mindspore/ops/composite/multitype_ops/mod_impl.py +27 -0
  149. mindspore/ops/function/__init__.py +1 -0
  150. mindspore/ops/function/array_func.py +14 -12
  151. mindspore/ops/function/comm_func.py +3883 -0
  152. mindspore/ops/function/debug_func.py +3 -4
  153. mindspore/ops/function/math_func.py +45 -54
  154. mindspore/ops/function/nn_func.py +75 -294
  155. mindspore/ops/function/random_func.py +9 -18
  156. mindspore/ops/functional.py +2 -0
  157. mindspore/ops/functional_overload.py +354 -18
  158. mindspore/ops/operations/__init__.py +2 -5
  159. mindspore/ops/operations/_custom_ops_utils.py +7 -9
  160. mindspore/ops/operations/_inner_ops.py +1 -38
  161. mindspore/ops/operations/_rl_inner_ops.py +0 -933
  162. mindspore/ops/operations/array_ops.py +1 -0
  163. mindspore/ops/operations/comm_ops.py +94 -2
  164. mindspore/ops/operations/custom_ops.py +228 -19
  165. mindspore/ops/operations/debug_ops.py +27 -29
  166. mindspore/ops/operations/manually_defined/ops_def.py +27 -306
  167. mindspore/ops/operations/nn_ops.py +2 -2
  168. mindspore/ops/operations/sparse_ops.py +0 -83
  169. mindspore/ops/primitive.py +1 -17
  170. mindspore/ops/tensor_method.py +72 -3
  171. mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +5 -5
  172. mindspore/ops_generate/aclnn/gen_aclnn_implement.py +8 -8
  173. mindspore/ops_generate/api/functions_cc_generator.py +53 -4
  174. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +25 -11
  175. mindspore/ops_generate/common/gen_constants.py +11 -10
  176. mindspore/ops_generate/common/op_proto.py +18 -1
  177. mindspore/ops_generate/common/template.py +102 -245
  178. mindspore/ops_generate/common/template_utils.py +212 -0
  179. mindspore/ops_generate/gen_custom_ops.py +69 -0
  180. mindspore/ops_generate/op_def/ops_def_cc_generator.py +78 -7
  181. mindspore/ops_generate/op_def_py/base_op_prim_py_generator.py +360 -0
  182. mindspore/ops_generate/op_def_py/custom_op_prim_py_generator.py +140 -0
  183. mindspore/ops_generate/op_def_py/op_def_py_generator.py +54 -7
  184. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -312
  185. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +74 -17
  186. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +22 -5
  187. mindspore/ops_generate/pyboost/op_template_parser.py +3 -2
  188. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +21 -5
  189. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +2 -2
  190. mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +30 -10
  191. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +10 -3
  192. mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +1 -1
  193. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +19 -9
  194. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +71 -28
  195. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +10 -9
  196. mindspore/ops_generate/pyboost/pyboost_utils.py +27 -16
  197. mindspore/ops_generate/resources/yaml_loader.py +13 -0
  198. mindspore/ops_generate/tensor_py_cc_generator.py +2 -2
  199. mindspore/parallel/_cell_wrapper.py +1 -1
  200. mindspore/parallel/_parallel_serialization.py +1 -4
  201. mindspore/parallel/_utils.py +29 -6
  202. mindspore/parallel/checkpoint_transform.py +18 -2
  203. mindspore/parallel/cluster/process_entity/_api.py +24 -32
  204. mindspore/parallel/cluster/process_entity/_utils.py +9 -5
  205. mindspore/{experimental/llm_boost/atb → parallel/distributed}/__init__.py +21 -23
  206. mindspore/parallel/distributed/distributed_data_parallel.py +393 -0
  207. mindspore/parallel/distributed/flatten_grad_buffer.py +295 -0
  208. mindspore/parallel/strategy.py +336 -0
  209. mindspore/parallel/transform_safetensors.py +117 -16
  210. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +3 -0
  211. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +1 -1
  212. mindspore/profiler/common/constant.py +5 -0
  213. mindspore/profiler/common/file_manager.py +9 -0
  214. mindspore/profiler/common/msprof_cmd_tool.py +38 -2
  215. mindspore/profiler/common/path_manager.py +56 -24
  216. mindspore/profiler/common/profiler_context.py +2 -12
  217. mindspore/profiler/common/profiler_info.py +3 -3
  218. mindspore/profiler/common/profiler_path_manager.py +13 -0
  219. mindspore/profiler/common/util.py +30 -3
  220. mindspore/profiler/experimental_config.py +2 -1
  221. mindspore/profiler/platform/npu_profiler.py +33 -6
  222. mindspore/run_check/_check_version.py +108 -24
  223. mindspore/runtime/__init__.py +3 -2
  224. mindspore/runtime/executor.py +11 -3
  225. mindspore/runtime/memory.py +112 -0
  226. mindspore/swresample-4.dll +0 -0
  227. mindspore/swscale-6.dll +0 -0
  228. mindspore/tinyxml2.dll +0 -0
  229. mindspore/{experimental/llm_boost → tools}/__init__.py +5 -5
  230. mindspore/tools/data_dump.py +130 -0
  231. mindspore/tools/sdc_detect.py +91 -0
  232. mindspore/tools/stress_detect.py +63 -0
  233. mindspore/train/__init__.py +6 -6
  234. mindspore/train/_utils.py +5 -18
  235. mindspore/train/amp.py +6 -4
  236. mindspore/train/callback/_checkpoint.py +0 -9
  237. mindspore/train/callback/_train_fault_tolerance.py +69 -18
  238. mindspore/train/data_sink.py +1 -5
  239. mindspore/train/model.py +38 -211
  240. mindspore/train/serialization.py +126 -387
  241. mindspore/turbojpeg.dll +0 -0
  242. mindspore/utils/__init__.py +6 -3
  243. mindspore/utils/dlpack.py +92 -0
  244. mindspore/utils/dryrun.py +1 -1
  245. mindspore/utils/runtime_execution_order_check.py +10 -0
  246. mindspore/utils/sdc_detect.py +14 -12
  247. mindspore/utils/stress_detect.py +43 -0
  248. mindspore/utils/utils.py +144 -8
  249. mindspore/version.py +1 -1
  250. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/METADATA +3 -2
  251. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/RECORD +254 -267
  252. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -210
  253. mindspore/experimental/llm_boost/ascend_native/llm_boost.py +0 -52
  254. mindspore/experimental/llm_boost/atb/boost_base.py +0 -385
  255. mindspore/experimental/llm_boost/atb/llama_boost.py +0 -137
  256. mindspore/experimental/llm_boost/atb/qwen_boost.py +0 -124
  257. mindspore/experimental/llm_boost/register.py +0 -130
  258. mindspore/experimental/llm_boost/utils.py +0 -31
  259. mindspore/include/OWNERS +0 -7
  260. mindspore/mindspore_cpu_res_manager.dll +0 -0
  261. mindspore/mindspore_ops_kernel_common.dll +0 -0
  262. mindspore/mindspore_res_manager.dll +0 -0
  263. mindspore/nn/optim/_dist_optimizer_registry.py +0 -111
  264. mindspore/nn/reinforcement/_batch_read_write.py +0 -142
  265. mindspore/nn/reinforcement/_tensors_queue.py +0 -152
  266. mindspore/nn/reinforcement/tensor_array.py +0 -145
  267. mindspore/opencv_core452.dll +0 -0
  268. mindspore/opencv_imgcodecs452.dll +0 -0
  269. mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +0 -113
  270. mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +0 -96
  271. mindspore/ops/_op_impl/aicpu/sparse_cross.py +0 -42
  272. mindspore/ops/_op_impl/cpu/buffer_append.py +0 -28
  273. mindspore/ops/_op_impl/cpu/buffer_get.py +0 -28
  274. mindspore/ops/_op_impl/cpu/buffer_sample.py +0 -28
  275. mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +0 -42
  276. mindspore/ops/operations/_tensor_array.py +0 -359
  277. mindspore/ops/operations/rl_ops.py +0 -288
  278. mindspore/parallel/_offload_context.py +0 -275
  279. mindspore/parallel/_recovery_context.py +0 -115
  280. mindspore/parallel/_transformer/__init__.py +0 -35
  281. mindspore/parallel/_transformer/layers.py +0 -765
  282. mindspore/parallel/_transformer/loss.py +0 -251
  283. mindspore/parallel/_transformer/moe.py +0 -693
  284. mindspore/parallel/_transformer/op_parallel_config.py +0 -222
  285. mindspore/parallel/_transformer/transformer.py +0 -3124
  286. mindspore/parallel/mpi/_mpi_config.py +0 -116
  287. mindspore/train/memory_profiling_pb2.py +0 -298
  288. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/WHEEL +0 -0
  289. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/entry_points.txt +0 -0
  290. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/top_level.txt +0 -0
mindspore/common/api.py CHANGED
@@ -44,7 +44,7 @@ from mindspore.common.sparse_tensor import RowTensor as PythonRowTensor
44
44
  from mindspore._c_expression.amp import get_curr_amp_strategy
45
45
  from mindspore._c_expression import GraphExecutor_, JitExecutor_, CSRTensor, RowTensor, COOTensor, \
46
46
  PyNativeExecutor_, verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_pipeline, \
47
- _run_jit_pipeline, _ms_memory_recycle, _bind_device_ctx, MSContext, TensorPy as Tensor
47
+ _run_jit_pipeline, _ms_memory_recycle, _bind_device_ctx, TensorPy as Tensor, dump_func_graph, _GraphFragment_
48
48
  from mindspore.parallel._ps_context import _is_role_sched
49
49
  from mindspore.parallel._utils import _check_full_batch, _get_parameter_broadcast, _is_in_auto_parallel_mode, \
50
50
  _is_parallel_mode
@@ -208,6 +208,11 @@ def _handle_func_args(func, *args, **kwargs):
208
208
  args = bound_arguments.args
209
209
  kwargs = bound_arguments.kwargs
210
210
 
211
+ return args, kwargs
212
+
213
+
214
+ def _check_func_args(func, *args):
215
+ """Check the *args inputs of the function"""
211
216
  positional_args = 0
212
217
  default_args = 0
213
218
  has_var = False
@@ -221,14 +226,13 @@ def _handle_func_args(func, *args, **kwargs):
221
226
  default_args += 1
222
227
 
223
228
  if has_var:
224
- return args, kwargs
229
+ return
225
230
 
226
231
  if len(args) < positional_args:
227
232
  raise TypeError(f"Function {func.__name__} needs {positional_args} positional argument, but got {len(args)}.")
228
233
  if len(args) > positional_args + default_args:
229
234
  raise TypeError(f"Function {func.__name__} needs {positional_args} positional argument and {default_args} "
230
235
  f"default argument, total {positional_args + default_args}, but got {len(args)}.")
231
- return args, kwargs
232
236
 
233
237
 
234
238
  sys_path = list(sys.path)
@@ -349,7 +353,7 @@ def _get_parameter_layout():
349
353
  return layout
350
354
 
351
355
 
352
- def _handle_arg(obj, arg, has_mutable_arg):
356
+ def _handle_arg(obj, arg, has_mutable_arg, is_predict):
353
357
  """Handle arg for runtime .If need handle the arg, return True"""
354
358
  from mindspore._extends.parse import compile_config
355
359
  if isinstance(arg, PythonTensor):
@@ -364,7 +368,7 @@ def _handle_arg(obj, arg, has_mutable_arg):
364
368
  if isinstance(arg, list) and not arg:
365
369
  return None
366
370
  return arg
367
- elif (context.get_context("grad_for_scalar") or str(compile_config.GRAD_FOR_SCALAR) == '1') and \
371
+ elif not is_predict and (context.get_context("grad_for_scalar") or str(compile_config.GRAD_FOR_SCALAR) == '1') and \
368
372
  isinstance(arg, (int, float)):
369
373
  return arg
370
374
  elif hasattr(obj, "enable_tuple_broaden") and obj.enable_tuple_broaden and isinstance(arg, tuple) and \
@@ -394,17 +398,16 @@ def _handle_arg_predict(obj, arg, has_mutable_arg):
394
398
  return arg
395
399
 
396
400
 
397
- def _get_args_for_run(obj, args, kwargs, has_mutable_args_list, is_predict):
401
+ def _get_args_for_run(obj, args, kwargs, has_mutable_args_list, is_predict=False):
398
402
  """Get the actual input args and kwargs for runtime."""
399
403
  new_args = []
400
- fn = _handle_arg_predict if is_predict else _handle_arg
401
404
  for arg, has_mutable_arg in zip(args, has_mutable_args_list):
402
- new_arg = fn(obj, arg, has_mutable_arg)
405
+ new_arg = _handle_arg(obj, arg, has_mutable_arg, is_predict)
403
406
  if new_arg is not None:
404
407
  new_args.append(new_arg)
405
408
 
406
409
  for _, value in kwargs.items():
407
- new_value = fn(obj, value, None)
410
+ new_value = _handle_arg(obj, value, None, is_predict)
408
411
  if new_value is not None:
409
412
  new_args.append(new_value)
410
413
 
@@ -609,7 +612,7 @@ class _JitExecutor:
609
612
  else:
610
613
  self._graph_executor = GraphExecutor_.get_instance()
611
614
  self._create_time = ms_create_time
612
- self._compile_args = None
615
+ self._mutable_flags = None
613
616
  self._enable_auto_dynamic = dynamic == 1
614
617
  self.jit_config_dict = jit_config.jit_config_dict if jit_config else None
615
618
  self._cell_cache_key_extend = cell_cache_key_extend
@@ -634,16 +637,8 @@ class _JitExecutor:
634
637
  except Exception as err:
635
638
  _pynative_executor.clear_res()
636
639
  raise err
637
- else: # get compiled args to generate run args by _generate_run_args
638
- compile_args = self._generate_compile_args(args_list)
639
- key_id = self._get_key_id()
640
- if self.input_signature is None:
641
- compile_args = get_auto_dynamic_shape_args(
642
- compile_args, key_id, self._enable_auto_dynamic
643
- )
644
- self._compile_args = compile_args
645
640
 
646
- new_inputs = self._generate_run_args(args_list, kwargs)
641
+ new_inputs = self._generate_run_args(args_list, kwargs, is_predict=True)
647
642
  if self.jit_config_dict:
648
643
  jit_config_dict = self.jit_config_dict
649
644
  else:
@@ -656,11 +651,25 @@ class _JitExecutor:
656
651
  res = _convert_python_data(output)
657
652
  return True, res
658
653
 
654
+ def compile_frontend(self, *args, **kwargs):
655
+ """Only compile to the frontend graph."""
656
+ args_list = args
657
+ if self.obj is not None:
658
+ args_list = args_list[1:]
659
+ os.environ['MS_DEV_PRECOMPILE_ONLY'] = '1'
660
+ phase = ""
661
+ _pynative_executor.set_jit_compile_phase(phase)
662
+ phase = self.compile(self.fn.__name__, *args_list, **kwargs)
663
+ _pynative_executor.set_jit_compile_phase(phase)
664
+ os.unsetenv('MS_DEV_PRECOMPILE_ONLY')
665
+ return self._graph_executor.get_func_graph(phase), self._mutable_flags, phase, self.enable_tuple_broaden
666
+
659
667
  @_wrap_func
660
668
  def __call__(self, *args, **kwargs):
661
669
  predict, res = self._predict(*args, **kwargs)
662
670
  if predict:
663
671
  return res
672
+ _check_func_args(self.fn, *args)
664
673
  if jit_context() and jit_context().is_nested():
665
674
  return jit_context().run_graph("", None, *())
666
675
  args_list = args
@@ -668,9 +677,9 @@ class _JitExecutor:
668
677
  args_list = args_list[1:]
669
678
  phase = ""
670
679
  try:
671
- _pynative_executor.set_jit_compile_status(True, phase)
680
+ _pynative_executor.set_jit_compile_phase(phase)
672
681
  phase = self.compile(self.fn.__name__, *args_list, **kwargs)
673
- _pynative_executor.set_jit_compile_status(False, phase)
682
+ _pynative_executor.set_jit_compile_phase(phase)
674
683
  except Exception as err:
675
684
  _pynative_executor.clear_res()
676
685
  raise err
@@ -694,6 +703,7 @@ class _JitExecutor:
694
703
  def compile(self, method_name, *args, **kwargs):
695
704
  """Returns pipeline for the given args."""
696
705
  # Chose dynamic shape tensors or actual input tensors as compile args.
706
+ self._graph_executor.set_real_args(args, kwargs)
697
707
  compile_args = self._generate_compile_args(args)
698
708
  key_id = self._get_key_id()
699
709
  if self.input_signature is None:
@@ -705,7 +715,11 @@ class _JitExecutor:
705
715
  # 1) Origin args is mutable.
706
716
  # 2) Args contains sequence with gradient tensor.
707
717
  compile_args = _add_mutable_attr(args, compile_args, _pynative_executor.requires_grad())
708
- self._compile_args = compile_args
718
+ mutable_flags = _get_mutable_flags(compile_args)
719
+ self._mutable_flags = mutable_flags
720
+ # Store the _mutable_flags in the cell obj for incremental inference.
721
+ if self.obj is not None:
722
+ self.obj._mutable_flags = mutable_flags
709
723
  generate_name, echo_function_name = self._get_generate_name()
710
724
  # The full Function name
711
725
  full_function_name = generate_name
@@ -839,6 +853,7 @@ class _JitExecutor:
839
853
  else:
840
854
  _pynative_executor.set_dynamic_input(self.fn, *compile_args)
841
855
  logger.info(f"dynamic shape compile_args: {compile_args}")
856
+ Validator.check_symbolic_shape(compile_args, args_list)
842
857
  return compile_args
843
858
 
844
859
  def _generate_compile_args_by_set_inputs(self, args_list):
@@ -895,7 +910,7 @@ class _JitExecutor:
895
910
  # Case: If the shape of input args is dynamic, get dynamic shape tensor from context and use it to compile.
896
911
  return _pynative_executor.get_dynamic_input(args_list)
897
912
 
898
- def _generate_run_args(self, args_list, kwargs):
913
+ def _generate_run_args(self, args_list, kwargs, is_predict=False):
899
914
  """
900
915
  Generate input args, which are required for running.
901
916
 
@@ -906,7 +921,11 @@ class _JitExecutor:
906
921
  Returns:
907
922
  new_inputs, new input args, which are required for running.
908
923
  """
909
- return _get_args_for_run(self, args_list, kwargs, _get_mutable_flags(self._compile_args), False)
924
+ if self.obj is not None and hasattr(self.obj, '_mutable_flags'):
925
+ mutable_flags = self.obj._mutable_flags
926
+ else:
927
+ mutable_flags = self._mutable_flags
928
+ return _get_args_for_run(self, args_list, kwargs, mutable_flags, is_predict)
910
929
 
911
930
  def _get_func_graph_proto(self, obj, exec_id, ir_type="onnx_ir", use_prefix=False, incremental=False):
912
931
  """Get graph proto from pipeline."""
@@ -978,7 +997,7 @@ def _check_option_backend(option, backend):
978
997
  'ge_options': ['GE'],
979
998
  'infer_boost': ['ms_backend'],
980
999
  }
981
- if option in option_backend_cfgs and backend not in option_backend_cfgs[option]:
1000
+ if option in option_backend_cfgs and backend != '' and backend not in option_backend_cfgs[option]:
982
1001
  logger.warning(f"For 'jit(options)', the option '{option}' is only support backend in "
983
1002
  f"'{option_backend_cfgs[option]}', but got '{backend}', ignore it.")
984
1003
 
@@ -1187,8 +1206,10 @@ def jit(
1187
1206
  - ms_backend: Utilizes the built-in backend engine of MindSpore for hardware-related compilation
1188
1207
  optimization and execution, supporting multiple hardware forms such as Ascend, GPU, and CPU.
1189
1208
  - GE: Utilizes the GraphEngine, a graph compilation and execution engine within CANN,
1190
- for Ascend model compilation and execution. Note: This backend takes effect only in static graph mode
1191
- and can be executed only on Ascend hardware.
1209
+ for Ascend model compilation and execution. Note: This backend only supports GRAPH Mode in Ascend,
1210
+ only supports whole graph sinking or sub graph sinking in pipeline parallel, and does not support
1211
+ dynamic shape scenes. In addition, this backend incurs additional compilation costs and is difficult to
1212
+ debug and tune.
1192
1213
 
1193
1214
  **options (dict): A dictionary of options to pass to the compilation backend.
1194
1215
 
@@ -1333,9 +1354,8 @@ def jit(
1333
1354
  jit_level = Validator.check_string(jit_level, ["O0", "O1"], "jit_level", "jit")
1334
1355
  dynamic = Validator.check_int_range(dynamic, 0, 1, Validator.INC_BOTH, "dynamic", "jit")
1335
1356
  fullgraph = Validator.check_bool(fullgraph, "fullgraph", "jit")
1336
- if backend == "":
1337
- backend = "GE" if MSContext.get_instance().get_ascend_soc_version() == "ascend910" else "ms_backend"
1338
- backend = Validator.check_string(backend, ["ms_backend", "GE"], "backend", "jit")
1357
+ if backend != "":
1358
+ backend = Validator.check_string(backend, ["ms_backend", "GE"], "backend", "jit")
1339
1359
  jit_syntax_level = "LAX" if fullgraph is False else "STRICT"
1340
1360
  hash_obj = _get_hash_obj(options)
1341
1361
  _check_options(options, backend)
@@ -1350,7 +1370,7 @@ def jit(
1350
1370
  elif capture_mode == "bytecode":
1351
1371
  wrap_func = PIJitCaptureContext(fullgraph=fullgraph, jit_config=jit_config)
1352
1372
  else:
1353
- wrap_func = _jit_trace()
1373
+ wrap_func = _jit_trace(jit_config)
1354
1374
 
1355
1375
  if function is not None:
1356
1376
  return wrap_func(function)
@@ -1557,6 +1577,20 @@ def _parameter_broadcast(obj):
1557
1577
  _build_broadcast_graph(broadcast_params_dict, broadcast_phase)
1558
1578
 
1559
1579
 
1580
+ def _run_in_jit():
1581
+ """In jit, this function always returns true. Otherwise, returns false."""
1582
+ def _temp_func():
1583
+ return 0
1584
+
1585
+ from mindspore.ops.primitive import constexpr
1586
+
1587
+ @constexpr(check=False)
1588
+ def _check_func(func):
1589
+ return func is None
1590
+
1591
+ return _check_func(_temp_func)
1592
+
1593
+
1560
1594
  class _no_grad(contextlib.ContextDecorator):
1561
1595
  """
1562
1596
  Context Manager to disable gradient calculation. When enter this context, we will disable calculate
@@ -1826,17 +1860,16 @@ class _PyNativeExecutor:
1826
1860
  """
1827
1861
  return self._executor.requires_grad()
1828
1862
 
1829
- def set_jit_compile_status(self, status, phase):
1863
+ def set_jit_compile_phase(self, phase):
1830
1864
  """
1831
- Set jit is compiling
1865
+ Set jit phase
1832
1866
 
1833
1867
  Args:
1834
- status(bool): jit compile status
1835
1868
  phase (str): The phase of cell/function instance.
1836
1869
  Return:
1837
1870
  None.
1838
1871
  """
1839
- self._executor.set_jit_compile_status(status, phase)
1872
+ self._executor.set_jit_compile_phase(phase)
1840
1873
 
1841
1874
  def set_is_run_recompute(self, status):
1842
1875
  """
@@ -1934,6 +1967,19 @@ class _PyNativeExecutor:
1934
1967
  """
1935
1968
  return self._executor.set_creation_type(tensor, creation_type)
1936
1969
 
1970
+ def queue_backward_final_callback(self, callback):
1971
+ """
1972
+ add backward final callback
1973
+
1974
+ Args:
1975
+ callback(Function): callback function.
1976
+
1977
+ Return:
1978
+ None.
1979
+ """
1980
+ return self._executor.queue_backward_final_callback(callback)
1981
+
1982
+
1937
1983
 
1938
1984
  class _CellGraphExecutor:
1939
1985
  """
@@ -2075,6 +2121,8 @@ class _CellGraphExecutor:
2075
2121
  obj.phase_cache[raw_phase] = phase
2076
2122
  update_auto_dynamic_shape_phase(args, key_id, phase)
2077
2123
  obj.current_phase = phase
2124
+ obj._add_attr("compile_phase", phase)
2125
+ obj.compile_phase = phase
2078
2126
  if phase in obj.compile_cache and self.has_compiled(phase):
2079
2127
  logger.debug("%r graph has existed.", phase)
2080
2128
  # Release resource should be released when CompileInner won't be executed, such as cur_convert_input_
@@ -2124,6 +2172,10 @@ class _CellGraphExecutor:
2124
2172
  new_param = {x.name: replace[x] for x in replace if id(x) != id(replace[x])}
2125
2173
  return self._graph_executor.updata_param_node_default_input(phase, new_param)
2126
2174
 
2175
+ def set_real_args(self, args, kwargs):
2176
+ """Set real arguments to graph executor."""
2177
+ self._graph_executor.set_real_args(args, kwargs)
2178
+
2127
2179
  def _get_shard_strategy(self, obj):
2128
2180
  real_phase = _real_phase(obj.phase, obj)
2129
2181
  return self._graph_executor.get_strategy(real_phase)
@@ -2213,6 +2265,19 @@ class _CellGraphExecutor:
2213
2265
  return None
2214
2266
  return self._graph_executor.get_func_graph_proto(exec_id, ir_type, incremental)
2215
2267
 
2268
+ def _get_onnx_func_graph_proto(self, obj, exec_id, use_prefix=False, input_names=None, output_names=None,
2269
+ opset_version=11, export_params=True, keep_initializers_as_inputs=False,
2270
+ dynamic_axes=None, extra_save_params=False, save_file_dir=None):
2271
+ """Get graph proto from pipeline."""
2272
+ if use_prefix:
2273
+ exec_id = exec_id + '.' + obj.arguments_key
2274
+ if self._graph_executor.has_compiled(exec_id) is False:
2275
+ return None
2276
+
2277
+ return self._graph_executor.get_onnx_func_graph_proto(exec_id, input_names, output_names, opset_version,
2278
+ export_params, keep_initializers_as_inputs, dynamic_axes,
2279
+ extra_save_params, save_file_dir)
2280
+
2216
2281
  def get_optimize_graph_proto(self, obj):
2217
2282
  """Return optimize graph binary proto."""
2218
2283
  exec_id = _real_phase(obj.phase, obj)
@@ -2295,5 +2360,190 @@ def flops_collection(phase='train'):
2295
2360
  return _cell_graph_executor.flops_collection(phase)
2296
2361
 
2297
2362
 
2363
+ class _ScriptGraph:
2364
+ """Store the graph compiled by the frontend compiler."""
2365
+ def __init__(self, func_graph, func, origin_cell, mutable_flags, phase, enable_tuple_broaden):
2366
+ self.func_graph = func_graph
2367
+ self.func = func
2368
+ self.origin_cell = origin_cell
2369
+ self.mutable_flags = mutable_flags
2370
+ self.phase = phase
2371
+ self.enable_tuple_broaden = enable_tuple_broaden
2372
+
2373
+ def print(self):
2374
+ """Print the MindIR of the frontend graph."""
2375
+ graph_str = dump_func_graph(self.func_graph)
2376
+ print(graph_str, flush=True)
2377
+
2378
+
2379
+ def _frontend_compile_ast(dynamic, jit_config, jit_graph_name=''):
2380
+ """Return the wrapped function for ast mode jit."""
2381
+ def wrap_func(func):
2382
+ if hasattr(func, "construct") and isinstance(func, ms.nn.Cell):
2383
+ # Bound the cell object to get the self arg.
2384
+ return types.MethodType(_frontend_compile_ast(dynamic, jit_config,
2385
+ func._jit_graph_name)(func.construct.__func__), func)
2386
+
2387
+ if isinstance(func, types.MethodType):
2388
+ return types.MethodType(_frontend_compile_ast(dynamic, jit_config)(func.__func__), func.__self__)
2389
+
2390
+ if not isinstance(func, types.FunctionType):
2391
+ logger.warning(f"The func should be function, method or cell instance/class, but got {func}")
2392
+ return func
2393
+
2394
+ hash_obj = int(time.time() * 1e9)
2395
+
2396
+ @wraps(func)
2397
+ def staging_specialize(*args, **kwargs):
2398
+ if os.getenv("MS_JIT") == '0':
2399
+ return func(*args, **kwargs)
2400
+
2401
+ args, kwargs = _handle_func_args(func, *args, **kwargs)
2402
+ process_obj = None
2403
+ if args and not isinstance(args[0], PythonTensor) and hasattr(args[0], func.__name__):
2404
+ process_obj = args[0]
2405
+ # Handle auto mixed precision strategy.
2406
+ if not hasattr(func, "amp_strategy"):
2407
+ setattr(get_func(func), "amp_strategy", get_curr_amp_strategy())
2408
+
2409
+ jit_graph_name = ''
2410
+ if hasattr(staging_specialize, "__jit_graph_name__"):
2411
+ jit_graph_name = staging_specialize.__jit_graph_name__
2412
+ jit_executor = _JitExecutor(func, hash_obj, None, process_obj, jit_config, dynamic, jit_graph_name)
2413
+ func_graph, mutable_flags, phase, enable_tuple_broaden = jit_executor.compile_frontend(*args, **kwargs)
2414
+ return _ScriptGraph(func_graph, func, process_obj, mutable_flags, phase, enable_tuple_broaden)
2415
+
2416
+ # `inspect.getfullargspec(func)` will get the specification of the decorated function by default. By set
2417
+ # `__signature__` for the decorated function, `inspect.getfullargspec(func)` will get the specification of
2418
+ # original `func`.
2419
+ staging_specialize.__signature__ = inspect.signature(func)
2420
+ setattr(staging_specialize, "__jit_graph_name__", jit_graph_name)
2421
+ return staging_specialize
2422
+
2423
+ return wrap_func
2424
+
2425
+
2426
+ def _frontend_compile(function: Callable,
2427
+ *,
2428
+ dynamic: int = 0,
2429
+ fullgraph: bool = False):
2430
+ """
2431
+ Create a frontend MindSpore graph from a Python function by the ast capture mode.
2432
+
2433
+ Args:
2434
+ function (Callable, optional): The Python function or Cell instance that will be compiled as a frontend graph.
2435
+ Default: ``None``.
2436
+
2437
+ Keyword Args:
2438
+ dynamic (int, optional): Whether dynamic shape compilation should be performed. Default: ``0``. The value range
2439
+ is as follows:
2440
+
2441
+ - `0`: Do not perform dynamic shape compilation.
2442
+ - `1`: Enable dynamic shape compilation and automatically detect shape changes.
2443
+
2444
+ fullgraph (bool, optional): Whether to capture the entire function into graph. If False, jit attempts to
2445
+ be compatible with all Python syntax in the function as much as possible. If True, we require that the
2446
+ entire function can be captured into graph. If this is not possible (that is, if there is Python syntax
2447
+ not supported), then it will raise an exception. This currently only applies when capture_mode is ``ast``
2448
+ or ``bytecode``. Default: ``False``.
2449
+
2450
+ Returns:
2451
+ a :class:`_ScriptGraph` object.
2452
+
2453
+ Supported Platforms:
2454
+ ``Ascend`` ``GPU`` ``CPU``
2455
+
2456
+ Examples:
2457
+ >>> import numpy as np
2458
+ >>> from mindspore import Tensor
2459
+ >>> from mindspore import ops
2460
+ >>> from mindspore.common.api import _frontend_compile
2461
+ ...
2462
+ >>> x = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
2463
+ >>> y = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
2464
+ ...
2465
+ >>> def tensor_add(x, y):
2466
+ ... z = x + y
2467
+ ... return z
2468
+ ...
2469
+ >>> tensor_add_graph = _frontend_compile(tensor_add)(x, y)
2470
+ >>> tensor_add_graph.print()
2471
+ ...
2472
+ """
2473
+
2474
+ dynamic = Validator.check_int_range(dynamic, 0, 1, Validator.INC_BOTH, "dynamic", "jit")
2475
+ fullgraph = Validator.check_bool(fullgraph, "fullgraph", "jit")
2476
+ jit_syntax_level = "LAX" if fullgraph is False else "STRICT"
2477
+ jit_config = JitConfig(jit_syntax_level=jit_syntax_level)
2478
+ return _frontend_compile_ast(dynamic, jit_config)(function)
2479
+
2480
+
2481
+ class _GraphFragment(_GraphFragment_):
2482
+ """
2483
+ Represents the output by backend graph split.
2484
+ """
2485
+ def __init__(self, frag):
2486
+ if frag is None or not isinstance(frag, _GraphFragment_):
2487
+ raise TypeError(f"Expect input `frag` to be a _GraphFragment_, but got {type(frag)}")
2488
+ _GraphFragment_.__init__(self, frag)
2489
+
2490
+ def __call__(self, *args):
2491
+ return super().__call__(args)
2492
+
2493
+ def __repr__(self):
2494
+ return self.__str__()
2495
+
2496
+ def id(self):
2497
+ return self.id_()
2498
+
2499
+ def is_graph(self):
2500
+ return self.is_graph_()
2501
+
2502
+ def py_key(self):
2503
+ return self.py_key_()
2504
+
2505
+ def args_list(self):
2506
+ return self.args_list_()
2507
+
2508
+
2509
+ def _graph_split(script_graph):
2510
+ """
2511
+ Split the script_graph into several fragments according to the nodes with the split op attribute.
2512
+
2513
+ Args:
2514
+ a :class:`_ScriptGraph` object.
2515
+
2516
+ Returns:
2517
+ several :class:`_GraphFragment` object.
2518
+
2519
+ Supported Platforms:
2520
+ ``Ascend`` ``GPU`` ``CPU``
2521
+
2522
+ Examples:
2523
+ >>> import numpy as np
2524
+ >>> from mindspore import Tensor
2525
+ >>> from mindspore import ops
2526
+ >>> from mindspore.common.api import _frontend_compile, _graph_split
2527
+ ...
2528
+ >>> x = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
2529
+ >>> y = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
2530
+ >>> add = ops.Add().add_prim_attr("split_op", True).add_prim_attr("func_id", "add_func")
2531
+ ...
2532
+ >>> def tensor_add(x, y):
2533
+ ... z1 = x + y
2534
+ ... z2 = add(z1, x)
2535
+ ... return z2
2536
+ ...
2537
+ >>> tensor_add_graph = _frontend_compile(tensor_add)(x, y)
2538
+ >>> frags = _graph_split(tensor_add_graph)
2539
+ >>> print(frags)
2540
+ ...
2541
+ """
2542
+ outputs = JitExecutor_.get_instance().split_graph(script_graph.func_graph)
2543
+ fragments = []
2544
+ for arg in outputs:
2545
+ fragments.append(_GraphFragment(arg))
2546
+ return fragments
2547
+
2298
2548
  _cell_graph_executor = _CellGraphExecutor()
2299
2549
  _pynative_executor = _PyNativeExecutor()
mindspore/common/dump.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2021-2022 Huawei Technologies Co., Ltd
1
+ # Copyright 2021-2025 Huawei Technologies Co., Ltd
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -14,115 +14,14 @@
14
14
  # ============================================================================
15
15
  """Controlling dump behavior."""
16
16
  from __future__ import absolute_import
17
- from warnings import warn
18
-
19
- import mindspore.context as context
20
- from mindspore._c_expression import security
17
+ from mindspore.tools import set_dump as tools_set_dump
18
+ from mindspore.common._decorator import deprecated
21
19
 
22
20
 
21
+ @deprecated("2.7.1", "mindspore.tools.set_dump", module_prefix="mindspore.")
23
22
  def set_dump(target, enabled=True):
24
23
  """
25
- Enable or disable dump for the `target` and its contents.
26
-
27
- `target` should be an instance of :class:`mindspore.nn.Cell` or :class:`mindspore.ops.Primitive` .
28
- Please note that this API takes effect only when the Dump function is enabled, and the `dump_mode`
29
- field in the Dump configuration file is set to `"2"` with the `ms_backend` compilation backend
30
- (please refer to the backend parameter in
31
- `jit <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.jit.html>`_).
32
- See the `dump document <https://www.mindspore.cn/tutorials/en/master/debug/dump.html>`_ for details.
33
- The default enabled status for
34
- a :class:`mindspore.nn.Cell` or :class:`mindspore.ops.Primitive` is False.
35
-
36
- Note:
37
- 1. This API is only available for JIT compilation, requires 'Ascend' as the device_target and
38
- `ms_backend` as the compilation backend (please refer to the backend parameter in
39
- `jit <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.jit.html>`_),
40
- and does not support fused operators.
41
- 2. This API only supports being called before training starts.
42
- If you call this API during training, it may not be effective.
43
- 3. After using `set_dump(Cell, True)` , operators in forward and backward
44
- computation (computation generated by the grad operations) of the
45
- cell will be dumped.
46
- 4. For :class:`mindspore.nn.SoftmaxCrossEntropyWithLogits` layer, the forward
47
- computation and backward computation use the same set of
48
- operators. So you can only see dump data from backward computation.
49
- Please note that :class:`mindspore.nn.SoftmaxCrossEntropyWithLogits` layer will also use
50
- the above operators internally when initialized with `sparse=True` and
51
- `reduction="mean"` .
52
-
53
- Args:
54
- target (Union[Cell, Primitive]): The Cell instance or Primitive instance
55
- to which the dump flag is set.
56
- enabled (bool, optional): ``True`` means enable dump, ``False`` means disable dump.
57
- Default: ``True`` .
58
-
59
- Supported Platforms:
60
- ``Ascend``
61
-
62
- Examples:
63
- .. note::
64
- Please set environment variable `MINDSPORE_DUMP_CONFIG` to the dump config file and set `dump_mode` field
65
- in dump config file to 2 before running this example.
66
- See `dump document <https://www.mindspore.cn/tutorials/en/master/debug/dump.html>`_ for details.
67
-
68
- >>> import numpy as np
69
- >>> import mindspore as ms
70
- >>> import mindspore.nn as nn
71
- >>> from mindspore import Tensor, set_dump, jit
72
- >>>
73
- >>> ms.set_device(device_target="Ascend")
74
- >>>
75
- >>> class MyNet(nn.Cell):
76
- ... def __init__(self):
77
- ... super().__init__()
78
- ... self.conv1 = nn.Conv2d(5, 6, 5, pad_mode='valid')
79
- ... self.relu1 = nn.ReLU()
80
- ...
81
- ... @jit
82
- ... def construct(self, x):
83
- ... x = self.conv1(x)
84
- ... x = self.relu1(x)
85
- ... return x
86
- >>>
87
- >>> if __name__ == "__main__":
88
- ... net = MyNet()
89
- ... set_dump(net.conv1)
90
- ... input_tensor = Tensor(np.ones([1, 5, 10, 10], dtype=np.float32))
91
- ... output = net(input_tensor)
24
+ This api will be deprecated and removed in future versions, please use the api
25
+ :func:`mindspore.tools.set_dump` instead.
92
26
  """
93
- if security.enable_security():
94
- raise ValueError('The set_dump API is not supported, please recompile '
95
- 'source without "-s on".')
96
-
97
- import mindspore.nn as nn # avoid circular import
98
- from mindspore.ops import Primitive
99
- if not isinstance(target, nn.Cell) and not isinstance(target, Primitive):
100
- raise ValueError(f"The \"target\" parameter must be an instance of "
101
- f"Cell or Primitive, "
102
- f"but got an instance of {type(target)}.")
103
-
104
- if not isinstance(enabled, bool):
105
- raise ValueError("The \"enabled\" parameter must be bool.")
106
-
107
- # Checking for device target and mode.
108
- current_target = context.get_context("device_target")
109
- if current_target != "Ascend":
110
- # We will not return here in case user changed device_target later.
111
- warn("Current device_target is {}, which is not supported by set_dump. "
112
- "Only Ascend device target is supported currently. "
113
- "If you have Ascend device, consider set device_target to Ascend "
114
- "before calling set_dump.".format(current_target))
115
-
116
- # The actual set dump logic.
117
- if isinstance(target, nn.Cell):
118
- target.add_flags(dump=enabled)
119
- for cell in target.cells():
120
- set_dump(cell, enabled)
121
-
122
- primitives = getattr(target, "_primitives", {})
123
- for value in primitives.values():
124
- if value and "dump" in value.attrs:
125
- set_dump(value, enabled)
126
-
127
- if isinstance(target, Primitive):
128
- target.add_prim_attr("dump", "true" if enabled else "false")
27
+ tools_set_dump(target, enabled)
@@ -275,9 +275,7 @@ class _AutoIdentifyDynamicShape:
275
275
  continue
276
276
  if not isinstance(elem, (list, tuple, Tensor, int, float)):
277
277
  return False
278
- if isinstance(elem, Tensor) and \
279
- self._is_invalid_shape(elem.shape) and \
280
- not enable_jit_dynamic:
278
+ if isinstance(elem, Tensor) and self._is_invalid_shape(elem.shape) and not enable_jit_dynamic:
281
279
  return False
282
280
  if not is_sink_mode and isinstance(elem, (list, tuple)):
283
281
  return self._is_enable_auto_dynamic_shape(elem, is_sink_mode, enable_jit_dynamic)