mindspore 2.3.0__cp310-cp310-win_amd64.whl → 2.4.0__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +3 -1
- mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +50 -9
- mindspore/_extends/parse/compile_config.py +41 -0
- mindspore/_extends/parse/parser.py +9 -7
- mindspore/_extends/parse/standard_method.py +52 -14
- mindspore/_extends/pijit/pijit_func_white_list.py +350 -24
- mindspore/amp.py +24 -10
- mindspore/atlprov.dll +0 -0
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +6 -4
- mindspore/common/_pijit_context.py +190 -0
- mindspore/common/_register_for_tensor.py +2 -1
- mindspore/common/_tensor_overload.py +139 -0
- mindspore/common/api.py +102 -87
- mindspore/common/dump.py +5 -6
- mindspore/common/generator.py +1 -7
- mindspore/common/hook_handle.py +14 -26
- mindspore/common/mindir_util.py +2 -2
- mindspore/common/parameter.py +46 -13
- mindspore/common/recompute.py +39 -9
- mindspore/common/sparse_tensor.py +7 -3
- mindspore/common/tensor.py +209 -29
- mindspore/communication/__init__.py +1 -1
- mindspore/communication/_comm_helper.py +38 -3
- mindspore/communication/comm_func.py +310 -55
- mindspore/communication/management.py +14 -14
- mindspore/context.py +123 -22
- mindspore/dataset/__init__.py +1 -1
- mindspore/dataset/audio/__init__.py +1 -1
- mindspore/dataset/core/config.py +7 -0
- mindspore/dataset/core/validator_helpers.py +7 -0
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +72 -44
- mindspore/dataset/engine/datasets_audio.py +7 -7
- mindspore/dataset/engine/datasets_standard_format.py +53 -3
- mindspore/dataset/engine/datasets_text.py +20 -20
- mindspore/dataset/engine/datasets_user_defined.py +174 -104
- mindspore/dataset/engine/datasets_vision.py +33 -33
- mindspore/dataset/engine/iterators.py +29 -0
- mindspore/dataset/engine/obs/util.py +7 -0
- mindspore/dataset/engine/queue.py +114 -60
- mindspore/dataset/engine/serializer_deserializer.py +2 -2
- mindspore/dataset/engine/validators.py +34 -14
- mindspore/dataset/text/__init__.py +1 -4
- mindspore/dataset/transforms/__init__.py +0 -3
- mindspore/dataset/utils/line_reader.py +2 -0
- mindspore/dataset/vision/__init__.py +1 -4
- mindspore/dataset/vision/utils.py +1 -1
- mindspore/dataset/vision/validators.py +2 -1
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/{nn/extend → experimental/es}/__init__.py +4 -11
- mindspore/experimental/es/embedding_service.py +883 -0
- mindspore/{nn/layer → experimental/es}/embedding_service_layer.py +218 -30
- mindspore/experimental/llm_boost/__init__.py +21 -0
- mindspore/{nn/extend/layer → experimental/llm_boost/atb}/__init__.py +4 -8
- mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
- mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
- mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
- mindspore/experimental/llm_boost/register.py +129 -0
- mindspore/experimental/llm_boost/utils.py +31 -0
- mindspore/experimental/optim/adamw.py +85 -0
- mindspore/experimental/optim/optimizer.py +3 -0
- mindspore/hal/__init__.py +3 -3
- mindspore/hal/contiguous_tensors_handle.py +175 -0
- mindspore/hal/stream.py +18 -0
- mindspore/include/api/model_group.h +13 -1
- mindspore/include/api/types.h +10 -10
- mindspore/include/dataset/config.h +2 -2
- mindspore/include/dataset/constants.h +2 -2
- mindspore/include/dataset/execute.h +2 -2
- mindspore/include/dataset/vision.h +4 -0
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +1 -1
- mindspore/mindrecord/filewriter.py +68 -51
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mint/__init__.py +495 -46
- mindspore/mint/distributed/__init__.py +31 -0
- mindspore/mint/distributed/distributed.py +254 -0
- mindspore/mint/nn/__init__.py +266 -21
- mindspore/mint/nn/functional.py +125 -19
- mindspore/mint/nn/layer/__init__.py +39 -0
- mindspore/mint/nn/layer/activation.py +133 -0
- mindspore/mint/nn/layer/normalization.py +477 -0
- mindspore/mint/nn/layer/pooling.py +110 -0
- mindspore/mint/optim/adamw.py +28 -7
- mindspore/mint/special/__init__.py +63 -0
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/multiprocessing/__init__.py +2 -1
- mindspore/nn/__init__.py +0 -1
- mindspore/nn/cell.py +275 -93
- mindspore/nn/layer/activation.py +211 -44
- mindspore/nn/layer/basic.py +113 -3
- mindspore/nn/layer/embedding.py +120 -2
- mindspore/nn/layer/normalization.py +101 -5
- mindspore/nn/layer/padding.py +34 -48
- mindspore/nn/layer/pooling.py +161 -7
- mindspore/nn/layer/transformer.py +3 -3
- mindspore/nn/loss/__init__.py +2 -2
- mindspore/nn/loss/loss.py +84 -6
- mindspore/nn/optim/__init__.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -1
- mindspore/nn/optim/adam.py +1 -1
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/tft_wrapper.py +127 -0
- mindspore/nn/wrap/cell_wrapper.py +12 -23
- mindspore/nn/wrap/grad_reducer.py +5 -5
- mindspore/nn/wrap/loss_scale.py +17 -3
- mindspore/numpy/__init__.py +1 -1
- mindspore/numpy/array_creations.py +65 -68
- mindspore/numpy/array_ops.py +64 -60
- mindspore/numpy/fft.py +610 -75
- mindspore/numpy/logic_ops.py +11 -10
- mindspore/numpy/math_ops.py +85 -84
- mindspore/numpy/utils_const.py +4 -4
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +6 -4
- mindspore/ops/_grad_experimental/grad_comm_ops.py +47 -3
- mindspore/ops/_grad_experimental/grad_math_ops.py +0 -22
- mindspore/ops/_vmap/vmap_array_ops.py +2 -4
- mindspore/ops/_vmap/vmap_math_ops.py +17 -1
- mindspore/ops/_vmap/vmap_nn_ops.py +43 -2
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +85 -7
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +2 -0
- mindspore/ops/auto_generate/gen_extend_func.py +734 -13
- mindspore/ops/auto_generate/gen_ops_def.py +2420 -381
- mindspore/ops/auto_generate/gen_ops_prim.py +5196 -1659
- mindspore/ops/auto_generate/pyboost_inner_prim.py +176 -56
- mindspore/ops/composite/base.py +85 -48
- mindspore/ops/composite/multitype_ops/_compile_utils.py +1 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -2
- mindspore/ops/function/__init__.py +22 -0
- mindspore/ops/function/array_func.py +490 -153
- mindspore/ops/function/debug_func.py +113 -1
- mindspore/ops/function/fft_func.py +15 -2
- mindspore/ops/function/grad/grad_func.py +3 -2
- mindspore/ops/function/math_func.py +558 -207
- mindspore/ops/function/nn_func.py +817 -383
- mindspore/ops/function/other_func.py +3 -2
- mindspore/ops/function/random_func.py +184 -8
- mindspore/ops/function/reshard_func.py +13 -11
- mindspore/ops/function/sparse_unary_func.py +1 -1
- mindspore/ops/function/vmap_func.py +3 -2
- mindspore/ops/functional.py +24 -14
- mindspore/ops/op_info_register.py +3 -3
- mindspore/ops/operations/__init__.py +6 -1
- mindspore/ops/operations/_grad_ops.py +2 -76
- mindspore/ops/operations/_infer_ops.py +1 -1
- mindspore/ops/operations/_inner_ops.py +71 -94
- mindspore/ops/operations/array_ops.py +12 -146
- mindspore/ops/operations/comm_ops.py +42 -53
- mindspore/ops/operations/custom_ops.py +83 -19
- mindspore/ops/operations/debug_ops.py +42 -10
- mindspore/ops/operations/manually_defined/_inner.py +12 -0
- mindspore/ops/operations/manually_defined/ops_def.py +265 -10
- mindspore/ops/operations/math_ops.py +12 -223
- mindspore/ops/operations/nn_ops.py +20 -114
- mindspore/ops/operations/other_ops.py +7 -4
- mindspore/ops/operations/random_ops.py +46 -1
- mindspore/ops/primitive.py +18 -6
- mindspore/ops_generate/arg_dtype_cast.py +2 -0
- mindspore/ops_generate/gen_aclnn_implement.py +11 -11
- mindspore/ops_generate/gen_constants.py +36 -0
- mindspore/ops_generate/gen_ops.py +67 -52
- mindspore/ops_generate/gen_ops_inner_prim.py +1 -1
- mindspore/ops_generate/gen_pyboost_func.py +131 -47
- mindspore/ops_generate/op_proto.py +10 -3
- mindspore/ops_generate/pyboost_utils.py +14 -1
- mindspore/ops_generate/template.py +43 -21
- mindspore/parallel/__init__.py +3 -1
- mindspore/parallel/_auto_parallel_context.py +28 -8
- mindspore/parallel/_cell_wrapper.py +83 -0
- mindspore/parallel/_parallel_serialization.py +47 -19
- mindspore/parallel/_tensor.py +81 -11
- mindspore/parallel/_utils.py +13 -1
- mindspore/parallel/algo_parameter_config.py +5 -5
- mindspore/parallel/checkpoint_transform.py +46 -39
- mindspore/parallel/cluster/process_entity/__init__.py +1 -1
- mindspore/parallel/cluster/process_entity/_api.py +31 -23
- mindspore/parallel/cluster/process_entity/_utils.py +2 -27
- mindspore/parallel/parameter_broadcast.py +3 -4
- mindspore/parallel/shard.py +162 -31
- mindspore/parallel/transform_safetensors.py +993 -0
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +2 -1
- mindspore/profiler/common/constant.py +29 -0
- mindspore/profiler/common/registry.py +47 -0
- mindspore/profiler/common/util.py +28 -0
- mindspore/profiler/dynamic_profiler.py +694 -0
- mindspore/profiler/envprofiling.py +17 -19
- mindspore/profiler/parser/ascend_analysis/constant.py +18 -0
- mindspore/profiler/parser/ascend_analysis/file_manager.py +25 -4
- mindspore/profiler/parser/ascend_analysis/function_event.py +43 -19
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +31 -26
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +56 -10
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +55 -8
- mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +27 -20
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +9 -2
- mindspore/profiler/parser/ascend_msprof_exporter.py +5 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +27 -25
- mindspore/profiler/parser/base_timeline_generator.py +19 -25
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
- mindspore/profiler/parser/framework_parser.py +1 -391
- mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
- mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
- mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
- mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
- mindspore/profiler/parser/memory_usage_parser.py +0 -154
- mindspore/profiler/parser/profiler_info.py +78 -6
- mindspore/profiler/profiler.py +153 -0
- mindspore/profiler/profiling.py +280 -412
- mindspore/rewrite/__init__.py +1 -2
- mindspore/rewrite/common/namespace.py +4 -4
- mindspore/rewrite/symbol_tree/symbol_tree.py +3 -3
- mindspore/run_check/_check_version.py +36 -103
- mindspore/safeguard/rewrite_obfuscation.py +591 -247
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +4 -3
- mindspore/train/_utils.py +28 -2
- mindspore/train/amp.py +171 -53
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +4 -4
- mindspore/train/callback/_checkpoint.py +85 -22
- mindspore/train/callback/_cluster_monitor.py +1 -1
- mindspore/train/callback/_flops_collector.py +1 -0
- mindspore/train/callback/_loss_monitor.py +3 -3
- mindspore/train/callback/_on_request_exit.py +134 -31
- mindspore/train/callback/_summary_collector.py +5 -5
- mindspore/train/callback/_tft_register.py +352 -0
- mindspore/train/dataset_helper.py +7 -3
- mindspore/train/metrics/metric.py +3 -3
- mindspore/train/metrics/roc.py +4 -4
- mindspore/train/mind_ir_pb2.py +44 -39
- mindspore/train/model.py +134 -58
- mindspore/train/serialization.py +336 -112
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +21 -0
- mindspore/utils/utils.py +60 -0
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/METADATA +6 -2
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/RECORD +281 -275
- mindspore/include/c_api/ms/abstract.h +0 -67
- mindspore/include/c_api/ms/attribute.h +0 -197
- mindspore/include/c_api/ms/base/handle_types.h +0 -43
- mindspore/include/c_api/ms/base/macros.h +0 -32
- mindspore/include/c_api/ms/base/status.h +0 -33
- mindspore/include/c_api/ms/base/types.h +0 -283
- mindspore/include/c_api/ms/context.h +0 -102
- mindspore/include/c_api/ms/graph.h +0 -160
- mindspore/include/c_api/ms/node.h +0 -606
- mindspore/include/c_api/ms/tensor.h +0 -161
- mindspore/include/c_api/ms/value.h +0 -84
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/nn/extend/basic.py +0 -140
- mindspore/nn/extend/embedding.py +0 -143
- mindspore/nn/extend/layer/normalization.py +0 -109
- mindspore/nn/extend/pooling.py +0 -117
- mindspore/nn/layer/embedding_service.py +0 -531
- mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
- mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
- mindspore/ops/extend/__init__.py +0 -53
- mindspore/ops/extend/array_func.py +0 -218
- mindspore/ops/extend/math_func.py +0 -76
- mindspore/ops/extend/nn_func.py +0 -308
- mindspore/ops/silent_check.py +0 -162
- mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
- mindspore/profiler/parser/msadvisor_parser.py +0 -240
- mindspore/train/callback/_mindio_ttp.py +0 -443
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/WHEEL +0 -0
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/entry_points.txt +0 -0
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/top_level.txt +0 -0
mindspore/common/api.py
CHANGED
|
@@ -38,12 +38,13 @@ from mindspore.common.tensor import Tensor as PythonTensor
|
|
|
38
38
|
from mindspore.common.sparse_tensor import CSRTensor as PythonCSRTensor
|
|
39
39
|
from mindspore.common.sparse_tensor import COOTensor as PythonCOOTensor
|
|
40
40
|
from mindspore.common.sparse_tensor import RowTensor as PythonRowTensor
|
|
41
|
+
from mindspore._c_expression.amp import get_curr_amp_strategy
|
|
41
42
|
from mindspore._c_expression import GraphExecutor_, Tensor, CSRTensor, RowTensor, COOTensor, \
|
|
42
43
|
PyNativeExecutor_, verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_pipeline, \
|
|
43
|
-
_ms_memory_recycle, _bind_device_ctx
|
|
44
|
+
_ms_memory_recycle, _bind_device_ctx
|
|
44
45
|
from mindspore.parallel._ps_context import _is_role_sched
|
|
45
46
|
from mindspore.parallel._utils import _check_full_batch, _get_parameter_broadcast, _is_pynative_parallel, \
|
|
46
|
-
_is_in_auto_parallel_mode
|
|
47
|
+
_is_in_auto_parallel_mode, _is_parallel_mode
|
|
47
48
|
from mindspore import _checkparam as Validator
|
|
48
49
|
from mindspore._checkparam import is_stub_tensor
|
|
49
50
|
from mindspore.common._utils import is_shape_unknown
|
|
@@ -51,6 +52,8 @@ from mindspore.common.mutable import mutable
|
|
|
51
52
|
from mindspore.common._register_for_adapter import ms_adapter_registry
|
|
52
53
|
from mindspore.common.auto_dynamic_shape import get_auto_dynamic_shape_args, update_auto_dynamic_shape_phase, \
|
|
53
54
|
get_auto_dynamic_shape_args_with_check_input_signature, update_auto_dynamic_shape_phase_with_check_input_signature
|
|
55
|
+
from mindspore.common._pijit_context import PIJitCaptureContext
|
|
56
|
+
from mindspore.common.parameter import Parameter
|
|
54
57
|
|
|
55
58
|
# Store ms_function class compiled pipeline cache.
|
|
56
59
|
ms_compile_cache = set()
|
|
@@ -513,6 +516,19 @@ def _generate_dyn_compile_args(compile_args, dyn_args):
|
|
|
513
516
|
return tuple(new_compile_args)
|
|
514
517
|
|
|
515
518
|
|
|
519
|
+
def _get_parameter_ids(args, kwargs):
|
|
520
|
+
"""Get the ids of parameters."""
|
|
521
|
+
parameter_ids = ""
|
|
522
|
+
for arg in args:
|
|
523
|
+
if isinstance(arg, Parameter):
|
|
524
|
+
parameter_ids += str(id(arg))
|
|
525
|
+
for _, value in kwargs.items():
|
|
526
|
+
# The type of key is usually String type.
|
|
527
|
+
if isinstance(value, Parameter):
|
|
528
|
+
parameter_ids += str(id(value))
|
|
529
|
+
return parameter_ids
|
|
530
|
+
|
|
531
|
+
|
|
516
532
|
class _MindsporeFunctionExecutor:
|
|
517
533
|
"""
|
|
518
534
|
Represents a function compiled by graph compiler.
|
|
@@ -625,6 +641,10 @@ class _MindsporeFunctionExecutor:
|
|
|
625
641
|
|
|
626
642
|
self._graph_executor.set_enable_tuple_broaden(self.enable_tuple_broaden)
|
|
627
643
|
key = self._graph_executor.generate_arguments_key(self.fn, compile_args, kwargs, self.enable_tuple_broaden)
|
|
644
|
+
|
|
645
|
+
parameter_ids = _get_parameter_ids(args, kwargs)
|
|
646
|
+
if parameter_ids != "":
|
|
647
|
+
key = str(key) + '.' + parameter_ids
|
|
628
648
|
phase = generate_name + '.' + str(key)
|
|
629
649
|
|
|
630
650
|
update_auto_dynamic_shape_phase_with_check_input_signature(compile_args, key_id, phase, self.input_signature)
|
|
@@ -783,31 +803,28 @@ def _get_jit_hash(hash_input):
|
|
|
783
803
|
return _get_obj_id(hash_input)
|
|
784
804
|
|
|
785
805
|
|
|
786
|
-
def _update_graph_executor_config(jit_config):
|
|
787
|
-
"""Update GraphExecutor jit_config"""
|
|
788
|
-
if isinstance(jit_config, JitConfig):
|
|
789
|
-
jit_config = jit_config.jit_config_dict
|
|
790
|
-
if not isinstance(jit_config, dict):
|
|
791
|
-
return
|
|
792
|
-
valid_config = dict()
|
|
793
|
-
for k, v in jit_config.items():
|
|
794
|
-
valid_config[str(k)] = str(v)
|
|
795
|
-
GraphExecutor_.get_instance().set_jit_config(JitConfig(**valid_config).jit_config_dict)
|
|
796
|
-
|
|
797
|
-
|
|
798
806
|
def jit(fn=None, mode="PSJit", input_signature=None, hash_args=None, jit_config=None, compile_once=False):
|
|
799
807
|
"""
|
|
800
808
|
Create a callable MindSpore graph from a Python function.
|
|
801
809
|
|
|
802
810
|
This allows the MindSpore runtime to apply optimizations based on graph.
|
|
803
811
|
|
|
812
|
+
Note:
|
|
813
|
+
- If `input_signature` is specified, each input of `fn` must be a Tensor. And the input arguments for `fn`
|
|
814
|
+
will not accept `**kwargs`.
|
|
815
|
+
- It is not supported to run a function with decoration @jit(mode=“PIJit”)
|
|
816
|
+
in static graph mode, in which case the decoration @jit(mode=“PIJit”) is considered invalid.
|
|
817
|
+
- Calls to functions with decorated @jit(mode=“PIJit”) inside functions
|
|
818
|
+
decorated with @jit(mode=“PIJit”) are not supported,
|
|
819
|
+
and the decoration @jit(mode=“PIJit”) is considered invalid.
|
|
820
|
+
|
|
804
821
|
Args:
|
|
805
822
|
fn (Function): The Python function that will be run as a graph. Default: ``None`` .
|
|
806
823
|
mode (str): The type of jit used, the value of mode should be ``PIJit`` or ``PSJit``. Default: ``PSJit`` .
|
|
807
824
|
|
|
808
|
-
- `PSJit <https://www.mindspore.cn/docs/en/master/
|
|
825
|
+
- `PSJit <https://www.mindspore.cn/docs/en/master/model_train/program_form/static_graph.html>`_ :
|
|
809
826
|
Parse python ast to build graph.
|
|
810
|
-
- `PIJit <https://www.mindspore.cn/docs/en/master/
|
|
827
|
+
- `PIJit <https://www.mindspore.cn/docs/en/master/model_train/program_form/pynative.html#pijit>`_ :
|
|
811
828
|
Parse python bytecode to build graph at runtime.
|
|
812
829
|
|
|
813
830
|
input_signature (Union[Tuple, List, Dict, Tensor]): The Tensor which describes the input arguments. The
|
|
@@ -831,10 +848,6 @@ def jit(fn=None, mode="PSJit", input_signature=None, hash_args=None, jit_config=
|
|
|
831
848
|
it was created again.
|
|
832
849
|
Default: ``False`` .
|
|
833
850
|
|
|
834
|
-
Note:
|
|
835
|
-
If `input_signature` is specified, each input of `fn` must be a Tensor. And the input arguments for `fn`
|
|
836
|
-
will not accept `**kwargs`.
|
|
837
|
-
|
|
838
851
|
Returns:
|
|
839
852
|
Function, if `fn` is not None, returns a callable function that will execute the compiled function; If `fn` is
|
|
840
853
|
None, returns a decorator and when this decorator invokes with a single `fn` argument, the callable function is
|
|
@@ -938,45 +951,20 @@ def jit(fn=None, mode="PSJit", input_signature=None, hash_args=None, jit_config=
|
|
|
938
951
|
# only the function or cell instance wrapped by shard will fall into this branch
|
|
939
952
|
if _is_pynative_parallel() and func.__name__ == _PYNATIVE_PARALLEL_FUNC_NAME:
|
|
940
953
|
process_obj = hash_args
|
|
954
|
+
# Handle auto mixed precision strategy.
|
|
955
|
+
if not hasattr(func, "amp_strategy"):
|
|
956
|
+
if isinstance(func, types.MethodType):
|
|
957
|
+
setattr(func.__func__, "amp_strategy", get_curr_amp_strategy())
|
|
958
|
+
else:
|
|
959
|
+
setattr(func, "amp_strategy", get_curr_amp_strategy())
|
|
941
960
|
out = _MindsporeFunctionExecutor(func, hash_obj, dyn_args, process_obj, jit_config)(*args, **kwargs)
|
|
942
961
|
return out
|
|
943
962
|
|
|
944
963
|
return staging_specialize
|
|
945
964
|
|
|
946
|
-
def pi_wrap_mindspore(decorated):
|
|
947
|
-
func = decorated
|
|
948
|
-
if isinstance(func, ms.nn.Cell):
|
|
949
|
-
func = func.construct
|
|
950
|
-
if isinstance(func, type) and issubclass(func, ms.nn.Cell):
|
|
951
|
-
func = func.construct
|
|
952
|
-
if isinstance(func, types.MethodType):
|
|
953
|
-
func = func.__func__
|
|
954
|
-
if not isinstance(func, types.FunctionType):
|
|
955
|
-
logger.warning("only support function and mindspore.nn.Cell instance")
|
|
956
|
-
return decorated
|
|
957
|
-
|
|
958
|
-
# generator, coroutine, awaitable and a function that return them is unsupported
|
|
959
|
-
UNSUPPORTED_CODE_TYPE = (inspect.CO_GENERATOR | inspect.CO_COROUTINE |
|
|
960
|
-
inspect.CO_ASYNC_GENERATOR | inspect.CO_ITERABLE_COROUTINE)
|
|
961
|
-
if func.__code__.co_flags & UNSUPPORTED_CODE_TYPE:
|
|
962
|
-
return decorated
|
|
963
|
-
|
|
964
|
-
_update_graph_executor_config(jit_config)
|
|
965
|
-
config = dict()
|
|
966
|
-
if isinstance(jit_config, JitConfig):
|
|
967
|
-
config.update(jit_config.jit_config_dict)
|
|
968
|
-
elif jit_config is not None:
|
|
969
|
-
config.update(jit_config)
|
|
970
|
-
jit_mode_pi_enable()
|
|
971
|
-
|
|
972
|
-
if jit_mode_pi_compile(func, config, input_signature) is False:
|
|
973
|
-
logger.warning('add fn {} to compile failed '.format(func))
|
|
974
|
-
|
|
975
|
-
return decorated
|
|
976
|
-
|
|
977
965
|
wrap_func = wrap_mindspore
|
|
978
966
|
if mode == "PIJit":
|
|
979
|
-
wrap_func =
|
|
967
|
+
wrap_func = PIJitCaptureContext(jit_config, input_signature)
|
|
980
968
|
|
|
981
969
|
if fn is not None:
|
|
982
970
|
return wrap_func(fn)
|
|
@@ -1272,7 +1260,7 @@ def jit_class(cls):
|
|
|
1272
1260
|
if not inspect.isclass(cls):
|
|
1273
1261
|
raise TypeError(f'Decorator jit_class can only be used for class type, but got {cls}.')
|
|
1274
1262
|
# Check if cls is nn.Cell.
|
|
1275
|
-
if issubclass(cls, nn.Cell):
|
|
1263
|
+
if issubclass(cls, nn.cell.Cell):
|
|
1276
1264
|
raise TypeError(f"Decorator jit_class is used for user-defined classes and cannot be used for nn.Cell: {cls}.")
|
|
1277
1265
|
setattr(cls, '__ms_class__', True)
|
|
1278
1266
|
return cls
|
|
@@ -1463,23 +1451,22 @@ class _PyNativeExecutor:
|
|
|
1463
1451
|
"""
|
|
1464
1452
|
self._executor.end_graph(obj, output, *args, *(kwargs.values()))
|
|
1465
1453
|
|
|
1466
|
-
def check_run(self, grad, obj, weights, grad_hash_id, *args
|
|
1454
|
+
def check_run(self, grad, obj, weights, grad_hash_id, *args):
|
|
1467
1455
|
"""
|
|
1468
1456
|
Whether the forward graph need to construct.
|
|
1469
1457
|
|
|
1470
1458
|
Args:
|
|
1471
1459
|
grad (GradOperation): The gradoperation object.
|
|
1472
1460
|
obj (Function/Cell): The function or cell instance.
|
|
1473
|
-
grad_hash_id (tuple): The id of objects which
|
|
1461
|
+
grad_hash_id (tuple): The id of objects, which contributes to cache of compiled graph in pynative mode.
|
|
1474
1462
|
args (tuple): Function or cell input arguments.
|
|
1475
|
-
kwargs (dict): keyword arguments.
|
|
1476
1463
|
|
|
1477
1464
|
Return:
|
|
1478
|
-
bool, specifies whether the forward graph
|
|
1465
|
+
bool, specifies whether the forward graph needs to construct.
|
|
1479
1466
|
"""
|
|
1480
|
-
return self._executor.check_run(grad, obj, weights, grad_hash_id, *args
|
|
1467
|
+
return self._executor.check_run(grad, obj, weights, grad_hash_id, *args)
|
|
1481
1468
|
|
|
1482
|
-
def grad(self, obj, grad, weights, grad_position, *args
|
|
1469
|
+
def grad(self, obj, grad, weights, grad_position, *args):
|
|
1483
1470
|
"""
|
|
1484
1471
|
Get grad graph.
|
|
1485
1472
|
|
|
@@ -1490,12 +1477,11 @@ class _PyNativeExecutor:
|
|
|
1490
1477
|
grad_position (Union(int, tuple[int])): If int, get the gradient with respect to single input.
|
|
1491
1478
|
If tuple, get the gradients with respect to selected inputs. 'grad_position' begins with 0. Default: 0.
|
|
1492
1479
|
args (tuple): Function or cell input arguments.
|
|
1493
|
-
kwargs (dict): keyword arguments.
|
|
1494
1480
|
|
|
1495
1481
|
Return:
|
|
1496
1482
|
None.
|
|
1497
1483
|
"""
|
|
1498
|
-
return self._executor.grad(grad, obj, weights, grad_position, *args
|
|
1484
|
+
return self._executor.grad(grad, obj, weights, grad_position, *args)
|
|
1499
1485
|
|
|
1500
1486
|
def clear_res(self):
|
|
1501
1487
|
"""
|
|
@@ -1528,9 +1514,23 @@ class _PyNativeExecutor:
|
|
|
1528
1514
|
"""
|
|
1529
1515
|
return self._executor.grad_jit(output, *args)
|
|
1530
1516
|
|
|
1517
|
+
def call_custom_bprop(self, obj, output, *args, **kwargs):
|
|
1518
|
+
"""
|
|
1519
|
+
Call custom bprop to build variable for cell bprop.
|
|
1520
|
+
Args:
|
|
1521
|
+
obj (Cell): The function or cell instance.
|
|
1522
|
+
output (Tensor/tuple/list): Function or cell output object.
|
|
1523
|
+
args (tuple): Function or cell input arguments.
|
|
1524
|
+
kwargs (dict): keyword arguments.
|
|
1525
|
+
|
|
1526
|
+
Return:
|
|
1527
|
+
None.
|
|
1528
|
+
"""
|
|
1529
|
+
return self._executor.call_custom_bprop(obj, output, *args, *(kwargs.values()))
|
|
1530
|
+
|
|
1531
1531
|
def grad_flag(self):
|
|
1532
1532
|
"""
|
|
1533
|
-
The flag of building grad graph.
|
|
1533
|
+
The flag of whether the net building grad graph.
|
|
1534
1534
|
|
|
1535
1535
|
Return:
|
|
1536
1536
|
bool, whether building grad graph.
|
|
@@ -1563,7 +1563,7 @@ class _PyNativeExecutor:
|
|
|
1563
1563
|
|
|
1564
1564
|
def enable_grad(self):
|
|
1565
1565
|
"""
|
|
1566
|
-
The global flag whether
|
|
1566
|
+
The global flag that whether need to calculate gradient use in no_grad.
|
|
1567
1567
|
|
|
1568
1568
|
Return:
|
|
1569
1569
|
bool, whether needing to calculate gradient.
|
|
@@ -1582,6 +1582,18 @@ class _PyNativeExecutor:
|
|
|
1582
1582
|
"""
|
|
1583
1583
|
self._executor.set_enable_grad(flag)
|
|
1584
1584
|
|
|
1585
|
+
def requires_grad(self):
|
|
1586
|
+
"""
|
|
1587
|
+
When both enable_grad is true and grad_flag is true, that the flag requires_grad will be true.
|
|
1588
|
+
|
|
1589
|
+
Args:
|
|
1590
|
+
flag (bool): Specifying whether calculating gradient.
|
|
1591
|
+
|
|
1592
|
+
Return:
|
|
1593
|
+
None.
|
|
1594
|
+
"""
|
|
1595
|
+
return self._executor.requires_grad()
|
|
1596
|
+
|
|
1585
1597
|
def set_jit_compile_status(self, status, phase):
|
|
1586
1598
|
"""
|
|
1587
1599
|
Set jit is compiling
|
|
@@ -1605,6 +1617,18 @@ class _PyNativeExecutor:
|
|
|
1605
1617
|
"""
|
|
1606
1618
|
self._executor.set_is_run_recompute(status)
|
|
1607
1619
|
|
|
1620
|
+
def set_cell_use_dynamic_shape_process(self, flag):
|
|
1621
|
+
"""
|
|
1622
|
+
Set the dynamic shape flag of eval process.
|
|
1623
|
+
|
|
1624
|
+
Args:
|
|
1625
|
+
flag (bool): Specifying whether using a dynamic process.
|
|
1626
|
+
|
|
1627
|
+
Return:
|
|
1628
|
+
None.
|
|
1629
|
+
"""
|
|
1630
|
+
self._executor.set_cell_use_dynamic_shape_process(flag)
|
|
1631
|
+
|
|
1608
1632
|
def set_dynamic_input(self, obj, *args):
|
|
1609
1633
|
"""
|
|
1610
1634
|
Set dynamic shape tensor of input arguments.
|
|
@@ -1630,27 +1654,19 @@ class _PyNativeExecutor:
|
|
|
1630
1654
|
"""
|
|
1631
1655
|
return self._executor.get_dynamic_input(*actual_args)
|
|
1632
1656
|
|
|
1633
|
-
def
|
|
1634
|
-
"""
|
|
1635
|
-
The flag of first cell instance.
|
|
1636
|
-
|
|
1637
|
-
Return:
|
|
1638
|
-
bool, specifies whether is the first cell.
|
|
1657
|
+
def set_mixed_precision_type(self, mixed_precision_type, is_push=True):
|
|
1639
1658
|
"""
|
|
1640
|
-
|
|
1641
|
-
return self._executor.is_first_cell()
|
|
1642
|
-
|
|
1643
|
-
def set_hook_changed(self, cell):
|
|
1644
|
-
"""
|
|
1645
|
-
The flag of registering or removing a hook function on Cell instance.
|
|
1659
|
+
The value of mixed precision type.
|
|
1646
1660
|
|
|
1647
1661
|
Args:
|
|
1648
|
-
|
|
1662
|
+
type(MixedPrecisionType): Mix precision type.
|
|
1663
|
+
is_push(bool): If called by __enter__, is push will be True
|
|
1649
1664
|
|
|
1650
1665
|
Return:
|
|
1651
1666
|
None.
|
|
1652
1667
|
"""
|
|
1653
|
-
|
|
1668
|
+
|
|
1669
|
+
return self._executor.set_mixed_precision_type(mixed_precision_type, is_push)
|
|
1654
1670
|
|
|
1655
1671
|
def constant_folding(self, *args):
|
|
1656
1672
|
"""
|
|
@@ -1687,6 +1703,7 @@ class _CellGraphExecutor:
|
|
|
1687
1703
|
self._graph_executor = GraphExecutor_.get_instance()
|
|
1688
1704
|
self._graph_executor.set_py_exe_path(sys.executable)
|
|
1689
1705
|
self._graph_executor.set_kernel_build_server_dir(os.path.split(kernel_build_server.__file__)[0] + os.sep)
|
|
1706
|
+
self._pid = os.getpid()
|
|
1690
1707
|
|
|
1691
1708
|
def init_dataset(self, queue_name, dataset_size, batch_size, dataset_types, dataset_shapes,
|
|
1692
1709
|
input_indexs, phase='dataset', need_run=True):
|
|
@@ -1789,6 +1806,10 @@ class _CellGraphExecutor:
|
|
|
1789
1806
|
self._graph_executor.set_enable_tuple_broaden(self.enable_tuple_broaden)
|
|
1790
1807
|
key = self._graph_executor.generate_arguments_key(obj, args, kwargs, self.enable_tuple_broaden)
|
|
1791
1808
|
obj.arguments_key = str(key)
|
|
1809
|
+
# When exist parameter in the top graph inputs, need check if the parameter object has changed.
|
|
1810
|
+
parameter_ids = _get_parameter_ids(args, kwargs)
|
|
1811
|
+
if parameter_ids != "":
|
|
1812
|
+
obj.arguments_key = obj.arguments_key + '.' + parameter_ids
|
|
1792
1813
|
raw_phase = phase
|
|
1793
1814
|
phase = phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
|
|
1794
1815
|
obj.phase_cache[raw_phase] = phase
|
|
@@ -1825,7 +1846,7 @@ class _CellGraphExecutor:
|
|
|
1825
1846
|
if graph is None:
|
|
1826
1847
|
raise RuntimeError("Compile graph failed for phase {}.".format(phase))
|
|
1827
1848
|
|
|
1828
|
-
auto_parallel_mode = _is_in_auto_parallel_mode()
|
|
1849
|
+
auto_parallel_mode = _is_in_auto_parallel_mode() or _is_parallel_mode()
|
|
1829
1850
|
if not auto_parallel_mode:
|
|
1830
1851
|
replace = obj.init_parameters_data(auto_parallel_mode=auto_parallel_mode)
|
|
1831
1852
|
self._update_param_node_default_input(phase, replace)
|
|
@@ -1913,15 +1934,9 @@ class _CellGraphExecutor:
|
|
|
1913
1934
|
|
|
1914
1935
|
def del_net_res(self, obj, net_id):
|
|
1915
1936
|
"""Clear the memory resource of a network."""
|
|
1916
|
-
|
|
1917
|
-
|
|
1918
|
-
|
|
1919
|
-
"""Increase the count of GraphCell instance."""
|
|
1920
|
-
self._graph_executor.inc_graph_cell_count()
|
|
1921
|
-
|
|
1922
|
-
def dec_graph_cell_count(self):
|
|
1923
|
-
"""Decrease the count of GraphCell instance."""
|
|
1924
|
-
self._graph_executor.dec_graph_cell_count()
|
|
1937
|
+
# no need to del net res by gc in independent dataset process which is a subprocess forked by main process
|
|
1938
|
+
if self._pid == os.getpid():
|
|
1939
|
+
self._graph_executor.del_net_res(obj, net_id)
|
|
1925
1940
|
|
|
1926
1941
|
def _get_branch_control_input(self):
|
|
1927
1942
|
if ('obf_ratio' not in self.obfuscate_config.keys()) or (
|
mindspore/common/dump.py
CHANGED
|
@@ -27,18 +27,17 @@ def set_dump(target, enabled=True):
|
|
|
27
27
|
`target` should be an instance of :class:`mindspore.nn.Cell` or :class:`mindspore.ops.Primitive` .
|
|
28
28
|
Please note that this API takes effect only when Synchronous Dump is enabled and the `dump_mode`
|
|
29
29
|
field in dump config file is ``"2"`` . See the `dump document
|
|
30
|
-
<https://www.mindspore.cn/
|
|
30
|
+
<https://www.mindspore.cn/docs/en/master/model_train/debug/dump.html>`_ for details.
|
|
31
31
|
The default enabled status for
|
|
32
32
|
a :class:`mindspore.nn.Cell` or :class:`mindspore.ops.Primitive` is False.
|
|
33
33
|
|
|
34
34
|
Note:
|
|
35
|
-
1. This API
|
|
36
|
-
2. This API only supports being called before training starts.
|
|
35
|
+
1. This API only supports being called before training starts.
|
|
37
36
|
If you call this API during training, it may not be effective.
|
|
38
|
-
|
|
37
|
+
2. After using `set_dump(Cell, True)` , operators in forward and backward
|
|
39
38
|
computation (computation generated by the grad operations) of the
|
|
40
39
|
cell will be dumped.
|
|
41
|
-
|
|
40
|
+
3. For :class:`mindspore.nn.SoftmaxCrossEntropyWithLogits` layer, the forward
|
|
42
41
|
computation and backward computation use the same set of
|
|
43
42
|
operators. So you can only see dump data from backward computation.
|
|
44
43
|
Please note that :class:`mindspore.nn.SoftmaxCrossEntropyWithLogits` layer will also use
|
|
@@ -58,7 +57,7 @@ def set_dump(target, enabled=True):
|
|
|
58
57
|
.. note::
|
|
59
58
|
Please set environment variable `MINDSPORE_DUMP_CONFIG` to the dump config file and set `dump_mode` field
|
|
60
59
|
in dump config file to 2 before running this example.
|
|
61
|
-
See `dump document <https://www.mindspore.cn/
|
|
60
|
+
See `dump document <https://www.mindspore.cn/docs/en/master/model_train/debug/dump.html>`_ for details.
|
|
62
61
|
|
|
63
62
|
>>> import numpy as np
|
|
64
63
|
>>> import mindspore as ms
|
mindspore/common/generator.py
CHANGED
|
@@ -56,12 +56,6 @@ class Generator:
|
|
|
56
56
|
A generator that manages the state of random numbers and provides seed and offset for random functions.
|
|
57
57
|
When the seed and offset are fixed, the random function generates the same random sequence.
|
|
58
58
|
|
|
59
|
-
Inputs:
|
|
60
|
-
- **step** (int) - Set the step size for offset update.
|
|
61
|
-
|
|
62
|
-
Outputs:
|
|
63
|
-
Tuple consisting of the seed and offset of generator.
|
|
64
|
-
|
|
65
59
|
Supported Platforms:
|
|
66
60
|
``Ascend`` ``GPU`` ``CPU``
|
|
67
61
|
|
|
@@ -199,7 +193,7 @@ def manual_seed(seed): # pylint: disable=redefined-outer-name
|
|
|
199
193
|
>>> print(initial_seed())
|
|
200
194
|
13
|
|
201
195
|
"""
|
|
202
|
-
default_generator.manual_seed(seed)
|
|
196
|
+
return default_generator.manual_seed(seed)
|
|
203
197
|
|
|
204
198
|
|
|
205
199
|
def initial_seed():
|
mindspore/common/hook_handle.py
CHANGED
|
@@ -77,27 +77,19 @@ class HookHandle:
|
|
|
77
77
|
It is only supported in pynative mode and works when registering or removing hook function for Cell object.
|
|
78
78
|
|
|
79
79
|
Args:
|
|
80
|
-
|
|
81
|
-
hook_key (int): The key of cell hook function in dict. It is generated during cell hook function registration.
|
|
82
|
-
Default value: -1.
|
|
83
|
-
hook_type (str): The type of cell hook function: '_forward_pre_hook', '_forward_hook' or '_cell_backward_hook'.
|
|
84
|
-
Default value: "".
|
|
80
|
+
hook_dict (Dict): The hook object with hook function registered on. Default value: None.
|
|
85
81
|
|
|
86
82
|
Supported Platforms:
|
|
87
83
|
``Ascend`` ``GPU`` ``CPU``
|
|
88
84
|
"""
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
def __del__(self):
|
|
98
|
-
self._hook_cell = None
|
|
99
|
-
self._hook_key = None
|
|
100
|
-
self._hook_type = None
|
|
85
|
+
unique_id = 0
|
|
86
|
+
|
|
87
|
+
def __init__(self, hook_dict=None):
|
|
88
|
+
self.hook_dict_ref = None
|
|
89
|
+
if hook_dict is not None:
|
|
90
|
+
self.hook_dict_ref = weakref.ref(hook_dict)
|
|
91
|
+
self.handle_id = HookHandle.unique_id
|
|
92
|
+
HookHandle.unique_id += 1
|
|
101
93
|
|
|
102
94
|
def remove(self):
|
|
103
95
|
"""
|
|
@@ -121,7 +113,7 @@ class HookHandle:
|
|
|
121
113
|
>>> from mindspore import Tensor
|
|
122
114
|
>>> from mindspore.ops import GradOperation
|
|
123
115
|
>>> ms.set_context(mode=ms.PYNATIVE_MODE)
|
|
124
|
-
>>> def forward_pre_hook_fn(
|
|
116
|
+
>>> def forward_pre_hook_fn(cell, inputs):
|
|
125
117
|
... print("forward inputs: ", inputs)
|
|
126
118
|
...
|
|
127
119
|
>>> class Net(nn.Cell):
|
|
@@ -145,11 +137,7 @@ class HookHandle:
|
|
|
145
137
|
(Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1], dtype=Float32,
|
|
146
138
|
value= [ 2.00000000e+00]))
|
|
147
139
|
"""
|
|
148
|
-
if self.
|
|
149
|
-
|
|
150
|
-
if
|
|
151
|
-
del
|
|
152
|
-
elif self._hook_type == "_forward_hook" and self._hook_key in hook_cell._forward_hook:
|
|
153
|
-
del hook_cell._forward_hook[self._hook_key]
|
|
154
|
-
elif self._hook_type == "_cell_backward_hook":
|
|
155
|
-
hook_cell._cell_backward_hook.remove_backward_hook(self._hook_key)
|
|
140
|
+
if self.hook_dict_ref is not None:
|
|
141
|
+
hook_dict = self.hook_dict_ref()
|
|
142
|
+
if hook_dict is not None and self.handle_id in hook_dict:
|
|
143
|
+
del hook_dict[self.handle_id]
|
mindspore/common/mindir_util.py
CHANGED
|
@@ -90,9 +90,9 @@ def save_mindir(model, file_name):
|
|
|
90
90
|
if not file_name.endswith('.mindir'):
|
|
91
91
|
file_name += ".mindir"
|
|
92
92
|
|
|
93
|
-
current_path = os.path.
|
|
93
|
+
current_path = os.path.realpath(file_name)
|
|
94
94
|
dirname = os.path.dirname(current_path)
|
|
95
|
-
os.makedirs(dirname, exist_ok=True)
|
|
95
|
+
os.makedirs(dirname, mode=0o700, exist_ok=True)
|
|
96
96
|
if os.path.exists(file_name):
|
|
97
97
|
os.chmod(file_name, stat.S_IWUSR)
|
|
98
98
|
|
mindspore/common/parameter.py
CHANGED
|
@@ -41,6 +41,8 @@ from mindspore.parallel._ps_context import _is_role_worker, _is_role_pserver, _i
|
|
|
41
41
|
_is_ps_mode
|
|
42
42
|
from mindspore.parallel._ps_context import _reinsert_hash_table_size, _insert_accumu_init_info, _cache_enable
|
|
43
43
|
from mindspore.common._decorator import deprecated
|
|
44
|
+
from mindspore.communication._comm_helper import _is_initialized
|
|
45
|
+
from mindspore.communication import get_group_size
|
|
44
46
|
import mindspore.common._monad as monad
|
|
45
47
|
|
|
46
48
|
__all__ = ['Parameter', 'ParameterTuple']
|
|
@@ -52,11 +54,22 @@ PARAMETER_NAME_PREFIX_MAX_LEN = 1024
|
|
|
52
54
|
_GLOBAL_PARAMETER_KEY = -1
|
|
53
55
|
|
|
54
56
|
|
|
55
|
-
def
|
|
57
|
+
def _is_in_auto_parallel_mode():
|
|
56
58
|
"""Get parallel mode."""
|
|
57
59
|
return auto_parallel_context().get_parallel_mode() in ["semi_auto_parallel", "auto_parallel"]
|
|
58
60
|
|
|
59
61
|
|
|
62
|
+
def _is_parallel_mode():
|
|
63
|
+
""" Whether is parallel mode """
|
|
64
|
+
if not _is_initialized() or context.get_context('mode') == context.PYNATIVE_MODE:
|
|
65
|
+
return False
|
|
66
|
+
if os.getenv("RUN_MODE") != "predict":
|
|
67
|
+
return False
|
|
68
|
+
if get_group_size() > 1 and _get_parallel_mode() == "stand_alone":
|
|
69
|
+
return True
|
|
70
|
+
return False
|
|
71
|
+
|
|
72
|
+
|
|
60
73
|
def init_to_value(init):
|
|
61
74
|
"""
|
|
62
75
|
Get value of initializer.
|
|
@@ -91,6 +104,15 @@ def _get_unique_parameter_key():
|
|
|
91
104
|
return _GLOBAL_PARAMETER_KEY
|
|
92
105
|
|
|
93
106
|
|
|
107
|
+
def _gen_offload_file_path(offload_dir):
|
|
108
|
+
offload_dir = os.path.relpath(offload_dir)
|
|
109
|
+
if not os.path.exists(offload_dir):
|
|
110
|
+
os.makedirs(offload_dir, mode=0o700, exist_ok=True)
|
|
111
|
+
offload_file_path = offload_dir + "/" + str(_get_global_rank()) + "_" + str(
|
|
112
|
+
_get_unique_parameter_key()) + "_" + str(time.time()) + ".data"
|
|
113
|
+
return offload_file_path
|
|
114
|
+
|
|
115
|
+
|
|
94
116
|
def _offload_if_config(data):
|
|
95
117
|
"""
|
|
96
118
|
Offload parameter(data size > 512) to file when enable memory offload and offload parameter to disk.
|
|
@@ -111,11 +133,7 @@ def _offload_if_config(data):
|
|
|
111
133
|
offload_file_path = data.offload_file_path()
|
|
112
134
|
if offload_file_path is None or offload_file_path == "":
|
|
113
135
|
offload_dir = offload_context.get("offload_path", "./offload")
|
|
114
|
-
|
|
115
|
-
if not os.path.exists(offload_dir):
|
|
116
|
-
os.makedirs(offload_dir)
|
|
117
|
-
offload_file_path = offload_dir + "/" + str(_get_global_rank()) + "_" + str(
|
|
118
|
-
_get_unique_parameter_key()) + "_" + str(time.time()) + ".data"
|
|
136
|
+
offload_file_path = _gen_offload_file_path(offload_dir)
|
|
119
137
|
data.offload(offload_file_path)
|
|
120
138
|
|
|
121
139
|
|
|
@@ -191,6 +209,12 @@ class Parameter(Tensor_):
|
|
|
191
209
|
storage_format (str): Only Ascend device target is supported. It is used to specify the format of the weight
|
|
192
210
|
loaded to the device. By default, the format is not changed. The optional values are ``"FRACTAL_NZ"`` ,
|
|
193
211
|
``"NC1HWC0"`` , ``"FRACTAL_Z"`` , etc. Default: ``""`` .
|
|
212
|
+
device(str): Only Ascend device target is supported. It is used to specify the device which the parameter is
|
|
213
|
+
stored. By default, the parameter will be stored on NPU while computing. When the device is specified as
|
|
214
|
+
``"CPU"``, the parameter will be loaded into the device when it needs to be used, and unloaded to the CPU
|
|
215
|
+
after use. It takes effext only when `memory_offload` is ``"ON"``, `jit_level` is not ``"O2"`` and
|
|
216
|
+
`memory_optimize_level` is ``O0`` in `mindspore.set_context()`. Less device memory is needed when device is
|
|
217
|
+
specified as ``"CPU"``.
|
|
194
218
|
|
|
195
219
|
Examples:
|
|
196
220
|
>>> import numpy as np
|
|
@@ -244,7 +268,7 @@ class Parameter(Tensor_):
|
|
|
244
268
|
Parameter, (data, self.name, self.requires_grad, self.layerwise_parallel))
|
|
245
269
|
|
|
246
270
|
def __init__(self, default_input, name=None, requires_grad=True, layerwise_parallel=False, parallel_optimizer=True,
|
|
247
|
-
storage_format=""):
|
|
271
|
+
storage_format="", device=None):
|
|
248
272
|
self.param_info = ParamInfo()
|
|
249
273
|
self.init_in_server = False
|
|
250
274
|
self.name = name
|
|
@@ -263,7 +287,7 @@ class Parameter(Tensor_):
|
|
|
263
287
|
self.requires_aggr = True
|
|
264
288
|
self._cast_type = None
|
|
265
289
|
self._unique = False
|
|
266
|
-
self.is_in_parallel =
|
|
290
|
+
self.is_in_parallel = _is_in_auto_parallel_mode()
|
|
267
291
|
self.is_in_shard = False
|
|
268
292
|
self._pipeline_stage_list = []
|
|
269
293
|
self.slice_num = 1
|
|
@@ -296,6 +320,10 @@ class Parameter(Tensor_):
|
|
|
296
320
|
f" 'numpy.ndarray', 'list']. But got type {type(default_input)}.")
|
|
297
321
|
self.param_info.parameter_shape = self.shape
|
|
298
322
|
self.param_info.storage_format = storage_format
|
|
323
|
+
if device is not None:
|
|
324
|
+
if device != "CPU":
|
|
325
|
+
raise ValueError(f"Only 'CPU' is supported for device, but got ${device}.")
|
|
326
|
+
self._set_user_data("parameter_device", device)
|
|
299
327
|
|
|
300
328
|
import mindspore.ops.operations.other_ops as other_ops
|
|
301
329
|
self.load = other_ops.Load()
|
|
@@ -342,7 +370,8 @@ class Parameter(Tensor_):
|
|
|
342
370
|
return (Tensor, data.asnumpy(), mstype.qint4x2)
|
|
343
371
|
return (Tensor, data.asnumpy())
|
|
344
372
|
|
|
345
|
-
not_init_data = _is_role_sched() or (_is_role_pserver() and _cache_enable()
|
|
373
|
+
not_init_data = _is_role_sched() or (_is_role_pserver() and _cache_enable()
|
|
374
|
+
) or _is_in_auto_parallel_mode() or _is_parallel_mode()
|
|
346
375
|
if not_init_data:
|
|
347
376
|
# do not init data while in auto parallel.
|
|
348
377
|
return (Tensor, None, data.dtype, get_slice_shape(data.dtype, data.shape), data.init)
|
|
@@ -368,7 +397,7 @@ class Parameter(Tensor_):
|
|
|
368
397
|
|
|
369
398
|
Tutorial Examples:
|
|
370
399
|
- `Parameter Server Mode
|
|
371
|
-
<https://www.mindspore.cn/
|
|
400
|
+
<https://www.mindspore.cn/docs/en/master/model_train/parallel/parameter_server_training.html>`_
|
|
372
401
|
"""
|
|
373
402
|
if not _is_ps_mode() or not (_is_role_worker() or _is_role_pserver() or _is_role_sched()):
|
|
374
403
|
raise RuntimeError("Must complete following two steps before calling set_param_ps: \n"
|
|
@@ -616,6 +645,9 @@ class Parameter(Tensor_):
|
|
|
616
645
|
shape = self.shape if self.slice_num == 1 else self.param_info.origin_shape
|
|
617
646
|
dtype = self.dtype
|
|
618
647
|
x.set_data(initializer(init, shape=shape, dtype=dtype))
|
|
648
|
+
device = self._get_user_data("parameter_device")
|
|
649
|
+
if device is not None:
|
|
650
|
+
x._set_user_data("parameter_device", device)
|
|
619
651
|
return x
|
|
620
652
|
|
|
621
653
|
@property
|
|
@@ -942,7 +974,7 @@ class Parameter(Tensor_):
|
|
|
942
974
|
>>> x = Parameter(Tensor(np.array([[1, 2], [3, 4]], dtype=np.float32)), name="param")
|
|
943
975
|
>>> x.init_data()
|
|
944
976
|
"""
|
|
945
|
-
if self.is_default_input_init and self.is_in_parallel !=
|
|
977
|
+
if self.is_default_input_init and self.is_in_parallel != _is_in_auto_parallel_mode():
|
|
946
978
|
raise RuntimeError("Must set or change parallel mode before any initializer Tensor created.")
|
|
947
979
|
if self.init_mode is None:
|
|
948
980
|
return self
|
|
@@ -1026,8 +1058,9 @@ class ParameterTuple(tuple):
|
|
|
1026
1058
|
Tuple, the new Parameter tuple.
|
|
1027
1059
|
|
|
1028
1060
|
Tutorial Examples:
|
|
1029
|
-
- `
|
|
1030
|
-
<https://mindspore.cn/
|
|
1061
|
+
- `Tensor and Parameter - Parameter Tuple
|
|
1062
|
+
<https://mindspore.cn/docs/en/master/model_train/model_building/tensor_and_parameter.html
|
|
1063
|
+
#parameter-tuple>`_
|
|
1031
1064
|
"""
|
|
1032
1065
|
Validator.check_str_by_regular(prefix)
|
|
1033
1066
|
new = []
|