mindspore 2.3.0__cp310-cp310-win_amd64.whl → 2.4.1__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +3 -1
- mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +50 -9
- mindspore/_extends/parse/compile_config.py +41 -0
- mindspore/_extends/parse/parser.py +9 -7
- mindspore/_extends/parse/standard_method.py +52 -14
- mindspore/_extends/pijit/pijit_func_white_list.py +350 -24
- mindspore/amp.py +24 -10
- mindspore/common/__init__.py +6 -4
- mindspore/common/_pijit_context.py +190 -0
- mindspore/common/_register_for_tensor.py +2 -1
- mindspore/common/_tensor_overload.py +139 -0
- mindspore/common/api.py +102 -87
- mindspore/common/dump.py +5 -6
- mindspore/common/generator.py +1 -7
- mindspore/common/hook_handle.py +14 -26
- mindspore/common/initializer.py +51 -15
- mindspore/common/mindir_util.py +2 -2
- mindspore/common/parameter.py +62 -15
- mindspore/common/recompute.py +39 -9
- mindspore/common/sparse_tensor.py +7 -3
- mindspore/common/tensor.py +183 -37
- mindspore/communication/__init__.py +1 -1
- mindspore/communication/_comm_helper.py +38 -3
- mindspore/communication/comm_func.py +315 -60
- mindspore/communication/management.py +14 -14
- mindspore/context.py +132 -22
- mindspore/dataset/__init__.py +1 -1
- mindspore/dataset/audio/__init__.py +1 -1
- mindspore/dataset/core/config.py +7 -0
- mindspore/dataset/core/validator_helpers.py +7 -0
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +72 -44
- mindspore/dataset/engine/datasets_audio.py +7 -7
- mindspore/dataset/engine/datasets_standard_format.py +53 -3
- mindspore/dataset/engine/datasets_text.py +20 -20
- mindspore/dataset/engine/datasets_user_defined.py +174 -104
- mindspore/dataset/engine/datasets_vision.py +33 -33
- mindspore/dataset/engine/iterators.py +29 -0
- mindspore/dataset/engine/obs/util.py +7 -0
- mindspore/dataset/engine/queue.py +114 -60
- mindspore/dataset/engine/serializer_deserializer.py +2 -2
- mindspore/dataset/engine/validators.py +34 -14
- mindspore/dataset/text/__init__.py +1 -4
- mindspore/dataset/transforms/__init__.py +0 -3
- mindspore/dataset/utils/line_reader.py +2 -0
- mindspore/dataset/vision/__init__.py +1 -4
- mindspore/dataset/vision/utils.py +1 -1
- mindspore/dataset/vision/validators.py +2 -1
- mindspore/{nn/extend → experimental/es}/__init__.py +4 -11
- mindspore/experimental/es/embedding_service.py +883 -0
- mindspore/{nn/layer → experimental/es}/embedding_service_layer.py +218 -30
- mindspore/experimental/llm_boost/__init__.py +21 -0
- mindspore/{nn/extend/layer → experimental/llm_boost/atb}/__init__.py +4 -8
- mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
- mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
- mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
- mindspore/experimental/llm_boost/register.py +129 -0
- mindspore/experimental/llm_boost/utils.py +31 -0
- mindspore/experimental/optim/adamw.py +85 -0
- mindspore/experimental/optim/optimizer.py +3 -0
- mindspore/hal/__init__.py +3 -3
- mindspore/hal/contiguous_tensors_handle.py +175 -0
- mindspore/hal/stream.py +18 -0
- mindspore/include/api/model_group.h +13 -1
- mindspore/include/api/types.h +10 -10
- mindspore/include/dataset/config.h +2 -2
- mindspore/include/dataset/constants.h +2 -2
- mindspore/include/dataset/execute.h +2 -2
- mindspore/include/dataset/vision.h +4 -0
- mindspore/log.py +1 -1
- mindspore/mindrecord/filewriter.py +68 -51
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mint/__init__.py +983 -46
- mindspore/mint/distributed/__init__.py +31 -0
- mindspore/mint/distributed/distributed.py +254 -0
- mindspore/mint/nn/__init__.py +268 -23
- mindspore/mint/nn/functional.py +125 -19
- mindspore/mint/nn/layer/__init__.py +39 -0
- mindspore/mint/nn/layer/activation.py +133 -0
- mindspore/mint/nn/layer/normalization.py +477 -0
- mindspore/mint/nn/layer/pooling.py +110 -0
- mindspore/mint/optim/adamw.py +26 -13
- mindspore/mint/special/__init__.py +63 -0
- mindspore/multiprocessing/__init__.py +2 -1
- mindspore/nn/__init__.py +0 -1
- mindspore/nn/cell.py +276 -96
- mindspore/nn/layer/activation.py +211 -44
- mindspore/nn/layer/basic.py +137 -10
- mindspore/nn/layer/embedding.py +137 -2
- mindspore/nn/layer/normalization.py +101 -5
- mindspore/nn/layer/padding.py +34 -48
- mindspore/nn/layer/pooling.py +161 -7
- mindspore/nn/layer/transformer.py +3 -3
- mindspore/nn/loss/__init__.py +2 -2
- mindspore/nn/loss/loss.py +84 -6
- mindspore/nn/optim/__init__.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -1
- mindspore/nn/optim/adam.py +1 -1
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/tft_wrapper.py +124 -0
- mindspore/nn/wrap/cell_wrapper.py +12 -23
- mindspore/nn/wrap/grad_reducer.py +5 -5
- mindspore/nn/wrap/loss_scale.py +17 -3
- mindspore/numpy/__init__.py +1 -1
- mindspore/numpy/array_creations.py +65 -68
- mindspore/numpy/array_ops.py +64 -60
- mindspore/numpy/fft.py +610 -75
- mindspore/numpy/logic_ops.py +11 -10
- mindspore/numpy/math_ops.py +85 -84
- mindspore/numpy/utils_const.py +4 -4
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +6 -4
- mindspore/ops/_grad_experimental/grad_array_ops.py +0 -11
- mindspore/ops/_grad_experimental/grad_comm_ops.py +67 -4
- mindspore/ops/_grad_experimental/grad_math_ops.py +0 -22
- mindspore/ops/_vmap/vmap_array_ops.py +2 -4
- mindspore/ops/_vmap/vmap_math_ops.py +17 -1
- mindspore/ops/_vmap/vmap_nn_ops.py +43 -2
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +91 -7
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +2 -0
- mindspore/ops/auto_generate/gen_extend_func.py +767 -13
- mindspore/ops/auto_generate/gen_ops_def.py +2452 -364
- mindspore/ops/auto_generate/gen_ops_prim.py +5442 -1756
- mindspore/ops/auto_generate/pyboost_inner_prim.py +176 -56
- mindspore/ops/composite/base.py +85 -48
- mindspore/ops/composite/multitype_ops/_compile_utils.py +1 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -2
- mindspore/ops/function/__init__.py +22 -0
- mindspore/ops/function/array_func.py +492 -153
- mindspore/ops/function/debug_func.py +113 -1
- mindspore/ops/function/fft_func.py +15 -2
- mindspore/ops/function/grad/grad_func.py +3 -2
- mindspore/ops/function/math_func.py +564 -207
- mindspore/ops/function/nn_func.py +817 -383
- mindspore/ops/function/other_func.py +3 -2
- mindspore/ops/function/random_func.py +402 -12
- mindspore/ops/function/reshard_func.py +13 -11
- mindspore/ops/function/sparse_unary_func.py +1 -1
- mindspore/ops/function/vmap_func.py +3 -2
- mindspore/ops/functional.py +24 -14
- mindspore/ops/op_info_register.py +3 -3
- mindspore/ops/operations/__init__.py +7 -2
- mindspore/ops/operations/_grad_ops.py +2 -76
- mindspore/ops/operations/_infer_ops.py +1 -1
- mindspore/ops/operations/_inner_ops.py +71 -94
- mindspore/ops/operations/array_ops.py +14 -146
- mindspore/ops/operations/comm_ops.py +63 -53
- mindspore/ops/operations/custom_ops.py +83 -19
- mindspore/ops/operations/debug_ops.py +42 -10
- mindspore/ops/operations/manually_defined/_inner.py +12 -0
- mindspore/ops/operations/manually_defined/ops_def.py +273 -20
- mindspore/ops/operations/math_ops.py +12 -223
- mindspore/ops/operations/nn_ops.py +20 -114
- mindspore/ops/operations/other_ops.py +7 -4
- mindspore/ops/operations/random_ops.py +46 -1
- mindspore/ops/primitive.py +18 -6
- mindspore/ops_generate/arg_dtype_cast.py +2 -0
- mindspore/ops_generate/gen_aclnn_implement.py +11 -11
- mindspore/ops_generate/gen_constants.py +36 -0
- mindspore/ops_generate/gen_ops.py +67 -52
- mindspore/ops_generate/gen_ops_inner_prim.py +1 -1
- mindspore/ops_generate/gen_pyboost_func.py +131 -47
- mindspore/ops_generate/op_proto.py +10 -3
- mindspore/ops_generate/pyboost_utils.py +14 -1
- mindspore/ops_generate/template.py +43 -21
- mindspore/parallel/__init__.py +3 -1
- mindspore/parallel/_auto_parallel_context.py +31 -9
- mindspore/parallel/_cell_wrapper.py +85 -0
- mindspore/parallel/_parallel_serialization.py +47 -19
- mindspore/parallel/_tensor.py +127 -13
- mindspore/parallel/_utils.py +53 -22
- mindspore/parallel/algo_parameter_config.py +5 -5
- mindspore/parallel/checkpoint_transform.py +46 -39
- mindspore/parallel/cluster/process_entity/__init__.py +1 -1
- mindspore/parallel/cluster/process_entity/_api.py +31 -23
- mindspore/parallel/cluster/process_entity/_utils.py +2 -27
- mindspore/parallel/parameter_broadcast.py +3 -4
- mindspore/parallel/shard.py +162 -31
- mindspore/parallel/transform_safetensors.py +1146 -0
- mindspore/profiler/__init__.py +2 -1
- mindspore/profiler/common/constant.py +29 -0
- mindspore/profiler/common/registry.py +47 -0
- mindspore/profiler/common/util.py +28 -0
- mindspore/profiler/dynamic_profiler.py +694 -0
- mindspore/profiler/envprofiling.py +17 -19
- mindspore/profiler/parser/ascend_analysis/constant.py +18 -0
- mindspore/profiler/parser/ascend_analysis/file_manager.py +25 -4
- mindspore/profiler/parser/ascend_analysis/function_event.py +43 -19
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +31 -26
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +56 -10
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +55 -8
- mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +27 -20
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +9 -2
- mindspore/profiler/parser/ascend_msprof_exporter.py +5 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +27 -25
- mindspore/profiler/parser/base_timeline_generator.py +19 -25
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
- mindspore/profiler/parser/framework_parser.py +1 -391
- mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
- mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
- mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
- mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
- mindspore/profiler/parser/memory_usage_parser.py +0 -154
- mindspore/profiler/parser/profiler_info.py +78 -6
- mindspore/profiler/profiler.py +153 -0
- mindspore/profiler/profiling.py +285 -413
- mindspore/rewrite/__init__.py +1 -2
- mindspore/rewrite/common/namespace.py +4 -4
- mindspore/rewrite/symbol_tree/symbol_tree.py +3 -3
- mindspore/run_check/_check_version.py +39 -104
- mindspore/safeguard/rewrite_obfuscation.py +591 -247
- mindspore/train/__init__.py +4 -3
- mindspore/train/_utils.py +105 -19
- mindspore/train/amp.py +171 -53
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +4 -4
- mindspore/train/callback/_checkpoint.py +97 -31
- mindspore/train/callback/_cluster_monitor.py +1 -1
- mindspore/train/callback/_flops_collector.py +1 -0
- mindspore/train/callback/_loss_monitor.py +3 -3
- mindspore/train/callback/_on_request_exit.py +145 -31
- mindspore/train/callback/_summary_collector.py +5 -5
- mindspore/train/callback/_tft_register.py +375 -0
- mindspore/train/dataset_helper.py +15 -3
- mindspore/train/metrics/metric.py +3 -3
- mindspore/train/metrics/roc.py +4 -4
- mindspore/train/mind_ir_pb2.py +44 -39
- mindspore/train/model.py +154 -58
- mindspore/train/serialization.py +342 -128
- mindspore/utils/__init__.py +21 -0
- mindspore/utils/utils.py +60 -0
- mindspore/version.py +1 -1
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/METADATA +13 -7
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/RECORD +248 -242
- mindspore/include/c_api/ms/abstract.h +0 -67
- mindspore/include/c_api/ms/attribute.h +0 -197
- mindspore/include/c_api/ms/base/handle_types.h +0 -43
- mindspore/include/c_api/ms/base/macros.h +0 -32
- mindspore/include/c_api/ms/base/status.h +0 -33
- mindspore/include/c_api/ms/base/types.h +0 -283
- mindspore/include/c_api/ms/context.h +0 -102
- mindspore/include/c_api/ms/graph.h +0 -160
- mindspore/include/c_api/ms/node.h +0 -606
- mindspore/include/c_api/ms/tensor.h +0 -161
- mindspore/include/c_api/ms/value.h +0 -84
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/nn/extend/basic.py +0 -140
- mindspore/nn/extend/embedding.py +0 -143
- mindspore/nn/extend/layer/normalization.py +0 -109
- mindspore/nn/extend/pooling.py +0 -117
- mindspore/nn/layer/embedding_service.py +0 -531
- mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
- mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
- mindspore/ops/extend/__init__.py +0 -53
- mindspore/ops/extend/array_func.py +0 -218
- mindspore/ops/extend/math_func.py +0 -76
- mindspore/ops/extend/nn_func.py +0 -308
- mindspore/ops/silent_check.py +0 -162
- mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
- mindspore/profiler/parser/msadvisor_parser.py +0 -240
- mindspore/train/callback/_mindio_ttp.py +0 -443
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/WHEEL +0 -0
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/top_level.txt +0 -0
|
@@ -19,8 +19,10 @@ import mindspore as ms
|
|
|
19
19
|
from mindspore import nn, ops, Tensor, Parameter
|
|
20
20
|
from mindspore.ops.auto_generate import init_partition_map, init_embedding_hashmap, embedding_table_find_and_init,\
|
|
21
21
|
embedding_table_find, fake_remote_lookup_uniqued
|
|
22
|
-
from mindspore.ops.
|
|
23
|
-
EmbeddingComputeVarImport, EmbeddingComputeVarExport
|
|
22
|
+
from mindspore.ops.auto_generate import EmbeddingTableImport, EmbeddingTableExport, \
|
|
23
|
+
EmbeddingComputeVarImport, EmbeddingComputeVarExport, EmbeddingTableEvict, EmbeddingFeatureMappingV2, \
|
|
24
|
+
EmbeddingFeatureMappingTableSize, EmbeddingFeatureMappingFind, EmbeddingFeatureMappingExport, \
|
|
25
|
+
EmbeddingFeatureMappingFileSize, EmbeddingFeatureMappingImport, EmbeddingFeatureMappingInsert
|
|
24
26
|
|
|
25
27
|
|
|
26
28
|
class CounterFilter:
|
|
@@ -55,12 +57,14 @@ def _get_backward_float_params(optimizer_mode):
|
|
|
55
57
|
[beta1_power, beta2_power, lr, weight_decay, beta1, beta2, epsilon]
|
|
56
58
|
- when the backward_mode is 'adagrad', it means [lr,]
|
|
57
59
|
"""
|
|
58
|
-
if optimizer_mode == "adagrad":
|
|
60
|
+
if optimizer_mode == "adagrad" or optimizer_mode == "sgd":
|
|
59
61
|
return [0.001]
|
|
60
62
|
if optimizer_mode == "adam":
|
|
61
63
|
return [0.9, 0.99, 0.001, 0.9, 0.999, 1e-08]
|
|
62
64
|
if optimizer_mode == "ftrl":
|
|
63
65
|
return [0.001, -0.5, 0.0, 0.0]
|
|
66
|
+
if optimizer_mode == "rmsprop":
|
|
67
|
+
return [0.001, 0.9, 0.1, 1e-08]
|
|
64
68
|
# adamw
|
|
65
69
|
return [0.9, 0.99, 0.001, 0.01, 0.9, 0.999, 1e-08]
|
|
66
70
|
|
|
@@ -99,6 +103,9 @@ class ESInitLayer(nn.Cell):
|
|
|
99
103
|
self.default_value = None
|
|
100
104
|
|
|
101
105
|
def construct(self):
|
|
106
|
+
"""
|
|
107
|
+
ESInitLayer construct: init embedding hashmap
|
|
108
|
+
"""
|
|
102
109
|
init_partition = init_partition_map(self.ps_num_tensor,
|
|
103
110
|
self.ps_ids_tensor,
|
|
104
111
|
_embedding_dim=self.embedding_dim,
|
|
@@ -145,9 +152,36 @@ class ESInitLayer(nn.Cell):
|
|
|
145
152
|
|
|
146
153
|
|
|
147
154
|
class EsEmbeddingLookup(nn.Cell):
|
|
155
|
+
r"""
|
|
156
|
+
Look up a PS embedding.
|
|
157
|
+
|
|
158
|
+
.. warning::
|
|
159
|
+
This is an experimental EmbeddingService API that is subject to change.
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
table_id (int): The table id.
|
|
163
|
+
es_initializer (EsInitializer): The EsInitialize object for PS embedding with table_id,
|
|
164
|
+
which can be None when the inference is performed.
|
|
165
|
+
embedding_dim (int): The embedding dim of keys for PS embedding with table_id.
|
|
166
|
+
max_key_num (int): The num of keys when lookup.
|
|
167
|
+
optimizer_mode (str): The type of optimizer. Default is ``None``.
|
|
168
|
+
optimizer_params (tuple[float]): The parameters of optimizer. Default is ``None``.
|
|
169
|
+
es_filter (CounterFilter): The option of counter filter for PS embedding with table_id. Default is ``None``.
|
|
170
|
+
es_padding_key (PaddingParamsOption): The option of padding key for PS embedding with table_id.
|
|
171
|
+
Default is ``None``.
|
|
172
|
+
es_completion_key (CompletionKeyOption): The option of completion key for PS embedding with table_id.
|
|
173
|
+
Default is ``None``.
|
|
174
|
+
|
|
175
|
+
Inputs:
|
|
176
|
+
- **keys** (Tensor): The keys of each feature in PS embedding.
|
|
177
|
+
- **actual_keys_input** (Tensor): Tensor composed of all unique elements of keys.
|
|
178
|
+
- **unique_indices** (Tensor): The index value of each element in keys to actual_keys_input .
|
|
179
|
+
- **key_count** (Tensor): The count of each element in the actual_keys_input to keys.
|
|
180
|
+
|
|
181
|
+
Supported Platforms:
|
|
182
|
+
``Atlas A2 training series products``
|
|
148
183
|
"""
|
|
149
|
-
|
|
150
|
-
"""
|
|
184
|
+
|
|
151
185
|
def __init__(self, table_id, es_initializer, embedding_dim, max_key_num, optimizer_mode=None,
|
|
152
186
|
optimizer_params=None, es_filter=None, es_padding_key=None, es_completion_key=None):
|
|
153
187
|
super(EsEmbeddingLookup, self).__init__()
|
|
@@ -182,7 +216,7 @@ class EsEmbeddingLookup(nn.Cell):
|
|
|
182
216
|
self.filter_freq = 1
|
|
183
217
|
self.default_key_or_value = 1
|
|
184
218
|
self.default_key = 0
|
|
185
|
-
self.default_value =
|
|
219
|
+
self.default_value = 1.0
|
|
186
220
|
|
|
187
221
|
self.global_step = 1
|
|
188
222
|
if es_padding_key is not None:
|
|
@@ -193,7 +227,7 @@ class EsEmbeddingLookup(nn.Cell):
|
|
|
193
227
|
self.mask_zero = 0
|
|
194
228
|
self.padding_key = 0
|
|
195
229
|
self.padding_key_mask = 1
|
|
196
|
-
if self.optimizer_mode in ["adam", "ftrl", "adagrad"]:
|
|
230
|
+
if self.optimizer_mode in ["adam", "ftrl", "adagrad", "sgd", "rmsprop"]:
|
|
197
231
|
self.backward_int_params = ([self.global_step], [self.mask_zero],
|
|
198
232
|
[self.padding_key], [self.padding_key_mask])
|
|
199
233
|
else:
|
|
@@ -211,6 +245,9 @@ class EsEmbeddingLookup(nn.Cell):
|
|
|
211
245
|
self.max_grad_norm = Tensor([1.0], ms.float32)
|
|
212
246
|
|
|
213
247
|
def construct(self, keys, actual_keys_input=None, unique_indices=None, key_count=None):
|
|
248
|
+
"""
|
|
249
|
+
Using the corresponding query method to calculate the PS embedding for each key.
|
|
250
|
+
"""
|
|
214
251
|
origin_shape = None
|
|
215
252
|
if len(keys.shape) != 1:
|
|
216
253
|
origin_shape = keys.shape
|
|
@@ -227,11 +264,9 @@ class EsEmbeddingLookup(nn.Cell):
|
|
|
227
264
|
key_count = keys
|
|
228
265
|
if self.training:
|
|
229
266
|
if use_host_unique:
|
|
230
|
-
output = fake_remote_lookup_uniqued(table_id=self.table_id,
|
|
231
|
-
keys=keys,
|
|
267
|
+
output = fake_remote_lookup_uniqued(table_id=self.table_id, keys=keys,
|
|
232
268
|
actual_keys_num=actual_keys_input,
|
|
233
|
-
unique_indices=unique_indices,
|
|
234
|
-
key_count=key_count,
|
|
269
|
+
unique_indices=unique_indices, key_count=key_count,
|
|
235
270
|
max_grad_norm=self.max_grad_norm,
|
|
236
271
|
embedding_dim=self.embedding_dim,
|
|
237
272
|
initializer_mode=self.es_initializer.initializer_mode,
|
|
@@ -250,8 +285,7 @@ class EsEmbeddingLookup(nn.Cell):
|
|
|
250
285
|
default_value=self.default_value,
|
|
251
286
|
optimizer_mode=self.optimizer_mode,
|
|
252
287
|
optimizer_params=self.optimizer_params,
|
|
253
|
-
_max_key_num=self.max_key_num,
|
|
254
|
-
_table_id=self._table_id,
|
|
288
|
+
_max_key_num=self.max_key_num, _table_id=self._table_id,
|
|
255
289
|
_use_counter_filter=use_counter_filter,
|
|
256
290
|
backward_mode=self.optimizer_mode,
|
|
257
291
|
backward_int_params=self.backward_int_params,
|
|
@@ -280,8 +314,7 @@ class EsEmbeddingLookup(nn.Cell):
|
|
|
280
314
|
default_value=self.default_value,
|
|
281
315
|
optimizer_mode=self.optimizer_mode,
|
|
282
316
|
optimizer_params=self.optimizer_params,
|
|
283
|
-
_max_key_num=self.max_key_num,
|
|
284
|
-
_table_id=self._table_id,
|
|
317
|
+
_max_key_num=self.max_key_num, _table_id=self._table_id,
|
|
285
318
|
_use_counter_filter=use_counter_filter,
|
|
286
319
|
backward_mode=self.optimizer_mode,
|
|
287
320
|
backward_int_params=self.backward_int_params,
|
|
@@ -290,14 +323,10 @@ class EsEmbeddingLookup(nn.Cell):
|
|
|
290
323
|
completion_key_mask=self.completion_key_mask,
|
|
291
324
|
parameter=self.b)
|
|
292
325
|
else:
|
|
293
|
-
output = embedding_table_find(self.table_id, keys,
|
|
294
|
-
embedding_dim=self.embedding_dim,
|
|
326
|
+
output = embedding_table_find(self.table_id, keys, embedding_dim=self.embedding_dim,
|
|
295
327
|
default_value=self.default_value,
|
|
296
|
-
_max_key_num=self.max_key_num,
|
|
297
|
-
_table_id=self._table_id,
|
|
328
|
+
_max_key_num=self.max_key_num, _table_id=self._table_id,
|
|
298
329
|
_use_counter_filter=use_counter_filter)
|
|
299
|
-
# input 20480 2 ->41960
|
|
300
|
-
# output 41960 embedding_dim -> 20480 2 embedding_dim
|
|
301
330
|
if origin_shape is not None:
|
|
302
331
|
output = self.reshape(output, origin_shape + (-1,))
|
|
303
332
|
return output
|
|
@@ -321,10 +350,10 @@ class ESEmbeddingCKPTExport(nn.Cell):
|
|
|
321
350
|
self.table_id_tensor = Tensor(table_id_list, ms.int32)
|
|
322
351
|
self.depend = ops.Depend()
|
|
323
352
|
|
|
324
|
-
def construct(self):
|
|
325
|
-
export_op1 = self.embedding_table_export(self.file_path, self.ps_id_tensor, self.table_id_tensor)
|
|
353
|
+
def construct(self, global_step):
|
|
354
|
+
export_op1 = self.embedding_table_export(self.file_path, self.ps_id_tensor, self.table_id_tensor, global_step)
|
|
326
355
|
z = self.depend(self.file_path, export_op1)
|
|
327
|
-
export_op2 = self.embedding_compute_var_export(z, self.ps_id_tensor, self.table_id_tensor)
|
|
356
|
+
export_op2 = self.embedding_compute_var_export(z, self.ps_id_tensor, self.table_id_tensor, global_step)
|
|
328
357
|
return export_op2
|
|
329
358
|
|
|
330
359
|
|
|
@@ -345,8 +374,31 @@ class ESEmbeddingTableExport(nn.Cell):
|
|
|
345
374
|
self.ps_id_tensor = Tensor(0, ms.int32)
|
|
346
375
|
self.table_id_tensor = Tensor(table_id_list, ms.int32)
|
|
347
376
|
|
|
348
|
-
def construct(self):
|
|
349
|
-
y = self.op(self.file_path, self.ps_id_tensor, self.table_id_tensor)
|
|
377
|
+
def construct(self, global_step):
|
|
378
|
+
y = self.op(self.file_path, self.ps_id_tensor, self.table_id_tensor, global_step)
|
|
379
|
+
return y
|
|
380
|
+
|
|
381
|
+
|
|
382
|
+
class ESIncrementalEmbeddingTableExport(nn.Cell):
|
|
383
|
+
"""
|
|
384
|
+
ESIncrementalEmbeddingTableExport.
|
|
385
|
+
"""
|
|
386
|
+
def __init__(self, embedding_dim_list, value_total_len_list, table_name_list, table_id_list,
|
|
387
|
+
file_path, steps_to_live_list):
|
|
388
|
+
super(ESIncrementalEmbeddingTableExport, self).__init__()
|
|
389
|
+
self.op = EmbeddingTableExport(
|
|
390
|
+
embedding_dim_list,
|
|
391
|
+
value_total_len_list,
|
|
392
|
+
table_name=table_name_list,
|
|
393
|
+
steps_to_live_list=steps_to_live_list,
|
|
394
|
+
export_mode="new",
|
|
395
|
+
only_var_flag=True)
|
|
396
|
+
self.file_path = Tensor(np.array(file_path))
|
|
397
|
+
self.ps_id_tensor = Tensor(0, ms.int32)
|
|
398
|
+
self.table_id_tensor = Tensor(table_id_list, ms.int32)
|
|
399
|
+
|
|
400
|
+
def construct(self, global_step):
|
|
401
|
+
y = self.op(self.file_path, self.ps_id_tensor, self.table_id_tensor, global_step)
|
|
350
402
|
return y
|
|
351
403
|
|
|
352
404
|
|
|
@@ -366,10 +418,10 @@ class ESEmbeddingCKPTImport(nn.Cell):
|
|
|
366
418
|
self.table_id_tensor = Tensor(table_id_list, ms.int32)
|
|
367
419
|
self.depend = ops.Depend()
|
|
368
420
|
|
|
369
|
-
def construct(self):
|
|
370
|
-
export_op1 = self.embedding_table_import(self.file_path, self.ps_id_tensor, self.table_id_tensor)
|
|
421
|
+
def construct(self, global_step):
|
|
422
|
+
export_op1 = self.embedding_table_import(self.file_path, self.ps_id_tensor, self.table_id_tensor, global_step)
|
|
371
423
|
z = self.depend(self.file_path, export_op1)
|
|
372
|
-
export_op2 = self.embedding_compute_var_import(z, self.ps_id_tensor, self.table_id_tensor)
|
|
424
|
+
export_op2 = self.embedding_compute_var_import(z, self.ps_id_tensor, self.table_id_tensor, global_step)
|
|
373
425
|
return export_op2
|
|
374
426
|
|
|
375
427
|
|
|
@@ -388,6 +440,142 @@ class ESEmbeddingTableImport(nn.Cell):
|
|
|
388
440
|
self.ps_id_tensor = Tensor(0, ms.int32)
|
|
389
441
|
self.table_id_tensor = Tensor(table_id_list, ms.int32)
|
|
390
442
|
|
|
443
|
+
def construct(self, global_step):
|
|
444
|
+
y = self.op(self.file_path, self.ps_id_tensor, self.table_id_tensor, global_step)
|
|
445
|
+
return y
|
|
446
|
+
|
|
447
|
+
|
|
448
|
+
class ESEmbeddingTableEvict(nn.Cell):
|
|
449
|
+
"""
|
|
450
|
+
ESEmbeddingTableEvict.
|
|
451
|
+
"""
|
|
452
|
+
def __init__(self, var_handle, global_step, steps_to_live):
|
|
453
|
+
super(ESEmbeddingTableEvict, self).__init__()
|
|
454
|
+
self.op = EmbeddingTableEvict()
|
|
455
|
+
self.var_handle = Tensor(var_handle, ms.int32)
|
|
456
|
+
self.global_step = global_step
|
|
457
|
+
self.steps_to_live = steps_to_live
|
|
458
|
+
|
|
391
459
|
def construct(self):
|
|
392
|
-
y = self.op(self.
|
|
460
|
+
y = self.op(self.var_handle, self.global_step, self.steps_to_live)
|
|
393
461
|
return y
|
|
462
|
+
|
|
463
|
+
|
|
464
|
+
class ESEmbeddingFeatureMappingExport(nn.Cell):
|
|
465
|
+
"""
|
|
466
|
+
ESEmbeddingFeatureMappingExport.
|
|
467
|
+
"""
|
|
468
|
+
def __init__(self, file_path, export_value, var, var_name, small_table_embedding_dim):
|
|
469
|
+
super(ESEmbeddingFeatureMappingExport, self).__init__()
|
|
470
|
+
self.embedding_feature_mapping_table_size = EmbeddingFeatureMappingTableSize()
|
|
471
|
+
self.embedding_feature_mapping_find = EmbeddingFeatureMappingFind()
|
|
472
|
+
self.embedding_feature_mapping_export = EmbeddingFeatureMappingExport()
|
|
473
|
+
self.file_path = file_path
|
|
474
|
+
self.export_value = export_value
|
|
475
|
+
self.gather = ops.Gather()
|
|
476
|
+
self.var = Tensor(var, ms.float32)
|
|
477
|
+
self.var_name = Tensor(np.array([var_name]))
|
|
478
|
+
self.small_table_embedding_dim = [small_table_embedding_dim]
|
|
479
|
+
self.global_step = Tensor([-1], ms.int64)
|
|
480
|
+
|
|
481
|
+
def construct(self):
|
|
482
|
+
"""
|
|
483
|
+
ESEmbeddingFeatureMappingExport construct: export feature mapping for data_parallel embedding.
|
|
484
|
+
"""
|
|
485
|
+
feature_size = self.embedding_feature_mapping_table_size(self.var_name)
|
|
486
|
+
feature_id, offset_id = self.embedding_feature_mapping_find(self.var_name, feature_size, 1)
|
|
487
|
+
values = self.gather(self.var, offset_id, 0)
|
|
488
|
+
if self.export_value:
|
|
489
|
+
embed_values = values
|
|
490
|
+
else:
|
|
491
|
+
embed_values = Tensor([0], ms.float32)
|
|
492
|
+
feature_mapping_export = self.embedding_feature_mapping_export(self.file_path, self.var_name, self.global_step,
|
|
493
|
+
embed_values, self.small_table_embedding_dim,
|
|
494
|
+
[feature_id], [offset_id])
|
|
495
|
+
return feature_mapping_export
|
|
496
|
+
|
|
497
|
+
|
|
498
|
+
class ESEmbeddingFeatureMappingImport(nn.Cell):
|
|
499
|
+
"""
|
|
500
|
+
ESEmbeddingFeatureMappingImport.
|
|
501
|
+
"""
|
|
502
|
+
def __init__(self, file_path, small_table_name, small_table_embedding_dim, only_offset_flag):
|
|
503
|
+
super(ESEmbeddingFeatureMappingImport, self).__init__()
|
|
504
|
+
self.embedding_feature_mapping_file_size = EmbeddingFeatureMappingFileSize()
|
|
505
|
+
self.embedding_feature_mapping_import = EmbeddingFeatureMappingImport()
|
|
506
|
+
self.embedding_feature_mapping_insert = EmbeddingFeatureMappingInsert()
|
|
507
|
+
self.file_path = file_path
|
|
508
|
+
self.small_table_name = Tensor(np.array([small_table_name]))
|
|
509
|
+
self.small_table_embedding_dim = [small_table_embedding_dim]
|
|
510
|
+
self.only_offset_flag = only_offset_flag
|
|
511
|
+
self.global_step = Tensor([-1], ms.int64)
|
|
512
|
+
|
|
513
|
+
def construct(self):
|
|
514
|
+
"""
|
|
515
|
+
ESEmbeddingFeatureMappingImport construct: import feature mapping for data_parallel embedding.
|
|
516
|
+
"""
|
|
517
|
+
feature_size = self.embedding_feature_mapping_file_size(self.file_path,
|
|
518
|
+
self.small_table_name,
|
|
519
|
+
self.global_step,
|
|
520
|
+
self.small_table_embedding_dim,
|
|
521
|
+
self.only_offset_flag)
|
|
522
|
+
feature_id, offset_id = self.embedding_feature_mapping_import(self.file_path,
|
|
523
|
+
self.small_table_name,
|
|
524
|
+
feature_size, self.global_step,
|
|
525
|
+
self.small_table_embedding_dim,
|
|
526
|
+
self.only_offset_flag, 1)
|
|
527
|
+
feature_mapping_insert = self.embedding_feature_mapping_insert(self.small_table_name, 1,
|
|
528
|
+
[feature_id], [offset_id])
|
|
529
|
+
return feature_mapping_insert
|
|
530
|
+
|
|
531
|
+
|
|
532
|
+
class ESEmbeddingSmallTableLookup(nn.Cell):
|
|
533
|
+
r"""
|
|
534
|
+
Look up a data_parallel embedding.
|
|
535
|
+
|
|
536
|
+
.. warning::
|
|
537
|
+
This is an experimental EmbeddingService API that is subject to change.
|
|
538
|
+
|
|
539
|
+
Args:
|
|
540
|
+
name (str): The data_parallel embedding name.
|
|
541
|
+
rank_id (int): The rank id when look up data_parallel embedding key.
|
|
542
|
+
rank_size (int): The rank size when look up data_parallel embedding key.
|
|
543
|
+
small_table_to_variable (dict[str, parameter]): The dict to restore data_parallel embedding information:
|
|
544
|
+
key is table name, value is parameter.
|
|
545
|
+
|
|
546
|
+
Inputs:
|
|
547
|
+
- **ids_list** (Tensor) - The keys of each feature in data_parallel embedding.
|
|
548
|
+
|
|
549
|
+
Supported Platforms:
|
|
550
|
+
``Atlas A2 training series products``
|
|
551
|
+
"""
|
|
552
|
+
|
|
553
|
+
def __init__(self, name, rank_id, rank_size, small_table_to_variable):
|
|
554
|
+
super(ESEmbeddingSmallTableLookup, self).__init__()
|
|
555
|
+
self.small_table_to_variable = small_table_to_variable[name]
|
|
556
|
+
self.small_table_to_variable.feature_name = name
|
|
557
|
+
self.allgather = ops.AllGather()
|
|
558
|
+
self.gather = ops.Gather()
|
|
559
|
+
self.embedding_feature_mapping_v2 = EmbeddingFeatureMappingV2()
|
|
560
|
+
self.name = name
|
|
561
|
+
self.rank_id = rank_id
|
|
562
|
+
self.rank_size = rank_size
|
|
563
|
+
|
|
564
|
+
def construct(self, ids_list):
|
|
565
|
+
"""
|
|
566
|
+
Using the EmbeddingFeatureMappingV2 method to mapping hash key to non hash key, and then get embedding value.
|
|
567
|
+
"""
|
|
568
|
+
hash_key_shape = ids_list.shape
|
|
569
|
+
if self.rank_size > 1 and (hash_key_shape[0] is not None):
|
|
570
|
+
hash_key = ops.stop_gradient(self.allgather(ids_list))
|
|
571
|
+
non_hash_key = self.embedding_feature_mapping_v2(self.name, hash_key, [1], [1])
|
|
572
|
+
recovery_matrix = []
|
|
573
|
+
for i in range(hash_key_shape[0]):
|
|
574
|
+
recovery_matrix.append(self.rank_id * hash_key_shape[0] + i)
|
|
575
|
+
local_non_hash_keys = self.gather(non_hash_key, Tensor(recovery_matrix), 0)
|
|
576
|
+
else:
|
|
577
|
+
hash_key = ids_list
|
|
578
|
+
local_non_hash_keys = self.embedding_feature_mapping_v2(self.name, hash_key, [1], [1])
|
|
579
|
+
|
|
580
|
+
embedding = self.gather(self.small_table_to_variable, local_non_hash_keys, 0)
|
|
581
|
+
return embedding
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""LlmBoost Register"""
|
|
16
|
+
from __future__ import absolute_import
|
|
17
|
+
|
|
18
|
+
from mindspore.experimental.llm_boost.atb import *
|
|
19
|
+
from mindspore.experimental.llm_boost.register import LlmBoostRegister
|
|
20
|
+
|
|
21
|
+
__all__ = ['LlmBoostRegister']
|
|
@@ -13,15 +13,11 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ============================================================================
|
|
15
15
|
"""
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
The high-level components(Cells) used to construct the neural network.
|
|
16
|
+
Provide llm boost for inference, such as LlamaBoost.
|
|
19
17
|
"""
|
|
20
18
|
from __future__ import absolute_import
|
|
21
19
|
|
|
22
|
-
from mindspore.
|
|
23
|
-
from mindspore.
|
|
24
|
-
|
|
25
|
-
__all__ = []
|
|
20
|
+
from mindspore.experimental.llm_boost.atb.llama_boost import LlamaBoost
|
|
21
|
+
from mindspore.experimental.llm_boost.atb.qwen_boost import QwenBoost
|
|
26
22
|
|
|
27
|
-
__all__
|
|
23
|
+
__all__ = ['LlamaBoost', 'QwenBoost']
|
|
@@ -0,0 +1,211 @@
|
|
|
1
|
+
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""boost base class"""
|
|
16
|
+
import numpy as np
|
|
17
|
+
import mindspore as ms
|
|
18
|
+
from mindspore import ops, Tensor
|
|
19
|
+
from mindspore.ops import operations as P
|
|
20
|
+
import mindspore.common.dtype as mstype
|
|
21
|
+
from mindspore._c_expression import _set_format
|
|
22
|
+
|
|
23
|
+
from mindspore.common.parameter import Parameter
|
|
24
|
+
from mindspore.experimental.llm_boost.utils import get_real_rank, get_real_group_size
|
|
25
|
+
from mindspore.common.initializer import Zero
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class AttentionMask:
|
|
29
|
+
"""attention mask"""
|
|
30
|
+
|
|
31
|
+
@classmethod
|
|
32
|
+
def static(cls, max_seq_len, dtype=mstype.float16, need_nz=False):
|
|
33
|
+
"""cache mask"""
|
|
34
|
+
bias_cache = Tensor(np.tril(np.ones((max_seq_len, max_seq_len), dtype=np.bool_))).reshape(max_seq_len,
|
|
35
|
+
max_seq_len)
|
|
36
|
+
bias_cache = ~bias_cache
|
|
37
|
+
if dtype == mstype.float16:
|
|
38
|
+
mask_value = Tensor(np.finfo(np.float32).min, mstype.float16)
|
|
39
|
+
else:
|
|
40
|
+
mask_value = Tensor(1)
|
|
41
|
+
attn_mask = ops.masked_fill(Tensor(np.zeros(
|
|
42
|
+
(max_seq_len, max_seq_len)), dtype=mstype.float16), bias_cache, mask_value)
|
|
43
|
+
if need_nz:
|
|
44
|
+
# ND -> NZ
|
|
45
|
+
attn_mask = ops.reshape(attn_mask, (1, max_seq_len, max_seq_len))
|
|
46
|
+
attn_mask = ops.reshape(
|
|
47
|
+
attn_mask, (1, max_seq_len, max_seq_len // 16, 16))
|
|
48
|
+
attn_mask = ops.transpose(attn_mask, (0, 2, 1, 3)).contiguous()
|
|
49
|
+
attn_mask = _set_format(attn_mask, "FRACTAL_NZ")
|
|
50
|
+
return attn_mask
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class AtbBoostBase():
|
|
54
|
+
"""atb boost base class"""
|
|
55
|
+
|
|
56
|
+
def __init__(self, config):
|
|
57
|
+
super().__init__()
|
|
58
|
+
self.is_first_iteration = False
|
|
59
|
+
self.config = config
|
|
60
|
+
self.dtype = config.compute_dtype
|
|
61
|
+
self.num_heads = config.num_heads
|
|
62
|
+
self.num_kv_heads = config.n_kv_heads if config.n_kv_heads else self.num_heads
|
|
63
|
+
self.num_layers = config.num_layers
|
|
64
|
+
self.n_kv_heads = config.n_kv_heads if config.n_kv_heads else config.num_heads
|
|
65
|
+
self.head_dim = config.hidden_size // self.num_heads
|
|
66
|
+
self.need_nz = False
|
|
67
|
+
if hasattr(config, "need_nz"):
|
|
68
|
+
self.need_nz = config.need_nz
|
|
69
|
+
self.placeholder = Tensor(np.zeros(1), dtype=self.dtype)
|
|
70
|
+
self.lm_head_indices_fake = Tensor([0], dtype=mstype.int64)
|
|
71
|
+
self.position_embedding_type = "ROPE"
|
|
72
|
+
self.add_norm_enable = True
|
|
73
|
+
self.max_decode_length = self.config.max_decode_length
|
|
74
|
+
self.max_base_len = 128
|
|
75
|
+
self.attn_mask = AttentionMask.static(
|
|
76
|
+
self.max_base_len, dtype=self.dtype, need_nz=self.need_nz)
|
|
77
|
+
|
|
78
|
+
self.cast = P.Cast()
|
|
79
|
+
self.reshape = P.Reshape()
|
|
80
|
+
self.kv_quant = None
|
|
81
|
+
self.rank_id = get_real_rank()
|
|
82
|
+
self.device_num = get_real_group_size()
|
|
83
|
+
|
|
84
|
+
def _convert_tensor_format_and_dtype(self, tensor, dtype=mstype.float16):
|
|
85
|
+
tensor = self.cast(tensor, dtype=dtype)
|
|
86
|
+
if self.need_nz:
|
|
87
|
+
tensor = _set_format(tensor, "FRACTAL_NZ")
|
|
88
|
+
return tensor
|
|
89
|
+
|
|
90
|
+
def set_weights(self, parm_dict, dtype=mstype.float16):
|
|
91
|
+
"""set weights for llm boost"""
|
|
92
|
+
embedding_weight_name = "model.tok_embeddings.embedding_weight"
|
|
93
|
+
attention_norm_name = "attention_norm"
|
|
94
|
+
qkv_name = "attention.w_qkv"
|
|
95
|
+
o_name = "attention.wo"
|
|
96
|
+
mlp_norm_name = "ffn_norm"
|
|
97
|
+
mlp_gate_name = "feed_forward.w_gate_hidden"
|
|
98
|
+
mlp_down_name = "feed_forward.w2"
|
|
99
|
+
norm_out_name = "model.norm_out"
|
|
100
|
+
lm_head_name = "lm_head"
|
|
101
|
+
placeholder = Parameter(Tensor(np.zeros(1), dtype=dtype))
|
|
102
|
+
|
|
103
|
+
ascend_weight = []
|
|
104
|
+
ascend_weight.append(
|
|
105
|
+
self.cast(parm_dict[embedding_weight_name], dtype))
|
|
106
|
+
for i in range(self.num_layers):
|
|
107
|
+
ascend_weight.append(self._convert_tensor_format_and_dtype(
|
|
108
|
+
parm_dict[f"model.layers.{i}.{attention_norm_name}.weight"], dtype))
|
|
109
|
+
ascend_weight.extend([placeholder] * 3)
|
|
110
|
+
|
|
111
|
+
ascend_weight.append(
|
|
112
|
+
self._convert_tensor_format_and_dtype(parm_dict[f"model.layers.{i}.{qkv_name}.weight"], dtype))
|
|
113
|
+
ascend_weight.append(self._convert_tensor_format_and_dtype(parm_dict.get(
|
|
114
|
+
f"model.layers.{i}.{qkv_name}.bias", placeholder), dtype))
|
|
115
|
+
ascend_weight.extend([placeholder] * 16)
|
|
116
|
+
|
|
117
|
+
ascend_weight.append(
|
|
118
|
+
self._convert_tensor_format_and_dtype(parm_dict[f"model.layers.{i}.{o_name}.weight"], dtype))
|
|
119
|
+
ascend_weight.append(self._convert_tensor_format_and_dtype(parm_dict.get(
|
|
120
|
+
f"model.layers.{i}.{o_name}.bias", placeholder), dtype))
|
|
121
|
+
ascend_weight.extend([placeholder] * 4)
|
|
122
|
+
|
|
123
|
+
ascend_weight.append(
|
|
124
|
+
self._convert_tensor_format_and_dtype(parm_dict[f"model.layers.{i}.{mlp_norm_name}.weight"], dtype))
|
|
125
|
+
ascend_weight.extend([placeholder] * 3)
|
|
126
|
+
|
|
127
|
+
ascend_weight.append(
|
|
128
|
+
self._convert_tensor_format_and_dtype(parm_dict[f"model.layers.{i}.{mlp_gate_name}.weight"], dtype))
|
|
129
|
+
ascend_weight.append(self._convert_tensor_format_and_dtype(parm_dict.get(
|
|
130
|
+
f"model.layers.{i}.{mlp_gate_name}.bias", placeholder), dtype))
|
|
131
|
+
ascend_weight.extend([placeholder] * 10)
|
|
132
|
+
|
|
133
|
+
ascend_weight.append(
|
|
134
|
+
self._convert_tensor_format_and_dtype(parm_dict[f"model.layers.{i}.{mlp_down_name}.weight"], dtype))
|
|
135
|
+
ascend_weight.append(self._convert_tensor_format_and_dtype(parm_dict.get(
|
|
136
|
+
f"model.layers.{i}.{mlp_down_name}.bias", placeholder), dtype))
|
|
137
|
+
ascend_weight.extend([placeholder] * 4)
|
|
138
|
+
|
|
139
|
+
ascend_weight.append(
|
|
140
|
+
self._convert_tensor_format_and_dtype(parm_dict[f"{norm_out_name}.weight"], dtype))
|
|
141
|
+
ascend_weight.append(
|
|
142
|
+
self._convert_tensor_format_and_dtype(parm_dict[f"{lm_head_name}.weight"], dtype))
|
|
143
|
+
self.atb_encoder_operation.set_weights(ascend_weight)
|
|
144
|
+
self.atb_decoder_operation.set_weights(ascend_weight)
|
|
145
|
+
|
|
146
|
+
def set_kvcache(self, k_caches=None, v_caches=None):
|
|
147
|
+
"""set kv_cache for llm boost"""
|
|
148
|
+
if not k_caches or v_caches:
|
|
149
|
+
if self.need_nz:
|
|
150
|
+
kv_shape = (self.config.num_blocks, self.num_kv_heads*self.head_dim //
|
|
151
|
+
self.device_num // 16, self.config.block_size, 16)
|
|
152
|
+
k_caches = [_set_format(Parameter(Tensor(
|
|
153
|
+
shape=kv_shape, dtype=self.dtype, init=Zero())), "FRACTAL_NZ") for _ in range(self.num_layers)]
|
|
154
|
+
v_caches = [_set_format(Parameter(Tensor(
|
|
155
|
+
shape=kv_shape, dtype=self.dtype, init=Zero())), "FRACTAL_NZ") for _ in range(self.num_layers)]
|
|
156
|
+
else:
|
|
157
|
+
kv_shape = (self.config.num_blocks, self.config.block_size,
|
|
158
|
+
self.num_kv_heads // self.device_num, self.head_dim)
|
|
159
|
+
k_caches = [Parameter(Tensor(
|
|
160
|
+
shape=kv_shape, dtype=self.dtype, init=Zero())) for _ in range(self.num_layers)]
|
|
161
|
+
v_caches = [Parameter(Tensor(
|
|
162
|
+
shape=kv_shape, dtype=self.dtype, init=Zero())) for _ in range(self.num_layers)]
|
|
163
|
+
|
|
164
|
+
self.atb_encoder_operation.set_kvcache(k_caches, v_caches)
|
|
165
|
+
self.atb_decoder_operation.set_kvcache(k_caches, v_caches)
|
|
166
|
+
|
|
167
|
+
def add_flags(self, is_first_iteration):
|
|
168
|
+
"""add_flags."""
|
|
169
|
+
self.is_first_iteration = is_first_iteration
|
|
170
|
+
|
|
171
|
+
def _execute_operator(self, acl_inputs, acl_param):
|
|
172
|
+
"""execute operator."""
|
|
173
|
+
if self.is_first_iteration:
|
|
174
|
+
acl_model_out = self.atb_encoder_operation.forward(
|
|
175
|
+
acl_inputs, acl_param)
|
|
176
|
+
else:
|
|
177
|
+
acl_model_out = self.atb_decoder_operation.forward(
|
|
178
|
+
acl_inputs, acl_param)
|
|
179
|
+
acl_hidden_state = acl_model_out[0]
|
|
180
|
+
return acl_hidden_state
|
|
181
|
+
|
|
182
|
+
def forward(self, boost_inputs):
|
|
183
|
+
r"""
|
|
184
|
+
LlmBoost forward.
|
|
185
|
+
"""
|
|
186
|
+
input_ids = boost_inputs["input_ids"]
|
|
187
|
+
position_ids = boost_inputs["position_ids"]
|
|
188
|
+
cos_embed = boost_inputs["cos_embed"]
|
|
189
|
+
sin_embed = boost_inputs["sin_embed"]
|
|
190
|
+
block_tables = boost_inputs["block_tables"]
|
|
191
|
+
slot_mapping = boost_inputs["slot_mapping"]
|
|
192
|
+
batch_valid_length = boost_inputs["batch_valid_length"]
|
|
193
|
+
lm_head_indices = boost_inputs["lm_head_indices"]
|
|
194
|
+
seqLen = boost_inputs["seq_lens"]
|
|
195
|
+
if self.is_first_iteration:
|
|
196
|
+
attention_mask = self.attn_mask
|
|
197
|
+
else:
|
|
198
|
+
position_ids = batch_valid_length - 1
|
|
199
|
+
attention_mask = self.placeholder
|
|
200
|
+
lm_head_indices = self.lm_head_indices_fake
|
|
201
|
+
|
|
202
|
+
acl_inputs, acl_param = self._prepare_inputs(prefill=self.is_first_iteration, input_ids=input_ids,
|
|
203
|
+
position_ids=position_ids, cos_embed=cos_embed,
|
|
204
|
+
sin_embed=sin_embed, attention_mask=attention_mask,
|
|
205
|
+
block_tables=block_tables, slots=slot_mapping,
|
|
206
|
+
input_lengths=batch_valid_length, lm_head_indices=lm_head_indices,
|
|
207
|
+
seqLen=seqLen)
|
|
208
|
+
ms.hal.synchronize()
|
|
209
|
+
logits = self._execute_operator(acl_inputs, acl_param)
|
|
210
|
+
logits = self.cast(logits, mstype.float32)
|
|
211
|
+
return logits
|