mindspore 2.7.0__cp311-cp311-win_amd64.whl → 2.7.1__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (290) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
  6. mindspore/_extends/parse/compile_config.py +24 -1
  7. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +6 -2
  8. mindspore/_extends/parse/resources.py +1 -1
  9. mindspore/_extends/parse/standard_method.py +8 -1
  10. mindspore/_extends/parse/trope.py +2 -1
  11. mindspore/_extends/pijit/pijit_func_white_list.py +7 -22
  12. mindspore/avcodec-59.dll +0 -0
  13. mindspore/avdevice-59.dll +0 -0
  14. mindspore/avfilter-8.dll +0 -0
  15. mindspore/avformat-59.dll +0 -0
  16. mindspore/avutil-57.dll +0 -0
  17. mindspore/boost/base.py +29 -2
  18. mindspore/common/_decorator.py +3 -2
  19. mindspore/common/_grad_function.py +3 -1
  20. mindspore/common/_tensor_cpp_method.py +1 -1
  21. mindspore/common/_tensor_docs.py +275 -64
  22. mindspore/common/_utils.py +0 -44
  23. mindspore/common/api.py +285 -35
  24. mindspore/common/dump.py +7 -108
  25. mindspore/common/dynamic_shape/auto_dynamic_shape.py +1 -3
  26. mindspore/common/hook_handle.py +60 -0
  27. mindspore/common/jit_config.py +5 -1
  28. mindspore/common/jit_trace.py +27 -12
  29. mindspore/common/lazy_inline.py +5 -3
  30. mindspore/common/parameter.py +13 -107
  31. mindspore/common/recompute.py +4 -11
  32. mindspore/common/tensor.py +16 -169
  33. mindspore/communication/_comm_helper.py +11 -1
  34. mindspore/communication/comm_func.py +138 -4
  35. mindspore/communication/management.py +85 -1
  36. mindspore/config/op_info.config +0 -15
  37. mindspore/context.py +5 -85
  38. mindspore/dataset/engine/datasets.py +8 -4
  39. mindspore/dataset/engine/datasets_vision.py +1 -1
  40. mindspore/dataset/engine/validators.py +1 -15
  41. mindspore/dnnl.dll +0 -0
  42. mindspore/{experimental/llm_boost/ascend_native → graph}/__init__.py +7 -7
  43. mindspore/graph/custom_pass.py +55 -0
  44. mindspore/include/dataset/execute.h +2 -2
  45. mindspore/jpeg62.dll +0 -0
  46. mindspore/mindrecord/__init__.py +3 -3
  47. mindspore/mindrecord/common/exceptions.py +1 -0
  48. mindspore/mindrecord/config.py +1 -1
  49. mindspore/{parallel/mpi → mindrecord/core}/__init__.py +4 -1
  50. mindspore/mindrecord/{shardheader.py → core/shardheader.py} +2 -1
  51. mindspore/mindrecord/{shardindexgenerator.py → core/shardindexgenerator.py} +1 -1
  52. mindspore/mindrecord/{shardreader.py → core/shardreader.py} +2 -1
  53. mindspore/mindrecord/{shardsegment.py → core/shardsegment.py} +2 -2
  54. mindspore/mindrecord/{shardutils.py → core/shardutils.py} +1 -1
  55. mindspore/mindrecord/{shardwriter.py → core/shardwriter.py} +1 -1
  56. mindspore/mindrecord/filereader.py +4 -4
  57. mindspore/mindrecord/filewriter.py +5 -5
  58. mindspore/mindrecord/mindpage.py +2 -2
  59. mindspore/mindrecord/tools/cifar10.py +1 -1
  60. mindspore/mindrecord/tools/cifar100.py +1 -1
  61. mindspore/mindrecord/tools/cifar100_to_mr.py +1 -1
  62. mindspore/mindrecord/tools/cifar10_to_mr.py +1 -1
  63. mindspore/mindrecord/tools/csv_to_mr.py +1 -1
  64. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  65. mindspore/mindrecord/tools/mnist_to_mr.py +1 -1
  66. mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -1
  67. mindspore/mindspore_backend_common.dll +0 -0
  68. mindspore/mindspore_backend_manager.dll +0 -0
  69. mindspore/mindspore_cluster.dll +0 -0
  70. mindspore/mindspore_common.dll +0 -0
  71. mindspore/mindspore_core.dll +0 -0
  72. mindspore/mindspore_cpu.dll +0 -0
  73. mindspore/mindspore_dump.dll +0 -0
  74. mindspore/mindspore_frontend.dll +0 -0
  75. mindspore/mindspore_glog.dll +0 -0
  76. mindspore/mindspore_hardware_abstract.dll +0 -0
  77. mindspore/mindspore_memory_pool.dll +0 -0
  78. mindspore/mindspore_ms_backend.dll +0 -0
  79. mindspore/mindspore_ops.dll +0 -0
  80. mindspore/{mindspore_ops_host.dll → mindspore_ops_cpu.dll} +0 -0
  81. mindspore/mindspore_profiler.dll +0 -0
  82. mindspore/mindspore_pyboost.dll +0 -0
  83. mindspore/mindspore_pynative.dll +0 -0
  84. mindspore/mindspore_runtime_pipeline.dll +0 -0
  85. mindspore/mindspore_runtime_utils.dll +0 -0
  86. mindspore/mindspore_tools.dll +0 -0
  87. mindspore/mint/__init__.py +15 -10
  88. mindspore/mint/distributed/distributed.py +182 -62
  89. mindspore/mint/nn/__init__.py +2 -16
  90. mindspore/mint/nn/functional.py +4 -110
  91. mindspore/mint/nn/layer/__init__.py +0 -2
  92. mindspore/mint/nn/layer/activation.py +0 -6
  93. mindspore/mint/nn/layer/basic.py +0 -47
  94. mindspore/mint/nn/layer/conv.py +4 -4
  95. mindspore/mint/nn/layer/normalization.py +8 -13
  96. mindspore/mint/nn/layer/pooling.py +0 -4
  97. mindspore/nn/__init__.py +1 -3
  98. mindspore/nn/cell.py +16 -66
  99. mindspore/nn/layer/basic.py +49 -1
  100. mindspore/nn/layer/container.py +16 -0
  101. mindspore/nn/layer/embedding.py +4 -169
  102. mindspore/nn/layer/normalization.py +2 -1
  103. mindspore/nn/layer/thor_layer.py +4 -85
  104. mindspore/nn/optim/ada_grad.py +0 -1
  105. mindspore/nn/optim/adafactor.py +0 -1
  106. mindspore/nn/optim/adam.py +31 -124
  107. mindspore/nn/optim/adamax.py +0 -1
  108. mindspore/nn/optim/asgd.py +0 -1
  109. mindspore/nn/optim/ftrl.py +8 -102
  110. mindspore/nn/optim/lamb.py +0 -1
  111. mindspore/nn/optim/lars.py +0 -3
  112. mindspore/nn/optim/lazyadam.py +25 -218
  113. mindspore/nn/optim/momentum.py +5 -43
  114. mindspore/nn/optim/optimizer.py +6 -55
  115. mindspore/nn/optim/proximal_ada_grad.py +0 -1
  116. mindspore/nn/optim/rmsprop.py +0 -1
  117. mindspore/nn/optim/rprop.py +0 -1
  118. mindspore/nn/optim/sgd.py +0 -1
  119. mindspore/nn/optim/tft_wrapper.py +0 -1
  120. mindspore/nn/optim/thor.py +0 -2
  121. mindspore/nn/probability/bijector/bijector.py +7 -8
  122. mindspore/nn/probability/bijector/gumbel_cdf.py +2 -2
  123. mindspore/nn/probability/bijector/power_transform.py +20 -21
  124. mindspore/nn/probability/bijector/scalar_affine.py +5 -5
  125. mindspore/nn/probability/bijector/softplus.py +13 -14
  126. mindspore/nn/wrap/grad_reducer.py +4 -74
  127. mindspore/numpy/array_creations.py +2 -2
  128. mindspore/numpy/fft.py +9 -9
  129. mindspore/{nn/reinforcement → onnx}/__init__.py +5 -8
  130. mindspore/onnx/onnx_export.py +137 -0
  131. mindspore/opencv_core4110.dll +0 -0
  132. mindspore/opencv_imgcodecs4110.dll +0 -0
  133. mindspore/{opencv_imgproc452.dll → opencv_imgproc4110.dll} +0 -0
  134. mindspore/ops/__init__.py +2 -0
  135. mindspore/ops/_grad_experimental/grad_comm_ops.py +38 -2
  136. mindspore/ops/_op_impl/aicpu/__init__.py +0 -10
  137. mindspore/ops/_op_impl/cpu/__init__.py +0 -5
  138. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +16 -22
  139. mindspore/ops/auto_generate/gen_extend_func.py +2 -7
  140. mindspore/ops/auto_generate/gen_ops_def.py +98 -141
  141. mindspore/ops/auto_generate/gen_ops_prim.py +12708 -12686
  142. mindspore/ops/communication.py +97 -0
  143. mindspore/ops/composite/__init__.py +5 -2
  144. mindspore/ops/composite/base.py +15 -1
  145. mindspore/ops/composite/multitype_ops/__init__.py +3 -1
  146. mindspore/ops/composite/multitype_ops/_compile_utils.py +150 -8
  147. mindspore/ops/composite/multitype_ops/add_impl.py +7 -0
  148. mindspore/ops/composite/multitype_ops/mod_impl.py +27 -0
  149. mindspore/ops/function/__init__.py +1 -0
  150. mindspore/ops/function/array_func.py +14 -12
  151. mindspore/ops/function/comm_func.py +3883 -0
  152. mindspore/ops/function/debug_func.py +3 -4
  153. mindspore/ops/function/math_func.py +45 -54
  154. mindspore/ops/function/nn_func.py +75 -294
  155. mindspore/ops/function/random_func.py +9 -18
  156. mindspore/ops/functional.py +2 -0
  157. mindspore/ops/functional_overload.py +354 -18
  158. mindspore/ops/operations/__init__.py +2 -5
  159. mindspore/ops/operations/_custom_ops_utils.py +7 -9
  160. mindspore/ops/operations/_inner_ops.py +1 -38
  161. mindspore/ops/operations/_rl_inner_ops.py +0 -933
  162. mindspore/ops/operations/array_ops.py +1 -0
  163. mindspore/ops/operations/comm_ops.py +94 -2
  164. mindspore/ops/operations/custom_ops.py +228 -19
  165. mindspore/ops/operations/debug_ops.py +27 -29
  166. mindspore/ops/operations/manually_defined/ops_def.py +27 -306
  167. mindspore/ops/operations/nn_ops.py +2 -2
  168. mindspore/ops/operations/sparse_ops.py +0 -83
  169. mindspore/ops/primitive.py +1 -17
  170. mindspore/ops/tensor_method.py +72 -3
  171. mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +5 -5
  172. mindspore/ops_generate/aclnn/gen_aclnn_implement.py +8 -8
  173. mindspore/ops_generate/api/functions_cc_generator.py +53 -4
  174. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +25 -11
  175. mindspore/ops_generate/common/gen_constants.py +11 -10
  176. mindspore/ops_generate/common/op_proto.py +18 -1
  177. mindspore/ops_generate/common/template.py +102 -245
  178. mindspore/ops_generate/common/template_utils.py +212 -0
  179. mindspore/ops_generate/gen_custom_ops.py +69 -0
  180. mindspore/ops_generate/op_def/ops_def_cc_generator.py +78 -7
  181. mindspore/ops_generate/op_def_py/base_op_prim_py_generator.py +360 -0
  182. mindspore/ops_generate/op_def_py/custom_op_prim_py_generator.py +140 -0
  183. mindspore/ops_generate/op_def_py/op_def_py_generator.py +54 -7
  184. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -312
  185. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +74 -17
  186. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +22 -5
  187. mindspore/ops_generate/pyboost/op_template_parser.py +3 -2
  188. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +21 -5
  189. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +2 -2
  190. mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +30 -10
  191. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +10 -3
  192. mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +1 -1
  193. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +19 -9
  194. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +71 -28
  195. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +10 -9
  196. mindspore/ops_generate/pyboost/pyboost_utils.py +27 -16
  197. mindspore/ops_generate/resources/yaml_loader.py +13 -0
  198. mindspore/ops_generate/tensor_py_cc_generator.py +2 -2
  199. mindspore/parallel/_cell_wrapper.py +1 -1
  200. mindspore/parallel/_parallel_serialization.py +1 -4
  201. mindspore/parallel/_utils.py +29 -6
  202. mindspore/parallel/checkpoint_transform.py +18 -2
  203. mindspore/parallel/cluster/process_entity/_api.py +24 -32
  204. mindspore/parallel/cluster/process_entity/_utils.py +9 -5
  205. mindspore/{experimental/llm_boost/atb → parallel/distributed}/__init__.py +21 -23
  206. mindspore/parallel/distributed/distributed_data_parallel.py +393 -0
  207. mindspore/parallel/distributed/flatten_grad_buffer.py +295 -0
  208. mindspore/parallel/strategy.py +336 -0
  209. mindspore/parallel/transform_safetensors.py +117 -16
  210. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +3 -0
  211. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +1 -1
  212. mindspore/profiler/common/constant.py +5 -0
  213. mindspore/profiler/common/file_manager.py +9 -0
  214. mindspore/profiler/common/msprof_cmd_tool.py +38 -2
  215. mindspore/profiler/common/path_manager.py +56 -24
  216. mindspore/profiler/common/profiler_context.py +2 -12
  217. mindspore/profiler/common/profiler_info.py +3 -3
  218. mindspore/profiler/common/profiler_path_manager.py +13 -0
  219. mindspore/profiler/common/util.py +30 -3
  220. mindspore/profiler/experimental_config.py +2 -1
  221. mindspore/profiler/platform/npu_profiler.py +33 -6
  222. mindspore/run_check/_check_version.py +108 -24
  223. mindspore/runtime/__init__.py +3 -2
  224. mindspore/runtime/executor.py +11 -3
  225. mindspore/runtime/memory.py +112 -0
  226. mindspore/swresample-4.dll +0 -0
  227. mindspore/swscale-6.dll +0 -0
  228. mindspore/tinyxml2.dll +0 -0
  229. mindspore/{experimental/llm_boost → tools}/__init__.py +5 -5
  230. mindspore/tools/data_dump.py +130 -0
  231. mindspore/tools/sdc_detect.py +91 -0
  232. mindspore/tools/stress_detect.py +63 -0
  233. mindspore/train/__init__.py +6 -6
  234. mindspore/train/_utils.py +5 -18
  235. mindspore/train/amp.py +6 -4
  236. mindspore/train/callback/_checkpoint.py +0 -9
  237. mindspore/train/callback/_train_fault_tolerance.py +69 -18
  238. mindspore/train/data_sink.py +1 -5
  239. mindspore/train/model.py +38 -211
  240. mindspore/train/serialization.py +126 -387
  241. mindspore/turbojpeg.dll +0 -0
  242. mindspore/utils/__init__.py +6 -3
  243. mindspore/utils/dlpack.py +92 -0
  244. mindspore/utils/dryrun.py +1 -1
  245. mindspore/utils/runtime_execution_order_check.py +10 -0
  246. mindspore/utils/sdc_detect.py +14 -12
  247. mindspore/utils/stress_detect.py +43 -0
  248. mindspore/utils/utils.py +144 -8
  249. mindspore/version.py +1 -1
  250. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/METADATA +3 -2
  251. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/RECORD +254 -267
  252. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -210
  253. mindspore/experimental/llm_boost/ascend_native/llm_boost.py +0 -52
  254. mindspore/experimental/llm_boost/atb/boost_base.py +0 -385
  255. mindspore/experimental/llm_boost/atb/llama_boost.py +0 -137
  256. mindspore/experimental/llm_boost/atb/qwen_boost.py +0 -124
  257. mindspore/experimental/llm_boost/register.py +0 -130
  258. mindspore/experimental/llm_boost/utils.py +0 -31
  259. mindspore/include/OWNERS +0 -7
  260. mindspore/mindspore_cpu_res_manager.dll +0 -0
  261. mindspore/mindspore_ops_kernel_common.dll +0 -0
  262. mindspore/mindspore_res_manager.dll +0 -0
  263. mindspore/nn/optim/_dist_optimizer_registry.py +0 -111
  264. mindspore/nn/reinforcement/_batch_read_write.py +0 -142
  265. mindspore/nn/reinforcement/_tensors_queue.py +0 -152
  266. mindspore/nn/reinforcement/tensor_array.py +0 -145
  267. mindspore/opencv_core452.dll +0 -0
  268. mindspore/opencv_imgcodecs452.dll +0 -0
  269. mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +0 -113
  270. mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +0 -96
  271. mindspore/ops/_op_impl/aicpu/sparse_cross.py +0 -42
  272. mindspore/ops/_op_impl/cpu/buffer_append.py +0 -28
  273. mindspore/ops/_op_impl/cpu/buffer_get.py +0 -28
  274. mindspore/ops/_op_impl/cpu/buffer_sample.py +0 -28
  275. mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +0 -42
  276. mindspore/ops/operations/_tensor_array.py +0 -359
  277. mindspore/ops/operations/rl_ops.py +0 -288
  278. mindspore/parallel/_offload_context.py +0 -275
  279. mindspore/parallel/_recovery_context.py +0 -115
  280. mindspore/parallel/_transformer/__init__.py +0 -35
  281. mindspore/parallel/_transformer/layers.py +0 -765
  282. mindspore/parallel/_transformer/loss.py +0 -251
  283. mindspore/parallel/_transformer/moe.py +0 -693
  284. mindspore/parallel/_transformer/op_parallel_config.py +0 -222
  285. mindspore/parallel/_transformer/transformer.py +0 -3124
  286. mindspore/parallel/mpi/_mpi_config.py +0 -116
  287. mindspore/train/memory_profiling_pb2.py +0 -298
  288. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/WHEEL +0 -0
  289. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/entry_points.txt +0 -0
  290. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/top_level.txt +0 -0
