mindspore 2.3.0__cp310-cp310-win_amd64.whl → 2.4.0__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +3 -1
- mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +50 -9
- mindspore/_extends/parse/compile_config.py +41 -0
- mindspore/_extends/parse/parser.py +9 -7
- mindspore/_extends/parse/standard_method.py +52 -14
- mindspore/_extends/pijit/pijit_func_white_list.py +350 -24
- mindspore/amp.py +24 -10
- mindspore/atlprov.dll +0 -0
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +6 -4
- mindspore/common/_pijit_context.py +190 -0
- mindspore/common/_register_for_tensor.py +2 -1
- mindspore/common/_tensor_overload.py +139 -0
- mindspore/common/api.py +102 -87
- mindspore/common/dump.py +5 -6
- mindspore/common/generator.py +1 -7
- mindspore/common/hook_handle.py +14 -26
- mindspore/common/mindir_util.py +2 -2
- mindspore/common/parameter.py +46 -13
- mindspore/common/recompute.py +39 -9
- mindspore/common/sparse_tensor.py +7 -3
- mindspore/common/tensor.py +209 -29
- mindspore/communication/__init__.py +1 -1
- mindspore/communication/_comm_helper.py +38 -3
- mindspore/communication/comm_func.py +310 -55
- mindspore/communication/management.py +14 -14
- mindspore/context.py +123 -22
- mindspore/dataset/__init__.py +1 -1
- mindspore/dataset/audio/__init__.py +1 -1
- mindspore/dataset/core/config.py +7 -0
- mindspore/dataset/core/validator_helpers.py +7 -0
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +72 -44
- mindspore/dataset/engine/datasets_audio.py +7 -7
- mindspore/dataset/engine/datasets_standard_format.py +53 -3
- mindspore/dataset/engine/datasets_text.py +20 -20
- mindspore/dataset/engine/datasets_user_defined.py +174 -104
- mindspore/dataset/engine/datasets_vision.py +33 -33
- mindspore/dataset/engine/iterators.py +29 -0
- mindspore/dataset/engine/obs/util.py +7 -0
- mindspore/dataset/engine/queue.py +114 -60
- mindspore/dataset/engine/serializer_deserializer.py +2 -2
- mindspore/dataset/engine/validators.py +34 -14
- mindspore/dataset/text/__init__.py +1 -4
- mindspore/dataset/transforms/__init__.py +0 -3
- mindspore/dataset/utils/line_reader.py +2 -0
- mindspore/dataset/vision/__init__.py +1 -4
- mindspore/dataset/vision/utils.py +1 -1
- mindspore/dataset/vision/validators.py +2 -1
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/{nn/extend → experimental/es}/__init__.py +4 -11
- mindspore/experimental/es/embedding_service.py +883 -0
- mindspore/{nn/layer → experimental/es}/embedding_service_layer.py +218 -30
- mindspore/experimental/llm_boost/__init__.py +21 -0
- mindspore/{nn/extend/layer → experimental/llm_boost/atb}/__init__.py +4 -8
- mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
- mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
- mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
- mindspore/experimental/llm_boost/register.py +129 -0
- mindspore/experimental/llm_boost/utils.py +31 -0
- mindspore/experimental/optim/adamw.py +85 -0
- mindspore/experimental/optim/optimizer.py +3 -0
- mindspore/hal/__init__.py +3 -3
- mindspore/hal/contiguous_tensors_handle.py +175 -0
- mindspore/hal/stream.py +18 -0
- mindspore/include/api/model_group.h +13 -1
- mindspore/include/api/types.h +10 -10
- mindspore/include/dataset/config.h +2 -2
- mindspore/include/dataset/constants.h +2 -2
- mindspore/include/dataset/execute.h +2 -2
- mindspore/include/dataset/vision.h +4 -0
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +1 -1
- mindspore/mindrecord/filewriter.py +68 -51
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mint/__init__.py +495 -46
- mindspore/mint/distributed/__init__.py +31 -0
- mindspore/mint/distributed/distributed.py +254 -0
- mindspore/mint/nn/__init__.py +266 -21
- mindspore/mint/nn/functional.py +125 -19
- mindspore/mint/nn/layer/__init__.py +39 -0
- mindspore/mint/nn/layer/activation.py +133 -0
- mindspore/mint/nn/layer/normalization.py +477 -0
- mindspore/mint/nn/layer/pooling.py +110 -0
- mindspore/mint/optim/adamw.py +28 -7
- mindspore/mint/special/__init__.py +63 -0
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/multiprocessing/__init__.py +2 -1
- mindspore/nn/__init__.py +0 -1
- mindspore/nn/cell.py +275 -93
- mindspore/nn/layer/activation.py +211 -44
- mindspore/nn/layer/basic.py +113 -3
- mindspore/nn/layer/embedding.py +120 -2
- mindspore/nn/layer/normalization.py +101 -5
- mindspore/nn/layer/padding.py +34 -48
- mindspore/nn/layer/pooling.py +161 -7
- mindspore/nn/layer/transformer.py +3 -3
- mindspore/nn/loss/__init__.py +2 -2
- mindspore/nn/loss/loss.py +84 -6
- mindspore/nn/optim/__init__.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -1
- mindspore/nn/optim/adam.py +1 -1
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/tft_wrapper.py +127 -0
- mindspore/nn/wrap/cell_wrapper.py +12 -23
- mindspore/nn/wrap/grad_reducer.py +5 -5
- mindspore/nn/wrap/loss_scale.py +17 -3
- mindspore/numpy/__init__.py +1 -1
- mindspore/numpy/array_creations.py +65 -68
- mindspore/numpy/array_ops.py +64 -60
- mindspore/numpy/fft.py +610 -75
- mindspore/numpy/logic_ops.py +11 -10
- mindspore/numpy/math_ops.py +85 -84
- mindspore/numpy/utils_const.py +4 -4
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +6 -4
- mindspore/ops/_grad_experimental/grad_comm_ops.py +47 -3
- mindspore/ops/_grad_experimental/grad_math_ops.py +0 -22
- mindspore/ops/_vmap/vmap_array_ops.py +2 -4
- mindspore/ops/_vmap/vmap_math_ops.py +17 -1
- mindspore/ops/_vmap/vmap_nn_ops.py +43 -2
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +85 -7
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +2 -0
- mindspore/ops/auto_generate/gen_extend_func.py +734 -13
- mindspore/ops/auto_generate/gen_ops_def.py +2420 -381
- mindspore/ops/auto_generate/gen_ops_prim.py +5196 -1659
- mindspore/ops/auto_generate/pyboost_inner_prim.py +176 -56
- mindspore/ops/composite/base.py +85 -48
- mindspore/ops/composite/multitype_ops/_compile_utils.py +1 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -2
- mindspore/ops/function/__init__.py +22 -0
- mindspore/ops/function/array_func.py +490 -153
- mindspore/ops/function/debug_func.py +113 -1
- mindspore/ops/function/fft_func.py +15 -2
- mindspore/ops/function/grad/grad_func.py +3 -2
- mindspore/ops/function/math_func.py +558 -207
- mindspore/ops/function/nn_func.py +817 -383
- mindspore/ops/function/other_func.py +3 -2
- mindspore/ops/function/random_func.py +184 -8
- mindspore/ops/function/reshard_func.py +13 -11
- mindspore/ops/function/sparse_unary_func.py +1 -1
- mindspore/ops/function/vmap_func.py +3 -2
- mindspore/ops/functional.py +24 -14
- mindspore/ops/op_info_register.py +3 -3
- mindspore/ops/operations/__init__.py +6 -1
- mindspore/ops/operations/_grad_ops.py +2 -76
- mindspore/ops/operations/_infer_ops.py +1 -1
- mindspore/ops/operations/_inner_ops.py +71 -94
- mindspore/ops/operations/array_ops.py +12 -146
- mindspore/ops/operations/comm_ops.py +42 -53
- mindspore/ops/operations/custom_ops.py +83 -19
- mindspore/ops/operations/debug_ops.py +42 -10
- mindspore/ops/operations/manually_defined/_inner.py +12 -0
- mindspore/ops/operations/manually_defined/ops_def.py +265 -10
- mindspore/ops/operations/math_ops.py +12 -223
- mindspore/ops/operations/nn_ops.py +20 -114
- mindspore/ops/operations/other_ops.py +7 -4
- mindspore/ops/operations/random_ops.py +46 -1
- mindspore/ops/primitive.py +18 -6
- mindspore/ops_generate/arg_dtype_cast.py +2 -0
- mindspore/ops_generate/gen_aclnn_implement.py +11 -11
- mindspore/ops_generate/gen_constants.py +36 -0
- mindspore/ops_generate/gen_ops.py +67 -52
- mindspore/ops_generate/gen_ops_inner_prim.py +1 -1
- mindspore/ops_generate/gen_pyboost_func.py +131 -47
- mindspore/ops_generate/op_proto.py +10 -3
- mindspore/ops_generate/pyboost_utils.py +14 -1
- mindspore/ops_generate/template.py +43 -21
- mindspore/parallel/__init__.py +3 -1
- mindspore/parallel/_auto_parallel_context.py +28 -8
- mindspore/parallel/_cell_wrapper.py +83 -0
- mindspore/parallel/_parallel_serialization.py +47 -19
- mindspore/parallel/_tensor.py +81 -11
- mindspore/parallel/_utils.py +13 -1
- mindspore/parallel/algo_parameter_config.py +5 -5
- mindspore/parallel/checkpoint_transform.py +46 -39
- mindspore/parallel/cluster/process_entity/__init__.py +1 -1
- mindspore/parallel/cluster/process_entity/_api.py +31 -23
- mindspore/parallel/cluster/process_entity/_utils.py +2 -27
- mindspore/parallel/parameter_broadcast.py +3 -4
- mindspore/parallel/shard.py +162 -31
- mindspore/parallel/transform_safetensors.py +993 -0
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +2 -1
- mindspore/profiler/common/constant.py +29 -0
- mindspore/profiler/common/registry.py +47 -0
- mindspore/profiler/common/util.py +28 -0
- mindspore/profiler/dynamic_profiler.py +694 -0
- mindspore/profiler/envprofiling.py +17 -19
- mindspore/profiler/parser/ascend_analysis/constant.py +18 -0
- mindspore/profiler/parser/ascend_analysis/file_manager.py +25 -4
- mindspore/profiler/parser/ascend_analysis/function_event.py +43 -19
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +31 -26
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +56 -10
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +55 -8
- mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +27 -20
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +9 -2
- mindspore/profiler/parser/ascend_msprof_exporter.py +5 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +27 -25
- mindspore/profiler/parser/base_timeline_generator.py +19 -25
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
- mindspore/profiler/parser/framework_parser.py +1 -391
- mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
- mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
- mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
- mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
- mindspore/profiler/parser/memory_usage_parser.py +0 -154
- mindspore/profiler/parser/profiler_info.py +78 -6
- mindspore/profiler/profiler.py +153 -0
- mindspore/profiler/profiling.py +280 -412
- mindspore/rewrite/__init__.py +1 -2
- mindspore/rewrite/common/namespace.py +4 -4
- mindspore/rewrite/symbol_tree/symbol_tree.py +3 -3
- mindspore/run_check/_check_version.py +36 -103
- mindspore/safeguard/rewrite_obfuscation.py +591 -247
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +4 -3
- mindspore/train/_utils.py +28 -2
- mindspore/train/amp.py +171 -53
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +4 -4
- mindspore/train/callback/_checkpoint.py +85 -22
- mindspore/train/callback/_cluster_monitor.py +1 -1
- mindspore/train/callback/_flops_collector.py +1 -0
- mindspore/train/callback/_loss_monitor.py +3 -3
- mindspore/train/callback/_on_request_exit.py +134 -31
- mindspore/train/callback/_summary_collector.py +5 -5
- mindspore/train/callback/_tft_register.py +352 -0
- mindspore/train/dataset_helper.py +7 -3
- mindspore/train/metrics/metric.py +3 -3
- mindspore/train/metrics/roc.py +4 -4
- mindspore/train/mind_ir_pb2.py +44 -39
- mindspore/train/model.py +134 -58
- mindspore/train/serialization.py +336 -112
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +21 -0
- mindspore/utils/utils.py +60 -0
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/METADATA +6 -2
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/RECORD +281 -275
- mindspore/include/c_api/ms/abstract.h +0 -67
- mindspore/include/c_api/ms/attribute.h +0 -197
- mindspore/include/c_api/ms/base/handle_types.h +0 -43
- mindspore/include/c_api/ms/base/macros.h +0 -32
- mindspore/include/c_api/ms/base/status.h +0 -33
- mindspore/include/c_api/ms/base/types.h +0 -283
- mindspore/include/c_api/ms/context.h +0 -102
- mindspore/include/c_api/ms/graph.h +0 -160
- mindspore/include/c_api/ms/node.h +0 -606
- mindspore/include/c_api/ms/tensor.h +0 -161
- mindspore/include/c_api/ms/value.h +0 -84
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/nn/extend/basic.py +0 -140
- mindspore/nn/extend/embedding.py +0 -143
- mindspore/nn/extend/layer/normalization.py +0 -109
- mindspore/nn/extend/pooling.py +0 -117
- mindspore/nn/layer/embedding_service.py +0 -531
- mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
- mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
- mindspore/ops/extend/__init__.py +0 -53
- mindspore/ops/extend/array_func.py +0 -218
- mindspore/ops/extend/math_func.py +0 -76
- mindspore/ops/extend/nn_func.py +0 -308
- mindspore/ops/silent_check.py +0 -162
- mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
- mindspore/profiler/parser/msadvisor_parser.py +0 -240
- mindspore/train/callback/_mindio_ttp.py +0 -443
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/WHEEL +0 -0
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/entry_points.txt +0 -0
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/top_level.txt +0 -0
mindspore/nn/extend/pooling.py
DELETED
|
@@ -1,117 +0,0 @@
|
|
|
1
|
-
#Copyright 2020-2022 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ============================================================================
|
|
15
|
-
"""pooling"""
|
|
16
|
-
from __future__ import absolute_import
|
|
17
|
-
|
|
18
|
-
from mindspore.ops.auto_generate.gen_ops_prim import MaxPoolWithIndices, MaxPoolWithMask
|
|
19
|
-
from mindspore.nn.cell import Cell
|
|
20
|
-
|
|
21
|
-
__all__ = ['MaxPool2d']
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
class MaxPool2d(Cell):
|
|
25
|
-
r"""
|
|
26
|
-
Applies a 2D max pooling over an input Tensor which can be regarded as a composition of 2D planes.
|
|
27
|
-
|
|
28
|
-
Typically the input is of shape :math:`(N_{in}, C_{in}, H_{in}, W_{in})`, MaxPool2d outputs
|
|
29
|
-
regional maximum in the :math:`(H_{in}, W_{in})`-dimension. Given kernel size
|
|
30
|
-
:math:`(h_{ker}, w_{ker})` and stride :math:`(s_0, s_1)`, the operation is as follows.
|
|
31
|
-
|
|
32
|
-
.. math::
|
|
33
|
-
\text{output}(N_i, C_j, h, w) = \max_{m=0, \ldots, h_{ker}-1} \max_{n=0, \ldots, w_{ker}-1}
|
|
34
|
-
\text{input}(N_i, C_j, s_0 \times h + m, s_1 \times w + n)
|
|
35
|
-
|
|
36
|
-
.. warning::
|
|
37
|
-
Only support on Atlas training series.
|
|
38
|
-
|
|
39
|
-
Args:
|
|
40
|
-
kernel_size (Union[int, tuple[int]]): The size of kernel used to take the max value,
|
|
41
|
-
is an int number or a single element tuple that represents height and width are both kernel_size,
|
|
42
|
-
or a tuple of two int numbers that represent height and width respectively.
|
|
43
|
-
Default: ``1`` .
|
|
44
|
-
stride (Union[int, tuple[int], None]): The distance of kernel moving, an int number or a single element tuple
|
|
45
|
-
that represents the height and width of movement are both stride, or a tuple of two int numbers that
|
|
46
|
-
represent height and width of movement respectively.
|
|
47
|
-
Default: ``None`` , which indicates the moving step is `kernel_size` .
|
|
48
|
-
padding (Union(int, tuple[int], list[int])): Specifies the padding value of the pooling operation.
|
|
49
|
-
Default: ``0`` . `padding` can only be an integer or a tuple/list containing one or two integers. If
|
|
50
|
-
`padding` is an integer or a tuple/list containing one integer, it will be padded `padding` times in the
|
|
51
|
-
four directions of the input. If `padding` is a tuple/list containing two integers, it will be padded
|
|
52
|
-
`padding[0]` times in the up-down direction of the input and `padding[1]` times in the left-right direction
|
|
53
|
-
of the input.
|
|
54
|
-
dilation (Union(int, tuple[int])): The spacing between the elements of the kernel in convolution,
|
|
55
|
-
used to increase the receptive field of the pooling operation. If it is a tuple, it must contain one or two
|
|
56
|
-
integers. Default: ``1`` .
|
|
57
|
-
return_indices (bool): If ``True`` , the function will return both the result of max pooling and the indices of
|
|
58
|
-
the max elements. Default: ``False`` .
|
|
59
|
-
ceil_mode (bool): If ``True`` , use ceil to compute the output shape instead of floor. Default: ``False`` .
|
|
60
|
-
|
|
61
|
-
Inputs:
|
|
62
|
-
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
|
|
63
|
-
|
|
64
|
-
Outputs:
|
|
65
|
-
If `return_indices` is ``False`` , return a Tensor `output`, else return a tuple (`output`, `argmax`).
|
|
66
|
-
|
|
67
|
-
- **output** (Tensor) - Maxpooling result, with shape :math:`(N_{out}, C_{out}, H_{out}, W_{out})`. It has the
|
|
68
|
-
same data type as `input`.
|
|
69
|
-
- **argmax** (Tensor) - Index corresponding to the maximum value. Data type is int32.
|
|
70
|
-
|
|
71
|
-
.. math::
|
|
72
|
-
H_{out} = \left\lfloor\frac{H_{in} + 2 * \text{padding[0]} - \text{dilation[0]}
|
|
73
|
-
\times (\text{kernel_size[0]} - 1) - 1}{\text{stride[0]}} + 1\right\rfloor
|
|
74
|
-
|
|
75
|
-
.. math::
|
|
76
|
-
W_{out} = \left\lfloor\frac{W_{in} + 2 * \text{padding[1]} - \text{dilation[1]}
|
|
77
|
-
\times (\text{kernel_size[1]} - 1) - 1}{\text{stride[1]}} + 1\right\rfloor
|
|
78
|
-
|
|
79
|
-
Raises:
|
|
80
|
-
TypeError: If `input` is not a Tensor.
|
|
81
|
-
ValueError: If length of shape of `input` is not equal to 4.
|
|
82
|
-
TypeError: If `kernel_size` , `stride` , `padding` or `dilation` is not int or tuple.
|
|
83
|
-
ValueError: If `kernel_size`, `stride` or `dilation` is less than 1.
|
|
84
|
-
ValueError: If `dilation` is not all 1.
|
|
85
|
-
ValueError: If `padding` is less than 0.
|
|
86
|
-
ValueError: If `padding` is more than half of `kernel_size`.
|
|
87
|
-
TypeError: If `ceil_mode` is not bool.
|
|
88
|
-
|
|
89
|
-
Supported Platforms:
|
|
90
|
-
``Ascend``
|
|
91
|
-
|
|
92
|
-
Examples:
|
|
93
|
-
>>> import mindspore as ms
|
|
94
|
-
>>> import numpy as np
|
|
95
|
-
>>> pool = ms.nn.extend.MaxPool2d(kernel_size=3, stride=1)
|
|
96
|
-
>>> input = ms.Tensor(np.random.randint(0, 10, [1, 2, 4, 4]), ms.float32)
|
|
97
|
-
>>> output = pool(input)
|
|
98
|
-
>>> print(output.shape)
|
|
99
|
-
(1, 2, 2, 2)
|
|
100
|
-
"""
|
|
101
|
-
|
|
102
|
-
def __init__(self, kernel_size=1, stride=None, padding=0, dilation=1, return_indices=False,
|
|
103
|
-
ceil_mode=False):
|
|
104
|
-
"""Initialize MaxPool2d."""
|
|
105
|
-
super(MaxPool2d, self).__init__()
|
|
106
|
-
self.return_indices = return_indices
|
|
107
|
-
strides = stride if (stride is not None) else kernel_size
|
|
108
|
-
if return_indices:
|
|
109
|
-
self.max_pool_func_ = MaxPoolWithIndices(kernel_size, strides, padding, dilation, ceil_mode)
|
|
110
|
-
else:
|
|
111
|
-
self.max_pool_func_ = MaxPoolWithMask(kernel_size, strides, padding, dilation, ceil_mode)
|
|
112
|
-
|
|
113
|
-
def construct(self, input):
|
|
114
|
-
out, indices = self.max_pool_func_(input)
|
|
115
|
-
if self.return_indices:
|
|
116
|
-
return out, indices
|
|
117
|
-
return out
|
|
@@ -1,531 +0,0 @@
|
|
|
1
|
-
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ============================================================================
|
|
15
|
-
"""embedding service"""
|
|
16
|
-
import json
|
|
17
|
-
import os
|
|
18
|
-
import math
|
|
19
|
-
|
|
20
|
-
from mindspore.nn.layer.embedding_service_layer import ESInitLayer
|
|
21
|
-
from mindspore.common.initializer import Uniform, TruncatedNormal, Constant
|
|
22
|
-
from mindspore.nn.layer.embedding_service_layer import ESEmbeddingTableImport, ESEmbeddingTableExport, \
|
|
23
|
-
ESEmbeddingCKPTImport, ESEmbeddingCKPTExport
|
|
24
|
-
|
|
25
|
-
_INT32_MAX_VALUE = 2147483647
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
class CounterFilter:
|
|
29
|
-
""" Counter filter for embedding table. """
|
|
30
|
-
def __init__(self, filter_freq, default_key_or_value, default_key=None, default_value=None):
|
|
31
|
-
self.filter_freq = filter_freq
|
|
32
|
-
self.default_key = default_key
|
|
33
|
-
self.default_value = default_value
|
|
34
|
-
self.default_key_or_value = default_key_or_value
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
class PaddingParamsOption:
|
|
38
|
-
""" padding key option for embedding service table. """
|
|
39
|
-
def __init__(self, padding_key=None,
|
|
40
|
-
mask=True,
|
|
41
|
-
mask_zero=False):
|
|
42
|
-
self.padding_key = padding_key
|
|
43
|
-
self.mask = mask
|
|
44
|
-
self.mask_zero = mask_zero
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
class CompletionKeyOption:
|
|
48
|
-
""" completion key option for embedding service table. """
|
|
49
|
-
def __init__(self, completion_key=None, mask=1):
|
|
50
|
-
self.completion_key = completion_key
|
|
51
|
-
self.mask = mask
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
class EvictOption:
|
|
55
|
-
""" Evict option for embedding table. """
|
|
56
|
-
def __init__(self, steps_to_live):
|
|
57
|
-
self.steps_to_live = steps_to_live
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
class EmbeddingVariableOption:
|
|
61
|
-
""" option for embedding service table. """
|
|
62
|
-
def __init__(self, filter_option=None,
|
|
63
|
-
padding_option=None,
|
|
64
|
-
evict_option=None,
|
|
65
|
-
completion_option=None,
|
|
66
|
-
storage_option=None,
|
|
67
|
-
feature_freezing_option=None,
|
|
68
|
-
communication_option=None):
|
|
69
|
-
self.filter_option = filter_option
|
|
70
|
-
self.padding_option = padding_option
|
|
71
|
-
self.evict_option = evict_option
|
|
72
|
-
self.completion_option = completion_option
|
|
73
|
-
self.storage_option = storage_option
|
|
74
|
-
self.feature_freezing_option = feature_freezing_option
|
|
75
|
-
self.communication_option = communication_option
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
class EsInitializer:
|
|
79
|
-
"""Initializer for embedding service table."""
|
|
80
|
-
def __init__(self, initializer_mode, min_scale=-0.01, max_scale=0.01,
|
|
81
|
-
constant_value=1.0, mu=0.0, sigma=1.0, seed=0):
|
|
82
|
-
self.initializer_mode = initializer_mode
|
|
83
|
-
self.min = min_scale
|
|
84
|
-
self.max = max_scale
|
|
85
|
-
self.constant_value = constant_value
|
|
86
|
-
self.mu = mu
|
|
87
|
-
self.sigma = sigma
|
|
88
|
-
self.seed = seed
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
class EsOptimizer:
|
|
92
|
-
"""Optimizer for embedding service table."""
|
|
93
|
-
def __init__(self, name, initial_accumulator_value=0., ms=0., mom=0.):
|
|
94
|
-
self.name = name
|
|
95
|
-
self.initial_accumulator_value = initial_accumulator_value
|
|
96
|
-
self.ms = ms
|
|
97
|
-
self.mom = mom
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
def check_common_init_params(name, init_vocabulary_size, embedding_dim):
|
|
101
|
-
"""
|
|
102
|
-
Check init params.
|
|
103
|
-
"""
|
|
104
|
-
if (name is None) or (init_vocabulary_size is None) or (embedding_dim is None):
|
|
105
|
-
raise ValueError("table name, init_vocabulary_size and embedding_dim can not be None.")
|
|
106
|
-
if not isinstance(name, str):
|
|
107
|
-
raise TypeError("embedding table name must be string.")
|
|
108
|
-
if (not isinstance(init_vocabulary_size, int)) or (not isinstance(embedding_dim, int)):
|
|
109
|
-
raise ValueError("init_vocabulary_size and embedding_dim must be int.")
|
|
110
|
-
if init_vocabulary_size < 0:
|
|
111
|
-
raise ValueError("init_vocabulary_size can not be smaller than zero.")
|
|
112
|
-
if embedding_dim <= 0:
|
|
113
|
-
raise ValueError("embedding_dim must be greater than zero.")
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
class EmbeddingServiceOut:
|
|
117
|
-
"""
|
|
118
|
-
EmbeddingServiceOut
|
|
119
|
-
"""
|
|
120
|
-
def __init__(self, table_id_dict, es_initializer=None, es_counter_filter=None,
|
|
121
|
-
es_padding_keys=None, es_completion_keys=None):
|
|
122
|
-
self.table_id_dict = table_id_dict
|
|
123
|
-
self.es_initializer = es_initializer
|
|
124
|
-
self.es_counter_filter = es_counter_filter
|
|
125
|
-
self.es_padding_keys = es_padding_keys
|
|
126
|
-
self.es_completion_keys = es_completion_keys
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
class EmbeddingService:
|
|
130
|
-
"""
|
|
131
|
-
EmbeddingService
|
|
132
|
-
"""
|
|
133
|
-
def __init__(self):
|
|
134
|
-
"""
|
|
135
|
-
Init EmbeddingService
|
|
136
|
-
"""
|
|
137
|
-
env_dist = os.environ
|
|
138
|
-
es_cluster_config = env_dist.get("ESCLUSTER_CONFIG_PATH")
|
|
139
|
-
if es_cluster_config is None:
|
|
140
|
-
raise ValueError("EsClusterConfig env is null.")
|
|
141
|
-
self._server_ip_to_ps_num = {}
|
|
142
|
-
with open(es_cluster_config, encoding='utf-8') as a:
|
|
143
|
-
es_cluster_config_json = json.load(a)
|
|
144
|
-
self._es_cluster_conf = json.dumps(es_cluster_config_json)
|
|
145
|
-
self._ps_num = int(es_cluster_config_json["psNum"])
|
|
146
|
-
self._ps_ids = []
|
|
147
|
-
self._ps_ids_list = es_cluster_config_json["psCluster"]
|
|
148
|
-
for each_ps in self._ps_ids_list:
|
|
149
|
-
self._server_ip_to_ps_num[each_ps["ctrlPanel"]["ipaddr"]] = 0
|
|
150
|
-
|
|
151
|
-
for each_ps in self._ps_ids_list:
|
|
152
|
-
self._ps_ids.append(each_ps["id"])
|
|
153
|
-
ctrl_panel = each_ps["ctrlPanel"]
|
|
154
|
-
self._server_ip_to_ps_num[ctrl_panel["ipaddr"]] += 1
|
|
155
|
-
|
|
156
|
-
for each_server_ps_num in self._server_ip_to_ps_num:
|
|
157
|
-
if self._server_ip_to_ps_num[each_server_ps_num] > 4:
|
|
158
|
-
raise ValueError("PS num of one server can not exceed 4, please check config params.")
|
|
159
|
-
if self._ps_num > 4:
|
|
160
|
-
raise ValueError("PS num of one server can not exceed 4, please check config params.")
|
|
161
|
-
|
|
162
|
-
# storage each ps table's params
|
|
163
|
-
self._table_to_embedding_dim = {}
|
|
164
|
-
self._table_to_max_num = {}
|
|
165
|
-
self._table_to_optimizer = {}
|
|
166
|
-
self._table_to_slot_var_num = {}
|
|
167
|
-
self._table_to_counter_filter = {}
|
|
168
|
-
self._table_id_to_padding_key = {}
|
|
169
|
-
self._table_id_to_completion_key = {}
|
|
170
|
-
self._train_mode = True
|
|
171
|
-
self._train_level = False
|
|
172
|
-
self._optimizer = None
|
|
173
|
-
self._init_table_flag = False
|
|
174
|
-
|
|
175
|
-
self._small_table_name_list = []
|
|
176
|
-
self._ps_table_count = 0
|
|
177
|
-
self._table_name_to_id = {}
|
|
178
|
-
self._table_id_to_name = {}
|
|
179
|
-
self._table_id_to_initializer = {}
|
|
180
|
-
self._table_id_to_steps_to_live = {}
|
|
181
|
-
|
|
182
|
-
self._ps_table_id_list = []
|
|
183
|
-
# storage lookup: table_id list, lookup result list, lookup key list
|
|
184
|
-
self._ps_lookup_index = 0
|
|
185
|
-
# storage all inited table names
|
|
186
|
-
self._table_name_has_init = []
|
|
187
|
-
# only storage all inited PS table names
|
|
188
|
-
self._ps_table_name_list = []
|
|
189
|
-
# now only use for adagrad accum
|
|
190
|
-
self._ps_table_id_to_optimizer_params = {}
|
|
191
|
-
|
|
192
|
-
# use for counter filter
|
|
193
|
-
self._table_use_counter_filter = {}
|
|
194
|
-
self._use_counter_filter = False
|
|
195
|
-
self._use_evict = False
|
|
196
|
-
self._use_padding_key = False
|
|
197
|
-
self._use_completion_key = False
|
|
198
|
-
|
|
199
|
-
def embedding_init(self, name, init_vocabulary_size, embedding_dim, max_feature_count,
|
|
200
|
-
initializer=Uniform(scale=0.01), ev_option=None, optimizer=None, optimizer_param=None,
|
|
201
|
-
mode="train"):
|
|
202
|
-
"""
|
|
203
|
-
Init embedding
|
|
204
|
-
:param name: big table name
|
|
205
|
-
:param init_vocabulary_size: vocab size
|
|
206
|
-
:param embedding_dim: embedding dim
|
|
207
|
-
:param max_feature_count: max feature count
|
|
208
|
-
:param initializer: mindspore common initializer
|
|
209
|
-
:param ev_option: output of embedding_variable_option
|
|
210
|
-
:param optimizer: optimizer
|
|
211
|
-
:param optimizer_param: optimizer param
|
|
212
|
-
:param mode: mode, train or predict
|
|
213
|
-
:return: table_id_dict, es_initializer_dict, es_filter_dict
|
|
214
|
-
"""
|
|
215
|
-
check_common_init_params(name=name, init_vocabulary_size=init_vocabulary_size, embedding_dim=embedding_dim)
|
|
216
|
-
table_id = self._check_and_update_ps_init_params(name=name, init_vocabulary_size=init_vocabulary_size,
|
|
217
|
-
max_feature_count=max_feature_count, ev_option=ev_option)
|
|
218
|
-
self._ps_lookup_index = self._ps_table_count
|
|
219
|
-
self._table_to_embedding_dim[table_id] = embedding_dim
|
|
220
|
-
self._table_to_max_num[table_id] = max_feature_count
|
|
221
|
-
# storage the table id for embedding PS table
|
|
222
|
-
self._ps_table_id_list.append(table_id)
|
|
223
|
-
self._ps_table_name_list.append(name)
|
|
224
|
-
|
|
225
|
-
if len(self._ps_table_id_list) > 10:
|
|
226
|
-
raise ValueError("Now only 10 PS embedding tables can be init.")
|
|
227
|
-
bucket_size = math.ceil(init_vocabulary_size / self._ps_num)
|
|
228
|
-
if optimizer is None:
|
|
229
|
-
self._train_mode = False
|
|
230
|
-
self._table_to_slot_var_num[table_id] = 0
|
|
231
|
-
else:
|
|
232
|
-
self._check_ps_opt_and_initializer(optimizer=optimizer, initializer=initializer, table_id=table_id)
|
|
233
|
-
self._optimizer = optimizer
|
|
234
|
-
self._table_to_optimizer[table_id] = self._optimizer
|
|
235
|
-
self._ps_table_id_to_optimizer_params[table_id] = []
|
|
236
|
-
self._update_optimizer_slot_var_num(table_id=table_id)
|
|
237
|
-
# new train or continue train from a checkpoint
|
|
238
|
-
if initializer is not None:
|
|
239
|
-
self._train_level = True
|
|
240
|
-
filter_mode = self._init_counter_filter(table_id, ev_option)
|
|
241
|
-
self._init_padding_key(table_id, ev_option)
|
|
242
|
-
self._init_completion_key(table_id, ev_option)
|
|
243
|
-
self._init_optimizer_mode_and_params(table_id, optimizer_param)
|
|
244
|
-
es_init_layer = ESInitLayer(self._ps_num, self._ps_ids, self._train_mode, self._train_level, table_id,
|
|
245
|
-
bucket_size, embedding_dim, self._table_to_slot_var_num.get(table_id),
|
|
246
|
-
self._table_id_to_initializer.get(table_id), filter_mode, optimizer,
|
|
247
|
-
self._ps_table_id_to_optimizer_params.get(table_id), max_feature_count, mode)
|
|
248
|
-
es_init_layer()
|
|
249
|
-
return EmbeddingServiceOut(self._table_name_to_id, self._table_id_to_initializer,
|
|
250
|
-
self._table_to_counter_filter, self._table_id_to_padding_key,
|
|
251
|
-
self._table_id_to_completion_key)
|
|
252
|
-
|
|
253
|
-
def padding_param(self, padding_key, mask=True, mask_zero=False):
|
|
254
|
-
"""
|
|
255
|
-
Init padding key param
|
|
256
|
-
:param padding_key: padding key
|
|
257
|
-
:param mask: padding key mask
|
|
258
|
-
:param mask_zero: mask zero
|
|
259
|
-
:return: PaddingParamsOption obj
|
|
260
|
-
"""
|
|
261
|
-
if not isinstance(padding_key, int):
|
|
262
|
-
raise TypeError("padding_key must be int, please check.")
|
|
263
|
-
if not isinstance(mask, bool):
|
|
264
|
-
raise TypeError("mask must be bool, please check.")
|
|
265
|
-
self._use_padding_key = True
|
|
266
|
-
return PaddingParamsOption(padding_key=padding_key, mask=mask, mask_zero=mask_zero)
|
|
267
|
-
|
|
268
|
-
def completion_key(self, completion_key, mask=True):
|
|
269
|
-
"""
|
|
270
|
-
Init completion key param
|
|
271
|
-
:param completion_key: completion key
|
|
272
|
-
:param mask: completion key mask
|
|
273
|
-
:return: CompletionKeyOption obj
|
|
274
|
-
"""
|
|
275
|
-
if not isinstance(completion_key, int):
|
|
276
|
-
raise TypeError("completion_key must be int, please check.")
|
|
277
|
-
if not isinstance(mask, bool):
|
|
278
|
-
raise TypeError("mask must be bool, please check.")
|
|
279
|
-
self._use_completion_key = True
|
|
280
|
-
completion_key_mask = 1 if mask is True else 0
|
|
281
|
-
return CompletionKeyOption(completion_key=completion_key, mask=completion_key_mask)
|
|
282
|
-
|
|
283
|
-
def counter_filter(self, filter_freq, default_key=None, default_value=None):
|
|
284
|
-
"""
|
|
285
|
-
Set filter_option
|
|
286
|
-
:param filter_freq: filter freq
|
|
287
|
-
:param default_key: default key
|
|
288
|
-
:param default_value: default value
|
|
289
|
-
:return: CounterFilter obj
|
|
290
|
-
"""
|
|
291
|
-
if not isinstance(filter_freq, int):
|
|
292
|
-
raise TypeError("filter_freq must be int, please check.")
|
|
293
|
-
if filter_freq < 0:
|
|
294
|
-
raise ValueError("filter_freq must can not be smaller than 0.")
|
|
295
|
-
if (default_key is None) and (default_value is None):
|
|
296
|
-
raise ValueError("default_key and default_value can not be both None.")
|
|
297
|
-
if (default_key is not None) and (default_value is not None):
|
|
298
|
-
raise ValueError("default_key and default_value can not be both set.")
|
|
299
|
-
if default_key is None and (not isinstance(default_value, (int, float))):
|
|
300
|
-
raise TypeError("When default_value is not None, it must be float or int, please check.")
|
|
301
|
-
if default_value is None and (not isinstance(default_key, int)):
|
|
302
|
-
raise TypeError("When default_key is not None, it must be int, please check.")
|
|
303
|
-
self._use_counter_filter = True
|
|
304
|
-
if default_key is None:
|
|
305
|
-
return CounterFilter(filter_freq=filter_freq, default_key_or_value=0,
|
|
306
|
-
default_key=0, default_value=default_value)
|
|
307
|
-
return CounterFilter(filter_freq=filter_freq, default_key_or_value=1,
|
|
308
|
-
default_key=default_key, default_value=1)
|
|
309
|
-
|
|
310
|
-
def evict_option(self, steps_to_live):
|
|
311
|
-
"""
|
|
312
|
-
Set evict_option
|
|
313
|
-
:param steps_to_live: steps to live
|
|
314
|
-
:return: EvictOption obj
|
|
315
|
-
"""
|
|
316
|
-
if not isinstance(steps_to_live, int):
|
|
317
|
-
raise TypeError("steps_to_live must be int, please check.")
|
|
318
|
-
if steps_to_live <= 0:
|
|
319
|
-
raise ValueError("steps_to_live must must be greater than 0.")
|
|
320
|
-
self._use_evict = True
|
|
321
|
-
return EvictOption(steps_to_live=steps_to_live)
|
|
322
|
-
|
|
323
|
-
def embedding_variable_option(self, filter_option=None, padding_option=None, evict_option=None,
|
|
324
|
-
completion_option=None, storage_option=None, feature_freezing_option=None,
|
|
325
|
-
communication_option=None):
|
|
326
|
-
"""
|
|
327
|
-
Set embedding variable option
|
|
328
|
-
:param filter_option: filter policy, is the output of counter_filter
|
|
329
|
-
:param padding_option: padding policy, is the output of padding_keys
|
|
330
|
-
:param evict_option: evict policy
|
|
331
|
-
:param completion_option: not support
|
|
332
|
-
:param storage_option: not support
|
|
333
|
-
:param feature_freezing_option: not support
|
|
334
|
-
:param communication_option: not support
|
|
335
|
-
:return: EmbeddingVariableOption obj
|
|
336
|
-
"""
|
|
337
|
-
if (filter_option is not None) and (not isinstance(filter_option, CounterFilter)):
|
|
338
|
-
raise ValueError("If padding_option isn't None, it must be CounterFilter type.")
|
|
339
|
-
if filter_option is not None:
|
|
340
|
-
self._use_counter_filter = True
|
|
341
|
-
if (padding_option is not None) and (not isinstance(padding_option, PaddingParamsOption)):
|
|
342
|
-
raise TypeError("If padding_option isn't None, it must be EmbeddingPaddingParamsOption type.")
|
|
343
|
-
if (completion_option is not None) and (not isinstance(completion_option, CompletionKeyOption)):
|
|
344
|
-
raise TypeError("If completion_option isn't None, it must be EmbeddingPaddingCompletionKeyOption type.")
|
|
345
|
-
if (evict_option is not None) and (not isinstance(evict_option, EvictOption)):
|
|
346
|
-
raise TypeError("When evict_option is not None, it must be EvictOption type.")
|
|
347
|
-
return EmbeddingVariableOption(filter_option=filter_option, padding_option=padding_option,
|
|
348
|
-
evict_option=evict_option, completion_option=completion_option,
|
|
349
|
-
storage_option=storage_option, feature_freezing_option=feature_freezing_option,
|
|
350
|
-
communication_option=communication_option)
|
|
351
|
-
|
|
352
|
-
def embedding_ckpt_export(self, file_path):
|
|
353
|
-
"""
|
|
354
|
-
Export big table ckpt
|
|
355
|
-
:param file_path: the file path to storage ckpt ret
|
|
356
|
-
:return:
|
|
357
|
-
"""
|
|
358
|
-
embedding_dim_list = []
|
|
359
|
-
value_total_len_list = []
|
|
360
|
-
steps_to_live_list = []
|
|
361
|
-
for table_id in self._ps_table_id_list:
|
|
362
|
-
embedding_dim_list.append(self._table_to_embedding_dim.get(table_id))
|
|
363
|
-
value_total_len_list.append(self._table_to_embedding_dim.get(table_id) *
|
|
364
|
-
(self._table_to_slot_var_num.get(table_id) + 1) + 2)
|
|
365
|
-
steps_to_live_list.append(self._table_id_to_steps_to_live.get(table_id, 0))
|
|
366
|
-
embedding_ckpt_export_layer = ESEmbeddingCKPTExport(embedding_dim_list, value_total_len_list,
|
|
367
|
-
self._ps_table_name_list, self._ps_table_id_list,
|
|
368
|
-
file_path, steps_to_live_list)
|
|
369
|
-
embedding_ckpt_export_layer()
|
|
370
|
-
|
|
371
|
-
def embedding_table_export(self, file_path):
|
|
372
|
-
"""
|
|
373
|
-
Export big table embedding
|
|
374
|
-
:param file_path: the file path to storage embedding ret
|
|
375
|
-
:return:
|
|
376
|
-
"""
|
|
377
|
-
embedding_dim_list = []
|
|
378
|
-
steps_to_live_list = []
|
|
379
|
-
for table_id in self._ps_table_id_list:
|
|
380
|
-
embedding_dim_list.append(self._table_to_embedding_dim.get(table_id))
|
|
381
|
-
steps_to_live_list.append(self._table_id_to_steps_to_live.get(table_id, 0))
|
|
382
|
-
|
|
383
|
-
embedding_table_export_layer = ESEmbeddingTableExport(embedding_dim_list, embedding_dim_list,
|
|
384
|
-
self._ps_table_name_list, self._ps_table_id_list,
|
|
385
|
-
file_path, steps_to_live_list)
|
|
386
|
-
embedding_table_export_layer()
|
|
387
|
-
|
|
388
|
-
def embedding_ckpt_import(self, file_path):
|
|
389
|
-
"""
|
|
390
|
-
Import big table ckpt
|
|
391
|
-
:param file_path: the file path to import ckpt ret
|
|
392
|
-
:return:
|
|
393
|
-
"""
|
|
394
|
-
embedding_dim_list = []
|
|
395
|
-
value_total_len_list = []
|
|
396
|
-
for table_id in self._ps_table_id_list:
|
|
397
|
-
embedding_dim_list.append(self._table_to_embedding_dim.get(table_id))
|
|
398
|
-
value_total_len_list.append(self._table_to_embedding_dim.get(table_id) *
|
|
399
|
-
(self._table_to_slot_var_num.get(table_id) + 1) + 2)
|
|
400
|
-
|
|
401
|
-
embedding_ckpt_export_layer = ESEmbeddingCKPTImport(embedding_dim_list, value_total_len_list,
|
|
402
|
-
self._ps_table_name_list, self._ps_table_id_list,
|
|
403
|
-
file_path)
|
|
404
|
-
embedding_ckpt_export_layer()
|
|
405
|
-
|
|
406
|
-
def embedding_table_import(self, file_path):
|
|
407
|
-
"""
|
|
408
|
-
Import big table embedding
|
|
409
|
-
:param file_path: the file path to import embedding ret
|
|
410
|
-
:return:
|
|
411
|
-
"""
|
|
412
|
-
embedding_dim_list = []
|
|
413
|
-
for table_id in self._ps_table_id_list:
|
|
414
|
-
embedding_dim_list.append(self._table_to_embedding_dim.get(table_id))
|
|
415
|
-
embedding_table_export_layer = ESEmbeddingTableImport(embedding_dim_list, embedding_dim_list,
|
|
416
|
-
self._ps_table_name_list, self._ps_table_id_list,
|
|
417
|
-
file_path)
|
|
418
|
-
embedding_table_export_layer()
|
|
419
|
-
|
|
420
|
-
def _check_and_update_ps_init_params(self, name, init_vocabulary_size, max_feature_count, ev_option):
|
|
421
|
-
"""
|
|
422
|
-
Check parameter server params and init table id
|
|
423
|
-
"""
|
|
424
|
-
steps_to_live = 0
|
|
425
|
-
if max_feature_count is None:
|
|
426
|
-
raise ValueError("For ps table, max_feature_count can not be None.")
|
|
427
|
-
if (ev_option is not None) and (not isinstance(ev_option, EmbeddingVariableOption)):
|
|
428
|
-
raise TypeError("For ps table, ev_option must be EmbeddingVariableOption type.")
|
|
429
|
-
if (ev_option is not None) and (ev_option.evict_option is not None):
|
|
430
|
-
steps_to_live = ev_option.evict_option.steps_to_live
|
|
431
|
-
if not isinstance(max_feature_count, int):
|
|
432
|
-
raise ValueError("For ps table, max_feature_count must be int.")
|
|
433
|
-
if init_vocabulary_size >= _INT32_MAX_VALUE:
|
|
434
|
-
raise ValueError("init_vocabulary_size exceeds int32 max value.")
|
|
435
|
-
if max_feature_count <= 0:
|
|
436
|
-
raise ValueError("For ps table, max_feature_count must be greater than zero.")
|
|
437
|
-
if name not in self._table_name_has_init:
|
|
438
|
-
table_id = self._ps_table_count
|
|
439
|
-
self._table_name_to_id[name] = table_id
|
|
440
|
-
self._table_id_to_name[table_id] = name
|
|
441
|
-
self._table_id_to_steps_to_live[table_id] = steps_to_live
|
|
442
|
-
self._ps_table_count += 1
|
|
443
|
-
self._table_name_has_init.append(name)
|
|
444
|
-
else:
|
|
445
|
-
raise ValueError("This table has been initialized.")
|
|
446
|
-
return table_id
|
|
447
|
-
|
|
448
|
-
def _check_ps_opt_and_initializer(self, optimizer, initializer, table_id):
|
|
449
|
-
"""
|
|
450
|
-
Check args of parameter server
|
|
451
|
-
:param optimizer: the optimizer type, just support adam now
|
|
452
|
-
:param initializer: mindspore common initializer
|
|
453
|
-
:param table_id: table id
|
|
454
|
-
:return:
|
|
455
|
-
"""
|
|
456
|
-
if optimizer not in ["adam", "adagrad", "adamw", "ftrl"]:
|
|
457
|
-
raise ValueError("optimizer should be one of adam, adagrad, adamw, ftrl")
|
|
458
|
-
if initializer is not None:
|
|
459
|
-
if isinstance(initializer, EsInitializer):
|
|
460
|
-
self._table_id_to_initializer[table_id] = initializer
|
|
461
|
-
elif isinstance(initializer, TruncatedNormal):
|
|
462
|
-
self._table_id_to_initializer[table_id] = \
|
|
463
|
-
EsInitializer(initializer_mode="truncated_normal", mu=initializer.mean,
|
|
464
|
-
sigma=initializer.sigma, seed=initializer.seed[0])
|
|
465
|
-
elif isinstance(initializer, Uniform):
|
|
466
|
-
self._table_id_to_initializer[table_id] = \
|
|
467
|
-
EsInitializer(initializer_mode="random_uniform",
|
|
468
|
-
min_scale=-initializer.scale,
|
|
469
|
-
max_scale=initializer.scale, seed=initializer.seed[0])
|
|
470
|
-
elif isinstance(initializer, Constant):
|
|
471
|
-
self._table_id_to_initializer[table_id] = \
|
|
472
|
-
EsInitializer(initializer_mode="constant", constant_value=initializer.value)
|
|
473
|
-
else:
|
|
474
|
-
raise TypeError("initializer must be EsInitializer or mindspore initializer, and only support"
|
|
475
|
-
"Uniform, TruncatedNormal and Constant value.")
|
|
476
|
-
|
|
477
|
-
def _update_optimizer_slot_var_num(self, table_id):
|
|
478
|
-
"""
|
|
479
|
-
Update _table_to_slot_var_num by diff optimizer
|
|
480
|
-
"""
|
|
481
|
-
# adam, adamw, rmsprop include m and v, 2 slots; adagrad include accumulator, 1 slot; sgd include 0 slot
|
|
482
|
-
if self._optimizer == "adagrad":
|
|
483
|
-
self._table_to_slot_var_num[table_id] = 1
|
|
484
|
-
elif self._optimizer == "sgd":
|
|
485
|
-
self._table_to_slot_var_num[table_id] = 0
|
|
486
|
-
else:
|
|
487
|
-
self._table_to_slot_var_num[table_id] = 2
|
|
488
|
-
|
|
489
|
-
def _init_counter_filter(self, table_id, ev_option):
|
|
490
|
-
"""
|
|
491
|
-
Init counter filter params
|
|
492
|
-
"""
|
|
493
|
-
if (ev_option is not None) and (ev_option.filter_option is not None):
|
|
494
|
-
filter_mode = "counter"
|
|
495
|
-
self._table_to_counter_filter[table_id] = ev_option.filter_option
|
|
496
|
-
self._table_use_counter_filter[table_id] = 1
|
|
497
|
-
else:
|
|
498
|
-
filter_mode = "no_filter"
|
|
499
|
-
self._table_use_counter_filter[table_id] = 0
|
|
500
|
-
return filter_mode
|
|
501
|
-
|
|
502
|
-
def _init_padding_key(self, table_id, ev_option):
|
|
503
|
-
"""
|
|
504
|
-
Init padding key params
|
|
505
|
-
"""
|
|
506
|
-
if (ev_option is not None) and (ev_option.padding_option is not None):
|
|
507
|
-
self._table_id_to_padding_key[table_id] = ev_option.padding_option
|
|
508
|
-
|
|
509
|
-
def _init_completion_key(self, table_id, ev_option):
|
|
510
|
-
"""
|
|
511
|
-
Init completion key params
|
|
512
|
-
"""
|
|
513
|
-
if (ev_option is not None) and (ev_option.completion_option is not None):
|
|
514
|
-
self._table_id_to_completion_key[table_id] = ev_option.completion_option
|
|
515
|
-
|
|
516
|
-
def _init_optimizer_mode_and_params(self, table_id, optimizer_param):
|
|
517
|
-
"""
|
|
518
|
-
Init _ps_table_id_to_optimizer_params by diff optimizer
|
|
519
|
-
"""
|
|
520
|
-
optimizer = self._table_to_optimizer.get(table_id)
|
|
521
|
-
if optimizer is None:
|
|
522
|
-
return
|
|
523
|
-
if optimizer in ["adagrad", "ftrl"]:
|
|
524
|
-
if optimizer_param is not None:
|
|
525
|
-
self._ps_table_id_to_optimizer_params[table_id].append(optimizer_param)
|
|
526
|
-
else:
|
|
527
|
-
raise ValueError("For adagrad optimizer, optimizer_param should have 1 param, "
|
|
528
|
-
"initial_accumulator_value.")
|
|
529
|
-
|
|
530
|
-
if optimizer in ["adam", "adamw", "sgd", "ftrl"]:
|
|
531
|
-
self._ps_table_id_to_optimizer_params[table_id].append(0.)
|