mindspore 2.4.0__cp311-cp311-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (1387) hide show
  1. mindspore/.commit_id +1 -0
  2. mindspore/__init__.py +53 -0
  3. mindspore/_c_dataengine.cpython-311-darwin.so +0 -0
  4. mindspore/_c_expression.cpython-311-darwin.so +0 -0
  5. mindspore/_c_mindrecord.cpython-311-darwin.so +0 -0
  6. mindspore/_check_jit_forbidden_api.py +106 -0
  7. mindspore/_checkparam.py +1419 -0
  8. mindspore/_extends/__init__.py +23 -0
  9. mindspore/_extends/builtin_operations.py +224 -0
  10. mindspore/_extends/graph_kernel/__init__.py +17 -0
  11. mindspore/_extends/graph_kernel/model/__init__.py +19 -0
  12. mindspore/_extends/graph_kernel/model/graph_parallel.py +311 -0
  13. mindspore/_extends/graph_kernel/model/graph_split.py +1348 -0
  14. mindspore/_extends/graph_kernel/model/model.py +553 -0
  15. mindspore/_extends/graph_kernel/model/model_builder.py +216 -0
  16. mindspore/_extends/graph_kernel/parallel_estimate.py +60 -0
  17. mindspore/_extends/graph_kernel/splitter.py +140 -0
  18. mindspore/_extends/graph_kernel/utils.py +28 -0
  19. mindspore/_extends/parallel_compile/__init__.py +19 -0
  20. mindspore/_extends/parallel_compile/akg_compiler/__init__.py +19 -0
  21. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +269 -0
  22. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +529 -0
  23. mindspore/_extends/parallel_compile/akg_compiler/compiler.py +56 -0
  24. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +96 -0
  25. mindspore/_extends/parallel_compile/akg_compiler/get_file_path.py +36 -0
  26. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +556 -0
  27. mindspore/_extends/parallel_compile/akg_compiler/util.py +159 -0
  28. mindspore/_extends/parse/__init__.py +49 -0
  29. mindspore/_extends/parse/compile_config.py +299 -0
  30. mindspore/_extends/parse/namespace.py +136 -0
  31. mindspore/_extends/parse/parser.py +1448 -0
  32. mindspore/_extends/parse/resources.py +213 -0
  33. mindspore/_extends/parse/standard_method.py +4475 -0
  34. mindspore/_extends/parse/trope.py +97 -0
  35. mindspore/_extends/pijit/__init__.py +23 -0
  36. mindspore/_extends/pijit/pijit_func_white_list.py +669 -0
  37. mindspore/_extends/remote/__init__.py +19 -0
  38. mindspore/_extends/remote/kernel_build_server.py +199 -0
  39. mindspore/_extends/remote/kernel_build_server_akg.py +55 -0
  40. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  41. mindspore/_extends/remote/kernel_build_server_ascend.py +75 -0
  42. mindspore/_extends/utils.py +68 -0
  43. mindspore/_install_custom.py +43 -0
  44. mindspore/_profiler.py +30 -0
  45. mindspore/amp.py +433 -0
  46. mindspore/boost/__init__.py +42 -0
  47. mindspore/boost/adasum.py +319 -0
  48. mindspore/boost/base.py +535 -0
  49. mindspore/boost/boost.py +400 -0
  50. mindspore/boost/boost_cell_wrapper.py +790 -0
  51. mindspore/boost/dim_reduce.py +323 -0
  52. mindspore/boost/grad_accumulation.py +79 -0
  53. mindspore/boost/grad_freeze.py +382 -0
  54. mindspore/boost/group_loss_scale_manager.py +166 -0
  55. mindspore/boost/less_batch_normalization.py +174 -0
  56. mindspore/common/__init__.py +86 -0
  57. mindspore/common/_auto_dynamic.py +68 -0
  58. mindspore/common/_decorator.py +50 -0
  59. mindspore/common/_jit_fallback_utils.py +110 -0
  60. mindspore/common/_monad.py +25 -0
  61. mindspore/common/_pijit_context.py +190 -0
  62. mindspore/common/_register_for_adapter.py +74 -0
  63. mindspore/common/_register_for_recompute.py +48 -0
  64. mindspore/common/_register_for_tensor.py +46 -0
  65. mindspore/common/_stub_tensor.py +210 -0
  66. mindspore/common/_tensor_overload.py +139 -0
  67. mindspore/common/_utils.py +122 -0
  68. mindspore/common/api.py +2064 -0
  69. mindspore/common/auto_dynamic_shape.py +507 -0
  70. mindspore/common/dtype.py +422 -0
  71. mindspore/common/dump.py +130 -0
  72. mindspore/common/file_system.py +48 -0
  73. mindspore/common/generator.py +254 -0
  74. mindspore/common/hook_handle.py +143 -0
  75. mindspore/common/initializer.py +880 -0
  76. mindspore/common/jit_config.py +98 -0
  77. mindspore/common/lazy_inline.py +240 -0
  78. mindspore/common/mindir_util.py +111 -0
  79. mindspore/common/mutable.py +234 -0
  80. mindspore/common/no_inline.py +54 -0
  81. mindspore/common/np_dtype.py +25 -0
  82. mindspore/common/parameter.py +1081 -0
  83. mindspore/common/recompute.py +292 -0
  84. mindspore/common/seed.py +260 -0
  85. mindspore/common/sparse_tensor.py +1175 -0
  86. mindspore/common/symbol.py +122 -0
  87. mindspore/common/tensor.py +5039 -0
  88. mindspore/communication/__init__.py +37 -0
  89. mindspore/communication/_comm_helper.py +501 -0
  90. mindspore/communication/_hccl_management.py +297 -0
  91. mindspore/communication/comm_func.py +1395 -0
  92. mindspore/communication/management.py +673 -0
  93. mindspore/config/op_info.config +533 -0
  94. mindspore/context.py +2077 -0
  95. mindspore/dataset/__init__.py +90 -0
  96. mindspore/dataset/audio/__init__.py +61 -0
  97. mindspore/dataset/audio/transforms.py +3690 -0
  98. mindspore/dataset/audio/utils.py +386 -0
  99. mindspore/dataset/audio/validators.py +1172 -0
  100. mindspore/dataset/callback/__init__.py +20 -0
  101. mindspore/dataset/callback/ds_callback.py +368 -0
  102. mindspore/dataset/callback/validators.py +32 -0
  103. mindspore/dataset/core/__init__.py +13 -0
  104. mindspore/dataset/core/config.py +1095 -0
  105. mindspore/dataset/core/datatypes.py +101 -0
  106. mindspore/dataset/core/py_util_helpers.py +65 -0
  107. mindspore/dataset/core/validator_helpers.py +781 -0
  108. mindspore/dataset/debug/__init__.py +21 -0
  109. mindspore/dataset/debug/debug_hook.py +97 -0
  110. mindspore/dataset/debug/pre_defined_hook.py +67 -0
  111. mindspore/dataset/engine/__init__.py +124 -0
  112. mindspore/dataset/engine/cache_admin.py +47 -0
  113. mindspore/dataset/engine/cache_client.py +129 -0
  114. mindspore/dataset/engine/datasets.py +4582 -0
  115. mindspore/dataset/engine/datasets_audio.py +911 -0
  116. mindspore/dataset/engine/datasets_standard_format.py +543 -0
  117. mindspore/dataset/engine/datasets_text.py +2161 -0
  118. mindspore/dataset/engine/datasets_user_defined.py +1184 -0
  119. mindspore/dataset/engine/datasets_vision.py +4816 -0
  120. mindspore/dataset/engine/iterators.py +371 -0
  121. mindspore/dataset/engine/obs/__init__.py +23 -0
  122. mindspore/dataset/engine/obs/config_loader.py +68 -0
  123. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +508 -0
  124. mindspore/dataset/engine/obs/util.py +482 -0
  125. mindspore/dataset/engine/offload.py +596 -0
  126. mindspore/dataset/engine/queue.py +304 -0
  127. mindspore/dataset/engine/samplers.py +895 -0
  128. mindspore/dataset/engine/serializer_deserializer.py +159 -0
  129. mindspore/dataset/engine/validators.py +2895 -0
  130. mindspore/dataset/text/__init__.py +51 -0
  131. mindspore/dataset/text/transforms.py +1703 -0
  132. mindspore/dataset/text/utils.py +715 -0
  133. mindspore/dataset/text/validators.py +642 -0
  134. mindspore/dataset/transforms/__init__.py +45 -0
  135. mindspore/dataset/transforms/c_transforms.py +638 -0
  136. mindspore/dataset/transforms/py_transforms.py +393 -0
  137. mindspore/dataset/transforms/py_transforms_util.py +255 -0
  138. mindspore/dataset/transforms/transforms.py +1260 -0
  139. mindspore/dataset/transforms/validators.py +410 -0
  140. mindspore/dataset/utils/__init__.py +19 -0
  141. mindspore/dataset/utils/browse_dataset.py +190 -0
  142. mindspore/dataset/utils/line_reader.py +126 -0
  143. mindspore/dataset/vision/__init__.py +65 -0
  144. mindspore/dataset/vision/c_transforms.py +2641 -0
  145. mindspore/dataset/vision/py_transforms.py +2120 -0
  146. mindspore/dataset/vision/py_transforms_util.py +1660 -0
  147. mindspore/dataset/vision/transforms.py +7295 -0
  148. mindspore/dataset/vision/utils.py +863 -0
  149. mindspore/dataset/vision/validators.py +1483 -0
  150. mindspore/default_config.py +2 -0
  151. mindspore/experimental/__init__.py +20 -0
  152. mindspore/experimental/es/__init__.py +22 -0
  153. mindspore/experimental/es/embedding_service.py +883 -0
  154. mindspore/experimental/es/embedding_service_layer.py +581 -0
  155. mindspore/experimental/llm_boost/__init__.py +21 -0
  156. mindspore/experimental/llm_boost/atb/__init__.py +23 -0
  157. mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
  158. mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
  159. mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
  160. mindspore/experimental/llm_boost/register.py +129 -0
  161. mindspore/experimental/llm_boost/utils.py +31 -0
  162. mindspore/experimental/map_parameter.py +309 -0
  163. mindspore/experimental/optim/__init__.py +40 -0
  164. mindspore/experimental/optim/adadelta.py +161 -0
  165. mindspore/experimental/optim/adagrad.py +168 -0
  166. mindspore/experimental/optim/adam.py +193 -0
  167. mindspore/experimental/optim/adamax.py +170 -0
  168. mindspore/experimental/optim/adamw.py +290 -0
  169. mindspore/experimental/optim/asgd.py +153 -0
  170. mindspore/experimental/optim/lr_scheduler.py +1371 -0
  171. mindspore/experimental/optim/nadam.py +157 -0
  172. mindspore/experimental/optim/optimizer.py +262 -0
  173. mindspore/experimental/optim/radam.py +194 -0
  174. mindspore/experimental/optim/rmsprop.py +154 -0
  175. mindspore/experimental/optim/rprop.py +164 -0
  176. mindspore/experimental/optim/sgd.py +156 -0
  177. mindspore/hal/__init__.py +40 -0
  178. mindspore/hal/_ascend.py +57 -0
  179. mindspore/hal/_base.py +57 -0
  180. mindspore/hal/_cpu.py +56 -0
  181. mindspore/hal/_gpu.py +57 -0
  182. mindspore/hal/contiguous_tensors_handle.py +175 -0
  183. mindspore/hal/device.py +356 -0
  184. mindspore/hal/event.py +179 -0
  185. mindspore/hal/memory.py +326 -0
  186. mindspore/hal/stream.py +357 -0
  187. mindspore/include/OWNERS +7 -0
  188. mindspore/include/api/allocator.h +97 -0
  189. mindspore/include/api/callback/callback.h +93 -0
  190. mindspore/include/api/callback/ckpt_saver.h +41 -0
  191. mindspore/include/api/callback/loss_monitor.h +33 -0
  192. mindspore/include/api/callback/lr_scheduler.h +51 -0
  193. mindspore/include/api/callback/time_monitor.h +34 -0
  194. mindspore/include/api/callback/train_accuracy.h +37 -0
  195. mindspore/include/api/cell.h +90 -0
  196. mindspore/include/api/cfg.h +82 -0
  197. mindspore/include/api/context.h +602 -0
  198. mindspore/include/api/data_type.h +47 -0
  199. mindspore/include/api/delegate.h +178 -0
  200. mindspore/include/api/delegate_api.h +75 -0
  201. mindspore/include/api/dual_abi_helper.h +208 -0
  202. mindspore/include/api/format.h +28 -0
  203. mindspore/include/api/graph.h +46 -0
  204. mindspore/include/api/kernel.h +58 -0
  205. mindspore/include/api/kernel_api.h +168 -0
  206. mindspore/include/api/metrics/accuracy.h +36 -0
  207. mindspore/include/api/metrics/metrics.h +41 -0
  208. mindspore/include/api/model.h +438 -0
  209. mindspore/include/api/model_group.h +91 -0
  210. mindspore/include/api/model_parallel_runner.h +168 -0
  211. mindspore/include/api/serialization.h +185 -0
  212. mindspore/include/api/status.h +192 -0
  213. mindspore/include/api/types.h +431 -0
  214. mindspore/include/api/visible.h +41 -0
  215. mindspore/include/c_api/context_c.h +179 -0
  216. mindspore/include/c_api/data_type_c.h +52 -0
  217. mindspore/include/c_api/format_c.h +46 -0
  218. mindspore/include/c_api/model_c.h +347 -0
  219. mindspore/include/c_api/status_c.h +79 -0
  220. mindspore/include/c_api/tensor_c.h +146 -0
  221. mindspore/include/c_api/types_c.h +67 -0
  222. mindspore/include/dataset/config.h +163 -0
  223. mindspore/include/dataset/constants.h +363 -0
  224. mindspore/include/dataset/execute.h +196 -0
  225. mindspore/include/dataset/text.h +1092 -0
  226. mindspore/include/dataset/transforms.h +638 -0
  227. mindspore/include/dataset/vision.h +2129 -0
  228. mindspore/include/dataset/vision_ascend.h +206 -0
  229. mindspore/include/dataset/vision_lite.h +625 -0
  230. mindspore/lib/libavcodec.59.dylib +0 -0
  231. mindspore/lib/libavdevice.59.dylib +0 -0
  232. mindspore/lib/libavfilter.8.dylib +0 -0
  233. mindspore/lib/libavformat.59.dylib +0 -0
  234. mindspore/lib/libavutil.57.dylib +0 -0
  235. mindspore/lib/libdnnl.2.dylib +0 -0
  236. mindspore/lib/libicudata.69.dylib +0 -0
  237. mindspore/lib/libicui18n.69.dylib +0 -0
  238. mindspore/lib/libicuuc.69.dylib +0 -0
  239. mindspore/lib/libmindspore_address_sorting.15.dylib +0 -0
  240. mindspore/lib/libmindspore_backend.dylib +0 -0
  241. mindspore/lib/libmindspore_common.dylib +0 -0
  242. mindspore/lib/libmindspore_core.dylib +0 -0
  243. mindspore/lib/libmindspore_glog.0.dylib +0 -0
  244. mindspore/lib/libmindspore_gpr.15.dylib +0 -0
  245. mindspore/lib/libmindspore_grpc++.1.dylib +0 -0
  246. mindspore/lib/libmindspore_grpc.15.dylib +0 -0
  247. mindspore/lib/libmindspore_np_dtype.dylib +0 -0
  248. mindspore/lib/libmindspore_ops.dylib +0 -0
  249. mindspore/lib/libmindspore_upb.15.dylib +0 -0
  250. mindspore/lib/libnnacl.dylib +0 -0
  251. mindspore/lib/libopencv_core.4.5.dylib +0 -0
  252. mindspore/lib/libopencv_imgcodecs.4.5.dylib +0 -0
  253. mindspore/lib/libopencv_imgproc.4.5.dylib +0 -0
  254. mindspore/lib/libps_cache.dylib +0 -0
  255. mindspore/lib/libswresample.4.dylib +0 -0
  256. mindspore/lib/libswscale.6.dylib +0 -0
  257. mindspore/lib/libtinyxml2.8.dylib +0 -0
  258. mindspore/log.py +633 -0
  259. mindspore/mindrecord/__init__.py +43 -0
  260. mindspore/mindrecord/common/__init__.py +17 -0
  261. mindspore/mindrecord/common/constant.py +20 -0
  262. mindspore/mindrecord/common/enums.py +44 -0
  263. mindspore/mindrecord/common/exceptions.py +311 -0
  264. mindspore/mindrecord/config.py +809 -0
  265. mindspore/mindrecord/filereader.py +174 -0
  266. mindspore/mindrecord/filewriter.py +722 -0
  267. mindspore/mindrecord/mindpage.py +210 -0
  268. mindspore/mindrecord/shardheader.py +141 -0
  269. mindspore/mindrecord/shardindexgenerator.py +74 -0
  270. mindspore/mindrecord/shardreader.py +117 -0
  271. mindspore/mindrecord/shardsegment.py +128 -0
  272. mindspore/mindrecord/shardutils.py +185 -0
  273. mindspore/mindrecord/shardwriter.py +237 -0
  274. mindspore/mindrecord/tools/__init__.py +17 -0
  275. mindspore/mindrecord/tools/cifar10.py +140 -0
  276. mindspore/mindrecord/tools/cifar100.py +153 -0
  277. mindspore/mindrecord/tools/cifar100_to_mr.py +185 -0
  278. mindspore/mindrecord/tools/cifar10_to_mr.py +177 -0
  279. mindspore/mindrecord/tools/csv_to_mr.py +200 -0
  280. mindspore/mindrecord/tools/imagenet_to_mr.py +206 -0
  281. mindspore/mindrecord/tools/mnist_to_mr.py +259 -0
  282. mindspore/mindrecord/tools/tfrecord_to_mr.py +360 -0
  283. mindspore/mint/__init__.py +1586 -0
  284. mindspore/mint/distributed/__init__.py +31 -0
  285. mindspore/mint/distributed/distributed.py +254 -0
  286. mindspore/mint/linalg/__init__.py +22 -0
  287. mindspore/mint/nn/__init__.py +757 -0
  288. mindspore/mint/nn/functional.py +679 -0
  289. mindspore/mint/nn/layer/__init__.py +39 -0
  290. mindspore/mint/nn/layer/activation.py +133 -0
  291. mindspore/mint/nn/layer/normalization.py +477 -0
  292. mindspore/mint/nn/layer/pooling.py +110 -0
  293. mindspore/mint/optim/__init__.py +24 -0
  294. mindspore/mint/optim/adamw.py +206 -0
  295. mindspore/mint/special/__init__.py +63 -0
  296. mindspore/multiprocessing/__init__.py +73 -0
  297. mindspore/nn/__init__.py +47 -0
  298. mindspore/nn/cell.py +2787 -0
  299. mindspore/nn/dynamic_lr.py +482 -0
  300. mindspore/nn/grad/__init__.py +21 -0
  301. mindspore/nn/grad/cell_grad.py +196 -0
  302. mindspore/nn/layer/__init__.py +63 -0
  303. mindspore/nn/layer/activation.py +1822 -0
  304. mindspore/nn/layer/basic.py +1629 -0
  305. mindspore/nn/layer/channel_shuffle.py +90 -0
  306. mindspore/nn/layer/combined.py +248 -0
  307. mindspore/nn/layer/container.py +734 -0
  308. mindspore/nn/layer/conv.py +1505 -0
  309. mindspore/nn/layer/dense.py +204 -0
  310. mindspore/nn/layer/embedding.py +869 -0
  311. mindspore/nn/layer/image.py +661 -0
  312. mindspore/nn/layer/math.py +1069 -0
  313. mindspore/nn/layer/normalization.py +1273 -0
  314. mindspore/nn/layer/padding.py +880 -0
  315. mindspore/nn/layer/pooling.py +2302 -0
  316. mindspore/nn/layer/rnn_cells.py +388 -0
  317. mindspore/nn/layer/rnns.py +849 -0
  318. mindspore/nn/layer/thor_layer.py +963 -0
  319. mindspore/nn/layer/timedistributed.py +155 -0
  320. mindspore/nn/layer/transformer.py +823 -0
  321. mindspore/nn/learning_rate_schedule.py +512 -0
  322. mindspore/nn/loss/__init__.py +36 -0
  323. mindspore/nn/loss/loss.py +2924 -0
  324. mindspore/nn/metrics.py +53 -0
  325. mindspore/nn/optim/__init__.py +45 -0
  326. mindspore/nn/optim/_dist_optimizer_registry.py +111 -0
  327. mindspore/nn/optim/ada_grad.py +217 -0
  328. mindspore/nn/optim/adadelta.py +206 -0
  329. mindspore/nn/optim/adafactor.py +448 -0
  330. mindspore/nn/optim/adam.py +1297 -0
  331. mindspore/nn/optim/adamax.py +220 -0
  332. mindspore/nn/optim/adasum.py +548 -0
  333. mindspore/nn/optim/asgd.py +216 -0
  334. mindspore/nn/optim/ftrl.py +401 -0
  335. mindspore/nn/optim/lamb.py +296 -0
  336. mindspore/nn/optim/lars.py +202 -0
  337. mindspore/nn/optim/lazyadam.py +533 -0
  338. mindspore/nn/optim/momentum.py +239 -0
  339. mindspore/nn/optim/optimizer.py +1034 -0
  340. mindspore/nn/optim/proximal_ada_grad.py +242 -0
  341. mindspore/nn/optim/rmsprop.py +264 -0
  342. mindspore/nn/optim/rprop.py +251 -0
  343. mindspore/nn/optim/sgd.py +237 -0
  344. mindspore/nn/optim/tft_wrapper.py +127 -0
  345. mindspore/nn/optim/thor.py +1310 -0
  346. mindspore/nn/probability/__init__.py +22 -0
  347. mindspore/nn/probability/bijector/__init__.py +35 -0
  348. mindspore/nn/probability/bijector/bijector.py +337 -0
  349. mindspore/nn/probability/bijector/exp.py +65 -0
  350. mindspore/nn/probability/bijector/gumbel_cdf.py +144 -0
  351. mindspore/nn/probability/bijector/invert.py +126 -0
  352. mindspore/nn/probability/bijector/power_transform.py +196 -0
  353. mindspore/nn/probability/bijector/scalar_affine.py +167 -0
  354. mindspore/nn/probability/bijector/softplus.py +189 -0
  355. mindspore/nn/probability/bnn_layers/__init__.py +29 -0
  356. mindspore/nn/probability/bnn_layers/_util.py +46 -0
  357. mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +112 -0
  358. mindspore/nn/probability/bnn_layers/conv_variational.py +267 -0
  359. mindspore/nn/probability/bnn_layers/dense_variational.py +302 -0
  360. mindspore/nn/probability/bnn_layers/layer_distribution.py +123 -0
  361. mindspore/nn/probability/distribution/__init__.py +56 -0
  362. mindspore/nn/probability/distribution/_utils/__init__.py +34 -0
  363. mindspore/nn/probability/distribution/_utils/custom_ops.py +96 -0
  364. mindspore/nn/probability/distribution/_utils/utils.py +362 -0
  365. mindspore/nn/probability/distribution/bernoulli.py +334 -0
  366. mindspore/nn/probability/distribution/beta.py +391 -0
  367. mindspore/nn/probability/distribution/categorical.py +435 -0
  368. mindspore/nn/probability/distribution/cauchy.py +383 -0
  369. mindspore/nn/probability/distribution/distribution.py +827 -0
  370. mindspore/nn/probability/distribution/exponential.py +350 -0
  371. mindspore/nn/probability/distribution/gamma.py +391 -0
  372. mindspore/nn/probability/distribution/geometric.py +335 -0
  373. mindspore/nn/probability/distribution/gumbel.py +257 -0
  374. mindspore/nn/probability/distribution/half_normal.py +133 -0
  375. mindspore/nn/probability/distribution/laplace.py +128 -0
  376. mindspore/nn/probability/distribution/log_normal.py +272 -0
  377. mindspore/nn/probability/distribution/logistic.py +379 -0
  378. mindspore/nn/probability/distribution/normal.py +336 -0
  379. mindspore/nn/probability/distribution/poisson.py +288 -0
  380. mindspore/nn/probability/distribution/student_t.py +149 -0
  381. mindspore/nn/probability/distribution/transformed_distribution.py +235 -0
  382. mindspore/nn/probability/distribution/uniform.py +375 -0
  383. mindspore/nn/reinforcement/__init__.py +24 -0
  384. mindspore/nn/reinforcement/_batch_read_write.py +142 -0
  385. mindspore/nn/reinforcement/_tensors_queue.py +152 -0
  386. mindspore/nn/reinforcement/tensor_array.py +145 -0
  387. mindspore/nn/sparse/__init__.py +23 -0
  388. mindspore/nn/sparse/sparse.py +147 -0
  389. mindspore/nn/wrap/__init__.py +49 -0
  390. mindspore/nn/wrap/cell_wrapper.py +968 -0
  391. mindspore/nn/wrap/grad_reducer.py +608 -0
  392. mindspore/nn/wrap/loss_scale.py +694 -0
  393. mindspore/numpy/__init__.py +121 -0
  394. mindspore/numpy/array_creations.py +2731 -0
  395. mindspore/numpy/array_ops.py +2629 -0
  396. mindspore/numpy/dtypes.py +185 -0
  397. mindspore/numpy/fft.py +966 -0
  398. mindspore/numpy/logic_ops.py +936 -0
  399. mindspore/numpy/math_ops.py +5911 -0
  400. mindspore/numpy/utils.py +214 -0
  401. mindspore/numpy/utils_const.py +565 -0
  402. mindspore/ops/__init__.py +56 -0
  403. mindspore/ops/_constants.py +30 -0
  404. mindspore/ops/_grad_experimental/__init__.py +31 -0
  405. mindspore/ops/_grad_experimental/grad_array_ops.py +830 -0
  406. mindspore/ops/_grad_experimental/grad_base.py +143 -0
  407. mindspore/ops/_grad_experimental/grad_comm_ops.py +714 -0
  408. mindspore/ops/_grad_experimental/grad_debug_ops.py +31 -0
  409. mindspore/ops/_grad_experimental/grad_implementations.py +203 -0
  410. mindspore/ops/_grad_experimental/grad_inner_ops.py +79 -0
  411. mindspore/ops/_grad_experimental/grad_math_ops.py +802 -0
  412. mindspore/ops/_grad_experimental/grad_nn_ops.py +231 -0
  413. mindspore/ops/_grad_experimental/grad_quant_ops.py +238 -0
  414. mindspore/ops/_grad_experimental/grad_sparse.py +342 -0
  415. mindspore/ops/_grad_experimental/grad_sparse_ops.py +399 -0
  416. mindspore/ops/_grad_experimental/taylor_rule.py +220 -0
  417. mindspore/ops/_op_impl/__init__.py +23 -0
  418. mindspore/ops/_op_impl/_custom_op/__init__.py +39 -0
  419. mindspore/ops/_op_impl/_custom_op/_basic.py +158 -0
  420. mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +279 -0
  421. mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +156 -0
  422. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +109 -0
  423. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +125 -0
  424. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +105 -0
  425. mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +124 -0
  426. mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +116 -0
  427. mindspore/ops/_op_impl/_custom_op/correction_mul.py +89 -0
  428. mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +196 -0
  429. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +366 -0
  430. mindspore/ops/_op_impl/_custom_op/dsd_impl.py +162 -0
  431. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +136 -0
  432. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +206 -0
  433. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +88 -0
  434. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +128 -0
  435. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +199 -0
  436. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +88 -0
  437. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +156 -0
  438. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +184 -0
  439. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +143 -0
  440. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +169 -0
  441. mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +548 -0
  442. mindspore/ops/_op_impl/_custom_op/img2col_impl.py +881 -0
  443. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +278 -0
  444. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +200 -0
  445. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +334 -0
  446. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +255 -0
  447. mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +222 -0
  448. mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +644 -0
  449. mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +488 -0
  450. mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +87 -0
  451. mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +129 -0
  452. mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +121 -0
  453. mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +352 -0
  454. mindspore/ops/_op_impl/aicpu/__init__.py +441 -0
  455. mindspore/ops/_op_impl/aicpu/abs.py +36 -0
  456. mindspore/ops/_op_impl/aicpu/acos.py +32 -0
  457. mindspore/ops/_op_impl/aicpu/acos_grad.py +33 -0
  458. mindspore/ops/_op_impl/aicpu/acosh.py +34 -0
  459. mindspore/ops/_op_impl/aicpu/acosh_grad.py +35 -0
  460. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
  461. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
  462. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
  463. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
  464. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
  465. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
  466. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
  467. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
  468. mindspore/ops/_op_impl/aicpu/add.py +43 -0
  469. mindspore/ops/_op_impl/aicpu/add_n.py +41 -0
  470. mindspore/ops/_op_impl/aicpu/add_v2.py +40 -0
  471. mindspore/ops/_op_impl/aicpu/addcdiv.py +41 -0
  472. mindspore/ops/_op_impl/aicpu/addcmul.py +47 -0
  473. mindspore/ops/_op_impl/aicpu/adjust_contrastv2.py +32 -0
  474. mindspore/ops/_op_impl/aicpu/adjust_hue.py +31 -0
  475. mindspore/ops/_op_impl/aicpu/adjust_saturation.py +32 -0
  476. mindspore/ops/_op_impl/aicpu/affine_grid.py +33 -0
  477. mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
  478. mindspore/ops/_op_impl/aicpu/angle.py +31 -0
  479. mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
  480. mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
  481. mindspore/ops/_op_impl/aicpu/argmax_with_value.py +43 -0
  482. mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
  483. mindspore/ops/_op_impl/aicpu/asin.py +32 -0
  484. mindspore/ops/_op_impl/aicpu/asin_grad.py +33 -0
  485. mindspore/ops/_op_impl/aicpu/asinh.py +34 -0
  486. mindspore/ops/_op_impl/aicpu/asinh_grad.py +35 -0
  487. mindspore/ops/_op_impl/aicpu/atanh.py +34 -0
  488. mindspore/ops/_op_impl/aicpu/avgpool_grad_v1.py +37 -0
  489. mindspore/ops/_op_impl/aicpu/avgpool_v1.py +36 -0
  490. mindspore/ops/_op_impl/aicpu/bartlett_window.py +36 -0
  491. mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
  492. mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
  493. mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
  494. mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
  495. mindspore/ops/_op_impl/aicpu/betainc.py +31 -0
  496. mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
  497. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +42 -0
  498. mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
  499. mindspore/ops/_op_impl/aicpu/blackman_window.py +36 -0
  500. mindspore/ops/_op_impl/aicpu/broadcast_to.py +58 -0
  501. mindspore/ops/_op_impl/aicpu/bucketize.py +34 -0
  502. mindspore/ops/_op_impl/aicpu/cache_swap_table.py +102 -0
  503. mindspore/ops/_op_impl/aicpu/cast.py +225 -0
  504. mindspore/ops/_op_impl/aicpu/cauchy.py +33 -0
  505. mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
  506. mindspore/ops/_op_impl/aicpu/check_numerics.py +33 -0
  507. mindspore/ops/_op_impl/aicpu/cholesky.py +32 -0
  508. mindspore/ops/_op_impl/aicpu/cholesky_inverse.py +31 -0
  509. mindspore/ops/_op_impl/aicpu/cholesky_solve.py +33 -0
  510. mindspore/ops/_op_impl/aicpu/choleskygrad.py +32 -0
  511. mindspore/ops/_op_impl/aicpu/coalesce.py +37 -0
  512. mindspore/ops/_op_impl/aicpu/col2im.py +38 -0
  513. mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
  514. mindspore/ops/_op_impl/aicpu/compare_and_bitpack.py +37 -0
  515. mindspore/ops/_op_impl/aicpu/complex.py +32 -0
  516. mindspore/ops/_op_impl/aicpu/complex_abs.py +31 -0
  517. mindspore/ops/_op_impl/aicpu/compute_accidental_hits.py +44 -0
  518. mindspore/ops/_op_impl/aicpu/concat.py +57 -0
  519. mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
  520. mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
  521. mindspore/ops/_op_impl/aicpu/conj.py +42 -0
  522. mindspore/ops/_op_impl/aicpu/conjugate_transpose.py +58 -0
  523. mindspore/ops/_op_impl/aicpu/cos.py +34 -0
  524. mindspore/ops/_op_impl/aicpu/cosh.py +34 -0
  525. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  526. mindspore/ops/_op_impl/aicpu/crop_and_resize.py +69 -0
  527. mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_boxes.py +68 -0
  528. mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
  529. mindspore/ops/_op_impl/aicpu/cross.py +42 -0
  530. mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_dense.py +48 -0
  531. mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_sparse_tensor.py +51 -0
  532. mindspore/ops/_op_impl/aicpu/ctc_greedy_decoder.py +35 -0
  533. mindspore/ops/_op_impl/aicpu/ctc_loss_v2.py +43 -0
  534. mindspore/ops/_op_impl/aicpu/ctc_loss_v2_grad.py +45 -0
  535. mindspore/ops/_op_impl/aicpu/ctcloss.py +38 -0
  536. mindspore/ops/_op_impl/aicpu/cummax.py +41 -0
  537. mindspore/ops/_op_impl/aicpu/cumprod.py +58 -0
  538. mindspore/ops/_op_impl/aicpu/cumsum.py +58 -0
  539. mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
  540. mindspore/ops/_op_impl/aicpu/data_format_vec_permute.py +32 -0
  541. mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
  542. mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
  543. mindspore/ops/_op_impl/aicpu/dense_to_csr_sparse_matrix.py +49 -0
  544. mindspore/ops/_op_impl/aicpu/dense_to_dense_set_operation.py +45 -0
  545. mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
  546. mindspore/ops/_op_impl/aicpu/depth_to_space.py +44 -0
  547. mindspore/ops/_op_impl/aicpu/diag.py +36 -0
  548. mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
  549. mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
  550. mindspore/ops/_op_impl/aicpu/digamma.py +31 -0
  551. mindspore/ops/_op_impl/aicpu/div.py +41 -0
  552. mindspore/ops/_op_impl/aicpu/div_no_nan.py +35 -0
  553. mindspore/ops/_op_impl/aicpu/dropout2d.py +42 -0
  554. mindspore/ops/_op_impl/aicpu/dropout3d.py +42 -0
  555. mindspore/ops/_op_impl/aicpu/dropout_genmask.py +41 -0
  556. mindspore/ops/_op_impl/aicpu/dropout_genmask_v3.py +32 -0
  557. mindspore/ops/_op_impl/aicpu/dynamic_stitch.py +42 -0
  558. mindspore/ops/_op_impl/aicpu/edit_distance.py +56 -0
  559. mindspore/ops/_op_impl/aicpu/eig.py +35 -0
  560. mindspore/ops/_op_impl/aicpu/embedding_lookup.py +102 -0
  561. mindspore/ops/_op_impl/aicpu/end_of_sequence.py +30 -0
  562. mindspore/ops/_op_impl/aicpu/environ_create.py +28 -0
  563. mindspore/ops/_op_impl/aicpu/environ_destroy_all.py +28 -0
  564. mindspore/ops/_op_impl/aicpu/environ_get.py +41 -0
  565. mindspore/ops/_op_impl/aicpu/environ_set.py +40 -0
  566. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  567. mindspore/ops/_op_impl/aicpu/equal.py +41 -0
  568. mindspore/ops/_op_impl/aicpu/exp.py +37 -0
  569. mindspore/ops/_op_impl/aicpu/expand.py +45 -0
  570. mindspore/ops/_op_impl/aicpu/expand_dims.py +42 -0
  571. mindspore/ops/_op_impl/aicpu/expm1.py +34 -0
  572. mindspore/ops/_op_impl/aicpu/extract_glimpse.py +35 -0
  573. mindspore/ops/_op_impl/aicpu/eye.py +44 -0
  574. mindspore/ops/_op_impl/aicpu/fft_with_size.py +47 -0
  575. mindspore/ops/_op_impl/aicpu/fill_diagonal.py +39 -0
  576. mindspore/ops/_op_impl/aicpu/fill_v2.py +58 -0
  577. mindspore/ops/_op_impl/aicpu/flatten.py +43 -0
  578. mindspore/ops/_op_impl/aicpu/floor_div.py +38 -0
  579. mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
  580. mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
  581. mindspore/ops/_op_impl/aicpu/fractional_avg_pool.py +41 -0
  582. mindspore/ops/_op_impl/aicpu/fractional_avg_pool_grad.py +41 -0
  583. mindspore/ops/_op_impl/aicpu/fractional_max_pool.py +41 -0
  584. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_grad_with_fixed_ksize.py +43 -0
  585. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +65 -0
  586. mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad.py +42 -0
  587. mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad_with_fixed_ksize.py +42 -0
  588. mindspore/ops/_op_impl/aicpu/fractional_max_pool_with_fixed_ksize.py +49 -0
  589. mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
  590. mindspore/ops/_op_impl/aicpu/fused_sparse_adam.py +46 -0
  591. mindspore/ops/_op_impl/aicpu/fused_sparse_ftrl.py +41 -0
  592. mindspore/ops/_op_impl/aicpu/fused_sparse_lazy_adam.py +46 -0
  593. mindspore/ops/_op_impl/aicpu/fused_sparse_proximal_adagrad.py +39 -0
  594. mindspore/ops/_op_impl/aicpu/gamma.py +38 -0
  595. mindspore/ops/_op_impl/aicpu/gather.py +46 -0
  596. mindspore/ops/_op_impl/aicpu/gather_d.py +79 -0
  597. mindspore/ops/_op_impl/aicpu/gather_d_grad_v2.py +79 -0
  598. mindspore/ops/_op_impl/aicpu/gather_grad.py +54 -0
  599. mindspore/ops/_op_impl/aicpu/gather_nd.py +56 -0
  600. mindspore/ops/_op_impl/aicpu/gcd.py +32 -0
  601. mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +38 -0
  602. mindspore/ops/_op_impl/aicpu/geqrf.py +32 -0
  603. mindspore/ops/_op_impl/aicpu/get_next.py +39 -0
  604. mindspore/ops/_op_impl/aicpu/glu.py +33 -0
  605. mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
  606. mindspore/ops/_op_impl/aicpu/greater.py +41 -0
  607. mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
  608. mindspore/ops/_op_impl/aicpu/grid_sampler_2d.py +35 -0
  609. mindspore/ops/_op_impl/aicpu/grid_sampler_2d_grad.py +38 -0
  610. mindspore/ops/_op_impl/aicpu/grid_sampler_3d.py +34 -0
  611. mindspore/ops/_op_impl/aicpu/grid_sampler_3d_grad.py +38 -0
  612. mindspore/ops/_op_impl/aicpu/hamming_window.py +57 -0
  613. mindspore/ops/_op_impl/aicpu/hard_sigmoid.py +32 -0
  614. mindspore/ops/_op_impl/aicpu/hard_sigmoid_grad.py +33 -0
  615. mindspore/ops/_op_impl/aicpu/heaviside.py +40 -0
  616. mindspore/ops/_op_impl/aicpu/histogram.py +35 -0
  617. mindspore/ops/_op_impl/aicpu/hsv_to_rgb.py +32 -0
  618. mindspore/ops/_op_impl/aicpu/hypot.py +32 -0
  619. mindspore/ops/_op_impl/aicpu/identity.py +42 -0
  620. mindspore/ops/_op_impl/aicpu/identity_n.py +41 -0
  621. mindspore/ops/_op_impl/aicpu/igamma.py +30 -0
  622. mindspore/ops/_op_impl/aicpu/igammac.py +30 -0
  623. mindspore/ops/_op_impl/aicpu/igammagrada.py +30 -0
  624. mindspore/ops/_op_impl/aicpu/im2col.py +43 -0
  625. mindspore/ops/_op_impl/aicpu/imag.py +31 -0
  626. mindspore/ops/_op_impl/aicpu/index_fill.py +54 -0
  627. mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
  628. mindspore/ops/_op_impl/aicpu/init_data_set_queue.py +27 -0
  629. mindspore/ops/_op_impl/aicpu/inplace_index_add.py +39 -0
  630. mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
  631. mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
  632. mindspore/ops/_op_impl/aicpu/is_finite.py +40 -0
  633. mindspore/ops/_op_impl/aicpu/is_inf.py +31 -0
  634. mindspore/ops/_op_impl/aicpu/is_nan.py +31 -0
  635. mindspore/ops/_op_impl/aicpu/kldivloss.py +34 -0
  636. mindspore/ops/_op_impl/aicpu/kldivlossgrad.py +35 -0
  637. mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
  638. mindspore/ops/_op_impl/aicpu/lcm.py +32 -0
  639. mindspore/ops/_op_impl/aicpu/left_shift.py +38 -0
  640. mindspore/ops/_op_impl/aicpu/less.py +41 -0
  641. mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
  642. mindspore/ops/_op_impl/aicpu/lgamma.py +33 -0
  643. mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +57 -0
  644. mindspore/ops/_op_impl/aicpu/linspace.py +33 -0
  645. mindspore/ops/_op_impl/aicpu/list_diff.py +50 -0
  646. mindspore/ops/_op_impl/aicpu/log.py +37 -0
  647. mindspore/ops/_op_impl/aicpu/log1p.py +34 -0
  648. mindspore/ops/_op_impl/aicpu/log_matrix_determinant.py +31 -0
  649. mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
  650. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +37 -0
  651. mindspore/ops/_op_impl/aicpu/logical_xor.py +30 -0
  652. mindspore/ops/_op_impl/aicpu/logit.py +33 -0
  653. mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
  654. mindspore/ops/_op_impl/aicpu/logspace.py +36 -0
  655. mindspore/ops/_op_impl/aicpu/lower_bound.py +47 -0
  656. mindspore/ops/_op_impl/aicpu/lstsq.py +34 -0
  657. mindspore/ops/_op_impl/aicpu/lu.py +39 -0
  658. mindspore/ops/_op_impl/aicpu/lu_solve.py +32 -0
  659. mindspore/ops/_op_impl/aicpu/lu_unpack.py +114 -0
  660. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +49 -0
  661. mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
  662. mindspore/ops/_op_impl/aicpu/masked_scatter.py +40 -0
  663. mindspore/ops/_op_impl/aicpu/masked_select.py +31 -0
  664. mindspore/ops/_op_impl/aicpu/masked_select_grad.py +35 -0
  665. mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
  666. mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
  667. mindspore/ops/_op_impl/aicpu/matrix_determinant.py +30 -0
  668. mindspore/ops/_op_impl/aicpu/matrix_diag_part_v3.py +54 -0
  669. mindspore/ops/_op_impl/aicpu/matrix_diag_v3.py +56 -0
  670. mindspore/ops/_op_impl/aicpu/matrix_exp.py +34 -0
  671. mindspore/ops/_op_impl/aicpu/matrix_inverse.py +31 -0
  672. mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
  673. mindspore/ops/_op_impl/aicpu/matrix_power.py +37 -0
  674. mindspore/ops/_op_impl/aicpu/matrix_set_diag_v3.py +54 -0
  675. mindspore/ops/_op_impl/aicpu/matrix_solve.py +35 -0
  676. mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
  677. mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
  678. mindspore/ops/_op_impl/aicpu/max_pool3d_grad_with_argmax.py +60 -0
  679. mindspore/ops/_op_impl/aicpu/max_pool3d_with_argmax.py +59 -0
  680. mindspore/ops/_op_impl/aicpu/max_unpool2d.py +57 -0
  681. mindspore/ops/_op_impl/aicpu/max_unpool2d_grad.py +58 -0
  682. mindspore/ops/_op_impl/aicpu/max_unpool3d.py +57 -0
  683. mindspore/ops/_op_impl/aicpu/max_unpool3d_grad.py +58 -0
  684. mindspore/ops/_op_impl/aicpu/maximum_grad_grad.py +40 -0
  685. mindspore/ops/_op_impl/aicpu/maxpool_grad_v1.py +46 -0
  686. mindspore/ops/_op_impl/aicpu/maxpool_v1.py +42 -0
  687. mindspore/ops/_op_impl/aicpu/median.py +39 -0
  688. mindspore/ops/_op_impl/aicpu/median_grad.py +45 -0
  689. mindspore/ops/_op_impl/aicpu/meshgrid.py +41 -0
  690. mindspore/ops/_op_impl/aicpu/minimum_grad_grad.py +40 -0
  691. mindspore/ops/_op_impl/aicpu/mirror_pad.py +50 -0
  692. mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +48 -0
  693. mindspore/ops/_op_impl/aicpu/mul.py +43 -0
  694. mindspore/ops/_op_impl/aicpu/mul_no_nan.py +42 -0
  695. mindspore/ops/_op_impl/aicpu/multi_margin_loss.py +37 -0
  696. mindspore/ops/_op_impl/aicpu/multi_margin_loss_grad.py +41 -0
  697. mindspore/ops/_op_impl/aicpu/multilabel_margin_loss_grad.py +37 -0
  698. mindspore/ops/_op_impl/aicpu/multinomial.py +47 -0
  699. mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
  700. mindspore/ops/_op_impl/aicpu/mvlgamma.py +32 -0
  701. mindspore/ops/_op_impl/aicpu/mvlgamma_grad.py +33 -0
  702. mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
  703. mindspore/ops/_op_impl/aicpu/neg.py +36 -0
  704. mindspore/ops/_op_impl/aicpu/nextafter.py +32 -0
  705. mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
  706. mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
  707. mindspore/ops/_op_impl/aicpu/no_repeat_ngram.py +34 -0
  708. mindspore/ops/_op_impl/aicpu/non_deterministic_ints.py +33 -0
  709. mindspore/ops/_op_impl/aicpu/non_max_suppression.py +36 -0
  710. mindspore/ops/_op_impl/aicpu/non_max_suppression_with_overlaps.py +35 -0
  711. mindspore/ops/_op_impl/aicpu/non_zero.py +43 -0
  712. mindspore/ops/_op_impl/aicpu/not_equal.py +39 -0
  713. mindspore/ops/_op_impl/aicpu/nth_element.py +39 -0
  714. mindspore/ops/_op_impl/aicpu/nuclear_norm.py +33 -0
  715. mindspore/ops/_op_impl/aicpu/one_hot.py +116 -0
  716. mindspore/ops/_op_impl/aicpu/ones_like.py +39 -0
  717. mindspore/ops/_op_impl/aicpu/orgqr.py +34 -0
  718. mindspore/ops/_op_impl/aicpu/pad_and_shift.py +33 -0
  719. mindspore/ops/_op_impl/aicpu/pad_v3.py +61 -0
  720. mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +59 -0
  721. mindspore/ops/_op_impl/aicpu/padding.py +41 -0
  722. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +54 -0
  723. mindspore/ops/_op_impl/aicpu/pdist_grad.py +33 -0
  724. mindspore/ops/_op_impl/aicpu/poisson.py +37 -0
  725. mindspore/ops/_op_impl/aicpu/polar.py +32 -0
  726. mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
  727. mindspore/ops/_op_impl/aicpu/pow.py +39 -0
  728. mindspore/ops/_op_impl/aicpu/print_tensor.py +39 -0
  729. mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +113 -0
  730. mindspore/ops/_op_impl/aicpu/qr.py +36 -0
  731. mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
  732. mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
  733. mindspore/ops/_op_impl/aicpu/ragged_range.py +49 -0
  734. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
  735. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
  736. mindspore/ops/_op_impl/aicpu/random_categorical.py +68 -0
  737. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +36 -0
  738. mindspore/ops/_op_impl/aicpu/random_gamma.py +38 -0
  739. mindspore/ops/_op_impl/aicpu/random_poisson.py +134 -0
  740. mindspore/ops/_op_impl/aicpu/random_shuffle.py +47 -0
  741. mindspore/ops/_op_impl/aicpu/randperm.py +38 -0
  742. mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
  743. mindspore/ops/_op_impl/aicpu/range.py +36 -0
  744. mindspore/ops/_op_impl/aicpu/range_v2.py +35 -0
  745. mindspore/ops/_op_impl/aicpu/real.py +31 -0
  746. mindspore/ops/_op_impl/aicpu/real_div.py +40 -0
  747. mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
  748. mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
  749. mindspore/ops/_op_impl/aicpu/reduce_mean.py +57 -0
  750. mindspore/ops/_op_impl/aicpu/reduce_prod.py +57 -0
  751. mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
  752. mindspore/ops/_op_impl/aicpu/relu_grad_v3.py +41 -0
  753. mindspore/ops/_op_impl/aicpu/relu_v3.py +38 -0
  754. mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +96 -0
  755. mindspore/ops/_op_impl/aicpu/reshape.py +42 -0
  756. mindspore/ops/_op_impl/aicpu/resize_area.py +40 -0
  757. mindspore/ops/_op_impl/aicpu/resize_bicubic.py +20 -0
  758. mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +19 -0
  759. mindspore/ops/_op_impl/aicpu/resize_bilinear.py +32 -0
  760. mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +32 -0
  761. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +36 -0
  762. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +35 -0
  763. mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
  764. mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
  765. mindspore/ops/_op_impl/aicpu/reverse_sequence.py +55 -0
  766. mindspore/ops/_op_impl/aicpu/reversev2.py +54 -0
  767. mindspore/ops/_op_impl/aicpu/rgb_to_hsv.py +32 -0
  768. mindspore/ops/_op_impl/aicpu/right_shift.py +38 -0
  769. mindspore/ops/_op_impl/aicpu/rnnt_loss.py +35 -0
  770. mindspore/ops/_op_impl/aicpu/round.py +34 -0
  771. mindspore/ops/_op_impl/aicpu/rsqrt.py +33 -0
  772. mindspore/ops/_op_impl/aicpu/rsqrt_grad.py +36 -0
  773. mindspore/ops/_op_impl/aicpu/sample_distorted_bounding_box_v2.py +49 -0
  774. mindspore/ops/_op_impl/aicpu/scale_and_translate.py +52 -0
  775. mindspore/ops/_op_impl/aicpu/scale_and_translate_grad.py +36 -0
  776. mindspore/ops/_op_impl/aicpu/scatter.py +79 -0
  777. mindspore/ops/_op_impl/aicpu/scatter_add_with_axis.py +53 -0
  778. mindspore/ops/_op_impl/aicpu/scatter_elements.py +39 -0
  779. mindspore/ops/_op_impl/aicpu/scatter_nd.py +59 -0
  780. mindspore/ops/_op_impl/aicpu/scatter_nd_max.py +54 -0
  781. mindspore/ops/_op_impl/aicpu/scatter_nd_min.py +54 -0
  782. mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +59 -0
  783. mindspore/ops/_op_impl/aicpu/search_sorted.py +44 -0
  784. mindspore/ops/_op_impl/aicpu/segment_max.py +52 -0
  785. mindspore/ops/_op_impl/aicpu/segment_mean.py +56 -0
  786. mindspore/ops/_op_impl/aicpu/segment_min.py +52 -0
  787. mindspore/ops/_op_impl/aicpu/segment_prod.py +56 -0
  788. mindspore/ops/_op_impl/aicpu/segment_sum.py +56 -0
  789. mindspore/ops/_op_impl/aicpu/select.py +45 -0
  790. mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
  791. mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
  792. mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
  793. mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
  794. mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
  795. mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
  796. mindspore/ops/_op_impl/aicpu/set_size.py +38 -0
  797. mindspore/ops/_op_impl/aicpu/sign.py +36 -0
  798. mindspore/ops/_op_impl/aicpu/sin.py +34 -0
  799. mindspore/ops/_op_impl/aicpu/sinc.py +43 -0
  800. mindspore/ops/_op_impl/aicpu/sinh.py +34 -0
  801. mindspore/ops/_op_impl/aicpu/slice.py +59 -0
  802. mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
  803. mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
  804. mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
  805. mindspore/ops/_op_impl/aicpu/sort.py +39 -0
  806. mindspore/ops/_op_impl/aicpu/space_to_depth.py +44 -0
  807. mindspore/ops/_op_impl/aicpu/sparse_addmm.py +87 -0
  808. mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +80 -0
  809. mindspore/ops/_op_impl/aicpu/sparse_apply_centered_rms_prop.py +105 -0
  810. mindspore/ops/_op_impl/aicpu/sparse_apply_momentum.py +80 -0
  811. mindspore/ops/_op_impl/aicpu/sparse_apply_proximal_gradient_descent.py +79 -0
  812. mindspore/ops/_op_impl/aicpu/sparse_concat.py +59 -0
  813. mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
  814. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_add.py +58 -0
  815. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_div.py +58 -0
  816. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_mul.py +58 -0
  817. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
  818. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
  819. mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
  820. mindspore/ops/_op_impl/aicpu/sparse_matrix_nnz.py +81 -0
  821. mindspore/ops/_op_impl/aicpu/sparse_matrix_transpose.py +116 -0
  822. mindspore/ops/_op_impl/aicpu/sparse_reorder.py +56 -0
  823. mindspore/ops/_op_impl/aicpu/sparse_reshape.py +34 -0
  824. mindspore/ops/_op_impl/aicpu/sparse_segment_mean_grad.py +36 -0
  825. mindspore/ops/_op_impl/aicpu/sparse_segment_mean_with_num_segments.py +44 -0
  826. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n.py +43 -0
  827. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_grad.py +38 -0
  828. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_with_num_segments.py +44 -0
  829. mindspore/ops/_op_impl/aicpu/sparse_segment_sum.py +49 -0
  830. mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
  831. mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
  832. mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
  833. mindspore/ops/_op_impl/aicpu/sparse_softmax.py +33 -0
  834. mindspore/ops/_op_impl/aicpu/sparse_softmax_cross_entropy_with_logits_v2.py +35 -0
  835. mindspore/ops/_op_impl/aicpu/sparse_sparse_maximum.py +53 -0
  836. mindspore/ops/_op_impl/aicpu/sparse_sparse_minimum.py +53 -0
  837. mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_add.py +84 -0
  838. mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_mat_mul.py +190 -0
  839. mindspore/ops/_op_impl/aicpu/sparse_tensor_to_csr_sparse_matrix.py +51 -0
  840. mindspore/ops/_op_impl/aicpu/sparse_to_dense_v2.py +73 -0
  841. mindspore/ops/_op_impl/aicpu/split.py +45 -0
  842. mindspore/ops/_op_impl/aicpu/sqrt.py +34 -0
  843. mindspore/ops/_op_impl/aicpu/sqrt_grad.py +35 -0
  844. mindspore/ops/_op_impl/aicpu/square.py +35 -0
  845. mindspore/ops/_op_impl/aicpu/squared_difference.py +37 -0
  846. mindspore/ops/_op_impl/aicpu/squeeze.py +42 -0
  847. mindspore/ops/_op_impl/aicpu/sspaddmm.py +97 -0
  848. mindspore/ops/_op_impl/aicpu/stack.py +45 -0
  849. mindspore/ops/_op_impl/aicpu/stack_push_pop.py +87 -0
  850. mindspore/ops/_op_impl/aicpu/standard_laplace.py +34 -0
  851. mindspore/ops/_op_impl/aicpu/standard_normal.py +34 -0
  852. mindspore/ops/_op_impl/aicpu/stateless_dropout_genmask.py +37 -0
  853. mindspore/ops/_op_impl/aicpu/stft.py +70 -0
  854. mindspore/ops/_op_impl/aicpu/strided_slice.py +43 -0
  855. mindspore/ops/_op_impl/aicpu/strided_slice_grad.py +50 -0
  856. mindspore/ops/_op_impl/aicpu/sub.py +41 -0
  857. mindspore/ops/_op_impl/aicpu/sub_and_filter.py +36 -0
  858. mindspore/ops/_op_impl/aicpu/tan.py +34 -0
  859. mindspore/ops/_op_impl/aicpu/tanh.py +34 -0
  860. mindspore/ops/_op_impl/aicpu/tanh_grad.py +35 -0
  861. mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
  862. mindspore/ops/_op_impl/aicpu/tile.py +56 -0
  863. mindspore/ops/_op_impl/aicpu/topk.py +34 -0
  864. mindspore/ops/_op_impl/aicpu/trace.py +40 -0
  865. mindspore/ops/_op_impl/aicpu/tracegrad.py +41 -0
  866. mindspore/ops/_op_impl/aicpu/trans_data.py +35 -0
  867. mindspore/ops/_op_impl/aicpu/transpose.py +58 -0
  868. mindspore/ops/_op_impl/aicpu/tridiagonal_matmul.py +42 -0
  869. mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
  870. mindspore/ops/_op_impl/aicpu/tril.py +42 -0
  871. mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
  872. mindspore/ops/_op_impl/aicpu/triplet_margin_loss.py +62 -0
  873. mindspore/ops/_op_impl/aicpu/triu.py +43 -0
  874. mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
  875. mindspore/ops/_op_impl/aicpu/truncated_normal.py +39 -0
  876. mindspore/ops/_op_impl/aicpu/uniform.py +36 -0
  877. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +41 -0
  878. mindspore/ops/_op_impl/aicpu/uniform_int.py +36 -0
  879. mindspore/ops/_op_impl/aicpu/uniform_real.py +33 -0
  880. mindspore/ops/_op_impl/aicpu/unique.py +31 -0
  881. mindspore/ops/_op_impl/aicpu/unique_consecutive.py +47 -0
  882. mindspore/ops/_op_impl/aicpu/unique_with_pad.py +32 -0
  883. mindspore/ops/_op_impl/aicpu/unravel_index.py +32 -0
  884. mindspore/ops/_op_impl/aicpu/unsorted_segment_prod.py +53 -0
  885. mindspore/ops/_op_impl/aicpu/unsorted_segment_sum.py +57 -0
  886. mindspore/ops/_op_impl/aicpu/unstack.py +45 -0
  887. mindspore/ops/_op_impl/aicpu/update_cache.py +44 -0
  888. mindspore/ops/_op_impl/aicpu/upper_bound.py +47 -0
  889. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +42 -0
  890. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +49 -0
  891. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +40 -0
  892. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +50 -0
  893. mindspore/ops/_op_impl/aicpu/xdivy.py +35 -0
  894. mindspore/ops/_op_impl/aicpu/xlogy.py +33 -0
  895. mindspore/ops/_op_impl/aicpu/zeros_like.py +42 -0
  896. mindspore/ops/_op_impl/aicpu/zeta.py +31 -0
  897. mindspore/ops/_op_impl/akg/__init__.py +19 -0
  898. mindspore/ops/_op_impl/akg/ascend/__init__.py +48 -0
  899. mindspore/ops/_op_impl/akg/ascend/abs.py +35 -0
  900. mindspore/ops/_op_impl/akg/ascend/add.py +42 -0
  901. mindspore/ops/_op_impl/akg/ascend/add_n.py +37 -0
  902. mindspore/ops/_op_impl/akg/ascend/batchmatmul.py +33 -0
  903. mindspore/ops/_op_impl/akg/ascend/cast.py +46 -0
  904. mindspore/ops/_op_impl/akg/ascend/equal.py +35 -0
  905. mindspore/ops/_op_impl/akg/ascend/exp.py +35 -0
  906. mindspore/ops/_op_impl/akg/ascend/expand_dims.py +33 -0
  907. mindspore/ops/_op_impl/akg/ascend/greater.py +34 -0
  908. mindspore/ops/_op_impl/akg/ascend/greater_equal.py +35 -0
  909. mindspore/ops/_op_impl/akg/ascend/less.py +31 -0
  910. mindspore/ops/_op_impl/akg/ascend/less_equal.py +35 -0
  911. mindspore/ops/_op_impl/akg/ascend/load_im2col.py +33 -0
  912. mindspore/ops/_op_impl/akg/ascend/log.py +34 -0
  913. mindspore/ops/_op_impl/akg/ascend/maximum.py +36 -0
  914. mindspore/ops/_op_impl/akg/ascend/minimum.py +39 -0
  915. mindspore/ops/_op_impl/akg/ascend/mul.py +41 -0
  916. mindspore/ops/_op_impl/akg/ascend/neg.py +37 -0
  917. mindspore/ops/_op_impl/akg/ascend/pow.py +35 -0
  918. mindspore/ops/_op_impl/akg/ascend/prod_force_se_a.py +33 -0
  919. mindspore/ops/_op_impl/akg/ascend/real_div.py +36 -0
  920. mindspore/ops/_op_impl/akg/ascend/reciprocal.py +32 -0
  921. mindspore/ops/_op_impl/akg/ascend/reduce_max.py +32 -0
  922. mindspore/ops/_op_impl/akg/ascend/reduce_min.py +32 -0
  923. mindspore/ops/_op_impl/akg/ascend/reduce_sum.py +37 -0
  924. mindspore/ops/_op_impl/akg/ascend/rsqrt.py +35 -0
  925. mindspore/ops/_op_impl/akg/ascend/select.py +37 -0
  926. mindspore/ops/_op_impl/akg/ascend/sqrt.py +35 -0
  927. mindspore/ops/_op_impl/akg/ascend/square.py +35 -0
  928. mindspore/ops/_op_impl/akg/ascend/sub.py +42 -0
  929. mindspore/ops/_op_impl/akg/cpu/__init__.py +23 -0
  930. mindspore/ops/_op_impl/akg/cpu/coo2csr.py +29 -0
  931. mindspore/ops/_op_impl/akg/cpu/csr2coo.py +29 -0
  932. mindspore/ops/_op_impl/akg/cpu/csr_gather.py +33 -0
  933. mindspore/ops/_op_impl/akg/cpu/csr_mm.py +34 -0
  934. mindspore/ops/_op_impl/akg/cpu/csr_mul.py +33 -0
  935. mindspore/ops/_op_impl/akg/cpu/csr_mv.py +33 -0
  936. mindspore/ops/_op_impl/akg/cpu/csr_reduce_sum.py +31 -0
  937. mindspore/ops/_op_impl/akg/gpu/__init__.py +24 -0
  938. mindspore/ops/_op_impl/akg/gpu/coo2csr.py +29 -0
  939. mindspore/ops/_op_impl/akg/gpu/csr2coo.py +29 -0
  940. mindspore/ops/_op_impl/akg/gpu/csr_div.py +36 -0
  941. mindspore/ops/_op_impl/akg/gpu/csr_gather.py +33 -0
  942. mindspore/ops/_op_impl/akg/gpu/csr_mm.py +37 -0
  943. mindspore/ops/_op_impl/akg/gpu/csr_mul.py +36 -0
  944. mindspore/ops/_op_impl/akg/gpu/csr_mv.py +36 -0
  945. mindspore/ops/_op_impl/akg/gpu/csr_reduce_sum.py +33 -0
  946. mindspore/ops/_op_impl/cpu/__init__.py +78 -0
  947. mindspore/ops/_op_impl/cpu/adam.py +49 -0
  948. mindspore/ops/_op_impl/cpu/adam_weight_decay.py +47 -0
  949. mindspore/ops/_op_impl/cpu/arg_max.py +30 -0
  950. mindspore/ops/_op_impl/cpu/arg_max_with_value.py +31 -0
  951. mindspore/ops/_op_impl/cpu/arg_min_with_value.py +31 -0
  952. mindspore/ops/_op_impl/cpu/buffer_append.py +28 -0
  953. mindspore/ops/_op_impl/cpu/buffer_get.py +28 -0
  954. mindspore/ops/_op_impl/cpu/buffer_sample.py +28 -0
  955. mindspore/ops/_op_impl/cpu/cast.py +171 -0
  956. mindspore/ops/_op_impl/cpu/concat_offset.py +38 -0
  957. mindspore/ops/_op_impl/cpu/conv2d.py +30 -0
  958. mindspore/ops/_op_impl/cpu/conv3d.py +30 -0
  959. mindspore/ops/_op_impl/cpu/div.py +32 -0
  960. mindspore/ops/_op_impl/cpu/dropout.py +31 -0
  961. mindspore/ops/_op_impl/cpu/dropout_grad.py +30 -0
  962. mindspore/ops/_op_impl/cpu/dynamic_shape.py +42 -0
  963. mindspore/ops/_op_impl/cpu/dynamic_stitch.py +41 -0
  964. mindspore/ops/_op_impl/cpu/equal_count.py +30 -0
  965. mindspore/ops/_op_impl/cpu/gather_d.py +49 -0
  966. mindspore/ops/_op_impl/cpu/gather_d_grad.py +38 -0
  967. mindspore/ops/_op_impl/cpu/gather_d_grad_v2.py +40 -0
  968. mindspore/ops/_op_impl/cpu/gather_v2.py +40 -0
  969. mindspore/ops/_op_impl/cpu/hsigmoid.py +33 -0
  970. mindspore/ops/_op_impl/cpu/hsigmoid_grad.py +34 -0
  971. mindspore/ops/_op_impl/cpu/hswish.py +32 -0
  972. mindspore/ops/_op_impl/cpu/hswish_grad.py +33 -0
  973. mindspore/ops/_op_impl/cpu/identity_n.py +40 -0
  974. mindspore/ops/_op_impl/cpu/is_finite.py +39 -0
  975. mindspore/ops/_op_impl/cpu/l2loss.py +30 -0
  976. mindspore/ops/_op_impl/cpu/layer_norm.py +36 -0
  977. mindspore/ops/_op_impl/cpu/layer_norm_grad.py +38 -0
  978. mindspore/ops/_op_impl/cpu/maximum.py +35 -0
  979. mindspore/ops/_op_impl/cpu/maximum_grad.py +47 -0
  980. mindspore/ops/_op_impl/cpu/minimum.py +40 -0
  981. mindspore/ops/_op_impl/cpu/minimum_grad.py +51 -0
  982. mindspore/ops/_op_impl/cpu/mirror_pad.py +36 -0
  983. mindspore/ops/_op_impl/cpu/mirror_pad_grad.py +36 -0
  984. mindspore/ops/_op_impl/cpu/mul.py +32 -0
  985. mindspore/ops/_op_impl/cpu/one_hot.py +31 -0
  986. mindspore/ops/_op_impl/cpu/pad.py +32 -0
  987. mindspore/ops/_op_impl/cpu/pow.py +32 -0
  988. mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +42 -0
  989. mindspore/ops/_op_impl/cpu/pyexecute.py +29 -0
  990. mindspore/ops/_op_impl/cpu/pyfunc.py +29 -0
  991. mindspore/ops/_op_impl/cpu/range.py +34 -0
  992. mindspore/ops/_op_impl/cpu/real_div.py +33 -0
  993. mindspore/ops/_op_impl/cpu/reduce_all.py +29 -0
  994. mindspore/ops/_op_impl/cpu/reduce_any.py +29 -0
  995. mindspore/ops/_op_impl/cpu/reduce_max.py +32 -0
  996. mindspore/ops/_op_impl/cpu/reduce_mean.py +40 -0
  997. mindspore/ops/_op_impl/cpu/reduce_min.py +32 -0
  998. mindspore/ops/_op_impl/cpu/reduce_prod.py +40 -0
  999. mindspore/ops/_op_impl/cpu/reduce_std.py +31 -0
  1000. mindspore/ops/_op_impl/cpu/reduce_sum.py +41 -0
  1001. mindspore/ops/_op_impl/cpu/space_to_batch_nd.py +38 -0
  1002. mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
  1003. mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
  1004. mindspore/ops/_op_impl/cpu/split.py +34 -0
  1005. mindspore/ops/_op_impl/cpu/sspaddmm.py +95 -0
  1006. mindspore/ops/_op_impl/cpu/stack.py +38 -0
  1007. mindspore/ops/_op_impl/cpu/sub.py +32 -0
  1008. mindspore/ops/_op_impl/cpu/tensor_copy_slices.py +41 -0
  1009. mindspore/ops/_op_impl/cpu/tile.py +37 -0
  1010. mindspore/ops/_op_impl/cpu/top_k.py +31 -0
  1011. mindspore/ops/_op_impl/cpu/transpose.py +39 -0
  1012. mindspore/ops/_primitive_cache.py +90 -0
  1013. mindspore/ops/_register_for_op.py +73 -0
  1014. mindspore/ops/_utils/__init__.py +20 -0
  1015. mindspore/ops/_utils/utils.py +147 -0
  1016. mindspore/ops/_vmap/__init__.py +25 -0
  1017. mindspore/ops/_vmap/vmap_array_ops.py +2149 -0
  1018. mindspore/ops/_vmap/vmap_base.py +533 -0
  1019. mindspore/ops/_vmap/vmap_convolution_ops.py +441 -0
  1020. mindspore/ops/_vmap/vmap_debug_ops.py +50 -0
  1021. mindspore/ops/_vmap/vmap_grad_math_ops.py +274 -0
  1022. mindspore/ops/_vmap/vmap_grad_nn_ops.py +806 -0
  1023. mindspore/ops/_vmap/vmap_image_ops.py +194 -0
  1024. mindspore/ops/_vmap/vmap_math_ops.py +993 -0
  1025. mindspore/ops/_vmap/vmap_nn_ops.py +2250 -0
  1026. mindspore/ops/_vmap/vmap_other_ops.py +105 -0
  1027. mindspore/ops/_vmap/vmap_random_ops.py +122 -0
  1028. mindspore/ops/_vmap/vmap_sparse_ops.py +89 -0
  1029. mindspore/ops/auto_generate/__init__.py +31 -0
  1030. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +309 -0
  1031. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +252 -0
  1032. mindspore/ops/auto_generate/gen_arg_handler.py +197 -0
  1033. mindspore/ops/auto_generate/gen_extend_func.py +1701 -0
  1034. mindspore/ops/auto_generate/gen_ops_def.py +8482 -0
  1035. mindspore/ops/auto_generate/gen_ops_prim.py +16704 -0
  1036. mindspore/ops/auto_generate/pyboost_inner_prim.py +549 -0
  1037. mindspore/ops/composite/__init__.py +71 -0
  1038. mindspore/ops/composite/base.py +1318 -0
  1039. mindspore/ops/composite/env_ops.py +41 -0
  1040. mindspore/ops/composite/math_ops.py +125 -0
  1041. mindspore/ops/composite/multitype_ops/__init__.py +77 -0
  1042. mindspore/ops/composite/multitype_ops/_compile_utils.py +1459 -0
  1043. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +897 -0
  1044. mindspore/ops/composite/multitype_ops/add_impl.py +606 -0
  1045. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +56 -0
  1046. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +56 -0
  1047. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +56 -0
  1048. mindspore/ops/composite/multitype_ops/div_impl.py +189 -0
  1049. mindspore/ops/composite/multitype_ops/equal_impl.py +335 -0
  1050. mindspore/ops/composite/multitype_ops/floordiv_impl.py +88 -0
  1051. mindspore/ops/composite/multitype_ops/getitem_impl.py +400 -0
  1052. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +109 -0
  1053. mindspore/ops/composite/multitype_ops/greater_impl.py +110 -0
  1054. mindspore/ops/composite/multitype_ops/in_impl.py +196 -0
  1055. mindspore/ops/composite/multitype_ops/left_shift_impl.py +37 -0
  1056. mindspore/ops/composite/multitype_ops/less_equal_impl.py +111 -0
  1057. mindspore/ops/composite/multitype_ops/less_impl.py +112 -0
  1058. mindspore/ops/composite/multitype_ops/logic_not_impl.py +113 -0
  1059. mindspore/ops/composite/multitype_ops/logical_and_impl.py +60 -0
  1060. mindspore/ops/composite/multitype_ops/logical_or_impl.py +61 -0
  1061. mindspore/ops/composite/multitype_ops/mod_impl.py +86 -0
  1062. mindspore/ops/composite/multitype_ops/mul_impl.py +294 -0
  1063. mindspore/ops/composite/multitype_ops/negative_impl.py +79 -0
  1064. mindspore/ops/composite/multitype_ops/not_equal_impl.py +290 -0
  1065. mindspore/ops/composite/multitype_ops/not_in_impl.py +196 -0
  1066. mindspore/ops/composite/multitype_ops/ones_like_impl.py +96 -0
  1067. mindspore/ops/composite/multitype_ops/pow_impl.py +87 -0
  1068. mindspore/ops/composite/multitype_ops/right_shift_impl.py +37 -0
  1069. mindspore/ops/composite/multitype_ops/setitem_impl.py +884 -0
  1070. mindspore/ops/composite/multitype_ops/sub_impl.py +116 -0
  1071. mindspore/ops/composite/multitype_ops/uadd_impl.py +29 -0
  1072. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +228 -0
  1073. mindspore/ops/deprecated.py +315 -0
  1074. mindspore/ops/function/__init__.py +782 -0
  1075. mindspore/ops/function/array_func.py +7226 -0
  1076. mindspore/ops/function/clip_func.py +384 -0
  1077. mindspore/ops/function/debug_func.py +181 -0
  1078. mindspore/ops/function/fft_func.py +44 -0
  1079. mindspore/ops/function/grad/__init__.py +34 -0
  1080. mindspore/ops/function/grad/grad_func.py +1425 -0
  1081. mindspore/ops/function/image_func.py +292 -0
  1082. mindspore/ops/function/linalg_func.py +416 -0
  1083. mindspore/ops/function/math_func.py +12228 -0
  1084. mindspore/ops/function/nn_func.py +8609 -0
  1085. mindspore/ops/function/other_func.py +115 -0
  1086. mindspore/ops/function/parameter_func.py +134 -0
  1087. mindspore/ops/function/random_func.py +1715 -0
  1088. mindspore/ops/function/reshard_func.py +104 -0
  1089. mindspore/ops/function/sparse_func.py +884 -0
  1090. mindspore/ops/function/sparse_unary_func.py +2422 -0
  1091. mindspore/ops/function/spectral_func.py +150 -0
  1092. mindspore/ops/function/vmap_func.py +117 -0
  1093. mindspore/ops/functional.py +464 -0
  1094. mindspore/ops/op_info_register.py +1572 -0
  1095. mindspore/ops/operations/__init__.py +722 -0
  1096. mindspore/ops/operations/_csr_ops.py +403 -0
  1097. mindspore/ops/operations/_custom_grad.py +181 -0
  1098. mindspore/ops/operations/_embedding_cache_ops.py +307 -0
  1099. mindspore/ops/operations/_grad_ops.py +2978 -0
  1100. mindspore/ops/operations/_infer_ops.py +19 -0
  1101. mindspore/ops/operations/_inner_ops.py +2544 -0
  1102. mindspore/ops/operations/_map_tensor_ops.py +112 -0
  1103. mindspore/ops/operations/_ms_kernel.py +601 -0
  1104. mindspore/ops/operations/_ocr_ops.py +379 -0
  1105. mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
  1106. mindspore/ops/operations/_pyfunc_registry.py +58 -0
  1107. mindspore/ops/operations/_quant_ops.py +1844 -0
  1108. mindspore/ops/operations/_rl_inner_ops.py +1231 -0
  1109. mindspore/ops/operations/_scalar_ops.py +106 -0
  1110. mindspore/ops/operations/_sequence_ops.py +1155 -0
  1111. mindspore/ops/operations/_sparse_grad_ops.py +56 -0
  1112. mindspore/ops/operations/_tensor_array.py +359 -0
  1113. mindspore/ops/operations/_thor_ops.py +807 -0
  1114. mindspore/ops/operations/array_ops.py +6124 -0
  1115. mindspore/ops/operations/comm_ops.py +1985 -0
  1116. mindspore/ops/operations/control_ops.py +127 -0
  1117. mindspore/ops/operations/custom_ops.py +1129 -0
  1118. mindspore/ops/operations/debug_ops.py +678 -0
  1119. mindspore/ops/operations/image_ops.py +1041 -0
  1120. mindspore/ops/operations/inner_ops.py +697 -0
  1121. mindspore/ops/operations/linalg_ops.py +95 -0
  1122. mindspore/ops/operations/manually_defined/__init__.py +24 -0
  1123. mindspore/ops/operations/manually_defined/_inner.py +73 -0
  1124. mindspore/ops/operations/manually_defined/ops_def.py +2271 -0
  1125. mindspore/ops/operations/math_ops.py +5095 -0
  1126. mindspore/ops/operations/nn_ops.py +9575 -0
  1127. mindspore/ops/operations/other_ops.py +874 -0
  1128. mindspore/ops/operations/random_ops.py +1288 -0
  1129. mindspore/ops/operations/reshard_ops.py +53 -0
  1130. mindspore/ops/operations/rl_ops.py +288 -0
  1131. mindspore/ops/operations/sparse_ops.py +2753 -0
  1132. mindspore/ops/operations/spectral_ops.py +111 -0
  1133. mindspore/ops/primitive.py +1046 -0
  1134. mindspore/ops/signature.py +54 -0
  1135. mindspore/ops/vm_impl_registry.py +91 -0
  1136. mindspore/ops_generate/__init__.py +27 -0
  1137. mindspore/ops_generate/arg_dtype_cast.py +252 -0
  1138. mindspore/ops_generate/arg_handler.py +197 -0
  1139. mindspore/ops_generate/gen_aclnn_implement.py +263 -0
  1140. mindspore/ops_generate/gen_constants.py +36 -0
  1141. mindspore/ops_generate/gen_ops.py +1099 -0
  1142. mindspore/ops_generate/gen_ops_inner_prim.py +131 -0
  1143. mindspore/ops_generate/gen_pyboost_func.py +1052 -0
  1144. mindspore/ops_generate/gen_utils.py +209 -0
  1145. mindspore/ops_generate/op_proto.py +145 -0
  1146. mindspore/ops_generate/pyboost_utils.py +367 -0
  1147. mindspore/ops_generate/template.py +261 -0
  1148. mindspore/parallel/__init__.py +30 -0
  1149. mindspore/parallel/_auto_parallel_context.py +1486 -0
  1150. mindspore/parallel/_cell_wrapper.py +174 -0
  1151. mindspore/parallel/_cost_model_context.py +700 -0
  1152. mindspore/parallel/_dp_allreduce_fusion.py +159 -0
  1153. mindspore/parallel/_offload_context.py +275 -0
  1154. mindspore/parallel/_parallel_serialization.py +561 -0
  1155. mindspore/parallel/_ps_context.py +242 -0
  1156. mindspore/parallel/_recovery_context.py +110 -0
  1157. mindspore/parallel/_tensor.py +730 -0
  1158. mindspore/parallel/_transformer/__init__.py +35 -0
  1159. mindspore/parallel/_transformer/layers.py +765 -0
  1160. mindspore/parallel/_transformer/loss.py +251 -0
  1161. mindspore/parallel/_transformer/moe.py +693 -0
  1162. mindspore/parallel/_transformer/op_parallel_config.py +222 -0
  1163. mindspore/parallel/_transformer/transformer.py +3119 -0
  1164. mindspore/parallel/_utils.py +612 -0
  1165. mindspore/parallel/algo_parameter_config.py +400 -0
  1166. mindspore/parallel/checkpoint_transform.py +650 -0
  1167. mindspore/parallel/cluster/__init__.py +15 -0
  1168. mindspore/parallel/cluster/process_entity/__init__.py +18 -0
  1169. mindspore/parallel/cluster/process_entity/_api.py +352 -0
  1170. mindspore/parallel/cluster/process_entity/_utils.py +101 -0
  1171. mindspore/parallel/cluster/run.py +136 -0
  1172. mindspore/parallel/mpi/__init__.py +14 -0
  1173. mindspore/parallel/mpi/_mpi_config.py +116 -0
  1174. mindspore/parallel/parameter_broadcast.py +151 -0
  1175. mindspore/parallel/shard.py +481 -0
  1176. mindspore/parallel/transform_safetensors.py +993 -0
  1177. mindspore/profiler/__init__.py +28 -0
  1178. mindspore/profiler/common/__init__.py +14 -0
  1179. mindspore/profiler/common/constant.py +29 -0
  1180. mindspore/profiler/common/exceptions/__init__.py +14 -0
  1181. mindspore/profiler/common/exceptions/error_code.py +83 -0
  1182. mindspore/profiler/common/exceptions/exceptions.py +286 -0
  1183. mindspore/profiler/common/process_pool.py +41 -0
  1184. mindspore/profiler/common/registry.py +47 -0
  1185. mindspore/profiler/common/singleton.py +28 -0
  1186. mindspore/profiler/common/struct_type.py +118 -0
  1187. mindspore/profiler/common/util.py +472 -0
  1188. mindspore/profiler/common/validator/__init__.py +14 -0
  1189. mindspore/profiler/common/validator/validate_path.py +84 -0
  1190. mindspore/profiler/dynamic_profiler.py +694 -0
  1191. mindspore/profiler/envprofiling.py +254 -0
  1192. mindspore/profiler/parser/__init__.py +14 -0
  1193. mindspore/profiler/parser/aicpu_data_parser.py +272 -0
  1194. mindspore/profiler/parser/ascend_analysis/__init__.py +14 -0
  1195. mindspore/profiler/parser/ascend_analysis/constant.py +71 -0
  1196. mindspore/profiler/parser/ascend_analysis/file_manager.py +180 -0
  1197. mindspore/profiler/parser/ascend_analysis/function_event.py +185 -0
  1198. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +136 -0
  1199. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +131 -0
  1200. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +104 -0
  1201. mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
  1202. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +123 -0
  1203. mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +86 -0
  1204. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +75 -0
  1205. mindspore/profiler/parser/ascend_cluster_generator.py +116 -0
  1206. mindspore/profiler/parser/ascend_communicate_generator.py +314 -0
  1207. mindspore/profiler/parser/ascend_flops_generator.py +116 -0
  1208. mindspore/profiler/parser/ascend_fpbp_generator.py +82 -0
  1209. mindspore/profiler/parser/ascend_hccl_generator.py +271 -0
  1210. mindspore/profiler/parser/ascend_integrate_generator.py +42 -0
  1211. mindspore/profiler/parser/ascend_memory_generator.py +185 -0
  1212. mindspore/profiler/parser/ascend_msprof_exporter.py +282 -0
  1213. mindspore/profiler/parser/ascend_msprof_generator.py +187 -0
  1214. mindspore/profiler/parser/ascend_op_generator.py +334 -0
  1215. mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
  1216. mindspore/profiler/parser/ascend_timeline_generator.py +545 -0
  1217. mindspore/profiler/parser/base_timeline_generator.py +483 -0
  1218. mindspore/profiler/parser/container.py +229 -0
  1219. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +697 -0
  1220. mindspore/profiler/parser/flops_parser.py +531 -0
  1221. mindspore/profiler/parser/framework_enum.py +111 -0
  1222. mindspore/profiler/parser/framework_parser.py +464 -0
  1223. mindspore/profiler/parser/framework_struct.py +61 -0
  1224. mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
  1225. mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
  1226. mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
  1227. mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
  1228. mindspore/profiler/parser/hccl_parser.py +573 -0
  1229. mindspore/profiler/parser/hwts_log_parser.py +122 -0
  1230. mindspore/profiler/parser/integrator.py +526 -0
  1231. mindspore/profiler/parser/memory_usage_parser.py +277 -0
  1232. mindspore/profiler/parser/minddata_analyzer.py +800 -0
  1233. mindspore/profiler/parser/minddata_parser.py +186 -0
  1234. mindspore/profiler/parser/minddata_pipeline_parser.py +299 -0
  1235. mindspore/profiler/parser/op_intermediate_parser.py +149 -0
  1236. mindspore/profiler/parser/optime_parser.py +250 -0
  1237. mindspore/profiler/parser/profiler_info.py +213 -0
  1238. mindspore/profiler/parser/step_trace_parser.py +666 -0
  1239. mindspore/profiler/profiler.py +153 -0
  1240. mindspore/profiler/profiling.py +1922 -0
  1241. mindspore/rewrite/__init__.py +28 -0
  1242. mindspore/rewrite/api/__init__.py +17 -0
  1243. mindspore/rewrite/api/node.py +519 -0
  1244. mindspore/rewrite/api/node_type.py +53 -0
  1245. mindspore/rewrite/api/pattern_engine.py +490 -0
  1246. mindspore/rewrite/api/scoped_value.py +181 -0
  1247. mindspore/rewrite/api/symbol_tree.py +497 -0
  1248. mindspore/rewrite/ast_helpers/__init__.py +25 -0
  1249. mindspore/rewrite/ast_helpers/ast_converter.py +143 -0
  1250. mindspore/rewrite/ast_helpers/ast_finder.py +404 -0
  1251. mindspore/rewrite/ast_helpers/ast_flattener.py +268 -0
  1252. mindspore/rewrite/ast_helpers/ast_modifier.py +605 -0
  1253. mindspore/rewrite/ast_helpers/ast_replacer.py +79 -0
  1254. mindspore/rewrite/common/__init__.py +19 -0
  1255. mindspore/rewrite/common/config.py +24 -0
  1256. mindspore/rewrite/common/error_log.py +39 -0
  1257. mindspore/rewrite/common/event.py +28 -0
  1258. mindspore/rewrite/common/namer.py +271 -0
  1259. mindspore/rewrite/common/namespace.py +118 -0
  1260. mindspore/rewrite/common/observable.py +44 -0
  1261. mindspore/rewrite/common/observer.py +54 -0
  1262. mindspore/rewrite/node/__init__.py +22 -0
  1263. mindspore/rewrite/node/call_function.py +95 -0
  1264. mindspore/rewrite/node/cell_container.py +139 -0
  1265. mindspore/rewrite/node/control_flow.py +113 -0
  1266. mindspore/rewrite/node/node.py +1428 -0
  1267. mindspore/rewrite/node/node_manager.py +283 -0
  1268. mindspore/rewrite/node/node_topological_manager.py +223 -0
  1269. mindspore/rewrite/parsers/__init__.py +29 -0
  1270. mindspore/rewrite/parsers/arguments_parser.py +63 -0
  1271. mindspore/rewrite/parsers/assign_parser.py +852 -0
  1272. mindspore/rewrite/parsers/attribute_parser.py +57 -0
  1273. mindspore/rewrite/parsers/class_def_parser.py +289 -0
  1274. mindspore/rewrite/parsers/constant_parser.py +104 -0
  1275. mindspore/rewrite/parsers/container_parser.py +88 -0
  1276. mindspore/rewrite/parsers/expr_parser.py +55 -0
  1277. mindspore/rewrite/parsers/for_parser.py +61 -0
  1278. mindspore/rewrite/parsers/function_def_parser.py +84 -0
  1279. mindspore/rewrite/parsers/if_parser.py +85 -0
  1280. mindspore/rewrite/parsers/module_parser.py +117 -0
  1281. mindspore/rewrite/parsers/parser.py +43 -0
  1282. mindspore/rewrite/parsers/parser_register.py +86 -0
  1283. mindspore/rewrite/parsers/return_parser.py +37 -0
  1284. mindspore/rewrite/parsers/while_parser.py +59 -0
  1285. mindspore/rewrite/sparsify/__init__.py +0 -0
  1286. mindspore/rewrite/sparsify/sparse_transformer.py +457 -0
  1287. mindspore/rewrite/sparsify/sparsify.py +112 -0
  1288. mindspore/rewrite/sparsify/utils.py +179 -0
  1289. mindspore/rewrite/symbol_tree/__init__.py +20 -0
  1290. mindspore/rewrite/symbol_tree/symbol_tree.py +1819 -0
  1291. mindspore/rewrite/symbol_tree/symbol_tree_builder.py +76 -0
  1292. mindspore/rewrite/symbol_tree/symbol_tree_dumper.py +142 -0
  1293. mindspore/run_check/__init__.py +20 -0
  1294. mindspore/run_check/_check_version.py +507 -0
  1295. mindspore/run_check/run_check.py +66 -0
  1296. mindspore/safeguard/__init__.py +18 -0
  1297. mindspore/safeguard/rewrite_obfuscation.py +875 -0
  1298. mindspore/scipy/__init__.py +18 -0
  1299. mindspore/scipy/fft.py +264 -0
  1300. mindspore/scipy/linalg.py +919 -0
  1301. mindspore/scipy/ops.py +165 -0
  1302. mindspore/scipy/ops_grad.py +115 -0
  1303. mindspore/scipy/ops_wrapper.py +74 -0
  1304. mindspore/scipy/optimize/__init__.py +20 -0
  1305. mindspore/scipy/optimize/_bfgs.py +230 -0
  1306. mindspore/scipy/optimize/_lagrange.py +201 -0
  1307. mindspore/scipy/optimize/_lbfgs.py +146 -0
  1308. mindspore/scipy/optimize/gradient_optimization_algorithm.py +168 -0
  1309. mindspore/scipy/optimize/line_search.py +370 -0
  1310. mindspore/scipy/optimize/linear_sum_assignment.py +78 -0
  1311. mindspore/scipy/optimize/minimize.py +200 -0
  1312. mindspore/scipy/utils.py +156 -0
  1313. mindspore/scipy/utils_const.py +246 -0
  1314. mindspore/train/__init__.py +48 -0
  1315. mindspore/train/_utils.py +465 -0
  1316. mindspore/train/amp.py +935 -0
  1317. mindspore/train/anf_ir_pb2.py +1517 -0
  1318. mindspore/train/callback/__init__.py +44 -0
  1319. mindspore/train/callback/_backup_and_restore.py +117 -0
  1320. mindspore/train/callback/_callback.py +613 -0
  1321. mindspore/train/callback/_checkpoint.py +814 -0
  1322. mindspore/train/callback/_cluster_monitor.py +201 -0
  1323. mindspore/train/callback/_dataset_graph.py +150 -0
  1324. mindspore/train/callback/_early_stop.py +239 -0
  1325. mindspore/train/callback/_flops_collector.py +239 -0
  1326. mindspore/train/callback/_history.py +92 -0
  1327. mindspore/train/callback/_lambda_callback.py +80 -0
  1328. mindspore/train/callback/_landscape.py +1049 -0
  1329. mindspore/train/callback/_loss_monitor.py +107 -0
  1330. mindspore/train/callback/_lr_scheduler_callback.py +76 -0
  1331. mindspore/train/callback/_on_request_exit.py +298 -0
  1332. mindspore/train/callback/_reduce_lr_on_plateau.py +226 -0
  1333. mindspore/train/callback/_summary_collector.py +1184 -0
  1334. mindspore/train/callback/_tft_register.py +352 -0
  1335. mindspore/train/callback/_time_monitor.py +141 -0
  1336. mindspore/train/checkpoint_pb2.py +233 -0
  1337. mindspore/train/data_sink.py +219 -0
  1338. mindspore/train/dataset_helper.py +692 -0
  1339. mindspore/train/lineage_pb2.py +1260 -0
  1340. mindspore/train/loss_scale_manager.py +213 -0
  1341. mindspore/train/memory_profiling_pb2.py +298 -0
  1342. mindspore/train/metrics/__init__.py +175 -0
  1343. mindspore/train/metrics/accuracy.py +133 -0
  1344. mindspore/train/metrics/auc.py +129 -0
  1345. mindspore/train/metrics/bleu_score.py +170 -0
  1346. mindspore/train/metrics/confusion_matrix.py +700 -0
  1347. mindspore/train/metrics/cosine_similarity.py +109 -0
  1348. mindspore/train/metrics/dice.py +116 -0
  1349. mindspore/train/metrics/error.py +175 -0
  1350. mindspore/train/metrics/fbeta.py +167 -0
  1351. mindspore/train/metrics/hausdorff_distance.py +333 -0
  1352. mindspore/train/metrics/loss.py +97 -0
  1353. mindspore/train/metrics/mean_surface_distance.py +189 -0
  1354. mindspore/train/metrics/metric.py +373 -0
  1355. mindspore/train/metrics/occlusion_sensitivity.py +225 -0
  1356. mindspore/train/metrics/perplexity.py +133 -0
  1357. mindspore/train/metrics/precision.py +160 -0
  1358. mindspore/train/metrics/recall.py +159 -0
  1359. mindspore/train/metrics/roc.py +223 -0
  1360. mindspore/train/metrics/root_mean_square_surface_distance.py +191 -0
  1361. mindspore/train/metrics/topk.py +167 -0
  1362. mindspore/train/mind_ir_pb2.py +1908 -0
  1363. mindspore/train/model.py +2252 -0
  1364. mindspore/train/node_strategy_pb2.py +653 -0
  1365. mindspore/train/print_pb2.py +184 -0
  1366. mindspore/train/profiling_parallel_pb2.py +151 -0
  1367. mindspore/train/serialization.py +3325 -0
  1368. mindspore/train/summary/__init__.py +23 -0
  1369. mindspore/train/summary/_lineage_adapter.py +41 -0
  1370. mindspore/train/summary/_summary_adapter.py +496 -0
  1371. mindspore/train/summary/_writer_pool.py +207 -0
  1372. mindspore/train/summary/enums.py +56 -0
  1373. mindspore/train/summary/summary_record.py +581 -0
  1374. mindspore/train/summary/writer.py +167 -0
  1375. mindspore/train/summary_pb2.py +1165 -0
  1376. mindspore/train/train_thor/__init__.py +20 -0
  1377. mindspore/train/train_thor/convert_utils.py +268 -0
  1378. mindspore/train/train_thor/dataset_helper.py +192 -0
  1379. mindspore/train/train_thor/model_thor.py +257 -0
  1380. mindspore/utils/__init__.py +21 -0
  1381. mindspore/utils/utils.py +60 -0
  1382. mindspore/version.py +1 -0
  1383. mindspore-2.4.0.dist-info/METADATA +352 -0
  1384. mindspore-2.4.0.dist-info/RECORD +1387 -0
  1385. mindspore-2.4.0.dist-info/WHEEL +5 -0
  1386. mindspore-2.4.0.dist-info/entry_points.txt +3 -0
  1387. mindspore-2.4.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,3325 @@
1
+ # Copyright 2020-2024 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+
16
+ """Model and parameters serialization."""
17
+ from __future__ import absolute_import
18
+ from __future__ import division
19
+
20
+ import binascii
21
+ import copy
22
+ import json
23
+ import os
24
+ import re
25
+ import shutil
26
+ import stat
27
+ import threading
28
+ from threading import Thread, RLock
29
+ from multiprocessing import Process
30
+ from collections import defaultdict, OrderedDict
31
+ from io import BytesIO
32
+
33
+ import math
34
+ import sys
35
+ import time
36
+ import google
37
+ import numpy as np
38
+
39
+ from mindspore.train.checkpoint_pb2 import Checkpoint
40
+ from mindspore.train.mind_ir_pb2 import ModelProto as mindir_model
41
+ from mindspore.train.print_pb2 import Print
42
+
43
+ import mindspore
44
+ import mindspore.nn as nn
45
+ from mindspore import context
46
+ from mindspore import log as logger
47
+ from mindspore._checkparam import check_input_data, check_input_dataset
48
+ from mindspore import _checkparam as Validator
49
+ from mindspore.common import dtype as mstype
50
+ from mindspore.common.api import _cell_graph_executor as _executor
51
+ from mindspore.common.api import _MindsporeFunctionExecutor
52
+ from mindspore.common.api import _get_parameter_layout
53
+ from mindspore.common.api import _generate_branch_control_input
54
+ from mindspore.common.initializer import initializer, One
55
+ from mindspore.common.parameter import Parameter, _offload_if_config
56
+ from mindspore.common.tensor import Tensor
57
+ from mindspore._c_expression import Tensor as Tensor_
58
+ from mindspore.common._utils import is_shape_unknown
59
+ from mindspore.common.file_system import FileSystem, _register_basic_file_system, _register_mindio_file_system
60
+ from mindspore.communication.management import get_rank, get_group_size
61
+ from mindspore.experimental import MapParameter
62
+ from mindspore.ops import Cast
63
+ from mindspore.parallel._cell_wrapper import get_allgather_cell, _single_parameter_broadcast
64
+ from mindspore.parallel._tensor import _load_tensor, _get_tensor_strategy, _get_tensor_slice_index
65
+ from mindspore.parallel._tensor import _reshape_param_data, _reshape_param_data_with_weight
66
+ from mindspore.parallel._utils import _infer_rank_list, _remove_repeated_slices, _is_in_auto_parallel_mode, \
67
+ _get_device_num, _is_parallel_mode
68
+ from mindspore.parallel._auto_parallel_context import _get_auto_parallel_context
69
+ from mindspore.parallel._parallel_serialization import _convert_to_list, _convert_to_layout, _build_searched_strategy, \
70
+ _restore_group_info_list, _get_param_list_when_first_dim_sharded
71
+ from mindspore.parallel._ps_context import _set_checkpoint_load_status, _store_warm_up_ptr_by_tensor, \
72
+ _store_warm_up_ptr_by_tensor_list, _cache_enable
73
+ from mindspore.parallel.checkpoint_transform import sync_pipeline_shared_parameters
74
+ from mindspore.parallel.transform_safetensors import _load_parallel_checkpoint, _get_device_num_from_strategy, \
75
+ _extract_pipeline_stage_num
76
+ from mindspore.train._utils import read_proto, get_parameter_redundancy
77
+ from mindspore._c_expression import load_mindir, _encrypt, _decrypt, _is_cipher_file, dynamic_obfuscate_mindir, \
78
+ split_mindir, split_dynamic_mindir
79
+ from mindspore.common.generator import Generator
80
+ from safetensors.numpy import save_file
81
+ from safetensors import safe_open
82
+ from ..ops.operations._opaque_predicate_registry import add_opaque_predicate, clean_funcs
83
+
84
+ tensor_to_ms_type = {"Int8": mstype.int8, "UInt8": mstype.uint8, "Int16": mstype.int16, "UInt16": mstype.uint16,
85
+ "Int32": mstype.int32, "UInt32": mstype.uint32, "Int64": mstype.int64, "UInt64": mstype.uint64,
86
+ "Float16": mstype.float16, "Float32": mstype.float32, "Float64": mstype.float64,
87
+ "Bool": mstype.bool_, "str": mstype.string, "BFloat16": mstype.bfloat16, "Int4": mstype.qint4x2}
88
+
89
+ tensor_to_np_type = {"Int8": np.int8, "UInt8": np.uint8, "Int16": np.int16, "UInt16": np.uint16,
90
+ "Int32": np.int32, "UInt32": np.uint32, "Int64": np.int64, "UInt64": np.uint64,
91
+ "Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_, "str": "U"}
92
+
93
+ np_type_convert = {"int32": np.int32, "float32": np.float32, "float16": np.float16, "float64": np.float64}
94
+
95
+ mindir_to_tensor_type = {1: mstype.float32, 2: mstype.uint8, 3: mstype.int8, 4: mstype.uint16,
96
+ 5: mstype.int16, 6: mstype.int32, 7: mstype.int64, 10: mstype.float16,
97
+ 11: mstype.float64, 12: mstype.uint32, 13: mstype.uint64}
98
+
99
+ _ckpt_mutex = RLock()
100
+
101
+ # unit is KB
102
+ SLICE_SIZE = 512 * 1024
103
+ PROTO_LIMIT_SIZE = 1024 * 1024 * 2
104
+ TOTAL_SAVE = 1024 * 1024
105
+ PARAMETER_SPLIT_SIZE = 1024 * 1024 * 1024
106
+ ENCRYPT_BLOCK_SIZE = 64 * 1024
107
+ INT_64_MAX = 9223372036854775807
108
+
109
+ cpu_cast = Cast().set_device("CPU")
110
+
111
+ _ckpt_fs = FileSystem()
112
+
113
+
114
+ def init_ckpt_file_system(fs: FileSystem):
115
+ """Initialize checkpoint file system"""
116
+ if _register_mindio_file_system(fs):
117
+ return
118
+ _register_basic_file_system(fs)
119
+
120
+
121
+ # Initialize checkpoint file system
122
+ init_ckpt_file_system(_ckpt_fs)
123
+
124
+
125
+ def _get_cur_rank_dp(parameter_layout_dict):
126
+ """ Get dp and tp from layout dict. """
127
+ pp_num = _get_auto_parallel_context("pipeline_stages")
128
+ dev_num = _get_device_num()
129
+ global_rank = get_rank()
130
+ pipe_size = dev_num // pp_num
131
+ initial_rank = (global_rank // pipe_size) * pipe_size
132
+ parameter_redundancy_dict = get_parameter_redundancy(
133
+ parameter_layout_dict, initial_rank)
134
+ value_len = sys.maxsize
135
+ min_value = ()
136
+ for key, value in parameter_redundancy_dict.items():
137
+ if "accu_grads" in key or "inputs" in key:
138
+ continue
139
+ for item in value:
140
+ if len(item) < value_len and global_rank in item:
141
+ value_len = len(item)
142
+ min_value = item
143
+ return min_value
144
+
145
+
146
+ def get_ckpt_path_with_strategy(cur_ckpt_path, cur_strategy_path):
147
+ """
148
+ Find available checkpoint file path from all backup checkpoint files of current rank.
149
+ It suppose that checkpoint path contains substring 'rank_{rank_id}' which is used to
150
+ distinguish between different path.If cur_ckpt_path doesn't have 'rank_{rank_id}' substring, will return
151
+ cur_ckpt_path itself when cur_ckpt_path is exist, otherwise return None.
152
+
153
+ Note:
154
+ This API must be called after the communication is initialized because the cluster information
155
+ needs to be obtained internally.
156
+
157
+ Args:
158
+ cur_ckpt_path (str): the checkpoint file path which cur rank needs.
159
+ cur_strategy_path (str): strategy file path for current rank.
160
+
161
+ Returns:
162
+ - new_ckpt_file (String), if found available checkpoint file , return it.
163
+ - None, if not found available checkpoint, return None.
164
+
165
+ Examples:
166
+ >>> import mindspore as ms
167
+ >>> from mindspore.communication import init
168
+ >>> from mindspore import get_ckpt_path_with_strategy
169
+ >>> ms.set_context(mode=ms.GRAPH_MODE)
170
+ >>> ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.DATA_PARALLEL, gradients_mean=True)
171
+ >>> init()
172
+ >>> ckpt_file= "./rank_5/iteration-1_40.ckpt"
173
+ >>> strategy_file = "./src_pipeline_strategys/src_strategy_5.ckpt"
174
+ >>> ckpt_file_new = get_ckpt_path_with_strategy(ckpt_file, strategy_file)
175
+ >>> print(ckpt_file_new)
176
+ """
177
+ dp = _get_cur_rank_dp(cur_strategy_path)
178
+ pattern = r'rank_\d+'
179
+ for i in dp:
180
+ new_ckpt_path = re.sub(pattern, f"rank_{str(i)}", cur_ckpt_path)
181
+ if not os.path.isfile(new_ckpt_path):
182
+ continue
183
+ return new_ckpt_path
184
+ return None
185
+
186
+
187
+ class ParamDictFuture:
188
+ def __init__(self, executor, param_dict_future):
189
+ self.executor = executor
190
+ self.param_dict_future = param_dict_future
191
+
192
+ def result(self):
193
+ param_dict = self.param_dict_future.result()
194
+ self.executor.shutdown()
195
+ return param_dict
196
+
197
+
198
+ def _special_process_par(par, new_par):
199
+ """
200
+ Processes the special condition.
201
+
202
+ Like (12,2048,1,1)->(12,2048), this case is caused by GE 4 dimensions tensor.
203
+ """
204
+ par_shape_len = len(par.data.shape)
205
+ new_par_shape_len = len(new_par.data.shape)
206
+ if new_par_shape_len <= par_shape_len:
207
+ return False
208
+
209
+ for i in range(new_par_shape_len - par_shape_len):
210
+ if new_par.data.shape[par_shape_len + i] != 1:
211
+ return False
212
+
213
+ if new_par.data.dtype == mstype.bfloat16:
214
+ new_val = cpu_cast(new_par.data, mstype.float32).asnumpy()
215
+ else:
216
+ new_val = new_par.data.asnumpy()
217
+
218
+ new_val = new_val.reshape(par.data.shape)
219
+ par.set_data(Tensor(new_val, par.data.dtype))
220
+ return True
221
+
222
+
223
+ def _update_param(param, new_param, strict_load):
224
+ """Updates param's data from new_param's data."""
225
+ if isinstance(param.data, Tensor) and isinstance(new_param.data, Tensor):
226
+ if param.data.shape != new_param.data.shape:
227
+ if not _special_process_par(param, new_param):
228
+ logger.critical("Failed to combine the net and the parameters for param %s.", param.name)
229
+ msg = (f"For 'load_param_into_net', {param.name} in the argument 'net' should have the same shape "
230
+ f"as {param.name} in the argument 'parameter_dict'. But got its shape {param.data.shape} in"
231
+ f" the argument 'net' and shape {new_param.data.shape} in the argument 'parameter_dict'."
232
+ f"May you need to check whether the checkpoint you loaded is correct or the batch size and "
233
+ f"so on in the 'net' and 'parameter_dict' are same.")
234
+ raise RuntimeError(msg)
235
+
236
+ if param.data.dtype != new_param.data.dtype:
237
+ if _type_convert(param, new_param, strict_load):
238
+ if new_param.data.dtype == mstype.bfloat16:
239
+ new_tensor = cpu_cast(new_param.data, param.data.dtype)
240
+ else:
241
+ new_tensor = Tensor(new_param.data.asnumpy(), param.data.dtype)
242
+ param.set_data(new_tensor, param.sliced)
243
+ return
244
+
245
+ logger.critical("Failed to combine the net and the parameters for param %s.", param.name)
246
+ msg = (f"For 'load_param_into_net', {param.name} in the argument 'net' should have the same type as "
247
+ f"{param.name} in the argument 'parameter_dict'. but got its type {param.data.dtype} in the "
248
+ f"argument 'net' and type {new_param.data.dtype} in the argument 'parameter_dict'."
249
+ f"May you need to check whether the checkpoint you loaded is correct.")
250
+ raise RuntimeError(msg)
251
+
252
+ param.set_data(new_param.data, param.sliced)
253
+ return
254
+
255
+ if isinstance(param.data, Tensor) and not isinstance(new_param.data, Tensor):
256
+ if param.data.shape != (1,) and param.data.shape != ():
257
+ logger.critical("Failed to combine the net and the parameters for param %s.", param.name)
258
+ msg = (f"For 'load_param_into_net', {param.name} in the argument 'parameter_dict' is "
259
+ f"scalar, then the shape of {param.name} in the argument 'net' should be "
260
+ f"(1,) or (), but got shape {param.data.shape}."
261
+ f"May you need to check whether the checkpoint you loaded is correct.")
262
+ raise RuntimeError(msg)
263
+ param.set_data(initializer(new_param.data, param.data.shape, param.data.dtype))
264
+
265
+ elif isinstance(new_param.data, Tensor) and not isinstance(param.data, Tensor):
266
+ logger.critical("Failed to combine the net and the parameters for param %s.", param.name)
267
+ msg = (f"For 'load_param_into_net', {param.name} in the argument 'parameter_dict' is Tensor, "
268
+ f"then {param.name} in the argument 'net' also should be Tensor, but got {type(param.data)}."
269
+ f"May you need to check whether the checkpoint you loaded is correct.")
270
+ raise RuntimeError(msg)
271
+
272
+ else:
273
+ param.set_data(type(param.data)(new_param.data))
274
+
275
+
276
+ def _type_convert(param, new_param, strict_load):
277
+ """Whether to convert parameter's type during load checkpoint into network."""
278
+ float_type = (mstype.float16, mstype.float32, mstype.float64, mstype.bfloat16)
279
+ int_type = (mstype.int8, mstype.int16, mstype.int32, mstype.int64)
280
+ if not strict_load and ({param.data.dtype, new_param.data.dtype}.issubset(float_type) or
281
+ {param.data.dtype, new_param.data.dtype}.issubset(int_type)):
282
+ logger.warning(f"The type of {new_param.name}:{new_param.data.dtype} in 'parameter_dict' is different from "
283
+ f"the type of it in 'net':{param.data.dtype}, then the type convert from "
284
+ f"{new_param.data.dtype} to {param.data.dtype} in the network.")
285
+ return True
286
+ return False
287
+
288
+
289
+ def _save_weight(checkpoint_dir, model_name, iteration, params):
290
+ """Save model weight into checkpoint."""
291
+ logger.debug(f"Checkpoint dir is: '{checkpoint_dir}'")
292
+ exist_ckpt_file_list = []
293
+ if os.path.exists(checkpoint_dir):
294
+ for exist_ckpt_name in os.listdir(checkpoint_dir):
295
+ file_prefix = os.path.join(model_name, "_iteration_")
296
+ if exist_ckpt_name.startswith(file_prefix):
297
+ exist_ckpt_file_list.append(exist_ckpt_name)
298
+
299
+ param_dict = OrderedDict()
300
+ for key in params.keys():
301
+ value = params[key]
302
+ weight_type = value[0]
303
+ weight_shape = value[1]
304
+ weight_data = value[2]
305
+ weight_size = value[3]
306
+ weight_np = np.array(weight_data, dtype=weight_type.lower())
307
+ logger.debug(f"weight_type: '{weight_type}', weight_shape: '{weight_shape}', weight_size: "
308
+ f"'{weight_size}', weight_np.nbytes: '{weight_np.nbytes}'")
309
+
310
+ param_dict[key] = [weight_shape, weight_type, weight_np]
311
+ ckpt_file_save_name = model_name + "_iteration_" + iteration + ".ckpt"
312
+ ckpt_file_save_path = os.path.join(checkpoint_dir, ckpt_file_save_name)
313
+
314
+ _exec_save(ckpt_file_save_path, param_dict)
315
+
316
+ for exist_ckpt_name in exist_ckpt_file_list:
317
+ os.remove(os.path.join(checkpoint_dir, exist_ckpt_name))
318
+ logger.info(f"Save weight to checkpoint file path '{ckpt_file_save_path}' success.")
319
+ else:
320
+ logger.warning(f"Checkpoint dir: '{checkpoint_dir}' is not existed.")
321
+
322
+
323
+ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_param_inc=False, crc_check=False,
324
+ format="ckpt"):
325
+ """Execute the process of saving checkpoint into file."""
326
+ try:
327
+ with _ckpt_mutex:
328
+ file_name_list = list(os.path.splitext(ckpt_file_name))
329
+ file_name_list[1] = file_name_list[1].replace(f".{format}", ".tmp")
330
+ tmp_name = ''.join(file_name_list)
331
+ if os.path.exists(ckpt_file_name):
332
+ os.chmod(ckpt_file_name, stat.S_IWUSR)
333
+ os.remove(ckpt_file_name)
334
+ if os.path.exists(tmp_name):
335
+ os.chmod(tmp_name, stat.S_IWUSR)
336
+ os.remove(tmp_name)
337
+ if format == "ckpt":
338
+ with _ckpt_fs.create(tmp_name, *_ckpt_fs.create_args) as f:
339
+ plain_data = None
340
+ if enc_key is not None:
341
+ plain_data = BytesIO()
342
+
343
+ crc_num = 0
344
+ for name, value in data_list.items():
345
+ if name == "random_op":
346
+ _write_random_seed(name, value, f)
347
+ continue
348
+ if value[0] == "mapparameter":
349
+ _write_mapparameter(name, value, f, map_param_inc)
350
+ continue
351
+ if value[0] == "offload_parameter":
352
+ new_value = value[1:]
353
+ new_value[2] = value[3]
354
+ _write_parameter_bytes_data(name, new_value, f, enc_key, plain_data)
355
+ _offload_if_config(value[3])
356
+ continue
357
+ if value[1] == "str":
358
+ crc_num = _write_parameter_data(name, value, f, enc_key, plain_data, crc_num, crc_check)
359
+ continue
360
+ if isinstance(value[2], np.ndarray):
361
+ crc_num = _write_parameter_data(name, value, f, enc_key, plain_data, crc_num, crc_check)
362
+ continue
363
+ if isinstance(value[2], Tensor) and hasattr(value[2], "slice_num") and value[2].slice_num > 1:
364
+ _write_hugeparameter(name, value, f)
365
+ continue
366
+
367
+ crc_num = _write_parameter_bytes_data(name, value, f, enc_key, plain_data, crc_num, crc_check)
368
+
369
+ if enc_key is not None:
370
+ plain_data.seek(0)
371
+ max_block_size = ENCRYPT_BLOCK_SIZE * 1024
372
+ block_data = plain_data.read(max_block_size)
373
+ while block_data:
374
+ f.write(_encrypt(block_data, len(block_data), enc_key, len(enc_key), enc_mode))
375
+ block_data = plain_data.read(max_block_size)
376
+ if crc_check:
377
+ f.write('crc_num'.encode() + crc_num.to_bytes(10, byteorder='big'))
378
+ elif format == "safetensors":
379
+ save_dict = {}
380
+ for name, value in data_list.items():
381
+ save_dict[name] = value[2].asnumpy()
382
+ save_file(save_dict, tmp_name)
383
+ if not os.path.exists(tmp_name):
384
+ logger.warning(f"Rename failed, can't find {tmp_name}, it is possible that multiple processes have "
385
+ f"simultaneously modified a file.")
386
+ else:
387
+ os.rename(tmp_name, ckpt_file_name)
388
+ os.chmod(ckpt_file_name, stat.S_IRUSR)
389
+ except BaseException as e:
390
+ logger.critical("Failed to save the checkpoint file %s. Maybe don't have the permission to write files, "
391
+ "or the disk space is insufficient and so on.", ckpt_file_name)
392
+ raise e
393
+
394
+
395
+ def _write_random_seed(name, value, f):
396
+ """Write random op into protobuf file."""
397
+ checkpoint_list = Checkpoint()
398
+ param_value = checkpoint_list.value.add()
399
+ param_value.tag = name
400
+ param_tensor = param_value.tensor
401
+ param_tensor.dims.extend(0)
402
+ param_tensor.tensor_type = "random_op"
403
+ param_tensor.tensor_content = value
404
+ f.write(checkpoint_list.SerializeToString())
405
+
406
+
407
+ def _write_parameter_data(name, value, f, enc_key, plain_data, crc_num=0, crc_check=False):
408
+ """Write parameter data into protobuf file."""
409
+ data_size = value[2].nbytes / 1024
410
+ if data_size > SLICE_SIZE:
411
+ slice_count = math.ceil(data_size / SLICE_SIZE)
412
+ param_slice_list = np.array_split(value[2], slice_count)
413
+ else:
414
+ param_slice_list = [value[2]]
415
+
416
+ for param_slice in param_slice_list:
417
+ checkpoint_list = Checkpoint()
418
+ param_value = checkpoint_list.value.add()
419
+ param_value.tag = name
420
+ param_tensor = param_value.tensor
421
+ param_tensor.dims.extend(value[0])
422
+ param_tensor.tensor_type = value[1]
423
+ param_tensor.tensor_content = param_slice.tobytes()
424
+
425
+ if enc_key is None:
426
+ output_data = checkpoint_list.SerializeToString()
427
+ if crc_check:
428
+ crc_num = binascii.crc32(output_data, crc_num)
429
+ f.write(output_data)
430
+ else:
431
+ plain_data.write(checkpoint_list.SerializeToString())
432
+
433
+ return crc_num
434
+
435
+
436
+ def _write_parameter_bytes_data(name, value, f, enc_key, plain_data, crc_num=0, crc_check=False):
437
+ """Write parameter bytes data into protobuf file."""
438
+ bytes_value = value[2].get_bytes()
439
+ chunk_size = 1024 * SLICE_SIZE
440
+
441
+ for i in range(0, len(bytes_value), chunk_size):
442
+ checkpoint_list = Checkpoint()
443
+ param_value = checkpoint_list.value.add()
444
+ param_value.tag = name
445
+ param_tensor = param_value.tensor
446
+ param_tensor.dims.extend(value[0])
447
+ param_tensor.tensor_type = value[1]
448
+ param_tensor.tensor_content = bytes_value[i:i + chunk_size]
449
+
450
+ if enc_key is None:
451
+ output_data = checkpoint_list.SerializeToString()
452
+ if crc_check:
453
+ crc_num = binascii.crc32(output_data, crc_num)
454
+ f.write(output_data)
455
+ else:
456
+ plain_data.write(checkpoint_list.SerializeToString())
457
+
458
+ return crc_num
459
+
460
+
461
+ def _write_mapparameter(name, value, f, map_param_inc=False):
462
+ """Write map parameter into protobuf file."""
463
+ while True:
464
+ logger.info("Checkpoint save map_parameter.")
465
+ data_map_slice = value[1].export_slice_data(map_param_inc)
466
+ checkpoint_list = Checkpoint()
467
+ param_value = checkpoint_list.value.add()
468
+ param_value.tag = name
469
+ map_tensor = param_value.maptensor
470
+ for numpy_data in data_map_slice[:3]:
471
+ tensor_pro = map_tensor.tensor.add()
472
+ tensor_pro.dims.extend(numpy_data.shape)
473
+ tensor_pro.tensor_type = str(numpy_data.dtype)
474
+ tensor_pro.tensor_content = numpy_data.reshape(-1).tobytes()
475
+ f.write(checkpoint_list.SerializeToString())
476
+ if data_map_slice[3]:
477
+ break
478
+
479
+
480
+ def _write_hugeparameter(name, value, f):
481
+ """Write huge parameter into protobuf file."""
482
+ slice_num = value[2].slice_num
483
+ offset = 0
484
+ max_size = value[0][0]
485
+ for param_slice in range(slice_num):
486
+ checkpoint_list = Checkpoint()
487
+ param_value = checkpoint_list.value.add()
488
+ param_value.tag = name
489
+ param_tensor = param_value.tensor
490
+ param_tensor.dims.extend(value[0])
491
+ param_tensor.tensor_type = value[1]
492
+ param_key = value[3]
493
+ numpy_data = value[2].asnumpy_of_slice_persistent_data(param_key, param_slice)
494
+ if offset + numpy_data.shape[0] > max_size:
495
+ numpy_data = numpy_data[:max_size - offset]
496
+ param_tensor.tensor_content = numpy_data.tobytes()
497
+ f.write(checkpoint_list.SerializeToString())
498
+ offset += numpy_data.shape[0]
499
+
500
+
501
+ def _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name, format):
502
+ """Check save_obj and ckpt_file_name for save_checkpoint."""
503
+ if format not in ["safetensors", "ckpt"]:
504
+ raise ValueError(f"For 'save_checkpoint', the format must be "
505
+ f"'safetensors' or 'ckpt', but got {format}.")
506
+ if not isinstance(save_obj, (nn.Cell, list, dict)):
507
+ raise TypeError("For 'save_checkpoint', the parameter 'save_obj' must be nn.Cell, list or dict, "
508
+ "but got {}.".format(type(save_obj)))
509
+ if not isinstance(ckpt_file_name, str):
510
+ raise TypeError("For 'save_checkpoint', the parameter {} for checkpoint file name is invalid,"
511
+ "'ckpt_file_name' must be "
512
+ "string, but got {}.".format(ckpt_file_name, type(ckpt_file_name)))
513
+ ckpt_file_name = os.path.realpath(ckpt_file_name)
514
+ if os.path.isdir(ckpt_file_name):
515
+ raise IsADirectoryError("For 'save_checkpoint', the parameter `ckpt_file_name`: {} is a directory, "
516
+ "it must be a file name.".format(ckpt_file_name))
517
+ if not ckpt_file_name.endswith(format):
518
+ ckpt_file_name += f".{format}"
519
+ return ckpt_file_name
520
+
521
+
522
+ def _check_format_and_other_params(format, enc_key, enc_mode, crc_check=False, async_save=False, map_param_inc=False,
523
+ global_step_num=None):
524
+ param_not_default = (enc_key is not None or enc_mode != "AES-GCM" or crc_check or async_save
525
+ or map_param_inc or global_step_num is not None)
526
+ if format == "safetensors" and param_not_default:
527
+ raise ValueError("For 'save_checkpoint', when format is 'safetensors', other param must be default.")
528
+
529
+
530
+ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
531
+ async_save=False, append_dict=None, enc_key=None, enc_mode="AES-GCM", choice_func=None,
532
+ crc_check=False, format="ckpt", **kwargs):
533
+ r"""
534
+ Save checkpoint to a specified file.
535
+
536
+ Note:
537
+ The `enc_mode` and `crc_check` parameters are mutually exclusive and cannot be configured simultaneously.
538
+
539
+ Args:
540
+ save_obj (Union[Cell, list, dict]): The object to be saved. The data type can be :class:`mindspore.nn.Cell`,
541
+ list, or dict. If a list, it can be the returned value of `Cell.trainable_params()`, or a list of dict
542
+ elements(each element is a dictionary, like [{"name": param_name, "data": param_data},...], the type of
543
+ `param_name` must be string, and the type of `param_data` must be parameter or Tensor); If dict,
544
+ it can be the returned value of `mindspore.load_checkpoint()`.
545
+ ckpt_file_name (str): Checkpoint file name. If the file name already exists, it will be overwritten.
546
+ integrated_save (bool): Whether to integrated save in automatic model parallel scene. Default: ``True`` .
547
+ async_save (bool): Whether to open an independent thread to save the checkpoint file. Default: ``False`` .
548
+ append_dict (dict): Additional information that needs to be saved. The key of dict must be str, the value
549
+ of dict must be one of int, float, bool, string, Parameter or Tensor. Default: ``None`` .
550
+ enc_key (Union[None, bytes]): Byte type key used for encryption. If the value is ``None`` , the encryption
551
+ is not required. Default: ``None`` .
552
+ enc_mode (str): This parameter is valid only when enc_key is not set to ``None`` . Specifies the encryption
553
+ mode, currently supports ``"AES-GCM"`` and ``"AES-CBC"`` and ``"SM4-CBC"`` .
554
+ Default: ``"AES-GCM"`` .
555
+ choice_func (function) : A function for saving custom selected parameters. The input value of `choice_func` is
556
+ a parameter name in string type, and the returned value is a bool.
557
+ If returns ``True`` , the Parameter that matching the custom condition will be saved.
558
+ If returns ``False`` , the Parameter that not matching the custom condition will not
559
+ be saved. Default: ``None`` .
560
+ crc_check (bool) : Whether to perform crc32 calculation when saving checkpoint and save the calculation
561
+ result to the file. Default: ``False`` .
562
+ format (str): Format of the output file, can be "ckpt" or "safetensors". Default: "ckpt".
563
+ kwargs (dict): Configuration options dictionary.
564
+
565
+ Raises:
566
+ TypeError: If the parameter `save_obj` is not :class:`mindspore.nn.Cell` , list or dict type.
567
+ TypeError: If the parameter `integrated_save` or `async_save` is not bool type.
568
+ TypeError: If the parameter `ckpt_file_name` is not string type.
569
+
570
+ Examples:
571
+ >>> import mindspore as ms
572
+ >>>
573
+ >>> # Define the network structure of LeNet5. Refer to
574
+ >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
575
+ >>> net = LeNet5()
576
+ >>> ms.save_checkpoint(net, "./lenet.ckpt",
577
+ ... choice_func=lambda x: x.startswith("conv") and not x.startswith("conv1"))
578
+ >>> param_dict1 = ms.load_checkpoint("./lenet.ckpt")
579
+ >>> print(param_dict1)
580
+ {'conv2.weight': Parameter (name=conv2.weight, shape=(16, 6, 5, 5), dtype=Float32, requires_grad=True)}
581
+ >>> params_list = net.trainable_params()
582
+ >>> ms.save_checkpoint(params_list, "./lenet_list.ckpt",
583
+ ... choice_func=lambda x: x.startswith("conv") and not x.startswith("conv2"))
584
+ >>> param_dict2 = ms.load_checkpoint("./lenet_list.ckpt")
585
+ >>> print(param_dict2)
586
+ {'conv1.weight': Parameter (name=conv1.weight, shape=(6, 1, 5, 5), dtype=Float32, requires_grad=True)}
587
+ >>> ms.save_checkpoint(param_dict2, "./lenet_dict.ckpt")
588
+ >>> param_dict3 = ms.load_checkpoint("./lenet_dict.ckpt")
589
+ >>> print(param_dict3)
590
+ {'conv1.weight': Parameter (name=conv1.weight, shape=(6, 1, 5, 5), dtype=Float32, requires_grad=True)}
591
+
592
+ Tutorial Examples:
593
+ - `Saving and Loading the Model - Saving and Loading the Model Weight
594
+ <https://mindspore.cn/tutorials/en/master/beginner/save_load.html#saving-and-loading-the-model-weight>`_
595
+ """
596
+ ckpt_file_name = _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name, format)
597
+ integrated_save = Validator.check_bool(integrated_save)
598
+ async_save = Validator.check_bool(async_save)
599
+ append_dict = _check_append_dict(append_dict)
600
+ enc_key = Validator.check_isinstance('enc_key', enc_key, (type(None), bytes))
601
+ enc_mode = Validator.check_isinstance('enc_mode', enc_mode, str)
602
+ crc_check = Validator.check_isinstance('crc_check', crc_check, bool)
603
+ map_param_inc = kwargs.get('incremental', False)
604
+ logger.info("Execute the process of saving checkpoint files.")
605
+ global_step_num = kwargs.get('global_step_num', None)
606
+ _check_format_and_other_params(format, enc_key, enc_mode, crc_check, async_save, map_param_inc, global_step_num)
607
+
608
+ if append_dict and "__exception_save__" in append_dict:
609
+ s1 = mindspore.hal.Stream()
610
+ with mindspore.hal.StreamCtx(s1):
611
+ save_obj = _convert_save_obj_to_param_list(save_obj, integrated_save, append_dict, choice_func)
612
+ s1.synchronize()
613
+ else:
614
+ save_obj = _convert_save_obj_to_param_list(save_obj, integrated_save, append_dict, choice_func)
615
+
616
+ if append_dict:
617
+ if "__exception_save__" in append_dict:
618
+ del append_dict["__exception_save__"]
619
+ append_info_list = []
620
+ for k_name, value in append_dict.items():
621
+ if isinstance(value, Generator):
622
+ value = value.get_state()
623
+ elif not isinstance(value, str):
624
+ value = Tensor(value)
625
+ append_info_list.append({"name": k_name, "data": value})
626
+ save_obj.extend(append_info_list)
627
+
628
+ data_list = OrderedDict()
629
+ data_list_np = OrderedDict()
630
+ with _ckpt_mutex:
631
+ for param in save_obj:
632
+ if param["name"] == "random_op":
633
+ if os.getenv("AITURBO") == "1":
634
+ data_list_np["random_op"] = []
635
+ data_list_np["random_op"].append(param["data"])
636
+ if crc_check:
637
+ bytes_value = bytes(data_list_np[key][0])
638
+ data_list_np[key].append(binascii.crc32(bytes_value))
639
+ else:
640
+ data_list["random_op"] = param["data"]
641
+ continue
642
+ key = param["name"]
643
+ data_list[key] = []
644
+ data_list_np[key] = []
645
+ if isinstance(param["data"], MapParameter):
646
+ data_list[param["name"]].append("mapparameter")
647
+ data_list[param["name"]].append(param["data"])
648
+ continue
649
+ if isinstance(param["data"], list):
650
+ if param["data"][0] == "persistent_data":
651
+ _save_param_list_data(data_list, key, param)
652
+ elif param["data"][0] == "offload_parameter":
653
+ data_list[key].append("offload_parameter")
654
+ _save_param_list_data(data_list, key, param)
655
+
656
+ if isinstance(param["data"], str):
657
+ if os.getenv("AITURBO") == "1":
658
+ data_list_np[key].append(np.array(param["data"]))
659
+ if crc_check:
660
+ bytes_value = data_list_np[key][0].tobytes()
661
+ data_list_np[key].append(binascii.crc32(bytes_value))
662
+ else:
663
+ data_list[key].append([0])
664
+ data_list[key].append('str')
665
+ data = np.array(param["data"])
666
+ data_list[key].append(data)
667
+ else:
668
+ if isinstance(param["data"], Parameter):
669
+ param["data"].init_data()
670
+ if os.getenv("AITURBO") == "1":
671
+ data_list_np[key].append(param["data"].asnumpy())
672
+ if crc_check:
673
+ bytes_value = data_list_np[key][0].tobytes()
674
+ data_list_np[key].append(binascii.crc32(bytes_value))
675
+ else:
676
+ dims = []
677
+ for dim in param['data'].shape:
678
+ dims.append(dim)
679
+ data_list[key].append(dims)
680
+ tensor_type = str(param["data"].dtype)
681
+ data_list[key].append(tensor_type)
682
+ data = param["data"]
683
+ data_list[key].append(data)
684
+
685
+ if os.getenv("AITURBO") == "1":
686
+ from aiturbo.checkpoint import aiturbo_mindspore as aiturbo
687
+ ckpt_name = os.path.basename(ckpt_file_name)
688
+ aiturbo.save_ckpt(ckpt_name, global_step_num, data_list_np, crc_check)
689
+ elif async_save:
690
+ data_copy = copy.deepcopy(data_list)
691
+ thr = Thread(target=_exec_save,
692
+ args=(ckpt_file_name, data_copy, enc_key, enc_mode, map_param_inc, crc_check, format),
693
+ name="asyn_save_ckpt")
694
+ thr.start()
695
+ else:
696
+ _exec_save(ckpt_file_name, data_list, enc_key, enc_mode, map_param_inc, crc_check, format)
697
+
698
+ logger.info("Saving checkpoint process is finished.")
699
+
700
+
701
+ def _convert_list_to_param_list(save_obj, choice_func):
702
+ """Convert a list of Parameter to param_list."""
703
+ param_list = []
704
+ if not save_obj:
705
+ return param_list
706
+ if isinstance(save_obj[0], dict):
707
+ for param in save_obj:
708
+ if isinstance(param, dict) and "name" in param and "data" in param:
709
+ if not isinstance(param["name"], str):
710
+ raise TypeError(f"For save_checkpoint, when save_obj is a list of dict items, the name in dict "
711
+ f"should be string, but got {type(param['name'])}.")
712
+ if not isinstance(param["data"], Tensor):
713
+ raise TypeError(f"For save_checkpoint, when save_obj is a list of dict items, the data in dict "
714
+ f"should be parameter, but got {type(param['data'])}.")
715
+ if choice_func is not None and not choice_func(param["name"]):
716
+ continue
717
+ each_param = {"name": param["name"], "data": param["data"]}
718
+ param_list.append(each_param)
719
+ else:
720
+ raise TypeError(f"For save_checkpoint, save_obj should be a list of dict items, and the dict should "
721
+ f"have key values 'name' and 'value', but got {type(param)} and {param}.")
722
+ else:
723
+ for param in save_obj:
724
+ if isinstance(param, Parameter):
725
+ if choice_func is not None and not choice_func(param.name):
726
+ continue
727
+ each_param = {"name": param.name, "data": param}
728
+ param_list.append(each_param)
729
+ else:
730
+ raise TypeError(f"For save_checkpoint, when save_obj is made up by list of Parameter,"
731
+ f"the param should be parameter, but got {type(param)}")
732
+ return param_list
733
+
734
+
735
+ def _convert_dict_to_param_dict(save_obj, choice_func):
736
+ """Convert a dict of Parameter to param_list."""
737
+ param_list = []
738
+ for (key, value) in save_obj.items():
739
+ if isinstance(key, str) and isinstance(value, (Parameter, str)):
740
+ if choice_func is not None and not choice_func(key):
741
+ continue
742
+ each_param = {"name": key, "data": value}
743
+ param_list.append(each_param)
744
+ else:
745
+ raise TypeError(f"For save_checkpoint, when save_obj is made up by dict, the key should be str and"
746
+ f"value should be Parameter, but got the type of key is {type(key)} and"
747
+ f"the type of value is {type(value)}")
748
+ return param_list
749
+
750
+
751
+ def _convert_cell_param_and_names_to_dict(save_obj, choice_func):
752
+ """Convert cell.parameters_and_names to OrderedDict."""
753
+ param_dict = OrderedDict()
754
+ for _, param in save_obj.parameters_and_names():
755
+ not_sliced = not param.sliced
756
+ is_graph_mode = context.get_context('mode') == context.GRAPH_MODE
757
+ # All parameters are initialized immediately under PyNative mode, skip this judgement.
758
+ judgment = not_sliced or param.has_init
759
+ if is_graph_mode and _is_in_auto_parallel_mode() and judgment:
760
+ continue
761
+ if choice_func is not None and not choice_func(param.name):
762
+ continue
763
+ # Add suffix for cache_enabled parameter, and then parameter can carry key info.
764
+ # Notice that suffix needs be removed when loading into net.
765
+ if param.cache_enable:
766
+ param_dict[param.name + ".__param_key__" + str(param.key)] = param
767
+ else:
768
+ param_dict[param.name] = param
769
+ return param_dict
770
+
771
+
772
+ def _convert_cell_to_param_list(save_obj, integrated_save, append_dict, choice_func):
773
+ """Convert nn.Cell to param_list."""
774
+ sync_pipeline_shared_parameters(save_obj)
775
+ param_list = []
776
+ parameter_layout_dict = save_obj.parameter_layout_dict
777
+ if _is_in_auto_parallel_mode() and not parameter_layout_dict:
778
+ parameter_layout_dict = _get_parameter_layout()
779
+ if not _is_in_auto_parallel_mode():
780
+ save_obj.init_parameters_data()
781
+ param_dict = _convert_cell_param_and_names_to_dict(save_obj, choice_func)
782
+ if append_dict and "random_op" in append_dict:
783
+ phase = 'train' + '.' + str(save_obj.create_time) + '.' + str(id(save_obj)) + '.' + save_obj.arguments_key
784
+ if phase in save_obj.compile_cache and _executor.has_compiled(phase):
785
+ random_byte = _executor._graph_executor.get_random_status(phase)
786
+ param_list.append({"name": "random_op", "data": random_byte})
787
+ append_dict.pop("random_op")
788
+ for (key, value) in param_dict.items():
789
+ each_param = {"name": key}
790
+ if isinstance(value, MapParameter):
791
+ each_param["data"] = value
792
+ param_list.append(each_param)
793
+ continue
794
+
795
+ if value.data.is_persistent_data():
796
+ # list save persistent_data: [Tensor, shape, type, param.key]
797
+ param_data = ["persistent_data", value.data, value.param_info.origin_shape, str(value.dtype), value.key]
798
+ elif value.data.offload_file_path() != "":
799
+ # list save offload data: [Param, shape, type, param.key]
800
+ param_data = ["offload_parameter"]
801
+ param_tensor = value.data
802
+ if key in parameter_layout_dict:
803
+ param_tensor = _get_merged_param_data(save_obj, parameter_layout_dict, key, param_tensor,
804
+ integrated_save)
805
+ param_data.append(param_tensor)
806
+ param_data.append(param_tensor.shape)
807
+ param_data.append(str(param_tensor.dtype))
808
+ param_data.append(value.key)
809
+ else:
810
+ param_data = value.data
811
+ if append_dict and "__exception_save__" in append_dict:
812
+ param_data = Tensor(Tensor_.move_to(value, "CPU", False))
813
+
814
+ # in automatic model parallel scenario, some parameters were split to all the devices,
815
+ # which should be combined before saving
816
+ if key in parameter_layout_dict:
817
+ if not append_dict or "__exception_save__" not in append_dict:
818
+ param_data = Tensor(value.data)
819
+ param_data = _get_merged_param_data(save_obj, parameter_layout_dict, key, param_data,
820
+ integrated_save)
821
+
822
+ each_param["data"] = param_data
823
+ param_list.append(each_param)
824
+ return param_list
825
+
826
+
827
+ def _convert_save_obj_to_param_list(save_obj, integrated_save, append_dict, choice_func):
828
+ """Convert a save_obj to param_list."""
829
+ if isinstance(save_obj, list):
830
+ return _convert_list_to_param_list(save_obj, choice_func)
831
+
832
+ if isinstance(save_obj, dict):
833
+ return _convert_dict_to_param_dict(save_obj, choice_func)
834
+
835
+ return _convert_cell_to_param_list(save_obj, integrated_save, append_dict, choice_func)
836
+
837
+
838
+ def _save_param_list_data(data_list, key, param):
839
+ """Save persistent data into save_obj."""
840
+ dims = []
841
+ # persistent_data shape can not be ()
842
+ for dim in param['data'][2]:
843
+ dims.append(dim)
844
+ data_list[key].append(dims)
845
+ data_list[key].append(param['data'][3])
846
+ data_list[key].append(param['data'][1])
847
+ data_list[key].append(param['data'][4])
848
+
849
+
850
+ def _check_append_dict(append_dict):
851
+ """Check the argument append_dict for save_checkpoint."""
852
+ if append_dict is None:
853
+ return append_dict
854
+ if not isinstance(append_dict, dict):
855
+ raise TypeError("For 'save_checkpoint', the argument 'append_dict' must be dict, but got "
856
+ "{}.".format(type(append_dict)))
857
+ for key, value in append_dict.items():
858
+ if not isinstance(key, str) or not isinstance(value, (int, float, bool, str, Parameter, Tensor, Generator)):
859
+ raise TypeError(f"For 'save_checkpoint', the type of dict 'append_info' must be key: string, "
860
+ f"value: int, float, bool or Generator, but got key: {type(key)}, value: {type(value)}")
861
+ return append_dict
862
+
863
+
864
+ def _check_load_obfuscate(**kwargs):
865
+ if 'obf_func' in kwargs.keys():
866
+ customized_func = _check_customized_func(kwargs.get('obf_func'))
867
+ clean_funcs()
868
+ add_opaque_predicate(customized_func.__name__, customized_func)
869
+ return True
870
+ return False
871
+
872
+
873
+ def load(file_name, **kwargs):
874
+ """
875
+ Load MindIR.
876
+
877
+ The returned object can be executed by a `GraphCell`, see class :class:`mindspore.nn.GraphCell` for more details.
878
+
879
+ Args:
880
+ file_name (str): MindIR file name.
881
+
882
+ kwargs (dict): Configuration options dictionary.
883
+
884
+ - dec_key (bytes): Byte-type key used for decryption. The valid length is 16, 24, or 32.
885
+ - dec_mode (Union[str, function]): Specifies the decryption mode, to take effect when dec_key is set.
886
+
887
+ - Option: 'AES-GCM', 'AES-CBC', 'SM4-CBC' or customized decryption. Default: ``'AES-GCM'``.
888
+ - For details of using the customized decryption, please check the `tutorial
889
+ <https://mindspore.cn/mindarmour/docs/en/master/model_encrypt_protection.html>`_.
890
+
891
+ - obf_func (function): A python function used for loading obfuscated MindIR model, which can refer to
892
+ `obfuscate_model()
893
+ <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.obfuscate_model.html>`_.
894
+
895
+ Returns:
896
+ GraphCell, a compiled graph that can executed by `GraphCell`.
897
+
898
+ Raises:
899
+ ValueError: MindIR file does not exist or `file_name` is not a string.
900
+ RuntimeError: Failed to parse MindIR file.
901
+
902
+ Examples:
903
+ >>> import numpy as np
904
+ >>> import mindspore as ms
905
+ >>> import mindspore.nn as nn
906
+ >>> from mindspore import Tensor
907
+ >>> from mindspore import context
908
+ >>> context.set_context(mode=context.GRAPH_MODE)
909
+ >>>
910
+ >>> net = nn.Conv2d(1, 1, kernel_size=3, weight_init="ones")
911
+ >>> input_tensor = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
912
+ >>> ms.export(net, input_tensor, file_name="net", file_format="MINDIR")
913
+ >>> graph = ms.load("net.mindir")
914
+ >>> net = nn.GraphCell(graph)
915
+ >>> output = net(input_tensor)
916
+ >>> print(output)
917
+ [[[[4. 6. 4.]
918
+ [6. 9. 6.]
919
+ [4. 6. 4.]]]]
920
+
921
+ Tutorial Examples:
922
+ - `Saving and Loading the Model - Saving and Loading MindIR
923
+ <https://mindspore.cn/tutorials/en/master/beginner/save_load.html#saving-and-loading-mindir>`_
924
+ """
925
+ if not isinstance(file_name, str):
926
+ raise ValueError("For 'load', the argument 'file_name' must be string, but "
927
+ "got {}.".format(type(file_name)))
928
+ if not file_name.endswith(".mindir"):
929
+ raise ValueError("For 'load', the argument 'file_name'(MindIR file) should end with '.mindir', "
930
+ "please input the correct 'file_name'.")
931
+ if not os.path.exists(file_name):
932
+ raise ValueError("For 'load', the argument 'file_name'(MindIR file) does not exist, "
933
+ "please check whether the 'file_name' is correct.")
934
+ file_name = os.path.realpath(file_name)
935
+
936
+ # set customized functions for dynamic obfuscation
937
+ obfuscated = _check_load_obfuscate(**kwargs)
938
+
939
+ logger.info("Execute the process of loading mindir.")
940
+ if 'dec_key' in kwargs.keys():
941
+ dec_key = Validator.check_isinstance('dec_key', kwargs.get('dec_key'), bytes)
942
+ dec_mode = "AES-GCM"
943
+ dec_func = None
944
+ if 'dec_mode' in kwargs.keys():
945
+ if callable(kwargs.get('dec_mode')):
946
+ dec_mode = "Customized"
947
+ dec_func = kwargs.get('dec_mode')
948
+ else:
949
+ dec_mode = Validator.check_isinstance('dec_mode', kwargs.get('dec_mode'), str)
950
+ graph = load_mindir(file_name, dec_key=dec_key, key_len=len(dec_key), dec_mode=dec_mode,
951
+ decrypt=dec_func, obfuscated=obfuscated)
952
+ else:
953
+ graph = load_mindir(file_name, obfuscated=obfuscated)
954
+
955
+ if graph is None:
956
+ if _is_cipher_file(file_name):
957
+ raise RuntimeError("Load MindIR failed. The file may be encrypted and decrypt failed, you "
958
+ "can check whether the values of the arguments 'dec_key' and 'dec_mode'"
959
+ " are the same as when exported MindIR file, or check the file integrity.")
960
+ raise RuntimeError("Load MindIR failed.")
961
+ return graph
962
+
963
+
964
+ def export_split_mindir(file_name, device_num=8, rank_id=0, dynamic=True, sapp=True):
965
+ """
966
+ Auto Split MindIR.
967
+
968
+ The returned object can be executed by a `GraphCell`, see class :class:`mindspore.nn.GraphCell` for more details.
969
+
970
+ Args:
971
+ file_name (str): MindIR file name.
972
+ device_num (int): device number. Default: '8'.
973
+ rank_id (int): rank id. Default: '0'.
974
+ dynamic (bool): Indicates whether the model is a dynamic shape mindir model. Default: 'True'.
975
+ sapp (bool): Indicates whether to automatically generate split strategy through SAPP. Default: 'True'.
976
+
977
+ Raises:
978
+ ValueError: MindIR file does not exist or `file_name` is not a string.
979
+ RuntimeError: Failed to split MindIR file.
980
+
981
+ Examples:
982
+ >>> import mindspore as ms
983
+ >>> context.set_context(mode=context.GRAPH_MODE)
984
+ >>>
985
+ >>> ms.export_split_mindir("net.mindir", device_num=8, rank_id=0)
986
+
987
+ """
988
+ if not isinstance(file_name, str):
989
+ raise ValueError("For 'Split MindIR', the argument 'file_name' must be string, but "
990
+ "got {}.".format(type(file_name)))
991
+ if not file_name.endswith(".mindir"):
992
+ raise ValueError("For 'Split MindIR', the argument 'file_name'(MindIR file) should end with '.mindir', "
993
+ "please input the correct 'file_name'.")
994
+ if not os.path.exists(file_name):
995
+ raise ValueError("For 'Split MindIR', the argument 'file_name'(MindIR file) does not exist, "
996
+ "please check whether the 'file_name' is correct.")
997
+ file_name = os.path.realpath(file_name)
998
+
999
+ logger.info("Execute the process of export and split mindir.")
1000
+ dynamic = True
1001
+ if dynamic:
1002
+ graph = split_dynamic_mindir(file_name, device_num, rank_id, sapp)
1003
+ else:
1004
+ graph = split_mindir(file_name)
1005
+
1006
+ if graph is None:
1007
+ if _is_cipher_file(file_name):
1008
+ raise RuntimeError("Export and split MindIR failed. The file may be encrypted and decrypt failed, you "
1009
+ "can check whether the values of the arguments 'dec_key' and 'dec_mode'"
1010
+ " are the same as when exported MindIR file, or check the file integrity.")
1011
+ raise RuntimeError("Export and split MindIR failed.")
1012
+ return graph
1013
+
1014
+
1015
+ def _check_param_type(param_config, key, target_type, requested):
1016
+ """check type of parameters"""
1017
+ if key in param_config:
1018
+ if not isinstance(param_config[key], target_type):
1019
+ raise TypeError("The type of {} must be {}, but got {}.".format(key, target_type, type(param_config[key])))
1020
+ if key == 'obf_random_seed':
1021
+ if param_config[key] > INT_64_MAX or param_config[key] <= 0:
1022
+ raise ValueError(
1023
+ "'obf_random_seed' must be in (0, INT_64_MAX({})], but got {}.".format(INT_64_MAX,
1024
+ param_config[key]))
1025
+ return param_config[key]
1026
+ if requested:
1027
+ raise ValueError("The parameter {} is requested, but not got.".format(key))
1028
+ if key == "obf_random_seed":
1029
+ return 0
1030
+ return None
1031
+
1032
+
1033
+ def _check_customized_func(customized_func):
1034
+ """ check customized function of dynamic obfuscation """
1035
+ if not callable(customized_func):
1036
+ raise TypeError(
1037
+ "'customized_func' must be a function, but not got {}.".format(type(customized_func)))
1038
+ # test customized_func
1039
+ try:
1040
+ func_result = customized_func(1.0, 1.0)
1041
+ except Exception as ex:
1042
+ raise TypeError("customized_func must be a function with two inputs, but got exception: {}".format(ex))
1043
+ else:
1044
+ if not isinstance(func_result, bool):
1045
+ raise TypeError("Return value of customized_func must be boolean, but got: {}".format(type(func_result)))
1046
+ return customized_func
1047
+
1048
+
1049
+ def _check_obfuscate_params(obf_config):
1050
+ """Check obfuscation parameters, including obf_random_seed, obf_ratio, customized_func"""
1051
+ if 'obf_random_seed' not in obf_config.keys() and 'customized_func' not in obf_config.keys():
1052
+ raise ValueError(
1053
+ "At least one of 'obf_random_seed' or 'customized_func' must be set in obf_config, but got None of them.")
1054
+ obfuscate_type = _check_param_type(obf_config, "type", str, False)
1055
+ if obfuscate_type not in (None, "dynamic"):
1056
+ raise ValueError("Only 'dynamic' type is supported by now, but got {}.".format(obfuscate_type))
1057
+ if ('obf_ratio' in obf_config) and isinstance(obf_config['obf_ratio'], str):
1058
+ if obf_config['obf_ratio'] not in ["small", "medium", "large"]:
1059
+ raise ValueError("'obf_ratio' can only be 'small', 'medium', 'large' or float, but got {}.".format(
1060
+ obf_config['obf_ratio']))
1061
+ ratio_dict = {"small": 0.1, "medium": 0.3, "large": 0.6}
1062
+ obf_config['obf_ratio'] = ratio_dict.get(obf_config['obf_ratio'])
1063
+ obf_ratio = _check_param_type(obf_config, "obf_ratio", float, True)
1064
+ if (obf_ratio <= 0) or (obf_ratio > 1):
1065
+ raise ValueError("'obf_ratio' must be in (0, 1] if it is a float, but got {}.".format(obf_config['obf_ratio']))
1066
+ customized_funcs = []
1067
+ if 'customized_func' in obf_config.keys():
1068
+ device_target = context.get_context('device_target')
1069
+ if device_target in ["GPU", "Ascend"]:
1070
+ raise ValueError(
1071
+ "Customized func mode only support 'device_target'='CPU, but got {}.".format(device_target))
1072
+ customized_funcs.append(_check_customized_func(obf_config['customized_func']))
1073
+ obf_random_seed = _check_param_type(obf_config, "obf_random_seed", int, False)
1074
+ return obf_ratio, customized_funcs, obf_random_seed
1075
+
1076
+
1077
+ def obfuscate_model(obf_config, **kwargs):
1078
+ """
1079
+ Obfuscate a model of MindIR format. Obfuscation means changing the struct of a network without affecting its
1080
+ predict correctness. The obfuscated model can prevent attackers from stealing the model.
1081
+
1082
+ Args:
1083
+ obf_config (dict): obfuscation config.
1084
+
1085
+ - type (str): The type of obfuscation, only 'dynamic' is supported until now.
1086
+ - original_model_path (str): The path of MindIR format model that need to be obfuscated. If the original
1087
+ model is encrypted, then enc_key and enc_mode should be provided.
1088
+ - save_model_path (str): The path to save the obfuscated model.
1089
+ - model_inputs (list(Tensor)): The inputs of the original model, the values of Tensor can be random, which
1090
+ is the same as using :func:`mindspore.export`.
1091
+ - obf_ratio (Union(float, str)): The ratio of nodes in original model that would be obfuscated. `obf_ratio`
1092
+ should be in range of (0, 1] or in ["small", "medium", "large"]. "small", "medium" and "large" are
1093
+ correspond to 0.1, 0.3, and 0.6 respectively.
1094
+ - customized_func (function): A python function used for customized function mode, which used for control
1095
+ the switch branch of obfuscation structure. The outputs of customized_func should be boolean and const (
1096
+ Reference to 'my_func()' in
1097
+ `tutorials <https://www.mindspore.cn/mindarmour/docs/en/master/dynamic_obfuscation_protection.html>`_).
1098
+ This function needs to ensure that its result is constant for any input. Users can refer to opaque
1099
+ predicates. If customized_func is set, then it should be passed to :func:`mindspore.load` interface
1100
+ when loading obfuscated model.
1101
+ - obf_random_seed (int): Obfuscation random seed, which should be in (0, 9223372036854775807]. The
1102
+ structure of obfuscated models corresponding to different random seeds is different. If
1103
+ `obf_random_seed` is set, then it should be passed to :class:`mindspore.nn.GraphCell`
1104
+ interface when loading
1105
+ obfuscated model. It should be noted that at least one of `customized_func` or `obf_random_seed` should
1106
+ be set, and the latter mode would be applied if both of them are set.
1107
+
1108
+ kwargs (dict): Configuration options dictionary.
1109
+
1110
+ - enc_key (bytes): Byte type key used for encryption. The valid length is 16, 24, or 32.
1111
+ - enc_mode (str): Specifies the encryption mode, to take effect when dec_key is set.
1112
+ Options: ``'AES-GCM'`` | ``'AES-CBC'`` | ``'SM4-CBC'``. Default: ``'AES-GCM'``.
1113
+
1114
+ Raises:
1115
+ TypeError: If `obf_config` is not a dict.
1116
+ ValueError: If `enc_key` is passed and `enc_mode` is not in ["AES-GCM", "AES-CBC", "SM4-CBC"].
1117
+ ValueError: If `original_model_path` is not provided in `obf_config`.
1118
+ ValueError: If the model saved in `original_model_path` has been obfuscated.
1119
+ ValueError: If `save_model_path` is not provided in `obf_config`.
1120
+ ValueError: If `obf_ratio` is not provided in `obf_config`.
1121
+ ValueError: If both `customized_func` and `obf_random_seed` are not provided in `obf_config`.
1122
+ ValueError: If `obf_random_seed` is not in (0, 9223372036854775807].
1123
+ ValueError: If `original_model_path` does not exist or `original_model_path` does not end with '.mindir'.
1124
+
1125
+ Examples:
1126
+ >>> import mindspore as ms
1127
+ >>> import mindspore.nn as nn
1128
+ >>> import numpy as np
1129
+ >>> # Download ori_net.mindir
1130
+ >>> # https://gitee.com/mindspore/mindspore/blob/master/tests/ut/python/mindir/ori_net.mindir
1131
+ >>> input1 = ms.Tensor(np.ones((1, 1, 32, 32)).astype(np.float32))
1132
+ >>> obf_config = {'original_model_path': "./net.mindir",
1133
+ ... 'save_model_path': "./obf_net",
1134
+ ... 'model_inputs': [input1, ],
1135
+ ... 'obf_ratio': 0.1, 'obf_random_seed': 173262358423}
1136
+ >>> ms.obfuscate_model(obf_config)
1137
+ >>> obf_func = ms.load("obf_net.mindir")
1138
+ >>> obf_net = nn.GraphCell(obf_func, obf_random_seed=173262358423)
1139
+ >>> print(obf_net(input1).asnumpy())
1140
+ """
1141
+ if not isinstance(obf_config, dict):
1142
+ raise TypeError("'obf_config' must be a dict, but got {}.".format(type(obf_config)))
1143
+ file_path = _check_param_type(obf_config, "original_model_path", str, True)
1144
+ if not file_path.endswith(".mindir"):
1145
+ raise ValueError("For 'obfuscate_model', the argument 'file_path'(MindIR file) should end with '.mindir', "
1146
+ "please input the correct 'file_path'.")
1147
+ if not os.path.exists(file_path):
1148
+ raise ValueError("For 'obfuscate_model', the argument 'file_path'(MindIR file) does not exist, "
1149
+ "please check whether the 'file_path' is correct.")
1150
+ saved_path = _check_param_type(obf_config, "save_model_path", str, True)
1151
+ model_inputs = _check_param_type(obf_config, "model_inputs", list, True)
1152
+ for item in model_inputs:
1153
+ if not isinstance(item, Tensor):
1154
+ raise TypeError("The item in 'model_inputs' must be Tensor, but got {}.".format(type(item)))
1155
+ if -1 in item.shape:
1156
+ raise ValueError(
1157
+ "Dynamic shape input is not supported now, but got the shape of inputs: {}.".format(item.shape))
1158
+ obf_ratio, customized_funcs, obf_random_seed = _check_obfuscate_params(obf_config)
1159
+ if customized_funcs and obf_random_seed > 0:
1160
+ logger.warning("Although 'customized_func' and 'obf_random_seed' are set, the 'obf_random_seed' mode would be"
1161
+ " applied, remember to set 'obf_random_seed' when loading obfuscated model.")
1162
+
1163
+ if obf_random_seed == 0: # apply customized_func mode
1164
+ clean_funcs()
1165
+ for func in customized_funcs:
1166
+ add_opaque_predicate(func.__name__, func)
1167
+ branch_control_input = 0
1168
+ else: # apply password mode
1169
+ branch_control_input = _generate_branch_control_input(obf_random_seed)
1170
+
1171
+ if 'enc_key' in kwargs.keys():
1172
+ enc_key = Validator.check_isinstance('enc_key', kwargs.get('enc_key'), bytes)
1173
+ enc_mode = "AES-GCM"
1174
+ if 'enc_mode' in kwargs.keys():
1175
+ enc_mode = Validator.check_isinstance('enc_mode', kwargs.get('enc_mode'), str)
1176
+ if enc_mode not in ["AES-GCM", "AES-CBC", "SM4-CBC"]:
1177
+ raise ValueError(
1178
+ "Only MindIR files that encrypted with 'AES-GCM', 'AES-CBC' or 'SM4-CBC' is supported for"
1179
+ "obfuscate_model(), but got {}.".format(enc_mode))
1180
+ obf_graph = dynamic_obfuscate_mindir(file_name=file_path, obf_ratio=obf_ratio,
1181
+ branch_control_input=branch_control_input, dec_key=enc_key,
1182
+ key_len=len(enc_key),
1183
+ dec_mode=enc_mode)
1184
+ else:
1185
+ obf_graph = dynamic_obfuscate_mindir(file_name=file_path, obf_ratio=obf_ratio,
1186
+ branch_control_input=branch_control_input)
1187
+
1188
+ obf_net = nn.GraphCell(obf_graph)
1189
+ if obf_random_seed != 0:
1190
+ append_y_tensor = Tensor(np.ones((1, 1)).astype(np.int32))
1191
+ model_inputs += [append_y_tensor]
1192
+ export(obf_net, *model_inputs, file_name=saved_path, file_format="MINDIR", **kwargs)
1193
+
1194
+
1195
+ def _load_into_param_dict(ckpt_file_name, parameter_dict, specify_prefix, filter_prefix, choice_func, dec_key,
1196
+ dec_mode, crc_check, format):
1197
+ """load parameter into parameter_dict"""
1198
+ ckpt_file_name = _check_ckpt_file_name(ckpt_file_name, format)
1199
+ if format == "safetensors":
1200
+ with safe_open(ckpt_file_name, framework='np') as f:
1201
+ for k in f.keys():
1202
+ parameter_dict[k] = Parameter(f.get_tensor(k))
1203
+ return
1204
+ checkpoint_list = _parse_ckpt_proto(ckpt_file_name, dec_key, dec_mode, crc_check)
1205
+ try:
1206
+ param_data_list = []
1207
+ map_data_list = [[], [], []]
1208
+ map_shape_list = [0, 0, 0]
1209
+ if specify_prefix:
1210
+ logger.warning("For load_checkpoint, this parameter `specity_prefix` will be deprecated, "
1211
+ "please use `choice_func` instead.")
1212
+ if filter_prefix:
1213
+ logger.warning("For load_checkpoint, this parameter `filter_prefix` will be deprecated, "
1214
+ "please use `choice_func` instead.")
1215
+ for element_id, element in enumerate(checkpoint_list.value):
1216
+ if element.tag == "random_op":
1217
+ parameter_dict["random_op"] = element.tensor.tensor_content
1218
+ continue
1219
+ if not _whether_load_param(specify_prefix, filter_prefix, element.tag):
1220
+ continue
1221
+ if specify_prefix is None and filter_prefix is None and \
1222
+ choice_func is not None and not choice_func(element.tag):
1223
+ continue
1224
+ if element.tensor.ByteSize() == 0:
1225
+ _load_map_parameter(checkpoint_list, element, element_id, map_data_list, map_shape_list,
1226
+ parameter_dict)
1227
+ if element.tag in parameter_dict:
1228
+ map_data_list = [[], [], []]
1229
+ map_shape_list = [0, 0, 0]
1230
+ continue
1231
+ data = element.tensor.tensor_content
1232
+ data_type = element.tensor.tensor_type
1233
+ np_type = tensor_to_np_type.get(data_type)
1234
+ ms_type = tensor_to_ms_type[data_type]
1235
+ if data_type == 'str':
1236
+ str_length = int(len(data) / 4)
1237
+ np_type = np_type + str(str_length)
1238
+ param_data_list.append(data)
1239
+ if (element_id == len(checkpoint_list.value) - 1) or \
1240
+ (element.tag != checkpoint_list.value[element_id + 1].tag):
1241
+ new_data = b"".join(param_data_list)
1242
+ param_data_list.clear()
1243
+ dims = element.tensor.dims
1244
+ if data_type == 'str':
1245
+ str_value = np.frombuffer(new_data, np_type)
1246
+ parameter_dict[element.tag] = str(str_value[0])
1247
+ else:
1248
+ if dims == [0]:
1249
+ dims = []
1250
+ param_data = Tensor_.convert_bytes_to_tensor(new_data, tuple(dims), ms_type)
1251
+ parameter = Parameter(param_data, name=element.tag)
1252
+ parameter_dict[element.tag] = parameter
1253
+ _offload_if_config(parameter)
1254
+
1255
+ logger.info("Loading checkpoint files process is finished.")
1256
+
1257
+ except BaseException as e:
1258
+ logger.critical("Failed to load the checkpoint file '%s'.", ckpt_file_name)
1259
+ raise ValueError(e.__str__() + "\nFor 'load_checkpoint', "
1260
+ "failed to load the checkpoint file {}.".format(ckpt_file_name)) from e
1261
+
1262
+
1263
+ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=None,
1264
+ dec_key=None, dec_mode="AES-GCM", specify_prefix=None, choice_func=None,
1265
+ crc_check=False, remove_redundancy=False, format="ckpt"):
1266
+ """
1267
+ Load checkpoint info from a specified file.
1268
+
1269
+ Note:
1270
+ - `specify_prefix` and `filter_prefix` do not affect each other.
1271
+ - If none of the parameters are loaded from checkpoint file, it will throw ValueError.
1272
+ - `specify_prefix` and `filter_prefix` are in the process of being deprecated,
1273
+ `choice_func` is recommended instead.
1274
+ And using either of those two args will override `choice_func` at the same time.
1275
+ - When loading a checkpoint that has removed redundancy, the network should be compiled.
1276
+
1277
+ Args:
1278
+ ckpt_file_name (str): Checkpoint file name.
1279
+ net (Cell): The network where the parameters will be loaded. Default: ``None`` .
1280
+ strict_load (bool): Whether to strict load the parameter into net. If ``False`` , it will load parameter
1281
+ into net when parameter name's suffix in checkpoint file is the same as the
1282
+ parameter in the network. When the types are inconsistent perform type conversion
1283
+ on the parameters of the same type, such as float32 to float16. Default: ``False`` .
1284
+ filter_prefix (Union[str, list[str], tuple[str]]): Deprecated(see `choice_func`). Parameters starting with the
1285
+ filter_prefix will not be loaded. Default: ``None`` .
1286
+ dec_key (Union[None, bytes]): Byte type key used for decryption. If the value is ``None`` , the decryption
1287
+ is not required. Default: ``None`` .
1288
+ dec_mode (str): This parameter is valid only when dec_key is not set to ``None`` . Specifies the decryption
1289
+ mode, currently supports ``"AES-GCM"`` and ``"AES-CBC"`` and ``"SM4-CBC"`` .
1290
+ Default: ``"AES-GCM"`` .
1291
+ specify_prefix (Union[str, list[str], tuple[str]]): Deprecated(see `choice_func`). Parameters starting with the
1292
+ specify_prefix will be loaded. Default: ``None`` .
1293
+ choice_func (Union[None, function]) : Input value of the function is a Parameter name of type string,
1294
+ and the return value is a bool. If returns ``True`` , the Parameter
1295
+ that matches the custom condition will be loaded. If returns ``False`` , the Parameter that
1296
+ matches the custom condition will be removed. Default: ``None`` .
1297
+ crc_check (bool) : Whether to perform crc32 validation when loading checkpoint. Default: ``False`` .
1298
+ remove_redundancy (bool): Whether to enable loading of checkpoint saved with redundancy removal.
1299
+ Redundancy removal refers to eliminating redundant data in data parallelism mode. Default: ``False`` , means
1300
+ redundant-free loading is not enabled.
1301
+ format (str): Format of the input file, can be "ckpt" or "safetensors". Default: "ckpt".
1302
+
1303
+ Returns:
1304
+ Dict, key is parameter name, value is a Parameter or string. When the `append_dict` parameter of
1305
+ :func:`mindspore.save_checkpoint` and the `append_info` parameter of :class:`mindspore.train.CheckpointConfig`
1306
+ are used to save the checkpoint, `append_dict` and `append_info` are dict types, and their value are string,
1307
+ then the return value obtained by loading checkpoint is string, and in other cases the return value is
1308
+ Parameter.
1309
+
1310
+ Raises:
1311
+ ValueError: Checkpoint file's format is incorrect.
1312
+ ValueError: Parameter's dict is None after load checkpoint file.
1313
+ TypeError: The type of `specify_prefix` or `filter_prefix` is incorrect.
1314
+
1315
+ Examples:
1316
+ >>> import mindspore as ms
1317
+ >>>
1318
+ >>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt"
1319
+ >>> param_dict = ms.load_checkpoint(ckpt_file_name,
1320
+ ... choice_func=lambda x: x.startswith("conv") and not x.startswith("conv1"))
1321
+ >>> print(param_dict["conv2.weight"])
1322
+ Parameter (name=conv2.weight, shape=(16, 6, 5, 5), dtype=Float32, requires_grad=True)
1323
+ >>> def func(param_name):
1324
+ ... whether_load = False
1325
+ ... if param_name.startswith("conv"):
1326
+ ... whether_load = True
1327
+ ... if param_name.startswith("conv1"):
1328
+ ... whether_load = False
1329
+ ... return whether_load
1330
+ >>> param_dict1 = ms.load_checkpoint(ckpt_file_name, choice_func=func)
1331
+ >>> print(param_dict1["conv2.weight"])
1332
+ Parameter (name=conv2.weight, shape=(16, 6, 5, 5), dtype=Float32, requires_grad=True)
1333
+ >>> def func(param_name):
1334
+ ... whether_load = False
1335
+ ... if param_name.startswith("conv1"):
1336
+ ... whether_load = True
1337
+ ... return whether_load
1338
+ >>> param_dict2 = ms.load_checkpoint(ckpt_file_name, choice_func=func)
1339
+ >>> print(param_dict2)
1340
+ {'conv1.weight': Parameter (name=conv1.weight, shape=(6, 1, 5, 5), dtype=Float32, requires_grad=True)}
1341
+
1342
+ Tutorial Examples:
1343
+ - `Saving and Loading the Model - Saving and Loading the Model Weight
1344
+ <https://mindspore.cn/tutorials/en/master/beginner/save_load.html#saving-and-loading-the-model-weight>`_
1345
+ """
1346
+ specify_prefix = _check_prefix(specify_prefix)
1347
+ filter_prefix = _check_prefix(filter_prefix)
1348
+ dec_key = Validator.check_isinstance('dec_key', dec_key, (type(None), bytes))
1349
+ dec_mode = Validator.check_isinstance('dec_mode', dec_mode, str)
1350
+ crc_check = Validator.check_isinstance('crc_check', crc_check, bool)
1351
+ remove_redundancy = Validator.check_isinstance('remove_redundancy', remove_redundancy, bool)
1352
+ _check_format_and_other_params(format, dec_key, dec_mode, crc_check)
1353
+ logger.info("Execute the process of loading checkpoint files.")
1354
+
1355
+ parameter_dict = {}
1356
+
1357
+ if os.getenv("AITURBO") == "1":
1358
+ rank_id = get_rank()
1359
+ from aiturbo.checkpoint import aiturbo_mindspore as aiturbo
1360
+ ckpt_path = os.path.dirname(ckpt_file_name)
1361
+ ckpt_name = os.path.basename(ckpt_file_name)
1362
+ np_dict = aiturbo.load_ckpt(ckpt_path, ckpt_name, rank_id, crc_check)
1363
+ for key, value in np_dict.items():
1364
+ if crc_check and len(value) != 2:
1365
+ raise ValueError(f"When loading a checkpoint from AITurbo, if CRC check is enabled, "
1366
+ f"the length of the value must be 2, but got {len(value)}.")
1367
+ if isinstance(value, str):
1368
+ if crc_check and value[1] != binascii.crc32(np.array(value[0]).tobytes()):
1369
+ raise ValueError(f"When loading a checkpoint from AITurbo, the value of the string has not "
1370
+ f"passed the CRC check and has been corrupted.")
1371
+ parameter_dict[key] = value[0]
1372
+ else:
1373
+ if crc_check and value[1] != binascii.crc32(value[0].tobytes()):
1374
+ raise ValueError(f"When loading a checkpoint from AITurbo, the value of the parameter has not "
1375
+ f"passed the CRC check and has been corrupted.")
1376
+ parameter_dict[key] = Parameter(Tensor(value[0]), name=key)
1377
+ else:
1378
+ _load_into_param_dict(ckpt_file_name, parameter_dict, specify_prefix, filter_prefix, choice_func, dec_key,
1379
+ dec_mode, crc_check, format)
1380
+
1381
+ if not parameter_dict:
1382
+ raise ValueError(f"The loaded parameter dict is empty after filter or specify, please check whether "
1383
+ f"'filter_prefix' or 'specify_prefix' are set correctly.")
1384
+
1385
+ if _warm_up_host_cache_enabled(parameter_dict):
1386
+ (is_worker, net_dict, warm_up_dict) = _warm_up_host_cache(parameter_dict, net)
1387
+ if net is not None:
1388
+ load_param_into_net(net, parameter_dict, strict_load, remove_redundancy)
1389
+ if _warm_up_host_cache_enabled(parameter_dict):
1390
+ _warm_up_host_cache_post_process(is_worker, net_dict, warm_up_dict)
1391
+
1392
+ return parameter_dict
1393
+
1394
+
1395
+ def load_checkpoint_async(ckpt_file_name, net=None, strict_load=False, filter_prefix=None, dec_key=None,
1396
+ dec_mode="AES-GCM", specify_prefix=None, choice_func=None):
1397
+ """
1398
+ Load checkpoint info from a specified file asyncly.
1399
+
1400
+ .. warning::
1401
+ This is an experimental API that is subject to change or deletion.
1402
+
1403
+ Note:
1404
+ - `specify_prefix` and `filter_prefix` do not affect each other.
1405
+ - If none of the parameters are loaded from checkpoint file, it will throw ValueError.
1406
+ - `specify_prefix` and `filter_prefix` are in the process of being deprecated,
1407
+ `choice_func` is recommended instead.
1408
+ And using either of those two args will override `choice_func` at the same time.
1409
+
1410
+ Args:
1411
+ ckpt_file_name (str): Checkpoint file name.
1412
+ net (Cell, optional): The network where the parameters will be loaded. Default: ``None`` .
1413
+ strict_load (bool, optional): Whether to strict load the parameter into net. If ``False`` , it will load
1414
+ parameter into net when parameter name's suffix in checkpoint file is the
1415
+ same as the parameter in the network. When the types are inconsistent
1416
+ perform type conversion on the parameters of the same type, such as float32
1417
+ to float16. Default: ``False`` .
1418
+ filter_prefix (Union[str, list[str], tuple[str]], optional): Deprecated(see `choice_func`). Parameters
1419
+ starting with the `filter_prefix` will not be loaded. Default: ``None`` .
1420
+ dec_key (Union[None, bytes], optional): Byte type key used for decryption. If the value is ``None`` ,
1421
+ the decryption is not required. Default: ``None`` .
1422
+ dec_mode (str, optional): This parameter is valid only when dec_key is not set to ``None`` . Specifies
1423
+ the decryption mode, currently supports ``"AES-GCM"`` and ``"AES-CBC"``
1424
+ and ``"SM4-CBC"`` . Default: ``"AES-GCM"`` .
1425
+ specify_prefix (Union[str, list[str], tuple[str]], optional): Deprecated(see `choice_func`). Parameters
1426
+ starting with the specify_prefix will be loaded. Default: ``None`` .
1427
+ choice_func (Union[None, function], optional): Input value of the function is a Parameter name of type
1428
+ string, and the return value is a bool. If returns ``True`` , the Parameter
1429
+ that matches the custom condition will be loaded. If returns ``False`` , the Parameter that
1430
+ matches the custom condition will be removed. Default: ``None`` .
1431
+
1432
+ Returns:
1433
+ A custom inner class, calling its `result` method yields the :func:`mindspore.load_checkpoint` result.
1434
+
1435
+ Raises:
1436
+ ValueError: Checkpoint file's format is incorrect.
1437
+ ValueError: Parameter's dict is None after load checkpoint file.
1438
+ TypeError: The type of `specify_prefix` or `filter_prefix` is incorrect.
1439
+
1440
+ Examples:
1441
+ >>> import mindspore
1442
+ >>> from mindspore import nn
1443
+ >>> from mindspore.train import Model
1444
+ >>> from mindspore.amp import FixedLossScaleManager
1445
+ >>> from mindspore import context
1446
+ >>> from mindspore import load_checkpoint_async
1447
+ >>> from mindspore import load_param_into_net
1448
+ >>> context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
1449
+ >>> # Create the dataset taking MNIST as an example. Refer to
1450
+ >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/mnist.py
1451
+ >>> dataset = create_dataset()
1452
+ >>> # Define the network structure of LeNet5. Refer to
1453
+ >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
1454
+ >>> ckpt_file = "./checkpoint/LeNet5-1_32.ckpt"
1455
+ >>> net = LeNet5()
1456
+ >>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
1457
+ >>> loss_scale_manager = FixedLossScaleManager()
1458
+ >>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
1459
+ >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None,
1460
+ ... loss_scale_manager=loss_scale_manager)
1461
+ >>> pd_future = load_checkpoint_async(ckpt_file)
1462
+ >>> model.build(train_dataset=dataset, epoch=2)
1463
+ >>> param_dict = pd_future.result()
1464
+ >>> load_param_into_net(net, param_dict)
1465
+ >>> model.train(2, dataset)
1466
+ >>> print("param dict len: ", len(param_dict), flush=True)
1467
+ """
1468
+ from concurrent.futures import ThreadPoolExecutor
1469
+ executor = ThreadPoolExecutor(max_workers=2)
1470
+ param_dict_future = executor.submit(load_checkpoint, ckpt_file_name, net, strict_load, filter_prefix,
1471
+ dec_key, dec_mode, specify_prefix, choice_func)
1472
+ return ParamDictFuture(executor, param_dict_future)
1473
+
1474
+
1475
+ def _load_map_parameter(checkpoint_list, element, element_id, map_data_list,
1476
+ map_shape_list, parameter_dict):
1477
+ """load map parameter."""
1478
+ logger.info("Checkpoint load map_parameter.")
1479
+ if (element_id != len(checkpoint_list.value) - 1) and \
1480
+ element.tag == checkpoint_list.value[element_id + 1].tag:
1481
+ for index, tensor in enumerate(element.maptensor.tensor):
1482
+ data = tensor.tensor_content
1483
+ data_type = tensor.tensor_type
1484
+ np_type = np_type_convert.get(data_type)
1485
+ element_data = np.frombuffer(data, np_type)
1486
+ map_data_list[index].append(element_data)
1487
+ map_shape_list[index] += tensor.dims[0]
1488
+ else:
1489
+ map_array = []
1490
+ for index, tensor in enumerate(element.maptensor.tensor):
1491
+ data = tensor.tensor_content
1492
+ data_type = tensor.tensor_type
1493
+ np_type = np_type_convert.get(data_type)
1494
+ element_data = np.frombuffer(data, np_type)
1495
+ map_data_list[index].append(element_data)
1496
+ new_data = b"".join(map_data_list[index])
1497
+ param_data = np.frombuffer(new_data, np_type)
1498
+ dims = tensor.dims
1499
+ dims[0] += map_shape_list[index]
1500
+ param_data = param_data.reshape(list(dims))
1501
+ map_array.append(param_data)
1502
+ parameter_dict[element.tag] = map_array
1503
+
1504
+
1505
+ def _check_ckpt_file_name(ckpt_file_name, format):
1506
+ """Check function load_checkpoint's ckpt_file_name."""
1507
+ if not isinstance(ckpt_file_name, str):
1508
+ raise TypeError("For 'load_checkpoint', the argument 'ckpt_file_name' must be string, "
1509
+ "but got {}.".format(type(ckpt_file_name)))
1510
+
1511
+ if format not in ['ckpt', 'safetensors']:
1512
+ raise ValueError("For 'load_checkpoint', the checkpoint file should end with '.ckpt' or '.safetensors', please "
1513
+ "input the correct 'ckpt_file_name'.")
1514
+ if not ckpt_file_name.endswith(format):
1515
+ raise ValueError(f"For 'load_checkpoint', the checkpoint file format must same with 'format', but got "
1516
+ f"file_name:'{ckpt_file_name}', format:'{format}'")
1517
+
1518
+ ckpt_file_name = os.path.realpath(ckpt_file_name)
1519
+ if not os.path.exists(ckpt_file_name):
1520
+ raise ValueError("For 'load_checkpoint', the checkpoint file: {} does not exist, please check "
1521
+ "whether the 'ckpt_file_name' is correct.".format(ckpt_file_name))
1522
+
1523
+ return ckpt_file_name
1524
+
1525
+
1526
+ def _check_prefix(prefix):
1527
+ """Check the correctness of the parameters."""
1528
+ if prefix is None:
1529
+ return prefix
1530
+ if not isinstance(prefix, (str, list, tuple)):
1531
+ raise TypeError("For 'load_checkpoint', the type of 'specify_prefix' or 'filter_prefix' must be string, "
1532
+ "list[string] or tuple[string], but got {}.".format(str(type(prefix))))
1533
+ if isinstance(prefix, str):
1534
+ prefix = (prefix,)
1535
+ if not prefix:
1536
+ raise ValueError("For 'load_checkpoint', the argument 'specify_prefix' or 'filter_prefix' can't be empty when"
1537
+ " 'specify_prefix' or 'filter_prefix' is list or tuple.")
1538
+ for index, pre in enumerate(prefix):
1539
+ if not isinstance(pre, str):
1540
+ raise TypeError("For 'load_checkpoint', when 'specify_prefix' or 'filter_prefix' is list or tuple, "
1541
+ "the element in it must be string, but got "
1542
+ f"{str(type(pre))} at index {index}.")
1543
+ if pre == "":
1544
+ raise ValueError("For 'load_checkpoint', the value of 'specify_prefix' or 'filter_prefix' "
1545
+ "can't include ''.")
1546
+ return prefix
1547
+
1548
+
1549
+ def _parse_ckpt_proto(ckpt_file_name, dec_key, dec_mode, crc_check):
1550
+ """Parse checkpoint protobuf."""
1551
+ checkpoint_list = Checkpoint()
1552
+ try:
1553
+ if dec_key is None:
1554
+ with _ckpt_fs.open(ckpt_file_name, *_ckpt_fs.open_args) as f:
1555
+ pb_content = f.read()
1556
+ else:
1557
+ pb_content = _decrypt(ckpt_file_name, dec_key, len(dec_key), dec_mode)
1558
+ if pb_content is None:
1559
+ raise ValueError("For 'load_checkpoint', failed to decrypt the checkpoint file.")
1560
+ if crc_check and pb_content[-17:-10] != b"crc_num":
1561
+ logger.warning("For 'load_checkpoint', the ckpt file do not contain the crc code, please check the file.")
1562
+ if pb_content[-17:-10] == b"crc_num":
1563
+ crc_num_bytes = pb_content[-10:]
1564
+ pb_content = pb_content[:-17]
1565
+ if crc_check:
1566
+ crc_num = int.from_bytes(crc_num_bytes, byteorder='big')
1567
+ cal_crc_num = binascii.crc32(pb_content, 0)
1568
+ if cal_crc_num != crc_num:
1569
+ raise ValueError("For 'load_checkpoint', the crc check is failed, "
1570
+ "please check whether the ckpt file is damaged.")
1571
+ checkpoint_list.ParseFromString(pb_content)
1572
+ except BaseException as e:
1573
+ if _is_cipher_file(ckpt_file_name):
1574
+ err_info = "Failed to read the checkpoint file {}. The file may be encrypted or tempered with, " \
1575
+ "please pass in the correct 'dec_key' or check the file integrity.".format(ckpt_file_name)
1576
+ else:
1577
+ err_info = "Failed to read the checkpoint file {}. May not have permission to read it, please check" \
1578
+ " the correct of the file.".format(ckpt_file_name)
1579
+ logger.error(err_info)
1580
+ raise ValueError(err_info) from e
1581
+ return checkpoint_list
1582
+
1583
+
1584
+ def _whether_load_param(specify_prefix, filter_prefix, param_name):
1585
+ """Checks whether the load the parameter after `specify_prefix` or `filter_prefix`."""
1586
+ whether_load = True
1587
+ if specify_prefix:
1588
+ whether_load = False
1589
+ for prefix in specify_prefix:
1590
+ if param_name.startswith(prefix):
1591
+ whether_load = True
1592
+ break
1593
+ if filter_prefix:
1594
+ for prefix in filter_prefix:
1595
+ if param_name.startswith(prefix):
1596
+ whether_load = False
1597
+ break
1598
+ return whether_load
1599
+
1600
+
1601
+ def _init_parameter_data_in_parallel_mode(net, parameter_dict):
1602
+ """In parallel mode, only init the paraemters in ckpt."""
1603
+ is_train_phase = net.phase.startswith('train')
1604
+ for _, param in net.parameters_and_names():
1605
+ if param.name in parameter_dict and param.from_ckpt and not is_train_phase:
1606
+ param.shape = tuple(parameter_dict[param.name].shape)
1607
+ continue
1608
+ if param.name in parameter_dict and param.has_init:
1609
+ logger.warning("{} is not init while load ckpt.".format(param.name))
1610
+ new_tensor = param.init_data()
1611
+ param._update_tensor_data(new_tensor)
1612
+
1613
+
1614
+ def _check_load_param_into_net(net, parameter_dict):
1615
+ """check load_param_into_net"""
1616
+ if not isinstance(net, nn.Cell):
1617
+ logger.critical("Failed to combine the net and the parameters.")
1618
+ msg = ("For 'load_param_into_net', the argument 'net' should be a Cell, but got {}.".format(type(net)))
1619
+ raise TypeError(msg)
1620
+ if not isinstance(parameter_dict, dict):
1621
+ logger.critical("Failed to combine the net and the parameters.")
1622
+ msg = ("For 'load_param_into_net', the argument 'parameter_dict' should be a dict, "
1623
+ "but got {}.".format(type(parameter_dict)))
1624
+ raise TypeError(msg)
1625
+ if "random_op" in parameter_dict.keys():
1626
+ net._add_attr("random_op_snapshot", parameter_dict["random_op"])
1627
+ parameter_dict.pop("random_op")
1628
+
1629
+
1630
+ def load_param_into_net(net, parameter_dict, strict_load=False, remove_redundancy=False):
1631
+ """
1632
+ Load parameters into network, return parameter list that are not loaded in the network.
1633
+
1634
+ Note:
1635
+ - When loading a parameter dict that has removed redundancy, the network should be compiled.
1636
+
1637
+ Args:
1638
+ net (Cell): The network where the parameters will be loaded.
1639
+ parameter_dict (dict): The dictionary generated by load checkpoint file,
1640
+ it is a dictionary consisting of key: parameters's name, value: parameter.
1641
+ strict_load (bool): Whether to strict load the parameter into net. If ``False`` , it will load parameter
1642
+ into net when parameter name's suffix in checkpoint file is the same as the
1643
+ parameter in the network. When the types are inconsistent perform type conversion
1644
+ on the parameters of the same type, such as float32 to float16. Default: ``False`` .
1645
+ remove_redundancy (bool): Whether to enable loading of checkpoint saved with redundancy removal.
1646
+ Redundancy removal refers to eliminating redundant data in data parallelism mode. Default: ``False`` , means
1647
+ redundant-free loading is not enabled.
1648
+
1649
+ Returns:
1650
+ - param_not_load (List), the parameter name in model which are not loaded into the network.
1651
+ - ckpt_not_load (List), the parameter name in checkpoint file which are not loaded into the network.
1652
+
1653
+ Raises:
1654
+ TypeError: Argument is not a Cell, or parameter_dict is not a Parameter dictionary.
1655
+
1656
+ Examples:
1657
+ >>> import mindspore as ms
1658
+ >>>
1659
+ >>> # Define the network structure of LeNet5. Refer to
1660
+ >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
1661
+ >>> net = LeNet5()
1662
+ >>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt"
1663
+ >>> param_dict = ms.load_checkpoint(ckpt_file_name, filter_prefix="conv1")
1664
+ >>> param_not_load, _ = ms.load_param_into_net(net, param_dict)
1665
+ >>> print(param_not_load)
1666
+ ['conv1.weight']
1667
+
1668
+ Tutorial Examples:
1669
+ - `Saving and Loading the Model - Saving and Loading the Model Weight
1670
+ <https://mindspore.cn/tutorials/en/master/beginner/save_load.html#saving-and-loading-the-model-weight>`_
1671
+ """
1672
+ _check_load_param_into_net(net, parameter_dict)
1673
+ for key, value in parameter_dict.items():
1674
+ if not isinstance(key, str) or not isinstance(value, (Parameter, str, list)):
1675
+ logger.critical("Load parameters into net failed.")
1676
+ msg = ("For 'parameter_dict', the element in the argument 'parameter_dict' should be a "
1677
+ "'str' and 'Parameter' , but got {} and {}.".format(type(key), type(value)))
1678
+ raise TypeError(msg)
1679
+
1680
+ strict_load = Validator.check_bool(strict_load)
1681
+ remove_redundancy = Validator.check_isinstance('remove_redundancy', remove_redundancy, bool)
1682
+ logger.info("Execute the process of loading parameters into net.")
1683
+ for _, param in net.parameters_and_names():
1684
+ param.from_ckpt = True
1685
+ if not (_is_in_auto_parallel_mode() or _is_parallel_mode()):
1686
+ net.init_parameters_data()
1687
+ else:
1688
+ _init_parameter_data_in_parallel_mode(net, parameter_dict)
1689
+ param_not_load = []
1690
+ ckpt_not_load = list(parameter_dict.keys())
1691
+ for _, param in net.parameters_and_names():
1692
+ if param.name in parameter_dict:
1693
+ if isinstance(param, MapParameter):
1694
+ param.import_data(parameter_dict[param.name])
1695
+ continue
1696
+ # Add has attr protection when load server checkpoint file on worker.
1697
+ if not hasattr(parameter_dict[param.name], "data"):
1698
+ continue
1699
+ new_param = parameter_dict[param.name]
1700
+ _update_param(param, new_param, strict_load)
1701
+ ckpt_not_load.remove(param.name)
1702
+ else:
1703
+ param_not_load.append(param.name)
1704
+
1705
+ if param_not_load and not strict_load:
1706
+ _load_dismatch_prefix_params(net, parameter_dict, param_not_load, strict_load)
1707
+
1708
+ logger.info("Loading parameters into net is finished.")
1709
+ if param_not_load:
1710
+ logger.warning("For 'load_param_into_net', "
1711
+ "{} parameters in the 'net' are not loaded, because they are not in the "
1712
+ "'parameter_dict', please check whether the network structure is consistent "
1713
+ "when training and loading checkpoint. Another possibility is that "
1714
+ "the redundant loading is not enabled, but the loaded checkpoint is saved with "
1715
+ "redundancy removed. ".format(len(param_not_load)))
1716
+ logger.warning("{} are not loaded.".format(param_not_load))
1717
+ if remove_redundancy:
1718
+ parallel_mode = context.get_auto_parallel_context("parallel_mode")
1719
+ if parallel_mode == "stand_alone":
1720
+ raise TypeError(f"The deduplication feature for loading checkpoint can only be used "
1721
+ f"in parallel scenarios, but got {parallel_mode}.")
1722
+ if not net.compile_cache and not net.parameter_layout_dict:
1723
+ raise ValueError("When loading a parameter dict that has removed redundancy, "
1724
+ "the network should be compiled.")
1725
+ param_layout = net.parameter_layout_dict
1726
+ rank_id = get_rank()
1727
+ device_num = _get_device_num()
1728
+ stage_num = _get_auto_parallel_context("pipeline_stages")
1729
+ chunk_size = device_num // stage_num
1730
+ initial_rank = (rank_id // chunk_size) * chunk_size
1731
+ _single_parameter_broadcast(net, param_layout, rank_id, initial_rank)
1732
+
1733
+ return param_not_load, ckpt_not_load
1734
+
1735
+
1736
+ def _warm_up_host_cache_enabled(parameter_dict):
1737
+ """Warm up host cache enabled."""
1738
+ if _cache_enable():
1739
+ return True
1740
+ for key in parameter_dict.keys():
1741
+ if key.find(".__param_key__") != -1:
1742
+ return True
1743
+ return False
1744
+
1745
+
1746
+ def _warm_up_host_cache(parameter_dict, net):
1747
+ """Warm up host cache."""
1748
+ ms_role = os.getenv("MS_ROLE")
1749
+ is_worker = ms_role == "MS_WORKER"
1750
+ param_key_dict = {}
1751
+ # Traverse key, value in parameter_dict, warm up param key and record param key into param_key_dict.
1752
+ if is_worker:
1753
+ net.init_parameters_data()
1754
+ net_dict = {}
1755
+ for name, value in net.parameters_and_names():
1756
+ net_dict[name] = value
1757
+ for param_name, value in parameter_dict.items():
1758
+ pos = param_name.find(".__param_key__")
1759
+ if pos != -1:
1760
+ net_param_name = param_name[:pos]
1761
+ param_key_dict[param_name] = net_param_name
1762
+ net_value = None
1763
+ if net_param_name not in net_dict:
1764
+ logger.warning("net param name : %s is not in net", net_param_name)
1765
+ else:
1766
+ net_value = net_dict.get(net_param_name, None)
1767
+ pos += len(".__param_key__")
1768
+ param_key = int(param_name[pos:])
1769
+ value_is_map_parameter = isinstance(value, list) and len(value) == 3
1770
+ if value_is_map_parameter and (net_value is None or isinstance(net_value, Parameter)):
1771
+ key_tensor = Tensor.from_numpy(value[0])
1772
+ value_tensor = Tensor.from_numpy(value[1])
1773
+ status_tensor = Tensor.from_numpy(value[2])
1774
+ _store_warm_up_ptr_by_tensor_list(param_key, key_tensor, value_tensor, status_tensor)
1775
+ elif not isinstance(value, list) and isinstance(net_value, Parameter):
1776
+ _store_warm_up_ptr_by_tensor(param_key, value)
1777
+ else:
1778
+ logger.warning("Unknown matches parameter type %s and net_value %s", type(value), type(net_value))
1779
+ else:
1780
+ for param_name, value in parameter_dict.items():
1781
+ pos = param_name.find(".__param_key__")
1782
+ if pos != -1:
1783
+ net_param_name = param_name[:pos]
1784
+ param_key_dict[param_name] = net_param_name
1785
+ # Split param key from parameter_dict since worker cannot load param key.
1786
+ warm_up_dict = {}
1787
+ for key, value in param_key_dict.items():
1788
+ if is_worker:
1789
+ warm_up_dict[value] = parameter_dict.pop(key)
1790
+ else:
1791
+ parameter_dict[value] = parameter_dict.pop(key)
1792
+ return (is_worker, parameter_dict, warm_up_dict)
1793
+
1794
+
1795
+ def _warm_up_host_cache_post_process(is_worker, net_dict, warm_up_dict):
1796
+ """Warm up host cache post process."""
1797
+ if is_worker:
1798
+ net_dict.update(warm_up_dict)
1799
+ _set_checkpoint_load_status(True)
1800
+
1801
+
1802
+ def _load_dismatch_prefix_params(net, parameter_dict, param_not_load, strict_load):
1803
+ """When some net parameter did not load, try to continue loading."""
1804
+ prefix_name = ""
1805
+ longest_name = param_not_load[0]
1806
+ while prefix_name != longest_name and param_not_load:
1807
+ logger.debug("Count: {} parameters has not been loaded, try to continue loading.".format(len(param_not_load)))
1808
+ prefix_name = longest_name
1809
+ for net_param_name in param_not_load:
1810
+ for dict_name in parameter_dict:
1811
+ if dict_name.endswith(net_param_name):
1812
+ prefix_name = dict_name[:-len(net_param_name)]
1813
+ break
1814
+ if prefix_name != longest_name:
1815
+ break
1816
+
1817
+ if prefix_name != longest_name:
1818
+ logger.warning(f"For 'load_param_into_net', remove parameter prefix name: {prefix_name},"
1819
+ f" continue to load.")
1820
+ for _, param in net.parameters_and_names():
1821
+ new_param_name = prefix_name + param.name
1822
+ if param.name in param_not_load and new_param_name in parameter_dict:
1823
+ new_param = parameter_dict[new_param_name]
1824
+ _update_param(param, new_param, strict_load)
1825
+ param_not_load.remove(param.name)
1826
+
1827
+
1828
+ def _save_graph(network, file_name):
1829
+ """
1830
+ Saves the graph of network to a file.
1831
+
1832
+ Args:
1833
+ network (Cell): Obtain a pipeline through network for saving graph.
1834
+ file_name (str): Graph file name into which the graph will be saved.
1835
+ """
1836
+ logger.info("Execute the process of saving graph.")
1837
+
1838
+ file_name = os.path.realpath(file_name)
1839
+ graph_pb = network.get_func_graph_proto()
1840
+ if graph_pb:
1841
+ with open(file_name, "wb") as f:
1842
+ os.chmod(file_name, stat.S_IRUSR | stat.S_IWUSR)
1843
+ f.write(graph_pb)
1844
+
1845
+
1846
+ def _reshape_tensor(tensor, dst_shape):
1847
+ """reshape tensor to dst shape"""
1848
+ np_tensor = tensor.asnumpy()
1849
+ np_tensor = np_tensor.reshape(dst_shape)
1850
+ return Tensor(np_tensor, tensor.dtype)
1851
+
1852
+
1853
+ def _check_param_for_integrate_save(pipeline_stages, uniform_split):
1854
+ """check whether current settings and parameters are supported in integrated save checkpoint mode"""
1855
+ if pipeline_stages > 1:
1856
+ raise RuntimeError("Pipeline Parallel don't support Integrated save checkpoint now.")
1857
+ if uniform_split == 0:
1858
+ raise RuntimeError("For 'save_checkpoint' and in automatic model parallel scene, when set "
1859
+ "'integrated_save' to True, the checkpoint will be integrated save, it "
1860
+ "is only supports uniform split tensor now.")
1861
+
1862
+
1863
+ def _get_merged_param_data(net, parameter_layout_dict, param_name, param_data, integrated_save):
1864
+ """
1865
+ Gets the merged data(tensor) from tensor slice, by device arrangement and tensor map.
1866
+
1867
+ Args:
1868
+ net (Cell): MindSpore network.
1869
+ param_name (str): The parameter name, which to be combined.
1870
+ param_data (Tensor): The parameter data on the local device, which was a slice of the whole parameter data.
1871
+ integrated_save (bool): Whether to integrated save in automatic model parallel scene.
1872
+ Returns:
1873
+ Tensor, the combined tensor which with the whole data value.
1874
+ """
1875
+ layout = parameter_layout_dict[param_name]
1876
+ if len(layout) < 8:
1877
+ logger.info("The layout dict does not contain the key %s", param_name)
1878
+ return param_data
1879
+
1880
+ dev_mat = layout[0]
1881
+ tensor_map = layout[1]
1882
+ uniform_split = layout[4]
1883
+ opt_shard_group = layout[5]
1884
+ before_reshape_slice_shape = layout[2]
1885
+ before_reshape_full_shape = layout[6]
1886
+ after_reshape_slice_shape = layout[7]
1887
+ do_reshape = False
1888
+ if before_reshape_full_shape and after_reshape_slice_shape \
1889
+ and after_reshape_slice_shape != before_reshape_slice_shape:
1890
+ do_reshape = True
1891
+
1892
+ allgather_net = None
1893
+ mp_weight = False
1894
+ for dim in tensor_map:
1895
+ if dim != -1:
1896
+ mp_weight = True
1897
+ break
1898
+ if param_name in net.parallel_parameter_merge_net_dict:
1899
+ allgather_net = net.parallel_parameter_merge_net_dict[param_name]
1900
+ else:
1901
+ logger.info("Need to create allgather net for %s", param_name)
1902
+ if integrated_save:
1903
+ _check_param_for_integrate_save(context.get_auto_parallel_context("pipeline_stages"), uniform_split)
1904
+ # while any dim is not equal to -1, means param is split and needs to be merged
1905
+ # pipeline parallel need to be supported here later
1906
+ if mp_weight:
1907
+ allgather_net = get_allgather_cell(opt_shard_group, bool(opt_shard_group), do_reshape,
1908
+ tuple(after_reshape_slice_shape))
1909
+ object.__setattr__(allgather_net, "keep_input_unchanged", True)
1910
+ elif opt_shard_group:
1911
+ allgather_net = get_allgather_cell(opt_shard_group, False, do_reshape,
1912
+ tuple(after_reshape_slice_shape))
1913
+ elif opt_shard_group and context.get_auto_parallel_context("optimizer_weight_shard_aggregated_save"):
1914
+ allgather_net = get_allgather_cell(opt_shard_group, False, do_reshape,
1915
+ tuple(after_reshape_slice_shape))
1916
+ net.parallel_parameter_merge_net_dict[param_name] = allgather_net
1917
+ if allgather_net:
1918
+ param_data = allgather_net(param_data)
1919
+ if mp_weight and integrated_save:
1920
+ param_data = _reshape_param_data(param_data, dev_mat, tensor_map)
1921
+ if do_reshape:
1922
+ param_data = _reshape_tensor(param_data, before_reshape_full_shape)
1923
+ return param_data
1924
+
1925
+
1926
+ def export(net, *inputs, file_name, file_format, **kwargs):
1927
+ """
1928
+ Export the MindSpore network into an offline model in the specified format.
1929
+
1930
+ Note:
1931
+ 1. When exporting AIR, ONNX format, the size of a single tensor can not exceed 2GB.
1932
+ 2. When `file_name` does not have a suffix, the system will automatically add one
1933
+ according to the `file_format`.
1934
+ 3. Exporting functions decorated with :func:`mindspore.jit` to mindir format is supported.
1935
+ 4. When exporting a function decorated with :func:`mindspore.jit`, the function should not involve
1936
+ class properties in calculations.
1937
+ 5. AIR format is deprecated, and will be removed in a future version, please use other format or use
1938
+ MindSpore Lite to do offline inference.
1939
+
1940
+ Args:
1941
+ net (Union[Cell, function]): MindSpore network.
1942
+ inputs (Union[Tensor, Dataset, List, Tuple, Number, Bool]): It represents the inputs
1943
+ of the `net`, if the network has multiple inputs, set them together. While its type is Dataset,
1944
+ it represents the preprocess behavior of the `net`, data preprocess operations will be serialized.
1945
+ In second situation, you should adjust batch size of dataset script manually which will impact on
1946
+ the batch size of 'net' input. Only supports parse "image" column from dataset currently.
1947
+ file_name (str): File name of the model to be exported.
1948
+ file_format (str): MindSpore currently supports 'AIR', 'ONNX' and 'MINDIR' format for exported model.
1949
+
1950
+ - AIR: Ascend Intermediate Representation. An intermediate representation format of Ascend model.
1951
+ - ONNX: Open Neural Network eXchange. An open format built to represent machine learning models.
1952
+ - MINDIR: MindSpore Native Intermediate Representation for Anf. An intermediate representation format
1953
+ for MindSpore models. MINDIR does not support operators which have dictionary attribute.
1954
+
1955
+ kwargs (dict): Configuration options dictionary.
1956
+
1957
+ - enc_key (byte): Byte-type key used for encryption. The valid length is 16, 24, or 32.
1958
+ - enc_mode (Union[str, function]): Specifies the encryption mode, to take effect when enc_key is set.
1959
+
1960
+ - For 'AIR' and 'ONNX' models, only customized encryption is supported.
1961
+ - For 'MINDIR', all options are supported. Option: 'AES-GCM', 'AES-CBC', 'SM4-CBC'
1962
+ or Customized encryption.
1963
+ Default: ``'AES-GCM'``.
1964
+ - For details of using the customized encryption, please check the `tutorial
1965
+ <https://mindspore.cn/mindarmour/docs/en/master/model_encrypt_protection.html>`_.
1966
+
1967
+ - dataset (Dataset): Specifies the preprocessing method of the dataset, which is used to import the
1968
+ preprocessing of the dataset into MindIR.
1969
+
1970
+ - obf_config (dict): obfuscation config.
1971
+
1972
+ - type (str): The type of obfuscation, only 'dynamic' is supported until now.
1973
+ - obf_ratio (float, str): The ratio of nodes in original model that would be obfuscated. `obf_ratio`
1974
+ should be in range of (0, 1] or in ["small", "medium", "large"]. "small", "medium" and "large" are
1975
+ correspond to 0.1, 0.3, and 0.6 respectively.
1976
+ - customized_func (function): A python function used for customized function mode, which used for control
1977
+ the switch branch of obfuscation structure. The outputs of customized_func should be boolean and const (
1978
+ Reference to 'my_func()' in
1979
+ `tutorials <https://www.mindspore.cn/mindarmour/docs/en/master/dynamic_obfuscation_protection.html>`_).
1980
+ This function needs to ensure that its result is constant for any input. Users can refer to opaque
1981
+ predicates. If customized_func is set, then it should be passed to `load()` interface when loading
1982
+ obfuscated model.
1983
+ - obf_random_seed (int): Obfuscation random seed, which should be in (0, 9223372036854775807]. The
1984
+ structure of obfuscated models corresponding to different random seeds is different. If
1985
+ `obf_random_seed` is set, then it should be passed
1986
+ to :class:`mindspore.nn.GraphCell` interface when loading
1987
+ obfuscated model. It should be noted that at least one of `customized_func` or `obf_random_seed` should
1988
+ be set, and the latter mode would be applied if both of them are set.
1989
+
1990
+ - incremental (bool): export MindIR incrementally.
1991
+
1992
+ - custom_func (function): Functions for custom defined export policies. This function will be used to
1993
+ customize the model during network export. Currently only support for files with mindir format. The
1994
+ function only accepts one input representing the proto object of the mindir file. When modifying a model,
1995
+ it is necessary to ensure the correctness of the `custom_func` , otherwise it may lead to model loading
1996
+ failure or functional errors. Default: ``None`` .
1997
+
1998
+ Examples:
1999
+ >>> import mindspore as ms
2000
+ >>> import numpy as np
2001
+ >>> from mindspore import Tensor
2002
+ >>>
2003
+ >>> # Define the network structure of LeNet5. Refer to
2004
+ >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
2005
+ >>> net = LeNet5()
2006
+ >>> input_tensor = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32))
2007
+ >>> ms.export(net, input_tensor, file_name='lenet', file_format='MINDIR')
2008
+ >>>
2009
+ >>> # Export model in MindIR format and modified the model info using custom_func
2010
+ >>> # The custom_func only support one input representing the Proto object of the model
2011
+ >>> # And custom_func does not support return value
2012
+ >>> def _custom_func(mindir_model):
2013
+ ... mindir_model.producer_name = "test11111"
2014
+ ... mindir_model.producer_version = "11.0"
2015
+ ... mindir_model.user_info["version"] = "11.0"
2016
+ >>> ms.export(net, input_tensor, file_name="lenet", file_format='MINDIR', custom_func=_custom_func)
2017
+
2018
+
2019
+ Tutorial Examples:
2020
+ - `Saving and Loading the Model - Saving and Loading MindIR
2021
+ <https://mindspore.cn/tutorials/en/master/beginner/save_load.html#saving-and-loading-mindir>`_
2022
+ """
2023
+ old_ms_jit_value = context.get_context("jit_syntax_level")
2024
+ context.set_context(jit_syntax_level=mindspore.STRICT)
2025
+
2026
+ supported_formats = ['AIR', 'ONNX', 'MINDIR']
2027
+ if file_format not in supported_formats:
2028
+ raise ValueError(f"For 'export', 'file_format' must be one of {supported_formats}, but got {file_format}.")
2029
+ if file_format == 'AIR':
2030
+ logger.warning("AIR format is deprecated, and will be removed in a future version, please use other format or "
2031
+ "use MindSpore Lite to do offline inference")
2032
+ Validator.check_file_name_by_regular(file_name)
2033
+ logger.info("exporting model file:%s format:%s.", file_name, file_format)
2034
+
2035
+ if check_input_dataset(*inputs, dataset_type=mindspore.dataset.Dataset):
2036
+ if len(inputs) != 1:
2037
+ raise RuntimeError(f"You can only serialize one dataset into MindIR, got " + str(len(inputs)) + " datasets")
2038
+ shapes, types, columns = inputs[0].output_shapes(), inputs[0].output_types(), inputs[0].get_col_names()
2039
+ kwargs['dataset'] = inputs[0]
2040
+ only_support_col = "image"
2041
+
2042
+ inputs_col = list()
2043
+ for c, s, t in zip(columns, shapes, types):
2044
+ if only_support_col != c:
2045
+ continue
2046
+ inputs_col.append(Tensor(np.random.uniform(-1.0, 1.0, size=s).astype(t)))
2047
+ if not inputs_col:
2048
+ raise RuntimeError(f"Only supports parse \"image\" column from dataset now, given dataset has columns: "
2049
+ + str(columns))
2050
+ inputs = tuple(inputs_col)
2051
+
2052
+ file_name = os.path.realpath(file_name)
2053
+ if 'enc_key' in kwargs.keys():
2054
+ kwargs['enc_key'], kwargs['enc_mode'] = _check_key_mode_type(file_format, **kwargs)
2055
+ _export(net, file_name, file_format, *inputs, **kwargs)
2056
+
2057
+ context.set_context(jit_syntax_level=old_ms_jit_value)
2058
+
2059
+
2060
+ def _get_funcgraph(net, *inputs):
2061
+ """
2062
+ Compile the MindSpore network and get FuncGraph.
2063
+
2064
+ Arg:
2065
+ net (Union[Cell, function]): MindSpore network.
2066
+ inputs (Union[Tensor, Dataset, List, Tuple, Number, Bool]): It represents the inputs
2067
+ of the `net`, if the network has multiple inputs, set them together. While its type is Dataset,
2068
+ it represents the preprocess behavior of the `net`, data preprocess operations will be serialized.
2069
+ In second situation, you should adjust batch size of dataset script manually which will impact on
2070
+ the batch size of 'net' input. Only supports parse "image" column from dataset currently.
2071
+
2072
+ Returns:
2073
+ FuncGraph, a mindspore._c_expression.FuncGraph obj.
2074
+
2075
+ Raises:
2076
+ ValueError: input `net` is not a nn.Cell.
2077
+
2078
+ Examples:
2079
+ >>> import mindspore as ms
2080
+ >>> import numpy as np
2081
+ >>> from mindspore import Tensor
2082
+ >>>
2083
+ >>> # Define the network structure of LeNet5. Refer to
2084
+ >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
2085
+ >>> net = LeNet5()
2086
+ >>> input_tensor = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32))
2087
+ >>> ms.get_funcgraph(net, input_tensor)
2088
+
2089
+ """
2090
+ if not isinstance(net, nn.Cell):
2091
+ raise ValueError(f"For get_funcgraph's parameter 'net', currently only support Cell right now.")
2092
+ phase_name = "lite_infer_predict" if _is_in_auto_parallel_mode() else "lite_infer_get_func_graph"
2093
+ graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False)
2094
+ # pylint: disable=protected-access
2095
+ func_graph = _executor._get_func_graph(net, graph_id)
2096
+ return func_graph
2097
+
2098
+
2099
+ def _export(net, file_name, file_format, *inputs, **kwargs):
2100
+ """
2101
+ It is an internal conversion function. Export the MindSpore prediction model to a file in the specified format.
2102
+ """
2103
+ logger.info("exporting model file:%s format:%s.", file_name, file_format)
2104
+ if "obf_config" in kwargs and file_format != "MINDIR":
2105
+ raise ValueError(f"Dynamic obfuscation only support for MindIR format, but got {file_format} format.")
2106
+ if "custom_func" in kwargs and file_format != "MINDIR":
2107
+ raise ValueError(f"Currently only support custom_func for MindIR format, but got {file_format} format.")
2108
+ if file_format == 'AIR':
2109
+ _save_air(net, file_name, *inputs, **kwargs)
2110
+ elif file_format == 'ONNX':
2111
+ _save_onnx(net, file_name, *inputs, **kwargs)
2112
+ elif file_format == 'MINDIR':
2113
+ _save_mindir(net, file_name, *inputs, **kwargs)
2114
+
2115
+
2116
+ def _check_key_mode_type(file_format, **kwargs):
2117
+ """check enc_key and enc_mode are valid"""
2118
+ enc_key = Validator.check_isinstance('enc_key', kwargs.get('enc_key'), bytes)
2119
+ enc_mode = kwargs.get('enc_mode')
2120
+
2121
+ if callable(enc_mode):
2122
+ return enc_key, enc_mode
2123
+
2124
+ enc_mode = 'AES-GCM'
2125
+ if 'enc_mode' in kwargs.keys():
2126
+ enc_mode = Validator.check_isinstance('enc_mode', kwargs.get('enc_mode'), str)
2127
+
2128
+ if file_format in ('AIR', 'ONNX'):
2129
+ raise ValueError(f"AIR/ONNX only support customized encryption, but got {enc_mode}.")
2130
+
2131
+ if enc_mode in ('AES-CBC', 'AES-GCM', 'SM4-CBC'):
2132
+ return enc_key, enc_mode
2133
+ raise ValueError(f"MindIR only support AES-GCM/AES-CBC/SM4-CBC encryption, but got {enc_mode}")
2134
+
2135
+
2136
+ def _save_air(net, file_name, *inputs, **kwargs):
2137
+ """Save AIR format file."""
2138
+ phase_name = 'export.air'
2139
+ graph_id, _ = _executor.compile(net, *inputs, phase=phase_name)
2140
+ if not file_name.endswith('.air'):
2141
+ file_name += ".air"
2142
+ if os.path.exists(file_name):
2143
+ os.chmod(file_name, stat.S_IWUSR)
2144
+ if "/" in file_name:
2145
+ real_path = os.path.realpath(file_name[:file_name.rfind("/")])
2146
+ os.makedirs(real_path, mode=0o700, exist_ok=True)
2147
+ if 'enc_key' in kwargs.keys() and 'enc_mode' in kwargs.keys():
2148
+ _executor.export(file_name, graph_id, enc_key=kwargs.get('enc_key'), encrypt_func=kwargs.get('enc_mode'))
2149
+ else:
2150
+ _executor.export(file_name, graph_id)
2151
+ os.chmod(file_name, stat.S_IRUSR)
2152
+
2153
+
2154
+ def _save_onnx(net, file_name, *inputs, **kwargs):
2155
+ """Save ONNX format file."""
2156
+ # When dumping ONNX file, switch network mode to infer when it is training(NOTE: ONNX only designed for prediction)
2157
+ if not isinstance(net, nn.Cell):
2158
+ raise ValueError(f"Export ONNX format model only support nn.Cell object, but got {type(net)}.")
2159
+ _check_dynamic_input(inputs)
2160
+ cell_mode = net.training
2161
+ net.set_train(mode=False)
2162
+ total_size = _calculation_net_size(net)
2163
+ if total_size > PROTO_LIMIT_SIZE:
2164
+ raise RuntimeError('Export onnx model failed. Network size is: {}G, it exceeded the protobuf: {}G limit.'
2165
+ .format(total_size / 1024 / 1024, PROTO_LIMIT_SIZE / 1024 / 1024))
2166
+ phase_name = 'export.onnx'
2167
+ graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False)
2168
+ onnx_stream = _executor._get_func_graph_proto(net, graph_id)
2169
+ if 'enc_key' in kwargs.keys() and 'enc_mode' in kwargs.keys():
2170
+ enc_mode = kwargs.get('enc_mode')
2171
+ onnx_stream = enc_mode(onnx_stream, kwargs.get('enc_key'))
2172
+ if not file_name.endswith('.onnx'):
2173
+ file_name += ".onnx"
2174
+ if os.path.exists(file_name):
2175
+ os.chmod(file_name, stat.S_IWUSR)
2176
+ with open(file_name, 'wb') as f:
2177
+ f.write(onnx_stream)
2178
+ os.chmod(file_name, stat.S_IRUSR)
2179
+ net.set_train(mode=cell_mode)
2180
+
2181
+
2182
+ def _check_dynamic_input(inputs):
2183
+ for ele in inputs:
2184
+ if isinstance(ele, Tensor) and -1 in ele.shape:
2185
+ raise ValueError(f"Export ONNX format model not support dynamic shape mode.")
2186
+
2187
+
2188
+ def _generate_front_info_for_param_data_file(is_encrypt, kwargs):
2189
+ front_info = bytes()
2190
+ check_code = sys.byteorder == "little"
2191
+ front_info += check_code.to_bytes(1, byteorder=sys.byteorder)
2192
+ front_info += bytes(63)
2193
+ if is_encrypt():
2194
+ front_info = _encrypt(front_info, len(front_info), kwargs.get('enc_key'),
2195
+ len(kwargs.get('enc_key')), kwargs.get('enc_mode'))
2196
+ return front_info
2197
+
2198
+
2199
+ def _change_file(f, dirname, external_local, is_encrypt, kwargs):
2200
+ """Change to another file to write parameter data."""
2201
+ # The parameter has been not written in the file
2202
+ front_info = _generate_front_info_for_param_data_file(is_encrypt, kwargs)
2203
+ f.seek(0, 0)
2204
+ f.write(front_info)
2205
+ f.close()
2206
+ ori_data_file_name = f.name
2207
+ os.chmod(ori_data_file_name, stat.S_IRUSR)
2208
+ if os.path.getsize(ori_data_file_name) == 64:
2209
+ raise RuntimeError("The parameter size is exceed 1T,cannot export to the file")
2210
+ data_file_name = os.path.join(dirname, external_local)
2211
+ return _get_data_file(is_encrypt, kwargs, data_file_name)
2212
+
2213
+
2214
+ def _get_data_file(is_encrypt, kwargs, data_file_name):
2215
+ """Get Data File to write parameter data."""
2216
+ # Reserves 64 bytes as spare information such as check data
2217
+ offset = 64
2218
+ if os.path.exists(data_file_name):
2219
+ os.chmod(data_file_name, stat.S_IWUSR)
2220
+
2221
+ place_holder_data = bytes(offset)
2222
+ if is_encrypt():
2223
+ place_holder_data = _encrypt(place_holder_data, len(place_holder_data), kwargs["enc_key"],
2224
+ len(kwargs["enc_key"]), kwargs["enc_mode"])
2225
+ parameter_size = (offset / 1024)
2226
+ try:
2227
+ f = open(data_file_name, "wb")
2228
+ f.write(place_holder_data)
2229
+ except IOError:
2230
+ f.close()
2231
+
2232
+ return f, parameter_size, offset
2233
+
2234
+
2235
+ def _encrypt_data(is_encrypt, write_data, kwargs):
2236
+ """Encrypt parameter data."""
2237
+ if is_encrypt():
2238
+ if callable(kwargs.get('enc_mode')):
2239
+ enc_func = kwargs.get('enc_mode')
2240
+ write_data = enc_func(write_data, kwargs.get('enc_key'))
2241
+ else:
2242
+ write_data = _encrypt(write_data, len(write_data), kwargs.get('enc_key'),
2243
+ len(kwargs.get('enc_key')), kwargs.get('enc_mode'))
2244
+ return write_data
2245
+
2246
+
2247
+ def _split_save(net_dict, model, file_name, is_encrypt, **kwargs):
2248
+ """The function to save parameter data."""
2249
+ logger.warning("Parameters in the net capacity exceeds 1G, save MindIR model and parameters separately.")
2250
+ # save parameter
2251
+ if model.graph.map_parameter:
2252
+ raise ValueError("MapParameter not support save in split MindIR file now.")
2253
+ file_prefix = file_name.split("/")[-1]
2254
+ if file_prefix.endswith(".mindir"):
2255
+ file_prefix = file_prefix[:-7]
2256
+ current_path = os.path.realpath(file_name)
2257
+ dirname = os.path.dirname(current_path)
2258
+ data_path = os.path.join(dirname, file_prefix + "_variables")
2259
+ if os.path.exists(data_path):
2260
+ shutil.rmtree(data_path)
2261
+ os.makedirs(data_path, mode=0o700, exist_ok=True)
2262
+ os.chmod(data_path, stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR)
2263
+ index = 0
2264
+ external_local = os.path.join(file_prefix + "_variables", "data_" + str(index))
2265
+ data_file_name = os.path.join(dirname, external_local)
2266
+ f, parameter_size, offset = _get_data_file(is_encrypt, kwargs, data_file_name)
2267
+ try:
2268
+ round = 0
2269
+ names = []
2270
+ for param_proto in model.graph.parameter:
2271
+ name = param_proto.name[param_proto.name.find(":") + 1:]
2272
+ names.append((name, param_proto))
2273
+ names.sort(key=lambda x: x[0])
2274
+ for pairs in names:
2275
+ name = pairs[0]
2276
+ param_proto = pairs[1]
2277
+ param = net_dict[name]
2278
+ raw_data = param.data.get_bytes()
2279
+ data_length = len(raw_data)
2280
+ append_size = 0
2281
+ if data_length % 64 != 0:
2282
+ append_size = 64 - (data_length % 64)
2283
+ parameter_size += ((append_size + data_length) / 1024)
2284
+ if parameter_size > PARAMETER_SPLIT_SIZE:
2285
+ index += 1
2286
+ external_local = os.path.join(file_prefix + "_variables", "data_" + str(index))
2287
+ f, parameter_size, offset = _change_file(f, dirname, external_local, is_encrypt, kwargs)
2288
+ parameter_size += ((append_size + data_length) / 1024)
2289
+ param_proto.external_data.location = external_local
2290
+ param_proto.external_data.length = data_length
2291
+ param_proto.external_data.offset = offset
2292
+ write_data = raw_data + bytes(append_size)
2293
+ offset += (data_length + append_size)
2294
+ write_data = _encrypt_data(is_encrypt, write_data, kwargs)
2295
+ f.write(write_data)
2296
+ round += 1
2297
+ logger.debug(f"writing {round}th split data, name:{name}")
2298
+
2299
+ graph_file_name = os.path.join(dirname, file_prefix + "_graph.mindir")
2300
+ if os.path.exists(graph_file_name):
2301
+ os.chmod(graph_file_name, stat.S_IWUSR)
2302
+ with open(graph_file_name, 'wb') as model_file:
2303
+ os.chmod(graph_file_name, stat.S_IRUSR | stat.S_IWUSR)
2304
+ model_string = model.SerializeToString()
2305
+ if is_encrypt():
2306
+ model_string = _encrypt(model_string, len(model_string), kwargs.get('enc_key'),
2307
+ len(kwargs.get('enc_key')), kwargs.get('enc_mode'))
2308
+ model_file.write(model_string)
2309
+ os.chmod(graph_file_name, stat.S_IRUSR)
2310
+
2311
+ front_info = _generate_front_info_for_param_data_file(is_encrypt, kwargs)
2312
+ f.seek(0, 0)
2313
+ f.write(front_info)
2314
+ finally:
2315
+ f.close()
2316
+ os.chmod(data_file_name, stat.S_IRUSR)
2317
+
2318
+
2319
+ def _msfunc_info(net, *inputs):
2320
+ """Get mindir stream and parameter dict of ms_function"""
2321
+ # pylint: disable=protected-access
2322
+ net_dict = OrderedDict()
2323
+ _ms_func_executor = _MindsporeFunctionExecutor(net, time.time() * 1e9)
2324
+ graph_id = _ms_func_executor.compile(net.__name__, *inputs)
2325
+ mindir_stream = _executor._get_func_graph_proto(net, graph_id, 'mind_ir')
2326
+ params = _ms_func_executor._graph_executor.get_params(graph_id)
2327
+ for name, value in params.items():
2328
+ net_dict[name] = Parameter(value, name=name)
2329
+ return mindir_stream, net_dict
2330
+
2331
+
2332
+ def _cell_info(net, incremental, *inputs):
2333
+ """Get mindir stream and net dict of cell"""
2334
+ phase_name = "export.mindir"
2335
+ graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False)
2336
+ # pylint: disable=protected-access
2337
+ mindir_stream = _executor._get_func_graph_proto(net, graph_id, 'mind_ir', incremental=incremental)
2338
+ # clean obfuscation config to prevent the next call
2339
+ _executor.obfuscate_config = None
2340
+
2341
+ net_dict = net.parameters_dict()
2342
+ return mindir_stream, net_dict
2343
+
2344
+
2345
+ def _set_obfuscate_config(**kwargs):
2346
+ """Set obfuscation config for executor."""
2347
+ logger.warning("Obfuscate model.")
2348
+ if 'enc_mode' in kwargs.keys():
2349
+ enc_mode = Validator.check_isinstance('enc_mode', kwargs.get('enc_mode'), str)
2350
+ if enc_mode not in ["AES-GCM", "AES-CBC", "SM4-CBC"]:
2351
+ raise ValueError(
2352
+ "Only MindIR files that encrypted with 'AES-GCM', 'AES-CBC' or 'SM4-CBC' is supported for"
2353
+ "obfuscation, but got {}.".format(enc_mode))
2354
+ obf_ratio, customized_funcs, obf_random_seed = _check_obfuscate_params(kwargs.get('obf_config'))
2355
+ if customized_funcs and obf_random_seed > 0:
2356
+ logger.warning("Although 'customized_func' and 'obf_random_seed' are set, the 'obf_random_seed' mode would be"
2357
+ " applied, remember to set 'obf_random_seed' when loading obfuscated model.")
2358
+
2359
+ if obf_random_seed == 0: # apply customized_func mode
2360
+ device_target = context.get_context('device_target')
2361
+ if device_target in ["GPU", "Ascend"]:
2362
+ raise ValueError(
2363
+ "Customized func mode only support 'device_target'='CPU, but got {}.".format(device_target))
2364
+ clean_funcs()
2365
+ for func in customized_funcs:
2366
+ add_opaque_predicate(func.__name__, func)
2367
+ _executor.obfuscate_config = {'obf_ratio': obf_ratio, 'obf_random_seed': obf_random_seed}
2368
+
2369
+
2370
+ def _save_mindir(net, file_name, *inputs, **kwargs):
2371
+ """Save MindIR format file."""
2372
+ # set obfuscate configs
2373
+ if 'obf_config' in kwargs.keys():
2374
+ _set_obfuscate_config(**kwargs)
2375
+ for item in inputs:
2376
+ if -1 in item.shape:
2377
+ raise ValueError(
2378
+ "Dynamic shape input is not supported now, but got the shape of inputs: {}.".format(item.shape))
2379
+
2380
+ incremental = kwargs.get('incremental', False)
2381
+
2382
+ model = mindir_model()
2383
+ if not isinstance(net, nn.Cell):
2384
+ mindir_stream, net_dict = _msfunc_info(net, *inputs)
2385
+ else:
2386
+ mindir_stream, net_dict = _cell_info(net, incremental, *inputs)
2387
+ model.ParseFromString(mindir_stream)
2388
+
2389
+ if kwargs.get('dataset'):
2390
+ check_input_data(kwargs.get('dataset'), data_class=mindspore.dataset.Dataset)
2391
+ dataset = kwargs.get('dataset')
2392
+ _save_dataset_to_mindir(model, dataset)
2393
+
2394
+ custom_func = kwargs.get('custom_func', None)
2395
+ if custom_func is not None:
2396
+ custom_func(model)
2397
+
2398
+ save_together = _save_together(net_dict, model)
2399
+ is_encrypt = lambda: 'enc_key' in kwargs.keys() and 'enc_mode' in kwargs.keys()
2400
+ if save_together:
2401
+ _save_mindir_together(net_dict, model, file_name, is_encrypt, **kwargs)
2402
+ else:
2403
+ _split_save(net_dict, model, file_name, is_encrypt, **kwargs)
2404
+
2405
+
2406
+ def _save_mindir_together(net_dict, model, file_name, is_encrypt, **kwargs):
2407
+ """Save graph and parameter together."""
2408
+ for param_proto in model.graph.parameter:
2409
+ param_name = param_proto.name[param_proto.name.find(":") + 1:]
2410
+ if param_name in net_dict.keys():
2411
+ param_data = net_dict[param_name].data.get_bytes()
2412
+ param_proto.raw_data = param_data
2413
+ else:
2414
+ raise ValueError("The parameter '{}' is not belongs to any cell,"
2415
+ "the data of parameter cannot be exported.".format(param_proto.name))
2416
+ incremental = kwargs.get('incremental', False)
2417
+ for map_param_proto in model.graph.map_parameter:
2418
+ map_param_name = map_param_proto.name[map_param_proto.name.find(":") + 1:]
2419
+ if map_param_name in net_dict.keys():
2420
+ map_parameter = net_dict[map_param_name]
2421
+ key_bytes, value_bytes, status_bytes = map_parameter.export_bytes(incremental)
2422
+ map_param_proto.key_tensor.raw_data = key_bytes
2423
+ map_param_proto.value_tensor.raw_data = value_bytes
2424
+ map_param_proto.status_tensor.raw_data = status_bytes
2425
+ else:
2426
+ raise ValueError("The map_parameter '{}' is not belongs to any cell,"
2427
+ "the data of parameter cannot be exported.".format(map_param_proto.name))
2428
+ if not file_name.endswith('.mindir'):
2429
+ file_name += ".mindir"
2430
+ current_path = os.path.realpath(file_name)
2431
+ dirname = os.path.dirname(current_path)
2432
+ os.makedirs(dirname, mode=0o700, exist_ok=True)
2433
+ if os.path.exists(file_name):
2434
+ os.chmod(file_name, stat.S_IWUSR)
2435
+ with open(file_name, 'wb') as f:
2436
+ os.chmod(file_name, stat.S_IRUSR | stat.S_IWUSR)
2437
+ model_string = model.SerializeToString()
2438
+ if is_encrypt():
2439
+ if callable(kwargs.get('enc_mode')):
2440
+ enc_func = kwargs.get('enc_mode')
2441
+ model_string = enc_func(model_string, kwargs.get('enc_key'))
2442
+ else:
2443
+ model_string = _encrypt(model_string, len(model_string), kwargs.get('enc_key'),
2444
+ len(kwargs.get('enc_key')), kwargs.get('enc_mode'))
2445
+ f.write(model_string)
2446
+ os.chmod(file_name, stat.S_IRUSR)
2447
+
2448
+
2449
+ def _save_together(net_dict, model):
2450
+ """Whether graph and parameter save together during save mindir model."""
2451
+ data_total = 0
2452
+ for param_proto in model.graph.parameter:
2453
+ name = param_proto.name[param_proto.name.find(":") + 1:]
2454
+ if name in net_dict.keys():
2455
+ data_total += sys.getsizeof(net_dict[name].data.get_bytes()) / 1024
2456
+ else:
2457
+ raise ValueError("The parameter '{}' is not belongs to any cell,"
2458
+ "the data of parameter cannot be exported.".format(param_proto.name))
2459
+ if data_total > TOTAL_SAVE:
2460
+ return False
2461
+ return True
2462
+
2463
+
2464
+ def _save_dataset_to_mindir(model, dataset):
2465
+ """Save dataset preprocess operations into mindir model."""
2466
+ dataset_json = dataset.to_json()
2467
+ reverse_dataset = []
2468
+ while dataset_json:
2469
+ reverse_dataset = [dataset_json] + reverse_dataset
2470
+ if len(dataset_json['children']) > 1:
2471
+ logger.warning("Need to support dataset_node with more than one child, using child 0 as default.")
2472
+ dataset_json = dataset_json['children'][0] if dataset_json['children'] else []
2473
+
2474
+ for op in reverse_dataset:
2475
+ if op['op_type'] == 'Map':
2476
+ model.preprocessor.op.add()
2477
+ model.preprocessor.op[-1].input_columns = json.dumps(op['input_columns'])
2478
+ model.preprocessor.op[-1].output_columns = json.dumps(op['output_columns'])
2479
+ model.preprocessor.op[-1].op_type = json.dumps(op['op_type'])
2480
+ model.preprocessor.op[-1].operations = json.dumps(op['operations'])
2481
+ model.preprocessor.op[-1].offload = op['offload'] if 'offload' in op.keys() else False
2482
+
2483
+
2484
+ def check_checkpoint(ckpt_file_name):
2485
+ """
2486
+ Check whether the checkpoint is valid.
2487
+
2488
+ Args:
2489
+ ckpt_file_name (str): Checkpoint file name.
2490
+
2491
+ Returns:
2492
+ bool, whether the checkpoint is valid.
2493
+
2494
+ Examples:
2495
+ >>> import mindspore as ms
2496
+ >>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt"
2497
+ >>> check_result = ms.check_checkpoint(ckpt_file_name)
2498
+ >>> print(check_result)
2499
+ True
2500
+ """
2501
+ if not ckpt_file_name.endswith('.ckpt'):
2502
+ return False
2503
+ checkpoint_list = Checkpoint()
2504
+ with _ckpt_fs.open(ckpt_file_name, *_ckpt_fs.open_args) as f:
2505
+ pb_content = f.read()
2506
+ if pb_content[-17:-10] == b"crc_num":
2507
+ crc_num_bytes = pb_content[-10:]
2508
+ pb_content = pb_content[:-17]
2509
+ crc_num = int.from_bytes(crc_num_bytes, byteorder='big')
2510
+ cal_crc_num = binascii.crc32(pb_content, 0)
2511
+ if cal_crc_num != crc_num:
2512
+ logger.warning("For 'check_checkpoint', the ckpt crc check is failed.")
2513
+ return False
2514
+ try:
2515
+ checkpoint_list.ParseFromString(pb_content)
2516
+ except google.protobuf.message.DecodeError as e:
2517
+ logger.warning("For 'check_checkpoint', the ckpt parse is failed.")
2518
+ logger.warning(e)
2519
+ return False
2520
+ return True
2521
+
2522
+
2523
+ def parse_print(print_file_name):
2524
+ """
2525
+ Parse data file generated by :class:`mindspore.ops.Print`.
2526
+
2527
+ Args:
2528
+ print_file_name (str): The file name needs to be parsed.
2529
+
2530
+ Returns:
2531
+ List, element of list is Tensor.
2532
+
2533
+ Raises:
2534
+ ValueError: The print file does not exist or is empty.
2535
+ RuntimeError: Failed to parse the file.
2536
+
2537
+ Examples:
2538
+ >>> import numpy as np
2539
+ >>> import mindspore as ms
2540
+ >>> from mindspore import nn, Tensor, ops
2541
+ >>> ms.set_context(mode=ms.GRAPH_MODE, print_file_path='log.data')
2542
+ >>> class PrintInputTensor(nn.Cell):
2543
+ ... def __init__(self):
2544
+ ... super().__init__()
2545
+ ... self.print = ops.Print()
2546
+ ...
2547
+ ... def construct(self, input_pra):
2548
+ ... self.print('print:', input_pra)
2549
+ ... return input_pra
2550
+ >>> x = np.array([[1, 2, 3, 4], [5, 6, 7, 8]]).astype(np.float32)
2551
+ >>> input_pra = Tensor(x)
2552
+ >>> net = PrintInputTensor()
2553
+ >>> net(input_pra)
2554
+ >>>
2555
+ >>> data = ms.parse_print('./log.data')
2556
+ >>> print(data)
2557
+ ['print:', Tensor(shape=[2, 4], dtype=Float32, value=
2558
+ [[ 1.00000000e+00, 2.00000000e+00, 3.00000000e+00, 4.00000000e+00],
2559
+ [ 5.00000000e+00, 6.00000000e+00, 7.00000000e+00, 8.00000000e+00]])]
2560
+ """
2561
+ print_file_path = os.path.realpath(print_file_name)
2562
+
2563
+ if os.path.getsize(print_file_path) == 0:
2564
+ raise ValueError("For 'parse_print', the print file may be empty, please make sure enter the correct "
2565
+ "'print_file_name'.")
2566
+
2567
+ logger.info("Execute load print process.")
2568
+ print_list = Print()
2569
+
2570
+ try:
2571
+ with open(print_file_path, "rb") as f:
2572
+ pb_content = f.read()
2573
+ print_list.ParseFromString(pb_content)
2574
+ except BaseException as e:
2575
+ logger.critical("Failed to read the print file %s, please check whether the file is "
2576
+ "correct.", print_file_name)
2577
+ raise ValueError(e.__str__() + "\nFailed to read the print file {}, please check whether "
2578
+ "the file is correct.".format(print_file_name)) from e
2579
+
2580
+ tensor_list = []
2581
+
2582
+ try:
2583
+ for print_ in print_list.value:
2584
+ # String type
2585
+ if print_.HasField("desc"):
2586
+ tensor_list.append(print_.desc)
2587
+ elif print_.HasField("tensor"):
2588
+ dims = print_.tensor.dims
2589
+ data_type = print_.tensor.tensor_type
2590
+ data = print_.tensor.tensor_content
2591
+ np_type = tensor_to_np_type.get(data_type)
2592
+ param_data = np.fromstring(data, np_type)
2593
+ ms_type = tensor_to_ms_type.get(data_type)
2594
+ if dims and dims != [0]:
2595
+ param_value = param_data.reshape(dims)
2596
+ tensor_list.append(Tensor(param_value, ms_type))
2597
+ # Scalar type
2598
+ else:
2599
+ data_type_ = data_type.lower()
2600
+ if 'float' in data_type_:
2601
+ param_data = float(param_data[0])
2602
+ elif 'int' in data_type_:
2603
+ param_data = int(param_data[0])
2604
+ elif 'bool' in data_type_:
2605
+ param_data = bool(param_data[0])
2606
+ tensor_list.append(Tensor(param_data, ms_type))
2607
+
2608
+ except BaseException as e:
2609
+ logger.critical("Failed to load the print file %s.", print_list)
2610
+ raise RuntimeError(e.__str__() + "\nFailed to load the print file {}.".format(print_list)) from e
2611
+
2612
+ return tensor_list
2613
+
2614
+
2615
+ def _merge_param_with_strategy(sliced_data, parameter_name, strategy, is_even):
2616
+ """
2617
+ Merge data slices to one tensor with whole data when strategy is not None.
2618
+
2619
+ Args:
2620
+ sliced_data (list[numpy.ndarray]): Data slices in order of rank_id.
2621
+ parameter_name (str): Name of parameter.
2622
+ strategy (dict): Parameter slice strategy.
2623
+ is_even (bool): Slice manner that True represents slicing evenly and False represents slicing unevenly.
2624
+
2625
+ Returns:
2626
+ Tensor, the merged Tensor which has the whole data.
2627
+
2628
+ Raises:
2629
+ ValueError: Failed to merge.
2630
+ """
2631
+ layout = strategy.get(parameter_name)
2632
+ try:
2633
+ dev_mat = list(layout.dev_matrix[0].dim)
2634
+ tensor_map = list(layout.tensor_map[0].dim)
2635
+ param_split_shape = list(layout.param_split_shape[0].dim)
2636
+ field_size = int(layout.field)
2637
+ except BaseException as e:
2638
+ raise ValueError(f"{e.__str__()}. For 'merge_sliced_parameter'"
2639
+ f", please make sure that 'strategy' is correct.") from e
2640
+
2641
+ device_count = 1
2642
+ for dim in dev_mat:
2643
+ device_count *= dim
2644
+
2645
+ if len(sliced_data) != device_count:
2646
+ raise ValueError(f"For 'merge_sliced_parameter', the length of 'sliced_parameters' should be equal to "
2647
+ f"device_count. The length of 'sliced_parameters' is {len(sliced_data)}, but "
2648
+ f"device_count is {device_count}.")
2649
+
2650
+ if not param_split_shape:
2651
+ if not is_even:
2652
+ raise ValueError("For 'merge_sliced_parameter', the shape of every parameter in 'sliced_parameters' "
2653
+ "should be the same when slice manner is even.")
2654
+
2655
+ all_gather_tensor = Tensor(np.concatenate(sliced_data))
2656
+
2657
+ if field_size > 0:
2658
+ merged_tensor = _reshape_param_data_with_weight(all_gather_tensor, dev_mat, field_size)
2659
+ else:
2660
+ merged_tensor = _reshape_param_data(all_gather_tensor, dev_mat, tensor_map)
2661
+
2662
+ else:
2663
+ tensor_strategy = _get_tensor_strategy(dev_mat, tensor_map)
2664
+
2665
+ slice_count = 1
2666
+ for dim in tensor_strategy:
2667
+ slice_count *= dim
2668
+
2669
+ if len(param_split_shape) != slice_count:
2670
+ raise ValueError(f"For 'merge_sliced_parameter', the param_split_shape length in 'strategy' should be "
2671
+ f"{slice_count}, but got {len(param_split_shape)}.")
2672
+
2673
+ tensor_slices_new = list(range(slice_count))
2674
+ tensor_slices = sliced_data
2675
+ for i in range(device_count):
2676
+ slice_index = int(_get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, i))
2677
+ if tensor_slices[i].shape[0] != param_split_shape[slice_index]:
2678
+ raise ValueError(f"For 'merge_sliced_parameter', the slice {slice_index} should be "
2679
+ f"{param_split_shape[slice_index]} in 0 axis, but got "
2680
+ f"{tensor_slices[i].shape[0]}.")
2681
+ tensor_slices_new[slice_index] = np.array(tensor_slices[i])
2682
+
2683
+ dim_len = len(tensor_strategy)
2684
+ for i in range(dim_len):
2685
+ ele_count = int(len(tensor_slices_new) / tensor_strategy[dim_len - 1 - i])
2686
+ tensor_slices_new_inner = []
2687
+ for j in range(ele_count):
2688
+ new_tensor = tensor_slices_new[j * tensor_strategy[dim_len - 1 - i]]
2689
+ for k in range(j * tensor_strategy[dim_len - 1 - i] + 1,
2690
+ (j + 1) * tensor_strategy[dim_len - 1 - i]):
2691
+ new_tensor = np.concatenate((new_tensor, tensor_slices_new[k]), axis=dim_len - 1 - i)
2692
+ tensor_slices_new_inner.insert(len(tensor_slices_new_inner), np.array(new_tensor))
2693
+ tensor_slices_new = tensor_slices_new_inner
2694
+ merged_tensor = Tensor(tensor_slices_new[0])
2695
+
2696
+ return merged_tensor
2697
+
2698
+
2699
+ def restore_group_info_list(group_info_file_name):
2700
+ """
2701
+ Build rank list, the checkpoint of ranks in the rank list has the same contents with the local rank
2702
+ who saves the `group_info_file_name`. To save the group info file, please export GROUP_INFO_FIL
2703
+ environment variables like "export GROUP_INFO_FILE=/data/group_info.pb".
2704
+
2705
+ Args:
2706
+ group_info_file_name (str): Name of group information file.
2707
+
2708
+ Returns:
2709
+ List, the rank list.
2710
+
2711
+ Raises:
2712
+ ValueError: group information file is incorrect.
2713
+ TypeError: `group_info_file_name` is not str.
2714
+
2715
+ Examples:
2716
+ >>> import mindspore as ms
2717
+ >>> ms.restore_list = restore_group_info_list("./group_info.pb")
2718
+ """
2719
+ if not isinstance(group_info_file_name, str):
2720
+ raise TypeError(f"For 'restore_group_info_list', the argument 'group_info_file_name' should be str, "
2721
+ f"but got {type(group_info_file_name)}.")
2722
+
2723
+ if not os.path.isfile(group_info_file_name):
2724
+ raise ValueError(f"For 'restore_group_info_list', no such group information file: {group_info_file_name}.")
2725
+
2726
+ if os.path.getsize(group_info_file_name) == 0:
2727
+ raise ValueError("For 'restore_group_info_list', the group information file should not be empty.")
2728
+
2729
+ return _restore_group_info_list(group_info_file_name)
2730
+
2731
+
2732
+ def build_searched_strategy(strategy_filename):
2733
+ """
2734
+ Build strategy of every parameter in network. Used in the case of distributed inference.
2735
+
2736
+ Args:
2737
+ strategy_filename (str): Name of strategy file.
2738
+
2739
+ Returns:
2740
+ Dict, whose key is parameter name and value is slice strategy of this parameter.
2741
+
2742
+ Raises:
2743
+ ValueError: Strategy file is incorrect.
2744
+ TypeError: `strategy_filename` is not a string.
2745
+
2746
+ Examples:
2747
+ >>> import mindspore as ms
2748
+ >>> strategy = ms.build_searched_strategy("./strategy_train.ckpt")
2749
+ """
2750
+ return _build_searched_strategy(strategy_filename)
2751
+
2752
+
2753
+ def merge_sliced_parameter(sliced_parameters, strategy=None):
2754
+ """
2755
+ Merge parameter slices into one parameter. Used in the case of distributed inference.
2756
+
2757
+ Args:
2758
+ sliced_parameters (list[Parameter]): Parameter slices in order of rank id.
2759
+ strategy (Optional[dict]): Parameter slice strategy, whose key is parameter name and
2760
+ value is slice strategy of this parameter. If strategy is None, just merge
2761
+ parameter slices in 0 axis order. Default: ``None``.
2762
+
2763
+ Returns:
2764
+ Parameter, the merged parameter which has the whole data.
2765
+
2766
+ Raises:
2767
+ ValueError: Failed to merge.
2768
+ TypeError: The sliced_parameters is incorrect or strategy is not dict.
2769
+ KeyError: The parameter name is not in keys of strategy.
2770
+
2771
+ Examples:
2772
+ >>> import numpy as np
2773
+ >>> import mindspore as ms
2774
+ >>> from mindspore import Tensor, Parameter
2775
+ >>>
2776
+ >>> sliced_parameters = [
2777
+ ... Parameter(Tensor(np.array([0.00023915, 0.00013939, -0.00098059])),
2778
+ ... "network.embedding_table"),
2779
+ ... Parameter(Tensor(np.array([0.00015815, 0.00015458, -0.00012125])),
2780
+ ... "network.embedding_table"),
2781
+ ... Parameter(Tensor(np.array([0.00042165, 0.00029692, -0.00007941])),
2782
+ ... "network.embedding_table"),
2783
+ ... Parameter(Tensor(np.array([0.00084451, 0.00089960, -0.00010431])),
2784
+ ... "network.embedding_table")]
2785
+ >>> merged_parameter = ms.merge_sliced_parameter(sliced_parameters)
2786
+ >>> print(merged_parameter)
2787
+ Parameter (name=network.embedding_table, shape=(12,), dtype=Float64, requires_grad=True)
2788
+ """
2789
+ if not isinstance(sliced_parameters, list):
2790
+ raise TypeError(f"For 'merge_sliced_parameter', the argument 'sliced_parameters' should be list, "
2791
+ f"but got {type(sliced_parameters)}.")
2792
+
2793
+ if not sliced_parameters:
2794
+ raise ValueError("For 'merge_sliced_parameter', the argument 'sliced_parameters' should not be empty.")
2795
+
2796
+ if strategy and not isinstance(strategy, dict):
2797
+ raise TypeError(f"For 'merge_sliced_parameter', the argument 'strategy' should be dict, "
2798
+ f"but got {type(strategy)}.")
2799
+
2800
+ try:
2801
+ parameter_name = sliced_parameters[0].name
2802
+ parameter_shape = sliced_parameters[0].data.shape
2803
+ parameter_shape_length = len(parameter_shape)
2804
+ except BaseException as e:
2805
+ raise TypeError(e.__str__() + f" For 'merge_sliced_parameter', the element in 'sliced_parameters' should be "
2806
+ f"'Parameter', but got {type(sliced_parameters[0])} at index 0.") from e
2807
+
2808
+ is_even = True
2809
+ for index, parameter in enumerate(sliced_parameters):
2810
+ if not isinstance(parameter, Parameter):
2811
+ raise TypeError(f"For 'merge_sliced_parameter', the element in 'sliced_parameters' should be 'Parameter', "
2812
+ f"but got {type(parameter)} at index {index}.")
2813
+
2814
+ if parameter.name != parameter_name \
2815
+ or len(parameter.data.shape) != parameter_shape_length \
2816
+ or parameter.data.shape[1:] != parameter_shape[1:]:
2817
+ raise ValueError(f"For 'merge_sliced_parameter', please make sure that the elements in 'slice_parameters'"
2818
+ f" have the same name, dimension length and shape except 0 axis. The name, dimension "
2819
+ f"length, shape except 0 axis should be {parameter_name}, {parameter_shape_length}, "
2820
+ f"{parameter_shape[1:]}, but got name: {parameter.name}, dimension length: "
2821
+ f"{len(parameter.data.shape)}, shape except 0 axis: {parameter.data.shape[1:]} "
2822
+ f"at index {index}.")
2823
+
2824
+ if parameter.data.shape != parameter_shape:
2825
+ is_even = False
2826
+
2827
+ layerwise_parallel = sliced_parameters[0].layerwise_parallel
2828
+ requires_grad = sliced_parameters[0].requires_grad
2829
+ sliced_data = []
2830
+ for parameter in sliced_parameters:
2831
+ if parameter.data.dtype == mstype.bfloat16:
2832
+ sliced_data.append(cpu_cast(parameter.data, mstype.float32).asnumpy())
2833
+ else:
2834
+ sliced_data.append(parameter.data.asnumpy())
2835
+
2836
+ if not strategy:
2837
+ merged_tensor = Tensor(np.concatenate(sliced_data))
2838
+ merged_parameter = Parameter(merged_tensor, parameter_name, requires_grad, layerwise_parallel)
2839
+
2840
+ else:
2841
+ if parameter_name not in strategy.keys():
2842
+ raise KeyError(f"For 'merge_sliced_parameter', the parameter name {parameter_name} should be a key in "
2843
+ f"the 'strategy'. Please check 'sliced_parameter' and 'strategy'.")
2844
+ merged_tensor = _merge_param_with_strategy(sliced_data, parameter_name, strategy, is_even)
2845
+ merged_parameter = Parameter(merged_tensor, parameter_name, requires_grad, layerwise_parallel)
2846
+
2847
+ return merged_parameter
2848
+
2849
+
2850
+ def load_distributed_checkpoint(network, checkpoint_filenames=None, predict_strategy=None,
2851
+ train_strategy_filename=None, strict_load=False, dec_key=None, dec_mode='AES-GCM',
2852
+ format='ckpt', unified_safetensors_dir=None, dst_safetensors_dir=None, rank_id=None):
2853
+ """
2854
+ Load checkpoint into net for distributed predication. Used in the case of distributed inference.
2855
+
2856
+ Args:
2857
+ network (Cell): Network for distributed predication.
2858
+ checkpoint_filenames (list[str]): The name of Checkpoint files in order of rank id. Default: ``None`` .
2859
+ predict_strategy (dict): Strategy of predication process. It means that using one device to predict
2860
+ when setting predict_strategy as None. Default: ``None`` .
2861
+ train_strategy_filename (str): The filename of training strategy protocol buffer file.
2862
+ When train_strategy_filename is None, the training strategy file will be
2863
+ obtained from context.get_auto_parallel_context("strategy_ckpt_load_file").
2864
+ Therefore, the training strategy file needs to be specified
2865
+ in at least one of them. Default: ``None`` .
2866
+ strict_load (bool): Whether to strict load the parameter into net. If ``False`` , it will load parameter
2867
+ into net when parameter name's suffix in checkpoint file is the same as the
2868
+ parameter in the network. When the types are inconsistent, perform type conversion
2869
+ on the parameters of the same type, such as float32 to float16. Default: ``False`` .
2870
+ dec_key (Union[None, bytes]): Byte type key used for decryption. If the value is ``None`` , the decryption
2871
+ is not required. Default: ``None`` .
2872
+ dec_mode (str): This parameter is valid only when dec_key is not set to ``None`` . Specifies the decryption
2873
+ mode, currently supports ``'AES-GCM'`` , ``'AES-CBC'`` and ``'SM4-CBC'`` .
2874
+ Default: ``'AES-GCM'`` .
2875
+ format (str): Input weight format to be loaded into the network.
2876
+ It can be set to either "ckpt" or "safetensors". Default: "ckpt".
2877
+ unified_safetensors_dir (str): Directory of input weight files to be loaded into the network.
2878
+ Default: ``None`` .
2879
+ dst_safetensors_dir (str): In the save mode scenario, the save directory for safetensors.
2880
+ rank_id (int): The logical sequence number of the card. In non save mode, it is automatically obtained
2881
+ globally by initializing the network; In save mode, save the file according to the input
2882
+ sequence number. If it is not input, save the entire file.
2883
+
2884
+ Raises:
2885
+ TypeError: The type of inputs do not match the requirements.
2886
+ ValueError: Failed to load checkpoint into net.
2887
+
2888
+ Supported Platforms:
2889
+ ``Ascend`` ``GPU``
2890
+
2891
+ Examples:
2892
+ .. note::
2893
+ Before running the following examples, you need to configure the communication environment variables.
2894
+
2895
+ For the Ascend devices, users need to prepare the rank table, set rank_id and device_id.
2896
+ Please see the `rank table startup
2897
+ <https://www.mindspore.cn/docs/en/master/model_train/parallel/rank_table.html>`_
2898
+ for more details.
2899
+
2900
+ For the GPU devices, users need to prepare the host file and mpi, please see the `mpirun startup
2901
+ <https://www.mindspore.cn/docs/en/master/model_train/parallel/mpirun.html>`_ .
2902
+
2903
+ For the CPU device, users need to write a dynamic cluster startup script, please see the `Dynamic Cluster
2904
+ Startup <https://www.mindspore.cn/docs/en/master/model_train/parallel/dynamic_cluster.html>`_ .
2905
+
2906
+ >>> import os
2907
+ >>> import numpy as np
2908
+ >>> import mindspore as ms
2909
+ >>> import mindspore.dataset as ds
2910
+ >>> from mindspore import nn, ops, train
2911
+ >>> from mindspore.communication import init
2912
+ >>>
2913
+ >>> step_per_epoch = 4
2914
+ >>> device_num = 8
2915
+ >>>
2916
+ >>> # Define the network structure.
2917
+ >>> class Net(nn.Cell):
2918
+ ... def __init__(self, matmul_size, strategy=None):
2919
+ ... super().__init__()
2920
+ ... matmul_np = np.full(matmul_size, 0.5, dtype=np.float32)
2921
+ ... self.matmul_weight = ms.Parameter(ms.Tensor(matmul_np))
2922
+ ... self.matmul = ops.MatMul()
2923
+ ... self.neg = ops.Neg()
2924
+ ... if strategy is not None:
2925
+ ... self.matmul.shard(strategy)
2926
+ ...
2927
+ ... def construct(self, inputs):
2928
+ ... x = self.matmul(inputs, self.matmul_weight)
2929
+ ... x = self.neg(x)
2930
+ ... return x
2931
+ >>>
2932
+ >>> # Create dataset.
2933
+ >>> def get_dataset(*inputs):
2934
+ ... def generate():
2935
+ ... for _ in range(step_per_epoch):
2936
+ ... yield inputs
2937
+ ... return generate
2938
+ >>>
2939
+ >>> # Train network and save distributed checkpoint.
2940
+ >>> def train_net():
2941
+ ... ms.set_context(mode=ms.GRAPH_MODE)
2942
+ ... init()
2943
+ ... np.random.seed(1)
2944
+ ... input_data = np.random.rand(16, 96).astype(np.float32)
2945
+ ... label_data = np.random.rand(16, 16).astype(np.float32)
2946
+ ... fake_dataset = get_dataset(input_data, label_data)
2947
+ ... dataset = ds.GeneratorDataset(fake_dataset, ["input", "label"])
2948
+ ...
2949
+ ... # Set parallel strategy.
2950
+ ... strategy = ((1, 4), (4, 1))
2951
+ ... ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL, device_num=device_num,
2952
+ ... strategy_ckpt_save_file="./train_strategy.ckpt")
2953
+ ... network = Net(matmul_size=(96, 16), strategy=strategy)
2954
+ ... net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9)
2955
+ ... net_loss = nn.SoftmaxCrossEntropyWithLogits(reduction="mean")
2956
+ ... model = ms.Model(network=network, loss_fn=net_loss, optimizer=net_opt)
2957
+ ... ckpt_config = train.CheckpointConfig(keep_checkpoint_max=1, integrated_save=False)
2958
+ ... global_rank_id = int(os.getenv("RANK_ID"))
2959
+ ... ckpt_path = "./rank_{}_ckpt".format(global_rank_id)
2960
+ ... ckpt_callback = train.ModelCheckpoint(prefix="parallel", directory=ckpt_path, config=ckpt_config)
2961
+ ... model.train(epoch=2, train_dataset=dataset, callbacks=[ckpt_callback], dataset_sink_mode=False)
2962
+ ... ms.reset_auto_parallel_context()
2963
+ >>>
2964
+ >>> # Load distributed checkpoint and test.
2965
+ >>> def load_model():
2966
+ ... ms.set_context(mode=ms.GRAPH_MODE)
2967
+ ... init()
2968
+ ... ms.set_auto_parallel_context(full_batch=True, parallel_mode="semi_auto_parallel",
2969
+ ... strategy_ckpt_load_file="./train_strategy.ckpt", device_num=device_num)
2970
+ ... predict_data = ms.Tensor(np.random.randn(128, 96).astype(np.float32))
2971
+ ... network = Net(matmul_size=(96, 16))
2972
+ ... model = ms.Model(network)
2973
+ ... predict_layout = model.infer_predict_layout(ms.Tensor(predict_data))
2974
+ ... ckpt_file_list = ["./rank_{}_ckpt/parallel-2_4.ckpt".format(i) for i in range(0, device_num)]
2975
+ ... ms.load_distributed_checkpoint(network, ckpt_file_list, predict_layout)
2976
+ ... predict_result = model.predict(predict_data)
2977
+ ... print(predict_result)
2978
+ >>>
2979
+ >>> train_net()
2980
+ >>> load_model()
2981
+ [[-7.3259363 -7.497216 -7.398196 ... -7.374962 -7.204874 -7.234935 ]
2982
+ [ 3.362938 3.3535435 3.3832688 ... 3.4263954 3.279045 3.3202887]
2983
+ ...
2984
+ [ 1.6067538 1.6244187 1.5384722 ... 1.5449994 1.6195512 1.6176052]]
2985
+ """
2986
+ if format not in ['safetensors', 'ckpt']:
2987
+ raise ValueError(
2988
+ f"For 'load_distributed_checkpoint', 'format' must be 'ckpt' or 'safetensors', but got {format}.")
2989
+
2990
+ if format == 'safetensors':
2991
+ if unified_safetensors_dir is None:
2992
+ raise ValueError(f"For 'load_distributed_checkpoint', 'unified_safetensors_dir' can not be None "
2993
+ f"when format is 'safetensors'.")
2994
+ unsupport_param = [checkpoint_filenames, train_strategy_filename, dec_key]
2995
+ for param in unsupport_param:
2996
+ if param is not None:
2997
+ raise ValueError(f"For 'load_distributed_checkpoint', {param} must be None "
2998
+ f"when format is 'safetensors'.")
2999
+ if strict_load or dec_mode != 'AES-GCM':
3000
+ raise ValueError(f"For 'load_distributed_checkpoint', strict_load and dec_mode must be default "
3001
+ f"when format is 'safetensors'.")
3002
+ if network is not None:
3003
+ rank_id = get_rank()
3004
+ _load_parallel_checkpoint(unified_safetensors_dir, predict_strategy, network, rank_id=rank_id)
3005
+ else:
3006
+ if dst_safetensors_dir is None:
3007
+ raise ValueError(f"For 'load_distributed_checkpoint', 'dst_safetensors_dir' can not be None "
3008
+ f"when network is None.")
3009
+ if rank_id is not None:
3010
+ _load_parallel_checkpoint(unified_safetensors_dir, predict_strategy, network, dst_safetensors_dir,
3011
+ rank_id)
3012
+ else:
3013
+ dst_strategy_dict = _build_searched_strategy(predict_strategy)
3014
+ dst_stage_device_num = _get_device_num_from_strategy(dst_strategy_dict)
3015
+ dst_stage_num = _extract_pipeline_stage_num(dst_strategy_dict)
3016
+ dst_device_num = dst_stage_device_num * dst_stage_num
3017
+ processes = []
3018
+ activate_processes = 0
3019
+ for rank in range(0, dst_device_num):
3020
+ p = Process(target=_load_parallel_checkpoint, args=(
3021
+ unified_safetensors_dir, predict_strategy, network, dst_safetensors_dir, rank))
3022
+ p.start()
3023
+ processes.append(p)
3024
+ activate_processes += 1
3025
+ max_processes = 64
3026
+ if activate_processes >= max_processes:
3027
+ p = processes.pop(0)
3028
+ p.join()
3029
+ activate_processes -= 1
3030
+ for p in processes:
3031
+ p.join()
3032
+ return
3033
+
3034
+ network = Validator.check_isinstance("network", network, nn.Cell)
3035
+ _check_checkpoint_file(checkpoint_filenames)
3036
+ _check_predict_strategy(predict_strategy)
3037
+
3038
+ dec_key = Validator.check_isinstance('dec_key', dec_key, (type(None), bytes))
3039
+ dec_mode = Validator.check_isinstance('dec_mode', dec_mode, str)
3040
+
3041
+ if train_strategy_filename is None:
3042
+ train_strategy_filename = context.get_auto_parallel_context("strategy_ckpt_load_file")
3043
+ _train_strategy = build_searched_strategy(train_strategy_filename)
3044
+ train_strategy = _convert_to_list(_train_strategy)
3045
+
3046
+ train_dev_count = 1
3047
+ ckpt_file_len = len(checkpoint_filenames)
3048
+ for dim in train_strategy[list(train_strategy.keys())[0]][0]:
3049
+ train_dev_count *= dim
3050
+ if train_dev_count != ckpt_file_len:
3051
+ raise ValueError(f"For 'Load_distributed_checkpoint', the length of 'checkpoint_filenames' should be "
3052
+ f"equal to the device count of training process. "
3053
+ f"But got the length of 'checkpoint_filenames'"
3054
+ f" is {ckpt_file_len} and the device count is {train_dev_count}.")
3055
+ rank_list = _infer_rank_list(train_strategy, predict_strategy)
3056
+
3057
+ param_total_dict = defaultdict(dict)
3058
+ for file_index, file_name in enumerate(checkpoint_filenames):
3059
+ ckpt_dict = load_checkpoint(file_name, dec_key=dec_key, dec_mode=dec_mode)
3060
+ for param_name, param in ckpt_dict.items():
3061
+ param_total_dict[param_name][file_index] = param
3062
+
3063
+ param_dict = {}
3064
+ param_not_in_strategy = []
3065
+ param_not_in_ckpt = []
3066
+ for _, param in network.parameters_and_names():
3067
+ sliced_params = []
3068
+ if param.name not in rank_list.keys():
3069
+ param_not_in_strategy.append(param.name)
3070
+ continue
3071
+ if param.name not in param_total_dict:
3072
+ param_not_in_ckpt.append(param.name)
3073
+ continue
3074
+
3075
+ param_rank = rank_list.get(param.name)[0]
3076
+ skip_merge_split = rank_list.get(param.name)[1]
3077
+ shard_stride = train_strategy.get(param.name)[4]
3078
+ tensor_map = train_strategy.get(param.name)[1]
3079
+ first_dim_shard_idx = tensor_map[0] if tensor_map else -1
3080
+ device_arrangement = train_strategy.get(param.name)[0]
3081
+ first_dim_shard_size = 1
3082
+ if first_dim_shard_idx >= 0:
3083
+ first_dim_shard_size = device_arrangement[-1 - first_dim_shard_idx]
3084
+ if train_strategy.get(param.name)[5]:
3085
+ shard_size = int(ckpt_file_len / shard_stride / train_strategy.get(param.name)[5] / first_dim_shard_size)
3086
+ else:
3087
+ shard_size = 0
3088
+ for rank in param_rank:
3089
+ param_total_list = list(range(0, ckpt_file_len))
3090
+ if first_dim_shard_size != 1:
3091
+ param_total_list = _get_param_list_when_first_dim_sharded(device_arrangement, first_dim_shard_idx, rank)
3092
+ if shard_size > 0:
3093
+ rank_index = param_total_list.index(rank)
3094
+ start = rank_index // shard_size * shard_size
3095
+ param_total_list = param_total_list[start:start + shard_size]
3096
+ if shard_stride > 0:
3097
+ param_stride = []
3098
+ # merge pre parameter
3099
+ param_index = param_total_list[0:param_total_list.index(rank) + 1][::-1][::shard_stride]
3100
+ param_index.extend(param_total_list[param_total_list.index(rank):][::shard_stride])
3101
+ param_index = list(set(param_index))
3102
+ param_index.sort()
3103
+ for rank_num in param_index:
3104
+ if param_total_dict[param.name][rank_num].data.dtype == mstype.bfloat16:
3105
+ param_stride.append(
3106
+ cpu_cast(param_total_dict[param.name][rank_num].data, mstype.float32).asnumpy())
3107
+ else:
3108
+ param_stride.append(param_total_dict[param.name][rank_num].data.asnumpy())
3109
+
3110
+ sliced_param = Parameter(Tensor(np.concatenate(param_stride)), name=param.name)
3111
+ else:
3112
+ sliced_param = param_total_dict[param.name][rank]
3113
+
3114
+ sliced_params.append(sliced_param)
3115
+ if skip_merge_split:
3116
+ split_param = sliced_params[0]
3117
+ else:
3118
+ param_unique_strategy = _remove_repeated_slices(train_strategy[param.name])
3119
+ _param_unique_strategy = _convert_to_layout(param.name, param_unique_strategy)
3120
+ split_param = _merge_and_split(sliced_params, _param_unique_strategy, predict_strategy)
3121
+ opt_shard_group = predict_strategy[param.name][5] if predict_strategy else None
3122
+ if opt_shard_group:
3123
+ if split_param.data.dtype == mstype.bfloat16:
3124
+ data = cpu_cast(split_param.data, mstype.float32).asnumpy()
3125
+ else:
3126
+ data = split_param.data.asnumpy()
3127
+ rank = get_rank(opt_shard_group)
3128
+ size = get_group_size(opt_shard_group)
3129
+ try:
3130
+ data_slice = np.split(data, size)[rank]
3131
+ except BaseException as e:
3132
+ logger.critical("Failed to load opt shard slice in load distributed checkpoint for {}. Data shape is {}"
3133
+ " and group is {}".format(param.name, split_param.data.shape, opt_shard_group))
3134
+ raise RuntimeError(e.__str__() + f"\nFor 'load_distributed_checkpoint', failed to load opt shard slice"
3135
+ f" in load distributed checkpoint for {param.name}. Data shape is "
3136
+ f"{split_param.data.shape} and group is {opt_shard_group}.") from e
3137
+ split_param = Parameter(Tensor(data_slice), param.name,
3138
+ split_param.requires_grad, split_param.layerwise_parallel)
3139
+ param_dict[param.name] = split_param
3140
+
3141
+ if param_not_in_strategy:
3142
+ logger.warning("For 'load_distributed_checkpoint', {} parameters in network are not in the slice strategy, "
3143
+ "you can check whether 'predict_strategy' or 'train_strategy_filename' is correct."
3144
+ .format(param_not_in_strategy))
3145
+ if param_not_in_ckpt:
3146
+ logger.warning("For 'load_distributed_checkpoint', {} parameters in network and slice strategy but not in "
3147
+ "the checkpoint file, please check whether 'checkpoint_filenames' is correct."
3148
+ .format(param_not_in_ckpt))
3149
+
3150
+ load_param_into_net(network, param_dict, strict_load=strict_load)
3151
+
3152
+
3153
+ def async_ckpt_thread_status():
3154
+ """
3155
+ Get the status of asynchronous save checkpoint thread.
3156
+
3157
+ When performing asynchronous save checkpoint, you can determine whether the asynchronous thread is completed.
3158
+
3159
+ Returns:
3160
+ bool, True, Asynchronous save checkpoint thread is running.
3161
+ False, Asynchronous save checkpoint thread is not executing.
3162
+
3163
+ Examples:
3164
+ >>> import mindspore as ms
3165
+ >>> ms.async_ckpt_thread_status()
3166
+ False
3167
+ """
3168
+ thr_list = threading.enumerate()
3169
+ return True in [ele.getName() == "asyn_save_ckpt" for ele in thr_list]
3170
+
3171
+
3172
+ def _check_predict_strategy(predict_strategy):
3173
+ """Check predict strategy."""
3174
+
3175
+ def _check_int_list(arg):
3176
+ if not isinstance(arg, list):
3177
+ return False
3178
+ for item in arg:
3179
+ if not isinstance(item, int):
3180
+ return False
3181
+ return True
3182
+
3183
+ if predict_strategy is None:
3184
+ return
3185
+
3186
+ flag = True
3187
+ predict_strategy = Validator.check_isinstance("predict_strategy", predict_strategy, dict)
3188
+ for key in predict_strategy.keys():
3189
+ if not isinstance(key, str) or not isinstance(predict_strategy[key], (list, tuple)) \
3190
+ or len(predict_strategy[key]) < 4:
3191
+ flag = False
3192
+ dev_matrix, tensor_map, param_split_shape, field_size = predict_strategy[key][:4]
3193
+ if not _check_int_list(dev_matrix) or not _check_int_list(tensor_map) or \
3194
+ not (_check_int_list(param_split_shape) or not param_split_shape) or \
3195
+ not (isinstance(field_size, int) and field_size == 0):
3196
+ flag = False
3197
+
3198
+ if not flag:
3199
+ raise ValueError(f"For 'load_distributed_checkpoint', the argument 'predict_strategy' is dict, "
3200
+ f"the key of it must be string, and the value of it must be list or tuple that "
3201
+ f"the first four elements must be dev_matrix (list[int]), tensor_map (list[int]), "
3202
+ f"param_split_shape (list[int]) and field_size (int, which value is 0)."
3203
+ f"Please check whether 'predict_strategy' is correct.")
3204
+
3205
+
3206
+ def _check_checkpoint_file(checkpoint_filenames):
3207
+ """Check checkpoint file name."""
3208
+ for index, filename in enumerate(checkpoint_filenames):
3209
+ if not isinstance(filename, str) or not os.path.exists(filename) \
3210
+ or filename[-5:] != ".ckpt" or os.path.getsize(filename) == 0:
3211
+ raise ValueError(f"For 'load_distributed_checkpoint', please check 'checkpoint_filenames', and "
3212
+ f"make sure the {filename} at index {index} is a valid checkpoint file, it must "
3213
+ f"be a string ending with '.ckpt', and the checkpoint file it represents must "
3214
+ f"be exist and not empty.")
3215
+
3216
+
3217
+ def _merge_and_split(sliced_params, train_strategy, predict_strategy):
3218
+ """Merge sliced parameter and split it according to the predict strategy."""
3219
+ merged_param = merge_sliced_parameter(sliced_params, train_strategy)
3220
+ if predict_strategy is None:
3221
+ return merged_param
3222
+ param_name = merged_param.name
3223
+ tensor_layout = predict_strategy[param_name]
3224
+ rank = get_rank()
3225
+ split_tensor = _load_tensor(merged_param.data, tensor_layout[0], tensor_layout[1], rank_id=rank)
3226
+ requires_grad = merged_param.requires_grad
3227
+ layerwise_parallel = merged_param.layerwise_parallel
3228
+ if merged_param.data.dtype == mstype.bfloat16:
3229
+ split_param = Parameter(Tensor(split_tensor, mstype.bfloat16), param_name, requires_grad, layerwise_parallel)
3230
+ else:
3231
+ split_param = Parameter(split_tensor, param_name, requires_grad, layerwise_parallel)
3232
+ return split_param
3233
+
3234
+
3235
+ def _calculation_net_size(net):
3236
+ """Calculate the size of parameters in the network."""
3237
+ data_total = 0
3238
+ net_dict = net.parameters_dict()
3239
+ for name in net_dict:
3240
+ data_total += sys.getsizeof(net_dict[name].data.get_bytes()) / 1024
3241
+
3242
+ return data_total
3243
+
3244
+
3245
+ def _get_mindir_inputs(file_name):
3246
+ """
3247
+ Get MindIR file's inputs.
3248
+
3249
+ Note:
3250
+ 1. Parsing encrypted MindIR file is not supported.
3251
+ 2. Parsing dynamic shape MindIR file is not supported.
3252
+
3253
+ Args:
3254
+ file_name (str): MindIR file name.
3255
+
3256
+ Returns:
3257
+ Tensor, list(Tensor), the input of MindIR file.
3258
+
3259
+ Raises:
3260
+ TypeError: If the parameter file_name is not `str`.
3261
+ RuntimeError: MindIR's input is not tensor type or has no dims.
3262
+
3263
+ Examples:
3264
+ >>> input_tensor = get_mindir_inputs("lenet.mindir")
3265
+ """
3266
+ Validator.check_file_name_by_regular(file_name)
3267
+ file_name = os.path.realpath(file_name)
3268
+ model = read_proto(file_name)
3269
+ input_tensor = []
3270
+
3271
+ for ele_input in model.graph.input:
3272
+ input_shape = []
3273
+ if not hasattr(ele_input, "tensor") or not hasattr(ele_input.tensor[0], "dims"):
3274
+ raise RuntimeError("MindIR's inputs has no tensor or tensor has no dims, please check MindIR file.")
3275
+
3276
+ for ele_shape in ele_input.tensor[0].dims:
3277
+ input_shape.append(ele_shape)
3278
+ if is_shape_unknown(input_shape):
3279
+ raise RuntimeError(f"MindIR input's shape is: {input_shape}, dynamic shape is not supported.")
3280
+
3281
+ mindir_type = ele_input.tensor[0].data_type
3282
+ if mindir_type not in mindir_to_tensor_type:
3283
+ raise RuntimeError(f"MindIR input's type: {mindir_type} is not supported.")
3284
+
3285
+ input_type = mindir_to_tensor_type.get(mindir_type)
3286
+ input_tensor.append(Tensor(shape=input_shape, dtype=input_type, init=One()))
3287
+
3288
+ if not input_tensor:
3289
+ logger.warning("The MindIR model has no input, return None.")
3290
+ return None
3291
+ return input_tensor[0] if len(input_tensor) == 1 else input_tensor
3292
+
3293
+
3294
+ def convert_model(mindir_file, convert_file, file_format):
3295
+ """
3296
+ Convert mindir model to other format model. The current version only supports conversion to ONNX models.
3297
+
3298
+ .. warning::
3299
+ This is an experimental API that is subject to change or deletion.
3300
+
3301
+ Args:
3302
+ mindir_file (str): MindIR file name.
3303
+ convert_file (str): Convert model file name.
3304
+ file_format (str): Convert model's format, current version only supports "ONNX".
3305
+
3306
+ Raises:
3307
+ TypeError: If the parameter `mindir_file` is not `str`.
3308
+ TypeError: If the parameter `convert_file` is not `str`.
3309
+ ValueError: If the parameter `file_format` is not "ONNX".
3310
+
3311
+ Examples:
3312
+ >>> import mindspore as ms
3313
+ >>> ms.convert_model("lenet.mindir", "lenet.onnx", "ONNX")
3314
+ """
3315
+ Validator.check_file_name_by_regular(mindir_file)
3316
+ Validator.check_file_name_by_regular(convert_file)
3317
+ if file_format != "ONNX":
3318
+ raise ValueError(f"For 'convert_model', 'file_format' must be 'ONNX', but got {file_format}.")
3319
+ net_input = _get_mindir_inputs(mindir_file)
3320
+ graph = load(mindir_file)
3321
+ net = nn.GraphCell(graph)
3322
+ if isinstance(net_input, Tensor):
3323
+ export(net, net_input, file_name=convert_file, file_format=file_format)
3324
+ else:
3325
+ export(net, *net_input, file_name=convert_file, file_format=file_format)