@@ -15,6 +15,7 @@
15
15
  """The removable handle for cell hook function."""
16
16
  from __future__ import absolute_import
17
17
  import weakref
18
+ from collections import OrderedDict
18
19
  from mindspore._c_expression import TensorPy as Tensor_
19
20
  from mindspore._check_jit_forbidden_api import jit_forbidden_register
20
21
 
@@ -173,3 +174,62 @@ class HookHandle:
173
174
  extra_dict = self.extra_dict_ref()
174
175
  if extra_dict is not None and self.handle_id in extra_dict:
175
176
  del extra_dict[self.handle_id]
177
+
178
+
179
+ def _check_hook_results(pre_res, new_res, hook_fn):
180
+ if not isinstance(new_res, tuple):
181
+ raise RuntimeError(f"hook {hook_fn.__name__} should return a tuple of grad.")
182
+
183
+ new_res_len = len(new_res)
184
+ pre_res_len = len(pre_res)
185
+ if new_res_len != pre_res_len:
186
+ raise RuntimeError(
187
+ f"hook {hook_fn.__name__} returned incorrect length {new_res_len}, expected {pre_res_len}."
188
+ )
189
+
190
+
191
+ class _HookUtils:
192
+ r"""
193
+ Internal utility class for hook registration and execution.
194
+ """
195
+
196
+ @staticmethod
197
+ def register_hook(hook_dict, hook_fn):
198
+ """
199
+ Register hook
200
+
201
+ Args:
202
+ hook_dict (dict): hook dict.
203
+ hook_fn (function): hook function.
204
+
205
+ Returns:
206
+ tuple: Updated hook_dict and HookHandle object.
207
+ """
208
+ if hook_dict is None:
209
+ hook_dict = OrderedDict()
210
+ handle = HookHandle(hook_dict)
211
+ hook_dict[handle.handle_id] = hook_fn
212
+ return hook_dict, handle
213
+
214
+ @staticmethod
215
+ def run_hook(hook_dict, args):
216
+ """
217
+ Run all hooks in the hook_dict with the given arguments.
218
+
219
+ Args:
220
+ hook_dict (dict): Dictionary of registered hooks.
221
+ args (tuple): Arguments to pass to the hook functions.
222
+
223
+ Returns:
224
+ Modified first argument if any hook returns a new value; otherwise, None.
225
+ """
226
+ is_modify = False
227
+ args_list = list(args)
228
+ # Note: We create a list from hook_dict.values() to ensure safe iteration.
229
+ for hook_fn in list(hook_dict.values()):
230
+ res = hook_fn(*args_list)
231
+ if res is not None:
232
+ _check_hook_results(args_list[0], res, hook_fn)
233
+ args_list[0] = res
234
+ is_modify = True
235
+ return args_list[0] if is_modify else None
@@ -27,7 +27,11 @@ class JitConfig:
27
27
  adopt KernelByKernel execution mode.
