mindspore 2.4.0__cp311-none-any.whl → 2.4.1__cp311-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (106) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/_c_dataengine.cpython-311-aarch64-linux-gnu.so +0 -0
  3. mindspore/_c_expression.cpython-311-aarch64-linux-gnu.so +0 -0
  4. mindspore/common/initializer.py +51 -15
  5. mindspore/common/parameter.py +18 -4
  6. mindspore/common/tensor.py +15 -49
  7. mindspore/communication/comm_func.py +7 -7
  8. mindspore/context.py +9 -0
  9. mindspore/include/mindapi/base/format.h +13 -0
  10. mindspore/lib/libmindspore_backend.so +0 -0
  11. mindspore/lib/libmindspore_common.so +0 -0
  12. mindspore/lib/libmindspore_core.so +0 -0
  13. mindspore/lib/libmindspore_ops.so +0 -0
  14. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/kernel/config/ascend910b/all_finite.json +10 -10
  15. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/kernel/config/ascend910b/binary_info_config.json +8 -8
  16. mindspore/lib/plugin/ascend/custom_compiler/setup.py +1 -1
  17. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  18. mindspore/lib/plugin/ascend/libmindspore_internal_kernels.so +0 -0
  19. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/host/libasdops_cann_host.so +0 -0
  20. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/utils/rt/base/types.h +5 -5
  21. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/lib/libasdops.so +0 -0
  22. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/lib/libasdops_static.a +0 -0
  23. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/lib/liblcal.so +0 -0
  24. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/lib/liblcal_static.a +0 -0
  25. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/include/acme_op.h +1 -0
  26. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/src/ops/host_src/paged_attention_op.h +6 -1
  27. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/src/ops/host_src/rms_norm_op.h +4 -3
  28. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libAdd_impl.so +0 -0
  29. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libSub_impl.so +0 -0
  30. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libadd_layer_norm_impl.so +0 -0
  31. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libadd_rms_norm_impl.so +0 -0
  32. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libadd_rms_norm_quant_acme_impl.so +0 -0
  33. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libapply_rotary_pos_emb_310p_impl.so +0 -0
  34. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libapply_rotary_pos_emb_310p_old_impl.so +0 -0
  35. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libapply_rotary_pos_emb_impl.so +0 -0
  36. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libapply_rotary_pos_emb_old_impl.so +0 -0
  37. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libcast_impl.so +0 -0
  38. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libgelu_impl.so +0 -0
  39. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libmatmul_impl.so +0 -0
  40. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libms_kernels_internal.so +0 -0
  41. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libmulti_weight_matmul_kernel_impl.so +0 -0
  42. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libnot_equal_impl.so +0 -0
  43. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libreshape_and_cache_impl.so +0 -0
  44. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libreshape_and_cache_nz_impl.so +0 -0
  45. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libreshape_and_cache_nz_old_impl.so +0 -0
  46. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/librms_norm_impl.so +0 -0
  47. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/flash_attention_score/flash_attention_score_bf16_bnsd_full_mix.o +0 -0
  48. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/flash_attention_score/flash_attention_score_bf16_bnsd_tri_mix.o +0 -0
  49. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/flash_attention_score/flash_attention_score_bf16_bsh_full_mix.o +0 -0
  50. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/flash_attention_score/flash_attention_score_fp16_bnsd_full_mix.o +0 -0
  51. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/flash_attention_score/flash_attention_score_fp16_bnsd_tri_mix.o +0 -0
  52. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/flash_attention_score/flash_attention_score_fp16_bsh_full_mix.o +0 -0
  53. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/flash_attention_score/flash_attention_score_fp16_bsh_tri_mix.o +0 -0
  54. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/matmul_add_rmsnorm/matmul_add_rmsnorm_bf16_bf16.o +0 -0
  55. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/matmul_add_rmsnorm/matmul_add_rmsnorm_bf16_fp16.o +0 -0
  56. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/matmul_add_rmsnorm/matmul_add_rmsnorm_bf16_fp32.o +0 -0
  57. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/matmul_add_rmsnorm/matmul_add_rmsnorm_fp16_bf16.o +0 -0
  58. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/matmul_add_rmsnorm/matmul_add_rmsnorm_fp16_fp16.o +0 -0
  59. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/matmul_add_rmsnorm/matmul_add_rmsnorm_fp16_fp32.o +0 -0
  60. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/paged_attention/paged_attention_bf16_bnsd_mix.o +0 -0
  61. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/paged_attention/paged_attention_bf16_bsh_mix.o +0 -0
  62. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/paged_attention/paged_attention_fp16_bnsd_mix.o +0 -0
  63. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/paged_attention/paged_attention_fp16_bsh_mix.o +0 -0
  64. mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/lib/liblcal.so +0 -0
  65. mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/lib/liblccl_wrapper.so +0 -0
  66. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  67. mindspore/mint/__init__.py +490 -2
  68. mindspore/mint/nn/__init__.py +2 -2
  69. mindspore/mint/optim/adamw.py +6 -14
  70. mindspore/nn/cell.py +1 -3
  71. mindspore/nn/layer/basic.py +24 -7
  72. mindspore/nn/layer/embedding.py +31 -14
  73. mindspore/nn/optim/tft_wrapper.py +12 -15
  74. mindspore/ops/_grad_experimental/grad_array_ops.py +0 -11
  75. mindspore/ops/_grad_experimental/grad_comm_ops.py +20 -1
  76. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +6 -0
  77. mindspore/ops/auto_generate/gen_extend_func.py +33 -0
  78. mindspore/ops/auto_generate/gen_ops_def.py +52 -3
  79. mindspore/ops/auto_generate/gen_ops_prim.py +155 -6
  80. mindspore/ops/function/array_func.py +2 -0
  81. mindspore/ops/function/math_func.py +7 -1
  82. mindspore/ops/function/random_func.py +221 -7
  83. mindspore/ops/operations/__init__.py +1 -1
  84. mindspore/ops/operations/array_ops.py +3 -1
  85. mindspore/ops/operations/comm_ops.py +21 -0
  86. mindspore/ops/operations/manually_defined/ops_def.py +8 -10
  87. mindspore/parallel/_auto_parallel_context.py +3 -1
  88. mindspore/parallel/_cell_wrapper.py +2 -0
  89. mindspore/parallel/_tensor.py +46 -2
  90. mindspore/parallel/_utils.py +40 -21
  91. mindspore/parallel/transform_safetensors.py +196 -43
  92. mindspore/profiler/profiling.py +5 -1
  93. mindspore/run_check/_check_version.py +4 -2
  94. mindspore/train/_utils.py +92 -32
  95. mindspore/train/callback/_checkpoint.py +12 -9
  96. mindspore/train/callback/_on_request_exit.py +12 -1
  97. mindspore/train/callback/_tft_register.py +27 -4
  98. mindspore/train/dataset_helper.py +10 -2
  99. mindspore/train/model.py +20 -0
  100. mindspore/train/serialization.py +8 -18
  101. mindspore/version.py +1 -1
  102. {mindspore-2.4.0.dist-info → mindspore-2.4.1.dist-info}/METADATA +8 -6
  103. {mindspore-2.4.0.dist-info → mindspore-2.4.1.dist-info}/RECORD +106 -106
  104. {mindspore-2.4.0.dist-info → mindspore-2.4.1.dist-info}/WHEEL +0 -0
  105. {mindspore-2.4.0.dist-info → mindspore-2.4.1.dist-info}/entry_points.txt +0 -0
  106. {mindspore-2.4.0.dist-info → mindspore-2.4.1.dist-info}/top_level.txt +0 -0
