mindspore 2.7.0__cp310-cp310-win_amd64.whl → 2.7.0rc1__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (196) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +1 -1
  3. mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
  6. mindspore/_checkparam.py +2 -2
  7. mindspore/_extends/builtin_operations.py +3 -3
  8. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
  9. mindspore/_extends/parse/__init__.py +3 -3
  10. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +1 -0
  11. mindspore/_extends/parse/parser.py +22 -28
  12. mindspore/_extends/parse/standard_method.py +1 -15
  13. mindspore/_extends/pijit/pijit_func_white_list.py +5 -2
  14. mindspore/_extends/remote/kernel_build_server_ascend.py +75 -0
  15. mindspore/amp.py +18 -0
  16. mindspore/avcodec-59.dll +0 -0
  17. mindspore/avdevice-59.dll +0 -0
  18. mindspore/avfilter-8.dll +0 -0
  19. mindspore/avformat-59.dll +0 -0
  20. mindspore/avutil-57.dll +0 -0
  21. mindspore/common/__init__.py +12 -18
  22. mindspore/common/_tensor_cpp_method.py +1 -1
  23. mindspore/common/_tensor_docs.py +38 -102
  24. mindspore/common/_utils.py +1 -9
  25. mindspore/common/api.py +106 -155
  26. mindspore/common/{dynamic_shape/auto_dynamic_shape.py → auto_dynamic_shape.py} +23 -17
  27. mindspore/common/dtype.py +57 -98
  28. mindspore/common/dump.py +1 -1
  29. mindspore/common/file_system.py +9 -59
  30. mindspore/common/hook_handle.py +3 -22
  31. mindspore/common/np_dtype.py +3 -3
  32. mindspore/common/parameter.py +20 -4
  33. mindspore/common/recompute.py +4 -2
  34. mindspore/common/tensor.py +52 -38
  35. mindspore/communication/_hccl_management.py +297 -0
  36. mindspore/context.py +21 -15
  37. mindspore/dataset/__init__.py +1 -1
  38. mindspore/dataset/audio/transforms.py +1 -1
  39. mindspore/dataset/core/config.py +1 -35
  40. mindspore/dataset/engine/datasets.py +315 -330
  41. mindspore/dataset/engine/datasets_user_defined.py +22 -38
  42. mindspore/dataset/transforms/c_transforms.py +2 -2
  43. mindspore/dataset/transforms/transforms.py +3 -3
  44. mindspore/dataset/vision/__init__.py +1 -1
  45. mindspore/dataset/vision/py_transforms.py +8 -8
  46. mindspore/dataset/vision/transforms.py +5 -17
  47. mindspore/dataset/vision/utils.py +21 -632
  48. mindspore/device_context/ascend/op_tuning.py +1 -35
  49. mindspore/dnnl.dll +0 -0
  50. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -3
  51. mindspore/include/api/cell.h +4 -28
  52. mindspore/include/api/cfg.h +7 -24
  53. mindspore/include/api/context.h +0 -1
  54. mindspore/include/api/delegate.h +2 -0
  55. mindspore/include/api/dual_abi_helper.h +19 -100
  56. mindspore/include/api/graph.h +1 -14
  57. mindspore/include/api/kernel.h +3 -16
  58. mindspore/include/api/kernel_api.h +1 -9
  59. mindspore/include/api/metrics/accuracy.h +0 -9
  60. mindspore/include/api/model.h +1 -5
  61. mindspore/include/api/model_group.h +0 -4
  62. mindspore/include/api/model_parallel_runner.h +0 -2
  63. mindspore/include/api/status.h +10 -48
  64. mindspore/include/api/types.h +1 -6
  65. mindspore/include/dataset/constants.h +0 -9
  66. mindspore/jpeg62.dll +0 -0
  67. mindspore/mindrecord/tools/cifar10.py +2 -3
  68. mindspore/mindrecord/tools/cifar10_to_mr.py +5 -5
  69. mindspore/mindspore_backend_common.dll +0 -0
  70. mindspore/mindspore_backend_manager.dll +0 -0
  71. mindspore/mindspore_common.dll +0 -0
  72. mindspore/mindspore_core.dll +0 -0
  73. mindspore/mindspore_cpu_res_manager.dll +0 -0
  74. mindspore/mindspore_dump.dll +0 -0
  75. mindspore/mindspore_frontend.dll +0 -0
  76. mindspore/mindspore_glog.dll +0 -0
  77. mindspore/mindspore_memory_pool.dll +0 -0
  78. mindspore/mindspore_ms_backend.dll +0 -0
  79. mindspore/mindspore_ops.dll +0 -0
  80. mindspore/mindspore_ops_host.dll +0 -0
  81. mindspore/mindspore_ops_kernel_common.dll +0 -0
  82. mindspore/mindspore_profiler.dll +0 -0
  83. mindspore/mindspore_pyboost.dll +0 -0
  84. mindspore/mindspore_pynative.dll +0 -0
  85. mindspore/mindspore_res_manager.dll +0 -0
  86. mindspore/mindspore_runtime_pipeline.dll +0 -0
  87. mindspore/mint/distributed/__init__.py +0 -4
  88. mindspore/mint/distributed/distributed.py +14 -217
  89. mindspore/mint/nn/layer/_functions.py +2 -1
  90. mindspore/mint/nn/layer/conv.py +6 -6
  91. mindspore/mint/nn/layer/normalization.py +3 -3
  92. mindspore/nn/cell.py +174 -216
  93. mindspore/nn/layer/activation.py +2 -4
  94. mindspore/nn/layer/basic.py +13 -7
  95. mindspore/nn/layer/image.py +1 -1
  96. mindspore/nn/optim/adam.py +3 -1
  97. mindspore/nn/optim/lamb.py +3 -1
  98. mindspore/nn/optim/tft_wrapper.py +3 -2
  99. mindspore/nn/probability/distribution/_utils/utils.py +2 -2
  100. mindspore/nn/wrap/cell_wrapper.py +5 -39
  101. mindspore/nn/wrap/grad_reducer.py +15 -0
  102. mindspore/numpy/array_creations.py +2 -2
  103. mindspore/numpy/utils_const.py +1 -1
  104. mindspore/opencv_core452.dll +0 -0
  105. mindspore/opencv_imgcodecs452.dll +0 -0
  106. mindspore/opencv_imgproc452.dll +0 -0
  107. mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
  108. mindspore/ops/_op_impl/cpu/__init__.py +0 -1
  109. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +2 -12
  110. mindspore/ops/auto_generate/gen_extend_func.py +4 -4
  111. mindspore/ops/auto_generate/gen_ops_def.py +16 -290
  112. mindspore/ops/auto_generate/gen_ops_prim.py +76 -563
  113. mindspore/ops/composite/base.py +1 -1
  114. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -1
  115. mindspore/ops/function/__init__.py +0 -1
  116. mindspore/ops/function/array_func.py +6 -10
  117. mindspore/ops/function/debug_func.py +2 -4
  118. mindspore/ops/function/grad/grad_func.py +12 -4
  119. mindspore/ops/function/math_func.py +32 -44
  120. mindspore/ops/function/nn_func.py +20 -18
  121. mindspore/ops/functional.py +1 -2
  122. mindspore/ops/functional_overload.py +12 -23
  123. mindspore/ops/operations/_inner_ops.py +12 -11
  124. mindspore/ops/operations/array_ops.py +50 -4
  125. mindspore/ops/operations/comm_ops.py +15 -1
  126. mindspore/ops/operations/custom_ops.py +4 -10
  127. mindspore/ops/operations/debug_ops.py +6 -6
  128. mindspore/ops/operations/manually_defined/ops_def.py +12 -12
  129. mindspore/ops/operations/math_ops.py +5 -5
  130. mindspore/ops/operations/nn_ops.py +1 -1
  131. mindspore/ops/primitive.py +10 -3
  132. mindspore/ops/tensor_method.py +7 -16
  133. mindspore/ops_generate/pyboost/gen_pyboost_func.py +16 -0
  134. mindspore/parallel/_auto_parallel_context.py +15 -5
  135. mindspore/parallel/_parallel_serialization.py +2 -3
  136. mindspore/parallel/_ps_context.py +2 -2
  137. mindspore/parallel/_transformer/transformer.py +4 -4
  138. mindspore/parallel/_utils.py +11 -5
  139. mindspore/parallel/auto_parallel.py +9 -23
  140. mindspore/parallel/checkpoint_transform.py +0 -2
  141. mindspore/parallel/cluster/process_entity/_api.py +1 -4
  142. mindspore/parallel/cluster/run.py +3 -5
  143. mindspore/parallel/function/reshard_func.py +5 -6
  144. mindspore/parallel/nn/parallel_cell_wrapper.py +3 -40
  145. mindspore/parallel/nn/parallel_grad_reducer.py +8 -0
  146. mindspore/parallel/shard.py +21 -7
  147. mindspore/parallel/transform_safetensors.py +4 -10
  148. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +9 -10
  149. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +1 -1
  150. mindspore/profiler/common/msprof_cmd_tool.py +2 -2
  151. mindspore/profiler/common/path_manager.py +0 -9
  152. mindspore/profiler/common/profiler_context.py +2 -25
  153. mindspore/profiler/common/profiler_meta_data.py +0 -1
  154. mindspore/profiler/common/profiler_op_analyse.py +6 -10
  155. mindspore/{ops/_op_impl/cpu/joinedstr_op.py → profiler/common/validator/__init__.py} +1 -15
  156. mindspore/profiler/common/validator/validate_path.py +84 -0
  157. mindspore/profiler/dynamic_profiler.py +46 -91
  158. mindspore/profiler/envprofiler.py +5 -30
  159. mindspore/profiler/experimental_config.py +1 -16
  160. mindspore/profiler/platform/cpu_profiler.py +4 -10
  161. mindspore/profiler/platform/npu_profiler.py +1 -1
  162. mindspore/profiler/profiler.py +145 -193
  163. mindspore/profiler/profiler_action_controller.py +1 -1
  164. mindspore/profiler/profiler_interface.py +2 -2
  165. mindspore/rewrite/symbol_tree/symbol_tree.py +1 -1
  166. mindspore/runtime/__init__.py +4 -6
  167. mindspore/runtime/executor.py +0 -27
  168. mindspore/runtime/memory.py +0 -1
  169. mindspore/runtime/thread_bind_core.py +1 -1
  170. mindspore/swresample-4.dll +0 -0
  171. mindspore/swscale-6.dll +0 -0
  172. mindspore/tinyxml2.dll +0 -0
  173. mindspore/train/_utils.py +3 -3
  174. mindspore/train/amp.py +3 -0
  175. mindspore/train/callback/_callback.py +1 -2
  176. mindspore/train/callback/_checkpoint.py +8 -1
  177. mindspore/train/callback/_flops_collector.py +6 -10
  178. mindspore/train/callback/_train_fault_tolerance.py +7 -3
  179. mindspore/train/data_sink.py +4 -4
  180. mindspore/train/dataset_helper.py +5 -5
  181. mindspore/train/model.py +20 -4
  182. mindspore/train/serialization.py +15 -35
  183. mindspore/train/train_thor/model_thor.py +2 -2
  184. mindspore/turbojpeg.dll +0 -0
  185. mindspore/utils/hooks.py +81 -0
  186. mindspore/utils/utils.py +8 -8
  187. mindspore/version.py +1 -1
  188. {mindspore-2.7.0.dist-info → mindspore-2.7.0rc1.dist-info}/METADATA +1 -1
  189. {mindspore-2.7.0.dist-info → mindspore-2.7.0rc1.dist-info}/RECORD +193 -192
  190. mindspore/_extends/parallel_compile/akg_compiler/custom.py +0 -1109
  191. mindspore/common/dynamic_shape/__init__.py +0 -0
  192. mindspore/common/dynamic_shape/enable_dynamic.py +0 -197
  193. /mindspore/common/{dynamic_shape/_auto_dynamic.py → _auto_dynamic.py} +0 -0
  194. {mindspore-2.7.0.dist-info → mindspore-2.7.0rc1.dist-info}/WHEEL +0 -0
  195. {mindspore-2.7.0.dist-info → mindspore-2.7.0rc1.dist-info}/entry_points.txt +0 -0
  196. {mindspore-2.7.0.dist-info → mindspore-2.7.0rc1.dist-info}/top_level.txt +0 -0
