mindspore 2.6.0rc1__cp311-cp311-win_amd64.whl → 2.7.0__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (458) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +2 -2
  5. mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
  8. mindspore/_checkparam.py +42 -11
  9. mindspore/_extends/builtin_operations.py +3 -3
  10. mindspore/{_deprecated → _extends/optimize}/__init__.py +9 -3
  11. mindspore/_extends/optimize/cell_utils.py +96 -0
  12. mindspore/_extends/parallel_compile/akg_compiler/custom.py +1109 -0
  13. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
  14. mindspore/_extends/parse/__init__.py +3 -3
  15. mindspore/_extends/parse/compile_config.py +44 -22
  16. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +1 -2
  17. mindspore/_extends/parse/parser.py +65 -84
  18. mindspore/_extends/parse/resources.py +39 -0
  19. mindspore/_extends/parse/standard_method.py +58 -14
  20. mindspore/_extends/parse/trope.py +8 -1
  21. mindspore/_extends/pijit/__init__.py +1 -2
  22. mindspore/_extends/pijit/pijit_func_white_list.py +2 -5
  23. mindspore/amp.py +4 -22
  24. mindspore/atlprov.dll +0 -0
  25. mindspore/avcodec-59.dll +0 -0
  26. mindspore/avdevice-59.dll +0 -0
  27. mindspore/avfilter-8.dll +0 -0
  28. mindspore/avformat-59.dll +0 -0
  29. mindspore/avutil-57.dll +0 -0
  30. mindspore/boost/adasum.py +1 -1
  31. mindspore/boost/boost_cell_wrapper.py +4 -4
  32. mindspore/c1.dll +0 -0
  33. mindspore/c1xx.dll +0 -0
  34. mindspore/c2.dll +0 -0
  35. mindspore/common/__init__.py +43 -12
  36. mindspore/common/_grad_function.py +2 -1
  37. mindspore/common/_pijit_context.py +28 -7
  38. mindspore/common/_stub_tensor.py +1 -209
  39. mindspore/common/_tensor_cpp_method.py +1 -1
  40. mindspore/common/_tensor_docs.py +178 -53
  41. mindspore/common/_utils.py +9 -1
  42. mindspore/common/api.py +377 -203
  43. mindspore/common/dtype.py +108 -57
  44. mindspore/common/dump.py +11 -16
  45. mindspore/common/dynamic_shape/__init__.py +0 -0
  46. mindspore/common/{auto_dynamic_shape.py → dynamic_shape/auto_dynamic_shape.py} +17 -23
  47. mindspore/common/dynamic_shape/enable_dynamic.py +197 -0
  48. mindspore/common/file_system.py +59 -9
  49. mindspore/common/generator.py +5 -3
  50. mindspore/common/hook_handle.py +33 -5
  51. mindspore/common/jit_config.py +1 -1
  52. mindspore/common/jit_trace.py +84 -105
  53. mindspore/common/np_dtype.py +3 -3
  54. mindspore/common/parameter.py +27 -29
  55. mindspore/common/recompute.py +5 -7
  56. mindspore/common/sparse_tensor.py +0 -3
  57. mindspore/common/symbol.py +0 -1
  58. mindspore/common/tensor.py +117 -131
  59. mindspore/communication/_comm_helper.py +46 -4
  60. mindspore/communication/management.py +79 -7
  61. mindspore/context.py +67 -55
  62. mindspore/dataset/__init__.py +1 -1
  63. mindspore/dataset/audio/transforms.py +1 -1
  64. mindspore/dataset/core/config.py +38 -4
  65. mindspore/dataset/engine/datasets.py +350 -322
  66. mindspore/dataset/engine/datasets_user_defined.py +70 -24
  67. mindspore/dataset/engine/iterators.py +2 -2
  68. mindspore/dataset/engine/obs/config_loader.py +2 -2
  69. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +8 -0
  70. mindspore/dataset/transforms/c_transforms.py +2 -2
  71. mindspore/dataset/transforms/py_transforms.py +7 -3
  72. mindspore/dataset/transforms/transforms.py +10 -6
  73. mindspore/dataset/vision/__init__.py +1 -1
  74. mindspore/dataset/vision/py_transforms.py +8 -8
  75. mindspore/dataset/vision/transforms.py +17 -5
  76. mindspore/dataset/vision/utils.py +632 -21
  77. mindspore/dataset/vision/validators.py +1 -0
  78. mindspore/device_context/ascend/device.py +1 -1
  79. mindspore/device_context/ascend/op_tuning.py +35 -1
  80. mindspore/device_context/gpu/__init__.py +2 -2
  81. mindspore/device_context/gpu/device.py +1 -1
  82. mindspore/device_context/gpu/op_precision.py +4 -2
  83. mindspore/device_context/gpu/op_tuning.py +6 -3
  84. mindspore/device_manager.py +16 -9
  85. mindspore/dnnl.dll +0 -0
  86. mindspore/dpcmi.dll +0 -0
  87. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +3 -4
  88. mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
  89. mindspore/experimental/optim/adadelta.py +13 -20
  90. mindspore/experimental/optim/adagrad.py +15 -22
  91. mindspore/experimental/optim/adam.py +17 -24
  92. mindspore/experimental/optim/adamax.py +14 -22
  93. mindspore/experimental/optim/adamw.py +28 -34
  94. mindspore/experimental/optim/asgd.py +15 -25
  95. mindspore/experimental/optim/lr_scheduler.py +27 -45
  96. mindspore/experimental/optim/nadam.py +14 -24
  97. mindspore/experimental/optim/optimizer.py +13 -23
  98. mindspore/experimental/optim/radam.py +18 -24
  99. mindspore/experimental/optim/rmsprop.py +14 -25
  100. mindspore/experimental/optim/rprop.py +15 -26
  101. mindspore/experimental/optim/sgd.py +9 -19
  102. mindspore/hal/__init__.py +4 -4
  103. mindspore/hal/contiguous_tensors_handle.py +2 -2
  104. mindspore/hal/memory.py +27 -7
  105. mindspore/include/api/cell.h +65 -5
  106. mindspore/include/api/cfg.h +24 -7
  107. mindspore/include/api/context.h +1 -0
  108. mindspore/include/api/delegate.h +10 -2
  109. mindspore/include/api/dual_abi_helper.h +100 -19
  110. mindspore/include/api/graph.h +14 -1
  111. mindspore/include/api/kernel.h +16 -3
  112. mindspore/include/api/kernel_api.h +9 -1
  113. mindspore/include/api/metrics/accuracy.h +9 -0
  114. mindspore/include/api/model.h +8 -1
  115. mindspore/include/api/model_group.h +4 -0
  116. mindspore/include/api/model_parallel_runner.h +2 -0
  117. mindspore/include/api/status.h +48 -10
  118. mindspore/include/api/types.h +8 -3
  119. mindspore/include/c_api/model_c.h +0 -58
  120. mindspore/include/c_api/tensor_c.h +0 -26
  121. mindspore/include/dataset/constants.h +9 -0
  122. mindspore/include/dataset/vision_ascend.h +1 -1
  123. mindspore/jpeg62.dll +0 -0
  124. mindspore/mindrecord/tools/cifar10.py +61 -11
  125. mindspore/mindrecord/tools/cifar10_to_mr.py +5 -0
  126. mindspore/mindspore_backend_common.dll +0 -0
  127. mindspore/mindspore_backend_manager.dll +0 -0
  128. mindspore/mindspore_common.dll +0 -0
  129. mindspore/mindspore_core.dll +0 -0
  130. mindspore/mindspore_cpu_res_manager.dll +0 -0
  131. mindspore/mindspore_dump.dll +0 -0
  132. mindspore/mindspore_frontend.dll +0 -0
  133. mindspore/mindspore_glog.dll +0 -0
  134. mindspore/mindspore_memory_pool.dll +0 -0
  135. mindspore/mindspore_ms_backend.dll +0 -0
  136. mindspore/mindspore_ops.dll +0 -0
  137. mindspore/mindspore_ops_host.dll +0 -0
  138. mindspore/mindspore_ops_kernel_common.dll +0 -0
  139. mindspore/mindspore_profiler.dll +0 -0
  140. mindspore/mindspore_pyboost.dll +0 -0
  141. mindspore/mindspore_pynative.dll +0 -0
  142. mindspore/mindspore_res_manager.dll +0 -0
  143. mindspore/mindspore_runtime_pipeline.dll +0 -0
  144. mindspore/mint/__init__.py +6 -46
  145. mindspore/mint/distributed/__init__.py +5 -0
  146. mindspore/mint/distributed/distributed.py +429 -23
  147. mindspore/mint/nn/__init__.py +1 -1
  148. mindspore/mint/nn/functional.py +53 -6
  149. mindspore/mint/nn/layer/_functions.py +163 -294
  150. mindspore/mint/nn/layer/activation.py +8 -6
  151. mindspore/mint/nn/layer/conv.py +140 -104
  152. mindspore/mint/nn/layer/normalization.py +11 -25
  153. mindspore/mint/optim/adam.py +19 -18
  154. mindspore/mint/optim/adamw.py +14 -8
  155. mindspore/mint/optim/sgd.py +5 -5
  156. mindspore/msobj140.dll +0 -0
  157. mindspore/mspdb140.dll +0 -0
  158. mindspore/mspdbcore.dll +0 -0
  159. mindspore/mspdbst.dll +0 -0
  160. mindspore/mspft140.dll +0 -0
  161. mindspore/msvcdis140.dll +0 -0
  162. mindspore/msvcp140_1.dll +0 -0
  163. mindspore/msvcp140_2.dll +0 -0
  164. mindspore/msvcp140_atomic_wait.dll +0 -0
  165. mindspore/msvcp140_codecvt_ids.dll +0 -0
  166. mindspore/nn/cell.py +491 -623
  167. mindspore/nn/grad/cell_grad.py +11 -12
  168. mindspore/nn/layer/activation.py +36 -36
  169. mindspore/nn/layer/basic.py +74 -77
  170. mindspore/nn/layer/channel_shuffle.py +4 -4
  171. mindspore/nn/layer/combined.py +4 -2
  172. mindspore/nn/layer/conv.py +117 -110
  173. mindspore/nn/layer/dense.py +9 -7
  174. mindspore/nn/layer/embedding.py +50 -52
  175. mindspore/nn/layer/image.py +38 -40
  176. mindspore/nn/layer/math.py +111 -112
  177. mindspore/nn/layer/normalization.py +56 -44
  178. mindspore/nn/layer/pooling.py +58 -63
  179. mindspore/nn/layer/rnn_cells.py +33 -33
  180. mindspore/nn/layer/rnns.py +56 -56
  181. mindspore/nn/layer/thor_layer.py +74 -73
  182. mindspore/nn/layer/transformer.py +11 -1
  183. mindspore/nn/learning_rate_schedule.py +20 -20
  184. mindspore/nn/loss/loss.py +79 -81
  185. mindspore/nn/optim/adam.py +4 -6
  186. mindspore/nn/optim/adasum.py +2 -2
  187. mindspore/nn/optim/asgd.py +2 -0
  188. mindspore/nn/optim/lamb.py +1 -3
  189. mindspore/nn/optim/optimizer.py +1 -1
  190. mindspore/nn/optim/tft_wrapper.py +2 -3
  191. mindspore/nn/optim/thor.py +2 -2
  192. mindspore/nn/probability/distribution/_utils/utils.py +2 -2
  193. mindspore/nn/probability/distribution/exponential.py +2 -1
  194. mindspore/nn/probability/distribution/poisson.py +2 -1
  195. mindspore/nn/sparse/sparse.py +3 -3
  196. mindspore/nn/wrap/cell_wrapper.py +73 -42
  197. mindspore/nn/wrap/grad_reducer.py +37 -52
  198. mindspore/nn/wrap/loss_scale.py +72 -74
  199. mindspore/numpy/array_creations.py +7 -7
  200. mindspore/numpy/fft.py +1 -1
  201. mindspore/numpy/math_ops.py +5 -5
  202. mindspore/numpy/utils_const.py +1 -1
  203. mindspore/opencv_core452.dll +0 -0
  204. mindspore/opencv_imgcodecs452.dll +0 -0
  205. mindspore/opencv_imgproc452.dll +0 -0
  206. mindspore/ops/_grad_experimental/grad_comm_ops.py +51 -13
  207. mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -0
  208. mindspore/ops/_grad_experimental/grad_inner_ops.py +0 -9
  209. mindspore/ops/_op_impl/cpu/__init__.py +1 -0
  210. mindspore/{experimental/es/__init__.py → ops/_op_impl/cpu/joinedstr_op.py} +12 -6
  211. mindspore/ops/_vmap/vmap_array_ops.py +31 -13
  212. mindspore/ops/_vmap/vmap_nn_ops.py +8 -16
  213. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +54 -13
  214. mindspore/ops/auto_generate/gen_extend_func.py +27 -145
  215. mindspore/ops/auto_generate/gen_ops_def.py +1027 -347
  216. mindspore/ops/auto_generate/gen_ops_prim.py +2341 -1117
  217. mindspore/ops/auto_generate/pyboost_inner_prim.py +31 -1
  218. mindspore/ops/composite/__init__.py +10 -0
  219. mindspore/ops/composite/base.py +9 -5
  220. mindspore/ops/composite/multitype_ops/__init__.py +12 -1
  221. mindspore/ops/composite/multitype_ops/_compile_utils.py +133 -109
  222. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -1
  223. mindspore/ops/composite/multitype_ops/add_impl.py +70 -2
  224. mindspore/ops/composite/multitype_ops/div_impl.py +49 -0
  225. mindspore/ops/composite/multitype_ops/floordiv_impl.py +29 -0
  226. mindspore/ops/composite/multitype_ops/getitem_impl.py +11 -0
  227. mindspore/ops/composite/multitype_ops/mod_impl.py +5 -3
  228. mindspore/ops/composite/multitype_ops/mul_impl.py +49 -0
  229. mindspore/ops/composite/multitype_ops/setitem_impl.py +57 -0
  230. mindspore/ops/composite/multitype_ops/sub_impl.py +34 -0
  231. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +14 -0
  232. mindspore/ops/function/__init__.py +4 -1
  233. mindspore/ops/function/_add_attr_func.py +11 -6
  234. mindspore/ops/function/array_func.py +19 -102
  235. mindspore/ops/function/debug_func.py +8 -5
  236. mindspore/ops/function/grad/grad_func.py +5 -13
  237. mindspore/ops/function/math_func.py +77 -572
  238. mindspore/ops/function/nn_func.py +46 -94
  239. mindspore/ops/function/other_func.py +4 -1
  240. mindspore/ops/function/random_func.py +44 -5
  241. mindspore/ops/function/vmap_func.py +2 -1
  242. mindspore/ops/functional.py +4 -4
  243. mindspore/ops/functional_overload.py +594 -18
  244. mindspore/ops/op_info_register.py +21 -0
  245. mindspore/ops/operations/__init__.py +16 -11
  246. mindspore/ops/operations/_custom_ops_utils.py +689 -34
  247. mindspore/ops/operations/_inner_ops.py +14 -18
  248. mindspore/ops/operations/_sequence_ops.py +1 -1
  249. mindspore/ops/operations/array_ops.py +5 -51
  250. mindspore/ops/operations/comm_ops.py +186 -41
  251. mindspore/ops/operations/custom_ops.py +303 -177
  252. mindspore/ops/operations/debug_ops.py +59 -4
  253. mindspore/ops/operations/image_ops.py +13 -13
  254. mindspore/ops/operations/manually_defined/ops_def.py +27 -28
  255. mindspore/ops/operations/math_ops.py +8 -9
  256. mindspore/ops/operations/nn_ops.py +8 -40
  257. mindspore/ops/primitive.py +9 -20
  258. mindspore/ops/tensor_method.py +63 -15
  259. mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +1 -1
  260. mindspore/ops_generate/api/functional_map_cpp_generator.py +10 -9
  261. mindspore/ops_generate/api/functions_cc_generator.py +58 -10
  262. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +1 -1
  263. mindspore/ops_generate/common/base_generator.py +14 -0
  264. mindspore/ops_generate/common/gen_constants.py +8 -3
  265. mindspore/ops_generate/common/gen_utils.py +0 -19
  266. mindspore/ops_generate/common/op_proto.py +11 -4
  267. mindspore/ops_generate/common/template.py +88 -11
  268. mindspore/ops_generate/gen_ops.py +1 -1
  269. mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +4 -4
  270. mindspore/ops_generate/op_def/ops_def_cc_generator.py +0 -3
  271. mindspore/ops_generate/op_def/ops_name_h_generator.py +0 -3
  272. mindspore/ops_generate/op_def/ops_primitive_h_generator.py +0 -4
  273. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -2
  274. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +49 -8
  275. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +2 -2
  276. mindspore/ops_generate/pyboost/gen_pyboost_func.py +31 -16
  277. mindspore/ops_generate/pyboost/op_template_parser.py +98 -72
  278. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +70 -273
  279. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +14 -6
  280. mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +316 -0
  281. mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +1 -1
  282. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +5 -3
  283. mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +1 -1
  284. mindspore/ops_generate/pyboost/pyboost_internal_functions_cpp_generator.py +76 -0
  285. mindspore/ops_generate/pyboost/pyboost_internal_functions_h_generator.py +76 -0
  286. mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +125 -0
  287. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +4 -3
  288. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +348 -61
  289. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +1 -1
  290. mindspore/ops_generate/pyboost/pyboost_utils.py +118 -9
  291. mindspore/ops_generate/tensor_py_cc_generator.py +1 -24
  292. mindspore/parallel/_auto_parallel_context.py +16 -23
  293. mindspore/parallel/_cell_wrapper.py +113 -45
  294. mindspore/parallel/_parallel_serialization.py +4 -3
  295. mindspore/parallel/_ps_context.py +4 -6
  296. mindspore/parallel/_tensor.py +167 -12
  297. mindspore/parallel/_transformer/moe.py +1 -1
  298. mindspore/parallel/_transformer/transformer.py +17 -12
  299. mindspore/parallel/_utils.py +5 -11
  300. mindspore/parallel/auto_parallel.py +35 -14
  301. mindspore/parallel/checkpoint_convert.py +3 -3
  302. mindspore/parallel/checkpoint_transform.py +13 -7
  303. mindspore/parallel/cluster/process_entity/_api.py +88 -49
  304. mindspore/parallel/cluster/process_entity/_utils.py +95 -7
  305. mindspore/parallel/cluster/run.py +48 -7
  306. mindspore/parallel/function/__init__.py +8 -1
  307. mindspore/parallel/function/reshard_func.py +12 -12
  308. mindspore/parallel/nn/__init__.py +15 -2
  309. mindspore/parallel/nn/parallel_cell_wrapper.py +50 -14
  310. mindspore/parallel/nn/parallel_grad_reducer.py +7 -14
  311. mindspore/parallel/shard.py +10 -25
  312. mindspore/parallel/transform_safetensors.py +469 -174
  313. mindspore/pgodb140.dll +0 -0
  314. mindspore/pgort140.dll +0 -0
  315. mindspore/profiler/__init__.py +2 -1
  316. mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +7 -7
  317. mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +3 -0
  318. mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +12 -6
  319. mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +3 -3
  320. mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +3 -3
  321. mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +4 -4
  322. mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +3 -3
  323. mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +4 -1
  324. mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +2 -1
  325. mindspore/profiler/analysis/task_manager.py +1 -1
  326. mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +5 -1
  327. mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +2 -1
  328. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +10 -9
  329. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +43 -23
  330. mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +3 -2
  331. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +9 -5
  332. mindspore/profiler/analysis/viewer/ms_operator_details_viewer.py +132 -0
  333. mindspore/profiler/common/constant.py +16 -0
  334. mindspore/profiler/common/msprof_cmd_tool.py +2 -2
  335. mindspore/profiler/common/path_manager.py +9 -0
  336. mindspore/profiler/common/profiler_context.py +50 -29
  337. mindspore/profiler/common/profiler_info.py +0 -16
  338. mindspore/profiler/common/profiler_meta_data.py +1 -0
  339. mindspore/profiler/common/profiler_op_analyse.py +239 -0
  340. mindspore/profiler/common/profiler_output_path.py +23 -8
  341. mindspore/profiler/common/profiler_parameters.py +128 -35
  342. mindspore/profiler/dynamic_profile/__init__.py +0 -0
  343. mindspore/profiler/dynamic_profile/dynamic_monitor_proxy.py +39 -0
  344. mindspore/profiler/dynamic_profile/dynamic_profiler_config_context.py +666 -0
  345. mindspore/profiler/dynamic_profile/dynamic_profiler_utils.py +62 -0
  346. mindspore/profiler/dynamic_profiler.py +374 -338
  347. mindspore/profiler/envprofiler.py +42 -12
  348. mindspore/profiler/experimental_config.py +112 -7
  349. mindspore/profiler/mstx.py +33 -12
  350. mindspore/profiler/platform/__init__.py +2 -3
  351. mindspore/profiler/platform/cpu_profiler.py +10 -4
  352. mindspore/profiler/platform/npu_profiler.py +30 -20
  353. mindspore/profiler/profiler.py +218 -154
  354. mindspore/profiler/profiler_action_controller.py +65 -77
  355. mindspore/profiler/profiler_interface.py +2 -2
  356. mindspore/profiler/schedule.py +10 -4
  357. mindspore/rewrite/common/config.py +1 -0
  358. mindspore/rewrite/common/namer.py +1 -0
  359. mindspore/rewrite/common/namespace.py +1 -0
  360. mindspore/rewrite/node/node.py +31 -11
  361. mindspore/rewrite/parsers/assign_parser.py +1 -1
  362. mindspore/rewrite/symbol_tree/symbol_tree.py +2 -2
  363. mindspore/run_check/_check_version.py +7 -10
  364. mindspore/runtime/__init__.py +8 -6
  365. mindspore/runtime/event.py +10 -4
  366. mindspore/runtime/executor.py +87 -45
  367. mindspore/runtime/memory.py +31 -32
  368. mindspore/runtime/thread_bind_core.py +299 -165
  369. mindspore/safeguard/rewrite_obfuscation.py +12 -13
  370. mindspore/swresample-4.dll +0 -0
  371. mindspore/swscale-6.dll +0 -0
  372. mindspore/tbbmalloc.dll +0 -0
  373. mindspore/tinyxml2.dll +0 -0
  374. mindspore/train/_utils.py +17 -7
  375. mindspore/train/amp.py +43 -23
  376. mindspore/train/callback/__init__.py +5 -5
  377. mindspore/train/callback/_callback.py +2 -1
  378. mindspore/train/callback/_checkpoint.py +4 -14
  379. mindspore/train/callback/_flops_collector.py +11 -7
  380. mindspore/train/callback/_landscape.py +0 -1
  381. mindspore/train/callback/_train_fault_tolerance.py +98 -21
  382. mindspore/train/data_sink.py +15 -6
  383. mindspore/train/dataset_helper.py +14 -5
  384. mindspore/train/model.py +133 -69
  385. mindspore/train/serialization.py +168 -126
  386. mindspore/train/summary/summary_record.py +13 -2
  387. mindspore/train/train_thor/model_thor.py +2 -2
  388. mindspore/turbojpeg.dll +0 -0
  389. mindspore/utils/__init__.py +3 -2
  390. mindspore/utils/dryrun.py +0 -6
  391. mindspore/utils/runtime_execution_order_check.py +163 -77
  392. mindspore/utils/sdc_detect.py +68 -0
  393. mindspore/utils/utils.py +14 -17
  394. mindspore/vcmeta.dll +0 -0
  395. mindspore/vcruntime140.dll +0 -0
  396. mindspore/vcruntime140_1.dll +0 -0
  397. mindspore/version.py +1 -1
  398. {mindspore-2.6.0rc1.dist-info → mindspore-2.7.0.dist-info}/METADATA +5 -4
  399. {mindspore-2.6.0rc1.dist-info → mindspore-2.7.0.dist-info}/RECORD +403 -442
  400. mindspore/_deprecated/jit.py +0 -198
  401. mindspore/_extends/remote/kernel_build_server_ascend.py +0 -75
  402. mindspore/communication/_hccl_management.py +0 -297
  403. mindspore/experimental/es/embedding_service.py +0 -891
  404. mindspore/experimental/es/embedding_service_layer.py +0 -581
  405. mindspore/profiler/common/validator/__init__.py +0 -14
  406. mindspore/profiler/common/validator/validate_path.py +0 -84
  407. mindspore/profiler/parser/__init__.py +0 -14
  408. mindspore/profiler/parser/aicpu_data_parser.py +0 -272
  409. mindspore/profiler/parser/ascend_analysis/__init__.py +0 -14
  410. mindspore/profiler/parser/ascend_analysis/constant.py +0 -71
  411. mindspore/profiler/parser/ascend_analysis/file_manager.py +0 -180
  412. mindspore/profiler/parser/ascend_analysis/function_event.py +0 -185
  413. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +0 -136
  414. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +0 -131
  415. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +0 -104
  416. mindspore/profiler/parser/ascend_analysis/path_manager.py +0 -313
  417. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +0 -123
  418. mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +0 -86
  419. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +0 -75
  420. mindspore/profiler/parser/ascend_cluster_generator.py +0 -116
  421. mindspore/profiler/parser/ascend_communicate_generator.py +0 -314
  422. mindspore/profiler/parser/ascend_flops_generator.py +0 -116
  423. mindspore/profiler/parser/ascend_fpbp_generator.py +0 -82
  424. mindspore/profiler/parser/ascend_hccl_generator.py +0 -271
  425. mindspore/profiler/parser/ascend_integrate_generator.py +0 -42
  426. mindspore/profiler/parser/ascend_memory_generator.py +0 -185
  427. mindspore/profiler/parser/ascend_msprof_exporter.py +0 -282
  428. mindspore/profiler/parser/ascend_msprof_generator.py +0 -187
  429. mindspore/profiler/parser/ascend_op_generator.py +0 -334
  430. mindspore/profiler/parser/ascend_steptrace_generator.py +0 -94
  431. mindspore/profiler/parser/ascend_timeline_generator.py +0 -545
  432. mindspore/profiler/parser/base_timeline_generator.py +0 -483
  433. mindspore/profiler/parser/container.py +0 -229
  434. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +0 -697
  435. mindspore/profiler/parser/flops_parser.py +0 -531
  436. mindspore/profiler/parser/framework_enum.py +0 -111
  437. mindspore/profiler/parser/framework_parser.py +0 -464
  438. mindspore/profiler/parser/framework_struct.py +0 -61
  439. mindspore/profiler/parser/gpu_analysis/__init__.py +0 -14
  440. mindspore/profiler/parser/gpu_analysis/function_event.py +0 -44
  441. mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +0 -89
  442. mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +0 -72
  443. mindspore/profiler/parser/hccl_parser.py +0 -573
  444. mindspore/profiler/parser/hwts_log_parser.py +0 -122
  445. mindspore/profiler/parser/integrator.py +0 -526
  446. mindspore/profiler/parser/memory_usage_parser.py +0 -277
  447. mindspore/profiler/parser/minddata_analyzer.py +0 -800
  448. mindspore/profiler/parser/minddata_parser.py +0 -186
  449. mindspore/profiler/parser/minddata_pipeline_parser.py +0 -299
  450. mindspore/profiler/parser/op_intermediate_parser.py +0 -149
  451. mindspore/profiler/parser/optime_parser.py +0 -250
  452. mindspore/profiler/parser/profiler_info.py +0 -213
  453. mindspore/profiler/parser/step_trace_parser.py +0 -666
  454. mindspore/utils/hooks.py +0 -81
  455. /mindspore/common/{_auto_dynamic.py → dynamic_shape/_auto_dynamic.py} +0 -0
  456. {mindspore-2.6.0rc1.dist-info → mindspore-2.7.0.dist-info}/WHEEL +0 -0
  457. {mindspore-2.6.0rc1.dist-info → mindspore-2.7.0.dist-info}/entry_points.txt +0 -0
  458. {mindspore-2.6.0rc1.dist-info → mindspore-2.7.0.dist-info}/top_level.txt +0 -0
