mindspore 2.7.0__cp310-cp310-win_amd64.whl → 2.7.1__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (290) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
  6. mindspore/_extends/parse/compile_config.py +24 -1
  7. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +6 -2
  8. mindspore/_extends/parse/resources.py +1 -1
  9. mindspore/_extends/parse/standard_method.py +8 -1
  10. mindspore/_extends/parse/trope.py +2 -1
  11. mindspore/_extends/pijit/pijit_func_white_list.py +7 -22
  12. mindspore/avcodec-59.dll +0 -0
  13. mindspore/avdevice-59.dll +0 -0
  14. mindspore/avfilter-8.dll +0 -0
  15. mindspore/avformat-59.dll +0 -0
  16. mindspore/avutil-57.dll +0 -0
  17. mindspore/boost/base.py +29 -2
  18. mindspore/common/_decorator.py +3 -2
  19. mindspore/common/_grad_function.py +3 -1
  20. mindspore/common/_tensor_cpp_method.py +1 -1
  21. mindspore/common/_tensor_docs.py +275 -64
  22. mindspore/common/_utils.py +0 -44
  23. mindspore/common/api.py +285 -35
  24. mindspore/common/dump.py +7 -108
  25. mindspore/common/dynamic_shape/auto_dynamic_shape.py +1 -3
  26. mindspore/common/hook_handle.py +60 -0
  27. mindspore/common/jit_config.py +5 -1
  28. mindspore/common/jit_trace.py +27 -12
  29. mindspore/common/lazy_inline.py +5 -3
  30. mindspore/common/parameter.py +13 -107
  31. mindspore/common/recompute.py +4 -11
  32. mindspore/common/tensor.py +16 -169
  33. mindspore/communication/_comm_helper.py +11 -1
  34. mindspore/communication/comm_func.py +138 -4
  35. mindspore/communication/management.py +85 -1
  36. mindspore/config/op_info.config +0 -15
  37. mindspore/context.py +5 -85
  38. mindspore/dataset/engine/datasets.py +8 -4
  39. mindspore/dataset/engine/datasets_vision.py +1 -1
  40. mindspore/dataset/engine/validators.py +1 -15
  41. mindspore/dnnl.dll +0 -0
  42. mindspore/{experimental/llm_boost/ascend_native → graph}/__init__.py +7 -7
  43. mindspore/graph/custom_pass.py +55 -0
  44. mindspore/include/dataset/execute.h +2 -2
  45. mindspore/jpeg62.dll +0 -0
  46. mindspore/mindrecord/__init__.py +3 -3
  47. mindspore/mindrecord/common/exceptions.py +1 -0
  48. mindspore/mindrecord/config.py +1 -1
  49. mindspore/{parallel/mpi → mindrecord/core}/__init__.py +4 -1
  50. mindspore/mindrecord/{shardheader.py → core/shardheader.py} +2 -1
  51. mindspore/mindrecord/{shardindexgenerator.py → core/shardindexgenerator.py} +1 -1
  52. mindspore/mindrecord/{shardreader.py → core/shardreader.py} +2 -1
  53. mindspore/mindrecord/{shardsegment.py → core/shardsegment.py} +2 -2
  54. mindspore/mindrecord/{shardutils.py → core/shardutils.py} +1 -1
  55. mindspore/mindrecord/{shardwriter.py → core/shardwriter.py} +1 -1
  56. mindspore/mindrecord/filereader.py +4 -4
  57. mindspore/mindrecord/filewriter.py +5 -5
  58. mindspore/mindrecord/mindpage.py +2 -2
  59. mindspore/mindrecord/tools/cifar10.py +1 -1
  60. mindspore/mindrecord/tools/cifar100.py +1 -1
  61. mindspore/mindrecord/tools/cifar100_to_mr.py +1 -1
  62. mindspore/mindrecord/tools/cifar10_to_mr.py +1 -1
  63. mindspore/mindrecord/tools/csv_to_mr.py +1 -1
  64. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  65. mindspore/mindrecord/tools/mnist_to_mr.py +1 -1
  66. mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -1
  67. mindspore/mindspore_backend_common.dll +0 -0
  68. mindspore/mindspore_backend_manager.dll +0 -0
  69. mindspore/mindspore_cluster.dll +0 -0
  70. mindspore/mindspore_common.dll +0 -0
  71. mindspore/mindspore_core.dll +0 -0
  72. mindspore/mindspore_cpu.dll +0 -0
  73. mindspore/mindspore_dump.dll +0 -0
  74. mindspore/mindspore_frontend.dll +0 -0
  75. mindspore/mindspore_glog.dll +0 -0
  76. mindspore/mindspore_hardware_abstract.dll +0 -0
  77. mindspore/mindspore_memory_pool.dll +0 -0
  78. mindspore/mindspore_ms_backend.dll +0 -0
  79. mindspore/mindspore_ops.dll +0 -0
  80. mindspore/{mindspore_ops_host.dll → mindspore_ops_cpu.dll} +0 -0
  81. mindspore/mindspore_profiler.dll +0 -0
  82. mindspore/mindspore_pyboost.dll +0 -0
  83. mindspore/mindspore_pynative.dll +0 -0
  84. mindspore/mindspore_runtime_pipeline.dll +0 -0
  85. mindspore/mindspore_runtime_utils.dll +0 -0
  86. mindspore/mindspore_tools.dll +0 -0
  87. mindspore/mint/__init__.py +15 -10
  88. mindspore/mint/distributed/distributed.py +182 -62
  89. mindspore/mint/nn/__init__.py +2 -16
  90. mindspore/mint/nn/functional.py +4 -110
  91. mindspore/mint/nn/layer/__init__.py +0 -2
  92. mindspore/mint/nn/layer/activation.py +0 -6
  93. mindspore/mint/nn/layer/basic.py +0 -47
  94. mindspore/mint/nn/layer/conv.py +4 -4
  95. mindspore/mint/nn/layer/normalization.py +8 -13
  96. mindspore/mint/nn/layer/pooling.py +0 -4
  97. mindspore/nn/__init__.py +1 -3
  98. mindspore/nn/cell.py +16 -66
  99. mindspore/nn/layer/basic.py +49 -1
  100. mindspore/nn/layer/container.py +16 -0
  101. mindspore/nn/layer/embedding.py +4 -169
  102. mindspore/nn/layer/normalization.py +2 -1
  103. mindspore/nn/layer/thor_layer.py +4 -85
  104. mindspore/nn/optim/ada_grad.py +0 -1
  105. mindspore/nn/optim/adafactor.py +0 -1
  106. mindspore/nn/optim/adam.py +31 -124
  107. mindspore/nn/optim/adamax.py +0 -1
  108. mindspore/nn/optim/asgd.py +0 -1
  109. mindspore/nn/optim/ftrl.py +8 -102
  110. mindspore/nn/optim/lamb.py +0 -1
  111. mindspore/nn/optim/lars.py +0 -3
  112. mindspore/nn/optim/lazyadam.py +25 -218
  113. mindspore/nn/optim/momentum.py +5 -43
  114. mindspore/nn/optim/optimizer.py +6 -55
  115. mindspore/nn/optim/proximal_ada_grad.py +0 -1
  116. mindspore/nn/optim/rmsprop.py +0 -1
  117. mindspore/nn/optim/rprop.py +0 -1
  118. mindspore/nn/optim/sgd.py +0 -1
  119. mindspore/nn/optim/tft_wrapper.py +0 -1
  120. mindspore/nn/optim/thor.py +0 -2
  121. mindspore/nn/probability/bijector/bijector.py +7 -8
  122. mindspore/nn/probability/bijector/gumbel_cdf.py +2 -2
  123. mindspore/nn/probability/bijector/power_transform.py +20 -21
  124. mindspore/nn/probability/bijector/scalar_affine.py +5 -5
  125. mindspore/nn/probability/bijector/softplus.py +13 -14
  126. mindspore/nn/wrap/grad_reducer.py +4 -74
  127. mindspore/numpy/array_creations.py +2 -2
  128. mindspore/numpy/fft.py +9 -9
  129. mindspore/{nn/reinforcement → onnx}/__init__.py +5 -8
  130. mindspore/onnx/onnx_export.py +137 -0
  131. mindspore/opencv_core4110.dll +0 -0
  132. mindspore/opencv_imgcodecs4110.dll +0 -0
  133. mindspore/{opencv_imgproc452.dll → opencv_imgproc4110.dll} +0 -0
  134. mindspore/ops/__init__.py +2 -0
  135. mindspore/ops/_grad_experimental/grad_comm_ops.py +38 -2
  136. mindspore/ops/_op_impl/aicpu/__init__.py +0 -10
  137. mindspore/ops/_op_impl/cpu/__init__.py +0 -5
  138. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +16 -22
  139. mindspore/ops/auto_generate/gen_extend_func.py +2 -7
  140. mindspore/ops/auto_generate/gen_ops_def.py +98 -141
  141. mindspore/ops/auto_generate/gen_ops_prim.py +12708 -12686
  142. mindspore/ops/communication.py +97 -0
  143. mindspore/ops/composite/__init__.py +5 -2
  144. mindspore/ops/composite/base.py +15 -1
  145. mindspore/ops/composite/multitype_ops/__init__.py +3 -1
  146. mindspore/ops/composite/multitype_ops/_compile_utils.py +150 -8
  147. mindspore/ops/composite/multitype_ops/add_impl.py +7 -0
  148. mindspore/ops/composite/multitype_ops/mod_impl.py +27 -0
  149. mindspore/ops/function/__init__.py +1 -0
  150. mindspore/ops/function/array_func.py +14 -12
  151. mindspore/ops/function/comm_func.py +3883 -0
  152. mindspore/ops/function/debug_func.py +3 -4
  153. mindspore/ops/function/math_func.py +45 -54
  154. mindspore/ops/function/nn_func.py +75 -294
  155. mindspore/ops/function/random_func.py +9 -18
  156. mindspore/ops/functional.py +2 -0
  157. mindspore/ops/functional_overload.py +354 -18
  158. mindspore/ops/operations/__init__.py +2 -5
  159. mindspore/ops/operations/_custom_ops_utils.py +7 -9
  160. mindspore/ops/operations/_inner_ops.py +1 -38
  161. mindspore/ops/operations/_rl_inner_ops.py +0 -933
  162. mindspore/ops/operations/array_ops.py +1 -0
  163. mindspore/ops/operations/comm_ops.py +94 -2
  164. mindspore/ops/operations/custom_ops.py +228 -19
  165. mindspore/ops/operations/debug_ops.py +27 -29
  166. mindspore/ops/operations/manually_defined/ops_def.py +27 -306
  167. mindspore/ops/operations/nn_ops.py +2 -2
  168. mindspore/ops/operations/sparse_ops.py +0 -83
  169. mindspore/ops/primitive.py +1 -17
  170. mindspore/ops/tensor_method.py +72 -3
  171. mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +5 -5
  172. mindspore/ops_generate/aclnn/gen_aclnn_implement.py +8 -8
  173. mindspore/ops_generate/api/functions_cc_generator.py +53 -4
  174. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +25 -11
  175. mindspore/ops_generate/common/gen_constants.py +11 -10
  176. mindspore/ops_generate/common/op_proto.py +18 -1
  177. mindspore/ops_generate/common/template.py +102 -245
  178. mindspore/ops_generate/common/template_utils.py +212 -0
  179. mindspore/ops_generate/gen_custom_ops.py +69 -0
  180. mindspore/ops_generate/op_def/ops_def_cc_generator.py +78 -7
  181. mindspore/ops_generate/op_def_py/base_op_prim_py_generator.py +360 -0
  182. mindspore/ops_generate/op_def_py/custom_op_prim_py_generator.py +140 -0
  183. mindspore/ops_generate/op_def_py/op_def_py_generator.py +54 -7
  184. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -312
  185. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +74 -17
  186. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +22 -5
  187. mindspore/ops_generate/pyboost/op_template_parser.py +3 -2
  188. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +21 -5
  189. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +2 -2
  190. mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +30 -10
  191. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +10 -3
  192. mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +1 -1
  193. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +19 -9
  194. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +71 -28
  195. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +10 -9
  196. mindspore/ops_generate/pyboost/pyboost_utils.py +27 -16
  197. mindspore/ops_generate/resources/yaml_loader.py +13 -0
  198. mindspore/ops_generate/tensor_py_cc_generator.py +2 -2
  199. mindspore/parallel/_cell_wrapper.py +1 -1
  200. mindspore/parallel/_parallel_serialization.py +1 -4
  201. mindspore/parallel/_utils.py +29 -6
  202. mindspore/parallel/checkpoint_transform.py +18 -2
  203. mindspore/parallel/cluster/process_entity/_api.py +24 -32
  204. mindspore/parallel/cluster/process_entity/_utils.py +9 -5
  205. mindspore/{experimental/llm_boost/atb → parallel/distributed}/__init__.py +21 -23
  206. mindspore/parallel/distributed/distributed_data_parallel.py +393 -0
  207. mindspore/parallel/distributed/flatten_grad_buffer.py +295 -0
  208. mindspore/parallel/strategy.py +336 -0
  209. mindspore/parallel/transform_safetensors.py +117 -16
  210. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +3 -0
  211. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +1 -1
  212. mindspore/profiler/common/constant.py +5 -0
  213. mindspore/profiler/common/file_manager.py +9 -0
  214. mindspore/profiler/common/msprof_cmd_tool.py +38 -2
  215. mindspore/profiler/common/path_manager.py +56 -24
  216. mindspore/profiler/common/profiler_context.py +2 -12
  217. mindspore/profiler/common/profiler_info.py +3 -3
  218. mindspore/profiler/common/profiler_path_manager.py +13 -0
  219. mindspore/profiler/common/util.py +30 -3
  220. mindspore/profiler/experimental_config.py +2 -1
  221. mindspore/profiler/platform/npu_profiler.py +33 -6
  222. mindspore/run_check/_check_version.py +108 -24
  223. mindspore/runtime/__init__.py +3 -2
  224. mindspore/runtime/executor.py +11 -3
  225. mindspore/runtime/memory.py +112 -0
  226. mindspore/swresample-4.dll +0 -0
  227. mindspore/swscale-6.dll +0 -0
  228. mindspore/tinyxml2.dll +0 -0
  229. mindspore/{experimental/llm_boost → tools}/__init__.py +5 -5
  230. mindspore/tools/data_dump.py +130 -0
  231. mindspore/tools/sdc_detect.py +91 -0
  232. mindspore/tools/stress_detect.py +63 -0
  233. mindspore/train/__init__.py +6 -6
  234. mindspore/train/_utils.py +5 -18
  235. mindspore/train/amp.py +6 -4
  236. mindspore/train/callback/_checkpoint.py +0 -9
  237. mindspore/train/callback/_train_fault_tolerance.py +69 -18
  238. mindspore/train/data_sink.py +1 -5
  239. mindspore/train/model.py +38 -211
  240. mindspore/train/serialization.py +126 -387
  241. mindspore/turbojpeg.dll +0 -0
  242. mindspore/utils/__init__.py +6 -3
  243. mindspore/utils/dlpack.py +92 -0
  244. mindspore/utils/dryrun.py +1 -1
  245. mindspore/utils/runtime_execution_order_check.py +10 -0
  246. mindspore/utils/sdc_detect.py +14 -12
  247. mindspore/utils/stress_detect.py +43 -0
  248. mindspore/utils/utils.py +144 -8
  249. mindspore/version.py +1 -1
  250. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/METADATA +3 -2
  251. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/RECORD +254 -267
  252. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -210
  253. mindspore/experimental/llm_boost/ascend_native/llm_boost.py +0 -52
  254. mindspore/experimental/llm_boost/atb/boost_base.py +0 -385
  255. mindspore/experimental/llm_boost/atb/llama_boost.py +0 -137
  256. mindspore/experimental/llm_boost/atb/qwen_boost.py +0 -124
  257. mindspore/experimental/llm_boost/register.py +0 -130
  258. mindspore/experimental/llm_boost/utils.py +0 -31
  259. mindspore/include/OWNERS +0 -7
  260. mindspore/mindspore_cpu_res_manager.dll +0 -0
  261. mindspore/mindspore_ops_kernel_common.dll +0 -0
  262. mindspore/mindspore_res_manager.dll +0 -0
  263. mindspore/nn/optim/_dist_optimizer_registry.py +0 -111
  264. mindspore/nn/reinforcement/_batch_read_write.py +0 -142
  265. mindspore/nn/reinforcement/_tensors_queue.py +0 -152
  266. mindspore/nn/reinforcement/tensor_array.py +0 -145
  267. mindspore/opencv_core452.dll +0 -0
  268. mindspore/opencv_imgcodecs452.dll +0 -0
  269. mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +0 -113
  270. mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +0 -96
  271. mindspore/ops/_op_impl/aicpu/sparse_cross.py +0 -42
  272. mindspore/ops/_op_impl/cpu/buffer_append.py +0 -28
  273. mindspore/ops/_op_impl/cpu/buffer_get.py +0 -28
  274. mindspore/ops/_op_impl/cpu/buffer_sample.py +0 -28
  275. mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +0 -42
  276. mindspore/ops/operations/_tensor_array.py +0 -359
  277. mindspore/ops/operations/rl_ops.py +0 -288
  278. mindspore/parallel/_offload_context.py +0 -275
  279. mindspore/parallel/_recovery_context.py +0 -115
  280. mindspore/parallel/_transformer/__init__.py +0 -35
  281. mindspore/parallel/_transformer/layers.py +0 -765
  282. mindspore/parallel/_transformer/loss.py +0 -251
  283. mindspore/parallel/_transformer/moe.py +0 -693
  284. mindspore/parallel/_transformer/op_parallel_config.py +0 -222
  285. mindspore/parallel/_transformer/transformer.py +0 -3124
  286. mindspore/parallel/mpi/_mpi_config.py +0 -116
  287. mindspore/train/memory_profiling_pb2.py +0 -298
  288. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/WHEEL +0 -0
  289. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/entry_points.txt +0 -0
  290. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/top_level.txt +0 -0
