mindspore 2.4.0__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (1406) hide show
  1. mindspore/.commit_id +1 -0
  2. mindspore/ConcurrencyCheck.dll +0 -0
  3. mindspore/CppBuildInsights.dll +0 -0
  4. mindspore/CppCoreCheck.dll +0 -0
  5. mindspore/EnumIndex.dll +0 -0
  6. mindspore/EspXEngine.dll +0 -0
  7. mindspore/HResultCheck.dll +0 -0
  8. mindspore/KernelTraceControl.dll +0 -0
  9. mindspore/LocalESPC.dll +0 -0
  10. mindspore/Microsoft.Diagnostics.Tracing.EventSource.dll +0 -0
  11. mindspore/Microsoft.VisualStudio.RemoteControl.dll +0 -0
  12. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  13. mindspore/Microsoft.VisualStudio.Utilities.Internal.dll +0 -0
  14. mindspore/Newtonsoft.Json.dll +0 -0
  15. mindspore/System.Runtime.CompilerServices.Unsafe.dll +0 -0
  16. mindspore/VariantClear.dll +0 -0
  17. mindspore/__init__.py +53 -0
  18. mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
  19. mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
  20. mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
  21. mindspore/_check_jit_forbidden_api.py +106 -0
  22. mindspore/_checkparam.py +1419 -0
  23. mindspore/_extends/__init__.py +23 -0
  24. mindspore/_extends/builtin_operations.py +224 -0
  25. mindspore/_extends/graph_kernel/__init__.py +17 -0
  26. mindspore/_extends/graph_kernel/model/__init__.py +19 -0
  27. mindspore/_extends/graph_kernel/model/graph_parallel.py +311 -0
  28. mindspore/_extends/graph_kernel/model/graph_split.py +1348 -0
  29. mindspore/_extends/graph_kernel/model/model.py +553 -0
  30. mindspore/_extends/graph_kernel/model/model_builder.py +216 -0
  31. mindspore/_extends/graph_kernel/parallel_estimate.py +60 -0
  32. mindspore/_extends/graph_kernel/splitter.py +140 -0
  33. mindspore/_extends/graph_kernel/utils.py +28 -0
  34. mindspore/_extends/parallel_compile/__init__.py +19 -0
  35. mindspore/_extends/parallel_compile/akg_compiler/__init__.py +19 -0
  36. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +269 -0
  37. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +529 -0
  38. mindspore/_extends/parallel_compile/akg_compiler/compiler.py +56 -0
  39. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +96 -0
  40. mindspore/_extends/parallel_compile/akg_compiler/get_file_path.py +36 -0
  41. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +556 -0
  42. mindspore/_extends/parallel_compile/akg_compiler/util.py +159 -0
  43. mindspore/_extends/parse/__init__.py +49 -0
  44. mindspore/_extends/parse/compile_config.py +299 -0
  45. mindspore/_extends/parse/namespace.py +136 -0
  46. mindspore/_extends/parse/parser.py +1448 -0
  47. mindspore/_extends/parse/resources.py +213 -0
  48. mindspore/_extends/parse/standard_method.py +4475 -0
  49. mindspore/_extends/parse/trope.py +97 -0
  50. mindspore/_extends/pijit/__init__.py +23 -0
  51. mindspore/_extends/pijit/pijit_func_white_list.py +669 -0
  52. mindspore/_extends/remote/__init__.py +19 -0
  53. mindspore/_extends/remote/kernel_build_server.py +199 -0
  54. mindspore/_extends/remote/kernel_build_server_akg.py +55 -0
  55. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  56. mindspore/_extends/remote/kernel_build_server_ascend.py +75 -0
  57. mindspore/_extends/utils.py +68 -0
  58. mindspore/_install_custom.py +43 -0
  59. mindspore/_profiler.py +30 -0
  60. mindspore/amp.py +433 -0
  61. mindspore/atlprov.dll +0 -0
  62. mindspore/avcodec-59.dll +0 -0
  63. mindspore/avdevice-59.dll +0 -0
  64. mindspore/avfilter-8.dll +0 -0
  65. mindspore/avformat-59.dll +0 -0
  66. mindspore/avutil-57.dll +0 -0
  67. mindspore/boost/__init__.py +42 -0
  68. mindspore/boost/adasum.py +319 -0
  69. mindspore/boost/base.py +535 -0
  70. mindspore/boost/boost.py +400 -0
  71. mindspore/boost/boost_cell_wrapper.py +790 -0
  72. mindspore/boost/dim_reduce.py +323 -0
  73. mindspore/boost/grad_accumulation.py +79 -0
  74. mindspore/boost/grad_freeze.py +382 -0
  75. mindspore/boost/group_loss_scale_manager.py +166 -0
  76. mindspore/boost/less_batch_normalization.py +174 -0
  77. mindspore/c1.dll +0 -0
  78. mindspore/c1xx.dll +0 -0
  79. mindspore/c2.dll +0 -0
  80. mindspore/cfgpersist.dll +0 -0
  81. mindspore/clang_rt.asan_dbg_dynamic-x86_64.dll +0 -0
  82. mindspore/clang_rt.asan_dynamic-x86_64.dll +0 -0
  83. mindspore/common/__init__.py +86 -0
  84. mindspore/common/_auto_dynamic.py +68 -0
  85. mindspore/common/_decorator.py +50 -0
  86. mindspore/common/_jit_fallback_utils.py +110 -0
  87. mindspore/common/_monad.py +25 -0
  88. mindspore/common/_pijit_context.py +190 -0
  89. mindspore/common/_register_for_adapter.py +74 -0
  90. mindspore/common/_register_for_recompute.py +48 -0
  91. mindspore/common/_register_for_tensor.py +46 -0
  92. mindspore/common/_stub_tensor.py +210 -0
  93. mindspore/common/_tensor_overload.py +139 -0
  94. mindspore/common/_utils.py +122 -0
  95. mindspore/common/api.py +2064 -0
  96. mindspore/common/auto_dynamic_shape.py +507 -0
  97. mindspore/common/dtype.py +422 -0
  98. mindspore/common/dump.py +130 -0
  99. mindspore/common/file_system.py +48 -0
  100. mindspore/common/generator.py +254 -0
  101. mindspore/common/hook_handle.py +143 -0
  102. mindspore/common/initializer.py +880 -0
  103. mindspore/common/jit_config.py +98 -0
  104. mindspore/common/lazy_inline.py +240 -0
  105. mindspore/common/mindir_util.py +111 -0
  106. mindspore/common/mutable.py +234 -0
  107. mindspore/common/no_inline.py +54 -0
  108. mindspore/common/np_dtype.py +25 -0
  109. mindspore/common/parameter.py +1081 -0
  110. mindspore/common/recompute.py +292 -0
  111. mindspore/common/seed.py +260 -0
  112. mindspore/common/sparse_tensor.py +1175 -0
  113. mindspore/common/symbol.py +122 -0
  114. mindspore/common/tensor.py +5039 -0
  115. mindspore/communication/__init__.py +37 -0
  116. mindspore/communication/_comm_helper.py +501 -0
  117. mindspore/communication/_hccl_management.py +297 -0
  118. mindspore/communication/comm_func.py +1395 -0
  119. mindspore/communication/management.py +673 -0
  120. mindspore/config/op_info.config +533 -0
  121. mindspore/context.py +2077 -0
  122. mindspore/d3dcompiler_47.dll +0 -0
  123. mindspore/dataset/__init__.py +90 -0
  124. mindspore/dataset/audio/__init__.py +61 -0
  125. mindspore/dataset/audio/transforms.py +3690 -0
  126. mindspore/dataset/audio/utils.py +386 -0
  127. mindspore/dataset/audio/validators.py +1172 -0
  128. mindspore/dataset/callback/__init__.py +20 -0
  129. mindspore/dataset/callback/ds_callback.py +368 -0
  130. mindspore/dataset/callback/validators.py +32 -0
  131. mindspore/dataset/core/__init__.py +13 -0
  132. mindspore/dataset/core/config.py +1095 -0
  133. mindspore/dataset/core/datatypes.py +101 -0
  134. mindspore/dataset/core/py_util_helpers.py +65 -0
  135. mindspore/dataset/core/validator_helpers.py +781 -0
  136. mindspore/dataset/debug/__init__.py +21 -0
  137. mindspore/dataset/debug/debug_hook.py +97 -0
  138. mindspore/dataset/debug/pre_defined_hook.py +67 -0
  139. mindspore/dataset/engine/__init__.py +124 -0
  140. mindspore/dataset/engine/cache_admin.py +47 -0
  141. mindspore/dataset/engine/cache_client.py +129 -0
  142. mindspore/dataset/engine/datasets.py +4582 -0
  143. mindspore/dataset/engine/datasets_audio.py +911 -0
  144. mindspore/dataset/engine/datasets_standard_format.py +543 -0
  145. mindspore/dataset/engine/datasets_text.py +2161 -0
  146. mindspore/dataset/engine/datasets_user_defined.py +1184 -0
  147. mindspore/dataset/engine/datasets_vision.py +4816 -0
  148. mindspore/dataset/engine/iterators.py +371 -0
  149. mindspore/dataset/engine/obs/__init__.py +23 -0
  150. mindspore/dataset/engine/obs/config_loader.py +68 -0
  151. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +508 -0
  152. mindspore/dataset/engine/obs/util.py +482 -0
  153. mindspore/dataset/engine/offload.py +596 -0
  154. mindspore/dataset/engine/queue.py +304 -0
  155. mindspore/dataset/engine/samplers.py +895 -0
  156. mindspore/dataset/engine/serializer_deserializer.py +159 -0
  157. mindspore/dataset/engine/validators.py +2895 -0
  158. mindspore/dataset/text/__init__.py +51 -0
  159. mindspore/dataset/text/transforms.py +1703 -0
  160. mindspore/dataset/text/utils.py +715 -0
  161. mindspore/dataset/text/validators.py +642 -0
  162. mindspore/dataset/transforms/__init__.py +45 -0
  163. mindspore/dataset/transforms/c_transforms.py +638 -0
  164. mindspore/dataset/transforms/py_transforms.py +393 -0
  165. mindspore/dataset/transforms/py_transforms_util.py +255 -0
  166. mindspore/dataset/transforms/transforms.py +1260 -0
  167. mindspore/dataset/transforms/validators.py +410 -0
  168. mindspore/dataset/utils/__init__.py +19 -0
  169. mindspore/dataset/utils/browse_dataset.py +190 -0
  170. mindspore/dataset/utils/line_reader.py +126 -0
  171. mindspore/dataset/vision/__init__.py +65 -0
  172. mindspore/dataset/vision/c_transforms.py +2641 -0
  173. mindspore/dataset/vision/py_transforms.py +2120 -0
  174. mindspore/dataset/vision/py_transforms_util.py +1660 -0
  175. mindspore/dataset/vision/transforms.py +7295 -0
  176. mindspore/dataset/vision/utils.py +863 -0
  177. mindspore/dataset/vision/validators.py +1483 -0
  178. mindspore/default_config.py +2 -0
  179. mindspore/dnnl.dll +0 -0
  180. mindspore/dpcmi.dll +0 -0
  181. mindspore/experimental/__init__.py +20 -0
  182. mindspore/experimental/es/__init__.py +22 -0
  183. mindspore/experimental/es/embedding_service.py +883 -0
  184. mindspore/experimental/es/embedding_service_layer.py +581 -0
  185. mindspore/experimental/llm_boost/__init__.py +21 -0
  186. mindspore/experimental/llm_boost/atb/__init__.py +23 -0
  187. mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
  188. mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
  189. mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
  190. mindspore/experimental/llm_boost/register.py +129 -0
  191. mindspore/experimental/llm_boost/utils.py +31 -0
  192. mindspore/experimental/map_parameter.py +309 -0
  193. mindspore/experimental/optim/__init__.py +40 -0
  194. mindspore/experimental/optim/adadelta.py +161 -0
  195. mindspore/experimental/optim/adagrad.py +168 -0
  196. mindspore/experimental/optim/adam.py +193 -0
  197. mindspore/experimental/optim/adamax.py +170 -0
  198. mindspore/experimental/optim/adamw.py +290 -0
  199. mindspore/experimental/optim/asgd.py +153 -0
  200. mindspore/experimental/optim/lr_scheduler.py +1371 -0
  201. mindspore/experimental/optim/nadam.py +157 -0
  202. mindspore/experimental/optim/optimizer.py +262 -0
  203. mindspore/experimental/optim/radam.py +194 -0
  204. mindspore/experimental/optim/rmsprop.py +154 -0
  205. mindspore/experimental/optim/rprop.py +164 -0
  206. mindspore/experimental/optim/sgd.py +156 -0
  207. mindspore/hal/__init__.py +40 -0
  208. mindspore/hal/_ascend.py +57 -0
  209. mindspore/hal/_base.py +57 -0
  210. mindspore/hal/_cpu.py +56 -0
  211. mindspore/hal/_gpu.py +57 -0
  212. mindspore/hal/contiguous_tensors_handle.py +175 -0
  213. mindspore/hal/device.py +356 -0
  214. mindspore/hal/event.py +179 -0
  215. mindspore/hal/memory.py +326 -0
  216. mindspore/hal/stream.py +357 -0
  217. mindspore/include/OWNERS +7 -0
  218. mindspore/include/api/allocator.h +97 -0
  219. mindspore/include/api/callback/callback.h +93 -0
  220. mindspore/include/api/callback/ckpt_saver.h +41 -0
  221. mindspore/include/api/callback/loss_monitor.h +33 -0
  222. mindspore/include/api/callback/lr_scheduler.h +51 -0
  223. mindspore/include/api/callback/time_monitor.h +34 -0
  224. mindspore/include/api/callback/train_accuracy.h +37 -0
  225. mindspore/include/api/cell.h +90 -0
  226. mindspore/include/api/cfg.h +82 -0
  227. mindspore/include/api/context.h +602 -0
  228. mindspore/include/api/data_type.h +47 -0
  229. mindspore/include/api/delegate.h +178 -0
  230. mindspore/include/api/delegate_api.h +75 -0
  231. mindspore/include/api/dual_abi_helper.h +208 -0
  232. mindspore/include/api/format.h +28 -0
  233. mindspore/include/api/graph.h +46 -0
  234. mindspore/include/api/kernel.h +58 -0
  235. mindspore/include/api/kernel_api.h +168 -0
  236. mindspore/include/api/metrics/accuracy.h +36 -0
  237. mindspore/include/api/metrics/metrics.h +41 -0
  238. mindspore/include/api/model.h +438 -0
  239. mindspore/include/api/model_group.h +91 -0
  240. mindspore/include/api/model_parallel_runner.h +168 -0
  241. mindspore/include/api/serialization.h +185 -0
  242. mindspore/include/api/status.h +192 -0
  243. mindspore/include/api/types.h +431 -0
  244. mindspore/include/api/visible.h +41 -0
  245. mindspore/include/c_api/context_c.h +179 -0
  246. mindspore/include/c_api/data_type_c.h +52 -0
  247. mindspore/include/c_api/format_c.h +46 -0
  248. mindspore/include/c_api/model_c.h +347 -0
  249. mindspore/include/c_api/status_c.h +79 -0
  250. mindspore/include/c_api/tensor_c.h +146 -0
  251. mindspore/include/c_api/types_c.h +67 -0
  252. mindspore/include/dataset/config.h +163 -0
  253. mindspore/include/dataset/constants.h +363 -0
  254. mindspore/include/dataset/execute.h +196 -0
  255. mindspore/include/dataset/text.h +1092 -0
  256. mindspore/include/dataset/transforms.h +638 -0
  257. mindspore/include/dataset/vision.h +2129 -0
  258. mindspore/include/dataset/vision_ascend.h +206 -0
  259. mindspore/include/dataset/vision_lite.h +625 -0
  260. mindspore/jpeg62.dll +0 -0
  261. mindspore/log.py +633 -0
  262. mindspore/mindrecord/__init__.py +43 -0
  263. mindspore/mindrecord/common/__init__.py +17 -0
  264. mindspore/mindrecord/common/constant.py +20 -0
  265. mindspore/mindrecord/common/enums.py +44 -0
  266. mindspore/mindrecord/common/exceptions.py +311 -0
  267. mindspore/mindrecord/config.py +809 -0
  268. mindspore/mindrecord/filereader.py +174 -0
  269. mindspore/mindrecord/filewriter.py +722 -0
  270. mindspore/mindrecord/mindpage.py +210 -0
  271. mindspore/mindrecord/shardheader.py +141 -0
  272. mindspore/mindrecord/shardindexgenerator.py +74 -0
  273. mindspore/mindrecord/shardreader.py +117 -0
  274. mindspore/mindrecord/shardsegment.py +128 -0
  275. mindspore/mindrecord/shardutils.py +185 -0
  276. mindspore/mindrecord/shardwriter.py +237 -0
  277. mindspore/mindrecord/tools/__init__.py +17 -0
  278. mindspore/mindrecord/tools/cifar10.py +140 -0
  279. mindspore/mindrecord/tools/cifar100.py +153 -0
  280. mindspore/mindrecord/tools/cifar100_to_mr.py +185 -0
  281. mindspore/mindrecord/tools/cifar10_to_mr.py +177 -0
  282. mindspore/mindrecord/tools/csv_to_mr.py +200 -0
  283. mindspore/mindrecord/tools/imagenet_to_mr.py +206 -0
  284. mindspore/mindrecord/tools/mnist_to_mr.py +259 -0
  285. mindspore/mindrecord/tools/tfrecord_to_mr.py +360 -0
  286. mindspore/mindspore_backend.dll +0 -0
  287. mindspore/mindspore_common.dll +0 -0
  288. mindspore/mindspore_core.dll +0 -0
  289. mindspore/mindspore_glog.dll +0 -0
  290. mindspore/mindspore_np_dtype.dll +0 -0
  291. mindspore/mindspore_ops.dll +0 -0
  292. mindspore/mint/__init__.py +1586 -0
  293. mindspore/mint/distributed/__init__.py +31 -0
  294. mindspore/mint/distributed/distributed.py +254 -0
  295. mindspore/mint/linalg/__init__.py +22 -0
  296. mindspore/mint/nn/__init__.py +757 -0
  297. mindspore/mint/nn/functional.py +679 -0
  298. mindspore/mint/nn/layer/__init__.py +39 -0
  299. mindspore/mint/nn/layer/activation.py +133 -0
  300. mindspore/mint/nn/layer/normalization.py +477 -0
  301. mindspore/mint/nn/layer/pooling.py +110 -0
  302. mindspore/mint/optim/__init__.py +24 -0
  303. mindspore/mint/optim/adamw.py +206 -0
  304. mindspore/mint/special/__init__.py +63 -0
  305. mindspore/msobj140.dll +0 -0
  306. mindspore/mspdb140.dll +0 -0
  307. mindspore/mspdbcore.dll +0 -0
  308. mindspore/mspdbst.dll +0 -0
  309. mindspore/mspft140.dll +0 -0
  310. mindspore/msvcdis140.dll +0 -0
  311. mindspore/msvcp140.dll +0 -0
  312. mindspore/msvcp140_1.dll +0 -0
  313. mindspore/msvcp140_2.dll +0 -0
  314. mindspore/msvcp140_atomic_wait.dll +0 -0
  315. mindspore/msvcp140_codecvt_ids.dll +0 -0
  316. mindspore/multiprocessing/__init__.py +73 -0
  317. mindspore/nn/__init__.py +47 -0
  318. mindspore/nn/cell.py +2787 -0
  319. mindspore/nn/dynamic_lr.py +482 -0
  320. mindspore/nn/grad/__init__.py +21 -0
  321. mindspore/nn/grad/cell_grad.py +196 -0
  322. mindspore/nn/layer/__init__.py +63 -0
  323. mindspore/nn/layer/activation.py +1822 -0
  324. mindspore/nn/layer/basic.py +1629 -0
  325. mindspore/nn/layer/channel_shuffle.py +90 -0
  326. mindspore/nn/layer/combined.py +248 -0
  327. mindspore/nn/layer/container.py +734 -0
  328. mindspore/nn/layer/conv.py +1505 -0
  329. mindspore/nn/layer/dense.py +204 -0
  330. mindspore/nn/layer/embedding.py +869 -0
  331. mindspore/nn/layer/image.py +661 -0
  332. mindspore/nn/layer/math.py +1069 -0
  333. mindspore/nn/layer/normalization.py +1273 -0
  334. mindspore/nn/layer/padding.py +880 -0
  335. mindspore/nn/layer/pooling.py +2302 -0
  336. mindspore/nn/layer/rnn_cells.py +388 -0
  337. mindspore/nn/layer/rnns.py +849 -0
  338. mindspore/nn/layer/thor_layer.py +963 -0
  339. mindspore/nn/layer/timedistributed.py +155 -0
  340. mindspore/nn/layer/transformer.py +823 -0
  341. mindspore/nn/learning_rate_schedule.py +512 -0
  342. mindspore/nn/loss/__init__.py +36 -0
  343. mindspore/nn/loss/loss.py +2924 -0
  344. mindspore/nn/metrics.py +53 -0
  345. mindspore/nn/optim/__init__.py +45 -0
  346. mindspore/nn/optim/_dist_optimizer_registry.py +111 -0
  347. mindspore/nn/optim/ada_grad.py +217 -0
  348. mindspore/nn/optim/adadelta.py +206 -0
  349. mindspore/nn/optim/adafactor.py +448 -0
  350. mindspore/nn/optim/adam.py +1297 -0
  351. mindspore/nn/optim/adamax.py +220 -0
  352. mindspore/nn/optim/adasum.py +548 -0
  353. mindspore/nn/optim/asgd.py +216 -0
  354. mindspore/nn/optim/ftrl.py +401 -0
  355. mindspore/nn/optim/lamb.py +296 -0
  356. mindspore/nn/optim/lars.py +202 -0
  357. mindspore/nn/optim/lazyadam.py +533 -0
  358. mindspore/nn/optim/momentum.py +239 -0
  359. mindspore/nn/optim/optimizer.py +1034 -0
  360. mindspore/nn/optim/proximal_ada_grad.py +242 -0
  361. mindspore/nn/optim/rmsprop.py +264 -0
  362. mindspore/nn/optim/rprop.py +251 -0
  363. mindspore/nn/optim/sgd.py +237 -0
  364. mindspore/nn/optim/tft_wrapper.py +127 -0
  365. mindspore/nn/optim/thor.py +1310 -0
  366. mindspore/nn/probability/__init__.py +22 -0
  367. mindspore/nn/probability/bijector/__init__.py +35 -0
  368. mindspore/nn/probability/bijector/bijector.py +337 -0
  369. mindspore/nn/probability/bijector/exp.py +65 -0
  370. mindspore/nn/probability/bijector/gumbel_cdf.py +144 -0
  371. mindspore/nn/probability/bijector/invert.py +126 -0
  372. mindspore/nn/probability/bijector/power_transform.py +196 -0
  373. mindspore/nn/probability/bijector/scalar_affine.py +167 -0
  374. mindspore/nn/probability/bijector/softplus.py +189 -0
  375. mindspore/nn/probability/bnn_layers/__init__.py +29 -0
  376. mindspore/nn/probability/bnn_layers/_util.py +46 -0
  377. mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +112 -0
  378. mindspore/nn/probability/bnn_layers/conv_variational.py +267 -0
  379. mindspore/nn/probability/bnn_layers/dense_variational.py +302 -0
  380. mindspore/nn/probability/bnn_layers/layer_distribution.py +123 -0
  381. mindspore/nn/probability/distribution/__init__.py +56 -0
  382. mindspore/nn/probability/distribution/_utils/__init__.py +34 -0
  383. mindspore/nn/probability/distribution/_utils/custom_ops.py +96 -0
  384. mindspore/nn/probability/distribution/_utils/utils.py +362 -0
  385. mindspore/nn/probability/distribution/bernoulli.py +334 -0
  386. mindspore/nn/probability/distribution/beta.py +391 -0
  387. mindspore/nn/probability/distribution/categorical.py +435 -0
  388. mindspore/nn/probability/distribution/cauchy.py +383 -0
  389. mindspore/nn/probability/distribution/distribution.py +827 -0
  390. mindspore/nn/probability/distribution/exponential.py +350 -0
  391. mindspore/nn/probability/distribution/gamma.py +391 -0
  392. mindspore/nn/probability/distribution/geometric.py +335 -0
  393. mindspore/nn/probability/distribution/gumbel.py +257 -0
  394. mindspore/nn/probability/distribution/half_normal.py +133 -0
  395. mindspore/nn/probability/distribution/laplace.py +128 -0
  396. mindspore/nn/probability/distribution/log_normal.py +272 -0
  397. mindspore/nn/probability/distribution/logistic.py +379 -0
  398. mindspore/nn/probability/distribution/normal.py +336 -0
  399. mindspore/nn/probability/distribution/poisson.py +288 -0
  400. mindspore/nn/probability/distribution/student_t.py +149 -0
  401. mindspore/nn/probability/distribution/transformed_distribution.py +235 -0
  402. mindspore/nn/probability/distribution/uniform.py +375 -0
  403. mindspore/nn/reinforcement/__init__.py +24 -0
  404. mindspore/nn/reinforcement/_batch_read_write.py +142 -0
  405. mindspore/nn/reinforcement/_tensors_queue.py +152 -0
  406. mindspore/nn/reinforcement/tensor_array.py +145 -0
  407. mindspore/nn/sparse/__init__.py +23 -0
  408. mindspore/nn/sparse/sparse.py +147 -0
  409. mindspore/nn/wrap/__init__.py +49 -0
  410. mindspore/nn/wrap/cell_wrapper.py +968 -0
  411. mindspore/nn/wrap/grad_reducer.py +608 -0
  412. mindspore/nn/wrap/loss_scale.py +694 -0
  413. mindspore/numpy/__init__.py +121 -0
  414. mindspore/numpy/array_creations.py +2731 -0
  415. mindspore/numpy/array_ops.py +2629 -0
  416. mindspore/numpy/dtypes.py +185 -0
  417. mindspore/numpy/fft.py +966 -0
  418. mindspore/numpy/logic_ops.py +936 -0
  419. mindspore/numpy/math_ops.py +5911 -0
  420. mindspore/numpy/utils.py +214 -0
  421. mindspore/numpy/utils_const.py +565 -0
  422. mindspore/opencv_core452.dll +0 -0
  423. mindspore/opencv_imgcodecs452.dll +0 -0
  424. mindspore/opencv_imgproc452.dll +0 -0
  425. mindspore/ops/__init__.py +56 -0
  426. mindspore/ops/_constants.py +30 -0
  427. mindspore/ops/_grad_experimental/__init__.py +31 -0
  428. mindspore/ops/_grad_experimental/grad_array_ops.py +830 -0
  429. mindspore/ops/_grad_experimental/grad_base.py +143 -0
  430. mindspore/ops/_grad_experimental/grad_comm_ops.py +714 -0
  431. mindspore/ops/_grad_experimental/grad_debug_ops.py +31 -0
  432. mindspore/ops/_grad_experimental/grad_implementations.py +203 -0
  433. mindspore/ops/_grad_experimental/grad_inner_ops.py +79 -0
  434. mindspore/ops/_grad_experimental/grad_math_ops.py +802 -0
  435. mindspore/ops/_grad_experimental/grad_nn_ops.py +231 -0
  436. mindspore/ops/_grad_experimental/grad_quant_ops.py +238 -0
  437. mindspore/ops/_grad_experimental/grad_sparse.py +342 -0
  438. mindspore/ops/_grad_experimental/grad_sparse_ops.py +399 -0
  439. mindspore/ops/_grad_experimental/taylor_rule.py +220 -0
  440. mindspore/ops/_op_impl/__init__.py +23 -0
  441. mindspore/ops/_op_impl/_custom_op/__init__.py +39 -0
  442. mindspore/ops/_op_impl/_custom_op/_basic.py +158 -0
  443. mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +279 -0
  444. mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +156 -0
  445. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +109 -0
  446. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +125 -0
  447. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +105 -0
  448. mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +124 -0
  449. mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +116 -0
  450. mindspore/ops/_op_impl/_custom_op/correction_mul.py +89 -0
  451. mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +196 -0
  452. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +366 -0
  453. mindspore/ops/_op_impl/_custom_op/dsd_impl.py +162 -0
  454. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +136 -0
  455. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +206 -0
  456. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +88 -0
  457. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +128 -0
  458. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +199 -0
  459. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +88 -0
  460. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +156 -0
  461. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +184 -0
  462. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +143 -0
  463. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +169 -0
  464. mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +548 -0
  465. mindspore/ops/_op_impl/_custom_op/img2col_impl.py +881 -0
  466. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +278 -0
  467. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +200 -0
  468. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +334 -0
  469. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +255 -0
  470. mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +222 -0
  471. mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +644 -0
  472. mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +488 -0
  473. mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +87 -0
  474. mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +129 -0
  475. mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +121 -0
  476. mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +352 -0
  477. mindspore/ops/_op_impl/aicpu/__init__.py +441 -0
  478. mindspore/ops/_op_impl/aicpu/abs.py +36 -0
  479. mindspore/ops/_op_impl/aicpu/acos.py +32 -0
  480. mindspore/ops/_op_impl/aicpu/acos_grad.py +33 -0
  481. mindspore/ops/_op_impl/aicpu/acosh.py +34 -0
  482. mindspore/ops/_op_impl/aicpu/acosh_grad.py +35 -0
  483. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
  484. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
  485. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
  486. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
  487. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
  488. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
  489. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
  490. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
  491. mindspore/ops/_op_impl/aicpu/add.py +43 -0
  492. mindspore/ops/_op_impl/aicpu/add_n.py +41 -0
  493. mindspore/ops/_op_impl/aicpu/add_v2.py +40 -0
  494. mindspore/ops/_op_impl/aicpu/addcdiv.py +41 -0
  495. mindspore/ops/_op_impl/aicpu/addcmul.py +47 -0
  496. mindspore/ops/_op_impl/aicpu/adjust_contrastv2.py +32 -0
  497. mindspore/ops/_op_impl/aicpu/adjust_hue.py +31 -0
  498. mindspore/ops/_op_impl/aicpu/adjust_saturation.py +32 -0
  499. mindspore/ops/_op_impl/aicpu/affine_grid.py +33 -0
  500. mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
  501. mindspore/ops/_op_impl/aicpu/angle.py +31 -0
  502. mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
  503. mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
  504. mindspore/ops/_op_impl/aicpu/argmax_with_value.py +43 -0
  505. mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
  506. mindspore/ops/_op_impl/aicpu/asin.py +32 -0
  507. mindspore/ops/_op_impl/aicpu/asin_grad.py +33 -0
  508. mindspore/ops/_op_impl/aicpu/asinh.py +34 -0
  509. mindspore/ops/_op_impl/aicpu/asinh_grad.py +35 -0
  510. mindspore/ops/_op_impl/aicpu/atanh.py +34 -0
  511. mindspore/ops/_op_impl/aicpu/avgpool_grad_v1.py +37 -0
  512. mindspore/ops/_op_impl/aicpu/avgpool_v1.py +36 -0
  513. mindspore/ops/_op_impl/aicpu/bartlett_window.py +36 -0
  514. mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
  515. mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
  516. mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
  517. mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
  518. mindspore/ops/_op_impl/aicpu/betainc.py +31 -0
  519. mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
  520. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +42 -0
  521. mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
  522. mindspore/ops/_op_impl/aicpu/blackman_window.py +36 -0
  523. mindspore/ops/_op_impl/aicpu/broadcast_to.py +58 -0
  524. mindspore/ops/_op_impl/aicpu/bucketize.py +34 -0
  525. mindspore/ops/_op_impl/aicpu/cache_swap_table.py +102 -0
  526. mindspore/ops/_op_impl/aicpu/cast.py +225 -0
  527. mindspore/ops/_op_impl/aicpu/cauchy.py +33 -0
  528. mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
  529. mindspore/ops/_op_impl/aicpu/check_numerics.py +33 -0
  530. mindspore/ops/_op_impl/aicpu/cholesky.py +32 -0
  531. mindspore/ops/_op_impl/aicpu/cholesky_inverse.py +31 -0
  532. mindspore/ops/_op_impl/aicpu/cholesky_solve.py +33 -0
  533. mindspore/ops/_op_impl/aicpu/choleskygrad.py +32 -0
  534. mindspore/ops/_op_impl/aicpu/coalesce.py +37 -0
  535. mindspore/ops/_op_impl/aicpu/col2im.py +38 -0
  536. mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
  537. mindspore/ops/_op_impl/aicpu/compare_and_bitpack.py +37 -0
  538. mindspore/ops/_op_impl/aicpu/complex.py +32 -0
  539. mindspore/ops/_op_impl/aicpu/complex_abs.py +31 -0
  540. mindspore/ops/_op_impl/aicpu/compute_accidental_hits.py +44 -0
  541. mindspore/ops/_op_impl/aicpu/concat.py +57 -0
  542. mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
  543. mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
  544. mindspore/ops/_op_impl/aicpu/conj.py +42 -0
  545. mindspore/ops/_op_impl/aicpu/conjugate_transpose.py +58 -0
  546. mindspore/ops/_op_impl/aicpu/cos.py +34 -0
  547. mindspore/ops/_op_impl/aicpu/cosh.py +34 -0
  548. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  549. mindspore/ops/_op_impl/aicpu/crop_and_resize.py +69 -0
  550. mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_boxes.py +68 -0
  551. mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
  552. mindspore/ops/_op_impl/aicpu/cross.py +42 -0
  553. mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_dense.py +48 -0
  554. mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_sparse_tensor.py +51 -0
  555. mindspore/ops/_op_impl/aicpu/ctc_greedy_decoder.py +35 -0
  556. mindspore/ops/_op_impl/aicpu/ctc_loss_v2.py +43 -0
  557. mindspore/ops/_op_impl/aicpu/ctc_loss_v2_grad.py +45 -0
  558. mindspore/ops/_op_impl/aicpu/ctcloss.py +38 -0
  559. mindspore/ops/_op_impl/aicpu/cummax.py +41 -0
  560. mindspore/ops/_op_impl/aicpu/cumprod.py +58 -0
  561. mindspore/ops/_op_impl/aicpu/cumsum.py +58 -0
  562. mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
  563. mindspore/ops/_op_impl/aicpu/data_format_vec_permute.py +32 -0
  564. mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
  565. mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
  566. mindspore/ops/_op_impl/aicpu/dense_to_csr_sparse_matrix.py +49 -0
  567. mindspore/ops/_op_impl/aicpu/dense_to_dense_set_operation.py +45 -0
  568. mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
  569. mindspore/ops/_op_impl/aicpu/depth_to_space.py +44 -0
  570. mindspore/ops/_op_impl/aicpu/diag.py +36 -0
  571. mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
  572. mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
  573. mindspore/ops/_op_impl/aicpu/digamma.py +31 -0
  574. mindspore/ops/_op_impl/aicpu/div.py +41 -0
  575. mindspore/ops/_op_impl/aicpu/div_no_nan.py +35 -0
  576. mindspore/ops/_op_impl/aicpu/dropout2d.py +42 -0
  577. mindspore/ops/_op_impl/aicpu/dropout3d.py +42 -0
  578. mindspore/ops/_op_impl/aicpu/dropout_genmask.py +41 -0
  579. mindspore/ops/_op_impl/aicpu/dropout_genmask_v3.py +32 -0
  580. mindspore/ops/_op_impl/aicpu/dynamic_stitch.py +42 -0
  581. mindspore/ops/_op_impl/aicpu/edit_distance.py +56 -0
  582. mindspore/ops/_op_impl/aicpu/eig.py +35 -0
  583. mindspore/ops/_op_impl/aicpu/embedding_lookup.py +102 -0
  584. mindspore/ops/_op_impl/aicpu/end_of_sequence.py +30 -0
  585. mindspore/ops/_op_impl/aicpu/environ_create.py +28 -0
  586. mindspore/ops/_op_impl/aicpu/environ_destroy_all.py +28 -0
  587. mindspore/ops/_op_impl/aicpu/environ_get.py +41 -0
  588. mindspore/ops/_op_impl/aicpu/environ_set.py +40 -0
  589. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  590. mindspore/ops/_op_impl/aicpu/equal.py +41 -0
  591. mindspore/ops/_op_impl/aicpu/exp.py +37 -0
  592. mindspore/ops/_op_impl/aicpu/expand.py +45 -0
  593. mindspore/ops/_op_impl/aicpu/expand_dims.py +42 -0
  594. mindspore/ops/_op_impl/aicpu/expm1.py +34 -0
  595. mindspore/ops/_op_impl/aicpu/extract_glimpse.py +35 -0
  596. mindspore/ops/_op_impl/aicpu/eye.py +44 -0
  597. mindspore/ops/_op_impl/aicpu/fft_with_size.py +47 -0
  598. mindspore/ops/_op_impl/aicpu/fill_diagonal.py +39 -0
  599. mindspore/ops/_op_impl/aicpu/fill_v2.py +58 -0
  600. mindspore/ops/_op_impl/aicpu/flatten.py +43 -0
  601. mindspore/ops/_op_impl/aicpu/floor_div.py +38 -0
  602. mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
  603. mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
  604. mindspore/ops/_op_impl/aicpu/fractional_avg_pool.py +41 -0
  605. mindspore/ops/_op_impl/aicpu/fractional_avg_pool_grad.py +41 -0
  606. mindspore/ops/_op_impl/aicpu/fractional_max_pool.py +41 -0
  607. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_grad_with_fixed_ksize.py +43 -0
  608. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +65 -0
  609. mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad.py +42 -0
  610. mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad_with_fixed_ksize.py +42 -0
  611. mindspore/ops/_op_impl/aicpu/fractional_max_pool_with_fixed_ksize.py +49 -0
  612. mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
  613. mindspore/ops/_op_impl/aicpu/fused_sparse_adam.py +46 -0
  614. mindspore/ops/_op_impl/aicpu/fused_sparse_ftrl.py +41 -0
  615. mindspore/ops/_op_impl/aicpu/fused_sparse_lazy_adam.py +46 -0
  616. mindspore/ops/_op_impl/aicpu/fused_sparse_proximal_adagrad.py +39 -0
  617. mindspore/ops/_op_impl/aicpu/gamma.py +38 -0
  618. mindspore/ops/_op_impl/aicpu/gather.py +46 -0
  619. mindspore/ops/_op_impl/aicpu/gather_d.py +79 -0
  620. mindspore/ops/_op_impl/aicpu/gather_d_grad_v2.py +79 -0
  621. mindspore/ops/_op_impl/aicpu/gather_grad.py +54 -0
  622. mindspore/ops/_op_impl/aicpu/gather_nd.py +56 -0
  623. mindspore/ops/_op_impl/aicpu/gcd.py +32 -0
  624. mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +38 -0
  625. mindspore/ops/_op_impl/aicpu/geqrf.py +32 -0
  626. mindspore/ops/_op_impl/aicpu/get_next.py +39 -0
  627. mindspore/ops/_op_impl/aicpu/glu.py +33 -0
  628. mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
  629. mindspore/ops/_op_impl/aicpu/greater.py +41 -0
  630. mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
  631. mindspore/ops/_op_impl/aicpu/grid_sampler_2d.py +35 -0
  632. mindspore/ops/_op_impl/aicpu/grid_sampler_2d_grad.py +38 -0
  633. mindspore/ops/_op_impl/aicpu/grid_sampler_3d.py +34 -0
  634. mindspore/ops/_op_impl/aicpu/grid_sampler_3d_grad.py +38 -0
  635. mindspore/ops/_op_impl/aicpu/hamming_window.py +57 -0
  636. mindspore/ops/_op_impl/aicpu/hard_sigmoid.py +32 -0
  637. mindspore/ops/_op_impl/aicpu/hard_sigmoid_grad.py +33 -0
  638. mindspore/ops/_op_impl/aicpu/heaviside.py +40 -0
  639. mindspore/ops/_op_impl/aicpu/histogram.py +35 -0
  640. mindspore/ops/_op_impl/aicpu/hsv_to_rgb.py +32 -0
  641. mindspore/ops/_op_impl/aicpu/hypot.py +32 -0
  642. mindspore/ops/_op_impl/aicpu/identity.py +42 -0
  643. mindspore/ops/_op_impl/aicpu/identity_n.py +41 -0
  644. mindspore/ops/_op_impl/aicpu/igamma.py +30 -0
  645. mindspore/ops/_op_impl/aicpu/igammac.py +30 -0
  646. mindspore/ops/_op_impl/aicpu/igammagrada.py +30 -0
  647. mindspore/ops/_op_impl/aicpu/im2col.py +43 -0
  648. mindspore/ops/_op_impl/aicpu/imag.py +31 -0
  649. mindspore/ops/_op_impl/aicpu/index_fill.py +54 -0
  650. mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
  651. mindspore/ops/_op_impl/aicpu/init_data_set_queue.py +27 -0
  652. mindspore/ops/_op_impl/aicpu/inplace_index_add.py +39 -0
  653. mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
  654. mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
  655. mindspore/ops/_op_impl/aicpu/is_finite.py +40 -0
  656. mindspore/ops/_op_impl/aicpu/is_inf.py +31 -0
  657. mindspore/ops/_op_impl/aicpu/is_nan.py +31 -0
  658. mindspore/ops/_op_impl/aicpu/kldivloss.py +34 -0
  659. mindspore/ops/_op_impl/aicpu/kldivlossgrad.py +35 -0
  660. mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
  661. mindspore/ops/_op_impl/aicpu/lcm.py +32 -0
  662. mindspore/ops/_op_impl/aicpu/left_shift.py +38 -0
  663. mindspore/ops/_op_impl/aicpu/less.py +41 -0
  664. mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
  665. mindspore/ops/_op_impl/aicpu/lgamma.py +33 -0
  666. mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +57 -0
  667. mindspore/ops/_op_impl/aicpu/linspace.py +33 -0
  668. mindspore/ops/_op_impl/aicpu/list_diff.py +50 -0
  669. mindspore/ops/_op_impl/aicpu/log.py +37 -0
  670. mindspore/ops/_op_impl/aicpu/log1p.py +34 -0
  671. mindspore/ops/_op_impl/aicpu/log_matrix_determinant.py +31 -0
  672. mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
  673. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +37 -0
  674. mindspore/ops/_op_impl/aicpu/logical_xor.py +30 -0
  675. mindspore/ops/_op_impl/aicpu/logit.py +33 -0
  676. mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
  677. mindspore/ops/_op_impl/aicpu/logspace.py +36 -0
  678. mindspore/ops/_op_impl/aicpu/lower_bound.py +47 -0
  679. mindspore/ops/_op_impl/aicpu/lstsq.py +34 -0
  680. mindspore/ops/_op_impl/aicpu/lu.py +39 -0
  681. mindspore/ops/_op_impl/aicpu/lu_solve.py +32 -0
  682. mindspore/ops/_op_impl/aicpu/lu_unpack.py +114 -0
  683. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +49 -0
  684. mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
  685. mindspore/ops/_op_impl/aicpu/masked_scatter.py +40 -0
  686. mindspore/ops/_op_impl/aicpu/masked_select.py +31 -0
  687. mindspore/ops/_op_impl/aicpu/masked_select_grad.py +35 -0
  688. mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
  689. mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
  690. mindspore/ops/_op_impl/aicpu/matrix_determinant.py +30 -0
  691. mindspore/ops/_op_impl/aicpu/matrix_diag_part_v3.py +54 -0
  692. mindspore/ops/_op_impl/aicpu/matrix_diag_v3.py +56 -0
  693. mindspore/ops/_op_impl/aicpu/matrix_exp.py +34 -0
  694. mindspore/ops/_op_impl/aicpu/matrix_inverse.py +31 -0
  695. mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
  696. mindspore/ops/_op_impl/aicpu/matrix_power.py +37 -0
  697. mindspore/ops/_op_impl/aicpu/matrix_set_diag_v3.py +54 -0
  698. mindspore/ops/_op_impl/aicpu/matrix_solve.py +35 -0
  699. mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
  700. mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
  701. mindspore/ops/_op_impl/aicpu/max_pool3d_grad_with_argmax.py +60 -0
  702. mindspore/ops/_op_impl/aicpu/max_pool3d_with_argmax.py +59 -0
  703. mindspore/ops/_op_impl/aicpu/max_unpool2d.py +57 -0
  704. mindspore/ops/_op_impl/aicpu/max_unpool2d_grad.py +58 -0
  705. mindspore/ops/_op_impl/aicpu/max_unpool3d.py +57 -0
  706. mindspore/ops/_op_impl/aicpu/max_unpool3d_grad.py +58 -0
  707. mindspore/ops/_op_impl/aicpu/maximum_grad_grad.py +40 -0
  708. mindspore/ops/_op_impl/aicpu/maxpool_grad_v1.py +46 -0
  709. mindspore/ops/_op_impl/aicpu/maxpool_v1.py +42 -0
  710. mindspore/ops/_op_impl/aicpu/median.py +39 -0
  711. mindspore/ops/_op_impl/aicpu/median_grad.py +45 -0
  712. mindspore/ops/_op_impl/aicpu/meshgrid.py +41 -0
  713. mindspore/ops/_op_impl/aicpu/minimum_grad_grad.py +40 -0
  714. mindspore/ops/_op_impl/aicpu/mirror_pad.py +50 -0
  715. mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +48 -0
  716. mindspore/ops/_op_impl/aicpu/mul.py +43 -0
  717. mindspore/ops/_op_impl/aicpu/mul_no_nan.py +42 -0
  718. mindspore/ops/_op_impl/aicpu/multi_margin_loss.py +37 -0
  719. mindspore/ops/_op_impl/aicpu/multi_margin_loss_grad.py +41 -0
  720. mindspore/ops/_op_impl/aicpu/multilabel_margin_loss_grad.py +37 -0
  721. mindspore/ops/_op_impl/aicpu/multinomial.py +47 -0
  722. mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
  723. mindspore/ops/_op_impl/aicpu/mvlgamma.py +32 -0
  724. mindspore/ops/_op_impl/aicpu/mvlgamma_grad.py +33 -0
  725. mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
  726. mindspore/ops/_op_impl/aicpu/neg.py +36 -0
  727. mindspore/ops/_op_impl/aicpu/nextafter.py +32 -0
  728. mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
  729. mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
  730. mindspore/ops/_op_impl/aicpu/no_repeat_ngram.py +34 -0
  731. mindspore/ops/_op_impl/aicpu/non_deterministic_ints.py +33 -0
  732. mindspore/ops/_op_impl/aicpu/non_max_suppression.py +36 -0
  733. mindspore/ops/_op_impl/aicpu/non_max_suppression_with_overlaps.py +35 -0
  734. mindspore/ops/_op_impl/aicpu/non_zero.py +43 -0
  735. mindspore/ops/_op_impl/aicpu/not_equal.py +39 -0
  736. mindspore/ops/_op_impl/aicpu/nth_element.py +39 -0
  737. mindspore/ops/_op_impl/aicpu/nuclear_norm.py +33 -0
  738. mindspore/ops/_op_impl/aicpu/one_hot.py +116 -0
  739. mindspore/ops/_op_impl/aicpu/ones_like.py +39 -0
  740. mindspore/ops/_op_impl/aicpu/orgqr.py +34 -0
  741. mindspore/ops/_op_impl/aicpu/pad_and_shift.py +33 -0
  742. mindspore/ops/_op_impl/aicpu/pad_v3.py +61 -0
  743. mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +59 -0
  744. mindspore/ops/_op_impl/aicpu/padding.py +41 -0
  745. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +54 -0
  746. mindspore/ops/_op_impl/aicpu/pdist_grad.py +33 -0
  747. mindspore/ops/_op_impl/aicpu/poisson.py +37 -0
  748. mindspore/ops/_op_impl/aicpu/polar.py +32 -0
  749. mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
  750. mindspore/ops/_op_impl/aicpu/pow.py +39 -0
  751. mindspore/ops/_op_impl/aicpu/print_tensor.py +39 -0
  752. mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +113 -0
  753. mindspore/ops/_op_impl/aicpu/qr.py +36 -0
  754. mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
  755. mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
  756. mindspore/ops/_op_impl/aicpu/ragged_range.py +49 -0
  757. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
  758. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
  759. mindspore/ops/_op_impl/aicpu/random_categorical.py +68 -0
  760. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +36 -0
  761. mindspore/ops/_op_impl/aicpu/random_gamma.py +38 -0
  762. mindspore/ops/_op_impl/aicpu/random_poisson.py +134 -0
  763. mindspore/ops/_op_impl/aicpu/random_shuffle.py +47 -0
  764. mindspore/ops/_op_impl/aicpu/randperm.py +38 -0
  765. mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
  766. mindspore/ops/_op_impl/aicpu/range.py +36 -0
  767. mindspore/ops/_op_impl/aicpu/range_v2.py +35 -0
  768. mindspore/ops/_op_impl/aicpu/real.py +31 -0
  769. mindspore/ops/_op_impl/aicpu/real_div.py +40 -0
  770. mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
  771. mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
  772. mindspore/ops/_op_impl/aicpu/reduce_mean.py +57 -0
  773. mindspore/ops/_op_impl/aicpu/reduce_prod.py +57 -0
  774. mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
  775. mindspore/ops/_op_impl/aicpu/relu_grad_v3.py +41 -0
  776. mindspore/ops/_op_impl/aicpu/relu_v3.py +38 -0
  777. mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +96 -0
  778. mindspore/ops/_op_impl/aicpu/reshape.py +42 -0
  779. mindspore/ops/_op_impl/aicpu/resize_area.py +40 -0
  780. mindspore/ops/_op_impl/aicpu/resize_bicubic.py +20 -0
  781. mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +19 -0
  782. mindspore/ops/_op_impl/aicpu/resize_bilinear.py +32 -0
  783. mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +32 -0
  784. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +36 -0
  785. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +35 -0
  786. mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
  787. mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
  788. mindspore/ops/_op_impl/aicpu/reverse_sequence.py +55 -0
  789. mindspore/ops/_op_impl/aicpu/reversev2.py +54 -0
  790. mindspore/ops/_op_impl/aicpu/rgb_to_hsv.py +32 -0
  791. mindspore/ops/_op_impl/aicpu/right_shift.py +38 -0
  792. mindspore/ops/_op_impl/aicpu/rnnt_loss.py +35 -0
  793. mindspore/ops/_op_impl/aicpu/round.py +34 -0
  794. mindspore/ops/_op_impl/aicpu/rsqrt.py +33 -0
  795. mindspore/ops/_op_impl/aicpu/rsqrt_grad.py +36 -0
  796. mindspore/ops/_op_impl/aicpu/sample_distorted_bounding_box_v2.py +49 -0
  797. mindspore/ops/_op_impl/aicpu/scale_and_translate.py +52 -0
  798. mindspore/ops/_op_impl/aicpu/scale_and_translate_grad.py +36 -0
  799. mindspore/ops/_op_impl/aicpu/scatter.py +79 -0
  800. mindspore/ops/_op_impl/aicpu/scatter_add_with_axis.py +53 -0
  801. mindspore/ops/_op_impl/aicpu/scatter_elements.py +39 -0
  802. mindspore/ops/_op_impl/aicpu/scatter_nd.py +59 -0
  803. mindspore/ops/_op_impl/aicpu/scatter_nd_max.py +54 -0
  804. mindspore/ops/_op_impl/aicpu/scatter_nd_min.py +54 -0
  805. mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +59 -0
  806. mindspore/ops/_op_impl/aicpu/search_sorted.py +44 -0
  807. mindspore/ops/_op_impl/aicpu/segment_max.py +52 -0
  808. mindspore/ops/_op_impl/aicpu/segment_mean.py +56 -0
  809. mindspore/ops/_op_impl/aicpu/segment_min.py +52 -0
  810. mindspore/ops/_op_impl/aicpu/segment_prod.py +56 -0
  811. mindspore/ops/_op_impl/aicpu/segment_sum.py +56 -0
  812. mindspore/ops/_op_impl/aicpu/select.py +45 -0
  813. mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
  814. mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
  815. mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
  816. mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
  817. mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
  818. mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
  819. mindspore/ops/_op_impl/aicpu/set_size.py +38 -0
  820. mindspore/ops/_op_impl/aicpu/sign.py +36 -0
  821. mindspore/ops/_op_impl/aicpu/sin.py +34 -0
  822. mindspore/ops/_op_impl/aicpu/sinc.py +43 -0
  823. mindspore/ops/_op_impl/aicpu/sinh.py +34 -0
  824. mindspore/ops/_op_impl/aicpu/slice.py +59 -0
  825. mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
  826. mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
  827. mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
  828. mindspore/ops/_op_impl/aicpu/sort.py +39 -0
  829. mindspore/ops/_op_impl/aicpu/space_to_depth.py +44 -0
  830. mindspore/ops/_op_impl/aicpu/sparse_addmm.py +87 -0
  831. mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +80 -0
  832. mindspore/ops/_op_impl/aicpu/sparse_apply_centered_rms_prop.py +105 -0
  833. mindspore/ops/_op_impl/aicpu/sparse_apply_momentum.py +80 -0
  834. mindspore/ops/_op_impl/aicpu/sparse_apply_proximal_gradient_descent.py +79 -0
  835. mindspore/ops/_op_impl/aicpu/sparse_concat.py +59 -0
  836. mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
  837. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_add.py +58 -0
  838. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_div.py +58 -0
  839. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_mul.py +58 -0
  840. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
  841. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
  842. mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
  843. mindspore/ops/_op_impl/aicpu/sparse_matrix_nnz.py +81 -0
  844. mindspore/ops/_op_impl/aicpu/sparse_matrix_transpose.py +116 -0
  845. mindspore/ops/_op_impl/aicpu/sparse_reorder.py +56 -0
  846. mindspore/ops/_op_impl/aicpu/sparse_reshape.py +34 -0
  847. mindspore/ops/_op_impl/aicpu/sparse_segment_mean_grad.py +36 -0
  848. mindspore/ops/_op_impl/aicpu/sparse_segment_mean_with_num_segments.py +44 -0
  849. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n.py +43 -0
  850. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_grad.py +38 -0
  851. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_with_num_segments.py +44 -0
  852. mindspore/ops/_op_impl/aicpu/sparse_segment_sum.py +49 -0
  853. mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
  854. mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
  855. mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
  856. mindspore/ops/_op_impl/aicpu/sparse_softmax.py +33 -0
  857. mindspore/ops/_op_impl/aicpu/sparse_softmax_cross_entropy_with_logits_v2.py +35 -0
  858. mindspore/ops/_op_impl/aicpu/sparse_sparse_maximum.py +53 -0
  859. mindspore/ops/_op_impl/aicpu/sparse_sparse_minimum.py +53 -0
  860. mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_add.py +84 -0
  861. mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_mat_mul.py +190 -0
  862. mindspore/ops/_op_impl/aicpu/sparse_tensor_to_csr_sparse_matrix.py +51 -0
  863. mindspore/ops/_op_impl/aicpu/sparse_to_dense_v2.py +73 -0
  864. mindspore/ops/_op_impl/aicpu/split.py +45 -0
  865. mindspore/ops/_op_impl/aicpu/sqrt.py +34 -0
  866. mindspore/ops/_op_impl/aicpu/sqrt_grad.py +35 -0
  867. mindspore/ops/_op_impl/aicpu/square.py +35 -0
  868. mindspore/ops/_op_impl/aicpu/squared_difference.py +37 -0
  869. mindspore/ops/_op_impl/aicpu/squeeze.py +42 -0
  870. mindspore/ops/_op_impl/aicpu/sspaddmm.py +97 -0
  871. mindspore/ops/_op_impl/aicpu/stack.py +45 -0
  872. mindspore/ops/_op_impl/aicpu/stack_push_pop.py +87 -0
  873. mindspore/ops/_op_impl/aicpu/standard_laplace.py +34 -0
  874. mindspore/ops/_op_impl/aicpu/standard_normal.py +34 -0
  875. mindspore/ops/_op_impl/aicpu/stateless_dropout_genmask.py +37 -0
  876. mindspore/ops/_op_impl/aicpu/stft.py +70 -0
  877. mindspore/ops/_op_impl/aicpu/strided_slice.py +43 -0
  878. mindspore/ops/_op_impl/aicpu/strided_slice_grad.py +50 -0
  879. mindspore/ops/_op_impl/aicpu/sub.py +41 -0
  880. mindspore/ops/_op_impl/aicpu/sub_and_filter.py +36 -0
  881. mindspore/ops/_op_impl/aicpu/tan.py +34 -0
  882. mindspore/ops/_op_impl/aicpu/tanh.py +34 -0
  883. mindspore/ops/_op_impl/aicpu/tanh_grad.py +35 -0
  884. mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
  885. mindspore/ops/_op_impl/aicpu/tile.py +56 -0
  886. mindspore/ops/_op_impl/aicpu/topk.py +34 -0
  887. mindspore/ops/_op_impl/aicpu/trace.py +40 -0
  888. mindspore/ops/_op_impl/aicpu/tracegrad.py +41 -0
  889. mindspore/ops/_op_impl/aicpu/trans_data.py +35 -0
  890. mindspore/ops/_op_impl/aicpu/transpose.py +58 -0
  891. mindspore/ops/_op_impl/aicpu/tridiagonal_matmul.py +42 -0
  892. mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
  893. mindspore/ops/_op_impl/aicpu/tril.py +42 -0
  894. mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
  895. mindspore/ops/_op_impl/aicpu/triplet_margin_loss.py +62 -0
  896. mindspore/ops/_op_impl/aicpu/triu.py +43 -0
  897. mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
  898. mindspore/ops/_op_impl/aicpu/truncated_normal.py +39 -0
  899. mindspore/ops/_op_impl/aicpu/uniform.py +36 -0
  900. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +41 -0
  901. mindspore/ops/_op_impl/aicpu/uniform_int.py +36 -0
  902. mindspore/ops/_op_impl/aicpu/uniform_real.py +33 -0
  903. mindspore/ops/_op_impl/aicpu/unique.py +31 -0
  904. mindspore/ops/_op_impl/aicpu/unique_consecutive.py +47 -0
  905. mindspore/ops/_op_impl/aicpu/unique_with_pad.py +32 -0
  906. mindspore/ops/_op_impl/aicpu/unravel_index.py +32 -0
  907. mindspore/ops/_op_impl/aicpu/unsorted_segment_prod.py +53 -0
  908. mindspore/ops/_op_impl/aicpu/unsorted_segment_sum.py +57 -0
  909. mindspore/ops/_op_impl/aicpu/unstack.py +45 -0
  910. mindspore/ops/_op_impl/aicpu/update_cache.py +44 -0
  911. mindspore/ops/_op_impl/aicpu/upper_bound.py +47 -0
  912. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +42 -0
  913. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +49 -0
  914. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +40 -0
  915. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +50 -0
  916. mindspore/ops/_op_impl/aicpu/xdivy.py +35 -0
  917. mindspore/ops/_op_impl/aicpu/xlogy.py +33 -0
  918. mindspore/ops/_op_impl/aicpu/zeros_like.py +42 -0
  919. mindspore/ops/_op_impl/aicpu/zeta.py +31 -0
  920. mindspore/ops/_op_impl/akg/__init__.py +19 -0
  921. mindspore/ops/_op_impl/akg/ascend/__init__.py +48 -0
  922. mindspore/ops/_op_impl/akg/ascend/abs.py +35 -0
  923. mindspore/ops/_op_impl/akg/ascend/add.py +42 -0
  924. mindspore/ops/_op_impl/akg/ascend/add_n.py +37 -0
  925. mindspore/ops/_op_impl/akg/ascend/batchmatmul.py +33 -0
  926. mindspore/ops/_op_impl/akg/ascend/cast.py +46 -0
  927. mindspore/ops/_op_impl/akg/ascend/equal.py +35 -0
  928. mindspore/ops/_op_impl/akg/ascend/exp.py +35 -0
  929. mindspore/ops/_op_impl/akg/ascend/expand_dims.py +33 -0
  930. mindspore/ops/_op_impl/akg/ascend/greater.py +34 -0
  931. mindspore/ops/_op_impl/akg/ascend/greater_equal.py +35 -0
  932. mindspore/ops/_op_impl/akg/ascend/less.py +31 -0
  933. mindspore/ops/_op_impl/akg/ascend/less_equal.py +35 -0
  934. mindspore/ops/_op_impl/akg/ascend/load_im2col.py +33 -0
  935. mindspore/ops/_op_impl/akg/ascend/log.py +34 -0
  936. mindspore/ops/_op_impl/akg/ascend/maximum.py +36 -0
  937. mindspore/ops/_op_impl/akg/ascend/minimum.py +39 -0
  938. mindspore/ops/_op_impl/akg/ascend/mul.py +41 -0
  939. mindspore/ops/_op_impl/akg/ascend/neg.py +37 -0
  940. mindspore/ops/_op_impl/akg/ascend/pow.py +35 -0
  941. mindspore/ops/_op_impl/akg/ascend/prod_force_se_a.py +33 -0
  942. mindspore/ops/_op_impl/akg/ascend/real_div.py +36 -0
  943. mindspore/ops/_op_impl/akg/ascend/reciprocal.py +32 -0
  944. mindspore/ops/_op_impl/akg/ascend/reduce_max.py +32 -0
  945. mindspore/ops/_op_impl/akg/ascend/reduce_min.py +32 -0
  946. mindspore/ops/_op_impl/akg/ascend/reduce_sum.py +37 -0
  947. mindspore/ops/_op_impl/akg/ascend/rsqrt.py +35 -0
  948. mindspore/ops/_op_impl/akg/ascend/select.py +37 -0
  949. mindspore/ops/_op_impl/akg/ascend/sqrt.py +35 -0
  950. mindspore/ops/_op_impl/akg/ascend/square.py +35 -0
  951. mindspore/ops/_op_impl/akg/ascend/sub.py +42 -0
  952. mindspore/ops/_op_impl/akg/cpu/__init__.py +23 -0
  953. mindspore/ops/_op_impl/akg/cpu/coo2csr.py +29 -0
  954. mindspore/ops/_op_impl/akg/cpu/csr2coo.py +29 -0
  955. mindspore/ops/_op_impl/akg/cpu/csr_gather.py +33 -0
  956. mindspore/ops/_op_impl/akg/cpu/csr_mm.py +34 -0
  957. mindspore/ops/_op_impl/akg/cpu/csr_mul.py +33 -0
  958. mindspore/ops/_op_impl/akg/cpu/csr_mv.py +33 -0
  959. mindspore/ops/_op_impl/akg/cpu/csr_reduce_sum.py +31 -0
  960. mindspore/ops/_op_impl/akg/gpu/__init__.py +24 -0
  961. mindspore/ops/_op_impl/akg/gpu/coo2csr.py +29 -0
  962. mindspore/ops/_op_impl/akg/gpu/csr2coo.py +29 -0
  963. mindspore/ops/_op_impl/akg/gpu/csr_div.py +36 -0
  964. mindspore/ops/_op_impl/akg/gpu/csr_gather.py +33 -0
  965. mindspore/ops/_op_impl/akg/gpu/csr_mm.py +37 -0
  966. mindspore/ops/_op_impl/akg/gpu/csr_mul.py +36 -0
  967. mindspore/ops/_op_impl/akg/gpu/csr_mv.py +36 -0
  968. mindspore/ops/_op_impl/akg/gpu/csr_reduce_sum.py +33 -0
  969. mindspore/ops/_op_impl/cpu/__init__.py +78 -0
  970. mindspore/ops/_op_impl/cpu/adam.py +49 -0
  971. mindspore/ops/_op_impl/cpu/adam_weight_decay.py +47 -0
  972. mindspore/ops/_op_impl/cpu/arg_max.py +30 -0
  973. mindspore/ops/_op_impl/cpu/arg_max_with_value.py +31 -0
  974. mindspore/ops/_op_impl/cpu/arg_min_with_value.py +31 -0
  975. mindspore/ops/_op_impl/cpu/buffer_append.py +28 -0
  976. mindspore/ops/_op_impl/cpu/buffer_get.py +28 -0
  977. mindspore/ops/_op_impl/cpu/buffer_sample.py +28 -0
  978. mindspore/ops/_op_impl/cpu/cast.py +171 -0
  979. mindspore/ops/_op_impl/cpu/concat_offset.py +38 -0
  980. mindspore/ops/_op_impl/cpu/conv2d.py +30 -0
  981. mindspore/ops/_op_impl/cpu/conv3d.py +30 -0
  982. mindspore/ops/_op_impl/cpu/div.py +32 -0
  983. mindspore/ops/_op_impl/cpu/dropout.py +31 -0
  984. mindspore/ops/_op_impl/cpu/dropout_grad.py +30 -0
  985. mindspore/ops/_op_impl/cpu/dynamic_shape.py +42 -0
  986. mindspore/ops/_op_impl/cpu/dynamic_stitch.py +41 -0
  987. mindspore/ops/_op_impl/cpu/equal_count.py +30 -0
  988. mindspore/ops/_op_impl/cpu/gather_d.py +49 -0
  989. mindspore/ops/_op_impl/cpu/gather_d_grad.py +38 -0
  990. mindspore/ops/_op_impl/cpu/gather_d_grad_v2.py +40 -0
  991. mindspore/ops/_op_impl/cpu/gather_v2.py +40 -0
  992. mindspore/ops/_op_impl/cpu/hsigmoid.py +33 -0
  993. mindspore/ops/_op_impl/cpu/hsigmoid_grad.py +34 -0
  994. mindspore/ops/_op_impl/cpu/hswish.py +32 -0
  995. mindspore/ops/_op_impl/cpu/hswish_grad.py +33 -0
  996. mindspore/ops/_op_impl/cpu/identity_n.py +40 -0
  997. mindspore/ops/_op_impl/cpu/is_finite.py +39 -0
  998. mindspore/ops/_op_impl/cpu/l2loss.py +30 -0
  999. mindspore/ops/_op_impl/cpu/layer_norm.py +36 -0
  1000. mindspore/ops/_op_impl/cpu/layer_norm_grad.py +38 -0
  1001. mindspore/ops/_op_impl/cpu/maximum.py +35 -0
  1002. mindspore/ops/_op_impl/cpu/maximum_grad.py +47 -0
  1003. mindspore/ops/_op_impl/cpu/minimum.py +40 -0
  1004. mindspore/ops/_op_impl/cpu/minimum_grad.py +51 -0
  1005. mindspore/ops/_op_impl/cpu/mirror_pad.py +36 -0
  1006. mindspore/ops/_op_impl/cpu/mirror_pad_grad.py +36 -0
  1007. mindspore/ops/_op_impl/cpu/mul.py +32 -0
  1008. mindspore/ops/_op_impl/cpu/one_hot.py +31 -0
  1009. mindspore/ops/_op_impl/cpu/pad.py +32 -0
  1010. mindspore/ops/_op_impl/cpu/pow.py +32 -0
  1011. mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +42 -0
  1012. mindspore/ops/_op_impl/cpu/pyexecute.py +29 -0
  1013. mindspore/ops/_op_impl/cpu/pyfunc.py +29 -0
  1014. mindspore/ops/_op_impl/cpu/range.py +34 -0
  1015. mindspore/ops/_op_impl/cpu/real_div.py +33 -0
  1016. mindspore/ops/_op_impl/cpu/reduce_all.py +29 -0
  1017. mindspore/ops/_op_impl/cpu/reduce_any.py +29 -0
  1018. mindspore/ops/_op_impl/cpu/reduce_max.py +32 -0
  1019. mindspore/ops/_op_impl/cpu/reduce_mean.py +40 -0
  1020. mindspore/ops/_op_impl/cpu/reduce_min.py +32 -0
  1021. mindspore/ops/_op_impl/cpu/reduce_prod.py +40 -0
  1022. mindspore/ops/_op_impl/cpu/reduce_std.py +31 -0
  1023. mindspore/ops/_op_impl/cpu/reduce_sum.py +41 -0
  1024. mindspore/ops/_op_impl/cpu/space_to_batch_nd.py +38 -0
  1025. mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
  1026. mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
  1027. mindspore/ops/_op_impl/cpu/split.py +34 -0
  1028. mindspore/ops/_op_impl/cpu/sspaddmm.py +95 -0
  1029. mindspore/ops/_op_impl/cpu/stack.py +38 -0
  1030. mindspore/ops/_op_impl/cpu/sub.py +32 -0
  1031. mindspore/ops/_op_impl/cpu/tensor_copy_slices.py +41 -0
  1032. mindspore/ops/_op_impl/cpu/tile.py +37 -0
  1033. mindspore/ops/_op_impl/cpu/top_k.py +31 -0
  1034. mindspore/ops/_op_impl/cpu/transpose.py +39 -0
  1035. mindspore/ops/_primitive_cache.py +90 -0
  1036. mindspore/ops/_register_for_op.py +73 -0
  1037. mindspore/ops/_utils/__init__.py +20 -0
  1038. mindspore/ops/_utils/utils.py +147 -0
  1039. mindspore/ops/_vmap/__init__.py +25 -0
  1040. mindspore/ops/_vmap/vmap_array_ops.py +2149 -0
  1041. mindspore/ops/_vmap/vmap_base.py +533 -0
  1042. mindspore/ops/_vmap/vmap_convolution_ops.py +441 -0
  1043. mindspore/ops/_vmap/vmap_debug_ops.py +50 -0
  1044. mindspore/ops/_vmap/vmap_grad_math_ops.py +274 -0
  1045. mindspore/ops/_vmap/vmap_grad_nn_ops.py +806 -0
  1046. mindspore/ops/_vmap/vmap_image_ops.py +194 -0
  1047. mindspore/ops/_vmap/vmap_math_ops.py +993 -0
  1048. mindspore/ops/_vmap/vmap_nn_ops.py +2250 -0
  1049. mindspore/ops/_vmap/vmap_other_ops.py +105 -0
  1050. mindspore/ops/_vmap/vmap_random_ops.py +122 -0
  1051. mindspore/ops/_vmap/vmap_sparse_ops.py +89 -0
  1052. mindspore/ops/auto_generate/__init__.py +31 -0
  1053. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +309 -0
  1054. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +252 -0
  1055. mindspore/ops/auto_generate/gen_arg_handler.py +197 -0
  1056. mindspore/ops/auto_generate/gen_extend_func.py +1701 -0
  1057. mindspore/ops/auto_generate/gen_ops_def.py +8482 -0
  1058. mindspore/ops/auto_generate/gen_ops_prim.py +16704 -0
  1059. mindspore/ops/auto_generate/pyboost_inner_prim.py +549 -0
  1060. mindspore/ops/composite/__init__.py +71 -0
  1061. mindspore/ops/composite/base.py +1318 -0
  1062. mindspore/ops/composite/env_ops.py +41 -0
  1063. mindspore/ops/composite/math_ops.py +125 -0
  1064. mindspore/ops/composite/multitype_ops/__init__.py +77 -0
  1065. mindspore/ops/composite/multitype_ops/_compile_utils.py +1459 -0
  1066. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +897 -0
  1067. mindspore/ops/composite/multitype_ops/add_impl.py +606 -0
  1068. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +56 -0
  1069. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +56 -0
  1070. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +56 -0
  1071. mindspore/ops/composite/multitype_ops/div_impl.py +189 -0
  1072. mindspore/ops/composite/multitype_ops/equal_impl.py +335 -0
  1073. mindspore/ops/composite/multitype_ops/floordiv_impl.py +88 -0
  1074. mindspore/ops/composite/multitype_ops/getitem_impl.py +400 -0
  1075. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +109 -0
  1076. mindspore/ops/composite/multitype_ops/greater_impl.py +110 -0
  1077. mindspore/ops/composite/multitype_ops/in_impl.py +196 -0
  1078. mindspore/ops/composite/multitype_ops/left_shift_impl.py +37 -0
  1079. mindspore/ops/composite/multitype_ops/less_equal_impl.py +111 -0
  1080. mindspore/ops/composite/multitype_ops/less_impl.py +112 -0
  1081. mindspore/ops/composite/multitype_ops/logic_not_impl.py +113 -0
  1082. mindspore/ops/composite/multitype_ops/logical_and_impl.py +60 -0
  1083. mindspore/ops/composite/multitype_ops/logical_or_impl.py +61 -0
  1084. mindspore/ops/composite/multitype_ops/mod_impl.py +86 -0
  1085. mindspore/ops/composite/multitype_ops/mul_impl.py +294 -0
  1086. mindspore/ops/composite/multitype_ops/negative_impl.py +79 -0
  1087. mindspore/ops/composite/multitype_ops/not_equal_impl.py +290 -0
  1088. mindspore/ops/composite/multitype_ops/not_in_impl.py +196 -0
  1089. mindspore/ops/composite/multitype_ops/ones_like_impl.py +96 -0
  1090. mindspore/ops/composite/multitype_ops/pow_impl.py +87 -0
  1091. mindspore/ops/composite/multitype_ops/right_shift_impl.py +37 -0
  1092. mindspore/ops/composite/multitype_ops/setitem_impl.py +884 -0
  1093. mindspore/ops/composite/multitype_ops/sub_impl.py +116 -0
  1094. mindspore/ops/composite/multitype_ops/uadd_impl.py +29 -0
  1095. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +228 -0
  1096. mindspore/ops/deprecated.py +315 -0
  1097. mindspore/ops/function/__init__.py +782 -0
  1098. mindspore/ops/function/array_func.py +7226 -0
  1099. mindspore/ops/function/clip_func.py +384 -0
  1100. mindspore/ops/function/debug_func.py +181 -0
  1101. mindspore/ops/function/fft_func.py +44 -0
  1102. mindspore/ops/function/grad/__init__.py +34 -0
  1103. mindspore/ops/function/grad/grad_func.py +1425 -0
  1104. mindspore/ops/function/image_func.py +292 -0
  1105. mindspore/ops/function/linalg_func.py +416 -0
  1106. mindspore/ops/function/math_func.py +12228 -0
  1107. mindspore/ops/function/nn_func.py +8609 -0
  1108. mindspore/ops/function/other_func.py +115 -0
  1109. mindspore/ops/function/parameter_func.py +134 -0
  1110. mindspore/ops/function/random_func.py +1715 -0
  1111. mindspore/ops/function/reshard_func.py +104 -0
  1112. mindspore/ops/function/sparse_func.py +884 -0
  1113. mindspore/ops/function/sparse_unary_func.py +2422 -0
  1114. mindspore/ops/function/spectral_func.py +150 -0
  1115. mindspore/ops/function/vmap_func.py +117 -0
  1116. mindspore/ops/functional.py +464 -0
  1117. mindspore/ops/op_info_register.py +1572 -0
  1118. mindspore/ops/operations/__init__.py +722 -0
  1119. mindspore/ops/operations/_csr_ops.py +403 -0
  1120. mindspore/ops/operations/_custom_grad.py +181 -0
  1121. mindspore/ops/operations/_embedding_cache_ops.py +307 -0
  1122. mindspore/ops/operations/_grad_ops.py +2978 -0
  1123. mindspore/ops/operations/_infer_ops.py +19 -0
  1124. mindspore/ops/operations/_inner_ops.py +2544 -0
  1125. mindspore/ops/operations/_map_tensor_ops.py +112 -0
  1126. mindspore/ops/operations/_ms_kernel.py +601 -0
  1127. mindspore/ops/operations/_ocr_ops.py +379 -0
  1128. mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
  1129. mindspore/ops/operations/_pyfunc_registry.py +58 -0
  1130. mindspore/ops/operations/_quant_ops.py +1844 -0
  1131. mindspore/ops/operations/_rl_inner_ops.py +1231 -0
  1132. mindspore/ops/operations/_scalar_ops.py +106 -0
  1133. mindspore/ops/operations/_sequence_ops.py +1155 -0
  1134. mindspore/ops/operations/_sparse_grad_ops.py +56 -0
  1135. mindspore/ops/operations/_tensor_array.py +359 -0
  1136. mindspore/ops/operations/_thor_ops.py +807 -0
  1137. mindspore/ops/operations/array_ops.py +6124 -0
  1138. mindspore/ops/operations/comm_ops.py +1985 -0
  1139. mindspore/ops/operations/control_ops.py +127 -0
  1140. mindspore/ops/operations/custom_ops.py +1129 -0
  1141. mindspore/ops/operations/debug_ops.py +678 -0
  1142. mindspore/ops/operations/image_ops.py +1041 -0
  1143. mindspore/ops/operations/inner_ops.py +697 -0
  1144. mindspore/ops/operations/linalg_ops.py +95 -0
  1145. mindspore/ops/operations/manually_defined/__init__.py +24 -0
  1146. mindspore/ops/operations/manually_defined/_inner.py +73 -0
  1147. mindspore/ops/operations/manually_defined/ops_def.py +2271 -0
  1148. mindspore/ops/operations/math_ops.py +5095 -0
  1149. mindspore/ops/operations/nn_ops.py +9575 -0
  1150. mindspore/ops/operations/other_ops.py +874 -0
  1151. mindspore/ops/operations/random_ops.py +1288 -0
  1152. mindspore/ops/operations/reshard_ops.py +53 -0
  1153. mindspore/ops/operations/rl_ops.py +288 -0
  1154. mindspore/ops/operations/sparse_ops.py +2753 -0
  1155. mindspore/ops/operations/spectral_ops.py +111 -0
  1156. mindspore/ops/primitive.py +1046 -0
  1157. mindspore/ops/signature.py +54 -0
  1158. mindspore/ops/vm_impl_registry.py +91 -0
  1159. mindspore/ops_generate/__init__.py +27 -0
  1160. mindspore/ops_generate/arg_dtype_cast.py +252 -0
  1161. mindspore/ops_generate/arg_handler.py +197 -0
  1162. mindspore/ops_generate/gen_aclnn_implement.py +263 -0
  1163. mindspore/ops_generate/gen_constants.py +36 -0
  1164. mindspore/ops_generate/gen_ops.py +1099 -0
  1165. mindspore/ops_generate/gen_ops_inner_prim.py +131 -0
  1166. mindspore/ops_generate/gen_pyboost_func.py +1052 -0
  1167. mindspore/ops_generate/gen_utils.py +209 -0
  1168. mindspore/ops_generate/op_proto.py +145 -0
  1169. mindspore/ops_generate/pyboost_utils.py +367 -0
  1170. mindspore/ops_generate/template.py +261 -0
  1171. mindspore/parallel/__init__.py +30 -0
  1172. mindspore/parallel/_auto_parallel_context.py +1486 -0
  1173. mindspore/parallel/_cell_wrapper.py +174 -0
  1174. mindspore/parallel/_cost_model_context.py +700 -0
  1175. mindspore/parallel/_dp_allreduce_fusion.py +159 -0
  1176. mindspore/parallel/_offload_context.py +275 -0
  1177. mindspore/parallel/_parallel_serialization.py +561 -0
  1178. mindspore/parallel/_ps_context.py +242 -0
  1179. mindspore/parallel/_recovery_context.py +110 -0
  1180. mindspore/parallel/_tensor.py +730 -0
  1181. mindspore/parallel/_transformer/__init__.py +35 -0
  1182. mindspore/parallel/_transformer/layers.py +765 -0
  1183. mindspore/parallel/_transformer/loss.py +251 -0
  1184. mindspore/parallel/_transformer/moe.py +693 -0
  1185. mindspore/parallel/_transformer/op_parallel_config.py +222 -0
  1186. mindspore/parallel/_transformer/transformer.py +3119 -0
  1187. mindspore/parallel/_utils.py +612 -0
  1188. mindspore/parallel/algo_parameter_config.py +400 -0
  1189. mindspore/parallel/checkpoint_transform.py +650 -0
  1190. mindspore/parallel/cluster/__init__.py +15 -0
  1191. mindspore/parallel/cluster/process_entity/__init__.py +18 -0
  1192. mindspore/parallel/cluster/process_entity/_api.py +352 -0
  1193. mindspore/parallel/cluster/process_entity/_utils.py +101 -0
  1194. mindspore/parallel/cluster/run.py +136 -0
  1195. mindspore/parallel/mpi/__init__.py +14 -0
  1196. mindspore/parallel/mpi/_mpi_config.py +116 -0
  1197. mindspore/parallel/parameter_broadcast.py +151 -0
  1198. mindspore/parallel/shard.py +481 -0
  1199. mindspore/parallel/transform_safetensors.py +993 -0
  1200. mindspore/perf_msvcbuildinsights.dll +0 -0
  1201. mindspore/pgodb140.dll +0 -0
  1202. mindspore/pgort140.dll +0 -0
  1203. mindspore/profiler/__init__.py +28 -0
  1204. mindspore/profiler/common/__init__.py +14 -0
  1205. mindspore/profiler/common/constant.py +29 -0
  1206. mindspore/profiler/common/exceptions/__init__.py +14 -0
  1207. mindspore/profiler/common/exceptions/error_code.py +83 -0
  1208. mindspore/profiler/common/exceptions/exceptions.py +286 -0
  1209. mindspore/profiler/common/process_pool.py +41 -0
  1210. mindspore/profiler/common/registry.py +47 -0
  1211. mindspore/profiler/common/singleton.py +28 -0
  1212. mindspore/profiler/common/struct_type.py +118 -0
  1213. mindspore/profiler/common/util.py +472 -0
  1214. mindspore/profiler/common/validator/__init__.py +14 -0
  1215. mindspore/profiler/common/validator/validate_path.py +84 -0
  1216. mindspore/profiler/dynamic_profiler.py +694 -0
  1217. mindspore/profiler/envprofiling.py +254 -0
  1218. mindspore/profiler/parser/__init__.py +14 -0
  1219. mindspore/profiler/parser/aicpu_data_parser.py +272 -0
  1220. mindspore/profiler/parser/ascend_analysis/__init__.py +14 -0
  1221. mindspore/profiler/parser/ascend_analysis/constant.py +71 -0
  1222. mindspore/profiler/parser/ascend_analysis/file_manager.py +180 -0
  1223. mindspore/profiler/parser/ascend_analysis/function_event.py +185 -0
  1224. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +136 -0
  1225. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +131 -0
  1226. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +104 -0
  1227. mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
  1228. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +123 -0
  1229. mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +86 -0
  1230. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +75 -0
  1231. mindspore/profiler/parser/ascend_cluster_generator.py +116 -0
  1232. mindspore/profiler/parser/ascend_communicate_generator.py +314 -0
  1233. mindspore/profiler/parser/ascend_flops_generator.py +116 -0
  1234. mindspore/profiler/parser/ascend_fpbp_generator.py +82 -0
  1235. mindspore/profiler/parser/ascend_hccl_generator.py +271 -0
  1236. mindspore/profiler/parser/ascend_integrate_generator.py +42 -0
  1237. mindspore/profiler/parser/ascend_memory_generator.py +185 -0
  1238. mindspore/profiler/parser/ascend_msprof_exporter.py +282 -0
  1239. mindspore/profiler/parser/ascend_msprof_generator.py +187 -0
  1240. mindspore/profiler/parser/ascend_op_generator.py +334 -0
  1241. mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
  1242. mindspore/profiler/parser/ascend_timeline_generator.py +545 -0
  1243. mindspore/profiler/parser/base_timeline_generator.py +483 -0
  1244. mindspore/profiler/parser/container.py +229 -0
  1245. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +697 -0
  1246. mindspore/profiler/parser/flops_parser.py +531 -0
  1247. mindspore/profiler/parser/framework_enum.py +111 -0
  1248. mindspore/profiler/parser/framework_parser.py +464 -0
  1249. mindspore/profiler/parser/framework_struct.py +61 -0
  1250. mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
  1251. mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
  1252. mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
  1253. mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
  1254. mindspore/profiler/parser/hccl_parser.py +573 -0
  1255. mindspore/profiler/parser/hwts_log_parser.py +122 -0
  1256. mindspore/profiler/parser/integrator.py +526 -0
  1257. mindspore/profiler/parser/memory_usage_parser.py +277 -0
  1258. mindspore/profiler/parser/minddata_analyzer.py +800 -0
  1259. mindspore/profiler/parser/minddata_parser.py +186 -0
  1260. mindspore/profiler/parser/minddata_pipeline_parser.py +299 -0
  1261. mindspore/profiler/parser/op_intermediate_parser.py +149 -0
  1262. mindspore/profiler/parser/optime_parser.py +250 -0
  1263. mindspore/profiler/parser/profiler_info.py +213 -0
  1264. mindspore/profiler/parser/step_trace_parser.py +666 -0
  1265. mindspore/profiler/profiler.py +153 -0
  1266. mindspore/profiler/profiling.py +1922 -0
  1267. mindspore/rewrite/__init__.py +28 -0
  1268. mindspore/rewrite/api/__init__.py +17 -0
  1269. mindspore/rewrite/api/node.py +519 -0
  1270. mindspore/rewrite/api/node_type.py +53 -0
  1271. mindspore/rewrite/api/pattern_engine.py +490 -0
  1272. mindspore/rewrite/api/scoped_value.py +181 -0
  1273. mindspore/rewrite/api/symbol_tree.py +497 -0
  1274. mindspore/rewrite/ast_helpers/__init__.py +25 -0
  1275. mindspore/rewrite/ast_helpers/ast_converter.py +143 -0
  1276. mindspore/rewrite/ast_helpers/ast_finder.py +404 -0
  1277. mindspore/rewrite/ast_helpers/ast_flattener.py +268 -0
  1278. mindspore/rewrite/ast_helpers/ast_modifier.py +605 -0
  1279. mindspore/rewrite/ast_helpers/ast_replacer.py +79 -0
  1280. mindspore/rewrite/common/__init__.py +19 -0
  1281. mindspore/rewrite/common/config.py +24 -0
  1282. mindspore/rewrite/common/error_log.py +39 -0
  1283. mindspore/rewrite/common/event.py +28 -0
  1284. mindspore/rewrite/common/namer.py +271 -0
  1285. mindspore/rewrite/common/namespace.py +118 -0
  1286. mindspore/rewrite/common/observable.py +44 -0
  1287. mindspore/rewrite/common/observer.py +54 -0
  1288. mindspore/rewrite/node/__init__.py +22 -0
  1289. mindspore/rewrite/node/call_function.py +95 -0
  1290. mindspore/rewrite/node/cell_container.py +139 -0
  1291. mindspore/rewrite/node/control_flow.py +113 -0
  1292. mindspore/rewrite/node/node.py +1428 -0
  1293. mindspore/rewrite/node/node_manager.py +283 -0
  1294. mindspore/rewrite/node/node_topological_manager.py +223 -0
  1295. mindspore/rewrite/parsers/__init__.py +29 -0
  1296. mindspore/rewrite/parsers/arguments_parser.py +63 -0
  1297. mindspore/rewrite/parsers/assign_parser.py +852 -0
  1298. mindspore/rewrite/parsers/attribute_parser.py +57 -0
  1299. mindspore/rewrite/parsers/class_def_parser.py +289 -0
  1300. mindspore/rewrite/parsers/constant_parser.py +104 -0
  1301. mindspore/rewrite/parsers/container_parser.py +88 -0
  1302. mindspore/rewrite/parsers/expr_parser.py +55 -0
  1303. mindspore/rewrite/parsers/for_parser.py +61 -0
  1304. mindspore/rewrite/parsers/function_def_parser.py +84 -0
  1305. mindspore/rewrite/parsers/if_parser.py +85 -0
  1306. mindspore/rewrite/parsers/module_parser.py +117 -0
  1307. mindspore/rewrite/parsers/parser.py +43 -0
  1308. mindspore/rewrite/parsers/parser_register.py +86 -0
  1309. mindspore/rewrite/parsers/return_parser.py +37 -0
  1310. mindspore/rewrite/parsers/while_parser.py +59 -0
  1311. mindspore/rewrite/sparsify/__init__.py +0 -0
  1312. mindspore/rewrite/sparsify/sparse_transformer.py +457 -0
  1313. mindspore/rewrite/sparsify/sparsify.py +112 -0
  1314. mindspore/rewrite/sparsify/utils.py +179 -0
  1315. mindspore/rewrite/symbol_tree/__init__.py +20 -0
  1316. mindspore/rewrite/symbol_tree/symbol_tree.py +1819 -0
  1317. mindspore/rewrite/symbol_tree/symbol_tree_builder.py +76 -0
  1318. mindspore/rewrite/symbol_tree/symbol_tree_dumper.py +142 -0
  1319. mindspore/run_check/__init__.py +20 -0
  1320. mindspore/run_check/_check_version.py +507 -0
  1321. mindspore/run_check/run_check.py +66 -0
  1322. mindspore/safeguard/__init__.py +18 -0
  1323. mindspore/safeguard/rewrite_obfuscation.py +875 -0
  1324. mindspore/swresample-4.dll +0 -0
  1325. mindspore/swscale-6.dll +0 -0
  1326. mindspore/tbbmalloc.dll +0 -0
  1327. mindspore/tinyxml2.dll +0 -0
  1328. mindspore/train/__init__.py +48 -0
  1329. mindspore/train/_utils.py +465 -0
  1330. mindspore/train/amp.py +935 -0
  1331. mindspore/train/anf_ir_pb2.py +1517 -0
  1332. mindspore/train/callback/__init__.py +44 -0
  1333. mindspore/train/callback/_backup_and_restore.py +117 -0
  1334. mindspore/train/callback/_callback.py +613 -0
  1335. mindspore/train/callback/_checkpoint.py +814 -0
  1336. mindspore/train/callback/_cluster_monitor.py +201 -0
  1337. mindspore/train/callback/_dataset_graph.py +150 -0
  1338. mindspore/train/callback/_early_stop.py +239 -0
  1339. mindspore/train/callback/_flops_collector.py +239 -0
  1340. mindspore/train/callback/_history.py +92 -0
  1341. mindspore/train/callback/_lambda_callback.py +80 -0
  1342. mindspore/train/callback/_landscape.py +1049 -0
  1343. mindspore/train/callback/_loss_monitor.py +107 -0
  1344. mindspore/train/callback/_lr_scheduler_callback.py +76 -0
  1345. mindspore/train/callback/_on_request_exit.py +298 -0
  1346. mindspore/train/callback/_reduce_lr_on_plateau.py +226 -0
  1347. mindspore/train/callback/_summary_collector.py +1184 -0
  1348. mindspore/train/callback/_tft_register.py +352 -0
  1349. mindspore/train/callback/_time_monitor.py +141 -0
  1350. mindspore/train/checkpoint_pb2.py +233 -0
  1351. mindspore/train/data_sink.py +219 -0
  1352. mindspore/train/dataset_helper.py +692 -0
  1353. mindspore/train/lineage_pb2.py +1260 -0
  1354. mindspore/train/loss_scale_manager.py +213 -0
  1355. mindspore/train/memory_profiling_pb2.py +298 -0
  1356. mindspore/train/metrics/__init__.py +175 -0
  1357. mindspore/train/metrics/accuracy.py +133 -0
  1358. mindspore/train/metrics/auc.py +129 -0
  1359. mindspore/train/metrics/bleu_score.py +170 -0
  1360. mindspore/train/metrics/confusion_matrix.py +700 -0
  1361. mindspore/train/metrics/cosine_similarity.py +109 -0
  1362. mindspore/train/metrics/dice.py +116 -0
  1363. mindspore/train/metrics/error.py +175 -0
  1364. mindspore/train/metrics/fbeta.py +167 -0
  1365. mindspore/train/metrics/hausdorff_distance.py +333 -0
  1366. mindspore/train/metrics/loss.py +97 -0
  1367. mindspore/train/metrics/mean_surface_distance.py +189 -0
  1368. mindspore/train/metrics/metric.py +373 -0
  1369. mindspore/train/metrics/occlusion_sensitivity.py +225 -0
  1370. mindspore/train/metrics/perplexity.py +133 -0
  1371. mindspore/train/metrics/precision.py +160 -0
  1372. mindspore/train/metrics/recall.py +159 -0
  1373. mindspore/train/metrics/roc.py +223 -0
  1374. mindspore/train/metrics/root_mean_square_surface_distance.py +191 -0
  1375. mindspore/train/metrics/topk.py +167 -0
  1376. mindspore/train/mind_ir_pb2.py +1908 -0
  1377. mindspore/train/model.py +2252 -0
  1378. mindspore/train/node_strategy_pb2.py +653 -0
  1379. mindspore/train/print_pb2.py +184 -0
  1380. mindspore/train/profiling_parallel_pb2.py +151 -0
  1381. mindspore/train/serialization.py +3325 -0
  1382. mindspore/train/summary/__init__.py +23 -0
  1383. mindspore/train/summary/_lineage_adapter.py +41 -0
  1384. mindspore/train/summary/_summary_adapter.py +496 -0
  1385. mindspore/train/summary/_writer_pool.py +207 -0
  1386. mindspore/train/summary/enums.py +56 -0
  1387. mindspore/train/summary/summary_record.py +581 -0
  1388. mindspore/train/summary/writer.py +167 -0
  1389. mindspore/train/summary_pb2.py +1165 -0
  1390. mindspore/train/train_thor/__init__.py +20 -0
  1391. mindspore/train/train_thor/convert_utils.py +268 -0
  1392. mindspore/train/train_thor/dataset_helper.py +192 -0
  1393. mindspore/train/train_thor/model_thor.py +257 -0
  1394. mindspore/turbojpeg.dll +0 -0
  1395. mindspore/utils/__init__.py +21 -0
  1396. mindspore/utils/utils.py +60 -0
  1397. mindspore/vcmeta.dll +0 -0
  1398. mindspore/vcomp140.dll +0 -0
  1399. mindspore/vcruntime140.dll +0 -0
  1400. mindspore/vcruntime140_1.dll +0 -0
  1401. mindspore/version.py +1 -0
  1402. mindspore-2.4.0.dist-info/METADATA +352 -0
  1403. mindspore-2.4.0.dist-info/RECORD +1406 -0
  1404. mindspore-2.4.0.dist-info/WHEEL +5 -0
  1405. mindspore-2.4.0.dist-info/entry_points.txt +3 -0
  1406. mindspore-2.4.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1371 @@