mindspore/common/api.py CHANGED
@@ -50,14 +50,12 @@ from mindspore.parallel._utils import _check_full_batch, _get_parameter_broadcas
50
50
  _is_parallel_mode
51
51
  from mindspore import _checkparam as Validator
52
52
  from mindspore._checkparam import is_stub_tensor
53
- from mindspore.common._utils import is_shape_unknown, get_func
53
+ from mindspore.common._utils import is_shape_unknown
54
54
  from mindspore.common.mutable import mutable, _check_element_type
55
- from mindspore.common.dynamic_shape.auto_dynamic_shape import get_auto_dynamic_shape_args, \
56
- update_auto_dynamic_shape_phase
57
- from mindspore.common.dynamic_shape.enable_dynamic import generate_dynamic_tensor_args, ENABLE_DYNAMIC
55
+ from mindspore.common.auto_dynamic_shape import get_auto_dynamic_shape_args, update_auto_dynamic_shape_phase, \
56
+ get_auto_dynamic_shape_args_with_check_input_signature, update_auto_dynamic_shape_phase_with_check_input_signature
58
57
  from mindspore.common._pijit_context import PIJitCaptureContext
59
- from mindspore.common.parameter import Parameter
60
- from mindspore.common.hook_handle import _hook_version
58
+ from mindspore.common.parameter import Parameter, set_parameter_hook_updated, parameter_hook_updated
61
59
  from mindspore.common.jit_context import jit_context