mindspore/train/_utils.py CHANGED
@@ -16,6 +16,8 @@
16
16
  from __future__ import absolute_import
17
17
 
18
18
  import os
19
+ import threading
20
+ from datetime import datetime
19
21
  import json
20
22
  from collections.abc import Iterable
21
23
 
@@ -76,7 +78,14 @@ def _exec_datagraph(exec_dataset, dataset_size, phase='dataset', create_data_inf
76
78
  queue_name = str("")
77
79
 
78
80
  use_pipeline_parallel = (context.get_auto_parallel_context("pipeline_stages") > 1)
79
- if use_pipeline_parallel:
81
+
82
+ # temp env to disable dynamic feature of sink size 1
83
+ dynamic_sink1_env = os.getenv("MS_DEV_DYNAMIC_SINK1", None)
84
+ dynamic_sink1 = True
85
+ if dynamic_sink1_env and dynamic_sink1_env.strip() in ['False', 'false']:
86
+ dynamic_sink1 = False
87
+
88
+ if use_pipeline_parallel or not dynamic_sink1:
80
89
  create_data_info_queue = False
81
90
 
82
91
  exec_dataset = exec_dataset.device_que(send_epoch_end=send_epoch_end,
@@ -303,10 +312,68 @@ def parse_strategy_ckpt(file_name):
303
312
 
304
313
  for ele in param.parallel_layouts.tensor_map[0].ListFields()[0][1]:
305
314
  tensor_map.append(ele)
306
- layout_dict[param.param_name] = [dev_matrix, tensor_map]
315
+ layout_dict[param.param_name] = [dev_matrix, tensor_map, param.parallel_layouts.opt_weight_shard_step,
316
+ param.parallel_layouts.opt_weight_shard_size]
307
317
  return layout_dict
308
318
 
309
319
 
320
+ def _get_strategy_opt_shard(param_redundancy_dict, parameter_layout_opt_shard):
321
+ """Strategy ckpt append opt shard."""
322
+ for key, value in parameter_layout_opt_shard.items():
323
+ if value[1] not in (-1, 0):
324
+ opt_para_num = value[1]
325
+ param_redundancy_ranks = param_redundancy_dict.get(key)
326
+ res = []
327
+ for param_ranks in param_redundancy_ranks:
328
+ if len(param_ranks) % opt_para_num == 0:
329
+ for i in range(0, opt_para_num):
330
+ res.append(param_ranks[i::opt_para_num])
331
+ param_redundancy_dict[key] = tuple(res)
332
+
333
+
334
+ def _get_layout_opt_shard(layout_obj, param_redundancy_dict):
335
+ """Layout ckpt append opt shard."""
336
+ for key, value in layout_obj.items():
337
+ if value[5]:
338
+ world_groups = ("hccl_world_group", "nccl_world_group", "mccl_world_group")
339
+ if value[5] in world_groups:
340
+ opt_para_num = get_group_size()
341
+ elif "-" in value[5]:
342
+ opt_para_str = value[5].split("-")[0]
343
+ opt_para_num = int(opt_para_str)
344
+ else:
345
+ raise ValueError(f"For get_parameter_redundancy, the format of the parallel communication domain for "
346
+ f"the optimizer is incorrect.")
347
+ param_redundancy_ranks = param_redundancy_dict.get(key)
348
+ res = []
349
+ for param_ranks in param_redundancy_ranks:
350
+ if len(param_ranks) % opt_para_num == 0:
351
+ for i in range(0, opt_para_num):
352
+ res.append(param_ranks[i::opt_para_num])
353
+ param_redundancy_dict[key] = tuple(res)
354
+
355
+
356
+ def _get_parameter_redundancy_without_opt_shard(parameter_layout, param_redundancy_dict, initial_rank):
357
+ """Get parameter redundancy without opt shard."""
358
+ for key, (slices, deploy_loc, *_) in parameter_layout.items():
359
+ redundancy_matrix = np.zeros(shape=slices + [len(slices)], dtype=np.int8)
360
+ for i in deploy_loc:
361
+ internal_slice = tuple(slice(None) for _ in range(i))
362
+ for j in range(slices[-i - 1]):
363
+ if i == -1:
364
+ continue
365
+ else:
366
+ redundancy_matrix[(..., j) + internal_slice + (i,)] = j
367
+ locate_list = redundancy_matrix.reshape((-1, len(slices))).tolist()
368
+ redundancy_dict = {}
369
+ for index, locate in enumerate(locate_list):
370
+ redundancy_dict.setdefault(tuple(locate), []).append(index + initial_rank)
371
+ redundancy_list = []
372
+ for _, indices in sorted(redundancy_dict.items()):
373
+ redundancy_list.append(tuple(indices))
374
+ param_redundancy_dict[key] = tuple(redundancy_list)
375
+
376
+
310
377
  def get_parameter_redundancy(layout_obj, initial_rank=0):
311
378
  """
312
379
  Get parameter redundancy map.
@@ -327,7 +394,12 @@ def get_parameter_redundancy(layout_obj, initial_rank=0):
327
394
  'param4': ((0, 4, 8, 12), (1, 5, 9, 13), (2, 6, 10, 14), (3, 7, 11, 15))}
328
395
  """
329
396
  if isinstance(layout_obj, str):
330
- parameter_layout = parse_strategy_ckpt(layout_obj)
397
+ parameter_layout_total = parse_strategy_ckpt(layout_obj)
398
+ parameter_layout = {}
399
+ parameter_layout_opt_shard = {}
400
+ for key, value in parameter_layout_total.items():
401
+ parameter_layout[key] = value[0:2]
402
+ parameter_layout_opt_shard[key] = value[2:]
331
403
  elif isinstance(layout_obj, Cell):
332
404
  from mindspore.communication.management import get_process_group_ranks
333
405
  groups_ranks = (tuple(get_process_group_ranks()),)
@@ -339,37 +411,14 @@ def get_parameter_redundancy(layout_obj, initial_rank=0):
339
411
  parameter_layout[k] = v[:2]
340
412
 
341
413
  param_redundancy_dict = {}
342
- for key, (slices, deploy_loc, *_) in parameter_layout.items():
343
- redundancy_matrix = np.zeros(shape=slices + [len(slices)], dtype=np.int8)
344
- for i in deploy_loc:
345
- internal_slice = tuple(slice(None) for _ in range(i))
346
- for j in range(slices[-i - 1]):
347
- if i == -1:
348
- continue
349
- else:
350
- redundancy_matrix[(..., j) + internal_slice + (i,)] = j
351
- locate_list = redundancy_matrix.reshape((-1, len(slices))).tolist()
352
- redundancy_dict = {}
353
- for index, locate in enumerate(locate_list):
354
- redundancy_dict.setdefault(tuple(locate), []).append(index + initial_rank)
355
- redundancy_list = []
356
- for _, indices in sorted(redundancy_dict.items()):
357
- redundancy_list.append(tuple(indices))
358
- param_redundancy_dict[key] = tuple(redundancy_list)
414
+
415
+ _get_parameter_redundancy_without_opt_shard(parameter_layout, param_redundancy_dict, initial_rank)
416
+
359
417
  if isinstance(layout_obj, str):
360
- return param_redundancy_dict
418
+ _get_strategy_opt_shard(param_redundancy_dict, parameter_layout_opt_shard)
419
+ else:
420
+ _get_layout_opt_shard(layout_obj, param_redundancy_dict)
361
421
 
362
- for key, value in layout_obj.items():
363
- if value[5]:
364
- world_groups = ("hccl_world_group", "nccl_world_group", "mccl_world_group")
365
- opt_para_num = int(value[5][0]) if value[5] not in world_groups else get_group_size()
366
- param_redundancy_ranks = param_redundancy_dict.get(key)
367
- res = []
368
- for param_ranks in param_redundancy_ranks:
369
- if len(param_ranks) % opt_para_num == 0:
370
- for i in range(0, opt_para_num):
371
- res.append(param_ranks[i::opt_para_num])
372
- param_redundancy_dict[key] = tuple(res)
373
422
  return param_redundancy_dict
374
423
 
375
424
 
@@ -463,3 +512,14 @@ def parse_hccl_file(hccl_file_path):
463
512
  rankid_dict[int(device["rank_id"])] = device["device_ip"]
464
513
 
465
514
  return rankid_dict
515
+
516
+
517
+ def vlog_print(level, module, file, line, message):
518
+ '''Read environment variable VLOG_v and print to log'''
519
+ if os.environ.get("VLOG_v") == level:
520
+ now = datetime.now()
521
+ formatted_time = now.strftime("%Y-%m-%d-%H:%M:%S.%f")[:-3] + f".{now.microsecond // 1000}"
522
+ path = 'mindspore' + file.split("mindspore")[-1]
523
+ pid = os.getpid()
524
+ thread_id = threading.get_ident()
525
+ print(f"[V{level}] {module}({pid},{thread_id},python):{formatted_time} [{path}:{line}] {message}", flush=True)
@@ -44,6 +44,15 @@ SAVE_DIR = _cur_dir
44
44
  _info_list = ["epoch_num", "step_num"]
45
45
 
46
46
 
47
+ def _wait_async_save_ckpt(async_save=False):
48
+ """Waiting for asynchronous saving of ckpt to complete."""
49
+ if async_save:
50
+ thread_list = threading.enumerate()
51
+ for thread in thread_list:
52
+ if thread.getName() == "asyn_save_ckpt":
53
+ thread.join()
54
+
55
+
47
56
  def _get_dp_tp_from_redundancy(redundancy_tuple):
48
57
  """From redundancy get dp and tp"""
49
58
  dp = []
@@ -568,6 +577,7 @@ class ModelCheckpoint(Callback):
568
577
  "string that does not contain '/', but got {}.".format(self._prefix))
569
578
  if self._directory_func:
570
579
  self._directory = self._directory_func(cb_params)
580
+ _make_directory(self._directory)
571
581
  collect_host_info("Callback", "ModelCheckpoint", "step_end", start_time=get_clock_syscnt(), level=1)
572
582
  # In disaster recovery scenario, the training process may be rolled back to the last step where
573
583
  # the ckpt was successfully saved, so the _last_triggered_step should be updated.
@@ -575,7 +585,6 @@ class ModelCheckpoint(Callback):
575
585
  self._last_triggered_step = cb_params.last_save_ckpt_step
576
586
  cb_params.last_save_ckpt_step = None
577
587
 
578
- _make_directory(self._directory)
579
588
  # save graph (only once)
580
589
  if not self._graph_saved:
581
590
  graph_file_name = os.path.join(self._directory, self._prefix + '-graph.meta')
@@ -583,10 +592,6 @@ class ModelCheckpoint(Callback):
583
592
  os.remove(graph_file_name)
584
593
  _save_graph(cb_params.train_network, graph_file_name)
585
594
  self._graph_saved = True
586
- thread_list = threading.enumerate()
587
- for thread in thread_list:
588
- if thread.getName() == "asyn_save_ckpt":
589
- thread.join()
590
595
  self._save_ckpt(cb_params)
591
596
 
592
597
  def end(self, run_context):
@@ -602,10 +607,7 @@ class ModelCheckpoint(Callback):
602
607
 
603
608
  self._save_ckpt(cb_params, _to_save_last_ckpt)
604
609
 
605
- thread_list = threading.enumerate()
606
- for thread in thread_list:
607
- if thread.getName() == "asyn_save_ckpt":
608
- thread.join()
610
+ _wait_async_save_ckpt(self._config.async_save)
609
611
 
610
612
  destroy_allgather_cell()
611
613
 
@@ -643,6 +645,7 @@ class ModelCheckpoint(Callback):
643
645
  step_num_in_epoch = int((cb_params.cur_step_num - 1) % cb_params.batch_num + 1)
644
646
 
645
647
  if save_ckpt:
648
+ _wait_async_save_ckpt(self._config.async_save)
646
649
  if self._prefix_func:
647
650
  cur_ckpoint_file = self._prefix + f".{self._config.format}"
648
651
  else:
@@ -240,7 +240,18 @@ class OnRequestExit(Callback):
240
240
  if param.name == "graceful_exit" and param.asnumpy() == True: # pylint: disable=C0121
241
241
  logger.warning("Graceful exit is triggered, stop training.")
242
242
  if self.save_ckpt:
243
- save_checkpoint(net, self.train_name, integrated_save=self.integrated_save)
243
+ append_dict = {"epoch_num": call_params.cur_epoch_num,
244
+ "step_num": call_params.cur_step_num,
245
+ "batch_num": call_params.batch_num}
246
+ if call_params.loss_scale_mananger is not None:
247
+ append_dict["loss_scale"] = call_params.loss_scale_mananger.get_loss_scale()
248
+ if call_params.optimizer is not None:
249
+ global_step = int(call_params.optimizer.global_step.data)
250
+ else:
251
+ global_step = int(call_params.network.optimizer.global_step.data)
252
+ append_dict["global_step"] = global_step
253
+ save_checkpoint(net, self.train_name, integrated_save=self.integrated_save,
254
+ append_dict=append_dict)
244
255
  if self.save_mindir:
245
256
  inputs = call_params.train_dataset_element
246
257
  export(net, *inputs, file_name=self.train_name, file_format='MINDIR')
@@ -21,6 +21,7 @@ from mindspore import _checkparam as Validator
21
21
  from mindspore.train.callback._callback import Callback
22
22
  from mindspore import context
23
23
  from mindspore.common.parameter import Parameter
24
+ from mindspore.common.tensor import Tensor
24
25
  from mindspore.communication import get_rank, get_group_size
25
26
  from mindspore import log as logger
26
27
  from mindspore.train.serialization import _get_cur_rank_dp
@@ -29,6 +30,9 @@ from mindspore._c_expression import clean_tdt_channel
29
30
  from mindspore._c_expression import send_recv
30
31
  from mindspore._c_expression import CollectiveManager
31
32
  from mindspore._c_expression import _get_uce_process_strategy, _get_uce_mem_info
33
+ from mindspore._c_expression import Tensor as Tensor_
34
+ import mindspore
35
+ import mindspore.common.dtype as mstype
32
36
 
33
37
  def _get_ckpt_dir(step, ckpt_save_path, is_tmp_file):
34
38
  """ Common func to generate ckpt dir name."""
@@ -39,6 +43,9 @@ def _get_ckpt_dir(step, ckpt_save_path, is_tmp_file):
39
43
  def _save_checkpoint_on_failure(step, save_info, args, cb_ctx):
40
44
  """ Callback used for TFT save ckpt function when errors occur."""
41
45
  logger.info("Enter _save_checkpoint_on_failure function")
46
+ if not cb_ctx._is_params_consistent(): # pylint: disable=W0212
47
+ raise RuntimeError("Can't save parameters, because they are left in inconsistent state!")
48
+
42
49
  ckpt_save_path = cb_ctx.ckpt_save_path
43
50
  cb_params = args
44
51
  cur_rank = get_rank()
@@ -83,8 +90,6 @@ def _tft_exit_cb(ctx):
83
90
  _tft_sem_post()
84
91
  os._exit(1) # pylint: disable=W0212
85
92
 
86
-
87
-
88
93
  def _tft_repair_callback(step, need_rebuild, error_ranks, repair_info, args, cb_ctx):
89
94
  """ Callback used for TFT repair function."""
90
95
  logger.info("Enter _tft_repair_callback repair type: {}".format(repair_info["repair_type"]))
@@ -129,6 +134,8 @@ def _tft_stop_callback(cb_ctx):
129
134
  """ Callback used for TFT stop function."""
130
135
  logger.info("Enter _tft_stop_callback device_id: {}".format(cb_ctx.device_id))
131
136
  _stop_device(cb_ctx.device_id)
137
+ if not cb_ctx._is_params_consistent(): # pylint: disable=W0212
138
+ raise RuntimeError("Can't stop device, because training parameters are left in inconsistent state!")
132
139
  logger.info("Finish _tft_stop_callback")
133
140
 
134
141
 
@@ -260,9 +267,22 @@ class TFTRegister(Callback):
260
267
  self._controller_ip = ctrl_ip
261
268
  self._controller_rank_id = ctrl_rank_id
262
269
  self._controller_port = ctrl_port
270
+ self.cb_params = None
263
271
  self.device_id = context.get_context("device_id")
264
272
  self._init_tft()
265
273
  self.ckpt_save_path = ckpt_save_path
274
+ self.assign = mindspore.ops.Assign()
275
+ self.g_one = Parameter(Tensor([1], dtype=mstype.int32))
276
+ self.s1 = mindspore.hal.Stream()
277
+
278
+ def _is_params_consistent(self):
279
+ for key, param in self.cb_params.train_network.parameters_and_names():
280
+ if "tft_g_one_flag" in key:
281
+ with mindspore.hal.StreamCtx(self.s1):
282
+ tft_g_one_flag = Tensor(Tensor_.move_to(param, "CPU", False))
283
+ self.s1.synchronize()
284
+ return int(tft_g_one_flag) == 1
285
+ return False
266
286
 
267
287
  def _set_tft_optimizer_replica(self, run_context):
268
288
  """ set Mindio TFT optimizer replica info, used internal. """
@@ -328,12 +348,14 @@ class TFTRegister(Callback):
328
348
  self.has_init_replica = True
329
349
  self._set_tft_optimizer_replica(run_context)
330
350
  cb_params = run_context.original_args()
351
+ logger.info("START Set optimizer finish step status to TFT. step: {}".format(cb_params.cur_step_num))
352
+ self.tft.tft_end_updating_os(cb_params.cur_step_num)
331
353
  if cb_params.optimizer is not None:
332
354
  self.global_step = int(cb_params.optimizer.global_step.data)
355
+ self.assign(cb_params.optimizer.tft_g_one_flag, self.g_one)
333
356
  else:
334
357
  self.global_step = int(cb_params.network.optimizer.global_step.data)
335
- logger.info("START Set optimizer finish step status to TFT. step: {}".format(cb_params.cur_step_num))
336
- self.tft.tft_end_updating_os(cb_params.cur_step_num)
358
+ self.assign(cb_params.network.optimizer.tft_g_one_flag, self.g_one)
337
359
  logger.info("END Set optimizer finish step status to TFT.")
338
360
 
339
361
 
@@ -344,6 +366,7 @@ class TFTRegister(Callback):
344
366
  raise ValueError("TFT feature doesn't support sink_size > 1.")
345
367
  logger.info("Set set args to TFT.")
346
368
  self.tft.tft_set_step_args(cb_params)
369
+ self.cb_params = cb_params
347
370
 
348
371
  def end(self, run_context):
349
372
  cur_rank = get_rank()
@@ -15,6 +15,7 @@
15
15
  """Dataset help for minddata dataset"""
16
16
  from __future__ import absolute_import
17
17
 
18
+ import os
18
19
  import math
19
20
  import copy
20
21
 
@@ -264,7 +265,14 @@ def connect_network_with_dataset(network, dataset_helper):
264
265
  queue_name = dataset.__transfer_dataset__.queue_name
265
266
  # In pipeline parallel, some stages have no GetNext, should not get in.
266
267
  use_pipeline_parallel = (context.get_auto_parallel_context("pipeline_stages") > 1)
267
- if _dynamic_sink_scenario(dataset, dataset_iter, is_dynamic) and not use_pipeline_parallel:
268
+
269
+ # temp env to disable dynamic feature of sink size 1
270
+ dynamic_sink1_env = os.getenv("MS_DEV_DYNAMIC_SINK1", None)
271
+ dynamic_sink1 = True
272
+ if dynamic_sink1_env and dynamic_sink1_env.strip() in ['False', 'false']:
273
+ dynamic_sink1 = False
274
+
275
+ if _dynamic_sink_scenario(dataset, dataset_iter, is_dynamic) and not use_pipeline_parallel and dynamic_sink1:
268
276
  dataset_types, dataset_shapes = dataset_helper.get_data_info()
269
277
  # Need to do full_batch for shapes which also do in the _DatasetIterMSLoopSink
270
278
  if _need_to_full():
@@ -306,7 +314,7 @@ def connect_network_with_dataset(network, dataset_helper):
306
314
  aux.__shape_type__ = str(dataset_types) + str(dataset_shapes)
307
315
 
308
316
  if _dynamic_sink_data(dataset, dataset_iter) and _dynamic_sink_exception_scenario(dataset_iter, is_dynamic) and \
309
- not use_pipeline_parallel:
317
+ not use_pipeline_parallel and dynamic_sink1:
310
318
  dataset_helper.get_data_info()
311
319
  network.add_flags(sink_mode=True)
312
320
  return network
mindspore/train/model.py CHANGED
@@ -46,6 +46,7 @@ from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_
46
46
  from mindspore.parallel._ps_context import _is_role_worker, _is_role_pserver, _is_ps_mode, \
47
47
  _cache_enable, _enable_distributed_mindrt
48
48
  from mindspore.train.metrics import Loss
49
+ from mindspore.train._utils import vlog_print
49
50
  from mindspore import nn
50
51
  from mindspore.boost import AutoBoost
51
52
  from mindspore.context import ParallelMode
@@ -654,10 +655,12 @@ class Model:
654
655
  dataset.__loop_size__ = 1
655
656
 
656
657
  if dataset_helper is None:
658
+ vlog_print("1", "ME", __file__, sys._getframe().f_lineno, "Begin to create DatasetHelper.")
657
659
  logger.info("Begin to create DatasetHelper.")
658
660
  dataset_helper = DatasetHelper(dataset, dataset_sink_mode, sink_size, epoch_num)
659
661
 
660
662
  if dataset_sink_mode:
663
+ vlog_print("1", "ME", __file__, sys._getframe().f_lineno, "Begin to connect network with dataset.")
661
664
  logger.info("Begin to connect network with dataset.")
662
665
  network = connect_network_with_dataset(network, dataset_helper)
663
666
 
@@ -779,6 +782,7 @@ class Model:
779
782
  if not train_dataset and not valid_dataset:
780
783
  raise ValueError("The argument 'train_dataset' and 'valid_dataset' can not both be None or empty.")
781
784
 
785
+ vlog_print("1", "ME", __file__, sys._getframe().f_lineno, "Begin to check device number in model.build().")
782
786
  logger.info("Begin to check device number in model.build() procedure.")
783
787
  _device_number_check(self._parallel_mode, self._device_number)
784
788
 
@@ -787,17 +791,21 @@ class Model:
787
791
  raise TypeError("The type of 'train_dataset' must be `Dataset`, "
788
792
  "but got {}.".format(type(train_dataset)))
789
793
 
794
+ vlog_print("1", "ME", __file__, sys._getframe().f_lineno,
795
+ "Begin to check parameter broadcast in model.build().")
790
796
  logger.info("Begin to check parameter broadcast in model.build() procedure.")
791
797
  _parameter_broadcast_check(self._parallel_mode, self._parameter_broadcast)
792
798
  if self._parameter_broadcast:
793
799
  self._train_network.set_broadcast_flag()
794
800
 
801
+ vlog_print("1", "ME", __file__, sys._getframe().f_lineno, "Begin to exec preprocess in model.build().")
795
802
  logger.info("Begin to exec preprocess in model.build() procedure.")
796
803
  train_dataset.__no_send__ = True
797
804
  train_dataset_helper, train_network = self._exec_preprocess(is_train=True,
798
805
  dataset=train_dataset,
799
806
  dataset_sink_mode=True,
800
807
  sink_size=sink_size)
808
+ vlog_print("1", "ME", __file__, sys._getframe().f_lineno, "Begin to warmup dataset in model.build().")
801
809
  logger.info("Begin to warmup dataset in model.build() procedure.")
802
810
  self._warmup_dataset(epoch, train_dataset, sink_size)
803
811
 
@@ -805,13 +813,19 @@ class Model:
805
813
  delattr(train_dataset, "__no_send__")
806
814
 
807
815
  # Waiting for the dataset warmup ready
816
+ vlog_print("1", "ME", __file__, sys._getframe().f_lineno,
817
+ "Begin waiting for dataset warmup in model.build().")
808
818
  logger.info("Begin waiting for dataset warmup in model.build() procedure.")
809
819
  self._waiting_for_dataset_warmup_ready(train_dataset)
820
+ vlog_print("1", "ME", __file__, sys._getframe().f_lineno,
821
+ "The dataset warmup was successful in model.build().")
810
822
  logger.info("The dataset warmup was successful in model.build() procedure.")
811
823
 
812
824
  if context.get_auto_parallel_context("pipeline_stages") > 1 and valid_dataset:
813
825
  train_network.add_flags_recursive(is_first_iteration=True)
814
826
  for inputs in train_dataset_helper:
827
+ vlog_print("1", "ME", __file__, sys._getframe().f_lineno,
828
+ "Begin to compile train network in model.build().")
815
829
  logger.info("Begin to compile train network in model.build() procedure.")
816
830
  train_network.compile(*inputs)
817
831
  self._train_network.parameter_layout_dict = train_network.parameter_layout_dict
@@ -832,6 +846,8 @@ class Model:
832
846
  if context.get_auto_parallel_context("pipeline_stages") > 1:
833
847
  eval_network.add_flags_recursive(is_first_iteration=False)
834
848
  for inputs in valid_dataset_helper:
849
+ vlog_print("1", "ME", __file__, sys._getframe().f_lineno,
850
+ "Begin to compile eval network in model.build().")
835
851
  logger.info("Begin to compile eval network in model.build() procedure.")
836
852
  eval_network.compile(*inputs)
837
853
  break
@@ -905,6 +921,7 @@ class Model:
905
921
  epoch = 1
906
922
  cb_params.last_save_ckpt_step = None
907
923
  cb_params.latest_ckpt_file = None
924
+ cb_params.loss_scale_mananger = self._loss_scale_manager
908
925
 
909
926
  # build callback list
910
927
  with _CallbackManager(callbacks) as list_callback:
@@ -1567,8 +1584,11 @@ class Model:
1567
1584
  if hasattr(self._train_network, '_is_check_and_refresh') and not self._train_network._is_check_and_refresh:
1568
1585
  self._train_network.check_names_and_refresh_name()
1569
1586
  self._train_network._is_check_and_refresh = True
1587
+ vlog_print("1", "ME", __file__, sys._getframe().f_lineno, "Begin to init dataset in model.build().")
1570
1588
  logger.info("Begin to init dataset in model.build() procedure.")
1571
1589
  self._init(train_dataset, valid_dataset, sink_size, epoch)
1590
+ vlog_print("1", "ME", __file__, sys._getframe().f_lineno,
1591
+ "The model.build() which contains dataset warmup and network compile is success.")
1572
1592
  logger.info("The model.build() which contains dataset warmup and network compile is success.")
1573
1593
 
1574
1594
  def _eval_in_fit(self, valid_dataset, callbacks=None, dataset_sink_mode=True, cb_params=None):
@@ -64,7 +64,7 @@ from mindspore.parallel._cell_wrapper import get_allgather_cell, _single_paramet
64
64
  from mindspore.parallel._tensor import _load_tensor, _get_tensor_strategy, _get_tensor_slice_index
65
65
  from mindspore.parallel._tensor import _reshape_param_data, _reshape_param_data_with_weight
66
66
  from mindspore.parallel._utils import _infer_rank_list, _remove_repeated_slices, _is_in_auto_parallel_mode, \
67
- _get_device_num, _is_parallel_mode
67
+ _get_device_num
68
68
  from mindspore.parallel._auto_parallel_context import _get_auto_parallel_context
69
69
  from mindspore.parallel._parallel_serialization import _convert_to_list, _convert_to_layout, _build_searched_strategy, \
70
70
  _restore_group_info_list, _get_param_list_when_first_dim_sharded
@@ -1569,6 +1569,9 @@ def _parse_ckpt_proto(ckpt_file_name, dec_key, dec_mode, crc_check):
1569
1569
  raise ValueError("For 'load_checkpoint', the crc check is failed, "
1570
1570
  "please check whether the ckpt file is damaged.")
1571
1571
  checkpoint_list.ParseFromString(pb_content)
1572
+ except google.protobuf.message.DecodeError as e:
1573
+ raise ValueError(f"Failed to read the checkpoint file {ckpt_file_name}. "
1574
+ f"The file may be corrupted, and the content cannot be parsed.") from e
1572
1575
  except BaseException as e:
1573
1576
  if _is_cipher_file(ckpt_file_name):
1574
1577
  err_info = "Failed to read the checkpoint file {}. The file may be encrypted or tempered with, " \
@@ -1598,19 +1601,6 @@ def _whether_load_param(specify_prefix, filter_prefix, param_name):
1598
1601
  return whether_load
1599
1602
 
1600
1603
 
1601
- def _init_parameter_data_in_parallel_mode(net, parameter_dict):
1602
- """In parallel mode, only init the paraemters in ckpt."""
1603
- is_train_phase = net.phase.startswith('train')
1604
- for _, param in net.parameters_and_names():
1605
- if param.name in parameter_dict and param.from_ckpt and not is_train_phase:
1606
- param.shape = tuple(parameter_dict[param.name].shape)
1607
- continue
1608
- if param.name in parameter_dict and param.has_init:
1609
- logger.warning("{} is not init while load ckpt.".format(param.name))
1610
- new_tensor = param.init_data()
1611
- param._update_tensor_data(new_tensor)
1612
-
1613
-
1614
1604
  def _check_load_param_into_net(net, parameter_dict):
1615
1605
  """check load_param_into_net"""
1616
1606
  if not isinstance(net, nn.Cell):
@@ -1682,10 +1672,6 @@ def load_param_into_net(net, parameter_dict, strict_load=False, remove_redundanc
1682
1672
  logger.info("Execute the process of loading parameters into net.")
1683
1673
  for _, param in net.parameters_and_names():
1684
1674
  param.from_ckpt = True
1685
- if not (_is_in_auto_parallel_mode() or _is_parallel_mode()):
1686
- net.init_parameters_data()
1687
- else:
1688
- _init_parameter_data_in_parallel_mode(net, parameter_dict)
1689
1675
  param_not_load = []
1690
1676
  ckpt_not_load = list(parameter_dict.keys())
1691
1677
  for _, param in net.parameters_and_names():
@@ -1698,6 +1684,8 @@ def load_param_into_net(net, parameter_dict, strict_load=False, remove_redundanc
1698
1684
  continue
1699
1685
  new_param = parameter_dict[param.name]
1700
1686
  _update_param(param, new_param, strict_load)
1687
+ if hasattr(param, "init_param") and not param.init_param:
1688
+ param.init_param = True
1701
1689
  ckpt_not_load.remove(param.name)
1702
1690
  else:
1703
1691
  param_not_load.append(param.name)
@@ -1822,6 +1810,8 @@ def _load_dismatch_prefix_params(net, parameter_dict, param_not_load, strict_loa
1822
1810
  if param.name in param_not_load and new_param_name in parameter_dict:
1823
1811
  new_param = parameter_dict[new_param_name]
1824
1812
  _update_param(param, new_param, strict_load)
1813
+ if hasattr(param, "init_param") and not param.init_param:
1814
+ param.init_param = True
1825
1815
  param_not_load.remove(param.name)
1826
1816
 
1827
1817
 
mindspore/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = '2.4.0'
1
+ __version__ = '2.4.1'
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: mindspore
3
- Version: 2.4.0
3
+ Version: 2.4.1
4
4
  Summary: MindSpore is a new open source deep learning training/inference framework that could be used for mobile, edge and cloud scenarios.
5
5
  Home-page: https://www.mindspore.cn
6
6
  Download-URL: https://github.com/mindspore-ai/mindspore/tags
@@ -316,15 +316,17 @@ Project stable branches will be in one of the following states:
316
316
  | Development | 3 months | Features are under development. |
317
317
  | Maintained | 6 - 12 months | All bugfixes are appropriate. Releases produced. |
318
318
  | Unmaintained| 0 - 3 months | All bugfixes are appropriate. No Maintainers and No Releases produced. |
319
- | End Of Life (EOL) | N/A | Branch no longer accepting changes. |
319
+ | End Of Life (EOL) | N/A | Version no longer accepting changes. |
320
320
 
321
321
  ## Maintenance status
322
322
 
323
- | **Branch** | **Status** | **Initial Release Date** | **Next Phase** | **EOL Date**|
323
+ | **Version** | **Status** | **Initial Release Date** | **Next Phase** | **EOL Date**|
324
324
  |------------|--------------|--------------------------|----------------------------------------|-------------|
325
- | **r2.2** | Maintained | 2023-10-18 | Unmaintained <br> 2024-10-18 estimated | |
326
- | **r2.1** | Maintained | 2023-07-29 | Unmaintained <br> 2024-07-29 estimated | |
327
- | **r2.0** | Maintained | 2023-06-15 | Unmaintained <br> 2024-06-15 estimated | |
325
+ | **r2.4** | Maintained | 2024-10-30 | Unmaintained <br> 2025-10-30 estimated | 2025-10-30 |
326
+ | **r2.3** | Maintained | 2024-07-15 | Unmaintained <br> 2025-07-15 estimated | 2025-07-15 |
327
+ | **r2.2** | End Of Life | 2023-10-18 | | 2024-10-18 |
328
+ | **r2.1** | End Of Life | 2023-07-29 | | 2024-07-29 |
329
+ | **r2.0** | End Of Life | 2023-06-15 | | 2024-06-15 |
328
330
  | **r1.10** | End Of Life | 2023-02-02 | | 2024-02-02 |
329
331
  | **r1.9** | End Of Life | 2022-10-26 | | 2023-10-26 |
330
332
  | **r1.8** | End Of Life | 2022-07-29 | | 2023-07-29 |