mindspore 2.0.0rc1__cp38-none-any.whl → 2.2.0__cp38-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (870) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Third_Party_Open_Source_Software_Notice +2 -2
  3. mindspore/__init__.py +5 -2
  4. mindspore/_akg/akg/build_module.py +5 -6
  5. mindspore/_akg/akg/composite/build_module.py +49 -16
  6. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  7. mindspore/_akg/akg/config/repository.json +195 -0
  8. mindspore/_akg/akg/global_configs.py +5 -1
  9. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  10. mindspore/_akg/akg/tvm/api.py +4 -3
  11. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  12. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  13. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  14. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  15. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  16. mindspore/_akg/akg/tvm/build_module.py +16 -1
  17. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  18. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  19. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  20. mindspore/_akg/akg/tvm/module.py +1 -2
  21. mindspore/_akg/akg/tvm/stmt.py +2 -2
  22. mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
  23. mindspore/_akg/akg/utils/kernel_exec.py +58 -260
  24. mindspore/_akg/akg/utils/op_dsl.py +17 -1
  25. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  26. mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
  27. mindspore/_c_dataengine.cpython-38-aarch64-linux-gnu.so +0 -0
  28. mindspore/_c_expression.cpython-38-aarch64-linux-gnu.so +0 -0
  29. mindspore/_c_mindrecord.cpython-38-aarch64-linux-gnu.so +0 -0
  30. mindspore/_check_jit_forbidden_api.py +5 -1
  31. mindspore/_checkparam.py +79 -62
  32. mindspore/_extends/graph_kernel/__init__.py +0 -1
  33. mindspore/_extends/graph_kernel/model/graph_split.py +2 -0
  34. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  35. mindspore/_extends/graph_kernel/splitter.py +1 -9
  36. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +128 -21
  37. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
  38. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  39. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +18 -13
  40. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +13 -9
  41. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  42. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  43. mindspore/_extends/parse/__init__.py +19 -17
  44. mindspore/_extends/parse/namespace.py +7 -36
  45. mindspore/_extends/parse/parser.py +375 -189
  46. mindspore/_extends/parse/resources.py +36 -41
  47. mindspore/_extends/parse/standard_method.py +350 -245
  48. mindspore/_extends/parse/trope.py +2 -12
  49. mindspore/_extends/remote/kernel_build_server.py +24 -7
  50. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  51. mindspore/_install_custom.py +43 -0
  52. mindspore/_mindspore_offline_debug.cpython-38-aarch64-linux-gnu.so +0 -0
  53. mindspore/amp.py +85 -19
  54. mindspore/bin/cache_admin +0 -0
  55. mindspore/bin/cache_server +0 -0
  56. mindspore/boost/base.py +2 -2
  57. mindspore/boost/boost.py +27 -32
  58. mindspore/boost/boost_cell_wrapper.py +37 -13
  59. mindspore/boost/grad_accumulation.py +1 -1
  60. mindspore/boost/grad_freeze.py +34 -6
  61. mindspore/boost/group_loss_scale_manager.py +15 -14
  62. mindspore/boost/less_batch_normalization.py +28 -3
  63. mindspore/common/__init__.py +15 -11
  64. mindspore/common/_auto_dynamic.py +68 -0
  65. mindspore/common/_jit_fallback_utils.py +111 -0
  66. mindspore/common/_register_for_adapter.py +17 -5
  67. mindspore/common/_register_for_tensor.py +2 -2
  68. mindspore/common/_stub_tensor.py +18 -15
  69. mindspore/common/_utils.py +31 -7
  70. mindspore/common/api.py +269 -101
  71. mindspore/common/auto_dynamic_shape.py +498 -0
  72. mindspore/common/dtype.py +61 -21
  73. mindspore/common/dump.py +9 -7
  74. mindspore/common/initializer.py +106 -76
  75. mindspore/common/jit_config.py +35 -14
  76. mindspore/common/lazy_inline.py +187 -0
  77. mindspore/common/mindir_util.py +101 -0
  78. mindspore/common/mutable.py +10 -13
  79. mindspore/common/parameter.py +246 -55
  80. mindspore/common/seed.py +13 -7
  81. mindspore/common/sparse_tensor.py +29 -33
  82. mindspore/common/tensor.py +907 -251
  83. mindspore/communication/__init__.py +7 -4
  84. mindspore/communication/_comm_helper.py +84 -4
  85. mindspore/communication/management.py +160 -88
  86. mindspore/config/op_info.config +99 -75
  87. mindspore/config/super_bar_config.json +36 -4
  88. mindspore/context.py +526 -219
  89. mindspore/dataset/__init__.py +9 -46
  90. mindspore/dataset/audio/__init__.py +4 -19
  91. mindspore/dataset/audio/transforms.py +545 -233
  92. mindspore/dataset/audio/utils.py +21 -18
  93. mindspore/dataset/callback/ds_callback.py +42 -13
  94. mindspore/dataset/core/config.py +158 -100
  95. mindspore/dataset/core/validator_helpers.py +1 -63
  96. mindspore/dataset/debug/debug_hook.py +45 -13
  97. mindspore/dataset/debug/pre_defined_hook.py +5 -5
  98. mindspore/dataset/engine/__init__.py +0 -5
  99. mindspore/dataset/engine/cache_client.py +38 -15
  100. mindspore/dataset/engine/datasets.py +615 -278
  101. mindspore/dataset/engine/datasets_audio.py +154 -283
  102. mindspore/dataset/engine/datasets_standard_format.py +104 -116
  103. mindspore/dataset/engine/datasets_text.py +443 -326
  104. mindspore/dataset/engine/datasets_user_defined.py +251 -164
  105. mindspore/dataset/engine/datasets_vision.py +839 -1443
  106. mindspore/dataset/engine/iterators.py +11 -4
  107. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +7 -3
  108. mindspore/dataset/engine/obs/util.py +3 -0
  109. mindspore/dataset/engine/offload.py +6 -6
  110. mindspore/dataset/engine/queue.py +15 -14
  111. mindspore/dataset/engine/samplers.py +39 -23
  112. mindspore/dataset/engine/serializer_deserializer.py +22 -6
  113. mindspore/dataset/engine/validators.py +21 -331
  114. mindspore/dataset/text/__init__.py +5 -33
  115. mindspore/dataset/text/transforms.py +334 -165
  116. mindspore/dataset/text/utils.py +215 -145
  117. mindspore/dataset/transforms/__init__.py +1 -1
  118. mindspore/dataset/transforms/c_transforms.py +3 -2
  119. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  120. mindspore/dataset/transforms/transforms.py +174 -71
  121. mindspore/dataset/utils/browse_dataset.py +25 -17
  122. mindspore/dataset/utils/line_reader.py +24 -21
  123. mindspore/dataset/vision/__init__.py +5 -26
  124. mindspore/dataset/vision/c_transforms.py +177 -165
  125. mindspore/dataset/vision/py_transforms.py +114 -119
  126. mindspore/dataset/vision/py_transforms_util.py +54 -51
  127. mindspore/dataset/vision/transforms.py +1127 -381
  128. mindspore/dataset/vision/utils.py +54 -38
  129. mindspore/dataset/vision/validators.py +12 -2
  130. mindspore/experimental/map_parameter.py +38 -4
  131. mindspore/{dataset/datapreprocess → experimental/optim}/__init__.py +14 -4
  132. mindspore/experimental/optim/adam.py +192 -0
  133. mindspore/experimental/optim/adamw.py +181 -0
  134. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  135. mindspore/experimental/optim/optimizer.py +252 -0
  136. mindspore/experimental/optim/sgd.py +147 -0
  137. mindspore/gen_ops.py +273 -0
  138. mindspore/include/OWNERS +1 -2
  139. mindspore/include/api/context.h +21 -1
  140. mindspore/include/api/data_type.h +2 -1
  141. mindspore/include/api/graph.h +0 -15
  142. mindspore/include/api/kernel.h +2 -0
  143. mindspore/include/api/kernel_api.h +37 -12
  144. mindspore/include/api/model.h +29 -42
  145. mindspore/include/api/model_group.h +14 -3
  146. mindspore/include/api/model_parallel_runner.h +18 -2
  147. mindspore/include/api/serialization.h +26 -0
  148. mindspore/include/api/status.h +1 -0
  149. mindspore/include/api/types.h +38 -4
  150. mindspore/include/c_api/ms/abstract.h +67 -0
  151. mindspore/include/c_api/ms/attribute.h +197 -0
  152. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  153. mindspore/include/c_api/ms/base/macros.h +32 -0
  154. mindspore/include/c_api/ms/base/status.h +33 -0
  155. mindspore/include/c_api/ms/base/types.h +282 -0
  156. mindspore/include/c_api/ms/context.h +102 -0
  157. mindspore/include/c_api/ms/graph.h +160 -0
  158. mindspore/include/c_api/ms/node.h +606 -0
  159. mindspore/include/c_api/ms/tensor.h +161 -0
  160. mindspore/include/c_api/ms/value.h +84 -0
  161. mindspore/include/c_api/status_c.h +3 -0
  162. mindspore/include/dataset/constants.h +6 -12
  163. mindspore/include/dataset/execute.h +23 -13
  164. mindspore/include/dataset/text.h +26 -26
  165. mindspore/include/dataset/transforms.h +25 -31
  166. mindspore/include/dataset/vision.h +60 -60
  167. mindspore/include/dataset/vision_ascend.h +5 -6
  168. mindspore/include/dataset/vision_lite.h +17 -17
  169. mindspore/include/mindapi/base/format.h +0 -1
  170. mindspore/include/mindapi/base/type_id.h +2 -1
  171. mindspore/include/mindapi/base/types.h +5 -1
  172. mindspore/lib/libdnnl.so.2 +0 -0
  173. mindspore/lib/libjemalloc.so.2 +0 -0
  174. mindspore/lib/libmindspore.so +0 -0
  175. mindspore/lib/libmindspore_backend.so +0 -0
  176. mindspore/lib/libmindspore_common.so +0 -0
  177. mindspore/lib/libmindspore_core.so +0 -0
  178. mindspore/lib/libmindspore_glog.so.0 +0 -0
  179. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  180. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  181. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  182. mindspore/lib/libmindspore_shared_lib.so +0 -0
  183. mindspore/lib/libmpi_adapter.so +0 -0
  184. mindspore/lib/libnnacl.so +0 -0
  185. mindspore/lib/libopencv_core.so.4.5 +0 -0
  186. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  187. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  188. mindspore/lib/libps_cache.so +0 -0
  189. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  190. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  191. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
  192. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  193. mindspore/lib/plugin/ascend/libakg.so +0 -0
  194. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  195. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  196. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  197. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  198. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  199. mindspore/lib/plugin/cpu/libakg.so +0 -0
  200. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  201. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  202. mindspore/log.py +9 -6
  203. mindspore/mindrecord/filereader.py +33 -4
  204. mindspore/mindrecord/filewriter.py +70 -35
  205. mindspore/mindrecord/mindpage.py +40 -34
  206. mindspore/mindrecord/shardreader.py +1 -1
  207. mindspore/mindrecord/shardsegment.py +1 -1
  208. mindspore/mindrecord/tools/cifar100_to_mr.py +25 -18
  209. mindspore/mindrecord/tools/cifar10_to_mr.py +25 -18
  210. mindspore/mindrecord/tools/csv_to_mr.py +29 -13
  211. mindspore/mindrecord/tools/imagenet_to_mr.py +24 -10
  212. mindspore/mindrecord/tools/mnist_to_mr.py +24 -11
  213. mindspore/mindrecord/tools/tfrecord_to_mr.py +31 -26
  214. mindspore/nn/cell.py +463 -169
  215. mindspore/nn/dynamic_lr.py +47 -43
  216. mindspore/nn/layer/activation.py +225 -82
  217. mindspore/nn/layer/basic.py +121 -79
  218. mindspore/nn/layer/channel_shuffle.py +21 -21
  219. mindspore/nn/layer/combined.py +33 -26
  220. mindspore/nn/layer/container.py +277 -22
  221. mindspore/nn/layer/conv.py +441 -304
  222. mindspore/nn/layer/dense.py +19 -13
  223. mindspore/nn/layer/embedding.py +62 -49
  224. mindspore/nn/layer/flash_attention.py +264 -0
  225. mindspore/nn/layer/image.py +50 -39
  226. mindspore/nn/layer/math.py +62 -51
  227. mindspore/nn/layer/normalization.py +219 -167
  228. mindspore/nn/layer/padding.py +58 -70
  229. mindspore/nn/layer/pooling.py +334 -287
  230. mindspore/nn/layer/rnn_cells.py +53 -38
  231. mindspore/nn/layer/rnns.py +59 -56
  232. mindspore/nn/layer/thor_layer.py +52 -44
  233. mindspore/nn/layer/timedistributed.py +6 -4
  234. mindspore/nn/layer/transformer.py +284 -164
  235. mindspore/nn/learning_rate_schedule.py +34 -25
  236. mindspore/nn/loss/__init__.py +3 -2
  237. mindspore/nn/loss/loss.py +554 -311
  238. mindspore/nn/optim/ada_grad.py +12 -9
  239. mindspore/nn/optim/adadelta.py +14 -11
  240. mindspore/nn/optim/adafactor.py +19 -16
  241. mindspore/nn/optim/adam.py +62 -47
  242. mindspore/nn/optim/adamax.py +13 -10
  243. mindspore/nn/optim/adasum.py +12 -8
  244. mindspore/nn/optim/asgd.py +10 -9
  245. mindspore/nn/optim/ftrl.py +20 -17
  246. mindspore/nn/optim/lamb.py +16 -12
  247. mindspore/nn/optim/lars.py +8 -6
  248. mindspore/nn/optim/lazyadam.py +25 -20
  249. mindspore/nn/optim/momentum.py +10 -7
  250. mindspore/nn/optim/optimizer.py +61 -9
  251. mindspore/nn/optim/proximal_ada_grad.py +14 -13
  252. mindspore/nn/optim/rmsprop.py +17 -13
  253. mindspore/nn/optim/rprop.py +30 -17
  254. mindspore/nn/optim/sgd.py +40 -23
  255. mindspore/nn/optim/thor.py +24 -26
  256. mindspore/nn/probability/bijector/bijector.py +11 -11
  257. mindspore/nn/probability/bijector/exp.py +1 -1
  258. mindspore/nn/probability/bijector/gumbel_cdf.py +3 -3
  259. mindspore/nn/probability/bijector/invert.py +1 -1
  260. mindspore/nn/probability/bijector/power_transform.py +29 -29
  261. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  262. mindspore/nn/probability/bijector/softplus.py +5 -5
  263. mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +4 -2
  264. mindspore/nn/probability/bnn_layers/conv_variational.py +13 -13
  265. mindspore/nn/probability/bnn_layers/dense_variational.py +12 -12
  266. mindspore/nn/probability/bnn_layers/layer_distribution.py +9 -8
  267. mindspore/nn/probability/distribution/_utils/custom_ops.py +19 -3
  268. mindspore/nn/probability/distribution/_utils/utils.py +1 -1
  269. mindspore/nn/probability/distribution/bernoulli.py +9 -9
  270. mindspore/nn/probability/distribution/beta.py +8 -8
  271. mindspore/nn/probability/distribution/categorical.py +23 -15
  272. mindspore/nn/probability/distribution/cauchy.py +5 -6
  273. mindspore/nn/probability/distribution/distribution.py +3 -3
  274. mindspore/nn/probability/distribution/exponential.py +4 -4
  275. mindspore/nn/probability/distribution/gamma.py +10 -10
  276. mindspore/nn/probability/distribution/geometric.py +8 -8
  277. mindspore/nn/probability/distribution/gumbel.py +8 -9
  278. mindspore/nn/probability/distribution/half_normal.py +5 -5
  279. mindspore/nn/probability/distribution/laplace.py +5 -5
  280. mindspore/nn/probability/distribution/log_normal.py +12 -11
  281. mindspore/nn/probability/distribution/logistic.py +8 -8
  282. mindspore/nn/probability/distribution/normal.py +6 -5
  283. mindspore/nn/probability/distribution/poisson.py +10 -11
  284. mindspore/nn/probability/distribution/student_t.py +8 -9
  285. mindspore/nn/probability/distribution/transformed_distribution.py +5 -5
  286. mindspore/nn/probability/distribution/uniform.py +11 -11
  287. mindspore/nn/reinforcement/tensor_array.py +2 -2
  288. mindspore/nn/sparse/sparse.py +9 -9
  289. mindspore/nn/wrap/cell_wrapper.py +188 -63
  290. mindspore/nn/wrap/grad_reducer.py +21 -12
  291. mindspore/nn/wrap/loss_scale.py +136 -49
  292. mindspore/numpy/__init__.py +4 -4
  293. mindspore/numpy/array_creations.py +55 -56
  294. mindspore/numpy/array_ops.py +134 -35
  295. mindspore/numpy/logic_ops.py +66 -20
  296. mindspore/numpy/math_ops.py +142 -139
  297. mindspore/numpy/utils_const.py +2 -2
  298. mindspore/offline_debug/convert_async.py +2 -2
  299. mindspore/ops/_grad_experimental/__init__.py +7 -5
  300. mindspore/ops/_grad_experimental/grad_array_ops.py +231 -348
  301. mindspore/ops/{_grad → _grad_experimental}/grad_base.py +1 -33
  302. mindspore/ops/{_grad → _grad_experimental}/grad_comm_ops.py +25 -13
  303. mindspore/ops/{_grad/__init__.py → _grad_experimental/grad_debug_ops.py} +15 -7
  304. mindspore/ops/{_grad → _grad_experimental}/grad_implementations.py +17 -11
  305. mindspore/ops/_grad_experimental/grad_inner_ops.py +33 -52
  306. mindspore/ops/_grad_experimental/grad_math_ops.py +151 -1224
  307. mindspore/ops/_grad_experimental/grad_nn_ops.py +141 -414
  308. mindspore/ops/{_grad → _grad_experimental}/grad_quant_ops.py +10 -6
  309. mindspore/ops/_grad_experimental/grad_sparse.py +317 -2
  310. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -13
  311. mindspore/ops/{_grad → _grad_experimental}/taylor_rule.py +1 -1
  312. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  313. mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
  314. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +406 -0
  315. mindspore/{_extends/graph_kernel/expanders/complex/__init__.py → ops/_op_impl/_custom_op/flash_attention/constants.py} +27 -8
  316. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +467 -0
  317. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +563 -0
  318. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +193 -0
  319. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +435 -0
  320. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
  321. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +45 -0
  322. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +67 -0
  323. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +62 -0
  324. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
  325. mindspore/ops/_op_impl/aicpu/__init__.py +41 -1
  326. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
  327. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  328. mindspore/ops/_op_impl/aicpu/cast.py +52 -0
  329. mindspore/ops/_op_impl/aicpu/coalesce.py +2 -0
  330. mindspore/ops/_op_impl/aicpu/col2im.py +3 -1
  331. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  332. mindspore/ops/_op_impl/aicpu/dropout_genmask.py +6 -0
  333. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  334. mindspore/ops/_op_impl/aicpu/eye.py +4 -4
  335. mindspore/ops/_op_impl/aicpu/fft_with_size.py +6 -0
  336. mindspore/ops/_op_impl/aicpu/fill_diagonal.py +5 -0
  337. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  338. mindspore/ops/_op_impl/aicpu/im2col.py +3 -5
  339. mindspore/ops/_op_impl/aicpu/lgamma.py +1 -0
  340. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  341. mindspore/ops/_op_impl/aicpu/lu.py +39 -0
  342. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  343. mindspore/ops/_op_impl/aicpu/masked_scatter.py +1 -0
  344. mindspore/ops/_op_impl/aicpu/masked_select_grad.py +3 -0
  345. mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
  346. mindspore/ops/_op_impl/aicpu/matrix_power.py +6 -1
  347. mindspore/ops/_op_impl/aicpu/median.py +1 -0
  348. mindspore/ops/_op_impl/aicpu/multinomial.py +9 -9
  349. mindspore/ops/_op_impl/aicpu/not_equal.py +0 -5
  350. mindspore/ops/_op_impl/aicpu/pad_v3.py +3 -1
  351. mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +2 -0
  352. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  353. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  354. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  355. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  356. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  357. mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +0 -1
  358. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +0 -6
  359. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +0 -7
  360. mindspore/ops/_op_impl/aicpu/scatter_nd.py +2 -0
  361. mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
  362. mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
  363. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  364. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  365. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -4
  366. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -4
  367. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  368. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  369. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  370. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  371. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  372. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +14 -6
  373. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +22 -8
  374. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +11 -6
  375. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +21 -10
  376. mindspore/ops/_op_impl/tbe/__init__.py +6 -4
  377. mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
  378. mindspore/ops/_op_impl/tbe/avg_pool.py +2 -2
  379. mindspore/ops/_op_impl/tbe/avg_pool_3d.py +3 -3
  380. mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +4 -4
  381. mindspore/ops/_op_impl/tbe/avg_pool_ds.py +2 -2
  382. mindspore/ops/_op_impl/tbe/avg_pool_grad.py +3 -3
  383. mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +3 -3
  384. mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
  385. mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +2 -2
  386. mindspore/ops/_op_impl/tbe/bn_infer.py +2 -2
  387. mindspore/ops/_op_impl/tbe/bn_infer_ds.py +3 -2
  388. mindspore/ops/_op_impl/tbe/broadcast_to.py +1 -1
  389. mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +3 -3
  390. mindspore/ops/_op_impl/tbe/expand_dims.py +1 -1
  391. mindspore/ops/_op_impl/tbe/gather_v2.py +56 -0
  392. mindspore/ops/_op_impl/tbe/im2col.py +4 -4
  393. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  394. mindspore/ops/_op_impl/tbe/mem_set.py +38 -0
  395. mindspore/ops/_op_impl/tbe/scatter_nd_add.py +3 -0
  396. mindspore/ops/_op_impl/tbe/scatter_nd_d.py +1 -1
  397. mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
  398. mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +2 -2
  399. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  400. mindspore/ops/_primitive_cache.py +1 -1
  401. mindspore/ops/_tracefunc.py +241 -0
  402. mindspore/ops/_utils/utils.py +10 -2
  403. mindspore/ops/_vmap/vmap_array_ops.py +5 -3
  404. mindspore/ops/_vmap/vmap_base.py +5 -4
  405. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  406. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  407. mindspore/ops/_vmap/vmap_grad_nn_ops.py +11 -6
  408. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  409. mindspore/ops/_vmap/vmap_nn_ops.py +135 -11
  410. mindspore/ops/arg_dtype_cast.py +54 -0
  411. mindspore/ops/composite/__init__.py +7 -5
  412. mindspore/ops/composite/base.py +78 -34
  413. mindspore/ops/composite/math_ops.py +5 -695
  414. mindspore/ops/composite/multitype_ops/_compile_utils.py +403 -97
  415. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +28 -22
  416. mindspore/ops/composite/multitype_ops/add_impl.py +69 -7
  417. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  418. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  419. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -0
  420. mindspore/ops/composite/multitype_ops/div_impl.py +1 -0
  421. mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -0
  422. mindspore/ops/composite/multitype_ops/getitem_impl.py +48 -10
  423. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +2 -0
  424. mindspore/ops/composite/multitype_ops/greater_impl.py +2 -0
  425. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -0
  426. mindspore/ops/composite/multitype_ops/less_equal_impl.py +2 -0
  427. mindspore/ops/composite/multitype_ops/less_impl.py +2 -0
  428. mindspore/ops/composite/multitype_ops/logic_not_impl.py +2 -2
  429. mindspore/ops/composite/multitype_ops/mod_impl.py +1 -0
  430. mindspore/ops/composite/multitype_ops/mul_impl.py +1 -0
  431. mindspore/ops/composite/multitype_ops/negative_impl.py +1 -0
  432. mindspore/ops/composite/multitype_ops/not_in_impl.py +1 -0
  433. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  434. mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
  435. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -0
  436. mindspore/ops/composite/multitype_ops/setitem_impl.py +10 -7
  437. mindspore/ops/composite/multitype_ops/sub_impl.py +1 -0
  438. mindspore/ops/composite/multitype_ops/uadd_impl.py +2 -0
  439. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  440. mindspore/ops/deprecated.py +304 -0
  441. mindspore/ops/function/__init__.py +41 -4
  442. mindspore/ops/function/array_func.py +1108 -467
  443. mindspore/ops/function/clip_func.py +94 -27
  444. mindspore/ops/function/debug_func.py +3 -1
  445. mindspore/ops/function/grad/grad_func.py +82 -73
  446. mindspore/ops/function/image_func.py +28 -12
  447. mindspore/ops/function/linalg_func.py +135 -39
  448. mindspore/ops/function/math_func.py +3779 -894
  449. mindspore/ops/function/nn_func.py +1584 -657
  450. mindspore/ops/function/parameter_func.py +13 -3
  451. mindspore/ops/function/random_func.py +247 -153
  452. mindspore/ops/function/sparse_func.py +14 -11
  453. mindspore/ops/function/sparse_unary_func.py +173 -47
  454. mindspore/ops/function/spectral_func.py +8 -4
  455. mindspore/ops/function/vmap_func.py +8 -7
  456. mindspore/ops/functional.py +47 -16
  457. mindspore/ops/op_info_register.py +346 -86
  458. mindspore/ops/operations/__init__.py +38 -22
  459. mindspore/ops/operations/_grad_ops.py +145 -149
  460. mindspore/ops/operations/_inner_ops.py +298 -56
  461. mindspore/ops/operations/_ms_kernel.py +3 -3
  462. mindspore/ops/operations/_quant_ops.py +24 -28
  463. mindspore/ops/operations/_rl_inner_ops.py +9 -7
  464. mindspore/ops/operations/_scalar_ops.py +115 -0
  465. mindspore/ops/operations/_sequence_ops.py +148 -10
  466. mindspore/ops/operations/_tensor_array.py +1 -1
  467. mindspore/ops/operations/_thor_ops.py +2 -2
  468. mindspore/ops/operations/array_ops.py +1239 -561
  469. mindspore/ops/operations/comm_ops.py +166 -90
  470. mindspore/ops/operations/control_ops.py +3 -3
  471. mindspore/ops/operations/custom_ops.py +124 -102
  472. mindspore/ops/operations/debug_ops.py +24 -11
  473. mindspore/ops/operations/image_ops.py +86 -71
  474. mindspore/ops/operations/inner_ops.py +18 -13
  475. mindspore/ops/operations/linalg_ops.py +30 -11
  476. mindspore/ops/operations/math_ops.py +1730 -435
  477. mindspore/ops/operations/nn_ops.py +1953 -943
  478. mindspore/ops/operations/other_ops.py +65 -43
  479. mindspore/ops/operations/random_ops.py +258 -98
  480. mindspore/ops/operations/rl_ops.py +4 -36
  481. mindspore/ops/operations/sparse_ops.py +38 -33
  482. mindspore/ops/operations/spectral_ops.py +8 -4
  483. mindspore/ops/primitive.py +66 -44
  484. mindspore/ops/signature.py +5 -5
  485. mindspore/parallel/_auto_parallel_context.py +80 -19
  486. mindspore/parallel/_cost_model_context.py +42 -0
  487. mindspore/parallel/_offload_context.py +162 -72
  488. mindspore/parallel/_parallel_serialization.py +2 -2
  489. mindspore/parallel/_ps_context.py +16 -4
  490. mindspore/parallel/_recovery_context.py +2 -1
  491. mindspore/parallel/_tensor.py +15 -13
  492. mindspore/parallel/_transformer/layers.py +8 -6
  493. mindspore/parallel/_transformer/loss.py +1 -0
  494. mindspore/parallel/_transformer/moe.py +7 -7
  495. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  496. mindspore/parallel/_transformer/transformer.py +34 -14
  497. mindspore/parallel/_utils.py +36 -14
  498. mindspore/parallel/algo_parameter_config.py +114 -20
  499. mindspore/parallel/checkpoint_transform.py +16 -18
  500. mindspore/parallel/shard.py +16 -13
  501. mindspore/profiler/__init__.py +1 -1
  502. mindspore/profiler/common/struct_type.py +3 -3
  503. mindspore/profiler/common/util.py +3 -2
  504. mindspore/profiler/envprofiling.py +11 -4
  505. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  506. mindspore/profiler/parser/ascend_flops_generator.py +94 -0
  507. mindspore/profiler/parser/ascend_fpbp_generator.py +76 -0
  508. mindspore/profiler/parser/ascend_hccl_generator.py +288 -0
  509. mindspore/profiler/parser/ascend_msprof_exporter.py +213 -0
  510. mindspore/profiler/parser/ascend_msprof_generator.py +199 -0
  511. mindspore/profiler/parser/ascend_op_generator.py +276 -0
  512. mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
  513. mindspore/profiler/parser/ascend_timeline_generator.py +110 -54
  514. mindspore/profiler/parser/base_timeline_generator.py +11 -7
  515. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +45 -46
  516. mindspore/profiler/parser/flops_parser.py +15 -11
  517. mindspore/profiler/parser/framework_parser.py +92 -73
  518. mindspore/profiler/parser/hccl_parser.py +16 -12
  519. mindspore/profiler/parser/integrator.py +22 -11
  520. mindspore/profiler/parser/memory_usage_parser.py +36 -11
  521. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  522. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  523. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  524. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  525. mindspore/profiler/parser/optime_parser.py +1 -1
  526. mindspore/profiler/parser/profiler_info.py +4 -5
  527. mindspore/profiler/parser/step_trace_parser.py +11 -14
  528. mindspore/profiler/profiling.py +678 -377
  529. mindspore/rewrite/api/node.py +211 -54
  530. mindspore/rewrite/api/node_type.py +5 -0
  531. mindspore/rewrite/api/pattern_engine.py +22 -23
  532. mindspore/rewrite/api/scoped_value.py +20 -17
  533. mindspore/rewrite/api/symbol_tree.py +252 -106
  534. mindspore/rewrite/api/tree_node_helper.py +3 -0
  535. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  536. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  537. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  538. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +97 -46
  539. mindspore/rewrite/common/rewrite_elog.py +5 -1
  540. mindspore/rewrite/namer.py +51 -51
  541. mindspore/rewrite/namespace.py +14 -5
  542. mindspore/{ops/bprop_mindir → rewrite/node}/__init__.py +9 -4
  543. mindspore/rewrite/node/call_function.py +79 -0
  544. mindspore/rewrite/node/cell_container.py +135 -0
  545. mindspore/rewrite/node/control_flow.py +88 -0
  546. mindspore/rewrite/{node.py → node/node.py} +313 -247
  547. mindspore/rewrite/node/node_manager.py +254 -0
  548. mindspore/rewrite/node/node_topological_manager.py +243 -0
  549. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  550. mindspore/rewrite/parsers/assign_parser.py +225 -239
  551. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  552. mindspore/rewrite/parsers/class_def_parser.py +179 -218
  553. mindspore/rewrite/parsers/constant_parser.py +9 -6
  554. mindspore/rewrite/parsers/container_parser.py +9 -7
  555. mindspore/rewrite/parsers/for_parser.py +36 -15
  556. mindspore/rewrite/parsers/function_def_parser.py +23 -20
  557. mindspore/rewrite/parsers/if_parser.py +28 -24
  558. mindspore/rewrite/parsers/module_parser.py +202 -25
  559. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  560. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  561. mindspore/rewrite/parsers/return_parser.py +6 -6
  562. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  563. mindspore/rewrite/sparsify/sparsify.py +4 -1
  564. mindspore/rewrite/sparsify/utils.py +11 -5
  565. mindspore/rewrite/symbol_tree.py +577 -732
  566. mindspore/rewrite/symbol_tree_builder.py +9 -175
  567. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  568. mindspore/run_check/_check_version.py +46 -39
  569. mindspore/run_check/run_check.py +3 -2
  570. mindspore/{scipy/sparse → safeguard}/__init__.py +4 -5
  571. mindspore/safeguard/rewrite_obfuscation.py +517 -0
  572. mindspore/scipy/__init__.py +1 -1
  573. mindspore/scipy/linalg.py +67 -61
  574. mindspore/scipy/ops.py +5 -41
  575. mindspore/scipy/ops_grad.py +3 -2
  576. mindspore/scipy/ops_wrapper.py +5 -5
  577. mindspore/scipy/optimize/line_search.py +8 -8
  578. mindspore/scipy/optimize/linear_sum_assignment.py +4 -4
  579. mindspore/scipy/optimize/minimize.py +16 -12
  580. mindspore/scipy/utils.py +1 -52
  581. mindspore/scipy/utils_const.py +4 -4
  582. mindspore/train/__init__.py +4 -4
  583. mindspore/train/_utils.py +13 -5
  584. mindspore/train/amp.py +410 -148
  585. mindspore/train/anf_ir_pb2.py +16 -4
  586. mindspore/train/callback/_backup_and_restore.py +8 -11
  587. mindspore/train/callback/_callback.py +80 -3
  588. mindspore/train/callback/_checkpoint.py +82 -51
  589. mindspore/train/callback/_early_stop.py +12 -15
  590. mindspore/train/callback/_history.py +1 -1
  591. mindspore/train/callback/_lambda_callback.py +13 -13
  592. mindspore/train/callback/_landscape.py +21 -17
  593. mindspore/train/callback/_loss_monitor.py +9 -10
  594. mindspore/train/callback/_on_request_exit.py +16 -33
  595. mindspore/train/callback/_reduce_lr_on_plateau.py +21 -24
  596. mindspore/train/callback/_summary_collector.py +44 -30
  597. mindspore/train/callback/_time_monitor.py +62 -12
  598. mindspore/train/data_sink.py +10 -16
  599. mindspore/train/dataset_helper.py +154 -86
  600. mindspore/train/loss_scale_manager.py +14 -9
  601. mindspore/train/metrics/__init__.py +10 -2
  602. mindspore/train/metrics/accuracy.py +1 -1
  603. mindspore/train/metrics/auc.py +1 -1
  604. mindspore/train/metrics/bleu_score.py +2 -2
  605. mindspore/train/metrics/confusion_matrix.py +14 -14
  606. mindspore/train/metrics/cosine_similarity.py +3 -3
  607. mindspore/train/metrics/dice.py +1 -1
  608. mindspore/train/metrics/fbeta.py +1 -1
  609. mindspore/train/metrics/hausdorff_distance.py +8 -6
  610. mindspore/train/metrics/mean_surface_distance.py +5 -4
  611. mindspore/train/metrics/metric.py +49 -17
  612. mindspore/train/metrics/occlusion_sensitivity.py +4 -4
  613. mindspore/train/metrics/perplexity.py +1 -1
  614. mindspore/train/metrics/precision.py +2 -2
  615. mindspore/train/metrics/recall.py +2 -3
  616. mindspore/train/metrics/roc.py +7 -7
  617. mindspore/train/metrics/root_mean_square_surface_distance.py +5 -4
  618. mindspore/train/metrics/topk.py +7 -4
  619. mindspore/train/mind_ir_pb2.py +193 -48
  620. mindspore/train/model.py +377 -133
  621. mindspore/train/serialization.py +697 -245
  622. mindspore/train/summary/_summary_adapter.py +5 -2
  623. mindspore/train/summary/_writer_pool.py +4 -3
  624. mindspore/train/summary/summary_record.py +25 -23
  625. mindspore/train/train_thor/convert_utils.py +39 -23
  626. mindspore/train/train_thor/dataset_helper.py +4 -3
  627. mindspore/train/train_thor/model_thor.py +8 -8
  628. mindspore/version.py +1 -1
  629. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/METADATA +7 -8
  630. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/RECORD +633 -804
  631. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
  632. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  633. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  634. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  635. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  636. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  637. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  638. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  639. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  640. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  641. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  642. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  643. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  644. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  645. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  646. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  647. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  648. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  649. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  650. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  651. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  652. mindspore/_extends/graph_kernel/expander.py +0 -80
  653. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -57
  654. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  655. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  656. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  657. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  658. mindspore/_extends/graph_kernel/expanders/bias_add_grad.py +0 -49
  659. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  660. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  661. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  662. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  663. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  664. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  665. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  666. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  667. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  668. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  669. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  670. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  671. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  672. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  673. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  674. mindspore/_extends/graph_kernel/expanders/gather.py +0 -43
  675. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  676. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  677. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  678. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  679. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  680. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  681. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  682. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  683. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  684. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  685. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  686. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  687. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  688. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  689. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  690. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  691. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  692. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  693. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  694. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  695. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  696. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  697. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  698. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  699. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  700. mindspore/_extends/graph_kernel/expanders/tile.py +0 -54
  701. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  702. mindspore/_extends/parse/jit_fallback_modules.py +0 -51
  703. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  704. mindspore/dataset/engine/graphdata.py +0 -1586
  705. mindspore/include/api/net.h +0 -142
  706. mindspore/ops/_grad/grad_array_ops.py +0 -1347
  707. mindspore/ops/_grad/grad_clip_ops.py +0 -84
  708. mindspore/ops/_grad/grad_debug_ops.py +0 -68
  709. mindspore/ops/_grad/grad_inner_ops.py +0 -235
  710. mindspore/ops/_grad/grad_math_ops.py +0 -1684
  711. mindspore/ops/_grad/grad_nn_ops.py +0 -1529
  712. mindspore/ops/_grad/grad_other_ops.py +0 -89
  713. mindspore/ops/_grad/grad_sequence_ops.py +0 -296
  714. mindspore/ops/_grad/grad_sparse.py +0 -323
  715. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -249
  716. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -195
  717. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  718. mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
  719. mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
  720. mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +0 -19
  721. mindspore/ops/bprop_mindir/Argmax_bprop.mindir +0 -15
  722. mindspore/ops/bprop_mindir/Argmin_bprop.mindir +0 -15
  723. mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +0 -19
  724. mindspore/ops/bprop_mindir/Assign_bprop.mindir +0 -17
  725. mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +0 -150
  726. mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +0 -66
  727. mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
  728. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -15
  729. mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
  730. mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +0 -28
  731. mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
  732. mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +0 -33
  733. mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +0 -306
  734. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -13
  735. mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
  736. mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
  737. mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +0 -240
  738. mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +0 -247
  739. mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +0 -247
  740. mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +0 -315
  741. mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +0 -278
  742. mindspore/ops/bprop_mindir/DType_bprop.mindir +0 -14
  743. mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +0 -58
  744. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -13
  745. mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +0 -23
  746. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  747. mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +0 -15
  748. mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
  749. mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
  750. mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +0 -25
  751. mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +0 -18
  752. mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +0 -27
  753. mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
  754. mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
  755. mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
  756. mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +0 -14
  757. mindspore/ops/bprop_mindir/Elu_bprop.mindir +0 -16
  758. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  759. mindspore/ops/bprop_mindir/Equal_bprop.mindir +0 -19
  760. mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +0 -58
  761. mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +0 -16
  762. mindspore/ops/bprop_mindir/Flatten_bprop.mindir +0 -54
  763. mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +0 -19
  764. mindspore/ops/bprop_mindir/GatherD_bprop.mindir +0 -26
  765. mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +0 -57
  766. mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
  767. mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +0 -19
  768. mindspore/ops/bprop_mindir/Greater_bprop.mindir +0 -19
  769. mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +0 -16
  770. mindspore/ops/bprop_mindir/HSwish_bprop.mindir +0 -16
  771. mindspore/ops/bprop_mindir/IOU_bprop.mindir +0 -19
  772. mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
  773. mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +0 -15
  774. mindspore/ops/bprop_mindir/IsInf_bprop.mindir +0 -15
  775. mindspore/ops/bprop_mindir/IsNan_bprop.mindir +0 -15
  776. mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +0 -126
  777. mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +0 -15
  778. mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +0 -30
  779. mindspore/ops/bprop_mindir/LRN_bprop.mindir +0 -43
  780. mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
  781. mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +0 -19
  782. mindspore/ops/bprop_mindir/Less_bprop.mindir +0 -19
  783. mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +0 -23
  784. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -13
  785. mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +0 -23
  786. mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +0 -19
  787. mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +0 -15
  788. mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +0 -21
  789. mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +0 -74
  790. mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +0 -74
  791. mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +0 -75
  792. mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +0 -65
  793. mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
  794. mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
  795. mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
  796. mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +0 -27
  797. mindspore/ops/bprop_mindir/Mish_bprop.mindir +0 -35
  798. mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
  799. mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
  800. mindspore/ops/bprop_mindir/NonZero_bprop.mindir +0 -14
  801. mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +0 -19
  802. mindspore/ops/bprop_mindir/OneHot_bprop.mindir +0 -26
  803. mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +0 -14
  804. mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
  805. mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
  806. mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
  807. mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +0 -29
  808. mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +0 -82
  809. mindspore/ops/bprop_mindir/Range_bprop.mindir +0 -22
  810. mindspore/ops/bprop_mindir/Rank_bprop.mindir +0 -14
  811. mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +0 -16
  812. mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
  813. mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +0 -19
  814. mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +0 -19
  815. mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +0 -20
  816. mindspore/ops/bprop_mindir/Reshape_bprop.mindir +0 -60
  817. mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +0 -29
  818. mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +0 -89
  819. mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +0 -52
  820. mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +0 -22
  821. mindspore/ops/bprop_mindir/Round_bprop.mindir +0 -15
  822. mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
  823. mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
  824. mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +0 -22
  825. mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +0 -24
  826. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -22
  827. mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
  828. mindspore/ops/bprop_mindir/SeLU_bprop.mindir +0 -21
  829. mindspore/ops/bprop_mindir/Select_bprop.mindir +0 -31
  830. mindspore/ops/bprop_mindir/Shape_bprop.mindir +0 -14
  831. mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +0 -21
  832. mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
  833. mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +0 -16
  834. mindspore/ops/bprop_mindir/Sign_bprop.mindir +0 -15
  835. mindspore/ops/bprop_mindir/Slice_bprop.mindir +0 -26
  836. mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +0 -36
  837. mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  838. mindspore/ops/bprop_mindir/Softplus_bprop.mindir +0 -16
  839. mindspore/ops/bprop_mindir/Softsign_bprop.mindir +0 -33
  840. mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
  841. mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +0 -28
  842. mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +0 -23
  843. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  844. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  845. mindspore/ops/bprop_mindir/Split_bprop.mindir +0 -22
  846. mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +0 -54
  847. mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +0 -95
  848. mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +0 -98
  849. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -29
  850. mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
  851. mindspore/ops/bprop_mindir/Tanh_bprop.mindir +0 -66
  852. mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +0 -22
  853. mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +0 -29
  854. mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +0 -14
  855. mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
  856. mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
  857. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -23
  858. mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +0 -19
  859. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -20
  860. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -16
  861. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -22
  862. mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +0 -32
  863. mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +0 -38
  864. mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +0 -15
  865. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  866. mindspore/rewrite/node_visitor.py +0 -44
  867. mindspore/rewrite/topological_manager.py +0 -203
  868. mindspore/scipy/sparse/linalg.py +0 -192
  869. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
  870. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