28
28
  - ``"O1"``: Using commonly used optimizations and automatic operator fusion optimizations,
29
29
  adopt KernelByKernel execution mode.
30
- - ``"O2"``: Ultimate performance optimization, adopt Sink execution mode.
30
+ - ``"O2"``: Utilizes the GraphEngine, a graph compilation and execution engine within CANN,
31
+ for Ascend model compilation and execution. Note: O2 only supports GRAPH Mode in Ascend,
32
+ only supports whole graph sinking or sub graph sinking in pipeline parallel, and does not support
33
+ dynamic shape scenes. In addition, this mode incurs additional compilation costs and is difficult to
34
+ debug and tune.
31
35
 
32
36
  exc_mode (str, optional): Control the execution mode of the model.
33
37
  Supports ["auto", "sink", "no_sink"]. Default: ``"auto"`` .
@@ -28,6 +28,7 @@ from mindspore._c_expression import TraceRecorder as tr
28
28
  from mindspore._c_expression import JitExecutor_
29
29
  from mindspore._c_expression import TensorPy as Tensor, CSRTensor, COOTensor
30
30
  from mindspore._c_expression import typing
31
+ from mindspore.common.jit_config import JitConfig
31
32
 
32
33
 
33
34
  class TraceJitContext(JitContext):