62
60
  from mindspore.common.jit_trace import _jit_trace
63
61
  from mindspore.parallel._utils import _init_auto_parallel_context, _clear_auto_parallel_context
@@ -76,11 +74,6 @@ ARG_SPECIFIED = "arg_specified_infos"
76
74
  TOTAL_ARG_LEN = "total_arg_length"
77
75
 
78
76
 
79
- def _real_phase(phase, obj):
80
- real_phase = phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
81
- return real_phase
82
-
83
-
84
77
  def _check_recompile_args(compile_args, kwargs):
85
78
  """Check recompile of graph"""
86
79
 
@@ -545,12 +538,10 @@ def _get_parameter_ids(args, kwargs):
545
538
  parameter_ids += str(id(value))
546
539
  return parameter_ids
547
540
 
548
-
549
541
  def _get_tensor_hook_key(tensor):
550
542
  """Get the hook key of Tensor/Parameter"""
551
543
  return ".".join(map(str, map(id, tensor.hooks())))
552
544
 
553
-
554
545
  def _get_hook_key(*args, **kwargs):
555
546
  """Get the hook key of Tensors/Parameters"""
556
547
  hook_key = ""
@@ -597,8 +588,6 @@ class _JitExecutor:
597
588
 
598
589
  self.fn = fn
599
590
  self.input_signature = input_signature
600
- self.dynamic_args_shapes = getattr(get_func(fn), ENABLE_DYNAMIC, None)
601
- self.enable_jit_dynamic = self.dynamic_args_shapes is not None
602
591
  self.obj = None
603
592
  if obj and hasattr(obj, fn.__name__):
604
593
  self.obj = obj
@@ -637,10 +626,12 @@ class _JitExecutor:
637
626
  else: # get compiled args to generate run args by _generate_run_args
638
627
  compile_args = self._generate_compile_args(args_list)
639
628
  key_id = self._get_key_id()
640
- if self.input_signature is None:
641
- compile_args = get_auto_dynamic_shape_args(
642
- compile_args, key_id, self._enable_auto_dynamic
643
- )
629
+ compile_args = get_auto_dynamic_shape_args_with_check_input_signature(
630
+ compile_args,
631
+ key_id,
632
+ self.input_signature,
633
+ self._enable_auto_dynamic
634
+ )
644
635
  self._compile_args = compile_args
645
636
 
646
637
  new_inputs = self._generate_run_args(args_list, kwargs)
@@ -693,13 +684,18 @@ class _JitExecutor:
693
684
 
694
685
  def compile(self, method_name, *args, **kwargs):
695
686
  """Returns pipeline for the given args."""
687
+ # Check whether hook function registered on Cell object.
688
+ if self.obj and hasattr(self.obj, "_hook_fn_registered"):
689
+ if self.obj._hook_fn_registered():
690
+ logger.warning(f"For 'Cell', it's not support hook function when using 'jit' decorator. "
691
+ f"If you want to use hook function, please use context.set_context to set "
692
+ f"pynative mode and remove 'jit' decorator.")
696
693
  # Chose dynamic shape tensors or actual input tensors as compile args.
697
694
  compile_args = self._generate_compile_args(args)
698
695
  key_id = self._get_key_id()