mindspore/train/model.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2020-2022 Huawei Technologies Co., Ltd
1
+ # Copyright 2020-2023 Huawei Technologies Co., Ltd
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -21,9 +21,11 @@ from functools import wraps
21
21
  import os
22
22
  import math
23
23
  import copy
24
+ import importlib
24
25
  import numpy as np
25
26
 
26
27
  import mindspore
28
+ import mindspore.dataset as ds
27
29
  from mindspore import log as logger
28
30
  from mindspore.train.serialization import save_checkpoint, load_checkpoint
29
31
  from mindspore.train.callback._checkpoint import ModelCheckpoint, _chg_ckpt_file_name_if_same_exist
@@ -37,7 +39,7 @@ from mindspore import context
37
39
  from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_parameter_broadcast, \
38
40
  _device_number_check, _parameter_broadcast_check, _parallel_predict_check, \
39
41
  _reset_op_id_with_offset
40
- from mindspore.parallel._ps_context import _is_role_worker, _is_role_pserver, _is_role_sched, _is_ps_mode, \
42
+ from mindspore.parallel._ps_context import _is_role_worker, _is_role_pserver, _is_ps_mode, \
41
43
  _cache_enable, _enable_distributed_mindrt
42
44
  from mindspore.train.metrics import Loss
43
45
  from mindspore import nn
