mindspore 2.3.0__cp39-cp39-win_amd64.whl → 2.4.1__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +3 -1
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +50 -9
- mindspore/_extends/parse/compile_config.py +41 -0
- mindspore/_extends/parse/parser.py +9 -7
- mindspore/_extends/parse/standard_method.py +52 -14
- mindspore/_extends/pijit/pijit_func_white_list.py +350 -24
- mindspore/amp.py +24 -10
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/common/__init__.py +6 -4
- mindspore/common/_pijit_context.py +190 -0
- mindspore/common/_register_for_tensor.py +2 -1
- mindspore/common/_tensor_overload.py +139 -0
- mindspore/common/api.py +102 -87
- mindspore/common/dump.py +5 -6
- mindspore/common/generator.py +1 -7
- mindspore/common/hook_handle.py +14 -26
- mindspore/common/initializer.py +51 -15
- mindspore/common/mindir_util.py +2 -2
- mindspore/common/parameter.py +62 -15
- mindspore/common/recompute.py +39 -9
- mindspore/common/sparse_tensor.py +7 -3
- mindspore/common/tensor.py +183 -37
- mindspore/communication/__init__.py +1 -1
- mindspore/communication/_comm_helper.py +38 -3
- mindspore/communication/comm_func.py +315 -60
- mindspore/communication/management.py +14 -14
- mindspore/context.py +132 -22
- mindspore/dataset/__init__.py +1 -1
- mindspore/dataset/audio/__init__.py +1 -1
- mindspore/dataset/core/config.py +7 -0
- mindspore/dataset/core/validator_helpers.py +7 -0
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +72 -44
- mindspore/dataset/engine/datasets_audio.py +7 -7
- mindspore/dataset/engine/datasets_standard_format.py +53 -3
- mindspore/dataset/engine/datasets_text.py +20 -20
- mindspore/dataset/engine/datasets_user_defined.py +174 -104
- mindspore/dataset/engine/datasets_vision.py +33 -33
- mindspore/dataset/engine/iterators.py +29 -0
- mindspore/dataset/engine/obs/util.py +7 -0
- mindspore/dataset/engine/queue.py +114 -60
- mindspore/dataset/engine/serializer_deserializer.py +2 -2
- mindspore/dataset/engine/validators.py +34 -14
- mindspore/dataset/text/__init__.py +1 -4
- mindspore/dataset/transforms/__init__.py +0 -3
- mindspore/dataset/utils/line_reader.py +2 -0
- mindspore/dataset/vision/__init__.py +1 -4
- mindspore/dataset/vision/utils.py +1 -1
- mindspore/dataset/vision/validators.py +2 -1
- mindspore/dnnl.dll +0 -0
- mindspore/{nn/extend → experimental/es}/__init__.py +4 -11
- mindspore/experimental/es/embedding_service.py +883 -0
- mindspore/{nn/layer → experimental/es}/embedding_service_layer.py +218 -30
- mindspore/experimental/llm_boost/__init__.py +21 -0
- mindspore/{nn/extend/layer → experimental/llm_boost/atb}/__init__.py +4 -8
- mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
- mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
- mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
- mindspore/experimental/llm_boost/register.py +129 -0
- mindspore/experimental/llm_boost/utils.py +31 -0
- mindspore/experimental/optim/adamw.py +85 -0
- mindspore/experimental/optim/optimizer.py +3 -0
- mindspore/hal/__init__.py +3 -3
- mindspore/hal/contiguous_tensors_handle.py +175 -0
- mindspore/hal/stream.py +18 -0
- mindspore/include/api/model_group.h +13 -1
- mindspore/include/api/types.h +10 -10
- mindspore/include/dataset/config.h +2 -2
- mindspore/include/dataset/constants.h +2 -2
- mindspore/include/dataset/execute.h +2 -2
- mindspore/include/dataset/vision.h +4 -0
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +1 -1
- mindspore/mindrecord/filewriter.py +68 -51
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mint/__init__.py +983 -46
- mindspore/mint/distributed/__init__.py +31 -0
- mindspore/mint/distributed/distributed.py +254 -0
- mindspore/mint/nn/__init__.py +268 -23
- mindspore/mint/nn/functional.py +125 -19
- mindspore/mint/nn/layer/__init__.py +39 -0
- mindspore/mint/nn/layer/activation.py +133 -0
- mindspore/mint/nn/layer/normalization.py +477 -0
- mindspore/mint/nn/layer/pooling.py +110 -0
- mindspore/mint/optim/adamw.py +26 -13
- mindspore/mint/special/__init__.py +63 -0
- mindspore/multiprocessing/__init__.py +2 -1
- mindspore/nn/__init__.py +0 -1
- mindspore/nn/cell.py +276 -96
- mindspore/nn/layer/activation.py +211 -44
- mindspore/nn/layer/basic.py +137 -10
- mindspore/nn/layer/embedding.py +137 -2
- mindspore/nn/layer/normalization.py +101 -5
- mindspore/nn/layer/padding.py +34 -48
- mindspore/nn/layer/pooling.py +161 -7
- mindspore/nn/layer/transformer.py +3 -3
- mindspore/nn/loss/__init__.py +2 -2
- mindspore/nn/loss/loss.py +84 -6
- mindspore/nn/optim/__init__.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -1
- mindspore/nn/optim/adam.py +1 -1
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/tft_wrapper.py +124 -0
- mindspore/nn/wrap/cell_wrapper.py +12 -23
- mindspore/nn/wrap/grad_reducer.py +5 -5
- mindspore/nn/wrap/loss_scale.py +17 -3
- mindspore/numpy/__init__.py +1 -1
- mindspore/numpy/array_creations.py +65 -68
- mindspore/numpy/array_ops.py +64 -60
- mindspore/numpy/fft.py +610 -75
- mindspore/numpy/logic_ops.py +11 -10
- mindspore/numpy/math_ops.py +85 -84
- mindspore/numpy/utils_const.py +4 -4
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +6 -4
- mindspore/ops/_grad_experimental/grad_array_ops.py +0 -11
- mindspore/ops/_grad_experimental/grad_comm_ops.py +67 -4
- mindspore/ops/_grad_experimental/grad_math_ops.py +0 -22
- mindspore/ops/_vmap/vmap_array_ops.py +2 -4
- mindspore/ops/_vmap/vmap_math_ops.py +17 -1
- mindspore/ops/_vmap/vmap_nn_ops.py +43 -2
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +91 -7
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +2 -0
- mindspore/ops/auto_generate/gen_extend_func.py +767 -13
- mindspore/ops/auto_generate/gen_ops_def.py +2452 -364
- mindspore/ops/auto_generate/gen_ops_prim.py +5442 -1756
- mindspore/ops/auto_generate/pyboost_inner_prim.py +176 -56
- mindspore/ops/composite/base.py +85 -48
- mindspore/ops/composite/multitype_ops/_compile_utils.py +1 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -2
- mindspore/ops/function/__init__.py +22 -0
- mindspore/ops/function/array_func.py +492 -153
- mindspore/ops/function/debug_func.py +113 -1
- mindspore/ops/function/fft_func.py +15 -2
- mindspore/ops/function/grad/grad_func.py +3 -2
- mindspore/ops/function/math_func.py +564 -207
- mindspore/ops/function/nn_func.py +817 -383
- mindspore/ops/function/other_func.py +3 -2
- mindspore/ops/function/random_func.py +402 -12
- mindspore/ops/function/reshard_func.py +13 -11
- mindspore/ops/function/sparse_unary_func.py +1 -1
- mindspore/ops/function/vmap_func.py +3 -2
- mindspore/ops/functional.py +24 -14
- mindspore/ops/op_info_register.py +3 -3
- mindspore/ops/operations/__init__.py +7 -2
- mindspore/ops/operations/_grad_ops.py +2 -76
- mindspore/ops/operations/_infer_ops.py +1 -1
- mindspore/ops/operations/_inner_ops.py +71 -94
- mindspore/ops/operations/array_ops.py +14 -146
- mindspore/ops/operations/comm_ops.py +63 -53
- mindspore/ops/operations/custom_ops.py +83 -19
- mindspore/ops/operations/debug_ops.py +42 -10
- mindspore/ops/operations/manually_defined/_inner.py +12 -0
- mindspore/ops/operations/manually_defined/ops_def.py +273 -20
- mindspore/ops/operations/math_ops.py +12 -223
- mindspore/ops/operations/nn_ops.py +20 -114
- mindspore/ops/operations/other_ops.py +7 -4
- mindspore/ops/operations/random_ops.py +46 -1
- mindspore/ops/primitive.py +18 -6
- mindspore/ops_generate/arg_dtype_cast.py +2 -0
- mindspore/ops_generate/gen_aclnn_implement.py +11 -11
- mindspore/ops_generate/gen_constants.py +36 -0
- mindspore/ops_generate/gen_ops.py +67 -52
- mindspore/ops_generate/gen_ops_inner_prim.py +1 -1
- mindspore/ops_generate/gen_pyboost_func.py +131 -47
- mindspore/ops_generate/op_proto.py +10 -3
- mindspore/ops_generate/pyboost_utils.py +14 -1
- mindspore/ops_generate/template.py +43 -21
- mindspore/parallel/__init__.py +3 -1
- mindspore/parallel/_auto_parallel_context.py +31 -9
- mindspore/parallel/_cell_wrapper.py +85 -0
- mindspore/parallel/_parallel_serialization.py +47 -19
- mindspore/parallel/_tensor.py +127 -13
- mindspore/parallel/_utils.py +53 -22
- mindspore/parallel/algo_parameter_config.py +5 -5
- mindspore/parallel/checkpoint_transform.py +46 -39
- mindspore/parallel/cluster/process_entity/__init__.py +1 -1
- mindspore/parallel/cluster/process_entity/_api.py +31 -23
- mindspore/parallel/cluster/process_entity/_utils.py +2 -27
- mindspore/parallel/parameter_broadcast.py +3 -4
- mindspore/parallel/shard.py +162 -31
- mindspore/parallel/transform_safetensors.py +1146 -0
- mindspore/profiler/__init__.py +2 -1
- mindspore/profiler/common/constant.py +29 -0
- mindspore/profiler/common/registry.py +47 -0
- mindspore/profiler/common/util.py +28 -0
- mindspore/profiler/dynamic_profiler.py +694 -0
- mindspore/profiler/envprofiling.py +17 -19
- mindspore/profiler/parser/ascend_analysis/constant.py +18 -0
- mindspore/profiler/parser/ascend_analysis/file_manager.py +25 -4
- mindspore/profiler/parser/ascend_analysis/function_event.py +43 -19
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +31 -26
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +56 -10
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +55 -8
- mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +27 -20
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +9 -2
- mindspore/profiler/parser/ascend_msprof_exporter.py +5 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +27 -25
- mindspore/profiler/parser/base_timeline_generator.py +19 -25
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
- mindspore/profiler/parser/framework_parser.py +1 -391
- mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
- mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
- mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
- mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
- mindspore/profiler/parser/memory_usage_parser.py +0 -154
- mindspore/profiler/parser/profiler_info.py +78 -6
- mindspore/profiler/profiler.py +153 -0
- mindspore/profiler/profiling.py +285 -413
- mindspore/rewrite/__init__.py +1 -2
- mindspore/rewrite/common/namespace.py +4 -4
- mindspore/rewrite/symbol_tree/symbol_tree.py +3 -3
- mindspore/run_check/_check_version.py +39 -104
- mindspore/safeguard/rewrite_obfuscation.py +591 -247
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +4 -3
- mindspore/train/_utils.py +105 -19
- mindspore/train/amp.py +171 -53
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +4 -4
- mindspore/train/callback/_checkpoint.py +97 -31
- mindspore/train/callback/_cluster_monitor.py +1 -1
- mindspore/train/callback/_flops_collector.py +1 -0
- mindspore/train/callback/_loss_monitor.py +3 -3
- mindspore/train/callback/_on_request_exit.py +145 -31
- mindspore/train/callback/_summary_collector.py +5 -5
- mindspore/train/callback/_tft_register.py +375 -0
- mindspore/train/dataset_helper.py +15 -3
- mindspore/train/metrics/metric.py +3 -3
- mindspore/train/metrics/roc.py +4 -4
- mindspore/train/mind_ir_pb2.py +44 -39
- mindspore/train/model.py +154 -58
- mindspore/train/serialization.py +342 -128
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +21 -0
- mindspore/utils/utils.py +60 -0
- mindspore/version.py +1 -1
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/METADATA +13 -7
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/RECORD +260 -254
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/WHEEL +1 -1
- mindspore/include/c_api/ms/abstract.h +0 -67
- mindspore/include/c_api/ms/attribute.h +0 -197
- mindspore/include/c_api/ms/base/handle_types.h +0 -43
- mindspore/include/c_api/ms/base/macros.h +0 -32
- mindspore/include/c_api/ms/base/status.h +0 -33
- mindspore/include/c_api/ms/base/types.h +0 -283
- mindspore/include/c_api/ms/context.h +0 -102
- mindspore/include/c_api/ms/graph.h +0 -160
- mindspore/include/c_api/ms/node.h +0 -606
- mindspore/include/c_api/ms/tensor.h +0 -161
- mindspore/include/c_api/ms/value.h +0 -84
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/nn/extend/basic.py +0 -140
- mindspore/nn/extend/embedding.py +0 -143
- mindspore/nn/extend/layer/normalization.py +0 -109
- mindspore/nn/extend/pooling.py +0 -117
- mindspore/nn/layer/embedding_service.py +0 -531
- mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
- mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
- mindspore/ops/extend/__init__.py +0 -53
- mindspore/ops/extend/array_func.py +0 -218
- mindspore/ops/extend/math_func.py +0 -76
- mindspore/ops/extend/nn_func.py +0 -308
- mindspore/ops/silent_check.py +0 -162
- mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
- mindspore/profiler/parser/msadvisor_parser.py +0 -240
- mindspore/train/callback/_mindio_ttp.py +0 -443
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/top_level.txt +0 -0
mindspore/ops/primitive.py
CHANGED
|
@@ -171,6 +171,9 @@ class Primitive(Primitive_):
|
|
|
171
171
|
if not isinstance(in_value, int) and self.name not in SUPPORTED_TUPLE_IN_TUPLE_STRATEGY:
|
|
172
172
|
raise TypeError(f'The {log_info}: {strategy} of {self.name} is not valid,'
|
|
173
173
|
f' the value of strategy must be int type, but got:{type(in_value)}')
|
|
174
|
+
if isinstance(in_value, Layout) and (self.name in SUPPORTED_TUPLE_IN_TUPLE_STRATEGY):
|
|
175
|
+
is_layout.append(True)
|
|
176
|
+
continue
|
|
174
177
|
is_layout.append(False)
|
|
175
178
|
continue
|
|
176
179
|
is_layout.append(True)
|
|
@@ -188,9 +191,17 @@ class Primitive(Primitive_):
|
|
|
188
191
|
raise TypeError(f'{log_info} must be tuple type, but got:{type(layout)}')
|
|
189
192
|
layout_value = ()
|
|
190
193
|
for in_ele in layout:
|
|
191
|
-
if
|
|
192
|
-
|
|
193
|
-
|
|
194
|
+
if isinstance(in_ele, Layout):
|
|
195
|
+
layout_value += (in_ele.to_dict(),)
|
|
196
|
+
elif isinstance(in_ele, tuple):
|
|
197
|
+
new_layout_list = ()
|
|
198
|
+
for ele in in_ele:
|
|
199
|
+
if not isinstance(ele, Layout):
|
|
200
|
+
raise TypeError(f"The {log_info} item should be a object of class Layout.")
|
|
201
|
+
new_layout_list += (ele.to_dict(),)
|
|
202
|
+
layout_value += (new_layout_list,)
|
|
203
|
+
else:
|
|
204
|
+
raise TypeError(f"The {log_info} item should be a object of class Layout or a tuple.")
|
|
194
205
|
return layout_value
|
|
195
206
|
|
|
196
207
|
def _check_shard_strategy_in_out_match(self, in_strategy, out_strategy):
|
|
@@ -299,7 +310,8 @@ class Primitive(Primitive_):
|
|
|
299
310
|
if out_strategy is not None:
|
|
300
311
|
out_is_layout = self._check_shard_strategy(out_strategy, "out_strategy")
|
|
301
312
|
self._check_shard_strategy_in_out_match(in_strategy, out_strategy)
|
|
302
|
-
if in_is_layout is not None and out_is_layout is not None and
|
|
313
|
+
if in_is_layout is not None and out_is_layout is not None and (
|
|
314
|
+
(in_is_layout[0] != out_is_layout[0]) and (self.name not in SUPPORTED_TUPLE_IN_TUPLE_STRATEGY)):
|
|
303
315
|
raise ValueError(f'The in_strategy type must equal to the out_strategy type, '
|
|
304
316
|
f'one using tuple(tuple) and the other using tuple(Layout) is not allowed.')
|
|
305
317
|
in_layout_value = None
|
|
@@ -549,7 +561,7 @@ class PrimitiveWithCheck(Primitive):
|
|
|
549
561
|
the shape and type. Method infer_value() can also be defined (such as PrimitiveWithInfer) for constant propagation.
|
|
550
562
|
|
|
551
563
|
More on how to customize a Op, please refer to `Custom Operators
|
|
552
|
-
<https://www.mindspore.cn/
|
|
564
|
+
<https://www.mindspore.cn/docs/en/master/model_train/custom_program/op_custom.html>`_.
|
|
553
565
|
|
|
554
566
|
Args:
|
|
555
567
|
name (str): Name of the current Primitive.
|
|
@@ -643,7 +655,7 @@ class PrimitiveWithInfer(Primitive):
|
|
|
643
655
|
logic of the shape and type. The infer_value() is used for constant propagation.
|
|
644
656
|
|
|
645
657
|
More on how to customize a Op, please refer to `Custom Operators
|
|
646
|
-
<https://www.mindspore.cn/
|
|
658
|
+
<https://www.mindspore.cn/docs/en/master/model_train/custom_program/op_custom.html>`_.
|
|
647
659
|
|
|
648
660
|
Args:
|
|
649
661
|
name (str): Name of the current Primitive.
|
|
@@ -23,6 +23,7 @@ import pathlib
|
|
|
23
23
|
import logging
|
|
24
24
|
import gen_utils
|
|
25
25
|
from pyboost_utils import AclnnUtils, get_dtypes
|
|
26
|
+
from gen_constants import MS_OPS_KERNEL_PATH
|
|
26
27
|
|
|
27
28
|
auto_gen = ''
|
|
28
29
|
|
|
@@ -35,7 +36,7 @@ def gen_h(op_name, aclnn_name, op_yaml, kernelmod_h_path, need_update_shape):
|
|
|
35
36
|
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_{op_name.upper()}_ACLNN{auto_gen.upper()}_KERNEL_MOD_H_
|
|
36
37
|
#include <vector>
|
|
37
38
|
#include "ops/base_operator.h"
|
|
38
|
-
#include "
|
|
39
|
+
#include "{MS_OPS_KERNEL_PATH}/ascend/opapi/aclnn_kernel_mod.h"
|
|
39
40
|
#include "transform/acl_ir/acl_convert.h"
|
|
40
41
|
"""
|
|
41
42
|
update_shape = f"""
|
|
@@ -78,7 +79,7 @@ def gen_cc(op_name, class_name, op_yaml, kernelmod_cc_path, need_update_shape):
|
|
|
78
79
|
"""generate cc files"""
|
|
79
80
|
kernelmod_name = op_yaml.get('dispatch').get("Ascend")
|
|
80
81
|
cc_head = f"""
|
|
81
|
-
#include "
|
|
82
|
+
#include "{MS_OPS_KERNEL_PATH}/ascend/opapi/aclnn{auto_gen}/{op_name}_aclnn_kernel.h"
|
|
82
83
|
#include <algorithm>
|
|
83
84
|
#include <vector>
|
|
84
85
|
#include <memory>
|
|
@@ -135,8 +136,7 @@ bool {kernelmod_name}::Launch(const std::vector<KernelTensor *> &inputs, const s
|
|
|
135
136
|
const std::vector<KernelTensor *> &outputs, void *stream_ptr) {{
|
|
136
137
|
MS_EXCEPTION_IF_NULL(stream_ptr);
|
|
137
138
|
{input_templete}
|
|
138
|
-
|
|
139
|
-
RunOp(stream_ptr, workspace);
|
|
139
|
+
RunOp(stream_ptr, workspace, {inputs});
|
|
140
140
|
return true;
|
|
141
141
|
}}
|
|
142
142
|
"""
|
|
@@ -177,10 +177,10 @@ def gen_aclnn_kernel(op_name, yaml_str, need_update_shape=False, auto=False):
|
|
|
177
177
|
if check_op_registed(op_name) and not auto:
|
|
178
178
|
logging.warning("Kernel {%s} is already registered.", op_name)
|
|
179
179
|
return
|
|
180
|
-
current_path = os.path.dirname(os.path.
|
|
180
|
+
current_path = os.path.dirname(os.path.realpath(__file__))
|
|
181
181
|
work_path = os.path.join(current_path, '../../../../')
|
|
182
182
|
|
|
183
|
-
aclnn_path = '
|
|
183
|
+
aclnn_path = '{MS_OPS_KERNEL_PATH}/ascend/opapi/aclnn/'
|
|
184
184
|
# merge inner ops
|
|
185
185
|
op_yaml = yaml_str.get(op_name)
|
|
186
186
|
class_name = ''.join(word.capitalize() for word in op_name.split('_'))
|
|
@@ -196,7 +196,7 @@ def gen_aclnn_kernel(op_name, yaml_str, need_update_shape=False, auto=False):
|
|
|
196
196
|
return
|
|
197
197
|
auto_gen = "_auto_gen"
|
|
198
198
|
dispatch['Ascend'] = class_name + "Ascend"
|
|
199
|
-
aclnn_path = '
|
|
199
|
+
aclnn_path = f'{MS_OPS_KERNEL_PATH}/ascend/opapi/aclnn_auto_gen/'
|
|
200
200
|
pathlib.Path(os.path.join(work_path, aclnn_path)).mkdir(parents=True, exist_ok=True)
|
|
201
201
|
if dispatch.get("Ascend") is None:
|
|
202
202
|
raise ValueError("KernelMod {} is auto generated. If need achieve it, "
|
|
@@ -208,10 +208,10 @@ def gen_aclnn_kernel(op_name, yaml_str, need_update_shape=False, auto=False):
|
|
|
208
208
|
generate(op_name, class_name, op_yaml, kernelmod_h_and_cc_path, need_update_shape)
|
|
209
209
|
|
|
210
210
|
|
|
211
|
-
def get_registed_ops(file_path='
|
|
211
|
+
def get_registed_ops(file_path=f'{MS_OPS_KERNEL_PATH}/ascend/opapi/'):
|
|
212
212
|
'''get registered ops by search files'''
|
|
213
|
-
# default search in '
|
|
214
|
-
current_path = os.path.dirname(os.path.
|
|
213
|
+
# default search in 'ops/kernel/ascend/opapi/'
|
|
214
|
+
current_path = os.path.dirname(os.path.realpath(__file__))
|
|
215
215
|
work_path = os.path.join(current_path, '../../../../')
|
|
216
216
|
search_path = os.path.join(work_path, file_path)
|
|
217
217
|
ret = []
|
|
@@ -230,7 +230,7 @@ def get_registed_ops(file_path='mindspore/ccsrc/plugin/device/ascend/kernel/opap
|
|
|
230
230
|
|
|
231
231
|
|
|
232
232
|
registed_ops = get_registed_ops()
|
|
233
|
-
manual_registed_ops = get_registed_ops('
|
|
233
|
+
manual_registed_ops = get_registed_ops(f'{MS_OPS_KERNEL_PATH}/ascend/opapi/aclnn/')
|
|
234
234
|
|
|
235
235
|
|
|
236
236
|
def check_op_registed(op_name, manual=False):
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""
|
|
16
|
+
Generate constants
|
|
17
|
+
"""
|
|
18
|
+
import os
|
|
19
|
+
|
|
20
|
+
# op_def
|
|
21
|
+
OP_DEF_AUTO_GENERATE_PATH = "op_def/auto_generate"
|
|
22
|
+
MS_OP_DEF_AUTO_GENERATE_PATH = "mindspore/ops/op_def/auto_generate"
|
|
23
|
+
YAML_PATH = "op_def/yaml"
|
|
24
|
+
MS_YAML_PATH = "mindspore/ops/" + YAML_PATH
|
|
25
|
+
PY_AUTO_GEN_PATH = "mindspore/python/mindspore/ops/auto_generate"
|
|
26
|
+
PY_OPS_GEN_PATH = "mindspore/python/mindspore/ops_generate"
|
|
27
|
+
|
|
28
|
+
# infer
|
|
29
|
+
MS_OPS_FUNC_IMPL_PATH = "mindspore/ops/infer/ops_func_impl"
|
|
30
|
+
|
|
31
|
+
# view
|
|
32
|
+
MS_OPS_VIEW_PATH = "mindspore/ops/view"
|
|
33
|
+
|
|
34
|
+
# kernel
|
|
35
|
+
MS_OPS_KERNEL_PATH = "mindspore/ops/kernel"
|
|
36
|
+
MS_COMMON_PYBOOST_KERNEL_PATH = os.path.join(MS_OPS_KERNEL_PATH, "common/pyboost")
|
|
@@ -28,6 +28,7 @@ import template
|
|
|
28
28
|
from template import CppTemplate
|
|
29
29
|
from gen_pyboost_func import gen_pyboost_code
|
|
30
30
|
from gen_aclnn_implement import gen_aclnn_kernel
|
|
31
|
+
import gen_constants as K
|
|
31
32
|
|
|
32
33
|
|
|
33
34
|
def _get_op_name(yaml_key, yaml_value):
|
|
@@ -185,6 +186,7 @@ def generate_py_op_signature(op_name, args_signature, args_name, args_default):
|
|
|
185
186
|
"""
|
|
186
187
|
Generate __mindspore_signature__
|
|
187
188
|
"""
|
|
189
|
+
|
|
188
190
|
def _check_signature_arg_valid(op_name, sig_arg_names, args_names):
|
|
189
191
|
for sig_arg_name in sig_arg_names:
|
|
190
192
|
if sig_arg_name not in args_names:
|
|
@@ -635,7 +637,7 @@ def generate_op_prim_opdef(yaml_data):
|
|
|
635
637
|
#include <memory>
|
|
636
638
|
#include "ir/anf.h"
|
|
637
639
|
#include "ir/primitive.h"
|
|
638
|
-
#include "
|
|
640
|
+
#include "{K.MS_OP_DEF_AUTO_GENERATE_PATH}/gen_ops_name.h"
|
|
639
641
|
#include "mindapi/base/macros.h"
|
|
640
642
|
|
|
641
643
|
namespace mindspore::prim {{
|
|
@@ -665,7 +667,7 @@ def generate_lite_ops(yaml_data):
|
|
|
665
667
|
|
|
666
668
|
#include <vector>
|
|
667
669
|
#include "ops/base_operator.h"
|
|
668
|
-
#include "
|
|
670
|
+
#include "{K.OP_DEF_AUTO_GENERATE_PATH}/gen_ops_name.h"
|
|
669
671
|
|
|
670
672
|
namespace mindspore::ops {{
|
|
671
673
|
"""
|
|
@@ -674,14 +676,14 @@ namespace mindspore::ops {{
|
|
|
674
676
|
#endif // MINDSPORE_CORE_OPS_GEN_LITE_OPS_H_
|
|
675
677
|
"""
|
|
676
678
|
|
|
677
|
-
lite_ops_cc_head = """
|
|
678
|
-
#include "
|
|
679
|
-
#include "mindapi/
|
|
679
|
+
lite_ops_cc_head = f"""
|
|
680
|
+
#include "{K.OP_DEF_AUTO_GENERATE_PATH}/gen_lite_ops.h"
|
|
681
|
+
#include "mindapi/helper.h"
|
|
680
682
|
#include "ops/primitive_c.h"
|
|
681
683
|
#include "ops/base_operator.h"
|
|
682
684
|
#include "abstract/abstract_value.h"
|
|
683
685
|
|
|
684
|
-
namespace mindspore::ops {
|
|
686
|
+
namespace mindspore::ops {{
|
|
685
687
|
"""
|
|
686
688
|
|
|
687
689
|
lite_ops_cc_end = f"""}} // namespace mindspore::ops
|
|
@@ -694,7 +696,7 @@ namespace mindspore::ops {
|
|
|
694
696
|
lite_ops_cc_gen += lite_ops_cc_head
|
|
695
697
|
for operator_name, operator_data in yaml_data.items():
|
|
696
698
|
op_name = _get_op_name(operator_name, operator_data)
|
|
697
|
-
lite_ops_h_gen += f"""class
|
|
699
|
+
lite_ops_h_gen += f"""class OPS_API {op_name} : public BaseOperator {{
|
|
698
700
|
public:
|
|
699
701
|
MIND_API_BASE_MEMBER({op_name});
|
|
700
702
|
{op_name}() : BaseOperator(kName{op_name}) {{}}\n"""
|
|
@@ -740,12 +742,10 @@ def generate_cc_opdef(yaml_data):
|
|
|
740
742
|
"""
|
|
741
743
|
gen_cc_code = f"""\n
|
|
742
744
|
namespace mindspore::ops {{"""
|
|
743
|
-
gen_opdef_map = f"""
|
|
744
|
-
std::unordered_map<std::string, OpDefPtr> gOpDefTable = {{"""
|
|
745
745
|
gen_include = f"""\n
|
|
746
|
-
#include \"
|
|
746
|
+
#include \"{K.MS_OP_DEF_AUTO_GENERATE_PATH}/gen_ops_def.h\""""
|
|
747
747
|
gen_include += f"""
|
|
748
|
-
#include \"
|
|
748
|
+
#include \"ir/signature.h\""""
|
|
749
749
|
|
|
750
750
|
for operator_name, operator_data in yaml_data.items():
|
|
751
751
|
args = operator_data.get('args')
|
|
@@ -764,15 +764,13 @@ std::unordered_map<std::string, OpDefPtr> gOpDefTable = {{"""
|
|
|
764
764
|
|
|
765
765
|
is_view = operator_data.get('view')
|
|
766
766
|
if is_view:
|
|
767
|
-
|
|
767
|
+
is_view_s = "true"
|
|
768
768
|
else:
|
|
769
|
-
|
|
770
|
-
is_view_str = f"""{
|
|
771
|
-
|
|
769
|
+
is_view_s = "false"
|
|
770
|
+
is_view_str = f"""{is_view_s}"""
|
|
772
771
|
|
|
773
|
-
gen_include += f"""\n#include "
|
|
772
|
+
gen_include += f"""\n#include "{K.MS_OPS_FUNC_IMPL_PATH}/{operator_name}.h\""""
|
|
774
773
|
cc_index_str = ''
|
|
775
|
-
gen_opdef_map += f"""\n {{"{class_name}", &g{class_name}}},"""
|
|
776
774
|
input_args_str = ''
|
|
777
775
|
args_dict = {}
|
|
778
776
|
for i, (arg_name, arg_info) in enumerate(args.items()):
|
|
@@ -814,8 +812,9 @@ std::unordered_map<std::string, OpDefPtr> gOpDefTable = {{"""
|
|
|
814
812
|
indexes=cc_index_str, enable_dispatch=enable_dispatch_str,
|
|
815
813
|
is_view=is_view_str)
|
|
816
814
|
gen_cc_code += op_def_cc
|
|
817
|
-
|
|
818
|
-
|
|
815
|
+
if is_view:
|
|
816
|
+
view_op_def = op_def_cc.replace(class_name, class_name+"View")
|
|
817
|
+
gen_cc_code += view_op_def
|
|
819
818
|
|
|
820
819
|
cc_opdef_end = f"""\n}} // namespace mindspore::ops\n"""
|
|
821
820
|
return gen_include + gen_cc_code + cc_opdef_end
|
|
@@ -835,7 +834,6 @@ from mindspore._c_expression import OpDtype
|
|
|
835
834
|
from mindspore.common._stub_tensor import _convert_stub
|
|
836
835
|
"""
|
|
837
836
|
|
|
838
|
-
|
|
839
837
|
ops_py_def_header = f"""
|
|
840
838
|
\"\"\"Operators definition generated by gen_ops.py, includes functions.\"\"\"
|
|
841
839
|
|
|
@@ -847,8 +845,8 @@ from mindspore.ops._primitive_cache import _get_cache_prim
|
|
|
847
845
|
|
|
848
846
|
|
|
849
847
|
def generate_ops_prim_file(work_path, yaml_str, doc_str, file_pre):
|
|
850
|
-
py_path = os.path.join(work_path, f'
|
|
851
|
-
tmp_py_path = os.path.join(work_path, f'
|
|
848
|
+
py_path = os.path.join(work_path, f'{K.PY_AUTO_GEN_PATH}/{file_pre}_ops_prim.py')
|
|
849
|
+
tmp_py_path = os.path.join(work_path, f'{K.PY_AUTO_GEN_PATH}/tmp_{file_pre}_ops_prim.py')
|
|
852
850
|
pyboost_import_header = generate_pyboost_import_header(yaml_str)
|
|
853
851
|
py_prim = generate_py_primitive(yaml_str, doc_str)
|
|
854
852
|
write_file(tmp_py_path, py_licence_str + ops_py_prim_header + pyboost_import_header + py_prim)
|
|
@@ -856,8 +854,8 @@ def generate_ops_prim_file(work_path, yaml_str, doc_str, file_pre):
|
|
|
856
854
|
|
|
857
855
|
|
|
858
856
|
def generate_ops_def_file(work_path, yaml_str, doc_str, file_pre):
|
|
859
|
-
py_path = os.path.join(work_path, f'
|
|
860
|
-
tmp_py_path = os.path.join(work_path, f'
|
|
857
|
+
py_path = os.path.join(work_path, f'{K.PY_AUTO_GEN_PATH}/{file_pre}_ops_def.py')
|
|
858
|
+
tmp_py_path = os.path.join(work_path, f'{K.PY_AUTO_GEN_PATH}/tmp_{file_pre}_ops_def.py')
|
|
861
859
|
py_func = generate_py_op_func(yaml_str, doc_str)
|
|
862
860
|
write_file(tmp_py_path, py_licence_str + ops_py_def_header + py_func)
|
|
863
861
|
check_change_and_replace_file(py_path, tmp_py_path)
|
|
@@ -869,6 +867,8 @@ def generate_ops_py_files(work_path, yaml_str, doc_str, file_pre):
|
|
|
869
867
|
"""
|
|
870
868
|
generate_ops_prim_file(work_path, yaml_str, doc_str, file_pre)
|
|
871
869
|
generate_ops_def_file(work_path, yaml_str, doc_str, file_pre)
|
|
870
|
+
shutil.copy(os.path.join(work_path, K.PY_OPS_GEN_PATH, 'ops_auto_generate_init.txt'),
|
|
871
|
+
os.path.join(work_path, K.PY_AUTO_GEN_PATH, "__init__.py"))
|
|
872
872
|
|
|
873
873
|
|
|
874
874
|
def generate_ops_cc_files(work_path, yaml_str):
|
|
@@ -876,35 +876,35 @@ def generate_ops_cc_files(work_path, yaml_str):
|
|
|
876
876
|
Generate ops c++ file from yaml.
|
|
877
877
|
"""
|
|
878
878
|
# ops_def
|
|
879
|
-
op_cc_path = os.path.join(work_path, '
|
|
880
|
-
tmp_op_cc_path = os.path.join(work_path, '
|
|
879
|
+
op_cc_path = os.path.join(work_path, K.MS_OP_DEF_AUTO_GENERATE_PATH, 'gen_ops_def.cc')
|
|
880
|
+
tmp_op_cc_path = os.path.join(work_path, K.MS_OP_DEF_AUTO_GENERATE_PATH, 'tmp_gen_ops_def.cc')
|
|
881
881
|
cc_def_code = generate_cc_opdef(yaml_str)
|
|
882
882
|
write_file(tmp_op_cc_path, cc_license_str + cc_def_code)
|
|
883
883
|
check_change_and_replace_file(op_cc_path, tmp_op_cc_path)
|
|
884
884
|
|
|
885
885
|
# ops_primitive
|
|
886
|
-
op_prim_path = os.path.join(work_path, '
|
|
887
|
-
tmp_op_prim_path = os.path.join(work_path, '
|
|
886
|
+
op_prim_path = os.path.join(work_path, K.MS_OP_DEF_AUTO_GENERATE_PATH, 'gen_ops_primitive.h')
|
|
887
|
+
tmp_op_prim_path = os.path.join(work_path, K.MS_OP_DEF_AUTO_GENERATE_PATH, 'tmp_gen_ops_primitive.h')
|
|
888
888
|
op_prim_code = generate_op_prim_opdef(yaml_str)
|
|
889
889
|
write_file(tmp_op_prim_path, cc_license_str + op_prim_code)
|
|
890
890
|
check_change_and_replace_file(op_prim_path, tmp_op_prim_path)
|
|
891
891
|
|
|
892
892
|
# lite_h_ops
|
|
893
|
-
lite_ops_h_path = os.path.join(work_path, '
|
|
894
|
-
tmp_lite_ops_h_path = os.path.join(work_path, '
|
|
893
|
+
lite_ops_h_path = os.path.join(work_path, K.MS_OP_DEF_AUTO_GENERATE_PATH, 'gen_lite_ops.h')
|
|
894
|
+
tmp_lite_ops_h_path = os.path.join(work_path, K.MS_OP_DEF_AUTO_GENERATE_PATH, 'tmp_gen_lite_ops.h')
|
|
895
895
|
lite_ops_h_code, lite_ops_cc_code = generate_lite_ops(yaml_str)
|
|
896
896
|
write_file(tmp_lite_ops_h_path, cc_license_str + lite_ops_h_code)
|
|
897
897
|
check_change_and_replace_file(lite_ops_h_path, tmp_lite_ops_h_path)
|
|
898
898
|
|
|
899
899
|
# lite_cc_ops
|
|
900
|
-
lite_ops_cc_path = os.path.join(work_path, '
|
|
901
|
-
tmp_lite_ops_cc_path = os.path.join(work_path, '
|
|
900
|
+
lite_ops_cc_path = os.path.join(work_path, K.MS_OP_DEF_AUTO_GENERATE_PATH, 'gen_lite_ops.cc')
|
|
901
|
+
tmp_lite_ops_cc_path = os.path.join(work_path, K.MS_OP_DEF_AUTO_GENERATE_PATH, 'tmp_gen_lite_ops.cc')
|
|
902
902
|
write_file(tmp_lite_ops_cc_path, cc_license_str + lite_ops_cc_code)
|
|
903
903
|
check_change_and_replace_file(lite_ops_cc_path, tmp_lite_ops_cc_path)
|
|
904
904
|
|
|
905
905
|
# ops_names
|
|
906
|
-
op_name_path = os.path.join(work_path, '
|
|
907
|
-
tmp_op_name_path = os.path.join(work_path, '
|
|
906
|
+
op_name_path = os.path.join(work_path, K.MS_OP_DEF_AUTO_GENERATE_PATH, 'gen_ops_name.h')
|
|
907
|
+
tmp_op_name_path = os.path.join(work_path, K.MS_OP_DEF_AUTO_GENERATE_PATH, 'tmp_gen_ops_name.h')
|
|
908
908
|
op_name_code = generate_op_name_opdef(yaml_str)
|
|
909
909
|
write_file(tmp_op_name_path, cc_license_str + op_name_code)
|
|
910
910
|
check_change_and_replace_file(op_name_path, tmp_op_name_path)
|
|
@@ -958,7 +958,7 @@ def generate_create_instance_helper_file(work_path, yaml_str):
|
|
|
958
958
|
"""
|
|
959
959
|
Generate C++ helper file from yaml.
|
|
960
960
|
"""
|
|
961
|
-
dst_dir = os.path.join(work_path,
|
|
961
|
+
dst_dir = os.path.join(work_path, K.PY_AUTO_GEN_PATH)
|
|
962
962
|
op_py_path = os.path.join(dst_dir, 'cpp_create_prim_instance_helper.py')
|
|
963
963
|
tmp_op_py_path = os.path.join(dst_dir, 'tmp_cpp_create_prim_instance_helper.py')
|
|
964
964
|
py_labels = generate_op_labels(yaml_str)
|
|
@@ -969,13 +969,13 @@ def generate_create_instance_helper_file(work_path, yaml_str):
|
|
|
969
969
|
|
|
970
970
|
def generate_aclnn_reg_code(yaml_data):
|
|
971
971
|
"""generate aclnn register code"""
|
|
972
|
-
current_path = os.path.dirname(os.path.
|
|
972
|
+
current_path = os.path.dirname(os.path.realpath(__file__))
|
|
973
973
|
work_path = os.path.join(current_path, '../../../../')
|
|
974
|
-
ops_yaml_path = os.path.join(work_path,
|
|
974
|
+
ops_yaml_path = os.path.join(work_path, K.PY_OPS_GEN_PATH, "ops.yaml")
|
|
975
975
|
yaml_str = gen_utils.safe_load_yaml(ops_yaml_path)
|
|
976
976
|
|
|
977
977
|
reg_code = f"""
|
|
978
|
-
#include "
|
|
978
|
+
#include "{K.MS_OPS_KERNEL_PATH}/ascend/opapi/aclnn_kernel_mod.h"
|
|
979
979
|
|
|
980
980
|
namespace mindspore {{
|
|
981
981
|
namespace kernel {{
|
|
@@ -1010,8 +1010,8 @@ def generate_aclnn_reg_file(work_path, yaml_str):
|
|
|
1010
1010
|
"""
|
|
1011
1011
|
Generate nnacl kernelmod register
|
|
1012
1012
|
"""
|
|
1013
|
-
tmp_register_file = work_path + '
|
|
1014
|
-
register_file = work_path + '
|
|
1013
|
+
tmp_register_file = work_path + f'{K.MS_OPS_KERNEL_PATH}/ascend/opapi/tmp_aclnn_kernel_register.cc'
|
|
1014
|
+
register_file = work_path + f'{K.MS_OPS_KERNEL_PATH}/ascend/opapi/aclnn_kernel_register_auto.cc'
|
|
1015
1015
|
reg_code = generate_aclnn_reg_code(yaml_str)
|
|
1016
1016
|
write_file(tmp_register_file, cc_license_str + reg_code)
|
|
1017
1017
|
check_change_and_replace_file(register_file, tmp_register_file)
|
|
@@ -1021,39 +1021,52 @@ def generate_arg_handler_files(work_path):
|
|
|
1021
1021
|
"""
|
|
1022
1022
|
Generate arg handler files.
|
|
1023
1023
|
"""
|
|
1024
|
-
dst_dir = os.path.join(work_path,
|
|
1025
|
-
src_arg_handler_path = os.path.join(work_path, '
|
|
1024
|
+
dst_dir = os.path.join(work_path, K.PY_AUTO_GEN_PATH)
|
|
1025
|
+
src_arg_handler_path = os.path.join(work_path, K.PY_OPS_GEN_PATH, 'arg_handler.py')
|
|
1026
1026
|
dst_arg_handler_path = os.path.join(dst_dir, 'gen_arg_handler.py')
|
|
1027
1027
|
tmp_dst_arg_handler_path = os.path.join(dst_dir, 'tmp_gen_arg_handler.py')
|
|
1028
1028
|
if not os.path.exists(dst_dir):
|
|
1029
|
-
os.makedirs(dst_dir)
|
|
1029
|
+
os.makedirs(dst_dir, mode=0o700)
|
|
1030
1030
|
shutil.copy(src_arg_handler_path, tmp_dst_arg_handler_path)
|
|
1031
1031
|
check_change_and_replace_file(dst_arg_handler_path, tmp_dst_arg_handler_path)
|
|
1032
1032
|
|
|
1033
|
-
src_arg_dtype_cast_path = os.path.join(work_path, '
|
|
1033
|
+
src_arg_dtype_cast_path = os.path.join(work_path, K.PY_OPS_GEN_PATH, 'arg_dtype_cast.py')
|
|
1034
1034
|
dst_arg_dtype_cast_path = os.path.join(dst_dir, 'gen_arg_dtype_cast.py')
|
|
1035
1035
|
tmp_arg_dtype_cast_path = os.path.join(dst_dir, 'tmp_arg_dtype_cast.py')
|
|
1036
1036
|
shutil.copy(src_arg_dtype_cast_path, tmp_arg_dtype_cast_path)
|
|
1037
1037
|
check_change_and_replace_file(dst_arg_dtype_cast_path, tmp_arg_dtype_cast_path)
|
|
1038
1038
|
|
|
1039
1039
|
|
|
1040
|
+
def get_view_ops(yaml_data):
|
|
1041
|
+
"""
|
|
1042
|
+
Get ops with view: True
|
|
1043
|
+
"""
|
|
1044
|
+
view_ops = []
|
|
1045
|
+
for operator_name, operator_data in yaml_data.items():
|
|
1046
|
+
class_name = _get_op_name(operator_name, operator_data)
|
|
1047
|
+
view = operator_data.get("view")
|
|
1048
|
+
if view:
|
|
1049
|
+
view_ops.append(class_name + "View")
|
|
1050
|
+
return view_ops
|
|
1051
|
+
|
|
1052
|
+
|
|
1040
1053
|
def main():
|
|
1041
|
-
current_path = os.path.dirname(os.path.
|
|
1054
|
+
current_path = os.path.dirname(os.path.realpath(__file__))
|
|
1042
1055
|
work_path = os.path.join(current_path, '../../../../')
|
|
1043
1056
|
|
|
1044
1057
|
# merge ops yaml
|
|
1045
|
-
ops_yaml_path = os.path.join(work_path, '
|
|
1046
|
-
doc_yaml_path = os.path.join(work_path, '
|
|
1058
|
+
ops_yaml_path = os.path.join(work_path, K.PY_OPS_GEN_PATH, 'ops.yaml')
|
|
1059
|
+
doc_yaml_path = os.path.join(work_path, K.PY_OPS_GEN_PATH, 'ops_doc.yaml')
|
|
1047
1060
|
|
|
1048
|
-
ops_yaml_dir_path = os.path.join(work_path,
|
|
1049
|
-
infer_ops_yaml_dir_path = os.path.join(
|
|
1050
|
-
doc_yaml_dir_path = os.path.join(
|
|
1061
|
+
ops_yaml_dir_path = os.path.join(work_path, K.MS_YAML_PATH)
|
|
1062
|
+
infer_ops_yaml_dir_path = os.path.join(ops_yaml_dir_path, "infer")
|
|
1063
|
+
doc_yaml_dir_path = os.path.join(ops_yaml_dir_path, "doc")
|
|
1051
1064
|
merge_files(ops_yaml_dir_path, ops_yaml_path, '*op.yaml')
|
|
1052
1065
|
merge_files_append(infer_ops_yaml_dir_path, ops_yaml_path, '*op.yaml')
|
|
1053
1066
|
merge_files(doc_yaml_dir_path, doc_yaml_path, '*doc.yaml')
|
|
1054
1067
|
|
|
1055
1068
|
# make auto_generate dir
|
|
1056
|
-
cc_path = os.path.join(work_path,
|
|
1069
|
+
cc_path = os.path.join(work_path, K.MS_OP_DEF_AUTO_GENERATE_PATH)
|
|
1057
1070
|
pathlib.Path(cc_path).mkdir(parents=True, exist_ok=True)
|
|
1058
1071
|
|
|
1059
1072
|
# generate arg_handler files
|
|
@@ -1070,8 +1083,10 @@ def main():
|
|
|
1070
1083
|
generate_ops_cc_files(work_path, ops_yaml_str)
|
|
1071
1084
|
# generate create prim instance helper file
|
|
1072
1085
|
generate_create_instance_helper_file(work_path, ops_yaml_str)
|
|
1086
|
+
# get view extra ops
|
|
1087
|
+
extra_ops = get_view_ops(ops_yaml_str)
|
|
1073
1088
|
# generate pyboost code
|
|
1074
|
-
gen_pyboost_code(work_path, ops_yaml_str, doc_yaml_str)
|
|
1089
|
+
gen_pyboost_code(work_path, ops_yaml_str, doc_yaml_str, extra_ops)
|
|
1075
1090
|
# generate aclnn kernelmod register
|
|
1076
1091
|
generate_aclnn_reg_file(work_path, ops_yaml_str)
|
|
1077
1092
|
|
|
@@ -42,7 +42,7 @@ class DtypeToEnum(Primitive):
|
|
|
42
42
|
def __call__(self, op_name, arg_name, dtype):
|
|
43
43
|
"""Run in PyNative mode"""
|
|
44
44
|
if not isinstance(dtype, typing.Type):
|
|
45
|
-
raise TypeError(f"For '{op_name}', the input '{arg_name}' should be
|
|
45
|
+
raise TypeError(f"For '{op_name}', the input '{arg_name}' should be mindspore dtype, but got {dtype}.")
|
|
46
46
|
return typing.type_to_type_id(dtype)
|
|
47
47
|
|
|
48
48
|
|