mindspore 2.7.0__cp310-cp310-win_amd64.whl → 2.7.1__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (290) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
  6. mindspore/_extends/parse/compile_config.py +24 -1
  7. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +6 -2
  8. mindspore/_extends/parse/resources.py +1 -1
  9. mindspore/_extends/parse/standard_method.py +8 -1
  10. mindspore/_extends/parse/trope.py +2 -1
  11. mindspore/_extends/pijit/pijit_func_white_list.py +7 -22
  12. mindspore/avcodec-59.dll +0 -0
  13. mindspore/avdevice-59.dll +0 -0
  14. mindspore/avfilter-8.dll +0 -0
  15. mindspore/avformat-59.dll +0 -0
  16. mindspore/avutil-57.dll +0 -0
  17. mindspore/boost/base.py +29 -2
  18. mindspore/common/_decorator.py +3 -2
  19. mindspore/common/_grad_function.py +3 -1
  20. mindspore/common/_tensor_cpp_method.py +1 -1
  21. mindspore/common/_tensor_docs.py +275 -64
  22. mindspore/common/_utils.py +0 -44
  23. mindspore/common/api.py +285 -35
  24. mindspore/common/dump.py +7 -108
  25. mindspore/common/dynamic_shape/auto_dynamic_shape.py +1 -3
  26. mindspore/common/hook_handle.py +60 -0
  27. mindspore/common/jit_config.py +5 -1
  28. mindspore/common/jit_trace.py +27 -12
  29. mindspore/common/lazy_inline.py +5 -3
  30. mindspore/common/parameter.py +13 -107
  31. mindspore/common/recompute.py +4 -11
  32. mindspore/common/tensor.py +16 -169
  33. mindspore/communication/_comm_helper.py +11 -1
  34. mindspore/communication/comm_func.py +138 -4
  35. mindspore/communication/management.py +85 -1
  36. mindspore/config/op_info.config +0 -15
  37. mindspore/context.py +5 -85
  38. mindspore/dataset/engine/datasets.py +8 -4
  39. mindspore/dataset/engine/datasets_vision.py +1 -1
  40. mindspore/dataset/engine/validators.py +1 -15
  41. mindspore/dnnl.dll +0 -0
  42. mindspore/{experimental/llm_boost/ascend_native → graph}/__init__.py +7 -7
  43. mindspore/graph/custom_pass.py +55 -0
  44. mindspore/include/dataset/execute.h +2 -2
  45. mindspore/jpeg62.dll +0 -0
  46. mindspore/mindrecord/__init__.py +3 -3
  47. mindspore/mindrecord/common/exceptions.py +1 -0
  48. mindspore/mindrecord/config.py +1 -1
  49. mindspore/{parallel/mpi → mindrecord/core}/__init__.py +4 -1
  50. mindspore/mindrecord/{shardheader.py → core/shardheader.py} +2 -1
  51. mindspore/mindrecord/{shardindexgenerator.py → core/shardindexgenerator.py} +1 -1
  52. mindspore/mindrecord/{shardreader.py → core/shardreader.py} +2 -1
  53. mindspore/mindrecord/{shardsegment.py → core/shardsegment.py} +2 -2
  54. mindspore/mindrecord/{shardutils.py → core/shardutils.py} +1 -1
  55. mindspore/mindrecord/{shardwriter.py → core/shardwriter.py} +1 -1
  56. mindspore/mindrecord/filereader.py +4 -4
  57. mindspore/mindrecord/filewriter.py +5 -5
  58. mindspore/mindrecord/mindpage.py +2 -2
  59. mindspore/mindrecord/tools/cifar10.py +1 -1
  60. mindspore/mindrecord/tools/cifar100.py +1 -1
  61. mindspore/mindrecord/tools/cifar100_to_mr.py +1 -1
  62. mindspore/mindrecord/tools/cifar10_to_mr.py +1 -1
  63. mindspore/mindrecord/tools/csv_to_mr.py +1 -1
  64. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  65. mindspore/mindrecord/tools/mnist_to_mr.py +1 -1
  66. mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -1
  67. mindspore/mindspore_backend_common.dll +0 -0
  68. mindspore/mindspore_backend_manager.dll +0 -0
  69. mindspore/mindspore_cluster.dll +0 -0
  70. mindspore/mindspore_common.dll +0 -0
  71. mindspore/mindspore_core.dll +0 -0
  72. mindspore/mindspore_cpu.dll +0 -0
  73. mindspore/mindspore_dump.dll +0 -0
  74. mindspore/mindspore_frontend.dll +0 -0
  75. mindspore/mindspore_glog.dll +0 -0
  76. mindspore/mindspore_hardware_abstract.dll +0 -0
  77. mindspore/mindspore_memory_pool.dll +0 -0
  78. mindspore/mindspore_ms_backend.dll +0 -0
  79. mindspore/mindspore_ops.dll +0 -0
  80. mindspore/{mindspore_ops_host.dll → mindspore_ops_cpu.dll} +0 -0
  81. mindspore/mindspore_profiler.dll +0 -0
  82. mindspore/mindspore_pyboost.dll +0 -0
  83. mindspore/mindspore_pynative.dll +0 -0
  84. mindspore/mindspore_runtime_pipeline.dll +0 -0
  85. mindspore/mindspore_runtime_utils.dll +0 -0
  86. mindspore/mindspore_tools.dll +0 -0
  87. mindspore/mint/__init__.py +15 -10
  88. mindspore/mint/distributed/distributed.py +182 -62
  89. mindspore/mint/nn/__init__.py +2 -16
  90. mindspore/mint/nn/functional.py +4 -110
  91. mindspore/mint/nn/layer/__init__.py +0 -2
  92. mindspore/mint/nn/layer/activation.py +0 -6
  93. mindspore/mint/nn/layer/basic.py +0 -47
  94. mindspore/mint/nn/layer/conv.py +4 -4
  95. mindspore/mint/nn/layer/normalization.py +8 -13
  96. mindspore/mint/nn/layer/pooling.py +0 -4
  97. mindspore/nn/__init__.py +1 -3
  98. mindspore/nn/cell.py +16 -66
  99. mindspore/nn/layer/basic.py +49 -1
  100. mindspore/nn/layer/container.py +16 -0
  101. mindspore/nn/layer/embedding.py +4 -169
  102. mindspore/nn/layer/normalization.py +2 -1
  103. mindspore/nn/layer/thor_layer.py +4 -85
  104. mindspore/nn/optim/ada_grad.py +0 -1
  105. mindspore/nn/optim/adafactor.py +0 -1
  106. mindspore/nn/optim/adam.py +31 -124
  107. mindspore/nn/optim/adamax.py +0 -1
  108. mindspore/nn/optim/asgd.py +0 -1
  109. mindspore/nn/optim/ftrl.py +8 -102
  110. mindspore/nn/optim/lamb.py +0 -1
  111. mindspore/nn/optim/lars.py +0 -3
  112. mindspore/nn/optim/lazyadam.py +25 -218
  113. mindspore/nn/optim/momentum.py +5 -43
  114. mindspore/nn/optim/optimizer.py +6 -55
  115. mindspore/nn/optim/proximal_ada_grad.py +0 -1
  116. mindspore/nn/optim/rmsprop.py +0 -1
  117. mindspore/nn/optim/rprop.py +0 -1
  118. mindspore/nn/optim/sgd.py +0 -1
  119. mindspore/nn/optim/tft_wrapper.py +0 -1
  120. mindspore/nn/optim/thor.py +0 -2
  121. mindspore/nn/probability/bijector/bijector.py +7 -8
  122. mindspore/nn/probability/bijector/gumbel_cdf.py +2 -2
  123. mindspore/nn/probability/bijector/power_transform.py +20 -21
  124. mindspore/nn/probability/bijector/scalar_affine.py +5 -5
  125. mindspore/nn/probability/bijector/softplus.py +13 -14
  126. mindspore/nn/wrap/grad_reducer.py +4 -74
  127. mindspore/numpy/array_creations.py +2 -2
  128. mindspore/numpy/fft.py +9 -9
  129. mindspore/{nn/reinforcement → onnx}/__init__.py +5 -8
  130. mindspore/onnx/onnx_export.py +137 -0
  131. mindspore/opencv_core4110.dll +0 -0
  132. mindspore/opencv_imgcodecs4110.dll +0 -0
  133. mindspore/{opencv_imgproc452.dll → opencv_imgproc4110.dll} +0 -0
  134. mindspore/ops/__init__.py +2 -0
  135. mindspore/ops/_grad_experimental/grad_comm_ops.py +38 -2
  136. mindspore/ops/_op_impl/aicpu/__init__.py +0 -10
  137. mindspore/ops/_op_impl/cpu/__init__.py +0 -5
  138. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +16 -22
  139. mindspore/ops/auto_generate/gen_extend_func.py +2 -7
  140. mindspore/ops/auto_generate/gen_ops_def.py +98 -141
  141. mindspore/ops/auto_generate/gen_ops_prim.py +12708 -12686
  142. mindspore/ops/communication.py +97 -0
  143. mindspore/ops/composite/__init__.py +5 -2
  144. mindspore/ops/composite/base.py +15 -1
  145. mindspore/ops/composite/multitype_ops/__init__.py +3 -1
  146. mindspore/ops/composite/multitype_ops/_compile_utils.py +150 -8
  147. mindspore/ops/composite/multitype_ops/add_impl.py +7 -0
  148. mindspore/ops/composite/multitype_ops/mod_impl.py +27 -0
  149. mindspore/ops/function/__init__.py +1 -0
  150. mindspore/ops/function/array_func.py +14 -12
  151. mindspore/ops/function/comm_func.py +3883 -0
  152. mindspore/ops/function/debug_func.py +3 -4
  153. mindspore/ops/function/math_func.py +45 -54
  154. mindspore/ops/function/nn_func.py +75 -294
  155. mindspore/ops/function/random_func.py +9 -18
  156. mindspore/ops/functional.py +2 -0
  157. mindspore/ops/functional_overload.py +354 -18
  158. mindspore/ops/operations/__init__.py +2 -5
  159. mindspore/ops/operations/_custom_ops_utils.py +7 -9
  160. mindspore/ops/operations/_inner_ops.py +1 -38
  161. mindspore/ops/operations/_rl_inner_ops.py +0 -933
  162. mindspore/ops/operations/array_ops.py +1 -0
  163. mindspore/ops/operations/comm_ops.py +94 -2
  164. mindspore/ops/operations/custom_ops.py +228 -19
  165. mindspore/ops/operations/debug_ops.py +27 -29
  166. mindspore/ops/operations/manually_defined/ops_def.py +27 -306
  167. mindspore/ops/operations/nn_ops.py +2 -2
  168. mindspore/ops/operations/sparse_ops.py +0 -83
  169. mindspore/ops/primitive.py +1 -17
  170. mindspore/ops/tensor_method.py +72 -3
  171. mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +5 -5
  172. mindspore/ops_generate/aclnn/gen_aclnn_implement.py +8 -8
  173. mindspore/ops_generate/api/functions_cc_generator.py +53 -4
  174. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +25 -11
  175. mindspore/ops_generate/common/gen_constants.py +11 -10
  176. mindspore/ops_generate/common/op_proto.py +18 -1
  177. mindspore/ops_generate/common/template.py +102 -245
  178. mindspore/ops_generate/common/template_utils.py +212 -0
  179. mindspore/ops_generate/gen_custom_ops.py +69 -0
  180. mindspore/ops_generate/op_def/ops_def_cc_generator.py +78 -7
  181. mindspore/ops_generate/op_def_py/base_op_prim_py_generator.py +360 -0
  182. mindspore/ops_generate/op_def_py/custom_op_prim_py_generator.py +140 -0
  183. mindspore/ops_generate/op_def_py/op_def_py_generator.py +54 -7
  184. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -312
  185. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +74 -17
  186. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +22 -5
  187. mindspore/ops_generate/pyboost/op_template_parser.py +3 -2
  188. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +21 -5
  189. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +2 -2
  190. mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +30 -10
  191. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +10 -3
  192. mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +1 -1
  193. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +19 -9
  194. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +71 -28
  195. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +10 -9
  196. mindspore/ops_generate/pyboost/pyboost_utils.py +27 -16
  197. mindspore/ops_generate/resources/yaml_loader.py +13 -0
  198. mindspore/ops_generate/tensor_py_cc_generator.py +2 -2
  199. mindspore/parallel/_cell_wrapper.py +1 -1
  200. mindspore/parallel/_parallel_serialization.py +1 -4
  201. mindspore/parallel/_utils.py +29 -6
  202. mindspore/parallel/checkpoint_transform.py +18 -2
  203. mindspore/parallel/cluster/process_entity/_api.py +24 -32
  204. mindspore/parallel/cluster/process_entity/_utils.py +9 -5
  205. mindspore/{experimental/llm_boost/atb → parallel/distributed}/__init__.py +21 -23
  206. mindspore/parallel/distributed/distributed_data_parallel.py +393 -0
  207. mindspore/parallel/distributed/flatten_grad_buffer.py +295 -0
  208. mindspore/parallel/strategy.py +336 -0
  209. mindspore/parallel/transform_safetensors.py +117 -16
  210. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +3 -0
  211. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +1 -1
  212. mindspore/profiler/common/constant.py +5 -0
  213. mindspore/profiler/common/file_manager.py +9 -0
  214. mindspore/profiler/common/msprof_cmd_tool.py +38 -2
  215. mindspore/profiler/common/path_manager.py +56 -24
  216. mindspore/profiler/common/profiler_context.py +2 -12
  217. mindspore/profiler/common/profiler_info.py +3 -3
  218. mindspore/profiler/common/profiler_path_manager.py +13 -0
  219. mindspore/profiler/common/util.py +30 -3
  220. mindspore/profiler/experimental_config.py +2 -1
  221. mindspore/profiler/platform/npu_profiler.py +33 -6
  222. mindspore/run_check/_check_version.py +108 -24
  223. mindspore/runtime/__init__.py +3 -2
  224. mindspore/runtime/executor.py +11 -3
  225. mindspore/runtime/memory.py +112 -0
  226. mindspore/swresample-4.dll +0 -0
  227. mindspore/swscale-6.dll +0 -0
  228. mindspore/tinyxml2.dll +0 -0
  229. mindspore/{experimental/llm_boost → tools}/__init__.py +5 -5
  230. mindspore/tools/data_dump.py +130 -0
  231. mindspore/tools/sdc_detect.py +91 -0
  232. mindspore/tools/stress_detect.py +63 -0
  233. mindspore/train/__init__.py +6 -6
  234. mindspore/train/_utils.py +5 -18
  235. mindspore/train/amp.py +6 -4
  236. mindspore/train/callback/_checkpoint.py +0 -9
  237. mindspore/train/callback/_train_fault_tolerance.py +69 -18
  238. mindspore/train/data_sink.py +1 -5
  239. mindspore/train/model.py +38 -211
  240. mindspore/train/serialization.py +126 -387
  241. mindspore/turbojpeg.dll +0 -0
  242. mindspore/utils/__init__.py +6 -3
  243. mindspore/utils/dlpack.py +92 -0
  244. mindspore/utils/dryrun.py +1 -1
  245. mindspore/utils/runtime_execution_order_check.py +10 -0
  246. mindspore/utils/sdc_detect.py +14 -12
  247. mindspore/utils/stress_detect.py +43 -0
  248. mindspore/utils/utils.py +144 -8
  249. mindspore/version.py +1 -1
  250. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/METADATA +3 -2
  251. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/RECORD +254 -267
  252. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -210
  253. mindspore/experimental/llm_boost/ascend_native/llm_boost.py +0 -52
  254. mindspore/experimental/llm_boost/atb/boost_base.py +0 -385
  255. mindspore/experimental/llm_boost/atb/llama_boost.py +0 -137
  256. mindspore/experimental/llm_boost/atb/qwen_boost.py +0 -124
  257. mindspore/experimental/llm_boost/register.py +0 -130
  258. mindspore/experimental/llm_boost/utils.py +0 -31
  259. mindspore/include/OWNERS +0 -7
  260. mindspore/mindspore_cpu_res_manager.dll +0 -0
  261. mindspore/mindspore_ops_kernel_common.dll +0 -0
  262. mindspore/mindspore_res_manager.dll +0 -0
  263. mindspore/nn/optim/_dist_optimizer_registry.py +0 -111
  264. mindspore/nn/reinforcement/_batch_read_write.py +0 -142
  265. mindspore/nn/reinforcement/_tensors_queue.py +0 -152
  266. mindspore/nn/reinforcement/tensor_array.py +0 -145
  267. mindspore/opencv_core452.dll +0 -0
  268. mindspore/opencv_imgcodecs452.dll +0 -0
  269. mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +0 -113
  270. mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +0 -96
  271. mindspore/ops/_op_impl/aicpu/sparse_cross.py +0 -42
  272. mindspore/ops/_op_impl/cpu/buffer_append.py +0 -28
  273. mindspore/ops/_op_impl/cpu/buffer_get.py +0 -28
  274. mindspore/ops/_op_impl/cpu/buffer_sample.py +0 -28
  275. mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +0 -42
  276. mindspore/ops/operations/_tensor_array.py +0 -359
  277. mindspore/ops/operations/rl_ops.py +0 -288
  278. mindspore/parallel/_offload_context.py +0 -275
  279. mindspore/parallel/_recovery_context.py +0 -115
  280. mindspore/parallel/_transformer/__init__.py +0 -35
  281. mindspore/parallel/_transformer/layers.py +0 -765
  282. mindspore/parallel/_transformer/loss.py +0 -251
  283. mindspore/parallel/_transformer/moe.py +0 -693
  284. mindspore/parallel/_transformer/op_parallel_config.py +0 -222
  285. mindspore/parallel/_transformer/transformer.py +0 -3124
  286. mindspore/parallel/mpi/_mpi_config.py +0 -116
  287. mindspore/train/memory_profiling_pb2.py +0 -298
  288. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/WHEEL +0 -0
  289. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/entry_points.txt +0 -0
  290. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,360 @@
