mindspore 2.7.0__cp311-cp311-win_amd64.whl → 2.7.1__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (290) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
  6. mindspore/_extends/parse/compile_config.py +24 -1
  7. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +6 -2
  8. mindspore/_extends/parse/resources.py +1 -1
  9. mindspore/_extends/parse/standard_method.py +8 -1
  10. mindspore/_extends/parse/trope.py +2 -1
  11. mindspore/_extends/pijit/pijit_func_white_list.py +7 -22
  12. mindspore/avcodec-59.dll +0 -0
  13. mindspore/avdevice-59.dll +0 -0
  14. mindspore/avfilter-8.dll +0 -0
  15. mindspore/avformat-59.dll +0 -0
  16. mindspore/avutil-57.dll +0 -0
  17. mindspore/boost/base.py +29 -2
  18. mindspore/common/_decorator.py +3 -2
  19. mindspore/common/_grad_function.py +3 -1
  20. mindspore/common/_tensor_cpp_method.py +1 -1
  21. mindspore/common/_tensor_docs.py +275 -64
  22. mindspore/common/_utils.py +0 -44
  23. mindspore/common/api.py +285 -35
  24. mindspore/common/dump.py +7 -108
  25. mindspore/common/dynamic_shape/auto_dynamic_shape.py +1 -3
  26. mindspore/common/hook_handle.py +60 -0
  27. mindspore/common/jit_config.py +5 -1
  28. mindspore/common/jit_trace.py +27 -12
  29. mindspore/common/lazy_inline.py +5 -3
  30. mindspore/common/parameter.py +13 -107
  31. mindspore/common/recompute.py +4 -11
  32. mindspore/common/tensor.py +16 -169
  33. mindspore/communication/_comm_helper.py +11 -1
  34. mindspore/communication/comm_func.py +138 -4
  35. mindspore/communication/management.py +85 -1
  36. mindspore/config/op_info.config +0 -15
  37. mindspore/context.py +5 -85
  38. mindspore/dataset/engine/datasets.py +8 -4
  39. mindspore/dataset/engine/datasets_vision.py +1 -1
  40. mindspore/dataset/engine/validators.py +1 -15
  41. mindspore/dnnl.dll +0 -0
  42. mindspore/{experimental/llm_boost/ascend_native → graph}/__init__.py +7 -7
  43. mindspore/graph/custom_pass.py +55 -0
  44. mindspore/include/dataset/execute.h +2 -2
  45. mindspore/jpeg62.dll +0 -0
  46. mindspore/mindrecord/__init__.py +3 -3
  47. mindspore/mindrecord/common/exceptions.py +1 -0
  48. mindspore/mindrecord/config.py +1 -1
  49. mindspore/{parallel/mpi → mindrecord/core}/__init__.py +4 -1
  50. mindspore/mindrecord/{shardheader.py → core/shardheader.py} +2 -1
  51. mindspore/mindrecord/{shardindexgenerator.py → core/shardindexgenerator.py} +1 -1
  52. mindspore/mindrecord/{shardreader.py → core/shardreader.py} +2 -1
  53. mindspore/mindrecord/{shardsegment.py → core/shardsegment.py} +2 -2
  54. mindspore/mindrecord/{shardutils.py → core/shardutils.py} +1 -1
  55. mindspore/mindrecord/{shardwriter.py → core/shardwriter.py} +1 -1
  56. mindspore/mindrecord/filereader.py +4 -4
  57. mindspore/mindrecord/filewriter.py +5 -5
  58. mindspore/mindrecord/mindpage.py +2 -2
  59. mindspore/mindrecord/tools/cifar10.py +1 -1
  60. mindspore/mindrecord/tools/cifar100.py +1 -1
  61. mindspore/mindrecord/tools/cifar100_to_mr.py +1 -1
  62. mindspore/mindrecord/tools/cifar10_to_mr.py +1 -1
  63. mindspore/mindrecord/tools/csv_to_mr.py +1 -1
  64. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  65. mindspore/mindrecord/tools/mnist_to_mr.py +1 -1
  66. mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -1
  67. mindspore/mindspore_backend_common.dll +0 -0
  68. mindspore/mindspore_backend_manager.dll +0 -0
  69. mindspore/mindspore_cluster.dll +0 -0
  70. mindspore/mindspore_common.dll +0 -0
  71. mindspore/mindspore_core.dll +0 -0
  72. mindspore/mindspore_cpu.dll +0 -0
  73. mindspore/mindspore_dump.dll +0 -0
  74. mindspore/mindspore_frontend.dll +0 -0
  75. mindspore/mindspore_glog.dll +0 -0
  76. mindspore/mindspore_hardware_abstract.dll +0 -0
  77. mindspore/mindspore_memory_pool.dll +0 -0
  78. mindspore/mindspore_ms_backend.dll +0 -0
  79. mindspore/mindspore_ops.dll +0 -0
  80. mindspore/{mindspore_ops_host.dll → mindspore_ops_cpu.dll} +0 -0
  81. mindspore/mindspore_profiler.dll +0 -0
  82. mindspore/mindspore_pyboost.dll +0 -0
  83. mindspore/mindspore_pynative.dll +0 -0
  84. mindspore/mindspore_runtime_pipeline.dll +0 -0
  85. mindspore/mindspore_runtime_utils.dll +0 -0
  86. mindspore/mindspore_tools.dll +0 -0
  87. mindspore/mint/__init__.py +15 -10
  88. mindspore/mint/distributed/distributed.py +182 -62
  89. mindspore/mint/nn/__init__.py +2 -16
  90. mindspore/mint/nn/functional.py +4 -110
  91. mindspore/mint/nn/layer/__init__.py +0 -2
  92. mindspore/mint/nn/layer/activation.py +0 -6
  93. mindspore/mint/nn/layer/basic.py +0 -47
  94. mindspore/mint/nn/layer/conv.py +4 -4
  95. mindspore/mint/nn/layer/normalization.py +8 -13
  96. mindspore/mint/nn/layer/pooling.py +0 -4
  97. mindspore/nn/__init__.py +1 -3
  98. mindspore/nn/cell.py +16 -66
  99. mindspore/nn/layer/basic.py +49 -1
  100. mindspore/nn/layer/container.py +16 -0
  101. mindspore/nn/layer/embedding.py +4 -169
  102. mindspore/nn/layer/normalization.py +2 -1
  103. mindspore/nn/layer/thor_layer.py +4 -85
  104. mindspore/nn/optim/ada_grad.py +0 -1
  105. mindspore/nn/optim/adafactor.py +0 -1
  106. mindspore/nn/optim/adam.py +31 -124
  107. mindspore/nn/optim/adamax.py +0 -1
  108. mindspore/nn/optim/asgd.py +0 -1
  109. mindspore/nn/optim/ftrl.py +8 -102
  110. mindspore/nn/optim/lamb.py +0 -1
  111. mindspore/nn/optim/lars.py +0 -3
  112. mindspore/nn/optim/lazyadam.py +25 -218
  113. mindspore/nn/optim/momentum.py +5 -43
  114. mindspore/nn/optim/optimizer.py +6 -55
  115. mindspore/nn/optim/proximal_ada_grad.py +0 -1
  116. mindspore/nn/optim/rmsprop.py +0 -1
  117. mindspore/nn/optim/rprop.py +0 -1
  118. mindspore/nn/optim/sgd.py +0 -1
  119. mindspore/nn/optim/tft_wrapper.py +0 -1
  120. mindspore/nn/optim/thor.py +0 -2
  121. mindspore/nn/probability/bijector/bijector.py +7 -8
  122. mindspore/nn/probability/bijector/gumbel_cdf.py +2 -2
  123. mindspore/nn/probability/bijector/power_transform.py +20 -21
  124. mindspore/nn/probability/bijector/scalar_affine.py +5 -5
  125. mindspore/nn/probability/bijector/softplus.py +13 -14
  126. mindspore/nn/wrap/grad_reducer.py +4 -74
  127. mindspore/numpy/array_creations.py +2 -2
  128. mindspore/numpy/fft.py +9 -9
  129. mindspore/{nn/reinforcement → onnx}/__init__.py +5 -8
  130. mindspore/onnx/onnx_export.py +137 -0
  131. mindspore/opencv_core4110.dll +0 -0
  132. mindspore/opencv_imgcodecs4110.dll +0 -0
  133. mindspore/{opencv_imgproc452.dll → opencv_imgproc4110.dll} +0 -0
  134. mindspore/ops/__init__.py +2 -0
  135. mindspore/ops/_grad_experimental/grad_comm_ops.py +38 -2
  136. mindspore/ops/_op_impl/aicpu/__init__.py +0 -10
  137. mindspore/ops/_op_impl/cpu/__init__.py +0 -5
  138. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +16 -22
  139. mindspore/ops/auto_generate/gen_extend_func.py +2 -7
  140. mindspore/ops/auto_generate/gen_ops_def.py +98 -141
  141. mindspore/ops/auto_generate/gen_ops_prim.py +12708 -12686
  142. mindspore/ops/communication.py +97 -0
  143. mindspore/ops/composite/__init__.py +5 -2
  144. mindspore/ops/composite/base.py +15 -1
  145. mindspore/ops/composite/multitype_ops/__init__.py +3 -1
  146. mindspore/ops/composite/multitype_ops/_compile_utils.py +150 -8
  147. mindspore/ops/composite/multitype_ops/add_impl.py +7 -0
  148. mindspore/ops/composite/multitype_ops/mod_impl.py +27 -0
  149. mindspore/ops/function/__init__.py +1 -0
  150. mindspore/ops/function/array_func.py +14 -12
  151. mindspore/ops/function/comm_func.py +3883 -0
  152. mindspore/ops/function/debug_func.py +3 -4
  153. mindspore/ops/function/math_func.py +45 -54
  154. mindspore/ops/function/nn_func.py +75 -294
  155. mindspore/ops/function/random_func.py +9 -18
  156. mindspore/ops/functional.py +2 -0
  157. mindspore/ops/functional_overload.py +354 -18
  158. mindspore/ops/operations/__init__.py +2 -5
  159. mindspore/ops/operations/_custom_ops_utils.py +7 -9
  160. mindspore/ops/operations/_inner_ops.py +1 -38
  161. mindspore/ops/operations/_rl_inner_ops.py +0 -933
  162. mindspore/ops/operations/array_ops.py +1 -0
  163. mindspore/ops/operations/comm_ops.py +94 -2
  164. mindspore/ops/operations/custom_ops.py +228 -19
  165. mindspore/ops/operations/debug_ops.py +27 -29
  166. mindspore/ops/operations/manually_defined/ops_def.py +27 -306
  167. mindspore/ops/operations/nn_ops.py +2 -2
  168. mindspore/ops/operations/sparse_ops.py +0 -83
  169. mindspore/ops/primitive.py +1 -17
  170. mindspore/ops/tensor_method.py +72 -3
  171. mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +5 -5
  172. mindspore/ops_generate/aclnn/gen_aclnn_implement.py +8 -8
  173. mindspore/ops_generate/api/functions_cc_generator.py +53 -4
  174. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +25 -11
  175. mindspore/ops_generate/common/gen_constants.py +11 -10
  176. mindspore/ops_generate/common/op_proto.py +18 -1
  177. mindspore/ops_generate/common/template.py +102 -245
  178. mindspore/ops_generate/common/template_utils.py +212 -0
  179. mindspore/ops_generate/gen_custom_ops.py +69 -0
  180. mindspore/ops_generate/op_def/ops_def_cc_generator.py +78 -7
  181. mindspore/ops_generate/op_def_py/base_op_prim_py_generator.py +360 -0
  182. mindspore/ops_generate/op_def_py/custom_op_prim_py_generator.py +140 -0
  183. mindspore/ops_generate/op_def_py/op_def_py_generator.py +54 -7
  184. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -312
  185. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +74 -17
  186. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +22 -5
  187. mindspore/ops_generate/pyboost/op_template_parser.py +3 -2
  188. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +21 -5
  189. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +2 -2
  190. mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +30 -10
  191. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +10 -3
  192. mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +1 -1
  193. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +19 -9
  194. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +71 -28
  195. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +10 -9
  196. mindspore/ops_generate/pyboost/pyboost_utils.py +27 -16
  197. mindspore/ops_generate/resources/yaml_loader.py +13 -0
  198. mindspore/ops_generate/tensor_py_cc_generator.py +2 -2
  199. mindspore/parallel/_cell_wrapper.py +1 -1
  200. mindspore/parallel/_parallel_serialization.py +1 -4
  201. mindspore/parallel/_utils.py +29 -6
  202. mindspore/parallel/checkpoint_transform.py +18 -2
  203. mindspore/parallel/cluster/process_entity/_api.py +24 -32
  204. mindspore/parallel/cluster/process_entity/_utils.py +9 -5
  205. mindspore/{experimental/llm_boost/atb → parallel/distributed}/__init__.py +21 -23
  206. mindspore/parallel/distributed/distributed_data_parallel.py +393 -0
  207. mindspore/parallel/distributed/flatten_grad_buffer.py +295 -0
  208. mindspore/parallel/strategy.py +336 -0
  209. mindspore/parallel/transform_safetensors.py +117 -16
  210. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +3 -0
  211. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +1 -1
  212. mindspore/profiler/common/constant.py +5 -0
  213. mindspore/profiler/common/file_manager.py +9 -0
  214. mindspore/profiler/common/msprof_cmd_tool.py +38 -2
  215. mindspore/profiler/common/path_manager.py +56 -24
  216. mindspore/profiler/common/profiler_context.py +2 -12
  217. mindspore/profiler/common/profiler_info.py +3 -3
  218. mindspore/profiler/common/profiler_path_manager.py +13 -0
  219. mindspore/profiler/common/util.py +30 -3
  220. mindspore/profiler/experimental_config.py +2 -1
  221. mindspore/profiler/platform/npu_profiler.py +33 -6
  222. mindspore/run_check/_check_version.py +108 -24
  223. mindspore/runtime/__init__.py +3 -2
  224. mindspore/runtime/executor.py +11 -3
  225. mindspore/runtime/memory.py +112 -0
  226. mindspore/swresample-4.dll +0 -0
  227. mindspore/swscale-6.dll +0 -0
  228. mindspore/tinyxml2.dll +0 -0
  229. mindspore/{experimental/llm_boost → tools}/__init__.py +5 -5
  230. mindspore/tools/data_dump.py +130 -0
  231. mindspore/tools/sdc_detect.py +91 -0
  232. mindspore/tools/stress_detect.py +63 -0
  233. mindspore/train/__init__.py +6 -6
  234. mindspore/train/_utils.py +5 -18
  235. mindspore/train/amp.py +6 -4
  236. mindspore/train/callback/_checkpoint.py +0 -9
  237. mindspore/train/callback/_train_fault_tolerance.py +69 -18
  238. mindspore/train/data_sink.py +1 -5
  239. mindspore/train/model.py +38 -211
  240. mindspore/train/serialization.py +126 -387
  241. mindspore/turbojpeg.dll +0 -0
  242. mindspore/utils/__init__.py +6 -3
  243. mindspore/utils/dlpack.py +92 -0
  244. mindspore/utils/dryrun.py +1 -1
  245. mindspore/utils/runtime_execution_order_check.py +10 -0
  246. mindspore/utils/sdc_detect.py +14 -12
  247. mindspore/utils/stress_detect.py +43 -0
  248. mindspore/utils/utils.py +144 -8
  249. mindspore/version.py +1 -1
  250. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/METADATA +3 -2
  251. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/RECORD +254 -267
  252. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -210
  253. mindspore/experimental/llm_boost/ascend_native/llm_boost.py +0 -52
  254. mindspore/experimental/llm_boost/atb/boost_base.py +0 -385
  255. mindspore/experimental/llm_boost/atb/llama_boost.py +0 -137
  256. mindspore/experimental/llm_boost/atb/qwen_boost.py +0 -124
  257. mindspore/experimental/llm_boost/register.py +0 -130
  258. mindspore/experimental/llm_boost/utils.py +0 -31
  259. mindspore/include/OWNERS +0 -7
  260. mindspore/mindspore_cpu_res_manager.dll +0 -0
  261. mindspore/mindspore_ops_kernel_common.dll +0 -0
  262. mindspore/mindspore_res_manager.dll +0 -0
  263. mindspore/nn/optim/_dist_optimizer_registry.py +0 -111
  264. mindspore/nn/reinforcement/_batch_read_write.py +0 -142
  265. mindspore/nn/reinforcement/_tensors_queue.py +0 -152
  266. mindspore/nn/reinforcement/tensor_array.py +0 -145
  267. mindspore/opencv_core452.dll +0 -0
  268. mindspore/opencv_imgcodecs452.dll +0 -0
  269. mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +0 -113
  270. mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +0 -96
  271. mindspore/ops/_op_impl/aicpu/sparse_cross.py +0 -42
  272. mindspore/ops/_op_impl/cpu/buffer_append.py +0 -28
  273. mindspore/ops/_op_impl/cpu/buffer_get.py +0 -28
  274. mindspore/ops/_op_impl/cpu/buffer_sample.py +0 -28
  275. mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +0 -42
  276. mindspore/ops/operations/_tensor_array.py +0 -359
  277. mindspore/ops/operations/rl_ops.py +0 -288
  278. mindspore/parallel/_offload_context.py +0 -275
  279. mindspore/parallel/_recovery_context.py +0 -115
  280. mindspore/parallel/_transformer/__init__.py +0 -35
  281. mindspore/parallel/_transformer/layers.py +0 -765
  282. mindspore/parallel/_transformer/loss.py +0 -251
  283. mindspore/parallel/_transformer/moe.py +0 -693
  284. mindspore/parallel/_transformer/op_parallel_config.py +0 -222
  285. mindspore/parallel/_transformer/transformer.py +0 -3124
  286. mindspore/parallel/mpi/_mpi_config.py +0 -116
  287. mindspore/train/memory_profiling_pb2.py +0 -298
  288. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/WHEEL +0 -0
  289. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/entry_points.txt +0 -0
  290. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,212 @@