@@ -123,19 +124,19 @@ def nested_run(obj, cell, *args):
123
124
  return file_names, linenos, res
124
125
 
125
126
 
126
- def _jit_trace():
127
+ def _jit_trace(jit_config):
127
128
  """Return the wrapped function for trace mode jit."""
128
129
  def wrap_func(fn):
129
130
  if hasattr(fn, "construct"):
130
131
  if isinstance(fn, ms.nn.Cell):
131
132
  # Bound the cell object to get the self arg.
132
- return types.MethodType(_jit_trace()(fn.construct.__func__), fn)
133
+ return types.MethodType(_jit_trace(jit_config)(fn.construct.__func__), fn)
133
134
  if isinstance(fn, type) and issubclass(fn, ms.nn.Cell):
134
- fn.construct = _jit_trace()(fn.construct)
135
+ fn.construct = _jit_trace(jit_config)(fn.construct)
135
136
  return fn
136
137
 
137
138
  if isinstance(fn, types.MethodType):
138
- return types.MethodType(_jit_trace()(fn.__func__), fn.__self__)
139
+ return types.MethodType(_jit_trace(jit_config)(fn.__func__), fn.__self__)
139
140
 
140
141
  if not isinstance(fn, types.FunctionType):
141
142
  logger.warning(f"The fn should be function, method or cell instance/class, but got {fn}")