mindspore/train/model.py CHANGED
@@ -57,8 +57,10 @@ from mindspore.dataset.engine.datasets import _set_training_dataset, _reset_trai
57
57
  from mindspore.train import amp
58
58
  from mindspore._c_expression import _framework_profiler_step_start, _framework_profiler_step_end
59
59
  from mindspore._c_expression import _get_optimzer_timestamps
60
+ from mindspore._c_expression import clean_tdt_channel, _clean_rootinfo
60
61
 
61
62
  from mindspore.parallel._utils import _init_auto_parallel_context, _clear_auto_parallel_context
63
+ from .serialization import load_param_into_net
62
64
 
63
65
  def _transfer_tensor_to_tuple(inputs):
64
66
  """
@@ -130,7 +132,8 @@ def _handle_exception_info(obj, uce_env, tft, e):
130
132
  if not uce_env:
131
133
  logger.error("uce wrapper caught RuntimeError but uce not enable, enter MindIO TTP process.",
132
134
  exc_info=True)
133
- tft.tft_report_error(tft.ReportState.RS_UNKNOWN.value)
135
+ if tft:
136
+ tft.tft_report_error(tft.ReportState.RS_UNKNOWN.value)
134
137
  raise e
135
138
  e_str = str(e)
136
139
  logger.warning("uce wrapper caught RuntimeError e_str:{}".format(e_str))
@@ -151,6 +154,9 @@ def _handle_exception_info(obj, uce_env, tft, e):
151
154
  tft.tft_report_error(tft.ReportState.RS_UNKNOWN.value)
152
155
  raise e
153
156
  tft.tft_report_error(tft.ReportState.RS_UCE.value)
157
+ elif "HCCEError" in e_str:
158
+ logger.warning("uce wrapper caught HCCEError")
159
+ tft.tft_report_error(tft.ReportState.RS_HCCL_FAILED.value)
154
160
  elif "ForceStopError" in e_str:
155
161
  logger.warning("uce wrapper caught RuntimeError ForceStopError")
156
162
  force_stop_err = tft.ReportState.RS_NORMAL.value
@@ -165,6 +171,69 @@ def _handle_exception_info(obj, uce_env, tft, e):
165
171
  raise e
166
172
 
167
173
 
174
+ def _handle_training_result_error(model, tft_obj):
175
+ """
176
+ Handle training result error for resuming training.
177
+ """
178
+ ckpt_load_fn = tft_obj.ckpt_load_func
179
+ train_network = tft_obj.cb_params.train_network
180
+ logger.warning("Process training result error start.")
181
+ # 1. Clear tdt channel
182
+ logger.warning("Clean tdt channel.")
183
+ clean_tdt_channel()
184
+
185
+ # 2. Load checkpoint
186
+ logger.warning("Load checkpoint.")
187
+ new_param_dict, remove_redundancy = ckpt_load_fn()
188
+ param_not_load, ckpt_not_load = load_param_into_net(train_network, new_param_dict, True, remove_redundancy)
189
+ logger.warning(f"param_not_load: {param_not_load}")
190
+ logger.warning(f"ckpt_not_load: {ckpt_not_load}")
191
+ resume_epoch = new_param_dict.get('epoch_num')
192
+ resume_step = new_param_dict.get('step_num')
193
+ model._initial_step = int(resume_step.asnumpy())
194
+ logger.warning("Process training result error end.")
195
+ return (resume_epoch, resume_step)
196
+
197
+
198
+ def _calc_cb_initial_step(org_epoch, org_step, *args, **kwargs):
199
+ """calculate initial step for callback"""
200
+ train_dataset = args[1]
201
+ dataset_sink_mode = args[3] if len(args) > 3 else kwargs.get('dataset_sink_mode', True)
202
+ sink_size = args[4] if len(args) > 4 else kwargs.get('sink_size', -1)
203
+
204
+ cb_initial_step = 0
205
+ if dataset_sink_mode:
206
+ train_dataset.set_init_step(org_epoch)
207
+ dataset_size = train_dataset.get_dataset_size()
208
+ if sink_size != -1:
209
+ cb_initial_step = org_epoch * sink_size + org_step
210
+ else:
211
+ cb_initial_step = org_epoch * dataset_size + org_step
212
+ else:
213
+ train_dataset.set_init_step(org_step)
214
+ cb_initial_step = org_step
215
+ if hasattr(train_dataset, '_dataset_helper'):
216
+ dataset_helper = train_dataset._dataset_helper
217
+ _reset_training_dataset(cb_initial_step, dataset_helper.iter.dataset.get_dataset_size())
218
+ return cb_initial_step
219
+
220
+
221
+ def _update_ckpt_callback_info(resume_train_step, **kwargs):
222
+ """
223
+ Update checkpoint callback internal state
224
+ """
225
+ ckpt_obj = None
226
+ if kwargs.get('callbacks') and isinstance(kwargs.get('callbacks'), ModelCheckpoint):
227
+ ckpt_obj = kwargs.get('callbacks')
228
+ if kwargs.get('callbacks') and isinstance(kwargs.get('callbacks'), list):
229
+ for item in kwargs.get('callbacks'):
230
+ if isinstance(item, ModelCheckpoint):
231
+ ckpt_obj = item
232
+ if ckpt_obj is not None:
233
+ ckpt_obj._last_triggered_step = 0
234
+ ckpt_obj._append_step_num = resume_train_step
235
+
236
+
168
237
  def _handle_tft(func):
169
238
  """
