mindspore 2.3.0__cp310-cp310-win_amd64.whl → 2.4.1__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (275) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +3 -1
  3. mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
  6. mindspore/_checkparam.py +50 -9
  7. mindspore/_extends/parse/compile_config.py +41 -0
  8. mindspore/_extends/parse/parser.py +9 -7
  9. mindspore/_extends/parse/standard_method.py +52 -14
  10. mindspore/_extends/pijit/pijit_func_white_list.py +350 -24
  11. mindspore/amp.py +24 -10
  12. mindspore/common/__init__.py +6 -4
  13. mindspore/common/_pijit_context.py +190 -0
  14. mindspore/common/_register_for_tensor.py +2 -1
  15. mindspore/common/_tensor_overload.py +139 -0
  16. mindspore/common/api.py +102 -87
  17. mindspore/common/dump.py +5 -6
  18. mindspore/common/generator.py +1 -7
  19. mindspore/common/hook_handle.py +14 -26
  20. mindspore/common/initializer.py +51 -15
  21. mindspore/common/mindir_util.py +2 -2
  22. mindspore/common/parameter.py +62 -15
  23. mindspore/common/recompute.py +39 -9
  24. mindspore/common/sparse_tensor.py +7 -3
  25. mindspore/common/tensor.py +183 -37
  26. mindspore/communication/__init__.py +1 -1
  27. mindspore/communication/_comm_helper.py +38 -3
  28. mindspore/communication/comm_func.py +315 -60
  29. mindspore/communication/management.py +14 -14
  30. mindspore/context.py +132 -22
  31. mindspore/dataset/__init__.py +1 -1
  32. mindspore/dataset/audio/__init__.py +1 -1
  33. mindspore/dataset/core/config.py +7 -0
  34. mindspore/dataset/core/validator_helpers.py +7 -0
  35. mindspore/dataset/engine/cache_client.py +1 -1
  36. mindspore/dataset/engine/datasets.py +72 -44
  37. mindspore/dataset/engine/datasets_audio.py +7 -7
  38. mindspore/dataset/engine/datasets_standard_format.py +53 -3
  39. mindspore/dataset/engine/datasets_text.py +20 -20
  40. mindspore/dataset/engine/datasets_user_defined.py +174 -104
  41. mindspore/dataset/engine/datasets_vision.py +33 -33
  42. mindspore/dataset/engine/iterators.py +29 -0
  43. mindspore/dataset/engine/obs/util.py +7 -0
  44. mindspore/dataset/engine/queue.py +114 -60
  45. mindspore/dataset/engine/serializer_deserializer.py +2 -2
  46. mindspore/dataset/engine/validators.py +34 -14
  47. mindspore/dataset/text/__init__.py +1 -4
  48. mindspore/dataset/transforms/__init__.py +0 -3
  49. mindspore/dataset/utils/line_reader.py +2 -0
  50. mindspore/dataset/vision/__init__.py +1 -4
  51. mindspore/dataset/vision/utils.py +1 -1
  52. mindspore/dataset/vision/validators.py +2 -1
  53. mindspore/{nn/extend → experimental/es}/__init__.py +4 -11
  54. mindspore/experimental/es/embedding_service.py +883 -0
  55. mindspore/{nn/layer → experimental/es}/embedding_service_layer.py +218 -30
  56. mindspore/experimental/llm_boost/__init__.py +21 -0
  57. mindspore/{nn/extend/layer → experimental/llm_boost/atb}/__init__.py +4 -8
  58. mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
  59. mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
  60. mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
  61. mindspore/experimental/llm_boost/register.py +129 -0
  62. mindspore/experimental/llm_boost/utils.py +31 -0
  63. mindspore/experimental/optim/adamw.py +85 -0
  64. mindspore/experimental/optim/optimizer.py +3 -0
  65. mindspore/hal/__init__.py +3 -3
  66. mindspore/hal/contiguous_tensors_handle.py +175 -0
  67. mindspore/hal/stream.py +18 -0
  68. mindspore/include/api/model_group.h +13 -1
  69. mindspore/include/api/types.h +10 -10
  70. mindspore/include/dataset/config.h +2 -2
  71. mindspore/include/dataset/constants.h +2 -2
  72. mindspore/include/dataset/execute.h +2 -2
  73. mindspore/include/dataset/vision.h +4 -0
  74. mindspore/log.py +1 -1
  75. mindspore/mindrecord/filewriter.py +68 -51
  76. mindspore/mindspore_backend.dll +0 -0
  77. mindspore/mindspore_common.dll +0 -0
  78. mindspore/mindspore_core.dll +0 -0
  79. mindspore/mindspore_np_dtype.dll +0 -0
  80. mindspore/mindspore_ops.dll +0 -0
  81. mindspore/mint/__init__.py +983 -46
  82. mindspore/mint/distributed/__init__.py +31 -0
  83. mindspore/mint/distributed/distributed.py +254 -0
  84. mindspore/mint/nn/__init__.py +268 -23
  85. mindspore/mint/nn/functional.py +125 -19
  86. mindspore/mint/nn/layer/__init__.py +39 -0
  87. mindspore/mint/nn/layer/activation.py +133 -0
  88. mindspore/mint/nn/layer/normalization.py +477 -0
  89. mindspore/mint/nn/layer/pooling.py +110 -0
  90. mindspore/mint/optim/adamw.py +26 -13
  91. mindspore/mint/special/__init__.py +63 -0
  92. mindspore/multiprocessing/__init__.py +2 -1
  93. mindspore/nn/__init__.py +0 -1
  94. mindspore/nn/cell.py +276 -96
  95. mindspore/nn/layer/activation.py +211 -44
  96. mindspore/nn/layer/basic.py +137 -10
  97. mindspore/nn/layer/embedding.py +137 -2
  98. mindspore/nn/layer/normalization.py +101 -5
  99. mindspore/nn/layer/padding.py +34 -48
  100. mindspore/nn/layer/pooling.py +161 -7
  101. mindspore/nn/layer/transformer.py +3 -3
  102. mindspore/nn/loss/__init__.py +2 -2
  103. mindspore/nn/loss/loss.py +84 -6
  104. mindspore/nn/optim/__init__.py +2 -1
  105. mindspore/nn/optim/adadelta.py +1 -1
  106. mindspore/nn/optim/adam.py +1 -1
  107. mindspore/nn/optim/lamb.py +1 -1
  108. mindspore/nn/optim/tft_wrapper.py +124 -0
  109. mindspore/nn/wrap/cell_wrapper.py +12 -23
  110. mindspore/nn/wrap/grad_reducer.py +5 -5
  111. mindspore/nn/wrap/loss_scale.py +17 -3
  112. mindspore/numpy/__init__.py +1 -1
  113. mindspore/numpy/array_creations.py +65 -68
  114. mindspore/numpy/array_ops.py +64 -60
  115. mindspore/numpy/fft.py +610 -75
  116. mindspore/numpy/logic_ops.py +11 -10
  117. mindspore/numpy/math_ops.py +85 -84
  118. mindspore/numpy/utils_const.py +4 -4
  119. mindspore/opencv_core452.dll +0 -0
  120. mindspore/opencv_imgcodecs452.dll +0 -0
  121. mindspore/opencv_imgproc452.dll +0 -0
  122. mindspore/ops/__init__.py +6 -4
  123. mindspore/ops/_grad_experimental/grad_array_ops.py +0 -11
  124. mindspore/ops/_grad_experimental/grad_comm_ops.py +67 -4
  125. mindspore/ops/_grad_experimental/grad_math_ops.py +0 -22
  126. mindspore/ops/_vmap/vmap_array_ops.py +2 -4
  127. mindspore/ops/_vmap/vmap_math_ops.py +17 -1
  128. mindspore/ops/_vmap/vmap_nn_ops.py +43 -2
  129. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +91 -7
  130. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +2 -0
  131. mindspore/ops/auto_generate/gen_extend_func.py +767 -13
  132. mindspore/ops/auto_generate/gen_ops_def.py +2452 -364
  133. mindspore/ops/auto_generate/gen_ops_prim.py +5442 -1756
  134. mindspore/ops/auto_generate/pyboost_inner_prim.py +176 -56
  135. mindspore/ops/composite/base.py +85 -48
  136. mindspore/ops/composite/multitype_ops/_compile_utils.py +1 -0
  137. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -2
  138. mindspore/ops/function/__init__.py +22 -0
  139. mindspore/ops/function/array_func.py +492 -153
  140. mindspore/ops/function/debug_func.py +113 -1
  141. mindspore/ops/function/fft_func.py +15 -2
  142. mindspore/ops/function/grad/grad_func.py +3 -2
  143. mindspore/ops/function/math_func.py +564 -207
  144. mindspore/ops/function/nn_func.py +817 -383
  145. mindspore/ops/function/other_func.py +3 -2
  146. mindspore/ops/function/random_func.py +402 -12
  147. mindspore/ops/function/reshard_func.py +13 -11
  148. mindspore/ops/function/sparse_unary_func.py +1 -1
  149. mindspore/ops/function/vmap_func.py +3 -2
  150. mindspore/ops/functional.py +24 -14
  151. mindspore/ops/op_info_register.py +3 -3
  152. mindspore/ops/operations/__init__.py +7 -2
  153. mindspore/ops/operations/_grad_ops.py +2 -76
  154. mindspore/ops/operations/_infer_ops.py +1 -1
  155. mindspore/ops/operations/_inner_ops.py +71 -94
  156. mindspore/ops/operations/array_ops.py +14 -146
  157. mindspore/ops/operations/comm_ops.py +63 -53
  158. mindspore/ops/operations/custom_ops.py +83 -19
  159. mindspore/ops/operations/debug_ops.py +42 -10
  160. mindspore/ops/operations/manually_defined/_inner.py +12 -0
  161. mindspore/ops/operations/manually_defined/ops_def.py +273 -20
  162. mindspore/ops/operations/math_ops.py +12 -223
  163. mindspore/ops/operations/nn_ops.py +20 -114
  164. mindspore/ops/operations/other_ops.py +7 -4
  165. mindspore/ops/operations/random_ops.py +46 -1
  166. mindspore/ops/primitive.py +18 -6
  167. mindspore/ops_generate/arg_dtype_cast.py +2 -0
  168. mindspore/ops_generate/gen_aclnn_implement.py +11 -11
  169. mindspore/ops_generate/gen_constants.py +36 -0
  170. mindspore/ops_generate/gen_ops.py +67 -52
  171. mindspore/ops_generate/gen_ops_inner_prim.py +1 -1
  172. mindspore/ops_generate/gen_pyboost_func.py +131 -47
  173. mindspore/ops_generate/op_proto.py +10 -3
  174. mindspore/ops_generate/pyboost_utils.py +14 -1
  175. mindspore/ops_generate/template.py +43 -21
  176. mindspore/parallel/__init__.py +3 -1
  177. mindspore/parallel/_auto_parallel_context.py +31 -9
  178. mindspore/parallel/_cell_wrapper.py +85 -0
  179. mindspore/parallel/_parallel_serialization.py +47 -19
  180. mindspore/parallel/_tensor.py +127 -13
  181. mindspore/parallel/_utils.py +53 -22
  182. mindspore/parallel/algo_parameter_config.py +5 -5
  183. mindspore/parallel/checkpoint_transform.py +46 -39
  184. mindspore/parallel/cluster/process_entity/__init__.py +1 -1
  185. mindspore/parallel/cluster/process_entity/_api.py +31 -23
  186. mindspore/parallel/cluster/process_entity/_utils.py +2 -27
  187. mindspore/parallel/parameter_broadcast.py +3 -4
  188. mindspore/parallel/shard.py +162 -31
  189. mindspore/parallel/transform_safetensors.py +1146 -0
  190. mindspore/profiler/__init__.py +2 -1
  191. mindspore/profiler/common/constant.py +29 -0
  192. mindspore/profiler/common/registry.py +47 -0
  193. mindspore/profiler/common/util.py +28 -0
  194. mindspore/profiler/dynamic_profiler.py +694 -0
  195. mindspore/profiler/envprofiling.py +17 -19
  196. mindspore/profiler/parser/ascend_analysis/constant.py +18 -0
  197. mindspore/profiler/parser/ascend_analysis/file_manager.py +25 -4
  198. mindspore/profiler/parser/ascend_analysis/function_event.py +43 -19
  199. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +31 -26
  200. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +56 -10
  201. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +55 -8
  202. mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
  203. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +27 -20
  204. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +9 -2
  205. mindspore/profiler/parser/ascend_msprof_exporter.py +5 -4
  206. mindspore/profiler/parser/ascend_timeline_generator.py +27 -25
  207. mindspore/profiler/parser/base_timeline_generator.py +19 -25
  208. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
  209. mindspore/profiler/parser/framework_parser.py +1 -391
  210. mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
  211. mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
  212. mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
  213. mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
  214. mindspore/profiler/parser/memory_usage_parser.py +0 -154
  215. mindspore/profiler/parser/profiler_info.py +78 -6
  216. mindspore/profiler/profiler.py +153 -0
  217. mindspore/profiler/profiling.py +285 -413
  218. mindspore/rewrite/__init__.py +1 -2
  219. mindspore/rewrite/common/namespace.py +4 -4
  220. mindspore/rewrite/symbol_tree/symbol_tree.py +3 -3
  221. mindspore/run_check/_check_version.py +39 -104
  222. mindspore/safeguard/rewrite_obfuscation.py +591 -247
  223. mindspore/train/__init__.py +4 -3
  224. mindspore/train/_utils.py +105 -19
  225. mindspore/train/amp.py +171 -53
  226. mindspore/train/callback/__init__.py +2 -2
  227. mindspore/train/callback/_callback.py +4 -4
  228. mindspore/train/callback/_checkpoint.py +97 -31
  229. mindspore/train/callback/_cluster_monitor.py +1 -1
  230. mindspore/train/callback/_flops_collector.py +1 -0
  231. mindspore/train/callback/_loss_monitor.py +3 -3
  232. mindspore/train/callback/_on_request_exit.py +145 -31
  233. mindspore/train/callback/_summary_collector.py +5 -5
  234. mindspore/train/callback/_tft_register.py +375 -0
  235. mindspore/train/dataset_helper.py +15 -3
  236. mindspore/train/metrics/metric.py +3 -3
  237. mindspore/train/metrics/roc.py +4 -4
  238. mindspore/train/mind_ir_pb2.py +44 -39
  239. mindspore/train/model.py +154 -58
  240. mindspore/train/serialization.py +342 -128
  241. mindspore/utils/__init__.py +21 -0
  242. mindspore/utils/utils.py +60 -0
  243. mindspore/version.py +1 -1
  244. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/METADATA +13 -7
  245. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/RECORD +248 -242
  246. mindspore/include/c_api/ms/abstract.h +0 -67
  247. mindspore/include/c_api/ms/attribute.h +0 -197
  248. mindspore/include/c_api/ms/base/handle_types.h +0 -43
  249. mindspore/include/c_api/ms/base/macros.h +0 -32
  250. mindspore/include/c_api/ms/base/status.h +0 -33
  251. mindspore/include/c_api/ms/base/types.h +0 -283
  252. mindspore/include/c_api/ms/context.h +0 -102
  253. mindspore/include/c_api/ms/graph.h +0 -160
  254. mindspore/include/c_api/ms/node.h +0 -606
  255. mindspore/include/c_api/ms/tensor.h +0 -161
  256. mindspore/include/c_api/ms/value.h +0 -84
  257. mindspore/mindspore_shared_lib.dll +0 -0
  258. mindspore/nn/extend/basic.py +0 -140
  259. mindspore/nn/extend/embedding.py +0 -143
  260. mindspore/nn/extend/layer/normalization.py +0 -109
  261. mindspore/nn/extend/pooling.py +0 -117
  262. mindspore/nn/layer/embedding_service.py +0 -531
  263. mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
  264. mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
  265. mindspore/ops/extend/__init__.py +0 -53
  266. mindspore/ops/extend/array_func.py +0 -218
  267. mindspore/ops/extend/math_func.py +0 -76
  268. mindspore/ops/extend/nn_func.py +0 -308
  269. mindspore/ops/silent_check.py +0 -162
  270. mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
  271. mindspore/profiler/parser/msadvisor_parser.py +0 -240
  272. mindspore/train/callback/_mindio_ttp.py +0 -443
  273. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/WHEEL +0 -0
  274. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/entry_points.txt +0 -0
  275. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/top_level.txt +0 -0