@@ -150,6 +151,10 @@ def _jit_trace():
150
151
  if jit_context():
151
152
  return fn(*args, **kwargs)
152
153
  # Start trace process.
154
+ if jit_config:
155
+ jit_config_dict = jit_config.jit_config_dict
156
+ else:
157
+ jit_config_dict = JitConfig().jit_config_dict
153
158
  if kwargs:
154
159
  bound_arguments = inspect.signature(fn).bind(*args, **kwargs)
155
160
  bound_arguments.apply_defaults()
@@ -170,14 +175,16 @@ def _jit_trace():
170
175
  line_str = fn.__code__.co_filename + ":" + str(fn.__code__.co_firstlineno)
171
176
  generate_name = generate_name + '#[' + line_str + ']'
172
177
 
173
- new_compile = _jit_trace_begin(generate_name, *jit_args)
178
+ new_compile = _jit_trace_begin(
179
+ generate_name, *jit_args, jit_config=jit_config_dict)
174
180
  if new_compile:
175
181
  fn_res = fn(*args, **kwargs)
176
182
  logger.debug(f'fn: {fn}, fn_res: {fn_res}, line: {line_str}')
177
183
  # Use fn's output to build func graph's output.
178
- output = _jit_trace_end(fn_res)
184
+ output = _jit_trace_end(fn_res, jit_config=jit_config_dict)
179
185
  else:
180
- output = _jit_trace_end(None) # Run with compilation.
186
+ # Run with compilation.
187
+ output = _jit_trace_end(None, jit_config=jit_config_dict)
181
188
  logger.debug(f'output: {output}')
182
189
  return output
183
190
 
@@ -224,7 +231,7 @@ def _get_args_for_run(args):
224
231
  return tuple(new_args)
225
232
 
226
233
 
227
- def _jit_trace_begin(fn_name, *args):
234
+ def _jit_trace_begin(fn_name, *args, **kwargs):
228
235
  """
229
236
  Start to build a MindIR func graph for a code snippet by trace method.
230
237
 
@@ -257,6 +264,10 @@ def _jit_trace_begin(fn_name, *args):
257
264
  ...
258
265
  >>> out = tensor_add(x, y)
259
266
  """
267
+ if "jit_config" in kwargs:
268
+ jit_config = kwargs["jit_config"]
269
+ else:
270
+ jit_config = JitConfig().jit_config_dict
260
271
  global _using_trace
261
272
  if _using_trace:
262
273
  raise RuntimeError(
@@ -279,7 +290,7 @@ def _jit_trace_begin(fn_name, *args):
279
290
  if not _compile_only and phase in _trace_compile_cache:
280
291
  logger.debug('Had compiled, just run.')
281
292
  _trace_jit_context.compiled = True
282
- output = tr.get_instance().run_graph(phase, args)
293
+ output = tr.get_instance().run_graph(phase, jit_config, args)
283
294
  from mindspore.common.api import _convert_python_data
284
295
  _trace_jit_context.result = _convert_python_data(output)
285
296
  logger.debug(f'jit trace result: {_trace_jit_context.result}')
@@ -295,7 +306,7 @@ def _jit_trace_begin(fn_name, *args):
295
306
  return True
296
307
 
297
308
 
298
- def _jit_trace_end(*output_args):
309
+ def _jit_trace_end(*output_args, **kwargs):
299
310
  """
300
311
  Finish building a MindIR func graph for a code snippet by trace method.
301
312
 
@@ -330,19 +341,23 @@ def _jit_trace_end(*output_args):
330
341
  ...
331
342
  >>> out = tensor_add(x, y)
332
343
  """
344
+ if "jit_config" in kwargs:
345
+ jit_config = kwargs["jit_config"]
346
+ else:
347
+ jit_config = JitConfig().jit_config_dict
333
348
  if _trace_jit_context.compiled:
334
349
  output = _trace_jit_context.result
335
350
  logger.debug(f'jit trace result: {output}')
336
351
  else:
337
352
  logger.debug(f'output_args: {output_args}')
338
353
  file_names, linenos = _get_caller_lines()
339
- tr.get_instance().end_graph(file_names, linenos, *output_args)
354
+ tr.get_instance().end_graph(file_names, linenos, jit_config, *output_args)
340
355
  if _compile_only:
341
356
  output = output_args[0] if len(output_args) == 1 else output_args
342
357
  else:
343
358
  args = _get_args_for_run(_trace_jit_context.args)
344
359
  output = tr.get_instance().run_graph(
345
- _trace_jit_context.phase, args)
360
+ _trace_jit_context.phase, jit_config, args)
346
361
  from mindspore.common.api import _convert_python_data
347
362
  output = _convert_python_data(output)
348
363
  logger.debug(f'jit trace result: {output}')
