mindspore 2.3.0__cp39-cp39-win_amd64.whl → 2.4.0__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +3 -1
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +50 -9
- mindspore/_extends/parse/compile_config.py +41 -0
- mindspore/_extends/parse/parser.py +9 -7
- mindspore/_extends/parse/standard_method.py +52 -14
- mindspore/_extends/pijit/pijit_func_white_list.py +350 -24
- mindspore/amp.py +24 -10
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/common/__init__.py +6 -4
- mindspore/common/_pijit_context.py +190 -0
- mindspore/common/_register_for_tensor.py +2 -1
- mindspore/common/_tensor_overload.py +139 -0
- mindspore/common/api.py +102 -87
- mindspore/common/dump.py +5 -6
- mindspore/common/generator.py +1 -7
- mindspore/common/hook_handle.py +14 -26
- mindspore/common/mindir_util.py +2 -2
- mindspore/common/parameter.py +46 -13
- mindspore/common/recompute.py +39 -9
- mindspore/common/sparse_tensor.py +7 -3
- mindspore/common/tensor.py +209 -29
- mindspore/communication/__init__.py +1 -1
- mindspore/communication/_comm_helper.py +38 -3
- mindspore/communication/comm_func.py +310 -55
- mindspore/communication/management.py +14 -14
- mindspore/context.py +123 -22
- mindspore/dataset/__init__.py +1 -1
- mindspore/dataset/audio/__init__.py +1 -1
- mindspore/dataset/core/config.py +7 -0
- mindspore/dataset/core/validator_helpers.py +7 -0
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +72 -44
- mindspore/dataset/engine/datasets_audio.py +7 -7
- mindspore/dataset/engine/datasets_standard_format.py +53 -3
- mindspore/dataset/engine/datasets_text.py +20 -20
- mindspore/dataset/engine/datasets_user_defined.py +174 -104
- mindspore/dataset/engine/datasets_vision.py +33 -33
- mindspore/dataset/engine/iterators.py +29 -0
- mindspore/dataset/engine/obs/util.py +7 -0
- mindspore/dataset/engine/queue.py +114 -60
- mindspore/dataset/engine/serializer_deserializer.py +2 -2
- mindspore/dataset/engine/validators.py +34 -14
- mindspore/dataset/text/__init__.py +1 -4
- mindspore/dataset/transforms/__init__.py +0 -3
- mindspore/dataset/utils/line_reader.py +2 -0
- mindspore/dataset/vision/__init__.py +1 -4
- mindspore/dataset/vision/utils.py +1 -1
- mindspore/dataset/vision/validators.py +2 -1
- mindspore/dnnl.dll +0 -0
- mindspore/{nn/extend → experimental/es}/__init__.py +4 -11
- mindspore/experimental/es/embedding_service.py +883 -0
- mindspore/{nn/layer → experimental/es}/embedding_service_layer.py +218 -30
- mindspore/experimental/llm_boost/__init__.py +21 -0
- mindspore/{nn/extend/layer → experimental/llm_boost/atb}/__init__.py +4 -8
- mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
- mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
- mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
- mindspore/experimental/llm_boost/register.py +129 -0
- mindspore/experimental/llm_boost/utils.py +31 -0
- mindspore/experimental/optim/adamw.py +85 -0
- mindspore/experimental/optim/optimizer.py +3 -0
- mindspore/hal/__init__.py +3 -3
- mindspore/hal/contiguous_tensors_handle.py +175 -0
- mindspore/hal/stream.py +18 -0
- mindspore/include/api/model_group.h +13 -1
- mindspore/include/api/types.h +10 -10
- mindspore/include/dataset/config.h +2 -2
- mindspore/include/dataset/constants.h +2 -2
- mindspore/include/dataset/execute.h +2 -2
- mindspore/include/dataset/vision.h +4 -0
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +1 -1
- mindspore/mindrecord/filewriter.py +68 -51
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mint/__init__.py +495 -46
- mindspore/mint/distributed/__init__.py +31 -0
- mindspore/mint/distributed/distributed.py +254 -0
- mindspore/mint/nn/__init__.py +266 -21
- mindspore/mint/nn/functional.py +125 -19
- mindspore/mint/nn/layer/__init__.py +39 -0
- mindspore/mint/nn/layer/activation.py +133 -0
- mindspore/mint/nn/layer/normalization.py +477 -0
- mindspore/mint/nn/layer/pooling.py +110 -0
- mindspore/mint/optim/adamw.py +28 -7
- mindspore/mint/special/__init__.py +63 -0
- mindspore/multiprocessing/__init__.py +2 -1
- mindspore/nn/__init__.py +0 -1
- mindspore/nn/cell.py +275 -93
- mindspore/nn/layer/activation.py +211 -44
- mindspore/nn/layer/basic.py +113 -3
- mindspore/nn/layer/embedding.py +120 -2
- mindspore/nn/layer/normalization.py +101 -5
- mindspore/nn/layer/padding.py +34 -48
- mindspore/nn/layer/pooling.py +161 -7
- mindspore/nn/layer/transformer.py +3 -3
- mindspore/nn/loss/__init__.py +2 -2
- mindspore/nn/loss/loss.py +84 -6
- mindspore/nn/optim/__init__.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -1
- mindspore/nn/optim/adam.py +1 -1
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/tft_wrapper.py +127 -0
- mindspore/nn/wrap/cell_wrapper.py +12 -23
- mindspore/nn/wrap/grad_reducer.py +5 -5
- mindspore/nn/wrap/loss_scale.py +17 -3
- mindspore/numpy/__init__.py +1 -1
- mindspore/numpy/array_creations.py +65 -68
- mindspore/numpy/array_ops.py +64 -60
- mindspore/numpy/fft.py +610 -75
- mindspore/numpy/logic_ops.py +11 -10
- mindspore/numpy/math_ops.py +85 -84
- mindspore/numpy/utils_const.py +4 -4
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +6 -4
- mindspore/ops/_grad_experimental/grad_comm_ops.py +47 -3
- mindspore/ops/_grad_experimental/grad_math_ops.py +0 -22
- mindspore/ops/_vmap/vmap_array_ops.py +2 -4
- mindspore/ops/_vmap/vmap_math_ops.py +17 -1
- mindspore/ops/_vmap/vmap_nn_ops.py +43 -2
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +85 -7
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +2 -0
- mindspore/ops/auto_generate/gen_extend_func.py +734 -13
- mindspore/ops/auto_generate/gen_ops_def.py +2420 -381
- mindspore/ops/auto_generate/gen_ops_prim.py +5196 -1659
- mindspore/ops/auto_generate/pyboost_inner_prim.py +176 -56
- mindspore/ops/composite/base.py +85 -48
- mindspore/ops/composite/multitype_ops/_compile_utils.py +1 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -2
- mindspore/ops/function/__init__.py +22 -0
- mindspore/ops/function/array_func.py +490 -153
- mindspore/ops/function/debug_func.py +113 -1
- mindspore/ops/function/fft_func.py +15 -2
- mindspore/ops/function/grad/grad_func.py +3 -2
- mindspore/ops/function/math_func.py +558 -207
- mindspore/ops/function/nn_func.py +817 -383
- mindspore/ops/function/other_func.py +3 -2
- mindspore/ops/function/random_func.py +184 -8
- mindspore/ops/function/reshard_func.py +13 -11
- mindspore/ops/function/sparse_unary_func.py +1 -1
- mindspore/ops/function/vmap_func.py +3 -2
- mindspore/ops/functional.py +24 -14
- mindspore/ops/op_info_register.py +3 -3
- mindspore/ops/operations/__init__.py +6 -1
- mindspore/ops/operations/_grad_ops.py +2 -76
- mindspore/ops/operations/_infer_ops.py +1 -1
- mindspore/ops/operations/_inner_ops.py +71 -94
- mindspore/ops/operations/array_ops.py +12 -146
- mindspore/ops/operations/comm_ops.py +42 -53
- mindspore/ops/operations/custom_ops.py +83 -19
- mindspore/ops/operations/debug_ops.py +42 -10
- mindspore/ops/operations/manually_defined/_inner.py +12 -0
- mindspore/ops/operations/manually_defined/ops_def.py +265 -10
- mindspore/ops/operations/math_ops.py +12 -223
- mindspore/ops/operations/nn_ops.py +20 -114
- mindspore/ops/operations/other_ops.py +7 -4
- mindspore/ops/operations/random_ops.py +46 -1
- mindspore/ops/primitive.py +18 -6
- mindspore/ops_generate/arg_dtype_cast.py +2 -0
- mindspore/ops_generate/gen_aclnn_implement.py +11 -11
- mindspore/ops_generate/gen_constants.py +36 -0
- mindspore/ops_generate/gen_ops.py +67 -52
- mindspore/ops_generate/gen_ops_inner_prim.py +1 -1
- mindspore/ops_generate/gen_pyboost_func.py +131 -47
- mindspore/ops_generate/op_proto.py +10 -3
- mindspore/ops_generate/pyboost_utils.py +14 -1
- mindspore/ops_generate/template.py +43 -21
- mindspore/parallel/__init__.py +3 -1
- mindspore/parallel/_auto_parallel_context.py +28 -8
- mindspore/parallel/_cell_wrapper.py +83 -0
- mindspore/parallel/_parallel_serialization.py +47 -19
- mindspore/parallel/_tensor.py +81 -11
- mindspore/parallel/_utils.py +13 -1
- mindspore/parallel/algo_parameter_config.py +5 -5
- mindspore/parallel/checkpoint_transform.py +46 -39
- mindspore/parallel/cluster/process_entity/__init__.py +1 -1
- mindspore/parallel/cluster/process_entity/_api.py +31 -23
- mindspore/parallel/cluster/process_entity/_utils.py +2 -27
- mindspore/parallel/parameter_broadcast.py +3 -4
- mindspore/parallel/shard.py +162 -31
- mindspore/parallel/transform_safetensors.py +993 -0
- mindspore/profiler/__init__.py +2 -1
- mindspore/profiler/common/constant.py +29 -0
- mindspore/profiler/common/registry.py +47 -0
- mindspore/profiler/common/util.py +28 -0
- mindspore/profiler/dynamic_profiler.py +694 -0
- mindspore/profiler/envprofiling.py +17 -19
- mindspore/profiler/parser/ascend_analysis/constant.py +18 -0
- mindspore/profiler/parser/ascend_analysis/file_manager.py +25 -4
- mindspore/profiler/parser/ascend_analysis/function_event.py +43 -19
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +31 -26
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +56 -10
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +55 -8
- mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +27 -20
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +9 -2
- mindspore/profiler/parser/ascend_msprof_exporter.py +5 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +27 -25
- mindspore/profiler/parser/base_timeline_generator.py +19 -25
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
- mindspore/profiler/parser/framework_parser.py +1 -391
- mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
- mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
- mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
- mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
- mindspore/profiler/parser/memory_usage_parser.py +0 -154
- mindspore/profiler/parser/profiler_info.py +78 -6
- mindspore/profiler/profiler.py +153 -0
- mindspore/profiler/profiling.py +280 -412
- mindspore/rewrite/__init__.py +1 -2
- mindspore/rewrite/common/namespace.py +4 -4
- mindspore/rewrite/symbol_tree/symbol_tree.py +3 -3
- mindspore/run_check/_check_version.py +36 -103
- mindspore/safeguard/rewrite_obfuscation.py +591 -247
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +4 -3
- mindspore/train/_utils.py +28 -2
- mindspore/train/amp.py +171 -53
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +4 -4
- mindspore/train/callback/_checkpoint.py +85 -22
- mindspore/train/callback/_cluster_monitor.py +1 -1
- mindspore/train/callback/_flops_collector.py +1 -0
- mindspore/train/callback/_loss_monitor.py +3 -3
- mindspore/train/callback/_on_request_exit.py +134 -31
- mindspore/train/callback/_summary_collector.py +5 -5
- mindspore/train/callback/_tft_register.py +352 -0
- mindspore/train/dataset_helper.py +7 -3
- mindspore/train/metrics/metric.py +3 -3
- mindspore/train/metrics/roc.py +4 -4
- mindspore/train/mind_ir_pb2.py +44 -39
- mindspore/train/model.py +134 -58
- mindspore/train/serialization.py +336 -112
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +21 -0
- mindspore/utils/utils.py +60 -0
- mindspore/version.py +1 -1
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/METADATA +6 -2
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/RECORD +258 -252
- mindspore/include/c_api/ms/abstract.h +0 -67
- mindspore/include/c_api/ms/attribute.h +0 -197
- mindspore/include/c_api/ms/base/handle_types.h +0 -43
- mindspore/include/c_api/ms/base/macros.h +0 -32
- mindspore/include/c_api/ms/base/status.h +0 -33
- mindspore/include/c_api/ms/base/types.h +0 -283
- mindspore/include/c_api/ms/context.h +0 -102
- mindspore/include/c_api/ms/graph.h +0 -160
- mindspore/include/c_api/ms/node.h +0 -606
- mindspore/include/c_api/ms/tensor.h +0 -161
- mindspore/include/c_api/ms/value.h +0 -84
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/nn/extend/basic.py +0 -140
- mindspore/nn/extend/embedding.py +0 -143
- mindspore/nn/extend/layer/normalization.py +0 -109
- mindspore/nn/extend/pooling.py +0 -117
- mindspore/nn/layer/embedding_service.py +0 -531
- mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
- mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
- mindspore/ops/extend/__init__.py +0 -53
- mindspore/ops/extend/array_func.py +0 -218
- mindspore/ops/extend/math_func.py +0 -76
- mindspore/ops/extend/nn_func.py +0 -308
- mindspore/ops/silent_check.py +0 -162
- mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
- mindspore/profiler/parser/msadvisor_parser.py +0 -240
- mindspore/train/callback/_mindio_ttp.py +0 -443
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/WHEEL +0 -0
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/entry_points.txt +0 -0
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/top_level.txt +0 -0
|
@@ -27,13 +27,129 @@ from mindspore.ops.composite.base import GradOperation, _Grad
|
|
|
27
27
|
from mindspore.ops._primitive_cache import _get_cache_prim
|
|
28
28
|
from mindspore.common.api import jit
|
|
29
29
|
from mindspore.common.tensor import Tensor
|
|
30
|
-
from mindspore.common._register_for_tensor import Registry
|
|
31
|
-
from mindspore._c_expression import MetaFuncGraph_, function_id
|
|
30
|
+
from mindspore.common._register_for_tensor import Registry
|
|
31
|
+
from mindspore._c_expression import MetaFuncGraph_, function_id
|
|
32
32
|
from mindspore._c_expression import Tensor as Tensor_
|
|
33
33
|
from mindspore._extends.parse.resources import convert_object_map
|
|
34
34
|
from mindspore import _checkparam as validator
|
|
35
35
|
from mindspore import Parameter, ParameterTuple
|
|
36
36
|
from mindspore.common.initializer import Zero
|
|
37
|
+
from mindspore.ops.function import array_func
|
|
38
|
+
from mindspore.ops import operations as P
|
|
39
|
+
from mindspore.ops import functional as F
|
|
40
|
+
from mindspore._c_expression.np_dtypes import np_version_valid
|
|
41
|
+
from mindspore.common.dtype import type_size_in_bytes
|
|
42
|
+
from mindspore.communication._comm_helper import _is_initialized, _get_rank_helper, _get_local_rank_helper, \
|
|
43
|
+
_get_size_helper, _get_local_size_helper, _get_world_rank_from_group_rank_helper, _get_group_ranks, \
|
|
44
|
+
_get_group_rank_from_world_rank_helper, _set_elegant_exit_handle
|
|
45
|
+
from mindspore import SummaryCollector
|
|
46
|
+
from mindspore.train import ModelCheckpoint, LossMonitor
|
|
47
|
+
from mindspore.train.model import _FrameworkProfilerCallback
|
|
48
|
+
from mindspore.train.data_sink import _init_sink_dataset
|
|
49
|
+
from mindspore.train.summary import SummaryRecord
|
|
50
|
+
from mindspore.train._utils import _exec_datagraph
|
|
51
|
+
from mindspore.train.summary.writer import BaseWriter
|
|
52
|
+
from mindspore.train.serialization import _exec_save, load, export_split_mindir, obfuscate_model, _parse_ckpt_proto, \
|
|
53
|
+
_generate_front_info_for_param_data_file, _get_data_file, _encrypt_data, _split_save, _save_mindir_together, \
|
|
54
|
+
_load_into_param_dict
|
|
55
|
+
from mindspore.parallel import _cost_model_context
|
|
56
|
+
from mindspore.parallel._offload_context import offload_context
|
|
57
|
+
from mindspore.run_check._check_version import check_version_and_env_config
|
|
58
|
+
from mindspore.dataset.callback.ds_callback import DSCallback, WaitedDSCallback
|
|
59
|
+
from mindspore.dataset.transforms.c_transforms import TensorOperation as CTensorOperation, OneHot as COneHot, \
|
|
60
|
+
Fill as CFill, TypeCast as CTypeCast, Slice as CSlice, Mask as CMask, PadEnd as CPadEnd, \
|
|
61
|
+
Concatenate as CConcatenate, Duplicate as CDuplicate, Unique as CUnique, Compose as CCompose, \
|
|
62
|
+
RandomApply as CRandomApply, RandomChoice as CRandomChoice, Plugin as CPlugin
|
|
63
|
+
from mindspore.dataset.transforms.transforms import TensorOperation, Compose, Concatenate, Duplicate, Fill, Mask, \
|
|
64
|
+
OneHot, PadEnd, Plugin, RandomApply, RandomChoice, Slice, TypeCast, Unique
|
|
65
|
+
from mindspore.dataset.text.transforms import AddToken, JiebaTokenizer, Lookup, Ngram, SentencePieceTokenizer, \
|
|
66
|
+
SlidingWindow, ToNumber, ToVectors, Truncate, TruncateSequencePair, UnicodeCharTokenizer, WordpieceTokenizer, \
|
|
67
|
+
BasicTokenizer, BertTokenizer, CaseFold, FilterWikipediaXML, NormalizeUTF8, RegexReplace, RegexTokenizer, \
|
|
68
|
+
UnicodeScriptTokenizer, WhitespaceTokenizer
|
|
69
|
+
from mindspore.dataset.core.datatypes import nptype_to_detype, mstype_to_detype, mstypelist_to_detypelist
|
|
70
|
+
from mindspore.dataset.audio.utils import create_dct, linear_fbanks, melscale_fbanks
|
|
71
|
+
from mindspore.dataset.audio.transforms import AllpassBiquad, AmplitudeToDB, Angle, BandBiquad, BandpassBiquad, \
|
|
72
|
+
BandrejectBiquad, BassBiquad, Biquad, ComplexNorm, ComputeDeltas, Contrast, DBToAmplitude, DCShift, DeemphBiquad, \
|
|
73
|
+
DetectPitchFrequency, Dither, EqualizerBiquad, Fade, Filtfilt, Flanger, FrequencyMasking, Gain, GriffinLim, \
|
|
74
|
+
HighpassBiquad, InverseMelScale, InverseSpectrogram, LFCC, LFilter, LowpassBiquad, Magphase, MaskAlongAxis, \
|
|
75
|
+
MaskAlongAxisIID, MelScale, MelSpectrogram, MFCC, MuLawDecoding, MuLawEncoding, Overdrive, Phaser, PhaseVocoder, \
|
|
76
|
+
PitchShift, Resample, RiaaBiquad, SlidingWindowCmn, SpectralCentroid, Spectrogram, TimeMasking, TimeStretch, \
|
|
77
|
+
TrebleBiquad, Vad, Vol
|
|
78
|
+
from mindspore.dataset.engine.datasets_audio import CMUArcticDataset, GTZANDataset, LibriTTSDataset, LJSpeechDataset, \
|
|
79
|
+
SpeechCommandsDataset, TedliumDataset, YesNoDataset
|
|
80
|
+
from mindspore.dataset.engine.cache_client import DatasetCache
|
|
81
|
+
from mindspore.dataset.engine.iterators import Iterator
|
|
82
|
+
from mindspore.dataset.engine.datasets_standard_format import CSVDataset, MindDataset, TFRecordDataset
|
|
83
|
+
from mindspore.dataset.engine.datasets_text import AGNewsDataset, AmazonReviewDataset, CLUEDataset, CoNLL2000Dataset, \
|
|
84
|
+
DBpediaDataset, EnWik9Dataset, IMDBDataset, IWSLT2016Dataset, IWSLT2017Dataset, Multi30kDataset, WikiTextDataset, \
|
|
85
|
+
PennTreebankDataset, SogouNewsDataset, SQuADDataset, SST2Dataset, TextFileDataset, UDPOSDataset, \
|
|
86
|
+
YelpReviewDataset, YahooAnswersDataset
|
|
87
|
+
from mindspore.dataset.engine.datasets_user_defined import GeneratorDataset
|
|
88
|
+
from mindspore.dataset.engine.datasets_vision import Caltech256Dataset, CelebADataset, Cifar100Dataset, \
|
|
89
|
+
Cifar10Dataset, CityscapesDataset, CocoDataset, DIV2KDataset, EMnistDataset, FakeImageDataset, \
|
|
90
|
+
FashionMnistDataset, FlickrDataset, Food101Dataset, ImageFolderDataset, KITTIDataset, KMnistDataset, \
|
|
91
|
+
LFWDataset, LSUNDataset, ManifestDataset, MnistDataset, OmniglotDataset, PhotoTourDataset, \
|
|
92
|
+
Places365Dataset, QMnistDataset, RandomDataset, RenderedSST2Dataset, SBUDataset, SemeionDataset, \
|
|
93
|
+
STL10Dataset, SUN397Dataset, USPSDataset, VOCDataset, WIDERFaceDataset
|
|
94
|
+
from mindspore.dataset.engine.queue import _SharedQueue
|
|
95
|
+
from mindspore.dataset.engine.datasets import Dataset, BucketBatchByLengthDataset, BatchDataset, \
|
|
96
|
+
PaddedBatchDataset, SyncWaitDataset, ShuffleDataset, MapDataset, FilterDataset, RepeatDataset, SkipDataset, \
|
|
97
|
+
TakeDataset, ZipDataset, ConcatDataset, RenameDataset, ProjectDataset, _ToDevice, TransferDataset, Schema
|
|
98
|
+
from mindspore.dataset.engine.samplers import Sampler, DistributedSampler, PKSampler, RandomSampler, SubsetSampler, \
|
|
99
|
+
SequentialSampler, SubsetRandomSampler, WeightedRandomSampler
|
|
100
|
+
from mindspore.dataset.vision.c_transforms import ImageTensorOperation, AdjustGamma, AutoAugment, AutoContrast, \
|
|
101
|
+
BoundingBoxAugment, CenterCrop, ConvertColor, Crop, CutMixBatch, CutOut, Decode, Equalize, GaussianBlur, \
|
|
102
|
+
HorizontalFlip, HWC2CHW, Invert, MixUpBatch, Normalize, NormalizePad, Pad, RandomAdjustSharpness, RandomAffine, \
|
|
103
|
+
RandomAutoContrast, RandomColor, RandomColorAdjust, RandomCrop, RandomCropDecodeResize, RandomCropWithBBox, \
|
|
104
|
+
RandomEqualize, RandomHorizontalFlip, RandomHorizontalFlipWithBBox, RandomInvert, RandomLighting, RandomRotation, \
|
|
105
|
+
RandomPosterize, RandomResizedCrop, RandomResizedCropWithBBox, RandomResize, RandomResizeWithBBox, \
|
|
106
|
+
RandomSelectSubpolicy, RandomSharpness, RandomSolarize, RandomVerticalFlip, RandomVerticalFlipWithBBox, Rescale, \
|
|
107
|
+
Resize, ResizeWithBBox, RgbToBgr, Rotate, SlicePatches, UniformAugment, VerticalFlip
|
|
108
|
+
from mindspore.dataset.vision.utils import encode_jpeg, encode_png, get_image_num_channels, get_image_size, \
|
|
109
|
+
read_file, read_image, read_video, read_video_timestamps, write_file, write_jpeg, write_png
|
|
110
|
+
from mindspore.dataset.vision.transforms import AdjustBrightness, AdjustContrast, AdjustGamma as VAdjustGamma, \
|
|
111
|
+
AdjustHue, AdjustSaturation, AdjustSharpness, Affine, AutoAugment as VAutoAugment, AutoContrast as VAutoContrast, \
|
|
112
|
+
BoundingBoxAugment as VBoundingBoxAugment, CenterCrop as VCenterCrop, ConvertColor as VConvertColor, \
|
|
113
|
+
Crop as VCrop, CutMixBatch as VCutMixBatch, CutOut as VCutOut, Decode as VDecode, DecodeVideo, \
|
|
114
|
+
Equalize as VEqualize, Erase, GaussianBlur as VGaussianBlur, HorizontalFlip as VHorizontalFlip, \
|
|
115
|
+
HWC2CHW as VHWC2CHW, Invert as VInvert, MixUpBatch as VMixUpBatch, Normalize as VNormalize, \
|
|
116
|
+
NormalizePad as VNormalizePad, Pad as VPad, PadToSize, Perspective, Posterize, RandAugment, \
|
|
117
|
+
RandomAdjustSharpness as VRandomAdjustSharpness, RandomAffine as VRandomAffine, \
|
|
118
|
+
RandomAutoContrast as VRandomAutoContrast, RandomColor as VRandomColor, RandomColorAdjust as VRandomColorAdjust, \
|
|
119
|
+
RandomCrop as VRandomCrop, RandomCropDecodeResize as VRandomCropDecodeResize, \
|
|
120
|
+
RandomCropWithBBox as VRandomCropWithBBox, RandomEqualize as VRandomEqualize, RandomResize as VRandomResize, \
|
|
121
|
+
RandomHorizontalFlip as VRandomHorizontalFlip, RandomHorizontalFlipWithBBox as VRandomHorizontalFlipWithBBox, \
|
|
122
|
+
RandomInvert as VRandomInvert, RandomLighting as VRandomLighting, RandomPosterize as VRandomPosterize, \
|
|
123
|
+
RandomResizedCrop as VRandomResizedCrop, RandomResizedCropWithBBox as VRandomResizedCropWithBBox, \
|
|
124
|
+
RandomResizeWithBBox as VRandomResizeWithBBox, RandomRotation as VRandomRotation, \
|
|
125
|
+
RandomSelectSubpolicy as VRandomSelectSubpolicy, RandomSharpness as VRandomSharpness, \
|
|
126
|
+
RandomSolarize as VRandomSolarize, RandomVerticalFlip as VRandomVerticalFlip, \
|
|
127
|
+
RandomVerticalFlipWithBBox as VRandomVerticalFlipWithBBox, Rescale as VRescale, Resize as VResize, ResizedCrop, \
|
|
128
|
+
ResizeWithBBox as VResizeWithBBox, Rotate as VRotate, SlicePatches as VSlicePatches, Solarize, ToTensor,\
|
|
129
|
+
TrivialAugmentWide, UniformAugment as VUniformAugment, VerticalFlip as VVerticalFlip
|
|
130
|
+
from mindspore.profiler.profiling import Profiler
|
|
131
|
+
from mindspore.communication._hccl_management import get_rank_size, get_rank_id
|
|
132
|
+
from mindspore.communication._comm_helper import _create_group_helper, _destroy_group_helper
|
|
133
|
+
from mindspore.communication.management import _set_rank_from_mpi, init as cinit, release as crelease
|
|
134
|
+
from mindspore.hal.stream import Stream, synchronize, set_cur_stream, current_stream, default_stream
|
|
135
|
+
from mindspore.hal.event import Event
|
|
136
|
+
from mindspore.hal.memory import memory_stats, memory_reserved, max_memory_allocated, reset_peak_memory_stats, \
|
|
137
|
+
memory_summary, memory_allocated, max_memory_reserved, reset_max_memory_allocated, reset_max_memory_reserved
|
|
138
|
+
from mindspore.multiprocessing import Process
|
|
139
|
+
from mindspore.mindrecord.shardsegment import ShardSegment
|
|
140
|
+
from mindspore.mindrecord.shardreader import ShardReader
|
|
141
|
+
from mindspore.mindrecord.shardindexgenerator import ShardIndexGenerator
|
|
142
|
+
from mindspore.mindrecord.shardwriter import ShardWriter
|
|
143
|
+
from mindspore.mindrecord.shardheader import ShardHeader
|
|
144
|
+
from mindspore.mindrecord.config import encrypt, decrypt
|
|
145
|
+
from mindspore.parallel.mpi._mpi_config import _MpiConfig
|
|
146
|
+
from mindspore.parallel._ps_context import ps_context
|
|
147
|
+
from mindspore.parallel.algo_parameter_config import _AlgoParameterConfig
|
|
148
|
+
from mindspore.parallel._utils import _reset_op_id, _reset_op_id_with_offset
|
|
149
|
+
from mindspore.parallel._recovery_context import recovery_context
|
|
150
|
+
from mindspore.parallel._auto_parallel_context import _AutoParallelContext
|
|
151
|
+
from mindspore.common.api import ms_memory_recycle
|
|
152
|
+
from mindspore.context import _Context
|
|
37
153
|
|
|
38
154
|
|
|
39
155
|
def _get_after_grad_code():
|
|
@@ -51,6 +167,107 @@ def _get_after_grad_code():
|
|
|
51
167
|
return codes
|
|
52
168
|
|
|
53
169
|
|
|
170
|
+
def _get_dataset_forbidden_code():
|
|
171
|
+
"""Get the forbidden function which should be broken in graph"""
|
|
172
|
+
codes = []
|
|
173
|
+
codes.extend([DSCallback.__init__, WaitedDSCallback.__init__])
|
|
174
|
+
codes.extend([CTensorOperation.__call__, COneHot.parse, CFill.__init__, CFill.parse, CTypeCast.parse, \
|
|
175
|
+
CSlice.parse, CMask.__init__, CMask.parse, CPadEnd.__init__, CPadEnd.parse, CConcatenate.__init__, \
|
|
176
|
+
CConcatenate.parse, CDuplicate.parse, CUnique.parse, CCompose.parse, CRandomApply.parse, \
|
|
177
|
+
CRandomChoice.parse, CPlugin.parse])
|
|
178
|
+
codes.extend([TensorOperation.__call__, Compose.parse, Concatenate.__init__, Concatenate.parse, Duplicate.parse, \
|
|
179
|
+
Fill.__init__, Fill.parse, Mask.__init__, Mask.parse, OneHot.parse, PadEnd.__init__, PadEnd.parse, \
|
|
180
|
+
Plugin.parse, RandomApply.parse, RandomChoice.parse, Slice.parse, TypeCast.parse, Unique.parse])
|
|
181
|
+
codes.extend([AddToken.parse, JiebaTokenizer.parse, Lookup.parse, Ngram.parse, SentencePieceTokenizer.parse, \
|
|
182
|
+
SlidingWindow.parse, ToNumber.parse, ToVectors.parse, Truncate.parse, TruncateSequencePair.parse, \
|
|
183
|
+
UnicodeCharTokenizer.parse, WordpieceTokenizer.parse, BasicTokenizer.parse, BertTokenizer.parse, \
|
|
184
|
+
CaseFold.parse, FilterWikipediaXML.parse, NormalizeUTF8.parse, RegexReplace.parse, \
|
|
185
|
+
RegexTokenizer.parse, UnicodeScriptTokenizer.parse, WhitespaceTokenizer.parse])
|
|
186
|
+
codes.extend([create_dct, linear_fbanks, melscale_fbanks])
|
|
187
|
+
codes.extend([AllpassBiquad.parse, AmplitudeToDB.parse, Angle.parse, BandBiquad.parse, BandpassBiquad.parse, \
|
|
188
|
+
BandrejectBiquad.parse, BassBiquad.parse, Biquad.parse, ComplexNorm.parse, ComputeDeltas.parse, \
|
|
189
|
+
Contrast.parse, DBToAmplitude.parse, DCShift.parse, DeemphBiquad.parse, DetectPitchFrequency.parse, \
|
|
190
|
+
Dither.parse, EqualizerBiquad.parse, Fade.parse, Filtfilt.parse, Flanger.parse, \
|
|
191
|
+
FrequencyMasking.parse, Gain.parse, GriffinLim.parse, HighpassBiquad.parse, InverseMelScale.parse, \
|
|
192
|
+
InverseSpectrogram.parse, LFCC.parse, LFilter.parse, LowpassBiquad.parse, Magphase.parse, \
|
|
193
|
+
MaskAlongAxis.parse, MaskAlongAxisIID.parse, MelScale.parse, MelSpectrogram.parse, MFCC.parse, \
|
|
194
|
+
MuLawDecoding.parse, MuLawEncoding.parse, Overdrive.parse, Phaser.parse, PhaseVocoder.parse, \
|
|
195
|
+
PhaseVocoder.__init__, PitchShift.parse, Resample.parse, RiaaBiquad.parse, SlidingWindowCmn.parse, \
|
|
196
|
+
SpectralCentroid.parse, Spectrogram.parse, TimeMasking.parse, TimeStretch.parse, \
|
|
197
|
+
TrebleBiquad.parse, Vad.parse, Vol.parse])
|
|
198
|
+
codes.extend([CMUArcticDataset.parse, GTZANDataset.parse, LibriTTSDataset.parse, LJSpeechDataset.parse, \
|
|
199
|
+
SpeechCommandsDataset.parse, TedliumDataset.parse, YesNoDataset.parse])
|
|
200
|
+
codes.extend([DatasetCache.__init__, Iterator.__init__])
|
|
201
|
+
codes.extend([CSVDataset.parse, MindDataset.parse, TFRecordDataset.parse])
|
|
202
|
+
codes.extend([AGNewsDataset.parse, AmazonReviewDataset.parse, CLUEDataset.parse, CoNLL2000Dataset.parse, \
|
|
203
|
+
DBpediaDataset.parse, EnWik9Dataset.parse, IMDBDataset.parse, IWSLT2016Dataset.parse, \
|
|
204
|
+
IWSLT2017Dataset.parse, Multi30kDataset.parse, PennTreebankDataset.parse, SogouNewsDataset.parse, \
|
|
205
|
+
SQuADDataset.parse, SST2Dataset.parse, TextFileDataset.parse, UDPOSDataset.parse, \
|
|
206
|
+
WikiTextDataset.parse, YahooAnswersDataset.parse, YelpReviewDataset.parse])
|
|
207
|
+
codes.extend([GeneratorDataset.parse, Caltech256Dataset.parse, CelebADataset.parse, Cifar10Dataset.parse, \
|
|
208
|
+
Cifar100Dataset.parse, CityscapesDataset.parse, CocoDataset.parse, DIV2KDataset.parse, \
|
|
209
|
+
EMnistDataset.parse, FakeImageDataset.parse, FashionMnistDataset.parse, FlickrDataset.parse, \
|
|
210
|
+
Food101Dataset.parse, ImageFolderDataset.parse, KITTIDataset.parse, KMnistDataset.parse, \
|
|
211
|
+
LFWDataset.parse, LSUNDataset.parse, ManifestDataset.parse, MnistDataset.parse, VOCDataset.parse, \
|
|
212
|
+
OmniglotDataset.parse, PhotoTourDataset.parse, Places365Dataset.parse, QMnistDataset.parse, \
|
|
213
|
+
RandomDataset.parse, RenderedSST2Dataset.parse, SBUDataset.parse, SemeionDataset.parse, \
|
|
214
|
+
STL10Dataset.parse, SUN397Dataset.parse, USPSDataset.parse, WIDERFaceDataset.parse])
|
|
215
|
+
codes.extend([_SharedQueue.put, _SharedQueue.get])
|
|
216
|
+
codes.extend([BucketBatchByLengthDataset.parse, BatchDataset.parse, \
|
|
217
|
+
PaddedBatchDataset.parse, SyncWaitDataset.parse, ShuffleDataset.parse, MapDataset.parse, \
|
|
218
|
+
FilterDataset.parse, RepeatDataset.parse, SkipDataset.parse, TakeDataset.parse, ZipDataset.parse, \
|
|
219
|
+
ConcatDataset.parse, RenameDataset.parse, ProjectDataset.parse, _ToDevice.__init__, \
|
|
220
|
+
TransferDataset.parse, Schema.__init__, Schema.add_column, Dataset.save])
|
|
221
|
+
codes.extend([Sampler.parse, DistributedSampler.parse, DistributedSampler.parse_for_minddataset, PKSampler.parse, \
|
|
222
|
+
PKSampler.parse_for_minddataset, RandomSampler.parse, RandomSampler.parse_for_minddataset, \
|
|
223
|
+
SequentialSampler.parse, SequentialSampler.parse_for_minddataset, SubsetSampler.parse, \
|
|
224
|
+
SubsetSampler.parse_for_minddataset, SubsetRandomSampler.parse, \
|
|
225
|
+
SubsetRandomSampler.parse_for_minddataset, WeightedRandomSampler.parse])
|
|
226
|
+
codes.extend([ImageTensorOperation.__call__, AdjustGamma.parse, AutoAugment.parse, AutoContrast.parse, Pad.parse, \
|
|
227
|
+
BoundingBoxAugment.parse, CenterCrop.parse, ConvertColor.parse, Crop.parse, CutMixBatch.parse, \
|
|
228
|
+
CutOut.parse, Decode.parse, Equalize.parse, GaussianBlur.parse, HorizontalFlip.parse, \
|
|
229
|
+
Invert.parse, MixUpBatch.parse, Normalize.parse, NormalizePad.parse, HWC2CHW.parse, \
|
|
230
|
+
RandomAdjustSharpness.parse, RandomAffine.parse, RandomAutoContrast.parse, RandomColor.parse, \
|
|
231
|
+
RandomColorAdjust.parse, RandomCrop.parse, RandomCropDecodeResize.parse, RandomCropWithBBox.parse, \
|
|
232
|
+
RandomEqualize.parse, RandomHorizontalFlip.parse, RandomHorizontalFlipWithBBox.parse, \
|
|
233
|
+
RandomInvert.parse, RandomLighting.parse, RandomPosterize.parse, RandomResizedCrop.parse, \
|
|
234
|
+
RandomResizedCropWithBBox.parse, RandomResize.parse, RandomResizeWithBBox.parse, Resize.parse, \
|
|
235
|
+
RandomRotation.parse, RandomSelectSubpolicy.parse, RandomSharpness.parse, RandomSolarize.parse, \
|
|
236
|
+
RandomVerticalFlip.parse, RandomVerticalFlipWithBBox.parse, Rescale.parse, VerticalFlip.parse, \
|
|
237
|
+
ResizeWithBBox.parse, RgbToBgr.parse, Rotate.parse, SlicePatches.parse, UniformAugment.parse])
|
|
238
|
+
codes.extend([encode_jpeg, encode_png, get_image_num_channels, get_image_size, read_file, read_image, read_video, \
|
|
239
|
+
read_video_timestamps, write_file, write_jpeg, write_png])
|
|
240
|
+
codes.extend([AdjustBrightness.parse, AdjustContrast.parse, VAdjustGamma.parse, AdjustHue.parse, \
|
|
241
|
+
AdjustSaturation.parse, AdjustSharpness.parse, Affine.parse, VAutoAugment.parse, \
|
|
242
|
+
VAutoContrast.parse, VBoundingBoxAugment.parse, VCenterCrop.parse, VConvertColor.parse, \
|
|
243
|
+
VCrop.parse, VCutMixBatch.parse, VCutOut.parse, VDecode.parse, DecodeVideo.parse, \
|
|
244
|
+
VEqualize.parse, Erase.parse, VGaussianBlur.parse, VHorizontalFlip.parse, \
|
|
245
|
+
VHWC2CHW.parse, VInvert.parse, VMixUpBatch.parse, VNormalize.parse, VNormalizePad.parse, \
|
|
246
|
+
VPad.parse, PadToSize.parse, Perspective.parse, Posterize.parse, RandAugment.parse, \
|
|
247
|
+
VRandomAdjustSharpness.parse, VRandomAffine.parse, \
|
|
248
|
+
VRandomAutoContrast.parse, VRandomColor.parse, VRandomColorAdjust.parse, \
|
|
249
|
+
VRandomCrop.parse, VRandomCropDecodeResize.parse, \
|
|
250
|
+
VRandomCropWithBBox.parse, VRandomEqualize.parse, VRandomResize.parse, \
|
|
251
|
+
VRandomHorizontalFlip.parse, VRandomHorizontalFlipWithBBox.parse, \
|
|
252
|
+
VRandomInvert.parse, VRandomLighting.parse, VRandomPosterize.parse, \
|
|
253
|
+
VRandomResizedCrop.parse, VRandomResizedCropWithBBox.parse, \
|
|
254
|
+
VRandomResizeWithBBox.parse, VRandomRotation.parse, \
|
|
255
|
+
VRandomSelectSubpolicy.parse, VRandomSharpness.parse, \
|
|
256
|
+
VRandomSolarize.parse, VRandomVerticalFlip.parse, \
|
|
257
|
+
VRandomVerticalFlipWithBBox.parse, VRescale.parse, VResize.parse, ResizedCrop.parse, \
|
|
258
|
+
VResizeWithBBox.parse, VRotate.parse, VSlicePatches.parse, Solarize.parse, ToTensor.parse,\
|
|
259
|
+
TrivialAugmentWide.parse, VUniformAugment.parse, VVerticalFlip])
|
|
260
|
+
return codes
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
def _get_math_code():
|
|
264
|
+
"""Get the math builtin function which should be guarded in graph"""
|
|
265
|
+
codes = []
|
|
266
|
+
for i in dir(math):
|
|
267
|
+
codes.append(getattr(math, i))
|
|
268
|
+
return codes
|
|
269
|
+
|
|
270
|
+
|
|
54
271
|
def _get_psjit_code():
|
|
55
272
|
"""Get the code object of 'staging_specialize'"""
|
|
56
273
|
@jit
|
|
@@ -95,30 +312,13 @@ def _get_pijit_constexpr_code():
|
|
|
95
312
|
return codes
|
|
96
313
|
|
|
97
314
|
|
|
98
|
-
def _get_ms_api():
|
|
99
|
-
"""Get ms api"""
|
|
100
|
-
target_types = Cell, types.FunctionType, Primitive_, PrimitiveFunction_
|
|
101
|
-
results = []
|
|
102
|
-
from mindspore.ops import operations as P
|
|
103
|
-
from mindspore.ops import functional as F
|
|
104
|
-
from mindspore.ops import composite as C
|
|
105
|
-
mods = P, F, C
|
|
106
|
-
for mod in mods:
|
|
107
|
-
for i in mod.__all__:
|
|
108
|
-
f = getattr(mod, i)
|
|
109
|
-
if isinstance(f, target_types):
|
|
110
|
-
results.append(f)
|
|
111
|
-
for f in tensor_operator_registry.values():
|
|
112
|
-
if isinstance(f, target_types):
|
|
113
|
-
results.append(f)
|
|
114
|
-
return results
|
|
115
|
-
|
|
116
|
-
|
|
117
315
|
psjit_code = _get_psjit_code()
|
|
118
316
|
constexpr_code = _get_constexpr_code()
|
|
119
317
|
primexpr_code = _get_primexpr_code()
|
|
120
318
|
|
|
121
319
|
primitive_key = id(Primitive.__call__)
|
|
320
|
+
primitive_assign_key = id(P.Assign.__call__)
|
|
321
|
+
|
|
122
322
|
constexpr_key = id(constexpr_code)
|
|
123
323
|
primexpr_key = id(primexpr_code)
|
|
124
324
|
meta_func_graph_key = id(MetaFuncGraph_)
|
|
@@ -153,6 +353,13 @@ FUNC_KEY_PSJIT_CONVERTMAP = 15 # "mindspore._extends.parse.resources.convert_obj
|
|
|
153
353
|
FUNC_KEY_GRAPH_CELL = 16 # "mindspore.nn.cell.GraphCell"
|
|
154
354
|
FUNC_KEY_MS_API = 17 # mindspore common api
|
|
155
355
|
FUNC_KEY_MAPPING_GET = 18 # collections.abc.Mapping.get
|
|
356
|
+
FUNC_KEY_LIST_POP = 19 # list.pop
|
|
357
|
+
FUNC_KEY_LIST_REMOVE = 20 # list.remove
|
|
358
|
+
FUNC_KEY_LIST_REVERSE = 21 # list.reverse
|
|
359
|
+
FUNC_KEY_DICT_ITEMS = 22 # dict.items
|
|
360
|
+
FUNC_KEY_PRIMITIVE_ASSIGN = 23 # mindspore.ops.assign, Primitive("Assign")
|
|
361
|
+
FUNC_KEY_TENSOR_SETITEM = 24 # Tensor.__setitem__
|
|
362
|
+
FUNC_KEY_TENSOR_ASSIGN_VALUE = 25 # Tensor.assign_value
|
|
156
363
|
|
|
157
364
|
# Initialized only once. This map will initialize by c++ when start pijit.
|
|
158
365
|
# key is customer if fuzzy match. (Primitive, constexpr, primexpr, MetaFuncGraph)
|
|
@@ -162,6 +369,8 @@ FUNC_KEY_MAPPING_GET = 18 # collections.abc.Mapping.get
|
|
|
162
369
|
_func_map = {
|
|
163
370
|
# special function
|
|
164
371
|
pijit_constexpr_key: FUNC_KEY_PIJIT_CONSTEXPR,
|
|
372
|
+
id(getattr(array_func, "_get_max_type")): FUNC_KEY_PIJIT_CONSTEXPR,
|
|
373
|
+
id(Cell.__getattr__): FUNC_KEY_PIJIT_CONSTEXPR,
|
|
165
374
|
pijit_forbidden_key: FUNC_KEY_PIJIT_FORBIDDEN,
|
|
166
375
|
primitive_key: FUNC_KEY_PRIMITIVE,
|
|
167
376
|
constexpr_key: FUNC_KEY_CONSTEXPR,
|
|
@@ -172,6 +381,12 @@ _func_map = {
|
|
|
172
381
|
id(_get_cache_prim): FUNC_KEY_GET_CACHE_PRIM,
|
|
173
382
|
id(Registry.get): FUNC_KEY_REGISTRY_GET,
|
|
174
383
|
|
|
384
|
+
# tensor side-effect
|
|
385
|
+
primitive_assign_key: FUNC_KEY_PRIMITIVE_ASSIGN,
|
|
386
|
+
id(F.assign): FUNC_KEY_PRIMITIVE_ASSIGN,
|
|
387
|
+
id(Tensor.assign_value): FUNC_KEY_TENSOR_ASSIGN_VALUE,
|
|
388
|
+
id(Tensor.__setitem__): FUNC_KEY_TENSOR_SETITEM,
|
|
389
|
+
|
|
175
390
|
# Tensor method
|
|
176
391
|
id(Tensor.astype): FUNC_KEY_TENSOR_ASTYPE,
|
|
177
392
|
|
|
@@ -181,6 +396,7 @@ _func_map = {
|
|
|
181
396
|
function_id(len): FUNC_KEY_BUILTIN_FUNC,
|
|
182
397
|
function_id(abs): FUNC_KEY_BUILTIN_FUNC,
|
|
183
398
|
function_id(max): FUNC_KEY_BUILTIN_FUNC,
|
|
399
|
+
function_id(min): FUNC_KEY_BUILTIN_FUNC,
|
|
184
400
|
function_id(all): FUNC_KEY_BUILTIN_FUNC,
|
|
185
401
|
function_id(any): FUNC_KEY_BUILTIN_FUNC,
|
|
186
402
|
function_id(hash): FUNC_KEY_BUILTIN_FUNC,
|
|
@@ -189,6 +405,10 @@ _func_map = {
|
|
|
189
405
|
function_id(callable): FUNC_KEY_BUILTIN_FUNC,
|
|
190
406
|
function_id(getattr): FUNC_KEY_BUILTIN_FUNC,
|
|
191
407
|
function_id(hasattr): FUNC_KEY_BUILTIN_FUNC,
|
|
408
|
+
function_id(chr): FUNC_KEY_BUILTIN_FUNC,
|
|
409
|
+
function_id(divmod): FUNC_KEY_BUILTIN_FUNC,
|
|
410
|
+
function_id(repr): FUNC_KEY_BUILTIN_FUNC,
|
|
411
|
+
function_id(type): FUNC_KEY_BUILTIN_FUNC,
|
|
192
412
|
|
|
193
413
|
# types.MethodDescriptorType, types.WrapperDescriptorType
|
|
194
414
|
function_id(tuple.__getitem__): FUNC_KEY_BUILTIN_FUNC,
|
|
@@ -232,7 +452,11 @@ _func_map = {
|
|
|
232
452
|
function_id(str.format_map): FUNC_KEY_BUILTIN_FUNC,
|
|
233
453
|
function_id(str.__format__): FUNC_KEY_BUILTIN_FUNC,
|
|
234
454
|
function_id(list.append): FUNC_KEY_LIST_APPEND,
|
|
455
|
+
function_id(list.pop): FUNC_KEY_LIST_POP,
|
|
456
|
+
function_id(list.remove): FUNC_KEY_LIST_REMOVE,
|
|
457
|
+
function_id(list.reverse): FUNC_KEY_LIST_REVERSE,
|
|
235
458
|
function_id(dict.pop): FUNC_KEY_DICT_POP,
|
|
459
|
+
function_id(dict.items): FUNC_KEY_DICT_ITEMS,
|
|
236
460
|
|
|
237
461
|
# instancemethod
|
|
238
462
|
function_id(Tensor_._flatten_tensors): FUNC_KEY_BUILTIN_FUNC, # pylint: disable=protected-access
|
|
@@ -254,23 +478,125 @@ _func_map = {
|
|
|
254
478
|
|
|
255
479
|
# other builtin function
|
|
256
480
|
function_id(collections.abc.Mapping.get): FUNC_KEY_MAPPING_GET,
|
|
257
|
-
function_id(math.log): FUNC_KEY_BUILTIN_FUNC,
|
|
258
|
-
|
|
259
481
|
function_id(numpy.isinf): FUNC_KEY_BUILTIN_FUNC,
|
|
260
482
|
function_id(numpy.isnan): FUNC_KEY_BUILTIN_FUNC,
|
|
261
483
|
function_id(numpy.abs): FUNC_KEY_BUILTIN_FUNC,
|
|
262
484
|
function_id(numpy.log): FUNC_KEY_BUILTIN_FUNC,
|
|
485
|
+
function_id(os.getenv): FUNC_KEY_BUILTIN_FUNC,
|
|
263
486
|
|
|
264
487
|
# const function
|
|
265
|
-
function_id(os.getenv): FUNC_KEY_PIJIT_CONSTEXPR,
|
|
266
488
|
function_id(validator.check_number_range): FUNC_KEY_PIJIT_CONSTEXPR,
|
|
267
489
|
function_id(validator.check_is_int): FUNC_KEY_PIJIT_CONSTEXPR,
|
|
268
490
|
function_id(validator.check_is_number): FUNC_KEY_PIJIT_CONSTEXPR,
|
|
491
|
+
function_id(np_version_valid): FUNC_KEY_PIJIT_CONSTEXPR,
|
|
492
|
+
function_id(_is_initialized): FUNC_KEY_PIJIT_CONSTEXPR,
|
|
493
|
+
function_id(_set_elegant_exit_handle): FUNC_KEY_PIJIT_CONSTEXPR,
|
|
494
|
+
function_id(_cost_model_context.get_cost_model_context): FUNC_KEY_PIJIT_CONSTEXPR,
|
|
495
|
+
function_id(Stream.__repr__): FUNC_KEY_PIJIT_CONSTEXPR,
|
|
496
|
+
function_id(get_rank_size): FUNC_KEY_PIJIT_CONSTEXPR,
|
|
497
|
+
function_id(get_rank_id): FUNC_KEY_PIJIT_CONSTEXPR,
|
|
498
|
+
function_id(offload_context): FUNC_KEY_PIJIT_CONSTEXPR,
|
|
499
|
+
function_id(check_version_and_env_config): FUNC_KEY_PIJIT_CONSTEXPR,
|
|
500
|
+
|
|
501
|
+
# inner function
|
|
502
|
+
function_id(type_size_in_bytes): FUNC_KEY_BUILTIN_FUNC,
|
|
503
|
+
function_id(_get_rank_helper): FUNC_KEY_BUILTIN_FUNC,
|
|
504
|
+
function_id(_get_local_rank_helper): FUNC_KEY_BUILTIN_FUNC,
|
|
505
|
+
function_id(_get_size_helper): FUNC_KEY_BUILTIN_FUNC,
|
|
506
|
+
function_id(_get_local_size_helper): FUNC_KEY_BUILTIN_FUNC,
|
|
507
|
+
function_id(_get_world_rank_from_group_rank_helper): FUNC_KEY_BUILTIN_FUNC,
|
|
508
|
+
function_id(_get_group_ranks): FUNC_KEY_BUILTIN_FUNC,
|
|
509
|
+
function_id(_get_group_rank_from_world_rank_helper): FUNC_KEY_BUILTIN_FUNC,
|
|
510
|
+
function_id(nptype_to_detype): FUNC_KEY_BUILTIN_FUNC,
|
|
511
|
+
function_id(mstype_to_detype): FUNC_KEY_BUILTIN_FUNC,
|
|
512
|
+
function_id(mstypelist_to_detypelist): FUNC_KEY_BUILTIN_FUNC,
|
|
513
|
+
|
|
514
|
+
# no need to capture function in black list
|
|
515
|
+
function_id(SummaryCollector.__init__): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
516
|
+
function_id(SummaryCollector.begin): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
517
|
+
function_id(SummaryCollector.step_end): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
518
|
+
function_id(SummaryCollector.epoch_end): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
519
|
+
function_id(SummaryCollector.end): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
520
|
+
function_id(ModelCheckpoint.step_end): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
521
|
+
function_id(ModelCheckpoint.end): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
522
|
+
function_id(LossMonitor.step_end): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
523
|
+
function_id(LossMonitor.on_train_epoch_end): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
524
|
+
function_id(SummaryRecord.__init__): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
525
|
+
function_id(_exec_datagraph): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
526
|
+
function_id(BaseWriter.writer): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
527
|
+
function_id(_FrameworkProfilerCallback.step_begin): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
528
|
+
function_id(_FrameworkProfilerCallback.step_end): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
529
|
+
function_id(_init_sink_dataset): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
530
|
+
function_id(_exec_save): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
531
|
+
function_id(load): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
532
|
+
function_id(export_split_mindir): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
533
|
+
function_id(obfuscate_model): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
534
|
+
function_id(_parse_ckpt_proto): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
535
|
+
function_id(_generate_front_info_for_param_data_file): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
536
|
+
function_id(_get_data_file): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
537
|
+
function_id(_encrypt_data): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
538
|
+
function_id(_split_save): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
539
|
+
function_id(_save_mindir_together): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
540
|
+
function_id(_load_into_param_dict): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
541
|
+
function_id(Profiler.start): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
542
|
+
function_id(_create_group_helper): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
543
|
+
function_id(_destroy_group_helper): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
544
|
+
function_id(_set_rank_from_mpi): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
545
|
+
function_id(cinit): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
546
|
+
function_id(crelease): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
547
|
+
function_id(Stream.__init__): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
548
|
+
function_id(Stream.synchronize): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
549
|
+
function_id(Stream.query): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
550
|
+
function_id(Stream.__eq__): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
551
|
+
function_id(synchronize): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
552
|
+
function_id(set_cur_stream): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
553
|
+
function_id(current_stream): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
554
|
+
function_id(default_stream): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
555
|
+
function_id(Event.__init__): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
556
|
+
function_id(Event.record): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
557
|
+
function_id(Event.wait): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
558
|
+
function_id(Event.synchronize): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
559
|
+
function_id(Event.query): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
560
|
+
function_id(Event.elapsed_time): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
561
|
+
function_id(memory_stats): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
562
|
+
function_id(memory_reserved): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
563
|
+
function_id(max_memory_reserved): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
564
|
+
function_id(reset_peak_memory_stats): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
565
|
+
function_id(reset_peak_memory_stats): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
566
|
+
function_id(memory_summary): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
567
|
+
function_id(memory_allocated): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
568
|
+
function_id(max_memory_allocated): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
569
|
+
function_id(reset_max_memory_reserved): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
570
|
+
function_id(reset_max_memory_allocated): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
571
|
+
function_id(Process.run): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
572
|
+
function_id(Process.start): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
573
|
+
function_id(ShardSegment.__init__): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
574
|
+
function_id(ShardReader.__init__): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
575
|
+
function_id(ShardIndexGenerator.__init__): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
576
|
+
function_id(ShardWriter.__init__): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
577
|
+
function_id(ShardHeader.__init__): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
578
|
+
function_id(encrypt): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
579
|
+
function_id(decrypt): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
580
|
+
function_id(_MpiConfig.__init__): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
581
|
+
function_id(ps_context): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
582
|
+
function_id(_AlgoParameterConfig.__init__): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
583
|
+
function_id(_reset_op_id): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
584
|
+
function_id(_reset_op_id_with_offset): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
585
|
+
function_id(recovery_context): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
586
|
+
function_id(_AutoParallelContext.__init__): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
587
|
+
function_id(ms_memory_recycle): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
588
|
+
function_id(_Context.__init__): FUNC_KEY_PIJIT_FORBIDDEN,
|
|
269
589
|
}
|
|
270
590
|
|
|
271
591
|
for after_grad in _get_after_grad_code():
|
|
272
592
|
_func_map[id(after_grad)] = FUNC_KEY_GRAD_OPERATIONS_CODE
|
|
273
593
|
|
|
594
|
+
for func in _get_dataset_forbidden_code():
|
|
595
|
+
_func_map[function_id(func)] = FUNC_KEY_PIJIT_FORBIDDEN
|
|
596
|
+
|
|
597
|
+
for math_func in _get_math_code():
|
|
598
|
+
_func_map[function_id(math_func)] = FUNC_KEY_BUILTIN_FUNC
|
|
599
|
+
|
|
274
600
|
for k, v in convert_object_map.items():
|
|
275
601
|
key = id(k)
|
|
276
602
|
if key not in _func_map and isinstance(v, Primitive):
|
mindspore/amp.py
CHANGED
|
@@ -21,8 +21,10 @@ from mindspore.common import mutable
|
|
|
21
21
|
from mindspore.ops._primitive_cache import _get_cache_prim
|
|
22
22
|
from mindspore.ops.operations.math_ops import NPUGetFloatStatusV2, NPUClearFloatStatusV2
|
|
23
23
|
from mindspore.ops.operations.nn_ops import AllFinite
|
|
24
|
+
from mindspore.run_check._check_version import AscendEnvChecker
|
|
24
25
|
from mindspore import _checkparam as validator
|
|
25
26
|
from mindspore._c_expression import MSContext
|
|
27
|
+
from mindspore import log as logger
|
|
26
28
|
from .common import dtype as mstype
|
|
27
29
|
from . import context
|
|
28
30
|
from . import ops
|
|
@@ -31,10 +33,9 @@ from .common.api import jit_class, jit
|
|
|
31
33
|
from .common.parameter import Parameter
|
|
32
34
|
from .common.tensor import Tensor
|
|
33
35
|
from .train.loss_scale_manager import DynamicLossScaleManager, LossScaleManager, FixedLossScaleManager
|
|
34
|
-
from .train.amp import build_train_network, auto_mixed_precision, custom_mixed_precision
|
|
36
|
+
from .train.amp import build_train_network, auto_mixed_precision, custom_mixed_precision, \
|
|
35
37
|
get_white_list, get_black_list
|
|
36
38
|
|
|
37
|
-
|
|
38
39
|
_hypermap = ops.HyperMap()
|
|
39
40
|
_partial = ops.Partial()
|
|
40
41
|
|
|
@@ -50,8 +51,8 @@ def _ascend_910a_target():
|
|
|
50
51
|
|
|
51
52
|
|
|
52
53
|
@constexpr
|
|
53
|
-
def
|
|
54
|
-
return MSContext.get_instance().get_ascend_soc_version() in ["ascend910b", "
|
|
54
|
+
def _ascend_910b_target():
|
|
55
|
+
return MSContext.get_instance().get_ascend_soc_version() in ["ascend910b", "ascend910_93"]
|
|
55
56
|
|
|
56
57
|
|
|
57
58
|
@constexpr
|
|
@@ -62,15 +63,24 @@ def _gpu_target():
|
|
|
62
63
|
@constexpr
|
|
63
64
|
def _enable_all_finite():
|
|
64
65
|
"""check whether enable all finite"""
|
|
66
|
+
logger.debug("Start enable all finite.")
|
|
67
|
+
if _ascend_target():
|
|
68
|
+
checker = AscendEnvChecker(None)
|
|
69
|
+
if not checker.check_custom_version():
|
|
70
|
+
logger.debug("Disable AllFinite due to version check failure.")
|
|
71
|
+
return False
|
|
65
72
|
runtime_conf = os.environ.get('MS_DEV_RUNTIME_CONF')
|
|
66
73
|
global_jit_config = context.get_jit_config()
|
|
67
74
|
if runtime_conf is not None and ("all_finite:True" in runtime_conf or "all_finite:true" in runtime_conf):
|
|
75
|
+
logger.debug("Enable AllFinite through the environment variable MS_DEV_RUNTIME_CONF.")
|
|
68
76
|
return True
|
|
69
77
|
|
|
70
78
|
if runtime_conf is not None and ("all_finite:False" in runtime_conf or "all_finite:false" in runtime_conf):
|
|
79
|
+
logger.debug("Disable AllFinite through the environment variable MS_DEV_RUNTIME_CONF.")
|
|
71
80
|
return False
|
|
72
81
|
|
|
73
82
|
if global_jit_config:
|
|
83
|
+
logger.debug("Current global jit config is: {}".format(global_jit_config["jit_level"]))
|
|
74
84
|
return global_jit_config["jit_level"] == "O0" or global_jit_config["jit_level"] == "O1"
|
|
75
85
|
return False
|
|
76
86
|
|
|
@@ -105,7 +115,7 @@ def _all_finite(inputs, check_overflow_mode, enable_allfinite):
|
|
|
105
115
|
"""all finite check"""
|
|
106
116
|
if _ascend_target():
|
|
107
117
|
if (_ascend_910a_target()) or \
|
|
108
|
-
|
|
118
|
+
(_ascend_910b_target() and check_overflow_mode == "SATURATION_MODE"):
|
|
109
119
|
status = Tensor([0] * 8, mstype.int32)
|
|
110
120
|
status = ops.depend(status, inputs)
|
|
111
121
|
get_status = _get_cache_prim(NPUGetFloatStatusV2)()(status)
|
|
@@ -153,7 +163,7 @@ def all_finite(inputs):
|
|
|
153
163
|
|
|
154
164
|
Tutorial Examples:
|
|
155
165
|
- `Automatic Mix Precision - Loss Scaling
|
|
156
|
-
<https://mindspore.cn/tutorials/en/master/
|
|
166
|
+
<https://mindspore.cn/tutorials/en/master/beginner/mixed_precision.html#loss-scaling>`_
|
|
157
167
|
"""
|
|
158
168
|
inputs = mutable(inputs)
|
|
159
169
|
_check_overflow_mode = os.environ.get('MS_ASCEND_CHECK_OVERFLOW_MODE')
|
|
@@ -169,7 +179,7 @@ class LossScaler(ABC):
|
|
|
169
179
|
to scale and unscale the loss value and gradients to avoid overflow, `adjust` is used to update the
|
|
170
180
|
loss scale value.
|
|
171
181
|
|
|
172
|
-
For more information, refer to the `tutorials <https://mindspore.cn/tutorials/en/master/
|
|
182
|
+
For more information, refer to the `tutorials <https://mindspore.cn/tutorials/en/master/beginner/
|
|
173
183
|
mixed_precision.html#loss-scaling>`_.
|
|
174
184
|
|
|
175
185
|
.. warning::
|
|
@@ -200,6 +210,7 @@ class LossScaler(ABC):
|
|
|
200
210
|
>>>
|
|
201
211
|
>>> loss_scaler = MyLossScaler(1024)
|
|
202
212
|
"""
|
|
213
|
+
|
|
203
214
|
@abstractmethod
|
|
204
215
|
def scale(self, inputs):
|
|
205
216
|
"""
|
|
@@ -262,6 +273,7 @@ class StaticLossScaler(LossScaler):
|
|
|
262
273
|
(Tensor(shape=[2], dtype=Float16, value= [ 1.4648e-03, 9.7656e-04]),
|
|
263
274
|
Tensor(shape=[1], dtype=Float16, value= [ 1.1721e-03]))
|
|
264
275
|
"""
|
|
276
|
+
|
|
265
277
|
def __init__(self, scale_value):
|
|
266
278
|
scale_value = validator.check_value_type("scale_value", scale_value, [float, int])
|
|
267
279
|
if scale_value < 1.0:
|
|
@@ -340,6 +352,7 @@ class DynamicLossScaler(LossScaler):
|
|
|
340
352
|
>>> print(loss_scaler.scale_value.asnumpy())
|
|
341
353
|
512.0
|
|
342
354
|
"""
|
|
355
|
+
|
|
343
356
|
def __init__(self, scale_value, scale_factor, scale_window):
|
|
344
357
|
scale_value = validator.check_value_type("scale_value", scale_value, [float, int])
|
|
345
358
|
if scale_value < 1.0:
|
|
@@ -361,7 +374,7 @@ class DynamicLossScaler(LossScaler):
|
|
|
361
374
|
|
|
362
375
|
Tutorial Examples:
|
|
363
376
|
- `Automatic Mix Precision - Loss Scaling
|
|
364
|
-
<https://mindspore.cn/tutorials/en/master/
|
|
377
|
+
<https://mindspore.cn/tutorials/en/master/beginner/mixed_precision.html#loss-scaling>`_
|
|
365
378
|
"""
|
|
366
379
|
inputs = mutable(inputs)
|
|
367
380
|
return _grad_scale_map(self.scale_value, inputs)
|
|
@@ -378,7 +391,7 @@ class DynamicLossScaler(LossScaler):
|
|
|
378
391
|
|
|
379
392
|
Tutorial Examples:
|
|
380
393
|
- `Automatic Mix Precision - Loss Scaling
|
|
381
|
-
<https://mindspore.cn/tutorials/en/master/
|
|
394
|
+
<https://mindspore.cn/tutorials/en/master/beginner/mixed_precision.html#loss-scaling>`_
|
|
382
395
|
"""
|
|
383
396
|
inputs = mutable(inputs)
|
|
384
397
|
return _grad_unscale_map(self.scale_value, inputs)
|
|
@@ -392,7 +405,7 @@ class DynamicLossScaler(LossScaler):
|
|
|
392
405
|
|
|
393
406
|
Tutorial Examples:
|
|
394
407
|
- `Automatic Mix Precision - Loss Scaling
|
|
395
|
-
<https://mindspore.cn/tutorials/en/master/
|
|
408
|
+
<https://mindspore.cn/tutorials/en/master/beginner/mixed_precision.html#loss-scaling>`_
|
|
396
409
|
"""
|
|
397
410
|
one = ops.ones((), self.scale_value.dtype)
|
|
398
411
|
scale_mul_factor = self.scale_value * self.scale_factor
|
|
@@ -411,6 +424,7 @@ class DynamicLossScaler(LossScaler):
|
|
|
411
424
|
ops.assign(self.counter, counter)
|
|
412
425
|
return True
|
|
413
426
|
|
|
427
|
+
|
|
414
428
|
__all__ = [
|
|
415
429
|
"DynamicLossScaleManager", "LossScaleManager", "FixedLossScaleManager",
|
|
416
430
|
"build_train_network", "DynamicLossScaler", "StaticLossScaler", "LossScaler",
|
mindspore/avcodec-59.dll
CHANGED
|
Binary file
|
mindspore/avdevice-59.dll
CHANGED
|
Binary file
|
mindspore/avfilter-8.dll
CHANGED
|
Binary file
|
mindspore/avformat-59.dll
CHANGED
|
Binary file
|
mindspore/avutil-57.dll
CHANGED
|
Binary file
|
mindspore/common/__init__.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2020-
|
|
1
|
+
# Copyright 2020-2024 Huawei Technologies Co., Ltd
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -37,6 +37,7 @@ from mindspore.common.recompute import recompute
|
|
|
37
37
|
from mindspore.common import generator
|
|
38
38
|
from mindspore.common.generator import (
|
|
39
39
|
Generator, default_generator, seed, manual_seed, initial_seed, get_rng_state, set_rng_state)
|
|
40
|
+
from mindspore.ops.function.array_func import is_tensor, from_numpy
|
|
40
41
|
|
|
41
42
|
# symbols from dtype
|
|
42
43
|
__all__ = [
|
|
@@ -67,11 +68,11 @@ __all__ = [
|
|
|
67
68
|
]
|
|
68
69
|
|
|
69
70
|
__all__.extend([
|
|
70
|
-
"tensor", "Tensor", "RowTensor", "SparseTensor", "COOTensor", "CSRTensor",
|
|
71
|
+
"tensor", "Tensor", "RowTensor", "SparseTensor", "COOTensor", "CSRTensor", # tensor
|
|
71
72
|
"ms_function", "ms_class", 'jit', 'jit_class', '_no_grad', # api
|
|
72
73
|
"Parameter", "ParameterTuple", # parameter
|
|
73
74
|
"dtype",
|
|
74
|
-
"set_seed", "get_seed",
|
|
75
|
+
"set_seed", "get_seed", "manual_seed", # random seed
|
|
75
76
|
"set_dump",
|
|
76
77
|
"ms_memory_recycle",
|
|
77
78
|
"mutable", "JitConfig",
|
|
@@ -79,6 +80,7 @@ __all__.extend([
|
|
|
79
80
|
"lazy_inline", "load_mindir", "save_mindir",
|
|
80
81
|
"no_inline",
|
|
81
82
|
"Symbol",
|
|
82
|
-
"recompute"
|
|
83
|
+
"recompute",
|
|
84
|
+
"is_tensor", "from_numpy",
|
|
83
85
|
])
|
|
84
86
|
__all__.extend(generator.__all__)
|