1
+ # Copyright 2025 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """Templates for code auto generation."""
16
+ import re
17
+
18
+
19
+ class Template:
20
+ """
21
+ template for generate c++/python code
22
+ """
23
+ regular_str = r"(^[^\n\S]*)?\$([^\d\W]\w*|\{,?[^\d\W]\w*\,?})"
24
+ regular_match = re.compile(regular_str, re.MULTILINE)
25
+
26
+ def __init__(self, code_pattern):
27
+ self.code_pattern = code_pattern
28
+
29
+ @staticmethod
30
+ def load_from_file(file_path):
31
+ """load template from file"""
32
+ with open(file_path, "r") as f:
33
+ return Template(f.read())
34
+
35
+ def replace(self, **kwargs):
36
+ """
37
+ replace param.
38
+ :param kwargs:
39
+ :return:
40
+ """
41
+
42
+ def find(key: str):
43
+ if key in kwargs:
44
+ return kwargs[key]
45
+ raise TypeError(f"{key} should be in kwargs!")
46
+
47
+ def add_indent(indent, var):
48
+ return "".join([indent + line + "\n" for data in var for line in str(data).splitlines()]).rstrip()
49
+
50
+ def extract_variable(key):
51
+ start = ""
52
+ end = ""
53
+ if key[0] == "{":
54
+ key = key[1:-1]
55
+ if key[0] == ",":
56
+ start = ","
57
+ key = key[1:]
58
+ if key[-1] == ",":
59
+ end = ", "
60
+ key = key[:-1]
61
+ return find(key), start, end
62
+
63
+ def match_rule(match):
64
+ indent = match.group(1)
65
+ key = match.group(2)
66
+ var, start, end = extract_variable(key)
67
+ if indent is not None:
68
+ if not isinstance(var, list):
69
+ return add_indent(indent, [var])
70
+ return add_indent(indent, var)
71
+ if isinstance(var, list):
72
+ code = ", ".join(str(x) for x in var)
73
+ if not var:
74
+ return code
75
+ return start + code + end
76
+ return str(var)
77
+
78
+ return self.regular_match.sub(match_rule, self.code_pattern)
79
+
80
+
81
+ NEW_LINE = "\n"
82
+
83
+ PYTHON_PRIM_TEMPLATE = Template("""
84
+
85
+ class _Pyboost${class_name}Prim(${class_name}Prim_):
86
+ def __call__(self, ${input_args}):
87
+ ${process_func}
88
+ return super().__call__([${processed_args}])
89
+
90
+
91
+ ${func_impl_name}_impl = _Pyboost${class_name}Prim()
92
+ """)
93
+
94
+ IMPORT_PYBOOST_PRIM_HEADER = f"""
95
+ from mindspore.ops._utils.arg_handler import *
96
+ """
97
+
98
+ IMPORT_PYBOOST_FUNC_HEADER = f"""
99
+ from mindspore.common import dtype as mstype
100
+ from mindspore.ops.auto_generate.pyboost_inner_prim import *
101
+
102
+ """
103
+
104
+ PY_LICENSE_STR = f"""# Copyright 2023 Huawei Technologies Co., Ltd
105
+ #
106
+ # Licensed under the Apache License, Version 2.0 (the "License");
107
+ # you may not use this file except in compliance with the License.
108
+ # You may obtain a copy of the License at
109
+ #
110
+ # http://www.apache.org/licenses/LICENSE-2.0
111
+ #
112
+ # Unless required by applicable law or agreed to in writing, software
113
+ # distributed under the License is distributed on an "AS IS" BASIS,
114
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
115
+ # See the License for the specific language governing permissions and
116
+ # limitations under the License.
117
+ # ============================================================================
118
+ """
119
+
120
+ OPS_PY_PRIM_HEADER = f"""
121
+ \"\"\"Operators definition generated by gen_ops.py, includes primitive classes.\"\"\"
122
+
123
+ from mindspore.ops.primitive import Primitive, prim_arg_register
124
+ from mindspore.ops import signature as sig
125
+ from mindspore.common import dtype as mstype
126
+ from mindspore.common._decorator import deprecated
127
+ from mindspore.ops._primitive_cache import _get_cache_prim
128
+ from mindspore.ops._utils.arg_dtype_cast import type_it
129
+ from mindspore.ops._utils.arg_handler import *
130
+ from mindspore._c_expression import OpDtype
131
+ from mindspore.common.jit_context import jit_context
132
+ from mindspore._checkparam import is_stub_tensor
133
+ """
134
+
135
+ OPS_PY_DEF_HEADER = f"""
136
+ \"\"\"Operators definition generated by gen_ops.py, includes functions.\"\"\"
137
+
138
+ from .gen_ops_prim import *
139
+ from .pyboost_inner_prim import *
140
+ from mindspore.ops.operations.manually_defined.ops_def import *
141
+ from mindspore.ops._primitive_cache import _get_cache_prim
142
+ """
143
+
144
+ CUSTOM_OPS_PY_DEF_HEADER = f"""
145
+ \"\"\"Operators definition generated by gen_ops.py, includes functions.\"\"\"
146
+
147
+ from .gen_ops_prim import *
148
+ """
149
+
150
+ PRIMITIVE_CLASS_DESC = """ r\"\"\"
151
+ .. code-block::
152
+
153
+ prim = ops.$class_name($init_args_str)
154
+ out = prim($input_args_str)
155
+
156
+ is equivalent to
157
+
158
+ .. code-block::
159
+
160
+ ops.$func_name($args_str)
161
+
162
+ Refer to :func:`mindspore.ops.$func_name` for more details.
163
+ \"\"\"
164
+ """
165
+
166
+ CC_LICENSE_STR = f"""/**
167
+ * Copyright 2023-2025 Huawei Technologies Co., Ltd
168
+ *
169
+ * Licensed under the Apache License, Version 2.0 (the "License");
170
+ * you may not use this file except in compliance with the License.
171
+ * You may obtain a copy of the License at
172
+ *
173
+ * http://www.apache.org/licenses/LICENSE-2.0
174
+ *
175
+ * Unless required by applicable law or agreed to in writing, software
176
+ * distributed under the License is distributed on an "AS IS" BASIS,
177
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
178
+ * See the License for the specific language governing permissions and
179
+ * limitations under the License.
180
+ */"""
181
+
182
+ OP_PROTO_TEMPLATE = Template("""
183
+ ${func_impl_declaration}
184
+ OpDef g${class_name} = {
185
+ /*.name_=*/"${class_name}",
186
+ /*.args_=*/ {
187
+ ${input_args}
188
+ },
189
+ /* .returns_ = */ {
190
+ ${return_args}
191
+ },
192
+ /*.signatures_ =*/ {
193
+ ${signatures}
194
+ },
195
+ /*.indexes_ =*/ {
196
+ ${indexes}
197
+ },
198
+ /*.func_impl_=*/${func_impl_define},
199
+ /*.enable_dispatch_ =*/${enable_dispatch},
200
+ /*.is_view_ =*/${is_view},
201
+ /*.is_graph_view_ =*/${is_graph_view},
202
+ };
203
+ REGISTER_PRIMITIVE_OP_DEF(${class_name}, &g${class_name});
204
+ """)
205
+
206
+ OP_PRIM_CLASS_DEFINE_TEMPLATE = Template("""
207
+
208
+ class ${class_name}(Primitive):
209
+ ${class_desc}${signature_code}${deprecated_code}
210
+ ${init_method}
211
+
212
+ ${call_method}""")
@@ -0,0 +1,69 @@
1
+ # Copyright 2025 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """
16
+ Auto generate custom ops files.
17
+ """
18
+
19
+ import logging
20
+ import argparse
21
+ from resources.resource_manager import ResourceManager
22
+ from resources.resource_list import ResourceType
23
+ from resources.yaml_loader import CustomOpDocYamlLoader
24
+ from op_def.ops_def_cc_generator import CustomOpsDefCcGenerator
25
+ from op_def_py.custom_op_prim_py_generator import CustomOpPrimPyGenerator
26
+ from op_def_py.op_def_py_generator import CustomOpDefPyGenerator
27
+ from common.op_proto import CustomOpProtoLoader
28
+
29
+
30
+ def get_config():
31
+ """get config from user"""
32
+ parser = argparse.ArgumentParser()
33
+ parser.add_argument("-m", "--module_name", type=str, required=True)
34
+ parser.add_argument("-i", "--input_path", type=str, required=True)
35
+ parser.add_argument("-o", "--output_path", type=str, required=True)
36
+ parser.add_argument("-d", "--doc_path", type=str, required=True)
37
+ return parser.parse_args()
38
+
39
+
40
+ def generate_custom_op_def(module_name, input_path, doc_path, output_path):
41
+ """Automatically generate all necessary files for custom operators."""
42
+ resource_mgr = ResourceManager()
43
+ resource_mgr.register_resource(CustomOpProtoLoader(input_path))
44
+ op_protos = resource_mgr.get_resource(ResourceType.OP_PROTO)
45
+ doc_dict = dict()
46
+ if doc_path != "":
47
+ resource_mgr.register_resource(CustomOpDocYamlLoader(doc_path))
48
+ doc_dict = resource_mgr.get_resource(ResourceType.OP_DOC_YAML)
49
+
50
+ generator = CustomOpsDefCcGenerator()
51
+ generator.generate(output_path, op_protos)
52
+ generator = CustomOpPrimPyGenerator()
53
+ generator.generate(output_path, module_name, op_protos, doc_dict, "gen")
54
+ generator = CustomOpDefPyGenerator()
55
+ generator.generate(output_path, op_protos, doc_dict, "gen")
56
+
57
+
58
+ def main():
59
+ """main function"""
60
+ args = get_config()
61
+ generate_custom_op_def(args.module_name, args.input_path, args.doc_path, args.output_path)
62
+
63
+
64
+ if __name__ == "__main__":
65
+ try:
66
+ main()
67
+ except Exception as e:
68
+ logging.critical("Auto generate failed, err info: %s", e)
69
+ raise e
@@ -24,11 +24,10 @@ import common.gen_utils as gen_utils
24
24
 
25
25
  # refactored
26
26
  from common.op_proto import OpProto
27
- import common.template as template
27
+ import common.template_utils as template
28
28
 
29
29
  from common.base_generator import BaseGenerator
30
30
 
31
-
32
31
  CC_OPS_DEF = """