@@ -32,9 +32,11 @@ def lazy_inline(fn=None, attrs=None, policy=None):
32
32
  static_graph_expert_programming.html#using-lazy-inline-decorator>`_ .
33
33
 
34
34
  .. warning::
35
- This feature is only supported on Ascend and is not supported on other hardwares.
36
- The construct parameters must be positional or key word arguments and have not default values.
37
- The cell has not switch sub graph.
35
+ - This feature is only supported on Ascend and is not supported on other hardwares.
36
+ - The construct parameters must be positional or key word arguments and have not default values.
37
+ - The cell has not switch sub graph.
38
+ - In the gradient accumulation scenario, it is recommended to use the @lazy_inline decorator to
39
+ reduce compilation time, and this decorator is only allowed to configure on the outermost cell.
38
40
 
39
41
  Args:
40
42
  fn (function): `__init__` function of a cell.
@@ -21,7 +21,6 @@ from copy import copy
21
21
  import time
22
22
  import os
23
23
  import sys
24
- import math
25
24
  import numbers
26
25
  import numpy as np
27
26
 
@@ -29,8 +28,6 @@ from mindspore import log as logger
29
28
  from mindspore.log import _LogActionOnce
30
29
  from mindspore._c_expression import ParamInfo
31
30
  from mindspore.common import dtype as mstype
32
- from mindspore import context
33
- from mindspore.common._utils import get_slice_num, get_slice_shape
34
31
  from mindspore.common.initializer import initializer
35
32
  from mindspore.common.tensor import Tensor, _TensorMeta
36
33
  from mindspore.common.hook_handle import _update_hook_version
@@ -39,10 +36,6 @@ from mindspore._check_jit_forbidden_api import jit_forbidden_register
39
36
  from mindspore._c_expression import TensorPy as Tensor_
40
37
  from mindspore.parallel._tensor import _get_slice_index
41
38
  from mindspore.parallel._auto_parallel_context import auto_parallel_context
42
- from mindspore.parallel._ps_context import _is_role_worker, _is_role_pserver, _is_role_sched, _clone_hash_table, \
43
- _is_ps_mode
44
- from mindspore.parallel._ps_context import _reinsert_hash_table_size, _insert_accumu_init_info, _cache_enable
45
- from mindspore.common._decorator import deprecated
46
39
  from mindspore.communication._comm_helper import _is_initialized
47
40
  from mindspore.communication import get_group_size, get_rank
48
41
  import mindspore.common._monad as monad
@@ -138,11 +131,7 @@ def _offload_if_config(data):
138
131
  Args:
139
132
  data: The parameter data to offload.