@@ -49,6 +51,7 @@ from mindspore.common.api import _pynative_executor
49
51
  from mindspore.dataset.core.config import get_debug_mode
50
52
  from mindspore.dataset.engine.datasets import _set_training_dataset, _reset_training_dataset
51
53
  from mindspore.train import amp
54
+ from mindspore._c_expression import _framework_profiler_step_start, _framework_profiler_step_end
52
55
 
53
56
 
54
57
  def _transfer_tensor_to_tuple(inputs):
@@ -67,6 +70,17 @@ class _StepSync(Callback):
67
70
  _pynative_executor.sync()
68
71
 
69
72
 
73
+ class _FrameworkProfilerCallback(Callback):
74
+ """
75
+ Profiler callback of framework for training.
76
+ """
77
+ def step_begin(self, run_context):
78
+ _framework_profiler_step_start()
79
+
80
+ def step_end(self, run_context):
81
+ _framework_profiler_step_end()
82
+
83
+
70
84
  def _save_final_ckpt(func):
71
85
  """
72
86
  Decorator function, which saves the current checkpoint when an exception occurs during training.
@@ -108,29 +122,33 @@ class Model:
108
122
  `Model` groups layers into an object with training and inference features based on the arguments.
109
123
 
110
124
  Note:
111
- If use mixed precision functions, need to set parameter `optimizer` at the same time,
112
- otherwise mixed precision functions do not take effect.
113
- When uses mixed precision functions, `global_step` in optimizer may be different from `cur_step_num` in Model.
125
+ - If use mixed precision functions, need to set parameter `optimizer` at the same time,
126
+ otherwise mixed precision functions do not take effect.
127
+ When uses mixed precision functions, `global_step` in optimizer may be different from `cur_step_num`
128
+ in Model.
129
+ - After using `custom_mixed_precision` or `auto_mixed_precision` for precision conversion, it is not supported
130
+ to perform the precision conversion again. If `Model` is used to train a converted network, `amp_level`
131
+ need to be configured to ``O0`` to avoid the duplicated accuracy conversion.
114
132
 
115
133
  Args:
116
134
  network (Cell): A training or testing network.
117
135
  loss_fn (Cell): Objective function. If `loss_fn` is None, the `network` should contain the calculation of loss
118
- and parallel if needed. Default: None.
136
+ and parallel if needed. Default: ``None`` .
119
137
  optimizer (Cell): Optimizer for updating the weights. If `optimizer` is None, the `network` needs to
120
- do backpropagation and update weights. Default value: None.
138
+ do backpropagation and update weights. Default: ``None`` .
121
139
  metrics (Union[dict, set]): A Dictionary or a set of metrics for model evaluation.
122
- eg: {'accuracy', 'recall'}. Default: None.
140
+ eg: {'accuracy', 'recall'}. Default: ``None`` .
123
141
  eval_network (Cell): Network for evaluation. If not defined, `network` and `loss_fn` would be wrapped as
124
- `eval_network` . Default: None.
142
+ `eval_network` . Default: ``None`` .
125
143
  eval_indexes (list): It is used when eval_network is defined. If `eval_indexes` is None by default, all outputs
126
144
  of the `eval_network` would be passed to metrics. If `eval_indexes` is set, it must contain
127
145
  three elements: the positions of loss value, predicted value and label in outputs of the
128
146
  `eval_network`. In this case, the loss value will be passed to the `Loss` metric, the
129
147
  predicted value and label will be passed to other metrics.
130
148
  :func:`mindspore.train.Metric.set_indexes` is recommended instead of `eval_indexes`.
131
- Default: None.
149
+ Default: ``None`` .
132
150
  amp_level (str): Option for argument `level` in :func:`mindspore.amp.build_train_network`, level for mixed
133
- precision training. Supports ["O0", "O1", "O2", "O3", "auto"]. Default: "O0".
151
+ precision training. Supports ["O0", "O1", "O2", "O3", "auto"]. Default: ``"O0"`` .
134
152
 
135
153
  - "O0": Do not change.
136
154
  - "O1": Cast the operators in white_list to float16, the remaining operators are kept in float32.
@@ -138,7 +156,7 @@ class Model:
138
156
  Conv3dTranspose, Dense, LSTMCell, RNNCell, GRUCell, MatMul, BatchMatMul, PReLU, ReLU, Ger].
139
157
  - "O2": Cast network to float16, keep BatchNorm run in float32, using dynamic loss scale.
140
158
  - "O3": Cast network to float16, the BatchNorm is also cast to float16, loss scale will not be used.
141
- - auto: Set level to recommended level in different devices. Set level to "O2" on GPU, set
159
+ - "auto": Set level to recommended level in different devices. Set level to "O2" on GPU, set
142
160
  level to "O3" on Ascend. The recommended level is chosen by the expert experience, not applicable to all
143
161
  scenarios. User should specify the level for special network.
144
162
 
@@ -149,7 +167,7 @@ class Model:
149
167
  The more detailed explanation of `amp_level` setting can be found at `mindspore.amp.build_train_network`.
150
168
 
151
169
  boost_level (str): Option for argument `level` in `mindspore.boost`, level for boost mode
152
- training. Supports ["O0", "O1", "O2"]. Default: "O0".
170
+ training. Supports ["O0", "O1", "O2"]. Default: ``"O0"`` .
153
171
 
154
172
  - "O0": Do not change.
155
173
  - "O1": Enable the boost mode, the performance is improved by about 20%, and
@@ -165,39 +183,23 @@ class Model:
165
183
  can obtain the same benefits. It is recommended to enable this function on
166
184
  the Graph mode + Ascend platform, and for better acceleration, refer to the documentation to configure
167
185
  boost_config_dict.
186
+
168
187
  Examples:
169
188
  >>> from mindspore import nn
170
189
  >>> from mindspore.train import Model
171
190
  >>>
172
- >>> class Net(nn.Cell):
173
- ... def __init__(self, num_class=10, num_channel=1):
174
- ... super(Net, self).__init__()
175
- ... self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
176
- ... self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
177
- ... self.fc1 = nn.Dense(16*5*5, 120, weight_init='ones')
178
- ... self.fc2 = nn.Dense(120, 84, weight_init='ones')
179
- ... self.fc3 = nn.Dense(84, num_class, weight_init='ones')
180
- ... self.relu = nn.ReLU()
181
- ... self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
182
- ... self.flatten = nn.Flatten()
183
- ...
184
- ... def construct(self, x):
185
- ... x = self.max_pool2d(self.relu(self.conv1(x)))
186
- ... x = self.max_pool2d(self.relu(self.conv2(x)))
187
- ... x = self.flatten(x)
188
- ... x = self.relu(self.fc1(x))
189
- ... x = self.relu(self.fc2(x))
190
- ... x = self.fc3(x)
191
- ... return x
192
- >>>
193
- >>> net = Net()
194
- >>> loss = nn.SoftmaxCrossEntropyWithLogits()
191
+ >>> # Define the network structure of LeNet5. Refer to
192
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
193
+ >>> net = LeNet5()
194
+ >>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
195
195
  >>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
196
196
  >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
197
- >>> # For details about how to build the dataset, please refer to the variable `dataset_train` in tutorial
198
- >>> # document on the official website:
199
- >>> # https://www.mindspore.cn/tutorials/zh-CN/r2.0/beginner/quick_start.html
200
- >>> dataset = create_custom_dataset()
197
+ >>> model.train_network
198
+ >>> model.predict_network
199
+ >>> model.eval_network
200
+ >>> # Create the dataset taking MNIST as an example. Refer to
201
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/mnist.py
202
+ >>> dataset = create_dataset()
201
203
  >>> model.train(2, dataset)
202
204
  """