@@ -20,14 +20,14 @@ import os
20
20
 
21
21
  import common.gen_constants as K
22
22
  import common.gen_utils as gen_utils
23
- import common.template as template
24
- from common.base_generator import BaseGenerator
23
+ import common.template_utils as template
25
24
  from common.op_proto import OpProto
26
- from common.template import Template
25
+ from common.template_utils import Template
27
26
  from pyboost import pyboost_utils
27
+ from op_def_py.base_op_prim_py_generator import BaseOpPrimPyGenerator, _generate_arg_handler, generate_py_op_deprecated
28
28
 
29
29
 
30
- class OpPrimPyGenerator(BaseGenerator):
30
+ class OpPrimPyGenerator(BaseOpPrimPyGenerator):
31
31
  """
32
32
  Generates Python code for primitive operators based on provided specifications.
33
33
  """
@@ -87,7 +87,7 @@ class OpPrimPyGenerator(BaseGenerator):
87
87
 
88
88
  pyboost_import_header = self.generate_pyboost_import_header(op_protos)
89
89
  res_str = template.PY_LICENSE_STR + \
90
- template.OPS_PY_PRIM_HEADER + pyboost_import_header + gen_py
90
+ template.OPS_PY_PRIM_HEADER + pyboost_import_header + gen_py
91
91
 
