mindspore 2.3.0__cp39-cp39-win_amd64.whl → 2.4.0__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (285) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +3 -1
  3. mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
  6. mindspore/_checkparam.py +50 -9
  7. mindspore/_extends/parse/compile_config.py +41 -0
  8. mindspore/_extends/parse/parser.py +9 -7
  9. mindspore/_extends/parse/standard_method.py +52 -14
  10. mindspore/_extends/pijit/pijit_func_white_list.py +350 -24
  11. mindspore/amp.py +24 -10
  12. mindspore/avcodec-59.dll +0 -0
  13. mindspore/avdevice-59.dll +0 -0
  14. mindspore/avfilter-8.dll +0 -0
  15. mindspore/avformat-59.dll +0 -0
  16. mindspore/avutil-57.dll +0 -0
  17. mindspore/common/__init__.py +6 -4
  18. mindspore/common/_pijit_context.py +190 -0
  19. mindspore/common/_register_for_tensor.py +2 -1
  20. mindspore/common/_tensor_overload.py +139 -0
  21. mindspore/common/api.py +102 -87
  22. mindspore/common/dump.py +5 -6
  23. mindspore/common/generator.py +1 -7
  24. mindspore/common/hook_handle.py +14 -26
  25. mindspore/common/mindir_util.py +2 -2
  26. mindspore/common/parameter.py +46 -13
  27. mindspore/common/recompute.py +39 -9
  28. mindspore/common/sparse_tensor.py +7 -3
  29. mindspore/common/tensor.py +209 -29
  30. mindspore/communication/__init__.py +1 -1
  31. mindspore/communication/_comm_helper.py +38 -3
  32. mindspore/communication/comm_func.py +310 -55
  33. mindspore/communication/management.py +14 -14
  34. mindspore/context.py +123 -22
  35. mindspore/dataset/__init__.py +1 -1
  36. mindspore/dataset/audio/__init__.py +1 -1
  37. mindspore/dataset/core/config.py +7 -0
  38. mindspore/dataset/core/validator_helpers.py +7 -0
  39. mindspore/dataset/engine/cache_client.py +1 -1
  40. mindspore/dataset/engine/datasets.py +72 -44
  41. mindspore/dataset/engine/datasets_audio.py +7 -7
  42. mindspore/dataset/engine/datasets_standard_format.py +53 -3
  43. mindspore/dataset/engine/datasets_text.py +20 -20
  44. mindspore/dataset/engine/datasets_user_defined.py +174 -104
  45. mindspore/dataset/engine/datasets_vision.py +33 -33
  46. mindspore/dataset/engine/iterators.py +29 -0
  47. mindspore/dataset/engine/obs/util.py +7 -0
  48. mindspore/dataset/engine/queue.py +114 -60
  49. mindspore/dataset/engine/serializer_deserializer.py +2 -2
  50. mindspore/dataset/engine/validators.py +34 -14
  51. mindspore/dataset/text/__init__.py +1 -4
  52. mindspore/dataset/transforms/__init__.py +0 -3
  53. mindspore/dataset/utils/line_reader.py +2 -0
  54. mindspore/dataset/vision/__init__.py +1 -4
  55. mindspore/dataset/vision/utils.py +1 -1
  56. mindspore/dataset/vision/validators.py +2 -1
  57. mindspore/dnnl.dll +0 -0
  58. mindspore/{nn/extend → experimental/es}/__init__.py +4 -11
  59. mindspore/experimental/es/embedding_service.py +883 -0
  60. mindspore/{nn/layer → experimental/es}/embedding_service_layer.py +218 -30
  61. mindspore/experimental/llm_boost/__init__.py +21 -0
  62. mindspore/{nn/extend/layer → experimental/llm_boost/atb}/__init__.py +4 -8
  63. mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
  64. mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
  65. mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
  66. mindspore/experimental/llm_boost/register.py +129 -0
  67. mindspore/experimental/llm_boost/utils.py +31 -0
  68. mindspore/experimental/optim/adamw.py +85 -0
  69. mindspore/experimental/optim/optimizer.py +3 -0
  70. mindspore/hal/__init__.py +3 -3
  71. mindspore/hal/contiguous_tensors_handle.py +175 -0
  72. mindspore/hal/stream.py +18 -0
  73. mindspore/include/api/model_group.h +13 -1
  74. mindspore/include/api/types.h +10 -10
  75. mindspore/include/dataset/config.h +2 -2
  76. mindspore/include/dataset/constants.h +2 -2
  77. mindspore/include/dataset/execute.h +2 -2
  78. mindspore/include/dataset/vision.h +4 -0
  79. mindspore/jpeg62.dll +0 -0
  80. mindspore/log.py +1 -1
  81. mindspore/mindrecord/filewriter.py +68 -51
  82. mindspore/mindspore_backend.dll +0 -0
  83. mindspore/mindspore_common.dll +0 -0
  84. mindspore/mindspore_core.dll +0 -0
  85. mindspore/mindspore_glog.dll +0 -0
  86. mindspore/mindspore_np_dtype.dll +0 -0
  87. mindspore/mindspore_ops.dll +0 -0
  88. mindspore/mint/__init__.py +495 -46
  89. mindspore/mint/distributed/__init__.py +31 -0
  90. mindspore/mint/distributed/distributed.py +254 -0
  91. mindspore/mint/nn/__init__.py +266 -21
  92. mindspore/mint/nn/functional.py +125 -19
  93. mindspore/mint/nn/layer/__init__.py +39 -0
  94. mindspore/mint/nn/layer/activation.py +133 -0
  95. mindspore/mint/nn/layer/normalization.py +477 -0
  96. mindspore/mint/nn/layer/pooling.py +110 -0
  97. mindspore/mint/optim/adamw.py +28 -7
  98. mindspore/mint/special/__init__.py +63 -0
  99. mindspore/multiprocessing/__init__.py +2 -1
  100. mindspore/nn/__init__.py +0 -1
  101. mindspore/nn/cell.py +275 -93
  102. mindspore/nn/layer/activation.py +211 -44
  103. mindspore/nn/layer/basic.py +113 -3
  104. mindspore/nn/layer/embedding.py +120 -2
  105. mindspore/nn/layer/normalization.py +101 -5
  106. mindspore/nn/layer/padding.py +34 -48
  107. mindspore/nn/layer/pooling.py +161 -7
  108. mindspore/nn/layer/transformer.py +3 -3
  109. mindspore/nn/loss/__init__.py +2 -2
  110. mindspore/nn/loss/loss.py +84 -6
  111. mindspore/nn/optim/__init__.py +2 -1
  112. mindspore/nn/optim/adadelta.py +1 -1
  113. mindspore/nn/optim/adam.py +1 -1
  114. mindspore/nn/optim/lamb.py +1 -1
  115. mindspore/nn/optim/tft_wrapper.py +127 -0
  116. mindspore/nn/wrap/cell_wrapper.py +12 -23
  117. mindspore/nn/wrap/grad_reducer.py +5 -5
  118. mindspore/nn/wrap/loss_scale.py +17 -3
  119. mindspore/numpy/__init__.py +1 -1
  120. mindspore/numpy/array_creations.py +65 -68
  121. mindspore/numpy/array_ops.py +64 -60
  122. mindspore/numpy/fft.py +610 -75
  123. mindspore/numpy/logic_ops.py +11 -10
  124. mindspore/numpy/math_ops.py +85 -84
  125. mindspore/numpy/utils_const.py +4 -4
  126. mindspore/opencv_core452.dll +0 -0
  127. mindspore/opencv_imgcodecs452.dll +0 -0
  128. mindspore/opencv_imgproc452.dll +0 -0
  129. mindspore/ops/__init__.py +6 -4
  130. mindspore/ops/_grad_experimental/grad_comm_ops.py +47 -3
  131. mindspore/ops/_grad_experimental/grad_math_ops.py +0 -22
  132. mindspore/ops/_vmap/vmap_array_ops.py +2 -4
  133. mindspore/ops/_vmap/vmap_math_ops.py +17 -1
  134. mindspore/ops/_vmap/vmap_nn_ops.py +43 -2
  135. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +85 -7
  136. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +2 -0
  137. mindspore/ops/auto_generate/gen_extend_func.py +734 -13
  138. mindspore/ops/auto_generate/gen_ops_def.py +2420 -381
  139. mindspore/ops/auto_generate/gen_ops_prim.py +5196 -1659
  140. mindspore/ops/auto_generate/pyboost_inner_prim.py +176 -56
  141. mindspore/ops/composite/base.py +85 -48
  142. mindspore/ops/composite/multitype_ops/_compile_utils.py +1 -0
  143. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -2
  144. mindspore/ops/function/__init__.py +22 -0
  145. mindspore/ops/function/array_func.py +490 -153
  146. mindspore/ops/function/debug_func.py +113 -1
  147. mindspore/ops/function/fft_func.py +15 -2
  148. mindspore/ops/function/grad/grad_func.py +3 -2
  149. mindspore/ops/function/math_func.py +558 -207
  150. mindspore/ops/function/nn_func.py +817 -383
  151. mindspore/ops/function/other_func.py +3 -2
  152. mindspore/ops/function/random_func.py +184 -8
  153. mindspore/ops/function/reshard_func.py +13 -11
  154. mindspore/ops/function/sparse_unary_func.py +1 -1
  155. mindspore/ops/function/vmap_func.py +3 -2
  156. mindspore/ops/functional.py +24 -14
  157. mindspore/ops/op_info_register.py +3 -3
  158. mindspore/ops/operations/__init__.py +6 -1
  159. mindspore/ops/operations/_grad_ops.py +2 -76
  160. mindspore/ops/operations/_infer_ops.py +1 -1
  161. mindspore/ops/operations/_inner_ops.py +71 -94
  162. mindspore/ops/operations/array_ops.py +12 -146
  163. mindspore/ops/operations/comm_ops.py +42 -53
  164. mindspore/ops/operations/custom_ops.py +83 -19
  165. mindspore/ops/operations/debug_ops.py +42 -10
  166. mindspore/ops/operations/manually_defined/_inner.py +12 -0
  167. mindspore/ops/operations/manually_defined/ops_def.py +265 -10
  168. mindspore/ops/operations/math_ops.py +12 -223
  169. mindspore/ops/operations/nn_ops.py +20 -114
  170. mindspore/ops/operations/other_ops.py +7 -4
  171. mindspore/ops/operations/random_ops.py +46 -1
  172. mindspore/ops/primitive.py +18 -6
  173. mindspore/ops_generate/arg_dtype_cast.py +2 -0
  174. mindspore/ops_generate/gen_aclnn_implement.py +11 -11
  175. mindspore/ops_generate/gen_constants.py +36 -0
  176. mindspore/ops_generate/gen_ops.py +67 -52
  177. mindspore/ops_generate/gen_ops_inner_prim.py +1 -1
  178. mindspore/ops_generate/gen_pyboost_func.py +131 -47
  179. mindspore/ops_generate/op_proto.py +10 -3
  180. mindspore/ops_generate/pyboost_utils.py +14 -1
  181. mindspore/ops_generate/template.py +43 -21
  182. mindspore/parallel/__init__.py +3 -1
  183. mindspore/parallel/_auto_parallel_context.py +28 -8
  184. mindspore/parallel/_cell_wrapper.py +83 -0
  185. mindspore/parallel/_parallel_serialization.py +47 -19
  186. mindspore/parallel/_tensor.py +81 -11
  187. mindspore/parallel/_utils.py +13 -1
  188. mindspore/parallel/algo_parameter_config.py +5 -5
  189. mindspore/parallel/checkpoint_transform.py +46 -39
  190. mindspore/parallel/cluster/process_entity/__init__.py +1 -1
  191. mindspore/parallel/cluster/process_entity/_api.py +31 -23
  192. mindspore/parallel/cluster/process_entity/_utils.py +2 -27
  193. mindspore/parallel/parameter_broadcast.py +3 -4
  194. mindspore/parallel/shard.py +162 -31
  195. mindspore/parallel/transform_safetensors.py +993 -0
  196. mindspore/profiler/__init__.py +2 -1
  197. mindspore/profiler/common/constant.py +29 -0
  198. mindspore/profiler/common/registry.py +47 -0
  199. mindspore/profiler/common/util.py +28 -0
  200. mindspore/profiler/dynamic_profiler.py +694 -0
  201. mindspore/profiler/envprofiling.py +17 -19
  202. mindspore/profiler/parser/ascend_analysis/constant.py +18 -0
  203. mindspore/profiler/parser/ascend_analysis/file_manager.py +25 -4
  204. mindspore/profiler/parser/ascend_analysis/function_event.py +43 -19
  205. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +31 -26
  206. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +56 -10
  207. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +55 -8
  208. mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
  209. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +27 -20
  210. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +9 -2
  211. mindspore/profiler/parser/ascend_msprof_exporter.py +5 -4
  212. mindspore/profiler/parser/ascend_timeline_generator.py +27 -25
  213. mindspore/profiler/parser/base_timeline_generator.py +19 -25
  214. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
  215. mindspore/profiler/parser/framework_parser.py +1 -391
  216. mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
  217. mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
  218. mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
  219. mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
  220. mindspore/profiler/parser/memory_usage_parser.py +0 -154
  221. mindspore/profiler/parser/profiler_info.py +78 -6
  222. mindspore/profiler/profiler.py +153 -0
  223. mindspore/profiler/profiling.py +280 -412
  224. mindspore/rewrite/__init__.py +1 -2
  225. mindspore/rewrite/common/namespace.py +4 -4
  226. mindspore/rewrite/symbol_tree/symbol_tree.py +3 -3
  227. mindspore/run_check/_check_version.py +36 -103
  228. mindspore/safeguard/rewrite_obfuscation.py +591 -247
  229. mindspore/swresample-4.dll +0 -0
  230. mindspore/swscale-6.dll +0 -0
  231. mindspore/tinyxml2.dll +0 -0
  232. mindspore/train/__init__.py +4 -3
  233. mindspore/train/_utils.py +28 -2
  234. mindspore/train/amp.py +171 -53
  235. mindspore/train/callback/__init__.py +2 -2
  236. mindspore/train/callback/_callback.py +4 -4
  237. mindspore/train/callback/_checkpoint.py +85 -22
  238. mindspore/train/callback/_cluster_monitor.py +1 -1
  239. mindspore/train/callback/_flops_collector.py +1 -0
  240. mindspore/train/callback/_loss_monitor.py +3 -3
  241. mindspore/train/callback/_on_request_exit.py +134 -31
  242. mindspore/train/callback/_summary_collector.py +5 -5
  243. mindspore/train/callback/_tft_register.py +352 -0
  244. mindspore/train/dataset_helper.py +7 -3
  245. mindspore/train/metrics/metric.py +3 -3
  246. mindspore/train/metrics/roc.py +4 -4
  247. mindspore/train/mind_ir_pb2.py +44 -39
  248. mindspore/train/model.py +134 -58
  249. mindspore/train/serialization.py +336 -112
  250. mindspore/turbojpeg.dll +0 -0
  251. mindspore/utils/__init__.py +21 -0
  252. mindspore/utils/utils.py +60 -0
  253. mindspore/version.py +1 -1
  254. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/METADATA +6 -2
  255. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/RECORD +258 -252
  256. mindspore/include/c_api/ms/abstract.h +0 -67
  257. mindspore/include/c_api/ms/attribute.h +0 -197
  258. mindspore/include/c_api/ms/base/handle_types.h +0 -43
  259. mindspore/include/c_api/ms/base/macros.h +0 -32
  260. mindspore/include/c_api/ms/base/status.h +0 -33
  261. mindspore/include/c_api/ms/base/types.h +0 -283
  262. mindspore/include/c_api/ms/context.h +0 -102
  263. mindspore/include/c_api/ms/graph.h +0 -160
  264. mindspore/include/c_api/ms/node.h +0 -606
  265. mindspore/include/c_api/ms/tensor.h +0 -161
  266. mindspore/include/c_api/ms/value.h +0 -84
  267. mindspore/mindspore_shared_lib.dll +0 -0
  268. mindspore/nn/extend/basic.py +0 -140
  269. mindspore/nn/extend/embedding.py +0 -143
  270. mindspore/nn/extend/layer/normalization.py +0 -109
  271. mindspore/nn/extend/pooling.py +0 -117
  272. mindspore/nn/layer/embedding_service.py +0 -531
  273. mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
  274. mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
  275. mindspore/ops/extend/__init__.py +0 -53
  276. mindspore/ops/extend/array_func.py +0 -218
  277. mindspore/ops/extend/math_func.py +0 -76
  278. mindspore/ops/extend/nn_func.py +0 -308
  279. mindspore/ops/silent_check.py +0 -162
  280. mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
  281. mindspore/profiler/parser/msadvisor_parser.py +0 -240
  282. mindspore/train/callback/_mindio_ttp.py +0 -443
  283. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/WHEEL +0 -0
  284. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/entry_points.txt +0 -0
  285. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/top_level.txt +0 -0