170
239
  Decorator function, which starts uce handle process when an exception occurs during training.
@@ -180,42 +249,34 @@ def _handle_tft(func):
180
249
  if isinstance(item, TrainFaultTolerance):
181
250
  obj = item
182
251
  if obj:
183
- tft = obj.tft
184
252
  tft_env = os.getenv("MS_ENABLE_TFT", "")
185
- uce_env = "UCE:1" in tft_env or "ARF:1" in tft_env
253
+ uce_env = "UCE:1" in tft_env or "ARF:1" in tft_env or "HCCE:1" in tft_env
254
+ tre_env = "TRE:1" in tft_env
186
255
  while True:
187
256
  try:
188
257
  return func(self, *args, **kwargs)
189
258
  except RuntimeError as e:
190
- _handle_exception_info(obj, uce_env, tft, e)
191
- ret = tft.tft_wait_next_action()
192
- if ret == tft.Action.EXIT.value:
193
- raise e
194
- repair_step = tft.tft_get_repair_step()
195
- logger.warning(
196
- "uce wrapper caught repair finish REPAIR STEP: {} batch_num:{}".format(repair_step,
197
- self.batch_num))
259
+ if tre_env and 'TREError' in str(e):
260
+ _, resume_step = _handle_training_result_error(self, obj)
261
+ repair_step = int(resume_step.asnumpy())
262
+ _update_ckpt_callback_info(repair_step, **kwargs)
263
+ logger.warning(f'Resume training after TREError from step {repair_step}.')
264
+ else:
265
+ _handle_exception_info(obj, uce_env, obj.tft, e)
266
+ ret = obj.tft.tft_wait_next_action()
267
+ if ret == obj.tft.Action.EXIT.value:
268
+ raise e
269
+ repair_step = obj.tft.tft_get_repair_step()
270
+ logger.warning(
271
+ "uce wrapper caught repair finish REPAIR STEP: {} batch_num:{}".format(repair_step,
272
+ self.batch_num))
198
273
  initial_epoch = int(repair_step / self.batch_num)