@@ -18,8 +18,8 @@ from __future__ import absolute_import
18
18
  import os
19
19
  import stat
20
20
  import time
21
-
22
21
  import threading
22
+
23
23
  import mindspore.context as context
24
24
  from mindspore import log as logger
25
25
  from mindspore import nn
@@ -37,14 +37,22 @@ from mindspore.common.tensor import Tensor
37
37
  from mindspore.common.parameter import Parameter
38
38
  from mindspore.common.generator import Generator
39
39
  from mindspore.common.api import _cell_graph_executor
40
- from mindspore._c_expression import _collect_host_info
41
-
40
+ from mindspore._c_expression import collect_host_info, get_clock_syscnt
42
41
 
43
42
  _cur_dir = os.getcwd()
44
43
  SAVE_DIR = _cur_dir
45
44
  _info_list = ["epoch_num", "step_num"]
46
45
 
47
46
 
47
+ def _wait_async_save_ckpt(async_save=False):
48
+ """Waiting for asynchronous saving of ckpt to complete."""
49
+ if async_save:
50
+ thread_list = threading.enumerate()
51
+ for thread in thread_list:
52
+ if thread.getName() == "asyn_save_ckpt":
53
+ thread.join()
54
+
55
+
48
56
  def _get_dp_tp_from_redundancy(redundancy_tuple):