mindspore/.commit_id CHANGED
@@ -1 +1 @@
1
- __commit_id__ = ''[sha1]:a4230c71,[branch]:(HEAD->master,origin/master,origin/HEAD)''
1
+ __commit_id__ = ''[sha1]:8c86f33f,[branch]:(HEAD->master,origin/master,origin/HEAD)''
mindspore/__init__.py CHANGED
@@ -23,6 +23,7 @@ from mindspore.mindrecord import *
23
23
  from mindspore.ops import _op_impl, grad, value_and_grad, vjp, jvp, jacfwd, jacrev, vmap, get_grad, constexpr, reshard
24
24
  from mindspore.train import *
25
25
  from mindspore.log import *
26
+ from mindspore.utils import *
26
27
  from mindspore.context import GRAPH_MODE, PYNATIVE_MODE, set_context, get_context, set_auto_parallel_context, \
27
28
  get_auto_parallel_context, reset_auto_parallel_context, ParallelMode, set_ps_context, \
28
29
  get_ps_context, reset_ps_context, set_offload_context, get_offload_context, STRICT, COMPATIBLE, LAX
@@ -30,7 +31,8 @@ from mindspore.version import __version__
30
31
  from mindspore.profiler import Profiler, EnvProfiler
31
32
  from mindspore.parallel import set_algo_parameters, get_algo_parameters, reset_algo_parameters, \