140
133
  """
141
- if not context.get_context("memory_offload") or data is None:
142
- return
143
-
144
- offload_context = context.get_offload_context()
145
- if offload_context.get("offload_param", None) != "disk":
134
+ if data is None:
146
135
  return
147
136
 
148
137
  data_size_threshold = 512
@@ -219,7 +208,10 @@ class Parameter(Tensor_):
219
208
  self.param_a = Parameter(Tensor([1], ms.float32), name="name_a")
220
209
  self.param_tuple = (self.param_a, self.param_a)
221
210
 
222
- requires_grad (bool): True if the parameter requires gradient. Default: ``True`` .
211
+ requires_grad (bool): It is Used to filter parameters in :func:`mindspore.nn.Cell.trainable_params()`.
212
+ If it is ``False``, the filter parameters will not be returned in
213
+ :func:`mindspore.nn.Cell.trainable_params()`.
214
+ Default: ``True`` .
223
215
  layerwise_parallel (bool): When `layerwise_parallel` is true in data/hybrid parallel mode,
224
216
  broadcast and gradients communication would not be applied to the `Parameter`. Default: ``False`` .
225
217
  parallel_optimizer (bool): It is used to filter the weight shard operation in parallel mode. It works only when
@@ -230,10 +222,8 @@ class Parameter(Tensor_):
230
222
  device(str): Only Ascend device target is supported. It is used to specify the device which the parameter is
231
223
  stored. By default, the parameter will be stored on NPU while computing. When the device is specified as
232
224
  ``"CPU"``, the parameter will be loaded into the device when it needs to be used, and unloaded to the CPU
233
- after use. It takes effext only when `memory_offload` is ``"ON"``, `jit_level` is not ``"O2"`` and
234
- `memory_optimize_level` is ``O0`` in :func:`mindspore.set_context`.
235
- Less device memory is needed when device is
236
- specified as ``"CPU"``.
225
+ after use. It takes effext only when `jit_level` is not ``"O2"`` and `memory_optimize_level` is ``O0``
226
+ in :func:`mindspore.set_context`. Less device memory is needed when device is specified as ``"CPU"``.
237
227
 
238
228
  Examples:
239
229
  >>> import numpy as np
@@ -272,8 +262,6 @@ class Parameter(Tensor_):
272
262
  obj.is_default_input_init = init_data_flag
273
263
  if obj.has_init:
274
264
  obj.init_mode = default_input
275
- else:
276
- _offload_if_config(obj)
277
265
  return obj
278
266
 
279
267
  def __reduce_ex__(self, _):
@@ -289,7 +277,6 @@ class Parameter(Tensor_):
289
277
  def __init__(self, default_input, name=None, requires_grad=True, layerwise_parallel=False, parallel_optimizer=True,
290
278
  storage_format="", device=None):
291
279
  self.param_info = ParamInfo()
292
- self.init_in_server = False
293
280
  self.name = name
294
281
  self.requires_grad = requires_grad
295
282
  self.layerwise_parallel = layerwise_parallel
@@ -300,32 +287,15 @@ class Parameter(Tensor_):
300
287
  self.is_init = False
301
288
  self._inited_param = None
302
289
  self._sliced = False
303
- self.is_param_ps = False
304
- self.push_weight_to_server = False
305
- self.pull_weight_from_server = False
306
290
  self.requires_aggr = True
307
291
  self._cast_type = None
308
292
  self._unique = False
309
293
  self.is_in_parallel = _is_in_auto_parallel_mode()
310
294
  self._pipeline_stage_list = []
311
- self.slice_num = 1
312
295
  if -1 in self.shape:
313
296
  raise ValueError(f"All shape elements of the Parameter must be positive. But got None.")
314
297
  if isinstance(default_input, (Tensor_, Tensor)):
315
- # At embedding cache scenes, we need limit the size of memory for parameter.
316
- # And save out range data to persistent storage to support TB-Level size parameter.
317
- slice_num_of_persistent_data = get_slice_num(default_input.dtype, default_input.shape)
318
- if slice_num_of_persistent_data > 1:
319
- data_shape = list(default_input.shape)
320
- slice_first_dim = math.ceil(data_shape[0] / slice_num_of_persistent_data)
321
- data_shape[0] = slice_first_dim
322
- self.param_info.use_persistent_storage = True
323
- self.param_info.origin_shape = default_input.shape
324
- self.slice_num = slice_num_of_persistent_data
325
- Tensor_.__init__(self, dtype=default_input.dtype, shape=tuple(data_shape))
326
- else:
327
- Tensor_.__init__(self, dtype=default_input.dtype, shape=default_input.shape)
328
-
298
+ Tensor_.__init__(self, dtype=default_input.dtype, shape=default_input.shape)
329
299
  elif isinstance(default_input, int):
330
300
  Tensor_.__init__(self, dtype=mstype.int64, shape=())
331
301
  elif isinstance(default_input, float):
@@ -387,11 +357,10 @@ class Parameter(Tensor_):
387
357
  return (Tensor, data.asnumpy(), mstype.qint4x2)
388
358
  return (Tensor, data.asnumpy())
389
359
 
390
- not_init_data = not init_param or _is_role_sched() or (_is_role_pserver() and _cache_enable()) \
391
- or _is_in_auto_parallel_mode() or _is_parallel_mode()
360
+ not_init_data = not init_param or _is_in_auto_parallel_mode() or _is_parallel_mode()
392
361
  if not_init_data:
393
362
  # do not init data while in auto parallel.
394
- return (Tensor, None, data.dtype, get_slice_shape(data.dtype, data.shape), data.init)
363
+ return (Tensor, None, data.dtype, data.shape, data.init)
395
364
  return (Tensor, data.init_data())
396
365
  if isinstance(data, int):
397
366
  return (Tensor, data, mstype.int32)
@@ -399,29 +368,6 @@ class Parameter(Tensor_):
399
368
  return (Tensor, data, mstype.float32)
400
369
  return (Tensor, data)
401
370
 
402
- def set_param_ps(self, init_in_server=False):
403
- """
404
- Set whether the trainable parameter is updated by parameter server and whether the
405
- trainable parameter is initialized on server.
406
-
407
- Note:
408
- It only works when a running task is in the parameter server mode.
409
- It is supported only in graph mode.
410
-
411
- Args:
412
- init_in_server (bool): Whether trainable parameter updated by parameter server is
413
- initialized on server. Default: ``False``.
414
-
415
- """
416
- if not _is_ps_mode() or not (_is_role_worker() or _is_role_pserver() or _is_role_sched()):
417
- raise RuntimeError("Must complete following two steps before calling set_param_ps: \n"
418
- "1. context.set_ps_context(enable_ps=True) \n"
419
- "2. export MS_ROLE environment variable \n"
420
- "Please refer to the official website for detailed usage.")
421
- self.is_param_ps = True
422
- self.init_in_server = init_in_server
423
- self.param_info.init_in_server = init_in_server
424
-
425
371
  def copy(self):
426
372
  """
427
373
  Copy the parameter.
@@ -437,16 +383,6 @@ class Parameter(Tensor_):
437
383
  """
438
384
  return self.clone(init='same')
439
385
 
440
- @deprecated("1.8", "set_param_fl")
441
- def set_param_fl(self, push_to_server=False, pull_from_server=False, requires_aggr=True):
442
- if push_to_server:
443
- self.push_weight_to_server = True
444
- if pull_from_server:
445
- self.pull_weight_from_server = True
446
- if not requires_aggr:
447
- self.requires_aggr = False
448
- self.param_info.requires_aggr = False
449
-
450
386
  @property
451
387
  def inited_param(self):
452
388
  """
@@ -512,8 +448,6 @@ class Parameter(Tensor_):
512
448
  raise ValueError("The type of the Parameter's name should be 'string' or 'None', "
513
449
  "but got {}.".format(type(name_)))
514
450
 
515
- if _is_role_worker() and self.cache_enable:
516
- _reinsert_hash_table_size(name_, self.param_info.name)
517
451
  self.param_info.name = name_
518
452
 
519
453
  @property
@@ -642,8 +576,6 @@ class Parameter(Tensor_):
642
576
  x.param_info = param_info_clone
643
577
  x.is_init = False
644
578
  x.init = self.init
645
- x.is_param_ps = self.is_param_ps
646
- x.init_in_server = self.init_in_server
647
579
  x.cache_enable = self.cache_enable
648
580
  if x.cache_enable:
649
581
  x.key = _get_unique_parameter_key()
@@ -651,7 +583,7 @@ class Parameter(Tensor_):
651
583
  if self.cache_shape:
652
584
  x.cache_shape = self.cache_shape
653
585
  if init != 'same':
654
- shape = self.shape if self.slice_num == 1 else self.param_info.origin_shape
586
+ shape = self.shape
655
587
  dtype = self.dtype