92
92
  save_path = os.path.join(work_path, K.PY_AUTO_GEN_PATH)
93
93
  file_name = f"{file_pre}_ops_prim.py"
@@ -111,113 +111,6 @@ class OpPrimPyGenerator(BaseGenerator):
111
111
  pyboost_import_header += header
112
112
  return pyboost_import_header
113
113
 
114
- def _process_args(self, op_proto: OpProto):
115
- """
116
- Processes operator arguments to categorize them for code generation.
117
-
118
- Args:
119
- op_proto (OpProto): The operator prototype.
120
-
121
- Returns:
122
- tuple: A tuple containing processed arguments.
123
- """
124
- inputs_name = []
125
- args_name = []
126
- args_assign = []
127
- inputs_default = {}
128
- init_args_with_default = []
129
- args_handlers = {}
130
-
131
- for arg in op_proto.op_args:
132
- # step1: get args infos:
133
- if arg.is_prim_init:
134
- # step1.1: get args name:
135
- args_name.append(arg.arg_name)
136
- # step1.2: get args assign with default value:
137
- if arg.default is not None:
138
- init_args_with_default.append(f"""{arg.arg_name}={arg.default}""")
139
- else:
140
- init_args_with_default.append(f"""{arg.arg_name}""")
141
-
142
- # step1.3: get args set prim arg expression:
143
- assign_str = self._get_assign_str_by_type_it(op_proto.op_class.name, arg)
144
- if arg.arg_handler:
145
- assign_str = (
146
- f' self._set_prim_arg_with_handler('
147
- f'"{arg.arg_name}", {assign_str}, {arg.arg_handler})'
148
- )
149
- else:
150
- assign_str = f""" self._set_prim_arg("{arg.arg_name}", {assign_str})"""
151
- args_assign.append(assign_str)
152
- # step2: get inputs infos:
153
- else:
154
- # step2.1: get inputs name:
155
- inputs_name.append(arg.arg_name)
156
-
157
- # step2.2: get default value of inputs:
158
- if arg.default is not None:
159
- inputs_default[arg.arg_name] = arg.default
160
-
161
- # step2.3: get args_handler functions for inputs
162
- if arg.arg_handler:
163
- args_handlers[arg.arg_name] = arg.arg_handler
164
-
165
- return inputs_name, inputs_default, args_name, args_assign, init_args_with_default, args_handlers
166
-
167
- def _get_assign_str_by_type_it(self, class_name, arg):
168
- """
169
- Generates assignment string with type casting.
170
-
171
- Args:
172
- class_name (str): The name of the class.
173
- arg (OpArg): The operator argument.
174
-
175
- Returns:
176
- str: A string representing the assignment.
177
- """
178
- assign_str = ""
179
- type_cast = arg.type_cast
180
- if type_cast:
181
- assign_str += f"type_it('{class_name}', '{arg.arg_name}', {arg.arg_name}, "
182
- if len(type_cast) == 1:
183
- assign_str += gen_utils.get_type_str(type_cast[0]) + ', '
184
- else:
185
- assign_str += '(' + ', '.join(gen_utils.get_type_str(ct) for ct in type_cast) + '), '
186
- assign_str += gen_utils.get_type_str(arg.arg_dtype) + ')'
187
- else:
188
- assign_str = arg.arg_name
189
- return assign_str
190
-
191
- def _generate_class_desc(self, op_proto: OpProto, input_args, init_args, doc_dic):
192
- """
193
- Generates a class description based on the operator prototype.
194
-
195
- Args:
196
- op_proto (OpProto): The operator prototype.
197
- input_args (list): List of input argument names.
198
- init_args (list): List of initialization argument names.
199
- doc_dic (dict): Documentation dictionary.
200
-
201
- Returns:
202
- str: A string containing the class description.
203
- """
204
- if op_proto.op_function and op_proto.op_function.disable:
205
- # if function disabled, function name is equal to operator_name
206
- return gen_utils.get_op_description(op_proto.op_name, doc_dic)
207
-
208
- # If function is a released API, refer to the function doc.
209
- init_args_str = ", ".join(init_args)
210
- input_args_str = ", ".join(input_args)
211
- args_str = ", ".join(input_args + init_args)
212
-
213
- description_template = Template(template.PRIMITIVE_CLASS_DESC)
214
- description_str = description_template.replace(class_name=op_proto.op_class.name,
215
- init_args_str=init_args_str,
216
- input_args_str=input_args_str,
217
- func_name=op_proto.op_function.name,
218
- args_str=args_str)
219
- return description_str
220
-
221
114
  def _generate_init_code(self, args_assign, init_args_with_default, op_proto: OpProto):