199
274
  initial_step = repair_step % self.batch_num
200
275
  kwargs["initial_epoch"] = initial_epoch
201
-
202
- train_dataset = args[1]
203
- dataset_sink_mode = args[3] if len(args) > 3 else kwargs.get('dataset_sink_mode', True)
204
- sink_size = args[4] if len(args) > 4 else kwargs.get('sink_size', -1)
205
-
206
- cb_initial_step = 0
207
- if dataset_sink_mode:
208
- train_dataset.set_init_step(initial_epoch)
209
- dataset_size = train_dataset.get_dataset_size()
210
- if sink_size != -1:
211
- cb_initial_step = initial_epoch * sink_size + initial_step
212
- else:
213
- cb_initial_step = initial_epoch * dataset_size + initial_step
214
- else:
215
- train_dataset.set_init_step(initial_step)
216
- cb_initial_step = initial_step
217
-
218
- kwargs["initial_step"] = cb_initial_step
276
+ cb_initial_step = _calc_cb_initial_step(initial_epoch, initial_step, *args, **kwargs)
277
+ if not self.enable_tre:
278
+ kwargs["initial_step"] = cb_initial_step
279
+ self._initial_step = 0
219
280
  # reset all accu grads to zero
220
281
  obj._reset_acc_grads()
221
282
  logger.warning(
@@ -223,8 +284,9 @@ def _handle_tft(func):
223
284
  cb_initial_step))