49
57
  """From redundancy get dp and tp"""
50
58
  dp = []
@@ -88,9 +96,9 @@ def _chg_ckpt_file_name_if_same_exist(directory, prefix, exception=False):
88
96
  if index == 0:
89
97
  suffix_num = max(suffix_num, 1)
90
98
  elif index != -1:
91
- num = filename[pre_len+1:pre_len+index]
99
+ num = filename[pre_len + 1:pre_len + index]
92
100
  if num.isdigit():
93
- suffix_num = max(suffix_num, int(num)+1)
101
+ suffix_num = max(suffix_num, int(num) + 1)
94
102
 
95
103
  if suffix_num != 0:
96
104
  prefix = f'{prefix}_{suffix_num}'
@@ -98,6 +106,14 @@ def _chg_ckpt_file_name_if_same_exist(directory, prefix, exception=False):
98
106
  return prefix
99
107
 
100
108
 
109
+ def _check_format_and_other_params(format, enc_key, enc_mode, crc_check=False, async_save=False, exception_save=False,
110
+ map_param_inc=False, global_step_num=None):
111
+ param_not_default = (enc_key is not None or enc_mode != "AES-GCM" or crc_check or async_save
112
+ or exception_save or map_param_inc or global_step_num is not None)
113
+ if format == "safetensors" and param_not_default:
114
+ raise ValueError("For 'save_checkpoint', when format is 'safetensors', other param must be default.")
115
+
116
+
101
117
  class CheckpointConfig:
102
118
  """
103
119
  The configuration of model checkpoint.
@@ -136,6 +152,10 @@ class CheckpointConfig:
136
152
  exception_save (bool): Whether to save the current checkpoint when an exception occurs. Default: ``False`` .
137
153
  crc_check (bool): Whether to perform crc32 calculation when saving checkpoint and save the calculation
138
154
  result to the end of ckpt. Default: ``False`` .
155
+ remove_redundancy (bool): Whether to enable saving the checkpoint with redundancy removal.
156
+ Redundancy removal refers to eliminating redundant data in data parallelism mode. Default: ``False`` , means
157
+ redundant-free saving is not enabled.
158
+ format (str): Format of the output file, can be "ckpt" or "safetensors". Default: "ckpt".
139
159
  kwargs (dict): Configuration options dictionary.
140
160
 
141
161
  Raises:
@@ -188,6 +208,8 @@ class CheckpointConfig:
188
208
  enc_mode='AES-GCM',
189
209
  exception_save=False,
190
210
  crc_check=False,
211
+ remove_redundancy=False,
212
+ format="ckpt",
191
213
  **kwargs):
192
214
 
193
215
  if save_checkpoint_steps is not None:
@@ -231,8 +253,13 @@ class CheckpointConfig:
231
253
  self._enc_key = Validator.check_isinstance('enc_key', enc_key, (type(None), bytes))
232
254
  self._enc_mode = Validator.check_isinstance('enc_mode', enc_mode, str)
233
255
  self._crc_check = Validator.check_isinstance('crc_check', crc_check, bool)
256
+ self._format = Validator.check_isinstance('format', format, str)
234
257
  self._map_param_inc = kwargs.get('incremental', False)
235
258
  self.enable_redundance = kwargs.get('enable_redundance', False)
259
+ self.remove_redundancy = Validator.check_isinstance('remove_redundancy', remove_redundancy, bool)
260
+
261
+ _check_format_and_other_params(format, enc_key, enc_mode, crc_check, async_save, exception_save,
262
+ self._map_param_inc)
236
263
 
237
264
  @property
238
265
  def save_checkpoint_steps(self):
@@ -333,6 +360,10 @@ class CheckpointConfig:
333
360
  """
334
361
  return self._crc_check
335
362
 
363
+ @property
364
+ def format(self):
365
+ return self._format
366
+
336
367
  @property
337
368
  def append_dict(self):
338
369
  """
@@ -495,10 +526,10 @@ class ModelCheckpoint(Callback):
495
526
  self._aiturbo_init_flag = os.getenv("AITURBO") == "1"
496
527
  # get existing checkpoint files
497
528
  if self._aiturbo_init_flag:
498
- import aiturbo
499
- self._manager = aiturbo.CheckpointShmManager()
529
+ from aiturbo.checkpoint.aiturbo_mindspore_ckpt import CheckpointShmManager
530
+ self._manager = CheckpointShmManager()
500
531
  else:
501
- self._manager = CheckpointManager()
532
+ self._manager = CheckpointManager(self._config.format)
502
533
  if not callable(directory) and not callable(prefix):
503
534
  self._prefix = _chg_ckpt_file_name_if_same_exist(self._directory, self._prefix)
504
535
  self._append_dict = self._config.append_dict or {}
@@ -517,7 +548,7 @@ class ModelCheckpoint(Callback):
517
548
  """
518
549
  cb_params = run_context.original_args()
519
550
  if self._aiturbo_init_flag:
520
- import aiturbo
551
+ from aiturbo.checkpoint import aiturbo_mindspore as aiturbo
521
552
  ckpt_storage_path = self._directory
522
553
  rank_id = get_rank()
523
554
  stage_num = _get_auto_parallel_context("pipeline_stages")
@@ -536,7 +567,7 @@ class ModelCheckpoint(Callback):
536
567
  "stage_layout": param_redundancy_dict}
537
568
  single_params = remove_param_redundancy(param_redundancy_dict)
538
569
  single_params = {device_id: list(params) for device_id, params in single_params.items()}
539
- aiturbo.init(ckpt_storage_path, rank_id, layout, single_params, self._config.enable_redundance, dp)
570
+ aiturbo.init(ckpt_storage_path, rank_id, layout, single_params, not self._config.enable_redundance, dp)
540
571
  self._aiturbo_init_flag = False
541
572
  if self._prefix_func:
542
573
  self._prefix = self._prefix_func(cb_params)
@@ -546,14 +577,14 @@ class ModelCheckpoint(Callback):
546
577
  "string that does not contain '/', but got {}.".format(self._prefix))
547
578
  if self._directory_func:
548
579
  self._directory = self._directory_func(cb_params)
549
- _collect_host_info("Callback", "ModelCheckpoint", "step_end", level=1)
580
+ _make_directory(self._directory)
581
+ collect_host_info("Callback", "ModelCheckpoint", "step_end", start_time=get_clock_syscnt(), level=1)
550
582
  # In disaster recovery scenario, the training process may be rolled back to the last step where
551
583
  # the ckpt was successfully saved, so the _last_triggered_step should be updated.
552
584
  if _get_recovery_context("enable_recovery") and cb_params.last_save_ckpt_step is not None:
553
585
  self._last_triggered_step = cb_params.last_save_ckpt_step
554
586
  cb_params.last_save_ckpt_step = None
555
587
 
556
- _make_directory(self._directory)
557
588
  # save graph (only once)
558
589
  if not self._graph_saved:
559
590
  graph_file_name = os.path.join(self._directory, self._prefix + '-graph.meta')
@@ -561,10 +592,6 @@ class ModelCheckpoint(Callback):
561
592
  os.remove(graph_file_name)
562
593
  _save_graph(cb_params.train_network, graph_file_name)
563
594
  self._graph_saved = True
564
- thread_list = threading.enumerate()
565
- for thread in thread_list:
566
- if thread.getName() == "asyn_save_ckpt":
567
- thread.join()
568
595
  self._save_ckpt(cb_params)
569
596
 
570
597
  def end(self, run_context):
@@ -575,15 +602,12 @@ class ModelCheckpoint(Callback):
575
602
  run_context (RunContext): Context of the train running.
