mindspore 2.7.0__cp310-cp310-win_amd64.whl → 2.7.1__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -1
- mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
- mindspore/_extends/parse/compile_config.py +24 -1
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +6 -2
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +8 -1
- mindspore/_extends/parse/trope.py +2 -1
- mindspore/_extends/pijit/pijit_func_white_list.py +7 -22
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/base.py +29 -2
- mindspore/common/_decorator.py +3 -2
- mindspore/common/_grad_function.py +3 -1
- mindspore/common/_tensor_cpp_method.py +1 -1
- mindspore/common/_tensor_docs.py +275 -64
- mindspore/common/_utils.py +0 -44
- mindspore/common/api.py +285 -35
- mindspore/common/dump.py +7 -108
- mindspore/common/dynamic_shape/auto_dynamic_shape.py +1 -3
- mindspore/common/hook_handle.py +60 -0
- mindspore/common/jit_config.py +5 -1
- mindspore/common/jit_trace.py +27 -12
- mindspore/common/lazy_inline.py +5 -3
- mindspore/common/parameter.py +13 -107
- mindspore/common/recompute.py +4 -11
- mindspore/common/tensor.py +16 -169
- mindspore/communication/_comm_helper.py +11 -1
- mindspore/communication/comm_func.py +138 -4
- mindspore/communication/management.py +85 -1
- mindspore/config/op_info.config +0 -15
- mindspore/context.py +5 -85
- mindspore/dataset/engine/datasets.py +8 -4
- mindspore/dataset/engine/datasets_vision.py +1 -1
- mindspore/dataset/engine/validators.py +1 -15
- mindspore/dnnl.dll +0 -0
- mindspore/{experimental/llm_boost/ascend_native → graph}/__init__.py +7 -7
- mindspore/graph/custom_pass.py +55 -0
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/mindrecord/__init__.py +3 -3
- mindspore/mindrecord/common/exceptions.py +1 -0
- mindspore/mindrecord/config.py +1 -1
- mindspore/{parallel/mpi → mindrecord/core}/__init__.py +4 -1
- mindspore/mindrecord/{shardheader.py → core/shardheader.py} +2 -1
- mindspore/mindrecord/{shardindexgenerator.py → core/shardindexgenerator.py} +1 -1
- mindspore/mindrecord/{shardreader.py → core/shardreader.py} +2 -1
- mindspore/mindrecord/{shardsegment.py → core/shardsegment.py} +2 -2
- mindspore/mindrecord/{shardutils.py → core/shardutils.py} +1 -1
- mindspore/mindrecord/{shardwriter.py → core/shardwriter.py} +1 -1
- mindspore/mindrecord/filereader.py +4 -4
- mindspore/mindrecord/filewriter.py +5 -5
- mindspore/mindrecord/mindpage.py +2 -2
- mindspore/mindrecord/tools/cifar10.py +1 -1
- mindspore/mindrecord/tools/cifar100.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +1 -1
- mindspore/mindrecord/tools/cifar10_to_mr.py +1 -1
- mindspore/mindrecord/tools/csv_to_mr.py +1 -1
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +1 -1
- mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -1
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_cluster.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_cpu.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_hardware_abstract.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_ops_host.dll → mindspore_ops_cpu.dll} +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mindspore_runtime_utils.dll +0 -0
- mindspore/mindspore_tools.dll +0 -0
- mindspore/mint/__init__.py +15 -10
- mindspore/mint/distributed/distributed.py +182 -62
- mindspore/mint/nn/__init__.py +2 -16
- mindspore/mint/nn/functional.py +4 -110
- mindspore/mint/nn/layer/__init__.py +0 -2
- mindspore/mint/nn/layer/activation.py +0 -6
- mindspore/mint/nn/layer/basic.py +0 -47
- mindspore/mint/nn/layer/conv.py +4 -4
- mindspore/mint/nn/layer/normalization.py +8 -13
- mindspore/mint/nn/layer/pooling.py +0 -4
- mindspore/nn/__init__.py +1 -3
- mindspore/nn/cell.py +16 -66
- mindspore/nn/layer/basic.py +49 -1
- mindspore/nn/layer/container.py +16 -0
- mindspore/nn/layer/embedding.py +4 -169
- mindspore/nn/layer/normalization.py +2 -1
- mindspore/nn/layer/thor_layer.py +4 -85
- mindspore/nn/optim/ada_grad.py +0 -1
- mindspore/nn/optim/adafactor.py +0 -1
- mindspore/nn/optim/adam.py +31 -124
- mindspore/nn/optim/adamax.py +0 -1
- mindspore/nn/optim/asgd.py +0 -1
- mindspore/nn/optim/ftrl.py +8 -102
- mindspore/nn/optim/lamb.py +0 -1
- mindspore/nn/optim/lars.py +0 -3
- mindspore/nn/optim/lazyadam.py +25 -218
- mindspore/nn/optim/momentum.py +5 -43
- mindspore/nn/optim/optimizer.py +6 -55
- mindspore/nn/optim/proximal_ada_grad.py +0 -1
- mindspore/nn/optim/rmsprop.py +0 -1
- mindspore/nn/optim/rprop.py +0 -1
- mindspore/nn/optim/sgd.py +0 -1
- mindspore/nn/optim/tft_wrapper.py +0 -1
- mindspore/nn/optim/thor.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +7 -8
- mindspore/nn/probability/bijector/gumbel_cdf.py +2 -2
- mindspore/nn/probability/bijector/power_transform.py +20 -21
- mindspore/nn/probability/bijector/scalar_affine.py +5 -5
- mindspore/nn/probability/bijector/softplus.py +13 -14
- mindspore/nn/wrap/grad_reducer.py +4 -74
- mindspore/numpy/array_creations.py +2 -2
- mindspore/numpy/fft.py +9 -9
- mindspore/{nn/reinforcement → onnx}/__init__.py +5 -8
- mindspore/onnx/onnx_export.py +137 -0
- mindspore/opencv_core4110.dll +0 -0
- mindspore/opencv_imgcodecs4110.dll +0 -0
- mindspore/{opencv_imgproc452.dll → opencv_imgproc4110.dll} +0 -0
- mindspore/ops/__init__.py +2 -0
- mindspore/ops/_grad_experimental/grad_comm_ops.py +38 -2
- mindspore/ops/_op_impl/aicpu/__init__.py +0 -10
- mindspore/ops/_op_impl/cpu/__init__.py +0 -5
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +16 -22
- mindspore/ops/auto_generate/gen_extend_func.py +2 -7
- mindspore/ops/auto_generate/gen_ops_def.py +98 -141
- mindspore/ops/auto_generate/gen_ops_prim.py +12708 -12686
- mindspore/ops/communication.py +97 -0
- mindspore/ops/composite/__init__.py +5 -2
- mindspore/ops/composite/base.py +15 -1
- mindspore/ops/composite/multitype_ops/__init__.py +3 -1
- mindspore/ops/composite/multitype_ops/_compile_utils.py +150 -8
- mindspore/ops/composite/multitype_ops/add_impl.py +7 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +27 -0
- mindspore/ops/function/__init__.py +1 -0
- mindspore/ops/function/array_func.py +14 -12
- mindspore/ops/function/comm_func.py +3883 -0
- mindspore/ops/function/debug_func.py +3 -4
- mindspore/ops/function/math_func.py +45 -54
- mindspore/ops/function/nn_func.py +75 -294
- mindspore/ops/function/random_func.py +9 -18
- mindspore/ops/functional.py +2 -0
- mindspore/ops/functional_overload.py +354 -18
- mindspore/ops/operations/__init__.py +2 -5
- mindspore/ops/operations/_custom_ops_utils.py +7 -9
- mindspore/ops/operations/_inner_ops.py +1 -38
- mindspore/ops/operations/_rl_inner_ops.py +0 -933
- mindspore/ops/operations/array_ops.py +1 -0
- mindspore/ops/operations/comm_ops.py +94 -2
- mindspore/ops/operations/custom_ops.py +228 -19
- mindspore/ops/operations/debug_ops.py +27 -29
- mindspore/ops/operations/manually_defined/ops_def.py +27 -306
- mindspore/ops/operations/nn_ops.py +2 -2
- mindspore/ops/operations/sparse_ops.py +0 -83
- mindspore/ops/primitive.py +1 -17
- mindspore/ops/tensor_method.py +72 -3
- mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +5 -5
- mindspore/ops_generate/aclnn/gen_aclnn_implement.py +8 -8
- mindspore/ops_generate/api/functions_cc_generator.py +53 -4
- mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +25 -11
- mindspore/ops_generate/common/gen_constants.py +11 -10
- mindspore/ops_generate/common/op_proto.py +18 -1
- mindspore/ops_generate/common/template.py +102 -245
- mindspore/ops_generate/common/template_utils.py +212 -0
- mindspore/ops_generate/gen_custom_ops.py +69 -0
- mindspore/ops_generate/op_def/ops_def_cc_generator.py +78 -7
- mindspore/ops_generate/op_def_py/base_op_prim_py_generator.py +360 -0
- mindspore/ops_generate/op_def_py/custom_op_prim_py_generator.py +140 -0
- mindspore/ops_generate/op_def_py/op_def_py_generator.py +54 -7
- mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -312
- mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +74 -17
- mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +22 -5
- mindspore/ops_generate/pyboost/op_template_parser.py +3 -2
- mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +21 -5
- mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +2 -2
- mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +30 -10
- mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +10 -3
- mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +1 -1
- mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +19 -9
- mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +71 -28
- mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +10 -9
- mindspore/ops_generate/pyboost/pyboost_utils.py +27 -16
- mindspore/ops_generate/resources/yaml_loader.py +13 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +2 -2
- mindspore/parallel/_cell_wrapper.py +1 -1
- mindspore/parallel/_parallel_serialization.py +1 -4
- mindspore/parallel/_utils.py +29 -6
- mindspore/parallel/checkpoint_transform.py +18 -2
- mindspore/parallel/cluster/process_entity/_api.py +24 -32
- mindspore/parallel/cluster/process_entity/_utils.py +9 -5
- mindspore/{experimental/llm_boost/atb → parallel/distributed}/__init__.py +21 -23
- mindspore/parallel/distributed/distributed_data_parallel.py +393 -0
- mindspore/parallel/distributed/flatten_grad_buffer.py +295 -0
- mindspore/parallel/strategy.py +336 -0
- mindspore/parallel/transform_safetensors.py +117 -16
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +3 -0
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +1 -1
- mindspore/profiler/common/constant.py +5 -0
- mindspore/profiler/common/file_manager.py +9 -0
- mindspore/profiler/common/msprof_cmd_tool.py +38 -2
- mindspore/profiler/common/path_manager.py +56 -24
- mindspore/profiler/common/profiler_context.py +2 -12
- mindspore/profiler/common/profiler_info.py +3 -3
- mindspore/profiler/common/profiler_path_manager.py +13 -0
- mindspore/profiler/common/util.py +30 -3
- mindspore/profiler/experimental_config.py +2 -1
- mindspore/profiler/platform/npu_profiler.py +33 -6
- mindspore/run_check/_check_version.py +108 -24
- mindspore/runtime/__init__.py +3 -2
- mindspore/runtime/executor.py +11 -3
- mindspore/runtime/memory.py +112 -0
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/{experimental/llm_boost → tools}/__init__.py +5 -5
- mindspore/tools/data_dump.py +130 -0
- mindspore/tools/sdc_detect.py +91 -0
- mindspore/tools/stress_detect.py +63 -0
- mindspore/train/__init__.py +6 -6
- mindspore/train/_utils.py +5 -18
- mindspore/train/amp.py +6 -4
- mindspore/train/callback/_checkpoint.py +0 -9
- mindspore/train/callback/_train_fault_tolerance.py +69 -18
- mindspore/train/data_sink.py +1 -5
- mindspore/train/model.py +38 -211
- mindspore/train/serialization.py +126 -387
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +6 -3
- mindspore/utils/dlpack.py +92 -0
- mindspore/utils/dryrun.py +1 -1
- mindspore/utils/runtime_execution_order_check.py +10 -0
- mindspore/utils/sdc_detect.py +14 -12
- mindspore/utils/stress_detect.py +43 -0
- mindspore/utils/utils.py +144 -8
- mindspore/version.py +1 -1
- {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/METADATA +3 -2
- {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/RECORD +254 -267
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -210
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +0 -52
- mindspore/experimental/llm_boost/atb/boost_base.py +0 -385
- mindspore/experimental/llm_boost/atb/llama_boost.py +0 -137
- mindspore/experimental/llm_boost/atb/qwen_boost.py +0 -124
- mindspore/experimental/llm_boost/register.py +0 -130
- mindspore/experimental/llm_boost/utils.py +0 -31
- mindspore/include/OWNERS +0 -7
- mindspore/mindspore_cpu_res_manager.dll +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/nn/optim/_dist_optimizer_registry.py +0 -111
- mindspore/nn/reinforcement/_batch_read_write.py +0 -142
- mindspore/nn/reinforcement/_tensors_queue.py +0 -152
- mindspore/nn/reinforcement/tensor_array.py +0 -145
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +0 -113
- mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +0 -96
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +0 -42
- mindspore/ops/_op_impl/cpu/buffer_append.py +0 -28
- mindspore/ops/_op_impl/cpu/buffer_get.py +0 -28
- mindspore/ops/_op_impl/cpu/buffer_sample.py +0 -28
- mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +0 -42
- mindspore/ops/operations/_tensor_array.py +0 -359
- mindspore/ops/operations/rl_ops.py +0 -288
- mindspore/parallel/_offload_context.py +0 -275
- mindspore/parallel/_recovery_context.py +0 -115
- mindspore/parallel/_transformer/__init__.py +0 -35
- mindspore/parallel/_transformer/layers.py +0 -765
- mindspore/parallel/_transformer/loss.py +0 -251
- mindspore/parallel/_transformer/moe.py +0 -693
- mindspore/parallel/_transformer/op_parallel_config.py +0 -222
- mindspore/parallel/_transformer/transformer.py +0 -3124
- mindspore/parallel/mpi/_mpi_config.py +0 -116
- mindspore/train/memory_profiling_pb2.py +0 -298
- {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/WHEEL +0 -0
- {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/top_level.txt +0 -0
|
@@ -1,137 +0,0 @@
|
|
|
1
|
-
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ============================================================================
|
|
15
|
-
"""llm boost"""
|
|
16
|
-
import json
|
|
17
|
-
import mindspore.common.dtype as mstype
|
|
18
|
-
from mindspore.experimental.llm_boost.atb.boost_base import (
|
|
19
|
-
AtbBoostBase,
|
|
20
|
-
PositionEmbeddingType,
|
|
21
|
-
NormType,
|
|
22
|
-
)
|
|
23
|
-
from mindspore._c_expression import LlmBoostBinder
|
|
24
|
-
from mindspore.experimental.llm_boost.register import LlmBoostRegister, LlmBoostType
|
|
25
|
-
|
|
26
|
-
CPP_LLAMA_MODEL_CLASS_NAME = "llama_LlamaDecoderModel"
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
@LlmBoostRegister.register(LlmBoostType.BUILDIN, "Llama")
|
|
30
|
-
class LlamaBoost(AtbBoostBase):
|
|
31
|
-
"""LlamaBoost class"""
|
|
32
|
-
|
|
33
|
-
def __init__(self, config):
|
|
34
|
-
super().__init__(config)
|
|
35
|
-
self.in_tensor_length = 13
|
|
36
|
-
self.acl_encoder_operation_inputs = [None] * self.in_tensor_length
|
|
37
|
-
self.acl_decoder_operation_inputs = [None] * self.in_tensor_length
|
|
38
|
-
self.atb_encoder_operation = LlmBoostBinder(
|
|
39
|
-
self.backend_name, CPP_LLAMA_MODEL_CLASS_NAME
|
|
40
|
-
)
|
|
41
|
-
self.atb_decoder_operation = LlmBoostBinder(
|
|
42
|
-
self.backend_name, CPP_LLAMA_MODEL_CLASS_NAME
|
|
43
|
-
)
|
|
44
|
-
|
|
45
|
-
def init(self):
|
|
46
|
-
"""
|
|
47
|
-
Initialize the object
|
|
48
|
-
returns True if object needs input manipulation by mindformers
|
|
49
|
-
"""
|
|
50
|
-
|
|
51
|
-
coder_param = {
|
|
52
|
-
"normEps": self.config.rms_norm_eps,
|
|
53
|
-
"normType": NormType.RMS_NORM,
|
|
54
|
-
"numAttentionHeadsPerRank": self.config.num_heads // self.device_num,
|
|
55
|
-
"hiddenSizePerAttentionHead": self.head_dim,
|
|
56
|
-
"numHiddenLayers": self.num_layers,
|
|
57
|
-
"numKeyValueHeadsPerRank": self.n_kv_heads // self.device_num,
|
|
58
|
-
"skipWordEmbedding": False,
|
|
59
|
-
"isFA": False,
|
|
60
|
-
"isBF16": self.dtype == mstype.bfloat16,
|
|
61
|
-
"packQuantType": [[1, 1] for _ in range(self.num_layers)],
|
|
62
|
-
"linearQuantType": [
|
|
63
|
-
[0, -1, -1, 0, 0, -1, 0] for _ in range(self.num_layers)
|
|
64
|
-
],
|
|
65
|
-
"linearTransposeType": [
|
|
66
|
-
[1, -1, -1, 1, 1, -1, 1] for i in range(self.num_layers)
|
|
67
|
-
],
|
|
68
|
-
"isEmbeddingParallel": False,
|
|
69
|
-
"isLmHeadParallel": not self.config.parallel_config.vocab_emb_dp,
|
|
70
|
-
"lmHeadTransposeType": 1,
|
|
71
|
-
"enableSwiGLU": True,
|
|
72
|
-
"enablekvQuant": self.kv_quant is not None,
|
|
73
|
-
"rank": self.rank_id,
|
|
74
|
-
"worldSize": self.device_num,
|
|
75
|
-
"backend": self.config.communication_backend,
|
|
76
|
-
"rankTableFile": "",
|
|
77
|
-
"positionEmbeddingType": PositionEmbeddingType.ROPE,
|
|
78
|
-
"hiddenSize": self.config.hidden_size,
|
|
79
|
-
"gemma": False,
|
|
80
|
-
"enableAddNorm": False,
|
|
81
|
-
"enableCompressHead": False,
|
|
82
|
-
"isUnpadInputs": True,
|
|
83
|
-
}
|
|
84
|
-
encoder_param = {
|
|
85
|
-
**coder_param,
|
|
86
|
-
"isPrefill": True,
|
|
87
|
-
"enableLcoc": True,
|
|
88
|
-
"enableSpeculate": False,
|
|
89
|
-
"skipWordEmbedding": False,
|
|
90
|
-
"enableSplitFuse": False,
|
|
91
|
-
}
|
|
92
|
-
decoder_param = {
|
|
93
|
-
**coder_param,
|
|
94
|
-
"isPrefill": False,
|
|
95
|
-
"enableLcoc": False,
|
|
96
|
-
"enableSpeculate": False,
|
|
97
|
-
}
|
|
98
|
-
self.atb_encoder_operation.init(json.dumps({**encoder_param}))
|
|
99
|
-
self.atb_decoder_operation.init(json.dumps({**decoder_param}))
|
|
100
|
-
return True
|
|
101
|
-
|
|
102
|
-
def _prepare_inputs(
|
|
103
|
-
self,
|
|
104
|
-
prefill=None,
|
|
105
|
-
input_ids=None,
|
|
106
|
-
position_ids=None,
|
|
107
|
-
cos_embed=None,
|
|
108
|
-
sin_embed=None,
|
|
109
|
-
attention_mask=None,
|
|
110
|
-
block_tables=None,
|
|
111
|
-
slots=None,
|
|
112
|
-
input_lengths=None,
|
|
113
|
-
lm_head_indices=None,
|
|
114
|
-
seqLen=None,
|
|
115
|
-
**kwargs
|
|
116
|
-
):
|
|
117
|
-
"""prepare inputs"""
|
|
118
|
-
self.acl_param = json.dumps(
|
|
119
|
-
{
|
|
120
|
-
"seqLen": seqLen,
|
|
121
|
-
}
|
|
122
|
-
)
|
|
123
|
-
|
|
124
|
-
self.acl_decoder_operation_inputs[0] = input_ids
|
|
125
|
-
self.acl_decoder_operation_inputs[1] = self.placeholder
|
|
126
|
-
self.acl_decoder_operation_inputs[2] = position_ids
|
|
127
|
-
self.acl_decoder_operation_inputs[3] = cos_embed
|
|
128
|
-
self.acl_decoder_operation_inputs[4] = sin_embed
|
|
129
|
-
self.acl_decoder_operation_inputs[5] = attention_mask
|
|
130
|
-
self.acl_decoder_operation_inputs[6] = block_tables
|
|
131
|
-
self.acl_decoder_operation_inputs[7] = slots
|
|
132
|
-
self.acl_decoder_operation_inputs[8] = self.placeholder
|
|
133
|
-
self.acl_decoder_operation_inputs[9] = self.placeholder
|
|
134
|
-
self.acl_decoder_operation_inputs[10] = self.placeholder
|
|
135
|
-
self.acl_decoder_operation_inputs[11] = input_lengths
|
|
136
|
-
self.acl_decoder_operation_inputs[12] = lm_head_indices
|
|
137
|
-
return self.acl_decoder_operation_inputs, self.acl_param
|
|
@@ -1,124 +0,0 @@
|
|
|
1
|
-
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ============================================================================
|
|
15
|
-
"""llm boost"""
|
|
16
|
-
import json
|
|
17
|
-
import mindspore.common.dtype as mstype
|
|
18
|
-
from mindspore.experimental.llm_boost.atb.boost_base import AtbBoostBase, NormType
|
|
19
|
-
from mindspore._c_expression import LlmBoostBinder
|
|
20
|
-
from mindspore.experimental.llm_boost.register import LlmBoostRegister, LlmBoostType
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
CPP_QWEN_MODEL_CLASS_NAME = "qwen_QwenDecoderModel"
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
@LlmBoostRegister.register(LlmBoostType.BUILDIN, "Qwen")
|
|
27
|
-
class QwenBoost(AtbBoostBase):
|
|
28
|
-
"""QwenBoost class"""
|
|
29
|
-
|
|
30
|
-
def __init__(self, config):
|
|
31
|
-
super().__init__(config)
|
|
32
|
-
self.in_tensor_length = 12
|
|
33
|
-
self.acl_encoder_operation_inputs = [None] * self.in_tensor_length
|
|
34
|
-
self.acl_decoder_operation_inputs = [None] * self.in_tensor_length
|
|
35
|
-
self.atb_encoder_operation = LlmBoostBinder(
|
|
36
|
-
self.backend_name, CPP_QWEN_MODEL_CLASS_NAME
|
|
37
|
-
)
|
|
38
|
-
self.atb_decoder_operation = LlmBoostBinder(
|
|
39
|
-
self.backend_name, CPP_QWEN_MODEL_CLASS_NAME
|
|
40
|
-
)
|
|
41
|
-
|
|
42
|
-
def init(self):
|
|
43
|
-
"""set param"""
|
|
44
|
-
param_dict = {
|
|
45
|
-
"isFA": False,
|
|
46
|
-
"isBF16": self.dtype == mstype.bfloat16,
|
|
47
|
-
"withEmbedding": True,
|
|
48
|
-
"isEmbeddingParallel": True,
|
|
49
|
-
"isLmHeadParallel": True,
|
|
50
|
-
"linearTransposeType": [
|
|
51
|
-
[1, -1, -1, 1, 1, -1, 1] for i in range(self.num_layers)
|
|
52
|
-
],
|
|
53
|
-
"lmHeadTransposeType": 1,
|
|
54
|
-
"enableSwiGLU": not self.need_nz,
|
|
55
|
-
"normEps": self.config.rms_norm_eps,
|
|
56
|
-
"normType": NormType.RMS_NORM,
|
|
57
|
-
"numAttentionHeadsPerRank": self.config.num_heads // self.device_num,
|
|
58
|
-
"hiddenSizePerAttentionHead": self.head_dim,
|
|
59
|
-
"numHiddenLayers": self.num_layers,
|
|
60
|
-
"numKeyValueHeadsPerRank": self.n_kv_heads // self.device_num,
|
|
61
|
-
"rank": self.rank_id,
|
|
62
|
-
"worldSize": self.device_num,
|
|
63
|
-
"backend": self.config.communication_backend,
|
|
64
|
-
"packQuantType": [[1, 1] for _ in range(self.num_layers)],
|
|
65
|
-
"linearQuantType": [
|
|
66
|
-
[0, -1, -1, 0, 0, -1, 0] for _ in range(self.num_layers)
|
|
67
|
-
],
|
|
68
|
-
"linearHasBias": [[True, False, False, False]] * self.num_layers,
|
|
69
|
-
"enableKvQuant": self.kv_quant is not None,
|
|
70
|
-
"enableLora": False,
|
|
71
|
-
"isUnpadInputs": True,
|
|
72
|
-
"enableAddNorm": False,
|
|
73
|
-
}
|
|
74
|
-
encoder_param = {
|
|
75
|
-
**param_dict,
|
|
76
|
-
"isPrefill": True,
|
|
77
|
-
"enableLcoc": False,
|
|
78
|
-
"enableSplitFuse": False,
|
|
79
|
-
}
|
|
80
|
-
decoder_param = {
|
|
81
|
-
**param_dict,
|
|
82
|
-
"isPrefill": False,
|
|
83
|
-
"enableLcoc": False,
|
|
84
|
-
"enableSpeculate": False,
|
|
85
|
-
"enablePrefixCache": False,
|
|
86
|
-
}
|
|
87
|
-
self.atb_encoder_operation.init(json.dumps({**encoder_param}))
|
|
88
|
-
self.atb_decoder_operation.init(json.dumps({**decoder_param}))
|
|
89
|
-
|
|
90
|
-
def _prepare_inputs(
|
|
91
|
-
self,
|
|
92
|
-
prefill=None,
|
|
93
|
-
input_ids=None,
|
|
94
|
-
position_ids=None,
|
|
95
|
-
cos_embed=None,
|
|
96
|
-
sin_embed=None,
|
|
97
|
-
attention_mask=None,
|
|
98
|
-
block_tables=None,
|
|
99
|
-
slots=None,
|
|
100
|
-
input_lengths=None,
|
|
101
|
-
lm_head_indices=None,
|
|
102
|
-
seqLen=None,
|
|
103
|
-
**kwargs
|
|
104
|
-
):
|
|
105
|
-
"""prepare inputs"""
|
|
106
|
-
self.acl_param = json.dumps(
|
|
107
|
-
{
|
|
108
|
-
"seqLen": seqLen,
|
|
109
|
-
}
|
|
110
|
-
)
|
|
111
|
-
|
|
112
|
-
self.acl_decoder_operation_inputs[0] = input_ids
|
|
113
|
-
self.acl_decoder_operation_inputs[1] = position_ids
|
|
114
|
-
self.acl_decoder_operation_inputs[2] = cos_embed
|
|
115
|
-
self.acl_decoder_operation_inputs[3] = sin_embed
|
|
116
|
-
self.acl_decoder_operation_inputs[4] = attention_mask
|
|
117
|
-
self.acl_decoder_operation_inputs[5] = block_tables
|
|
118
|
-
self.acl_decoder_operation_inputs[6] = slots
|
|
119
|
-
self.acl_decoder_operation_inputs[7] = self.placeholder
|
|
120
|
-
self.acl_decoder_operation_inputs[8] = self.placeholder
|
|
121
|
-
self.acl_decoder_operation_inputs[9] = self.placeholder
|
|
122
|
-
self.acl_decoder_operation_inputs[10] = input_lengths
|
|
123
|
-
self.acl_decoder_operation_inputs[11] = lm_head_indices
|
|
124
|
-
return self.acl_decoder_operation_inputs, self.acl_param
|
|
@@ -1,130 +0,0 @@
|
|
|
1
|
-
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ============================================================================
|
|
15
|
-
"""LlmBoostRegister"""
|
|
16
|
-
import inspect
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
class LlmBoostType:
|
|
20
|
-
"""Class module type for vision pretrain"""
|
|
21
|
-
|
|
22
|
-
def __init__(self):
|
|
23
|
-
pass
|
|
24
|
-
|
|
25
|
-
BUILDIN = 'BuildIn'
|
|
26
|
-
ASCEND_NATIVE = 'LLMBoost'
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
class LlmBoostRegister:
|
|
30
|
-
"""
|
|
31
|
-
Module class factory.
|
|
32
|
-
"""
|
|
33
|
-
|
|
34
|
-
def __init__(self):
|
|
35
|
-
pass
|
|
36
|
-
|
|
37
|
-
registry = {}
|
|
38
|
-
|
|
39
|
-
@classmethod
|
|
40
|
-
def register(cls, boost_type=LlmBoostType.BUILDIN, alias=None):
|
|
41
|
-
"""Register class into registry
|
|
42
|
-
Args:
|
|
43
|
-
boost_type:
|
|
44
|
-
boost type name, default LlmBoostType.BUILDIN
|
|
45
|
-
alias (str) : model_name
|
|
46
|
-
|
|
47
|
-
Returns:
|
|
48
|
-
wrapper
|
|
49
|
-
"""
|
|
50
|
-
|
|
51
|
-
def wrapper(register_class):
|
|
52
|
-
"""Register-Class with wrapper function.
|
|
53
|
-
|
|
54
|
-
Args:
|
|
55
|
-
register_class : class need to register
|
|
56
|
-
|
|
57
|
-
Returns:
|
|
58
|
-
wrapper of register_class
|
|
59
|
-
"""
|
|
60
|
-
model_name = alias if alias is not None else register_class.__name__
|
|
61
|
-
if boost_type not in cls.registry:
|
|
62
|
-
cls.registry[boost_type] = {model_name: register_class}
|
|
63
|
-
else:
|
|
64
|
-
cls.registry[boost_type][model_name] = register_class
|
|
65
|
-
return register_class
|
|
66
|
-
|
|
67
|
-
return wrapper
|
|
68
|
-
|
|
69
|
-
@classmethod
|
|
70
|
-
def is_exist(cls, boost_type, model_name=None):
|
|
71
|
-
"""Determine whether class name is in the current type group.
|
|
72
|
-
|
|
73
|
-
Args:
|
|
74
|
-
boost_type : Module type
|
|
75
|
-
model_name : model name
|
|
76
|
-
|
|
77
|
-
Returns:
|
|
78
|
-
True/False
|
|
79
|
-
"""
|
|
80
|
-
if not model_name:
|
|
81
|
-
return boost_type in cls.registry
|
|
82
|
-
registered = boost_type in cls.registry and model_name in cls.registry.get(
|
|
83
|
-
boost_type)
|
|
84
|
-
return registered
|
|
85
|
-
|
|
86
|
-
@classmethod
|
|
87
|
-
def get_cls(cls, boost_type, model_name=None):
|
|
88
|
-
"""Get class
|
|
89
|
-
|
|
90
|
-
Args:
|
|
91
|
-
boost_type : Module type
|
|
92
|
-
model_name : model name
|
|
93
|
-
|
|
94
|
-
Returns:
|
|
95
|
-
register_class
|
|
96
|
-
"""
|
|
97
|
-
if not cls.is_exist(boost_type, model_name):
|
|
98
|
-
raise ValueError("Can't find class type {} class name {} \
|
|
99
|
-
in class registry".format(boost_type, model_name))
|
|
100
|
-
|
|
101
|
-
if not model_name:
|
|
102
|
-
raise ValueError(
|
|
103
|
-
"Can't find model. model name = {}".format(model_name))
|
|
104
|
-
register_class = cls.registry.get(boost_type).get(model_name)
|
|
105
|
-
return register_class
|
|
106
|
-
|
|
107
|
-
@classmethod
|
|
108
|
-
def get_instance(cls, boost_type=LlmBoostType.BUILDIN, model_name=None, **kwargs):
|
|
109
|
-
"""Get instance.
|
|
110
|
-
Args:
|
|
111
|
-
boost_type : module type
|
|
112
|
-
model_name : model type
|
|
113
|
-
Returns:
|
|
114
|
-
object : The constructed object
|
|
115
|
-
"""
|
|
116
|
-
if model_name is None:
|
|
117
|
-
raise ValueError("Class name cannot be None.")
|
|
118
|
-
|
|
119
|
-
if isinstance(model_name, str):
|
|
120
|
-
obj_cls = cls.get_cls(boost_type, model_name)
|
|
121
|
-
elif inspect.isclass(model_name):
|
|
122
|
-
obj_cls = model_name
|
|
123
|
-
else:
|
|
124
|
-
raise ValueError("Can't find boost type {} model name {} \
|
|
125
|
-
in class registry.".format(boost_type, model_name))
|
|
126
|
-
|
|
127
|
-
try:
|
|
128
|
-
return obj_cls(**kwargs)
|
|
129
|
-
except Exception as e:
|
|
130
|
-
raise type(e)('{}: {}'.format(obj_cls.__name__, e))
|
|
@@ -1,31 +0,0 @@
|
|
|
1
|
-
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ============================================================================
|
|
15
|
-
"""LlmBoostRegister"""
|
|
16
|
-
import os
|
|
17
|
-
from mindspore.communication import get_group_size, get_rank
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
def get_real_rank():
|
|
21
|
-
try:
|
|
22
|
-
return get_rank()
|
|
23
|
-
except RuntimeError:
|
|
24
|
-
return int(os.getenv("RANK_ID", "0"))
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
def get_real_group_size():
|
|
28
|
-
try:
|
|
29
|
-
return get_group_size()
|
|
30
|
-
except RuntimeError:
|
|
31
|
-
return int(os.getenv("RANK_SIZE", "1"))
|
mindspore/include/OWNERS
DELETED
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
@@ -1,111 +0,0 @@
|
|
|
1
|
-
# Copyright 2022 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ============================================================================
|
|
15
|
-
"""_dist_optimizer_registry"""
|
|
16
|
-
from __future__ import absolute_import
|
|
17
|
-
|
|
18
|
-
from inspect import isfunction
|
|
19
|
-
|
|
20
|
-
from mindspore.parallel._ps_context import _get_ps_context, _is_ps_mode
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
_create_func_map = {}
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
def _register_dist_optimizer(optimizer_type, creating_func):
|
|
27
|
-
"""
|
|
28
|
-
Register distributed optimizers.
|
|
29
|
-
This method should be called by original optimizers.
|
|
30
|
-
"""
|
|
31
|
-
if optimizer_type in _create_func_map:
|
|
32
|
-
return
|
|
33
|
-
if not isfunction(creating_func):
|
|
34
|
-
raise TypeError("creating_func is not a function type!")
|
|
35
|
-
_create_func_map[optimizer_type] = creating_func
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
def empty_creating_func(*args, **kwargs):
|
|
39
|
-
"""Empty function as placeholder."""
|
|
40
|
-
return
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
_pserver_optmizer_attrs = {
|
|
44
|
-
"ms_role": "MS_PSERVER",
|
|
45
|
-
"primitive_target": "CPU",
|
|
46
|
-
"update_parameter": True
|
|
47
|
-
}
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
def create_optimizers_on_pserver(optimizer_type, parameters, *args, **kwargs):
|
|
51
|
-
"""
|
|
52
|
-
Create the optimizers on parameter server.
|
|
53
|
-
This method should be called only in Parameter Server training mode.
|
|
54
|
-
Return distributed optimizer list and the flag list which indicates whether the parameters use them.
|
|
55
|
-
The size of the two lists returned should be the same as the size of input 'parameters'
|
|
56
|
-
"""
|
|
57
|
-
distributed_optimizer_list = []
|
|
58
|
-
use_flag_list = []
|
|
59
|
-
for index, param in enumerate(parameters):
|
|
60
|
-
if param.is_param_ps and (not param.cache_enable):
|
|
61
|
-
if optimizer_type not in _create_func_map:
|
|
62
|
-
raise ValueError("Optimizer type %s is not recognized!" % optimizer_type)
|
|
63
|
-
distributed_optimizer = _create_func_map.get(optimizer_type)(*args, **kwargs)
|
|
64
|
-
|
|
65
|
-
server_rank_id = index % _get_ps_context("server_num")
|
|
66
|
-
distributed_optimizer.add_prim_attr("rank_id", server_rank_id)
|
|
67
|
-
for key, value in _pserver_optmizer_attrs.items():
|
|
68
|
-
distributed_optimizer.add_prim_attr(key, value)
|
|
69
|
-
distributed_optimizer_list.append(distributed_optimizer)
|
|
70
|
-
use_flag_list.append(True)
|
|
71
|
-
else:
|
|
72
|
-
distributed_optimizer_list.append(empty_creating_func)
|
|
73
|
-
use_flag_list.append(False)
|
|
74
|
-
return distributed_optimizer_list, use_flag_list
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
def no_distributed_optimizer(optimizer_type, parameters, *args, **kwargs):
|
|
78
|
-
"""
|
|
79
|
-
In some cases, no distributed optimizers are needed.
|
|
80
|
-
But we still need to return lists so optimizer subclasses can build the network using HyperMap.
|
|
81
|
-
"""
|
|
82
|
-
empty_list = []
|
|
83
|
-
use_flag_list = []
|
|
84
|
-
for _ in parameters:
|
|
85
|
-
empty_list.append(empty_creating_func)
|
|
86
|
-
use_flag_list.append(False)
|
|
87
|
-
return empty_list, use_flag_list
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
def get_creating_func():
|
|
91
|
-
"""
|
|
92
|
-
Returns creating functions for distributed optimizers.
|
|
93
|
-
"""
|
|
94
|
-
# Only support optimizers in parameter server mode for now.
|
|
95
|
-
if _is_ps_mode():
|
|
96
|
-
return create_optimizers_on_pserver
|
|
97
|
-
return no_distributed_optimizer
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
def generate_dist_optimizer_list(optimizer_type, parameters, *args, **kwargs):
|
|
101
|
-
"""
|
|
102
|
-
Generate the distributed optimizers according to the execution mode.
|
|
103
|
-
Only Parameter Server training mode is supported for now.
|
|
104
|
-
"""
|
|
105
|
-
func = get_creating_func()
|
|
106
|
-
opt_list, use_flag_list = func(optimizer_type, parameters, *args, **kwargs)
|
|
107
|
-
if len(opt_list) != len(parameters) or len(use_flag_list) != len(parameters):
|
|
108
|
-
raise ValueError(f"Size of distributed optimizer list should be the same as parameter list. "
|
|
109
|
-
f"But got len(opt_list):{len(opt_list)}"
|
|
110
|
-
f", len(parameters):{len(parameters)}")
|
|
111
|
-
return opt_list, tuple(use_flag_list)
|
|
@@ -1,142 +0,0 @@
|
|
|
1
|
-
# Copyright 2022 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ============================================================================
|
|
15
|
-
"""
|
|
16
|
-
BatchReadWrite
|
|
17
|
-
"""
|
|
18
|
-
from __future__ import absolute_import
|
|
19
|
-
|
|
20
|
-
from mindspore.nn.cell import Cell
|
|
21
|
-
from mindspore.ops.operations._rl_inner_ops import BatchAssign
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
class BatchWrite(Cell):
|
|
25
|
-
r"""BatchWrite: write a list of parameters to assign the target.
|
|
26
|
-
|
|
27
|
-
.. warning::
|
|
28
|
-
This is an experiential prototype that is subject to change and/or deletion.
|
|
29
|
-
|
|
30
|
-
Supported Platforms:
|
|
31
|
-
``GPU`` ``CPU``
|
|
32
|
-
|
|
33
|
-
Examples:
|
|
34
|
-
>>> import mindspore
|
|
35
|
-
>>> from mindspore import nn
|
|
36
|
-
>>> from mindspore.common.parameter import Parameter, ParameterTuple
|
|
37
|
-
>>> from mindspore.nn.reinforcement import BatchWrite
|
|
38
|
-
>>> class SourceNet(nn.Cell):
|
|
39
|
-
... def __init__(self):
|
|
40
|
-
... super(SourceNet, self).__init__()
|
|
41
|
-
... self.a = Parameter(Tensor(0.5, mstype.float32), name="a")
|
|
42
|
-
... self.dense = nn.Dense(in_channels=16, out_channels=1, weight_init=0)
|
|
43
|
-
>>> class DstNet(nn.Cell):
|
|
44
|
-
... def __init__(self):
|
|
45
|
-
... super(DstNet, self).__init__()
|
|
46
|
-
... self.a = Parameter(Tensor(0.1, mstype.float32), name="a")
|
|
47
|
-
... self.dense = nn.Dense(in_channels=16, out_channels=1)
|
|
48
|
-
>>> class Write(nn.Cell):
|
|
49
|
-
... def __init__(self, dst, src):
|
|
50
|
-
... super(Write, self).__init__()
|
|
51
|
-
... self.w = BatchWrite()
|
|
52
|
-
... self.dst = ParameterTuple(dst.trainable_params())
|
|
53
|
-
... self.src = ParameterTuple(src.trainable_params())
|
|
54
|
-
... def construct(self):
|
|
55
|
-
... success = self.w(self.dst, self.src)
|
|
56
|
-
... return success
|
|
57
|
-
>>> dst_net = DstNet()
|
|
58
|
-
>>> source_net = SourceNet()
|
|
59
|
-
>>> nets = nn.CellList()
|
|
60
|
-
>>> nets.append(dst_net)
|
|
61
|
-
>>> nets.append(source_net)
|
|
62
|
-
>>> success = Write(nets[0], nets[1])()
|
|
63
|
-
"""
|
|
64
|
-
def __init__(self):
|
|
65
|
-
"""Initialize BatchWrite"""
|
|
66
|
-
super(BatchWrite, self).__init__()
|
|
67
|
-
self.write = BatchAssign(lock=True)
|
|
68
|
-
|
|
69
|
-
def construct(self, dst, src):
|
|
70
|
-
"""
|
|
71
|
-
Write the source parameter list to assign the dst.
|
|
72
|
-
|
|
73
|
-
Inputs:
|
|
74
|
-
- **dst** (tuple) - A paramameter tuple of the dst model.
|
|
75
|
-
- **src** (tuple) - A paramameter tuple of the source model.
|
|
76
|
-
|
|
77
|
-
Returns:
|
|
78
|
-
Bool, true.
|
|
79
|
-
"""
|
|
80
|
-
self.write(dst, src)
|
|
81
|
-
return True
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
class BatchRead(Cell):
|
|
85
|
-
r"""BatchRead: read a list of parameters to assign the target.
|
|
86
|
-
|
|
87
|
-
.. warning::
|
|
88
|
-
This is an experiential prototype that is subject to change and/or deletion.
|
|
89
|
-
|
|
90
|
-
Supported Platforms:
|
|
91
|
-
``GPU`` ``CPU``
|
|
92
|
-
|
|
93
|
-
Examples:
|
|
94
|
-
>>> import mindspore
|
|
95
|
-
>>> from mindspore import nn
|
|
96
|
-
>>> from mindspore.common.parameter import Parameter, ParameterTuple
|
|
97
|
-
>>> from mindspore.nn.reinforcement import BatchRead
|
|
98
|
-
>>> class SNet(nn.Cell):
|
|
99
|
-
... def __init__(self):
|
|
100
|
-
... super(SNet, self).__init__()
|
|
101
|
-
... self.a = Parameter(Tensor(0.5, mstype.float32), name="a")
|
|
102
|
-
... self.dense = nn.Dense(in_channels=16, out_channels=1, weight_init=0)
|
|
103
|
-
>>> class DNet(nn.Cell):
|
|
104
|
-
... def __init__(self):
|
|
105
|
-
... super(DNet, self).__init__()
|
|
106
|
-
... self.a = Parameter(Tensor(0.1, mstype.float32), name="a")
|
|
107
|
-
... self.dense = nn.Dense(in_channels=16, out_channels=1)
|
|
108
|
-
>>> class Read(nn.Cell):
|
|
109
|
-
... def __init__(self, dst, src):
|
|
110
|
-
... super(Read, self).__init__()
|
|
111
|
-
... self.read = BatchRead()
|
|
112
|
-
... self.dst = ParameterTuple(dst.trainable_params())
|
|
113
|
-
... self.src = ParameterTuple(src.trainable_params())
|
|
114
|
-
... def construct(self):
|
|
115
|
-
... success = self.read(self.dst, self.src)
|
|
116
|
-
... return success
|
|
117
|
-
>>> dst_net = DNet()
|
|
118
|
-
>>> source_net = SNet()
|
|
119
|
-
>>> nets = nn.CellList()
|
|
120
|
-
>>> nets.append(dst_net)
|
|
121
|
-
>>> nets.append(source_net)
|
|
122
|
-
>>> success = Read(nets[0], nets[1])()
|
|
123
|
-
|
|
124
|
-
"""
|
|
125
|
-
def __init__(self):
|
|
126
|
-
"""Initialize BatchRead"""
|
|
127
|
-
super(BatchRead, self).__init__()
|
|
128
|
-
self.read = BatchAssign(lock=False)
|
|
129
|
-
|
|
130
|
-
def construct(self, dst, src):
|
|
131
|
-
"""
|
|
132
|
-
Read the source parameter list to assign the dst.
|
|
133
|
-
|
|
134
|
-
Inputs:
|
|
135
|
-
- **dst** (tuple) - A paramameter tuple of the dst model.
|
|
136
|
-
- **src** (tuple) - A paramameter tuple of the source model.
|
|
137
|
-
|
|
138
|
-
Returns:
|
|
139
|
-
Bool, true.
|
|
140
|
-
"""
|
|
141
|
-
self.read(dst, src)
|
|
142
|
-
return True
|