224
285
  continue
225
286
  except BaseException as e:
226
- logger.error("uce wrapper caught BaseException error, enter MindIO TTP process.", exc_info=True)
227
- tft.tft_report_error(tft.ReportState.RS_UNKNOWN.value)
287
+ if obj.tft:
288
+ logger.error("uce wrapper caught BaseException error, enter MindIO TTP process.", exc_info=True)
289
+ obj.tft.tft_report_error(obj.tft.ReportState.RS_UNKNOWN.value)
228
290
  raise e
229
291
  else:
230
292
  return func(self, *args, **kwargs)
@@ -241,9 +303,6 @@ def _check_tft():
241
303
  ascend_target = MSContext.get_instance().get_ascend_soc_version()
242
304
  if ascend_target == 'ascend910':
243
305
  raise ValueError("TFT is not supported when using ascend910")
244
- ms_mode = context.get_context("mode")
245
- if ms_mode != mindspore.GRAPH_MODE:
246
- raise ValueError("TFT is only supported in GRAPH_MODE")
247
306
  jit_level = context.get_context("jit_level")
248
307
  if jit_level == "O2" and ("UCE:1" in tft_env or "ARF:1" in tft_env):
249
308
  raise ValueError("TFT is not supported when using jit_level == O2")
@@ -384,6 +443,11 @@ def _set_with_processed_inputs(network, inputs):
384
443
  "Reset inputs from a process inputs, should be a list/tuple or a dict, but got %s!" % str(inputs))