222
115
  """
223
116
  Generates the __init__ method code for the operator primitive class.
@@ -242,50 +135,6 @@ class OpPrimPyGenerator(BaseGenerator):
242
135
  init_code_str += f"\n"
243
136
  return init_code_str
244
137
 
245
- def _get_init_code(self, init_code, op_proto: OpProto):
246
- """
247
- Generates additional initialization code for the operator primitive class.
248
-
249
- Args:
250
- init_code (str): Existing initialization code.
251
- op_proto (OpProto): The operator prototype.
252
-
253
- Returns:
254
- str: A string containing additional initialization code.
255
- """
256
- labels_dic = op_proto.op_labels
257
- if labels_dic:
258
- if init_code:
259
- init_code += "\n"
260
- init_code += "\n".join([f""" self.add_prim_attr("{k}", {v})""" for k, v in labels_dic.items()])
261
-
262
- return init_code if init_code else f""" pass"""
263
-
264
- def _generate_call_code(self, args_handlers, init_args, inputs_args, inputs_default, op_proto: OpProto):
265
- """
266
- Generates the __call__ method code for the operator primitive class.
267
-
268
- Args:
269
- args_handlers (dict): Dictionary of argument handlers.
270
- init_args (list): List of initialization argument names.
271
- inputs_args (list): List of input argument names.
272
- inputs_default (dict): Dictionary of default input values.
273
- op_proto (OpProto): The operator prototype.
274
-
275
- Returns:
276
- str: A string containing the __call__ method code.
277
- """
278
- call_code_str = ""
279
- call_args = []
280
- for name in inputs_args:
281
- call_args.append(f"{name}={inputs_default[name]}" if name in inputs_default else name)
282
- call_method_args_str = ", ".join(call_args)
283
- call_method_body_str = self._get_call_method_body_str(args_handlers, init_args, inputs_args, inputs_default,
284
- op_proto)
285
- call_code_str += f""" def __call__(self, {call_method_args_str}):"""
286
- call_code_str += f"""{call_method_body_str}"""
287
- return call_code_str
288
-
289
138
  def _get_call_method_body_str(self, args_handlers, init_args, inputs_args, inputs_default, op_proto: OpProto):
290
139
  """
291
140
  Generates the body of the __call__ method.
@@ -334,159 +183,3 @@ class OpPrimPyGenerator(BaseGenerator):
334
183
  call_method_body_str += f"""
335
184
  return super().__call__({call_args_list_str})\n"""
336
185
  return call_method_body_str