699
- if self.input_signature is None:
700
- compile_args = get_auto_dynamic_shape_args(
701
- compile_args, key_id, self._enable_auto_dynamic, self.enable_jit_dynamic
702
- )
696
+ compile_args = get_auto_dynamic_shape_args_with_check_input_signature(compile_args, key_id,
697
+ self.input_signature,
698
+ self._enable_auto_dynamic)
703
699
 
704
700
  # Add mutable for compile_args for two scene:
705
701
  # 1) Origin args is mutable.
@@ -739,23 +735,20 @@ class _JitExecutor:
739
735
 
740
736
  self._graph_executor.set_enable_tuple_broaden(self.enable_tuple_broaden)
741
737
  key = self._graph_executor.generate_arguments_key(self.fn, compile_args, kwargs, self.enable_tuple_broaden)
742
- key = str(key)
743
738
 
744
739
  parameter_ids = _get_parameter_ids(args, kwargs)
745
740
  if parameter_ids != "":
746
- key += '.' + parameter_ids
741
+ key = str(key) + '.' + parameter_ids
747
742
 
748
- key += "." + _get_hook_key(*args, **kwargs)
749
- key += "." + str(_hook_version())
743
+ key = str(key) + "." + _get_hook_key(*args, **kwargs)
750
744
 
751
- phase = generate_name + '.' + key
745
+ phase = generate_name + '.' + str(key)
752
746
 
753
- if self.input_signature is None:
754
- update_auto_dynamic_shape_phase(compile_args, key_id, phase)
747
+ update_auto_dynamic_shape_phase_with_check_input_signature(compile_args, key_id, phase, self.input_signature)
755
748
 
756
749
  phase = phase + self._cell_cache_key_extend
757
750
 
758
- if phase in ms_compile_cache and self._graph_executor.has_compiled(phase):
751
+ if phase in ms_compile_cache and self._graph_executor.has_compiled(phase) and not parameter_hook_updated():
759
752
  # Release resource should be released when CompileInner won't be executed, such as cur_convert_input_
760
753
  # generated in generate_arguments_key.
761
754
  self._graph_executor.clear_compile_arguments_resource()
@@ -772,9 +765,16 @@ class _JitExecutor:
772
765
 
773
766
  if self.obj is None:
774
767
  # Set an attribute to fn as an identifier.
775
- setattr(get_func(self.fn), "__jit_function__", True)
776
- is_compile = self._graph_executor.compile(self.fn, compile_args, kwargs, phase, jit_config_dict)
777
- delattr(get_func(self.fn), "__jit_function__")
768
+ if isinstance(self.fn, types.MethodType):
769
+ setattr(self.fn.__func__, "__jit_function__", True)
770
+ else:
771
+ setattr(self.fn, "__jit_function__", True)
772
+ is_compile = self._graph_executor.compile(
773
+ self.fn, compile_args, kwargs, phase, jit_config_dict)
774
+ if isinstance(self.fn, types.MethodType):
775
+ delattr(self.fn.__func__, "__jit_function__")
776
+ else:
777
+ delattr(self.fn, "__jit_function__")
778
778
  else:
779
779
  if isinstance(self.obj, ms.nn.Cell):
780
780
  self._graph_executor.set_weights_values(self.obj.parameters_dict())
@@ -783,6 +783,7 @@ class _JitExecutor:
783
783
 
784
784
  if not is_compile:
785
785
  raise RuntimeError("Executor compile failed.")
786
+ set_parameter_hook_updated(False)
786
787
  ms_compile_cache.add(phase)
787
788
  if hasattr(self.obj, "phase"):
788
789
  self.obj.phase_cache[self.obj.phase] = phase
@@ -830,70 +831,41 @@ class _JitExecutor:
830
831
  if enable_compile_cache is True or enable_compile_cache == "1":
831
832
  self._graph_executor.set_compile_cache_dep_files(_get_compile_cache_dep_files())
832
833
 
833
- def _generate_compile_args_by_enable_dynamic(self, args_list):
834
- """Generate compile args by enable_dynamic."""
835
- compile_args = generate_dynamic_tensor_args(args_list, self.dynamic_args_shapes)
836
- compile_args = _add_mutable_attr(args_list, compile_args, _pynative_executor.requires_grad())
837
- if self.obj is not None:
838
- _pynative_executor.set_dynamic_input(self.obj, *compile_args)
839
- else:
840
- _pynative_executor.set_dynamic_input(self.fn, *compile_args)
841
- logger.info(f"dynamic shape compile_args: {compile_args}")
842
- return compile_args
843
-
844
- def _generate_compile_args_by_set_inputs(self, args_list):
845
- """Generate compile args by set_inputs."""
846
- compile_args = _generate_dyn_compile_args(args_list, self.obj.get_inputs())
847
- if len(compile_args) != len(args_list):
848
- raise ValueError(f"The number of actual input tensors: {len(args_list)} is not equal to the number of "
849
- f"dynamic shape tensors: {len(compile_args)}.")
850
- self._graph_executor.check_argument_consistency(compile_args, args_list, "set_inputs")
851
- Validator.check_symbolic_shape(compile_args, args_list)
852
- return compile_args
853
-
854
- def _generate_compile_args_by_input_signature(self, args_list):
855
- """Generate compile args by input_signature."""
856
- compile_args = list(_generate_dyn_compile_args(args_list, self.input_signature))
857
- dyn_shape = any([is_shape_unknown(elem.shape) for elem in compile_args if isinstance(elem, PythonTensor)])
858
- Validator.check_symbolic_shape(self.input_signature, args_list)
859
- if dyn_shape:
860
- # Checkout whether the `sens` has been added to args_list.
861
- if len(compile_args) == len(args_list) - 1:
862
- logger.warning(f"The number of actual input args '{len(args_list)}' is one more than the number "
863
- f"of input_signature args '{len(compile_args)}'. The last actual args may "
864
- f"be 'sens' and added it to compile args.")
865
- compile_args.append(args_list[-1])
866
- compile_args = tuple(compile_args)
867
- self._graph_executor.check_argument_consistency(compile_args, args_list, "input_signature")
868
- if self.obj is not None:
869
- _pynative_executor.set_dynamic_input(self.obj, *compile_args)
870
- else:
871
- _pynative_executor.set_dynamic_input(self.fn, *compile_args)
872
- else:
873
- if not verify_inputs_signature(compile_args, args_list):
874
- raise ValueError("The input args is incompatible with the args in `input_signature`!")
875
- return compile_args
876
-
877
- def _check_set_inputs(self):
878
- """Check if the `set_inputs()` of Cell object has been set."""
879
- return self.fn.__name__ == 'construct' and isinstance(self.obj, ms.nn.Cell) and self.obj.get_inputs()
880
-
881
834
  def _generate_compile_args(self, args_list):