32
33
  rank_list_for_transform, transform_checkpoint_by_rank, transform_checkpoints, merge_pipeline_strategys, shard, \
33
- Layout, sync_pipeline_shared_parameters, parameter_broadcast, load_segmented_checkpoints
34
+ Layout, sync_pipeline_shared_parameters, parameter_broadcast, load_segmented_checkpoints, \
35
+ safetensors_to_ckpt, ckpt_to_safetensors, unified_safetensors
34
36
  from mindspore.rewrite import SymbolTree, ScopedValue, Node, NodeType
35
37
  from mindspore.safeguard import obfuscate_ckpt, load_obf_params_into_net
36
38
  from mindspore._check_jit_forbidden_api import get_obj_module_and_name_info, is_jit_forbidden_module, \
Binary file
Binary file
Binary file
mindspore/_checkparam.py CHANGED
@@ -29,7 +29,6 @@ from mindspore import log as logger
29
29
  from mindspore.common import dtype as mstype
30
30
  from mindspore._c_expression import Tensor as Tensor_
31
31
 
32
-
33
32
  EQ = 1 # ==
34
33
  NE = 2 # !=
35
34
  LT = 3 # <
@@ -148,7 +147,7 @@ def _check_3d_int_or_tuple(arg_name, arg_value, prim_name, allow_five=False, ret
148
147
  ret = (1, 1, arg_value, arg_value, arg_value) if ret_five else (arg_value, arg_value, arg_value)
149
148
  elif len(arg_value) == 3:
150
149
  ret = (1, 1, arg_value[0], arg_value[1], arg_value[2]) if ret_five else arg_value
151
- else: # case: len(arg_value) == 5
150
+ else: # case: len(arg_value) == 5
152
151
  ret = arg_value if ret_five else (arg_value[2], arg_value[3], arg_value[4])
153
152
 
154
153
  return ret
@@ -240,6 +239,7 @@ def check_is_number(arg_value, arg_type, arg_name=None, prim_name=None):
240
239
  else:
241
240
  raise TypeError(f"{prim_name} type of {arg_name} must be '{arg_type.__name__}', " \
242
241
  f"but got '{type(arg_value).__name__}'.")
242
+
243
243
  _check_param()
244
244
  return arg_value
245
245
 
@@ -265,6 +265,7 @@ def check_number_range(arg_value, lower_limit, upper_limit, rel, value_type, arg
265
265
  rel_str = _format_str_two_value(lower_limit, upper_limit, rel)
266
266
  raise ValueError(f"{prim_name} {arg_name} must be in range of {rel_str}, " \
267
267
  f"but got {arg_value} with type '{type(arg_value).__name__}'.")
268
+
268
269
  _check_param()
269
270
  return arg_value
270
271
 
@@ -274,6 +275,7 @@ def check(arg_name, arg_value, value_name, value, rel=EQ, prim_name=None, excp_c
274
275
  Method for judging relation between two int values or list/tuple made up of ints.
275
276
  This method is not suitable for judging relation between floats, since it does not consider float error.
276
277
  """
278
+
277
279
  def _check():
278
280
  if not _check_binary_rel(arg_value, value, rel):
279
281
  rel_str = _format_str_one_value(f'{value_name}: {value}', rel)
@@ -475,20 +477,24 @@ def check_non_negative_float(arg_value, arg_name=None, prim_name=None):
475
477
 
476
478
  def check_number(arg_name, arg_value, value, rel, prim_name):
477
479
  """Number value judgment."""
480
+
478
481
  def _check():
479
482
  if not _check_binary_rel(arg_value, value, rel):
480
483
  rel_str = _format_str_one_value(value, rel)
481
484
  raise ValueError(f'For \'{prim_name}\', the argument \'{arg_name}\' ' \
482
485
  f'must {rel_str}, but got {arg_value}.')
486
+
483
487
  _check()
484
488
  return arg_value
485
489
 
486
490
 
487
491
  def check_isinstance(arg_name, arg_value, classes):
488
492
  """Check arg isinstance of classes"""
493
+
489
494
  def _check():
490
495
  if not isinstance(arg_value, classes):
491
496
  raise ValueError(f'The parameter \'{arg_name}\' must be isinstance of {classes}, but got {arg_value}.')
497
+
492
498
  _check()
493
499
  return arg_value
494
500
 
@@ -507,6 +513,7 @@ def check_bool(arg_value, arg_name=None, prim_name=None):
507
513
  def _check():
508
514
  if not isinstance(arg_value, bool):
509
515
  raise TypeError(f"{prim_name} {arg_name} must be a bool, but got {type(arg_value).__name__}.")
516
+
510
517
  _check()
511
518
  return arg_value
512
519
 
@@ -547,6 +554,7 @@ def check_string(arg_value, valid_values, arg_name=None, prim_name=None):
547
554
  if not (isinstance(arg_value, str) and arg_value in valid_values):
548
555
  raise ValueError(f"{msg_prefix} '{arg_name}' must be str and must be in '{valid_values}'," \
549
556
  f" but got '{arg_value}'.")
557
+
550
558
  _check()
551
559
  return arg_value
552
560
 
@@ -626,10 +634,12 @@ def check_subclass(arg_name, type_, template_types, prim_name, addition_error_in
626
634
 
627
635
  def check_valid_input(arg_name, arg_value, prim_name):
628
636
  """Checks valid value."""
637
+
629
638
  def _check():
630
639
  if arg_value is None:
631
640
  raise ValueError(f"For \'{prim_name}\', the argument '{arg_name}'" \
632
641
  f"can not be None, but got {arg_value}.")
642
+
633
643
  _check()
634
644
  return arg_value
635
645
 
@@ -786,6 +796,7 @@ def check_astype_dtype(dtype):
786
796
 
787
797
  def check_transpose_axis(axes, ndim):
788
798
  """Check the axis argument for tensor.transpose"""
799
+
789
800
  def _check_dim():
790
801
  # if multiple arguments provided, it must be `ndim` number of ints
791
802
  if len(axes) != ndim:
@@ -793,7 +804,7 @@ def check_transpose_axis(axes, ndim):
793
804
  f"but got {len(axes)} in the number of axes.")
794
805
 
795
806
  if not axes or (len(axes) == 1 and axes[0] is None):
796
- return tuple(range(ndim-1, -1, -1))
807
+ return tuple(range(ndim - 1, -1, -1))
797
808
 
798
809
  if len(axes) == 1:
799
810
  perm = axes[0]
@@ -912,6 +923,7 @@ def prepare_shape_for_squeeze(shape, axes):
912
923
 
913
924
  def check_axis_in_range(axis, ndim):
914
925
  """Checks axes are with the bounds of ndim"""
926
+
915
927
  def _check():
916
928
  if not isinstance(axis, int):
917
929
  raise TypeError(f'The axes must be integers, but got {type(axis)}')
@@ -928,6 +940,7 @@ def check_axis_valid(axes, ndim):
928
940
  Checks axes are valid given ndim, and returns axes that can be passed
929
941
  to the built-in operator (non-negative, int or tuple)
930
942
  """
943
+
931
944
  def _check_range(axes):
932
945
  for axis in axes:
933
946
  check_axis_in_range(axis, ndim)
@@ -977,16 +990,17 @@ def infer_out_shape(*shapes):
977
990
  """
978
991
  Returns shape of output after broadcasting. Raises ValueError if shapes cannot be broadcast.
979
992
  """
993
+
980
994
  def _check(items, max_size, shapes):
981
995
  for item in items:
982
996
  if item not in (1, max_size):
983
997
  raise ValueError(f'For Tensor, the dimension on each axis must be 1 or the max value on the axis' \
984
998
  f'to support broadcasting, but got shapes {shapes,}')
999
+
985
1000
  shape_out = ()
986
1001
  max_len = max([len(it) for it in shapes])
987
1002
  for i in range(max_len):
988
- items = [it[i-(max_len-len(it))] if i - (max_len - len(it))
989
- >= 0 else 1 for it in shapes]
1003
+ items = [it[i - (max_len - len(it))] if i - (max_len - len(it)) >= 0 else 1 for it in shapes]
990
1004
  max_size = 0 if 0 in items else max(items)
991
1005
  _check(items, max_size, shapes)
992
1006
  shape_out = shape_out + (max_size,)
@@ -1015,6 +1029,7 @@ def check_axis_type(axis, type_int=True, type_tuple=True, type_list=True):
1015
1029
 
1016
1030
  def check_and_canonicalize_axes(axes, ndim):
1017
1031
  """Check whether the types and values of input axes are valid."""
1032
+
1018
1033
  def _check(axes, ax, ndim):
1019
1034
  if not isinstance(ax, int):
1020
1035
  raise TypeError(f"Each axis should be integer, but got {type(ax)} in {axes}.")
@@ -1091,8 +1106,8 @@ def check_csr_tensor_shape(indptr_shp, indices_shp, values_shp, csr_shp):
1091
1106
  f"{len(csr_shp)}")
1092
1107
  if values_shp[1:] != csr_shp[2:]:
1093
1108
  raise ValueError(f"CSRTensor's shape[2: ] must be equal to value's shape[1: ]," \
1094
- f"but CSRTensor's shape[2: ] got: {csr_shp[2: ]} and value's shape[1: ]" \
1095
- f"got: {values_shp[1: ]}")
1109
+ f"but CSRTensor's shape[2: ] got: {csr_shp[2:]} and value's shape[1: ]" \
1110
+ f"got: {values_shp[1:]}")
1096
1111
 
1097
1112
 
1098
1113
  def check_csr_tensor_dtype(indptr_dtype, indices_dtype):
@@ -1370,9 +1385,35 @@ def check_hook_fn(hook_type, hook_fn):
1370
1385
  if hook_fn.__code__.co_name == "staging_specialize":
1371
1386
  raise TypeError(f"Decorating hook function {hook_fn.__name__} with '@jit' is not supported.")
1372
1387
 
1373
- if hook_type == "register_hook" and hook_fn.__code__.co_argcount != 1:
1374
- raise TypeError(f"Tensor hook function {hook_fn.__name__} arg num is not equal to 1.")
1388
+ tensor_hook_func_args_num = 1
1389
+ pre_hook_func_args_num = 2
1390
+ forward_hook_and_backward_hook_func_args_num = 3
1391
+ # Real args number, exclude class method self param
1392
+ hook_fn_args_num = len(inspect.signature(hook_fn).parameters)
1393
+
1394
+ if hook_type == "register_hook" and hook_fn_args_num != tensor_hook_func_args_num:
1395
+ raise TypeError(f"Tensor hook function {hook_fn.__name__} arg num should be {tensor_hook_func_args_num}, but "
1396
+ f"got {hook_fn_args_num}")
1397
+
1398
+ if hook_type == "register_forward_pre_hook" and hook_fn_args_num != pre_hook_func_args_num:
1399
+ raise TypeError(f"forward_pre_hook function {hook_fn.__name__} args num should be {pre_hook_func_args_num}, "
1400
+ f"but got {hook_fn_args_num}")
1401
+
1402
+ if (hook_type == "register_forward_hook" and
1403
+ hook_fn_args_num != forward_hook_and_backward_hook_func_args_num):
1404
+ raise TypeError(f"forward_hook function {hook_fn.__name__} args num should be "
1405
+ f"{forward_hook_and_backward_hook_func_args_num}, but got {hook_fn_args_num}")
1406
+
1407
+ if hook_type == "register_backward_pre_hook" and hook_fn_args_num != pre_hook_func_args_num:
1408
+ raise TypeError(f"backward_pre_hook function {hook_fn.__name__} args num should be {pre_hook_func_args_num},"
1409
+ f" but got {hook_fn_args_num}")
1410
+
1411
+ if (hook_type == "register_backward_hook" and
1412
+ hook_fn_args_num != forward_hook_and_backward_hook_func_args_num):
1413
+ raise TypeError(f"backward_hook function {hook_fn.__name__} args num should be "
1414
+ f"{forward_hook_and_backward_hook_func_args_num}, but got {hook_fn_args_num}")
1375
1415
 
1376
1416
  return True
1377
1417
 
1418
+
1378
1419
  _set_record = {}
@@ -12,6 +12,15 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ============================================================================
15
+ """
16
+ Name: AUTO_PASSES_OPTIMIZE_PATH
17
+ Function: Whether to do optimize the passes configure.
18
+ Value Range:
19
+ string: The passes configure file path.
20
+ Default: '' .empty string. Disable to do optimize the passes.
21
+ """
22
+ AUTO_PASSES_OPTIMIZE_PATH = ''
23
+
15
24
  """
16
25
  Name: COMPILE_PROFILE
17
26
  Function: Whether to do profile and print profile log.
@@ -29,6 +38,16 @@ Value Range:
29
38
  """
30
39
  COMPILE_PROFILE_FINISH_ACTION = ''
31
40
 
41
+ """
42
+ Name: DEBUG_MODE
43
+ Function: Whether to compile in debug mode.
44
+ Value Range:
45
+ "debug": Debug mode
46
+ "release": Release mode
47
+ Default: "debug"
48
+ """
49
+ COMPILE_DEBUG_MODE = ''
50
+
32
51
  """
33
52
  Name: FALLBACK_SUPPORT_LIST_DICT_INPLACE
34
53
  Function: Whether to support the inplace operation of list and dict.
@@ -230,9 +249,28 @@ Value Range:
230
249
  """
231
250
  DUMP_VALIDATE_BEFORE_RESET_ID = ''
232
251
 
252
+ """
253
+ Name: ENABLE_RECOMPUTE_BEFORE_INLINE
254
+ Function: Whether to do recomputation before fprop and bprop being inlined.
255
+ Value Range:
256
+ 1: Enable
257
+ Default: Disable.
258
+ """
259
+ ENABLE_RECOMPUTE_BEFORE_INLINE = ''
260
+
261
+ """
262
+ Name: STRICT_CHECK_PARENT_CONTEXT
263
+ Function: Whether to check parent context strictly.
264
+ Value Range:
265
+ 1: Enable
266
+ Default: Disable.
267
+ """
268
+ STRICT_CHECK_PARENT_CONTEXT = ''
269
+
233
270
  __all__ = [
234
271
  "COMPILE_PROFILE",
235
272
  "COMPILE_PROFILE_FINISH_ACTION",
273
+ "COMPILE_DEBUG_MODE",
236
274
  "FALLBACK_SUPPORT_LIST_DICT_INPLACE",
237
275
  "FALLBACK_FORCE_ANY",
238
276
  "IF_PARALLEL_CALL",
@@ -255,4 +293,7 @@ __all__ = [
255
293
  "DUMP_IR_DDE_DETAIL",
256
294
  "COMBINE_LIKE_GRAPHS",
257
295
  "DUMP_VALIDATE_BEFORE_RESET_ID",
296
+ "ENABLE_RECOMPUTE_BEFORE_INLINE",
297
+ "STRICT_CHECK_PARENT_CONTEXT",
298
+ "AUTO_PASSES_OPTIMIZE_PATH",
258
299
  ]
@@ -127,7 +127,7 @@ _modules_from_mindspore = (
127
127
  "mindspore_rl", "mindformers", "mindpet", "mindpose", "mindface", "mindsearch", "mindinsight", "mindelec",
128
128
  "mindflow", "mindsponge", "mindearth", "sciai", "mindquantum", "mindarmour", "mindpandas", "mindvision",
129
129
  "mindspore_gl", "mindspore_federated", "mindspore_gs", "mindspore_serving", "mindspore_xai", "mindspore_hub",
130
- "ringmo_framework", "troubleshooter", "mindtorch",
130
+ "ringmo_framework", "troubleshooter", "mindtorch", "mindchemistry",
131
131
  )
132
132
 
133
133
  _global_params = {}
@@ -203,7 +203,7 @@ def get_parse_method_of_class(obj, parse_method=None):
203
203
  if parse_method is not None:
204
204
  method_name = parse_method
205
205
  elif isinstance(obj, nn.Cell):
206
- if obj._enable_backward_hook:
206
+ if obj._backward_hook:
207
207
  method_name = "_backward_hook_construct"
208
208
  else:
209
209
  method_name = "construct"
@@ -486,7 +486,7 @@ def convert_class_to_function(cls_str, cls_obj):
486
486
  f"supported in 'construct' or @jit decorated function. Try to create {cls_str} "
487
487
  f"instances external such as initialized in the method '__init__' before assigning. "
488
488
  f"For more details, please refer to "
489
- f"https://www.mindspore.cn/docs/zh-CN/master/design/dynamic_graph_and_static_graph.html \n")
489
+ f"https://www.mindspore.cn/docs/zh-CN/master/model_train/program_form/overview.html \n")
490
490
  return convert_class_to_function_map.get(cls_str)