337
-
338
- def _generate_py_op_signature(self, op_proto: OpProto, args_name, args_default):
339
- """
340
- Generates the __mindspore_signature__ for the operator.
341
-
342
- Args:
343
- op_proto (OpProto): The operator prototype.
344
- args_name (list): List of argument names.
345
- args_default (dict): Dictionary of default argument values.
346
-
347
- Returns:
348
- str: A string containing the __mindspore_signature__ code.
349
- """
350
- op_name = op_proto.op_name
351
- args_signature = op_proto.op_args_signature
352
-
353
- if args_signature is None and not args_default:
354
- return ''
355
-
356
- signature_code = f"""\n __mindspore_signature__ = """
357
-
358
- # Init rw.
359
- read_list, ref_list, write_list = gen_utils.init_args_signature_rw(args_signature)
360
- _check_signature_arg_valid(op_name, write_list, args_name)
361
- _check_signature_arg_valid(op_name, read_list, args_name)
362
- _check_signature_arg_valid(op_name, ref_list, args_name)
363
-
364
- # Init dtype group.
365
- same_dtype_groups, dtype_count = gen_utils.get_same_dtype_groups(args_signature, args_name)
366
- _check_signature_arg_valid(op_name, list(same_dtype_groups.keys()), args_name)
367
-
368
- # Only one dtype_group is set.
369
- if dtype_count == 1 and not any([write_list, read_list, ref_list, args_default]):
370
- signature_code += '('
371
- for _ in range(len(args_name) - 1):
372
- signature_code += 'sig.sig_dtype.T, '
373
- signature_code += 'sig.sig_dtype.T)\n'
374
- return signature_code
375
-
376
- # Set sig.make_sig.
377
- signature_code += f""" (\n"""
378
- for arg_name in args_name:
379
- signature_code += f""" sig.make_sig('{arg_name}'"""
380
- signature_code += signature_get_rw_label(arg_name, write_list, read_list, ref_list)
381
- if arg_name in same_dtype_groups:
382
- signature_code += f""", """ + signature_get_dtype_label(same_dtype_groups[arg_name])
383
- if arg_name in args_default:
384
- signature_code += f""", default=""" + str(args_default[arg_name])
385
- signature_code += f"""),\n"""
386
- signature_code += f""" )\n"""
387
- return signature_code
388
-
389
-
390
- def _check_signature_arg_valid(op_name, sig_arg_names, args_names):
391
- """
392
- Validates that all signature arguments are present in the list of argument names.
393
-
394
- Args:
395
- op_name (str): The name of the operator.
396
- sig_arg_names (list): List of signature argument names.
397
- args_names (list): List of actual argument names.
398
-
399
- Raises:
400
- ValueError: If a signature argument is not found in the list of argument names.
401
- """
402
- for sig_arg_name in sig_arg_names:
403
- if sig_arg_name not in args_names:
404
- raise ValueError(f"Op {op_name} has no input arg named '{sig_arg_name}'!")
405
-
406
-
407
- def signature_get_dtype_label(index):
408
- """
409
- Generates the label for the data type in the signature.
410
-
411
- Args:
412
- index (int): The index of the data type.
413
-
414
- Returns:
415
- str: The label string for the data type.
416
- """
417
- dtype_index = ''
418
- if index > 0:
419
- dtype_index = f"""{index}"""
420
- return f"""dtype=sig.sig_dtype.T{dtype_index}"""
421
-
422
-
423
- def signature_get_rw_label(arg_name, write_list, read_list, ref_list):
424
- """
425
- Determines the read-write label for an argument in the signature.
426
-
427
- Args:
428
- arg_name (str): The name of the argument.
429
- write_list (list): List of arguments that are writable.
430
- read_list (list): List of arguments that are readable.
431
- ref_list (list): List of arguments that are references.
432
-
433
- Returns:
434
- str: The read-write label for the argument.
435
- """
436
- for rw_arg_name in write_list:
437
- if rw_arg_name == arg_name:
438
- return ', sig.sig_rw.RW_WRITE'
439
- for read_arg_name in read_list:
440
- if read_arg_name == arg_name:
441
- return ', sig.sig_rw.RW_READ'
442
- for ref_arg_name in ref_list:
443
- if ref_arg_name == arg_name:
444
- return ', sig.sig_rw.RW_REF'
445
- return ''
446
-
447
-
448
- def generate_py_op_deprecated(deprecated):
449
- """
450
- Generates the deprecated decorator for an operator.
451
-
452
- Args:
453
- deprecated (dict): The deprecation information.
454
-
455
- Returns:
456
- str: A string containing the deprecated decorator.
457
- """
458
- if deprecated is None:
459
- return ''
460
- version = deprecated.get("version")
461
- if version is None:
462
- raise ValueError("The version of deprecated can't be None.")
463
- substitute = deprecated.get("substitute")
464
- if substitute is None:
465
- raise ValueError("The substitute of deprecated can't be None.")
466
- use_substitute = deprecated.get("use_substitute")
467
- if use_substitute is None:
468
- raise ValueError("The use_substitute of deprecated can't be None.")
469
- if use_substitute is not True and use_substitute is not False:
470
- raise ValueError(f"The use_substitute must be True or False, but got {use_substitute}")
471
-
472
- deprecated = f""" @deprecated("{version}", "{substitute}", {use_substitute})\n"""
473
- return deprecated
474
-
475
-
476
- def _generate_arg_handler(class_name, arg, arg_handler, is_optional):
477
- """
478
- Generates the argument handler call for an argument.
479
-
480
- Args:
481
- class_name (str): The name of the class.
482
- arg (str): The name of the argument.
483
- arg_handler (str): The handler function for the argument.
484
- is_optional (bool): Indicates whether the argument is optional.
485
-
486
- Returns:
487
- str: The argument handler call string.
488
- """
489
- arg_handler_call = f"""{arg_handler}('{class_name}', '{arg}', {arg})"""
490
- if is_optional:
491
- arg_handler_call = f"""{arg} if {arg} is None else {arg_handler_call}"""
492
- return arg_handler_call
@@ -23,7 +23,7 @@ from common.template import Template
23
23
  import common.gen_constants as K
24
24
  from common.gen_utils import save_file
25
25
  from common.base_generator import BaseGenerator
26
- from pyboost.pyboost_utils import is_optional_param, get_input_dtype, is_op_multi_output
26
+ from pyboost.pyboost_utils import is_optional_param, get_input_dtype, is_op_multi_output, get_output_dtype
27
27
 
28
28
 
29
29
  class AutoGradImplGenerator(BaseGenerator):
@@ -38,6 +38,8 @@ class AutoGradImplGenerator(BaseGenerator):
38
38
  self.OP_DEF_INC_HEAD_TEMPLATE = template.OP_DEF_INC_HEAD_TEMPLATE
39
39
  self.AUTO_GRAD_IMPL_CC_TEMPLATE = template.AUTO_GRAD_IMPL_CC_TEMPLATE