882
835
  """Chose dynamic shape tensors or actual input tensors as compile args."""
883
- # Case: The `enable_dynamic` is provided and `set_inputs()` of Cell object has been set.
884
- if self.enable_jit_dynamic and self._check_set_inputs():
885
- raise ValueError("When `enable_dynamic` is provided, the `set_inputs()` cannot be set!")
886
- # Case: The `enable_dynamic` is provided.
887
- if self.enable_jit_dynamic:
888
- return self._generate_compile_args_by_enable_dynamic(args_list)
836
+ # Case: If the shape of input args is dynamic, get dynamic shape tensor from context and use it to compile.
837
+ compile_args = _pynative_executor.get_dynamic_input(args_list)
889
838
  # Case: The `set_inputs()` of Cell object has been set, using these dynamic shape args as compile args.
890
- if self._check_set_inputs():
891
- return self._generate_compile_args_by_set_inputs(args_list)
839
+ if self.fn.__name__ == 'construct' and isinstance(self.obj, ms.nn.Cell) and self.obj.get_inputs():
840
+ compile_args = _generate_dyn_compile_args(args_list, self.obj.get_inputs())
841
+ if len(compile_args) != len(args_list):
842
+ raise ValueError(f"The number of actual input tensors: {len(args_list)} is not equal to the number of "
843
+ f"dynamic shape tensors: {len(compile_args)}.")
844
+ self._graph_executor.check_argument_consistency(compile_args, args_list, "input_signature")
845
+ Validator.check_symbolic_shape(compile_args, args_list)
846
+
892
847
  # Case: If dynamic shape tensors have been assigned to `input_signature`, they are preferred as compile args.
893
848
  if self.input_signature is not None:
894
- return self._generate_compile_args_by_input_signature(args_list)
895
- # Case: If the shape of input args is dynamic, get dynamic shape tensor from context and use it to compile.
896
- return _pynative_executor.get_dynamic_input(args_list)
849
+ compile_args = list(_generate_dyn_compile_args(args_list, self.input_signature))
850
+ dyn_shape = any([is_shape_unknown(elem.shape) for elem in compile_args if isinstance(elem, PythonTensor)])
851
+ Validator.check_symbolic_shape(self.input_signature, args_list)
852
+ if dyn_shape:
853
+ # Checkout whether the `sens` has been added to args_list.
854
+ if len(compile_args) == len(args_list) - 1:
855
+ logger.warning(f"The number of actual input args '{len(args_list)}' is one more than the number "
856
+ f"of input_signature args '{len(compile_args)}'. The last actual args may "
857
+ f"be 'sens' and added it to compile args.")
858
+ compile_args.append(args_list[-1])
859
+ compile_args = tuple(compile_args)
860
+ self._graph_executor.check_argument_consistency(compile_args, args_list, "input_signature")
861
+ if self.obj is not None:
862
+ _pynative_executor.set_dynamic_input(self.obj, *compile_args)
863
+ else:
864
+ _pynative_executor.set_dynamic_input(self.fn, *compile_args)
865
+ else:
866
+ if not verify_inputs_signature(compile_args, args_list):
867
+ raise ValueError("The input args is incompatible with the args in `input_signature`!")
868
+ return compile_args
897
869
 
898
870
  def _generate_run_args(self, args_list, kwargs):