33
32
 
34
33
  #include "$auto_generate_path/gen_ops_def.h"
@@ -72,8 +71,12 @@ class OpsDefCcGenerator(BaseGenerator):
72
71
  operator_name = op_proto.op_name
73
72
  class_name = op_proto.op_class.name
74
73
  if not op_proto.func_op:
75
- gen_include_list.append(self.include_template.replace(path=K.MS_OPS_FUNC_IMPL_PATH,
76
- operator_name=operator_name))
74
+ if op_proto.op_dispatch and op_proto.op_dispatch.is_comm_op:
75
+ gen_include_list.append(self.include_template.replace(path=K.MS_OPS_COMM_FUNC_IMPL_PATH,
76
+ operator_name=operator_name))
77
+ else:
78
+ gen_include_list.append(self.include_template.replace(path=K.MS_OPS_FUNC_IMPL_PATH,
79
+ operator_name=operator_name))
77
80
  func_impl_declaration_str = self.func_impl_declaration_template.replace(class_name=class_name)
78
81
  else:
79
82
  func_impl_declaration_str = self.empty_func_impl_declaration_template.replace(class_name=class_name)
@@ -85,7 +88,6 @@ class OpsDefCcGenerator(BaseGenerator):
85
88
  # Process outputs.
86
89
  return_args_str = get_cc_op_def_return(args_dict, op_proto)
87
90
 
88
-
89
91
  inputs_args = self.process_args(op_proto.op_args)
90
92
  signature_code = generate_cc_op_signature(op_proto.op_args_signature, inputs_args)
91
93
  enable_dispatch = "true" if op_proto.op_dispatch and op_proto.op_dispatch.enable else "false"
@@ -112,9 +114,9 @@ class OpsDefCcGenerator(BaseGenerator):
112
114
  save_path = os.path.join(work_path, K.MS_OP_DEF_AUTO_GENERATE_PATH)
113
115
  for numbering in range(math.ceil(op_size / max_op_size_in_one_file)):
114
116
  gen_include = ''.join(
115
- gen_include_list[numbering*max_op_size_in_one_file: (numbering+1)*max_op_size_in_one_file])
117
+ gen_include_list[numbering * max_op_size_in_one_file: (numbering + 1) * max_op_size_in_one_file])
116
118
  gen_cc = ''.join(
117
- gen_cc_list[numbering*max_op_size_in_one_file: (numbering+1)*max_op_size_in_one_file])
119
+ gen_cc_list[numbering * max_op_size_in_one_file: (numbering + 1) * max_op_size_in_one_file])
118
120
  cc_ops_def = self.CC_OPS_DEF_TEMPLATE.replace(auto_generate_path=K.MS_OP_DEF_AUTO_GENERATE_PATH,
119
121
  gen_include=gen_include,
120
122
  gen_cc_code=gen_cc)
@@ -148,6 +150,75 @@ class OpsDefCcGenerator(BaseGenerator):
148
150
  return inputs_name
149
151
 
150
152
 
153
+ class CustomOpsDefCcGenerator(OpsDefCcGenerator):
154
+ """
155
+ Generates C++ definition files for operators.
156
+ """
157
+
158
+ def __init__(self):
159
+ """
160
+ Initializes templates for generating C++ operator definitions.
161
+ """
162
+ super(CustomOpsDefCcGenerator, self).__init__()
163
+
164
+ self.include_template = template.Template("""#include "${path}/${operator_name}.h\"\n""")
165
+ self.func_impl_declaration_template = template.Template("extern OpFuncImpl &g${class_name}FuncImpl;")
166
+ self.empty_func_impl_declaration_template = template.Template("static OpFuncImpl g${class_name}FuncImpl;")
167
+ self.func_impl_define_template = template.Template("g${class_name}FuncImpl")
168
+ self.OP_PROTO_TEMPLATE = template.OP_PROTO_TEMPLATE
169
+ self.CC_OPS_DEF_TEMPLATE = template.Template(CC_OPS_DEF)
170
+
171
+ def generate(self, work_path, op_protos):
172
+ """
173
+ Generates C++ code for operator definitions and saves it to a file.
174
+
175
+ Args:
176
+ work_path (str): The directory to save the generated files.
177
+ op_protos (list): A list of operator prototypes.
178
+ """
179
+ gen_cc_list = list()
180
+ gen_include_list = list()
181
+
182
+ for op_proto in op_protos:
183
+ class_name = "Custom_" + op_proto.op_name
184
+ func_impl_declaration_str = self.func_impl_declaration_template.replace(class_name=class_name)
185
+ func_impl_define = self.func_impl_define_template.replace(class_name=class_name)
186
+
187
+ # process input
188
+ args_dict, cc_index_str, input_args_str = process_input_args(op_proto)
189
+
190
+ # Process outputs.
191
+ return_args_str = get_cc_op_def_return(args_dict, op_proto)
192
+
193
+ inputs_args = self.process_args(op_proto.op_args)
194
+ signature_code = generate_cc_op_signature(op_proto.op_args_signature, inputs_args)
195
+ enable_dispatch = "true" if op_proto.op_dispatch and op_proto.op_dispatch.enable else "false"
196
+ is_view = "true" if op_proto.op_view else "false"
197
+ is_graph_view = "true" if op_proto.op_graph_view else "false"
198
+ op_def_cc = self.OP_PROTO_TEMPLATE.replace(class_name=class_name,
199
+ input_args=input_args_str,
200
+ return_args=return_args_str,
201
+ signatures=signature_code,
202
+ indexes=cc_index_str,
203
+ enable_dispatch=enable_dispatch,
204
+ is_view=is_view,
205
+ is_graph_view=is_graph_view,
206
+ func_impl_declaration=func_impl_declaration_str,
207
+ func_impl_define=func_impl_define)
208
+
209
+ gen_cc_list.append(op_def_cc)
210
+
211
+ gen_include = ''.join(gen_include_list)
212
+ gen_cc = ''.join(gen_cc_list)
213
+ cc_ops_def = self.CC_OPS_DEF_TEMPLATE.replace(auto_generate_path=K.MS_OP_DEF_AUTO_GENERATE_PATH,
214
+ gen_include=gen_include,
215
+ gen_cc_code=gen_cc)
216
+
217
+ file_name = f"gen_custom_ops_def.cc"
218
+ ops_def_cc_file_str = template.CC_LICENSE_STR + cc_ops_def
219
+ gen_utils.save_file(work_path, file_name, ops_def_cc_file_str)
220
+
221
+
151
222
  def process_input_args(op_proto: OpProto):
152
223
  """
153
224
  Processes input arguments for C++ code generation.