mindspore 2.5.0__cp39-cp39-win_amd64.whl → 2.6.0rc1__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (491) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +6 -4
  5. mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
  8. mindspore/_check_jit_forbidden_api.py +3 -0
  9. mindspore/_checkparam.py +3 -33
  10. mindspore/_deprecated/__init__.py +17 -0
  11. mindspore/_deprecated/jit.py +198 -0
  12. mindspore/_extends/builtin_operations.py +1 -1
  13. mindspore/_extends/parse/__init__.py +6 -7
  14. mindspore/_extends/parse/compile_config.py +19 -0
  15. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +22 -3
  16. mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
  17. mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
  18. mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
  19. mindspore/_extends/parse/parser.py +24 -193
  20. mindspore/_extends/parse/resources.py +1 -5
  21. mindspore/_extends/parse/standard_method.py +97 -74
  22. mindspore/_extends/pijit/__init__.py +2 -2
  23. mindspore/_extends/pijit/pijit_func_white_list.py +16 -11
  24. mindspore/_extends/pijit/tensor_func_list.py +27 -0
  25. mindspore/_extends/utils.py +1 -1
  26. mindspore/amp.py +4 -4
  27. mindspore/atlprov.dll +0 -0
  28. mindspore/avcodec-59.dll +0 -0
  29. mindspore/avdevice-59.dll +0 -0
  30. mindspore/avfilter-8.dll +0 -0
  31. mindspore/avformat-59.dll +0 -0
  32. mindspore/avutil-57.dll +0 -0
  33. mindspore/boost/__init__.py +2 -2
  34. mindspore/boost/base.py +3 -7
  35. mindspore/boost/boost_cell_wrapper.py +2 -2
  36. mindspore/c1.dll +0 -0
  37. mindspore/c1xx.dll +0 -0
  38. mindspore/c2.dll +0 -0
  39. mindspore/common/__init__.py +4 -3
  40. mindspore/common/_grad_function.py +56 -0
  41. mindspore/common/_pijit_context.py +14 -5
  42. mindspore/common/_register_for_tensor.py +1 -1
  43. mindspore/common/_stub_tensor.py +5 -10
  44. mindspore/common/_tensor_cpp_method.py +1 -1
  45. mindspore/common/_tensor_docs.py +1915 -3287
  46. mindspore/common/api.py +341 -354
  47. mindspore/common/auto_dynamic_shape.py +41 -44
  48. mindspore/common/dtype.py +5 -2
  49. mindspore/common/dump.py +7 -5
  50. mindspore/common/file_system.py +3 -0
  51. mindspore/common/hook_handle.py +5 -3
  52. mindspore/common/initializer.py +10 -6
  53. mindspore/common/jit_begin_end.py +94 -0
  54. mindspore/common/jit_config.py +6 -1
  55. mindspore/common/jit_context.py +76 -0
  56. mindspore/common/jit_trace.py +378 -0
  57. mindspore/common/lazy_inline.py +2 -2
  58. mindspore/common/mutable.py +5 -4
  59. mindspore/common/parameter.py +106 -39
  60. mindspore/common/seed.py +2 -2
  61. mindspore/common/sparse_tensor.py +23 -17
  62. mindspore/common/tensor.py +297 -714
  63. mindspore/communication/__init__.py +7 -5
  64. mindspore/communication/_comm_helper.py +47 -2
  65. mindspore/communication/comm_func.py +70 -53
  66. mindspore/communication/management.py +83 -17
  67. mindspore/context.py +214 -560
  68. mindspore/dataset/__init__.py +44 -20
  69. mindspore/dataset/audio/__init__.py +2 -8
  70. mindspore/dataset/audio/transforms.py +3 -17
  71. mindspore/dataset/core/config.py +3 -3
  72. mindspore/dataset/engine/cache_client.py +1 -1
  73. mindspore/dataset/engine/datasets.py +102 -120
  74. mindspore/dataset/engine/datasets_audio.py +22 -22
  75. mindspore/dataset/engine/datasets_standard_format.py +43 -24
  76. mindspore/dataset/engine/datasets_text.py +78 -85
  77. mindspore/dataset/engine/datasets_user_defined.py +108 -76
  78. mindspore/dataset/engine/datasets_vision.py +111 -108
  79. mindspore/dataset/engine/iterators.py +5 -3
  80. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
  81. mindspore/dataset/engine/samplers.py +279 -57
  82. mindspore/dataset/engine/serializer_deserializer.py +2 -1
  83. mindspore/dataset/engine/validators.py +10 -0
  84. mindspore/dataset/text/__init__.py +7 -6
  85. mindspore/dataset/text/transforms.py +6 -5
  86. mindspore/dataset/text/utils.py +3 -3
  87. mindspore/dataset/transforms/__init__.py +0 -9
  88. mindspore/dataset/transforms/transforms.py +3 -3
  89. mindspore/dataset/utils/browse_dataset.py +1 -1
  90. mindspore/dataset/vision/__init__.py +2 -9
  91. mindspore/dataset/vision/transforms.py +202 -158
  92. mindspore/dataset/vision/utils.py +7 -5
  93. mindspore/device_context/ascend/op_debug.py +60 -1
  94. mindspore/device_context/ascend/op_tuning.py +0 -4
  95. mindspore/device_manager.py +39 -3
  96. mindspore/dnnl.dll +0 -0
  97. mindspore/dpcmi.dll +0 -0
  98. mindspore/experimental/es/embedding_service.py +35 -27
  99. mindspore/experimental/map_parameter.py +4 -4
  100. mindspore/experimental/optim/adadelta.py +22 -26
  101. mindspore/experimental/optim/adagrad.py +4 -4
  102. mindspore/experimental/optim/adam.py +4 -0
  103. mindspore/experimental/optim/adamax.py +4 -4
  104. mindspore/experimental/optim/adamw.py +4 -0
  105. mindspore/experimental/optim/asgd.py +1 -1
  106. mindspore/experimental/optim/lr_scheduler.py +40 -22
  107. mindspore/experimental/optim/radam.py +5 -5
  108. mindspore/experimental/optim/rprop.py +1 -1
  109. mindspore/experimental/optim/sgd.py +1 -1
  110. mindspore/hal/contiguous_tensors_handle.py +6 -10
  111. mindspore/hal/device.py +55 -81
  112. mindspore/hal/event.py +38 -55
  113. mindspore/hal/memory.py +93 -144
  114. mindspore/hal/stream.py +81 -125
  115. mindspore/include/dataset/constants.h +7 -4
  116. mindspore/include/dataset/execute.h +2 -2
  117. mindspore/jpeg62.dll +0 -0
  118. mindspore/log.py +40 -2
  119. mindspore/mindrecord/__init__.py +20 -7
  120. mindspore/mindspore_backend_common.dll +0 -0
  121. mindspore/mindspore_backend_manager.dll +0 -0
  122. mindspore/mindspore_common.dll +0 -0
  123. mindspore/mindspore_core.dll +0 -0
  124. mindspore/mindspore_dump.dll +0 -0
  125. mindspore/mindspore_frontend.dll +0 -0
  126. mindspore/mindspore_glog.dll +0 -0
  127. mindspore/mindspore_memory_pool.dll +0 -0
  128. mindspore/mindspore_ms_backend.dll +0 -0
  129. mindspore/mindspore_ops.dll +0 -0
  130. mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
  131. mindspore/mindspore_ops_kernel_common.dll +0 -0
  132. mindspore/mindspore_profiler.dll +0 -0
  133. mindspore/mindspore_pyboost.dll +0 -0
  134. mindspore/mindspore_pynative.dll +0 -0
  135. mindspore/mindspore_res_manager.dll +0 -0
  136. mindspore/mindspore_runtime_pipeline.dll +0 -0
  137. mindspore/mint/__init__.py +131 -700
  138. mindspore/mint/distributed/__init__.py +5 -1
  139. mindspore/mint/distributed/distributed.py +194 -109
  140. mindspore/mint/linalg/__init__.py +2 -0
  141. mindspore/mint/nn/__init__.py +280 -18
  142. mindspore/mint/nn/functional.py +282 -64
  143. mindspore/mint/nn/layer/__init__.py +4 -0
  144. mindspore/mint/nn/layer/_functions.py +7 -3
  145. mindspore/mint/nn/layer/activation.py +120 -13
  146. mindspore/mint/nn/layer/conv.py +218 -24
  147. mindspore/mint/nn/layer/normalization.py +15 -16
  148. mindspore/mint/nn/layer/padding.py +1 -1
  149. mindspore/mint/nn/layer/pooling.py +66 -1
  150. mindspore/mint/optim/__init__.py +2 -1
  151. mindspore/mint/optim/sgd.py +171 -0
  152. mindspore/msobj140.dll +0 -0
  153. mindspore/mspdb140.dll +0 -0
  154. mindspore/mspdbcore.dll +0 -0
  155. mindspore/mspdbst.dll +0 -0
  156. mindspore/mspft140.dll +0 -0
  157. mindspore/msvcdis140.dll +0 -0
  158. mindspore/msvcp140_1.dll +0 -0
  159. mindspore/msvcp140_2.dll +0 -0
  160. mindspore/msvcp140_atomic_wait.dll +0 -0
  161. mindspore/msvcp140_codecvt_ids.dll +0 -0
  162. mindspore/nn/__init__.py +4 -1
  163. mindspore/nn/cell.py +1250 -176
  164. mindspore/nn/layer/activation.py +23 -21
  165. mindspore/nn/layer/basic.py +22 -16
  166. mindspore/nn/layer/container.py +1 -1
  167. mindspore/nn/layer/conv.py +22 -17
  168. mindspore/nn/layer/embedding.py +9 -8
  169. mindspore/nn/layer/normalization.py +48 -42
  170. mindspore/nn/layer/pooling.py +75 -31
  171. mindspore/nn/layer/transformer.py +11 -10
  172. mindspore/nn/learning_rate_schedule.py +4 -2
  173. mindspore/nn/loss/loss.py +27 -19
  174. mindspore/nn/optim/ada_grad.py +6 -5
  175. mindspore/nn/optim/adadelta.py +9 -7
  176. mindspore/nn/optim/adafactor.py +1 -1
  177. mindspore/nn/optim/adam.py +16 -12
  178. mindspore/nn/optim/adamax.py +8 -7
  179. mindspore/nn/optim/adasum.py +5 -5
  180. mindspore/nn/optim/asgd.py +1 -1
  181. mindspore/nn/optim/ftrl.py +11 -9
  182. mindspore/nn/optim/lamb.py +1 -1
  183. mindspore/nn/optim/lazyadam.py +12 -10
  184. mindspore/nn/optim/momentum.py +7 -6
  185. mindspore/nn/optim/optimizer.py +2 -2
  186. mindspore/nn/optim/proximal_ada_grad.py +12 -10
  187. mindspore/nn/optim/rmsprop.py +13 -12
  188. mindspore/nn/optim/rprop.py +9 -7
  189. mindspore/nn/optim/sgd.py +9 -6
  190. mindspore/nn/optim/tft_wrapper.py +5 -2
  191. mindspore/nn/probability/bijector/bijector.py +17 -11
  192. mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
  193. mindspore/nn/probability/bijector/invert.py +2 -2
  194. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  195. mindspore/nn/probability/bijector/softplus.py +3 -2
  196. mindspore/nn/probability/distribution/beta.py +3 -3
  197. mindspore/nn/probability/distribution/categorical.py +1 -1
  198. mindspore/nn/probability/distribution/cauchy.py +4 -2
  199. mindspore/nn/probability/distribution/exponential.py +6 -7
  200. mindspore/nn/probability/distribution/gamma.py +2 -2
  201. mindspore/nn/probability/distribution/gumbel.py +2 -2
  202. mindspore/nn/probability/distribution/half_normal.py +5 -3
  203. mindspore/nn/probability/distribution/logistic.py +5 -3
  204. mindspore/nn/probability/distribution/poisson.py +1 -1
  205. mindspore/nn/probability/distribution/uniform.py +5 -3
  206. mindspore/nn/reinforcement/_tensors_queue.py +1 -1
  207. mindspore/nn/reinforcement/tensor_array.py +1 -1
  208. mindspore/nn/wrap/__init__.py +6 -6
  209. mindspore/nn/wrap/cell_wrapper.py +178 -117
  210. mindspore/nn/wrap/grad_reducer.py +45 -36
  211. mindspore/nn/wrap/loss_scale.py +3 -3
  212. mindspore/numpy/array_creations.py +3 -3
  213. mindspore/numpy/array_ops.py +1 -1
  214. mindspore/numpy/math_ops.py +4 -4
  215. mindspore/numpy/utils.py +1 -2
  216. mindspore/numpy/utils_const.py +1 -2
  217. mindspore/opencv_core452.dll +0 -0
  218. mindspore/opencv_imgcodecs452.dll +0 -0
  219. mindspore/opencv_imgproc452.dll +0 -0
  220. mindspore/ops/__init__.py +3 -2
  221. mindspore/ops/_grad_experimental/grad_comm_ops.py +18 -3
  222. mindspore/ops/_grad_experimental/grad_debug_ops.py +8 -1
  223. mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
  224. mindspore/ops/_register_for_op.py +0 -11
  225. mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
  226. mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -4
  227. mindspore/ops/_vmap/vmap_array_ops.py +7 -6
  228. mindspore/ops/_vmap/vmap_grad_nn_ops.py +2 -1
  229. mindspore/ops/_vmap/vmap_math_ops.py +4 -7
  230. mindspore/ops/_vmap/vmap_nn_ops.py +9 -8
  231. mindspore/ops/auto_generate/__init__.py +4 -3
  232. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +102 -49
  233. mindspore/ops/auto_generate/gen_extend_func.py +281 -135
  234. mindspore/ops/auto_generate/gen_ops_def.py +2574 -2326
  235. mindspore/ops/auto_generate/gen_ops_prim.py +8566 -2755
  236. mindspore/ops/auto_generate/pyboost_inner_prim.py +106 -76
  237. mindspore/ops/composite/__init__.py +2 -1
  238. mindspore/ops/composite/base.py +19 -24
  239. mindspore/ops/composite/math_ops.py +6 -16
  240. mindspore/ops/composite/multitype_ops/__init__.py +5 -2
  241. mindspore/ops/composite/multitype_ops/_compile_utils.py +2 -3
  242. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
  243. mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
  244. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  245. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  246. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
  247. mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
  248. mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
  249. mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
  250. mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
  251. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
  252. mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
  253. mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
  254. mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
  255. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
  256. mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
  257. mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
  258. mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
  259. mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
  260. mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
  261. mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
  262. mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
  263. mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
  264. mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
  265. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
  266. mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
  267. mindspore/ops/composite/multitype_ops/pow_impl.py +2 -1
  268. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
  269. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  270. mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
  271. mindspore/ops/function/__init__.py +28 -2
  272. mindspore/ops/function/_add_attr_func.py +58 -0
  273. mindspore/ops/function/array_func.py +1629 -2345
  274. mindspore/ops/function/clip_func.py +38 -45
  275. mindspore/ops/function/debug_func.py +36 -44
  276. mindspore/ops/function/grad/__init__.py +1 -0
  277. mindspore/ops/function/grad/grad_func.py +104 -71
  278. mindspore/ops/function/image_func.py +1 -1
  279. mindspore/ops/function/linalg_func.py +46 -78
  280. mindspore/ops/function/math_func.py +3035 -3705
  281. mindspore/ops/function/nn_func.py +676 -241
  282. mindspore/ops/function/other_func.py +159 -1
  283. mindspore/ops/function/parameter_func.py +17 -30
  284. mindspore/ops/function/random_func.py +204 -361
  285. mindspore/ops/function/reshard_func.py +4 -70
  286. mindspore/ops/function/sparse_func.py +3 -3
  287. mindspore/ops/function/sparse_unary_func.py +5 -5
  288. mindspore/ops/function/spectral_func.py +25 -58
  289. mindspore/ops/function/vmap_func.py +24 -17
  290. mindspore/ops/functional.py +6 -4
  291. mindspore/ops/functional_overload.py +547 -4
  292. mindspore/ops/op_info_register.py +32 -244
  293. mindspore/ops/operations/__init__.py +10 -5
  294. mindspore/ops/operations/_custom_ops_utils.py +247 -0
  295. mindspore/ops/operations/_grad_ops.py +1 -10
  296. mindspore/ops/operations/_inner_ops.py +5 -76
  297. mindspore/ops/operations/_ms_kernel.py +4 -10
  298. mindspore/ops/operations/_rl_inner_ops.py +1 -1
  299. mindspore/ops/operations/_scalar_ops.py +3 -2
  300. mindspore/ops/operations/_sequence_ops.py +1 -1
  301. mindspore/ops/operations/_tensor_array.py +1 -1
  302. mindspore/ops/operations/array_ops.py +37 -22
  303. mindspore/ops/operations/comm_ops.py +150 -107
  304. mindspore/ops/operations/custom_ops.py +221 -23
  305. mindspore/ops/operations/debug_ops.py +115 -16
  306. mindspore/ops/operations/inner_ops.py +1 -1
  307. mindspore/ops/operations/linalg_ops.py +1 -58
  308. mindspore/ops/operations/manually_defined/_inner.py +1 -1
  309. mindspore/ops/operations/manually_defined/ops_def.py +746 -79
  310. mindspore/ops/operations/math_ops.py +21 -18
  311. mindspore/ops/operations/nn_ops.py +65 -191
  312. mindspore/ops/operations/other_ops.py +62 -9
  313. mindspore/ops/operations/random_ops.py +13 -7
  314. mindspore/ops/operations/reshard_ops.py +1 -1
  315. mindspore/ops/operations/sparse_ops.py +2 -2
  316. mindspore/ops/primitive.py +43 -32
  317. mindspore/ops/tensor_method.py +232 -13
  318. mindspore/ops_generate/__init__.py +0 -5
  319. mindspore/ops_generate/aclnn/__init__.py +0 -0
  320. mindspore/ops_generate/{aclnn_kernel_register_auto_cc_generator.py → aclnn/aclnn_kernel_register_auto_cc_generator.py} +43 -18
  321. mindspore/ops_generate/{gen_aclnn_implement.py → aclnn/gen_aclnn_implement.py} +49 -51
  322. mindspore/ops_generate/api/__init__.py +0 -0
  323. mindspore/ops_generate/{add_tensor_docs_generator.py → api/add_tensor_docs_generator.py} +9 -7
  324. mindspore/ops_generate/{cpp_create_prim_instance_helper_generator.py → api/cpp_create_prim_instance_helper_generator.py} +6 -9
  325. mindspore/ops_generate/{functional_map_cpp_generator.py → api/functional_map_cpp_generator.py} +25 -12
  326. mindspore/ops_generate/{functional_overload_py_generator.py → api/functional_overload_py_generator.py} +8 -6
  327. mindspore/ops_generate/{functions_cc_generator.py → api/functions_cc_generator.py} +14 -10
  328. mindspore/ops_generate/api/gen_api.py +103 -0
  329. mindspore/ops_generate/{op_api_proto.py → api/op_api_proto.py} +98 -69
  330. mindspore/ops_generate/{tensor_func_reg_cpp_generator.py → api/tensor_func_reg_cpp_generator.py} +82 -43
  331. mindspore/ops_generate/common/__init__.py +0 -0
  332. mindspore/ops_generate/common/gen_constants.py +91 -0
  333. mindspore/ops_generate/{gen_utils.py → common/gen_utils.py} +72 -19
  334. mindspore/ops_generate/{op_proto.py → common/op_proto.py} +64 -1
  335. mindspore/ops_generate/{template.py → common/template.py} +96 -84
  336. mindspore/ops_generate/gen_ops.py +23 -325
  337. mindspore/ops_generate/op_def/__init__.py +0 -0
  338. mindspore/ops_generate/op_def/gen_op_def.py +90 -0
  339. mindspore/ops_generate/{lite_ops_cpp_generator.py → op_def/lite_ops_cpp_generator.py} +47 -11
  340. mindspore/ops_generate/{ops_def_cc_generator.py → op_def/ops_def_cc_generator.py} +18 -7
  341. mindspore/ops_generate/{ops_def_h_generator.py → op_def/ops_def_h_generator.py} +5 -5
  342. mindspore/ops_generate/{ops_name_h_generator.py → op_def/ops_name_h_generator.py} +30 -15
  343. mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
  344. mindspore/ops_generate/op_def_py/__init__.py +0 -0
  345. mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
  346. mindspore/ops_generate/{op_def_py_generator.py → op_def_py/op_def_py_generator.py} +6 -5
  347. mindspore/ops_generate/{op_prim_py_generator.py → op_def_py/op_prim_py_generator.py} +24 -15
  348. mindspore/ops_generate/pyboost/__init__.py +0 -0
  349. mindspore/ops_generate/{auto_grad_impl_cc_generator.py → pyboost/auto_grad_impl_cc_generator.py} +11 -7
  350. mindspore/ops_generate/{auto_grad_reg_cc_generator.py → pyboost/auto_grad_reg_cc_generator.py} +7 -7
  351. mindspore/ops_generate/{gen_pyboost_func.py → pyboost/gen_pyboost_func.py} +40 -16
  352. mindspore/ops_generate/{op_template_parser.py → pyboost/op_template_parser.py} +105 -24
  353. mindspore/ops_generate/{pyboost_functions_cpp_generator.py → pyboost/pyboost_functions_cpp_generator.py} +55 -18
  354. mindspore/ops_generate/{pyboost_functions_h_generator.py → pyboost/pyboost_functions_h_generator.py} +42 -10
  355. mindspore/ops_generate/{pyboost_functions_py_generator.py → pyboost/pyboost_functions_py_generator.py} +6 -6
  356. mindspore/ops_generate/{pyboost_grad_function_cpp_generator.py → pyboost/pyboost_grad_function_cpp_generator.py} +11 -10
  357. mindspore/ops_generate/{pyboost_inner_prim_generator.py → pyboost/pyboost_inner_prim_generator.py} +8 -7
  358. mindspore/ops_generate/{pyboost_native_grad_functions_generator.py → pyboost/pyboost_native_grad_functions_generator.py} +14 -10
  359. mindspore/ops_generate/{pyboost_op_cpp_code_generator.py → pyboost/pyboost_op_cpp_code_generator.py} +140 -53
  360. mindspore/ops_generate/{pyboost_overload_functions_cpp_generator.py → pyboost/pyboost_overload_functions_cpp_generator.py} +28 -15
  361. mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +88 -4
  362. mindspore/ops_generate/resources/__init__.py +0 -0
  363. mindspore/ops_generate/resources/resource_list.py +30 -0
  364. mindspore/ops_generate/resources/resource_loader.py +36 -0
  365. mindspore/ops_generate/resources/resource_manager.py +64 -0
  366. mindspore/ops_generate/resources/yaml_loader.py +88 -0
  367. mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
  368. mindspore/parallel/__init__.py +6 -2
  369. mindspore/parallel/_auto_parallel_context.py +133 -6
  370. mindspore/parallel/_cell_wrapper.py +130 -15
  371. mindspore/parallel/_parallel_serialization.py +95 -4
  372. mindspore/parallel/_ps_context.py +1 -1
  373. mindspore/parallel/_recovery_context.py +7 -2
  374. mindspore/parallel/_tensor.py +142 -18
  375. mindspore/parallel/_utils.py +198 -25
  376. mindspore/parallel/algo_parameter_config.py +3 -3
  377. mindspore/parallel/auto_parallel.py +732 -0
  378. mindspore/parallel/checkpoint_convert.py +159 -0
  379. mindspore/parallel/checkpoint_transform.py +656 -37
  380. mindspore/parallel/cluster/process_entity/_api.py +151 -19
  381. mindspore/parallel/cluster/run.py +1 -1
  382. mindspore/parallel/function/__init__.py +24 -0
  383. mindspore/parallel/function/reshard_func.py +259 -0
  384. mindspore/parallel/nn/__init__.py +25 -0
  385. mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
  386. mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
  387. mindspore/parallel/parameter_broadcast.py +24 -13
  388. mindspore/parallel/shard.py +137 -61
  389. mindspore/parallel/transform_safetensors.py +287 -95
  390. mindspore/pgodb140.dll +0 -0
  391. mindspore/pgort140.dll +0 -0
  392. mindspore/profiler/__init__.py +9 -5
  393. mindspore/profiler/analysis/parser/ascend_cann_parser.py +6 -2
  394. mindspore/profiler/analysis/parser/ms_framework_parser.py +4 -4
  395. mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +7 -4
  396. mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +22 -0
  397. mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +3 -3
  398. mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +241 -86
  399. mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +41 -2
  400. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +33 -35
  401. mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +7 -0
  402. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +8 -3
  403. mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +141 -30
  404. mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +5 -6
  405. mindspore/profiler/common/ascend_msprof_exporter.py +5 -4
  406. mindspore/profiler/common/constant.py +12 -0
  407. mindspore/profiler/common/msprof_cmd_tool.py +42 -23
  408. mindspore/profiler/common/path_manager.py +24 -0
  409. mindspore/profiler/common/profiler_context.py +26 -2
  410. mindspore/profiler/common/profiler_meta_data.py +74 -0
  411. mindspore/profiler/common/profiler_parameters.py +59 -18
  412. mindspore/profiler/common/profiler_path_manager.py +66 -7
  413. mindspore/profiler/dynamic_profiler.py +112 -79
  414. mindspore/profiler/envprofiler.py +26 -1
  415. mindspore/profiler/experimental_config.py +197 -0
  416. mindspore/profiler/mstx.py +57 -14
  417. mindspore/profiler/platform/npu_profiler.py +33 -7
  418. mindspore/profiler/profiler.py +541 -45
  419. mindspore/profiler/profiler_action_controller.py +1 -1
  420. mindspore/profiler/profiler_interface.py +4 -0
  421. mindspore/profiler/schedule.py +57 -22
  422. mindspore/rewrite/api/node.py +15 -13
  423. mindspore/rewrite/api/symbol_tree.py +1 -1
  424. mindspore/run_check/_check_version.py +25 -14
  425. mindspore/run_check/run_check.py +1 -1
  426. mindspore/runtime/__init__.py +2 -2
  427. mindspore/runtime/executor.py +40 -11
  428. mindspore/runtime/memory.py +25 -8
  429. mindspore/safeguard/rewrite_obfuscation.py +12 -9
  430. mindspore/swresample-4.dll +0 -0
  431. mindspore/swscale-6.dll +0 -0
  432. mindspore/tbbmalloc.dll +0 -0
  433. mindspore/tinyxml2.dll +0 -0
  434. mindspore/train/__init__.py +8 -8
  435. mindspore/train/_utils.py +35 -7
  436. mindspore/train/amp.py +1 -1
  437. mindspore/train/callback/__init__.py +2 -2
  438. mindspore/train/callback/_callback.py +2 -16
  439. mindspore/train/callback/_checkpoint.py +24 -40
  440. mindspore/train/callback/_cluster_monitor.py +14 -18
  441. mindspore/train/callback/_flops_collector.py +2 -3
  442. mindspore/train/callback/_history.py +7 -4
  443. mindspore/train/callback/_lambda_callback.py +2 -2
  444. mindspore/train/callback/_landscape.py +0 -3
  445. mindspore/train/callback/_loss_monitor.py +2 -1
  446. mindspore/train/callback/_on_request_exit.py +6 -5
  447. mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
  448. mindspore/train/callback/_summary_collector.py +8 -13
  449. mindspore/train/callback/_time_monitor.py +2 -1
  450. mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +179 -103
  451. mindspore/train/data_sink.py +25 -2
  452. mindspore/train/dataset_helper.py +4 -5
  453. mindspore/train/loss_scale_manager.py +8 -7
  454. mindspore/train/metrics/accuracy.py +3 -3
  455. mindspore/train/metrics/confusion_matrix.py +9 -9
  456. mindspore/train/metrics/error.py +3 -3
  457. mindspore/train/metrics/hausdorff_distance.py +4 -4
  458. mindspore/train/metrics/mean_surface_distance.py +3 -3
  459. mindspore/train/metrics/metric.py +0 -12
  460. mindspore/train/metrics/occlusion_sensitivity.py +4 -2
  461. mindspore/train/metrics/precision.py +8 -6
  462. mindspore/train/metrics/recall.py +9 -9
  463. mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
  464. mindspore/train/mind_ir_pb2.py +19 -12
  465. mindspore/train/model.py +176 -103
  466. mindspore/train/serialization.py +246 -988
  467. mindspore/train/summary/_summary_adapter.py +2 -2
  468. mindspore/train/summary/summary_record.py +1 -1
  469. mindspore/turbojpeg.dll +0 -0
  470. mindspore/utils/__init__.py +3 -2
  471. mindspore/utils/dryrun.py +4 -2
  472. mindspore/utils/hooks.py +81 -0
  473. mindspore/utils/utils.py +138 -4
  474. mindspore/vcmeta.dll +0 -0
  475. mindspore/vcruntime140.dll +0 -0
  476. mindspore/vcruntime140_1.dll +0 -0
  477. mindspore/version.py +1 -1
  478. {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/METADATA +2 -1
  479. {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/RECORD +483 -438
  480. mindspore/_install_custom.py +0 -43
  481. mindspore/common/_register_for_adapter.py +0 -74
  482. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
  483. mindspore/ops/auto_generate/gen_arg_handler.py +0 -136
  484. mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
  485. mindspore/ops_generate/gen_constants.py +0 -190
  486. mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
  487. mindspore/ops_generate/ops_primitive_h_generator.py +0 -81
  488. /mindspore/ops_generate/{base_generator.py → common/base_generator.py} +0 -0
  489. {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/WHEEL +0 -0
  490. {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/entry_points.txt +0 -0
  491. {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/top_level.txt +0 -0
@@ -228,21 +228,23 @@ class FTRL(Optimizer):
228
228
  If `order_params` in the keys, other keys will be ignored and the element of 'order_params' must be in
229
229
  one group of `params`.
230
230
 
231
- initial_accum (float): The starting value for accumulators `m`, must be zero or positive values.
231
+ initial_accum (float, optional): The starting value for accumulators `m`, must be zero or positive values.
232
232
  Default: ``0.1`` .
233
- learning_rate (float): The learning rate value, must be zero or positive, dynamic learning rate is currently
234
- not supported. Default: ``0.001`` .
235
- lr_power (float): Learning rate power controls how the learning rate decreases during training, must be less
233
+ learning_rate (float, optional): The learning rate value, must be zero or positive, dynamic learning rate
234
+ is currently not supported. Default: ``0.001`` .
235
+ lr_power (float, optional): Learning rate power controls how the learning rate decreases during training,
236
+ must be less
236
237
  than or equal to zero. Use fixed learning rate if lr_power is zero. Default: ``-0.5`` .
237
- l1 (float): l1 regularization strength, must be greater than or equal to zero. Default: ``0.0`` .
238
- l2 (float): l2 regularization strength, must be greater than or equal to zero. Default: ``0.0`` .
239
- use_locking (bool): If true, use locks for updating operation. Default: ``False`` .
240
- loss_scale (float): Value for the loss scale. It must be greater than 0.0. In general, use the default value.
238
+ l1 (float, optional): l1 regularization strength, must be greater than or equal to zero. Default: ``0.0`` .
239
+ l2 (float, optional): l2 regularization strength, must be greater than or equal to zero. Default: ``0.0`` .
240
+ use_locking (bool, optional): If true, use locks for updating operation. Default: ``False`` .
241
+ loss_scale (float, optional): Value for the loss scale. It must be greater than 0.0. In general,
242
+ use the default value.
241
243
  Only when `FixedLossScaleManager` is used for training and the `drop_overflow_update` in
242
244
  `FixedLossScaleManager` is set to ``False`` , then this value needs to be the same as the `loss_scale` in
243
245
  `FixedLossScaleManager`. Refer to class :class:`mindspore.amp.FixedLossScaleManager` for more details.
244
246
  Default: ``1.0`` .
245
- weight_decay (Union[float, int, Cell]): Weight decay (L2 penalty). Default: ``0.0`` .
247
+ weight_decay (Union[float, int, Cell], optional): Weight decay (L2 penalty). Default: ``0.0`` .
246
248
 
247
249
  - float: The fixed weight decay value. Must be equal to or greater than 0.
248
250
 
@@ -265,7 +265,7 @@ class Lamb(Optimizer):
265
265
  self.moments2 = self.params.clone(prefix="lamb_v", init='zeros')
266
266
  self.device_ascend = context.get_context("device_target") == "Ascend"
267
267
 
268
- @jit
268
+ @jit(backend="ms_backend")
269
269
  def construct(self, gradients):
270
270
  weight_decay = self.get_weight_decay()
271
271
  lr = self.get_lr()
@@ -321,7 +321,7 @@ class LazyAdam(Optimizer):
321
321
  If `order_params` in the keys, other keys will be ignored and the element of 'order_params' must be in
322
322
  one group of `params`.
323
323
 
324
- learning_rate (Union[float, int, Tensor, Iterable, LearningRateSchedule]): Default: ``1e-3`` .
324
+ learning_rate (Union[float, int, Tensor, Iterable, LearningRateSchedule], optional): Default: ``1e-3`` .
325
325
 
326
326
  - float: The fixed learning rate value. Must be equal to or greater than 0.
327
327
 
@@ -337,20 +337,21 @@ class LazyAdam(Optimizer):
337
337
  <https://www.mindspore.cn/docs/en/master/api_python/mindspore.nn.html#learningrateschedule-class>`_
338
338
  with step as the input to get the learning rate of current step.
339
339
 
340
- beta1 (float): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0).
341
- Default: ``0.9`` .
342
- beta2 (float): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0).
343
- Default: ``0.999`` .
344
- eps (float): Term added to the denominator to improve numerical stability. Should be greater than 0.
340
+ beta1 (float, optional): The exponential decay rate for the 1st moment estimations.
341
+ Should be in range (0.0, 1.0). Default: ``0.9`` .
342
+ beta2 (float, optional): The exponential decay rate for the 2nd moment estimations.
343
+ Should be in range (0.0, 1.0). Default: ``0.999`` .
344
+ eps (float, optional): Term added to the denominator to improve numerical stability. Should be greater than 0.
345
345
  Default: ``1e-8`` .
346
- use_locking (bool): Whether to enable a lock to protect the updating process of variable tensors.
346
+ use_locking (bool, optional): Whether to enable a lock to protect the updating process of variable tensors.
347
347
  If ``true`` , updates of the `w`, `m`, and `v` tensors will be protected by a lock.
348
348
  If ``false`` , the result is unpredictable. Default: ``False`` .
349
- use_nesterov (bool): Whether to use Nesterov Accelerated Gradient (NAG) algorithm to update the gradients.
349
+ use_nesterov (bool, optional): Whether to use Nesterov Accelerated Gradient (NAG) algorithm to
350
+ update the gradients.
350
351
  If ``true`` , update the gradients using NAG.
351
352
  If ``false`` , update the gradients without using NAG. Default: ``False`` .
352
353
 
353
- weight_decay (Union[float, int, Cell]): Weight decay (L2 penalty). Default: ``0.0`` .
354
+ weight_decay (Union[float, int, Cell], optional): Weight decay (L2 penalty). Default: ``0.0`` .
354
355
 
355
356
  - float: The fixed weight decay value. Must be equal to or greater than 0.
356
357
 
@@ -359,7 +360,8 @@ class LazyAdam(Optimizer):
359
360
  - Cell: Weight decay is dynamic. During training, the optimizer calls the instance of
360
361
  the Cell with step as the input to get the weight decay value of current step.
361
362
 
362
- loss_scale (float): A floating point value for the loss scale. Should be equal to or greater than 1. In general,
363
+ loss_scale (float, optional): A floating point value for the loss scale. Should be equal to or greater than 1.
364
+ In general,
363
365
  use the default value. Only when `FixedLossScaleManager` is used for training and the `drop_overflow_update`
364
366
  in `FixedLossScaleManager` is set to ``False`` , then this value needs to be the same as the `loss_scale` in
365
367
  `FixedLossScaleManager`. Refer to class :class:`mindspore.amp.FixedLossScaleManager` for more details.
@@ -103,7 +103,7 @@ class Momentum(Optimizer):
103
103
  If `order_params` in the keys, other keys will be ignored and the element of 'order_params' must be in
104
104
  one group of `params`.
105
105
 
106
- learning_rate (Union[float, int, Tensor, Iterable, LearningRateSchedule]):
106
+ learning_rate (Union[float, int, Tensor, Iterable, LearningRateSchedule], optional):
107
107
 
108
108
  - float: The fixed learning rate value. Must be equal to or greater than 0.
109
109
 
@@ -119,10 +119,10 @@ class Momentum(Optimizer):
119
119
  <https://www.mindspore.cn/docs/en/master/api_python/mindspore.nn.html#learningrateschedule-class>`_
120
120
  with step as the input to get the learning rate of current step.
121
121
 
122
- momentum (float): Hyperparameter of type float, means momentum for the moving average.
122
+ momentum (float, optional): Hyperparameter of type float, means momentum for the moving average.
123
123
  It must be at least 0.0.
124
124
 
125
- weight_decay (Union[float, int, Cell]): Weight decay (L2 penalty). Default: ``0.0`` .
125
+ weight_decay (Union[float, int, Cell], optional): Weight decay (L2 penalty). Default: ``0.0`` .
126
126
 
127
127
  - float: The fixed weight decay value. Must be equal to or greater than 0.
128
128
 
@@ -131,12 +131,13 @@ class Momentum(Optimizer):
131
131
  - Cell: Weight decay is dynamic. During training, the optimizer calls the instance of
132
132
  the Cell with step as the input to get the weight decay value of current step.
133
133
 
134
- loss_scale (float): A floating point value for the loss scale. It must be greater than 0.0. In general, use the
134
+ loss_scale (float, optional): A floating point value for the loss scale. It must be greater than 0.0.
135
+ In general, use the
135
136
  default value. Only when `FixedLossScaleManager` is used for training and the `drop_overflow_update` in
136
137
  `FixedLossScaleManager` is set to ``False`` , then this value needs to be the same as the `loss_scale` in
137
138
  `FixedLossScaleManager`. Refer to class :class:`mindspore.amp.FixedLossScaleManager` for more details.
138
139
  Default: ``1.0`` .
139
- use_nesterov (bool): Enable Nesterov momentum. Default: ``False`` .
140
+ use_nesterov (bool, optional): Enable Nesterov momentum. Default: ``False`` .
140
141
 
141
142
  Inputs:
142
143
  - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
@@ -199,7 +200,7 @@ class Momentum(Optimizer):
199
200
  self._get_distributed_optimizer_list("momentum", use_nesterov=self.use_nesterov)
200
201
  self.use_dist_optimizer = self._use_distibuted_optimizer()
201
202
 
202
- @jit
203
+ @jit(backend="ms_backend")
203
204
  def construct(self, gradients):
204
205
  params = self.params
205
206
  moments = self.moments
@@ -122,10 +122,10 @@ class Optimizer(Cell):
122
122
  If `order_params` in the keys, other keys will be ignored and the element of 'order_params' must be in
123
123
  one group of `params`.
124
124
 
125
- weight_decay (Union[float, int]): An int or a floating point value for the weight decay.
125
+ weight_decay (Union[float, int], optional): An int or a floating point value for the weight decay.
126
126
  It must be equal to or greater than 0.
127
127
  If the type of `weight_decay` input is int, it will be converted to float. Default: ``0.0`` .
128
- loss_scale (float): A floating point value for the loss scale. It must be greater than 0. If the
128
+ loss_scale (float, optional): A floating point value for the loss scale. It must be greater than 0. If the
129
129
  type of `loss_scale` input is int, it will be converted to float. In general, use the default value. Only
130
130
  when `FixedLossScaleManager` is used for training and the `drop_overflow_update` in
131
131
  `FixedLossScaleManager` is set to ``False`` , this value needs to be the same as the `loss_scale` in
@@ -83,8 +83,8 @@ class ProximalAdagrad(Optimizer):
83
83
 
84
84
  Args:
85
85
  params (Union[list[Parameter], list[dict]]): Must be list of `Parameter` or list of `dict`. When the
86
- `params` is a list of `dict`, the string "params", "lr", "weight_decay", "grad_centralization" and
87
- "order_params" are the keys can be parsed.
86
+ `params` is a list of `dict`, the string `"params"`, `"lr"`, `"weight_decay"`, `"grad_centralization"` and
87
+ `"order_params"` are the keys can be parsed.
88
88
 
89
89
  - params: Required. Parameters in current group. The value must be a list of `Parameter`.
90
90
 
@@ -108,8 +108,9 @@ class ProximalAdagrad(Optimizer):
108
108
  If `order_params` in the keys, other keys will be ignored and the element of 'order_params' must be in
109
109
  one group of `params`.
110
110
 
111
- accum (float): The starting value for accumulators `accum`, must be zero or positive values. Default: ``0.1`` .
112
- learning_rate (Union[float, int, Tensor, Iterable, LearningRateSchedule]): Default: ``0.001`` .
111
+ accum (float, optional): The starting value for accumulators `accum`, must be zero or positive values.
112
+ Default: ``0.1`` .
113
+ learning_rate (Union[float, int, Tensor, Iterable, LearningRateSchedule], optional): Default: ``0.001`` .
113
114
 
114
115
  - float: The fixed learning rate value. Must be equal to or greater than 0.
115
116
 
@@ -125,15 +126,16 @@ class ProximalAdagrad(Optimizer):
125
126
  <https://www.mindspore.cn/docs/en/master/api_python/mindspore.nn.html#learningrateschedule-class>`_
126
127
  with step as the input to get the learning rate of the current step.
127
128
 
128
- l1 (float): l1 regularization strength, must be greater than or equal to zero. Default: ``0.0`` .
129
- l2 (float): l2 regularization strength, must be greater than or equal to zero. Default: ``0.0`` .
130
- use_locking (bool): If true, use locks for updating operation. Default: ``False`` .
131
- loss_scale (float): Value for the loss scale. It must be greater than 0.0. In general, use the default value.
129
+ l1 (float, optional): l1 regularization strength, must be greater than or equal to zero. Default: ``0.0`` .
130
+ l2 (float, optional): l2 regularization strength, must be greater than or equal to zero. Default: ``0.0`` .
131
+ use_locking (bool, optional): If ``True``, use locks for updating operation. Default: ``False`` .
132
+ loss_scale (float, optional): Value for the loss scale. It must be greater than 0.0. In general,
133
+ use the default value.
132
134
  Only when `FixedLossScaleManager` is used for training and the `drop_overflow_update` in
133
135
  `FixedLossScaleManager` is set to ``False`` , then this value needs to be the same as the `loss_scale` in
134
136
  `FixedLossScaleManager`. Refer to class :class:`mindspore.amp.FixedLossScaleManager` for more details.
135
137
  Default: ``1.0`` .
136
- weight_decay (Union[float, int, Cell]): Weight decay (L2 penalty). Default: ``0.0`` .
138
+ weight_decay (Union[float, int, Cell], optional): Weight decay (L2 penalty). Default: ``0.0`` .
137
139
 
138
140
  - float: The fixed weight decay value. Must be equal to or greater than 0.
139
141
 
@@ -199,7 +201,7 @@ class ProximalAdagrad(Optimizer):
199
201
  self.opt = P.ApplyProximalAdagrad(use_locking=use_locking)
200
202
  self.sparse_opt = P.SparseApplyProximalAdagrad(use_locking=use_locking)
201
203
 
202
- @jit
204
+ @jit(backend="ms_backend")
203
205
  def construct(self, grads):
204
206
  params = self._parameters
205
207
  accum = self.accum
@@ -92,9 +92,9 @@ class RMSProp(Optimizer):
92
92
  :math:`t` represents the current step.
93
93
 
94
94
  Note:
95
- If parameters are not grouped, the `weight_decay` in optimizer will be applied on the network parameters without
96
- 'beta' or 'gamma' in their names. Users can group parameters to change the strategy of decaying weight. When
97
- parameters are grouped, each group can set `weight_decay`. If not, the `weight_decay` in optimizer will be
95
+ If parameters are not grouped, the `weight_decay` in optimizer will be applied on the network parameters
96
+ without 'beta' or 'gamma' in their names. Users can group parameters to change the strategy of decaying weight.
97
+ When parameters are grouped, each group can set `weight_decay`. If not, the `weight_decay` in optimizer will be
98
98
  applied.
99
99
 
100
100
  Args:
@@ -124,7 +124,7 @@ class RMSProp(Optimizer):
124
124
  If `order_params` in the keys, other keys will be ignored and the element of 'order_params' must be in
125
125
  one group of `params`.
126
126
 
127
- learning_rate (Union[float, int, Tensor, Iterable, LearningRateSchedule]): Default: ``0.1`` .
127
+ learning_rate (Union[float, int, Tensor, Iterable, LearningRateSchedule], optional): Default: ``0.1`` .
128
128
 
129
129
  - float: The fixed learning rate value. Must be equal to or greater than 0.
130
130
 
@@ -140,21 +140,22 @@ class RMSProp(Optimizer):
140
140
  <https://www.mindspore.cn/docs/en/master/api_python/mindspore.nn.html#learningrateschedule-class>`_
141
141
  with step as the input to get the learning rate of the current step.
142
142
 
143
- decay (float): Decay rate. Should be equal to or greater than 0. Default: ``0.9`` .
144
- momentum (float): Hyperparameter of type float, means momentum for the moving average. Should be equal to or
145
- greater than 0. Default: ``0.0`` .
146
- epsilon (float): Term added to the denominator to improve numerical stability. Should be greater than
143
+ decay (float, optional): Decay rate. Should be equal to or greater than 0. Default: ``0.9`` .
144
+ momentum (float, optional): Hyperparameter of type float, means momentum for the moving average.
145
+ Should be equal to or greater than 0. Default: ``0.0`` .
146
+ epsilon (float, optional): Term added to the denominator to improve numerical stability. Should be greater than
147
147
  0. Default: ``1e-10`` .
148
- use_locking (bool): Whether to enable a lock to protect the updating process of variable tensors.
148
+ use_locking (bool, optional): Whether to enable a lock to protect the updating process of variable tensors.
149
149
  Default: ``False`` .
150
- centered (bool): If True, gradients are normalized by the estimated variance of the gradient.
150
+ centered (bool, optional): If True, gradients are normalized by the estimated variance of the gradient.
151
151
  Default: ``False`` .
152
- loss_scale (float): A floating point value for the loss scale. Should be greater than 0. In general, use the
152
+ loss_scale (float, optional): A floating point value for the loss scale. Should be greater than 0. In general,
153
+ use the
153
154
  default value. Only when `FixedLossScaleManager` is used for training and the `drop_overflow_update` in
154
155
  `FixedLossScaleManager` is set to ``False`` , then this value needs to be the same as the `loss_scale` in
155
156
  `FixedLossScaleManager`. Refer to class :class:`mindspore.amp.FixedLossScaleManager` for more details.
156
157
  Default: ``1.0`` .
157
- weight_decay (Union[float, int, Cell]): Weight decay (L2 penalty). Default: ``0.0`` .
158
+ weight_decay (Union[float, int, Cell], optional): Weight decay (L2 penalty). Default: ``0.0`` .
158
159
 
159
160
  - float: The fixed weight decay value. Must be equal to or greater than 0.
160
161
 
@@ -58,8 +58,8 @@ class Rprop(Optimizer):
58
58
 
59
59
  Args:
60
60
  params (Union[list[Parameter], list[dict]]): Must be list of `Parameter` or list of `dict`. When the
61
- `parameters` is a list of `dict`, the "params", "lr", "weight_decay", "grad_centralization" and
62
- "order_params" are the keys can be parsed.
61
+ `parameters` is a list of `dict`, the `"params"`, `"lr"`, `"weight_decay"`, `"grad_centralization"` and
62
+ `"order_params"` are the keys can be parsed.
63
63
 
64
64
  - params: Required. Parameters in current group. The value must be a list of `Parameter`.
65
65
 
@@ -83,7 +83,8 @@ class Rprop(Optimizer):
83
83
  If `order_params` in the keys, other keys will be ignored and the element of 'order_params' must be in
84
84
  one group of `params`.
85
85
 
86
- learning_rate (Union[float, int, Tensor, Iterable, LearningRateSchedule]): Learning_rate. Default: ``0.1`` .
86
+ learning_rate (Union[float, int, Tensor, Iterable, LearningRateSchedule], optional): Learning_rate.
87
+ Default: ``0.1`` .
87
88
 
88
89
  - float: The fixed learning rate value. Must be equal to or greater than 0.
89
90
 
@@ -99,11 +100,12 @@ class Rprop(Optimizer):
99
100
  <https://www.mindspore.cn/docs/en/master/api_python/mindspore.nn.html#learningrateschedule-class>`_
100
101
  with step as the input to get the learning rate of current step.
101
102
 
102
- etas (tuple[float, float]): The factor of multiplicative increasing or
103
+ etas (tuple[float, float], optional): The factor of multiplicative increasing or
103
104
  descreasing(etaminus, etaplus). Default: ``(0.5, 1.2)`` .
104
- step_sizes(tuple[float, float]): The allowed minimal and maximal step size(min_step_sizes, max_step_size).
105
+ step_sizes(tuple[float, float], optional): The allowed minimal and maximal
106
+ step size(min_step_sizes, max_step_size).
105
107
  Default: ``(1e-6, 50.)`` .
106
- weight_decay (Union[float, int, Cell]): Weight decay (L2 penalty). Default: ``0.0`` .
108
+ weight_decay (Union[float, int, Cell], optional): Weight decay (L2 penalty). Default: ``0.0`` .
107
109
 
108
110
  - float: The fixed weight decay value. Must be equal to or greater than 0.
109
111
 
@@ -199,7 +201,7 @@ class Rprop(Optimizer):
199
201
  self.select = P.Select()
200
202
  self.ones_like = P.OnesLike()
201
203
 
202
- @jit
204
+ @jit(backend="ms_backend")
203
205
  def construct(self, gradients):
204
206
  gradients = self.flatten_gradients(gradients)
205
207
  gradients = self.decay_weight(gradients)
mindspore/nn/optim/sgd.py CHANGED
@@ -90,7 +90,7 @@ class SGD(Optimizer):
90
90
  If `order_params` in the keys, other keys will be ignored and the element of 'order_params' must be in
91
91
  one group of `params`.
92
92
 
93
- learning_rate (Union[float, int, Tensor, Iterable, LearningRateSchedule]): Default: ``0.1`` .
93
+ learning_rate (Union[float, int, Tensor, Iterable, LearningRateSchedule], optional): Default: ``0.1`` .
94
94
 
95
95
  - float: The fixed learning rate value. Must be equal to or greater than 0.
96
96
 
@@ -106,12 +106,15 @@ class SGD(Optimizer):
106
106
  <https://www.mindspore.cn/docs/en/master/api_python/mindspore.nn.html#learningrateschedule-class>`_
107
107
  with step as the input to get the learning rate of current step.
108
108
 
109
- momentum (float): A floating point value the momentum. must be at least 0.0. Default: ``0.0`` .
110
- dampening (float): A floating point value of dampening for momentum. must be at least 0.0. Default: ``0.0`` .
111
- weight_decay (float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: ``0.0`` .
112
- nesterov (bool): Enables the Nesterov momentum. If use nesterov, momentum must be positive,
109
+ momentum (float, optional): A floating point value the momentum. must be at least 0.0. Default: ``0.0`` .
110
+ dampening (float, optional): A floating point value of dampening for momentum. must be at least 0.0.
111
+ Default: ``0.0`` .
112
+ weight_decay (float, optional): Weight decay (L2 penalty). It must be equal to or greater than 0.
113
+ Default: ``0.0`` .
114
+ nesterov (bool, optional): Enables the Nesterov momentum. If use nesterov, momentum must be positive,
113
115
  and dampening must be equal to 0.0. Default: ``False`` .
114
- loss_scale (float): A floating point value for the loss scale, which must be larger than 0.0. In general, use
116
+ loss_scale (float, optional): A floating point value for the loss scale, which must be larger than 0.0.
117
+ In general, use
115
118
  the default value. Only when `FixedLossScaleManager` is used for training and the `drop_overflow_update` in
116
119
  `FixedLossScaleManager` is set to ``False`` , then this value needs to be the same as the `loss_scale` in
117
120
  `FixedLossScaleManager`. Refer to class :class:`mindspore.amp.FixedLossScaleManager` for more details.
@@ -67,14 +67,16 @@ class OptTFTWrapper(Optimizer):
67
67
  raise TypeError(f"For 'OptTFTWrapper', the argument 'opt' must be Optimizer type, " f"but got {type(opt)}.")
68
68
  super(OptTFTWrapper, self).__init__(opt.learning_rate, opt._parameters) # pylint: disable=W0212
69
69
  tft_env = os.getenv("MS_ENABLE_TFT", "")
70
- if ("TTP:1" not in tft_env) and ("UCE:1" not in tft_env):
71
- raise ValueError("MindIO TFT regitster need custom switch on[MS_ENABLE_TFT='{TTP:1,UCE:1}']!")
70
+ if ("TTP:1" not in tft_env) and ("UCE:1" not in tft_env) and ("ARF:1" not in tft_env):
71
+ raise ValueError("MindIO TFT regitster need custom switch on[MS_ENABLE_TFT='{TTP:1,UCE:1,ARF:1}']!")
72
72
  mode = context.get_context("mode")
73
73
  device_target = context.get_context("device_target")
74
74
  if device_target != "Ascend" or mode != context.GRAPH_MODE:
75
75
  raise ValueError("MindIO adataper only support on Ascend device with GRAPH Mode!")
76
76
  self.opt = opt
77
77
  self.report = TensorReport()
78
+ self.report_end = TensorReport()
79
+ self.report_end.add_prim_attr("side_effect_mem", True).add_prim_attr("optimizer_end", True)
78
80
  self.depend = ops.Depend()
79
81
  self.allreduce_sum = ops.AllReduce()
80
82
  self.allreduce_sum.add_prim_attr("tft_report_before", True)
@@ -121,4 +123,5 @@ class OptTFTWrapper(Optimizer):
121
123
 
122
124
  grads = self.depend(gradients, self.report("tft_report", self.tft_g_one_flag))
123
125
  opt_ret = self.opt(grads)
126
+ self.report_end("tft_report", self.tft_g_one_flag)
124
127
  return opt_ret
@@ -41,15 +41,20 @@ class Bijector(Cell):
41
41
  param (dict): The parameters used to initialize the Bijector. Default: ``None`` .
42
42
 
43
43
  Note:
44
- `dtype` of bijector represents the type of the distributions that the bijector could operate on.
45
- When `dtype` is None, there is no enforcement on the type of input value except that the input value
46
- has to be float type. During initialization, when `dtype` is None, there is no enforcement on the dtype
47
- of the parameters. All parameters should have the same float type, otherwise a TypeError will be raised.
48
- Specifically, the parameter type will follow the dtype of the input value, i.e. parameters of the bijector
49
- will be casted into the same type as input value when `dtype` is None.
50
- When `dtype` is specified, it is forcing the parameters and input value to be the same dtype as `dtype`.
51
- When the type of parameters or the type of the input value is not the same as `dtype`, a TypeError will be
52
- raised. Only subtype of mindspore.float_type can be used to specify bijector's `dtype`.
44
+ - `dtype` of bijector represents the type of the distributions that the bijector could operate on.
45
+ - When `dtype` is None, there is no enforcement on the type of input value except that the input value
46
+ has to be float type. During initialization, when `dtype` is None, there is no enforcement on the dtype
47
+ of the parameters. All parameters should have the same float type, otherwise a TypeError will be raised.
48
+
49
+ Specifically, the parameter type will follow the dtype of the input value.
50
+
51
+ - Parameters of the bijector will be casted into the same type as input value when `dtype` is None.
52
+
53
+ - When `dtype` is specified, it is forcing the parameters and input value to be the same dtype as `dtype`.
54
+ When the type of parameters or the type of the input value is not the same as `dtype`, a TypeError will be
55
+ raised.
56
+
57
+ - Only subtype of mindspore.float_type can be used to specify bijector's `dtype`.
53
58
 
54
59
  Supported Platforms:
55
60
  ``Ascend`` ``GPU``
@@ -226,7 +231,8 @@ class Bijector(Cell):
226
231
 
227
232
  def cast_param_by_value(self, value, para):
228
233
  """
229
- Cast the parameter(s) of the bijector to be the same type of input_value.
234
+ Converts the data type of `para` in the input to the same type as `value`.
235
+ Typically used by subclasses of Bijector to convert data types of their own parameters.
230
236
 
231
237
  Args:
232
238
  value (Tensor): input value.
@@ -276,7 +282,7 @@ class Bijector(Cell):
276
282
  **kwargs (dict): the dictionary of keyword arguments forwarded to subclasses.
277
283
 
278
284
  Returns:
279
- Tensor, the value of logarithm of the derivative of the forward transformation.
285
+ Tensor, outputs the value of a random variable after mapping.
280
286
  """
281
287
  return self._forward_log_jacobian(value, *args, **kwargs)
282
288
 
@@ -33,11 +33,11 @@ class GumbelCDF(Bijector):
33
33
  name (str): The name of the Bijector. Default: ``'GumbelCDF'`` .
34
34
 
35
35
  Note:
36
- `scale` must be greater than zero.
37
- For `inverse` and `inverse_log_jacobian`, input should be in range of (0, 1).
38
- The dtype of `loc` and `scale` must be float.
39
- If `loc`, `scale` are passed in as numpy.ndarray or tensor, they have to have
40
- the same dtype otherwise an error will be raised.
36
+ - `scale` must be greater than zero.
37
+ - For `inverse` and `inverse_log_jacobian`, input should be in range of (0, 1).
38
+ - The dtype of `loc` and `scale` must be float.
39
+ - If `loc`, `scale` are passed in as numpy.ndarray or tensor, they have to have
40
+ the same dtype otherwise an error will be raised.
41
41
 
42
42
  Raises:
43
43
  TypeError: When the dtype of `loc` or `scale` is not float,
@@ -25,8 +25,8 @@ class Invert(Bijector):
25
25
 
26
26
  Args:
27
27
  bijector (Bijector): Base Bijector.
28
- name (str): The name of the Bijector. Default: ``""`` . When name is set to "", it is actually
29
- 'Invert' + bijector.name.
28
+ name (str): The name of the Bijector. Default: ``""`` . When name is set to ``""``, it is actually
29
+ ``'Invert' + Bijector.name``.
30
30
 
31
31
  Supported Platforms:
32
32
  ``Ascend`` ``GPU`` ``CPU``
@@ -26,11 +26,11 @@ class ScalarAffine(Bijector):
26
26
  .. math::
27
27
  Y = a * X + b
28
28
 
29
- where a is the scale factor and b is the shift factor.
29
+ where :math:`a` is the scale factor and :math:`b` is the shift factor.
30
30
 
31
31
  Args:
32
- scale (float, list, numpy.ndarray, Tensor): The scale factor. Default: ``1.0`` .
33
- shift (float, list, numpy.ndarray, Tensor): The shift factor. Default: ``0.0`` .
32
+ scale (float, list, numpy.ndarray, Tensor): The scale factor. :math:`a` in the formula. Default: ``1.0`` .
33
+ shift (float, list, numpy.ndarray, Tensor): The shift factor. :math:`b` in the formula. Default: ``0.0`` .
34
34
  name (str): The name of the bijector. Default: ``'ScalarAffine'`` .
35
35
 
36
36
  Note:
@@ -29,10 +29,11 @@ class Softplus(Bijector):
29
29
  .. math::
30
30
  Y = \frac{\log(1 + e ^ {kX})}{k}
31
31
 
32
- where k is the sharpness factor.
32
+ where :math:`k` is the sharpness factor.
33
33
 
34
34
  Args:
35
- sharpness (float, list, numpy.ndarray, Tensor): The scale factor. Default: ``1.0`` .
35
+ sharpness (float, list, numpy.ndarray, Tensor): The scale factor. :math:`k` in the above formula.
36
+ Default: ``1.0`` .
36
37
  name (str): The name of the Bijector. Default: ``'Softplus'`` .
37
38
 
38
39
  Note:
@@ -37,9 +37,9 @@ class Beta(Distribution):
37
37
 
38
38
  Args:
39
39
  concentration1 (int, float, list, numpy.ndarray, Tensor): The concentration1,
40
- also know as alpha of the Beta distribution. Default: ``None`` .
40
+ also know as :math:`alpha` of the Beta distribution. Default: ``None`` .
41
41
  concentration0 (int, float, list, numpy.ndarray, Tensor): The concentration0, also know as
42
- beta of the Beta distribution. Default: ``None`` .
42
+ :math:`beta` of the Beta distribution. Default: ``None`` .
43
43
  seed (int): The seed used in sampling. The global seed is used if it is None. Default: ``None`` .
44
44
  dtype (mindspore.dtype): The type of the event samples. Default: ``mstype.float32`` .
45
45
  name (str): The name of the distribution. Default: ``'Beta'`` .
@@ -51,7 +51,7 @@ class Beta(Distribution):
51
51
 
52
52
  Raises:
53
53
  ValueError: When concentration1 <= 0 or concentration0 >=1.
54
- TypeError: When the input `dtype` is not a subclass of float.
54
+ TypeError: When the input `dtype` is not a float or a subclass of float.
55
55
 
56
56
  Supported Platforms:
57
57
  ``Ascend``
@@ -40,7 +40,7 @@ class Categorical(Distribution):
40
40
  probs (Tensor, list, numpy.ndarray): Event probabilities. Default: ``None`` .
41
41
  seed (int): The global seed is used in sampling. Global seed is used if it is None. Default: ``None`` .
42
42
  dtype (mindspore.dtype): The type of the event samples. Default: ``mstype.int32`` .
43
- name (str): The name of the distribution. Default: ``Categorical`` .
43
+ name (str): The name of the distribution. Default: ``'Categorical'`` .
44
44
 
45
45
  Note:
46
46
  `probs` must have rank at least 1, values are proper probabilities and sum to 1.
@@ -35,8 +35,10 @@ class Cauchy(Distribution):
35
35
  Where :math:`a, b` are loc and scale parameter respectively.
36
36
 
37
37
  Args:
38
- loc (int, float, list, numpy.ndarray, Tensor): The location of the Cauchy distribution. Default: ``None`` .
39
- scale (int, float, list, numpy.ndarray, Tensor): The scale of the Cauchy distribution. Default: ``None`` .
38
+ loc (int, float, list, numpy.ndarray, Tensor): The location of the Cauchy distribution.
39
+ :math:`a` in the formula. Default: ``None`` .
40
+ scale (int, float, list, numpy.ndarray, Tensor): The scale of the Cauchy distribution.
41
+ :math:`b` in the formula. Default: ``None`` .
40
42
  seed (int): The seed used in sampling. The global seed is used if it is None. Default: ``None`` .
41
43
  dtype (mindspore.dtype): The type of the event samples. Default: ``mstype.float32`` .
42
44
  name (str): The name of the distribution. Default: ``'Cauchy'`` .
@@ -35,15 +35,14 @@ class Exponential(Distribution):
35
35
  where :math:`\lambda` is the rate of the distribution.
36
36
 
37
37
  Args:
38
- rate (int, float, list, numpy.ndarray, Tensor): The inverse scale. Default: ``None`` .
39
- seed (int): The seed used in sampling. The global seed is used if it is None. Default: ``None`` .
40
- dtype (mindspore.dtype): The type of the event samples. Default: ``mstype.float32`` .
41
- name (str): The name of the distribution. Default: ``'Exponential'`` .
38
+ rate (int, float, list, numpy.ndarray, Tensor, optional): The inverse scale. :math:`\lambda` in the formula. Default: ``None`` .
39
+ seed (int, optional): The seed used in sampling. The global seed is used if it is None. Default: ``None`` .
40
+ dtype (mindspore.dtype, optional): The type of the event samples. Default: ``mstype.float32`` .
41
+ name (str, optional): The name of the distribution. Default: ``'Exponential'`` .
42
42
 
43
43
  Note:
44
- `rate` must be strictly greater than 0.
45
- `dist_spec_args` is `rate`.
46
- `dtype` must be a float type because Exponential distributions are continuous.
44
+ - `rate` must be strictly greater than 0.
45
+ - `dtype` must be a float type because Exponential distributions are continuous.
47
46
 
48
47
  Raises:
49
48
  ValueError: When rate <= 0.
@@ -39,9 +39,9 @@ class Gamma(Distribution):
39
39
 
40
40
  Args:
41
41
  concentration (int, float, list, numpy.ndarray, Tensor): The concentration,
42
- also know as alpha of the Gamma distribution. Default: ``None`` .
42
+ also know as :math:`\alpha` of the Gamma distribution. Default: ``None`` .
43
43
  rate (int, float, list, numpy.ndarray, Tensor): The rate, also know as
44
- beta of the Gamma distribution. Default: ``None`` .
44
+ :math:`\beta` of the Gamma distribution. Default: ``None`` .
45
45
  seed (int): The seed used in sampling. The global seed is used if it is None. Default: ``None`` .
46
46
  dtype (mindspore.dtype): The type of the event samples. Default: ``mstype.float32`` .
47
47
  name (str): The name of the distribution. Default: ``'Gamma'`` .
@@ -37,8 +37,8 @@ class Gumbel(TransformedDistribution):
37
37
  Where :math:`a, b` are loc and scale parameter respectively.
38
38
 
39
39
  Args:
40
- loc (int, float, list, numpy.ndarray, Tensor): The location of Gumbel distribution.
41
- scale (int, float, list, numpy.ndarray, Tensor): The scale of Gumbel distribution.
40
+ loc (int, float, list, numpy.ndarray, Tensor): The location of Gumbel distribution. :math:`a` in the formula.
41
+ scale (int, float, list, numpy.ndarray, Tensor): The scale of Gumbel distribution. :math:`b` in the formula.
42
42
  seed (int): the seed used in sampling. The global seed is used if it is None. Default: ``0`` .
43
43
  dtype (mindspore.dtype): type of the distribution. Default: ``mstype.float32`` .
44
44
  name (str): the name of the distribution. Default: ``'Gumbel'`` .
@@ -36,9 +36,11 @@ class HalfNormal(Distribution):
36
36
  where :math:`\mu, \sigma` are the mean and the standard deviation of the half normal distribution respectively.
37
37
 
38
38
  Args:
39
- mean (Union[int, float, list, numpy.ndarray, Tensor], optional): The mean of the distribution.
39
+ mean (Union[int, float, list, numpy.ndarray, Tensor], optional):
40
+ The mean of the distribution. :math:`\mu` in the formula.
40
41
  If this arg is ``None`` , then the mean of the distribution will be passed in runtime. Default: ``None`` .
41
- sd (Union[int, float, list, numpy.ndarray, Tensor], optional): The standard deviation of the distribution.
42
+ sd (Union[int, float, list, numpy.ndarray, Tensor], optional):
43
+ The standard deviation of the distribution. :math:`\sigma` in the formula.
42
44
  If this arg is ``None`` , then the sd of the distribution will be passed in runtime. Default: ``None`` .
43
45
  seed (int, optional): The seed used in sampling. The global seed is used if it is None. Default: ``None`` .
44
46
  dtype (mindspore.dtype, optional): The type of the event samples. Default: ``mstype.float32`` .
@@ -52,7 +54,7 @@ class HalfNormal(Distribution):
52
54
 
53
55
  Raises:
54
56
  ValueError: When sd <= 0.
55
- TypeError: When the input `dtype` is not a subclass of float.
57
+ TypeError: When the input `dtype` is not a float or a subclass of float.
56
58
 
57
59
  Supported Platforms:
58
60
  ``CPU``
@@ -36,8 +36,10 @@ class Logistic(Distribution):
36
36
  where :math:`a, b` are loc and scale parameter respectively.
37
37
 
38
38
  Args:
39
- loc (float, list, numpy.ndarray, Tensor): The location of the Logistic distribution. Default: ``None`` .
40
- scale (float, list, numpy.ndarray, Tensor): The scale of the Logistic distribution. Default: ``None`` .
39
+ loc (float, list, numpy.ndarray, Tensor): The location of the Logistic distribution.
40
+ :math:`a` in the formula. Default: ``None`` .
41
+ scale (float, list, numpy.ndarray, Tensor): The scale of the Logistic distribution.
42
+ :math:`b` in the formula. Default: ``None`` .
41
43
  seed (int): The seed used in sampling. The global seed is used if it is None. Default: ``None`` .
42
44
  dtype (mindspore.dtype): The type of the event samples. Default: ``mstype.float32`` .
43
45
  name (str): The name of the distribution. Default: ``'Logistic'`` .
@@ -49,7 +51,7 @@ class Logistic(Distribution):
49
51
 
50
52
  Raises:
51
53
  ValueError: When scale <= 0.
52
- TypeError: When the input `dtype` is not a subclass of float.
54
+ TypeError: When the input `dtype` is not a float or a subclass of float.
53
55
 
54
56
  Supported Platforms:
55
57
  ``Ascend`` ``GPU``