1
+ # Copyright 2025 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """
16
+ Module for generating Python primitive operator definitions from specifications.
17
+ """
18
+
19
+ import common.gen_utils as gen_utils
20
+ import common.template_utils as template
21
+ from common.base_generator import BaseGenerator
22
+ from common.op_proto import OpProto
23
+ from common.template_utils import Template
24
+
25
+
26
+ class BaseOpPrimPyGenerator(BaseGenerator):
27
+ """
28
+ Generates Python code for primitive operators based on provided specifications.
29
+ """
30
+
31
+ def __init__(self):
32
+ """
33
+ Initializes the generator with a template for defining operator primitive classes.
34
+ """
35
+ self.op_prim_class_define_template = template.OP_PRIM_CLASS_DEFINE_TEMPLATE
36
+
37
+ def _process_args(self, op_proto: OpProto):
38
+ """
39
+ Processes operator arguments to categorize them for code generation.
40
+
41
+ Args:
42
+ op_proto (OpProto): The operator prototype.
43
+
44
+ Returns:
45
+ tuple: A tuple containing processed arguments.
46
+ """
47
+ inputs_name = []
48
+ args_name = []
49
+ args_assign = []
50
+ inputs_default = {}
51
+ init_args_with_default = []
52
+ args_handlers = {}
53
+
54
+ for arg in op_proto.op_args:
55
+ # step1: get args infos:
56
+ if arg.is_prim_init:
57
+ # step1.1: get args name:
58
+ args_name.append(arg.arg_name)
59
+ # step1.2: get args assign with default value:
60
+ if arg.default is not None:
61
+ init_args_with_default.append(f"""{arg.arg_name}={arg.default}""")
62
+ else:
63
+ init_args_with_default.append(f"""{arg.arg_name}""")
64
+
65
+ # step1.3: get args set prim arg expression:
66
+ assign_str = self._get_assign_str_by_type_it(op_proto.op_class.name, arg)
67
+ if arg.arg_handler:
68
+ assign_str = (
69
+ f' self._set_prim_arg_with_handler('
70
+ f'"{arg.arg_name}", {assign_str}, {arg.arg_handler})'
71
+ )
72
+ else:
73
+ assign_str = f""" self._set_prim_arg("{arg.arg_name}", {assign_str})"""
74
+ args_assign.append(assign_str)
75
+ # step2: get inputs infos:
76
+ else:
77
+ # step2.1: get inputs name:
78
+ inputs_name.append(arg.arg_name)
79
+
80
+ # step2.2: get default value of inputs:
81
+ if arg.default is not None:
82
+ inputs_default[arg.arg_name] = arg.default
83
+
84
+ # step2.3: get args_handler functions for inputs
85
+ if arg.arg_handler:
86
+ args_handlers[arg.arg_name] = arg.arg_handler
87
+
88
+ return inputs_name, inputs_default, args_name, args_assign, init_args_with_default, args_handlers
89
+
90
+ def _get_assign_str_by_type_it(self, class_name, arg):
91
+ """
92
+ Generates assignment string with type casting.
93
+
94
+ Args:
95
+ class_name (str): The name of the class.
96
+ arg (OpArg): The operator argument.
97
+
98
+ Returns:
99
+ str: A string representing the assignment.
100
+ """
101
+ assign_str = ""
102
+ type_cast = arg.type_cast
103
+ if type_cast:
104
+ assign_str += f"type_it('{class_name}', '{arg.arg_name}', {arg.arg_name}, "
105
+ if len(type_cast) == 1:
106
+ assign_str += gen_utils.get_type_str(type_cast[0]) + ', '
107
+ else:
108
+ assign_str += '(' + ', '.join(gen_utils.get_type_str(ct) for ct in type_cast) + '), '
109
+ assign_str += gen_utils.get_type_str(arg.arg_dtype) + ')'
110
+ else:
111
+ assign_str = arg.arg_name
112
+ return assign_str
113
+
114
+ def _generate_class_desc(self, op_proto: OpProto, input_args, init_args, doc_dic):
115
+ """
116
+ Generates a class description based on the operator prototype.
117
+
118
+ Args:
119
+ op_proto (OpProto): The operator prototype.
120
+ input_args (list): List of input argument names.
121
+ init_args (list): List of initialization argument names.
122
+ doc_dic (dict): Documentation dictionary.
123
+
124
+ Returns:
125
+ str: A string containing the class description.
126
+ """
127
+ if op_proto.op_function and op_proto.op_function.disable:
128
+ # if function disabled, function name is equal to operator_name
129
+ return gen_utils.get_op_description(op_proto.op_name, doc_dic)
130
+
131
+ # If function is a released API, refer to the function doc.
132
+ init_args_str = ", ".join(init_args)
133
+ input_args_str = ", ".join(input_args)
134
+ args_str = ", ".join(input_args + init_args)
135
+
136
+ description_template = Template(template.PRIMITIVE_CLASS_DESC)
137
+ description_str = description_template.replace(class_name=op_proto.op_class.name,
138
+ init_args_str=init_args_str,
139
+ input_args_str=input_args_str,
140
+ func_name=op_proto.op_function.name,
141
+ args_str=args_str)
142
+ return description_str
143
+
144
+ def _get_init_code(self, init_code, op_proto: OpProto):
145
+ """
146
+ Generates additional initialization code for the operator primitive class.
147
+
148
+ Args:
149
+ init_code (str): Existing initialization code.
150
+ op_proto (OpProto): The operator prototype.
151
+
152
+ Returns:
153
+ str: A string containing additional initialization code.
154
+ """
155
+ labels_dic = op_proto.op_labels
156
+ if labels_dic:
157
+ if init_code:
158
+ init_code += "\n"
159
+ init_code += "\n".join([f""" self.add_prim_attr("{k}", {v})""" for k, v in labels_dic.items()])
160
+
161
+ return init_code if init_code else f""" pass"""
162
+
163
+ def _generate_py_op_signature(self, op_proto: OpProto, args_name, args_default):
164
+ """
165
+ Generates the __mindspore_signature__ for the operator.
166
+
167
+ Args:
168
+ op_proto (OpProto): The operator prototype.
169
+ args_name (list): List of argument names.
170
+ args_default (dict): Dictionary of default argument values.
171
+
172
+ Returns:
173
+ str: A string containing the __mindspore_signature__ code.
174
+ """
175
+ op_name = op_proto.op_name
176
+ args_signature = op_proto.op_args_signature
177
+
178
+ if args_signature is None and not args_default:
179
+ return ''
180
+
181
+ signature_code = f"""\n __mindspore_signature__ = """
182
+
183
+ # Init rw.
184
+ read_list, ref_list, write_list = gen_utils.init_args_signature_rw(args_signature)
185
+ _check_signature_arg_valid(op_name, write_list, args_name)
186
+ _check_signature_arg_valid(op_name, read_list, args_name)
187
+ _check_signature_arg_valid(op_name, ref_list, args_name)
188
+
189
+ # Init dtype group.
190
+ same_dtype_groups, dtype_count = gen_utils.get_same_dtype_groups(args_signature, args_name)
191
+ _check_signature_arg_valid(op_name, list(same_dtype_groups.keys()), args_name)
192
+
193
+ # Only one dtype_group is set.
194
+ if dtype_count == 1 and not any([write_list, read_list, ref_list, args_default]):
195
+ signature_code += '('
196
+ for _ in range(len(args_name) - 1):
197
+ signature_code += 'sig.sig_dtype.T, '
198
+ signature_code += 'sig.sig_dtype.T)\n'
199
+ return signature_code
200
+
201
+ # Set sig.make_sig.
202
+ signature_code += f""" (\n"""
203
+ for arg_name in args_name:
204
+ signature_code += f""" sig.make_sig('{arg_name}'"""
205
+ signature_code += signature_get_rw_label(arg_name, write_list, read_list, ref_list)
206
+ if arg_name in same_dtype_groups:
207
+ signature_code += f""", """ + signature_get_dtype_label(same_dtype_groups[arg_name])
208
+ if arg_name in args_default:
209
+ signature_code += f""", default=""" + str(args_default[arg_name])
210
+ signature_code += f"""),\n"""
211
+ signature_code += f""" )\n"""
212
+ return signature_code
213
+
214
+ def _generate_call_code(self, args_handlers, init_args, inputs_args, inputs_default, op_proto: OpProto):
215
+ """
216
+ Generates the __call__ method code for the operator primitive class.
217
+
218
+ Args:
219
+ args_handlers (dict): Dictionary of argument handlers.
220
+ init_args (list): List of initialization argument names.
221
+ inputs_args (list): List of input argument names.
222
+ inputs_default (dict): Dictionary of default input values.
223
+ op_proto (OpProto): The operator prototype.
224
+
225
+ Returns:
226
+ str: A string containing the __call__ method code.
227
+ """
228
+ call_code_str = ""
229
+ call_args = []
230
+ for name in inputs_args:
231
+ call_args.append(f"{name}={inputs_default[name]}" if name in inputs_default else name)
232
+ call_method_args_str = ", ".join(call_args)
233
+ call_method_body_str = self._get_call_method_body_str(args_handlers, init_args, inputs_args, inputs_default,
234
+ op_proto)
235
+ call_code_str += f""" def __call__(self, {call_method_args_str}):"""
236
+ call_code_str += f"""{call_method_body_str}"""
237
+ return call_code_str
238
+
239
+ def _get_call_method_body_str(self, args_handlers, init_args, inputs_args, inputs_default, op_proto: OpProto):
240
+ """
241
+ Generates the body of the __call__ method.
242
+
243
+ Args:
244
+ args_handlers (dict): Dictionary of argument handlers.
245
+ init_args (list): List of initialization argument names.
246
+ inputs_args (list): List of input argument names.
247
+ inputs_default (dict): Dictionary of default input values.
248
+ op_proto (OpProto): The operator prototype.
249
+
250
+ Returns:
251
+ str: A string containing the body of the call method.
252
+ """
253
+ raise NotImplementedError(
254
+ f"For '{self.__class__.__name__}', the '_get_call_method_body_str' method must be implemented."
255
+ )
256
+
257
+
258
+ def _check_signature_arg_valid(op_name, sig_arg_names, args_names):
259
+ """
260
+ Validates that all signature arguments are present in the list of argument names.
261
+
262
+ Args:
263
+ op_name (str): The name of the operator.
264
+ sig_arg_names (list): List of signature argument names.
265
+ args_names (list): List of actual argument names.
266
+
267
+ Raises:
268
+ ValueError: If a signature argument is not found in the list of argument names.
269
+ """
270
+ for sig_arg_name in sig_arg_names:
271
+ if sig_arg_name not in args_names:
272
+ raise ValueError(f"Op {op_name} has no input arg named '{sig_arg_name}'!")
273
+
274
+
275
+ def signature_get_dtype_label(index):
276
+ """
277
+ Generates the label for the data type in the signature.
278
+
279
+ Args:
280
+ index (int): The index of the data type.
281
+
282
+ Returns:
283
+ str: The label string for the data type.
284
+ """
285
+ dtype_index = ''
286
+ if index > 0:
287
+ dtype_index = f"""{index}"""
288
+ return f"""dtype=sig.sig_dtype.T{dtype_index}"""
289
+
290
+
291
+ def signature_get_rw_label(arg_name, write_list, read_list, ref_list):
292
+ """
293
+ Determines the read-write label for an argument in the signature.
294
+
295
+ Args:
296
+ arg_name (str): The name of the argument.
297
+ write_list (list): List of arguments that are writable.
298
+ read_list (list): List of arguments that are readable.
299
+ ref_list (list): List of arguments that are references.
300
+
301
+ Returns:
302
+ str: The read-write label for the argument.
303
+ """
304
+ for rw_arg_name in write_list:
305
+ if rw_arg_name == arg_name:
306
+ return ', sig.sig_rw.RW_WRITE'
307
+ for read_arg_name in read_list:
308
+ if read_arg_name == arg_name:
309
+ return ', sig.sig_rw.RW_READ'
310
+ for ref_arg_name in ref_list:
311
+ if ref_arg_name == arg_name:
312
+ return ', sig.sig_rw.RW_REF'
313
+ return ''
314
+
315
+
316
+ def generate_py_op_deprecated(deprecated):
317
+ """
318
+ Generates the deprecated decorator for an operator.
319
+
320
+ Args:
321
+ deprecated (dict): The deprecation information.
322
+
323
+ Returns:
324
+ str: A string containing the deprecated decorator.
325
+ """
326
+ if deprecated is None:
327
+ return ''
328
+ version = deprecated.get("version")
329
+ if version is None:
330
+ raise ValueError("The version of deprecated can't be None.")
331
+ substitute = deprecated.get("substitute")
332
+ if substitute is None:
333
+ raise ValueError("The substitute of deprecated can't be None.")
334
+ use_substitute = deprecated.get("use_substitute")
335
+ if use_substitute is None:
336
+ raise ValueError("The use_substitute of deprecated can't be None.")
337
+ if use_substitute is not True and use_substitute is not False:
338
+ raise ValueError(f"The use_substitute must be True or False, but got {use_substitute}")
339
+
340
+ deprecated = f""" @deprecated("{version}", "{substitute}", {use_substitute})\n"""
341
+ return deprecated
342
+
343
+
344
+ def _generate_arg_handler(class_name, arg, arg_handler, is_optional):
345
+ """
346
+ Generates the argument handler call for an argument.
347
+
348
+ Args:
349
+ class_name (str): The name of the class.
350
+ arg (str): The name of the argument.
351
+ arg_handler (str): The handler function for the argument.
352
+ is_optional (bool): Indicates whether the argument is optional.
353
+
354
+ Returns:
355
+ str: The argument handler call string.
356
+ """
357
+ arg_handler_call = f"""{arg_handler}('{class_name}', '{arg}', {arg})"""
358
+ if is_optional:
359
+ arg_handler_call = f"""{arg} if {arg} is None else {arg_handler_call}"""
360
+ return arg_handler_call
@@ -0,0 +1,140 @@
1
+ # Copyright 2025 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """
16
+ Module for generating Python primitive operator definitions from specifications.
17
+ """
18
+ import common.gen_utils as gen_utils
19
+ import common.template_utils as template
20
+ from common.op_proto import OpProto
21
+ from op_def_py.base_op_prim_py_generator import BaseOpPrimPyGenerator, generate_py_op_deprecated, _generate_arg_handler
22
+
23
+
24
+ class CustomOpPrimPyGenerator(BaseOpPrimPyGenerator):
25
+ """
26
+ Generates Python code for primitive operators based on provided specifications.
27
+ """
28
+
29
+ def __init__(self):
30
+ """
31
+ Initializes the generator with a template for defining operator primitive classes.
32
+ """
33
+ self.op_prim_class_define_template = template.OP_PRIM_CLASS_DEFINE_TEMPLATE
34
+
35
+ def generate(self, work_path, module_name, op_protos, doc_dict, file_pre):
36
+ """
37
+ Generates Python code for operator primitives and saves it to a file.
38
+
39
+ Args:
40
+ work_path (str): The directory to save the generated files.
41
+ op_protos (list): A list of operator prototypes.
42
+ doc_dict (dict): A dictionary containing documentation strings.
43
+ file_pre (str): The prefix for the generated file names.
44
+ """
45
+ gen_py = ""
46
+ for op_proto in op_protos:
47
+ if op_proto.op_class.disable:
48
+ continue
49
+
50
+ inputs_args, inputs_default, init_args, args_assign, init_args_with_default, args_handlers = (
51
+ self._process_args(op_proto))
52
+
53
+ # add class description
54
+ class_desc = self._generate_class_desc(op_proto, inputs_args, init_args, doc_dict)
55
+
56
+ # add signature
57
+ signature_code = self._generate_py_op_signature(op_proto, inputs_args, inputs_default)
58
+
59
+ # add deprecated
60
+ deprecated_code = generate_py_op_deprecated(op_proto.op_deprecated)
61
+
62
+ init_method = self._generate_init_code(args_assign, init_args_with_default, op_proto)
63
+
64
+ # add __call__ method code
65
+ call_method = self._generate_call_code(args_handlers, init_args, inputs_args, inputs_default, op_proto)
66
+
67
+ class_name = "Custom_" + op_proto.op_name
68
+ # generate op prim class define
69
+ op_prim_class_define = self.op_prim_class_define_template.replace(class_name=class_name,
70
+ class_desc=class_desc,
71
+ signature_code=signature_code,
72
+ deprecated_code=deprecated_code,
73
+ init_method=init_method,
74
+ call_method=call_method)
75
+ op_prim_class_define += "\n" if call_method.endswith("\n") else ""
76
+ gen_py += op_prim_class_define
77
+
78
+ # add prim_op_object
79
+ if not init_args:
80
+ gen_py += f"\n\n{op_proto.op_name}_op={class_name}({module_name}.{op_proto.op_name})\n"
81
+
82
+ custom_import_header = f"import {module_name}"
83
+ res_str = template.PY_LICENSE_STR + \
84
+ template.OPS_PY_PRIM_HEADER + custom_import_header + gen_py
85
+
86
+ file_name = f"{file_pre}_ops_prim.py"
87
+ gen_utils.save_file(work_path, file_name, res_str)
88
+
89
+ def _generate_init_code(self, args_assign, init_args_with_default, op_proto: OpProto):
90
+ """
91
+ Generates the __init__ method code for the operator primitive class.
92
+
93
+ Args:
94
+ args_assign (list): List of argument assignment strings.
95
+ init_args_with_default (list): List of initialization arguments with default values.
96
+ op_proto (OpProto): The operator prototype.
97
+
98
+ Returns:
99
+ str: A string containing the __init__ method code.
100
+ """
101
+ init_code_str = ""
102
+ init_code = "\n self.custom_op_func = op_func"
103
+ init_code = self._get_init_code(init_code, op_proto)
104
+ init_code_str += f" @prim_arg_register\n"
105
+ init_code_str += f" def __init__(self, op_func):\n"
106
+ init_code_str += f"{init_code}\n"
107
+ init_code_str += f"\n"
108
+ return init_code_str
109
+
110
+ def _get_call_method_body_str(self, args_handlers, init_args, inputs_args, inputs_default, op_proto: OpProto):
111
+ """
112
+ Generates the body of the __call__ method.
113
+
114
+ Args:
115
+ args_handlers (dict): Dictionary of argument handlers.
116
+ init_args (list): List of initialization argument names.
117
+ inputs_args (list): List of input argument names.
118
+ inputs_default (dict): Dictionary of default input values.
119
+ op_proto (OpProto): The operator prototype.
120
+
121
+ Returns:
122
+ str: A string containing the body of the call method.
123
+ """
124
+ call_args_list_str = ""
125
+ if inputs_args:
126
+ args_with_handler = []
127
+ for arg in inputs_args:
128
+ if arg in args_handlers:
129
+ is_optional = inputs_default.get(arg) == "None"
130
+ args_with_handler.append(
131
+ _generate_arg_handler(op_proto.op_class.name, arg, args_handlers[arg], is_optional))
132
+ else:
133
+ args_with_handler.append(arg)
134
+ call_args_list_str += ", ".join(args_with_handler)
135
+ if init_args:
136
+ call_args_list_str += ", "
137
+ call_args_list_str += ", ".join([f'self.{arg}' for arg in init_args])
138
+
139
+ call_method_body_str = f"\n return self.custom_op_func({call_args_list_str})"
140
+ return call_method_body_str
@@ -22,7 +22,7 @@ import common.gen_constants as K
22
22
  import common.gen_utils as gen_utils
23
23
 
24
24
  # refactored
25
- import common.template as template
25
+ import common.template_utils as template
26
26
 
27
27
  from common.base_generator import BaseGenerator
28
28
 
@@ -57,11 +57,23 @@ class OpDefPyGenerator(BaseGenerator):
57
57
  the provided operation prototypes and documentation. It saves the code in a file
58
58
  with the given prefix in the specified work path.
59
59
  """