40
40
  self.DO_GRAD_FUNCTION_BODY_TEMPLATE = template.DO_GRAD_FUNCTION_BODY_TEMPLATE
41
+ self.DO_VIEW_GRAD_FUNCTION_BODY_TEMPLATE = template.DO_VIEW_GRAD_FUNCTION_BODY_TEMPLATE
42
+ self.DO_VIEW_CUSTOMIZE_GRAD_FUNCTION_BODY_TEMPLATE = template.DO_VIEW_CUSTOMIZE_GRAD_FUNCTION_BODY_TEMPLATE
41
43
  self.auto_grad_reg_template = Template("const_cast<kernel::pyboost::${class_name}GradFunc&>(" + \
42
44
  "kernel::pyboost::AutoGradFactory::Get()." + \
43
45
  "ops_auto_grad_registers().${class_name}GradFuncObj) = " + \
@@ -45,6 +47,9 @@ class AutoGradImplGenerator(BaseGenerator):
45
47
  self.do_grad_op_args_with_type = Template(
46
48
  "const kernel::pyboost::OpPtr &op, ${input_args_with_type}"
47
49
  )
50
+ self.do_grad_view_op_args_with_type = Template(
51
+ "${output_args_with_type}, ${input_args_with_type}"
52
+ )
48
53
 
49
54
  def generate(self, work_path, op_protos):
50
55
  """
@@ -60,8 +65,13 @@ class AutoGradImplGenerator(BaseGenerator):
60
65
  for op_proto in op_protos:
61
66
  if op_proto.op_dispatch is None:
62
67
  continue
68
+ # the backward func of flatten_ext and t_ext are implemented by other view ops, just continue
69
+ if op_proto.op_view and not op_proto.bprop_expander:
70
+ continue
63
71
  auto_grad_reg_list.append(self.auto_grad_reg_template.replace(class_name=op_proto.op_class.name))
64
- do_grad_op_list.append(self._get_single_do_grad_op(op_proto))
72
+ do_single_grad_op_str = self._get_single_do_grad_view_op(op_proto)\
73
+ if op_proto.op_view else self._get_single_do_grad_op(op_proto)
74
+ do_grad_op_list.append(do_single_grad_op_str)
65
75
  ops_inc_head_set.add(self.OP_DEF_INC_HEAD_TEMPLATE.replace(prefix_char=op_proto.op_class.name[0].lower()))
66
76
  pyboost_func_h_str = self.AUTO_GRAD_IMPL_CC_TEMPLATE.replace(do_grad_op=do_grad_op_list,
67
77
  auto_grad_reg=auto_grad_reg_list,
@@ -80,12 +90,11 @@ class AutoGradImplGenerator(BaseGenerator):
80
90
  Returns:
81
91
  str: The generated DoGrad function string.
82
92
  """
83
- input_args_str = self._get_input_args(op_proto, False, False, op_proto.op_view)
84
- input_args_with_optional_str = self._get_input_args(op_proto, False, True, op_proto.op_view)
85
- input_args_with_type_str = self._get_input_args(op_proto, True, False, op_proto.op_view)
93
+ input_args_str = self._get_input_args(op_proto, False, False, False)
94
+ input_args_with_optional_str = self._get_input_args(op_proto, False, True, False)
95
+ input_args_with_type_str = self._get_input_args(op_proto, True, False, False)
86
96
  inner_grad_args_with_type = self._get_input_args(op_proto, True, False, False)
87
97
  multi_output_str = 'Multi' if is_op_multi_output(op_proto.op_returns) else ''
88
- view_arg_str = self._get_view_str(op_proto.op_view, input_args_str)
89
98
  grad_args_with_type_str = self.do_grad_op_args_with_type.replace(input_args_with_type=input_args_with_type_str)
90
99
  inner_grad_args_with_type =\
91
100
  self.do_grad_op_args_with_type.replace(input_args_with_type=inner_grad_args_with_type)
@@ -94,22 +103,62 @@ class AutoGradImplGenerator(BaseGenerator):
94
103
  FALSE = "false"
95
104
  bprop_expander = TRUE if op_proto.bprop_expander else FALSE
96
105
  non_differentiable = TRUE if op_proto.non_differentiable else FALSE
97
- if not op_proto.op_view:
98
- convert_basic_to_value = ''
99
- else:
100
- input_args_with_optional_str, convert_basic_to_value = self._get_convert_str(op_proto,
101
- input_args_with_optional_str)
106
+
102
107
  return self.DO_GRAD_FUNCTION_BODY_TEMPLATE.replace(class_name=op_proto.op_class.name,
103
108
  inner_grad_args_with_type=inner_grad_args_with_type,
104
109
  grad_args_with_type=grad_args_with_type_str,
105
110
  grad_input_args=input_args_str,
106
111
  grad_input_args_with_optional=input_args_with_optional_str,
107
112
  is_multi=multi_output_str,
108
- view_arg=view_arg_str,
109
113
  op_def_name=op_def_name_str,
110
114
  bprop_expander=bprop_expander,
111
- non_differentiable=non_differentiable,
112
- convert_basic_to_value=convert_basic_to_value)
115
+ non_differentiable=non_differentiable)
116
+
117
+ def _get_single_do_grad_view_op(self, op_proto):
118
+ """
119
+ Generate the DoGrad function for a single view operator prototype.
120
+
121
+ Args:
122
+ op_proto: The operator prototype for which the DoGrad function is generated.
123
+
124
+ Returns:
125
+ str: The generated DoGrad function string.
126
+ """
127
+ input_args_str = self._get_input_args(op_proto, False, False, True)
128
+ input_args_with_optional_str = self._get_input_args(op_proto, False, True, True)
129
+ input_args_with_type_str = self._get_input_args(op_proto, True, False, True)
130
+ inner_grad_args_with_type = self._get_input_args(op_proto, True, False, False)
131
+ view_arg_str = self._get_view_str(input_args_str)
132
+ grad_args_with_type_str = self.do_grad_view_op_args_with_type\
133
+ .replace(input_args_with_type=input_args_with_type_str,
134
+ output_args_with_type=self._get_output_arg(op_proto))
135
+ inner_grad_args_with_type =\
136
+ self.do_grad_view_op_args_with_type.replace(output_args_with_type="const ValuePtr &output_value",
137
+ input_args_with_type=inner_grad_args_with_type)
138
+ op_def_name_str = "g" + op_proto.op_class.name
139
+ TRUE = "true"
140
+ FALSE = "false"
141
+ bprop_expander = TRUE if op_proto.bprop_expander else FALSE
142
+ non_differentiable = TRUE if op_proto.non_differentiable else FALSE
143
+ if op_proto.op_name in ["reshape", "expand_dims", "transpose", "slice_ext_view",\
144
+ "select_ext_view", "transpose_ext_view"]:
145
+ do_view_grad_function_body_tpl = self.DO_VIEW_CUSTOMIZE_GRAD_FUNCTION_BODY_TEMPLATE
146
+ convert_basic_to_value = ""
147
+ else:
148
+ do_view_grad_function_body_tpl = self.DO_VIEW_GRAD_FUNCTION_BODY_TEMPLATE
149
+ input_args_with_optional_str, convert_basic_to_value = self._get_convert_str(op_proto,
150
+ input_args_with_optional_str)
151
+ return do_view_grad_function_body_tpl.replace(class_name=op_proto.op_class.name,
152
+ inner_grad_args_with_type=inner_grad_args_with_type,
153
+ grad_args_with_type=grad_args_with_type_str,
154
+ grad_input_args=input_args_str,
155
+ grad_input_args_with_optional=input_args_with_optional_str,
156
+ view_arg=view_arg_str,
157
+ op_def_name=op_def_name_str,
158
+ bprop_expander=bprop_expander,
159
+ non_differentiable=non_differentiable,
160
+ convert_basic_to_value=convert_basic_to_value)
161
+
113
162
 