203
205
 
@@ -223,6 +225,7 @@ class Model:
223
225
  self._check_for_graph_cell(kwargs)
224
226
  self._build_boost_network(kwargs)
225
227
  self._train_network = self._build_train_network()
228
+ self._train_network._jit_config_dict = network.jit_config_dict
226
229
  self._build_eval_network(metrics, self._eval_network, eval_indexes)
227
230
  self._build_predict_network()
228
231
  self._current_epoch_num = 0
@@ -231,6 +234,12 @@ class Model:
231
234
  self.enable_recovery = False
232
235
  self._backbone_is_train = True
233
236
  self.need_load_ckpt = False
237
+ self._lite_full_predictor = None
238
+ self._lite_incremental_predictor = None
239
+ self._mindspore_lite = None
240
+ self._lite_infer = True # if backend lite infer fails, set False
241
+ self._mindspore_lite_model_group_id = id(self) & 0xFFFF
242
+
234
243
 
235
244
  def _check_for_graph_cell(self, kwargs):
236
245
  """Check for graph cell"""
@@ -458,7 +467,7 @@ class Model:
458
467
  Args:
459
468
  epoch (int): Total number of iterations on the data.
460
469
  train_dataset (Dataset): A training dataset iterator. If `train_dataset` is defined, training graphs will be