60
+
61
+ gen_py = self._generate_func_code(op_protos, doc_dict)
62
+ res_str = template.PY_LICENSE_STR + \
63
+ template.OPS_PY_DEF_HEADER + gen_py[:-len(template.NEW_LINE)]
64
+ save_path = os.path.join(work_path, K.PY_AUTO_GEN_PATH)
65
+ file_name = f"{file_pre}_ops_def.py"
66
+ gen_utils.save_file(save_path, file_name, res_str)
67
+
68
+ def _generate_func_code(self, op_protos, doc_dict):
69
+ """
70
+ Generate Python source code for operator functions based on a list of
71
+ operator protocols and their documentation.
72
+ """
60
73
  gen_py = "\n"
61
74
  for op_proto in op_protos:
62
75
  if op_proto.op_function.disable:
63
76
  continue
64
-
65
77
  class_name = op_proto.op_class.name
66
78
  func_name = op_proto.op_function.name
67
79
  op_args = op_proto.op_args
@@ -93,11 +105,7 @@ class OpDefPyGenerator(BaseGenerator):
93
105
  gen_py += func_code
94
106
  gen_py += "\n"
95
107
 
96
- res_str = template.PY_LICENSE_STR + \
97
- template.OPS_PY_DEF_HEADER + gen_py[:-len(template.NEW_LINE)]
98
- save_path = os.path.join(work_path, K.PY_AUTO_GEN_PATH)
99
- file_name = f"{file_pre}_ops_def.py"
100
- gen_utils.save_file(save_path, file_name, res_str)
108
+ return gen_py
101
109
 