385
444
 
386
445
 
446
+ def _check_tft_reset_dataset():
447
+ env_tft = os.getenv("MS_ENABLE_TFT", "")
448
+ return any([v in env_tft for v in ["TRE:1", "UCE:1", "HCCE:1", "ARF:1"]])
449
+
450
+
387
451
  class Model:
388
452
  """
389
453
  High-Level API for training or inference.
@@ -501,6 +565,10 @@ class Model:
501
565
  self._lite_infer = True # if backend lite infer fails, set False
502
566
  self._mindspore_lite_model_group_id = id(self) & 0xFFFF
503
567
  self.batch_num = -1
568
+ self.enable_tre = "TRE:1" in os.getenv("MS_ENABLE_TFT", "")
569
+ self.enable_hcce = "HCCE:1" in os.getenv("MS_ENABLE_TFT", "")
570
+ self._initial_step = None
571
+ self._need_reset_data = _check_tft_reset_dataset()
504
572
  _clear_auto_parallel_context(self._network)
505
573
 
506
574
  def _check_for_graph_cell(self, kwargs):
@@ -700,7 +768,7 @@ class Model:
700
768
  logger.info("Begin to connect network with dataset.")
701
769
  network = connect_network_with_dataset(network, dataset_helper)
702
770
 
703
- if _get_recovery_context("enable_recovery") and is_train:
771
+ if (_get_recovery_context("enable_recovery") or self._need_reset_data) and is_train:
704
772
  _set_training_dataset(dataset_helper)
705
773
 
706
774
  network.set_train(is_train)
@@ -744,7 +812,7 @@ class Model:
744
812
  """