461
- initialized. Default: None.
470
+ initialized. Default: ``None``.
462
471
  sink_size (int): Control the amount of data in each sink. Default: -1.
463
472
  """
464
473
  if sink_size == -1:
@@ -485,9 +494,9 @@ class Model:
485
494
 
486
495
  Args:
487
496
  train_dataset (Dataset): A training dataset iterator. If `train_dataset` is defined, training graphs will be
488
- initialized. Default: None.
497
+ initialized. Default: ``None``.
489
498
  valid_dataset (Dataset): A evaluating dataset iterator. If `valid_dataset` is defined, evaluation graphs
490
- will be initialized, and `metrics` in `Model` can not be None. Default: None.
499
+ will be initialized, and `metrics` in `Model` can not be None. Default: ``None``.
491
500
  sink_size (int): Control the amount of data in each sink. Default: -1.
492
501
  epoch (int): Total number of iterations on the data. Default: 1.
493
502
  """
@@ -562,16 +571,15 @@ class Model:
562
571
  returned and passed to the network. Otherwise, a tuple (data, label) will
563
572
  be returned. The data and label would be passed to the network and loss
564
573
  function respectively.
565
- callbacks (list): List of callback objects which should be executed while training. Default: None.
574
+ callbacks (list): List of callback objects which should be executed while training. Default: ``None``.
566
575
  dataset_sink_mode (bool): Determine whether the data should be passed through the dataset channel.
567
- Default: True.
576
+ Default: ``True``.
568
577
  Configure pynative mode or CPU, the training process will be performed with
569
578
  dataset not sink.
570
579
  sink_size (int): Control the amount of data in each sink. Default: -1.
571
580
  initial_epoch (int): Epoch at which to start train, it used for resuming a previous training run.
572
581
  Default: 0.
573
582
  """
574
- epoch = Validator.check_positive_int(epoch)
575
583
  if self._parameter_broadcast:
576
584
  self._train_network.set_broadcast_flag()
577
585
 
@@ -590,15 +598,14 @@ class Model:
590
598
  cb_params.train_dataset = train_dataset
591
599
  cb_params.list_callback = self._transform_callbacks(callbacks)
592
600
  valid_infos = (valid_dataset, valid_frequency, valid_dataset_sink_mode)
601
+ cb_params.list_callback.insert(0, _FrameworkProfilerCallback())
593
602
  if context.get_context("mode") == context.PYNATIVE_MODE:
594
603
  cb_params.list_callback.insert(0, _StepSync())
595
- callbacks = cb_params.list_callback
604
+ callbacks = cb_params.list_callback
596
605
  cb_params.train_dataset_element = None
597
606
  cb_params.network = self._network
598
- if _is_role_sched():
599
- epoch = 1
600
607
  # Embedding cache server only run one step.
601
- if (_is_role_pserver() or _is_role_sched()) and _cache_enable():
608
+ if _is_role_pserver() and _cache_enable():
602
609
  epoch = 1
603
610
  cb_params.last_save_ckpt_step = None
604
611
  cb_params.latest_ckpt_file = None
@@ -632,18 +639,23 @@ class Model:
632
639
  returned and passed to the network. Otherwise, a tuple (data, label) should
633
640
  be returned. The data and label would be passed to the network and loss
634
641
  function respectively.
635
- list_callback (Callback): Executor of callback list. Default: None.
636
- cb_params (_InternalCallbackParam): Callback parameters. Default: None.
642
+ list_callback (Callback): Executor of callback list. Default: ``None``.
643
+ cb_params (_InternalCallbackParam): Callback parameters. Default: ``None``.
637
644
  sink_size (int): Control the amount of data in each sink. Default: -1.
638
645
  initial_epoch (int): Epoch at which to start train, it used for resuming a previous training run.
639
646
  Default: 0.
640
647
  """
641
648
  is_graph = (context.get_context("mode") == context.GRAPH_MODE)
649
+ dataset_size = train_dataset.get_dataset_size()
650
+ if dataset_size % sink_size != 0:
651
+ logger.warning("In dataset_sink mode (dataset_size % sink_size) should equal to 0, "
652
+ "it is suggested to pad/drop data or adjust sink_size. "
653
+ "But got 'dataset_size': {}, 'sink_size': {}.".format(dataset_size, sink_size))
642
654
  if sink_size == -1:
643
- epoch_num = epoch - initial_epoch
655
+ dataset_sink_num = epoch
644
656
  else:
645
- epoch_num = math.ceil(epoch * sink_size / train_dataset.get_dataset_size()) - initial_epoch
646
- train_dataset.__total_batch__ = (epoch - initial_epoch) * sink_size
657
+ dataset_sink_num = math.ceil(epoch * sink_size / dataset_size)
658
+ train_dataset.__total_batch__ = epoch * sink_size
647
659
 
648
660
  cb_params.cur_step_num = 0
649
661
  cb_params.dataset_sink_mode = True
@@ -659,7 +671,7 @@ class Model:
659
671
 
660
672
  self._check_enable_recovery()
661
673
  # Used to check whether need perform recovery for process which is restarted.
662
- self._check_need_load_ckpt(cb_params, train_dataset.get_dataset_size(), sink_size)
674
+ self._check_need_load_ckpt(cb_params, dataset_size, sink_size)
663
675
  # Check whether this process is embedding cache server.