656
588
  tensor = initializer(init, shape=shape, dtype=dtype)
657
589
  x.set_data(tensor)
@@ -796,6 +728,7 @@ class Parameter(Tensor_):
796
728
  raise TypeError("The argument `requires_grad` must be bool type")
797
729
  Tensor_.wait_pipeline(self)
798
730
  self.param_info.requires_grad = value
731
+ self._requires_grad = value
799
732
 
800
733
  @property
801
734
  def data(self):
@@ -862,20 +795,6 @@ class Parameter(Tensor_):
862
795
  raise TypeError("The original tensor data is initialized, but the argument 'data' is not initialized."
863
796
  "Please initialize 'data' before call this method.")
864
797
 
865
- @staticmethod
866
- def _from_tensor(tensor, *args, **kwargs):
867
- """Create a `Parameter` that data is shared from a `Tensor`."""
868
- if not isinstance(tensor, Tensor_):
869
- raise TypeError(f"The type of input must be Tensor, but got {type(tensor)}.")
870
- param = Tensor_.__new__(Parameter)
871
- Tensor_.__init__(param, tensor)
872
- param.init = None
873
- param.init_mode = None
874
- param.has_init = False
875
- param.is_default_input_init = False
876
- Parameter.__init__(param, tensor, *args, **kwargs)
877
- return param
878
-
879
798
  @jit_forbidden_register
880
799
  def set_data(self, data, slice_shape=False):
881
800
  """
@@ -981,16 +900,7 @@ class Parameter(Tensor_):
981
900
 
982
901
  init_data_args = self._get_init_data_args(layout)
983
902
 
984
- if _is_role_sched():
985
- return self
986
- if self.init_in_server and self.is_param_ps and isinstance(self.init_mode, Tensor) and \
987
- self.init_mode.init is not None and _is_role_worker():
988
- if self.cache_enable:
989
- data = self.init_mode.init_data(*init_data_args)
990
- else:
991
- data = self.init_mode.init_data(0, [1])
992
- else:
993
- data = self.init_mode.init_data(*init_data_args)
903
+ data = self.init_mode.init_data(*init_data_args)
994
904
  origin_dtype = self.dtype
995
905
  obj = self._update_tensor_data(data)
996
906
  if self.dtype != origin_dtype:
@@ -999,7 +909,6 @@ class Parameter(Tensor_):
999
909
  self._inited_param = obj
1000
910
  obj.init_mode = None
1001
911
  obj.sliced = set_sliced
1002
- _offload_if_config(obj)
1003
912
  return obj
1004
913
 
1005
914
  def register_hook(self, hook_fn):
@@ -1154,9 +1063,6 @@ class ParameterTuple(tuple):
1154
1063
  if not x1.cache_enable:
1155
1064
  continue
1156
1065
 
1157
- if _is_role_worker():
1158
- _clone_hash_table(x.name, x.key, x1.name, x1.key)
1159
- _insert_accumu_init_info(x1.name, init_to_value(init))
1160
1066
  return ParameterTuple(new)
1161
1067
 
1162
1068
  def __parameter_tuple__(self):
@@ -22,11 +22,10 @@ from mindspore.common.tensor import Tensor
22
22
  from mindspore import ops
23
23
  from mindspore.ops.composite import GradOperation
24
24
  from mindspore.common._register_for_recompute import recompute_registry
25
- from mindspore.common.api import _pynative_executor, _no_grad
25
+ from mindspore.common.api import _pynative_executor, _no_grad, _run_in_jit
26
26
  from mindspore.common.generator import get_rng_state, set_rng_state
27
27
  from mindspore.train.amp import AmpDecorator
28
28
  from mindspore._c_expression.amp import get_curr_amp_strategy
29
- from mindspore._check_jit_forbidden_api import jit_forbidden_register
30
29
 
31
30
 
32
31
  class _WrapCell(Cell):
@@ -211,22 +210,15 @@ def _detach_input(input_arg):
211
210
  def _check_validation(block):
212
211
  if not isinstance(block, Cell):
213
212
  raise TypeError("Recompute function now only support block which inherited from Cell!")
214
- if block.construct.__code__.co_name == "staging_specialize":
215
- logger.warning('Block\'s construct method decorated by @jit that recompute '
216
- 'function will not come into effect.')
217
213
 
218
214
 
219
- @jit_forbidden_register
220
215
  def recompute(block, *args, **kwargs):
221
216
  r"""
222
217
  This function is used to reduce memory, when run block, rather than
223
218
  storing the intermediate activation computed in forward pass, we will recompute it in backward pass.
224
219
 
225
220
  Note:
226
- - Recompute function only support block which inherited from Cell object.
227
- - This function interface now only support pynative mode. you can use Cell.recompute interface
228
- in graph mode.
229
- - When use recompute function, block object should not decorated by @jit.
221
+ Recompute function only support block which inherited from Cell object.
230
222
 
231
223
  Args:
232
224
  block (Cell): Block to be recompute.
@@ -238,7 +230,6 @@ def recompute(block, *args, **kwargs):
238
230
 
239
231
  Raises:
240
232
  TypeError: If `block` is not Cell object.
241
- AssertionError: If execute mode is not PYNATIVE_MODE.
242
233
 
243
234
  Supported Platforms:
244
235
  ``Ascend`` ``GPU`` ``CPU``
@@ -272,6 +263,8 @@ def recompute(block, *args, **kwargs):
272
263
  """
273
264
 
274
265
  _check_validation(block)
266
+ if _run_in_jit(): # @jit.cond: True
267
+ return ops.recompute_block(block)(*args, **kwargs)
275
268
  return _RecomputeCell(block)(*args, **kwargs)
276
269
 
277
270