899
871
  """
@@ -1105,7 +1077,10 @@ def _jit_ast(hash_obj, dynamic, jit_config, jit_graph_name):
1105
1077
  process_obj = args[0]
1106
1078
  # Handle auto mixed precision strategy.
1107
1079
  if not hasattr(func, "amp_strategy"):
1108
- setattr(get_func(func), "amp_strategy", get_curr_amp_strategy())
1080
+ if isinstance(func, types.MethodType):
1081
+ setattr(func.__func__, "amp_strategy", get_curr_amp_strategy())
1082
+ else:
1083
+ setattr(func, "amp_strategy", get_curr_amp_strategy())
1109
1084
 
1110
1085
  jit_graph_name = ''
1111
1086
  if hasattr(staging_specialize, "__jit_graph_name__"):
@@ -1113,8 +1088,6 @@ def _jit_ast(hash_obj, dynamic, jit_config, jit_graph_name):
1113
1088
  jit_executor = _JitExecutor(
1114
1089
  func, hash_obj, None, process_obj, jit_config, dynamic, jit_graph_name)
1115
1090
  out = jit_executor(*args, **kwargs)
1116
- if isinstance(process_obj, ms.nn.Cell):
1117
- _clear_auto_parallel_context(process_obj)
1118
1091
  return out
1119
1092
 
1120
1093
  # `inspect.getfullargspec(func)` will get the specification of the decorated function by default. By set
@@ -1154,26 +1127,28 @@ def jit(
1154
1127
 
1155
1128
  Keyword Args:
1156
1129
  capture_mode (str, optional): The method to create a callable MindSpore graph. The value of capture_mode
1157
- should be ``"ast"`` , ``"bytecode"`` or ``"trace"`` . Default: ``"ast"`` .
1130
+ should be ``ast`` , ``bytecode`` or ``trace`` . Default: ``ast`` .
1158
1131
 
1159
- - ast: Parse Python ast to build graph.
1160
- - bytecode: Parse Python bytecode to build graph at runtime. This is an experimental prototype
1161
- that is subject to change and/or deletion.
1162
- - trace: Trace the execution of Python code to build graph. This is an experimental prototype
1163
- that is subject to change and/or deletion.
1132
+ - `ast <https://www.mindspore.cn/docs/en/master/features/compile/graph_construction.html#ast>`_ :
1133
+ Parse Python ast to build graph.
1134
+ - `bytecode <https://www.mindspore.cn/docs/en/master/features/compile/graph_construction.html#bytecode>`_ :
1135
+ Parse Python bytecode to build graph at runtime. This is an experimental prototype that is subject to
1136
+ change and/or deletion.
1137
+ - `trace <https://www.mindspore.cn/docs/en/master/features/compile/graph_construction.html#trace>`_ : Trace the execution of Python code to build graph. This is an experimental prototype that is
1138
+ subject to change and/or deletion.
1164
1139
 
1165
1140
  jit_level (str, optional): Used to control the compilation optimization level. Currently is only effective
1166
- with ms_backend. The value of jit_level should be ``"O0"`` or ``"O1"`` . Default: ``"O0"`` .
1141
+ with ms_backend. The value of jit_level should be ``O0`` or ``O1`` . Default: ``O0`` .
1167
1142
 
1168
- - O0: Except for optimizations that may affect functionality, all other optimizations are turned off.
1169
- - O1: Using commonly used optimizations and automatic operator fusion optimizations. This optimization
1143
+ - `O0`: Except for optimizations that may affect functionality, all other optimizations are turned off.
1144
+ - `O1`: Using commonly used optimizations and automatic operator fusion optimizations. This optimization
1170
1145
  level is experimental and is being improved.
1171
1146
 
1172
1147
  dynamic (int, optional): Whether dynamic shape compilation should be performed. Default: ``0``. The value range
1173
1148
  is as follows:
1174
1149
 
1175
- - 0: Do not perform dynamic shape compilation.
1176
- - 1: Enable dynamic shape compilation and automatically detect shape changes.
1150
+ - `0`: Do not perform dynamic shape compilation.
1151
+ - `1`: Enable dynamic shape compilation and automatically detect shape changes.
1177
1152
 
1178
1153
  fullgraph (bool, optional): Whether to capture the entire function into graph. If False, jit attempts to
1179
1154
  be compatible with all Python syntax in the function as much as possible. If True, we require that the
@@ -1181,14 +1156,12 @@ def jit(
1181
1156
  not supported), then it will raise an exception. This currently only applies when capture_mode is ``ast``
1182
1157
  or ``bytecode``. Default: ``False``.
1183
1158
  backend (str, optional): The compilation backend to be used. If this parameter is not set, the framework will
1184
- use ``"GE"`` backend for Atlas training series products and ``"ms_backend"`` backend for others including
1185
- Atlas A2 training series products by default.
1159
+ use ``GE`` backend for Atlas training series products and ``ms_backend`` backend for others including Atlas
1160
+ A2 training series products by default.
1186
1161
 
1187
- - ms_backend: Utilizes the built-in backend engine of MindSpore for hardware-related compilation
1188
- optimization and execution, supporting multiple hardware forms such as Ascend, GPU, and CPU.
1189
- - GE: Utilizes the GraphEngine, a graph compilation and execution engine within CANN,
1190
- for Ascend model compilation and execution. Note: This backend takes effect only in static graph mode
1191
- and can be executed only on Ascend hardware.
1162
+ - `ms_backend`: Adopt KernelByKernel execution mode.
1163
+ - `GE`: Adopt Sink execution mode. The whole model will be sinked to device to execute, only applicable to
1164
+ the top cell of model. And only can be used in Ascend platform.
1192
1165
 
1193
1166
  **options (dict): A dictionary of options to pass to the compilation backend.
1194
1167
 
@@ -1211,11 +1184,11 @@ def jit(
1211
1184
  `disable_format_transform` can be set to ``True`` to try to improve training performance.
1212
1185
  Default: ``False`` .
1213
1186
  - exec_order (str, optional): Set the sorting method for operator execution, currently only two sorting
1214
- methods are supported: ``"bfs"`` and ``"dfs"`` . Default: ``"bfs"`` .
1187
+ methods are supported: ``bfs`` and ``dfs`` . Default: ``bfs`` .
1215
1188
 
1216
- - bfs: The default sorting method, breadth priority, good communication masking, relatively good
1189
+ - `bfs`: The default sorting method, breadth priority, good communication masking, relatively good
1217
1190
  performance.
1218
- - dfs: An optional sorting method, depth-first sorting. The performance is relatively worse than that
1191
+ - `dfs`: An optional sorting method, depth-first sorting. The performance is relatively worse than that
1219
1192
  of bfs execution order, but it occupies less memory. It is recommended to try dfs in scenarios where
1220
1193
  other execution orders run out of memory (OOM).
1221
1194
 
@@ -1226,11 +1199,11 @@ def jit(
1226
1199
  - global (dict): Set global options.
1227
1200
  - session (dict): Set session options.
1228
1201
 
1229
- - infer_boost (str, optional): Used to control the inference mode. Default: ``"off"``, which means
1202
+ - infer_boost (str, optional): Used to control the inference mode. Default: ``off``, which means
1230
1203
  the inference mode is disabled. The range is as follows:
1231
1204
 
1232
- - on: Enable inference mode, get better infer performance.
1233
- - off: Disable inference mode, use forward for inference. The performance is poor.
1205
+ - `on`: Enable inference mode, get better infer performance.
1206
+ - `off`: Disable inference mode, use forward for inference. The performance is poor.
1234
1207
 
1235
1208
  Returns:
1236
1209
  Function, if `fn` is not None, returns a callable function that will execute the compiled function; If `fn` is
@@ -1921,19 +1894,6 @@ class _PyNativeExecutor:
1921
1894
  """