664
676
  is_embedding_cache_server = _is_role_pserver() and _cache_enable()
665
677
 
@@ -672,10 +684,11 @@ class Model:
672
684
  dataset=train_dataset,
673
685
  dataset_sink_mode=True,
674
686
  sink_size=sink_size,
675
- epoch_num=epoch_num,
687
+ epoch_num=dataset_sink_num,
676
688
  dataset_helper=dataset_helper)
677
689
 
678
690
  cb_params.train_network = train_network
691
+ cb_params.dataset_helper = dataset_helper
679
692
 
680
693
  # Perform recovery for process which is restarted.
681
694
  self._reset_training_step_for_abnormal_process(cb_params, dataset_helper)
@@ -695,9 +708,6 @@ class Model:
695
708
  outputs = train_network(*inputs)
696
709
  cb_params.net_outputs = outputs
697
710
 
698
- if _is_role_sched():
699
- os._exit(0)
700
-
701
711
  # In disaster recovery scenarios, need not to execute callbacks if this step executes failed.
702
712
  need_exec_callback_step_end = not (self.enable_recovery and _get_recovery_context("need_reset"))
703
713
  if need_exec_callback_step_end:
@@ -824,7 +834,7 @@ class Model:
824
834
  os.remove(cb_params.latest_ckpt_file)
825
835
  raise RuntimeError(e.__str__() + ", load ckpt failed and remove the ckpt: "\
826
836
  + cb_params.latest_ckpt_file) from e
827
- _reset_training_dataset(cb_params.cur_step_num, dataset_helper.sink_size())
837
+ _reset_training_dataset(cb_params.cur_step_num, dataset_helper.iter.dataset.get_dataset_size())
828
838
  self.need_load_ckpt = False
829
839
 
830
840
  def _reset_training_step_for_normal_process(self, cb_params, dataset_helper):
@@ -853,9 +863,9 @@ class Model:
853
863
  self.epoch_iter = recovery_epoch_num
854
864
  cb_params.cur_epoch_num = self.epoch_iter + 1
855
865
  cb_params.last_save_ckpt_step = cb_params.cur_step_num
856
- _reset_training_dataset(cb_params.cur_step_num, dataset_helper.sink_size())
866
+ _reset_training_dataset(cb_params.cur_step_num, dataset_helper.iter.dataset.get_dataset_size())
857
867
  else:
858
- _reset_training_dataset(0, dataset_helper.sink_size())
868
+ _reset_training_dataset(0, dataset_helper.iter.dataset.get_dataset_size())
859
869
 
860
870
  _set_recovery_context(need_reset=False)
861
871
 
@@ -871,15 +881,15 @@ class Model:
871
881
  returned and passed to the network. Otherwise, a tuple (data, label) should
872
882
  be returned. The data and label would be passed to the network and loss
873
883
  function respectively.
874
- list_callback (Callback): Executor of callback list. Default: None.
875
- cb_params (_InternalCallbackParam): Callback parameters. Default: None.
884
+ list_callback (Callback): Executor of callback list. Default: ``None``.
885
+ cb_params (_InternalCallbackParam): Callback parameters. Default: ``None``.
876
886
  initial_epoch (int): Epoch at which to start train, it used for resuming a previous training run.
877
887
  Default: 0.
878
888
  """
879
889
  dataset_helper, _ = self._exec_preprocess(is_train=True,
880
890
  dataset=train_dataset,
881
891
  dataset_sink_mode=False,
882
- epoch_num=(epoch-initial_epoch))
892
+ epoch_num=epoch)
883
893
  cb_params.cur_step_num = 0
884
894
  cb_params.dataset_sink_mode = False
885
895
  run_context = RunContext(cb_params)
@@ -914,8 +924,6 @@ class Model:
914
924
  self._loss_scale_manager.update_loss_scale(overflow)
915
925
 
916
926
  list_callback.on_train_step_end(run_context)
917
- if _is_role_sched():
918
- os._exit(0)
919
927
  # Embedding cache server only run one step.
920
928
  if is_embedding_cache_server:
921
929
  break
@@ -959,7 +967,7 @@ class Model:
959
967
  of data will be transferred one by one. The limitation of data transmission per time is 256M.
960
968
 
961
969
  When dataset_sink_mode is True, the `step_end` method of the instance of Callback will be called at the end
962
- of epoch.
970
+ of step in PyNative mode, or will be called at the end of epoch in Graph mode.
963
971
 
964
972
  If dataset_sink_mode is True, dataset will be bound to this model and cannot be used by other models.
965
973
 
@@ -983,12 +991,12 @@ class Model:
983
991
  passed to the `network`.
984
992
  callbacks (Optional[list[Callback], Callback]): List of callback objects or callback object,
985
993
  which should be executed while training.
986
- Default: None.
994
+ Default: ``None``.
987
995
  dataset_sink_mode (bool): Determines whether to pass the data through dataset channel.
988
996
  Configure pynative mode or CPU, the training process will be performed with
989
- dataset not sink. Default: False.
990
- sink_size (int): Control the amount of data in each sink. `sink_size` is invalid if `dataset_sink_mode`
991
- is False.
997
+ dataset not sink. Default: ``False``.
998
+ sink_size (int): Control the number of steps for each sinking.
999
+ `sink_size` is invalid if `dataset_sink_mode` is False.
992
1000
  If sink_size = -1, sink the complete dataset for each epoch.
993
1001
  If sink_size > 0, sink sink_size data for each epoch.
994
1002
  Default: -1.
@@ -999,17 +1007,21 @@ class Model:
999
1007
  >>> from mindspore import nn
1000
1008
  >>> from mindspore.train import Model
1001
1009
  >>>
1002
- >>> # For details about how to build the dataset, please refer to the tutorial
1003
- >>> # document on the official website.
1004
- >>> dataset = create_custom_dataset()
1005
- >>> net = Net()
1006
- >>> loss = nn.SoftmaxCrossEntropyWithLogits()
1010
+ >>> # Create the dataset taking MNIST as an example. Refer to
1011
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/mnist.py
1012
+ >>> dataset = create_dataset()
1013
+ >>> # Define the network structure of LeNet5. Refer to
1014
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
1015
+ >>> net = LeNet5()
1016
+ >>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
1007
1017
  >>> loss_scale_manager = ms.FixedLossScaleManager(1024., False)
1008
1018
  >>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
1009
1019
  >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None,
1010
1020
  ... loss_scale_manager=loss_scale_manager)
1011
1021
  >>> model.train(2, dataset)
1012
1022
  """
1023
+ # prepare dataset for obfuscated model
1024
+ train_dataset = self._prepare_obf_dataset(train_dataset)
1013
1025
  device_target = context.get_context("device_target")
1014
1026
  if _is_ps_mode() and not _cache_enable() and (device_target in ["Ascend", "CPU"]) and dataset_sink_mode:
1015
1027
  logger.info("For PS mode, reset datasink mode to False when using Ascend or CPU backend.")
@@ -1033,7 +1045,7 @@ class Model:
1033
1045
  self._check_sink_mode_for_ds_debug_mode(dataset_sink_mode)
1034
1046
 
1035
1047
  Validator.check_is_int(sink_size)
1036
- Validator.check_non_negative_int(epoch)
1048
+ Validator.check_positive_int(epoch)
1037
1049
  Validator.check_non_negative_int(initial_epoch)
1038
1050
  if initial_epoch >= epoch:
1039
1051
  raise ValueError(f"For 'Model.train', the parameter 'epoch' must bigger than parameter 'initial_epoch',"
@@ -1121,42 +1133,48 @@ class Model:
1121
1133
  then a tuple (data1, data2, data3, ...) with all data returned from dataset
1122
1134
  will be passed to the `network`.
1123
1135
  valid_dataset (Dataset): Dataset to evaluate the model. If `valid_dataset` is provided, evaluation process
1124
- will be performed on the end of training process. Default: None.
1136
+ will be performed on the end of training process. Default: ``None`` .
1125
1137
  valid_frequency (int, list): Only relevant if `valid_dataset` is provided. If an integer, specifies
1126
1138
  how many training epochs to run before a new validation run is performed,
1127
1139
  e.g. `valid_frequency=2` runs validation every 2 epochs.
1128
1140
  If a list, specifies the epochs on which to run validation,
1129
1141
  e.g. `valid_frequency=[1, 5]` runs validation at the end of the 1st, 5th epochs.
1130
- Default: 1
1142
+ Default: ``1`` .
1131
1143
  callbacks (Optional[list[Callback], Callback]): List of callback objects or callback object,
1132
1144
  which should be executed while training.
1133
- Default: None.
1145
+ Default: ``None`` .
1134
1146
  dataset_sink_mode (bool): Determines whether to pass the train data through dataset channel.
1135
1147
  Configure pynative mode or CPU, the training process will be performed with
1136
- dataset not sink. Default: False.
1148
+ dataset not sink. Default: ``False`` .
1137
1149
  valid_dataset_sink_mode (bool): Determines whether to pass the validation data through dataset channel.
1138
- Default: False.
1139
- sink_size (int): Control the amount of data in each sink. `sink_size` is invalid if `dataset_sink_mode`
1140
- is False.
1150
+ Default: ``False`` .
1151
+ sink_size (int): Control the number of steps for each sinking.
1152
+ `sink_size` is invalid if `dataset_sink_mode` is False.
1141
1153
  If sink_size = -1, sink the complete dataset for each epoch.
1142
1154
  If sink_size > 0, sink sink_size data for each epoch.
1143
- Default: -1.
1155
+ Default: ``-1`` .
1144
1156
  initial_epoch (int): Epoch at which to start train, it useful for resuming a previous training run.
1145
- Default: 0.
1157
+ Default: ``0`` .
1146
1158
 
1147
1159
  Examples:
1148
1160
  >>> from mindspore import nn
1149
1161
  >>> from mindspore.train import Model
1150
1162
  >>>
1151
- >>> # For details about how to build the dataset, please refer to the tutorial
1152
- >>> # document on the official website.
1153
- >>> train_dataset = create_custom_dataset()
1154
- >>> valid_dataset = create_custom_dataset()
1155
- >>> net = Net()
1156
- >>> loss = nn.SoftmaxCrossEntropyWithLogits()
1163
+ >>> # Create the dataset taking MNIST as an example. Refer to
1164
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/mnist.py
1165
+ >>> train_dataset = create_dataset("train")
1166
+ >>> valid_dataset = create_dataset("test")
1167
+ >>> # Define the network structure of LeNet5. Refer to
1168
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
1169
+ >>> net = LeNet5()
1170
+ >>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
1157
1171
  >>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
1158
1172
  >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics={"accuracy"})
1159
1173
  >>> model.fit(2, train_dataset, valid_dataset)
