mindspore 2.3.0__cp310-cp310-win_amd64.whl → 2.4.1__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +3 -1
- mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +50 -9
- mindspore/_extends/parse/compile_config.py +41 -0
- mindspore/_extends/parse/parser.py +9 -7
- mindspore/_extends/parse/standard_method.py +52 -14
- mindspore/_extends/pijit/pijit_func_white_list.py +350 -24
- mindspore/amp.py +24 -10
- mindspore/common/__init__.py +6 -4
- mindspore/common/_pijit_context.py +190 -0
- mindspore/common/_register_for_tensor.py +2 -1
- mindspore/common/_tensor_overload.py +139 -0
- mindspore/common/api.py +102 -87
- mindspore/common/dump.py +5 -6
- mindspore/common/generator.py +1 -7
- mindspore/common/hook_handle.py +14 -26
- mindspore/common/initializer.py +51 -15
- mindspore/common/mindir_util.py +2 -2
- mindspore/common/parameter.py +62 -15
- mindspore/common/recompute.py +39 -9
- mindspore/common/sparse_tensor.py +7 -3
- mindspore/common/tensor.py +183 -37
- mindspore/communication/__init__.py +1 -1
- mindspore/communication/_comm_helper.py +38 -3
- mindspore/communication/comm_func.py +315 -60
- mindspore/communication/management.py +14 -14
- mindspore/context.py +132 -22
- mindspore/dataset/__init__.py +1 -1
- mindspore/dataset/audio/__init__.py +1 -1
- mindspore/dataset/core/config.py +7 -0
- mindspore/dataset/core/validator_helpers.py +7 -0
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +72 -44
- mindspore/dataset/engine/datasets_audio.py +7 -7
- mindspore/dataset/engine/datasets_standard_format.py +53 -3
- mindspore/dataset/engine/datasets_text.py +20 -20
- mindspore/dataset/engine/datasets_user_defined.py +174 -104
- mindspore/dataset/engine/datasets_vision.py +33 -33
- mindspore/dataset/engine/iterators.py +29 -0
- mindspore/dataset/engine/obs/util.py +7 -0
- mindspore/dataset/engine/queue.py +114 -60
- mindspore/dataset/engine/serializer_deserializer.py +2 -2
- mindspore/dataset/engine/validators.py +34 -14
- mindspore/dataset/text/__init__.py +1 -4
- mindspore/dataset/transforms/__init__.py +0 -3
- mindspore/dataset/utils/line_reader.py +2 -0
- mindspore/dataset/vision/__init__.py +1 -4
- mindspore/dataset/vision/utils.py +1 -1
- mindspore/dataset/vision/validators.py +2 -1
- mindspore/{nn/extend → experimental/es}/__init__.py +4 -11
- mindspore/experimental/es/embedding_service.py +883 -0
- mindspore/{nn/layer → experimental/es}/embedding_service_layer.py +218 -30
- mindspore/experimental/llm_boost/__init__.py +21 -0
- mindspore/{nn/extend/layer → experimental/llm_boost/atb}/__init__.py +4 -8
- mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
- mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
- mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
- mindspore/experimental/llm_boost/register.py +129 -0
- mindspore/experimental/llm_boost/utils.py +31 -0
- mindspore/experimental/optim/adamw.py +85 -0
- mindspore/experimental/optim/optimizer.py +3 -0
- mindspore/hal/__init__.py +3 -3
- mindspore/hal/contiguous_tensors_handle.py +175 -0
- mindspore/hal/stream.py +18 -0
- mindspore/include/api/model_group.h +13 -1
- mindspore/include/api/types.h +10 -10
- mindspore/include/dataset/config.h +2 -2
- mindspore/include/dataset/constants.h +2 -2
- mindspore/include/dataset/execute.h +2 -2
- mindspore/include/dataset/vision.h +4 -0
- mindspore/log.py +1 -1
- mindspore/mindrecord/filewriter.py +68 -51
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mint/__init__.py +983 -46
- mindspore/mint/distributed/__init__.py +31 -0
- mindspore/mint/distributed/distributed.py +254 -0
- mindspore/mint/nn/__init__.py +268 -23
- mindspore/mint/nn/functional.py +125 -19
- mindspore/mint/nn/layer/__init__.py +39 -0
- mindspore/mint/nn/layer/activation.py +133 -0
- mindspore/mint/nn/layer/normalization.py +477 -0
- mindspore/mint/nn/layer/pooling.py +110 -0
- mindspore/mint/optim/adamw.py +26 -13
- mindspore/mint/special/__init__.py +63 -0
- mindspore/multiprocessing/__init__.py +2 -1
- mindspore/nn/__init__.py +0 -1
- mindspore/nn/cell.py +276 -96
- mindspore/nn/layer/activation.py +211 -44
- mindspore/nn/layer/basic.py +137 -10
- mindspore/nn/layer/embedding.py +137 -2
- mindspore/nn/layer/normalization.py +101 -5
- mindspore/nn/layer/padding.py +34 -48
- mindspore/nn/layer/pooling.py +161 -7
- mindspore/nn/layer/transformer.py +3 -3
- mindspore/nn/loss/__init__.py +2 -2
- mindspore/nn/loss/loss.py +84 -6
- mindspore/nn/optim/__init__.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -1
- mindspore/nn/optim/adam.py +1 -1
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/tft_wrapper.py +124 -0
- mindspore/nn/wrap/cell_wrapper.py +12 -23
- mindspore/nn/wrap/grad_reducer.py +5 -5
- mindspore/nn/wrap/loss_scale.py +17 -3
- mindspore/numpy/__init__.py +1 -1
- mindspore/numpy/array_creations.py +65 -68
- mindspore/numpy/array_ops.py +64 -60
- mindspore/numpy/fft.py +610 -75
- mindspore/numpy/logic_ops.py +11 -10
- mindspore/numpy/math_ops.py +85 -84
- mindspore/numpy/utils_const.py +4 -4
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +6 -4
- mindspore/ops/_grad_experimental/grad_array_ops.py +0 -11
- mindspore/ops/_grad_experimental/grad_comm_ops.py +67 -4
- mindspore/ops/_grad_experimental/grad_math_ops.py +0 -22
- mindspore/ops/_vmap/vmap_array_ops.py +2 -4
- mindspore/ops/_vmap/vmap_math_ops.py +17 -1
- mindspore/ops/_vmap/vmap_nn_ops.py +43 -2
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +91 -7
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +2 -0
- mindspore/ops/auto_generate/gen_extend_func.py +767 -13
- mindspore/ops/auto_generate/gen_ops_def.py +2452 -364
- mindspore/ops/auto_generate/gen_ops_prim.py +5442 -1756
- mindspore/ops/auto_generate/pyboost_inner_prim.py +176 -56
- mindspore/ops/composite/base.py +85 -48
- mindspore/ops/composite/multitype_ops/_compile_utils.py +1 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -2
- mindspore/ops/function/__init__.py +22 -0
- mindspore/ops/function/array_func.py +492 -153
- mindspore/ops/function/debug_func.py +113 -1
- mindspore/ops/function/fft_func.py +15 -2
- mindspore/ops/function/grad/grad_func.py +3 -2
- mindspore/ops/function/math_func.py +564 -207
- mindspore/ops/function/nn_func.py +817 -383
- mindspore/ops/function/other_func.py +3 -2
- mindspore/ops/function/random_func.py +402 -12
- mindspore/ops/function/reshard_func.py +13 -11
- mindspore/ops/function/sparse_unary_func.py +1 -1
- mindspore/ops/function/vmap_func.py +3 -2
- mindspore/ops/functional.py +24 -14
- mindspore/ops/op_info_register.py +3 -3
- mindspore/ops/operations/__init__.py +7 -2
- mindspore/ops/operations/_grad_ops.py +2 -76
- mindspore/ops/operations/_infer_ops.py +1 -1
- mindspore/ops/operations/_inner_ops.py +71 -94
- mindspore/ops/operations/array_ops.py +14 -146
- mindspore/ops/operations/comm_ops.py +63 -53
- mindspore/ops/operations/custom_ops.py +83 -19
- mindspore/ops/operations/debug_ops.py +42 -10
- mindspore/ops/operations/manually_defined/_inner.py +12 -0
- mindspore/ops/operations/manually_defined/ops_def.py +273 -20
- mindspore/ops/operations/math_ops.py +12 -223
- mindspore/ops/operations/nn_ops.py +20 -114
- mindspore/ops/operations/other_ops.py +7 -4
- mindspore/ops/operations/random_ops.py +46 -1
- mindspore/ops/primitive.py +18 -6
- mindspore/ops_generate/arg_dtype_cast.py +2 -0
- mindspore/ops_generate/gen_aclnn_implement.py +11 -11
- mindspore/ops_generate/gen_constants.py +36 -0
- mindspore/ops_generate/gen_ops.py +67 -52
- mindspore/ops_generate/gen_ops_inner_prim.py +1 -1
- mindspore/ops_generate/gen_pyboost_func.py +131 -47
- mindspore/ops_generate/op_proto.py +10 -3
- mindspore/ops_generate/pyboost_utils.py +14 -1
- mindspore/ops_generate/template.py +43 -21
- mindspore/parallel/__init__.py +3 -1
- mindspore/parallel/_auto_parallel_context.py +31 -9
- mindspore/parallel/_cell_wrapper.py +85 -0
- mindspore/parallel/_parallel_serialization.py +47 -19
- mindspore/parallel/_tensor.py +127 -13
- mindspore/parallel/_utils.py +53 -22
- mindspore/parallel/algo_parameter_config.py +5 -5
- mindspore/parallel/checkpoint_transform.py +46 -39
- mindspore/parallel/cluster/process_entity/__init__.py +1 -1
- mindspore/parallel/cluster/process_entity/_api.py +31 -23
- mindspore/parallel/cluster/process_entity/_utils.py +2 -27
- mindspore/parallel/parameter_broadcast.py +3 -4
- mindspore/parallel/shard.py +162 -31
- mindspore/parallel/transform_safetensors.py +1146 -0
- mindspore/profiler/__init__.py +2 -1
- mindspore/profiler/common/constant.py +29 -0
- mindspore/profiler/common/registry.py +47 -0
- mindspore/profiler/common/util.py +28 -0
- mindspore/profiler/dynamic_profiler.py +694 -0
- mindspore/profiler/envprofiling.py +17 -19
- mindspore/profiler/parser/ascend_analysis/constant.py +18 -0
- mindspore/profiler/parser/ascend_analysis/file_manager.py +25 -4
- mindspore/profiler/parser/ascend_analysis/function_event.py +43 -19
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +31 -26
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +56 -10
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +55 -8
- mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +27 -20
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +9 -2
- mindspore/profiler/parser/ascend_msprof_exporter.py +5 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +27 -25
- mindspore/profiler/parser/base_timeline_generator.py +19 -25
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
- mindspore/profiler/parser/framework_parser.py +1 -391
- mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
- mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
- mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
- mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
- mindspore/profiler/parser/memory_usage_parser.py +0 -154
- mindspore/profiler/parser/profiler_info.py +78 -6
- mindspore/profiler/profiler.py +153 -0
- mindspore/profiler/profiling.py +285 -413
- mindspore/rewrite/__init__.py +1 -2
- mindspore/rewrite/common/namespace.py +4 -4
- mindspore/rewrite/symbol_tree/symbol_tree.py +3 -3
- mindspore/run_check/_check_version.py +39 -104
- mindspore/safeguard/rewrite_obfuscation.py +591 -247
- mindspore/train/__init__.py +4 -3
- mindspore/train/_utils.py +105 -19
- mindspore/train/amp.py +171 -53
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +4 -4
- mindspore/train/callback/_checkpoint.py +97 -31
- mindspore/train/callback/_cluster_monitor.py +1 -1
- mindspore/train/callback/_flops_collector.py +1 -0
- mindspore/train/callback/_loss_monitor.py +3 -3
- mindspore/train/callback/_on_request_exit.py +145 -31
- mindspore/train/callback/_summary_collector.py +5 -5
- mindspore/train/callback/_tft_register.py +375 -0
- mindspore/train/dataset_helper.py +15 -3
- mindspore/train/metrics/metric.py +3 -3
- mindspore/train/metrics/roc.py +4 -4
- mindspore/train/mind_ir_pb2.py +44 -39
- mindspore/train/model.py +154 -58
- mindspore/train/serialization.py +342 -128
- mindspore/utils/__init__.py +21 -0
- mindspore/utils/utils.py +60 -0
- mindspore/version.py +1 -1
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/METADATA +13 -7
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/RECORD +248 -242
- mindspore/include/c_api/ms/abstract.h +0 -67
- mindspore/include/c_api/ms/attribute.h +0 -197
- mindspore/include/c_api/ms/base/handle_types.h +0 -43
- mindspore/include/c_api/ms/base/macros.h +0 -32
- mindspore/include/c_api/ms/base/status.h +0 -33
- mindspore/include/c_api/ms/base/types.h +0 -283
- mindspore/include/c_api/ms/context.h +0 -102
- mindspore/include/c_api/ms/graph.h +0 -160
- mindspore/include/c_api/ms/node.h +0 -606
- mindspore/include/c_api/ms/tensor.h +0 -161
- mindspore/include/c_api/ms/value.h +0 -84
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/nn/extend/basic.py +0 -140
- mindspore/nn/extend/embedding.py +0 -143
- mindspore/nn/extend/layer/normalization.py +0 -109
- mindspore/nn/extend/pooling.py +0 -117
- mindspore/nn/layer/embedding_service.py +0 -531
- mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
- mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
- mindspore/ops/extend/__init__.py +0 -53
- mindspore/ops/extend/array_func.py +0 -218
- mindspore/ops/extend/math_func.py +0 -76
- mindspore/ops/extend/nn_func.py +0 -308
- mindspore/ops/silent_check.py +0 -162
- mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
- mindspore/profiler/parser/msadvisor_parser.py +0 -240
- mindspore/train/callback/_mindio_ttp.py +0 -443
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/WHEEL +0 -0
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/top_level.txt +0 -0
mindspore/common/api.py
CHANGED
|
@@ -38,12 +38,13 @@ from mindspore.common.tensor import Tensor as PythonTensor
|
|
|
38
38
|
from mindspore.common.sparse_tensor import CSRTensor as PythonCSRTensor
|
|
39
39
|
from mindspore.common.sparse_tensor import COOTensor as PythonCOOTensor
|
|
40
40
|
from mindspore.common.sparse_tensor import RowTensor as PythonRowTensor
|
|
41
|
+
from mindspore._c_expression.amp import get_curr_amp_strategy
|
|
41
42
|
from mindspore._c_expression import GraphExecutor_, Tensor, CSRTensor, RowTensor, COOTensor, \
|
|
42
43
|
PyNativeExecutor_, verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_pipeline, \
|
|
43
|
-
_ms_memory_recycle, _bind_device_ctx
|
|
44
|
+
_ms_memory_recycle, _bind_device_ctx
|
|
44
45
|
from mindspore.parallel._ps_context import _is_role_sched
|
|
45
46
|
from mindspore.parallel._utils import _check_full_batch, _get_parameter_broadcast, _is_pynative_parallel, \
|
|
46
|
-
_is_in_auto_parallel_mode
|
|
47
|
+
_is_in_auto_parallel_mode, _is_parallel_mode
|
|
47
48
|
from mindspore import _checkparam as Validator
|
|
48
49
|
from mindspore._checkparam import is_stub_tensor
|
|
49
50
|
from mindspore.common._utils import is_shape_unknown
|
|
@@ -51,6 +52,8 @@ from mindspore.common.mutable import mutable
|
|
|
51
52
|
from mindspore.common._register_for_adapter import ms_adapter_registry
|
|
52
53
|
from mindspore.common.auto_dynamic_shape import get_auto_dynamic_shape_args, update_auto_dynamic_shape_phase, \
|
|
53
54
|
get_auto_dynamic_shape_args_with_check_input_signature, update_auto_dynamic_shape_phase_with_check_input_signature
|
|
55
|
+
from mindspore.common._pijit_context import PIJitCaptureContext
|
|
56
|
+
from mindspore.common.parameter import Parameter
|
|
54
57
|
|
|
55
58
|
# Store ms_function class compiled pipeline cache.
|
|
56
59
|
ms_compile_cache = set()
|
|
@@ -513,6 +516,19 @@ def _generate_dyn_compile_args(compile_args, dyn_args):
|
|
|
513
516
|
return tuple(new_compile_args)
|
|
514
517
|
|
|
515
518
|
|
|
519
|
+
def _get_parameter_ids(args, kwargs):
|
|
520
|
+
"""Get the ids of parameters."""
|
|
521
|
+
parameter_ids = ""
|
|
522
|
+
for arg in args:
|
|
523
|
+
if isinstance(arg, Parameter):
|
|
524
|
+
parameter_ids += str(id(arg))
|
|
525
|
+
for _, value in kwargs.items():
|
|
526
|
+
# The type of key is usually String type.
|
|
527
|
+
if isinstance(value, Parameter):
|
|
528
|
+
parameter_ids += str(id(value))
|
|
529
|
+
return parameter_ids
|
|
530
|
+
|
|
531
|
+
|
|
516
532
|
class _MindsporeFunctionExecutor:
|
|
517
533
|
"""
|
|
518
534
|
Represents a function compiled by graph compiler.
|
|
@@ -625,6 +641,10 @@ class _MindsporeFunctionExecutor:
|
|
|
625
641
|
|
|
626
642
|
self._graph_executor.set_enable_tuple_broaden(self.enable_tuple_broaden)
|
|
627
643
|
key = self._graph_executor.generate_arguments_key(self.fn, compile_args, kwargs, self.enable_tuple_broaden)
|
|
644
|
+
|
|
645
|
+
parameter_ids = _get_parameter_ids(args, kwargs)
|
|
646
|
+
if parameter_ids != "":
|
|
647
|
+
key = str(key) + '.' + parameter_ids
|
|
628
648
|
phase = generate_name + '.' + str(key)
|
|
629
649
|
|
|
630
650
|
update_auto_dynamic_shape_phase_with_check_input_signature(compile_args, key_id, phase, self.input_signature)
|
|
@@ -783,31 +803,28 @@ def _get_jit_hash(hash_input):
|
|
|
783
803
|
return _get_obj_id(hash_input)
|
|
784
804
|
|
|
785
805
|
|
|
786
|
-
def _update_graph_executor_config(jit_config):
|
|
787
|
-
"""Update GraphExecutor jit_config"""
|
|
788
|
-
if isinstance(jit_config, JitConfig):
|
|
789
|
-
jit_config = jit_config.jit_config_dict
|
|
790
|
-
if not isinstance(jit_config, dict):
|
|
791
|
-
return
|
|
792
|
-
valid_config = dict()
|
|
793
|
-
for k, v in jit_config.items():
|
|
794
|
-
valid_config[str(k)] = str(v)
|
|
795
|
-
GraphExecutor_.get_instance().set_jit_config(JitConfig(**valid_config).jit_config_dict)
|
|
796
|
-
|
|
797
|
-
|
|
798
806
|
def jit(fn=None, mode="PSJit", input_signature=None, hash_args=None, jit_config=None, compile_once=False):
|
|
799
807
|
"""
|
|
800
808
|
Create a callable MindSpore graph from a Python function.
|
|
801
809
|
|
|
802
810
|
This allows the MindSpore runtime to apply optimizations based on graph.
|
|
803
811
|
|
|
812
|
+
Note:
|
|
813
|
+
- If `input_signature` is specified, each input of `fn` must be a Tensor. And the input arguments for `fn`
|
|
814
|
+
will not accept `**kwargs`.
|
|
815
|
+
- It is not supported to run a function with decoration @jit(mode=“PIJit”)
|
|
816
|
+
in static graph mode, in which case the decoration @jit(mode=“PIJit”) is considered invalid.
|
|
817
|
+
- Calls to functions with decorated @jit(mode=“PIJit”) inside functions
|
|
818
|
+
decorated with @jit(mode=“PIJit”) are not supported,
|
|
819
|
+
and the decoration @jit(mode=“PIJit”) is considered invalid.
|
|
820
|
+
|
|
804
821
|
Args:
|
|
805
822
|
fn (Function): The Python function that will be run as a graph. Default: ``None`` .
|
|
806
823
|
mode (str): The type of jit used, the value of mode should be ``PIJit`` or ``PSJit``. Default: ``PSJit`` .
|
|
807
824
|
|
|
808
|
-
- `PSJit <https://www.mindspore.cn/docs/en/master/
|
|
825
|
+
- `PSJit <https://www.mindspore.cn/docs/en/master/model_train/program_form/static_graph.html>`_ :
|
|
809
826
|
Parse python ast to build graph.
|
|
810
|
-
- `PIJit <https://www.mindspore.cn/docs/en/master/
|
|
827
|
+
- `PIJit <https://www.mindspore.cn/docs/en/master/model_train/program_form/pynative.html#pijit>`_ :
|
|
811
828
|
Parse python bytecode to build graph at runtime.
|
|
812
829
|
|
|
813
830
|
input_signature (Union[Tuple, List, Dict, Tensor]): The Tensor which describes the input arguments. The
|
|
@@ -831,10 +848,6 @@ def jit(fn=None, mode="PSJit", input_signature=None, hash_args=None, jit_config=
|
|
|
831
848
|
it was created again.
|
|
832
849
|
Default: ``False`` .
|
|
833
850
|
|
|
834
|
-
Note:
|
|
835
|
-
If `input_signature` is specified, each input of `fn` must be a Tensor. And the input arguments for `fn`
|
|
836
|
-
will not accept `**kwargs`.
|
|
837
|
-
|
|
838
851
|
Returns:
|
|
839
852
|
Function, if `fn` is not None, returns a callable function that will execute the compiled function; If `fn` is
|
|
840
853
|
None, returns a decorator and when this decorator invokes with a single `fn` argument, the callable function is
|
|
@@ -938,45 +951,20 @@ def jit(fn=None, mode="PSJit", input_signature=None, hash_args=None, jit_config=
|
|
|
938
951
|
# only the function or cell instance wrapped by shard will fall into this branch
|
|
939
952
|
if _is_pynative_parallel() and func.__name__ == _PYNATIVE_PARALLEL_FUNC_NAME:
|
|
940
953
|
process_obj = hash_args
|
|
954
|
+
# Handle auto mixed precision strategy.
|
|
955
|
+
if not hasattr(func, "amp_strategy"):
|
|
956
|
+
if isinstance(func, types.MethodType):
|
|
957
|
+
setattr(func.__func__, "amp_strategy", get_curr_amp_strategy())
|
|
958
|
+
else:
|
|
959
|
+
setattr(func, "amp_strategy", get_curr_amp_strategy())
|
|
941
960
|
out = _MindsporeFunctionExecutor(func, hash_obj, dyn_args, process_obj, jit_config)(*args, **kwargs)
|
|
942
961
|
return out
|
|
943
962
|
|
|
944
963
|
return staging_specialize
|
|
945
964
|
|
|
946
|
-
def pi_wrap_mindspore(decorated):
|
|
947
|
-
func = decorated
|
|
948
|
-
if isinstance(func, ms.nn.Cell):
|
|
949
|
-
func = func.construct
|
|
950
|
-
if isinstance(func, type) and issubclass(func, ms.nn.Cell):
|
|
951
|
-
func = func.construct
|
|
952
|
-
if isinstance(func, types.MethodType):
|
|
953
|
-
func = func.__func__
|
|
954
|
-
if not isinstance(func, types.FunctionType):
|
|
955
|
-
logger.warning("only support function and mindspore.nn.Cell instance")
|
|
956
|
-
return decorated
|
|
957
|
-
|
|
958
|
-
# generator, coroutine, awaitable and a function that return them is unsupported
|
|
959
|
-
UNSUPPORTED_CODE_TYPE = (inspect.CO_GENERATOR | inspect.CO_COROUTINE |
|
|
960
|
-
inspect.CO_ASYNC_GENERATOR | inspect.CO_ITERABLE_COROUTINE)
|
|
961
|
-
if func.__code__.co_flags & UNSUPPORTED_CODE_TYPE:
|
|
962
|
-
return decorated
|
|
963
|
-
|
|
964
|
-
_update_graph_executor_config(jit_config)
|
|
965
|
-
config = dict()
|
|
966
|
-
if isinstance(jit_config, JitConfig):
|
|
967
|
-
config.update(jit_config.jit_config_dict)
|
|
968
|
-
elif jit_config is not None:
|
|
969
|
-
config.update(jit_config)
|
|
970
|
-
jit_mode_pi_enable()
|
|
971
|
-
|
|
972
|
-
if jit_mode_pi_compile(func, config, input_signature) is False:
|
|
973
|
-
logger.warning('add fn {} to compile failed '.format(func))
|
|
974
|
-
|
|
975
|
-
return decorated
|
|
976
|
-
|
|
977
965
|
wrap_func = wrap_mindspore
|
|
978
966
|
if mode == "PIJit":
|
|
979
|
-
wrap_func =
|
|
967
|
+
wrap_func = PIJitCaptureContext(jit_config, input_signature)
|
|
980
968
|
|
|
981
969
|
if fn is not None:
|
|
982
970
|
return wrap_func(fn)
|
|
@@ -1272,7 +1260,7 @@ def jit_class(cls):
|
|
|
1272
1260
|
if not inspect.isclass(cls):
|
|
1273
1261
|
raise TypeError(f'Decorator jit_class can only be used for class type, but got {cls}.')
|
|
1274
1262
|
# Check if cls is nn.Cell.
|
|
1275
|
-
if issubclass(cls, nn.Cell):
|
|
1263
|
+
if issubclass(cls, nn.cell.Cell):
|
|
1276
1264
|
raise TypeError(f"Decorator jit_class is used for user-defined classes and cannot be used for nn.Cell: {cls}.")
|
|
1277
1265
|
setattr(cls, '__ms_class__', True)
|
|
1278
1266
|
return cls
|
|
@@ -1463,23 +1451,22 @@ class _PyNativeExecutor:
|
|
|
1463
1451
|
"""
|
|
1464
1452
|
self._executor.end_graph(obj, output, *args, *(kwargs.values()))
|
|
1465
1453
|
|
|
1466
|
-
def check_run(self, grad, obj, weights, grad_hash_id, *args
|
|
1454
|
+
def check_run(self, grad, obj, weights, grad_hash_id, *args):
|
|
1467
1455
|
"""
|
|
1468
1456
|
Whether the forward graph need to construct.
|
|
1469
1457
|
|
|
1470
1458
|
Args:
|
|
1471
1459
|
grad (GradOperation): The gradoperation object.
|
|
1472
1460
|
obj (Function/Cell): The function or cell instance.
|
|
1473
|
-
grad_hash_id (tuple): The id of objects which
|
|
1461
|
+
grad_hash_id (tuple): The id of objects, which contributes to cache of compiled graph in pynative mode.
|
|
1474
1462
|
args (tuple): Function or cell input arguments.
|
|
1475
|
-
kwargs (dict): keyword arguments.
|
|
1476
1463
|
|
|
1477
1464
|
Return:
|
|
1478
|
-
bool, specifies whether the forward graph
|
|
1465
|
+
bool, specifies whether the forward graph needs to construct.
|
|
1479
1466
|
"""
|
|
1480
|
-
return self._executor.check_run(grad, obj, weights, grad_hash_id, *args
|
|
1467
|
+
return self._executor.check_run(grad, obj, weights, grad_hash_id, *args)
|
|
1481
1468
|
|
|
1482
|
-
def grad(self, obj, grad, weights, grad_position, *args
|
|
1469
|
+
def grad(self, obj, grad, weights, grad_position, *args):
|
|
1483
1470
|
"""
|
|
1484
1471
|
Get grad graph.
|
|
1485
1472
|
|
|
@@ -1490,12 +1477,11 @@ class _PyNativeExecutor:
|
|
|
1490
1477
|
grad_position (Union(int, tuple[int])): If int, get the gradient with respect to single input.
|
|
1491
1478
|
If tuple, get the gradients with respect to selected inputs. 'grad_position' begins with 0. Default: 0.
|
|
1492
1479
|
args (tuple): Function or cell input arguments.
|
|
1493
|
-
kwargs (dict): keyword arguments.
|
|
1494
1480
|
|
|
1495
1481
|
Return:
|
|
1496
1482
|
None.
|
|
1497
1483
|
"""
|
|
1498
|
-
return self._executor.grad(grad, obj, weights, grad_position, *args
|
|
1484
|
+
return self._executor.grad(grad, obj, weights, grad_position, *args)
|
|
1499
1485
|
|
|
1500
1486
|
def clear_res(self):
|
|
1501
1487
|
"""
|
|
@@ -1528,9 +1514,23 @@ class _PyNativeExecutor:
|
|
|
1528
1514
|
"""
|
|
1529
1515
|
return self._executor.grad_jit(output, *args)
|
|
1530
1516
|
|
|
1517
|
+
def call_custom_bprop(self, obj, output, *args, **kwargs):
|
|
1518
|
+
"""
|
|
1519
|
+
Call custom bprop to build variable for cell bprop.
|
|
1520
|
+
Args:
|
|
1521
|
+
obj (Cell): The function or cell instance.
|
|
1522
|
+
output (Tensor/tuple/list): Function or cell output object.
|
|
1523
|
+
args (tuple): Function or cell input arguments.
|
|
1524
|
+
kwargs (dict): keyword arguments.
|
|
1525
|
+
|
|
1526
|
+
Return:
|
|
1527
|
+
None.
|
|
1528
|
+
"""
|
|
1529
|
+
return self._executor.call_custom_bprop(obj, output, *args, *(kwargs.values()))
|
|
1530
|
+
|
|
1531
1531
|
def grad_flag(self):
|
|
1532
1532
|
"""
|
|
1533
|
-
The flag of building grad graph.
|
|
1533
|
+
The flag of whether the net building grad graph.
|
|
1534
1534
|
|
|
1535
1535
|
Return:
|
|
1536
1536
|
bool, whether building grad graph.
|
|
@@ -1563,7 +1563,7 @@ class _PyNativeExecutor:
|
|
|
1563
1563
|
|
|
1564
1564
|
def enable_grad(self):
|
|
1565
1565
|
"""
|
|
1566
|
-
The global flag whether
|
|
1566
|
+
The global flag that whether need to calculate gradient use in no_grad.
|
|
1567
1567
|
|
|
1568
1568
|
Return:
|
|
1569
1569
|
bool, whether needing to calculate gradient.
|
|
@@ -1582,6 +1582,18 @@ class _PyNativeExecutor:
|
|
|
1582
1582
|
"""
|
|
1583
1583
|
self._executor.set_enable_grad(flag)
|
|
1584
1584
|
|
|
1585
|
+
def requires_grad(self):
|
|
1586
|
+
"""
|
|
1587
|
+
When both enable_grad is true and grad_flag is true, that the flag requires_grad will be true.
|
|
1588
|
+
|
|
1589
|
+
Args:
|
|
1590
|
+
flag (bool): Specifying whether calculating gradient.
|
|
1591
|
+
|
|
1592
|
+
Return:
|
|
1593
|
+
None.
|
|
1594
|
+
"""
|
|
1595
|
+
return self._executor.requires_grad()
|
|
1596
|
+
|
|
1585
1597
|
def set_jit_compile_status(self, status, phase):
|
|
1586
1598
|
"""
|
|
1587
1599
|
Set jit is compiling
|
|
@@ -1605,6 +1617,18 @@ class _PyNativeExecutor:
|
|
|
1605
1617
|
"""
|
|
1606
1618
|
self._executor.set_is_run_recompute(status)
|
|
1607
1619
|
|
|
1620
|
+
def set_cell_use_dynamic_shape_process(self, flag):
|
|
1621
|
+
"""
|
|
1622
|
+
Set the dynamic shape flag of eval process.
|
|
1623
|
+
|
|
1624
|
+
Args:
|
|
1625
|
+
flag (bool): Specifying whether using a dynamic process.
|
|
1626
|
+
|
|
1627
|
+
Return:
|
|
1628
|
+
None.
|
|
1629
|
+
"""
|
|
1630
|
+
self._executor.set_cell_use_dynamic_shape_process(flag)
|
|
1631
|
+
|
|
1608
1632
|
def set_dynamic_input(self, obj, *args):
|
|
1609
1633
|
"""
|
|
1610
1634
|
Set dynamic shape tensor of input arguments.
|
|
@@ -1630,27 +1654,19 @@ class _PyNativeExecutor:
|
|
|
1630
1654
|
"""
|
|
1631
1655
|
return self._executor.get_dynamic_input(*actual_args)
|
|
1632
1656
|
|
|
1633
|
-
def
|
|
1634
|
-
"""
|
|
1635
|
-
The flag of first cell instance.
|
|
1636
|
-
|
|
1637
|
-
Return:
|
|
1638
|
-
bool, specifies whether is the first cell.
|
|
1657
|
+
def set_mixed_precision_type(self, mixed_precision_type, is_push=True):
|
|
1639
1658
|
"""
|
|
1640
|
-
|
|
1641
|
-
return self._executor.is_first_cell()
|
|
1642
|
-
|
|
1643
|
-
def set_hook_changed(self, cell):
|
|
1644
|
-
"""
|
|
1645
|
-
The flag of registering or removing a hook function on Cell instance.
|
|
1659
|
+
The value of mixed precision type.
|
|
1646
1660
|
|
|
1647
1661
|
Args:
|
|
1648
|
-
|
|
1662
|
+
type(MixedPrecisionType): Mix precision type.
|
|
1663
|
+
is_push(bool): If called by __enter__, is push will be True
|
|
1649
1664
|
|
|
1650
1665
|
Return:
|
|
1651
1666
|
None.
|
|
1652
1667
|
"""
|
|
1653
|
-
|
|
1668
|
+
|
|
1669
|
+
return self._executor.set_mixed_precision_type(mixed_precision_type, is_push)
|
|
1654
1670
|
|
|
1655
1671
|
def constant_folding(self, *args):
|
|
1656
1672
|
"""
|
|
@@ -1687,6 +1703,7 @@ class _CellGraphExecutor:
|
|
|
1687
1703
|
self._graph_executor = GraphExecutor_.get_instance()
|
|
1688
1704
|
self._graph_executor.set_py_exe_path(sys.executable)
|
|
1689
1705
|
self._graph_executor.set_kernel_build_server_dir(os.path.split(kernel_build_server.__file__)[0] + os.sep)
|
|
1706
|
+
self._pid = os.getpid()
|
|
1690
1707
|
|
|
1691
1708
|
def init_dataset(self, queue_name, dataset_size, batch_size, dataset_types, dataset_shapes,
|
|
1692
1709
|
input_indexs, phase='dataset', need_run=True):
|
|
@@ -1789,6 +1806,10 @@ class _CellGraphExecutor:
|
|
|
1789
1806
|
self._graph_executor.set_enable_tuple_broaden(self.enable_tuple_broaden)
|
|
1790
1807
|
key = self._graph_executor.generate_arguments_key(obj, args, kwargs, self.enable_tuple_broaden)
|
|
1791
1808
|
obj.arguments_key = str(key)
|
|
1809
|
+
# When exist parameter in the top graph inputs, need check if the parameter object has changed.
|
|
1810
|
+
parameter_ids = _get_parameter_ids(args, kwargs)
|
|
1811
|
+
if parameter_ids != "":
|
|
1812
|
+
obj.arguments_key = obj.arguments_key + '.' + parameter_ids
|
|
1792
1813
|
raw_phase = phase
|
|
1793
1814
|
phase = phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
|
|
1794
1815
|
obj.phase_cache[raw_phase] = phase
|
|
@@ -1825,7 +1846,7 @@ class _CellGraphExecutor:
|
|
|
1825
1846
|
if graph is None:
|
|
1826
1847
|
raise RuntimeError("Compile graph failed for phase {}.".format(phase))
|
|
1827
1848
|
|
|
1828
|
-
auto_parallel_mode = _is_in_auto_parallel_mode()
|
|
1849
|
+
auto_parallel_mode = _is_in_auto_parallel_mode() or _is_parallel_mode()
|
|
1829
1850
|
if not auto_parallel_mode:
|
|
1830
1851
|
replace = obj.init_parameters_data(auto_parallel_mode=auto_parallel_mode)
|
|
1831
1852
|
self._update_param_node_default_input(phase, replace)
|
|
@@ -1913,15 +1934,9 @@ class _CellGraphExecutor:
|
|
|
1913
1934
|
|
|
1914
1935
|
def del_net_res(self, obj, net_id):
|
|
1915
1936
|
"""Clear the memory resource of a network."""
|
|
1916
|
-
|
|
1917
|
-
|
|
1918
|
-
|
|
1919
|
-
"""Increase the count of GraphCell instance."""
|
|
1920
|
-
self._graph_executor.inc_graph_cell_count()
|
|
1921
|
-
|
|
1922
|
-
def dec_graph_cell_count(self):
|
|
1923
|
-
"""Decrease the count of GraphCell instance."""
|
|
1924
|
-
self._graph_executor.dec_graph_cell_count()
|
|
1937
|
+
# no need to del net res by gc in independent dataset process which is a subprocess forked by main process
|
|
1938
|
+
if self._pid == os.getpid():
|
|
1939
|
+
self._graph_executor.del_net_res(obj, net_id)
|
|
1925
1940
|
|
|
1926
1941
|
def _get_branch_control_input(self):
|
|
1927
1942
|
if ('obf_ratio' not in self.obfuscate_config.keys()) or (
|
mindspore/common/dump.py
CHANGED
|
@@ -27,18 +27,17 @@ def set_dump(target, enabled=True):
|
|
|
27
27
|
`target` should be an instance of :class:`mindspore.nn.Cell` or :class:`mindspore.ops.Primitive` .
|
|
28
28
|
Please note that this API takes effect only when Synchronous Dump is enabled and the `dump_mode`
|
|
29
29
|
field in dump config file is ``"2"`` . See the `dump document
|
|
30
|
-
<https://www.mindspore.cn/
|
|
30
|
+
<https://www.mindspore.cn/docs/en/master/model_train/debug/dump.html>`_ for details.
|
|
31
31
|
The default enabled status for
|
|
32
32
|
a :class:`mindspore.nn.Cell` or :class:`mindspore.ops.Primitive` is False.
|
|
33
33
|
|
|
34
34
|
Note:
|
|
35
|
-
1. This API
|
|
36
|
-
2. This API only supports being called before training starts.
|
|
35
|
+
1. This API only supports being called before training starts.
|
|
37
36
|
If you call this API during training, it may not be effective.
|
|
38
|
-
|
|
37
|
+
2. After using `set_dump(Cell, True)` , operators in forward and backward
|
|
39
38
|
computation (computation generated by the grad operations) of the
|
|
40
39
|
cell will be dumped.
|
|
41
|
-
|
|
40
|
+
3. For :class:`mindspore.nn.SoftmaxCrossEntropyWithLogits` layer, the forward
|
|
42
41
|
computation and backward computation use the same set of
|
|
43
42
|
operators. So you can only see dump data from backward computation.
|
|
44
43
|
Please note that :class:`mindspore.nn.SoftmaxCrossEntropyWithLogits` layer will also use
|
|
@@ -58,7 +57,7 @@ def set_dump(target, enabled=True):
|
|
|
58
57
|
.. note::
|
|
59
58
|
Please set environment variable `MINDSPORE_DUMP_CONFIG` to the dump config file and set `dump_mode` field
|
|
60
59
|
in dump config file to 2 before running this example.
|
|
61
|
-
See `dump document <https://www.mindspore.cn/
|
|
60
|
+
See `dump document <https://www.mindspore.cn/docs/en/master/model_train/debug/dump.html>`_ for details.
|
|
62
61
|
|
|
63
62
|
>>> import numpy as np
|
|
64
63
|
>>> import mindspore as ms
|
mindspore/common/generator.py
CHANGED
|
@@ -56,12 +56,6 @@ class Generator:
|
|
|
56
56
|
A generator that manages the state of random numbers and provides seed and offset for random functions.
|
|
57
57
|
When the seed and offset are fixed, the random function generates the same random sequence.
|
|
58
58
|
|
|
59
|
-
Inputs:
|
|
60
|
-
- **step** (int) - Set the step size for offset update.
|
|
61
|
-
|
|
62
|
-
Outputs:
|
|
63
|
-
Tuple consisting of the seed and offset of generator.
|
|
64
|
-
|
|
65
59
|
Supported Platforms:
|
|
66
60
|
``Ascend`` ``GPU`` ``CPU``
|
|
67
61
|
|
|
@@ -199,7 +193,7 @@ def manual_seed(seed): # pylint: disable=redefined-outer-name
|
|
|
199
193
|
>>> print(initial_seed())
|
|
200
194
|
13
|
|
201
195
|
"""
|
|
202
|
-
default_generator.manual_seed(seed)
|
|
196
|
+
return default_generator.manual_seed(seed)
|
|
203
197
|
|
|
204
198
|
|
|
205
199
|
def initial_seed():
|
mindspore/common/hook_handle.py
CHANGED
|
@@ -77,27 +77,19 @@ class HookHandle:
|
|
|
77
77
|
It is only supported in pynative mode and works when registering or removing hook function for Cell object.
|
|
78
78
|
|
|
79
79
|
Args:
|
|
80
|
-
|
|
81
|
-
hook_key (int): The key of cell hook function in dict. It is generated during cell hook function registration.
|
|
82
|
-
Default value: -1.
|
|
83
|
-
hook_type (str): The type of cell hook function: '_forward_pre_hook', '_forward_hook' or '_cell_backward_hook'.
|
|
84
|
-
Default value: "".
|
|
80
|
+
hook_dict (Dict): The hook object with hook function registered on. Default value: None.
|
|
85
81
|
|
|
86
82
|
Supported Platforms:
|
|
87
83
|
``Ascend`` ``GPU`` ``CPU``
|
|
88
84
|
"""
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
def __del__(self):
|
|
98
|
-
self._hook_cell = None
|
|
99
|
-
self._hook_key = None
|
|
100
|
-
self._hook_type = None
|
|
85
|
+
unique_id = 0
|
|
86
|
+
|
|
87
|
+
def __init__(self, hook_dict=None):
|
|
88
|
+
self.hook_dict_ref = None
|
|
89
|
+
if hook_dict is not None:
|
|
90
|
+
self.hook_dict_ref = weakref.ref(hook_dict)
|
|
91
|
+
self.handle_id = HookHandle.unique_id
|
|
92
|
+
HookHandle.unique_id += 1
|
|
101
93
|
|
|
102
94
|
def remove(self):
|
|
103
95
|
"""
|
|
@@ -121,7 +113,7 @@ class HookHandle:
|
|
|
121
113
|
>>> from mindspore import Tensor
|
|
122
114
|
>>> from mindspore.ops import GradOperation
|
|
123
115
|
>>> ms.set_context(mode=ms.PYNATIVE_MODE)
|
|
124
|
-
>>> def forward_pre_hook_fn(
|
|
116
|
+
>>> def forward_pre_hook_fn(cell, inputs):
|
|
125
117
|
... print("forward inputs: ", inputs)
|
|
126
118
|
...
|
|
127
119
|
>>> class Net(nn.Cell):
|
|
@@ -145,11 +137,7 @@ class HookHandle:
|
|
|
145
137
|
(Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1], dtype=Float32,
|
|
146
138
|
value= [ 2.00000000e+00]))
|
|
147
139
|
"""
|
|
148
|
-
if self.
|
|
149
|
-
|
|
150
|
-
if
|
|
151
|
-
del
|
|
152
|
-
elif self._hook_type == "_forward_hook" and self._hook_key in hook_cell._forward_hook:
|
|
153
|
-
del hook_cell._forward_hook[self._hook_key]
|
|
154
|
-
elif self._hook_type == "_cell_backward_hook":
|
|
155
|
-
hook_cell._cell_backward_hook.remove_backward_hook(self._hook_key)
|
|
140
|
+
if self.hook_dict_ref is not None:
|
|
141
|
+
hook_dict = self.hook_dict_ref()
|
|
142
|
+
if hook_dict is not None and self.handle_id in hook_dict:
|
|
143
|
+
del hook_dict[self.handle_id]
|
mindspore/common/initializer.py
CHANGED
|
@@ -103,6 +103,12 @@ def _numpy_seed():
|
|
|
103
103
|
return np.random.randint(low=1, high=(1 << 63), dtype=np.int64)
|
|
104
104
|
|
|
105
105
|
|
|
106
|
+
def _init_random_normal_inplace(mean, sigma, arr):
|
|
107
|
+
if sigma < 0:
|
|
108
|
+
raise ValueError("sigma < 0")
|
|
109
|
+
_random_normal(_numpy_seed(), arr, mean, sigma)
|
|
110
|
+
|
|
111
|
+
|
|
106
112
|
def _init_random_normal(mean, sigma, shape):
|
|
107
113
|
if sigma < 0:
|
|
108
114
|
raise ValueError("sigma < 0")
|
|
@@ -111,12 +117,22 @@ def _init_random_normal(mean, sigma, shape):
|
|
|
111
117
|
return data
|
|
112
118
|
|
|
113
119
|
|
|
120
|
+
def _init_random_uniform_inplace(a, b, arr):
|
|
121
|
+
_random_uniform(_numpy_seed(), arr, a, b)
|
|
122
|
+
|
|
123
|
+
|
|
114
124
|
def _init_random_uniform(a, b, shape):
|
|
115
125
|
data = np.ndarray(shape=shape, dtype=np.float32)
|
|
116
126
|
_random_uniform(_numpy_seed(), data, a, b)
|
|
117
127
|
return data
|
|
118
128
|
|
|
119
129
|
|
|
130
|
+
def _init_truncated_normal_inplace(a, b, mean, sigma, arr):
|
|
131
|
+
if sigma < 0:
|
|
132
|
+
raise ValueError("sigma < 0")
|
|
133
|
+
_truncated_normal(_numpy_seed(), arr, a, b, mean, sigma)
|
|
134
|
+
|
|
135
|
+
|
|
120
136
|
def _init_truncated_normal(a, b, mean, sigma, shape):
|
|
121
137
|
if sigma < 0:
|
|
122
138
|
raise ValueError("sigma < 0")
|
|
@@ -298,9 +314,11 @@ class XavierNormal(Initializer):
|
|
|
298
314
|
fan_in, fan_out = _calculate_fan_in_and_fan_out(arr.shape)
|
|
299
315
|
|
|
300
316
|
std = self.gain * math.sqrt(2.0 / float(fan_in + fan_out))
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
317
|
+
if isinstance(arr, np.ndarray) and arr.dtype == np.float32:
|
|
318
|
+
_init_random_normal_inplace(0, std, arr)
|
|
319
|
+
else:
|
|
320
|
+
data = _init_random_normal(0, std, arr.shape)
|
|
321
|
+
_assignment(arr, data)
|
|
304
322
|
|
|
305
323
|
|
|
306
324
|
@_register('xavier_uniform')
|
|
@@ -337,8 +355,11 @@ class XavierUniform(Initializer):
|
|
|
337
355
|
def _initialize(self, arr):
|
|
338
356
|
n_in, n_out = _calculate_fan_in_and_fan_out(arr.shape)
|
|
339
357
|
boundary = self.gain * math.sqrt(6.0 / (n_in + n_out))
|
|
340
|
-
|
|
341
|
-
|
|
358
|
+
if isinstance(arr, np.ndarray) and arr.dtype == np.float32:
|
|
359
|
+
_init_random_uniform_inplace(-boundary, boundary, arr)
|
|
360
|
+
else:
|
|
361
|
+
data = _init_random_uniform(-boundary, boundary, arr.shape)
|
|
362
|
+
_assignment(arr, data)
|
|
342
363
|
|
|
343
364
|
|
|
344
365
|
@_register('he_uniform')
|
|
@@ -386,8 +407,11 @@ class HeUniform(Initializer):
|
|
|
386
407
|
gain = _calculate_gain(self.nonlinearity, self.negative_slope)
|
|
387
408
|
std = gain / math.sqrt(fan)
|
|
388
409
|
boundary = math.sqrt(3.0) * std
|
|
389
|
-
|
|
390
|
-
|
|
410
|
+
if isinstance(arr, np.ndarray) and arr.dtype == np.float32:
|
|
411
|
+
_init_random_uniform_inplace(-boundary, boundary, arr)
|
|
412
|
+
else:
|
|
413
|
+
data = _init_random_uniform(-boundary, boundary, arr.shape)
|
|
414
|
+
_assignment(arr, data)
|
|
391
415
|
|
|
392
416
|
|
|
393
417
|
@_register('he_normal')
|
|
@@ -432,8 +456,11 @@ class HeNormal(Initializer):
|
|
|
432
456
|
fan = _calculate_correct_fan(arr.shape, self.mode)
|
|
433
457
|
gain = _calculate_gain(self.nonlinearity, self.negative_slope)
|
|
434
458
|
std = gain / math.sqrt(fan)
|
|
435
|
-
|
|
436
|
-
|
|
459
|
+
if isinstance(arr, np.ndarray) and arr.dtype == np.float32:
|
|
460
|
+
_init_random_normal_inplace(0, std, arr)
|
|
461
|
+
else:
|
|
462
|
+
data = _init_random_normal(0, std, arr.shape)
|
|
463
|
+
_assignment(arr, data)
|
|
437
464
|
|
|
438
465
|
|
|
439
466
|
class Constant(Initializer):
|
|
@@ -718,8 +745,11 @@ class Uniform(Initializer):
|
|
|
718
745
|
self.scale = scale
|
|
719
746
|
|
|
720
747
|
def _initialize(self, arr):
|
|
721
|
-
|
|
722
|
-
|
|
748
|
+
if isinstance(arr, np.ndarray) and arr.dtype == np.float32:
|
|
749
|
+
_init_random_uniform_inplace(-self.scale, self.scale, arr)
|
|
750
|
+
else:
|
|
751
|
+
tmp = _init_random_uniform(-self.scale, self.scale, arr.shape)
|
|
752
|
+
_assignment(arr, tmp)
|
|
723
753
|
|
|
724
754
|
|
|
725
755
|
@_register()
|
|
@@ -749,8 +779,11 @@ class Normal(Initializer):
|
|
|
749
779
|
self.mean = mean
|
|
750
780
|
|
|
751
781
|
def _initialize(self, arr):
|
|
752
|
-
|
|
753
|
-
|
|
782
|
+
if isinstance(arr, np.ndarray) and arr.dtype == np.float32:
|
|
783
|
+
_init_random_normal_inplace(self.mean, self.sigma, arr)
|
|
784
|
+
else:
|
|
785
|
+
data = _init_random_normal(self.mean, self.sigma, arr.shape)
|
|
786
|
+
_assignment(arr, data)
|
|
754
787
|
|
|
755
788
|
|
|
756
789
|
@_register()
|
|
@@ -780,8 +813,11 @@ class TruncatedNormal(Initializer):
|
|
|
780
813
|
self.b = b
|
|
781
814
|
|
|
782
815
|
def _initialize(self, arr):
|
|
783
|
-
|
|
784
|
-
|
|
816
|
+
if isinstance(arr, np.ndarray) and arr.dtype == np.float32:
|
|
817
|
+
_init_truncated_normal_inplace(self.a, self.b, self.mean, self.sigma, arr)
|
|
818
|
+
else:
|
|
819
|
+
tmp = _init_truncated_normal(self.a, self.b, self.mean, self.sigma, arr.shape)
|
|
820
|
+
_assignment(arr, tmp)
|
|
785
821
|
|
|
786
822
|
|
|
787
823
|
def initializer(init, shape=None, dtype=mstype.float32):
|
mindspore/common/mindir_util.py
CHANGED
|
@@ -90,9 +90,9 @@ def save_mindir(model, file_name):
|
|
|
90
90
|
if not file_name.endswith('.mindir'):
|
|
91
91
|
file_name += ".mindir"
|
|
92
92
|
|
|
93
|
-
current_path = os.path.
|
|
93
|
+
current_path = os.path.realpath(file_name)
|
|
94
94
|
dirname = os.path.dirname(current_path)
|
|
95
|
-
os.makedirs(dirname, exist_ok=True)
|
|
95
|
+
os.makedirs(dirname, mode=0o700, exist_ok=True)
|
|
96
96
|
if os.path.exists(file_name):
|
|
97
97
|
os.chmod(file_name, stat.S_IWUSR)
|
|
98
98
|
|