491
491
 
492
492
 
@@ -931,7 +931,7 @@ class ThirdPartyLibraryChecker:
931
931
  """
932
932
  def __init__(self):
933
933
  self.user_workspace_dir = self.get_top_level_module_path(os.getcwd())
934
- self.python_builtin_dir = os.path.abspath(os.path.dirname(os.__file__))
934
+ self.python_builtin_dir = os.path.realpath(os.path.dirname(os.__file__))
935
935
 
936
936
  @staticmethod
937
937
  def get_jit_modules():
@@ -963,8 +963,8 @@ class ThirdPartyLibraryChecker:
963
963
 
964
964
  def get_top_level_module_path(self, module_path):
965
965
  """Get the path of the top level package of the current working directory."""
966
- module_abspath = os.path.abspath(module_path)
967
- upper_path = os.path.abspath(os.path.dirname(module_abspath))
966
+ module_abspath = os.path.realpath(module_path)
967
+ upper_path = os.path.realpath(os.path.dirname(module_abspath))
968
968
  if module_abspath == upper_path:
969
969
  return module_abspath
970
970
  # Check whether __init__.py exists in the upper directory.
@@ -990,7 +990,7 @@ class ThirdPartyLibraryChecker:
990
990
  # A modules without __file__ attribute is considered to be in user workspace.
991
991
  if not hasattr(module, '__file__'):
992
992
  return False
993
- module_path = os.path.abspath(module.__file__)
993
+ module_path = os.path.realpath(module.__file__)
994
994
  # Python builtin modules are treated as third-party libraries.
995
995
  if module_path.startswith(self.python_builtin_dir):
996
996
  logger.debug(f"Found python builtin module '{module.__name__}', which is a third-party module.")
@@ -1180,6 +1180,7 @@ class Parser:
1180
1180
  return SYNTAX_SUPPORTED
1181
1181
 
1182
1182
  def check_lambda(self, src):
1183
+ """Check if the lamda expressions is correct."""
1183
1184
  obj_type = get_obj_type(self.fn)
1184
1185
  if (obj_type != RESOLVE_TYPE_FUNCTION or src[:4] == "def ") and is_lambda_function(self.fn):
1185
1186
  logger.debug("fn is lambda: %r", self.fn)
@@ -1242,6 +1243,7 @@ class Parser:
1242
1243
  return None, None
1243
1244
 
1244
1245
  def get_name_from_namespace(self, value):
1246
+ """Get the name of value from namespace"""
1245
1247
  try:
1246
1248
  value_str = value.__name__
1247
1249
  logger.debug(
@@ -26,6 +26,7 @@ from mindspore.common.sparse_tensor import RowTensorInner
26
26
  from mindspore.ops.composite.base import _append, _insert, _pop, _list_clear, _reverse, \
27
27
  _extend, _dict_setitem, _dict_clear, _haskey, _update, _fromkeys
28
28
  from mindspore.ops.operations._sequence_ops import TensorToTuple
29
+ from mindspore.ops.auto_generate import trace_v2_op, inplace_addmm_op
29
30
 
30
31
  from ... import _checkparam as validator
31
32
  from ..._checkparam import check_is_number, check_reshape_shp, check_axis_in_range, \
@@ -69,6 +70,38 @@ itemsize_map = {mstype.bool_: 1, mstype.int8: 1, mstype.uint8: 1,
69
70
 
70
71
  nan_tensor = Tensor(float('nan'), dtype=mstype.float32)
71
72
 
73
+ _map = composite.HyperMap()
74
+
75
+
76
+ def hypermap_dynamic_tuple(func, *inputs):
77
+ """Make hypermap for dynamic shape tuple."""
78
+ iter_len = len(inputs[0])
79
+ i = 0
80
+ ret = F.make_tuple()
81
+ while i < iter_len:
82
+ cur_input = F.make_tuple()
83
+ for m in inputs:
84
+ cur_input = cur_input + F.make_tuple(m[i])
85
+ new_out = _map(func, *cur_input)
86
+ ret = ret + F.make_tuple(new_out)
87
+ i = i + 1
88
+ return ret
89
+
90
+
91
+ def hypermap_dynamic_list(func, *inputs):
92
+ """Make hypermap for dynamic shape list."""
93
+ iter_len = len(inputs[0])
94
+ i = 0
95
+ ret = F.make_list()
96
+ while i < iter_len:
97
+ cur_input = F.make_tuple()
98
+ for m in inputs:
99
+ cur_input = cur_input + F.make_tuple(m[i])
100
+ new_out = _map(func, *cur_input)
101
+ ret = ret + F.make_list(new_out)
102
+ i = i + 1
103
+ return ret
104
+
72
105
 
73
106
  def mean(x, axis=None, keep_dims=False):
74
107
  """