102
110
  def get_op_args(self, op_args):
103
111
  """
@@ -130,3 +138,42 @@ class OpDefPyGenerator(BaseGenerator):
130
138
  else:
131
139
  prim_call_args.append(op_arg.arg_name)
132
140
  return func_args, prim_call_args, prim_init_args
141
+
142
+
143
+ class CustomOpDefPyGenerator(OpDefPyGenerator):
144
+ """
145
+ This class is responsible for generating Python operator definitions based on provided
146
+ operation prototypes and documentation strings. It generates the code for the operator
147
+ functions that can be used in Python scripts to interact with the underlying operations.
148
+ """
149
+
150
+ def __init__(self):
151
+ """
152
+ Initializes the generator with the template for primitive class definitions.
153
+ """
154
+ super(CustomOpDefPyGenerator).__init__()
155
+ self.op_prim_class_define_template = template.OP_PRIM_CLASS_DEFINE_TEMPLATE
156
+
157
+ def generate(self, work_path, op_protos, doc_dict, file_pre):
158
+ """
159
+ Generates Python code for operator definitions and saves it to a file.
160
+
161
+ Args:
162
+ work_path (str): The base directory where the generated files will be saved.
163
+ op_protos (list): A list of operation prototypes to generate Python code for.
164
+ doc_dict (dict): A dictionary containing documentation strings for the operators.
165
+ file_pre (str): The prefix for the generated Python files.
166
+
167
+ Returns:
168
+ None
169
+
170
+ The generated Python code includes function definitions for each operator, using
171
+ the provided operation prototypes and documentation. It saves the code in a file
172
+ with the given prefix in the specified work path.
173
+ """
174
+
175
+ gen_py = self._generate_func_code(op_protos, doc_dict)
176
+ res_str = template.PY_LICENSE_STR + \
177
+ template.CUSTOM_OPS_PY_DEF_HEADER + gen_py[:-len(template.NEW_LINE)]
178
+ file_name = f"{file_pre}_ops_def.py"
179
+ gen_utils.save_file(work_path, file_name, res_str)