576
603
  """
577
604
  cb_params = run_context.original_args()
578
- _collect_host_info("Callback", "ModelCheckpoint", "end", level=1)
605
+ collect_host_info("Callback", "ModelCheckpoint", "end", start_time=get_clock_syscnt(), level=1)
579
606
  _to_save_last_ckpt = True
580
607
 
581
608
  self._save_ckpt(cb_params, _to_save_last_ckpt)
582
609
 
583
- thread_list = threading.enumerate()
584
- for thread in thread_list:
585
- if thread.getName() == "asyn_save_ckpt":
586
- thread.join()
610
+ _wait_async_save_ckpt(self._config.async_save)
587
611
 
588
612
  destroy_allgather_cell()
589
613
 
@@ -601,6 +625,13 @@ class ModelCheckpoint(Callback):
601
625
 
602
626
  return False
603
627
 
628
+ def _append_dict_content(self, epoch_num, step_num):
629
+ """Append append_dict content."""
630
+ if "epoch_num" in self._append_dict:
631
+ self._append_dict["epoch_num"] = self._append_epoch_num + epoch_num
632
+ if "step_num" in self._append_dict:
633
+ self._append_dict["step_num"] = self._append_step_num + step_num
634
+
604
635
  def _save_ckpt(self, cb_params, force_to_save=False):
605
636
  """Save checkpoint files."""
606
637
  if cb_params.cur_step_num == self._last_triggered_step:
@@ -614,11 +645,12 @@ class ModelCheckpoint(Callback):
614
645
  step_num_in_epoch = int((cb_params.cur_step_num - 1) % cb_params.batch_num + 1)
615
646
 
616
647
  if save_ckpt:
648
+ _wait_async_save_ckpt(self._config.async_save)
617
649
  if self._prefix_func:
618
- cur_ckpoint_file = self._prefix + ".ckpt"
650
+ cur_ckpoint_file = self._prefix + f".{self._config.format}"
619
651
  else:
620
652
  cur_ckpoint_file = self._prefix + "-" + str(cb_params.cur_epoch_num) + "_" \
621
- + str(step_num_in_epoch) + ".ckpt"
653
+ + str(step_num_in_epoch) + f".{self._config.format}"
622
654
  # update checkpoint file list.
623
655
  self._manager.update_ckpoint_filelist(self._directory, self._prefix)
624
656
  # keep checkpoint files number equal max number.
@@ -644,20 +676,51 @@ class ModelCheckpoint(Callback):
644
676
  set_cur_net(cb_params.train_network)
645
677
  cb_params.train_network.add_flags(ge_sync_data=True)
646
678
  _cell_graph_executor(cb_params.train_network, phase='save')
647
- if "epoch_num" in self._append_dict:
648
- self._append_dict["epoch_num"] = self._append_epoch_num + cb_params.cur_epoch_num
649
- if "step_num" in self._append_dict:
650
- self._append_dict["step_num"] = self._append_step_num + cb_params.cur_step_num
679
+ self._append_dict_content(cb_params.cur_epoch_num, cb_params.cur_step_num)
651
680
  network = self._config.saved_network if self._config.saved_network is not None else cb_params.train_network
652
681
  if os.getenv("AITURBO") == "1":
653
682
  save_checkpoint(network, cur_file, self._config.integrated_save, self._config.async_save,
654
683
  self._append_dict, self._config.enc_key, self._config.enc_mode,
655
684
  crc_check=self._config.crc_check, incremental=self._map_param_inc,
656
685
  global_step_num=cb_params.cur_step_num)
686
+ elif self._config.remove_redundancy:
687
+ parallel_mode = context.get_auto_parallel_context("parallel_mode")
688
+ if parallel_mode == "stand_alone":
689
+ raise TypeError(f"The deduplication feature for saving checkpoint can only be used "
690
+ f"in parallel scenarios, but got {parallel_mode}.")
691
+ param_layout = network.parameter_layout_dict
692
+ rank_id = get_rank()
693
+ if param_layout:
694
+ device_num = _get_device_num()
695
+ stage_num = _get_auto_parallel_context("pipeline_stages")
696
+ chunk_size = device_num // stage_num
697
+ initial_rank = (rank_id // chunk_size) * chunk_size
698
+ param_redundancy_dict = get_parameter_redundancy(param_layout, initial_rank)
699
+ single_params = remove_param_redundancy(param_redundancy_dict)
700
+ save_param_names = single_params.get(rank_id)
701
+ param_layout_set = set(param_layout.keys())
702
+ if save_param_names == param_layout.keys():
703
+ logger.warning(
704
+ f"For remove_redundancy save checkpoint, the saved parameters are non-redundant.")
705
+
706
+ def choice_func(x):
707
+ return x not in param_layout_set or x in save_param_names
708
+ else:
709
+ param_redundancy_dict = get_parameter_redundancy(network)
710
+ single_params = remove_param_redundancy(param_redundancy_dict)
711
+ save_param_names = single_params.get(rank_id)
712
+
713
+ def choice_func(x):
714
+ return x in save_param_names
715
+ save_checkpoint(network, cur_file, False, self._config.async_save,
716
+ self._append_dict, self._config.enc_key, self._config.enc_mode,
717
+ crc_check=self._config.crc_check, format=self._config.format,
718
+ incremental=self._map_param_inc, choice_func=choice_func)
657
719
  else:
658
720
  save_checkpoint(network, cur_file, self._config.integrated_save, self._config.async_save,
659
721
  self._append_dict, self._config.enc_key, self._config.enc_mode,
660
- crc_check=self._config.crc_check, incremental=self._map_param_inc)
722
+ crc_check=self._config.crc_check, format=self._config.format,
723
+ incremental=self._map_param_inc)
661
724
 
662
725
  self._latest_ckpt_file_name = cur_file
663
726
 
@@ -691,8 +754,9 @@ class ModelCheckpoint(Callback):
691
754
  class CheckpointManager:
692
755
  """Manage checkpoint files according to train_config of checkpoint."""
693
756
 
694
- def __init__(self):
757
+ def __init__(self, format='ckpt'):
695
758
  self._ckpoint_filelist = []
759
+ self._format = format
696
760
 
697
761
  @property
698
762
  def ckpoint_filelist(self):
@@ -707,10 +771,12 @@ class CheckpointManager:
707
771
  def update_ckpoint_filelist(self, directory, prefix):
708
772
  """Update the checkpoint file list."""
709
773
  self._ckpoint_filelist = []
774
+ format = self._format
775
+ format_length = len(format) + 1
710
776
  files = os.listdir(directory)
711
777
  for filename in files:
712
- if os.path.splitext(filename)[-1] == ".ckpt" and filename.startswith(prefix + "-"):
713
- mid_name = filename[len(prefix):-5]
778
+ if os.path.splitext(filename)[-1] == f".{format}" and filename.startswith(prefix + "-"):
779
+ mid_name = filename[len(prefix):-format_length]
714
780
  flag = not (True in [char.isalpha() for char in mid_name])
715
781
  if flag:
716
782
  self._ckpoint_filelist.append(os.path.join(directory, filename))
@@ -150,7 +150,7 @@ class ClusterMonitor(Callback):
150
150
  with _perf_mutex:
151
151
  dir_path = os.path.dirname(self.full_path)
152
152
  if not os.path.exists(dir_path):
153
- os.makedirs(dir_path)
153
+ os.makedirs(dir_path, mode=0o700)
154
154
  if os.path.exists(self.full_path):
155
155
  os.chmod(self.full_path, stat.S_IWUSR)
156
156
  os.remove(self.full_path)
@@ -65,6 +65,7 @@ class FlopsUtilizationCollector(Callback):
65
65
  Raises:
66
66
  TypeError: If data_size is not positive int.
67
67
  TypeError: If full_flops is not bool.
68
+ AssertionError: If the training mode is not a static graph or not a static shape.
68
69
 
69
70
  Examples:
70
71
  >>> import numpy as np
@@ -19,7 +19,7 @@ import numpy as np
19
19
 
20
20
  from mindspore import _checkparam as Validator
21
21
  from mindspore.train.callback._callback import Callback, _handle_loss
22
- from mindspore._c_expression import _collect_host_info
22
+ from mindspore._c_expression import collect_host_info, get_clock_syscnt
23
23
 
24
24
 
25
25
  class LossMonitor(Callback):
@@ -70,7 +70,7 @@ class LossMonitor(Callback):
70
70
  please refer to :class:`mindspore.train.RunContext`.
71
71
  """
