mindspore 2.7.0rc1__cp311-cp311-win_amd64.whl → 2.7.1__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (370) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +5 -2
  3. mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
  6. mindspore/_checkparam.py +2 -2
  7. mindspore/_extends/builtin_operations.py +3 -3
  8. mindspore/_extends/parallel_compile/akg_compiler/custom.py +1109 -0
  9. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
  10. mindspore/_extends/parse/__init__.py +3 -3
  11. mindspore/_extends/parse/compile_config.py +24 -1
  12. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +6 -3
  13. mindspore/_extends/parse/parser.py +28 -22
  14. mindspore/_extends/parse/resources.py +1 -1
  15. mindspore/_extends/parse/standard_method.py +23 -2
  16. mindspore/_extends/parse/trope.py +2 -1
  17. mindspore/_extends/pijit/pijit_func_white_list.py +9 -27
  18. mindspore/amp.py +0 -18
  19. mindspore/avcodec-59.dll +0 -0
  20. mindspore/avdevice-59.dll +0 -0
  21. mindspore/avfilter-8.dll +0 -0
  22. mindspore/avformat-59.dll +0 -0
  23. mindspore/avutil-57.dll +0 -0
  24. mindspore/boost/base.py +29 -2
  25. mindspore/common/__init__.py +18 -12
  26. mindspore/common/_decorator.py +3 -2
  27. mindspore/common/_grad_function.py +3 -1
  28. mindspore/common/_tensor_cpp_method.py +1 -1
  29. mindspore/common/_tensor_docs.py +371 -96
  30. mindspore/common/_utils.py +7 -43
  31. mindspore/common/api.py +434 -135
  32. mindspore/common/dtype.py +98 -57
  33. mindspore/common/dump.py +7 -108
  34. mindspore/common/dynamic_shape/__init__.py +0 -0
  35. mindspore/common/{auto_dynamic_shape.py → dynamic_shape/auto_dynamic_shape.py} +15 -23
  36. mindspore/common/dynamic_shape/enable_dynamic.py +197 -0
  37. mindspore/common/file_system.py +59 -9
  38. mindspore/common/hook_handle.py +82 -3
  39. mindspore/common/jit_config.py +5 -1
  40. mindspore/common/jit_trace.py +27 -12
  41. mindspore/common/lazy_inline.py +5 -3
  42. mindspore/common/np_dtype.py +3 -3
  43. mindspore/common/parameter.py +17 -127
  44. mindspore/common/recompute.py +4 -13
  45. mindspore/common/tensor.py +50 -217
  46. mindspore/communication/_comm_helper.py +11 -1
  47. mindspore/communication/comm_func.py +138 -4
  48. mindspore/communication/management.py +85 -1
  49. mindspore/config/op_info.config +0 -15
  50. mindspore/context.py +20 -106
  51. mindspore/dataset/__init__.py +1 -1
  52. mindspore/dataset/audio/transforms.py +1 -1
  53. mindspore/dataset/core/config.py +35 -1
  54. mindspore/dataset/engine/datasets.py +338 -319
  55. mindspore/dataset/engine/datasets_user_defined.py +38 -22
  56. mindspore/dataset/engine/datasets_vision.py +1 -1
  57. mindspore/dataset/engine/validators.py +1 -15
  58. mindspore/dataset/transforms/c_transforms.py +2 -2
  59. mindspore/dataset/transforms/transforms.py +3 -3
  60. mindspore/dataset/vision/__init__.py +1 -1
  61. mindspore/dataset/vision/py_transforms.py +8 -8
  62. mindspore/dataset/vision/transforms.py +17 -5
  63. mindspore/dataset/vision/utils.py +632 -21
  64. mindspore/device_context/ascend/op_tuning.py +35 -1
  65. mindspore/dnnl.dll +0 -0
  66. mindspore/{profiler/common/validator → graph}/__init__.py +9 -1
  67. mindspore/graph/custom_pass.py +55 -0
  68. mindspore/include/api/cell.h +28 -4
  69. mindspore/include/api/cfg.h +24 -7
  70. mindspore/include/api/context.h +1 -0
  71. mindspore/include/api/delegate.h +0 -2
  72. mindspore/include/api/dual_abi_helper.h +100 -19
  73. mindspore/include/api/graph.h +14 -1
  74. mindspore/include/api/kernel.h +16 -3
  75. mindspore/include/api/kernel_api.h +9 -1
  76. mindspore/include/api/metrics/accuracy.h +9 -0
  77. mindspore/include/api/model.h +5 -1
  78. mindspore/include/api/model_group.h +4 -0
  79. mindspore/include/api/model_parallel_runner.h +2 -0
  80. mindspore/include/api/status.h +48 -10
  81. mindspore/include/api/types.h +6 -1
  82. mindspore/include/dataset/constants.h +9 -0
  83. mindspore/include/dataset/execute.h +2 -2
  84. mindspore/jpeg62.dll +0 -0
  85. mindspore/mindrecord/__init__.py +3 -3
  86. mindspore/mindrecord/common/exceptions.py +1 -0
  87. mindspore/mindrecord/config.py +1 -1
  88. mindspore/{parallel/mpi → mindrecord/core}/__init__.py +4 -1
  89. mindspore/mindrecord/{shardheader.py → core/shardheader.py} +2 -1
  90. mindspore/mindrecord/{shardindexgenerator.py → core/shardindexgenerator.py} +1 -1
  91. mindspore/mindrecord/{shardreader.py → core/shardreader.py} +2 -1
  92. mindspore/mindrecord/{shardsegment.py → core/shardsegment.py} +2 -2
  93. mindspore/mindrecord/{shardutils.py → core/shardutils.py} +1 -1
  94. mindspore/mindrecord/{shardwriter.py → core/shardwriter.py} +1 -1
  95. mindspore/mindrecord/filereader.py +4 -4
  96. mindspore/mindrecord/filewriter.py +5 -5
  97. mindspore/mindrecord/mindpage.py +2 -2
  98. mindspore/mindrecord/tools/cifar10.py +4 -3
  99. mindspore/mindrecord/tools/cifar100.py +1 -1
  100. mindspore/mindrecord/tools/cifar100_to_mr.py +1 -1
  101. mindspore/mindrecord/tools/cifar10_to_mr.py +6 -6
  102. mindspore/mindrecord/tools/csv_to_mr.py +1 -1
  103. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  104. mindspore/mindrecord/tools/mnist_to_mr.py +1 -1
  105. mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -1
  106. mindspore/mindspore_backend_common.dll +0 -0
  107. mindspore/mindspore_backend_manager.dll +0 -0
  108. mindspore/mindspore_cluster.dll +0 -0
  109. mindspore/mindspore_common.dll +0 -0
  110. mindspore/mindspore_core.dll +0 -0
  111. mindspore/mindspore_cpu.dll +0 -0
  112. mindspore/mindspore_dump.dll +0 -0
  113. mindspore/mindspore_frontend.dll +0 -0
  114. mindspore/mindspore_glog.dll +0 -0
  115. mindspore/mindspore_hardware_abstract.dll +0 -0
  116. mindspore/mindspore_memory_pool.dll +0 -0
  117. mindspore/mindspore_ms_backend.dll +0 -0
  118. mindspore/mindspore_ops.dll +0 -0
  119. mindspore/{mindspore_ops_host.dll → mindspore_ops_cpu.dll} +0 -0
  120. mindspore/mindspore_profiler.dll +0 -0
  121. mindspore/mindspore_pyboost.dll +0 -0
  122. mindspore/mindspore_pynative.dll +0 -0
  123. mindspore/mindspore_runtime_pipeline.dll +0 -0
  124. mindspore/mindspore_runtime_utils.dll +0 -0
  125. mindspore/mindspore_tools.dll +0 -0
  126. mindspore/mint/__init__.py +15 -10
  127. mindspore/mint/distributed/__init__.py +4 -0
  128. mindspore/mint/distributed/distributed.py +392 -69
  129. mindspore/mint/nn/__init__.py +2 -16
  130. mindspore/mint/nn/functional.py +4 -110
  131. mindspore/mint/nn/layer/__init__.py +0 -2
  132. mindspore/mint/nn/layer/_functions.py +1 -2
  133. mindspore/mint/nn/layer/activation.py +0 -6
  134. mindspore/mint/nn/layer/basic.py +0 -47
  135. mindspore/mint/nn/layer/conv.py +10 -10
  136. mindspore/mint/nn/layer/normalization.py +11 -16
  137. mindspore/mint/nn/layer/pooling.py +0 -4
  138. mindspore/nn/__init__.py +1 -3
  139. mindspore/nn/cell.py +231 -239
  140. mindspore/nn/layer/activation.py +4 -2
  141. mindspore/nn/layer/basic.py +56 -14
  142. mindspore/nn/layer/container.py +16 -0
  143. mindspore/nn/layer/embedding.py +4 -169
  144. mindspore/nn/layer/image.py +1 -1
  145. mindspore/nn/layer/normalization.py +2 -1
  146. mindspore/nn/layer/thor_layer.py +4 -85
  147. mindspore/nn/optim/ada_grad.py +0 -1
  148. mindspore/nn/optim/adafactor.py +0 -1
  149. mindspore/nn/optim/adam.py +32 -127
  150. mindspore/nn/optim/adamax.py +0 -1
  151. mindspore/nn/optim/asgd.py +0 -1
  152. mindspore/nn/optim/ftrl.py +8 -102
  153. mindspore/nn/optim/lamb.py +1 -4
  154. mindspore/nn/optim/lars.py +0 -3
  155. mindspore/nn/optim/lazyadam.py +25 -218
  156. mindspore/nn/optim/momentum.py +5 -43
  157. mindspore/nn/optim/optimizer.py +6 -55
  158. mindspore/nn/optim/proximal_ada_grad.py +0 -1
  159. mindspore/nn/optim/rmsprop.py +0 -1
  160. mindspore/nn/optim/rprop.py +0 -1
  161. mindspore/nn/optim/sgd.py +0 -1
  162. mindspore/nn/optim/tft_wrapper.py +2 -4
  163. mindspore/nn/optim/thor.py +0 -2
  164. mindspore/nn/probability/bijector/bijector.py +7 -8
  165. mindspore/nn/probability/bijector/gumbel_cdf.py +2 -2
  166. mindspore/nn/probability/bijector/power_transform.py +20 -21
  167. mindspore/nn/probability/bijector/scalar_affine.py +5 -5
  168. mindspore/nn/probability/bijector/softplus.py +13 -14
  169. mindspore/nn/probability/distribution/_utils/utils.py +2 -2
  170. mindspore/nn/wrap/cell_wrapper.py +39 -5
  171. mindspore/nn/wrap/grad_reducer.py +4 -89
  172. mindspore/numpy/array_creations.py +4 -4
  173. mindspore/numpy/fft.py +9 -9
  174. mindspore/numpy/utils_const.py +1 -1
  175. mindspore/{nn/reinforcement → onnx}/__init__.py +5 -8
  176. mindspore/onnx/onnx_export.py +137 -0
  177. mindspore/opencv_core4110.dll +0 -0
  178. mindspore/opencv_imgcodecs4110.dll +0 -0
  179. mindspore/{opencv_imgproc452.dll → opencv_imgproc4110.dll} +0 -0
  180. mindspore/ops/__init__.py +2 -0
  181. mindspore/ops/_grad_experimental/grad_comm_ops.py +38 -2
  182. mindspore/ops/_grad_experimental/grad_inner_ops.py +0 -9
  183. mindspore/ops/_op_impl/aicpu/__init__.py +0 -10
  184. mindspore/ops/_op_impl/cpu/__init__.py +1 -5
  185. mindspore/ops/_op_impl/cpu/{buffer_append.py → joinedstr_op.py} +8 -8
  186. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +28 -24
  187. mindspore/ops/auto_generate/gen_extend_func.py +6 -11
  188. mindspore/ops/auto_generate/gen_ops_def.py +385 -154
  189. mindspore/ops/auto_generate/gen_ops_prim.py +5676 -5167
  190. mindspore/ops/communication.py +97 -0
  191. mindspore/ops/composite/__init__.py +5 -2
  192. mindspore/ops/composite/base.py +16 -2
  193. mindspore/ops/composite/multitype_ops/__init__.py +3 -1
  194. mindspore/ops/composite/multitype_ops/_compile_utils.py +150 -8
  195. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -1
  196. mindspore/ops/composite/multitype_ops/add_impl.py +7 -0
  197. mindspore/ops/composite/multitype_ops/mod_impl.py +27 -0
  198. mindspore/ops/function/__init__.py +2 -0
  199. mindspore/ops/function/array_func.py +24 -18
  200. mindspore/ops/function/comm_func.py +3883 -0
  201. mindspore/ops/function/debug_func.py +7 -6
  202. mindspore/ops/function/grad/grad_func.py +4 -12
  203. mindspore/ops/function/math_func.py +89 -86
  204. mindspore/ops/function/nn_func.py +92 -313
  205. mindspore/ops/function/random_func.py +9 -18
  206. mindspore/ops/functional.py +4 -1
  207. mindspore/ops/functional_overload.py +377 -30
  208. mindspore/ops/operations/__init__.py +2 -5
  209. mindspore/ops/operations/_custom_ops_utils.py +7 -9
  210. mindspore/ops/operations/_inner_ops.py +12 -50
  211. mindspore/ops/operations/_rl_inner_ops.py +0 -933
  212. mindspore/ops/operations/array_ops.py +5 -50
  213. mindspore/ops/operations/comm_ops.py +95 -17
  214. mindspore/ops/operations/custom_ops.py +237 -22
  215. mindspore/ops/operations/debug_ops.py +33 -35
  216. mindspore/ops/operations/manually_defined/ops_def.py +39 -318
  217. mindspore/ops/operations/math_ops.py +5 -5
  218. mindspore/ops/operations/nn_ops.py +3 -3
  219. mindspore/ops/operations/sparse_ops.py +0 -83
  220. mindspore/ops/primitive.py +4 -27
  221. mindspore/ops/tensor_method.py +88 -10
  222. mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +5 -5
  223. mindspore/ops_generate/aclnn/gen_aclnn_implement.py +8 -8
  224. mindspore/ops_generate/api/functions_cc_generator.py +53 -4
  225. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +25 -11
  226. mindspore/ops_generate/common/gen_constants.py +11 -10
  227. mindspore/ops_generate/common/op_proto.py +18 -1
  228. mindspore/ops_generate/common/template.py +102 -245
  229. mindspore/ops_generate/common/template_utils.py +212 -0
  230. mindspore/ops_generate/gen_custom_ops.py +69 -0
  231. mindspore/ops_generate/op_def/ops_def_cc_generator.py +78 -7
  232. mindspore/ops_generate/op_def_py/base_op_prim_py_generator.py +360 -0
  233. mindspore/ops_generate/op_def_py/custom_op_prim_py_generator.py +140 -0
  234. mindspore/ops_generate/op_def_py/op_def_py_generator.py +54 -7
  235. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -312
  236. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +74 -17
  237. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +22 -5
  238. mindspore/ops_generate/pyboost/gen_pyboost_func.py +0 -16
  239. mindspore/ops_generate/pyboost/op_template_parser.py +3 -2
  240. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +21 -5
  241. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +2 -2
  242. mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +30 -10
  243. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +10 -3
  244. mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +1 -1
  245. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +19 -9
  246. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +71 -28
  247. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +10 -9
  248. mindspore/ops_generate/pyboost/pyboost_utils.py +27 -16
  249. mindspore/ops_generate/resources/yaml_loader.py +13 -0
  250. mindspore/ops_generate/tensor_py_cc_generator.py +2 -2
  251. mindspore/parallel/_auto_parallel_context.py +5 -15
  252. mindspore/parallel/_cell_wrapper.py +1 -1
  253. mindspore/parallel/_parallel_serialization.py +4 -6
  254. mindspore/parallel/_ps_context.py +2 -2
  255. mindspore/parallel/_utils.py +34 -17
  256. mindspore/parallel/auto_parallel.py +23 -9
  257. mindspore/parallel/checkpoint_transform.py +20 -2
  258. mindspore/parallel/cluster/process_entity/_api.py +28 -33
  259. mindspore/parallel/cluster/process_entity/_utils.py +9 -5
  260. mindspore/parallel/cluster/run.py +5 -3
  261. mindspore/{experimental/llm_boost/ascend_native → parallel/distributed}/__init__.py +21 -22
  262. mindspore/parallel/distributed/distributed_data_parallel.py +393 -0
  263. mindspore/parallel/distributed/flatten_grad_buffer.py +295 -0
  264. mindspore/parallel/function/reshard_func.py +6 -5
  265. mindspore/parallel/nn/parallel_cell_wrapper.py +40 -3
  266. mindspore/parallel/nn/parallel_grad_reducer.py +0 -8
  267. mindspore/parallel/shard.py +7 -21
  268. mindspore/parallel/strategy.py +336 -0
  269. mindspore/parallel/transform_safetensors.py +127 -20
  270. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +13 -9
  271. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +1 -1
  272. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +1 -1
  273. mindspore/profiler/common/constant.py +5 -0
  274. mindspore/profiler/common/file_manager.py +9 -0
  275. mindspore/profiler/common/msprof_cmd_tool.py +40 -4
  276. mindspore/profiler/common/path_manager.py +65 -24
  277. mindspore/profiler/common/profiler_context.py +27 -14
  278. mindspore/profiler/common/profiler_info.py +3 -3
  279. mindspore/profiler/common/profiler_meta_data.py +1 -0
  280. mindspore/profiler/common/profiler_op_analyse.py +10 -6
  281. mindspore/profiler/common/profiler_path_manager.py +13 -0
  282. mindspore/profiler/common/util.py +30 -3
  283. mindspore/profiler/dynamic_profiler.py +91 -46
  284. mindspore/profiler/envprofiler.py +30 -5
  285. mindspore/profiler/experimental_config.py +18 -2
  286. mindspore/profiler/platform/cpu_profiler.py +10 -4
  287. mindspore/profiler/platform/npu_profiler.py +34 -7
  288. mindspore/profiler/profiler.py +193 -145
  289. mindspore/profiler/profiler_action_controller.py +1 -1
  290. mindspore/profiler/profiler_interface.py +2 -2
  291. mindspore/rewrite/symbol_tree/symbol_tree.py +1 -1
  292. mindspore/run_check/_check_version.py +108 -24
  293. mindspore/runtime/__init__.py +9 -6
  294. mindspore/runtime/executor.py +35 -0
  295. mindspore/runtime/memory.py +113 -0
  296. mindspore/runtime/thread_bind_core.py +1 -1
  297. mindspore/swresample-4.dll +0 -0
  298. mindspore/swscale-6.dll +0 -0
  299. mindspore/tinyxml2.dll +0 -0
  300. mindspore/{experimental/llm_boost → tools}/__init__.py +5 -5
  301. mindspore/tools/data_dump.py +130 -0
  302. mindspore/tools/sdc_detect.py +91 -0
  303. mindspore/tools/stress_detect.py +63 -0
  304. mindspore/train/__init__.py +6 -6
  305. mindspore/train/_utils.py +8 -21
  306. mindspore/train/amp.py +6 -7
  307. mindspore/train/callback/_callback.py +2 -1
  308. mindspore/train/callback/_checkpoint.py +1 -17
  309. mindspore/train/callback/_flops_collector.py +10 -6
  310. mindspore/train/callback/_train_fault_tolerance.py +72 -25
  311. mindspore/train/data_sink.py +5 -9
  312. mindspore/train/dataset_helper.py +5 -5
  313. mindspore/train/model.py +41 -230
  314. mindspore/train/serialization.py +160 -401
  315. mindspore/train/train_thor/model_thor.py +2 -2
  316. mindspore/turbojpeg.dll +0 -0
  317. mindspore/utils/__init__.py +6 -3
  318. mindspore/utils/dlpack.py +92 -0
  319. mindspore/utils/dryrun.py +1 -1
  320. mindspore/utils/runtime_execution_order_check.py +10 -0
  321. mindspore/utils/sdc_detect.py +14 -12
  322. mindspore/utils/stress_detect.py +43 -0
  323. mindspore/utils/utils.py +152 -16
  324. mindspore/version.py +1 -1
  325. {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/METADATA +3 -2
  326. {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/RECORD +330 -344
  327. mindspore/_extends/remote/kernel_build_server_ascend.py +0 -75
  328. mindspore/communication/_hccl_management.py +0 -297
  329. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -207
  330. mindspore/experimental/llm_boost/ascend_native/llm_boost.py +0 -52
  331. mindspore/experimental/llm_boost/atb/__init__.py +0 -23
  332. mindspore/experimental/llm_boost/atb/boost_base.py +0 -385
  333. mindspore/experimental/llm_boost/atb/llama_boost.py +0 -137
  334. mindspore/experimental/llm_boost/atb/qwen_boost.py +0 -124
  335. mindspore/experimental/llm_boost/register.py +0 -130
  336. mindspore/experimental/llm_boost/utils.py +0 -31
  337. mindspore/include/OWNERS +0 -7
  338. mindspore/mindspore_cpu_res_manager.dll +0 -0
  339. mindspore/mindspore_ops_kernel_common.dll +0 -0
  340. mindspore/mindspore_res_manager.dll +0 -0
  341. mindspore/nn/optim/_dist_optimizer_registry.py +0 -111
  342. mindspore/nn/reinforcement/_batch_read_write.py +0 -142
  343. mindspore/nn/reinforcement/_tensors_queue.py +0 -152
  344. mindspore/nn/reinforcement/tensor_array.py +0 -145
  345. mindspore/opencv_core452.dll +0 -0
  346. mindspore/opencv_imgcodecs452.dll +0 -0
  347. mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +0 -113
  348. mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +0 -96
  349. mindspore/ops/_op_impl/aicpu/sparse_cross.py +0 -42
  350. mindspore/ops/_op_impl/cpu/buffer_get.py +0 -28
  351. mindspore/ops/_op_impl/cpu/buffer_sample.py +0 -28
  352. mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +0 -42
  353. mindspore/ops/operations/_tensor_array.py +0 -359
  354. mindspore/ops/operations/rl_ops.py +0 -288
  355. mindspore/parallel/_offload_context.py +0 -275
  356. mindspore/parallel/_recovery_context.py +0 -115
  357. mindspore/parallel/_transformer/__init__.py +0 -35
  358. mindspore/parallel/_transformer/layers.py +0 -765
  359. mindspore/parallel/_transformer/loss.py +0 -251
  360. mindspore/parallel/_transformer/moe.py +0 -693
  361. mindspore/parallel/_transformer/op_parallel_config.py +0 -222
  362. mindspore/parallel/_transformer/transformer.py +0 -3124
  363. mindspore/parallel/mpi/_mpi_config.py +0 -116
  364. mindspore/profiler/common/validator/validate_path.py +0 -84
  365. mindspore/train/memory_profiling_pb2.py +0 -298
  366. mindspore/utils/hooks.py +0 -81
  367. /mindspore/common/{_auto_dynamic.py → dynamic_shape/_auto_dynamic.py} +0 -0
  368. {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/WHEEL +0 -0
  369. {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/entry_points.txt +0 -0
  370. {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/top_level.txt +0 -0
@@ -13,14 +13,11 @@
13
13
  # limitations under the License.
14
14
  # ============================================================================
15
15
  """debug_ops"""
16
- import stat
17
- from pathlib import Path
18
-
19
- import numpy as np
16
+ import inspect
20
17
  from mindspore import log as logger
21
18
  from mindspore._c_expression import security, HookType
22
19
  from mindspore._c_expression import TensorPy as Tensor_
23
- from mindspore._c_expression import _tensordump_process_file
20
+ from mindspore._c_expression import _tensordump_exec
24
21
  from mindspore import _checkparam as validator
25
22
  from mindspore.common import dtype as mstype
26
23
  from mindspore.common.parameter import Parameter
@@ -309,29 +306,17 @@ class TensorDump(Primitive):
309
306
  """Initialize TensorDump."""
310
307
  if security.enable_security():
311
308
  raise ValueError('The TensorDump is not supported, please without `-s on` and recompile source.')
309
+ if input_output not in ['in', 'out']:
310
+ raise ValueError(f"The 'input_output' argument should be one of ['in', 'out'], but got: {input_output}")
312
311
  self.add_prim_attr("side_effect_io", True)
313
312
  self.add_prim_attr("channel_name", "ms_tensor_dump")
314
313
 
315
- def _save_file(self, file, data):
316
- file = Path(file)
317
- if file.exists():
318
- file.chmod(stat.S_IWUSR)
319
- np.save(file, data)
320
- file.chmod(stat.S_IRUSR)
321
-
322
314
  def __call__(self, file, input_x):
323
315
  validator.check_value_type('file', file, [str], self.__class__.__name__)
324
316
  if not file:
325
317
  raise ValueError("For 'TensorDump', the input argument[file] cannot be an empty string.")
326
318
  validator.check_value_type('input_x', input_x, [Tensor], self.__class__.__name__)
327
-
328
- dtype = input_x.dtype
329
- file = _tensordump_process_file(file, str(dtype))
330
- if not file:
331
- return
332
- if dtype == mstype.bfloat16:
333
- input_x = P.Cast()(input_x, mstype.float32)
334
- self._save_file(file, input_x.asnumpy())
319
+ _tensordump_exec(file, input_x)
335
320
 
336
321
 
337
322
  class HistogramSummary(Primitive):
@@ -501,13 +486,11 @@ class DumpGradient(Primitive):
501
486
  def __init__(self):
502
487
  pass
503
488
 
504
- def _dump_hook(self, dout):
505
- P.TensorDump()(self.bwd_dump_path, dout)
506
- return dout
507
-
508
489
  def __call__(self, path, x, input_output):
509
- self.bwd_dump_path = path
510
- x = P.InsertGradientOf(self._dump_hook)(x)
490
+ def _dump_hook(dout):
491
+ P.TensorDump()(path, dout)
492
+ return dout
493
+ x = P.InsertGradientOf(_dump_hook)(x)
511
494
  return x
512
495
 
513
496
 
@@ -529,14 +512,15 @@ class Morph(PrimitiveWithInfer):
529
512
 
530
513
  .. note::
531
514
  - This primitive is only supported in GRAPH_MODE.
532
- - `fn` must satisfy the syntax constraints of the graph mode.
533
- - Users do not need to implement a custom backward function.
515
+ - A user-defined bprop (by argument: `bprop_fn`) is allowed for `Morph`.
516
+ - `fn` and `bprop_fn` must satisfy the syntax constraints of the graph mode.
534
517
  - `vararg`, `kwarg`, `kwonlyargs` and free variables are not supported in user-defined function.
535
518
 
536
519
  Args:
537
- fn (Function): Mindspore's function, user-defined function.
538
- infer_shape (Function): Mindspore's function, user-defined infer_shape function.
539
- infer_dtype (Function): Mindspore's function, user-defined infer_dtype function.
520
+ fn (Function): MindSpore's function, user-defined function.
521
+ infer_shape (Function): MindSpore's function, user-defined infer_shape function.
522
+ infer_dtype (Function): MindSpore's function, user-defined infer_dtype function.
523
+ bprop_fn (Function, optional): MindSpore's function, user-defined bprop function, default: ``None``.
540
524
 
541
525
  Inputs:
542
526
  The inputs of user-defined `fn`.
@@ -590,21 +574,35 @@ class Morph(PrimitiveWithInfer):
590
574
  >>> weight0_grad = bwd_out[1][0].asnumpy()
591
575
  >>> weight1_grad = bwd_out[1][1].asnumpy()
592
576
  >>> print("x_grad", x_grad)
593
- >>> print("weight0_grad", weight0_grad)
594
- >>> print("weight1_grad", weight1_grad)
595
577
  x_grad [ 400. 1000. 1800.]
578
+ >>> print("weight0_grad", weight0_grad)
596
579
  weight0_grad [2800. 4000. 5400.]
580
+ >>> print("weight1_grad", weight1_grad)
597
581
  weight1_grad [ 700. 1600. 2700.]
598
582
  """
599
583
  @prim_attr_register
600
- def __init__(self, fn, infer_shape, infer_dtype):
584
+ def __init__(self, fn, infer_shape, infer_dtype, bprop_fn=None):
601
585
  self.add_prim_attr('side_effect_backprop', True)
602
586
  self.add_prim_attr('side_effect_mem', True)
603
587
  self.add_prim_attr('side_effect_io', True)
604
- self.add_prim_attr('__metamorphosis__', fn)
605
588
  self._infer_shape = infer_shape
606
589
  self._infer_dtype = infer_dtype
607
590
 
591
+ self.add_prim_attr('__metamorphosis__', True)
592
+ self.__morph_fn__ = fn
593
+ self.__morph_bprop_fn__ = None
594
+ if bprop_fn:
595
+ self._check_fn_supported(fn)
596
+ self.__morph_bprop_fn__ = bprop_fn
597
+
598
+ def _check_fn_supported(self, fn):
599
+ fn_sig = inspect.signature(fn)
600
+ for param in fn_sig.parameters.values():
601
+ if not (param.kind is inspect.Parameter.POSITIONAL_OR_KEYWORD and param.default is inspect.Parameter.empty):
602
+ raise ValueError(f"When use `bprop` in Morph, Morph `fn` only support positional or keyword parameters "
603
+ f"with default value is empty, but got param '{param.name}' "
604
+ f"of kind '{param.kind.name}' with default value '{param.default}'.")
605
+
608
606
  def infer_shape(self, *args):
609
607
  return self._infer_shape(*args)
610
608
 
@@ -977,7 +977,7 @@ class ScalarToTensor(PrimitiveWithInfer):
977
977
  def __call__(self, x, dtype=mstype.float32):
978
978
  validator.check_value_type("x", x, [bool, int, float], self.name)
979
979
  validator.check_subclass("dtype", dtype, mstype.number, self.name)
980
- data_type = mstype.dtype_to_nptype(dtype)
980
+ data_type = mstype._dtype_to_nptype(dtype) # pylint:disable=protected-access
981
981
  return Tensor(np.array(x, data_type), dtype=dtype)
982
982
 
983
983
 
@@ -1149,7 +1149,7 @@ def scalar_cast(input_x, input_y):
1149
1149
  Args:
1150
1150
  input_x (scalar): The input scalar.
1151
1151
  input_y (mindspore.dtype): The type to be cast. Only constant value is allowed.
1152
- The value should only be mindspore.int64, mindspore.float64, or mindspore.bool\_.
1152
+ The value should only be mindspore.int64, mindspore.float64, or mindspore.bool.
1153
1153
 
1154
1154
  Returns:
1155
1155
  Scalar, the type is the same as the python type corresponding to `input_y`.
@@ -1294,308 +1294,6 @@ class TypeAs(Primitive):
1294
1294
  return pyboost_type_as(self, [input, other])
1295
1295
 
1296
1296
 
1297
- def to_sequence(val):
1298
- """
1299
- to_sequence
1300
- """
1301
- if isinstance(val, (tuple, list)):
1302
- return tuple(val)
1303
- return (val,)
1304
-
1305
-
1306
- class EmbeddingTableExport(Primitive):
1307
- """
1308
- EmbeddingTableExport
1309
- """
1310
-
1311
- @prim_attr_register
1312
- def __init__(self, embedding_dim, value_total_len, export_mode="all",
1313
- only_var_flag=False, file_type="bin", table_name=(),
1314
- filter_export_flag=False, steps_to_live_list=()):
1315
- """Initialize EmbeddingTableExport"""
1316
- self.add_prim_attr("_process_node_engine_id", "PS")
1317
-
1318
-
1319
- class EmbeddingTableImport(Primitive):
1320
- """
1321
- EmbeddingTableImport
1322
- """
1323
-
1324
- @prim_attr_register
1325
- def __init__(self, embedding_dim, value_total_len,
1326
- only_var_flag=False, file_type="bin", table_name=()):
1327
- """Initialize EmbeddingTableImport"""
1328
- self.add_prim_attr("_process_node_engine_id", "PS")
1329
-
1330
-
1331
- class EmbeddingComputeVarImport(Primitive):
1332
- """
1333
- EmbeddingComputeVarImport
1334
- """
1335
-
1336
- @prim_attr_register
1337
- def __init__(self, table_name=()):
1338
- """Initialize EmbeddingComputeVarImport"""
1339
- self.add_prim_attr("_process_node_engine_id", "PS")
1340
-
1341
-
1342
- class EmbeddingComputeVarExport(Primitive):
1343
- """
1344
- EmbeddingComputeVarExport
1345
- """
1346
-
1347
- @prim_attr_register
1348
- def __init__(self, table_name=()):
1349
- """Initialize EmbeddingComputeVarExport"""
1350
- self.add_prim_attr("_process_node_engine_id", "PS")
1351
-
1352
-
1353
- class InitEmbeddingHashmap(Primitive):
1354
- """
1355
- InitEmbeddingHashmap
1356
- """
1357
- @prim_attr_register
1358
- def __init__(self, value_total_len, embedding_dim, _table_id,
1359
- bucket_size=0, dtype=mstype.float32, initializer_mode="",
1360
- constant_valu=0., min=-2., max=2., mu=0., sigma=1., seed=0,
1361
- seed2=0, filter_mode="no_filter", optimizer_mode="",
1362
- optimizer_params=()):
1363
- self.add_prim_attr("_process_node_engine_id", "PS")
1364
-
1365
-
1366
- def init_embedding_hashmap(table_id, value_total_len, embedding_dim, _table_id,
1367
- bucket_size=0, dtype=mstype.float32, initializer_mode='',
1368
- constant_value=0.0, min=-2.0, max=2.0, mu=0.0, sigma=1.0,
1369
- seed=0, seed2=0, filter_mode='no_filter',
1370
- optimizer_mode='', optimizer_params=()):
1371
- """
1372
- init_embedding_hashmap
1373
- """
1374
- op = _get_cache_prim(InitEmbeddingHashmap)(value_total_len, embedding_dim, _table_id,
1375
- bucket_size, dtype, initializer_mode,
1376
- constant_value, min, max, mu, sigma, seed,
1377
- seed2, filter_mode, optimizer_mode, optimizer_params)
1378
- return op(table_id)
1379
-
1380
-
1381
- class InitPartitionMap(Primitive):
1382
- """
1383
- InitPartitionMap
1384
- """
1385
- @prim_attr_register
1386
- def __init__(self, _embedding_dim, _max_key_num,
1387
- _ps_num=1, partition_num=65537):
1388
- self.add_prim_attr("_process_node_engine_id", "PS")
1389
-
1390
-
1391
- def init_partition_map(ps_num, ps_ids, _embedding_dim, _max_key_num,
1392
- _ps_num=1, partition_num=65537):
1393
- """
1394
- init_partition_map
1395
- """
1396
- op = _get_cache_prim(InitPartitionMap)(_embedding_dim, _max_key_num, _ps_num, partition_num)
1397
- return op(ps_num, ps_ids)
1398
-
1399
-
1400
- class EmbeddingApplyAdam(Primitive):
1401
- """
1402
- EmbeddingApplyAdam
1403
- """
1404
- @prim_attr_register
1405
- def __init__(self, embedding_dim, _max_key_num, mask_zero=(0,),
1406
- padding_key=(0,), padding_key_mask=(1,),
1407
- completion_key=(0,), completion_key_mask=(1,)):
1408
- self.add_prim_attr("_process_node_engine_id", "PS")
1409
-
1410
-
1411
- class EmbeddingApplyAdamW(Primitive):
1412
- """
1413
- EmbeddingApplyAdam
1414
- """
1415
- @prim_attr_register
1416
- def __init__(self, embedding_dim, _max_key_num, amsgrad=(0,),
1417
- maximize=(0,), mask_zero=(0,), padding_key=(0,),
1418
- padding_key_mask=(1,), completion_key=(0,), completion_key_mask=(1,)):
1419
- self.add_prim_attr("_process_node_engine_id", "PS")
1420
-
1421
-
1422
- class EmbeddingApplyAdaGrad(Primitive):
1423
- """
1424
- EmbeddingApplyAdaGrad
1425
- """
1426
- @prim_attr_register
1427
- def __init__(self, embedding_dim, _max_key_num, mask_zero=(0,),
1428
- padding_key=(0,), padding_key_mask=(1,),
1429
- completion_key=(0,), completion_key_mask=(1,)):
1430
- self.add_prim_attr("_process_node_engine_id", "PS")
1431
-
1432
-
1433
- class EmbeddingApplyFtrl(Primitive):
1434
- """
1435
- EmbeddingApplyFtrl
1436
- """
1437
- @prim_attr_register
1438
- def __init__(self, embedding_dim, _max_key_num, mask_zero=(0,),
1439
- padding_key=(0,), padding_key_mask=(1,),
1440
- completion_key=(0,), completion_key_mask=(1,)):
1441
- self.add_prim_attr("_process_node_engine_id", "PS")
1442
-
1443
-
1444
- class EmbeddingTableFind(Primitive):
1445
- """
1446
- EmbeddingTableFind
1447
- """
1448
- @prim_attr_register
1449
- def __init__(self, embedding_dim, _embedding_dim, _max_key_num,
1450
- _table_id, default_value=(-1.), _use_counter_filter=0):
1451
- self.add_prim_attr("_process_node_engine_id", "PS")
1452
- self.add_prim_attr("_execute_times", 2)
1453
-
1454
-
1455
- def embedding_table_find(table_id, keys, embedding_dim, _max_key_num,
1456
- _table_id, default_value=(-1.0,), _use_counter_filter=0):
1457
- r"""
1458
- embedding_table_find
1459
- """
1460
- _embedding_dim = embedding_dim if isinstance(embedding_dim, int) else embedding_dim[_table_id]
1461
- op = _get_cache_prim(EmbeddingTableFind)(to_sequence(embedding_dim), _embedding_dim,
1462
- _max_key_num, _table_id,
1463
- to_sequence(default_value),
1464
- _use_counter_filter)
1465
- return op(table_id, keys)
1466
-
1467
-
1468
- class EmbeddingTableFindAndInit(Primitive):
1469
- """
1470
- EmbeddingTableFindAndInit
1471
- """
1472
- @prim_attr_register
1473
- def __init__(self, embedding_dim, value_total_len, _embedding_dim, _table_id,
1474
- _max_key_num, initializer_mode=("random_uniform",),
1475
- constant_value=(0.,), min=(-2.,), max=(2.,), mu=(0.,),
1476
- sigma=(1.,), seed=(0,), seed2=(0,),
1477
- filter_mode=("no_filter",), filter_freq=(0,),
1478
- default_key_or_value=(0,), default_key=(0,),
1479
- default_value=(0.,), completion_key=(0,),
1480
- completion_key_mask=(1,), optimizer_mode=(),
1481
- optimizer_params=(), _use_counter_filter=0,
1482
- backward_mode="adam",
1483
- backward_int_params=((0,), (0,), (0,), (1,)),
1484
- backward_float_params=(0.9, 0.99, 0.001, 0.9, 0.999, 1e-08)):
1485
- self.add_prim_attr("_process_node_engine_id", "PS")
1486
- self.add_prim_attr("_execute_times", 2)
1487
-
1488
-
1489
- def embedding_table_find_and_init(table_id, keys, max_grad_norm, parameter, embedding_dim,
1490
- value_total_len, _table_id, _max_key_num,
1491
- initializer_mode=('random_uniform',), constant_value=(0.,),
1492
- min=(-2.,), max=(2.,), mu=(0.,), sigma=(1.,), seed=(0,),
1493
- seed2=(0,), filter_mode=("no_filter",),
1494
- filter_freq=(0,), default_key_or_value=(0,),
1495
- default_key=(0,), default_value=(0.,),
1496
- completion_key=(0,), completion_key_mask=(1,),
1497
- optimizer_mode=(), optimizer_params=(), _use_counter_filter=0,
1498
- backward_mode="adam", backward_int_params=((0,), (0,), (0,), (1,)),
1499
- backward_float_params=(0.9, 0.99, 0.001, 0.9, 0.999, 1e-08)):
1500
- """
1501
- embedding_table_find_and_init
1502
-
1503
- backward_int_params (Union[tuple[tuple[int]], list[list[int]]]):
1504
- - when the backward_mode is 'adam', 'ftrl' or 'adagrad',
1505
- it means [[global_step], mask_zero, padding_key, padding_key_mask]
1506
- - when the backward_mode is 'adamw', it means:
1507
- [[global_step], amsgrad, maximize, mask_zero, padding_key, padding_key_mask]
1508
- backward_float_params (Union[tuple[float], list[float]]):
1509
- - when the backward_mode is 'adam', it means:
1510
- [beta1_power, beta2_power, lr, beta1, beta2, epsilon]
1511
- - when the backward_mode is 'ftrl', it means:
1512
- [lr, lr_power, lambda1, lambda2]
1513
- - when the backward_mode is 'adamw', it means:
1514
- [beta1_power, beta2_power, lr, weight_decay, beta1, beta2, epsilon]
1515
- - when the backward_mode is 'adagrad', it means [lr,]
1516
- """
1517
- _embedding_dim = embedding_dim if isinstance(embedding_dim, int) else embedding_dim[_table_id]
1518
- op = _get_cache_prim(EmbeddingTableFindAndInit)(to_sequence(embedding_dim), to_sequence(value_total_len),
1519
- _embedding_dim, _table_id, _max_key_num,
1520
- to_sequence(initializer_mode),
1521
- to_sequence(constant_value), to_sequence(min),
1522
- to_sequence(max), to_sequence(mu),
1523
- to_sequence(sigma), to_sequence(seed),
1524
- to_sequence(seed2), to_sequence(filter_mode),
1525
- to_sequence(filter_freq), to_sequence(default_key_or_value),
1526
- to_sequence(default_key), to_sequence(default_value),
1527
- to_sequence(completion_key), to_sequence(completion_key_mask),
1528
- to_sequence(optimizer_mode), to_sequence(optimizer_params),
1529
- _use_counter_filter,
1530
- backward_mode, backward_int_params, backward_float_params)
1531
- return op(table_id, keys, max_grad_norm, parameter)
1532
-
1533
-
1534
- class FakeRemoteLookupUniqued(Primitive):
1535
-
1536
- """
1537
- FakeRemoteLookupUniqued
1538
- """
1539
- @prim_attr_register
1540
- def __init__(self, embedding_dim, value_total_len, _embedding_dim, _table_id,
1541
- _max_key_num, initializer_mode=('random_uniform',), constant_value=(0.,),
1542
- min=(-2.,), max=(2.,), mu=(0.,), sigma=(1.,), seed=(0,), seed2=(0,),
1543
- filter_mode=("no_filter",), filter_freq=(0,),
1544
- default_key_or_value=(0,), default_key=(0,), default_value=(0.,),
1545
- completion_key=(0,), completion_key_mask=(1,),
1546
- optimizer_mode=(), optimizer_params=(), _use_counter_filter=0,
1547
- backward_mode="adam", backward_int_params=((0,), (0,), (0,), (1,)),
1548
- backward_float_params=(0.9, 0.99, 0.001, 0.9, 0.999, 1e-08)):
1549
- self.add_prim_attr("_process_node_engine_id", "PS")
1550
- self.add_prim_attr("_execute_times", 2)
1551
-
1552
-
1553
- def fake_remote_lookup_uniqued(table_id, keys, actual_keys_num, unique_indices,
1554
- key_count, max_grad_norm, parameter,
1555
- embedding_dim, value_total_len, _table_id, _max_key_num,
1556
- initializer_mode=('random_uniform',), constant_value=(0.,),
1557
- min=(-2.,), max=(2.,), mu=(0.,), sigma=(1.,), seed=(0,),
1558
- seed2=(0,), filter_mode=("no_filter",),
1559
- filter_freq=(0,), default_key_or_value=(0,),
1560
- default_key=(0,), default_value=(0.,),
1561
- completion_key=(0,), completion_key_mask=(1,),
1562
- optimizer_mode=(), optimizer_params=(), _use_counter_filter=0,
1563
- backward_mode='adam', backward_int_params=((0,), (0,), (0,), (1,)),
1564
- backward_float_params=(0.9, 0.99, 0.001, 0.9, 0.999, 1e-08)):
1565
- """
1566
- fake_remote_lookup_uniqued
1567
-
1568
- backward_mode (str): determine the optimizer used by backpropagation,
1569
- valid values are ["adam", "adamw", "adagrad", "ftrl"]
1570
- backward_int_params (Union[tuple[tuple[int]], list[list[int]]]):
1571
- - when the backward_mode is 'adam', 'ftrl' or 'adagrad',
1572
- it means [[global_step], mask_zero, padding_key, padding_key_mask]
1573
- - when the backward_mode is 'adamw', it means:
1574
- [[global_step], amsgrad, maximize, mask_zero, padding_key, padding_key_mask]
1575
- backward_float_params (Union[tuple[float], list[float]]):
1576
- - when the backward_mode is 'adam', it means:
1577
- [beta1_power, beta2_power, lr, beta1, beta2, epsilon]
1578
- - when the backward_mode is 'ftrl', it means:
1579
- [lr, lr_power, lambda1, lambda2]
1580
- - when the backward_mode is 'adamw', it means:
1581
- [beta1_power, beta2_power, lr, weight_decay, beta1, beta2, epsilon]
1582
- - when the backward_mode is 'adagrad', it means [lr,]
1583
- """
1584
- _embedding_dim = embedding_dim if isinstance(embedding_dim, int) else embedding_dim[_table_id]
1585
- op = _get_cache_prim(FakeRemoteLookupUniqued)(to_sequence(embedding_dim), to_sequence(value_total_len),
1586
- _embedding_dim, _table_id, _max_key_num,
1587
- to_sequence(initializer_mode), to_sequence(constant_value),
1588
- to_sequence(min), to_sequence(max), to_sequence(mu),
1589
- to_sequence(sigma), to_sequence(seed), to_sequence(seed2),
1590
- to_sequence(filter_mode), to_sequence(filter_freq),
1591
- to_sequence(default_key_or_value), to_sequence(default_key),
1592
- to_sequence(default_value), to_sequence(completion_key),
1593
- to_sequence(completion_key_mask), to_sequence(optimizer_mode),
1594
- to_sequence(optimizer_params), _use_counter_filter,
1595
- backward_mode, backward_int_params, backward_float_params)
1596
- return op(table_id, keys, actual_keys_num, unique_indices, key_count, max_grad_norm, parameter)
1597
-
1598
-
1599
1297
  # Following is Python Infer Value.
1600
1298
  # A valid infer value function should be:
1601
1299
  #
@@ -1628,7 +1326,13 @@ def infer_value_for_Concat(tensors, axis):
1628
1326
  return None
1629
1327
 
1630
1328
  tensor_to_concat = [x.asnumpy() for x in tensors]
1631
- return Tensor(np.concatenate(tensor_to_concat, axis), dtype=tensors[0].dtype)
1329
+ out = np.concatenate(tensor_to_concat, axis)
1330
+ if out.dtype != np.float32:
1331
+ return Tensor(out)
1332
+ for x in tensors:
1333
+ if x.dtype in [mstype.float16, mstype.float32]:
1334
+ return Tensor(out)
1335
+ return Tensor(out, dtype=mstype.bfloat16)
1632
1336
 
1633
1337
 
1634
1338
  def infer_value_for_GatherD(input, dim, index):
@@ -1714,7 +1418,7 @@ def infer_value_for_Arange(start, end, step, dtype=None):
1714
1418
  if has_float:
1715
1419
  np_dtype = np.float32
1716
1420
  else:
1717
- np_dtype = mstype.dtype_to_nptype(typing.type_id_to_type(dtype))
1421
+ np_dtype = mstype._dtype_to_nptype(typing.type_id_to_type(dtype)) # pylint:disable=protected-access
1718
1422
  return Tensor(np.arange(start, end, step, dtype=np_dtype))
1719
1423
 
1720
1424
 
@@ -1738,7 +1442,7 @@ def _infer_value_for_ReduceExtand(input_x, axis, keep_dims, dtype, prim_name):
1738
1442
  else:
1739
1443
  axis = tuple(range(len(value.shape)))
1740
1444
  if dtype is not None:
1741
- np_dtype = mstype.dtype_to_nptype(typing.type_id_to_type(dtype))
1445
+ np_dtype = mstype._dtype_to_nptype(typing.type_id_to_type(dtype)) # pylint:disable=protected-access
1742
1446
  value = np_reduce_extand_func(value, axis, dtype=np_dtype, keepdims=keep_dims)
1743
1447
  else:
1744
1448
  value = np_reduce_extand_func(value, axis, keepdims=keep_dims)
@@ -1771,7 +1475,7 @@ def infer_value_for_Cast(x, dst_type_enum=None):
1771
1475
  if x is None or dst_type_enum is None:
1772
1476
  return None
1773
1477
  dst_type = typing.type_id_to_type(dst_type_enum)
1774
- src_type = mstype.get_py_obj_dtype(x)
1478
+ src_type = mstype._get_py_obj_dtype(x) # pylint:disable=protected-access
1775
1479
  validator.check_subclass("input_x", src_type, [mstype.tensor_type, mstype.number], "Cast")
1776
1480
  validator.check_subclass("type", dst_type, mstype.number, "Cast")
1777
1481
 
@@ -1781,7 +1485,7 @@ def infer_value_for_Cast(x, dst_type_enum=None):
1781
1485
  dst_type = dst_type.element_type()
1782
1486
 
1783
1487
  value = None
1784
- np_dst_type = mstype.dtype_to_nptype(dst_type)
1488
+ np_dst_type = mstype._dtype_to_nptype(dst_type) # pylint:disable=protected-access
1785
1489
  if isinstance(x, (int, float)):
1786
1490
  value = Tensor(np.array(x).astype(np_dst_type), dtype=dst_type)
1787
1491
  else:
@@ -1959,11 +1663,27 @@ def infer_value_for_BroadcastTo(x, shape):
1959
1663
  validator.check_value_type("shape", shape, [tuple], "BroadcastTo")
1960
1664
  shape = list(shape)
1961
1665
 
1962
- np_data = np.broadcast_to(x.asnumpy(), shape)
1963
- if 0 in shape:
1666
+ # Resolve -1 entries and support input rank < target rank.
1667
+ input_shape = list(x.shape)
1668
+ target_shape = list(shape)
1669
+ in_rank = len(input_shape)
1670
+ out_rank = len(target_shape)
1671
+ for k in range(1, out_rank + 1):
1672
+ t = target_shape[-k]
1673
+ if t == -1:
1674
+ if k <= in_rank:
1675
+ target_shape[-k] = input_shape[-k]
1676
+ else:
1677
+ pass
1678
+
1679
+ resolved_shape = target_shape
1680
+
1681
+ np_data = np.broadcast_to(x.asnumpy(), resolved_shape)
1682
+ if 0 in resolved_shape:
1964
1683
  init_func = Zero()
1965
1684
  init_func.__enable_zero_dim__ = True
1966
- out = Tensor(shape=shape, dtype=x.dtype, init=init_func)
1685
+ out = Tensor(shape=resolved_shape, dtype=x.dtype, init=init_func)
1686
+ out.init_data()
1967
1687
  return out
1968
1688
  return Tensor(np_data)
1969
1689
 
@@ -2014,6 +1734,7 @@ def infer_value_for_Reshape(x, shape):
2014
1734
  init_func = Zero()
2015
1735
  init_func.__enable_zero_dim__ = True
2016
1736
  out = Tensor(shape=shape, dtype=x.dtype, init=init_func)
1737
+ out.init_data()
2017
1738
  else:
2018
1739
  out = Tensor(x.asnumpy().reshape(shape))
2019
1740
  return out
@@ -2839,8 +2560,8 @@ class WhileLoop(Primitive):
2839
2560
  while cond_func(val):
2840
2561
  val = loop_func(val)
2841
2562
  except Exception as e:
2842
- raise ValueError("Invalid loop_func, please check input arguments and \
2843
- return value, error info: {}".format(e))
2563
+ raise ValueError(f"Invalid loop_func, please check input arguments and "
2564
+ f"return value, error info: {e}")
2844
2565
  return val
2845
2566
 
2846
2567
 
@@ -2935,8 +2656,8 @@ class Scan(Primitive):
2935
2656
  ys.append(y)
2936
2657
  i = i + 1
2937
2658
  except Exception as e:
2938
- raise ValueError("Invalid loop_func, please check input arguments and \
2939
- return value, error info: {}".format(e))
2659
+ raise ValueError(f"Invalid loop_func, please check input arguments and "
2660
+ f"return value, error info: {e}")
2940
2661
  return carry, ys
2941
2662
 
2942
2663
 
@@ -3011,6 +2732,6 @@ class ForiLoop(Primitive):
3011
2732
  for i in range(lower, upper):
3012
2733
  val = loop_func(i, val)
3013
2734
  except Exception as e:
3014
- raise ValueError("Invalid loop_func, please check input arguments and \
3015
- return value, error info: {}".format(e))
2735
+ raise ValueError(f"Invalid loop_func, please check input arguments and "
2736
+ f"return value, error info: {e}")
3016
2737
  return val
@@ -882,7 +882,7 @@ class Sub(_MathBinaryOp):
882
882
  Note:
883
883
  - When the two inputs have different shapes, they must be able to broadcast to a common shape.
884
884
  - The two inputs can not be bool type at the same time,
885
- [True, Tensor(True, bool\_), Tensor(np.array([True]), bool\_)] are all considered bool type.
885
+ [True, Tensor(True), Tensor(np.array([True]))] are all considered bool type.
886
886
  - The two inputs comply with the implicit type conversion rules to make the data types
887
887
  consistent.
888
888
 
@@ -890,7 +890,7 @@ class Sub(_MathBinaryOp):
890
890
  - **x** (Union[Tensor, number.Number, bool]) - The first input is a number.Number or
891
891
  a bool or a tensor whose data type is
892
892
  `number <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.dtype.html>`_ or
893
- `bool_ <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.dtype.html>`_.
893
+ `bool <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.dtype.html>`_.
894
894
  - **y** (Union[Tensor, number.Number, bool]) - The second input, when the first input is a Tensor,
895
895
  the second input should be a number.Number or bool value, or a Tensor whose data type is number or bool.
896
896
 
@@ -1289,10 +1289,10 @@ class DivNoNan(Primitive):
1289
1289
  - **x1** (Union[Tensor, number.Number, bool]) - The first input is a number.Number or
1290
1290
  a bool or a tensor whose data type is
1291
1291
  `number <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.dtype.html>`_ or
1292
- `bool_ <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.dtype.html>`_.
1292
+ `bool <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.dtype.html>`_.
1293
1293
  - **x2** (Union[Tensor, number.Number, bool]) - The second input is a number.Number or
1294
- a bool when the first input is a bool or a tensor whose data type is number or bool\_.
1295
- When the first input is Scalar, the second input must be a Tensor whose data type is number or bool\_.
1294
+ a bool when the first input is a bool or a tensor whose data type is number or bool.
1295
+ When the first input is Scalar, the second input must be a Tensor whose data type is number or bool.
1296
1296
 
1297
1297
  Outputs:
1298
1298
  Tensor, the shape is the same as the one after broadcasting,
@@ -41,7 +41,7 @@ from ..auto_generate import (CeLU, Flatten, LogSoftmax, LogSoftmaxExt, GLU, ReLU
41
41
  UpsampleNearest3D, UpsampleTrilinear3D,
42
42
  SoftMarginLoss, UpsampleBilinear2D, UpsampleLinear1D,
43
43
  BinaryCrossEntropy, BCEWithLogitsLoss, SoftShrink, AdaptiveMaxPool2D,
44
- SmoothL1Loss)
44
+ SmoothL1Loss, KvScaleCache)
45
45
  from .manually_defined import BatchNorm
46
46
 
47
47
 
@@ -6868,8 +6868,8 @@ class CTCLossV2(Primitive):
6868
6868
  >>> print(neg_log_hood)
6869
6869
  [-2.2986124]
6870
6870
  >>> print(log_alpha)
6871
- [[[0.3 0.3 -inf -inf -inf]
6872
- [1.2 1.8931472 1.2 -inf -inf]]]
6871
+ [[[0.3 0.3 -inf -inf 1.8931472 1.2 0. 0. ]
6872
+ [0. 0. 0. 0. 0. 0. 0. 0. ]]]
6873
6873
  """
6874
6874
 
6875
6875
  @prim_attr_register