1174
+
1175
+ Tutorial Examples:
1176
+ - `Advanced Encapsulation: Model - Train and Save Model
1177
+ <https://www.mindspore.cn/tutorials/en/r2.2/advanced/model.html#training-and-saving-model>`_
1160
1178
  """
1161
1179
  device_target = context.get_context("device_target")
1162
1180
  if _is_ps_mode() and not _cache_enable() and (device_target in ["Ascend", "CPU"]) and dataset_sink_mode:
@@ -1175,7 +1193,7 @@ class Model:
1175
1193
  .format(train_dataset._warmup_epoch, epoch))
1176
1194
 
1177
1195
  Validator.check_is_int(sink_size)
1178
- Validator.check_non_negative_int(epoch)
1196
+ Validator.check_positive_int(epoch)
1179
1197
  Validator.check_non_negative_int(initial_epoch)
1180
1198
  if initial_epoch >= epoch:
1181
1199
  raise ValueError(f"For 'Model.fit', the parameter 'epoch' must bigger than parameter 'initial_epoch',"
@@ -1224,21 +1242,23 @@ class Model:
1224
1242
 
1225
1243
  Args:
1226
1244
  train_dataset (Dataset): A training dataset iterator. If `train_dataset` is defined, training graphs will be
1227
- built. Default: None.
1245
+ built. Default: ``None`` .
1228
1246
  valid_dataset (Dataset): An evaluating dataset iterator. If `valid_dataset` is defined, evaluation graphs
1229
- will be built, and `metrics` in `Model` can not be None. Default: None.
1230
- sink_size (int): Control the amount of data in each sink. Default: -1.
1231
- epoch (int): Control the training epochs. Default: 1.
1247
+ will be built, and `metrics` in `Model` can not be None. Default: ``None`` .
1248
+ sink_size (int): Control the number of steps for each sinking. Default: ``-1`` .
1249
+ epoch (int): Control the training epochs. Default: ``1`` .
1232
1250
 
1233
1251
  Examples:
1234
1252
  >>> from mindspore import nn
1235
1253
  >>> from mindspore.train import Model
1236
1254
  >>> from mindspore.amp import FixedLossScaleManager
1237
1255
  >>>
1238
- >>> # For details about how to build the dataset, please refer to the tutorial
1239
- >>> # document on the official website.
1240
- >>> dataset = create_custom_dataset()
1241
- >>> net = Net()
1256
+ >>> # Create the dataset taking MNIST as an example. Refer to
1257
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/mnist.py
1258
+ >>> dataset = create_dataset()
1259
+ >>> # Define the network structure of LeNet5. Refer to
1260
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
1261
+ >>> net = LeNet5()
1242
1262
  >>> loss = nn.SoftmaxCrossEntropyWithLogits()
1243
1263
  >>> loss_scale_manager = FixedLossScaleManager()
1244
1264
  >>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
@@ -1247,6 +1267,10 @@ class Model:
1247
1267
  >>> model.build(dataset, epoch=2)
1248
1268
  >>> model.train(2, dataset)
1249
1269
  """
1270
+ epoch = Validator.check_positive_int(epoch)
1271
+ if hasattr(self._train_network, '_is_check_and_refresh') and not self._train_network._is_check_and_refresh:
1272
+ self._train_network.check_names_and_refresh_name()
1273
+ self._train_network._is_check_and_refresh = True
1250
1274
  self._init(train_dataset, valid_dataset, sink_size, epoch)
1251
1275
 
1252
1276
  def _eval_in_fit(self, valid_dataset, callbacks=None, dataset_sink_mode=True, cb_params=None):
@@ -1255,12 +1279,12 @@ class Model:
1255
1279
 
1256
1280
  Args:
1257
1281
  valid_dataset (Dataset): Dataset to evaluate the model. If `valid_dataset` is provided, evaluation process
1258
- will be performed on the end of training process. Default: None.
1282
+ will be performed on the end of training process. Default: ``None``.
1259
1283
  callbacks (Optional[list[Callback], Callback]): List of callback objects or callback object, which should be
1260
- executed while evaluation. Default: None.
1284
+ executed while evaluation. Default: ``None``.
1261
1285
  valid_dataset_sink_mode (bool): Determines whether to pass the validation data through dataset channel.
1262
- Default: True.
1263
- cb_params (_InternalCallbackParam): Callback parameters. Default: None.
1286
+ Default: ``True``.
1287
+ cb_params (_InternalCallbackParam): Callback parameters. Default: ``None``.
1264
1288
  """
1265
1289
  if isinstance(self._eval_network, nn.GraphCell) and dataset_sink_mode:
1266
1290
  raise ValueError("Sink mode is currently not supported when evaluating with a GraphCell.")
@@ -1289,8 +1313,8 @@ class Model:
1289
1313
 
1290
1314
  Args:
1291
1315
  valid_dataset (Dataset): Dataset to evaluate the model.
1292
- list_callback (Callback): Executor of callback list. Default: None.
1293
- cb_params (_InternalCallbackParam): Callback parameters. Default: None.
1316
+ list_callback (Callback): Executor of callback list. Default: ``None``.
1317
+ cb_params (_InternalCallbackParam): Callback parameters. Default: ``None``.
1294
1318
 
1295
1319
  Returns:
1296
1320
  Dict, which returns the loss value and metrics values for the model in the test mode.
@@ -1313,8 +1337,6 @@ class Model:
1313
1337
  outputs = eval_network(*inputs)
1314
1338
  cb_params.net_outputs = outputs
1315
1339
  list_callback.on_eval_step_end(run_context)
1316
- if _is_role_sched():
1317
- os._exit(0)
1318
1340
  self._update_metrics(outputs)
1319
1341
  if add_eval_loss:
1320
1342
  eval_loss_fn = get_metric_fn("loss")
@@ -1337,8 +1359,8 @@ class Model:
1337
1359
 
1338
1360
  Args:
1339
1361
  valid_dataset (Dataset): Dataset to evaluate the model.
1340
- list_callback (Callback): Executor of callback list. Default: None.
1341
- cb_params (_InternalCallbackParam): Callback parameters. Default: None.
1362
+ list_callback (Callback): Executor of callback list. Default: ``None``.
1363
+ cb_params (_InternalCallbackParam): Callback parameters. Default: ``None``.
1342
1364
 
1343
1365
  Returns:
1344
1366
  Dict, which returns the loss value and metrics values for the model in the test mode.
@@ -1359,8 +1381,6 @@ class Model:
1359
1381
  outputs = self._eval_network(*next_element)
1360
1382
  cb_params.net_outputs = outputs
1361
1383
  list_callback.on_eval_step_end(run_context)
1362
- if _is_role_sched():
1363
- os._exit(0)
1364
1384
  self._update_metrics(outputs)
1365
1385
  if add_eval_loss:
1366
1386
  eval_loss_fn = get_metric_fn("loss")
@@ -1397,9 +1417,9 @@ class Model:
1397
1417
  valid_dataset (Dataset): Dataset to evaluate the model.
1398
1418
  callbacks (Optional[list(Callback), Callback]): List of callback objects or callback object,
1399
1419
  which should be executed while evaluation.
1400
- Default: None.
1420
+ Default: ``None`` .
1401
1421
  dataset_sink_mode (bool): Determines whether to pass the data through dataset channel.
1402
- Default: False.
1422
+ Default: ``False`` .
1403
1423
 
1404
1424
  Returns:
1405
1425
  Dict, the key is the metric name defined by users and the value is the metrics value for
@@ -1409,14 +1429,21 @@ class Model:
1409
1429
  >>> from mindspore import nn
1410
1430
  >>> from mindspore.train import Model
1411
1431
  >>>
1412
- >>> # For details about how to build the dataset, please refer to the tutorial
1413
- >>> # document on the official website.
1414
- >>> dataset = create_custom_dataset()
1415
- >>> net = Net()
1416
- >>> loss = nn.SoftmaxCrossEntropyWithLogits()
1432
+ >>> # Create the dataset taking MNIST as an example. Refer to
1433
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/mnist.py
1434
+ >>> dataset = create_dataset()
1435
+ >>> # Define the network structure of LeNet5. Refer to
1436
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
1437
+ >>> net = LeNet5()
1438
+ >>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
1417
1439
  >>> model = Model(net, loss_fn=loss, optimizer=None, metrics={'acc'})
1418
1440
  >>> acc = model.eval(dataset, dataset_sink_mode=False)
1441
+
1442
+ Tutorial Examples:
1443
+ - `Advanced Encapsulation: Model - Train and Save Model
1444
+ <https://www.mindspore.cn/tutorials/en/r2.2/advanced/model.html#training-and-saving-model>`_
1419
1445
  """
1446
+ valid_dataset = self._prepare_obf_dataset(valid_dataset)
1420
1447
  dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
1421
1448
 
1422
1449
  _device_number_check(self._parallel_mode, self._device_number)
@@ -1464,7 +1491,140 @@ class Model:
1464
1491
 
1465
1492
  return eval_result
1466
1493
 