114
163
  def _get_input_args(self, op_proto, has_type, with_optional, use_basic_type=False):
115
164
  """
@@ -134,6 +183,15 @@ class AutoGradImplGenerator(BaseGenerator):
134
183
  args_list.append(f"{op_arg.arg_name}_tensor")
135
184
  return args_list
136
185
 
186
+ def _get_output_arg(self, op_proto):
187
+ # for view operators, the output is tensor or vector<tensor>
188
+ if len(op_proto.op_returns) != 1:
189
+ raise ValueError(f"the output of {op_proto.op_name} is not tensor, ",
190
+ "tuple[tensor] or list[tensor], which is not not as expected")
191
+ output_dtype = get_output_dtype(op_proto.op_returns[0].arg_dtype)
192
+ output_arg = f"const {output_dtype} &output"
193
+ return output_arg
194
+
137
195
  def _get_convert_str(self, op_proto, args_name):
138
196
  """
139
197
  Get the input convert func for the DoGrad function.
@@ -161,12 +219,11 @@ class AutoGradImplGenerator(BaseGenerator):
161
219
  args_name_list.append(out_arg_name)
162
220
  return args_name_list, convert_funcs
163
221
 
164
- def _get_view_str(self, is_view_op: bool, grad_args: list):
222
+ def _get_view_str(self, grad_args: list):
165
223
  """
166
224
  Get the view argument string for a DoGrad function.
167
225
 
168
226
  Args:
169
- is_view_op (bool): Whether the operator is a view operator.
170
227
  grad_args (list): A list of gradient arguments.
171
228
 
172
229
  Returns:
@@ -174,7 +231,7 @@ class AutoGradImplGenerator(BaseGenerator):
174
231
  """
175
232
  view_arg_str = ''
176
233
  for i, grad_arg in enumerate(grad_args):
177
- if is_view_op and i == 0:
234
+ if i == 0:
178
235
  view_arg_str = ", " + grad_arg
179
236
  break
180
237
  return view_arg_str
@@ -23,7 +23,7 @@ from common.template import Template
23
23
  import common.gen_constants as K
24
24
  from common.gen_utils import save_file
25
25
  from common.base_generator import BaseGenerator
26
- from pyboost.pyboost_utils import is_optional_param, get_input_dtype
26
+ from pyboost.pyboost_utils import is_optional_param, get_input_dtype, get_output_dtype
27
27
 
28
28
 
29
29
  class AutoGradRegHeaderGenerator(BaseGenerator):
@@ -42,6 +42,9 @@ class AutoGradRegHeaderGenerator(BaseGenerator):
42
42
  self.op_grad_func_args_template = Template(
43
43
  "const kernel::pyboost::OpPtr &, ${input_tensor_prt_args}"
44
44
  )
45
+ self.op_view_grad_func_args_template = Template(
46
+ "${output_tensor_prt_args}, ${input_tensor_prt_args}"
47
+ )
45
48
 
46
49
  def generate(self, work_path, op_protos):
