mindspore 2.4.0__cp311-cp311-macosx_10_15_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (1387) hide show
  1. mindspore/.commit_id +1 -0
  2. mindspore/__init__.py +53 -0
  3. mindspore/_c_dataengine.cpython-311-darwin.so +0 -0
  4. mindspore/_c_expression.cpython-311-darwin.so +0 -0
  5. mindspore/_c_mindrecord.cpython-311-darwin.so +0 -0
  6. mindspore/_check_jit_forbidden_api.py +106 -0
  7. mindspore/_checkparam.py +1419 -0
  8. mindspore/_extends/__init__.py +23 -0
  9. mindspore/_extends/builtin_operations.py +224 -0
  10. mindspore/_extends/graph_kernel/__init__.py +17 -0
  11. mindspore/_extends/graph_kernel/model/__init__.py +19 -0
  12. mindspore/_extends/graph_kernel/model/graph_parallel.py +311 -0
  13. mindspore/_extends/graph_kernel/model/graph_split.py +1348 -0
  14. mindspore/_extends/graph_kernel/model/model.py +553 -0
  15. mindspore/_extends/graph_kernel/model/model_builder.py +216 -0
  16. mindspore/_extends/graph_kernel/parallel_estimate.py +60 -0
  17. mindspore/_extends/graph_kernel/splitter.py +140 -0
  18. mindspore/_extends/graph_kernel/utils.py +28 -0
  19. mindspore/_extends/parallel_compile/__init__.py +19 -0
  20. mindspore/_extends/parallel_compile/akg_compiler/__init__.py +19 -0
  21. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +269 -0
  22. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +529 -0
  23. mindspore/_extends/parallel_compile/akg_compiler/compiler.py +56 -0
  24. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +96 -0
  25. mindspore/_extends/parallel_compile/akg_compiler/get_file_path.py +36 -0
  26. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +556 -0
  27. mindspore/_extends/parallel_compile/akg_compiler/util.py +159 -0
  28. mindspore/_extends/parse/__init__.py +49 -0
  29. mindspore/_extends/parse/compile_config.py +299 -0
  30. mindspore/_extends/parse/namespace.py +136 -0
  31. mindspore/_extends/parse/parser.py +1448 -0
  32. mindspore/_extends/parse/resources.py +213 -0
  33. mindspore/_extends/parse/standard_method.py +4475 -0
  34. mindspore/_extends/parse/trope.py +97 -0
  35. mindspore/_extends/pijit/__init__.py +23 -0
  36. mindspore/_extends/pijit/pijit_func_white_list.py +669 -0
  37. mindspore/_extends/remote/__init__.py +19 -0
  38. mindspore/_extends/remote/kernel_build_server.py +199 -0
  39. mindspore/_extends/remote/kernel_build_server_akg.py +55 -0
  40. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  41. mindspore/_extends/remote/kernel_build_server_ascend.py +75 -0
  42. mindspore/_extends/utils.py +68 -0
  43. mindspore/_install_custom.py +43 -0
  44. mindspore/_profiler.py +30 -0
  45. mindspore/amp.py +433 -0
  46. mindspore/boost/__init__.py +42 -0
  47. mindspore/boost/adasum.py +319 -0
  48. mindspore/boost/base.py +535 -0
  49. mindspore/boost/boost.py +400 -0
  50. mindspore/boost/boost_cell_wrapper.py +790 -0
  51. mindspore/boost/dim_reduce.py +323 -0
  52. mindspore/boost/grad_accumulation.py +79 -0
  53. mindspore/boost/grad_freeze.py +382 -0
  54. mindspore/boost/group_loss_scale_manager.py +166 -0
  55. mindspore/boost/less_batch_normalization.py +174 -0
  56. mindspore/common/__init__.py +86 -0
  57. mindspore/common/_auto_dynamic.py +68 -0
  58. mindspore/common/_decorator.py +50 -0
  59. mindspore/common/_jit_fallback_utils.py +110 -0
  60. mindspore/common/_monad.py +25 -0
  61. mindspore/common/_pijit_context.py +190 -0
  62. mindspore/common/_register_for_adapter.py +74 -0
  63. mindspore/common/_register_for_recompute.py +48 -0
  64. mindspore/common/_register_for_tensor.py +46 -0
  65. mindspore/common/_stub_tensor.py +210 -0
  66. mindspore/common/_tensor_overload.py +139 -0
  67. mindspore/common/_utils.py +122 -0
  68. mindspore/common/api.py +2064 -0
  69. mindspore/common/auto_dynamic_shape.py +507 -0
  70. mindspore/common/dtype.py +422 -0
  71. mindspore/common/dump.py +130 -0
  72. mindspore/common/file_system.py +48 -0
  73. mindspore/common/generator.py +254 -0
  74. mindspore/common/hook_handle.py +143 -0
  75. mindspore/common/initializer.py +880 -0
  76. mindspore/common/jit_config.py +98 -0
  77. mindspore/common/lazy_inline.py +240 -0
  78. mindspore/common/mindir_util.py +111 -0
  79. mindspore/common/mutable.py +234 -0
  80. mindspore/common/no_inline.py +54 -0
  81. mindspore/common/np_dtype.py +25 -0
  82. mindspore/common/parameter.py +1081 -0
  83. mindspore/common/recompute.py +292 -0
  84. mindspore/common/seed.py +260 -0
  85. mindspore/common/sparse_tensor.py +1175 -0
  86. mindspore/common/symbol.py +122 -0
  87. mindspore/common/tensor.py +5039 -0
  88. mindspore/communication/__init__.py +37 -0
  89. mindspore/communication/_comm_helper.py +501 -0
  90. mindspore/communication/_hccl_management.py +297 -0
  91. mindspore/communication/comm_func.py +1395 -0
  92. mindspore/communication/management.py +673 -0
  93. mindspore/config/op_info.config +533 -0
  94. mindspore/context.py +2077 -0
  95. mindspore/dataset/__init__.py +90 -0
  96. mindspore/dataset/audio/__init__.py +61 -0
  97. mindspore/dataset/audio/transforms.py +3690 -0
  98. mindspore/dataset/audio/utils.py +386 -0
  99. mindspore/dataset/audio/validators.py +1172 -0
  100. mindspore/dataset/callback/__init__.py +20 -0
  101. mindspore/dataset/callback/ds_callback.py +368 -0
  102. mindspore/dataset/callback/validators.py +32 -0
  103. mindspore/dataset/core/__init__.py +13 -0
  104. mindspore/dataset/core/config.py +1095 -0
  105. mindspore/dataset/core/datatypes.py +101 -0
  106. mindspore/dataset/core/py_util_helpers.py +65 -0
  107. mindspore/dataset/core/validator_helpers.py +781 -0
  108. mindspore/dataset/debug/__init__.py +21 -0
  109. mindspore/dataset/debug/debug_hook.py +97 -0
  110. mindspore/dataset/debug/pre_defined_hook.py +67 -0
  111. mindspore/dataset/engine/__init__.py +124 -0
  112. mindspore/dataset/engine/cache_admin.py +47 -0
  113. mindspore/dataset/engine/cache_client.py +129 -0
  114. mindspore/dataset/engine/datasets.py +4582 -0
  115. mindspore/dataset/engine/datasets_audio.py +911 -0
  116. mindspore/dataset/engine/datasets_standard_format.py +543 -0
  117. mindspore/dataset/engine/datasets_text.py +2161 -0
  118. mindspore/dataset/engine/datasets_user_defined.py +1184 -0
  119. mindspore/dataset/engine/datasets_vision.py +4816 -0
  120. mindspore/dataset/engine/iterators.py +371 -0
  121. mindspore/dataset/engine/obs/__init__.py +23 -0
  122. mindspore/dataset/engine/obs/config_loader.py +68 -0
  123. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +508 -0
  124. mindspore/dataset/engine/obs/util.py +482 -0
  125. mindspore/dataset/engine/offload.py +596 -0
  126. mindspore/dataset/engine/queue.py +304 -0
  127. mindspore/dataset/engine/samplers.py +895 -0
  128. mindspore/dataset/engine/serializer_deserializer.py +159 -0
  129. mindspore/dataset/engine/validators.py +2895 -0
  130. mindspore/dataset/text/__init__.py +51 -0
  131. mindspore/dataset/text/transforms.py +1703 -0
  132. mindspore/dataset/text/utils.py +715 -0
  133. mindspore/dataset/text/validators.py +642 -0
  134. mindspore/dataset/transforms/__init__.py +45 -0
  135. mindspore/dataset/transforms/c_transforms.py +638 -0
  136. mindspore/dataset/transforms/py_transforms.py +393 -0
  137. mindspore/dataset/transforms/py_transforms_util.py +255 -0
  138. mindspore/dataset/transforms/transforms.py +1260 -0
  139. mindspore/dataset/transforms/validators.py +410 -0
  140. mindspore/dataset/utils/__init__.py +19 -0
  141. mindspore/dataset/utils/browse_dataset.py +190 -0
  142. mindspore/dataset/utils/line_reader.py +126 -0
  143. mindspore/dataset/vision/__init__.py +65 -0
  144. mindspore/dataset/vision/c_transforms.py +2641 -0
  145. mindspore/dataset/vision/py_transforms.py +2120 -0
  146. mindspore/dataset/vision/py_transforms_util.py +1660 -0
  147. mindspore/dataset/vision/transforms.py +7295 -0
  148. mindspore/dataset/vision/utils.py +863 -0
  149. mindspore/dataset/vision/validators.py +1483 -0
  150. mindspore/default_config.py +2 -0
  151. mindspore/experimental/__init__.py +20 -0
  152. mindspore/experimental/es/__init__.py +22 -0
  153. mindspore/experimental/es/embedding_service.py +883 -0
  154. mindspore/experimental/es/embedding_service_layer.py +581 -0
  155. mindspore/experimental/llm_boost/__init__.py +21 -0
  156. mindspore/experimental/llm_boost/atb/__init__.py +23 -0
  157. mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
  158. mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
  159. mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
  160. mindspore/experimental/llm_boost/register.py +129 -0
  161. mindspore/experimental/llm_boost/utils.py +31 -0
  162. mindspore/experimental/map_parameter.py +309 -0
  163. mindspore/experimental/optim/__init__.py +40 -0
  164. mindspore/experimental/optim/adadelta.py +161 -0
  165. mindspore/experimental/optim/adagrad.py +168 -0
  166. mindspore/experimental/optim/adam.py +193 -0
  167. mindspore/experimental/optim/adamax.py +170 -0
  168. mindspore/experimental/optim/adamw.py +290 -0
  169. mindspore/experimental/optim/asgd.py +153 -0
  170. mindspore/experimental/optim/lr_scheduler.py +1371 -0
  171. mindspore/experimental/optim/nadam.py +157 -0
  172. mindspore/experimental/optim/optimizer.py +262 -0
  173. mindspore/experimental/optim/radam.py +194 -0
  174. mindspore/experimental/optim/rmsprop.py +154 -0
  175. mindspore/experimental/optim/rprop.py +164 -0
  176. mindspore/experimental/optim/sgd.py +156 -0
  177. mindspore/hal/__init__.py +40 -0
  178. mindspore/hal/_ascend.py +57 -0
  179. mindspore/hal/_base.py +57 -0
  180. mindspore/hal/_cpu.py +56 -0
  181. mindspore/hal/_gpu.py +57 -0
  182. mindspore/hal/contiguous_tensors_handle.py +175 -0
  183. mindspore/hal/device.py +356 -0
  184. mindspore/hal/event.py +179 -0
  185. mindspore/hal/memory.py +326 -0
  186. mindspore/hal/stream.py +357 -0
  187. mindspore/include/OWNERS +7 -0
  188. mindspore/include/api/allocator.h +97 -0
  189. mindspore/include/api/callback/callback.h +93 -0
  190. mindspore/include/api/callback/ckpt_saver.h +41 -0
  191. mindspore/include/api/callback/loss_monitor.h +33 -0
  192. mindspore/include/api/callback/lr_scheduler.h +51 -0
  193. mindspore/include/api/callback/time_monitor.h +34 -0
  194. mindspore/include/api/callback/train_accuracy.h +37 -0
  195. mindspore/include/api/cell.h +90 -0
  196. mindspore/include/api/cfg.h +82 -0
  197. mindspore/include/api/context.h +602 -0
  198. mindspore/include/api/data_type.h +47 -0
  199. mindspore/include/api/delegate.h +178 -0
  200. mindspore/include/api/delegate_api.h +75 -0
  201. mindspore/include/api/dual_abi_helper.h +208 -0
  202. mindspore/include/api/format.h +28 -0
  203. mindspore/include/api/graph.h +46 -0
  204. mindspore/include/api/kernel.h +58 -0
  205. mindspore/include/api/kernel_api.h +168 -0
  206. mindspore/include/api/metrics/accuracy.h +36 -0
  207. mindspore/include/api/metrics/metrics.h +41 -0
  208. mindspore/include/api/model.h +438 -0
  209. mindspore/include/api/model_group.h +91 -0
  210. mindspore/include/api/model_parallel_runner.h +168 -0
  211. mindspore/include/api/serialization.h +185 -0
  212. mindspore/include/api/status.h +192 -0
  213. mindspore/include/api/types.h +431 -0
  214. mindspore/include/api/visible.h +41 -0
  215. mindspore/include/c_api/context_c.h +179 -0
  216. mindspore/include/c_api/data_type_c.h +52 -0
  217. mindspore/include/c_api/format_c.h +46 -0
  218. mindspore/include/c_api/model_c.h +347 -0
  219. mindspore/include/c_api/status_c.h +79 -0
  220. mindspore/include/c_api/tensor_c.h +146 -0
  221. mindspore/include/c_api/types_c.h +67 -0
  222. mindspore/include/dataset/config.h +163 -0
  223. mindspore/include/dataset/constants.h +363 -0
  224. mindspore/include/dataset/execute.h +196 -0
  225. mindspore/include/dataset/text.h +1092 -0
  226. mindspore/include/dataset/transforms.h +638 -0
  227. mindspore/include/dataset/vision.h +2129 -0
  228. mindspore/include/dataset/vision_ascend.h +206 -0
  229. mindspore/include/dataset/vision_lite.h +625 -0
  230. mindspore/lib/libavcodec.59.dylib +0 -0
  231. mindspore/lib/libavdevice.59.dylib +0 -0
  232. mindspore/lib/libavfilter.8.dylib +0 -0
  233. mindspore/lib/libavformat.59.dylib +0 -0
  234. mindspore/lib/libavutil.57.dylib +0 -0
  235. mindspore/lib/libdnnl.2.dylib +0 -0
  236. mindspore/lib/libicudata.69.dylib +0 -0
  237. mindspore/lib/libicui18n.69.dylib +0 -0
  238. mindspore/lib/libicuuc.69.dylib +0 -0
  239. mindspore/lib/libmindspore_address_sorting.15.dylib +0 -0
  240. mindspore/lib/libmindspore_backend.dylib +0 -0
  241. mindspore/lib/libmindspore_common.dylib +0 -0
  242. mindspore/lib/libmindspore_core.dylib +0 -0
  243. mindspore/lib/libmindspore_glog.0.dylib +0 -0
  244. mindspore/lib/libmindspore_gpr.15.dylib +0 -0
  245. mindspore/lib/libmindspore_grpc++.1.dylib +0 -0
  246. mindspore/lib/libmindspore_grpc.15.dylib +0 -0
  247. mindspore/lib/libmindspore_np_dtype.dylib +0 -0
  248. mindspore/lib/libmindspore_ops.dylib +0 -0
  249. mindspore/lib/libmindspore_upb.15.dylib +0 -0
  250. mindspore/lib/libnnacl.dylib +0 -0
  251. mindspore/lib/libopencv_core.4.5.dylib +0 -0
  252. mindspore/lib/libopencv_imgcodecs.4.5.dylib +0 -0
  253. mindspore/lib/libopencv_imgproc.4.5.dylib +0 -0
  254. mindspore/lib/libps_cache.dylib +0 -0
  255. mindspore/lib/libswresample.4.dylib +0 -0
  256. mindspore/lib/libswscale.6.dylib +0 -0
  257. mindspore/lib/libtinyxml2.8.dylib +0 -0
  258. mindspore/log.py +633 -0
  259. mindspore/mindrecord/__init__.py +43 -0
  260. mindspore/mindrecord/common/__init__.py +17 -0
  261. mindspore/mindrecord/common/constant.py +20 -0
  262. mindspore/mindrecord/common/enums.py +44 -0
  263. mindspore/mindrecord/common/exceptions.py +311 -0
  264. mindspore/mindrecord/config.py +809 -0
  265. mindspore/mindrecord/filereader.py +174 -0
  266. mindspore/mindrecord/filewriter.py +722 -0
  267. mindspore/mindrecord/mindpage.py +210 -0
  268. mindspore/mindrecord/shardheader.py +141 -0
  269. mindspore/mindrecord/shardindexgenerator.py +74 -0
  270. mindspore/mindrecord/shardreader.py +117 -0
  271. mindspore/mindrecord/shardsegment.py +128 -0
  272. mindspore/mindrecord/shardutils.py +185 -0
  273. mindspore/mindrecord/shardwriter.py +237 -0
  274. mindspore/mindrecord/tools/__init__.py +17 -0
  275. mindspore/mindrecord/tools/cifar10.py +140 -0
  276. mindspore/mindrecord/tools/cifar100.py +153 -0
  277. mindspore/mindrecord/tools/cifar100_to_mr.py +185 -0
  278. mindspore/mindrecord/tools/cifar10_to_mr.py +177 -0
  279. mindspore/mindrecord/tools/csv_to_mr.py +200 -0
  280. mindspore/mindrecord/tools/imagenet_to_mr.py +206 -0
  281. mindspore/mindrecord/tools/mnist_to_mr.py +259 -0
  282. mindspore/mindrecord/tools/tfrecord_to_mr.py +360 -0
  283. mindspore/mint/__init__.py +1586 -0
  284. mindspore/mint/distributed/__init__.py +31 -0
  285. mindspore/mint/distributed/distributed.py +254 -0
  286. mindspore/mint/linalg/__init__.py +22 -0
  287. mindspore/mint/nn/__init__.py +757 -0
  288. mindspore/mint/nn/functional.py +679 -0
  289. mindspore/mint/nn/layer/__init__.py +39 -0
  290. mindspore/mint/nn/layer/activation.py +133 -0
  291. mindspore/mint/nn/layer/normalization.py +477 -0
  292. mindspore/mint/nn/layer/pooling.py +110 -0
  293. mindspore/mint/optim/__init__.py +24 -0
  294. mindspore/mint/optim/adamw.py +206 -0
  295. mindspore/mint/special/__init__.py +63 -0
  296. mindspore/multiprocessing/__init__.py +73 -0
  297. mindspore/nn/__init__.py +47 -0
  298. mindspore/nn/cell.py +2787 -0
  299. mindspore/nn/dynamic_lr.py +482 -0
  300. mindspore/nn/grad/__init__.py +21 -0
  301. mindspore/nn/grad/cell_grad.py +196 -0
  302. mindspore/nn/layer/__init__.py +63 -0
  303. mindspore/nn/layer/activation.py +1822 -0
  304. mindspore/nn/layer/basic.py +1629 -0
  305. mindspore/nn/layer/channel_shuffle.py +90 -0
  306. mindspore/nn/layer/combined.py +248 -0
  307. mindspore/nn/layer/container.py +734 -0
  308. mindspore/nn/layer/conv.py +1505 -0
  309. mindspore/nn/layer/dense.py +204 -0
  310. mindspore/nn/layer/embedding.py +869 -0
  311. mindspore/nn/layer/image.py +661 -0
  312. mindspore/nn/layer/math.py +1069 -0
  313. mindspore/nn/layer/normalization.py +1273 -0
  314. mindspore/nn/layer/padding.py +880 -0
  315. mindspore/nn/layer/pooling.py +2302 -0
  316. mindspore/nn/layer/rnn_cells.py +388 -0
  317. mindspore/nn/layer/rnns.py +849 -0
  318. mindspore/nn/layer/thor_layer.py +963 -0
  319. mindspore/nn/layer/timedistributed.py +155 -0
  320. mindspore/nn/layer/transformer.py +823 -0
  321. mindspore/nn/learning_rate_schedule.py +512 -0
  322. mindspore/nn/loss/__init__.py +36 -0
  323. mindspore/nn/loss/loss.py +2924 -0
  324. mindspore/nn/metrics.py +53 -0
  325. mindspore/nn/optim/__init__.py +45 -0
  326. mindspore/nn/optim/_dist_optimizer_registry.py +111 -0
  327. mindspore/nn/optim/ada_grad.py +217 -0
  328. mindspore/nn/optim/adadelta.py +206 -0
  329. mindspore/nn/optim/adafactor.py +448 -0
  330. mindspore/nn/optim/adam.py +1297 -0
  331. mindspore/nn/optim/adamax.py +220 -0
  332. mindspore/nn/optim/adasum.py +548 -0
  333. mindspore/nn/optim/asgd.py +216 -0
  334. mindspore/nn/optim/ftrl.py +401 -0
  335. mindspore/nn/optim/lamb.py +296 -0
  336. mindspore/nn/optim/lars.py +202 -0
  337. mindspore/nn/optim/lazyadam.py +533 -0
  338. mindspore/nn/optim/momentum.py +239 -0
  339. mindspore/nn/optim/optimizer.py +1034 -0
  340. mindspore/nn/optim/proximal_ada_grad.py +242 -0
  341. mindspore/nn/optim/rmsprop.py +264 -0
  342. mindspore/nn/optim/rprop.py +251 -0
  343. mindspore/nn/optim/sgd.py +237 -0
  344. mindspore/nn/optim/tft_wrapper.py +127 -0
  345. mindspore/nn/optim/thor.py +1310 -0
  346. mindspore/nn/probability/__init__.py +22 -0
  347. mindspore/nn/probability/bijector/__init__.py +35 -0
  348. mindspore/nn/probability/bijector/bijector.py +337 -0
  349. mindspore/nn/probability/bijector/exp.py +65 -0
  350. mindspore/nn/probability/bijector/gumbel_cdf.py +144 -0
  351. mindspore/nn/probability/bijector/invert.py +126 -0
  352. mindspore/nn/probability/bijector/power_transform.py +196 -0
  353. mindspore/nn/probability/bijector/scalar_affine.py +167 -0
  354. mindspore/nn/probability/bijector/softplus.py +189 -0
  355. mindspore/nn/probability/bnn_layers/__init__.py +29 -0
  356. mindspore/nn/probability/bnn_layers/_util.py +46 -0
  357. mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +112 -0
  358. mindspore/nn/probability/bnn_layers/conv_variational.py +267 -0
  359. mindspore/nn/probability/bnn_layers/dense_variational.py +302 -0
  360. mindspore/nn/probability/bnn_layers/layer_distribution.py +123 -0
  361. mindspore/nn/probability/distribution/__init__.py +56 -0
  362. mindspore/nn/probability/distribution/_utils/__init__.py +34 -0
  363. mindspore/nn/probability/distribution/_utils/custom_ops.py +96 -0
  364. mindspore/nn/probability/distribution/_utils/utils.py +362 -0
  365. mindspore/nn/probability/distribution/bernoulli.py +334 -0
  366. mindspore/nn/probability/distribution/beta.py +391 -0
  367. mindspore/nn/probability/distribution/categorical.py +435 -0
  368. mindspore/nn/probability/distribution/cauchy.py +383 -0
  369. mindspore/nn/probability/distribution/distribution.py +827 -0
  370. mindspore/nn/probability/distribution/exponential.py +350 -0
  371. mindspore/nn/probability/distribution/gamma.py +391 -0
  372. mindspore/nn/probability/distribution/geometric.py +335 -0
  373. mindspore/nn/probability/distribution/gumbel.py +257 -0
  374. mindspore/nn/probability/distribution/half_normal.py +133 -0
  375. mindspore/nn/probability/distribution/laplace.py +128 -0
  376. mindspore/nn/probability/distribution/log_normal.py +272 -0
  377. mindspore/nn/probability/distribution/logistic.py +379 -0
  378. mindspore/nn/probability/distribution/normal.py +336 -0
  379. mindspore/nn/probability/distribution/poisson.py +288 -0
  380. mindspore/nn/probability/distribution/student_t.py +149 -0
  381. mindspore/nn/probability/distribution/transformed_distribution.py +235 -0
  382. mindspore/nn/probability/distribution/uniform.py +375 -0
  383. mindspore/nn/reinforcement/__init__.py +24 -0
  384. mindspore/nn/reinforcement/_batch_read_write.py +142 -0
  385. mindspore/nn/reinforcement/_tensors_queue.py +152 -0
  386. mindspore/nn/reinforcement/tensor_array.py +145 -0
  387. mindspore/nn/sparse/__init__.py +23 -0
  388. mindspore/nn/sparse/sparse.py +147 -0
  389. mindspore/nn/wrap/__init__.py +49 -0
  390. mindspore/nn/wrap/cell_wrapper.py +968 -0
  391. mindspore/nn/wrap/grad_reducer.py +608 -0
  392. mindspore/nn/wrap/loss_scale.py +694 -0
  393. mindspore/numpy/__init__.py +121 -0
  394. mindspore/numpy/array_creations.py +2731 -0
  395. mindspore/numpy/array_ops.py +2629 -0
  396. mindspore/numpy/dtypes.py +185 -0
  397. mindspore/numpy/fft.py +966 -0
  398. mindspore/numpy/logic_ops.py +936 -0
  399. mindspore/numpy/math_ops.py +5911 -0
  400. mindspore/numpy/utils.py +214 -0
  401. mindspore/numpy/utils_const.py +565 -0
  402. mindspore/ops/__init__.py +56 -0
  403. mindspore/ops/_constants.py +30 -0
  404. mindspore/ops/_grad_experimental/__init__.py +31 -0
  405. mindspore/ops/_grad_experimental/grad_array_ops.py +830 -0
  406. mindspore/ops/_grad_experimental/grad_base.py +143 -0
  407. mindspore/ops/_grad_experimental/grad_comm_ops.py +714 -0
  408. mindspore/ops/_grad_experimental/grad_debug_ops.py +31 -0
  409. mindspore/ops/_grad_experimental/grad_implementations.py +203 -0
  410. mindspore/ops/_grad_experimental/grad_inner_ops.py +79 -0
  411. mindspore/ops/_grad_experimental/grad_math_ops.py +802 -0
  412. mindspore/ops/_grad_experimental/grad_nn_ops.py +231 -0
  413. mindspore/ops/_grad_experimental/grad_quant_ops.py +238 -0
  414. mindspore/ops/_grad_experimental/grad_sparse.py +342 -0
  415. mindspore/ops/_grad_experimental/grad_sparse_ops.py +399 -0
  416. mindspore/ops/_grad_experimental/taylor_rule.py +220 -0
  417. mindspore/ops/_op_impl/__init__.py +23 -0
  418. mindspore/ops/_op_impl/_custom_op/__init__.py +39 -0
  419. mindspore/ops/_op_impl/_custom_op/_basic.py +158 -0
  420. mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +279 -0
  421. mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +156 -0
  422. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +109 -0
  423. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +125 -0
  424. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +105 -0
  425. mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +124 -0
  426. mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +116 -0
  427. mindspore/ops/_op_impl/_custom_op/correction_mul.py +89 -0
  428. mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +196 -0
  429. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +366 -0
  430. mindspore/ops/_op_impl/_custom_op/dsd_impl.py +162 -0
  431. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +136 -0
  432. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +206 -0
  433. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +88 -0
  434. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +128 -0
  435. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +199 -0
  436. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +88 -0
  437. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +156 -0
  438. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +184 -0
  439. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +143 -0
  440. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +169 -0
  441. mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +548 -0
  442. mindspore/ops/_op_impl/_custom_op/img2col_impl.py +881 -0
  443. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +278 -0
  444. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +200 -0
  445. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +334 -0
  446. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +255 -0
  447. mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +222 -0
  448. mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +644 -0
  449. mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +488 -0
  450. mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +87 -0
  451. mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +129 -0
  452. mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +121 -0
  453. mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +352 -0
  454. mindspore/ops/_op_impl/aicpu/__init__.py +441 -0
  455. mindspore/ops/_op_impl/aicpu/abs.py +36 -0
  456. mindspore/ops/_op_impl/aicpu/acos.py +32 -0
  457. mindspore/ops/_op_impl/aicpu/acos_grad.py +33 -0
  458. mindspore/ops/_op_impl/aicpu/acosh.py +34 -0
  459. mindspore/ops/_op_impl/aicpu/acosh_grad.py +35 -0
  460. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
  461. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
  462. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
  463. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
  464. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
  465. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
  466. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
  467. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
  468. mindspore/ops/_op_impl/aicpu/add.py +43 -0
  469. mindspore/ops/_op_impl/aicpu/add_n.py +41 -0
  470. mindspore/ops/_op_impl/aicpu/add_v2.py +40 -0
  471. mindspore/ops/_op_impl/aicpu/addcdiv.py +41 -0
  472. mindspore/ops/_op_impl/aicpu/addcmul.py +47 -0
  473. mindspore/ops/_op_impl/aicpu/adjust_contrastv2.py +32 -0
  474. mindspore/ops/_op_impl/aicpu/adjust_hue.py +31 -0
  475. mindspore/ops/_op_impl/aicpu/adjust_saturation.py +32 -0
  476. mindspore/ops/_op_impl/aicpu/affine_grid.py +33 -0
  477. mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
  478. mindspore/ops/_op_impl/aicpu/angle.py +31 -0
  479. mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
  480. mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
  481. mindspore/ops/_op_impl/aicpu/argmax_with_value.py +43 -0
  482. mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
  483. mindspore/ops/_op_impl/aicpu/asin.py +32 -0
  484. mindspore/ops/_op_impl/aicpu/asin_grad.py +33 -0
  485. mindspore/ops/_op_impl/aicpu/asinh.py +34 -0
  486. mindspore/ops/_op_impl/aicpu/asinh_grad.py +35 -0
  487. mindspore/ops/_op_impl/aicpu/atanh.py +34 -0
  488. mindspore/ops/_op_impl/aicpu/avgpool_grad_v1.py +37 -0
  489. mindspore/ops/_op_impl/aicpu/avgpool_v1.py +36 -0
  490. mindspore/ops/_op_impl/aicpu/bartlett_window.py +36 -0
  491. mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
  492. mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
  493. mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
  494. mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
  495. mindspore/ops/_op_impl/aicpu/betainc.py +31 -0
  496. mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
  497. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +42 -0
  498. mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
  499. mindspore/ops/_op_impl/aicpu/blackman_window.py +36 -0
  500. mindspore/ops/_op_impl/aicpu/broadcast_to.py +58 -0
  501. mindspore/ops/_op_impl/aicpu/bucketize.py +34 -0
  502. mindspore/ops/_op_impl/aicpu/cache_swap_table.py +102 -0
  503. mindspore/ops/_op_impl/aicpu/cast.py +225 -0
  504. mindspore/ops/_op_impl/aicpu/cauchy.py +33 -0
  505. mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
  506. mindspore/ops/_op_impl/aicpu/check_numerics.py +33 -0
  507. mindspore/ops/_op_impl/aicpu/cholesky.py +32 -0
  508. mindspore/ops/_op_impl/aicpu/cholesky_inverse.py +31 -0
  509. mindspore/ops/_op_impl/aicpu/cholesky_solve.py +33 -0
  510. mindspore/ops/_op_impl/aicpu/choleskygrad.py +32 -0
  511. mindspore/ops/_op_impl/aicpu/coalesce.py +37 -0
  512. mindspore/ops/_op_impl/aicpu/col2im.py +38 -0
  513. mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
  514. mindspore/ops/_op_impl/aicpu/compare_and_bitpack.py +37 -0
  515. mindspore/ops/_op_impl/aicpu/complex.py +32 -0
  516. mindspore/ops/_op_impl/aicpu/complex_abs.py +31 -0
  517. mindspore/ops/_op_impl/aicpu/compute_accidental_hits.py +44 -0
  518. mindspore/ops/_op_impl/aicpu/concat.py +57 -0
  519. mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
  520. mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
  521. mindspore/ops/_op_impl/aicpu/conj.py +42 -0
  522. mindspore/ops/_op_impl/aicpu/conjugate_transpose.py +58 -0
  523. mindspore/ops/_op_impl/aicpu/cos.py +34 -0
  524. mindspore/ops/_op_impl/aicpu/cosh.py +34 -0
  525. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  526. mindspore/ops/_op_impl/aicpu/crop_and_resize.py +69 -0
  527. mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_boxes.py +68 -0
  528. mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
  529. mindspore/ops/_op_impl/aicpu/cross.py +42 -0
  530. mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_dense.py +48 -0
  531. mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_sparse_tensor.py +51 -0
  532. mindspore/ops/_op_impl/aicpu/ctc_greedy_decoder.py +35 -0
  533. mindspore/ops/_op_impl/aicpu/ctc_loss_v2.py +43 -0
  534. mindspore/ops/_op_impl/aicpu/ctc_loss_v2_grad.py +45 -0
  535. mindspore/ops/_op_impl/aicpu/ctcloss.py +38 -0
  536. mindspore/ops/_op_impl/aicpu/cummax.py +41 -0
  537. mindspore/ops/_op_impl/aicpu/cumprod.py +58 -0
  538. mindspore/ops/_op_impl/aicpu/cumsum.py +58 -0
  539. mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
  540. mindspore/ops/_op_impl/aicpu/data_format_vec_permute.py +32 -0
  541. mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
  542. mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
  543. mindspore/ops/_op_impl/aicpu/dense_to_csr_sparse_matrix.py +49 -0
  544. mindspore/ops/_op_impl/aicpu/dense_to_dense_set_operation.py +45 -0
  545. mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
  546. mindspore/ops/_op_impl/aicpu/depth_to_space.py +44 -0
  547. mindspore/ops/_op_impl/aicpu/diag.py +36 -0
  548. mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
  549. mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
  550. mindspore/ops/_op_impl/aicpu/digamma.py +31 -0
  551. mindspore/ops/_op_impl/aicpu/div.py +41 -0
  552. mindspore/ops/_op_impl/aicpu/div_no_nan.py +35 -0
  553. mindspore/ops/_op_impl/aicpu/dropout2d.py +42 -0
  554. mindspore/ops/_op_impl/aicpu/dropout3d.py +42 -0
  555. mindspore/ops/_op_impl/aicpu/dropout_genmask.py +41 -0
  556. mindspore/ops/_op_impl/aicpu/dropout_genmask_v3.py +32 -0
  557. mindspore/ops/_op_impl/aicpu/dynamic_stitch.py +42 -0
  558. mindspore/ops/_op_impl/aicpu/edit_distance.py +56 -0
  559. mindspore/ops/_op_impl/aicpu/eig.py +35 -0
  560. mindspore/ops/_op_impl/aicpu/embedding_lookup.py +102 -0
  561. mindspore/ops/_op_impl/aicpu/end_of_sequence.py +30 -0
  562. mindspore/ops/_op_impl/aicpu/environ_create.py +28 -0
  563. mindspore/ops/_op_impl/aicpu/environ_destroy_all.py +28 -0
  564. mindspore/ops/_op_impl/aicpu/environ_get.py +41 -0
  565. mindspore/ops/_op_impl/aicpu/environ_set.py +40 -0
  566. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  567. mindspore/ops/_op_impl/aicpu/equal.py +41 -0
  568. mindspore/ops/_op_impl/aicpu/exp.py +37 -0
  569. mindspore/ops/_op_impl/aicpu/expand.py +45 -0
  570. mindspore/ops/_op_impl/aicpu/expand_dims.py +42 -0
  571. mindspore/ops/_op_impl/aicpu/expm1.py +34 -0
  572. mindspore/ops/_op_impl/aicpu/extract_glimpse.py +35 -0
  573. mindspore/ops/_op_impl/aicpu/eye.py +44 -0
  574. mindspore/ops/_op_impl/aicpu/fft_with_size.py +47 -0
  575. mindspore/ops/_op_impl/aicpu/fill_diagonal.py +39 -0
  576. mindspore/ops/_op_impl/aicpu/fill_v2.py +58 -0
  577. mindspore/ops/_op_impl/aicpu/flatten.py +43 -0
  578. mindspore/ops/_op_impl/aicpu/floor_div.py +38 -0
  579. mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
  580. mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
  581. mindspore/ops/_op_impl/aicpu/fractional_avg_pool.py +41 -0
  582. mindspore/ops/_op_impl/aicpu/fractional_avg_pool_grad.py +41 -0
  583. mindspore/ops/_op_impl/aicpu/fractional_max_pool.py +41 -0
  584. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_grad_with_fixed_ksize.py +43 -0
  585. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +65 -0
  586. mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad.py +42 -0
  587. mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad_with_fixed_ksize.py +42 -0
  588. mindspore/ops/_op_impl/aicpu/fractional_max_pool_with_fixed_ksize.py +49 -0
  589. mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
  590. mindspore/ops/_op_impl/aicpu/fused_sparse_adam.py +46 -0
  591. mindspore/ops/_op_impl/aicpu/fused_sparse_ftrl.py +41 -0
  592. mindspore/ops/_op_impl/aicpu/fused_sparse_lazy_adam.py +46 -0
  593. mindspore/ops/_op_impl/aicpu/fused_sparse_proximal_adagrad.py +39 -0
  594. mindspore/ops/_op_impl/aicpu/gamma.py +38 -0
  595. mindspore/ops/_op_impl/aicpu/gather.py +46 -0
  596. mindspore/ops/_op_impl/aicpu/gather_d.py +79 -0
  597. mindspore/ops/_op_impl/aicpu/gather_d_grad_v2.py +79 -0
  598. mindspore/ops/_op_impl/aicpu/gather_grad.py +54 -0
  599. mindspore/ops/_op_impl/aicpu/gather_nd.py +56 -0
  600. mindspore/ops/_op_impl/aicpu/gcd.py +32 -0
  601. mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +38 -0
  602. mindspore/ops/_op_impl/aicpu/geqrf.py +32 -0
  603. mindspore/ops/_op_impl/aicpu/get_next.py +39 -0
  604. mindspore/ops/_op_impl/aicpu/glu.py +33 -0
  605. mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
  606. mindspore/ops/_op_impl/aicpu/greater.py +41 -0
  607. mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
  608. mindspore/ops/_op_impl/aicpu/grid_sampler_2d.py +35 -0
  609. mindspore/ops/_op_impl/aicpu/grid_sampler_2d_grad.py +38 -0
  610. mindspore/ops/_op_impl/aicpu/grid_sampler_3d.py +34 -0
  611. mindspore/ops/_op_impl/aicpu/grid_sampler_3d_grad.py +38 -0
  612. mindspore/ops/_op_impl/aicpu/hamming_window.py +57 -0
  613. mindspore/ops/_op_impl/aicpu/hard_sigmoid.py +32 -0
  614. mindspore/ops/_op_impl/aicpu/hard_sigmoid_grad.py +33 -0
  615. mindspore/ops/_op_impl/aicpu/heaviside.py +40 -0
  616. mindspore/ops/_op_impl/aicpu/histogram.py +35 -0
  617. mindspore/ops/_op_impl/aicpu/hsv_to_rgb.py +32 -0
  618. mindspore/ops/_op_impl/aicpu/hypot.py +32 -0
  619. mindspore/ops/_op_impl/aicpu/identity.py +42 -0
  620. mindspore/ops/_op_impl/aicpu/identity_n.py +41 -0
  621. mindspore/ops/_op_impl/aicpu/igamma.py +30 -0
  622. mindspore/ops/_op_impl/aicpu/igammac.py +30 -0
  623. mindspore/ops/_op_impl/aicpu/igammagrada.py +30 -0
  624. mindspore/ops/_op_impl/aicpu/im2col.py +43 -0
  625. mindspore/ops/_op_impl/aicpu/imag.py +31 -0
  626. mindspore/ops/_op_impl/aicpu/index_fill.py +54 -0
  627. mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
  628. mindspore/ops/_op_impl/aicpu/init_data_set_queue.py +27 -0
  629. mindspore/ops/_op_impl/aicpu/inplace_index_add.py +39 -0
  630. mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
  631. mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
  632. mindspore/ops/_op_impl/aicpu/is_finite.py +40 -0
  633. mindspore/ops/_op_impl/aicpu/is_inf.py +31 -0
  634. mindspore/ops/_op_impl/aicpu/is_nan.py +31 -0
  635. mindspore/ops/_op_impl/aicpu/kldivloss.py +34 -0
  636. mindspore/ops/_op_impl/aicpu/kldivlossgrad.py +35 -0
  637. mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
  638. mindspore/ops/_op_impl/aicpu/lcm.py +32 -0
  639. mindspore/ops/_op_impl/aicpu/left_shift.py +38 -0
  640. mindspore/ops/_op_impl/aicpu/less.py +41 -0
  641. mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
  642. mindspore/ops/_op_impl/aicpu/lgamma.py +33 -0
  643. mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +57 -0
  644. mindspore/ops/_op_impl/aicpu/linspace.py +33 -0
  645. mindspore/ops/_op_impl/aicpu/list_diff.py +50 -0
  646. mindspore/ops/_op_impl/aicpu/log.py +37 -0
  647. mindspore/ops/_op_impl/aicpu/log1p.py +34 -0
  648. mindspore/ops/_op_impl/aicpu/log_matrix_determinant.py +31 -0
  649. mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
  650. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +37 -0
  651. mindspore/ops/_op_impl/aicpu/logical_xor.py +30 -0
  652. mindspore/ops/_op_impl/aicpu/logit.py +33 -0
  653. mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
  654. mindspore/ops/_op_impl/aicpu/logspace.py +36 -0
  655. mindspore/ops/_op_impl/aicpu/lower_bound.py +47 -0
  656. mindspore/ops/_op_impl/aicpu/lstsq.py +34 -0
  657. mindspore/ops/_op_impl/aicpu/lu.py +39 -0
  658. mindspore/ops/_op_impl/aicpu/lu_solve.py +32 -0
  659. mindspore/ops/_op_impl/aicpu/lu_unpack.py +114 -0
  660. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +49 -0
  661. mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
  662. mindspore/ops/_op_impl/aicpu/masked_scatter.py +40 -0
  663. mindspore/ops/_op_impl/aicpu/masked_select.py +31 -0
  664. mindspore/ops/_op_impl/aicpu/masked_select_grad.py +35 -0
  665. mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
  666. mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
  667. mindspore/ops/_op_impl/aicpu/matrix_determinant.py +30 -0
  668. mindspore/ops/_op_impl/aicpu/matrix_diag_part_v3.py +54 -0
  669. mindspore/ops/_op_impl/aicpu/matrix_diag_v3.py +56 -0
  670. mindspore/ops/_op_impl/aicpu/matrix_exp.py +34 -0
  671. mindspore/ops/_op_impl/aicpu/matrix_inverse.py +31 -0
  672. mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
  673. mindspore/ops/_op_impl/aicpu/matrix_power.py +37 -0
  674. mindspore/ops/_op_impl/aicpu/matrix_set_diag_v3.py +54 -0
  675. mindspore/ops/_op_impl/aicpu/matrix_solve.py +35 -0
  676. mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
  677. mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
  678. mindspore/ops/_op_impl/aicpu/max_pool3d_grad_with_argmax.py +60 -0
  679. mindspore/ops/_op_impl/aicpu/max_pool3d_with_argmax.py +59 -0
  680. mindspore/ops/_op_impl/aicpu/max_unpool2d.py +57 -0
  681. mindspore/ops/_op_impl/aicpu/max_unpool2d_grad.py +58 -0
  682. mindspore/ops/_op_impl/aicpu/max_unpool3d.py +57 -0
  683. mindspore/ops/_op_impl/aicpu/max_unpool3d_grad.py +58 -0
  684. mindspore/ops/_op_impl/aicpu/maximum_grad_grad.py +40 -0
  685. mindspore/ops/_op_impl/aicpu/maxpool_grad_v1.py +46 -0
  686. mindspore/ops/_op_impl/aicpu/maxpool_v1.py +42 -0
  687. mindspore/ops/_op_impl/aicpu/median.py +39 -0
  688. mindspore/ops/_op_impl/aicpu/median_grad.py +45 -0
  689. mindspore/ops/_op_impl/aicpu/meshgrid.py +41 -0
  690. mindspore/ops/_op_impl/aicpu/minimum_grad_grad.py +40 -0
  691. mindspore/ops/_op_impl/aicpu/mirror_pad.py +50 -0
  692. mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +48 -0
  693. mindspore/ops/_op_impl/aicpu/mul.py +43 -0
  694. mindspore/ops/_op_impl/aicpu/mul_no_nan.py +42 -0
  695. mindspore/ops/_op_impl/aicpu/multi_margin_loss.py +37 -0
  696. mindspore/ops/_op_impl/aicpu/multi_margin_loss_grad.py +41 -0
  697. mindspore/ops/_op_impl/aicpu/multilabel_margin_loss_grad.py +37 -0
  698. mindspore/ops/_op_impl/aicpu/multinomial.py +47 -0
  699. mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
  700. mindspore/ops/_op_impl/aicpu/mvlgamma.py +32 -0
  701. mindspore/ops/_op_impl/aicpu/mvlgamma_grad.py +33 -0
  702. mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
  703. mindspore/ops/_op_impl/aicpu/neg.py +36 -0
  704. mindspore/ops/_op_impl/aicpu/nextafter.py +32 -0
  705. mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
  706. mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
  707. mindspore/ops/_op_impl/aicpu/no_repeat_ngram.py +34 -0
  708. mindspore/ops/_op_impl/aicpu/non_deterministic_ints.py +33 -0
  709. mindspore/ops/_op_impl/aicpu/non_max_suppression.py +36 -0
  710. mindspore/ops/_op_impl/aicpu/non_max_suppression_with_overlaps.py +35 -0
  711. mindspore/ops/_op_impl/aicpu/non_zero.py +43 -0
  712. mindspore/ops/_op_impl/aicpu/not_equal.py +39 -0
  713. mindspore/ops/_op_impl/aicpu/nth_element.py +39 -0
  714. mindspore/ops/_op_impl/aicpu/nuclear_norm.py +33 -0
  715. mindspore/ops/_op_impl/aicpu/one_hot.py +116 -0
  716. mindspore/ops/_op_impl/aicpu/ones_like.py +39 -0
  717. mindspore/ops/_op_impl/aicpu/orgqr.py +34 -0
  718. mindspore/ops/_op_impl/aicpu/pad_and_shift.py +33 -0
  719. mindspore/ops/_op_impl/aicpu/pad_v3.py +61 -0
  720. mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +59 -0
  721. mindspore/ops/_op_impl/aicpu/padding.py +41 -0
  722. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +54 -0
  723. mindspore/ops/_op_impl/aicpu/pdist_grad.py +33 -0
  724. mindspore/ops/_op_impl/aicpu/poisson.py +37 -0
  725. mindspore/ops/_op_impl/aicpu/polar.py +32 -0
  726. mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
  727. mindspore/ops/_op_impl/aicpu/pow.py +39 -0
  728. mindspore/ops/_op_impl/aicpu/print_tensor.py +39 -0
  729. mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +113 -0
  730. mindspore/ops/_op_impl/aicpu/qr.py +36 -0
  731. mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
  732. mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
  733. mindspore/ops/_op_impl/aicpu/ragged_range.py +49 -0
  734. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
  735. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
  736. mindspore/ops/_op_impl/aicpu/random_categorical.py +68 -0
  737. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +36 -0
  738. mindspore/ops/_op_impl/aicpu/random_gamma.py +38 -0
  739. mindspore/ops/_op_impl/aicpu/random_poisson.py +134 -0
  740. mindspore/ops/_op_impl/aicpu/random_shuffle.py +47 -0
  741. mindspore/ops/_op_impl/aicpu/randperm.py +38 -0
  742. mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
  743. mindspore/ops/_op_impl/aicpu/range.py +36 -0
  744. mindspore/ops/_op_impl/aicpu/range_v2.py +35 -0
  745. mindspore/ops/_op_impl/aicpu/real.py +31 -0
  746. mindspore/ops/_op_impl/aicpu/real_div.py +40 -0
  747. mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
  748. mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
  749. mindspore/ops/_op_impl/aicpu/reduce_mean.py +57 -0
  750. mindspore/ops/_op_impl/aicpu/reduce_prod.py +57 -0
  751. mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
  752. mindspore/ops/_op_impl/aicpu/relu_grad_v3.py +41 -0
  753. mindspore/ops/_op_impl/aicpu/relu_v3.py +38 -0
  754. mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +96 -0
  755. mindspore/ops/_op_impl/aicpu/reshape.py +42 -0
  756. mindspore/ops/_op_impl/aicpu/resize_area.py +40 -0
  757. mindspore/ops/_op_impl/aicpu/resize_bicubic.py +20 -0
  758. mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +19 -0
  759. mindspore/ops/_op_impl/aicpu/resize_bilinear.py +32 -0
  760. mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +32 -0
  761. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +36 -0
  762. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +35 -0
  763. mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
  764. mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
  765. mindspore/ops/_op_impl/aicpu/reverse_sequence.py +55 -0
  766. mindspore/ops/_op_impl/aicpu/reversev2.py +54 -0
  767. mindspore/ops/_op_impl/aicpu/rgb_to_hsv.py +32 -0
  768. mindspore/ops/_op_impl/aicpu/right_shift.py +38 -0
  769. mindspore/ops/_op_impl/aicpu/rnnt_loss.py +35 -0
  770. mindspore/ops/_op_impl/aicpu/round.py +34 -0
  771. mindspore/ops/_op_impl/aicpu/rsqrt.py +33 -0
  772. mindspore/ops/_op_impl/aicpu/rsqrt_grad.py +36 -0
  773. mindspore/ops/_op_impl/aicpu/sample_distorted_bounding_box_v2.py +49 -0
  774. mindspore/ops/_op_impl/aicpu/scale_and_translate.py +52 -0
  775. mindspore/ops/_op_impl/aicpu/scale_and_translate_grad.py +36 -0
  776. mindspore/ops/_op_impl/aicpu/scatter.py +79 -0
  777. mindspore/ops/_op_impl/aicpu/scatter_add_with_axis.py +53 -0
  778. mindspore/ops/_op_impl/aicpu/scatter_elements.py +39 -0
  779. mindspore/ops/_op_impl/aicpu/scatter_nd.py +59 -0
  780. mindspore/ops/_op_impl/aicpu/scatter_nd_max.py +54 -0
  781. mindspore/ops/_op_impl/aicpu/scatter_nd_min.py +54 -0
  782. mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +59 -0
  783. mindspore/ops/_op_impl/aicpu/search_sorted.py +44 -0
  784. mindspore/ops/_op_impl/aicpu/segment_max.py +52 -0
  785. mindspore/ops/_op_impl/aicpu/segment_mean.py +56 -0
  786. mindspore/ops/_op_impl/aicpu/segment_min.py +52 -0
  787. mindspore/ops/_op_impl/aicpu/segment_prod.py +56 -0
  788. mindspore/ops/_op_impl/aicpu/segment_sum.py +56 -0
  789. mindspore/ops/_op_impl/aicpu/select.py +45 -0
  790. mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
  791. mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
  792. mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
  793. mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
  794. mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
  795. mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
  796. mindspore/ops/_op_impl/aicpu/set_size.py +38 -0
  797. mindspore/ops/_op_impl/aicpu/sign.py +36 -0
  798. mindspore/ops/_op_impl/aicpu/sin.py +34 -0
  799. mindspore/ops/_op_impl/aicpu/sinc.py +43 -0
  800. mindspore/ops/_op_impl/aicpu/sinh.py +34 -0
  801. mindspore/ops/_op_impl/aicpu/slice.py +59 -0
  802. mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
  803. mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
  804. mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
  805. mindspore/ops/_op_impl/aicpu/sort.py +39 -0
  806. mindspore/ops/_op_impl/aicpu/space_to_depth.py +44 -0
  807. mindspore/ops/_op_impl/aicpu/sparse_addmm.py +87 -0
  808. mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +80 -0
  809. mindspore/ops/_op_impl/aicpu/sparse_apply_centered_rms_prop.py +105 -0
  810. mindspore/ops/_op_impl/aicpu/sparse_apply_momentum.py +80 -0
  811. mindspore/ops/_op_impl/aicpu/sparse_apply_proximal_gradient_descent.py +79 -0
  812. mindspore/ops/_op_impl/aicpu/sparse_concat.py +59 -0
  813. mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
  814. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_add.py +58 -0
  815. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_div.py +58 -0
  816. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_mul.py +58 -0
  817. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
  818. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
  819. mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
  820. mindspore/ops/_op_impl/aicpu/sparse_matrix_nnz.py +81 -0
  821. mindspore/ops/_op_impl/aicpu/sparse_matrix_transpose.py +116 -0
  822. mindspore/ops/_op_impl/aicpu/sparse_reorder.py +56 -0
  823. mindspore/ops/_op_impl/aicpu/sparse_reshape.py +34 -0
  824. mindspore/ops/_op_impl/aicpu/sparse_segment_mean_grad.py +36 -0
  825. mindspore/ops/_op_impl/aicpu/sparse_segment_mean_with_num_segments.py +44 -0
  826. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n.py +43 -0
  827. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_grad.py +38 -0
  828. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_with_num_segments.py +44 -0
  829. mindspore/ops/_op_impl/aicpu/sparse_segment_sum.py +49 -0
  830. mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
  831. mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
  832. mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
  833. mindspore/ops/_op_impl/aicpu/sparse_softmax.py +33 -0
  834. mindspore/ops/_op_impl/aicpu/sparse_softmax_cross_entropy_with_logits_v2.py +35 -0
  835. mindspore/ops/_op_impl/aicpu/sparse_sparse_maximum.py +53 -0
  836. mindspore/ops/_op_impl/aicpu/sparse_sparse_minimum.py +53 -0
  837. mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_add.py +84 -0
  838. mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_mat_mul.py +190 -0
  839. mindspore/ops/_op_impl/aicpu/sparse_tensor_to_csr_sparse_matrix.py +51 -0
  840. mindspore/ops/_op_impl/aicpu/sparse_to_dense_v2.py +73 -0
  841. mindspore/ops/_op_impl/aicpu/split.py +45 -0
  842. mindspore/ops/_op_impl/aicpu/sqrt.py +34 -0
  843. mindspore/ops/_op_impl/aicpu/sqrt_grad.py +35 -0
  844. mindspore/ops/_op_impl/aicpu/square.py +35 -0
  845. mindspore/ops/_op_impl/aicpu/squared_difference.py +37 -0
  846. mindspore/ops/_op_impl/aicpu/squeeze.py +42 -0
  847. mindspore/ops/_op_impl/aicpu/sspaddmm.py +97 -0
  848. mindspore/ops/_op_impl/aicpu/stack.py +45 -0
  849. mindspore/ops/_op_impl/aicpu/stack_push_pop.py +87 -0
  850. mindspore/ops/_op_impl/aicpu/standard_laplace.py +34 -0
  851. mindspore/ops/_op_impl/aicpu/standard_normal.py +34 -0
  852. mindspore/ops/_op_impl/aicpu/stateless_dropout_genmask.py +37 -0
  853. mindspore/ops/_op_impl/aicpu/stft.py +70 -0
  854. mindspore/ops/_op_impl/aicpu/strided_slice.py +43 -0
  855. mindspore/ops/_op_impl/aicpu/strided_slice_grad.py +50 -0
  856. mindspore/ops/_op_impl/aicpu/sub.py +41 -0
  857. mindspore/ops/_op_impl/aicpu/sub_and_filter.py +36 -0
  858. mindspore/ops/_op_impl/aicpu/tan.py +34 -0
  859. mindspore/ops/_op_impl/aicpu/tanh.py +34 -0
  860. mindspore/ops/_op_impl/aicpu/tanh_grad.py +35 -0
  861. mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
  862. mindspore/ops/_op_impl/aicpu/tile.py +56 -0
  863. mindspore/ops/_op_impl/aicpu/topk.py +34 -0
  864. mindspore/ops/_op_impl/aicpu/trace.py +40 -0
  865. mindspore/ops/_op_impl/aicpu/tracegrad.py +41 -0
  866. mindspore/ops/_op_impl/aicpu/trans_data.py +35 -0
  867. mindspore/ops/_op_impl/aicpu/transpose.py +58 -0
  868. mindspore/ops/_op_impl/aicpu/tridiagonal_matmul.py +42 -0
  869. mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
  870. mindspore/ops/_op_impl/aicpu/tril.py +42 -0
  871. mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
  872. mindspore/ops/_op_impl/aicpu/triplet_margin_loss.py +62 -0
  873. mindspore/ops/_op_impl/aicpu/triu.py +43 -0
  874. mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
  875. mindspore/ops/_op_impl/aicpu/truncated_normal.py +39 -0
  876. mindspore/ops/_op_impl/aicpu/uniform.py +36 -0
  877. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +41 -0
  878. mindspore/ops/_op_impl/aicpu/uniform_int.py +36 -0
  879. mindspore/ops/_op_impl/aicpu/uniform_real.py +33 -0
  880. mindspore/ops/_op_impl/aicpu/unique.py +31 -0
  881. mindspore/ops/_op_impl/aicpu/unique_consecutive.py +47 -0
  882. mindspore/ops/_op_impl/aicpu/unique_with_pad.py +32 -0
  883. mindspore/ops/_op_impl/aicpu/unravel_index.py +32 -0
  884. mindspore/ops/_op_impl/aicpu/unsorted_segment_prod.py +53 -0
  885. mindspore/ops/_op_impl/aicpu/unsorted_segment_sum.py +57 -0
  886. mindspore/ops/_op_impl/aicpu/unstack.py +45 -0
  887. mindspore/ops/_op_impl/aicpu/update_cache.py +44 -0
  888. mindspore/ops/_op_impl/aicpu/upper_bound.py +47 -0
  889. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +42 -0
  890. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +49 -0
  891. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +40 -0
  892. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +50 -0
  893. mindspore/ops/_op_impl/aicpu/xdivy.py +35 -0
  894. mindspore/ops/_op_impl/aicpu/xlogy.py +33 -0
  895. mindspore/ops/_op_impl/aicpu/zeros_like.py +42 -0
  896. mindspore/ops/_op_impl/aicpu/zeta.py +31 -0
  897. mindspore/ops/_op_impl/akg/__init__.py +19 -0
  898. mindspore/ops/_op_impl/akg/ascend/__init__.py +48 -0
  899. mindspore/ops/_op_impl/akg/ascend/abs.py +35 -0
  900. mindspore/ops/_op_impl/akg/ascend/add.py +42 -0
  901. mindspore/ops/_op_impl/akg/ascend/add_n.py +37 -0
  902. mindspore/ops/_op_impl/akg/ascend/batchmatmul.py +33 -0
  903. mindspore/ops/_op_impl/akg/ascend/cast.py +46 -0
  904. mindspore/ops/_op_impl/akg/ascend/equal.py +35 -0
  905. mindspore/ops/_op_impl/akg/ascend/exp.py +35 -0
  906. mindspore/ops/_op_impl/akg/ascend/expand_dims.py +33 -0
  907. mindspore/ops/_op_impl/akg/ascend/greater.py +34 -0
  908. mindspore/ops/_op_impl/akg/ascend/greater_equal.py +35 -0
  909. mindspore/ops/_op_impl/akg/ascend/less.py +31 -0
  910. mindspore/ops/_op_impl/akg/ascend/less_equal.py +35 -0
  911. mindspore/ops/_op_impl/akg/ascend/load_im2col.py +33 -0
  912. mindspore/ops/_op_impl/akg/ascend/log.py +34 -0
  913. mindspore/ops/_op_impl/akg/ascend/maximum.py +36 -0
  914. mindspore/ops/_op_impl/akg/ascend/minimum.py +39 -0
  915. mindspore/ops/_op_impl/akg/ascend/mul.py +41 -0
  916. mindspore/ops/_op_impl/akg/ascend/neg.py +37 -0
  917. mindspore/ops/_op_impl/akg/ascend/pow.py +35 -0
  918. mindspore/ops/_op_impl/akg/ascend/prod_force_se_a.py +33 -0
  919. mindspore/ops/_op_impl/akg/ascend/real_div.py +36 -0
  920. mindspore/ops/_op_impl/akg/ascend/reciprocal.py +32 -0
  921. mindspore/ops/_op_impl/akg/ascend/reduce_max.py +32 -0
  922. mindspore/ops/_op_impl/akg/ascend/reduce_min.py +32 -0
  923. mindspore/ops/_op_impl/akg/ascend/reduce_sum.py +37 -0
  924. mindspore/ops/_op_impl/akg/ascend/rsqrt.py +35 -0
  925. mindspore/ops/_op_impl/akg/ascend/select.py +37 -0
  926. mindspore/ops/_op_impl/akg/ascend/sqrt.py +35 -0
  927. mindspore/ops/_op_impl/akg/ascend/square.py +35 -0
  928. mindspore/ops/_op_impl/akg/ascend/sub.py +42 -0
  929. mindspore/ops/_op_impl/akg/cpu/__init__.py +23 -0
  930. mindspore/ops/_op_impl/akg/cpu/coo2csr.py +29 -0
  931. mindspore/ops/_op_impl/akg/cpu/csr2coo.py +29 -0
  932. mindspore/ops/_op_impl/akg/cpu/csr_gather.py +33 -0
  933. mindspore/ops/_op_impl/akg/cpu/csr_mm.py +34 -0
  934. mindspore/ops/_op_impl/akg/cpu/csr_mul.py +33 -0
  935. mindspore/ops/_op_impl/akg/cpu/csr_mv.py +33 -0
  936. mindspore/ops/_op_impl/akg/cpu/csr_reduce_sum.py +31 -0
  937. mindspore/ops/_op_impl/akg/gpu/__init__.py +24 -0
  938. mindspore/ops/_op_impl/akg/gpu/coo2csr.py +29 -0
  939. mindspore/ops/_op_impl/akg/gpu/csr2coo.py +29 -0
  940. mindspore/ops/_op_impl/akg/gpu/csr_div.py +36 -0
  941. mindspore/ops/_op_impl/akg/gpu/csr_gather.py +33 -0
  942. mindspore/ops/_op_impl/akg/gpu/csr_mm.py +37 -0
  943. mindspore/ops/_op_impl/akg/gpu/csr_mul.py +36 -0
  944. mindspore/ops/_op_impl/akg/gpu/csr_mv.py +36 -0
  945. mindspore/ops/_op_impl/akg/gpu/csr_reduce_sum.py +33 -0
  946. mindspore/ops/_op_impl/cpu/__init__.py +78 -0
  947. mindspore/ops/_op_impl/cpu/adam.py +49 -0
  948. mindspore/ops/_op_impl/cpu/adam_weight_decay.py +47 -0
  949. mindspore/ops/_op_impl/cpu/arg_max.py +30 -0
  950. mindspore/ops/_op_impl/cpu/arg_max_with_value.py +31 -0
  951. mindspore/ops/_op_impl/cpu/arg_min_with_value.py +31 -0
  952. mindspore/ops/_op_impl/cpu/buffer_append.py +28 -0
  953. mindspore/ops/_op_impl/cpu/buffer_get.py +28 -0
  954. mindspore/ops/_op_impl/cpu/buffer_sample.py +28 -0
  955. mindspore/ops/_op_impl/cpu/cast.py +171 -0
  956. mindspore/ops/_op_impl/cpu/concat_offset.py +38 -0
  957. mindspore/ops/_op_impl/cpu/conv2d.py +30 -0
  958. mindspore/ops/_op_impl/cpu/conv3d.py +30 -0
  959. mindspore/ops/_op_impl/cpu/div.py +32 -0
  960. mindspore/ops/_op_impl/cpu/dropout.py +31 -0
  961. mindspore/ops/_op_impl/cpu/dropout_grad.py +30 -0
  962. mindspore/ops/_op_impl/cpu/dynamic_shape.py +42 -0
  963. mindspore/ops/_op_impl/cpu/dynamic_stitch.py +41 -0
  964. mindspore/ops/_op_impl/cpu/equal_count.py +30 -0
  965. mindspore/ops/_op_impl/cpu/gather_d.py +49 -0
  966. mindspore/ops/_op_impl/cpu/gather_d_grad.py +38 -0
  967. mindspore/ops/_op_impl/cpu/gather_d_grad_v2.py +40 -0
  968. mindspore/ops/_op_impl/cpu/gather_v2.py +40 -0
  969. mindspore/ops/_op_impl/cpu/hsigmoid.py +33 -0
  970. mindspore/ops/_op_impl/cpu/hsigmoid_grad.py +34 -0
  971. mindspore/ops/_op_impl/cpu/hswish.py +32 -0
  972. mindspore/ops/_op_impl/cpu/hswish_grad.py +33 -0
  973. mindspore/ops/_op_impl/cpu/identity_n.py +40 -0
  974. mindspore/ops/_op_impl/cpu/is_finite.py +39 -0
  975. mindspore/ops/_op_impl/cpu/l2loss.py +30 -0
  976. mindspore/ops/_op_impl/cpu/layer_norm.py +36 -0
  977. mindspore/ops/_op_impl/cpu/layer_norm_grad.py +38 -0
  978. mindspore/ops/_op_impl/cpu/maximum.py +35 -0
  979. mindspore/ops/_op_impl/cpu/maximum_grad.py +47 -0
  980. mindspore/ops/_op_impl/cpu/minimum.py +40 -0
  981. mindspore/ops/_op_impl/cpu/minimum_grad.py +51 -0
  982. mindspore/ops/_op_impl/cpu/mirror_pad.py +36 -0
  983. mindspore/ops/_op_impl/cpu/mirror_pad_grad.py +36 -0
  984. mindspore/ops/_op_impl/cpu/mul.py +32 -0
  985. mindspore/ops/_op_impl/cpu/one_hot.py +31 -0
  986. mindspore/ops/_op_impl/cpu/pad.py +32 -0
  987. mindspore/ops/_op_impl/cpu/pow.py +32 -0
  988. mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +42 -0
  989. mindspore/ops/_op_impl/cpu/pyexecute.py +29 -0
  990. mindspore/ops/_op_impl/cpu/pyfunc.py +29 -0
  991. mindspore/ops/_op_impl/cpu/range.py +34 -0
  992. mindspore/ops/_op_impl/cpu/real_div.py +33 -0
  993. mindspore/ops/_op_impl/cpu/reduce_all.py +29 -0
  994. mindspore/ops/_op_impl/cpu/reduce_any.py +29 -0
  995. mindspore/ops/_op_impl/cpu/reduce_max.py +32 -0
  996. mindspore/ops/_op_impl/cpu/reduce_mean.py +40 -0
  997. mindspore/ops/_op_impl/cpu/reduce_min.py +32 -0
  998. mindspore/ops/_op_impl/cpu/reduce_prod.py +40 -0
  999. mindspore/ops/_op_impl/cpu/reduce_std.py +31 -0
  1000. mindspore/ops/_op_impl/cpu/reduce_sum.py +41 -0
  1001. mindspore/ops/_op_impl/cpu/space_to_batch_nd.py +38 -0
  1002. mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
  1003. mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
  1004. mindspore/ops/_op_impl/cpu/split.py +34 -0
  1005. mindspore/ops/_op_impl/cpu/sspaddmm.py +95 -0
  1006. mindspore/ops/_op_impl/cpu/stack.py +38 -0
  1007. mindspore/ops/_op_impl/cpu/sub.py +32 -0
  1008. mindspore/ops/_op_impl/cpu/tensor_copy_slices.py +41 -0
  1009. mindspore/ops/_op_impl/cpu/tile.py +37 -0
  1010. mindspore/ops/_op_impl/cpu/top_k.py +31 -0
  1011. mindspore/ops/_op_impl/cpu/transpose.py +39 -0
  1012. mindspore/ops/_primitive_cache.py +90 -0
  1013. mindspore/ops/_register_for_op.py +73 -0
  1014. mindspore/ops/_utils/__init__.py +20 -0
  1015. mindspore/ops/_utils/utils.py +147 -0
  1016. mindspore/ops/_vmap/__init__.py +25 -0
  1017. mindspore/ops/_vmap/vmap_array_ops.py +2149 -0
  1018. mindspore/ops/_vmap/vmap_base.py +533 -0
  1019. mindspore/ops/_vmap/vmap_convolution_ops.py +441 -0
  1020. mindspore/ops/_vmap/vmap_debug_ops.py +50 -0
  1021. mindspore/ops/_vmap/vmap_grad_math_ops.py +274 -0
  1022. mindspore/ops/_vmap/vmap_grad_nn_ops.py +806 -0
  1023. mindspore/ops/_vmap/vmap_image_ops.py +194 -0
  1024. mindspore/ops/_vmap/vmap_math_ops.py +993 -0
  1025. mindspore/ops/_vmap/vmap_nn_ops.py +2250 -0
  1026. mindspore/ops/_vmap/vmap_other_ops.py +105 -0
  1027. mindspore/ops/_vmap/vmap_random_ops.py +122 -0
  1028. mindspore/ops/_vmap/vmap_sparse_ops.py +89 -0
  1029. mindspore/ops/auto_generate/__init__.py +31 -0
  1030. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +309 -0
  1031. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +252 -0
  1032. mindspore/ops/auto_generate/gen_arg_handler.py +197 -0
  1033. mindspore/ops/auto_generate/gen_extend_func.py +1701 -0
  1034. mindspore/ops/auto_generate/gen_ops_def.py +8482 -0
  1035. mindspore/ops/auto_generate/gen_ops_prim.py +16704 -0
  1036. mindspore/ops/auto_generate/pyboost_inner_prim.py +549 -0
  1037. mindspore/ops/composite/__init__.py +71 -0
  1038. mindspore/ops/composite/base.py +1318 -0
  1039. mindspore/ops/composite/env_ops.py +41 -0
  1040. mindspore/ops/composite/math_ops.py +125 -0
  1041. mindspore/ops/composite/multitype_ops/__init__.py +77 -0
  1042. mindspore/ops/composite/multitype_ops/_compile_utils.py +1459 -0
  1043. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +897 -0
  1044. mindspore/ops/composite/multitype_ops/add_impl.py +606 -0
  1045. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +56 -0
  1046. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +56 -0
  1047. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +56 -0
  1048. mindspore/ops/composite/multitype_ops/div_impl.py +189 -0
  1049. mindspore/ops/composite/multitype_ops/equal_impl.py +335 -0
  1050. mindspore/ops/composite/multitype_ops/floordiv_impl.py +88 -0
  1051. mindspore/ops/composite/multitype_ops/getitem_impl.py +400 -0
  1052. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +109 -0
  1053. mindspore/ops/composite/multitype_ops/greater_impl.py +110 -0
  1054. mindspore/ops/composite/multitype_ops/in_impl.py +196 -0
  1055. mindspore/ops/composite/multitype_ops/left_shift_impl.py +37 -0
  1056. mindspore/ops/composite/multitype_ops/less_equal_impl.py +111 -0
  1057. mindspore/ops/composite/multitype_ops/less_impl.py +112 -0
  1058. mindspore/ops/composite/multitype_ops/logic_not_impl.py +113 -0
  1059. mindspore/ops/composite/multitype_ops/logical_and_impl.py +60 -0
  1060. mindspore/ops/composite/multitype_ops/logical_or_impl.py +61 -0
  1061. mindspore/ops/composite/multitype_ops/mod_impl.py +86 -0
  1062. mindspore/ops/composite/multitype_ops/mul_impl.py +294 -0
  1063. mindspore/ops/composite/multitype_ops/negative_impl.py +79 -0
  1064. mindspore/ops/composite/multitype_ops/not_equal_impl.py +290 -0
  1065. mindspore/ops/composite/multitype_ops/not_in_impl.py +196 -0
  1066. mindspore/ops/composite/multitype_ops/ones_like_impl.py +96 -0
  1067. mindspore/ops/composite/multitype_ops/pow_impl.py +87 -0
  1068. mindspore/ops/composite/multitype_ops/right_shift_impl.py +37 -0
  1069. mindspore/ops/composite/multitype_ops/setitem_impl.py +884 -0
  1070. mindspore/ops/composite/multitype_ops/sub_impl.py +116 -0
  1071. mindspore/ops/composite/multitype_ops/uadd_impl.py +29 -0
  1072. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +228 -0
  1073. mindspore/ops/deprecated.py +315 -0
  1074. mindspore/ops/function/__init__.py +782 -0
  1075. mindspore/ops/function/array_func.py +7226 -0
  1076. mindspore/ops/function/clip_func.py +384 -0
  1077. mindspore/ops/function/debug_func.py +181 -0
  1078. mindspore/ops/function/fft_func.py +44 -0
  1079. mindspore/ops/function/grad/__init__.py +34 -0
  1080. mindspore/ops/function/grad/grad_func.py +1425 -0
  1081. mindspore/ops/function/image_func.py +292 -0
  1082. mindspore/ops/function/linalg_func.py +416 -0
  1083. mindspore/ops/function/math_func.py +12228 -0
  1084. mindspore/ops/function/nn_func.py +8609 -0
  1085. mindspore/ops/function/other_func.py +115 -0
  1086. mindspore/ops/function/parameter_func.py +134 -0
  1087. mindspore/ops/function/random_func.py +1715 -0
  1088. mindspore/ops/function/reshard_func.py +104 -0
  1089. mindspore/ops/function/sparse_func.py +884 -0
  1090. mindspore/ops/function/sparse_unary_func.py +2422 -0
  1091. mindspore/ops/function/spectral_func.py +150 -0
  1092. mindspore/ops/function/vmap_func.py +117 -0
  1093. mindspore/ops/functional.py +464 -0
  1094. mindspore/ops/op_info_register.py +1572 -0
  1095. mindspore/ops/operations/__init__.py +722 -0
  1096. mindspore/ops/operations/_csr_ops.py +403 -0
  1097. mindspore/ops/operations/_custom_grad.py +181 -0
  1098. mindspore/ops/operations/_embedding_cache_ops.py +307 -0
  1099. mindspore/ops/operations/_grad_ops.py +2978 -0
  1100. mindspore/ops/operations/_infer_ops.py +19 -0
  1101. mindspore/ops/operations/_inner_ops.py +2544 -0
  1102. mindspore/ops/operations/_map_tensor_ops.py +112 -0
  1103. mindspore/ops/operations/_ms_kernel.py +601 -0
  1104. mindspore/ops/operations/_ocr_ops.py +379 -0
  1105. mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
  1106. mindspore/ops/operations/_pyfunc_registry.py +58 -0
  1107. mindspore/ops/operations/_quant_ops.py +1844 -0
  1108. mindspore/ops/operations/_rl_inner_ops.py +1231 -0
  1109. mindspore/ops/operations/_scalar_ops.py +106 -0
  1110. mindspore/ops/operations/_sequence_ops.py +1155 -0
  1111. mindspore/ops/operations/_sparse_grad_ops.py +56 -0
  1112. mindspore/ops/operations/_tensor_array.py +359 -0
  1113. mindspore/ops/operations/_thor_ops.py +807 -0
  1114. mindspore/ops/operations/array_ops.py +6124 -0
  1115. mindspore/ops/operations/comm_ops.py +1985 -0
  1116. mindspore/ops/operations/control_ops.py +127 -0
  1117. mindspore/ops/operations/custom_ops.py +1129 -0
  1118. mindspore/ops/operations/debug_ops.py +678 -0
  1119. mindspore/ops/operations/image_ops.py +1041 -0
  1120. mindspore/ops/operations/inner_ops.py +697 -0
  1121. mindspore/ops/operations/linalg_ops.py +95 -0
  1122. mindspore/ops/operations/manually_defined/__init__.py +24 -0
  1123. mindspore/ops/operations/manually_defined/_inner.py +73 -0
  1124. mindspore/ops/operations/manually_defined/ops_def.py +2271 -0
  1125. mindspore/ops/operations/math_ops.py +5095 -0
  1126. mindspore/ops/operations/nn_ops.py +9575 -0
  1127. mindspore/ops/operations/other_ops.py +874 -0
  1128. mindspore/ops/operations/random_ops.py +1288 -0
  1129. mindspore/ops/operations/reshard_ops.py +53 -0
  1130. mindspore/ops/operations/rl_ops.py +288 -0
  1131. mindspore/ops/operations/sparse_ops.py +2753 -0
  1132. mindspore/ops/operations/spectral_ops.py +111 -0
  1133. mindspore/ops/primitive.py +1046 -0
  1134. mindspore/ops/signature.py +54 -0
  1135. mindspore/ops/vm_impl_registry.py +91 -0
  1136. mindspore/ops_generate/__init__.py +27 -0
  1137. mindspore/ops_generate/arg_dtype_cast.py +252 -0
  1138. mindspore/ops_generate/arg_handler.py +197 -0
  1139. mindspore/ops_generate/gen_aclnn_implement.py +263 -0
  1140. mindspore/ops_generate/gen_constants.py +36 -0
  1141. mindspore/ops_generate/gen_ops.py +1099 -0
  1142. mindspore/ops_generate/gen_ops_inner_prim.py +131 -0
  1143. mindspore/ops_generate/gen_pyboost_func.py +1052 -0
  1144. mindspore/ops_generate/gen_utils.py +209 -0
  1145. mindspore/ops_generate/op_proto.py +145 -0
  1146. mindspore/ops_generate/pyboost_utils.py +367 -0
  1147. mindspore/ops_generate/template.py +261 -0
  1148. mindspore/parallel/__init__.py +30 -0
  1149. mindspore/parallel/_auto_parallel_context.py +1486 -0
  1150. mindspore/parallel/_cell_wrapper.py +174 -0
  1151. mindspore/parallel/_cost_model_context.py +700 -0
  1152. mindspore/parallel/_dp_allreduce_fusion.py +159 -0
  1153. mindspore/parallel/_offload_context.py +275 -0
  1154. mindspore/parallel/_parallel_serialization.py +561 -0
  1155. mindspore/parallel/_ps_context.py +242 -0
  1156. mindspore/parallel/_recovery_context.py +110 -0
  1157. mindspore/parallel/_tensor.py +730 -0
  1158. mindspore/parallel/_transformer/__init__.py +35 -0
  1159. mindspore/parallel/_transformer/layers.py +765 -0
  1160. mindspore/parallel/_transformer/loss.py +251 -0
  1161. mindspore/parallel/_transformer/moe.py +693 -0
  1162. mindspore/parallel/_transformer/op_parallel_config.py +222 -0
  1163. mindspore/parallel/_transformer/transformer.py +3119 -0
  1164. mindspore/parallel/_utils.py +612 -0
  1165. mindspore/parallel/algo_parameter_config.py +400 -0
  1166. mindspore/parallel/checkpoint_transform.py +650 -0
  1167. mindspore/parallel/cluster/__init__.py +15 -0
  1168. mindspore/parallel/cluster/process_entity/__init__.py +18 -0
  1169. mindspore/parallel/cluster/process_entity/_api.py +352 -0
  1170. mindspore/parallel/cluster/process_entity/_utils.py +101 -0
  1171. mindspore/parallel/cluster/run.py +136 -0
  1172. mindspore/parallel/mpi/__init__.py +14 -0
  1173. mindspore/parallel/mpi/_mpi_config.py +116 -0
  1174. mindspore/parallel/parameter_broadcast.py +151 -0
  1175. mindspore/parallel/shard.py +481 -0
  1176. mindspore/parallel/transform_safetensors.py +993 -0
  1177. mindspore/profiler/__init__.py +28 -0
  1178. mindspore/profiler/common/__init__.py +14 -0
  1179. mindspore/profiler/common/constant.py +29 -0
  1180. mindspore/profiler/common/exceptions/__init__.py +14 -0
  1181. mindspore/profiler/common/exceptions/error_code.py +83 -0
  1182. mindspore/profiler/common/exceptions/exceptions.py +286 -0
  1183. mindspore/profiler/common/process_pool.py +41 -0
  1184. mindspore/profiler/common/registry.py +47 -0
  1185. mindspore/profiler/common/singleton.py +28 -0
  1186. mindspore/profiler/common/struct_type.py +118 -0
  1187. mindspore/profiler/common/util.py +472 -0
  1188. mindspore/profiler/common/validator/__init__.py +14 -0
  1189. mindspore/profiler/common/validator/validate_path.py +84 -0
  1190. mindspore/profiler/dynamic_profiler.py +694 -0
  1191. mindspore/profiler/envprofiling.py +254 -0
  1192. mindspore/profiler/parser/__init__.py +14 -0
  1193. mindspore/profiler/parser/aicpu_data_parser.py +272 -0
  1194. mindspore/profiler/parser/ascend_analysis/__init__.py +14 -0
  1195. mindspore/profiler/parser/ascend_analysis/constant.py +71 -0
  1196. mindspore/profiler/parser/ascend_analysis/file_manager.py +180 -0
  1197. mindspore/profiler/parser/ascend_analysis/function_event.py +185 -0
  1198. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +136 -0
  1199. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +131 -0
  1200. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +104 -0
  1201. mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
  1202. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +123 -0
  1203. mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +86 -0
  1204. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +75 -0
  1205. mindspore/profiler/parser/ascend_cluster_generator.py +116 -0
  1206. mindspore/profiler/parser/ascend_communicate_generator.py +314 -0
  1207. mindspore/profiler/parser/ascend_flops_generator.py +116 -0
  1208. mindspore/profiler/parser/ascend_fpbp_generator.py +82 -0
  1209. mindspore/profiler/parser/ascend_hccl_generator.py +271 -0
  1210. mindspore/profiler/parser/ascend_integrate_generator.py +42 -0
  1211. mindspore/profiler/parser/ascend_memory_generator.py +185 -0
  1212. mindspore/profiler/parser/ascend_msprof_exporter.py +282 -0
  1213. mindspore/profiler/parser/ascend_msprof_generator.py +187 -0
  1214. mindspore/profiler/parser/ascend_op_generator.py +334 -0
  1215. mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
  1216. mindspore/profiler/parser/ascend_timeline_generator.py +545 -0
  1217. mindspore/profiler/parser/base_timeline_generator.py +483 -0
  1218. mindspore/profiler/parser/container.py +229 -0
  1219. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +697 -0
  1220. mindspore/profiler/parser/flops_parser.py +531 -0
  1221. mindspore/profiler/parser/framework_enum.py +111 -0
  1222. mindspore/profiler/parser/framework_parser.py +464 -0
  1223. mindspore/profiler/parser/framework_struct.py +61 -0
  1224. mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
  1225. mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
  1226. mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
  1227. mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
  1228. mindspore/profiler/parser/hccl_parser.py +573 -0
  1229. mindspore/profiler/parser/hwts_log_parser.py +122 -0
  1230. mindspore/profiler/parser/integrator.py +526 -0
  1231. mindspore/profiler/parser/memory_usage_parser.py +277 -0
  1232. mindspore/profiler/parser/minddata_analyzer.py +800 -0
  1233. mindspore/profiler/parser/minddata_parser.py +186 -0
  1234. mindspore/profiler/parser/minddata_pipeline_parser.py +299 -0
  1235. mindspore/profiler/parser/op_intermediate_parser.py +149 -0
  1236. mindspore/profiler/parser/optime_parser.py +250 -0
  1237. mindspore/profiler/parser/profiler_info.py +213 -0
  1238. mindspore/profiler/parser/step_trace_parser.py +666 -0
  1239. mindspore/profiler/profiler.py +153 -0
  1240. mindspore/profiler/profiling.py +1922 -0
  1241. mindspore/rewrite/__init__.py +28 -0
  1242. mindspore/rewrite/api/__init__.py +17 -0
  1243. mindspore/rewrite/api/node.py +519 -0
  1244. mindspore/rewrite/api/node_type.py +53 -0
  1245. mindspore/rewrite/api/pattern_engine.py +490 -0
  1246. mindspore/rewrite/api/scoped_value.py +181 -0
  1247. mindspore/rewrite/api/symbol_tree.py +497 -0
  1248. mindspore/rewrite/ast_helpers/__init__.py +25 -0
  1249. mindspore/rewrite/ast_helpers/ast_converter.py +143 -0
  1250. mindspore/rewrite/ast_helpers/ast_finder.py +404 -0
  1251. mindspore/rewrite/ast_helpers/ast_flattener.py +268 -0
  1252. mindspore/rewrite/ast_helpers/ast_modifier.py +605 -0
  1253. mindspore/rewrite/ast_helpers/ast_replacer.py +79 -0
  1254. mindspore/rewrite/common/__init__.py +19 -0
  1255. mindspore/rewrite/common/config.py +24 -0
  1256. mindspore/rewrite/common/error_log.py +39 -0
  1257. mindspore/rewrite/common/event.py +28 -0
  1258. mindspore/rewrite/common/namer.py +271 -0
  1259. mindspore/rewrite/common/namespace.py +118 -0
  1260. mindspore/rewrite/common/observable.py +44 -0
  1261. mindspore/rewrite/common/observer.py +54 -0
  1262. mindspore/rewrite/node/__init__.py +22 -0
  1263. mindspore/rewrite/node/call_function.py +95 -0
  1264. mindspore/rewrite/node/cell_container.py +139 -0
  1265. mindspore/rewrite/node/control_flow.py +113 -0
  1266. mindspore/rewrite/node/node.py +1428 -0
  1267. mindspore/rewrite/node/node_manager.py +283 -0
  1268. mindspore/rewrite/node/node_topological_manager.py +223 -0
  1269. mindspore/rewrite/parsers/__init__.py +29 -0
  1270. mindspore/rewrite/parsers/arguments_parser.py +63 -0
  1271. mindspore/rewrite/parsers/assign_parser.py +852 -0
  1272. mindspore/rewrite/parsers/attribute_parser.py +57 -0
  1273. mindspore/rewrite/parsers/class_def_parser.py +289 -0
  1274. mindspore/rewrite/parsers/constant_parser.py +104 -0
  1275. mindspore/rewrite/parsers/container_parser.py +88 -0
  1276. mindspore/rewrite/parsers/expr_parser.py +55 -0
  1277. mindspore/rewrite/parsers/for_parser.py +61 -0
  1278. mindspore/rewrite/parsers/function_def_parser.py +84 -0
  1279. mindspore/rewrite/parsers/if_parser.py +85 -0
  1280. mindspore/rewrite/parsers/module_parser.py +117 -0
  1281. mindspore/rewrite/parsers/parser.py +43 -0
  1282. mindspore/rewrite/parsers/parser_register.py +86 -0
  1283. mindspore/rewrite/parsers/return_parser.py +37 -0
  1284. mindspore/rewrite/parsers/while_parser.py +59 -0
  1285. mindspore/rewrite/sparsify/__init__.py +0 -0
  1286. mindspore/rewrite/sparsify/sparse_transformer.py +457 -0
  1287. mindspore/rewrite/sparsify/sparsify.py +112 -0
  1288. mindspore/rewrite/sparsify/utils.py +179 -0
  1289. mindspore/rewrite/symbol_tree/__init__.py +20 -0
  1290. mindspore/rewrite/symbol_tree/symbol_tree.py +1819 -0
  1291. mindspore/rewrite/symbol_tree/symbol_tree_builder.py +76 -0
  1292. mindspore/rewrite/symbol_tree/symbol_tree_dumper.py +142 -0
  1293. mindspore/run_check/__init__.py +20 -0
  1294. mindspore/run_check/_check_version.py +507 -0
  1295. mindspore/run_check/run_check.py +66 -0
  1296. mindspore/safeguard/__init__.py +18 -0
  1297. mindspore/safeguard/rewrite_obfuscation.py +875 -0
  1298. mindspore/scipy/__init__.py +18 -0
  1299. mindspore/scipy/fft.py +264 -0
  1300. mindspore/scipy/linalg.py +919 -0
  1301. mindspore/scipy/ops.py +165 -0
  1302. mindspore/scipy/ops_grad.py +115 -0
  1303. mindspore/scipy/ops_wrapper.py +74 -0
  1304. mindspore/scipy/optimize/__init__.py +20 -0
  1305. mindspore/scipy/optimize/_bfgs.py +230 -0
  1306. mindspore/scipy/optimize/_lagrange.py +201 -0
  1307. mindspore/scipy/optimize/_lbfgs.py +146 -0
  1308. mindspore/scipy/optimize/gradient_optimization_algorithm.py +168 -0
  1309. mindspore/scipy/optimize/line_search.py +370 -0
  1310. mindspore/scipy/optimize/linear_sum_assignment.py +78 -0
  1311. mindspore/scipy/optimize/minimize.py +200 -0
  1312. mindspore/scipy/utils.py +156 -0
  1313. mindspore/scipy/utils_const.py +246 -0
  1314. mindspore/train/__init__.py +48 -0
  1315. mindspore/train/_utils.py +465 -0
  1316. mindspore/train/amp.py +935 -0
  1317. mindspore/train/anf_ir_pb2.py +1517 -0
  1318. mindspore/train/callback/__init__.py +44 -0
  1319. mindspore/train/callback/_backup_and_restore.py +117 -0
  1320. mindspore/train/callback/_callback.py +613 -0
  1321. mindspore/train/callback/_checkpoint.py +814 -0
  1322. mindspore/train/callback/_cluster_monitor.py +201 -0
  1323. mindspore/train/callback/_dataset_graph.py +150 -0
  1324. mindspore/train/callback/_early_stop.py +239 -0
  1325. mindspore/train/callback/_flops_collector.py +239 -0
  1326. mindspore/train/callback/_history.py +92 -0
  1327. mindspore/train/callback/_lambda_callback.py +80 -0
  1328. mindspore/train/callback/_landscape.py +1049 -0
  1329. mindspore/train/callback/_loss_monitor.py +107 -0
  1330. mindspore/train/callback/_lr_scheduler_callback.py +76 -0
  1331. mindspore/train/callback/_on_request_exit.py +298 -0
  1332. mindspore/train/callback/_reduce_lr_on_plateau.py +226 -0
  1333. mindspore/train/callback/_summary_collector.py +1184 -0
  1334. mindspore/train/callback/_tft_register.py +352 -0
  1335. mindspore/train/callback/_time_monitor.py +141 -0
  1336. mindspore/train/checkpoint_pb2.py +233 -0
  1337. mindspore/train/data_sink.py +219 -0
  1338. mindspore/train/dataset_helper.py +692 -0
  1339. mindspore/train/lineage_pb2.py +1260 -0
  1340. mindspore/train/loss_scale_manager.py +213 -0
  1341. mindspore/train/memory_profiling_pb2.py +298 -0
  1342. mindspore/train/metrics/__init__.py +175 -0
  1343. mindspore/train/metrics/accuracy.py +133 -0
  1344. mindspore/train/metrics/auc.py +129 -0
  1345. mindspore/train/metrics/bleu_score.py +170 -0
  1346. mindspore/train/metrics/confusion_matrix.py +700 -0
  1347. mindspore/train/metrics/cosine_similarity.py +109 -0
  1348. mindspore/train/metrics/dice.py +116 -0
  1349. mindspore/train/metrics/error.py +175 -0
  1350. mindspore/train/metrics/fbeta.py +167 -0
  1351. mindspore/train/metrics/hausdorff_distance.py +333 -0
  1352. mindspore/train/metrics/loss.py +97 -0
  1353. mindspore/train/metrics/mean_surface_distance.py +189 -0
  1354. mindspore/train/metrics/metric.py +373 -0
  1355. mindspore/train/metrics/occlusion_sensitivity.py +225 -0
  1356. mindspore/train/metrics/perplexity.py +133 -0
  1357. mindspore/train/metrics/precision.py +160 -0
  1358. mindspore/train/metrics/recall.py +159 -0
  1359. mindspore/train/metrics/roc.py +223 -0
  1360. mindspore/train/metrics/root_mean_square_surface_distance.py +191 -0
  1361. mindspore/train/metrics/topk.py +167 -0
  1362. mindspore/train/mind_ir_pb2.py +1908 -0
  1363. mindspore/train/model.py +2252 -0
  1364. mindspore/train/node_strategy_pb2.py +653 -0
  1365. mindspore/train/print_pb2.py +184 -0
  1366. mindspore/train/profiling_parallel_pb2.py +151 -0
  1367. mindspore/train/serialization.py +3325 -0
  1368. mindspore/train/summary/__init__.py +23 -0
  1369. mindspore/train/summary/_lineage_adapter.py +41 -0
  1370. mindspore/train/summary/_summary_adapter.py +496 -0
  1371. mindspore/train/summary/_writer_pool.py +207 -0
  1372. mindspore/train/summary/enums.py +56 -0
  1373. mindspore/train/summary/summary_record.py +581 -0
  1374. mindspore/train/summary/writer.py +167 -0
  1375. mindspore/train/summary_pb2.py +1165 -0
  1376. mindspore/train/train_thor/__init__.py +20 -0
  1377. mindspore/train/train_thor/convert_utils.py +268 -0
  1378. mindspore/train/train_thor/dataset_helper.py +192 -0
  1379. mindspore/train/train_thor/model_thor.py +257 -0
  1380. mindspore/utils/__init__.py +21 -0
  1381. mindspore/utils/utils.py +60 -0
  1382. mindspore/version.py +1 -0
  1383. mindspore-2.4.0.dist-info/METADATA +352 -0
  1384. mindspore-2.4.0.dist-info/RECORD +1387 -0
  1385. mindspore-2.4.0.dist-info/WHEEL +5 -0
  1386. mindspore-2.4.0.dist-info/entry_points.txt +3 -0
  1387. mindspore-2.4.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1049 @@