1467
- def predict(self, *predict_data):
1494
+ def _predict_lite(self, *predict_data, config=None):
1495
+ """
1496
+ Generate output predictions for the input samples using backend 'lite'.
1497
+
1498
+ Args:
1499
+ predict_data (Union[Tensor, list[Tensor], tuple[Tensor]], optional):
1500
+ The predict data, can be a single tensor,
1501
+ a list of tensor, or a tuple of tensor.
1502
+
1503
+ config (dict, optional) - The config parameter is enabled when the backend is ‘lite’.
1504
+ The config includes two parts: config_path (configPath, str) and config_item (str, dict).
1505
+ When the config_item is set, its priority is higher than the config_path. Set the ranking
1506
+ table file for inference. The content of the configuration file is as follows:
1507
+
1508
+ config_path defines the path of the configuration file, which is used to pass user-defined
1509
+ options during model building. In the following scenarios, users may need to set parameters.
1510
+ For example: "/home/user/config.ini". Default value: ``"" `` , here is the content of the
1511
+ config.ini file:
1512
+
1513
+ .. code-block::
1514
+
1515
+ [ascend_context]
1516
+ rank_table_file = [path_a](storage initial path of the rank table file)
1517
+ [execution_plan]
1518
+ [op_name1] = data_type:float16 (operator named op_name1 is set to data type Float16)
1519
+ [op_name2] = data_type:float32 (operator named op_name2 is set to data type Float32)
1520
+
1521
+ When only the config_path is configured, it is done as follows:
1522
+
1523
+ .. code-block::
1524
+
1525
+ config = {"configPath" : "/home/user/config.ini"}
1526
+
1527
+ When only the config_dict is configured, it is done as follows:
1528
+
1529
+ .. code-block::
1530
+
1531
+ config = {"ascend_context" : {"rank_table_file" : "path_b"},
1532
+ "execution_plan" : {"op_name1" : "data_type:float16", "op_name2" : "data_type:float32"}}
1533
+
1534
+ When both the `config_path` and the `config_dict` are configured, it is done as follows:
1535
+
1536
+ .. code-block::
1537
+
1538
+ config = {"configPath" : "/home/user/config.ini",
1539
+ "ascend_context" : {"rank_table_file" : "path_b"},
1540
+ "execution_plan" : {"op_name3" : "data_type:float16", "op_name4" : "data_type:float32"}}
1541
+
1542
+ Note that both the "configPath" is configured in the config_dict and the config_item,
1543
+ in this case, the path_b in the config_dict takes precedence.
1544
+
1545
+ Returns:
1546
+ Tensor, array(s) of predictions.
1547
+ """
1548
+ def _get_lite_context(lite_context_input):
1549
+ # use default lite context parameters for now
1550
+ device_target = context.get_context("device_target").lower()
1551
+ lite_context_input.target = [device_target]
1552
+ if device_target == 'cpu':
1553
+ inter_op_parallel_num = context.get_context('inter_op_parallel_num')
1554
+ if inter_op_parallel_num and isinstance(inter_op_parallel_num, int):
1555
+ lite_context_input.cpu.inter_op_parallel_num = inter_op_parallel_num
1556
+ elif device_target == 'gpu':
1557
+ device_id = context.get_context('device_id')
1558
+ if device_id and isinstance(device_id, int):
1559
+ lite_context_input.gpu.device_id = device_id
1560
+ if context.get_auto_parallel_context("parallel_mode") == context.ParallelMode.SEMI_AUTO_PARALLEL:
1561
+ from mindspore.communication import init, get_rank
1562
+ init()
1563
+ lite_context_input.gpu.rank_id = get_rank()
1564
+ elif device_target == 'ascend':
1565
+ device_id = context.get_context('device_id')
1566
+ if device_id and isinstance(device_id, int):
1567
+ lite_context_input.ascend.device_id = device_id
1568
+ if context.get_auto_parallel_context("parallel_mode") == context.ParallelMode.SEMI_AUTO_PARALLEL:
1569
+ from mindspore.communication import init, get_rank
1570
+ init()
1571
+ lite_context_input.ascend.rank_id = get_rank()
1572
+ lite_context_input.ascend.provider = "ge"
1573
+ else:
1574
+ raise RuntimeError(f"For predict lite, device target should be in ['gpu', 'cpu', 'ascend']"
1575
+ f" but got {device_target}")
1576
+ return lite_context_input
1577
+
1578
+ if not self._mindspore_lite:
1579
+ self._mindspore_lite = importlib.import_module('mindspore_lite')
1580
+
1581
+ use_past = False # default execute full model inference
1582
+ model_group_id = None
1583
+ if self._predict_network.get_flags().__contains__("is_first_iteration"):
1584
+ is_first_iteration = self._predict_network.get_flags()['is_first_iteration']
1585
+ if isinstance(is_first_iteration, bool):
1586
+ use_past = not is_first_iteration
1587
+ model_group_id = self._mindspore_lite_model_group_id
1588
+
1589
+ check_input_data(*predict_data, data_class=Tensor)
1590
+ if use_past:
1591
+ # Execute incremental model inference
1592
+ if not self._lite_incremental_predictor:
1593
+ lite_context = _get_lite_context(self._mindspore_lite.Context())
1594
+ self._lite_incremental_predictor = \
1595
+ self._mindspore_lite.lite_infer.LiteInfer(self, *predict_data, context=lite_context,
1596
+ model_group_id=model_group_id, config=config)
1597
+
1598
+ inputs = self._lite_incremental_predictor.get_inputs()
1599
+ if len(predict_data) != len(inputs):
1600
+ raise RuntimeError(f"For 'Model.predict', numbers of predict_data {len(predict_data)} "
1601
+ f"is not equal to numbers of net input {len(inputs)}")
1602
+ for i, single_data in enumerate(predict_data):
1603
+ inputs[i].set_data_from_numpy(single_data.asnumpy())
1604
+ outputs: list = self._lite_incremental_predictor.predict(inputs)
1605
+ else:
1606
+ # Execute full model inference
1607
+ if not self._lite_full_predictor:
1608
+ lite_context = _get_lite_context(self._mindspore_lite.Context())
1609
+ self._lite_full_predictor = \
1610
+ self._mindspore_lite.lite_infer.LiteInfer(self, *predict_data, context=lite_context,
1611
+ model_group_id=model_group_id, config=config)
1612
+
1613
+ inputs = self._lite_full_predictor.get_inputs()
1614
+ if len(predict_data) != len(inputs):
1615
+ raise RuntimeError(f"For 'Model.predict', numbers of predict_data {len(predict_data)} "
1616
+ f"is not equal to numbers of net input {len(inputs)}")
1617
+ for i, single_data in enumerate(predict_data):
1618
+ inputs[i].set_data_from_numpy(single_data.asnumpy())
1619
+ outputs: list = self._lite_full_predictor.predict(inputs)
1620
+ if not outputs:
1621
+ return Tensor(outputs)
1622
+ if len(outputs) == 1:
1623
+ return Tensor(outputs[0].get_data_to_numpy())
1624
+ outputs = [Tensor(single_output.get_data_to_numpy()) for single_output in outputs]
1625
+ return tuple(outputs)
1626
+
1627
+ def predict(self, *predict_data, backend=None, config=None):
1468
1628
  """
1469
1629
  Generate output predictions for the input samples.
1470
1630
 
@@ -1472,6 +1632,49 @@ class Model:
1472
1632
  predict_data (Union[Tensor, list[Tensor], tuple[Tensor]], optional):
1473
1633
  The predict data, can be a single tensor,
1474
1634
  a list of tensor, or a tuple of tensor.
1635
+ backend (str): Select predict backend, this parameter is an experimental feature
1636
+ and is mainly used for MindSpore Lite cloud-side inference. Default: ``None`` .
1637
+ config (dict, optional) - The config parameter is enabled when the backend is ‘lite’.
1638
+ The config includes two parts: config_path (configPath, str) and config_item (str, dict).
1639
+ When the config_item is set, its priority is higher than the config_path. Set the ranking
1640
+ table file for inference. The content of the configuration file is as follows:
1641
+
1642
+ config_path defines the path of the configuration file, which is used to pass user-defined
1643
+ options during model building. In the following scenarios, users may need to set parameters.
1644
+ For example: "/home/user/config.ini". Default value: ``""`` , here is the content of the
1645
+ config.ini file:
1646
+
1647
+ .. code-block::
1648
+
1649
+ [ascend_context]
1650
+ rank_table_file = [path_a](storage initial path of the rank table file)
1651
+ [execution_plan]
1652
+ [op_name1] = data_type:float16 (operator named op_name1 is set to data type Float16)
1653
+ [op_name2] = data_type:float32 (operator named op_name2 is set to data type Float32)
1654
+
1655
+ When only the config_path is configured, it is done as follows:
1656
+
1657
+ .. code-block::
1658
+
1659
+ config = {"configPath" : "/home/user/config.ini"}
1660
+
1661
+ When only the config_dict is configured, it is done as follows:
1662
+
1663
+ .. code-block::
1664
+
1665
+ config = {"ascend_context" : {"rank_table_file" : "path_b"},
1666
+ "execution_plan" : {"op_name1" : "data_type:float16", "op_name2" : "data_type:float32"}}
1667
+
1668
+ When both the `config_path` and the `config_dict` are configured, it is done as follows:
1669
+
1670
+ .. code-block::
1671
+
1672
+ config = {"configPath" : "/home/user/config.ini",
1673
+ "ascend_context" : {"rank_table_file" : "path_b"},
1674
+ "execution_plan" : {"op_name3" : "data_type:float16", "op_name4" : "data_type:float32"}}
1675
+
1676
+ Note that both the "configPath" is configured in the config_dict and the config_item,
1677
+ in this case, the path_b in the config_dict takes precedence.
1475
1678
 
1476
1679
  Returns:
1477
1680
  Tensor, array(s) of predictions.
@@ -1483,9 +1686,27 @@ class Model:
1483
1686
  >>> from mindspore.train import Model
1484
1687
  >>>
1485
1688
  >>> input_data = Tensor(np.random.randint(0, 255, [1, 1, 32, 32]), mindspore.float32)
1486
- >>> model = Model(Net())
1689
+ >>> # Define the network structure of LeNet5. Refer to
1690
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
1691
+ >>> model = Model(LeNet5())
1487
1692
  >>> result = model.predict(input_data)
1488
1693
  """
1694
+ if backend not in ['lite', None]:
1695
+ raise ValueError(f"For Model.predict, `backend` should be 'lite' or None, but got {backend}")
1696
+ if backend == "lite" and self._lite_infer:
1697
+ # pylint: disable=broad-except
1698
+ try:
1699
+ return self._predict_lite(*predict_data, config=config)
1700
+ except RuntimeError:
1701
+ self._lite_infer = False
1702
+ logger.warning("Lite inference failed, fallback to original inference!")
1703
+ except ImportError:
1704
+ self._lite_infer = False
1705
+ logger.warning("Import mindspore_lite failed, fallback to original inference!")
1706
+ except BaseException as e:
1707
+ self._lite_infer = False
1708
+ logger.warning(f"Lite inference failed, {e.__str__()}, fallback to original inference!")
1709
+
1489
1710
  self._check_network_mode(self._predict_network, False)
1490
1711
  check_input_data(*predict_data, data_class=(int, float, str, None, Tensor))
1491
1712
  _parallel_predict_check()
@@ -1550,12 +1771,12 @@ class Model:
1550
1771
  function respectively.
1551
1772
  dataset_sink_mode (bool): Determines whether to pass the data through dataset channel.
1552
1773
  Configure pynative mode or CPU, the training process will be performed with
1553
- dataset not sink. Default: True.
1554
- sink_size (int): Control the amount of data in each sink.
1774
+ dataset not sink. Default: ``True`` .
1775
+ sink_size (int): Control the number of steps for each sinking.
1555
1776
  If sink_size = -1, sink the complete dataset for each epoch.
1556
1777
  If sink_size > 0, sink sink_size data for each epoch.
1557
1778
  If dataset_sink_mode is False, set sink_size as invalid.
1558
- Default: -1.
1779
+ Default: ``-1`` .
1559
1780
 
1560
1781
  Returns:
1561
1782
  Dict, Parameter layout dictionary used for load distributed checkpoint
@@ -1573,10 +1794,12 @@ class Model:
1573
1794
  >>> init()
1574
1795
  >>> ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL)
1575
1796
  >>>
1576
- >>> # For details about how to build the dataset, please refer to the tutorial
1577
- >>> # document on the official website.
1578
- >>> dataset = create_custom_dataset()
1579
- >>> net = Net()
1797
+ >>> # Create the dataset taking MNIST as an example. Refer to
1798
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/mnist.py
1799
+ >>> dataset = create_dataset()
1800
+ >>> # Define the network structure of LeNet5. Refer to
1801
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
1802
+ >>> net = LeNet5()
1580
1803
  >>> loss = nn.SoftmaxCrossEntropyWithLogits()
1581
1804
  >>> loss_scale_manager = ms.FixedLossScaleManager()
1582
1805
  >>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
@@ -1598,7 +1821,7 @@ class Model:
1598
1821
  return train_network.parameter_layout_dict
1599
1822
 
1600
1823
 
1601
- def infer_predict_layout(self, *predict_data):
1824
+ def infer_predict_layout(self, *predict_data, skip_backend_compile=False):
1602
1825
  """
1603
1826
  Generate parameter layout for the predict network in 'AUTO_PARALLEL' or 'SEMI_AUTO_PARALLEL' mode.
1604
1827
 
@@ -1611,6 +1834,9 @@ class Model:
1611
1834
  predict_data (Union[Tensor, list[Tensor], tuple[Tensor]], optional):
1612
1835
  The predict data, can be a single tensor,
1613
1836
  a list of tensor, or a tuple of tensor.
1837
+ skip_backend_compile (bool): Only run the frontend compile process,
1838
+ skip the compile process on the device side. Set this flag to True may
1839
+ lead to recompiling process can not hit cache.
1614
1840
 
1615
1841
  Returns:
1616
1842
  Dict, Parameter layout dictionary used for load distributed checkpoint.
@@ -1646,7 +1872,14 @@ class Model:
1646
1872
  predict_net = self._predict_network
1647
1873
  # Unlike the cases in build_train_network() and build_eval_network(), 'multi_subgraphs' is not set
1648
1874
  predict_net = self._check_network_mode(predict_net, False)
1649
- predict_net.compile(*predict_data)
1875
+ if skip_backend_compile:
1876
+ origin_phase = predict_net.phase
1877
+ predict_net.phase = "export." + predict_net.phase
1878
+ predict_net.compile(*predict_data)
1879
+ # set phase back to prevent from hitting incomplete compile cache
1880
+ predict_net.phase = origin_phase
1881
+ else:
1882
+ predict_net.compile(*predict_data)
1650
1883
  return predict_net.parameter_layout_dict
1651
1884
 
1652
1885
  def _flush_from_cache(self, cb_params):
@@ -1686,5 +1919,16 @@ class Model:
1686
1919
  """
1687
1920
  return self._eval_network
1688
1921
 
1922
+ def _prepare_obf_dataset(self, dataset):
1923
+ if not hasattr(self._network, 'obf_ratios'):
1924
+ return dataset
1925
+ data_size = dataset.get_dataset_size()
1926
+ obf_ratio_dataset = []
1927
+ for _ in range(data_size):
1928
+ obf_ratio_dataset.append(self._network.obf_ratios)
1929
+ obf_ratio_dataset = ds.NumpySlicesDataset(data=obf_ratio_dataset, column_names=["y_obf"])
1930
+ dataset = ds.zip((dataset, obf_ratio_dataset))
1931
+ return dataset
1932
+
1689
1933
 
1690
1934
  __all__ = ["Model"]