1
+ # Copyright 2023 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """LRScheduler."""
16
+ from collections import Counter
17
+ from bisect import bisect_right
18
+ import math
19
+ from mindspore import ops, Tensor, Parameter
20
+ from mindspore.experimental.optim.optimizer import Optimizer
21
+ from mindspore.common.api import jit_class
22
+ import mindspore.common.dtype as mstype
23
+ from mindspore.ops import functional as F
24
+ from mindspore.ops import operations as P
25
+ from mindspore import _checkparam as Validator
26
+
27
+
28
+ __all__ = ['StepLR', 'LinearLR', 'LRScheduler', 'ExponentialLR', 'PolynomialLR',
29
+ 'MultiplicativeLR', 'ConstantLR', 'MultiStepLR', 'LambdaLR', 'SequentialLR', 'ReduceLROnPlateau',
30
+ 'CyclicLR', 'CosineAnnealingWarmRestarts', 'CosineAnnealingLR']
31
+
32
+
33
+ @jit_class
34
+ class LRScheduler:
35
+ r"""
36
+ Basic class of learning rate schedule.
37
+
38
+ .. warning::
39
+ This is an experimental lr scheduler module that is subject to change.
40
+ This module must be used with optimizers in `Experimental Optimizer
41
+ <https://www.mindspore.cn/docs/en/master/api_python/mindspore.experimental.html#experimental-optimizer>`_ .
42
+
43
+ Args:
44
+ optimizer (:class:`mindspore.experimental.optim.Optimizer`): The optimizer instance.
45
+ last_epoch (int, optional): The index of the last epoch. Default: ``-1``.
46
+
47
+ Raises:
48
+ TypeError: If `optimizer` is not an Optimizer.
49
+ KeyError: If `last_epoch` != -1 and ``'initial_lr'`` not in param groups.
50
+ ValueError: if `last_epoch` is not int.
51
+ ValueError: If `last_epoch` is not greater than -1.
52
+
53
+ Supported Platforms:
54
+ ``Ascend`` ``GPU`` ``CPU``
55
+
56
+ Examples:
57
+ >>> from mindspore import nn
58
+ >>> from mindspore.experimental import optim
59
+ >>>
60
+ >>> class ConstantLR(optim.lr_scheduler.LRScheduler):
61
+ ... def __init__(self, optimizer, factor=0.5, total_iters=3, last_epoch=-1):
62
+ ... self.factor = factor
63
+ ... self.total_iters = total_iters
64
+ ... super(ConstantLR, self).__init__(optimizer, last_epoch)
65
+ ...
66
+ ... def get_lr(self):
67
+ ... if self.last_epoch == 0:
68
+ ... return [lr * self.factor for lr in self._last_lr]
69
+ ... if self.last_epoch != self.total_iters:
70
+ ... return [lr * 1. for lr in self._last_lr]
71
+ ... return [lr / self.factor for lr in self._last_lr]
72
+ >>>
73
+ >>> net = nn.Dense(8, 2)
74
+ >>> optimizer = optim.SGD(net.trainable_params(), 0.01)
75
+ >>> scheduler = ConstantLR(optimizer)
76
+ >>> for i in range(4):
77
+ ... scheduler.step()
78
+ ... current_lr = scheduler.get_last_lr()
79
+ ... print(current_lr)
80
+ [Tensor(shape=[], dtype=Float32, value= 0.005)]
81
+ [Tensor(shape=[], dtype=Float32, value= 0.005)]
82
+ [Tensor(shape=[], dtype=Float32, value= 0.01)]
83
+ [Tensor(shape=[], dtype=Float32, value= 0.01)]
84
+ """
85
+ def __init__(self, optimizer, last_epoch=-1):
86
+ if not isinstance(optimizer, Optimizer):
87
+ raise TypeError('{} is not an Optimizer'.format(
88
+ type(optimizer).__name__))
89
+ Validator.check_value_type("last_epoch", last_epoch, [int])
90
+ if last_epoch < -1:
91
+ raise ValueError("Invalid last_epoch: {}".format(last_epoch))
92
+ if last_epoch == -1:
93
+ for group in optimizer.param_groups:
94
+ group.setdefault('initial_lr', group['lr'].value())
95
+ else:
96
+ for i, group in enumerate(optimizer.param_groups):
97
+ if 'initial_lr' not in group:
98
+ raise KeyError(f"param 'initial_lr' is not specified "
99
+ f"in param_groups[{i}] when resuming an optimizer")
100
+ self.base_lrs = [group['initial_lr'] for group in optimizer.param_groups]
101
+ self.optimizer = optimizer
102
+ self._last_lr = [lr for lr in optimizer.lrs]
103
+ self.groups_num = len(optimizer.param_groups)
104
+ self.last_epoch = Parameter(Tensor(last_epoch, dtype=mstype.float32),
105
+ name='last_epoch_' + self.__class__.__name__)
106
+ self.increase_tensor = Tensor(1, mstype.int32)
107
+ self.step()
108
+
109
+ @staticmethod
110
+ def get_lr():
111
+ raise NotImplementedError
112
+
113
+ def get_last_lr(self):
114
+ """
115
+ Return last computed learning rate by current scheduler.
116
+ """
117
+ return [lr.value() for lr in self._last_lr]
118
+
119
+ def step(self, epoch=None):
120
+ """
121
+ Get the current learning rate and change the learning rate.
122
+
123
+ Args:
124
+ epoch (int, optional): The index of the last epoch. Default: ``None``.
125
+ """
126
+ if epoch is None:
127
+ ops.assign_add(self.last_epoch, self.increase_tensor)
128
+ values = self.get_lr()
129
+ else:
130
+ ops.assign(self.last_epoch, epoch)
131
+ if hasattr(self, "_get_closed_form_lr"):
132
+ values = self._get_closed_form_lr()
133
+ else:
134
+ values = self.get_lr()
135
+
136
+ for i in range(self.groups_num):
137
+ lr = values[i]
138
+ ops.assign(self._last_lr[i], lr)
139
+
140
+ return True
141
+
142
+
143
+ @jit_class
144
+ class StepLR(LRScheduler):
145
+ """Decays the learning rate of each parameter group by gamma every
146
+ step_size epochs. Notice that such decay can happen simultaneously with
147
+ other changes to the learning rate from outside this scheduler.
148
+
149
+ .. warning::
150
+ This is an experimental lr scheduler module that is subject to change.
151
+ This module must be used with optimizers in `Experimental Optimizer
152
+ <https://www.mindspore.cn/docs/en/master/api_python/mindspore.experimental.html#experimental-optimizer>`_ .
153
+
154
+ Args:
155
+ optimizer (:class:`mindspore.experimental.optim.Optimizer`): Wrapped optimizer.
156
+ step_size (int): Period of learning rate decay.
157
+ gamma (float, optional): Multiplicative factor of learning rate decay.
158
+ Default: ``0.5``.
159
+ last_epoch (int, optional): The index of the last epoch. Default: ``-1``.
160
+
161
+ Supported Platforms:
162
+ ``Ascend`` ``GPU`` ``CPU``
163
+
164
+ Examples:
165
+ >>> import mindspore
166
+ >>> from mindspore import nn
167
+ >>> from mindspore.experimental import optim
168
+ >>> # Define the network structure of LeNet5. Refer to
169
+ >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
170
+ >>> net = LeNet5()
171
+ >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
172
+ >>> optimizer = optim.Adam(net.trainable_params(), lr=0.05)
173
+ >>> # Assuming optimizer uses lr = 0.05 for all groups
174
+ >>> # lr = 0.05 if epoch < 2
175
+ >>> # lr = 0.005 if 2 <= epoch < 4
176
+ >>> # lr = 0.0005 if 4 <= epoch < 6
177
+ >>> scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.1)
178
+ >>> def forward_fn(data, label):
179
+ ... logits = net(data)
180
+ ... loss = loss_fn(logits, label)
181
+ ... return loss, logits
182
+ >>> grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)
183
+ >>> def train_step(data, label):
184
+ ... (loss, _), grads = grad_fn(data, label)
185
+ ... optimizer(grads)
186
+ ... return loss
187
+ >>> for epoch in range(6):
188
+ ... # Create the dataset taking MNIST as an example. Refer to
189
+ ... # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/mnist.py
190
+ ... for data, label in create_dataset():
191
+ ... train_step(data, label)
192
+ ... scheduler.step()
193
+ ... current_lr = scheduler.get_last_lr()
194
+ """
195
+ def __init__(self, optimizer, step_size, gamma=0.1, last_epoch=-1):
196
+ if not isinstance(step_size, int) and not isinstance(step_size, bool):
197
+ raise TypeError(f"For 'StepLR', the 'step_size' must be int, but got {type(step_size)}.")
198
+ if not isinstance(gamma, float):
199
+ raise TypeError(f"For 'StepLR', the 'gamma' must be float, but got {type(gamma)}.")
200
+ self.step_size = step_size
201
+ self.gamma = gamma
202
+ super(StepLR, self).__init__(optimizer, last_epoch)
203
+
204
+ def get_lr(self):
205
+ if self.last_epoch == 0 or self.last_epoch % self.step_size != 0:
206
+ return [lr * 1. for lr in self._last_lr]
207
+ return [lr * self.gamma for lr in self._last_lr]
208
+
209
+ def _get_closed_form_lr(self):
210
+ return [base_lr * self.gamma ** (self.last_epoch // self.step_size)
211
+ for base_lr in self.base_lrs]
212
+
213
+
214
+ @jit_class
215
+ class LinearLR(LRScheduler):
216
+ """Decays the learning rate of each parameter group by linearly changing small
217
+ multiplicative factor until the number of epoch reaches a pre-defined milestone: total_iters.
218
+ Notice that such decay can happen simultaneously with other changes to the learning rate
219
+ from outside this scheduler.
220
+
221
+ .. warning::
222
+ This is an experimental lr scheduler module that is subject to change.
223
+ This module must be used with optimizers in `Experimental Optimizer
224
+ <https://www.mindspore.cn/docs/en/master/api_python/mindspore.experimental.html#experimental-optimizer>`_ .
225
+
226
+ Args:
227
+ optimizer (:class:`mindspore.experimental.optim.Optimizer`): Wrapped optimizer.
228
+ start_factor (float, optional): The number we multiply learning rate in the first epoch.
229
+ The multiplication factor changes towards `end_factor` in the following epochs.
230
+ Default: ``1.0 /3``.
231
+ end_factor (float, optional): The number we multiply learning rate at the end of linear changing
232
+ process. Default: ``1.0``.
233
+ total_iters (int, optional): The number of iterations that multiplicative factor reaches to 1.
234
+ Default: ``5``.
235
+ last_epoch (int, optional): The index of the last epoch. Default: ``-1``.
236
+
237
+ Raises:
238
+ ValueError: If `start_factor` is not in the range of (0, 1].
239
+ ValueError: If `end_factor` is not in the range of [0, 1].
240
+
241
+ Supported Platforms:
242
+ ``Ascend`` ``GPU`` ``CPU``
243
+
244
+ Examples:
245
+ >>> import mindspore
246
+ >>> from mindspore import nn
247
+ >>> from mindspore.experimental import optim
248
+ >>> # Define the network structure of LeNet5. Refer to
249
+ >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
250
+ >>> net = LeNet5()
251
+ >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
252
+ >>> optimizer = optim.Adam(net.trainable_params(), lr=0.05)
253
+ >>> # Assuming optimizer uses lr = 0.05 for all groups
254
+ >>> # lr = 0.025 if epoch == 0
255
+ >>> # lr = 0.03125 if epoch == 1
256
+ >>> # lr = 0.0375 if epoch == 2
257
+ >>> # lr = 0.04375 if epoch == 3
258
+ >>> # lr = 0.05 if epoch >= 4
259
+ >>> scheduler = optim.lr_scheduler.LinearLR(optimizer, start_factor=0.5, total_iters=4)
260
+ >>> def forward_fn(data, label):
261
+ ... logits = net(data)
262
+ ... loss = loss_fn(logits, label)
263
+ ... return loss, logits
264
+ >>> grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)
265
+ >>> def train_step(data, label):
266
+ ... (loss, _), grads = grad_fn(data, label)
267
+ ... optimizer(grads)
268
+ ... return loss
269
+ >>> for epoch in range(5):
270
+ ... # Create the dataset taking MNIST as an example. Refer to
271
+ ... # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/mnist.py
272
+ ... for data, label in create_dataset():
273
+ ... train_step(data, label)
274
+ ... scheduler.step()
275
+ ... current_lr = scheduler.get_last_lr()
276
+ """
277
+
278
+ def __init__(self, optimizer, start_factor=1.0 / 3, end_factor=1.0, total_iters=5, last_epoch=-1):
279
+ if start_factor > 1.0 or start_factor <= 0:
280
+ raise ValueError('Starting multiplicative factor expected to be greater than 0 and '
281
+ 'less than or equal to 1.')
282
+
283
+ if end_factor > 1.0 or end_factor < 0:
284
+ raise ValueError('Ending multiplicative factor expected to be between 0 and 1.')
285
+ if not isinstance(total_iters, int):
286
+ raise TypeError(f"For 'LinearLR', the type of total_iters must be int, but got {type(total_iters)}.")
287
+ self.start_factor = start_factor
288
+ self.end_factor = end_factor
289
+ self.total_iters = total_iters
290
+ super(LinearLR, self).__init__(optimizer, last_epoch)
291
+
292
+ def get_lr(self):
293
+
294
+ if self.last_epoch == 0:
295
+ return [lr * self.start_factor for lr in self._last_lr]
296
+
297
+ if self.last_epoch > self.total_iters:
298
+ return [lr * 1. for lr in self._last_lr]
299
+
300
+ factor = 1. + (self.end_factor - self.start_factor) / (
301
+ self.total_iters * self.start_factor + (self.last_epoch - 1) * (self.end_factor - self.start_factor))
302
+ return [lr * factor for lr in self._last_lr]
303
+
304
+ def _get_closed_form_lr(self):
305
+ return [base_lr * (self.start_factor +
306
+ (self.end_factor - self.start_factor) * min(self.total_iters, self.last_epoch)
307
+ / self.total_iters) for base_lr in self.base_lrs]
308
+
309
+
310
+ @jit_class
311
+ class ExponentialLR(LRScheduler):
312
+ r"""
313
+ For each epoch, the learning rate decays exponentially, multiplied by gamma.
314
+ Notice that such decay can happen simultaneously with other changes to the learning rate
315
+ from outside this scheduler.
316
+
317
+ .. warning::
318
+ This is an experimental lr scheduler module that is subject to change.
319
+ This module must be used with optimizers in `Experimental Optimizer
320
+ <https://www.mindspore.cn/docs/en/master/api_python/mindspore.experimental.html#experimental-optimizer>`_ .
321
+
322
+ Args:
323
+ optimizer (:class:`mindspore.experimental.optim.Optimizer`): Wrapped optimizer.
324
+ gamma (float): Learning rate scaling factor.
325
+ last_epoch (int, optional): The index of the last epoch. Default: ``-1``.
326
+
327
+ Supported Platforms:
328
+ ``Ascend`` ``GPU`` ``CPU``
329
+
330
+ Examples:
331
+ >>> from mindspore import nn
332
+ >>> from mindspore.experimental import optim
333
+ >>> class Net(nn.Cell):
334
+ ... def __init__(self):
335
+ ... super(Net, self).__init__()
336
+ ... self.fc = nn.Dense(16 * 5 * 5, 120)
337
+ ... def construct(self, x):
338
+ ... return self.fc(x)
339
+ >>> net = Net()
340
+ >>> optimizer = optim.Adam(net.trainable_params(), 0.01)
341
+ >>> scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.5)
342
+ >>> for i in range(3):
343
+ ... scheduler.step()
344
+ ... current_lr = scheduler.get_last_lr()
345
+ ... print(current_lr)
346
+ [Tensor(shape=[], dtype=Float32, value= 0.005)]
347
+ [Tensor(shape=[], dtype=Float32, value= 0.0025)]
348
+ [Tensor(shape=[], dtype=Float32, value= 0.00125)]
349
+ """
350
+
351
+ def __init__(self, optimizer, gamma, last_epoch=-1):
352
+ if not isinstance(gamma, float):
353
+ raise TypeError(f"For 'ExponentialLR', the 'gamma' must be float, but got {type(gamma)}.")
354
+ self.gamma = gamma
355
+ super(ExponentialLR, self).__init__(optimizer, last_epoch)
356
+
357
+ def get_lr(self):
358
+ if self.last_epoch == 0:
359
+ return [lr * 1. for lr in self._last_lr]
360
+ return [lr * self.gamma for lr in self._last_lr]
361
+
362
+ def _get_closed_form_lr(self):
363
+ return [base_lr * self.gamma ** self.last_epoch
364
+ for base_lr in self.base_lrs]
365
+
366
+
367
+ @jit_class
368
+ class PolynomialLR(LRScheduler):
369
+ r"""
370
+ For each epoch, the learning rate is adjusted by polynomial fitting.
371
+ When the epoch is greater than or equal to `total_iters` , the learning rate is ``0`` .
372
+ Notice that such decay can happen simultaneously with other changes to the learning rate
373
+ from outside this scheduler.
374
+
375
+ The polynomial formula for learning rate calculation is as follows:
376
+
377
+ .. math::
378
+ \begin{split}
379
+ &factor = (\frac{1.0 - \frac{last\_epoch}{total\_iters}}{1.0 - \frac{last\_epoch - 1.0}{total\_iters}})
380
+ ^{power}\\
381
+ &lr = lr \times factor
382
+ \end{split}
383
+
384
+ .. warning::
385
+ This is an experimental lr scheduler module that is subject to change.
386
+ This module must be used with optimizers in `Experimental Optimizer
387
+ <https://www.mindspore.cn/docs/en/master/api_python/mindspore.experimental.html#experimental-optimizer>`_ .
388
+
389
+ Args:
390
+ optimizer (:class:`mindspore.experimental.optim.Optimizer`): Wrapped optimizer.
391
+ total_iters (int, optional): The number of iterations adjusting learning rate by polynomial fitting.
392
+ Default: ``5``.
393
+ power (float, optional): Power of polynomial. Default: ``1.0``.
394
+ last_epoch (int, optional): The index of the last epoch. Default: ``-1``.
395
+
396
+ Supported Platforms:
397
+ ``Ascend`` ``GPU`` ``CPU``
398
+
399
+ Examples:
400
+ >>> from mindspore import nn
401
+ >>> from mindspore.experimental import optim
402
+ >>> class Net(nn.Cell):
403
+ ... def __init__(self):
404
+ ... super(Net, self).__init__()
405
+ ... self.fc = nn.Dense(16 * 5 * 5, 120)
406
+ ... def construct(self, x):
407
+ ... return self.fc(x)
408
+ >>> net = Net()
409
+ >>> optimizer = optim.Adam(net.trainable_params(), 0.01)
410
+ >>> scheduler = optim.lr_scheduler.PolynomialLR(optimizer)
411
+ >>> for i in range(6):
412
+ ... scheduler.step()
413
+ ... current_lr = scheduler.get_last_lr()
414
+ ... print(current_lr)
415
+ [Tensor(shape=[], dtype=Float32, value= 0.008)]
416
+ [Tensor(shape=[], dtype=Float32, value= 0.006)]
417
+ [Tensor(shape=[], dtype=Float32, value= 0.004)]
418
+ [Tensor(shape=[], dtype=Float32, value= 0.002)]
419
+ [Tensor(shape=[], dtype=Float32, value= 0)]
420
+ [Tensor(shape=[], dtype=Float32, value= 0)]
421
+ """
422
+ def __init__(self, optimizer, total_iters=5, power=1.0, last_epoch=-1):
423
+ if not isinstance(power, float):
424
+ raise TypeError(f"For 'PolynomialLR', the 'power' must be float, but got {type(power)}.")
425
+ if power < 0:
426
+ raise ValueError(f"For 'PolynomialLR', the 'power' must be >= 0, but got {power}.")
427
+ if not isinstance(total_iters, int):
428
+ raise TypeError(f"For 'PolynomialLR', the type of total_iters must be int, but got {type(total_iters)}.")
429
+ self.total_iters = total_iters
430
+ self.power = power
431
+ self.min = P.Minimum()
432
+ self.cast = P.Cast()
433
+ super(PolynomialLR, self).__init__(optimizer, last_epoch)
434
+
435
+ def get_lr(self):
436
+ if self.last_epoch == 0 or self.last_epoch > self.total_iters:
437
+ return [lr * 1. for lr in self._last_lr]
438
+ factor = ((1.0 - self.last_epoch / self.total_iters) / (
439
+ 1.0 - (self.last_epoch - 1) / self.total_iters)) ** self.power
440
+ return [lr * factor for lr in self._last_lr]
441
+
442
+ def _get_closed_form_lr(self):
443
+ return [
444
+ (base_lr * (1.0 - self.min(self.total_iters, self.last_epoch) / self.total_iters) ** self.power)
445
+ for base_lr in self.base_lrs]
446
+
447
+
448
+ @jit_class
449
+ class LambdaLR(LRScheduler):
450
+ """Sets the learning rate of each parameter group to the initial lr
451
+ times a given function. When last_epoch=-1, sets initial lr as lr.
452
+
453
+ .. warning::
454
+ This is an experimental lr scheduler module that is subject to change.
455
+ This module must be used with optimizers in `Experimental Optimizer
456
+ <https://www.mindspore.cn/docs/en/master/api_python/mindspore.experimental.html#experimental-optimizer>`_ .
457
+
458
+ Args:
459
+ optimizer (:class:`mindspore.experimental.optim.Optimizer`): Wrapped optimizer.
460
+ lr_lambda (Union(function, list)): A function which computes a multiplicative
461
+ factor given a parameter `last_epoch`, or a list of such
462
+ functions, one for each group in `optimizer.param_groups`.
463
+ last_epoch (int, optional): The index of the last epoch. Default: ``-1``.
464
+
465
+ Raises:
466
+ ValueError: If the length of `lr_lambda` is not equal to the number of param groups.
467
+
468
+ Supported Platforms:
469
+ ``Ascend`` ``GPU`` ``CPU``
470
+
471
+ Examples:
472
+ >>> from mindspore import nn
473
+ >>> from mindspore.experimental import optim
474
+ >>> net = nn.Dense(2, 3)
475
+ >>> optimizer = optim.Adam(net.trainable_params(), 0.01)
476
+ >>> lmbda = lambda epoch: 0.9 ** epoch
477
+ >>> scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=[lmbda])
478
+ >>> for i in range(3):
479
+ ... scheduler.step()
480
+ ... current_lr = scheduler.get_last_lr()
481
+ ... print(current_lr)
482
+ [Tensor(shape=[], dtype=Float32, value= 0.009)]
483
+ [Tensor(shape=[], dtype=Float32, value= 0.0081)]
484
+ [Tensor(shape=[], dtype=Float32, value= 0.00729)]
485
+ """
486
+ def __init__(self, optimizer, lr_lambda, last_epoch=-1):
487
+ if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple):
488
+ self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups)
489
+ else:
490
+ if len(lr_lambda) != len(optimizer.param_groups):
491
+ raise ValueError("Expected {} lr_lambdas, but got {}".format(
492
+ len(optimizer.param_groups), len(lr_lambda)))
493
+ self.lr_lambdas = list(lr_lambda)
494
+ super(LambdaLR, self).__init__(optimizer, last_epoch)
495
+
496
+ def get_lr(self):
497
+ return [base_lr * lmbda(self.last_epoch)
498
+ for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs)]
499
+
500
+
501
+ @jit_class
502
+ class MultiplicativeLR(LRScheduler):
503
+ """Multiply the learning rate of each parameter group by the factor given
504
+ in the specified function. When last_epoch=-1, sets initial lr as lr.
505
+
506
+ .. warning::
507
+ This is an experimental lr scheduler module that is subject to change.
508
+ This module must be used with optimizers in `Experimental Optimizer
509
+ <https://www.mindspore.cn/docs/en/master/api_python/mindspore.experimental.html#experimental-optimizer>`_ .
510
+
511
+ Args:
512
+ optimizer (:class:`mindspore.experimental.optim.Optimizer`): Wrapped optimizer.
513
+ lr_lambda (Union(function, list)): A function which computes a multiplicative
514
+ factor given an integer parameter epoch, or a list of such
515
+ functions, one for each group in optimizer.param_groups.
516
+ last_epoch (int, optional): The index of the last epoch. Default: ``-1``.
517
+
518
+ Supported Platforms:
519
+ ``Ascend`` ``GPU`` ``CPU``
520
+
521
+ Examples:
522
+ >>> from mindspore import nn
523
+ >>> from mindspore.experimental import optim
524
+ >>> net = nn.Dense(2, 3)
525
+ >>> optimizer = optim.Adam(net.trainable_params(), 0.01)
526
+ >>> lmbda = lambda epoch: 0.95
527
+ >>> scheduler = optim.lr_scheduler.MultiplicativeLR(optimizer, lr_lambda=lmbda)
528
+ >>> for i in range(3):
529
+ ... scheduler.step()
530
+ ... current_lr = scheduler.get_last_lr()
531
+ ... print(current_lr)
532
+ [Tensor(shape=[], dtype=Float32, value= 0.0095)]
533
+ [Tensor(shape=[], dtype=Float32, value= 0.009025)]
534
+ [Tensor(shape=[], dtype=Float32, value= 0.00857375)]
535
+ """
536
+ def __init__(self, optimizer, lr_lambda, last_epoch=-1):
537
+ if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple):
538
+ self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups)
539
+ else:
540
+ if len(lr_lambda) != len(optimizer.param_groups):
541
+ raise ValueError("Expected {} lr_lambdas, but got {}".format(
542
+ len(optimizer.param_groups), len(lr_lambda)))
543
+ self.lr_lambdas = list(lr_lambda)
544
+ super(MultiplicativeLR, self).__init__(optimizer, last_epoch)
545
+
546
+ def get_lr(self):
547
+ if self.last_epoch > 0:
548
+ return [lr * lmbda(self.last_epoch)
549
+ for lmbda, lr in zip(self.lr_lambdas, self._last_lr)]
550
+ return [lr * 1. for lr in self._last_lr]
551
+
552
+
553
+ @jit_class
554
+ class MultiStepLR(LRScheduler):
555
+ """Multiply the learning rate of each parameter group by gamma once the
556
+ number of epoch reaches one of the milestones. Notice that such change can
557
+ happen simultaneously with other changes to the learning rate from outside
558
+ this scheduler. When last_epoch=-1, sets initial lr as lr.
559
+
560
+ .. warning::
561
+ This is an experimental lr scheduler module that is subject to change.
562
+ This module must be used with optimizers in `Experimental Optimizer
563
+ <https://www.mindspore.cn/docs/en/master/api_python/mindspore.experimental.html#experimental-optimizer>`_ .
564
+
565
+ Args:
566
+ optimizer (:class:`mindspore.experimental.optim.Optimizer`): Wrapped optimizer.
567
+ milestones (list): List of epoch indices. When `last_epoch` reach the milestone,
568
+ multiply the learning rate of each parameter group by `gamma`.
569
+ gamma (float, optional): Multiplicative factor of learning rate decay.
570
+ Default: ``0.1``.
571
+ last_epoch (int, optional): The index of the last epoch. Default: ``-1``.
572
+
573
+ Raises:
574
+ TypeError: If the `milestones` is not list.
575
+ TypeError: If elements of the `milestones` are not int.
576
+ TypeError: If the `gamma` is not float.
577
+
578
+ Supported Platforms:
579
+ ``Ascend`` ``GPU`` ``CPU``
580
+
581
+ Examples:
582
+ >>> from mindspore import nn
583
+ >>> from mindspore.experimental import optim
584
+ >>> net = nn.Dense(2, 3)
585
+ >>> optimizer = optim.Adam(net.trainable_params(), 0.05)
586
+ >>> # Assuming optimizer uses lr = 0.05 for all groups
587
+ >>> # lr = 0.05 if epoch < 2
588
+ >>> # lr = 0.005 if 2 <= epoch < 4
589
+ >>> # lr = 0.0005 if epoch >= 4
590
+ >>> scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[2,4], gamma=0.1)
591
+ >>> for i in range(6):
592
+ ... scheduler.step()
593
+ ... current_lr = scheduler.get_last_lr()
594
+ ... print(current_lr)
595
+ [Tensor(shape=[], dtype=Float32, value= 0.05)]
596
+ [Tensor(shape=[], dtype=Float32, value= 0.005)]
597
+ [Tensor(shape=[], dtype=Float32, value= 0.005)]
598
+ [Tensor(shape=[], dtype=Float32, value= 0.0005)]
599
+ [Tensor(shape=[], dtype=Float32, value= 0.0005)]
600
+ [Tensor(shape=[], dtype=Float32, value= 0.0005)]
601
+ """
602
+ def __init__(self, optimizer, milestones, gamma=0.1, last_epoch=-1):
603
+ Validator.check_value_type('milestones', milestones, [list])
604
+ for milestone in milestones:
605
+ if not isinstance(milestone, int) and not isinstance(milestone, bool):
606
+ raise TypeError(f"For 'MultiStepLR', elements of the 'milestones' must be type of int, "
607
+ f"but got one element of 'milestones' type: {type(milestone)}.")
608
+ Validator.check_value_type('gamma', gamma, [float, int])
609
+ self.milestones = Counter(milestones)
610
+ self.milestones_keys = list(self.milestones.keys())
611
+ self.milestones_values = list(self.milestones.values())
612
+ self.gamma = gamma
613
+ super(MultiStepLR, self).__init__(optimizer, last_epoch)
614
+
615
+ def get_lr(self):
616
+ lrs = [lr.value() for lr in self._last_lr]
617
+ tmp_epoch = int(self.last_epoch.value())
618
+
619
+ for i in range(len(self.milestones_keys)):
620
+ if tmp_epoch == self.milestones_keys[i]:
621
+ value = self.milestones_values[i]
622
+ return [lr * self.gamma ** value for lr in lrs]
623
+ return lrs
624
+
625
+ def _get_closed_form_lr(self):
626
+ return [base_lr * self.gamma ** (self.last_epoch // self.step_size)
627
+ for base_lr in self.base_lrs]
628
+
629
+
630
+ @jit_class
631
+ class ConstantLR(LRScheduler):
632
+ """Decays the learning rate of each parameter group by a small constant factor until the
633
+ number of epoch reaches a pre-defined milestone: total_iters. Notice that such decay can
634
+ happen simultaneously with other changes to the learning rate from outside this scheduler.
635
+
636
+ .. warning::
637
+ This is an experimental lr scheduler module that is subject to change.
638
+ This module must be used with optimizers in `Experimental Optimizer
639
+ <https://www.mindspore.cn/docs/en/master/api_python/mindspore.experimental.html#experimental-optimizer>`_ .
640
+
641
+ Args:
642
+ optimizer (:class:`mindspore.experimental.optim.Optimizer`): Wrapped optimizer.
643
+ factor (float, optional): The factor number multiplied learning rate. Default: ``1./3``.
644
+ total_iters (int, optional): The number of steps that the scheduler decays the learning rate,
645
+ when the `last_epoch` reach `total_iters`, restore the learning rate. Default: ``5``.
646
+ last_epoch (int, optional): The index of the last epoch. Default: ``-1``.
647
+
648
+ Supported Platforms:
649
+ ``Ascend`` ``GPU`` ``CPU``
650
+
651
+ Examples:
652
+ >>> from mindspore import nn
653
+ >>> from mindspore.experimental import optim
654
+ >>> net = nn.Dense(2, 3)
655
+ >>> optimizer = optim.Adam(net.trainable_params(), 0.05)
656
+ >>> # Assuming optimizer uses lr = 0.05 for all groups
657
+ >>> # lr = 0.025 if epoch <4
658
+ >>> # lr = 0.05 if epoch >= 4
659
+ >>> scheduler = optim.lr_scheduler.ConstantLR(optimizer, factor=0.5, total_iters=4)
660
+ >>> for i in range(6):
661
+ ... scheduler.step()
662
+ ... current_lr = scheduler.get_last_lr()
663
+ ... print(current_lr)
664
+ [Tensor(shape=[], dtype=Float32, value= 0.025)]
665
+ [Tensor(shape=[], dtype=Float32, value= 0.025)]
666
+ [Tensor(shape=[], dtype=Float32, value= 0.025)]
667
+ [Tensor(shape=[], dtype=Float32, value= 0.05)]
668
+ [Tensor(shape=[], dtype=Float32, value= 0.05)]
669
+ [Tensor(shape=[], dtype=Float32, value= 0.05)]
670
+ """
671
+ def __init__(self, optimizer, factor=1.0 / 3, total_iters=5, last_epoch=-1):
672
+ if factor > 1.0 or factor < 0:
673
+ raise ValueError('Constant multiplicative factor expected to be between 0 and 1.')
674
+ if not isinstance(total_iters, int):
675
+ raise TypeError(f"For 'ConstantLR', the type of total_iters must be int, but got {type(total_iters)}.")
676
+ self.factor = factor
677
+ self.total_iters = total_iters
678
+ super(ConstantLR, self).__init__(optimizer, last_epoch)
679
+
680
+ def get_lr(self):
681
+ if self.last_epoch == 0:
682
+ return [lr * self.factor for lr in self._last_lr]
683
+ if self.last_epoch != self.total_iters:
684
+ return [lr * 1. for lr in self._last_lr]
685
+ return [lr / self.factor for lr in self._last_lr]
686
+
687
+ def _get_closed_form_lr(self):
688
+ return [base_lr * (self.factor + (self.last_epoch >= self.total_iters) * (1 - self.factor))
689
+ for base_lr in self.base_lrs]
690
+
691
+
692
+ @jit_class
693
+ class SequentialLR:
694
+ r"""
695
+ Receives the list of schedulers that is expected to be called sequentially during
696
+ optimization process and milestone points that provides exact intervals to reflect
697
+ which scheduler is supposed to be called at a given epoch.
698
+
699
+ .. warning::
700
+ This is an experimental lr scheduler module that is subject to change.
701
+ This module must be used with optimizers in `Experimental Optimizer
702
+ <https://www.mindspore.cn/docs/en/master/api_python/mindspore.experimental.html#experimental-optimizer>`_ .
703
+
704
+ Args:
705
+ optimizer (:class:`mindspore.experimental.optim.Optimizer`): Wrapped optimizer.
706
+ schedulers (list[:class:`mindspore.experimental.optim.lr_scheduler.LRScheduler`]):
707
+ List of learning rate schedulers.
708
+ milestones (list): List of integers that reflects milestone points.
709
+ last_epoch (int, optional): The index of the last epoch. Default: ``-1``.
710
+
711
+ Raises:
712
+ ValueError: The optimizer in `schedulers` is different from the `optimizer` passed in.
713
+ ValueError: The optimizer in `schedulers` is different from the optimizer of `schedulers[0]`.
714
+ ValueError: Length of `milestones` is not equal to length of `schedulers` minus 1.
715
+
716
+ Supported Platforms:
717
+ ``Ascend`` ``GPU`` ``CPU``
718
+
719
+ Examples:
720
+ >>> from mindspore.experimental import optim
721
+ >>> from mindspore import nn
722
+ >>> net = nn.Dense(3, 2)
723
+ >>> optimizer = optim.Adam(net.trainable_params(), 0.1)
724
+ >>> scheduler1 = optim.lr_scheduler.ConstantLR(optimizer, factor=0.1, total_iters=2)
725
+ >>> scheduler2 = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
726
+ >>> scheduler = optim.lr_scheduler.SequentialLR(optimizer, schedulers=[scheduler1, scheduler2], milestones=[2])
727
+ >>> for i in range(6):
728
+ ... scheduler.step()
729
+ ... current_lr = scheduler.get_last_lr()
730
+ ... print(current_lr)
731
+ [Tensor(shape=[], dtype=Float32, value= 0.01)]
732
+ [Tensor(shape=[], dtype=Float32, value= 0.1)]
733
+ [Tensor(shape=[], dtype=Float32, value= 0.09)]
734
+ [Tensor(shape=[], dtype=Float32, value= 0.081)]
735
+ [Tensor(shape=[], dtype=Float32, value= 0.0729)]
736
+ [Tensor(shape=[], dtype=Float32, value= 0.06561)]
737
+ """
738
+ def __init__(self, optimizer, schedulers, milestones, last_epoch=-1):
739
+ for sched_idx in range(len(schedulers)):
740
+ if schedulers[sched_idx].optimizer != optimizer:
741
+ raise ValueError(
742
+ "Sequential Schedulers expects all schedulers to belong to the same optimizer, but "
743
+ f"got scheduler at index {sched_idx} is different from the optimizer passed in.")
744
+
745
+ if schedulers[sched_idx].optimizer != schedulers[0].optimizer:
746
+ raise ValueError(
747
+ "Sequential Schedulers expects all schedulers to belong to the same optimizer, but "
748
+ f"got schedulers at index {0} and {sched_idx} are different.")
749
+
750
+ if len(milestones) != len(schedulers) - 1:
751
+ raise ValueError(
752
+ "Sequential Schedulers expects number of schedulers provided to be one more "
753
+ "than the number of milestone points, but got number of schedulers {} and the "
754
+ "number of milestones {}".format(len(schedulers), len(milestones)))
755
+
756
+ self._schedulers = schedulers
757
+ self.milestones = milestones
758
+ self.milestones_len = len(milestones)
759
+ self.last_epoch = Parameter(Tensor(last_epoch + 1, dtype=mstype.float32),
760
+ name='last_epoch_' + self.__class__.__name__)
761
+ self.increase_tensor = Tensor(1, mstype.int32)
762
+
763
+ self.optimizer = optimizer
764
+ for group in self.optimizer.param_groups:
765
+ ops.assign(group["lr"], group["initial_lr"])
766
+
767
+ for scheduler in self._schedulers:
768
+ ops.assign_sub(scheduler.last_epoch, self.increase_tensor)
769
+
770
+ self._schedulers[0].step()
771
+ self._last_lr = schedulers[0]._last_lr # pylint: disable=W0212
772
+
773
+ def step(self):
774
+ """
775
+ Get the current learning rate and change the learning rate.
776
+ """
777
+ ops.assign_add(self.last_epoch, self.increase_tensor)
778
+ tmp_epoch = int(self.last_epoch)
779
+
780
+ cur_idx = bisect_right(self.milestones, tmp_epoch)
781
+ scheduler = self._schedulers[cur_idx]
782
+ if cur_idx > 0 and self.milestones[cur_idx - 1] == tmp_epoch:
783
+ scheduler.step(0)
784
+ else:
785
+ scheduler.step()
786
+
787
+ def get_last_lr(self):
788
+ """
789
+ Return last computed learning rate by current scheduler.
790
+ """
791
+ return [lr.value() for lr in self._last_lr]
792
+
793
+
794
+ @jit_class
795
+ class ReduceLROnPlateau:
796
+ """
797
+ Reduce learning rate when a metric has stopped improving.
798
+ Models often benefit from reducing the learning rate by a factor
799
+ of 2-10 once learning stagnates. This scheduler reads a metrics
800
+ quantity and if no improvement is seen for a 'patience' number
801
+ of epochs, the learning rate is reduced.
802
+
803
+ .. warning::
804
+ This is an experimental lr scheduler module that is subject to change.
805
+ This module must be used with optimizers in `Experimental Optimizer
806
+ <https://www.mindspore.cn/docs/en/master/api_python/mindspore.experimental.html#experimental-optimizer>`_ .
807
+
808
+ Args:
809
+ optimizer (:class:`mindspore.experimental.optim.Optimizer`): Wrapped optimizer.
810
+ mode (str, optional): One of `min`, `max`. In `min` mode, lr will
811
+ be reduced when the quantity monitored has stopped
812
+ decreasing; in `max` mode it will be reduced when the
813
+ quantity monitored has stopped increasing. Default: ``'min'``.
814
+ factor (float, optional): Factor by which the learning rate will be
815
+ reduced. Default: ``0.1``.
816
+ patience (int, optional): Number of epochs with no improvement after
817
+ which learning rate will be reduced. For example, if
818
+ `patience = 2`, then we will ignore the first 2 epochs
819
+ with no improvement, and will only decrease the LR after the
820
+ 3rd epoch if the loss still hasn't improved then.
821
+ Default: ``10``.
822
+ threshold (float, optional): Threshold for measuring the new optimum,
823
+ to only focus on significant changes. Default: ``1e-4``.
824
+ threshold_mode (str, optional): One of `rel`, `abs`. Given dynamic_threshold is the benchmark to
825
+ define whether the current metric is improvement,
826
+ in ``'rel'`` mode, dynamic_threshold = best * ( 1 + threshold ) in ``'max'`` mode
827
+ or best * ( 1 - threshold ) in ``'min'`` mode.
828
+ In ``'abs'`` mode, dynamic_threshold = best + threshold in ``'max'`` mode or
829
+ best - threshold in ``'min'`` mode. Default: ``'rel'``.
830
+ cooldown (int, optional): Number of epochs to wait before resuming
831
+ normal operation after lr has been reduced. Default: ``0``.
832
+ min_lr (Union(float, list), optional): A scalar or a list of scalars. A
833
+ lower bound on the learning rate of all param groups
834
+ or each group respectively. Default: ``0``.
835
+ eps (float, optional): Minimal decay applied to lr. If the difference
836
+ between new and old lr is smaller than eps, the update is
837
+ ignored. Default: ``1e-8``.
838
+
839
+ Raises:
840
+ ValueError: `factor` is greater than or equal to 1.0.
841
+ TypeError: `optimizer` is not an `Optimizer`.
842
+ ValueError: When `min_lr` is a list or tuple, the length of it is not equal to the number of param groups.
843
+ ValueError: `mode` is neither ``'min'`` nor ``'max'``.
844
+ ValueError: `threshold_mode` is neither ``'rel'`` nor ``'abs'``.
845
+
846
+ Supported Platforms:
847
+ ``Ascend`` ``GPU`` ``CPU``
848
+
849
+ Examples:
850
+ >>> from mindspore.experimental import optim
851
+ >>> from mindspore import nn
852
+ >>> net = nn.Dense(3, 2)
853
+ >>> optimizer = optim.Adam(net.trainable_params(), 0.1)
854
+ >>> scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=0)
855
+ >>> metrics = [1, 1.5, 1.8, 0.4, 0.5]
856
+ >>> for i in range(5):
857
+ ... scheduler.step(metrics[i])
858
+ ... current_lr = scheduler.get_last_lr()
859
+ ... print(current_lr)
860
+ [Tensor(shape=[], dtype=Float32, value= 0.1)]
861
+ [Tensor(shape=[], dtype=Float32, value= 0.01)]
862
+ [Tensor(shape=[], dtype=Float32, value= 0.001)]
863
+ [Tensor(shape=[], dtype=Float32, value= 0.001)]
864
+ [Tensor(shape=[], dtype=Float32, value= 0.0001)]
865
+ """
866
+ def __init__(self, optimizer, mode='min', factor=0.1, patience=10,
867
+ threshold=1e-4, threshold_mode='rel', cooldown=0,
868
+ min_lr=0, eps=1e-8):
869
+
870
+ if factor >= 1.0:
871
+ raise ValueError("The lr factor should be less than 1.0.")
872
+ self.factor = factor
873
+
874
+ if not isinstance(optimizer, Optimizer):
875
+ raise TypeError("Expected an `Optimizer`, but got type {}".format(type(optimizer).__name__))
876
+ self.optimizer = optimizer
877
+
878
+ if isinstance(min_lr, (list, tuple)):
879
+ if len(min_lr) != len(optimizer.param_groups):
880
+ raise ValueError("Expected {} min_lrs, got {}".format(len(optimizer.param_groups), len(min_lr)))
881
+ self.min_lrs = [Tensor(lr, mstype.float32) for lr in min_lr]
882
+ else:
883
+ self.min_lrs = [Tensor(min_lr, mstype.float32)] * len(optimizer.param_groups)
884
+
885
+ self.mode = mode
886
+ self.patience = patience
887
+ self.threshold = threshold
888
+ self.threshold_mode = threshold_mode
889
+ self.cooldown = cooldown
890
+ self.cooldown_counter = 0
891
+ self.eps = eps
892
+ self.mode_worse = None
893
+ self.assign = P.Assign()
894
+ self.cast = P.Cast()
895
+ self.last_epoch = Parameter(Tensor(0, dtype=mstype.int32),
896
+ name='last_epoch_' + self.__class__.__name__)
897
+
898
+ if self.mode not in {'min', 'max'}:
899
+ raise ValueError(f"`mode` should be 'min' or 'max', but got {self.mode}.")
900
+ if self.threshold_mode not in {'rel', 'abs'}:
901
+ raise ValueError(f"`threshold mode` should be 'rel' or 'abs', but got {self.threshold_mode}.")
902
+
903
+ if self.mode == 'min':
904
+ self.mode_worse = float("inf")
905
+ else:
906
+ self.mode_worse = float("-inf")
907
+
908
+ self.best = Parameter(Tensor(self.mode_worse, dtype=mstype.float32), name='best')
909
+
910
+ self.cooldown_counter = Parameter(Tensor(0, dtype=mstype.float32), name='cooldown_counter')
911
+ self.wait = Parameter(Tensor(0, dtype=mstype.float32), name='wait')
912
+ self.increase_tensor = Tensor(1, mstype.int32)
913
+ self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
914
+
915
+ def step(self, metrics):
916
+ """
917
+ Get the current learning rate and change the learning rate.
918
+
919
+ Args:
920
+ metrics(Union(int, float)): the evaluation metrics.
921
+ """
922
+ epoch = self.last_epoch + 1
923
+ current = self.cast(metrics, mstype.float32)
924
+ self.assign(self.last_epoch, epoch)
925
+
926
+ if self._is_improvement(current, self.best):
927
+ ops.assign(self.best, current)
928
+ ops.assign(self.wait, 0)
929
+ else:
930
+ ops.assign_add(self.wait, self.increase_tensor)
931
+
932
+ if self.in_cooldown():
933
+ ops.assign_sub(self.cooldown_counter, self.increase_tensor)
934
+ ops.assign(self.wait, 0)
935
+
936
+ if self.wait > self.patience:
937
+ self._reduce_lr(epoch)
938
+ ops.assign(self.cooldown_counter, self.cooldown)
939
+ ops.assign(self.wait, 0)
940
+
941
+ return True
942
+
943
+ def _reduce_lr(self, epoch):
944
+ for i, lr in enumerate(self._last_lr):
945
+ old_lr = lr.value()
946
+ new_lr = ops.maximum(old_lr * self.factor, self.min_lrs[i])
947
+ if old_lr > new_lr + self.eps:
948
+ ops.assign(lr, new_lr)
949
+ return True
950
+
951
+ def in_cooldown(self):
952
+ """ Whether in cooldown period. """
953
+ return self.cooldown_counter > 0
954
+
955
+ def _is_improvement(self, current, best):
956
+ """ Whether current metric value is better than best. """
957
+ if self.mode == 'min' and self.threshold_mode == 'rel':
958
+ rel_epsilon = 1. - self.threshold
959
+ benchmark = best * rel_epsilon
960
+ return current < benchmark
961
+
962
+ if self.mode == 'min' and self.threshold_mode == 'abs':
963
+ benchmark = best - self.threshold
964
+ return current < benchmark
965
+
966
+ if self.mode == 'max' and self.threshold_mode == 'rel':
967
+ rel_epsilon = self.threshold + 1.
968
+ benchmark = best * rel_epsilon
969
+ return current > benchmark
970
+
971
+ benchmark = best + self.threshold
972
+ return current > benchmark
973
+
974
+ def get_last_lr(self):
975
+ """
976
+ Return last computed learning rate by current scheduler.
977
+ """
978
+ return [lr.value() for lr in self._last_lr]
979
+
980
+
981
+ @jit_class
982
+ class CyclicLR(LRScheduler):
983
+ r"""
984
+ Sets the learning rate of each parameter group according to
985
+ cyclical learning rate policy (CLR). The policy cycles the learning
986
+ rate between two boundaries with a constant frequency, as detailed in
987
+ the paper `Cyclical Learning Rates for Training Neural Networks <https://arxiv.org/abs/1506.01186>`_.
988
+ The distance between the two boundaries can be scaled on a per-iteration
989
+ or per-cycle basis.
990
+
991
+ This class has three built-in policies, as put forth in the paper:
992
+
993
+ - "triangular": A basic triangular cycle without amplitude scaling.
994
+ - "triangular2": A basic triangular cycle that scales initial amplitude by half each cycle.
995
+ - "exp_range": A cycle that scales initial amplitude by :math:`\text{gamma}^{\text{cycle iterations}}`
996
+ at each cycle iteration.
997
+
998
+ .. warning::
999
+ This is an experimental lr scheduler module that is subject to change.
1000
+ This module must be used with optimizers in `Experimental Optimizer
1001
+ <https://www.mindspore.cn/docs/en/master/api_python/mindspore.experimental.html#experimental-optimizer>`_ .
1002
+
1003
+ Args:
1004
+ optimizer (:class:`mindspore.experimental.optim.Optimizer`): Wrapped optimizer.
1005
+ base_lr (Union(float, list)): Initial learning rate which is the
1006
+ lower boundary in the cycle for each parameter group.
1007
+ max_lr (Union(float, list)): Upper learning rate boundaries in the cycle
1008
+ for each parameter group. Functionally, it defines the cycle amplitude (max_lr - base_lr).
1009
+ The lr at any cycle is the sum of base_lr and some scaling of the amplitude.
1010
+ step_size_up (int, optional): Number of training iterations in the
1011
+ increasing half of a cycle. Default: ``2000``.
1012
+ step_size_down (int, optional): Number of training iterations in the
1013
+ decreasing half of a cycle. If step_size_down is None,
1014
+ it is set to step_size_up. Default: ``None``.
1015
+ mode (str, optional): One of {triangular, triangular2, exp_range}.
1016
+ Values correspond to policies detailed above.
1017
+ If scale_fn is not None, this argument is ignored.
1018
+ Default: ``'triangular'``.
1019
+ gamma (float, optional): Constant in 'exp_range' scaling function:
1020
+ gamma**(cycle iterations). Default: ``1.0``.
1021
+ scale_fn (function, optional): Custom scaling policy defined by a single
1022
+ argument lambda function, where 0 <= scale_fn(x) <= 1 for all x >= 0.
1023
+ If specified, then 'mode' is ignored. Default: ``None``.
1024
+ scale_mode (str, optional): {'cycle', 'iterations'}.
1025
+ Defines whether scale_fn is evaluated on cycle number or cycle iterations (training
1026
+ iterations since start of cycle). Illegal inputs will use ``'iterations'`` by defaults.
1027
+ Default: ``'cycle'``.
1028
+ last_epoch (int, optional): The index of the last epoch. Default: ``-1``.
1029
+
1030
+ Raises:
1031
+ ValueError: When `base_lr` is list or tuple, the length of it is not equal to the number of param groups.
1032
+ ValueError: When `max_lr` is list or tuple, the length of it is not equal to the number of param groups.
1033
+ ValueError: `mode` is not in [``'triangular'``, ``'triangular2'``, ``'exp_range'``] and `scale_fn` is None.
1034
+
1035
+ Supported Platforms:
1036
+ ``Ascend`` ``GPU`` ``CPU``
1037
+
1038
+ Examples:
1039
+ >>> from mindspore.experimental import optim
1040
+ >>> from mindspore import nn
1041
+ >>> net = nn.Dense(3, 2)
1042
+ >>> optimizer = optim.SGD(net.trainable_params(), lr=0.1, momentum=0.9)
1043
+ >>> scheduler = optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.01, max_lr=0.1)
1044
+ >>> expect_list = [[0.010045], [0.01009], [0.010135], [0.01018], [0.010225]]
1045
+ >>>
1046
+ >>> for i in range(5):
1047
+ ... scheduler.step()
1048
+ ... current_lr = scheduler.get_last_lr()
1049
+ ... print(current_lr)
1050
+ [Tensor(shape=[], dtype=Float32, value= 0.010045)]
1051
+ [Tensor(shape=[], dtype=Float32, value= 0.01009)]
1052
+ [Tensor(shape=[], dtype=Float32, value= 0.010135)]
1053
+ [Tensor(shape=[], dtype=Float32, value= 0.01018)]
1054
+ [Tensor(shape=[], dtype=Float32, value= 0.010225)]
1055
+ """
1056
+ def __init__(self,
1057
+ optimizer,
1058
+ base_lr,
1059
+ max_lr,
1060
+ step_size_up=2000,
1061
+ step_size_down=None,
1062
+ mode='triangular',
1063
+ gamma=1.,
1064
+ scale_fn=None,
1065
+ scale_mode='cycle',
1066
+ last_epoch=-1):
1067
+
1068
+ base_lrs = self._preprocess_input_param(optimizer, base_lr, 'base_lr')
1069
+
1070
+ if last_epoch == -1:
1071
+ for lr, group in zip(base_lrs, optimizer.param_groups):
1072
+ ops.assign(group['lr'], Parameter(lr))
1073
+
1074
+ self.max_lrs = self._preprocess_input_param(optimizer, max_lr, 'max_lr')
1075
+ self.max_lrs = [Tensor(lr) for lr in self.max_lrs]
1076
+
1077
+ step_size_up = float(step_size_up)
1078
+ step_size_down = step_size_up if step_size_down is None else float(step_size_down)
1079
+
1080
+ self.total_step_size = step_size_up + step_size_down
1081
+ self.step_up_ratio = step_size_up / self.total_step_size
1082
+
1083
+ if mode not in ['triangular', 'triangular2', 'exp_range'] \
1084
+ and scale_fn is None:
1085
+ raise ValueError('mode is invalid and scale_fn is None')
1086
+
1087
+ self.mode = mode
1088
+ self.gamma = gamma
1089
+
1090
+ self._scale_fn_ref = None
1091
+ self._scale_fn_custom = scale_fn
1092
+ self.scale_mode = scale_mode
1093
+ self._init_scale_fn()
1094
+ self.floor = P.Floor()
1095
+
1096
+ super(CyclicLR, self).__init__(optimizer, last_epoch)
1097
+ self.base_lrs = [Tensor(lr) for lr in base_lrs]
1098
+
1099
+ def _init_scale_fn(self):
1100
+ """ Define the scale function. """
1101
+ if self._scale_fn_custom is not None:
1102
+ return
1103
+ if self.mode == 'triangular':
1104
+ self._scale_fn_ref = self._triangular_scale_fn
1105
+ self.scale_mode = 'cycle'
1106
+ elif self.mode == 'triangular2':
1107
+ self._scale_fn_ref = self._triangular2_scale_fn
1108
+ self.scale_mode = 'cycle'
1109
+ elif self.mode == 'exp_range':
1110
+ self._scale_fn_ref = self._exp_range_scale_fn
1111
+ self.scale_mode = 'iterations'
1112
+
1113
+ def _preprocess_input_param(self, optimizer, param, name):
1114
+ """Return correctly formatted lr/momentum for each param group."""
1115
+ if isinstance(param, (list, tuple)):
1116
+ if len(param) != len(optimizer.param_groups):
1117
+ raise ValueError("Expected {} values for {}, got {}".format(
1118
+ len(optimizer.param_groups), name, len(param)))
1119
+ return param
1120
+ return [param] * len(optimizer.param_groups)
1121
+
1122
+ def scale_fn(self, x):
1123
+ if self._scale_fn_custom is None:
1124
+ return self._scale_fn_ref(x)
1125
+ return self._scale_fn_custom(x)
1126
+
1127
+ def _triangular_scale_fn(self, x):
1128
+ return 1.
1129
+
1130
+ def _triangular2_scale_fn(self, x):
1131
+ return 1 / (2. ** (x - 1))
1132
+
1133
+ def _exp_range_scale_fn(self, x):
1134
+ return self.gamma ** (x)
1135
+
1136
+ def get_lr(self):
1137
+ cycle = self.floor(1 + self.last_epoch / self.total_step_size)
1138
+ x = 1. + self.last_epoch / self.total_step_size - cycle
1139
+ if x <= self.step_up_ratio:
1140
+ scale_factor = x / self.step_up_ratio
1141
+ else:
1142
+ scale_factor = (x - 1) / (self.step_up_ratio - 1)
1143
+ lrs = []
1144
+ for base_lr, max_lr in zip(self.base_lrs, self.max_lrs):
1145
+ base_height = (max_lr - base_lr) * scale_factor
1146
+
1147
+ if self.scale_mode == 'cycle':
1148
+ lr = base_lr + base_height * self.scale_fn(cycle)
1149
+ else:
1150
+ lr = base_lr + base_height * self.scale_fn(self.last_epoch)
1151
+ lrs.append(lr)
1152
+
1153
+ return lrs
1154
+
1155
+
1156
+ @jit_class
1157
+ class CosineAnnealingWarmRestarts(LRScheduler):
1158
+ r"""
1159
+ Set the learning rate of each parameter group using a cosine annealing warm restarts
1160
+ schedule. Where :math:`\eta_{max}` is set to the initial lr, :math:`\eta_{min}` is the minimum value
1161
+ for learning rate, :math:`\eta_{t}` is the current learning rate, :math:`T_{0}` is the number of iterations for the
1162
+ first restar, :math:`T_{i}` is the current number of iterations between two warm restarts in SGDR,
1163
+ :math:`T_{cur}` is the number of epochs since the last restart in SGDR.
1164
+
1165
+ .. math::
1166
+ \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 +
1167
+ \cos\left(\frac{T_{cur}}{T_{i}}\pi\right)\right)
1168
+
1169
+ When :math:`T_{cur}=T_{i}`, set :math:`\eta_t = \eta_{min}`.
1170
+ When :math:`T_{cur}=0` after restart, set :math:`\eta_t=\eta_{max}`.
1171
+
1172
+ For more details, please refer to: `SGDR: Stochastic Gradient Descent with Warm Restarts
1173
+ <https://arxiv.org/abs/1608.03983>`_.
1174
+
1175
+ .. warning::
1176
+ This is an experimental lr scheduler module that is subject to change.
1177
+ This module must be used with optimizers in `Experimental Optimizer
1178
+ <https://www.mindspore.cn/docs/en/master/api_python/mindspore.experimental.html#experimental-optimizer>`_ .
1179
+
1180
+ Args:
1181
+ optimizer (:class:`mindspore.experimental.optim.Optimizer`): Wrapped optimizer.
1182
+ T_0 (int): Number of iterations for the first restart.
1183
+ T_mult (int, optional): A factor increases :math:`T_{i}` after a restart. Default: ``1``.
1184
+ eta_min (Union(float, int), optional): Minimum learning rate. Default: ``0``.
1185
+ last_epoch (int, optional): The index of the last epoch. Default: ``-1``.
1186
+
1187
+ Raises:
1188
+ ValueError: `T_0` is less than or equal than 0 or not an int.
1189
+ ValueError: `T_mult` is less than or equal than 1 or not an int.
1190
+ ValueError: `eta_min` is not int or float.
1191
+
1192
+ Supported Platforms:
1193
+ ``Ascend`` ``GPU`` ``CPU``
1194
+
1195
+ Examples:
1196
+ >>> from mindspore.experimental import optim
1197
+ >>> from mindspore import nn
1198
+ >>> net = nn.Dense(3, 2)
1199
+ >>> optimizer = optim.SGD(net.trainable_params(), lr=0.1, momentum=0.9)
1200
+ >>> scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 2)
1201
+ >>> iters = 3
1202
+ >>> for epoch in range(2):
1203
+ ... for i in range(iters):
1204
+ ... scheduler.step(epoch + i / iters)
1205
+ ... current_lr = scheduler.get_last_lr()
1206
+ ... print(current_lr)
1207
+ [Tensor(shape=[], dtype=Float32, value= 0.1)]
1208
+ [Tensor(shape=[], dtype=Float32, value= 0.0933013)]
1209
+ [Tensor(shape=[], dtype=Float32, value= 0.075)]
1210
+ [Tensor(shape=[], dtype=Float32, value= 0.05)]
1211
+ [Tensor(shape=[], dtype=Float32, value= 0.025)]
1212
+ [Tensor(shape=[], dtype=Float32, value= 0.00669873)]
1213
+ """
1214
+ def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1):
1215
+ if T_0 <= 0 or not isinstance(T_0, int):
1216
+ raise ValueError("T_0 should be an integer and equal or greater than 0, but got {}".format(T_0))
1217
+ if T_mult < 1 or not isinstance(T_mult, int):
1218
+ raise ValueError("T_mult should be an integer and equal or greater than 1, but got {}".format(T_mult))
1219
+ self.T_0 = Parameter(Tensor(T_0, dtype=mstype.float32), name='T_0')
1220
+ self.T_i = Parameter(Tensor(T_0, dtype=mstype.float32), name='T_i')
1221
+ self.T_mult = T_mult
1222
+ Validator.check_value_type('eta_min', eta_min, [float, int])
1223
+ self.eta_min = Tensor(eta_min)
1224
+ self.T_cur = Parameter(Tensor(last_epoch, dtype=mstype.float32), name='T_cur')
1225
+ self.increase_tensor = Tensor(1, mstype.int32)
1226
+ self.zero_tensor = Tensor(0, mstype.int32)
1227
+
1228
+ self.math_pi = math.pi
1229
+ self.cos = P.Cos()
1230
+ self.cast = P.Cast()
1231
+ self.log = P.Log()
1232
+ self.cast = P.Cast()
1233
+ self.assign = P.Assign()
1234
+ self.floor = P.Floor()
1235
+ self._last_lr = [group["lr"] for group in optimizer.param_groups]
1236
+ super(CosineAnnealingWarmRestarts, self).__init__(optimizer, last_epoch)
1237
+
1238
+ def get_lr(self):
1239
+ pct = self.cast(self.math_pi * self.T_cur / self.T_i, mstype.float32)
1240
+ return [self.eta_min + (base_lr - self.eta_min) * (1 + self.cos(pct)) / 2
1241
+ for base_lr in self.base_lrs]
1242
+
1243
+ def step(self, epoch=None):
1244
+ """
1245
+ Get the current learning rate and change the learning rate.
1246
+
1247
+ Args:
1248
+ epoch (int, optional): The index of the last epoch. Default: ``None``.
1249
+ """
1250
+ if epoch is None and self.last_epoch < 0:
1251
+ epoch = self.zero_tensor
1252
+
1253
+ if epoch is None:
1254
+ epoch = self.last_epoch + 1
1255
+ ops.assign_add(self.T_cur, self.increase_tensor)
1256
+ if self.T_cur >= self.T_i:
1257
+ ops.assign(self.T_cur, self.T_cur - self.T_i)
1258
+ ops.assign(self.T_i, self.T_i * self.T_mult)
1259
+
1260
+ else:
1261
+ if epoch < 0:
1262
+ raise ValueError("epoch should be a non-negative integer, but got {}".format(epoch))
1263
+ epoch = self.cast(epoch, mstype.float32)
1264
+
1265
+ if epoch >= self.T_0:
1266
+ if self.T_mult == 1:
1267
+ ops.assign(self.T_cur, epoch % self.T_0)
1268
+
1269
+ else:
1270
+ exp = int(self.log((epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult))
1271
+ value = epoch - self.T_0 * (self.T_mult ** exp - 1) / (self.T_mult - 1)
1272
+ ops.assign(self.T_cur, value)
1273
+ ops.assign(self.T_i, self.T_0 * self.T_mult ** exp)
1274
+
1275
+ else:
1276
+ ops.assign(self.T_i, self.T_0.value())
1277
+ ops.assign(self.T_cur, self.cast(epoch, mstype.float32))
1278
+
1279
+ self.assign(self.last_epoch, self.floor(epoch))
1280
+
1281
+ for i, data in enumerate(zip(self.optimizer.param_groups, self.get_lr())):
1282
+ _, lr = data
1283
+ F.assign(self.optimizer.param_groups[i]["lr"], lr)
1284
+
1285
+
1286
+ @jit_class
1287
+ class CosineAnnealingLR(LRScheduler):
1288
+ r"""
1289
+ Set the learning rate of each parameter group using a cosine annealing lr
1290
+ schedule. Where :math:`\eta_{max}` is set to the initial lr, :math:`\eta_{min}` is the minimum value
1291
+ for learning rate, :math:`\eta_{t}` is the current learning rate, :math:`T_{max}` is iteration number of cosine
1292
+ function, and :math:`T_{cur}` is the number of epochs since the last restart in SGDR.
1293
+
1294
+ .. math::
1295
+ \begin{aligned}
1296
+ \eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1
1297
+ + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right),
1298
+ & T_{cur} \neq (2k+1)T_{max}; \\
1299
+ \eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min})
1300
+ \left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right),
1301
+ & T_{cur} = (2k+1)T_{max}.
1302
+ \end{aligned}
1303
+
1304
+ For more details, please refer to: `SGDR: Stochastic Gradient Descent with Warm Restarts
1305
+ <https://arxiv.org/abs/1608.03983>`_ .
1306
+
1307
+ .. warning::
1308
+ This is an experimental lr scheduler module that is subject to change.
1309
+ This module must be used with optimizers in `Experimental Optimizer
1310
+ <https://www.mindspore.cn/docs/en/master/api_python/mindspore.experimental.html#experimental-optimizer>`_ .
1311
+
1312
+ Args:
1313
+ optimizer (:class:`mindspore.experimental.optim.Optimizer`): Wrapped optimizer.
1314
+ T_max (int): Maximum number of iterations.
1315
+ eta_min (float, optional): Minimum learning rate. Default: ``0.0``.
1316
+ last_epoch (int, optional): The index of the last epoch. Default: ``-1``.
1317
+
1318
+ Supported Platforms:
1319
+ ``Ascend`` ``GPU`` ``CPU``
1320
+
1321
+ Examples:
1322
+ >>> from mindspore.experimental import optim
1323
+ >>> from mindspore import nn
1324
+ >>> net = nn.Dense(3, 2)
1325
+ >>> optimizer = optim.SGD(net.trainable_params(), lr=0.1, momentum=0.9)
1326
+ >>> scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=2)
1327
+ >>>
1328
+ >>> for i in range(6):
1329
+ ... scheduler.step()
1330
+ ... current_lr = scheduler.get_last_lr()
1331
+ ... print(current_lr)
1332
+ [Tensor(shape=[], dtype=Float32, value= 0.05)]
1333
+ [Tensor(shape=[], dtype=Float32, value= 0)]
1334
+ [Tensor(shape=[], dtype=Float32, value= 0.05)]
1335
+ [Tensor(shape=[], dtype=Float32, value= 0.1)]
1336
+ [Tensor(shape=[], dtype=Float32, value= 0.05)]
1337
+ [Tensor(shape=[], dtype=Float32, value= 0)]
1338
+ """
1339
+ def __init__(self, optimizer, T_max, eta_min=0.0, last_epoch=-1):
1340
+ if not isinstance(eta_min, (float, int)):
1341
+ raise TypeError(f"For 'CosineAnnealingLR', the 'eta_min' must be float or int, but got {type(eta_min)}.")
1342
+ if not isinstance(T_max, int) and not isinstance(T_max, bool):
1343
+ raise TypeError(f"For 'CosineAnnealingLR', the 'T_max' must be int, but got {type(eta_min)}.")
1344
+ self.T_max = T_max
1345
+ self.eta_min = eta_min
1346
+ self.math_pi = math.pi
1347
+ self.cos = P.Cos()
1348
+ self.cast = P.Cast()
1349
+ super(CosineAnnealingLR, self).__init__(optimizer, last_epoch)
1350
+
1351
+ def get_lr(self):
1352
+
1353
+ if self.last_epoch == 0:
1354
+ return [lr * 1. for lr in self._last_lr]
1355
+
1356
+ if (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0:
1357
+ pct_pi = self.cast(self.math_pi / self.T_max, mstype.float32)
1358
+ return [lr + (base_lr - self.eta_min) *
1359
+ (1 - self.cos(pct_pi)) / 2
1360
+ for base_lr, lr in
1361
+ zip(self.base_lrs, self._last_lr)]
1362
+
1363
+ return [(1 + self.cos(self.math_pi * self.last_epoch / self.T_max)) /
1364
+ (1 + self.cos(self.math_pi * (self.last_epoch - 1) / self.T_max)) *
1365
+ (lr - self.eta_min) + self.eta_min
1366
+ for lr in self._last_lr]
1367
+
1368
+ def _get_closed_form_lr(self):
1369
+ return [self.eta_min + (base_lr - self.eta_min) *
1370
+ (1 + self.cos(self.math_pi * self.last_epoch / self.T_max)) / 2
1371
+ for base_lr in self.base_lrs]