1
+ # Copyright 2021-2023 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """Process data and Calc loss landscape."""
16
+ from __future__ import absolute_import
17
+
18
+ import os
19
+ import time
20
+ import json
21
+ import stat
22
+ import shutil
23
+ import numbers
24
+
25
+ from collections import defaultdict, namedtuple
26
+ from concurrent.futures import wait, ALL_COMPLETED, ProcessPoolExecutor
27
+
28
+ import numpy as np
29
+ from scipy import linalg, sparse
30
+
31
+ from mindspore import log as logger
32
+ from mindspore.common.tensor import Tensor
33
+ from mindspore.common.parameter import Parameter
34
+ from mindspore.train.serialization import load_checkpoint, load_param_into_net
35
+ from mindspore.train.summary_pb2 import LossLandscape
36
+ from mindspore.train.summary import SummaryRecord
37
+ from mindspore.train.summary.enums import PluginEnum
38
+ from mindspore.train.anf_ir_pb2 import DataType
39
+ from mindspore.train._utils import check_value_type, _make_directory
40
+ from mindspore.train.dataset_helper import DatasetHelper
41
+ from mindspore.train.metrics import get_metrics
42
+ from mindspore import context
43
+
44
+ # if there is no path, you need to set to empty list
45
+ Points = namedtuple("Points", ["x", "y", "z"])
46
+
47
+
48
+ def nptype_to_prototype(np_value):
49
+ """
50
+ Transform the np type to proto type.
51
+
52
+ Args:
53
+ np_value (Type): Numpy data type.
54
+
55
+ Returns:
56
+ Type, proto data type.
57
+ """
58
+ np2pt_tbl = {
59
+ np.bool_: 'DT_BOOL',
60
+ np.int8: 'DT_INT8',
61
+ np.int16: 'DT_INT16',
62
+ np.int32: 'DT_INT32',
63
+ np.int64: 'DT_INT64',
64
+ np.uint8: 'DT_UINT8',
65
+ np.uint16: 'DT_UINT16',
66
+ np.uint32: 'DT_UINT32',
67
+ np.uint64: 'DT_UINT64',
68
+ np.float16: 'DT_FLOAT16',
69
+ np.float_: 'DT_FLOAT64',
70
+ np.float32: 'DT_FLOAT32',
71
+ np.float64: 'DT_FLOAT64',
72
+ None: 'DT_UNDEFINED'
73
+ }
74
+ if np_value is None:
75
+ return None
76
+
77
+ np_type = np_value.dtype.type
78
+ proto = np2pt_tbl.get(np_type, None)
79
+ if proto is None:
80
+ raise TypeError("No match for proto data type.")
81
+ return proto
82
+
83
+
84
+ def fill_array_to_tensor(np_value, summary_tensor):
85
+ """
86
+ Package the tensor summary.
87
+
88
+ Args:
89
+ np_value (Type): Summary data type.
90
+ summary_tensor (Tensor): The tensor of summary.
91
+
92
+ Returns:
93
+ Summary, return tensor summary content.
94
+ """
95
+ # get tensor dtype
96
+ tensor_dtype = nptype_to_prototype(np_value)
97
+ summary_tensor.data_type = DataType.Value(tensor_dtype)
98
+
99
+ # get the value list
100
+ tensor_value_list = np_value.reshape(-1).tolist()
101
+ summary_tensor.float_data.extend(tensor_value_list)
102
+
103
+ # get the tensor dim
104
+ for vector in np_value.shape:
105
+ summary_tensor.dims.append(vector)
106
+
107
+ return summary_tensor
108
+
109
+
110
+ def transfer_tensor_to_tuple(inputs):
111
+ """
112
+ If the input is a tensor, convert it to a tuple. If not, the output is unchanged.
113
+ """
114
+ if isinstance(inputs, Tensor):
115
+ return (inputs,)
116
+
117
+ return inputs
118
+
119
+
120
+ class Landscape:
121
+ """Return loss landscape."""
122
+ def __init__(self,
123
+ intervals,
124
+ decomposition,
125
+ landscape_points: Points,
126
+ convergence_point=None,
127
+ path_points=None):
128
+ self.landscape_points = landscape_points
129
+ self.decomposition = decomposition
130
+ self.intervals = intervals
131
+ self.num_samples = 2048
132
+ self.convergence_point = convergence_point
133
+ self.path_points = path_points
134
+ self.unit = 'step'
135
+ self.step_per_epoch = 1
136
+
137
+ def set_convergence_point(self, convergence_point: Points):
138
+ """Set the convergence point."""
139
+ self.convergence_point = convergence_point
140
+
141
+ def transform_to_loss_landscape_msg(self, landscape_data):
142
+ """Transform to loss landscape_msg."""
143
+ landscape_msg = LossLandscape()
144
+ # only save one dim in x and y
145
+ fill_array_to_tensor(landscape_data.landscape_points.x[0], landscape_msg.landscape.x)
146
+ fill_array_to_tensor(landscape_data.landscape_points.y[:, 0], landscape_msg.landscape.y)
147
+ fill_array_to_tensor(landscape_data.landscape_points.z, landscape_msg.landscape.z)
148
+
149
+ if landscape_data.path_points:
150
+ landscape_msg.loss_path.intervals.extend(landscape_data.intervals)
151
+ fill_array_to_tensor(landscape_data.path_points.x, landscape_msg.loss_path.points.x)
152
+ fill_array_to_tensor(landscape_data.path_points.y, landscape_msg.loss_path.points.y)
153
+ fill_array_to_tensor(landscape_data.path_points.z, landscape_msg.loss_path.points.z)
154
+
155
+ if landscape_data.convergence_point:
156
+ fill_array_to_tensor(landscape_data.convergence_point.x, landscape_msg.convergence_point.x)
157
+ fill_array_to_tensor(landscape_data.convergence_point.y, landscape_msg.convergence_point.y)
158
+ fill_array_to_tensor(landscape_data.convergence_point.z, landscape_msg.convergence_point.z)
159
+
160
+ landscape_msg.metadata.decomposition = landscape_data.decomposition
161
+ landscape_msg.metadata.unit = self.unit
162
+ landscape_msg.metadata.step_per_epoch = self.step_per_epoch
163
+
164
+ return landscape_msg
165
+
166
+
167
+ class SummaryLandscape:
168
+ """
169
+ SummaryLandscape can help you to collect loss landscape information.
170
+ It can create landscape in PCA direction or random direction by calculating loss.
171
+
172
+ Note:
173
+ SummaryLandscape only supports Linux systems.
174
+
175
+ Args:
176
+ summary_dir (str): The path of summary is used to save the model weight,
177
+ metadata and other data required to create landscape.
178
+
179
+ Examples:
180
+ >>> import mindspore as ms
181
+ >>> import mindspore.nn as nn
182
+ >>> from mindspore.train import Model, Accuracy, Loss
183
+ >>> from mindspore import SummaryCollector, SummaryLandscape
184
+ >>>
185
+ >>> if __name__ == '__main__':
186
+ ... # If the device_target is Ascend, set the device_target to "Ascend"
187
+ ... ms.set_context(mode=ms.GRAPH_MODE, device_target="GPU")
188
+ ... # Create the dataset taking MNIST as an example. Refer to
189
+ ... # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/mnist.py
190
+ ... ds_train = create_dataset()
191
+ ... # Define the network structure of LeNet5. Refer to
192
+ ... # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
193
+ ... network = LeNet5()
194
+ ... net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
195
+ ... net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9)
196
+ ... model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
197
+ ... # Simple usage for collect landscape information:
198
+ ... interval_1 = [1, 2, 3, 4, 5]
199
+ ... summary_collector = SummaryCollector(summary_dir='./summary/lenet_interval_1',
200
+ ... collect_specified_data={'collect_landscape':{"landscape_size": 4,
201
+ ... "unit": "step",
202
+ ... "create_landscape":{"train":True,
203
+ ... "result":False},
204
+ ... "num_samples": 2048,
205
+ ... "intervals": [interval_1]}
206
+ ... })
207
+ ... model.train(1, ds_train, callbacks=[summary_collector], dataset_sink_mode=False)
208
+ ...
209
+ ... # Simple usage for visualization landscape:
210
+ ... def callback_fn():
211
+ ... # Define the network structure of LeNet5. Refer to
212
+ ... # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
213
+ ... network = LeNet5()
214
+ ... net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
215
+ ... metrics = {"Loss": Loss()}
216
+ ... model = Model(network, net_loss, metrics=metrics)
217
+ ... # Create the dataset taking MNIST as an example. Refer to
218
+ ... # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/mnist.py
219
+ ... ds_eval = create_dataset()
220
+ ... return model, network, ds_eval, metrics
221
+ ...
222
+ ... summary_landscape = SummaryLandscape('./summary/lenet_interval_1')
223
+ ... # parameters of collect_landscape can be modified or unchanged
224
+ ... summary_landscape.gen_landscapes_with_multi_process(callback_fn,
225
+ ... collect_landscape={"landscape_size": 4,
226
+ ... "create_landscape":{"train":False,
227
+ ... "result":False},
228
+ ... "num_samples": 2048,
229
+ ... "intervals": [interval_1]},
230
+ ... device_ids=[1])
231
+ """
232
+ def __init__(self, summary_dir):
233
+ self._summary_dir = os.path.realpath(summary_dir)
234
+ self._ckpt_dir = os.path.join(self._summary_dir, 'ckpt_dir')
235
+ _make_directory(self._ckpt_dir)
236
+
237
+ # save the model params file, key is epoch, value is the ckpt file path
238
+ self._model_params_file_map = {}
239
+ self._epoch_group = defaultdict(list)
240
+ self._metric_fns = None
241
+
242
+ def _get_model_params(self, epochs):
243
+ """Get the model params."""
244
+ parameters = []
245
+ for epoch in epochs:
246
+ file_path = self._model_params_file_map.get(str(epoch))
247
+ parameters.append(list(load_checkpoint(file_path).values()))
248
+ return parameters
249
+
250
+ def _create_epoch_group(self, intervals):
251
+ for i, interval in enumerate(intervals):
252
+ for j in interval:
253
+ self._epoch_group[i].append(j)
254
+
255
+ def clean_ckpt(self):
256
+ """
257
+ Clean the checkpoint.
258
+
259
+ Tutorial Examples:
260
+ - `Training Optimization Process Visualization
261
+ <https://www.mindspore.cn/mindinsight/docs/en/master/landscape.html>`_
262
+ """
263
+ shutil.rmtree(self._ckpt_dir, ignore_errors=True)
264
+
265
+ def gen_landscapes_with_multi_process(self, callback_fn, collect_landscape=None,
266
+ device_ids=None, output=None):
267
+ """
268
+ Use the multi process to generate landscape.
269
+
270
+ Args:
271
+ callback_fn (python function): A python function object. User needs to write a function,
272
+ it has no input, and the return requirements are as follows.
273
+
274
+ - mindspore.train.Model: User's model object.
275
+ - mindspore.nn.Cell: User's network object.
276
+ - mindspore.dataset: User's dataset object for create loss landscape.
277
+ - mindspore.train.Metrics: User's metrics object.
278
+ collect_landscape (Union[dict, None]): The meaning of the parameters
279
+ when creating loss landscape is consistent with the fields
280
+ with the same name in SummaryCollector. The purpose of setting here
281
+ is to allow users to freely modify creating parameters. Default: ``None`` .
282
+
283
+ - landscape_size (int): Specify the image resolution of the generated loss landscape.
284
+ For example, if it is set to ``128`` , the resolution of the landscape is 128 * 128.
285
+ The calculation time increases with the increase of resolution.
286
+ Default: ``40`` . Optional values: between 3 and 256.
287
+ - create_landscape (dict): Select how to create loss landscape.
288
+ Training process loss landscape(train) and training result loss landscape(result).
289
+ Default: ``{"train": True, "result": True}``. Optional: ``True`` / ``False`` .
290
+ - num_samples (int): The size of the dataset used to create the loss landscape.
291
+ For example, in image dataset, You can set num_samples is 2048,
292
+ which means that 2048 images are used to create loss landscape.
293
+ Default: ``2048`` .
294
+ - intervals (List[List[int]]): Specifies the interval
295
+ in which the loss landscape. For example: If the user wants to
296
+ create loss landscape of two training processes, they are 1-5 epoch
297
+ and 6-10 epoch respectively. They can set [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]].
298
+ Note: Each interval have at least three epochs.
299
+ device_ids (List(int)): Specifies which devices are used to create loss landscape.
300
+ For example: [0, 1] refers to creating loss landscape with device 0 and device 1.
301
+ Default: ``None`` .
302
+ output (str): Specifies the path to save the loss landscape.
303
+ Default: ``None`` . The default save path is the same as the summary file.
304
+ """
305
+
306
+ executor = None
307
+ if len(device_ids) > 1:
308
+ executor = ProcessPoolExecutor(len(device_ids))
309
+ futures = [executor.submit(self._set_context, i) for i in device_ids]
310
+ wait(futures, return_when=ALL_COMPLETED)
311
+
312
+ output_path = os.path.realpath(output) if output is not None else self._summary_dir
313
+ summary_record = SummaryRecord(output_path)
314
+ self._check_device_ids(device_ids)
315
+ if collect_landscape is not None:
316
+ try:
317
+ self._check_collect_landscape_data(collect_landscape)
318
+ except (ValueError, TypeError) as err:
319
+ summary_record.close()
320
+ raise err
321
+ json_path = os.path.join(self._ckpt_dir, 'train_metadata.json')
322
+ if not os.path.exists(json_path):
323
+ summary_record.close()
324
+ raise FileNotFoundError(f'For "{self.__class__.__name__}", '
325
+ f'train_metadata.json file path of {json_path} not exists.')
326
+ with open(json_path, 'r') as file:
327
+ data = json.load(file)
328
+ for key, value in collect_landscape.items():
329
+ if key in data.keys():
330
+ data[key] = value
331
+
332
+ if "intervals" in collect_landscape.keys():
333
+ self._create_epoch_group(collect_landscape.get("intervals"))
334
+ data["epoch_group"] = self._epoch_group
335
+ with open(json_path, 'w') as file:
336
+ json.dump(data, file)
337
+ os.chmod(json_path, stat.S_IRUSR)
338
+
339
+ for interval, landscape in self._list_landscapes(callback_fn=callback_fn, executor=executor,
340
+ device_ids=device_ids):
341
+ summary_record.add_value(PluginEnum.LANDSCAPE.value, f'landscape_{str(interval)}', landscape)
342
+ summary_record.record(0)
343
+ summary_record.flush()
344
+ summary_record.close()
345
+
346
+ def _list_landscapes(self, callback_fn, executor=None, device_ids=None):
347
+ """Create landscape with single device and list all landscape."""
348
+
349
+ if not os.path.exists(os.path.join(self._ckpt_dir, 'train_metadata.json')):
350
+ raise FileNotFoundError(f'For "{self.__class__.__name__}", train_metadata.json file does not exist '
351
+ f'under the path, please use summary_collector to collect information to '
352
+ f'create the json file')
353
+ with open(os.path.join(self._ckpt_dir, 'train_metadata.json'), 'r') as file:
354
+ data = json.load(file)
355
+ self._check_json_file_data(data)
356
+
357
+ self._epoch_group = data['epoch_group']
358
+ self._model_params_file_map = data['model_params_file_map']
359
+ kwargs = dict(proz=0.2, landscape_size=data['landscape_size'], device_ids=device_ids, callback_fn=callback_fn)
360
+
361
+ start = time.time()
362
+ kwargs['executor'] = executor
363
+ if data['create_landscape']['train']:
364
+ for i, epochs in enumerate(self._epoch_group.values()):
365
+ self._log_message(data['create_landscape'], index=i, interval=epochs)
366
+ kwargs['epochs'] = epochs
367
+ mid_time = time.time()
368
+ landscape_data = self._create_landscape_by_pca(**kwargs)
369
+ logger.info("Create landscape end, use time: %s s." % (round(time.time() - mid_time, 6)))
370
+ landscape_data.unit = data['unit']
371
+ landscape_data.step_per_epoch = data['step_per_epoch']
372
+ landscape_data.num_samples = data['num_samples']
373
+ yield [epochs[0], epochs[-1]], landscape_data.transform_to_loss_landscape_msg(landscape_data)
374
+
375
+ if data['create_landscape']['result']:
376
+ final_epochs = [list(self._epoch_group.values())[-1][-1]]
377
+ self._log_message(data['create_landscape'], final_epochs=final_epochs)
378
+ kwargs['epochs'] = final_epochs
379
+ mid_time = time.time()
380
+ landscape_data = self._create_landscape_by_random(**kwargs)
381
+ logger.info("Create landscape end, use time: %s s." % (round(time.time() - mid_time, 6)))
382
+ landscape_data.unit = data['unit']
383
+ landscape_data.step_per_epoch = data['step_per_epoch']
384
+ landscape_data.num_samples = data['num_samples']
385
+ yield final_epochs, landscape_data.transform_to_loss_landscape_msg(landscape_data)
386
+ logger.info("Total use time: %s s." % (round(time.time() - start, 6)))
387
+
388
+ def _log_message(self, create_landscape, index=None, interval=None, final_epochs=None):
389
+ """Generate drawing information using log."""
390
+ if final_epochs is None:
391
+ if create_landscape['result']:
392
+ msg = f"Start to create the {index + 1}/{len(self._epoch_group) + 1} landscapes, " \
393
+ f"checkpoint is {interval}, decomposition is PCA."
394
+ else:
395
+ msg = f"Start to create the {index + 1}/{len(self._epoch_group)} landscapes, " \
396
+ f"checkpoint is {interval}, decomposition is PCA."
397
+ else:
398
+ if create_landscape['train']:
399
+ msg = f"Start to create the {len(self._epoch_group) + 1}/{len(self._epoch_group) + 1} landscapes, " \
400
+ f"checkpoint is {final_epochs}, decomposition is Random. "
401
+ else:
402
+ msg = f"Start to create the {1}/{1} landscapes, " \
403
+ f"checkpoint is {final_epochs}, decomposition is Random."
404
+ logger.info(msg)
405
+
406
+ @staticmethod
407
+ def _set_context(device_id):
408
+ """Set context."""
409
+ context.set_context(device_id=device_id)
410
+ context.set_context(mode=context.GRAPH_MODE)
411
+
412
+ def _create_landscape_by_pca(self, epochs, proz, landscape_size, device_ids=None, callback_fn=None, executor=None):
413
+ """Create landscape by PCA."""
414
+ multi_parameters = self._get_model_params(epochs)
415
+ param_matrixs = []
416
+ for parameters in multi_parameters:
417
+ parlis = []
418
+ for param in parameters:
419
+ if ("weight" in param.name or "bias" in param.name) and ("moment" not in param.name):
420
+ data = param.data.asnumpy()
421
+ parlis = np.concatenate((parlis, data), axis=None)
422
+ else:
423
+ continue
424
+ param_matrixs.append(parlis)
425
+ param_matrixs = np.vstack(param_matrixs)
426
+ param_matrixs = param_matrixs[:-1] - param_matrixs[-1]
427
+ # Only 2 are needed, as we have to reduce high dimensions into 2D.And we reserve one for loss value.
428
+ pca = _PCA(n_comps=2)
429
+ principal_components = pca.compute(param_matrixs.T)
430
+ v_ori, w_ori = np.array(principal_components[:, 0]), np.array(principal_components[:, -1])
431
+ final_params = list(multi_parameters[-1])
432
+
433
+ # Reshape PCA directions(include dimensions of all parameters) into original shape of Model parameters
434
+ v_ndarray = self._reshape_vector(v_ori, final_params)
435
+ w_ndarray = self._reshape_vector(w_ori, final_params)
436
+
437
+ # Reshape PCA directions(include dimensions of only weights) into original shape of Model parameters
438
+ final_params_filtered = self._filter_weight_and_bias(final_params)
439
+ v_ndarray_filtered = self._reshape_vector(v_ori, final_params_filtered)
440
+ w_ndarray_filtered = self._reshape_vector(w_ori, final_params_filtered)
441
+
442
+ v_ndarray, w_ndarray = self._normalize_vector(final_params, v_ndarray, w_ndarray)
443
+ v_ndarray_filtered, w_ndarray_filtered = self._normalize_vector(final_params_filtered, v_ndarray_filtered,
444
+ w_ndarray_filtered)
445
+ # Flat to a single vector and calc alpha, beta
446
+ v_param = self._flat_ndarray(v_ndarray_filtered)
447
+ w_param = self._flat_ndarray(w_ndarray_filtered)
448
+ final_params_numpy = [param.data.asnumpy() for param in final_params]
449
+ final_params_filtered_numpy = [param.data.asnumpy() for param in final_params_filtered]
450
+ coefs = self._calc_coefs(multi_parameters, final_params_filtered_numpy, v_param, w_param)
451
+
452
+ # generate coordinates of loss landscape
453
+ coefs_x = coefs[:, 0][np.newaxis]
454
+ coefs_y = coefs[:, 1][np.newaxis]
455
+
456
+ x_axis = np.linspace(min(coefs_x[0]) - proz * (max(coefs_x[0]) - min(coefs_x[0])),
457
+ max(coefs_x[0]) + proz * (max(coefs_x[0]) - min(coefs_x[0])), landscape_size)
458
+ y_axis = np.linspace(min(coefs_y[0]) - proz * (max(coefs_y[0]) - min(coefs_y[0])),
459
+ max(coefs_y[0]) + proz * (max(coefs_y[0]) - min(coefs_y[0])), landscape_size)
460
+ x_points, y_points = np.meshgrid(x_axis, y_axis)
461
+
462
+ test_final_params = dict()
463
+ for param in final_params:
464
+ test_final_params[param.name] = param.data.asnumpy()
465
+
466
+ if executor is not None:
467
+ coefs_parts, y_points_parts = [], []
468
+ count_per_parts = len(coefs) // len(device_ids)
469
+ start = 0
470
+ for i in range(len(device_ids)):
471
+ if i != len(device_ids) - 1:
472
+ coefs_parts.append(coefs[start:start + count_per_parts])
473
+ start = start + count_per_parts
474
+ else:
475
+ coefs_parts.append(coefs[start:])
476
+ count_per_parts = len(y_points) // len(device_ids)
477
+ start = 0
478
+ logger.info("Use multi process, device_id: %s." % (device_ids))
479
+ for i in range(len(device_ids)):
480
+ if i != len(device_ids) - 1:
481
+ y_points_parts.append(y_points[start:start + count_per_parts])
482
+ start = start + count_per_parts
483
+ else:
484
+ y_points_parts.append(y_points[start:])
485
+
486
+ futures = []
487
+ for i, _ in enumerate(device_ids):
488
+ future = executor.submit(self._cont_loss_wrapper, callback_fn, test_final_params, final_params_numpy,
489
+ v_ndarray, w_ndarray, x_points, y_points_parts[i], coefs=coefs_parts[i])
490
+ futures.append(future)
491
+ wait(futures, return_when=ALL_COMPLETED)
492
+
493
+ z_points, paths = [], []
494
+ for future in futures:
495
+ paths += future.result()[0]
496
+ z_points += future.result()[1]
497
+ else:
498
+ paths, z_points = self._cont_loss_wrapper(callback_fn, test_final_params, final_params_numpy,
499
+ v_ndarray, w_ndarray, x_points, y_points, coefs=coefs)
500
+
501
+ paths = np.array(paths)
502
+ landscape_points = Points(x_points, y_points, np.vstack(z_points))
503
+ path_points = Points(coefs_x[0], coefs_y[0], paths.T[0])
504
+ zero_index = int(np.argwhere(path_points.x == 0))
505
+ convergence_point = Points(np.array([0]), np.array([0]), np.array([path_points.z[zero_index]]))
506
+ landscape = Landscape(intervals=epochs, decomposition='PCA', landscape_points=landscape_points,
507
+ path_points=path_points, convergence_point=convergence_point)
508
+ return landscape
509
+
510
+ def _cont_loss_wrapper(self, callback_fn, test_final_params, final_params_numpy,
511
+ v_ndarray, w_ndarray, x_points, y_points, coefs=None):
512
+ """Compute loss wrapper."""
513
+ model, network, valid_dataset, metrics = callback_fn()
514
+ with open(os.path.join(self._ckpt_dir, 'train_metadata.json'), 'r') as file:
515
+ data = json.load(file)
516
+ self._check_json_file_data(data)
517
+ num_samples = data['num_samples']
518
+ batch_size = valid_dataset.get_batch_size()
519
+ num_batches = num_samples // batch_size
520
+ valid_dataset = valid_dataset.take(num_batches)
521
+
522
+ paths, final_params = [], []
523
+ for (key, value) in test_final_params.items():
524
+ parameter = Parameter(Tensor(value), name=key, requires_grad=True)
525
+ final_params.append(parameter)
526
+ if coefs is not None:
527
+ for i, coef in enumerate(coefs):
528
+ loss_data = self._cont_loss(valid_dataset, network, model, metrics, final_params,
529
+ final_params_numpy, [coef[0]], coef[1], v_ndarray, w_ndarray, path=True)
530
+ paths.append(loss_data)
531
+ print("Drawing landscape path total progress is %s/%s, landscape path loss is %s."
532
+ % (i+1, len(coefs), loss_data[0]))
533
+ # Start to calc loss landscape
534
+ z_points = list()
535
+
536
+ # Compute loss landscape
537
+ for i, _ in enumerate(y_points):
538
+ print("Drawing landscape total progress: %s/%s." % (i+1, len(y_points)))
539
+ vals = self._cont_loss(valid_dataset, network, model, metrics, final_params,
540
+ final_params_numpy, x_points[i], y_points[i][0],
541
+ v_ndarray, w_ndarray)
542
+ z_points.append(vals)
543
+
544
+ return paths, z_points
545
+
546
+ def _create_landscape_by_random(self, epochs, proz, landscape_size, device_ids=None,
547
+ callback_fn=None, executor=None):
548
+ """Create landscape by Random."""
549
+ multi_parameters = self._get_model_params(epochs)
550
+ final_params = list(multi_parameters[-1])
551
+ final_params_numpy = [param.data.asnumpy() for param in final_params]
552
+ total_params = sum(np.size(p) for p in final_params_numpy)
553
+ v_rand = np.random.normal(size=total_params)
554
+ w_rand = np.random.normal(size=total_params)
555
+
556
+ # Reshape Random directions(include dimensions of all parameters) into original shape of Model parameters
557
+ v_ndarray = self._reshape_random_vector(v_rand, final_params_numpy)
558
+ w_ndarray = self._reshape_random_vector(w_rand, final_params_numpy)
559
+ v_ndarray, w_ndarray = self._normalize_vector(final_params, v_ndarray, w_ndarray)
560
+
561
+ boundaries_x, boundaries_y = 5, 5
562
+ x_axis = np.linspace(-proz * boundaries_x, proz * boundaries_x, landscape_size)
563
+ y_axis = np.linspace(-proz * boundaries_y, proz * boundaries_y, landscape_size)
564
+ x_points, y_points = np.meshgrid(x_axis, y_axis)
565
+ test_final_params = dict()
566
+ for param in final_params:
567
+ test_final_params[param.name] = param.data.asnumpy()
568
+ if executor is not None:
569
+ logger.info("Use multi process, device_id: %s." % (device_ids))
570
+ y_points_parts = []
571
+ count_per_parts = len(y_points) // len(device_ids)
572
+ start = 0
573
+ for i in range(len(device_ids)):
574
+ if i != len(device_ids) - 1:
575
+ y_points_parts.append(y_points[start:start + count_per_parts])
576
+ start = start + count_per_parts
577
+ else:
578
+ y_points_parts.append(y_points[start:])
579
+
580
+ futures = []
581
+ for i in range(len(device_ids)):
582
+ future = executor.submit(self._cont_loss_wrapper, callback_fn, test_final_params, final_params_numpy,
583
+ v_ndarray, w_ndarray, x_points, y_points_parts[i])
584
+ futures.append(future)
585
+ wait(futures, return_when=ALL_COMPLETED)
586
+ z_points = []
587
+ for future in futures:
588
+ z_points += future.result()[1]
589
+ else:
590
+ _, z_points = self._cont_loss_wrapper(callback_fn, test_final_params, final_params_numpy,
591
+ v_ndarray, w_ndarray, x_points, y_points)
592
+
593
+ landscape_points = Points(x_points, y_points, np.vstack(z_points))
594
+ convergence_point = Points(np.array([x_axis[len(x_axis)//2]]), np.array([y_axis[len(y_axis)//2]]),
595
+ np.array([z_points[len(x_axis)//2][len(y_axis)//2]]))
596
+ landscape = Landscape(intervals=epochs, decomposition='Random', landscape_points=landscape_points,
597
+ convergence_point=convergence_point)
598
+ return landscape
599
+
600
+ @staticmethod
601
+ def _filter_weight_and_bias(parameters):
602
+ """Filter the weight and bias of parameters."""
603
+
604
+ filter_params = []
605
+ for param in parameters:
606
+ if ('weight' not in param.name and 'bias' not in param.name) or ('moment' in param.name):
607
+ continue
608
+ filter_params.append(param)
609
+ return filter_params
610
+
611
+ @staticmethod
612
+ def _reshape_vector(vector, parameters):
613
+ """Reshape vector into model shape."""
614
+ ndarray = list()
615
+ index = 0
616
+ for param in parameters:
617
+ data = param.data.asnumpy()
618
+ if ("weight" not in param.name and "bias" not in param.name) or ("moment" in param.name):
619
+ ndarray.append(np.array(data, dtype=np.float32))
620
+ continue
621
+
622
+ vec_it = vector[index:(index + data.size)].reshape(data.shape)
623
+ ndarray.append(np.array(vec_it, dtype=np.float32))
624
+ index += data.size
625
+ return ndarray
626
+
627
+ @staticmethod
628
+ def _reshape_random_vector(vector, params_numpy):
629
+ """ Reshape random vector into model shape."""
630
+ ndarray = list()
631
+ index = 0
632
+ for param in params_numpy:
633
+ len_p = np.size(param)
634
+ p_size = np.shape(param)
635
+ vec_it = vector[index:(index + len_p)].reshape(p_size)
636
+ ndarray.append(np.array(vec_it, dtype=np.float32))
637
+ index += len_p
638
+ return ndarray
639
+
640
+ @staticmethod
641
+ def _normalize_vector(parameters, get_v, get_w):
642
+ """
643
+ Normalizes the vectors spanning the 2D space, to make trajectories comparable between each other.
644
+ """
645
+ for i, param in enumerate(parameters):
646
+ # Here as MindSpore ckpt has hyperparameters, we should skip them to make sure
647
+ # PCA calculation is correct.
648
+ data = param.data.asnumpy()
649
+ if ("weight" in param.name or "bias" in param.name) and ("moment" not in param.name):
650
+ factor_v = np.linalg.norm(data) / np.linalg.norm(get_v[i])
651
+ factor_w = np.linalg.norm(data) / np.linalg.norm(get_w[i])
652
+ get_v[i] = get_v[i] * factor_v
653
+ get_w[i] = get_w[i] * factor_w
654
+ else:
655
+ get_v[i] = get_v[i] * 0
656
+ get_w[i] = get_w[i] * 0
657
+
658
+ return get_v, get_w
659
+
660
+ @staticmethod
661
+ def _flat_ndarray(ndarray_vector):
662
+ """Concatenates a python array of numpy arrays into a single, flat numpy array."""
663
+ return np.concatenate([item.flatten() for item in ndarray_vector], axis=None)
664
+
665
+ def _calc_coefs(self, parameter_group, final_param_ndarray, v_vector, w_vector):
666
+ """
667
+ Calculates the scale factors for plotting points
668
+ in the 2D space spanned by the vectors v and w.
669
+ """
670
+
671
+ matris = [v_vector, w_vector]
672
+ matris = np.vstack(matris)
673
+ matris = matris.T
674
+
675
+ pas = self._flat_ndarray(final_param_ndarray)
676
+ coefs = list()
677
+ for parameters in parameter_group:
678
+ testi = list()
679
+ for param in parameters:
680
+ # Here as MindSpore ckpt has hyperparameters,
681
+ # we should skip them to make sure PCA calculation is correct
682
+ if ('weight' not in param.name and 'bias' not in param.name) or ('moment' in param.name):
683
+ continue
684
+ testi.append(param.data.asnumpy())
685
+
686
+ st_vec = self._flat_ndarray(testi)
687
+ b_vec = st_vec - pas
688
+ # Here using least square method to get solutions of a equation system to generate alpha and beta.
689
+ coefs.append(np.hstack(np.linalg.lstsq(matris, b_vec, rcond=None)[0]))
690
+
691
+ return np.array(coefs)
692
+
693
+ def _cont_loss(self, ds_eval, network, model, metrics, parameters,
694
+ final_params_numpy, alph, beta, get_v, get_w, path=False):
695
+ """
696
+ Calculates the loss landscape based on vectors v and w (which can be principal components).
697
+ Changes the internal state of model. Executes model.
698
+ """
699
+ logger.info("start to cont loss")
700
+ vals = list()
701
+
702
+ al_item = 0
703
+ for i, _ in enumerate(alph):
704
+ # calculate new parameters for model
705
+
706
+ parameters_dict = dict()
707
+ for j, param in enumerate(parameters):
708
+ parameters_dict[param.name] = self._change_parameter(j, param, final_params_numpy,
709
+ alph[al_item], beta,
710
+ get_v, get_w)
711
+
712
+ al_item += 1
713
+ # load parameters into model and calculate loss
714
+
715
+ load_param_into_net(network, parameters_dict)
716
+ del parameters_dict
717
+ loss = self._loss_compute(model, ds_eval, metrics)
718
+ if path is False:
719
+ print("Current local landscape progress is %s/%s, landscape loss is %s."
720
+ % (i+1, len(alph), loss.get('Loss')))
721
+ vals = np.append(vals, loss.get('Loss'))
722
+
723
+ return vals
724
+
725
+ @staticmethod
726
+ def _change_parameter(index, parameter, final_params_numpy, alpha, beta, get_v, get_w):
727
+ """Function for changing parameter value with map and lambda."""
728
+ data = final_params_numpy[index]
729
+ data_target = data + alpha * get_v[index] + beta * get_w[index]
730
+ data_target = Tensor(data_target.astype(np.float32))
731
+ parameter.set_data(Tensor(data_target))
732
+ return parameter
733
+
734
+ def _loss_compute(self, model, data, metrics):
735
+ """Compute loss."""
736
+ dataset_sink_mode = False
737
+ self._metric_fns = get_metrics(metrics)
738
+ for metric in self._metric_fns.values():
739
+ metric.clear()
740
+
741
+ network = model.train_network
742
+ dataset_helper = DatasetHelper(data, dataset_sink_mode)
743
+
744
+ network.set_train(True)
745
+ network.phase = 'train'
746
+
747
+ for inputs in dataset_helper:
748
+ inputs = transfer_tensor_to_tuple(inputs)
749
+ outputs = network(*inputs)
750
+ self._update_metrics(outputs)
751
+
752
+ metrics = self._get_metrics()
753
+ return metrics
754
+
755
+ def _update_metrics(self, outputs):
756
+ """Update metrics local values."""
757
+ if isinstance(outputs, Tensor):
758
+ outputs = (outputs,)
759
+ if not isinstance(outputs, tuple):
760
+ raise ValueError(f"The argument 'outputs' should be tuple, but got {type(outputs)}. "
761
+ f"Modify 'output' to Tensor or tuple. ")
762
+
763
+ for metric in self._metric_fns.values():
764
+ metric.update(outputs[0])
765
+
766
+ def _get_metrics(self):
767
+ """Get metrics local values."""
768
+ metrics = dict()
769
+ for key, value in self._metric_fns.items():
770
+ metrics[key] = value.eval()
771
+ return metrics
772
+
773
+ def _check_unit(self, unit):
774
+ """Check unit type and value."""
775
+ check_value_type('unit', unit, str)
776
+ if unit not in ["step", "epoch"]:
777
+ raise ValueError(f'For "{self.__class__.__name__}", the "unit" in train_metadata.json should be '
778
+ f'step or epoch, but got the: {unit}')
779
+
780
+ def _check_landscape_size(self, landscape_size):
781
+ """Check landscape size type and value."""
782
+ check_value_type('landscape_size', landscape_size, int)
783
+ # landscape size should be between 3 and 256.
784
+ if landscape_size < 3 or landscape_size > 256:
785
+ raise ValueError(f'For "{self.__class__.__name__}", "landscape_size" in train_metadata.json should be '
786
+ f'between 3 and 256, but got the: {landscape_size}')
787
+
788
+ def _check_create_landscape(self, create_landscape):
789
+ """Check create landscape type and value."""
790
+ check_value_type('create_landscape', create_landscape, dict)
791
+ for param, value in create_landscape.items():
792
+ if param not in ["train", "result"]:
793
+ raise ValueError(f'For "{self.__class__.__name__}", the key of "create_landscape" should be in '
794
+ f'["train", "result"], but got the: {param}.')
795
+ if len(create_landscape) < 2:
796
+ raise ValueError(f'For "{self.__class__.__name__}", the key of "create_landscape" should be train '
797
+ f'and result, but only got the: {param}')
798
+ check_value_type(param, value, bool)
799
+
800
+ def _check_intervals(self, intervals):
801
+ """Check intervals type and value."""
802
+ check_value_type('intervals', intervals, list)
803
+ for _, interval in enumerate(intervals):
804
+ check_value_type('each interval in intervals', interval, list)
805
+ #Each interval have at least three epochs.
806
+ if len(interval) < 3:
807
+ raise ValueError(f'For "{self.__class__.__name__}", the length of each list in "intervals" '
808
+ f'should not be less than three, but got the: {interval}.')
809
+ for j in interval:
810
+ if not isinstance(j, int):
811
+ raise TypeError(f'For "{self.__class__.__name__}", the type of each value in "intervals" '
812
+ f'should be int, but got the: {type(j)}.')
813
+
814
+ def _check_device_ids(self, device_ids):
815
+ """Check device_ids type and value."""
816
+ check_value_type('device_ids', device_ids, list)
817
+ for i in device_ids:
818
+ if not isinstance(i, int):
819
+ raise TypeError(f'For "{self.__class__.__name__}.gen_landscapes_with_multi_process", the parameter '
820
+ f'"device_ids" type should be int, but got the: {type(i)}.')
821
+ #device_id should be between 0 and 7.
822
+ if i < 0 or i > 7:
823
+ raise ValueError(f'For "{self.__class__.__name__}.gen_landscapes_with_multi_process", the parameter '
824
+ f'"device_ids" should be between 0 and 7, but got {i}.')
825
+
826
+ def _check_collect_landscape_data(self, collect_landscape):
827
+ """Check collect landscape data type and value."""
828
+ for param in collect_landscape.keys():
829
+ if param not in ["landscape_size", "unit", "num_samples", "create_landscape", "intervals"]:
830
+ raise ValueError(f'For "{self.__class__.__name__}", the key of collect landscape should be '
831
+ f'landscape_size, unit, num_samples create_landscape or intervals, '
832
+ f'but got the: {param}. ')
833
+ if "landscape_size" in collect_landscape:
834
+ landscape_size = collect_landscape.get("landscape_size")
835
+ self._check_landscape_size(landscape_size)
836
+ if "unit" in collect_landscape:
837
+ unit = collect_landscape.get("unit")
838
+ self._check_unit(unit)
839
+ if "num_samples" in collect_landscape:
840
+ num_samples = collect_landscape.get("num_samples")
841
+ check_value_type("num_samples", num_samples, int)
842
+ if "create_landscape" in collect_landscape:
843
+ create_landscape = collect_landscape.get("create_landscape")
844
+ self._check_create_landscape(create_landscape)
845
+ if "intervals" in collect_landscape:
846
+ intervals = collect_landscape.get("intervals")
847
+ self._check_intervals(intervals)
848
+
849
+ def _check_json_file_data(self, json_file_data):
850
+ """Check json file data."""
851
+ file_key = ["epoch_group", "model_params_file_map", "step_per_epoch", "unit",
852
+ "num_samples", "landscape_size", "create_landscape"]
853
+ for key in json_file_data.keys():
854
+ if key not in file_key:
855
+ raise ValueError(f'"train_metadata" json file should be {file_key}, but got the: {key}')
856
+ epoch_group = json_file_data["epoch_group"]
857
+ model_params_file_map = json_file_data["model_params_file_map"]
858
+ step_per_epoch = json_file_data["step_per_epoch"]
859
+ unit = json_file_data["unit"]
860
+ num_samples = json_file_data["num_samples"]
861
+ landscape_size = json_file_data["landscape_size"]
862
+ create_landscape = json_file_data["create_landscape"]
863
+
864
+ for _, epochs in enumerate(epoch_group.values()):
865
+ # Each epoch_group have at least three epochs.
866
+ if len(epochs) < 3:
867
+ raise ValueError(f'For "{self.__class__.__name__}", the "epoch_group" in train_metadata.json, '
868
+ f'length of each list in "epoch_group" should not be less than 3, '
869
+ f'but got: {len(epochs)}. ')
870
+ for epoch in epochs:
871
+ if str(epoch) not in model_params_file_map.keys():
872
+ raise ValueError(f'For "{self.__class__.__name__}", the "model_params_file_map" in '
873
+ f'train_metadata.json does not exist {epoch}th checkpoint in intervals.')
874
+
875
+ check_value_type('step_per_epoch', step_per_epoch, int)
876
+ self._check_landscape_size(landscape_size)
877
+ self._check_unit(unit)
878
+ check_value_type("num_samples", num_samples, int)
879
+ self._check_create_landscape(create_landscape)
880
+
881
+
882
+ class _PCA:
883
+ r"""
884
+ The internal class for computing PCA vectors.
885
+
886
+ .. math::
887
+
888
+ u, s, vt = svd(x - mean(x)),
889
+ u_i = u_i * s_i,
890
+
891
+ where :math:`mean` is the mean operator, :math:`svd` is the singular value decomposition operator.
892
+ :math:`u_i` is line :math:`i` of the :math:`u`, :math:`s_i` is column :math:`i` of the :math:`s`,
893
+ :math:`i` ranges from :math:`0` to :math:`n\_comps`.
894
+
895
+ Args:
896
+ n_comps (int): Number of principal components needed.
897
+ """
898
+ def __init__(self, n_comps):
899
+ self._n_comps = n_comps
900
+ self._random_status = None
901
+ self._iterated_power = "auto"
902
+ self._n_oversamples = 10
903
+
904
+ @staticmethod
905
+ def _safe_dot(a, b):
906
+ """Dot product that handle the matrix case correctly."""
907
+ if a.ndim > 2 or b.ndim > 2:
908
+ if sparse.issparse(b):
909
+ # Sparse is always 2 dimensional. Implies a is above 3 dimensional.
910
+ # [n, ..., o, p] @ [l, m] -> [n, ..., o, m]
911
+ a_2d = a.reshape(-1, a.shape[-1])
912
+ ret = a_2d @ b
913
+ ret = ret.reshape(*a.shape[:-1], b.shape[1])
914
+ elif sparse.issparse(a):
915
+ # Sparse is always 2 dimensional. Implies b is above 3 dimensional.
916
+ # [l, m] @ [n, ..., o, p, q] -> [l, n, ..., o, q]
917
+ b_ = np.rollaxis(b, -2)
918
+ b_2d = b_.reshape((b.shape[-2], -1))
919
+ ret = a @ b_2d
920
+ ret = ret.reshape(a.shape[0], *b_.shape[1:])
921
+ else:
922
+ ret = np.dot(a, b)
923
+
924
+ else:
925
+ ret = a @ b
926
+
927
+ return ret
928
+
929
+ @staticmethod
930
+ def _svd_turn(u, v, u_decision=True):
931
+ """Confirm correction to ensure deterministic output from SVD."""
932
+ if u_decision:
933
+ # rows of v, columns of u
934
+ max_cols = np.argmax(np.abs(u), axis=0)
935
+ signs = np.sign(u[max_cols, list(range(u.shape[1]))])
936
+ v *= signs[:, np.newaxis]
937
+ u *= signs
938
+ else:
939
+ # rows of u, columns of v
940
+ max_rows = np.argmax(np.abs(v), axis=1)
941
+ signs = np.sign(v[list(range(v.shape[0])), max_rows])
942
+ v *= signs[:, np.newaxis]
943
+ u *= signs
944
+ return u, v
945
+
946
+ @staticmethod
947
+ def _check_random_status(seed):
948
+ """Transform seed into a np.random.RandomState instance."""
949
+ if isinstance(seed, np.random.RandomState):
950
+ return seed
951
+ if seed is None or seed is np.random:
952
+ return np.random.RandomState()
953
+ if isinstance(seed, numbers.Integral):
954
+ return np.random.RandomState(seed)
955
+ raise ValueError(
956
+ "%r cannot be used to seed a numpy.random.RandomState instance" % seed
957
+ )
958
+
959
+ def compute(self, x):
960
+ """Main method for computing principal components."""
961
+ n_components = self._n_comps
962
+ # small dimension (the shape is less than 500), and the full amount is calculated.
963
+ if max(x.shape) <= 500:
964
+ u, s, _ = self._fit_few(x)
965
+ # When dimension of x is much, truncated SVD is used for calculation.
966
+ elif 1 <= n_components < 0.8 * min(x.shape):
967
+ u, s, _ = self._fit_much(x, n_components)
968
+ # A case of n_components in (0, 1)
969
+ else:
970
+ u, s, _ = self._fit_few(x)
971
+
972
+ for i, _ in enumerate(s):
973
+ # To prevent s from being equal to 0, a small fixed noise is added.
974
+ # Adjust 1e-19 was found a good compromise for s.
975
+ if s[i] == 0:
976
+ s[i] = 1e-19
977
+ u = u[:, :self._n_comps]
978
+ u *= s[:self._n_comps]
979
+
980
+ return u
981
+
982
+ def _fit_few(self, x):
983
+ """Compute principal components with full SVD on x, when dimension of x is few."""
984
+ mean_ = np.mean(x, axis=0)
985
+ x -= mean_
986
+ u, s, vt = linalg.svd(x, full_matrices=False)
987
+ u, vt = self._svd_turn(u, vt)
988
+
989
+ return u, s, vt
990
+
991
+ def _fit_much(self, x, n_components):
992
+ """Compute principal components with truncated SVD on x, when dimension of x is much."""
993
+ random_state = self._check_random_status(self._random_status)
994
+ mean_ = np.mean(x, axis=0)
995
+ x -= mean_
996
+ u, s, vt = self._random_svd(x, n_components, n_oversamples=self._n_oversamples, random_state=random_state)
997
+ return u, s, vt
998
+
999
+ def _random_svd(self, m, n_components, n_oversamples=10, random_state="warn"):
1000
+ """Compute a truncated randomized SVD."""
1001
+ n_random = n_components + n_oversamples
1002
+ n_samples, n_features = m.shape
1003
+ # Adjust 7 or 4 was found a good compromise for randomized SVD.
1004
+ n_iter = 7 if n_components < 0.1 * min(m.shape) else 4
1005
+ if n_samples < n_features:
1006
+ m = m.T
1007
+
1008
+ q = self._random_range_finder(m, size=n_random, n_iter=n_iter, random_state=random_state)
1009
+ # Project m to the low dimensional space using the basis vectors (q vector).
1010
+ b = self._safe_dot(q.T, m)
1011
+ # Compute the svd on this matrix (b matrix)
1012
+ uhat, s, vt = linalg.svd(b, full_matrices=False)
1013
+
1014
+ del b
1015
+ u = np.dot(q, uhat)
1016
+
1017
+ if n_samples < n_features:
1018
+ u, vt = self._svd_turn(u, vt, u_decision=False)
1019
+ else:
1020
+ u, vt = self._svd_turn(u, vt)
1021
+
1022
+ if n_samples < n_features:
1023
+ return vt[:n_components, :].T, s[:n_components], u[:, :n_components].T
1024
+
1025
+ return u[:, :n_components], s[:n_components], vt[:n_components, :]
1026
+
1027
+ def _random_range_finder(self, a, size, n_iter, random_state=None):
1028
+ """Computes an orthonormal matrix whose range approximates the range of A."""
1029
+ random_state = self._check_random_status(random_state)
1030
+ # Generate normal random vectors.
1031
+ q = random_state.normal(size=(a.shape[1], size))
1032
+ if a.dtype.kind == "f":
1033
+ # Ensure f32 is retained as f32
1034
+ q = q.astype(a.dtype, copy=False)
1035
+ if n_iter <= 2:
1036
+ power_iteration_normalizer = "none"
1037
+ else:
1038
+ power_iteration_normalizer = "LU"
1039
+ # use power iterations with q to further compute the top singular vectors of a in q
1040
+ for _ in range(n_iter):
1041
+ if power_iteration_normalizer == "none":
1042
+ q = self._safe_dot(a, q)
1043
+ q = self._safe_dot(a.T, q)
1044
+ elif power_iteration_normalizer == "LU":
1045
+ q, _ = linalg.lu(self._safe_dot(a, q), permute_l=True)
1046
+ q, _ = linalg.lu(self._safe_dot(a.T, q), permute_l=True)
1047
+ # The orthogonal basis is extracted by the linear projection of Q, and the range of a is sampled.
1048
+ q, _ = linalg.qr(self._safe_dot(a, q), mode="economic")
1049
+ return q