1922
1895
  return self._executor.constant_folding(*args)
1923
1896
 
1924
- def set_creation_type(self, tensor, creation_type):
1925
- """
1926
- Set tensor's view creation type
1927
-
1928
- Args:
1929
- tensor (Tensor): input tensor.
1930
- creation_type (CreationType): The type of view tensor when it is created.
1931
-
1932
- Return:
1933
- None.
1934
- """
1935
- return self._executor.set_creation_type(tensor, creation_type)
1936
-
1937
1897
 
1938
1898
  class _CellGraphExecutor:
1939
1899
  """
@@ -2042,11 +2002,6 @@ class _CellGraphExecutor:
2042
2002
  if not hasattr(obj, obj.__parse_method__):
2043
2003
  raise AttributeError(
2044
2004
  'The class {} does not have method {}'.format(obj.__class__.__name__, obj.__parse_method__))
2045
- inner_func = inspect.unwrap(obj.construct)
2046
- if hasattr(get_func(inner_func), ENABLE_DYNAMIC):
2047
- raise ValueError(
2048
- "When using set_context(mode=GRAPH_MODE) together with nn.Cell, the 'enable_dynamic' cannot be set!"
2049
- )
2050
2005
  key_id = str(id(obj)) + str(obj.create_time)
2051
2006
  args = get_auto_dynamic_shape_args(args, key_id)
2052
2007
 
@@ -2057,25 +2012,20 @@ class _CellGraphExecutor:
2057
2012
  self._graph_executor.set_enable_tuple_broaden(self.enable_tuple_broaden)
2058
2013
 
2059
2014
  key = self._graph_executor.generate_arguments_key(obj, args, kwargs, self.enable_tuple_broaden)
2060
- key = str(key)
2015
+ obj.arguments_key = str(key)
2016
+
2017
+ obj.arguments_key = obj.arguments_key + "." + _get_hook_key(*args, **kwargs)
2061
2018
 
2062
2019
  # When exist parameter in the top graph inputs, need check if the parameter object has changed.
2063
2020
  parameter_ids = _get_parameter_ids(args, kwargs)
2064
2021
  if parameter_ids != "":
2065
- key += '.' + parameter_ids
2066
-
2067
- key += "." + _get_hook_key(*args, **kwargs)
2068
- key += "." + str(_hook_version())
2069
-
2070
- obj.arguments_key = key
2071
-
2022
+ obj.arguments_key = obj.arguments_key + '.' + parameter_ids
2072
2023
  raw_phase = phase
2073
-
2074
- phase = _real_phase(phase, obj)
2024
+ phase = phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
2075
2025
  obj.phase_cache[raw_phase] = phase
2076
2026
  update_auto_dynamic_shape_phase(args, key_id, phase)
2077
2027
  obj.current_phase = phase
2078
- if phase in obj.compile_cache and self.has_compiled(phase):
2028
+ if phase in obj.compile_cache and self.has_compiled(phase) and not parameter_hook_updated():
2079
2029
  logger.debug("%r graph has existed.", phase)
2080
2030
  # Release resource should be released when CompileInner won't be executed, such as cur_convert_input_
2081
2031
  # generated in generate_arguments_key.
@@ -2101,6 +2051,7 @@ class _CellGraphExecutor:
2101
2051
  obj.compile_cache.add(phase)
2102
2052
  if not result:
2103
2053
  raise RuntimeError("Executor compile failed.")
2054
+ set_parameter_hook_updated(False)
2104
2055
  graph = self._graph_executor.get_func_graph(phase)
2105
2056
 
2106
2057
  if graph is None:
@@ -2125,15 +2076,15 @@ class _CellGraphExecutor:
2125
2076
  return self._graph_executor.updata_param_node_default_input(phase, new_param)
2126
2077
 
2127
2078
  def _get_shard_strategy(self, obj):
2128
- real_phase = _real_phase(obj.phase, obj)
2079
+ real_phase = obj.phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
2129
2080
  return self._graph_executor.get_strategy(real_phase)
2130
2081
 
2131
2082
  def _get_num_parallel_ops(self, obj):
2132
- real_phase = _real_phase(obj.phase, obj)
2083
+ real_phase = obj.phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
2133
2084
  return self._graph_executor.get_num_parallel_ops(real_phase)
2134
2085
 
2135
2086
  def _get_allreduce_fusion(self, obj):
2136
- real_phase = _real_phase(obj.phase, obj)
2087
+ real_phase = obj.phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
2137
2088
  return self._graph_executor.get_allreduce_fusion(real_phase)
2138
2089
 
2139
2090
  def __call__(self, obj, *args, phase='predict'):
@@ -2185,10 +2136,10 @@ class _CellGraphExecutor:
2185
2136
  Tensor/Tuple, return execute result.
2186
2137
  """
2187
2138
  if phase == 'save':
2188
- exe_phase = _real_phase(phase, obj)
2139
+ exe_phase = phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
2189
2140
  return self._graph_executor((), exe_phase)
2190
2141
 