72
72
  cb_params = run_context.original_args()
73
- _collect_host_info("Callback", "LossMonitor", "step_end", level=1)
73
+ collect_host_info("Callback", "LossMonitor", "step_end", start_time=get_clock_syscnt(), level=1)
74
74
  cur_epoch_num = cb_params.get("cur_epoch_num", 1)
75
75
  loss = _handle_loss(cb_params.net_outputs)
76
76
 
@@ -101,7 +101,7 @@ class LossMonitor(Callback):
101
101
  please refer to :class:`mindspore.train.RunContext`.
102
102
  """
103
103
  cb_params = run_context.original_args()
104
- _collect_host_info("Callback", "LossMonitor", "train_epoch_end", level=1)
104
+ collect_host_info("Callback", "LossMonitor", "train_epoch_end", start_time=get_clock_syscnt(), level=1)
105
105
  metrics = cb_params.get("metrics")
106
106
  if metrics:
107
107
  print("Eval result: epoch %d, metrics: %s" % (cb_params.cur_epoch_num, metrics))
@@ -16,12 +16,19 @@
16
16
 
17
17
  from __future__ import absolute_import
18
18
  import os
19
+ import json
19
20
  import signal
20
-
21
- from mindspore import log
21
+ import threading
22
+ from mindspore.common import dtype as mstype
23
+ from mindspore import context
24
+ from mindspore import log as logger
25
+ from mindspore.common.tensor import Tensor
26
+ from mindspore.train._utils import _make_directory
22
27
  from mindspore import _checkparam as Validator
23
28
  from mindspore.train.serialization import load_checkpoint, save_checkpoint, export
24
29
  from mindspore.train.callback._callback import Callback
30
+ from mindspore.parallel._utils import _get_parallel_mode
31
+ from mindspore.context import ParallelMode
25
32
 
26
33
 
27
34
  class OnRequestExit(Callback):
@@ -29,7 +36,8 @@ class OnRequestExit(Callback):
29
36
  Respond to the user's closing request, exit the training or eval process, and save the checkpoint and mindir.
30
37
 
31
38
  Register OnRequestExit Callback before training, when the user want to exit the training process
32
- and save the training data, could send the registered exit signal 'sig' to the training process.
39
+ and save the training data, could send the registered exit signal 'sig' to the training process or modify the
40
+ 'GracefulExit' that a key in the json file specified by the 'config_file' to '1'.
33
41
  After the training process executes the current step, saves the current training status,
34
42
  including checkpoint and mindir, and then exit the training process.
35
43
 
@@ -38,9 +46,12 @@ class OnRequestExit(Callback):
38
46
  save_mindir (bool): Whether save the mindir before the training process exit. Default: ``True`` .
39
47
  file_name (str): The saved checkpoint and mindir file name,
40
48
  the checkpoint file add suffix '.ckpt', the mindir file add suffix '.mindir'. Default: ``'Net'`` .
41
- directory (str): The directory save checkpoint and mindir. Default: ``'./'`` .
49
+ directory (str): The path to save files. It will generate a 'rank_{id}' path by rank_id
50
+ to save checkpoint and mindir. Default: ``'./'`` .
42
51
  sig (int): The user registered exit signal, it must be a captureable and negligible signal.
43
52
  When the process receives the signal, exits the training or eval process. Default: ``signal.SIGTERM`` .
53
+ config_file (str): A json config file used to exit training process gracefully. Key: ``{"GracefulExit": 1}`` .
54
+ Default: ``None`` .
44
55
 
45
56
  Raises:
46
57
  ValueError: If the 'save_ckpt' is not a bool.
@@ -67,20 +78,28 @@ class OnRequestExit(Callback):
67
78
  >>> model.train(10, dataset, callbacks=on_request_exit)
68
79
  """
69
80
 
70
- def __init__(self, save_ckpt=True, save_mindir=True, file_name='Net', directory='./', sig=signal.SIGTERM):
81
+ def __init__(self, save_ckpt=True, save_mindir=True, file_name='Net', directory='./', config_file=None,
82
+ sig=signal.SIGTERM):
71
83
  super(OnRequestExit, self).__init__()
72
84
  self.save_ckpt = Validator.check_isinstance('save_ckpt', save_ckpt, bool)
73
85
  self.save_mindir = Validator.check_isinstance('save_mindir', save_mindir, bool)
74
- if self.save_ckpt or self.save_mindir:
75
- file_name = Validator.check_isinstance('file_name', file_name, str)
76
- directory = Validator.check_isinstance('directory', directory, str)
77
- os.makedirs(os.path.abspath(directory), exist_ok=True)
78
- self.train_file_path = os.path.abspath(os.path.join(directory, f"{file_name}_train"))
79
- self.eval_file_path = os.path.abspath(os.path.join(directory, f"{file_name}_eval"))
80
86
  self.sig = Validator.check_isinstance('sig', sig, int)