745
813
  if os.environ.get("MS_ENABLE_CKPT_D2H_ASYNC") != "1":
746
814
  return
747
- if (context.get_context("mode") == context.GRAPH_MODE) and (context.get_context("device_target") == "Ascend"):
815
+ if context.get_context("device_target") == "Ascend":
748
816
  cb_params.need_ckpt, cb_params.save_checkpoint_steps, \
749
817
  cb_params.last_triggered_step = self._check_need_ckpt(cb_params.list_callback)
750
818
  logger.info(f"need_ckpt:{cb_params.need_ckpt},"
@@ -812,8 +880,8 @@ class Model:
812
880
  sink_size (int): Control the amount of data in each sink. Default: -1.
813
881
  epoch (int): Total number of iterations on the data. Default: 1.
814
882
  """
815
- if context.get_context("mode") != context.GRAPH_MODE or context.get_context("device_target") != "Ascend":
816
- raise RuntimeError('Pre-init process only supports GRAPH MODE and Ascend target currently.')
883
+ if context.get_context("device_target") != "Ascend":
884
+ raise RuntimeError('Pre-init process only supports Ascend target currently.')
817
885
 
818
886
  if not train_dataset and not valid_dataset:
819
887
  raise ValueError("The argument 'train_dataset' and 'valid_dataset' can not both be None or empty.")
@@ -957,6 +1025,7 @@ class Model:
957
1025
  cb_params.latest_ckpt_file = None
958
1026
  cb_params.loss_scale_mananger = self._loss_scale_manager
959
1027
  cb_params.is_arf = _get_recovery_context("is_arf")
1028
+ cb_params.initial_step = self._initial_step
960
1029
 
961
1030
  # build callback list
962
1031
  with _CallbackManager(callbacks) as list_callback:
@@ -995,7 +1064,7 @@ class Model:
995
1064
  initial_epoch (int): Epoch at which to start train, it used for resuming a previous training run.
996
1065
  Default: 0.
997
1066
  """
998
- is_graph = (context.get_context("mode") == context.GRAPH_MODE)
1067
+ is_graph = context.get_context("mode") == context.GRAPH_MODE
999
1068
  dataset_size = train_dataset.get_dataset_size()
1000
1069
  if dataset_size % sink_size != 0:
1001
1070
  logger.info("In dataset_sink mode (dataset_size % sink_size) should equal to 0, "
@@ -1064,6 +1133,7 @@ class Model:
1064
1133
  if cb_params.is_arf:
1065
1134
  cb_params.is_arf = False
1066
1135
  _set_recovery_context(is_arf=False)
1136
+ _clean_rootinfo()
1067
1137
 
1068
1138
  # Embedding cache server only run one step.
1069
1139
  if is_embedding_cache_server:
@@ -1142,8 +1212,6 @@ class Model:
1142
1212
  if not enable_recovery:
1143
1213
  self.enable_recovery = False
1144
1214
  else:
1145
- if context.get_context("mode") != context.GRAPH_MODE:
1146
- raise RuntimeError("Recovery for training only support graph mode currently.")
1147
1215
  self.enable_recovery = enable_recovery and _is_role_worker()
1148
1216
 
1149
1217
  def _check_need_load_ckpt(self, cb_params, dataset_size, sink_size=-1):
@@ -1278,6 +1346,7 @@ class Model:
1278
1346
  if cb_params.is_arf:
1279
1347
  cb_params.is_arf = False
1280
1348
  _set_recovery_context(is_arf=False)
1349
+ _clean_rootinfo()
1281
1350
  # Embedding cache server only run one step.
1282
1351
  if is_embedding_cache_server:
1283
1352
  break
@@ -2120,9 +2189,6 @@ class Model:
2120
2189
  dataset_sink_mode (bool): Determines whether to pass the data through dataset channel.
2121
2190
  sink_size (int): Control the amount of data in each sink.
2122
2191
  """
2123
- if context.get_context("mode") != context.GRAPH_MODE:
2124
- raise RuntimeError("Pre-compile process that generate parameter layout for the train network "
2125
- "only supports GRAPH MODE and Ascend target currently.")
2126
2192
  if _get_parallel_mode() not in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
2127
2193
  raise RuntimeError("'infer_train_layout' only supports 'semi_auto_parallel' and 'auto_parallel' "
2128
2194
  "mode, but got {}.".format(_get_parallel_mode()))
@@ -2241,6 +2307,7 @@ class Model:
2241
2307
 
2242
2308
  Examples:
2243
2309
  >>> import numpy as np
2310
+ >>> import mindspore as ms
2244
2311
  >>> import mindspore.nn as nn
2245
2312
  >>> from mindspore import Tensor
2246
2313
  >>> from mindspore.train import Model
@@ -2250,28 +2317,28 @@ class Model:
2250
2317
  >>> from mindspore.parallel.auto_parallel import AutoParallel
2251
2318
  >>>
2252
2319
  >>> class Net(nn.Cell):
2253
- >>> def __init__(self):
2254
- >>> super(Net, self).__init__()
2255
- >>> self.fc1 = nn.Dense(128, 768, activation='relu')
2256
- >>> self.fc2 = nn.Dense(128, 768, activation='relu')
2257
- >>> self.fc3 = nn.Dense(128, 768, activation='relu')
2258
- >>> self.fc4 = nn.Dense(768, 768, activation='relu')
2259
- >>> self.relu4 = nn.ReLU()
2260
- >>> self.relu5 = nn.ReLU()
2261
- >>> self.transpose = P.Transpose()
2262
- >>> self.matmul1 = P.MatMul()
2263
- >>> self.matmul2 = P.MatMul()
2264
- >>>
2265
- >>> def construct(self, x):
2266
- >>> q = self.fc1(x)
2267
- >>> k = self.fc2(x)
2268
- >>> v = self.fc3(x)
2269
- >>> k = self.transpose(k, (1, 0))
2270
- >>> c = self.relu4(self.matmul1(q, k))
2271
- >>> s = self.relu5(self.matmul2(c, v))
2272
- >>> s = self.fc4(s)
2273
- >>> return s
2274
- >>>
2320
+ ... def __init__(self):
2321
+ ... super(Net, self).__init__()
2322
+ ... self.fc1 = nn.Dense(128, 768, activation='relu')
2323
+ ... self.fc2 = nn.Dense(128, 768, activation='relu')
2324
+ ... self.fc3 = nn.Dense(128, 768, activation='relu')
2325
+ ... self.fc4 = nn.Dense(768, 768, activation='relu')
2326
+ ... self.relu4 = nn.ReLU()
2327
+ ... self.relu5 = nn.ReLU()
2328
+ ... self.transpose = P.Transpose()
2329
+ ... self.matmul1 = P.MatMul()
2330
+ ... self.matmul2 = P.MatMul()
2331
+ ...
2332
+ ... def construct(self, x):
2333
+ ... q = self.fc1(x)
2334
+ ... k = self.fc2(x)
2335
+ ... v = self.fc3(x)
2336
+ ... k = self.transpose(k, (1, 0))
2337
+ ... c = self.relu4(self.matmul1(q, k))
2338
+ ... s = self.relu5(self.matmul2(c, v))
2339
+ ... s = self.fc4(s)
2340
+ ... return s
2341
+ ...
2275
2342
  >>> ms.set_context(mode=ms.GRAPH_MODE)
2276
2343
  >>> init()
2277
2344
  >>> inputs = Tensor(np.ones([32, 128]).astype(np.float32))
@@ -2281,9 +2348,6 @@ class Model:
2281
2348
  >>> predict_map = model.infer_predict_layout(inputs)
2282
2349
  """
2283
2350
  _init_auto_parallel_context(self._network)
2284
- if context.get_context("mode") != context.GRAPH_MODE:
2285
- raise RuntimeError("Pre-compile process that generate parameter layout for the predict network "
2286
- "only supports GRAPH MODE and Ascend target currently.")
2287
2351
  if _get_parallel_mode() not in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
2288
2352
  raise RuntimeError('Infer predict layout only supports semi auto parallel and auto parallel mode.')
2289
2353
  _parallel_predict_check()