@@ -1598,17 +1631,7 @@ def trace(x, offset=0, axis1=0, axis2=1, dtype=None):
1598
1631
  >>> print(x.trace())
1599
1632
  3.0
1600
1633
  """
1601
- if offset == 0 and axis1 == 0 and axis2 == 1 and dtype is None:
1602
- return F.trace(x)
1603
- d = x.diagonal(offset, axis1=axis1, axis2=axis2)
1604
- shape = d.shape
1605
- if dtype is None:
1606
- dtype = d.dtype
1607
- dtype = check_astype_dtype_const(dtype)
1608
- if shape[-1] == 0:
1609
- return F.fill(dtype, shape[:-1], 0)
1610
- res = F.reduce_sum(d.astype(mstype.float32), -1)
1611
- return res.astype(dtype)
1634
+ return trace_v2_op(x, offset, axis1, axis2, dtype)
1612
1635
 
1613
1636
 
1614
1637
  def take(x, indices, axis=None, mode='clip'):
@@ -1794,7 +1817,7 @@ def searchsorted(x, v, side='left', sorter=None):
1794
1817
  no suitable index, return either 0 or N (where N is the length of `a`).
1795
1818
  sorter (Union[int, float, bool, list, tuple, Tensor]): 1-D optional array of
1796
1819
  integer indices that sort array `a` into ascending order. They are typically
1797
- the result of argsort.
1820
+ the result of argsort. CPU and GPU can only use default values
1798
1821
 
1799
1822
  Returns:
1800
1823
  Tensor, array of insertion points with the same shape as `v`.
@@ -2435,6 +2458,7 @@ def list_func(data):
2435
2458
  ret = ret + F.make_list(i)
2436
2459
  return ret
2437
2460
 
2461
+
2438
2462
  def tuple_func(data):
2439
2463
  """Implementation of `tuple`."""
2440
2464
  if isinstance(data, (CSRTensor, COOTensor, RowTensorInner)):
@@ -2453,7 +2477,7 @@ def tuple_func(data):
2453
2477
 
2454
2478
 
2455
2479
  def ms_zip(*data):
2456
- """Implementation of `tuple`."""
2480
+ """Packs elements in the corresponding positions in multiple sequences into tuples."""
2457
2481
  x = ()
2458
2482
  for i in data:
2459
2483
  if isinstance(i, Tensor):
@@ -3002,7 +3026,7 @@ def tensor_scatter_mul(input_x, indices, updates):
3002
3026
  `indices`, with values from `updates`. When multiple value are given for the same index,
3003
3027
  the output result will be the division of values.
3004
3028
  """
3005
- return F.tensor_sactter_mul(input_x, indices, updates)
3029
+ return F.tensor_scatter_mul(input_x, indices, updates)
3006
3030
 
3007
3031
 
3008
3032
  def tensor_sactter_div(input_x, indices, updates):
@@ -3813,6 +3837,20 @@ def addmm(x, mat1, mat2, *, beta=1, alpha=1):
3813
3837
  return F.addmm(x, mat1, mat2, beta=beta, alpha=alpha)
3814
3838
 
3815
3839
 
3840
+ def addmm_(self, mat1, mat2, *, beta=1, alpha=1):
3841
+ r"""
3842
+ For details, please refer to :func:`mindspore.ops.addmm`.
3843
+
3844
+ .. note::
3845
+ The output results are directly updated in the Tensor.
3846
+
3847
+ .. warning::
3848
+ This is an experimental API that is subject to change or deletion.
3849
+
3850
+ """
3851
+ return inplace_addmm_op(self, mat1, mat2, beta=beta, alpha=alpha)
3852
+
3853
+
3816
3854
  def addmv(x, mat, vec, beta=1, alpha=1):
3817
3855
  r"""
3818
3856
  Multiplies matrix `mat` and vector `vec`. The vector `x` is added to the final result.