81
87
  if hasattr(signal, "SIGKILL") and self.sig == signal.SIGKILL:
82
88
  raise ValueError("Not support send exit request by signal SIGKILL.")
83
- self.exit = False
89
+ self.exit = False # used signal to exit the training process
90
+ self.lock = threading.Lock()
91
+ self.save_path = directory
92
+ self.key = "GracefulExit"
93
+ self.remote_config_file = config_file # used config file to save checkpoint and exit training process
94
+ self.use_graceful = os.environ.get("MS_ENABLE_GRACEFUL_EXIT") == "1"
95
+ self.is_distributed = _get_parallel_mode() != ParallelMode.STAND_ALONE
96
+ self.integrated_save = True
97
+ if self.is_distributed:
98
+ self.integrated_save = _get_parallel_mode() == ParallelMode.AUTO_PARALLEL
99
+ self.stop_train = False
100
+ self.need_do_step_end = False
101
+ if self.save_ckpt or self.save_mindir:
102
+ self.train_name, self.eval_name = self._get_save_path(file_name)
84
103
 
85
104
  def on_train_begin(self, run_context):
86
105
  """
@@ -91,22 +110,31 @@ class OnRequestExit(Callback):
91
110
  For more details, please refer to :class:`mindspore.train.RunContext`.
92
111
  """
93
112
  signal.signal(self.sig, self._handle_signal)
94
- if self.save_ckpt and os.path.isfile(f"{self.train_file_path}.ckpt"):
113
+ if self.save_ckpt and os.path.isfile(f"{self.train_name}.ckpt"):
95
114
  cb_params = run_context.original_args()
96
115
  train_net = cb_params.train_network
97
- load_checkpoint(f"{self.train_file_path}.ckpt", net=train_net)
116
+ load_checkpoint(f"{self.train_name}.ckpt", net=train_net)
117
+
118
+ def on_train_step_begin(self, run_context):
119
+ """
120
+ Check whether received the exit signal or
121
+ whether the value of 'GracefulExit' in 'config_file' was changed to '1'.
122
+
123
+ Args:
124
+ run_context (RunContext): Context information of the model.
125
+ For more details, please refer to :class:`mindspore.train.RunContext`.
126
+ """
127
+ self._do_step_begin(run_context)
98
128
 
99
129
  def on_train_step_end(self, run_context):
100
130
  """
101
- When the train step end, if received the exit signal, set the 'run_context' attribute '_stop_requested' to True.
102
- Then exit the training process after this step training.
131
+ Save checkpoint file or mindir file according to config, and exit the training process.
103
132
 
104
133
  Args:
105
134
  run_context (RunContext): Include some information of the model.
106
135
  For more details, please refer to :class:`mindspore.train.RunContext`.
107
136
  """
108
- if self.exit:
109
- run_context.request_stop()
137
+ self._do_step_end(run_context)
110
138
 
111
139
  def on_train_epoch_end(self, run_context):
112
140
  """
@@ -118,8 +146,7 @@ class OnRequestExit(Callback):
118
146
  run_context (RunContext): Include some information of the model.
119
147
  For more details, please refer to :class:`mindspore.train.RunContext`.
120
148
  """
121
- if self.exit:
122
- run_context.request_stop()
149
+ self._do_step_end(run_context)
123
150
 
124
151
  def on_train_end(self, run_context):
125
152
  """
@@ -135,10 +162,10 @@ class OnRequestExit(Callback):
135
162
  cb_params = run_context.original_args()
136
163
  train_net = cb_params.train_network
137
164
  if self.save_ckpt:
138
- save_checkpoint(train_net, ckpt_file_name=self.train_file_path)
165
+ save_checkpoint(train_net, ckpt_file_name=self.train_name)
139
166
  if self.save_mindir:
140
167
  inputs = cb_params.train_dataset_element
141
- export(train_net, *inputs, file_name=self.train_file_path, file_format='MINDIR')
168
+ export(train_net, *inputs, file_name=self.train_name, file_format='MINDIR')
142
169
 
143
170
  def on_eval_begin(self, run_context):
144
171
  """
@@ -153,15 +180,15 @@ class OnRequestExit(Callback):
153
180
  return
154
181
  cb_params = run_context.original_args()
155
182
  eval_net = cb_params.eval_network
156
- if os.path.isfile(f"{self.eval_file_path}.ckpt"):
157
- load_checkpoint(f"{self.eval_file_path}.ckpt", net=eval_net)
158
- elif os.path.isfile(f"{self.train_file_path}.ckpt"):
159
- load_checkpoint(f"{self.train_file_path}.ckpt", net=eval_net)
183
+ if os.path.isfile(f"{self.eval_name}.ckpt"):
184
+ load_checkpoint(f"{self.eval_name}.ckpt", net=eval_net)
185
+ elif os.path.isfile(f"{self.train_name}.ckpt"):
186
+ load_checkpoint(f"{self.train_name}.ckpt", net=eval_net)
160
187
 
161
188
  def on_eval_step_end(self, run_context):