2191
- phase_real = _real_phase(phase, obj)
2142
+ phase_real = phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
2192
2143
  if self.has_compiled(phase_real):
2193
2144
  return self._exec_pip(obj, *args, phase=phase_real)
2194
2145
  raise KeyError('{} graph is not exist.'.format(phase_real))
@@ -2215,7 +2166,7 @@ class _CellGraphExecutor:
2215
2166
 
2216
2167
  def get_optimize_graph_proto(self, obj):
2217
2168
  """Return optimize graph binary proto."""
2218
- exec_id = _real_phase(obj.phase, obj)
2169
+ exec_id = obj.phase + "." + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
2219
2170
  if self._graph_executor.has_compiled(exec_id) is False:
2220
2171
  return None
2221
2172
  graph_proto = self._graph_executor.get_optimize_graph_proto(exec_id)
@@ -1,6 +1,6 @@
1
1
  # This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
2
2
  #
3
- # Copyright 2020-2025 Huawei Technologies Co., Ltd
3
+ # Copyright 2020-2023 Huawei Technologies Co., Ltd
4
4
  #
5
5
  # Licensed under the Apache License, Version 2.0 (the "License");
6
6
  # you may not use this file except in compliance with the License.
@@ -261,12 +261,7 @@ class _AutoIdentifyDynamicShape:
261
261
  return False
262
262
  return True
263
263
 
264
- @staticmethod
265
- def _is_invalid_shape(shape):
266
- """Check if input shape is valid"""
267
- return is_shape_unknown(shape) or not shape
268
-
269
- def _is_enable_auto_dynamic_shape(self, args_list, is_sink_mode, enable_jit_dynamic=False):
264
+ def _is_enable_auto_dynamic_shape(self, args_list, is_sink_mode):
270
265
  """is enable auto identify shape"""
271
266
  if not is_sink_mode and not args_list:
272
267
  return False
@@ -275,12 +270,10 @@ class _AutoIdentifyDynamicShape:
275
270
  continue
276
271
  if not isinstance(elem, (list, tuple, Tensor, int, float)):
277
272
  return False
278
- if isinstance(elem, Tensor) and \
279
- self._is_invalid_shape(elem.shape) and \
280
- not enable_jit_dynamic:
273
+ if isinstance(elem, Tensor) and (is_shape_unknown(elem.shape) or (not elem.shape)):
281
274
  return False
282
275
  if not is_sink_mode and isinstance(elem, (list, tuple)):
283
- return self._is_enable_auto_dynamic_shape(elem, is_sink_mode, enable_jit_dynamic)
276
+ return self._is_enable_auto_dynamic_shape(elem, is_sink_mode)
284
277
  return True
285
278
 
286
279
  @staticmethod
@@ -335,10 +328,10 @@ class _AutoIdentifyDynamicShape:
335
328
  logger.info((f'generalize with generalize shape cache, compile args shape = {res_shape}'))
336
329
  return new_generalize_shape
337
330
 
338
- def auto_dynamic_generate_compile_args(self, args_list, is_sink_mode, enable_jit_dynamic=False):
331
+ def auto_dynamic_generate_compile_args(self, args_list, is_sink_mode):
339
332
  """generate compile args in auto dynamic shape"""
340
333
  if not self.is_enable_auto_dynamic_shape or \
341
- not self._is_enable_auto_dynamic_shape(args_list, is_sink_mode, enable_jit_dynamic) or \
334
+ not self._is_enable_auto_dynamic_shape(args_list, is_sink_mode) or \
342
335
  not self._check_input_number_and_type(args_list):
343
336
  self.is_enable_auto_dynamic_shape = False
344
337
  return args_list
@@ -482,13 +475,11 @@ class _AutoIdentifyDynamicShape:
482
475
  _auto_dynamic_shape = _AutoIdentifyDynamicShape()
483
476
 
484
477
 
485
- def get_auto_dynamic_shape_args(compile_args, key_id, enable_auto_dynamic=False, enable_jit_dynamic=False):
478
+ def get_auto_dynamic_shape_args(compile_args, key_id, enable_auto_dynamic=False):
486
479
  """get auto dynamic shape args."""
487
480
  if key_id not in auto_dynamic_shape_dict:
488
481
  auto_dynamic_shape_dict[key_id] = _AutoIdentifyDynamicShape(enable_auto_dynamic)
489
- compile_args = auto_dynamic_shape_dict[key_id].auto_dynamic_generate_compile_args(
490
- compile_args, False, enable_jit_dynamic
491
- )
482
+ compile_args = auto_dynamic_shape_dict[key_id].auto_dynamic_generate_compile_args(compile_args, False)
492
483
  return compile_args
493
484
 
494
485
 
@@ -496,3 +487,18 @@ def update_auto_dynamic_shape_phase(compile_args, key_id, phase):
496
487
  """update auto dynamic shape phase."""
497
488
  if key_id in auto_dynamic_shape_dict:
498
489
  auto_dynamic_shape_dict[key_id].update_phase_and_compile_args(compile_args, phase, False)
490
+
491
+
492
+ def get_auto_dynamic_shape_args_with_check_input_signature(compile_args, key_id, input_signature,
493
+ enable_auto_dynamic=False):
494
+ """get auto dynamic shape args."""
495
+ if input_signature is None:
496
+ return get_auto_dynamic_shape_args(compile_args, key_id, enable_auto_dynamic)
497
+ return compile_args
498
+
499
+
500
+ def update_auto_dynamic_shape_phase_with_check_input_signature(compile_args, key_id, phase, input_signature):
501
+ """update auto dynamic shape phase."""
502
+ if input_signature is None:
503
+ if key_id in auto_dynamic_shape_dict:
504
+ auto_dynamic_shape_dict[key_id].update_phase_and_compile_args(compile_args, phase, False)