47
50
  """
@@ -60,9 +63,13 @@ class AutoGradRegHeaderGenerator(BaseGenerator):
60
63
  continue
61
64
  op_type_enum_list.append(self.op_type_enum_template.replace(class_name=op_proto.op_class.name,
62
65
  enum_val=index))
66
+ # the backward func of flatten_ext and t_ext are implemented by other view ops, just continue
67
+ if op_proto.op_view and not op_proto.bprop_expander:
68
+ continue
63
69
  grad_func_args_with_type_str = self._get_grad_func_args_with_type_str(op_proto)
64
- op_grad_func_list.append(self.op_grad_func_template.replace(class_name=op_proto.op_class.name,
65
- grad_func_args=grad_func_args_with_type_str))
70
+ op_grad_func_list.append(
71
+ self.op_grad_func_template.replace(class_name=op_proto.op_class.name,
72
+ grad_func_args=grad_func_args_with_type_str))
66
73
  op_grad_func_obj_list.append(self.op_grad_func_obj_template.replace(class_name=op_proto.op_class.name))
67
74
  index += 1
68
75
 
@@ -89,5 +96,15 @@ class AutoGradRegHeaderGenerator(BaseGenerator):
89
96
  is_optional = is_optional_param(op_arg)
90
97
  input_dtype = get_input_dtype(op_arg.arg_dtype, is_optional, op_proto.op_view)
91
98
  input_tensor_prt_args_str += f"const {input_dtype} &, "
92
-
93
- return self.op_grad_func_args_template.replace(input_tensor_prt_args=input_tensor_prt_args_str.rstrip(', '))
99
+ input_tensor_prt_args_str = input_tensor_prt_args_str.rstrip(', ')
100
+ if not op_proto.op_view:
101
+ return self.op_grad_func_args_template.replace(input_tensor_prt_args=\
102
+ input_tensor_prt_args_str)
103
+ # for view operators, the output is tensor or vector<tensor>
104
+ if len(op_proto.op_returns) != 1:
105
+ raise ValueError(f"the output of {op_proto.op_name} is not tensor,",
106
+ "tuple[tensor] or list[tensor], which is not not as expected")
107
+ output_dtype = get_output_dtype(op_proto.op_returns[0].arg_dtype)
108
+ output_tensor_prt_args_str = f"const {output_dtype} &"
109
+ return self.op_view_grad_func_args_template.replace(input_tensor_prt_args=input_tensor_prt_args_str,
110
+ output_tensor_prt_args=output_tensor_prt_args_str)
@@ -47,14 +47,15 @@ class OpTemplateParser:
47
47
  self.op_proto = op_proto
48
48
  self.tensor_arg_handler_prt_template = Template(
49
49
  "parse_args.arg_list_[${idx}] = "
50
- "py::cast((*pynative::${func_str}(\"${func_name}\", \"${op_arg_name}\", "
50
+ "PyLong_FromLong((*pynative::${func_str}(\"${func_name}\", \"${op_arg_name}\", "
51
51
  "parse_args.arg_list_[${idx}]))->value());\n"
52
52
  "parse_args.src_types_[${idx}] = ops::OP_DTYPE::DT_BEGIN;\n"
53
53
  "parse_args.dst_types_[${idx}] = ${new_type};\n"
54
54
  )
55
55
  self.function_arg_handler_prt_template = Template(
56
56
  "parse_args.arg_list_[${idx}] = "
57
- "py::cast((*${func_str}(\"${func_name}\", \"${op_arg_name}\", parse_args.arg_list_[${idx}]))->value());\n"
57
+ "PyLong_FromLong((*${func_str}(\"${func_name}\", \"${op_arg_name}\", "
58
+ "parse_args.arg_list_[${idx}]))->value());\n"
58
59
  "parse_args.src_types_[${idx}] = ops::OP_DTYPE::DT_BEGIN;\n"
59
60
  "parse_args.dst_types_[${idx}] = ${new_type};\n"
60
61
  )
@@ -55,7 +55,7 @@ class PyboostFunctionsGenerator(BaseGenerator):
55
55
  self.PYBOOST_REGISTRY_CC_TEMPLATE = template.PYBOOST_REGISTRY_CC_TEMPLATE
56
56
  self.TENSOR_FUNC_CLASS_REG = template.TENSOR_FUNC_CLASS_REG
57
57
  self.OP_DEF_INC_HEAD_TEMPLATE = template.OP_DEF_INC_HEAD_TEMPLATE
58
-
58
+ self.MARK_SIDE_EFFECT_STR = "PyNativeAlgo::PyBoost::MarkSideEffect(PyList_GetItem(args, 0));"
59
59
  self.pyboost_api_body_template = template.PYBOOST_API_BODY_CC_TEMPLATE
60
60
 
61
61
  def generate(self, work_path, op_protos):
@@ -91,8 +91,8 @@ class PyboostFunctionsGenerator(BaseGenerator):
91
91
  pyboost_op_name=pyboost_op_name,
92
92
  pyboost_cfunc_name=pyboost_func_name,
93
93
  class_name=op_proto.op_class.name)
94
- pyboost_func_include_headers_str += self.pyboost_func_include_header_template.replace(
95
- operator_name=op_proto.op_name)
94
+ pyboost_func_include_headers_str +=\
95
+ self.pyboost_func_include_header_template.replace(operator_name=op_proto.op_name)
96
96
  ops_inc_head_set.add(self.OP_DEF_INC_HEAD_TEMPLATE.replace(prefix_char=op_proto.op_class.name[0].lower()))
97
97
  register_func_str = self.REGISTER_TEMPLATE.replace(register_func=pyboost_func_pybind_def)
98
98
  function_class_register = self._get_function_class_register(op_protos)
@@ -132,11 +132,13 @@ class PyboostFunctionsGenerator(BaseGenerator):
132
132
  op_def_name_str = op_parser.get_op_def_name_str()
133
133
  parser_body_str = self._generate_parser_func(op_proto)
134
134
  op_args_str = [op_arg.arg_name for op_arg in op_proto.op_args]
135
+ side_effect_str = self._generate_mark_side_effect_str(op_proto)
135
136
  pyboost_api_body_str += self.pyboost_api_body_template.replace(func_name=op_pyboost_func_name,
136
137
  op_def_name=op_def_name_str,
137
138
  parser_body=parser_body_str,
138
139
  class_name=op_proto.op_class.name,
139
- op_args=op_args_str)
140
+ op_args=op_args_str,
141
+ mark_side_effect=side_effect_str)
140
142
 
141
143
  ops_inc_head_set.add(self.OP_DEF_INC_HEAD_TEMPLATE.replace(prefix_char=op_proto.op_class.name[0].lower()))
142
144
 
@@ -148,7 +150,7 @@ class PyboostFunctionsGenerator(BaseGenerator):
148
150
  op_def_name_str = op_parser.get_op_def_name_str()
149
151
  parser_body_str = self._generate_parser_func(op_proto)
150
152
  op_args_str = [op_arg.arg_name for op_arg in op_proto.op_args]
151
- registry_body_tpl = self.get_pyboost_registry_body_cc_tpl(op_proto)
153
+ registry_body_tpl = self.PYBOOST_REGISTRY_BODY_CC_TEMPLATE
152
154
  return registry_body_tpl.replace(func_name=op_pyboost_func_name,
153
155
  op_def_name=op_def_name_str,
154
156
  parser_body=parser_body_str,
@@ -200,5 +202,19 @@ class PyboostFunctionsGenerator(BaseGenerator):
200
202
  arg_index=pyboost_utils.get_index(index))
201
203
  return parser_func_str
202
204
 
205
+ def _generate_mark_side_effect_str(self, op_proto: OpProto) -> str:
206
+ """
207
+ Generates the mark side effect str for the inplace operator.
208
+
209
+ Args:
210
+ op_proto (OpProto): The operator prototype containing the argument information.
211
+
212
+ Returns:
213
+ str: The generated mark side effect flag as a string.
214
+ """
215
+ if op_proto.op_inplace or op_proto.op_view:
216
+ return self.MARK_SIDE_EFFECT_STR
217
+ return ""
218
+
203
219
  def get_pyboost_registry_body_cc_tpl(self, op_proto: OpProto):
204
220
  return self.PYBOOST_REGISTRY_BODY_CC_TEMPLATE
@@ -45,10 +45,10 @@ class PyboostFunctionsHeaderGenerator(BaseGenerator):
45
45
  self.PYBOOST_CORE_HEADER_TEMPLATE = template.PYBOOST_CORE_HEADER_TEMPLATE
46
46
 
47
47
  self.pyboost_func_template = Template(
48
- 'py::object PYNATIVE_EXPORT ${func_name}_Base(const PrimitivePtr &prim, const py::list &args);'
48
+ 'PYNATIVE_EXPORT PyObject* ${func_name}_Base(const PrimitivePtr &prim, PyObject* args);'
49
49
  )
50
50
  self.pyboost_op_func_template = Template(
51
- 'py::object PYNATIVE_EXPORT ${func_name}_OP(const PrimitivePtr &prim, '
51
+ 'PYNATIVE_EXPORT PyObject* ${func_name}_OP(const PrimitivePtr &prim, '
52
52
  'const std::vector<ops::OP_DTYPE>& source_type, ${input_args});'
53
53
  )
54
54
  self.input_args_template = Template(" const ${arg_type}& ${arg_name},")