162
189
  """
163
- When the eval step end, if received the exit signal, set the 'run_context' attribute '_stop_requested' to True.
164
- Then exit the eval process after this step eval.
190
+ When the eval step end, if received the exit signal, set attribute '_stop_requested' of the
191
+ 'run_context' to True. Then exit the eval process after this step eval.
165
192
 
166
193
  Args:
167
194
  run_context (RunContext): Include some information of the model.
@@ -184,12 +211,99 @@ class OnRequestExit(Callback):
184
211
  cb_params = run_context.original_args()
185
212
  eval_net = cb_params.eval_network
186
213
  if self.save_ckpt:
187
- save_checkpoint(eval_net, ckpt_file_name=self.eval_file_path)
214
+ save_checkpoint(eval_net, ckpt_file_name=self.eval_name)
188
215
  if self.save_mindir:
189
216
  inputs = cb_params.eval_dataset_element
190
- export(eval_net, *inputs, file_name=self.eval_file_path, file_format='MINDIR')
217
+ export(eval_net, *inputs, file_name=self.eval_name, file_format='MINDIR')
191
218
 
192
219
  def _handle_signal(self, signum, frame):
193
220
  """Handle the received signal"""
194
- log.debug(f"signum: {signum}, frame: {frame}")
221
+ logger.debug(f"signum: {signum}, frame: {frame}")
195
222
  self.exit = True
223
+
224
+ def _do_step_end(self, run_context):
225
+ """
226
+ Save the checkpoint or mindir, and then exit training process.
227
+
228
+ Args:
229
+ run_context (RunContext): Include some information of the model.
230
+ For more details, please refer to :class:`mindspore.train.RunContext`.
231
+ """
232
+ with self.lock:
233
+ # save once
234
+ if self.stop_train or not self.need_do_step_end:
235
+ return
236
+ logger.info("Gracefully exiting training process on step end.")
237
+ call_params = run_context.original_args()
238
+ net = call_params.train_network
239
+ for _, param in net.parameters_and_names():
240
+ if param.name == "graceful_exit" and param.asnumpy() == True: # pylint: disable=C0121
241
+ logger.warning("Graceful exit is triggered, stop training.")
242
+ if self.save_ckpt:
243
+ append_dict = {"epoch_num": call_params.cur_epoch_num,
244
+ "step_num": call_params.cur_step_num,
245
+ "batch_num": call_params.batch_num}
246
+ if call_params.loss_scale_mananger is not None:
247
+ append_dict["loss_scale"] = call_params.loss_scale_mananger.get_loss_scale()
248
+ if call_params.optimizer is not None:
249
+ global_step = int(call_params.optimizer.global_step.data)
250
+ else:
251
+ global_step = int(call_params.network.optimizer.global_step.data)
252
+ append_dict["global_step"] = global_step
253
+ save_checkpoint(net, self.train_name, integrated_save=self.integrated_save,
254
+ append_dict=append_dict)
255
+ if self.save_mindir:
256
+ inputs = call_params.train_dataset_element
257
+ export(net, *inputs, file_name=self.train_name, file_format='MINDIR')
258
+ run_context.request_stop()
259
+ self.stop_train = True
260
+
261
+ def _do_step_begin(self, run_context):
262
+ """
263
+ Check training process exit configuration at the step begin.
264
+
265
+ Args:
266
+ run_context (RunContext): Include some information of the model.
267
+ For more details, please refer to :class:`mindspore.train.RunContext`.
268
+ """
269
+ with self.lock:
270
+ # no env
271
+ if not self.use_graceful:
272
+ return
273
+ if self._check_config_info() or self.exit:
274
+ call_params = run_context.original_args()
275
+ net = call_params.train_network
276
+ for _, param in net.parameters_and_names():
277
+ if not self.is_distributed and param.name == "graceful_exit":
278
+ param.set_data(Tensor(True, mstype.bool_))
279
+ self.need_do_step_end = True
280
+ break
281
+ if param.name == "graceful_init":
282
+ param.set_data(Tensor([1], mstype.int32))
283
+ self.need_do_step_end = True
284
+ break
285
+
286
+ def _check_config_info(self):
287
+ """check json config info"""
288
+ if self.remote_config_file is not None and os.path.exists(self.remote_config_file):
289
+ with open(self.remote_config_file, "r") as f:
290
+ try:
291
+ config_info = json.load(f)
292
+ except json.JSONDecodeError as e:
293
+ logger.warning(f"Parse json file failed: {e}, please check json file: {self.remote_config_file}")
294
+ return False
295
+ if self.key in config_info and config_info[self.key] == 1:
296
+ return True
297
+ return False
298
+
299
+ def _get_save_path(self, file_name):
300
+ """path to save checkpoint files or mindir files"""
301
+ device_id = context.get_context("device_id")
302
+ if self.save_path is None:
303
+ tmp = os.path.join(os.getcwd(), r"rank_" + str(device_id))
304
+ path_ = _make_directory(tmp)
305
+ return os.path.join(path_, f"{file_name}_train"), os.path.join(path_, f"{file_name}_eval")
306
+
307
+ save_path = os.path.join(self.save_path, r"rank_" + str(device_id))
308
+ save_path = _make_directory(save_path)
309
+ return os.path.join(save_path, f"{file_name}_train"), os.path.join(save_path, f"{file_name}_eval")
@@ -41,7 +41,7 @@ from mindspore.nn.optim.optimizer import Optimizer
41
41
  from mindspore.nn.loss.loss import LossBase
42
42
  from mindspore.train._utils import check_value_type, _make_directory
43
43
  from mindspore._c_expression import security
44
- from mindspore._c_expression import _collect_host_info
44
+ from mindspore._c_expression import collect_host_info, get_clock_syscnt
45
45
 
46
46
  HYPER_CONFIG_ENV_NAME = "MINDINSIGHT_HYPER_CONFIG"
47
47
  HYPER_CONFIG_LEN_LIMIT = 100000
@@ -472,7 +472,7 @@ class SummaryCollector(Callback):
472
472
 
473
473
  def begin(self, run_context):
474
474
  cb_params = run_context.original_args()
475
- _collect_host_info("Callback", "SummaryCollector", "begin", level=1)
475
+ collect_host_info("Callback", "SummaryCollector", "begin", start_time=get_clock_syscnt(), level=1)
476
476
  self._check_callbacks(cb_params)
477
477
 
478
478
  if cb_params.mode not in ModeEnum.to_list():
@@ -484,7 +484,7 @@ class SummaryCollector(Callback):
484
484
 
485
485
  def step_end(self, run_context):
486
486
  cb_params = run_context.original_args()
487
- _collect_host_info("Callback", "SummaryCollector", "step_end", level=1)
487
+ collect_host_info("Callback", "SummaryCollector", "step_end", start_time=get_clock_syscnt(), level=1)
488
488
  if cb_params.mode != ModeEnum.TRAIN.value:
489
489
  return
490
490
 
@@ -559,7 +559,7 @@ class SummaryCollector(Callback):
559
559
 
560
560
  def epoch_end(self, run_context):
561
561
  cb_params = run_context.original_args()
562
- _collect_host_info("Callback", "SummaryCollector", "epoch_end", level=1)
562
+ collect_host_info("Callback", "SummaryCollector", "epoch_end", start_time=get_clock_syscnt(), level=1)
563
563
  self._collect_tensor_data(cb_params)
564
564
  collect_landscape = self._collect_specified_data.get('collect_landscape')
565
565
  if collect_landscape is not None:
@@ -576,7 +576,7 @@ class SummaryCollector(Callback):
576
576
 
577
577
  def end(self, run_context):
578
578
  cb_params = run_context.original_args()
579
- _collect_host_info("Callback", "SummaryCollector", "end", level=1)
579
+ collect_host_info("Callback", "SummaryCollector", "end", start_time=get_clock_syscnt(), level=1)
580
580
  if cb_params.mode == ModeEnum.TRAIN.value:
581
581
  self._collect_train_lineage(cb_params)
582
582
  else: