mindspore 2.7.0rc1__cp311-cp311-win_amd64.whl → 2.7.1__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (370) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +5 -2
  3. mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
  6. mindspore/_checkparam.py +2 -2
  7. mindspore/_extends/builtin_operations.py +3 -3
  8. mindspore/_extends/parallel_compile/akg_compiler/custom.py +1109 -0
  9. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
  10. mindspore/_extends/parse/__init__.py +3 -3
  11. mindspore/_extends/parse/compile_config.py +24 -1
  12. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +6 -3
  13. mindspore/_extends/parse/parser.py +28 -22
  14. mindspore/_extends/parse/resources.py +1 -1
  15. mindspore/_extends/parse/standard_method.py +23 -2
  16. mindspore/_extends/parse/trope.py +2 -1
  17. mindspore/_extends/pijit/pijit_func_white_list.py +9 -27
  18. mindspore/amp.py +0 -18
  19. mindspore/avcodec-59.dll +0 -0
  20. mindspore/avdevice-59.dll +0 -0
  21. mindspore/avfilter-8.dll +0 -0
  22. mindspore/avformat-59.dll +0 -0
  23. mindspore/avutil-57.dll +0 -0
  24. mindspore/boost/base.py +29 -2
  25. mindspore/common/__init__.py +18 -12
  26. mindspore/common/_decorator.py +3 -2
  27. mindspore/common/_grad_function.py +3 -1
  28. mindspore/common/_tensor_cpp_method.py +1 -1
  29. mindspore/common/_tensor_docs.py +371 -96
  30. mindspore/common/_utils.py +7 -43
  31. mindspore/common/api.py +434 -135
  32. mindspore/common/dtype.py +98 -57
  33. mindspore/common/dump.py +7 -108
  34. mindspore/common/dynamic_shape/__init__.py +0 -0
  35. mindspore/common/{auto_dynamic_shape.py → dynamic_shape/auto_dynamic_shape.py} +15 -23
  36. mindspore/common/dynamic_shape/enable_dynamic.py +197 -0
  37. mindspore/common/file_system.py +59 -9
  38. mindspore/common/hook_handle.py +82 -3
  39. mindspore/common/jit_config.py +5 -1
  40. mindspore/common/jit_trace.py +27 -12
  41. mindspore/common/lazy_inline.py +5 -3
  42. mindspore/common/np_dtype.py +3 -3
  43. mindspore/common/parameter.py +17 -127
  44. mindspore/common/recompute.py +4 -13
  45. mindspore/common/tensor.py +50 -217
  46. mindspore/communication/_comm_helper.py +11 -1
  47. mindspore/communication/comm_func.py +138 -4
  48. mindspore/communication/management.py +85 -1
  49. mindspore/config/op_info.config +0 -15
  50. mindspore/context.py +20 -106
  51. mindspore/dataset/__init__.py +1 -1
  52. mindspore/dataset/audio/transforms.py +1 -1
  53. mindspore/dataset/core/config.py +35 -1
  54. mindspore/dataset/engine/datasets.py +338 -319
  55. mindspore/dataset/engine/datasets_user_defined.py +38 -22
  56. mindspore/dataset/engine/datasets_vision.py +1 -1
  57. mindspore/dataset/engine/validators.py +1 -15
  58. mindspore/dataset/transforms/c_transforms.py +2 -2
  59. mindspore/dataset/transforms/transforms.py +3 -3
  60. mindspore/dataset/vision/__init__.py +1 -1
  61. mindspore/dataset/vision/py_transforms.py +8 -8
  62. mindspore/dataset/vision/transforms.py +17 -5
  63. mindspore/dataset/vision/utils.py +632 -21
  64. mindspore/device_context/ascend/op_tuning.py +35 -1
  65. mindspore/dnnl.dll +0 -0
  66. mindspore/{profiler/common/validator → graph}/__init__.py +9 -1
  67. mindspore/graph/custom_pass.py +55 -0
  68. mindspore/include/api/cell.h +28 -4
  69. mindspore/include/api/cfg.h +24 -7
  70. mindspore/include/api/context.h +1 -0
  71. mindspore/include/api/delegate.h +0 -2
  72. mindspore/include/api/dual_abi_helper.h +100 -19
  73. mindspore/include/api/graph.h +14 -1
  74. mindspore/include/api/kernel.h +16 -3
  75. mindspore/include/api/kernel_api.h +9 -1
  76. mindspore/include/api/metrics/accuracy.h +9 -0
  77. mindspore/include/api/model.h +5 -1
  78. mindspore/include/api/model_group.h +4 -0
  79. mindspore/include/api/model_parallel_runner.h +2 -0
  80. mindspore/include/api/status.h +48 -10
  81. mindspore/include/api/types.h +6 -1
  82. mindspore/include/dataset/constants.h +9 -0
  83. mindspore/include/dataset/execute.h +2 -2
  84. mindspore/jpeg62.dll +0 -0
  85. mindspore/mindrecord/__init__.py +3 -3
  86. mindspore/mindrecord/common/exceptions.py +1 -0
  87. mindspore/mindrecord/config.py +1 -1
  88. mindspore/{parallel/mpi → mindrecord/core}/__init__.py +4 -1
  89. mindspore/mindrecord/{shardheader.py → core/shardheader.py} +2 -1
  90. mindspore/mindrecord/{shardindexgenerator.py → core/shardindexgenerator.py} +1 -1
  91. mindspore/mindrecord/{shardreader.py → core/shardreader.py} +2 -1
  92. mindspore/mindrecord/{shardsegment.py → core/shardsegment.py} +2 -2
  93. mindspore/mindrecord/{shardutils.py → core/shardutils.py} +1 -1
  94. mindspore/mindrecord/{shardwriter.py → core/shardwriter.py} +1 -1
  95. mindspore/mindrecord/filereader.py +4 -4
  96. mindspore/mindrecord/filewriter.py +5 -5
  97. mindspore/mindrecord/mindpage.py +2 -2
  98. mindspore/mindrecord/tools/cifar10.py +4 -3
  99. mindspore/mindrecord/tools/cifar100.py +1 -1
  100. mindspore/mindrecord/tools/cifar100_to_mr.py +1 -1
  101. mindspore/mindrecord/tools/cifar10_to_mr.py +6 -6
  102. mindspore/mindrecord/tools/csv_to_mr.py +1 -1
  103. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  104. mindspore/mindrecord/tools/mnist_to_mr.py +1 -1
  105. mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -1
  106. mindspore/mindspore_backend_common.dll +0 -0
  107. mindspore/mindspore_backend_manager.dll +0 -0
  108. mindspore/mindspore_cluster.dll +0 -0
  109. mindspore/mindspore_common.dll +0 -0
  110. mindspore/mindspore_core.dll +0 -0
  111. mindspore/mindspore_cpu.dll +0 -0
  112. mindspore/mindspore_dump.dll +0 -0
  113. mindspore/mindspore_frontend.dll +0 -0
  114. mindspore/mindspore_glog.dll +0 -0
  115. mindspore/mindspore_hardware_abstract.dll +0 -0
  116. mindspore/mindspore_memory_pool.dll +0 -0
  117. mindspore/mindspore_ms_backend.dll +0 -0
  118. mindspore/mindspore_ops.dll +0 -0
  119. mindspore/{mindspore_ops_host.dll → mindspore_ops_cpu.dll} +0 -0
  120. mindspore/mindspore_profiler.dll +0 -0
  121. mindspore/mindspore_pyboost.dll +0 -0
  122. mindspore/mindspore_pynative.dll +0 -0
  123. mindspore/mindspore_runtime_pipeline.dll +0 -0
  124. mindspore/mindspore_runtime_utils.dll +0 -0
  125. mindspore/mindspore_tools.dll +0 -0
  126. mindspore/mint/__init__.py +15 -10
  127. mindspore/mint/distributed/__init__.py +4 -0
  128. mindspore/mint/distributed/distributed.py +392 -69
  129. mindspore/mint/nn/__init__.py +2 -16
  130. mindspore/mint/nn/functional.py +4 -110
  131. mindspore/mint/nn/layer/__init__.py +0 -2
  132. mindspore/mint/nn/layer/_functions.py +1 -2
  133. mindspore/mint/nn/layer/activation.py +0 -6
  134. mindspore/mint/nn/layer/basic.py +0 -47
  135. mindspore/mint/nn/layer/conv.py +10 -10
  136. mindspore/mint/nn/layer/normalization.py +11 -16
  137. mindspore/mint/nn/layer/pooling.py +0 -4
  138. mindspore/nn/__init__.py +1 -3
  139. mindspore/nn/cell.py +231 -239
  140. mindspore/nn/layer/activation.py +4 -2
  141. mindspore/nn/layer/basic.py +56 -14
  142. mindspore/nn/layer/container.py +16 -0
  143. mindspore/nn/layer/embedding.py +4 -169
  144. mindspore/nn/layer/image.py +1 -1
  145. mindspore/nn/layer/normalization.py +2 -1
  146. mindspore/nn/layer/thor_layer.py +4 -85
  147. mindspore/nn/optim/ada_grad.py +0 -1
  148. mindspore/nn/optim/adafactor.py +0 -1
  149. mindspore/nn/optim/adam.py +32 -127
  150. mindspore/nn/optim/adamax.py +0 -1
  151. mindspore/nn/optim/asgd.py +0 -1
  152. mindspore/nn/optim/ftrl.py +8 -102
  153. mindspore/nn/optim/lamb.py +1 -4
  154. mindspore/nn/optim/lars.py +0 -3
  155. mindspore/nn/optim/lazyadam.py +25 -218
  156. mindspore/nn/optim/momentum.py +5 -43
  157. mindspore/nn/optim/optimizer.py +6 -55
  158. mindspore/nn/optim/proximal_ada_grad.py +0 -1
  159. mindspore/nn/optim/rmsprop.py +0 -1
  160. mindspore/nn/optim/rprop.py +0 -1
  161. mindspore/nn/optim/sgd.py +0 -1
  162. mindspore/nn/optim/tft_wrapper.py +2 -4
  163. mindspore/nn/optim/thor.py +0 -2
  164. mindspore/nn/probability/bijector/bijector.py +7 -8
  165. mindspore/nn/probability/bijector/gumbel_cdf.py +2 -2
  166. mindspore/nn/probability/bijector/power_transform.py +20 -21
  167. mindspore/nn/probability/bijector/scalar_affine.py +5 -5
  168. mindspore/nn/probability/bijector/softplus.py +13 -14
  169. mindspore/nn/probability/distribution/_utils/utils.py +2 -2
  170. mindspore/nn/wrap/cell_wrapper.py +39 -5
  171. mindspore/nn/wrap/grad_reducer.py +4 -89
  172. mindspore/numpy/array_creations.py +4 -4
  173. mindspore/numpy/fft.py +9 -9
  174. mindspore/numpy/utils_const.py +1 -1
  175. mindspore/{nn/reinforcement → onnx}/__init__.py +5 -8
  176. mindspore/onnx/onnx_export.py +137 -0
  177. mindspore/opencv_core4110.dll +0 -0
  178. mindspore/opencv_imgcodecs4110.dll +0 -0
  179. mindspore/{opencv_imgproc452.dll → opencv_imgproc4110.dll} +0 -0
  180. mindspore/ops/__init__.py +2 -0
  181. mindspore/ops/_grad_experimental/grad_comm_ops.py +38 -2
  182. mindspore/ops/_grad_experimental/grad_inner_ops.py +0 -9
  183. mindspore/ops/_op_impl/aicpu/__init__.py +0 -10
  184. mindspore/ops/_op_impl/cpu/__init__.py +1 -5
  185. mindspore/ops/_op_impl/cpu/{buffer_append.py → joinedstr_op.py} +8 -8
  186. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +28 -24
  187. mindspore/ops/auto_generate/gen_extend_func.py +6 -11
  188. mindspore/ops/auto_generate/gen_ops_def.py +385 -154
  189. mindspore/ops/auto_generate/gen_ops_prim.py +5676 -5167
  190. mindspore/ops/communication.py +97 -0
  191. mindspore/ops/composite/__init__.py +5 -2
  192. mindspore/ops/composite/base.py +16 -2
  193. mindspore/ops/composite/multitype_ops/__init__.py +3 -1
  194. mindspore/ops/composite/multitype_ops/_compile_utils.py +150 -8
  195. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -1
  196. mindspore/ops/composite/multitype_ops/add_impl.py +7 -0
  197. mindspore/ops/composite/multitype_ops/mod_impl.py +27 -0
  198. mindspore/ops/function/__init__.py +2 -0
  199. mindspore/ops/function/array_func.py +24 -18
  200. mindspore/ops/function/comm_func.py +3883 -0
  201. mindspore/ops/function/debug_func.py +7 -6
  202. mindspore/ops/function/grad/grad_func.py +4 -12
  203. mindspore/ops/function/math_func.py +89 -86
  204. mindspore/ops/function/nn_func.py +92 -313
  205. mindspore/ops/function/random_func.py +9 -18
  206. mindspore/ops/functional.py +4 -1
  207. mindspore/ops/functional_overload.py +377 -30
  208. mindspore/ops/operations/__init__.py +2 -5
  209. mindspore/ops/operations/_custom_ops_utils.py +7 -9
  210. mindspore/ops/operations/_inner_ops.py +12 -50
  211. mindspore/ops/operations/_rl_inner_ops.py +0 -933
  212. mindspore/ops/operations/array_ops.py +5 -50
  213. mindspore/ops/operations/comm_ops.py +95 -17
  214. mindspore/ops/operations/custom_ops.py +237 -22
  215. mindspore/ops/operations/debug_ops.py +33 -35
  216. mindspore/ops/operations/manually_defined/ops_def.py +39 -318
  217. mindspore/ops/operations/math_ops.py +5 -5
  218. mindspore/ops/operations/nn_ops.py +3 -3
  219. mindspore/ops/operations/sparse_ops.py +0 -83
  220. mindspore/ops/primitive.py +4 -27
  221. mindspore/ops/tensor_method.py +88 -10
  222. mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +5 -5
  223. mindspore/ops_generate/aclnn/gen_aclnn_implement.py +8 -8
  224. mindspore/ops_generate/api/functions_cc_generator.py +53 -4
  225. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +25 -11
  226. mindspore/ops_generate/common/gen_constants.py +11 -10
  227. mindspore/ops_generate/common/op_proto.py +18 -1
  228. mindspore/ops_generate/common/template.py +102 -245
  229. mindspore/ops_generate/common/template_utils.py +212 -0
  230. mindspore/ops_generate/gen_custom_ops.py +69 -0
  231. mindspore/ops_generate/op_def/ops_def_cc_generator.py +78 -7
  232. mindspore/ops_generate/op_def_py/base_op_prim_py_generator.py +360 -0
  233. mindspore/ops_generate/op_def_py/custom_op_prim_py_generator.py +140 -0
  234. mindspore/ops_generate/op_def_py/op_def_py_generator.py +54 -7
  235. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -312
  236. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +74 -17
  237. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +22 -5
  238. mindspore/ops_generate/pyboost/gen_pyboost_func.py +0 -16
  239. mindspore/ops_generate/pyboost/op_template_parser.py +3 -2
  240. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +21 -5
  241. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +2 -2
  242. mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +30 -10
  243. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +10 -3
  244. mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +1 -1
  245. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +19 -9
  246. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +71 -28
  247. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +10 -9
  248. mindspore/ops_generate/pyboost/pyboost_utils.py +27 -16
  249. mindspore/ops_generate/resources/yaml_loader.py +13 -0
  250. mindspore/ops_generate/tensor_py_cc_generator.py +2 -2
  251. mindspore/parallel/_auto_parallel_context.py +5 -15
  252. mindspore/parallel/_cell_wrapper.py +1 -1
  253. mindspore/parallel/_parallel_serialization.py +4 -6
  254. mindspore/parallel/_ps_context.py +2 -2
  255. mindspore/parallel/_utils.py +34 -17
  256. mindspore/parallel/auto_parallel.py +23 -9
  257. mindspore/parallel/checkpoint_transform.py +20 -2
  258. mindspore/parallel/cluster/process_entity/_api.py +28 -33
  259. mindspore/parallel/cluster/process_entity/_utils.py +9 -5
  260. mindspore/parallel/cluster/run.py +5 -3
  261. mindspore/{experimental/llm_boost/ascend_native → parallel/distributed}/__init__.py +21 -22
  262. mindspore/parallel/distributed/distributed_data_parallel.py +393 -0
  263. mindspore/parallel/distributed/flatten_grad_buffer.py +295 -0
  264. mindspore/parallel/function/reshard_func.py +6 -5
  265. mindspore/parallel/nn/parallel_cell_wrapper.py +40 -3
  266. mindspore/parallel/nn/parallel_grad_reducer.py +0 -8
  267. mindspore/parallel/shard.py +7 -21
  268. mindspore/parallel/strategy.py +336 -0
  269. mindspore/parallel/transform_safetensors.py +127 -20
  270. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +13 -9
  271. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +1 -1
  272. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +1 -1
  273. mindspore/profiler/common/constant.py +5 -0
  274. mindspore/profiler/common/file_manager.py +9 -0
  275. mindspore/profiler/common/msprof_cmd_tool.py +40 -4
  276. mindspore/profiler/common/path_manager.py +65 -24
  277. mindspore/profiler/common/profiler_context.py +27 -14
  278. mindspore/profiler/common/profiler_info.py +3 -3
  279. mindspore/profiler/common/profiler_meta_data.py +1 -0
  280. mindspore/profiler/common/profiler_op_analyse.py +10 -6
  281. mindspore/profiler/common/profiler_path_manager.py +13 -0
  282. mindspore/profiler/common/util.py +30 -3
  283. mindspore/profiler/dynamic_profiler.py +91 -46
  284. mindspore/profiler/envprofiler.py +30 -5
  285. mindspore/profiler/experimental_config.py +18 -2
  286. mindspore/profiler/platform/cpu_profiler.py +10 -4
  287. mindspore/profiler/platform/npu_profiler.py +34 -7
  288. mindspore/profiler/profiler.py +193 -145
  289. mindspore/profiler/profiler_action_controller.py +1 -1
  290. mindspore/profiler/profiler_interface.py +2 -2
  291. mindspore/rewrite/symbol_tree/symbol_tree.py +1 -1
  292. mindspore/run_check/_check_version.py +108 -24
  293. mindspore/runtime/__init__.py +9 -6
  294. mindspore/runtime/executor.py +35 -0
  295. mindspore/runtime/memory.py +113 -0
  296. mindspore/runtime/thread_bind_core.py +1 -1
  297. mindspore/swresample-4.dll +0 -0
  298. mindspore/swscale-6.dll +0 -0
  299. mindspore/tinyxml2.dll +0 -0
  300. mindspore/{experimental/llm_boost → tools}/__init__.py +5 -5
  301. mindspore/tools/data_dump.py +130 -0
  302. mindspore/tools/sdc_detect.py +91 -0
  303. mindspore/tools/stress_detect.py +63 -0
  304. mindspore/train/__init__.py +6 -6
  305. mindspore/train/_utils.py +8 -21
  306. mindspore/train/amp.py +6 -7
  307. mindspore/train/callback/_callback.py +2 -1
  308. mindspore/train/callback/_checkpoint.py +1 -17
  309. mindspore/train/callback/_flops_collector.py +10 -6
  310. mindspore/train/callback/_train_fault_tolerance.py +72 -25
  311. mindspore/train/data_sink.py +5 -9
  312. mindspore/train/dataset_helper.py +5 -5
  313. mindspore/train/model.py +41 -230
  314. mindspore/train/serialization.py +160 -401
  315. mindspore/train/train_thor/model_thor.py +2 -2
  316. mindspore/turbojpeg.dll +0 -0
  317. mindspore/utils/__init__.py +6 -3
  318. mindspore/utils/dlpack.py +92 -0
  319. mindspore/utils/dryrun.py +1 -1
  320. mindspore/utils/runtime_execution_order_check.py +10 -0
  321. mindspore/utils/sdc_detect.py +14 -12
  322. mindspore/utils/stress_detect.py +43 -0
  323. mindspore/utils/utils.py +152 -16
  324. mindspore/version.py +1 -1
  325. {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/METADATA +3 -2
  326. {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/RECORD +330 -344
  327. mindspore/_extends/remote/kernel_build_server_ascend.py +0 -75
  328. mindspore/communication/_hccl_management.py +0 -297
  329. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -207
  330. mindspore/experimental/llm_boost/ascend_native/llm_boost.py +0 -52
  331. mindspore/experimental/llm_boost/atb/__init__.py +0 -23
  332. mindspore/experimental/llm_boost/atb/boost_base.py +0 -385
  333. mindspore/experimental/llm_boost/atb/llama_boost.py +0 -137
  334. mindspore/experimental/llm_boost/atb/qwen_boost.py +0 -124
  335. mindspore/experimental/llm_boost/register.py +0 -130
  336. mindspore/experimental/llm_boost/utils.py +0 -31
  337. mindspore/include/OWNERS +0 -7
  338. mindspore/mindspore_cpu_res_manager.dll +0 -0
  339. mindspore/mindspore_ops_kernel_common.dll +0 -0
  340. mindspore/mindspore_res_manager.dll +0 -0
  341. mindspore/nn/optim/_dist_optimizer_registry.py +0 -111
  342. mindspore/nn/reinforcement/_batch_read_write.py +0 -142
  343. mindspore/nn/reinforcement/_tensors_queue.py +0 -152
  344. mindspore/nn/reinforcement/tensor_array.py +0 -145
  345. mindspore/opencv_core452.dll +0 -0
  346. mindspore/opencv_imgcodecs452.dll +0 -0
  347. mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +0 -113
  348. mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +0 -96
  349. mindspore/ops/_op_impl/aicpu/sparse_cross.py +0 -42
  350. mindspore/ops/_op_impl/cpu/buffer_get.py +0 -28
  351. mindspore/ops/_op_impl/cpu/buffer_sample.py +0 -28
  352. mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +0 -42
  353. mindspore/ops/operations/_tensor_array.py +0 -359
  354. mindspore/ops/operations/rl_ops.py +0 -288
  355. mindspore/parallel/_offload_context.py +0 -275
  356. mindspore/parallel/_recovery_context.py +0 -115
  357. mindspore/parallel/_transformer/__init__.py +0 -35
  358. mindspore/parallel/_transformer/layers.py +0 -765
  359. mindspore/parallel/_transformer/loss.py +0 -251
  360. mindspore/parallel/_transformer/moe.py +0 -693
  361. mindspore/parallel/_transformer/op_parallel_config.py +0 -222
  362. mindspore/parallel/_transformer/transformer.py +0 -3124
  363. mindspore/parallel/mpi/_mpi_config.py +0 -116
  364. mindspore/profiler/common/validator/validate_path.py +0 -84
  365. mindspore/train/memory_profiling_pb2.py +0 -298
  366. mindspore/utils/hooks.py +0 -81
  367. /mindspore/common/{_auto_dynamic.py → dynamic_shape/_auto_dynamic.py} +0 -0
  368. {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/WHEEL +0 -0
  369. {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/entry_points.txt +0 -0
  370. {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/top_level.txt +0 -0
@@ -1,116 +0,0 @@
1
- # Copyright 2023 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ============================================================================
15
- """
16
- The MPI config, used to configure the MPI environment.
17
- """
18
- import threading
19
- from mindspore._c_expression import MpiConfig
20
- from mindspore._checkparam import args_type_check
21
-
22
-
23
- class _MpiConfig:
24
- """
25
- _MpiConfig is the config tool for controlling MPI
26
-
27
- Note:
28
- Create a config through instantiating MpiConfig object is not recommended.
29
- should use MpiConfig() to get the config since MpiConfig is singleton.
30
- """
31
- _instance = None
32
- _instance_lock = threading.Lock()
33
-
34
- def __init__(self):
35
- self._mpiconfig_handle = MpiConfig.get_instance()
36
-
37
- def __new__(cls, *args, **kwargs):
38
- if cls._instance is None:
39
- cls._instance_lock.acquire()
40
- cls._instance = object.__new__(cls)
41
- cls._instance_lock.release()
42
- return cls._instance
43
-
44
- def __getattribute__(self, attr):
45
- value = object.__getattribute__(self, attr)
46
- if attr == "_mpiconfig_handle" and value is None:
47
- raise ValueError("mpiconfig handle is none in MpiConfig!!!")
48
- return value
49
-
50
- @property
51
- def enable_mpi(self):
52
- """Get enable mpi."""
53
- return self._mpiconfig_handle.get_enable_mpi()
54
-
55
- @enable_mpi.setter
56
- def enable_mpi(self, enable_mpi):
57
- self._mpiconfig_handle.set_enable_mpi(enable_mpi)
58
-
59
- _k_mpi_config = None
60
-
61
-
62
- def _mpi_config():
63
- """
64
- Get the global mpi config, if mpi config is not created, create a new one.
65
-
66
- Returns:
67
- _MpiConfig, the global mpi config.
68
- """
69
- global _k_mpi_config
70
- if _k_mpi_config is None:
71
- _k_mpi_config = _MpiConfig()
72
- return _k_mpi_config
73
-
74
-
75
- @args_type_check(enable_mpi=bool)
76
- def _set_mpi_config(**kwargs):
77
- """
78
- Sets mpi config for running environment.
79
-
80
- mpi config should be configured before running your program. If there is no configuration,
81
- mpi module will be disabled by default.
82
-
83
- Note:
84
- Attribute name is required for setting attributes.
85
-
86
- Args:
87
- enable_mpi (bool): Whether to enable mpi. Default: False.
88
-
89
- Raises:
90
- ValueError: If input key is not an attribute in mpi config.
91
-
92
- Examples:
93
- >>> mpiconfig.set_mpi_config(enable_mpi=True)
94
- """
95
- for key, value in kwargs.items():
96
- if not hasattr(_mpi_config(), key):
97
- raise ValueError("Set mpi config keyword %s is not recognized!" % key)
98
- setattr(_mpi_config(), key, value)
99
-
100
-
101
- def _get_mpi_config(attr_key):
102
- """
103
- Gets mpi config attribute value according to the input key.
104
-
105
- Args:
106
- attr_key (str): The key of the attribute.
107
-
108
- Returns:
109
- Object, The value of given attribute key.
110
-
111
- Raises:
112
- ValueError: If input key is not an attribute in config.
113
- """
114
- if not hasattr(_mpi_config(), attr_key):
115
- raise ValueError("Get context keyword %s is not recognized!" % attr_key)
116
- return getattr(_mpi_config(), attr_key)
@@ -1,84 +0,0 @@
1
- # Copyright 2019 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ============================================================================
15
- """Validate the input path."""
16
- import os
17
- import re
18
-
19
-
20
- def check_valid_character_of_path(file_path):
21
- """
22
- Validates path.
23
-
24
- The output path of profiler only supports alphabets(a-zA-Z), digit(0-9) or {'-', '_', '.', '/'}.
25
-
26
- Note:
27
- Chinese and other paths are not supported at present.
28
-
29
- Args:
30
- path (str): Normalized Path.
31
-
32
- Returns:
33
- bool, whether valid.
34
- """
35
- re_path = r'^[/\\_a-zA-Z0-9-_.@]+$'
36
- path_valid = re.fullmatch(re_path, file_path)
37
- if not path_valid:
38
- msg = "The output path of profiler only supports alphabets(a-zA-Z), " \
39
- "digit(0-9) or {'-', '_', '.', '/', '@'}, but got the absolute path= " + file_path
40
- raise RuntimeError(msg)
41
-
42
-
43
- def validate_and_normalize_path(
44
- path,
45
- check_absolute_path=False,
46
- allow_parent_dir=True,
47
- ):
48
- """
49
- Validates path and returns its normalized form.
50
-
51
- If path has a valid scheme, treat path as url, otherwise consider path a
52
- unix local path.
53
-
54
- Note:
55
- File scheme (rfc8089) is currently not supported.
56
-
57
- Args:
58
- path (str): Path to be normalized.
59
- check_absolute_path (bool): Whether check path scheme is supported.
60
- allow_parent_dir (bool): Whether allow parent dir in path.
61
-
62
- Returns:
63
- str, normalized path.
64
- """
65
- if not path:
66
- raise RuntimeError("The path is invalid!")
67
-
68
- path_str = str(path)
69
- if not allow_parent_dir:
70
- path_components = path_str.split("/")
71
- if ".." in path_components:
72
- raise RuntimeError("The parent path is not allowed!")
73
-
74
- # path does not have valid schema, treat it as unix local path.
75
- if check_absolute_path:
76
- if not path_str.startswith("/"):
77
- raise RuntimeError("The path is invalid!")
78
- try:
79
- # most unix systems allow
80
- normalized_path = os.path.realpath(path)
81
- except ValueError as err:
82
- raise RuntimeError("The path is invalid!") from err
83
- check_valid_character_of_path(normalized_path)
84
- return normalized_path
@@ -1,298 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- # Generated by the protocol buffer compiler. DO NOT EDIT!
3
- # source: memory_profiling.proto
4
- """Generated protocol buffer code."""
5
- from google.protobuf import descriptor as _descriptor
6
- from google.protobuf import message as _message
7
- from google.protobuf import reflection as _reflection
8
- from google.protobuf import symbol_database as _symbol_database
9
- # @@protoc_insertion_point(imports)
10
-
11
- _sym_db = _symbol_database.Default()
12
-
13
-
14
-
15
-
16
- DESCRIPTOR = _descriptor.FileDescriptor(
17
- name='memory_profiling.proto',
18
- package='mindspore.profiler',
19
- syntax='proto3',
20
- serialized_options=None,
21
- create_key=_descriptor._internal_create_key,
22
- serialized_pb=b'\n\x16memory_profiling.proto\x12\x12mindspore.profiler\"V\n\x0bMemoryProto\x12\x34\n\tgraph_mem\x18\x01 \x03(\x0b\x32!.mindspore.profiler.GraphMemProto\x12\x11\n\ttotal_mem\x18\x02 \x01(\x04\"\xc5\x01\n\rGraphMemProto\x12\x10\n\x08graph_id\x18\x01 \x01(\x03\x12\x12\n\nstatic_mem\x18\x02 \x01(\x03\x12\x33\n\tnode_mems\x18\x03 \x03(\x0b\x32 .mindspore.profiler.NodeMemProto\x12\x37\n\x0btensor_mems\x18\x04 \x03(\x0b\x32\".mindspore.profiler.TensorMemProto\x12\x10\n\x08\x66p_start\x18\x05 \x01(\t\x12\x0e\n\x06\x62p_end\x18\x06 \x01(\t\"\x82\x01\n\x0cNodeMemProto\x12\x11\n\tnode_name\x18\x01 \x01(\t\x12\x0f\n\x07node_id\x18\x02 \x01(\x04\x12\x17\n\x0finput_tensor_id\x18\x03 \x03(\x04\x12\x18\n\x10output_tensor_id\x18\x04 \x03(\x04\x12\x1b\n\x13workspace_tensor_id\x18\x05 \x03(\x04\"x\n\x0eTensorMemProto\x12\x11\n\ttensor_id\x18\x01 \x01(\x04\x12\x0c\n\x04size\x18\x02 \x01(\x04\x12\x0c\n\x04type\x18\x03 \x01(\t\x12\x12\n\nlife_start\x18\x04 \x01(\x04\x12\x10\n\x08life_end\x18\x05 \x01(\x04\x12\x11\n\tlife_long\x18\x06 \x01(\tb\x06proto3'
23
- )
24
-
25
-
26
-
27
-
28
- _MEMORYPROTO = _descriptor.Descriptor(
29
- name='MemoryProto',
30
- full_name='mindspore.profiler.MemoryProto',
31
- filename=None,
32
- file=DESCRIPTOR,
33
- containing_type=None,
34
- create_key=_descriptor._internal_create_key,
35
- fields=[
36
- _descriptor.FieldDescriptor(
37
- name='graph_mem', full_name='mindspore.profiler.MemoryProto.graph_mem', index=0,
38
- number=1, type=11, cpp_type=10, label=3,
39
- has_default_value=False, default_value=[],
40
- message_type=None, enum_type=None, containing_type=None,
41
- is_extension=False, extension_scope=None,
42
- serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
43
- _descriptor.FieldDescriptor(
44
- name='total_mem', full_name='mindspore.profiler.MemoryProto.total_mem', index=1,
45
- number=2, type=4, cpp_type=4, label=1,
46
- has_default_value=False, default_value=0,
47
- message_type=None, enum_type=None, containing_type=None,
48
- is_extension=False, extension_scope=None,
49
- serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
50
- ],
51
- extensions=[
52
- ],
53
- nested_types=[],
54
- enum_types=[
55
- ],
56
- serialized_options=None,
57
- is_extendable=False,
58
- syntax='proto3',
59
- extension_ranges=[],
60
- oneofs=[
61
- ],
62
- serialized_start=46,
63
- serialized_end=132,
64
- )
65
-
66
-
67
- _GRAPHMEMPROTO = _descriptor.Descriptor(
68
- name='GraphMemProto',
69
- full_name='mindspore.profiler.GraphMemProto',
70
- filename=None,
71
- file=DESCRIPTOR,
72
- containing_type=None,
73
- create_key=_descriptor._internal_create_key,
74
- fields=[
75
- _descriptor.FieldDescriptor(
76
- name='graph_id', full_name='mindspore.profiler.GraphMemProto.graph_id', index=0,
77
- number=1, type=3, cpp_type=2, label=1,
78
- has_default_value=False, default_value=0,
79
- message_type=None, enum_type=None, containing_type=None,
80
- is_extension=False, extension_scope=None,
81
- serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
82
- _descriptor.FieldDescriptor(
83
- name='static_mem', full_name='mindspore.profiler.GraphMemProto.static_mem', index=1,
84
- number=2, type=3, cpp_type=2, label=1,
85
- has_default_value=False, default_value=0,
86
- message_type=None, enum_type=None, containing_type=None,
87
- is_extension=False, extension_scope=None,
88
- serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
89
- _descriptor.FieldDescriptor(
90
- name='node_mems', full_name='mindspore.profiler.GraphMemProto.node_mems', index=2,
91
- number=3, type=11, cpp_type=10, label=3,
92
- has_default_value=False, default_value=[],
93
- message_type=None, enum_type=None, containing_type=None,
94
- is_extension=False, extension_scope=None,
95
- serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
96
- _descriptor.FieldDescriptor(
97
- name='tensor_mems', full_name='mindspore.profiler.GraphMemProto.tensor_mems', index=3,
98
- number=4, type=11, cpp_type=10, label=3,
99
- has_default_value=False, default_value=[],
100
- message_type=None, enum_type=None, containing_type=None,
101
- is_extension=False, extension_scope=None,
102
- serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
103
- _descriptor.FieldDescriptor(
104
- name='fp_start', full_name='mindspore.profiler.GraphMemProto.fp_start', index=4,
105
- number=5, type=9, cpp_type=9, label=1,
106
- has_default_value=False, default_value=b"".decode('utf-8'),
107
- message_type=None, enum_type=None, containing_type=None,
108
- is_extension=False, extension_scope=None,
109
- serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
110
- _descriptor.FieldDescriptor(
111
- name='bp_end', full_name='mindspore.profiler.GraphMemProto.bp_end', index=5,
112
- number=6, type=9, cpp_type=9, label=1,
113
- has_default_value=False, default_value=b"".decode('utf-8'),
114
- message_type=None, enum_type=None, containing_type=None,
115
- is_extension=False, extension_scope=None,
116
- serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
117
- ],
118
- extensions=[
119
- ],
120
- nested_types=[],
121
- enum_types=[
122
- ],
123
- serialized_options=None,
124
- is_extendable=False,
125
- syntax='proto3',
126
- extension_ranges=[],
127
- oneofs=[
128
- ],
129
- serialized_start=135,
130
- serialized_end=332,
131
- )
132
-
133
-
134
- _NODEMEMPROTO = _descriptor.Descriptor(
135
- name='NodeMemProto',
136
- full_name='mindspore.profiler.NodeMemProto',
137
- filename=None,
138
- file=DESCRIPTOR,
139
- containing_type=None,
140
- create_key=_descriptor._internal_create_key,
141
- fields=[
142
- _descriptor.FieldDescriptor(
143
- name='node_name', full_name='mindspore.profiler.NodeMemProto.node_name', index=0,
144
- number=1, type=9, cpp_type=9, label=1,
145
- has_default_value=False, default_value=b"".decode('utf-8'),
146
- message_type=None, enum_type=None, containing_type=None,
147
- is_extension=False, extension_scope=None,
148
- serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
149
- _descriptor.FieldDescriptor(
150
- name='node_id', full_name='mindspore.profiler.NodeMemProto.node_id', index=1,
151
- number=2, type=4, cpp_type=4, label=1,
152
- has_default_value=False, default_value=0,
153
- message_type=None, enum_type=None, containing_type=None,
154
- is_extension=False, extension_scope=None,
155
- serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
156
- _descriptor.FieldDescriptor(
157
- name='input_tensor_id', full_name='mindspore.profiler.NodeMemProto.input_tensor_id', index=2,
158
- number=3, type=4, cpp_type=4, label=3,
159
- has_default_value=False, default_value=[],
160
- message_type=None, enum_type=None, containing_type=None,
161
- is_extension=False, extension_scope=None,
162
- serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
163
- _descriptor.FieldDescriptor(
164
- name='output_tensor_id', full_name='mindspore.profiler.NodeMemProto.output_tensor_id', index=3,
165
- number=4, type=4, cpp_type=4, label=3,
166
- has_default_value=False, default_value=[],
167
- message_type=None, enum_type=None, containing_type=None,
168
- is_extension=False, extension_scope=None,
169
- serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
170
- _descriptor.FieldDescriptor(
171
- name='workspace_tensor_id', full_name='mindspore.profiler.NodeMemProto.workspace_tensor_id', index=4,
172
- number=5, type=4, cpp_type=4, label=3,
173
- has_default_value=False, default_value=[],
174
- message_type=None, enum_type=None, containing_type=None,
175
- is_extension=False, extension_scope=None,
176
- serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
177
- ],
178
- extensions=[
179
- ],
180
- nested_types=[],
181
- enum_types=[
182
- ],
183
- serialized_options=None,
184
- is_extendable=False,
185
- syntax='proto3',
186
- extension_ranges=[],
187
- oneofs=[
188
- ],
189
- serialized_start=335,
190
- serialized_end=465,
191
- )
192
-
193
-
194
- _TENSORMEMPROTO = _descriptor.Descriptor(
195
- name='TensorMemProto',
196
- full_name='mindspore.profiler.TensorMemProto',
197
- filename=None,
198
- file=DESCRIPTOR,
199
- containing_type=None,
200
- create_key=_descriptor._internal_create_key,
201
- fields=[
202
- _descriptor.FieldDescriptor(
203
- name='tensor_id', full_name='mindspore.profiler.TensorMemProto.tensor_id', index=0,
204
- number=1, type=4, cpp_type=4, label=1,
205
- has_default_value=False, default_value=0,
206
- message_type=None, enum_type=None, containing_type=None,
207
- is_extension=False, extension_scope=None,
208
- serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
209
- _descriptor.FieldDescriptor(
210
- name='size', full_name='mindspore.profiler.TensorMemProto.size', index=1,
211
- number=2, type=4, cpp_type=4, label=1,
212
- has_default_value=False, default_value=0,
213
- message_type=None, enum_type=None, containing_type=None,
214
- is_extension=False, extension_scope=None,
215
- serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
216
- _descriptor.FieldDescriptor(
217
- name='type', full_name='mindspore.profiler.TensorMemProto.type', index=2,
218
- number=3, type=9, cpp_type=9, label=1,
219
- has_default_value=False, default_value=b"".decode('utf-8'),
220
- message_type=None, enum_type=None, containing_type=None,
221
- is_extension=False, extension_scope=None,
222
- serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
223
- _descriptor.FieldDescriptor(
224
- name='life_start', full_name='mindspore.profiler.TensorMemProto.life_start', index=3,
225
- number=4, type=4, cpp_type=4, label=1,
226
- has_default_value=False, default_value=0,
227
- message_type=None, enum_type=None, containing_type=None,
228
- is_extension=False, extension_scope=None,
229
- serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
230
- _descriptor.FieldDescriptor(
231
- name='life_end', full_name='mindspore.profiler.TensorMemProto.life_end', index=4,
232
- number=5, type=4, cpp_type=4, label=1,
233
- has_default_value=False, default_value=0,
234
- message_type=None, enum_type=None, containing_type=None,
235
- is_extension=False, extension_scope=None,
236
- serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
237
- _descriptor.FieldDescriptor(
238
- name='life_long', full_name='mindspore.profiler.TensorMemProto.life_long', index=5,
239
- number=6, type=9, cpp_type=9, label=1,
240
- has_default_value=False, default_value=b"".decode('utf-8'),
241
- message_type=None, enum_type=None, containing_type=None,
242
- is_extension=False, extension_scope=None,
243
- serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
244
- ],
245
- extensions=[
246
- ],
247
- nested_types=[],
248
- enum_types=[
249
- ],
250
- serialized_options=None,
251
- is_extendable=False,
252
- syntax='proto3',
253
- extension_ranges=[],
254
- oneofs=[
255
- ],
256
- serialized_start=467,
257
- serialized_end=587,
258
- )
259
-
260
- _MEMORYPROTO.fields_by_name['graph_mem'].message_type = _GRAPHMEMPROTO
261
- _GRAPHMEMPROTO.fields_by_name['node_mems'].message_type = _NODEMEMPROTO
262
- _GRAPHMEMPROTO.fields_by_name['tensor_mems'].message_type = _TENSORMEMPROTO
263
- DESCRIPTOR.message_types_by_name['MemoryProto'] = _MEMORYPROTO
264
- DESCRIPTOR.message_types_by_name['GraphMemProto'] = _GRAPHMEMPROTO
265
- DESCRIPTOR.message_types_by_name['NodeMemProto'] = _NODEMEMPROTO
266
- DESCRIPTOR.message_types_by_name['TensorMemProto'] = _TENSORMEMPROTO
267
- _sym_db.RegisterFileDescriptor(DESCRIPTOR)
268
-
269
- MemoryProto = _reflection.GeneratedProtocolMessageType('MemoryProto', (_message.Message,), {
270
- 'DESCRIPTOR' : _MEMORYPROTO,
271
- '__module__' : 'memory_profiling_pb2'
272
- # @@protoc_insertion_point(class_scope:mindspore.profiler.MemoryProto)
273
- })
274
- _sym_db.RegisterMessage(MemoryProto)
275
-
276
- GraphMemProto = _reflection.GeneratedProtocolMessageType('GraphMemProto', (_message.Message,), {
277
- 'DESCRIPTOR' : _GRAPHMEMPROTO,
278
- '__module__' : 'memory_profiling_pb2'
279
- # @@protoc_insertion_point(class_scope:mindspore.profiler.GraphMemProto)
280
- })
281
- _sym_db.RegisterMessage(GraphMemProto)
282
-
283
- NodeMemProto = _reflection.GeneratedProtocolMessageType('NodeMemProto', (_message.Message,), {
284
- 'DESCRIPTOR' : _NODEMEMPROTO,
285
- '__module__' : 'memory_profiling_pb2'
286
- # @@protoc_insertion_point(class_scope:mindspore.profiler.NodeMemProto)
287
- })
288
- _sym_db.RegisterMessage(NodeMemProto)
289
-
290
- TensorMemProto = _reflection.GeneratedProtocolMessageType('TensorMemProto', (_message.Message,), {
291
- 'DESCRIPTOR' : _TENSORMEMPROTO,
292
- '__module__' : 'memory_profiling_pb2'
293
- # @@protoc_insertion_point(class_scope:mindspore.profiler.TensorMemProto)
294
- })
295
- _sym_db.RegisterMessage(TensorMemProto)
296
-
297
-
298
- # @@protoc_insertion_point(module_scope)
mindspore/utils/hooks.py DELETED
@@ -1,81 +0,0 @@
1
- # Copyright 2025 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ============================================================================
15
- """hooks"""
16
- from collections import OrderedDict
17
- import weakref
18
- from typing import Any, Tuple
19
-
20
-
21
- class _RemovableHandle:
22
- r"""
23
- A handle which provides the capability to remove a hook.
24
-
25
- Args:
26
- hooks_dict (dict): A dictionary of hooks, indexed by hook `id`.
27
-
28
- Keyword Args:
29
- extra_dict (Union[dict, list[dict]], optional): An additional dictionary or list of
30
- dictionaries whose keys will be deleted when the same keys are
31
- removed from `hooks_dict`. Default ``None``.
32
- """
33
-
34
- id: int
35
- next_id: int = 0
36
-
37
- def __init__(self, hooks_dict: Any, *, extra_dict: Any = None) -> None:
38
- self.hooks_dict_ref = weakref.ref(hooks_dict)
39
- self.id = _RemovableHandle.next_id
40
- _RemovableHandle.next_id += 1
41
-
42
- self.extra_dict_ref: Tuple = ()
43
- if isinstance(extra_dict, dict):
44
- self.extra_dict_ref = (weakref.ref(extra_dict),)
45
- elif isinstance(extra_dict, list):
46
- self.extra_dict_ref = tuple(weakref.ref(d) for d in extra_dict)
47
-
48
- def remove(self) -> None:
49
- hooks_dict = self.hooks_dict_ref()
50
- if hooks_dict is not None and self.id in hooks_dict:
51
- del hooks_dict[self.id]
52
-
53
- for ref in self.extra_dict_ref:
54
- extra_dict = ref()
55
- if extra_dict is not None and self.id in extra_dict:
56
- del extra_dict[self.id]
57
-
58
- def __getstate__(self):
59
- if self.extra_dict_ref is None:
60
- return (self.hooks_dict_ref(), self.id)
61
- return (self.hooks_dict_ref(), self.id, tuple(ref() for ref in self.extra_dict_ref))
62
-
63
- def __setstate__(self, state) -> None:
64
- if state[0] is None:
65
- # create a dead reference
66
- self.hooks_dict_ref = weakref.ref(OrderedDict())
67
- else:
68
- self.hooks_dict_ref = weakref.ref(state[0])
69
- self.id = state[1]
70
- _RemovableHandle.next_id = max(_RemovableHandle.next_id, self.id + 1)
71
-
72
- if len(state) < 3 or state[2] is None:
73
- self.extra_dict_ref = ()
74
- else:
75
- self.extra_dict_ref = tuple(weakref.ref(d) for d in state[2])
76
-
77
- def __enter__(self) -> "_RemovableHandle":
78
- return self
79
-
80
- def __exit__(self, type: Any, value: Any, tb: Any) -> None:
81
- self.remove()