mindspore 2.7.0__cp311-cp311-win_amd64.whl → 2.7.1__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -1
- mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
- mindspore/_extends/parse/compile_config.py +24 -1
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +6 -2
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +8 -1
- mindspore/_extends/parse/trope.py +2 -1
- mindspore/_extends/pijit/pijit_func_white_list.py +7 -22
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/base.py +29 -2
- mindspore/common/_decorator.py +3 -2
- mindspore/common/_grad_function.py +3 -1
- mindspore/common/_tensor_cpp_method.py +1 -1
- mindspore/common/_tensor_docs.py +275 -64
- mindspore/common/_utils.py +0 -44
- mindspore/common/api.py +285 -35
- mindspore/common/dump.py +7 -108
- mindspore/common/dynamic_shape/auto_dynamic_shape.py +1 -3
- mindspore/common/hook_handle.py +60 -0
- mindspore/common/jit_config.py +5 -1
- mindspore/common/jit_trace.py +27 -12
- mindspore/common/lazy_inline.py +5 -3
- mindspore/common/parameter.py +13 -107
- mindspore/common/recompute.py +4 -11
- mindspore/common/tensor.py +16 -169
- mindspore/communication/_comm_helper.py +11 -1
- mindspore/communication/comm_func.py +138 -4
- mindspore/communication/management.py +85 -1
- mindspore/config/op_info.config +0 -15
- mindspore/context.py +5 -85
- mindspore/dataset/engine/datasets.py +8 -4
- mindspore/dataset/engine/datasets_vision.py +1 -1
- mindspore/dataset/engine/validators.py +1 -15
- mindspore/dnnl.dll +0 -0
- mindspore/{experimental/llm_boost/ascend_native → graph}/__init__.py +7 -7
- mindspore/graph/custom_pass.py +55 -0
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/mindrecord/__init__.py +3 -3
- mindspore/mindrecord/common/exceptions.py +1 -0
- mindspore/mindrecord/config.py +1 -1
- mindspore/{parallel/mpi → mindrecord/core}/__init__.py +4 -1
- mindspore/mindrecord/{shardheader.py → core/shardheader.py} +2 -1
- mindspore/mindrecord/{shardindexgenerator.py → core/shardindexgenerator.py} +1 -1
- mindspore/mindrecord/{shardreader.py → core/shardreader.py} +2 -1
- mindspore/mindrecord/{shardsegment.py → core/shardsegment.py} +2 -2
- mindspore/mindrecord/{shardutils.py → core/shardutils.py} +1 -1
- mindspore/mindrecord/{shardwriter.py → core/shardwriter.py} +1 -1
- mindspore/mindrecord/filereader.py +4 -4
- mindspore/mindrecord/filewriter.py +5 -5
- mindspore/mindrecord/mindpage.py +2 -2
- mindspore/mindrecord/tools/cifar10.py +1 -1
- mindspore/mindrecord/tools/cifar100.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +1 -1
- mindspore/mindrecord/tools/cifar10_to_mr.py +1 -1
- mindspore/mindrecord/tools/csv_to_mr.py +1 -1
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +1 -1
- mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -1
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_cluster.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_cpu.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_hardware_abstract.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_ops_host.dll → mindspore_ops_cpu.dll} +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mindspore_runtime_utils.dll +0 -0
- mindspore/mindspore_tools.dll +0 -0
- mindspore/mint/__init__.py +15 -10
- mindspore/mint/distributed/distributed.py +182 -62
- mindspore/mint/nn/__init__.py +2 -16
- mindspore/mint/nn/functional.py +4 -110
- mindspore/mint/nn/layer/__init__.py +0 -2
- mindspore/mint/nn/layer/activation.py +0 -6
- mindspore/mint/nn/layer/basic.py +0 -47
- mindspore/mint/nn/layer/conv.py +4 -4
- mindspore/mint/nn/layer/normalization.py +8 -13
- mindspore/mint/nn/layer/pooling.py +0 -4
- mindspore/nn/__init__.py +1 -3
- mindspore/nn/cell.py +16 -66
- mindspore/nn/layer/basic.py +49 -1
- mindspore/nn/layer/container.py +16 -0
- mindspore/nn/layer/embedding.py +4 -169
- mindspore/nn/layer/normalization.py +2 -1
- mindspore/nn/layer/thor_layer.py +4 -85
- mindspore/nn/optim/ada_grad.py +0 -1
- mindspore/nn/optim/adafactor.py +0 -1
- mindspore/nn/optim/adam.py +31 -124
- mindspore/nn/optim/adamax.py +0 -1
- mindspore/nn/optim/asgd.py +0 -1
- mindspore/nn/optim/ftrl.py +8 -102
- mindspore/nn/optim/lamb.py +0 -1
- mindspore/nn/optim/lars.py +0 -3
- mindspore/nn/optim/lazyadam.py +25 -218
- mindspore/nn/optim/momentum.py +5 -43
- mindspore/nn/optim/optimizer.py +6 -55
- mindspore/nn/optim/proximal_ada_grad.py +0 -1
- mindspore/nn/optim/rmsprop.py +0 -1
- mindspore/nn/optim/rprop.py +0 -1
- mindspore/nn/optim/sgd.py +0 -1
- mindspore/nn/optim/tft_wrapper.py +0 -1
- mindspore/nn/optim/thor.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +7 -8
- mindspore/nn/probability/bijector/gumbel_cdf.py +2 -2
- mindspore/nn/probability/bijector/power_transform.py +20 -21
- mindspore/nn/probability/bijector/scalar_affine.py +5 -5
- mindspore/nn/probability/bijector/softplus.py +13 -14
- mindspore/nn/wrap/grad_reducer.py +4 -74
- mindspore/numpy/array_creations.py +2 -2
- mindspore/numpy/fft.py +9 -9
- mindspore/{nn/reinforcement → onnx}/__init__.py +5 -8
- mindspore/onnx/onnx_export.py +137 -0
- mindspore/opencv_core4110.dll +0 -0
- mindspore/opencv_imgcodecs4110.dll +0 -0
- mindspore/{opencv_imgproc452.dll → opencv_imgproc4110.dll} +0 -0
- mindspore/ops/__init__.py +2 -0
- mindspore/ops/_grad_experimental/grad_comm_ops.py +38 -2
- mindspore/ops/_op_impl/aicpu/__init__.py +0 -10
- mindspore/ops/_op_impl/cpu/__init__.py +0 -5
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +16 -22
- mindspore/ops/auto_generate/gen_extend_func.py +2 -7
- mindspore/ops/auto_generate/gen_ops_def.py +98 -141
- mindspore/ops/auto_generate/gen_ops_prim.py +12708 -12686
- mindspore/ops/communication.py +97 -0
- mindspore/ops/composite/__init__.py +5 -2
- mindspore/ops/composite/base.py +15 -1
- mindspore/ops/composite/multitype_ops/__init__.py +3 -1
- mindspore/ops/composite/multitype_ops/_compile_utils.py +150 -8
- mindspore/ops/composite/multitype_ops/add_impl.py +7 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +27 -0
- mindspore/ops/function/__init__.py +1 -0
- mindspore/ops/function/array_func.py +14 -12
- mindspore/ops/function/comm_func.py +3883 -0
- mindspore/ops/function/debug_func.py +3 -4
- mindspore/ops/function/math_func.py +45 -54
- mindspore/ops/function/nn_func.py +75 -294
- mindspore/ops/function/random_func.py +9 -18
- mindspore/ops/functional.py +2 -0
- mindspore/ops/functional_overload.py +354 -18
- mindspore/ops/operations/__init__.py +2 -5
- mindspore/ops/operations/_custom_ops_utils.py +7 -9
- mindspore/ops/operations/_inner_ops.py +1 -38
- mindspore/ops/operations/_rl_inner_ops.py +0 -933
- mindspore/ops/operations/array_ops.py +1 -0
- mindspore/ops/operations/comm_ops.py +94 -2
- mindspore/ops/operations/custom_ops.py +228 -19
- mindspore/ops/operations/debug_ops.py +27 -29
- mindspore/ops/operations/manually_defined/ops_def.py +27 -306
- mindspore/ops/operations/nn_ops.py +2 -2
- mindspore/ops/operations/sparse_ops.py +0 -83
- mindspore/ops/primitive.py +1 -17
- mindspore/ops/tensor_method.py +72 -3
- mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +5 -5
- mindspore/ops_generate/aclnn/gen_aclnn_implement.py +8 -8
- mindspore/ops_generate/api/functions_cc_generator.py +53 -4
- mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +25 -11
- mindspore/ops_generate/common/gen_constants.py +11 -10
- mindspore/ops_generate/common/op_proto.py +18 -1
- mindspore/ops_generate/common/template.py +102 -245
- mindspore/ops_generate/common/template_utils.py +212 -0
- mindspore/ops_generate/gen_custom_ops.py +69 -0
- mindspore/ops_generate/op_def/ops_def_cc_generator.py +78 -7
- mindspore/ops_generate/op_def_py/base_op_prim_py_generator.py +360 -0
- mindspore/ops_generate/op_def_py/custom_op_prim_py_generator.py +140 -0
- mindspore/ops_generate/op_def_py/op_def_py_generator.py +54 -7
- mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -312
- mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +74 -17
- mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +22 -5
- mindspore/ops_generate/pyboost/op_template_parser.py +3 -2
- mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +21 -5
- mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +2 -2
- mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +30 -10
- mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +10 -3
- mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +1 -1
- mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +19 -9
- mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +71 -28
- mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +10 -9
- mindspore/ops_generate/pyboost/pyboost_utils.py +27 -16
- mindspore/ops_generate/resources/yaml_loader.py +13 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +2 -2
- mindspore/parallel/_cell_wrapper.py +1 -1
- mindspore/parallel/_parallel_serialization.py +1 -4
- mindspore/parallel/_utils.py +29 -6
- mindspore/parallel/checkpoint_transform.py +18 -2
- mindspore/parallel/cluster/process_entity/_api.py +24 -32
- mindspore/parallel/cluster/process_entity/_utils.py +9 -5
- mindspore/{experimental/llm_boost/atb → parallel/distributed}/__init__.py +21 -23
- mindspore/parallel/distributed/distributed_data_parallel.py +393 -0
- mindspore/parallel/distributed/flatten_grad_buffer.py +295 -0
- mindspore/parallel/strategy.py +336 -0
- mindspore/parallel/transform_safetensors.py +117 -16
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +3 -0
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +1 -1
- mindspore/profiler/common/constant.py +5 -0
- mindspore/profiler/common/file_manager.py +9 -0
- mindspore/profiler/common/msprof_cmd_tool.py +38 -2
- mindspore/profiler/common/path_manager.py +56 -24
- mindspore/profiler/common/profiler_context.py +2 -12
- mindspore/profiler/common/profiler_info.py +3 -3
- mindspore/profiler/common/profiler_path_manager.py +13 -0
- mindspore/profiler/common/util.py +30 -3
- mindspore/profiler/experimental_config.py +2 -1
- mindspore/profiler/platform/npu_profiler.py +33 -6
- mindspore/run_check/_check_version.py +108 -24
- mindspore/runtime/__init__.py +3 -2
- mindspore/runtime/executor.py +11 -3
- mindspore/runtime/memory.py +112 -0
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/{experimental/llm_boost → tools}/__init__.py +5 -5
- mindspore/tools/data_dump.py +130 -0
- mindspore/tools/sdc_detect.py +91 -0
- mindspore/tools/stress_detect.py +63 -0
- mindspore/train/__init__.py +6 -6
- mindspore/train/_utils.py +5 -18
- mindspore/train/amp.py +6 -4
- mindspore/train/callback/_checkpoint.py +0 -9
- mindspore/train/callback/_train_fault_tolerance.py +69 -18
- mindspore/train/data_sink.py +1 -5
- mindspore/train/model.py +38 -211
- mindspore/train/serialization.py +126 -387
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +6 -3
- mindspore/utils/dlpack.py +92 -0
- mindspore/utils/dryrun.py +1 -1
- mindspore/utils/runtime_execution_order_check.py +10 -0
- mindspore/utils/sdc_detect.py +14 -12
- mindspore/utils/stress_detect.py +43 -0
- mindspore/utils/utils.py +144 -8
- mindspore/version.py +1 -1
- {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/METADATA +3 -2
- {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/RECORD +254 -267
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -210
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +0 -52
- mindspore/experimental/llm_boost/atb/boost_base.py +0 -385
- mindspore/experimental/llm_boost/atb/llama_boost.py +0 -137
- mindspore/experimental/llm_boost/atb/qwen_boost.py +0 -124
- mindspore/experimental/llm_boost/register.py +0 -130
- mindspore/experimental/llm_boost/utils.py +0 -31
- mindspore/include/OWNERS +0 -7
- mindspore/mindspore_cpu_res_manager.dll +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/nn/optim/_dist_optimizer_registry.py +0 -111
- mindspore/nn/reinforcement/_batch_read_write.py +0 -142
- mindspore/nn/reinforcement/_tensors_queue.py +0 -152
- mindspore/nn/reinforcement/tensor_array.py +0 -145
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +0 -113
- mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +0 -96
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +0 -42
- mindspore/ops/_op_impl/cpu/buffer_append.py +0 -28
- mindspore/ops/_op_impl/cpu/buffer_get.py +0 -28
- mindspore/ops/_op_impl/cpu/buffer_sample.py +0 -28
- mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +0 -42
- mindspore/ops/operations/_tensor_array.py +0 -359
- mindspore/ops/operations/rl_ops.py +0 -288
- mindspore/parallel/_offload_context.py +0 -275
- mindspore/parallel/_recovery_context.py +0 -115
- mindspore/parallel/_transformer/__init__.py +0 -35
- mindspore/parallel/_transformer/layers.py +0 -765
- mindspore/parallel/_transformer/loss.py +0 -251
- mindspore/parallel/_transformer/moe.py +0 -693
- mindspore/parallel/_transformer/op_parallel_config.py +0 -222
- mindspore/parallel/_transformer/transformer.py +0 -3124
- mindspore/parallel/mpi/_mpi_config.py +0 -116
- mindspore/train/memory_profiling_pb2.py +0 -298
- {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/WHEEL +0 -0
- {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/top_level.txt +0 -0
mindspore/runtime/memory.py
CHANGED
|
@@ -14,9 +14,15 @@
|
|
|
14
14
|
# ============================================================================
|
|
15
15
|
|
|
16
16
|
"""Memory interfaces."""
|
|
17
|
+
import contextlib
|
|
18
|
+
import ctypes
|
|
17
19
|
import os
|
|
18
20
|
from mindspore._c_expression import RuntimeConf, DeviceManagerConf, _memory_stats, \
|
|
19
21
|
_reset_max_mem_reserved, _reset_max_mem_allocated, DeviceContextManager, _empty_cache, _memory_replay
|
|
22
|
+
try:
|
|
23
|
+
from mindspore._c_expression import _enable_pluggable_allocator, _disable_pluggable_allocator
|
|
24
|
+
except ImportError:
|
|
25
|
+
pass
|
|
20
26
|
from mindspore import _checkparam as Validator
|
|
21
27
|
from mindspore._checkparam import args_type_check
|
|
22
28
|
from mindspore import log as logger
|
|
@@ -406,3 +412,109 @@ def memory_replay(file_path):
|
|
|
406
412
|
>>> ms.runtime.memory_replay("/data/memory_block.csv")
|
|
407
413
|
"""
|
|
408
414
|
_memory_replay(os.path.realpath(file_path))
|
|
415
|
+
|
|
416
|
+
|
|
417
|
+
class PluggableAllocator():
|
|
418
|
+
r"""
|
|
419
|
+
Receive a .so file via ctypes, and dynamically load the alloc and free functions within it.
|
|
420
|
+
It needs to be used in conjunction with :class:`mindspore.runtime.MemPool` and
|
|
421
|
+
:func:`mindspore.runtime.use_mem_pool` to take over the memory allocation and free
|
|
422
|
+
in the MindSpore memory pool.
|
|
423
|
+
|
|
424
|
+
.. warning::
|
|
425
|
+
This is currently supported only in unix OSs.
|
|
426
|
+
|
|
427
|
+
Args:
|
|
428
|
+
path_to_so_file(str): Path in the file system to the `.so` file containing
|
|
429
|
+
the allocator functions.
|
|
430
|
+
alloc_fn_name(str): Name of the function to perform the memory allocation
|
|
431
|
+
in the so file. The signature must be:
|
|
432
|
+
`void* alloc_fn(size_t size, int device, aclrtStream stream);` .
|
|
433
|
+
free_fn_name(str): Name of the function to perform the memory release
|
|
434
|
+
in the so file. The signature must be:
|
|
435
|
+
`void free_fn(void* ptr, size_t size, aclrtStream stream);` .
|
|
436
|
+
|
|
437
|
+
Supported Platforms:
|
|
438
|
+
``Ascend``
|
|
439
|
+
"""
|
|
440
|
+
|
|
441
|
+
def __init__(self, path_to_so_file: str, alloc_fn_name: str, free_fn_name: str):
|
|
442
|
+
allocator = ctypes.CDLL(path_to_so_file)
|
|
443
|
+
alloc_fn = ctypes.cast(getattr(allocator, alloc_fn_name), ctypes.c_void_p).value
|
|
444
|
+
free_fn = ctypes.cast(getattr(allocator, free_fn_name), ctypes.c_void_p).value
|
|
445
|
+
if alloc_fn is None:
|
|
446
|
+
raise ValueError(f"Cannot find allocator function {alloc_fn_name} in {path_to_so_file}")
|
|
447
|
+
if free_fn is None:
|
|
448
|
+
raise ValueError(f"Cannot find free function {free_fn_name} in {path_to_so_file}")
|
|
449
|
+
self._alloc_fn = alloc_fn
|
|
450
|
+
self._free_fn = free_fn
|
|
451
|
+
|
|
452
|
+
@property
|
|
453
|
+
def alloc_fn_ptr(self) -> int:
|
|
454
|
+
"""Function pointer of the allocator function."""
|
|
455
|
+
return self._alloc_fn
|
|
456
|
+
|
|
457
|
+
@property
|
|
458
|
+
def free_fn_ptr(self) -> int:
|
|
459
|
+
"""Function pointer of the free function."""
|
|
460
|
+
return self._free_fn
|
|
461
|
+
|
|
462
|
+
|
|
463
|
+
class MemPool():
|
|
464
|
+
r"""
|
|
465
|
+
A MemPool warp a :class:`mindspore.runtime.PluggableAllocator`,
|
|
466
|
+
and pass it to :func:`mindspore.runtime.use_mem_pool`.
|
|
467
|
+
|
|
468
|
+
Args:
|
|
469
|
+
allocator(mindspore.runtime.PluggableAllocator): a mindspore.runtime.PluggableAllocator
|
|
470
|
+
that can be used to define how memory gets allocated and freed in the pool.
|
|
471
|
+
|
|
472
|
+
Supported Platforms:
|
|
473
|
+
``Ascend``
|
|
474
|
+
"""
|
|
475
|
+
|
|
476
|
+
def __init__(self, allocator: PluggableAllocator):
|
|
477
|
+
self._allocator = allocator
|
|
478
|
+
|
|
479
|
+
@property
|
|
480
|
+
def allocator(self) -> PluggableAllocator:
|
|
481
|
+
"""The allocator used by the pool."""
|
|
482
|
+
return self._allocator
|
|
483
|
+
|
|
484
|
+
|
|
485
|
+
@contextlib.contextmanager
|
|
486
|
+
def use_mem_pool(pool: MemPool):
|
|
487
|
+
r"""
|
|
488
|
+
A context manager that routes allocations and deallocations to a given pool.
|
|
489
|
+
|
|
490
|
+
Note:
|
|
491
|
+
- This context manager makes only current thread's allocations route to the given pool.
|
|
492
|
+
- If a new thread is spawned inside the context manager the allocations in that thread
|
|
493
|
+
will not route to the given pool.
|
|
494
|
+
- Only by allocating Device memory inside the context manager, the allocation operation
|
|
495
|
+
can be routed to the given pool.
|
|
496
|
+
- Only Atlas A2 training series products support this interface.
|
|
497
|
+
|
|
498
|
+
Args:
|
|
499
|
+
pool(mindspore.runtime.MemPool): a MemPool object that warp a PluggableAllocator.
|
|
500
|
+
|
|
501
|
+
Supported Platforms:
|
|
502
|
+
``Ascend``
|
|
503
|
+
|
|
504
|
+
Examples:
|
|
505
|
+
>>> import mindspore as ms
|
|
506
|
+
>>> path = "/path/to/allocator.so"
|
|
507
|
+
>>> allocator = ms.runtime.PluggableAllocator(path, "Alloc", "Free")
|
|
508
|
+
>>> mem_pool = ms.runtime.MemPool(allocator)
|
|
509
|
+
>>> shape = (1024, 1024)
|
|
510
|
+
>>> x = ms.ops.Ones()(shape, ms.float32)
|
|
511
|
+
>>> with ms.runtime.use_mem_pool(mem_pool):
|
|
512
|
+
>>> y = ms.ops.Ones()(shape, ms.float32)
|
|
513
|
+
>>> output = x + y
|
|
514
|
+
"""
|
|
515
|
+
allocator = pool.allocator
|
|
516
|
+
_enable_pluggable_allocator(allocator.alloc_fn_ptr, allocator.free_fn_ptr)
|
|
517
|
+
try:
|
|
518
|
+
yield
|
|
519
|
+
finally:
|
|
520
|
+
_disable_pluggable_allocator()
|
mindspore/swresample-4.dll
CHANGED
|
Binary file
|
mindspore/swscale-6.dll
CHANGED
|
Binary file
|
mindspore/tinyxml2.dll
CHANGED
|
Binary file
|
|
@@ -12,11 +12,11 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ============================================================================
|
|
15
|
-
"""
|
|
15
|
+
"""Tools module."""
|
|
16
16
|
from __future__ import absolute_import
|
|
17
17
|
|
|
18
|
-
|
|
19
|
-
from mindspore.experimental.llm_boost.ascend_native import *
|
|
20
|
-
from mindspore.experimental.llm_boost.register import LlmBoostRegister
|
|
18
|
+
__all__ = ["stress_detect", "sdc_detect_start", "sdc_detect_stop", "get_sdc_detect_result", "set_dump"]
|
|
21
19
|
|
|
22
|
-
|
|
20
|
+
from .stress_detect import stress_detect
|
|
21
|
+
from .sdc_detect import sdc_detect_start, sdc_detect_stop, get_sdc_detect_result
|
|
22
|
+
from .data_dump import set_dump
|
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
# Copyright 2021-2025 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""Controlling dump behavior."""
|
|
16
|
+
from __future__ import absolute_import
|
|
17
|
+
from warnings import warn
|
|
18
|
+
|
|
19
|
+
import mindspore.context as context
|
|
20
|
+
from mindspore._c_expression import security
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def set_dump(target, enabled=True):
|
|
24
|
+
"""
|
|
25
|
+
Enable or disable dump for the `target` and its contents.
|
|
26
|
+
|
|
27
|
+
`target` should be an instance of :class:`mindspore.nn.Cell` or :class:`mindspore.ops.Primitive` .
|
|
28
|
+
Please note that this API takes effect only when the Dump function is enabled, and the `dump_mode`
|
|
29
|
+
field in the Dump configuration file is set to `"2"` with the `ms_backend` compilation backend
|
|
30
|
+
(please refer to the backend parameter in
|
|
31
|
+
`jit <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.jit.html>`_).
|
|
32
|
+
See the `dump document <https://www.mindspore.cn/tutorials/en/master/debug/dump.html>`_ for details.
|
|
33
|
+
By default, instances of :class:`mindspore.nn.Cell` and :class:`mindspore.ops.Primitive` do not enable
|
|
34
|
+
the Dump data feature.
|
|
35
|
+
|
|
36
|
+
Note:
|
|
37
|
+
1. This API is only available for JIT compilation, requires 'Ascend' as the device_target and
|
|
38
|
+
`ms_backend` as the compilation backend (please refer to the backend parameter in
|
|
39
|
+
`jit <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.jit.html>`_),
|
|
40
|
+
and does not support fused operators.
|
|
41
|
+
2. This API only supports being called before training starts.
|
|
42
|
+
If you call this API during training, it may not be effective.
|
|
43
|
+
3. After using `set_dump(Cell, True)` , operators in forward and backward
|
|
44
|
+
computation (computation generated by the grad operations) of the
|
|
45
|
+
cell will be dumped.
|
|
46
|
+
4. For :class:`mindspore.nn.SoftmaxCrossEntropyWithLogits` layer, the forward
|
|
47
|
+
computation and backward computation use the same set of
|
|
48
|
+
operators. So you can only see dump data from backward computation.
|
|
49
|
+
Please note that :class:`mindspore.nn.SoftmaxCrossEntropyWithLogits` layer will also use
|
|
50
|
+
the above operators internally when initialized with `sparse=True` and
|
|
51
|
+
`reduction="mean"` .
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
target (Union[Cell, Primitive]): The Cell instance or Primitive instance
|
|
55
|
+
to which the dump flag is set.
|
|
56
|
+
enabled (bool, optional): ``True`` indicates that the dump is enabled, and ``False`` indicates that
|
|
57
|
+
the dump is disabled.
|
|
58
|
+
Default: ``True`` .
|
|
59
|
+
|
|
60
|
+
Supported Platforms:
|
|
61
|
+
``Ascend``
|
|
62
|
+
|
|
63
|
+
Examples:
|
|
64
|
+
.. note::
|
|
65
|
+
Please set environment variable `MINDSPORE_DUMP_CONFIG` to the dump config file and set `dump_mode` field
|
|
66
|
+
in dump config file to 2 before running this example.
|
|
67
|
+
See `dump document <https://www.mindspore.cn/tutorials/en/master/debug/dump.html>`_ for details.
|
|
68
|
+
|
|
69
|
+
>>> import numpy as np
|
|
70
|
+
>>> import mindspore as ms
|
|
71
|
+
>>> import mindspore.nn as nn
|
|
72
|
+
>>> from mindspore import Tensor, jit
|
|
73
|
+
>>> from mindspore.tools import set_dump
|
|
74
|
+
>>>
|
|
75
|
+
>>> ms.set_device(device_target="Ascend")
|
|
76
|
+
>>>
|
|
77
|
+
>>> class MyNet(nn.Cell):
|
|
78
|
+
... def __init__(self):
|
|
79
|
+
... super().__init__()
|
|
80
|
+
... self.conv1 = nn.Conv2d(5, 6, 5, pad_mode='valid')
|
|
81
|
+
... self.relu1 = nn.ReLU()
|
|
82
|
+
...
|
|
83
|
+
... @jit
|
|
84
|
+
... def construct(self, x):
|
|
85
|
+
... x = self.conv1(x)
|
|
86
|
+
... x = self.relu1(x)
|
|
87
|
+
... return x
|
|
88
|
+
>>>
|
|
89
|
+
>>> if __name__ == "__main__":
|
|
90
|
+
... net = MyNet()
|
|
91
|
+
... set_dump(net.conv1)
|
|
92
|
+
... input_tensor = Tensor(np.ones([1, 5, 10, 10], dtype=np.float32))
|
|
93
|
+
... output = net(input_tensor)
|
|
94
|
+
"""
|
|
95
|
+
if security.enable_security():
|
|
96
|
+
raise ValueError('The set_dump API is not supported, please recompile '
|
|
97
|
+
'source without "-s on".')
|
|
98
|
+
|
|
99
|
+
import mindspore.nn as nn # avoid circular import
|
|
100
|
+
from mindspore.ops import Primitive
|
|
101
|
+
if not isinstance(target, nn.Cell) and not isinstance(target, Primitive):
|
|
102
|
+
raise ValueError(f"The \"target\" parameter must be an instance of "
|
|
103
|
+
f"Cell or Primitive, "
|
|
104
|
+
f"but got an instance of {type(target)}.")
|
|
105
|
+
|
|
106
|
+
if not isinstance(enabled, bool):
|
|
107
|
+
raise ValueError("The \"enabled\" parameter must be bool.")
|
|
108
|
+
|
|
109
|
+
# Checking for device target and mode.
|
|
110
|
+
current_target = context.get_context("device_target")
|
|
111
|
+
if current_target != "Ascend":
|
|
112
|
+
# We will not return here in case user changed device_target later.
|
|
113
|
+
warn("Current device_target is {}, which is not supported by set_dump. "
|
|
114
|
+
"Only Ascend device target is supported currently. "
|
|
115
|
+
"If you have Ascend device, consider set device_target to Ascend "
|
|
116
|
+
"before calling set_dump.".format(current_target))
|
|
117
|
+
|
|
118
|
+
# The actual set dump logic.
|
|
119
|
+
if isinstance(target, nn.Cell):
|
|
120
|
+
target.add_flags(dump=enabled)
|
|
121
|
+
for cell in target.cells():
|
|
122
|
+
set_dump(cell, enabled)
|
|
123
|
+
|
|
124
|
+
primitives = getattr(target, "_primitives", {})
|
|
125
|
+
for value in primitives.values():
|
|
126
|
+
if value and "dump" in value.attrs:
|
|
127
|
+
set_dump(value, enabled)
|
|
128
|
+
|
|
129
|
+
if isinstance(target, Primitive):
|
|
130
|
+
target.add_prim_attr("dump", "true" if enabled else "false")
|
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
# Copyright 2025 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""SDC detect."""
|
|
16
|
+
from mindspore import _c_expression
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def sdc_detect_start():
|
|
20
|
+
"""
|
|
21
|
+
Start silent data corruption detection. It will check the inputs and outputs of MatMul operations during the
|
|
22
|
+
forward and backward computations on the current device, which may increase execution time. The overhead of the
|
|
23
|
+
check time decreases as the matrix shapes increase. Starting sdc detection results in approximately 100%
|
|
24
|
+
performance degradation for a single 4096-sized MatMul computation, and approximately 90% degradation on the
|
|
25
|
+
Llama2-7B model (model parallel is 4, pipeline parallel is 2, and using qkv concatenation and ffn concatenation in
|
|
26
|
+
decoder layers).
|
|
27
|
+
|
|
28
|
+
Supported Platforms:
|
|
29
|
+
``Ascend``
|
|
30
|
+
|
|
31
|
+
Examples:
|
|
32
|
+
>>> from mindspore.tools import sdc_detect_start
|
|
33
|
+
>>> sdc_detect_start()
|
|
34
|
+
"""
|
|
35
|
+
_c_expression.sdc_detect_start()
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def sdc_detect_stop():
|
|
39
|
+
"""
|
|
40
|
+
Stop silent data corruption detection.
|
|
41
|
+
|
|
42
|
+
Supported Platforms:
|
|
43
|
+
``Ascend``
|
|
44
|
+
|
|
45
|
+
Examples:
|
|
46
|
+
>>> from mindspore.tools import sdc_detect_stop
|
|
47
|
+
>>> sdc_detect_stop()
|
|
48
|
+
"""
|
|
49
|
+
_c_expression.sdc_detect_stop()
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def get_sdc_detect_result():
|
|
53
|
+
"""
|
|
54
|
+
Get the result of silent data corruption detection.
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
bool, indicating whether silent data corruption has occurred after detection start.
|
|
58
|
+
|
|
59
|
+
Supported Platforms:
|
|
60
|
+
``Ascend``
|
|
61
|
+
|
|
62
|
+
Examples:
|
|
63
|
+
>>> from mindspore.tools import get_sdc_detect_result
|
|
64
|
+
>>> result = get_sdc_detect_result()
|
|
65
|
+
>>> print(result)
|
|
66
|
+
False
|
|
67
|
+
"""
|
|
68
|
+
return _c_expression.get_sdc_detect_result()
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class _SdcDetector:
|
|
72
|
+
"""
|
|
73
|
+
Manager of feature value sampling for SDC detect
|
|
74
|
+
"""
|
|
75
|
+
def __init__(self):
|
|
76
|
+
self.param_count = -1
|
|
77
|
+
|
|
78
|
+
def need_sample(self):
|
|
79
|
+
""""If need to sample feature value."""
|
|
80
|
+
if not _c_expression.is_silent_detect_enable():
|
|
81
|
+
return False
|
|
82
|
+
grad_sample_interval = _c_expression.get_silent_detect_config('grad_sample_interval')
|
|
83
|
+
self.param_count = (self.param_count + 1) % grad_sample_interval
|
|
84
|
+
return self.param_count == 0
|
|
85
|
+
|
|
86
|
+
@staticmethod
|
|
87
|
+
def get_dump_name(param_name):
|
|
88
|
+
"""Get dump file name with sdc prefix."""
|
|
89
|
+
return _c_expression.get_silent_detect_feature_name(param_name)
|
|
90
|
+
|
|
91
|
+
_sdc_detector = _SdcDetector()
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
# Copyright 2025 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""Stress detect."""
|
|
16
|
+
from mindspore import _c_expression
|
|
17
|
+
from mindspore import log as logger
|
|
18
|
+
from mindspore.communication import init, create_group, get_rank
|
|
19
|
+
from mindspore.communication import get_local_rank_size
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def stress_detect(detect_type="aic"):
|
|
23
|
+
"""
|
|
24
|
+
Used to detect whether there are faults in hardware accuracy or communication between links.
|
|
25
|
+
The common usage scenario is to initiate a new thread or call this interface through a Callback function
|
|
26
|
+
at each step or when saving checkpoints, to check whether hardware malfunctions could affect accuracy.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
detect_type (str, optional): The type of stress test to perform. There are two options available: ``'aic'`` and
|
|
30
|
+
``'hccs'``, which perform AiCore and HCCS link stress tests on the device, respectively. Default: "aic".
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
int, the return value represents the error type. 0 indicates normal. 1 indicates failure to start some or
|
|
34
|
+
all test cases. 2 indicates a hardware failure, and it is recommended to replace the device.
|
|
35
|
+
|
|
36
|
+
Supported Platforms:
|
|
37
|
+
``Ascend``
|
|
38
|
+
|
|
39
|
+
Examples:
|
|
40
|
+
>>> from mindspore.tools import stress_detect
|
|
41
|
+
>>> ret = stress_detect()
|
|
42
|
+
>>> print(ret)
|
|
43
|
+
0
|
|
44
|
+
"""
|
|
45
|
+
if detect_type not in ["aic", "hccs"]:
|
|
46
|
+
logger.error(f"For stress detect, detection type must be 'aic' or 'hccs'."
|
|
47
|
+
f"But got {detect_type}. Exiting stress detect.")
|
|
48
|
+
return 1
|
|
49
|
+
|
|
50
|
+
if detect_type == "aic":
|
|
51
|
+
return _c_expression.stress_detect("aic")
|
|
52
|
+
|
|
53
|
+
init()
|
|
54
|
+
local_ranks = []
|
|
55
|
+
local_rank_size = get_local_rank_size()
|
|
56
|
+
node_num = get_rank() // local_rank_size
|
|
57
|
+
for i in range(local_rank_size):
|
|
58
|
+
local_ranks.append(local_rank_size * node_num + i)
|
|
59
|
+
if get_rank() in local_ranks:
|
|
60
|
+
group = f"new_group_{node_num}"
|
|
61
|
+
create_group(group, local_ranks)
|
|
62
|
+
|
|
63
|
+
return _c_expression.stress_detect(group)
|
mindspore/train/__init__.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2025 Huawei Technologies Co., Ltd
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -25,8 +25,8 @@ from mindspore.train import amp
|
|
|
25
25
|
from mindspore.train.amp import build_train_network
|
|
26
26
|
from mindspore.train.loss_scale_manager import LossScaleManager, FixedLossScaleManager, DynamicLossScaleManager
|
|
27
27
|
from mindspore.train.serialization import save_checkpoint, load_checkpoint, load_param_into_net, export, \
|
|
28
|
-
load,
|
|
29
|
-
load_checkpoint_async,
|
|
28
|
+
load, async_ckpt_thread_status, export_split_mindir, \
|
|
29
|
+
load_checkpoint_async, get_ckpt_path_with_strategy, ckpt_to_safetensors, safetensors_to_ckpt, \
|
|
30
30
|
build_searched_strategy, merge_sliced_parameter, load_distributed_checkpoint, restore_group_info_list
|
|
31
31
|
from mindspore.train.callback import Callback, LossMonitor, TimeMonitor, ModelCheckpoint, SummaryCollector, \
|
|
32
32
|
CheckpointConfig, RunContext, LearningRateScheduler, SummaryLandscape, FlopsUtilizationCollector, \
|
|
@@ -37,9 +37,9 @@ from mindspore.train.metrics import *
|
|
|
37
37
|
from mindspore.train.data_sink import data_sink
|
|
38
38
|
|
|
39
39
|
__all__ = ["Model", "DatasetHelper", "connect_network_with_dataset", "build_train_network", "LossScaleManager",
|
|
40
|
-
"FixedLossScaleManager", "DynamicLossScaleManager", "save_checkpoint", "load_checkpoint",
|
|
41
|
-
"load_param_into_net", "export", "load", "export_split_mindir", "
|
|
42
|
-
"
|
|
40
|
+
"FixedLossScaleManager", "DynamicLossScaleManager", "save_checkpoint", "load_checkpoint",
|
|
41
|
+
"load_param_into_net", "export", "load", "export_split_mindir", "async_ckpt_thread_status",
|
|
42
|
+
"data_sink", "load_checkpoint_async", "get_ckpt_path_with_strategy", "ckpt_to_safetensors",
|
|
43
43
|
"safetensors_to_ckpt", "build_searched_strategy", "merge_sliced_parameter", "load_distributed_checkpoint",
|
|
44
44
|
"restore_group_info_list"]
|
|
45
45
|
__all__.extend(callback.__all__)
|
mindspore/train/_utils.py
CHANGED
|
@@ -344,15 +344,7 @@ def _get_layout_opt_shard(layout_obj, param_redundancy_dict):
|
|
|
344
344
|
"""Layout ckpt append opt shard."""
|
|
345
345
|
for key, value in layout_obj.items():
|
|
346
346
|
if value[5]:
|
|
347
|
-
|
|
348
|
-
if value[5] in world_groups:
|
|
349
|
-
opt_para_num = get_group_size()
|
|
350
|
-
elif "-" in value[5]:
|
|
351
|
-
opt_para_str = value[5].split("-")[0]
|
|
352
|
-
opt_para_num = int(opt_para_str)
|
|
353
|
-
else:
|
|
354
|
-
raise ValueError(f"For get_parameter_redundancy, the format of the parallel communication domain for "
|
|
355
|
-
f"the optimizer is incorrect.")
|
|
347
|
+
opt_para_num = get_group_size(value[5])
|
|
356
348
|
param_redundancy_ranks = param_redundancy_dict.get(key)
|
|
357
349
|
res = []
|
|
358
350
|
for param_ranks in param_redundancy_ranks:
|
|
@@ -582,17 +574,12 @@ def _progress_bar(iterable, total=None):
|
|
|
582
574
|
print_progress_bar(i)
|
|
583
575
|
|
|
584
576
|
|
|
585
|
-
def _load_and_transform(path, name_map, load_func
|
|
577
|
+
def _load_and_transform(path, name_map, load_func):
|
|
586
578
|
"""use load_func to load and use transform_func to convert"""
|
|
587
|
-
|
|
588
|
-
param_dict = load_func(path)
|
|
589
|
-
else:
|
|
590
|
-
param_dict = path
|
|
579
|
+
param_dict = load_func(path)
|
|
591
580
|
transform_dict = {}
|
|
581
|
+
|
|
592
582
|
for k, v in param_dict.items():
|
|
593
583
|
new_name = name_map.get(k, k) if name_map is not None else k
|
|
594
|
-
|
|
595
|
-
transform_dict[new_name] = transform_func(v, new_name)
|
|
596
|
-
else:
|
|
597
|
-
transform_dict[new_name] = v
|
|
584
|
+
transform_dict[new_name] = v
|
|
598
585
|
return transform_dict
|
mindspore/train/amp.py
CHANGED
|
@@ -818,8 +818,10 @@ def get_white_list():
|
|
|
818
818
|
<class 'mindspore.ops.operations.nn_ops.Conv2DTranspose'>,
|
|
819
819
|
<class 'mindspore.ops.operations.nn_ops.Conv3DTranspose'>,
|
|
820
820
|
<class 'mindspore.ops.operations.nn_ops.Conv2DBackpropInput'>,
|
|
821
|
-
<class 'mindspore.ops.
|
|
822
|
-
<class 'mindspore.ops.
|
|
821
|
+
<class 'mindspore.ops.auto_generate.gen_ops_prim.MatMul'>,
|
|
822
|
+
<class 'mindspore.ops.auto_generate.gen_ops_prim.BatchMatMul'>,
|
|
823
|
+
<class 'mindspore.ops.auto_generate.gen_ops_prim.PReLU'>,
|
|
824
|
+
<class 'mindspore.ops.auto_generate.gen_ops_prim.ReLU'>,
|
|
823
825
|
<class 'mindspore.ops.operations.math_ops.Ger'>]
|
|
824
826
|
"""
|
|
825
827
|
white_list = AMP_WHITE_LIST.copy()
|
|
@@ -871,8 +873,8 @@ def custom_mixed_precision(network, *, white_list=None, black_list=None, dtype=m
|
|
|
871
873
|
white list is not used.
|
|
872
874
|
black_list (list[Cell], optional): Black list of custom mixed precision. Defaults: ``None`` , means
|
|
873
875
|
black list is not used.
|
|
874
|
-
dtype (Type): The type used in lower precision calculations, can be ``mstype.float16`` or
|
|
875
|
-
default: ``mstype.float16`` .
|
|
876
|
+
dtype (Type, optional): The type used in lower precision calculations, can be ``mstype.float16`` or
|
|
877
|
+
``mstype.bfloat16`` , default: ``mstype.float16`` .
|
|
876
878
|
|
|
877
879
|
Returns:
|
|
878
880
|
network (Cell), A network supporting mixed precision.
|
|
@@ -27,7 +27,6 @@ from mindspore.train._utils import _make_directory
|
|
|
27
27
|
from mindspore.train.serialization import save_checkpoint, _save_graph, _wait_async_process_save_ckpt, \
|
|
28
28
|
_wait_async_thread_save_ckpt, _check_async_save
|
|
29
29
|
from mindspore.parallel._cell_wrapper import destroy_allgather_cell
|
|
30
|
-
from mindspore.parallel._recovery_context import _set_recovery_context, _get_recovery_context
|
|
31
30
|
from mindspore.communication.management import get_rank, get_group_size
|
|
32
31
|
from mindspore.train._utils import get_parameter_redundancy, remove_param_redundancy, _get_pp_size_from_redundancy_map
|
|
33
32
|
from mindspore.train.callback._callback import Callback
|
|
@@ -509,9 +508,6 @@ class ModelCheckpoint(Callback):
|
|
|
509
508
|
if callable(prefix):
|
|
510
509
|
self._prefix_func = prefix
|
|
511
510
|
|
|
512
|
-
if context.get_context("device_target") == "GPU" and _get_recovery_context("enable_recovery"):
|
|
513
|
-
_set_recovery_context(ckpt_path=self._directory)
|
|
514
|
-
|
|
515
511
|
if config is None:
|
|
516
512
|
self._config = CheckpointConfig()
|
|
517
513
|
else:
|
|
@@ -577,11 +573,6 @@ class ModelCheckpoint(Callback):
|
|
|
577
573
|
self._directory = self._directory_func(cb_params)
|
|
578
574
|
_make_directory(self._directory)
|
|
579
575
|
collect_host_info("Callback", "ModelCheckpoint", "step_end", start_time=get_clock_syscnt(), level=1)
|
|
580
|
-
# In disaster recovery scenario, the training process may be rolled back to the last step where
|
|
581
|
-
# the ckpt was successfully saved, so the _last_triggered_step should be updated.
|
|
582
|
-
if _get_recovery_context("enable_recovery") and cb_params.last_save_ckpt_step is not None:
|
|
583
|
-
self._last_triggered_step = cb_params.last_save_ckpt_step
|
|
584
|
-
cb_params.last_save_ckpt_step = None
|
|
585
576
|
|
|
586
577
|
# save graph (only once)
|
|
587
578
|
if not self._graph_saved:
|