mindspore 2.3.0__cp310-cp310-win_amd64.whl → 2.4.0__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (308) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +3 -1
  5. mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
  8. mindspore/_checkparam.py +50 -9
  9. mindspore/_extends/parse/compile_config.py +41 -0
  10. mindspore/_extends/parse/parser.py +9 -7
  11. mindspore/_extends/parse/standard_method.py +52 -14
  12. mindspore/_extends/pijit/pijit_func_white_list.py +350 -24
  13. mindspore/amp.py +24 -10
  14. mindspore/atlprov.dll +0 -0
  15. mindspore/avcodec-59.dll +0 -0
  16. mindspore/avdevice-59.dll +0 -0
  17. mindspore/avfilter-8.dll +0 -0
  18. mindspore/avformat-59.dll +0 -0
  19. mindspore/avutil-57.dll +0 -0
  20. mindspore/c1.dll +0 -0
  21. mindspore/c1xx.dll +0 -0
  22. mindspore/c2.dll +0 -0
  23. mindspore/common/__init__.py +6 -4
  24. mindspore/common/_pijit_context.py +190 -0
  25. mindspore/common/_register_for_tensor.py +2 -1
  26. mindspore/common/_tensor_overload.py +139 -0
  27. mindspore/common/api.py +102 -87
  28. mindspore/common/dump.py +5 -6
  29. mindspore/common/generator.py +1 -7
  30. mindspore/common/hook_handle.py +14 -26
  31. mindspore/common/mindir_util.py +2 -2
  32. mindspore/common/parameter.py +46 -13
  33. mindspore/common/recompute.py +39 -9
  34. mindspore/common/sparse_tensor.py +7 -3
  35. mindspore/common/tensor.py +209 -29
  36. mindspore/communication/__init__.py +1 -1
  37. mindspore/communication/_comm_helper.py +38 -3
  38. mindspore/communication/comm_func.py +310 -55
  39. mindspore/communication/management.py +14 -14
  40. mindspore/context.py +123 -22
  41. mindspore/dataset/__init__.py +1 -1
  42. mindspore/dataset/audio/__init__.py +1 -1
  43. mindspore/dataset/core/config.py +7 -0
  44. mindspore/dataset/core/validator_helpers.py +7 -0
  45. mindspore/dataset/engine/cache_client.py +1 -1
  46. mindspore/dataset/engine/datasets.py +72 -44
  47. mindspore/dataset/engine/datasets_audio.py +7 -7
  48. mindspore/dataset/engine/datasets_standard_format.py +53 -3
  49. mindspore/dataset/engine/datasets_text.py +20 -20
  50. mindspore/dataset/engine/datasets_user_defined.py +174 -104
  51. mindspore/dataset/engine/datasets_vision.py +33 -33
  52. mindspore/dataset/engine/iterators.py +29 -0
  53. mindspore/dataset/engine/obs/util.py +7 -0
  54. mindspore/dataset/engine/queue.py +114 -60
  55. mindspore/dataset/engine/serializer_deserializer.py +2 -2
  56. mindspore/dataset/engine/validators.py +34 -14
  57. mindspore/dataset/text/__init__.py +1 -4
  58. mindspore/dataset/transforms/__init__.py +0 -3
  59. mindspore/dataset/utils/line_reader.py +2 -0
  60. mindspore/dataset/vision/__init__.py +1 -4
  61. mindspore/dataset/vision/utils.py +1 -1
  62. mindspore/dataset/vision/validators.py +2 -1
  63. mindspore/dnnl.dll +0 -0
  64. mindspore/dpcmi.dll +0 -0
  65. mindspore/{nn/extend → experimental/es}/__init__.py +4 -11
  66. mindspore/experimental/es/embedding_service.py +883 -0
  67. mindspore/{nn/layer → experimental/es}/embedding_service_layer.py +218 -30
  68. mindspore/experimental/llm_boost/__init__.py +21 -0
  69. mindspore/{nn/extend/layer → experimental/llm_boost/atb}/__init__.py +4 -8
  70. mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
  71. mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
  72. mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
  73. mindspore/experimental/llm_boost/register.py +129 -0
  74. mindspore/experimental/llm_boost/utils.py +31 -0
  75. mindspore/experimental/optim/adamw.py +85 -0
  76. mindspore/experimental/optim/optimizer.py +3 -0
  77. mindspore/hal/__init__.py +3 -3
  78. mindspore/hal/contiguous_tensors_handle.py +175 -0
  79. mindspore/hal/stream.py +18 -0
  80. mindspore/include/api/model_group.h +13 -1
  81. mindspore/include/api/types.h +10 -10
  82. mindspore/include/dataset/config.h +2 -2
  83. mindspore/include/dataset/constants.h +2 -2
  84. mindspore/include/dataset/execute.h +2 -2
  85. mindspore/include/dataset/vision.h +4 -0
  86. mindspore/jpeg62.dll +0 -0
  87. mindspore/log.py +1 -1
  88. mindspore/mindrecord/filewriter.py +68 -51
  89. mindspore/mindspore_backend.dll +0 -0
  90. mindspore/mindspore_common.dll +0 -0
  91. mindspore/mindspore_core.dll +0 -0
  92. mindspore/mindspore_glog.dll +0 -0
  93. mindspore/mindspore_np_dtype.dll +0 -0
  94. mindspore/mindspore_ops.dll +0 -0
  95. mindspore/mint/__init__.py +495 -46
  96. mindspore/mint/distributed/__init__.py +31 -0
  97. mindspore/mint/distributed/distributed.py +254 -0
  98. mindspore/mint/nn/__init__.py +266 -21
  99. mindspore/mint/nn/functional.py +125 -19
  100. mindspore/mint/nn/layer/__init__.py +39 -0
  101. mindspore/mint/nn/layer/activation.py +133 -0
  102. mindspore/mint/nn/layer/normalization.py +477 -0
  103. mindspore/mint/nn/layer/pooling.py +110 -0
  104. mindspore/mint/optim/adamw.py +28 -7
  105. mindspore/mint/special/__init__.py +63 -0
  106. mindspore/msobj140.dll +0 -0
  107. mindspore/mspdb140.dll +0 -0
  108. mindspore/mspdbcore.dll +0 -0
  109. mindspore/mspdbst.dll +0 -0
  110. mindspore/mspft140.dll +0 -0
  111. mindspore/msvcdis140.dll +0 -0
  112. mindspore/msvcp140_1.dll +0 -0
  113. mindspore/msvcp140_2.dll +0 -0
  114. mindspore/msvcp140_atomic_wait.dll +0 -0
  115. mindspore/msvcp140_codecvt_ids.dll +0 -0
  116. mindspore/multiprocessing/__init__.py +2 -1
  117. mindspore/nn/__init__.py +0 -1
  118. mindspore/nn/cell.py +275 -93
  119. mindspore/nn/layer/activation.py +211 -44
  120. mindspore/nn/layer/basic.py +113 -3
  121. mindspore/nn/layer/embedding.py +120 -2
  122. mindspore/nn/layer/normalization.py +101 -5
  123. mindspore/nn/layer/padding.py +34 -48
  124. mindspore/nn/layer/pooling.py +161 -7
  125. mindspore/nn/layer/transformer.py +3 -3
  126. mindspore/nn/loss/__init__.py +2 -2
  127. mindspore/nn/loss/loss.py +84 -6
  128. mindspore/nn/optim/__init__.py +2 -1
  129. mindspore/nn/optim/adadelta.py +1 -1
  130. mindspore/nn/optim/adam.py +1 -1
  131. mindspore/nn/optim/lamb.py +1 -1
  132. mindspore/nn/optim/tft_wrapper.py +127 -0
  133. mindspore/nn/wrap/cell_wrapper.py +12 -23
  134. mindspore/nn/wrap/grad_reducer.py +5 -5
  135. mindspore/nn/wrap/loss_scale.py +17 -3
  136. mindspore/numpy/__init__.py +1 -1
  137. mindspore/numpy/array_creations.py +65 -68
  138. mindspore/numpy/array_ops.py +64 -60
  139. mindspore/numpy/fft.py +610 -75
  140. mindspore/numpy/logic_ops.py +11 -10
  141. mindspore/numpy/math_ops.py +85 -84
  142. mindspore/numpy/utils_const.py +4 -4
  143. mindspore/opencv_core452.dll +0 -0
  144. mindspore/opencv_imgcodecs452.dll +0 -0
  145. mindspore/opencv_imgproc452.dll +0 -0
  146. mindspore/ops/__init__.py +6 -4
  147. mindspore/ops/_grad_experimental/grad_comm_ops.py +47 -3
  148. mindspore/ops/_grad_experimental/grad_math_ops.py +0 -22
  149. mindspore/ops/_vmap/vmap_array_ops.py +2 -4
  150. mindspore/ops/_vmap/vmap_math_ops.py +17 -1
  151. mindspore/ops/_vmap/vmap_nn_ops.py +43 -2
  152. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +85 -7
  153. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +2 -0
  154. mindspore/ops/auto_generate/gen_extend_func.py +734 -13
  155. mindspore/ops/auto_generate/gen_ops_def.py +2420 -381
  156. mindspore/ops/auto_generate/gen_ops_prim.py +5196 -1659
  157. mindspore/ops/auto_generate/pyboost_inner_prim.py +176 -56
  158. mindspore/ops/composite/base.py +85 -48
  159. mindspore/ops/composite/multitype_ops/_compile_utils.py +1 -0
  160. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -2
  161. mindspore/ops/function/__init__.py +22 -0
  162. mindspore/ops/function/array_func.py +490 -153
  163. mindspore/ops/function/debug_func.py +113 -1
  164. mindspore/ops/function/fft_func.py +15 -2
  165. mindspore/ops/function/grad/grad_func.py +3 -2
  166. mindspore/ops/function/math_func.py +558 -207
  167. mindspore/ops/function/nn_func.py +817 -383
  168. mindspore/ops/function/other_func.py +3 -2
  169. mindspore/ops/function/random_func.py +184 -8
  170. mindspore/ops/function/reshard_func.py +13 -11
  171. mindspore/ops/function/sparse_unary_func.py +1 -1
  172. mindspore/ops/function/vmap_func.py +3 -2
  173. mindspore/ops/functional.py +24 -14
  174. mindspore/ops/op_info_register.py +3 -3
  175. mindspore/ops/operations/__init__.py +6 -1
  176. mindspore/ops/operations/_grad_ops.py +2 -76
  177. mindspore/ops/operations/_infer_ops.py +1 -1
  178. mindspore/ops/operations/_inner_ops.py +71 -94
  179. mindspore/ops/operations/array_ops.py +12 -146
  180. mindspore/ops/operations/comm_ops.py +42 -53
  181. mindspore/ops/operations/custom_ops.py +83 -19
  182. mindspore/ops/operations/debug_ops.py +42 -10
  183. mindspore/ops/operations/manually_defined/_inner.py +12 -0
  184. mindspore/ops/operations/manually_defined/ops_def.py +265 -10
  185. mindspore/ops/operations/math_ops.py +12 -223
  186. mindspore/ops/operations/nn_ops.py +20 -114
  187. mindspore/ops/operations/other_ops.py +7 -4
  188. mindspore/ops/operations/random_ops.py +46 -1
  189. mindspore/ops/primitive.py +18 -6
  190. mindspore/ops_generate/arg_dtype_cast.py +2 -0
  191. mindspore/ops_generate/gen_aclnn_implement.py +11 -11
  192. mindspore/ops_generate/gen_constants.py +36 -0
  193. mindspore/ops_generate/gen_ops.py +67 -52
  194. mindspore/ops_generate/gen_ops_inner_prim.py +1 -1
  195. mindspore/ops_generate/gen_pyboost_func.py +131 -47
  196. mindspore/ops_generate/op_proto.py +10 -3
  197. mindspore/ops_generate/pyboost_utils.py +14 -1
  198. mindspore/ops_generate/template.py +43 -21
  199. mindspore/parallel/__init__.py +3 -1
  200. mindspore/parallel/_auto_parallel_context.py +28 -8
  201. mindspore/parallel/_cell_wrapper.py +83 -0
  202. mindspore/parallel/_parallel_serialization.py +47 -19
  203. mindspore/parallel/_tensor.py +81 -11
  204. mindspore/parallel/_utils.py +13 -1
  205. mindspore/parallel/algo_parameter_config.py +5 -5
  206. mindspore/parallel/checkpoint_transform.py +46 -39
  207. mindspore/parallel/cluster/process_entity/__init__.py +1 -1
  208. mindspore/parallel/cluster/process_entity/_api.py +31 -23
  209. mindspore/parallel/cluster/process_entity/_utils.py +2 -27
  210. mindspore/parallel/parameter_broadcast.py +3 -4
  211. mindspore/parallel/shard.py +162 -31
  212. mindspore/parallel/transform_safetensors.py +993 -0
  213. mindspore/pgodb140.dll +0 -0
  214. mindspore/pgort140.dll +0 -0
  215. mindspore/profiler/__init__.py +2 -1
  216. mindspore/profiler/common/constant.py +29 -0
  217. mindspore/profiler/common/registry.py +47 -0
  218. mindspore/profiler/common/util.py +28 -0
  219. mindspore/profiler/dynamic_profiler.py +694 -0
  220. mindspore/profiler/envprofiling.py +17 -19
  221. mindspore/profiler/parser/ascend_analysis/constant.py +18 -0
  222. mindspore/profiler/parser/ascend_analysis/file_manager.py +25 -4
  223. mindspore/profiler/parser/ascend_analysis/function_event.py +43 -19
  224. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +31 -26
  225. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +56 -10
  226. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +55 -8
  227. mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
  228. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +27 -20
  229. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +9 -2
  230. mindspore/profiler/parser/ascend_msprof_exporter.py +5 -4
  231. mindspore/profiler/parser/ascend_timeline_generator.py +27 -25
  232. mindspore/profiler/parser/base_timeline_generator.py +19 -25
  233. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
  234. mindspore/profiler/parser/framework_parser.py +1 -391
  235. mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
  236. mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
  237. mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
  238. mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
  239. mindspore/profiler/parser/memory_usage_parser.py +0 -154
  240. mindspore/profiler/parser/profiler_info.py +78 -6
  241. mindspore/profiler/profiler.py +153 -0
  242. mindspore/profiler/profiling.py +280 -412
  243. mindspore/rewrite/__init__.py +1 -2
  244. mindspore/rewrite/common/namespace.py +4 -4
  245. mindspore/rewrite/symbol_tree/symbol_tree.py +3 -3
  246. mindspore/run_check/_check_version.py +36 -103
  247. mindspore/safeguard/rewrite_obfuscation.py +591 -247
  248. mindspore/swresample-4.dll +0 -0
  249. mindspore/swscale-6.dll +0 -0
  250. mindspore/tbbmalloc.dll +0 -0
  251. mindspore/tinyxml2.dll +0 -0
  252. mindspore/train/__init__.py +4 -3
  253. mindspore/train/_utils.py +28 -2
  254. mindspore/train/amp.py +171 -53
  255. mindspore/train/callback/__init__.py +2 -2
  256. mindspore/train/callback/_callback.py +4 -4
  257. mindspore/train/callback/_checkpoint.py +85 -22
  258. mindspore/train/callback/_cluster_monitor.py +1 -1
  259. mindspore/train/callback/_flops_collector.py +1 -0
  260. mindspore/train/callback/_loss_monitor.py +3 -3
  261. mindspore/train/callback/_on_request_exit.py +134 -31
  262. mindspore/train/callback/_summary_collector.py +5 -5
  263. mindspore/train/callback/_tft_register.py +352 -0
  264. mindspore/train/dataset_helper.py +7 -3
  265. mindspore/train/metrics/metric.py +3 -3
  266. mindspore/train/metrics/roc.py +4 -4
  267. mindspore/train/mind_ir_pb2.py +44 -39
  268. mindspore/train/model.py +134 -58
  269. mindspore/train/serialization.py +336 -112
  270. mindspore/turbojpeg.dll +0 -0
  271. mindspore/utils/__init__.py +21 -0
  272. mindspore/utils/utils.py +60 -0
  273. mindspore/vcmeta.dll +0 -0
  274. mindspore/vcruntime140.dll +0 -0
  275. mindspore/vcruntime140_1.dll +0 -0
  276. mindspore/version.py +1 -1
  277. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/METADATA +6 -2
  278. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/RECORD +281 -275
  279. mindspore/include/c_api/ms/abstract.h +0 -67
  280. mindspore/include/c_api/ms/attribute.h +0 -197
  281. mindspore/include/c_api/ms/base/handle_types.h +0 -43
  282. mindspore/include/c_api/ms/base/macros.h +0 -32
  283. mindspore/include/c_api/ms/base/status.h +0 -33
  284. mindspore/include/c_api/ms/base/types.h +0 -283
  285. mindspore/include/c_api/ms/context.h +0 -102
  286. mindspore/include/c_api/ms/graph.h +0 -160
  287. mindspore/include/c_api/ms/node.h +0 -606
  288. mindspore/include/c_api/ms/tensor.h +0 -161
  289. mindspore/include/c_api/ms/value.h +0 -84
  290. mindspore/mindspore_shared_lib.dll +0 -0
  291. mindspore/nn/extend/basic.py +0 -140
  292. mindspore/nn/extend/embedding.py +0 -143
  293. mindspore/nn/extend/layer/normalization.py +0 -109
  294. mindspore/nn/extend/pooling.py +0 -117
  295. mindspore/nn/layer/embedding_service.py +0 -531
  296. mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
  297. mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
  298. mindspore/ops/extend/__init__.py +0 -53
  299. mindspore/ops/extend/array_func.py +0 -218
  300. mindspore/ops/extend/math_func.py +0 -76
  301. mindspore/ops/extend/nn_func.py +0 -308
  302. mindspore/ops/silent_check.py +0 -162
  303. mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
  304. mindspore/profiler/parser/msadvisor_parser.py +0 -240
  305. mindspore/train/callback/_mindio_ttp.py +0 -443
  306. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/WHEEL +0 -0
  307. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/entry_points.txt +0 -0
  308. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/top_level.txt +0 -0
@@ -20,7 +20,7 @@ DESCRIPTOR = _descriptor.FileDescriptor(
20
20
  syntax='proto2',
21
21
  serialized_options=None,
22
22
  create_key=_descriptor._internal_create_key,
23
- serialized_pb=b'\n\rmind_ir.proto\x12\x07mind_ir\"\x88\t\n\x0e\x41ttributeProto\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\t\n\x01\x66\x18\x02 \x01(\x02\x12\t\n\x01i\x18\x03 \x01(\x03\x12\t\n\x01\x64\x18\x04 \x01(\x01\x12\t\n\x01s\x18\x05 \x01(\x0c\x12\x1f\n\x01t\x18\x06 \x01(\x0b\x32\x14.mind_ir.TensorProto\x12\x1e\n\x01g\x18\x07 \x01(\x0b\x32\x13.mind_ir.GraphProto\x12\x0e\n\x06\x66loats\x18\x08 \x03(\x02\x12\x0f\n\x07\x64oubles\x18\t \x03(\x01\x12\x0c\n\x04ints\x18\n \x03(\x03\x12\x0f\n\x07strings\x18\x0b \x03(\x0c\x12%\n\x07tensors\x18\x0c \x03(\x0b\x32\x14.mind_ir.TensorProto\x12#\n\x06graphs\x18\r \x03(\x0b\x32\x13.mind_ir.GraphProto\x12\x12\n\ndoc_string\x18\x0e \x01(\t\x12\x15\n\rref_attr_name\x18\x0f \x01(\t\x12\x33\n\x04type\x18\x10 \x01(\x0e\x32%.mind_ir.AttributeProto.AttributeType\x12\'\n\x06values\x18\x11 \x03(\x0b\x32\x17.mind_ir.AttributeProto\x12\x36\n\x08seq_info\x18\x12 \x01(\x0b\x32$.mind_ir.AttributeProto.SeqInfoProto\x12&\n\x07\x66unctor\x18\x13 \x01(\x0b\x32\x15.mind_ir.FunctorProto\x1aT\n\x0cSeqInfoProto\x12\x12\n\nis_dyn_len\x18\x01 \x01(\x08\x12\x30\n\x0ftuple_elem_item\x18\x02 \x01(\x0b\x32\x17.mind_ir.AttributeProto\"\xaf\x04\n\rAttributeType\x12\r\n\tUNDEFINED\x10\x00\x12\t\n\x05\x46LOAT\x10\x01\x12\t\n\x05UINT8\x10\x02\x12\x08\n\x04INT8\x10\x03\x12\n\n\x06UINT16\x10\x04\x12\t\n\x05INT16\x10\x05\x12\t\n\x05INT32\x10\x06\x12\t\n\x05INT64\x10\x07\x12\n\n\x06STRING\x10\x08\x12\x08\n\x04\x42OOL\x10\t\x12\x0b\n\x07\x46LOAT16\x10\n\x12\n\n\x06\x44OUBLE\x10\x0b\x12\n\n\x06UINT32\x10\x0c\x12\n\n\x06UINT64\x10\r\x12\r\n\tCOMPLEX64\x10\x0e\x12\x0e\n\nCOMPLEX128\x10\x0f\x12\x0c\n\x08\x42\x46LOAT16\x10\x10\x12\n\n\x06TENSOR\x10\x11\x12\t\n\x05GRAPH\x10\x12\x12\x0b\n\x07TENSORS\x10\x13\x12\t\n\x05TUPLE\x10\x14\x12\x08\n\x04LIST\x10\x15\x12\x08\n\x04\x44ICT\x10\x16\x12\n\n\x06UMONAD\x10\x17\x12\x0b\n\x07IOMONAD\x10\x18\x12\x08\n\x04NONE\x10\x19\x12\x14\n\x10PRIMITIVECLOSURE\x10\x1a\x12\x14\n\x10\x46UNCGRAPHCLOSURE\x10\x1b\x12\x12\n\x0ePARTIALCLOSURE\x10\x1c\x12\x14\n\x10UNIONFUNCCLOSURE\x10\x1d\x12\x0e\n\nCSR_TENSOR\x10\x1e\x12\x0e\n\nCOO_TENSOR\x10\x1f\x12\x0e\n\nROW_TENSOR\x10 \x12\x0e\n\nCLASS_TYPE\x10!\x12\x0e\n\nNAME_SPACE\x10\"\x12\n\n\x06SYMBOL\x10#\x12\r\n\tTYPE_NULL\x10$\x12\x0e\n\nMAP_TENSOR\x10%\x12\x0b\n\x07\x46UNCTOR\x10&\x12\n\n\x06SCALAR\x10\'\"\x9d\x01\n\x0c\x46unctorProto\x12/\n\x04type\x18\x01 \x01(\x0e\x32!.mind_ir.FunctorProto.FunctorType\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\'\n\x06values\x18\x03 \x03(\x0b\x32\x17.mind_ir.AttributeProto\"%\n\x0b\x46unctorType\x12\x16\n\x12SHAPE_CALC_FUNCTOR\x10\x01\"\x98\x01\n\x0eValueInfoProto\x12\x0c\n\x04name\x18\x01 \x01(\t\x12$\n\x06tensor\x18\x02 \x03(\x0b\x32\x14.mind_ir.TensorProto\x12\x12\n\ndoc_string\x18\x03 \x01(\t\x12\x12\n\ndenotation\x18\x04 \x01(\t\x12*\n\tattr_info\x18\x05 \x01(\x0b\x32\x17.mind_ir.AttributeProto\"\xf3\x01\n\tNodeProto\x12\r\n\x05input\x18\x01 \x03(\t\x12\x0e\n\x06output\x18\x02 \x03(\t\x12\x0c\n\x04name\x18\x03 \x01(\t\x12\x0f\n\x07op_type\x18\x04 \x01(\t\x12*\n\tattribute\x18\x05 \x03(\x0b\x32\x17.mind_ir.AttributeProto\x12\x12\n\ndoc_string\x18\x06 \x01(\t\x12\x0e\n\x06\x64omain\x18\x07 \x01(\t\x12*\n\tnode_attr\x18\x08 \x03(\x0b\x32\x17.mind_ir.AttributeProto\x12,\n\x0bprimal_attr\x18\t \x03(\x0b\x32\x17.mind_ir.AttributeProto\"\xf8\x03\n\nModelProto\x12\x12\n\nir_version\x18\x01 \x01(\t\x12\x15\n\rproducer_name\x18\x02 \x01(\t\x12\x18\n\x10producer_version\x18\x03 \x01(\t\x12\x0e\n\x06\x64omain\x18\x04 \x01(\t\x12\x15\n\rmodel_version\x18\x05 \x01(\t\x12\x12\n\ndoc_string\x18\x06 \x01(\t\x12\"\n\x05graph\x18\x07 \x01(\x0b\x32\x13.mind_ir.GraphProto\x12&\n\tfunctions\x18\x08 \x03(\x0b\x32\x13.mind_ir.GraphProto\x12\x30\n\x0cpreprocessor\x18\t \x01(\x0b\x32\x1a.mind_ir.PreprocessorProto\x12\x15\n\rlittle_endian\x18\n \x01(\x08\x12(\n\x08parallel\x18\x0b \x01(\x0b\x32\x16.mind_ir.ParallelProto\x12+\n\nprimitives\x18\x0c \x03(\x0b\x32\x17.mind_ir.PrimitiveProto\x12\x17\n\x0fmind_ir_version\x18\r \x01(\x03\x12\x34\n\tuser_info\x18\x0e \x03(\x0b\x32!.mind_ir.ModelProto.UserInfoEntry\x1a/\n\rUserInfoEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\";\n\x11PreprocessorProto\x12&\n\x02op\x18\x01 \x03(\x0b\x32\x1a.mind_ir.PreprocessOpProto\"\x91\x01\n\x11PreprocessOpProto\x12\x15\n\rinput_columns\x18\x01 \x01(\t\x12\x16\n\x0eoutput_columns\x18\x02 \x01(\t\x12\x17\n\x0fproject_columns\x18\x03 \x01(\t\x12\x0f\n\x07op_type\x18\x04 \x01(\t\x12\x12\n\noperations\x18\x05 \x01(\t\x12\x0f\n\x07offload\x18\x06 \x01(\x08\"\xd2\x02\n\nGraphProto\x12 \n\x04node\x18\x01 \x03(\x0b\x32\x12.mind_ir.NodeProto\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\'\n\tparameter\x18\x03 \x03(\x0b\x32\x14.mind_ir.TensorProto\x12\x12\n\ndoc_string\x18\x04 \x01(\t\x12&\n\x05input\x18\x05 \x03(\x0b\x32\x17.mind_ir.ValueInfoProto\x12\'\n\x06output\x18\x06 \x03(\x0b\x32\x17.mind_ir.ValueInfoProto\x12\x12\n\nbprop_hash\x18\x07 \x01(\t\x12*\n\tattribute\x18\x08 \x03(\x0b\x32\x17.mind_ir.AttributeProto\x12\x16\n\x0e\x62prop_filepath\x18\t \x01(\t\x12.\n\rmap_parameter\x18\n \x03(\x0b\x32\x17.mind_ir.MapTensorProto\"\xda\x07\n\x0bTensorProto\x12\x0c\n\x04\x64ims\x18\x01 \x03(\x03\x12\x11\n\tdata_type\x18\x02 \x01(\x05\x12\x12\n\nfloat_data\x18\x03 \x03(\x02\x12\x12\n\nint32_data\x18\x04 \x03(\x05\x12\x13\n\x0bstring_data\x18\x05 \x03(\x0c\x12\x12\n\nint64_data\x18\x06 \x03(\x03\x12\x0c\n\x04name\x18\x07 \x01(\t\x12\x12\n\ndoc_string\x18\x08 \x01(\t\x12\x10\n\x08raw_data\x18\t \x01(\x0c\x12\x13\n\x0b\x64ouble_data\x18\n \x03(\x01\x12\x13\n\x0buint64_data\x18\x0b \x03(\x04\x12=\n\rexternal_data\x18\x0c \x01(\x0b\x32&.mind_ir.TensorProto.ExternalDataProto\x12\x0f\n\x07ref_key\x18\r \x01(\t\x12\x10\n\x08min_dims\x18\x0e \x03(\x03\x12\x10\n\x08max_dims\x18\x0f \x03(\x03\x12>\n\x10\x63ompression_type\x18\x10 \x01(\x0e\x32$.mind_ir.TensorProto.CompressionType\x12:\n\x0cquant_params\x18\x11 \x03(\x0b\x32$.mind_ir.TensorProto.QuantParamProto\x1a\x45\n\x11\x45xternalDataProto\x12\x10\n\x08location\x18\x01 \x01(\t\x12\x0e\n\x06offset\x18\x02 \x01(\x03\x12\x0e\n\x06length\x18\x03 \x01(\x03\x1aV\n\x0fQuantParamProto\x12\x17\n\x0fquant_algo_name\x18\x01 \x02(\t\x12*\n\tattribute\x18\x02 \x03(\x0b\x32\x17.mind_ir.AttributeProto\"\xf4\x01\n\x08\x44\x61taType\x12\r\n\tUNDEFINED\x10\x00\x12\t\n\x05\x46LOAT\x10\x01\x12\t\n\x05UINT8\x10\x02\x12\x08\n\x04INT8\x10\x03\x12\n\n\x06UINT16\x10\x04\x12\t\n\x05INT16\x10\x05\x12\t\n\x05INT32\x10\x06\x12\t\n\x05INT64\x10\x07\x12\n\n\x06STRING\x10\x08\x12\x08\n\x04\x42OOL\x10\t\x12\x0b\n\x07\x46LOAT16\x10\n\x12\n\n\x06\x44OUBLE\x10\x0b\x12\n\n\x06UINT32\x10\x0c\x12\n\n\x06UINT64\x10\r\x12\r\n\tCOMPLEX64\x10\x0e\x12\x0e\n\nCOMPLEX128\x10\x0f\x12\x0c\n\x08\x42\x46LOAT16\x10\x10\x12\x0b\n\x07\x46LOAT64\x10\x11\x12\x0b\n\x07QINT4X2\x10\x12\"u\n\x0f\x43ompressionType\x12\x12\n\x0eNO_COMPRESSION\x10\x00\x12\x0c\n\x08INDEXING\x10\x01\x12\n\n\x06SPARSE\x10\x02\x12\x07\n\x03\x46SE\x10\x03\x12\x0f\n\x0b\x42IT_PACKING\x10\x04\x12\x0b\n\x07\x46SE_INT\x10\x05\x12\r\n\tFSE_INFER\x10\x06\"\xd1\x01\n\x0eMapTensorProto\x12\x0c\n\x04name\x18\x01 \x02(\t\x12.\n\rdefault_value\x18\x02 \x02(\x0b\x32\x17.mind_ir.AttributeProto\x12(\n\nkey_tensor\x18\x03 \x02(\x0b\x32\x14.mind_ir.TensorProto\x12*\n\x0cvalue_tensor\x18\x04 \x02(\x0b\x32\x14.mind_ir.TensorProto\x12+\n\rstatus_tensor\x18\x05 \x02(\x0b\x32\x14.mind_ir.TensorProto\"5\n\rParallelProto\x12$\n\x06layout\x18\x01 \x03(\x0b\x32\x14.mind_ir.LayoutProto\"\xfd\x01\n\x0bLayoutProto\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x1e\n\x16\x64\x65vice_arrangement_int\x18\x02 \x03(\x03\x12\x16\n\x0etensor_map_int\x18\x03 \x03(\x03\x12\x17\n\x0fslice_shape_int\x18\x04 \x03(\x03\x12\x12\n\nfield_size\x18\x05 \x01(\x03\x12\x15\n\runiform_split\x18\x06 \x01(\x08\x12\x17\n\x0fopt_shard_group\x18\x07 \x01(\t\x12\x17\n\x0fpipeline_shared\x18\x08 \x01(\x08\x12\x0f\n\x07is_send\x18\t \x01(\x08\x12\x11\n\tpeer_rank\x18\n \x01(\x03\x12\x0e\n\x06sr_tag\x18\x0b \x01(\x03\"\xda\x01\n\x0ePrimitiveProto\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0f\n\x07op_type\x18\x02 \x01(\t\x12*\n\tattribute\x18\x03 \x03(\x0b\x32\x17.mind_ir.AttributeProto\x12\x15\n\rinstance_name\x18\x04 \x01(\t\x12\x33\n\tprim_type\x18\x05 \x01(\x0e\x32 .mind_ir.PrimitiveProto.PrimType\"1\n\x08PrimType\x12\r\n\tPRIMITIVE\x10\x01\x12\x16\n\x12PRIMITIVE_FUNCTION\x10\x02*R\n\x07Version\x12\x14\n\x10IR_VERSION_START\x10\x00\x12\x0e\n\nIR_VERSION\x10\x01\x12!\n\x1dIR_VERSION_WITH_PRIM_FUNCTION\x10\x02'
23
+ serialized_pb=b'\n\rmind_ir.proto\x12\x07mind_ir\"\x88\t\n\x0e\x41ttributeProto\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\t\n\x01\x66\x18\x02 \x01(\x02\x12\t\n\x01i\x18\x03 \x01(\x03\x12\t\n\x01\x64\x18\x04 \x01(\x01\x12\t\n\x01s\x18\x05 \x01(\x0c\x12\x1f\n\x01t\x18\x06 \x01(\x0b\x32\x14.mind_ir.TensorProto\x12\x1e\n\x01g\x18\x07 \x01(\x0b\x32\x13.mind_ir.GraphProto\x12\x0e\n\x06\x66loats\x18\x08 \x03(\x02\x12\x0f\n\x07\x64oubles\x18\t \x03(\x01\x12\x0c\n\x04ints\x18\n \x03(\x03\x12\x0f\n\x07strings\x18\x0b \x03(\x0c\x12%\n\x07tensors\x18\x0c \x03(\x0b\x32\x14.mind_ir.TensorProto\x12#\n\x06graphs\x18\r \x03(\x0b\x32\x13.mind_ir.GraphProto\x12\x12\n\ndoc_string\x18\x0e \x01(\t\x12\x15\n\rref_attr_name\x18\x0f \x01(\t\x12\x33\n\x04type\x18\x10 \x01(\x0e\x32%.mind_ir.AttributeProto.AttributeType\x12\'\n\x06values\x18\x11 \x03(\x0b\x32\x17.mind_ir.AttributeProto\x12\x36\n\x08seq_info\x18\x12 \x01(\x0b\x32$.mind_ir.AttributeProto.SeqInfoProto\x12&\n\x07\x66unctor\x18\x13 \x01(\x0b\x32\x15.mind_ir.FunctorProto\x1aT\n\x0cSeqInfoProto\x12\x12\n\nis_dyn_len\x18\x01 \x01(\x08\x12\x30\n\x0ftuple_elem_item\x18\x02 \x01(\x0b\x32\x17.mind_ir.AttributeProto\"\xaf\x04\n\rAttributeType\x12\r\n\tUNDEFINED\x10\x00\x12\t\n\x05\x46LOAT\x10\x01\x12\t\n\x05UINT8\x10\x02\x12\x08\n\x04INT8\x10\x03\x12\n\n\x06UINT16\x10\x04\x12\t\n\x05INT16\x10\x05\x12\t\n\x05INT32\x10\x06\x12\t\n\x05INT64\x10\x07\x12\n\n\x06STRING\x10\x08\x12\x08\n\x04\x42OOL\x10\t\x12\x0b\n\x07\x46LOAT16\x10\n\x12\n\n\x06\x44OUBLE\x10\x0b\x12\n\n\x06UINT32\x10\x0c\x12\n\n\x06UINT64\x10\r\x12\r\n\tCOMPLEX64\x10\x0e\x12\x0e\n\nCOMPLEX128\x10\x0f\x12\x0c\n\x08\x42\x46LOAT16\x10\x10\x12\n\n\x06TENSOR\x10\x11\x12\t\n\x05GRAPH\x10\x12\x12\x0b\n\x07TENSORS\x10\x13\x12\t\n\x05TUPLE\x10\x14\x12\x08\n\x04LIST\x10\x15\x12\x08\n\x04\x44ICT\x10\x16\x12\n\n\x06UMONAD\x10\x17\x12\x0b\n\x07IOMONAD\x10\x18\x12\x08\n\x04NONE\x10\x19\x12\x14\n\x10PRIMITIVECLOSURE\x10\x1a\x12\x14\n\x10\x46UNCGRAPHCLOSURE\x10\x1b\x12\x12\n\x0ePARTIALCLOSURE\x10\x1c\x12\x14\n\x10UNIONFUNCCLOSURE\x10\x1d\x12\x0e\n\nCSR_TENSOR\x10\x1e\x12\x0e\n\nCOO_TENSOR\x10\x1f\x12\x0e\n\nROW_TENSOR\x10 \x12\x0e\n\nCLASS_TYPE\x10!\x12\x0e\n\nNAME_SPACE\x10\"\x12\n\n\x06SYMBOL\x10#\x12\r\n\tTYPE_NULL\x10$\x12\x0e\n\nMAP_TENSOR\x10%\x12\x0b\n\x07\x46UNCTOR\x10&\x12\n\n\x06SCALAR\x10\'\"\xae\x01\n\x0c\x46unctorProto\x12/\n\x04type\x18\x01 \x01(\x0e\x32!.mind_ir.FunctorProto.FunctorType\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\'\n\x06values\x18\x03 \x03(\x0b\x32\x17.mind_ir.AttributeProto\"6\n\x0b\x46unctorType\x12\x16\n\x12SHAPE_CALC_FUNCTOR\x10\x01\x12\x0f\n\x0b\x41NY_FUNCTOR\x10\x02\"\x98\x01\n\x0eValueInfoProto\x12\x0c\n\x04name\x18\x01 \x01(\t\x12$\n\x06tensor\x18\x02 \x03(\x0b\x32\x14.mind_ir.TensorProto\x12\x12\n\ndoc_string\x18\x03 \x01(\t\x12\x12\n\ndenotation\x18\x04 \x01(\t\x12*\n\tattr_info\x18\x05 \x01(\x0b\x32\x17.mind_ir.AttributeProto\"\xf3\x01\n\tNodeProto\x12\r\n\x05input\x18\x01 \x03(\t\x12\x0e\n\x06output\x18\x02 \x03(\t\x12\x0c\n\x04name\x18\x03 \x01(\t\x12\x0f\n\x07op_type\x18\x04 \x01(\t\x12*\n\tattribute\x18\x05 \x03(\x0b\x32\x17.mind_ir.AttributeProto\x12\x12\n\ndoc_string\x18\x06 \x01(\t\x12\x0e\n\x06\x64omain\x18\x07 \x01(\t\x12*\n\tnode_attr\x18\x08 \x03(\x0b\x32\x17.mind_ir.AttributeProto\x12,\n\x0bprimal_attr\x18\t \x03(\x0b\x32\x17.mind_ir.AttributeProto\"\xf8\x03\n\nModelProto\x12\x12\n\nir_version\x18\x01 \x01(\t\x12\x15\n\rproducer_name\x18\x02 \x01(\t\x12\x18\n\x10producer_version\x18\x03 \x01(\t\x12\x0e\n\x06\x64omain\x18\x04 \x01(\t\x12\x15\n\rmodel_version\x18\x05 \x01(\t\x12\x12\n\ndoc_string\x18\x06 \x01(\t\x12\"\n\x05graph\x18\x07 \x01(\x0b\x32\x13.mind_ir.GraphProto\x12&\n\tfunctions\x18\x08 \x03(\x0b\x32\x13.mind_ir.GraphProto\x12\x30\n\x0cpreprocessor\x18\t \x01(\x0b\x32\x1a.mind_ir.PreprocessorProto\x12\x15\n\rlittle_endian\x18\n \x01(\x08\x12(\n\x08parallel\x18\x0b \x01(\x0b\x32\x16.mind_ir.ParallelProto\x12+\n\nprimitives\x18\x0c \x03(\x0b\x32\x17.mind_ir.PrimitiveProto\x12\x17\n\x0fmind_ir_version\x18\r \x01(\x03\x12\x34\n\tuser_info\x18\x0e \x03(\x0b\x32!.mind_ir.ModelProto.UserInfoEntry\x1a/\n\rUserInfoEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\";\n\x11PreprocessorProto\x12&\n\x02op\x18\x01 \x03(\x0b\x32\x1a.mind_ir.PreprocessOpProto\"\x91\x01\n\x11PreprocessOpProto\x12\x15\n\rinput_columns\x18\x01 \x01(\t\x12\x16\n\x0eoutput_columns\x18\x02 \x01(\t\x12\x17\n\x0fproject_columns\x18\x03 \x01(\t\x12\x0f\n\x07op_type\x18\x04 \x01(\t\x12\x12\n\noperations\x18\x05 \x01(\t\x12\x0f\n\x07offload\x18\x06 \x01(\x08\"\xd2\x02\n\nGraphProto\x12 \n\x04node\x18\x01 \x03(\x0b\x32\x12.mind_ir.NodeProto\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\'\n\tparameter\x18\x03 \x03(\x0b\x32\x14.mind_ir.TensorProto\x12\x12\n\ndoc_string\x18\x04 \x01(\t\x12&\n\x05input\x18\x05 \x03(\x0b\x32\x17.mind_ir.ValueInfoProto\x12\'\n\x06output\x18\x06 \x03(\x0b\x32\x17.mind_ir.ValueInfoProto\x12\x12\n\nbprop_hash\x18\x07 \x01(\t\x12*\n\tattribute\x18\x08 \x03(\x0b\x32\x17.mind_ir.AttributeProto\x12\x16\n\x0e\x62prop_filepath\x18\t \x01(\t\x12.\n\rmap_parameter\x18\n \x03(\x0b\x32\x17.mind_ir.MapTensorProto\"\xda\x07\n\x0bTensorProto\x12\x0c\n\x04\x64ims\x18\x01 \x03(\x03\x12\x11\n\tdata_type\x18\x02 \x01(\x05\x12\x12\n\nfloat_data\x18\x03 \x03(\x02\x12\x12\n\nint32_data\x18\x04 \x03(\x05\x12\x13\n\x0bstring_data\x18\x05 \x03(\x0c\x12\x12\n\nint64_data\x18\x06 \x03(\x03\x12\x0c\n\x04name\x18\x07 \x01(\t\x12\x12\n\ndoc_string\x18\x08 \x01(\t\x12\x10\n\x08raw_data\x18\t \x01(\x0c\x12\x13\n\x0b\x64ouble_data\x18\n \x03(\x01\x12\x13\n\x0buint64_data\x18\x0b \x03(\x04\x12=\n\rexternal_data\x18\x0c \x01(\x0b\x32&.mind_ir.TensorProto.ExternalDataProto\x12\x0f\n\x07ref_key\x18\r \x01(\t\x12\x10\n\x08min_dims\x18\x0e \x03(\x03\x12\x10\n\x08max_dims\x18\x0f \x03(\x03\x12>\n\x10\x63ompression_type\x18\x10 \x01(\x0e\x32$.mind_ir.TensorProto.CompressionType\x12:\n\x0cquant_params\x18\x11 \x03(\x0b\x32$.mind_ir.TensorProto.QuantParamProto\x1a\x45\n\x11\x45xternalDataProto\x12\x10\n\x08location\x18\x01 \x01(\t\x12\x0e\n\x06offset\x18\x02 \x01(\x03\x12\x0e\n\x06length\x18\x03 \x01(\x03\x1aV\n\x0fQuantParamProto\x12\x17\n\x0fquant_algo_name\x18\x01 \x02(\t\x12*\n\tattribute\x18\x02 \x03(\x0b\x32\x17.mind_ir.AttributeProto\"\xf4\x01\n\x08\x44\x61taType\x12\r\n\tUNDEFINED\x10\x00\x12\t\n\x05\x46LOAT\x10\x01\x12\t\n\x05UINT8\x10\x02\x12\x08\n\x04INT8\x10\x03\x12\n\n\x06UINT16\x10\x04\x12\t\n\x05INT16\x10\x05\x12\t\n\x05INT32\x10\x06\x12\t\n\x05INT64\x10\x07\x12\n\n\x06STRING\x10\x08\x12\x08\n\x04\x42OOL\x10\t\x12\x0b\n\x07\x46LOAT16\x10\n\x12\n\n\x06\x44OUBLE\x10\x0b\x12\n\n\x06UINT32\x10\x0c\x12\n\n\x06UINT64\x10\r\x12\r\n\tCOMPLEX64\x10\x0e\x12\x0e\n\nCOMPLEX128\x10\x0f\x12\x0c\n\x08\x42\x46LOAT16\x10\x10\x12\x0b\n\x07\x46LOAT64\x10\x11\x12\x0b\n\x07QINT4X2\x10\x12\"u\n\x0f\x43ompressionType\x12\x12\n\x0eNO_COMPRESSION\x10\x00\x12\x0c\n\x08INDEXING\x10\x01\x12\n\n\x06SPARSE\x10\x02\x12\x07\n\x03\x46SE\x10\x03\x12\x0f\n\x0b\x42IT_PACKING\x10\x04\x12\x0b\n\x07\x46SE_INT\x10\x05\x12\r\n\tFSE_INFER\x10\x06\"\xd1\x01\n\x0eMapTensorProto\x12\x0c\n\x04name\x18\x01 \x02(\t\x12.\n\rdefault_value\x18\x02 \x02(\x0b\x32\x17.mind_ir.AttributeProto\x12(\n\nkey_tensor\x18\x03 \x02(\x0b\x32\x14.mind_ir.TensorProto\x12*\n\x0cvalue_tensor\x18\x04 \x02(\x0b\x32\x14.mind_ir.TensorProto\x12+\n\rstatus_tensor\x18\x05 \x02(\x0b\x32\x14.mind_ir.TensorProto\"5\n\rParallelProto\x12$\n\x06layout\x18\x01 \x03(\x0b\x32\x14.mind_ir.LayoutProto\"\xfd\x01\n\x0bLayoutProto\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x1e\n\x16\x64\x65vice_arrangement_int\x18\x02 \x03(\x03\x12\x16\n\x0etensor_map_int\x18\x03 \x03(\x03\x12\x17\n\x0fslice_shape_int\x18\x04 \x03(\x03\x12\x12\n\nfield_size\x18\x05 \x01(\x03\x12\x15\n\runiform_split\x18\x06 \x01(\x08\x12\x17\n\x0fopt_shard_group\x18\x07 \x01(\t\x12\x17\n\x0fpipeline_shared\x18\x08 \x01(\x08\x12\x0f\n\x07is_send\x18\t \x01(\x08\x12\x11\n\tpeer_rank\x18\n \x01(\x03\x12\x0e\n\x06sr_tag\x18\x0b \x01(\x03\"\xda\x01\n\x0ePrimitiveProto\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0f\n\x07op_type\x18\x02 \x01(\t\x12*\n\tattribute\x18\x03 \x03(\x0b\x32\x17.mind_ir.AttributeProto\x12\x15\n\rinstance_name\x18\x04 \x01(\t\x12\x33\n\tprim_type\x18\x05 \x01(\x0e\x32 .mind_ir.PrimitiveProto.PrimType\"1\n\x08PrimType\x12\r\n\tPRIMITIVE\x10\x01\x12\x16\n\x12PRIMITIVE_FUNCTION\x10\x02*R\n\x07Version\x12\x14\n\x10IR_VERSION_START\x10\x00\x12\x0e\n\nIR_VERSION\x10\x01\x12!\n\x1dIR_VERSION_WITH_PRIM_FUNCTION\x10\x02'
24
24
  )
25
25
 
26
26
  _VERSION = _descriptor.EnumDescriptor(
@@ -48,8 +48,8 @@ _VERSION = _descriptor.EnumDescriptor(
48
48
  ],
49
49
  containing_type=None,
50
50
  serialized_options=None,
51
- serialized_start=4540,
52
- serialized_end=4622,
51
+ serialized_start=4557,
52
+ serialized_end=4639,
53
53
  )
54
54
  _sym_db.RegisterEnumDescriptor(_VERSION)
55
55
 
@@ -286,11 +286,16 @@ _FUNCTORPROTO_FUNCTORTYPE = _descriptor.EnumDescriptor(
286
286
  serialized_options=None,
287
287
  type=None,
288
288
  create_key=_descriptor._internal_create_key),
289
+ _descriptor.EnumValueDescriptor(
290
+ name='ANY_FUNCTOR', index=1, number=2,
291
+ serialized_options=None,
292
+ type=None,
293
+ create_key=_descriptor._internal_create_key),
289
294
  ],
290
295
  containing_type=None,
291
296
  serialized_options=None,
292
297
  serialized_start=1310,
293
- serialized_end=1347,
298
+ serialized_end=1364,
294
299
  )
295
300
  _sym_db.RegisterEnumDescriptor(_FUNCTORPROTO_FUNCTORTYPE)
296
301
 
@@ -399,8 +404,8 @@ _TENSORPROTO_DATATYPE = _descriptor.EnumDescriptor(
399
404
  ],
400
405
  containing_type=None,
401
406
  serialized_options=None,
402
- serialized_start=3431,
403
- serialized_end=3675,
407
+ serialized_start=3448,
408
+ serialized_end=3692,
404
409
  )
405
410
  _sym_db.RegisterEnumDescriptor(_TENSORPROTO_DATATYPE)
406
411
 
@@ -449,8 +454,8 @@ _TENSORPROTO_COMPRESSIONTYPE = _descriptor.EnumDescriptor(
449
454
  ],
450
455
  containing_type=None,
451
456
  serialized_options=None,
452
- serialized_start=3677,
453
- serialized_end=3794,
457
+ serialized_start=3694,
458
+ serialized_end=3811,
454
459
  )
455
460
  _sym_db.RegisterEnumDescriptor(_TENSORPROTO_COMPRESSIONTYPE)
456
461
 
@@ -474,8 +479,8 @@ _PRIMITIVEPROTO_PRIMTYPE = _descriptor.EnumDescriptor(
474
479
  ],
475
480
  containing_type=None,
476
481
  serialized_options=None,
477
- serialized_start=4489,
478
- serialized_end=4538,
482
+ serialized_start=4506,
483
+ serialized_end=4555,
479
484
  )
480
485
  _sym_db.RegisterEnumDescriptor(_PRIMITIVEPROTO_PRIMTYPE)
481
486
 
@@ -720,7 +725,7 @@ _FUNCTORPROTO = _descriptor.Descriptor(
720
725
  oneofs=[
721
726
  ],
722
727
  serialized_start=1190,
723
- serialized_end=1347,
728
+ serialized_end=1364,
724
729
  )
725
730
 
726
731
 
@@ -779,8 +784,8 @@ _VALUEINFOPROTO = _descriptor.Descriptor(
779
784
  extension_ranges=[],
780
785
  oneofs=[
781
786
  ],
782
- serialized_start=1350,
783
- serialized_end=1502,
787
+ serialized_start=1367,
788
+ serialized_end=1519,
784
789
  )
785
790
 
786
791
 
@@ -867,8 +872,8 @@ _NODEPROTO = _descriptor.Descriptor(
867
872
  extension_ranges=[],
868
873
  oneofs=[
869
874
  ],
870
- serialized_start=1505,
871
- serialized_end=1748,
875
+ serialized_start=1522,
876
+ serialized_end=1765,
872
877
  )
873
878
 
874
879
 
@@ -906,8 +911,8 @@ _MODELPROTO_USERINFOENTRY = _descriptor.Descriptor(
906
911
  extension_ranges=[],
907
912
  oneofs=[
908
913
  ],
909
- serialized_start=2208,
910
- serialized_end=2255,
914
+ serialized_start=2225,
915
+ serialized_end=2272,
911
916
  )
912
917
 
913
918
  _MODELPROTO = _descriptor.Descriptor(
@@ -1028,8 +1033,8 @@ _MODELPROTO = _descriptor.Descriptor(
1028
1033
  extension_ranges=[],
1029
1034
  oneofs=[
1030
1035
  ],
1031
- serialized_start=1751,
1032
- serialized_end=2255,
1036
+ serialized_start=1768,
1037
+ serialized_end=2272,
1033
1038
  )
1034
1039
 
1035
1040
 
@@ -1060,8 +1065,8 @@ _PREPROCESSORPROTO = _descriptor.Descriptor(
1060
1065
  extension_ranges=[],
1061
1066
  oneofs=[
1062
1067
  ],
1063
- serialized_start=2257,
1064
- serialized_end=2316,
1068
+ serialized_start=2274,
1069
+ serialized_end=2333,
1065
1070
  )
1066
1071
 
1067
1072
 
@@ -1127,8 +1132,8 @@ _PREPROCESSOPPROTO = _descriptor.Descriptor(
1127
1132
  extension_ranges=[],
1128
1133
  oneofs=[
1129
1134
  ],
1130
- serialized_start=2319,
1131
- serialized_end=2464,
1135
+ serialized_start=2336,
1136
+ serialized_end=2481,
1132
1137
  )
1133
1138
 
1134
1139
 
@@ -1222,8 +1227,8 @@ _GRAPHPROTO = _descriptor.Descriptor(
1222
1227
  extension_ranges=[],
1223
1228
  oneofs=[
1224
1229
  ],
1225
- serialized_start=2467,
1226
- serialized_end=2805,
1230
+ serialized_start=2484,
1231
+ serialized_end=2822,
1227
1232
  )
1228
1233
 
1229
1234
 
@@ -1268,8 +1273,8 @@ _TENSORPROTO_EXTERNALDATAPROTO = _descriptor.Descriptor(
1268
1273
  extension_ranges=[],
1269
1274
  oneofs=[
1270
1275
  ],
1271
- serialized_start=3271,
1272
- serialized_end=3340,
1276
+ serialized_start=3288,
1277
+ serialized_end=3357,
1273
1278
  )
1274
1279
 
1275
1280
  _TENSORPROTO_QUANTPARAMPROTO = _descriptor.Descriptor(
@@ -1306,8 +1311,8 @@ _TENSORPROTO_QUANTPARAMPROTO = _descriptor.Descriptor(
1306
1311
  extension_ranges=[],
1307
1312
  oneofs=[
1308
1313
  ],
1309
- serialized_start=3342,
1310
- serialized_end=3428,
1314
+ serialized_start=3359,
1315
+ serialized_end=3445,
1311
1316
  )
1312
1317
 
1313
1318
  _TENSORPROTO = _descriptor.Descriptor(
@@ -1451,8 +1456,8 @@ _TENSORPROTO = _descriptor.Descriptor(
1451
1456
  extension_ranges=[],
1452
1457
  oneofs=[
1453
1458
  ],
1454
- serialized_start=2808,
1455
- serialized_end=3794,
1459
+ serialized_start=2825,
1460
+ serialized_end=3811,
1456
1461
  )
1457
1462
 
1458
1463
 
@@ -1511,8 +1516,8 @@ _MAPTENSORPROTO = _descriptor.Descriptor(
1511
1516
  extension_ranges=[],
1512
1517
  oneofs=[
1513
1518
  ],
1514
- serialized_start=3797,
1515
- serialized_end=4006,
1519
+ serialized_start=3814,
1520
+ serialized_end=4023,
1516
1521
  )
1517
1522
 
1518
1523
 
@@ -1543,8 +1548,8 @@ _PARALLELPROTO = _descriptor.Descriptor(
1543
1548
  extension_ranges=[],
1544
1549
  oneofs=[
1545
1550
  ],
1546
- serialized_start=4008,
1547
- serialized_end=4061,
1551
+ serialized_start=4025,
1552
+ serialized_end=4078,
1548
1553
  )
1549
1554
 
1550
1555
 
@@ -1645,8 +1650,8 @@ _LAYOUTPROTO = _descriptor.Descriptor(
1645
1650
  extension_ranges=[],
1646
1651
  oneofs=[
1647
1652
  ],
1648
- serialized_start=4064,
1649
- serialized_end=4317,
1653
+ serialized_start=4081,
1654
+ serialized_end=4334,
1650
1655
  )
1651
1656
 
1652
1657
 
@@ -1706,8 +1711,8 @@ _PRIMITIVEPROTO = _descriptor.Descriptor(
1706
1711
  extension_ranges=[],
1707
1712
  oneofs=[
1708
1713
  ],
1709
- serialized_start=4320,
1710
- serialized_end=4538,
1714
+ serialized_start=4337,
1715
+ serialized_end=4555,
1711
1716
  )
1712
1717
 
1713
1718
  _ATTRIBUTEPROTO_SEQINFOPROTO.fields_by_name['tuple_elem_item'].message_type = _ATTRIBUTEPROTO
mindspore/train/model.py CHANGED
@@ -36,7 +36,7 @@ from mindspore.train.metrics import get_metrics, get_metric_fn
36
36
  from mindspore._checkparam import check_input_data, check_output_data
37
37
  from mindspore import _checkparam as Validator
38
38
  from mindspore.train.callback import _InternalCallbackParam, RunContext, _CallbackManager, Callback, TimeMonitor,\
39
- FlopsUtilizationCollector, MindIOTTPAdapter
39
+ FlopsUtilizationCollector, TFTRegister
40
40
  from mindspore.train.callback import __all__ as internal_cb_names
41
41
  from mindspore.train.callback._cluster_monitor import ClusterMonitor
42
42
  from mindspore import context
@@ -119,6 +119,101 @@ def _save_final_ckpt(func):
119
119
  func(self, *args, **kwargs)
120
120
  return wrapper
121
121
 
122
+ def _handle_tft(func):
123
+ """
124
+ Decorator function, which starts uce handle process when an exception occurs during training.
125
+ """
126
+ @wraps(func)
127
+ def wrapper(self, *args, **kwargs):
128
+ obj = None
129
+ if kwargs.get('callbacks') and isinstance(kwargs.get('callbacks'), TFTRegister):
130
+ obj = kwargs.get('callbacks')
131
+ if kwargs.get('callbacks') and isinstance(kwargs.get('callbacks'), list):
132
+ for item in kwargs.get('callbacks'):
133
+ if isinstance(item, TFTRegister):
134
+ obj = item
135
+ if obj:
136
+ tft = obj.tft
137
+ tft_env = os.getenv("MS_ENABLE_TFT", "")
138
+ uce_env = "UCE:1" in tft_env
139
+ while True:
140
+ try:
141
+ return func(self, *args, **kwargs)
142
+ except RuntimeError as e:
143
+ logger.info("uce wrapper caught RuntimeError")
144
+ if not uce_env:
145
+ logger.info("uce wrapper caught RuntimeError uce not enable")
146
+ tft.tft_report_error(tft.ReportState.RS_UNKNOWN.value)
147
+ raise e
148
+ e_str = str(e)
149
+ logger.info("uce wrapper caught RuntimeError e_str:{}".format(e_str))
150
+ if "UCEError" in e_str:
151
+ logger.info("uce wrapper report UCEError")
152
+ tft.tft_report_error(tft.ReportState.RS_UCE.value)
153
+ elif "ForceStopError" in e_str:
154
+ logger.info("uce wrapper caught RuntimeError ForceStopError")
155
+ force_stop_err = tft.ReportState.RS_NORMAL.value
156
+ tft.tft_report_error(force_stop_err)
157
+ else:
158
+ logger.info("uce wrapper caught RuntimeError rankid: {} OTHER ERROR")
159
+ tft.tft_report_error(tft.ReportState.RS_UNKNOWN.value)
160
+ raise e
161
+ ret = tft.tft_wait_next_action()
162
+ if ret == tft.Action.EXIT.value:
163
+ raise e
164
+ repair_step = tft.tft_get_repair_step()
165
+ logger.info("uce wrapper caught repair finish REPAIR STEP: {} batch_num: \
166
+ {}".format(repair_step, self.batch_num))
167
+ initial_epoch = int(repair_step/self.batch_num)
168
+ initial_step = repair_step % self.batch_num
169
+ kwargs["initial_epoch"] = initial_epoch
170
+
171
+ train_dataset = args[1]
172
+ dataset_sink_mode = args[3] if len(args) > 3 else kwargs.get('dataset_sink_mode', True)
173
+ sink_size = args[4] if len(args) > 4 else kwargs.get('sink_size', -1)
174
+
175
+ cb_initial_step = 0
176
+ if dataset_sink_mode:
177
+ train_dataset.set_init_step(initial_epoch)
178
+ dataset_size = train_dataset.get_dataset_size()
179
+ if sink_size != -1:
180
+ cb_initial_step = initial_epoch * sink_size + initial_step
181
+ else:
182
+ cb_initial_step = initial_epoch * dataset_size + initial_step
183
+ else:
184
+ train_dataset.set_init_step(initial_step)
185
+ cb_initial_step = initial_step
186
+
187
+ kwargs["initial_step"] = cb_initial_step
188
+
189
+ logger.info("uce wrapper repair complete \
190
+ initial_epoch: {}, cb_initial_step: {} ".format(initial_epoch, cb_initial_step))
191
+ continue
192
+ except BaseException as e:
193
+ logger.info("uce wrapper caught BaseException error")
194
+ tft.tft_report_error(tft.ReportState.RS_UNKNOWN.value)
195
+ raise e
196
+ else:
197
+ return func(self, *args, **kwargs)
198
+ return wrapper
199
+
200
+
201
+ def _check_tft():
202
+ """Check if TFT is supported"""
203
+ tft_env = os.getenv("MS_ENABLE_TFT")
204
+ device_target = context.get_context("device_target")
205
+ if tft_env and device_target == "Ascend":
206
+ from mindspore._c_expression import MSContext
207
+ ascend_target = MSContext.get_instance().get_ascend_soc_version()
208
+ if ascend_target == 'ascend910':
209
+ raise ValueError("TFT is not supported when using ascend910")
210
+ ms_mode = context.get_context("mode")
211
+ if ms_mode != mindspore.GRAPH_MODE:
212
+ raise ValueError("TFT is only supported in GRAPH_MODE")
213
+ jit_level = context.get_context("jit_level")
214
+ if jit_level == "O2" and "UCE:1" in tft_env:
215
+ raise ValueError("TFT is not supported when using jit_level == O2")
216
+
122
217
 
123
218
  def _append_ccae(callbacks):
124
219
  """Add cluster monitoring when CCAE is enabled."""
@@ -290,21 +385,11 @@ class Model:
290
385
  amp_level (str): Option for argument `level` in :func:`mindspore.amp.build_train_network`, level for mixed
291
386
  precision training. Supports ["O0", "O1", "O2", "O3", "auto"]. Default: ``"O0"`` .
292
387
 
293
- - "O0": Do not change.
294
- - "O1": Cast the operators in white_list to float16, the remaining operators are kept in float32.
295
- The operators in the whitelist: [Conv1d, Conv2d, Conv3d, Conv1dTranspose, Conv2dTranspose,
296
- Conv3dTranspose, Dense, LSTMCell, RNNCell, GRUCell, MatMul, BatchMatMul, PReLU, ReLU, Ger].
297
- - "O2": Cast network to float16, keep BatchNorm run in float32, using dynamic loss scale.
298
- - "O3": Cast network to float16, the BatchNorm is also cast to float16, loss scale will not be used.
299
- - "auto": Set level to recommended level in different devices. Set level to "O2" on GPU, set
300
- level to "O3" on Ascend. The recommended level is chosen by the expert experience, not applicable to all
301
- scenarios. User should specify the level for special network.
302
-
303
- "O2" is recommended on GPU, "O3" is recommended on Ascend.
388
+ For details on `amp_level` , refer to :func:`mindspore.amp.auto_mixed_precision`.
389
+
304
390
  The BatchNorm strategy can be changed by `keep_batchnorm_fp32` settings in `kwargs`. `keep_batchnorm_fp32`
305
391
  must be a bool. The loss scale strategy can be changed by `loss_scale_manager` setting in `kwargs`.
306
392
  `loss_scale_manager` should be a subclass of :class:`mindspore.amp.LossScaleManager`.
307
- The more detailed explanation of `amp_level` setting can be found at `mindspore.amp.build_train_network`.
308
393
 
309
394
  boost_level (str): Option for argument `level` in `mindspore.boost`, level for boost mode
310
395
  training. Supports ["O0", "O1", "O2"]. Default: ``"O0"`` .
@@ -379,6 +464,7 @@ class Model:
379
464
  self._mindspore_lite = None
380
465
  self._lite_infer = True # if backend lite infer fails, set False
381
466
  self._mindspore_lite_model_group_id = id(self) & 0xFFFF
467
+ self.batch_num = -1
382
468
 
383
469
  def _check_for_graph_cell(self, kwargs):
384
470
  """Check for graph cell"""
@@ -568,9 +654,11 @@ class Model:
568
654
  dataset.__loop_size__ = 1
569
655
 
570
656
  if dataset_helper is None:
657
+ logger.info("Begin to create DatasetHelper.")
571
658
  dataset_helper = DatasetHelper(dataset, dataset_sink_mode, sink_size, epoch_num)
572
659
 
573
660
  if dataset_sink_mode:
661
+ logger.info("Begin to connect network with dataset.")
574
662
  network = connect_network_with_dataset(network, dataset_helper)
575
663
 
576
664
  if _get_recovery_context("enable_recovery") and is_train:
@@ -589,6 +677,10 @@ class Model:
589
677
  if self._backbone_is_train != is_train:
590
678
  network.set_train(is_train)
591
679
  self._backbone_is_train = is_train
680
+ # Mode train and eval are the same net, network will be set_grad in _build_train_network.
681
+ # But if mode just want to do predict or eval, must set network set_grad False
682
+ if not is_train:
683
+ network.set_grad(False)
592
684
  return network
593
685
 
594
686
  def _check_need_ckpt(self, callbacks):
@@ -687,6 +779,7 @@ class Model:
687
779
  if not train_dataset and not valid_dataset:
688
780
  raise ValueError("The argument 'train_dataset' and 'valid_dataset' can not both be None or empty.")
689
781
 
782
+ logger.info("Begin to check device number in model.build() procedure.")
690
783
  _device_number_check(self._parallel_mode, self._device_number)
691
784
 
692
785
  if train_dataset:
@@ -694,27 +787,34 @@ class Model:
694
787
  raise TypeError("The type of 'train_dataset' must be `Dataset`, "
695
788
  "but got {}.".format(type(train_dataset)))
696
789
 
790
+ logger.info("Begin to check parameter broadcast in model.build() procedure.")
697
791
  _parameter_broadcast_check(self._parallel_mode, self._parameter_broadcast)
698
792
  if self._parameter_broadcast:
699
793
  self._train_network.set_broadcast_flag()
700
794
 
795
+ logger.info("Begin to exec preprocess in model.build() procedure.")
701
796
  train_dataset.__no_send__ = True
702
797
  train_dataset_helper, train_network = self._exec_preprocess(is_train=True,
703
798
  dataset=train_dataset,
704
799
  dataset_sink_mode=True,
705
800
  sink_size=sink_size)
801
+ logger.info("Begin to warmup dataset in model.build() procedure.")
706
802
  self._warmup_dataset(epoch, train_dataset, sink_size)
707
803
 
708
804
  # Since dataset pipeline has been triggered, delete flag
709
805
  delattr(train_dataset, "__no_send__")
710
806
 
711
807
  # Waiting for the dataset warmup ready
808
+ logger.info("Begin waiting for dataset warmup in model.build() procedure.")
712
809
  self._waiting_for_dataset_warmup_ready(train_dataset)
810
+ logger.info("The dataset warmup was successful in model.build() procedure.")
713
811
 
714
812
  if context.get_auto_parallel_context("pipeline_stages") > 1 and valid_dataset:
715
813
  train_network.add_flags_recursive(is_first_iteration=True)
716
814
  for inputs in train_dataset_helper:
815
+ logger.info("Begin to compile train network in model.build() procedure.")
717
816
  train_network.compile(*inputs)
817
+ self._train_network.parameter_layout_dict = train_network.parameter_layout_dict
718
818
  break
719
819
 
720
820
  if valid_dataset:
@@ -732,6 +832,7 @@ class Model:
732
832
  if context.get_auto_parallel_context("pipeline_stages") > 1:
733
833
  eval_network.add_flags_recursive(is_first_iteration=False)
734
834
  for inputs in valid_dataset_helper:
835
+ logger.info("Begin to compile eval network in model.build() procedure.")
735
836
  eval_network.compile(*inputs)
736
837
  break
737
838
 
@@ -746,9 +847,10 @@ class Model:
746
847
 
747
848
  return [callbacks]
748
849
 
850
+ @_handle_tft
749
851
  @_save_final_ckpt
750
852
  def _train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True, sink_size=-1, initial_epoch=0,
751
- valid_dataset=None, valid_frequency=1, valid_dataset_sink_mode=True):
853
+ valid_dataset=None, valid_frequency=1, valid_dataset_sink_mode=True, initial_step=0):
752
854
  """
753
855
  Training.
754
856
 
@@ -772,12 +874,14 @@ class Model:
772
874
  self._train_network.set_broadcast_flag()
773
875
 
774
876
  cb_params = _InternalCallbackParam()
877
+ cb_params.cur_step_num = initial_step
775
878
  cb_params.train_network = self._train_network
776
879
  cb_params.epoch_num = epoch - initial_epoch
777
880
  if dataset_sink_mode and sink_size > 0:
778
881
  cb_params.batch_num = sink_size
779
882
  else:
780
883
  cb_params.batch_num = train_dataset.get_dataset_size()
884
+ self.batch_num = cb_params.batch_num
781
885
  cb_params.mode = "train"
782
886
  cb_params.loss_fn = self._loss_fn
783
887
  cb_params.optimizer = self._optimizer
@@ -806,11 +910,13 @@ class Model:
806
910
  with _CallbackManager(callbacks) as list_callback:
807
911
  self._check_reuse_dataset(train_dataset)
808
912
  if not dataset_sink_mode:
809
- self._train_process(epoch, train_dataset, list_callback, cb_params, initial_epoch, valid_infos)
913
+ self._train_process(epoch, train_dataset, list_callback, cb_params, initial_epoch,
914
+ valid_infos)
810
915
  elif context.get_context("device_target") == "CPU":
811
916
  logger.info("The CPU cannot support dataset sink mode currently."
812
917
  "So the training process will be performed with dataset not sink.")
813
- self._train_process(epoch, train_dataset, list_callback, cb_params, initial_epoch, valid_infos)
918
+ self._train_process(epoch, train_dataset, list_callback, cb_params, initial_epoch,
919
+ valid_infos)
814
920
  else:
815
921
  self._train_dataset_sink_process(epoch, train_dataset, list_callback,
816
922
  cb_params, sink_size, initial_epoch, valid_infos)
@@ -850,9 +956,7 @@ class Model:
850
956
  train_dataset.__total_batch__ = epoch * sink_size
851
957
 
852
958
  cb_params.sink_size = sink_size
853
- cb_params.cur_step_num = 0
854
959
  cb_params.dataset_sink_mode = True
855
-
856
960
  run_context = RunContext(cb_params)
857
961
  list_callback.on_train_begin(run_context)
858
962
  # used to stop training for early stop, such as stopAtTIme or stopATStep
@@ -861,7 +965,6 @@ class Model:
861
965
  dataset_helper = train_dataset._dataset_helper
862
966
 
863
967
  self.epoch_iter = 0
864
-
865
968
  self._check_enable_recovery()
866
969
  # Used to check whether need perform recovery for process which is restarted.
867
970
  self._check_need_load_ckpt(cb_params, dataset_size, sink_size)
@@ -997,7 +1100,6 @@ class Model:
997
1100
  dataset_size (int): The number of batches in a dataset.
998
1101
  sink_size (int): Control the amount of data in each sink. Default: -1.
999
1102
  """
1000
-
1001
1103
  if not self.enable_recovery:
1002
1104
  self.need_load_ckpt = False
1003
1105
 
@@ -1084,7 +1186,6 @@ class Model:
1084
1186
  dataset=train_dataset,
1085
1187
  dataset_sink_mode=False,
1086
1188
  epoch_num=epoch)
1087
- cb_params.cur_step_num = 0
1088
1189
  cb_params.dataset_sink_mode = False
1089
1190
  run_context = RunContext(cb_params)
1090
1191
  list_callback.on_train_begin(run_context)
@@ -1106,7 +1207,6 @@ class Model:
1106
1207
  "returned by 'train_dataset'".format(len_element))
1107
1208
  cb_params.cur_step_num += 1
1108
1209
  self._current_step_num = int((cb_params.cur_step_num - 1) % cb_params.batch_num + 1)
1109
-
1110
1210
  cb_params.train_dataset_element = next_element
1111
1211
  list_callback.on_train_step_begin(run_context)
1112
1212
  self._check_network_mode(self._train_network, True)
@@ -1150,31 +1250,6 @@ class Model:
1150
1250
 
1151
1251
  list_callback.on_train_end(run_context)
1152
1252
 
1153
- def _wrapper_train(self, callbacks):
1154
- """
1155
- This method used to wrap train function with ttp wrapper which will do event notify when
1156
- exceptions throw.
1157
-
1158
- Args:
1159
- callbacks (function): Callbacks passed by train method.
1160
- """
1161
-
1162
- if not callbacks:
1163
- return self._train
1164
- cbs = callbacks if isinstance(callbacks, list) else [callbacks]
1165
- obj = None
1166
- _train_wrapper = None
1167
- for item in cbs:
1168
- if isinstance(item, MindIOTTPAdapter):
1169
- obj = item
1170
-
1171
- if (obj is not None) and (obj.enable is True):
1172
- logger.info("MindIO TTP is enable, so we wrapper ttp exception handdler for self train method.")
1173
- _train_wrapper = obj.wrapper_ttp_persist(self._train)
1174
-
1175
- return self._train if not _train_wrapper else _train_wrapper
1176
-
1177
-
1178
1253
  def train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=False, sink_size=-1, initial_epoch=0):
1179
1254
  """
1180
1255
  Training API.
@@ -1240,9 +1315,10 @@ class Model:
1240
1315
  ... loss_scale_manager=loss_scale_manager)
1241
1316
  >>> model.train(2, dataset)
1242
1317
  """
1318
+ _check_tft()
1319
+ device_target = context.get_context("device_target")
1243
1320
  # prepare dataset for obfuscated model
1244
1321
  train_dataset = self._prepare_obf_dataset(train_dataset)
1245
- device_target = context.get_context("device_target")
1246
1322
  if _is_ps_mode() and not _cache_enable() and (device_target in ["Ascend", "CPU"]) and dataset_sink_mode:
1247
1323
  logger.info("For PS mode, reset datasink mode to False when using Ascend or CPU backend.")
1248
1324
  dataset_sink_mode = False
@@ -1283,16 +1359,14 @@ class Model:
1283
1359
  _device_number_check(self._parallel_mode, self._device_number)
1284
1360
 
1285
1361
  callbacks = _append_ccae(callbacks)
1286
- _train_wrapper = None
1287
1362
  if callbacks:
1288
1363
  self._check_methods_for_custom_callbacks(callbacks, "train")
1289
- _train_wrapper = self._wrapper_train(callbacks)
1290
- _train_wrapper(epoch,
1291
- train_dataset,
1292
- callbacks=callbacks,
1293
- dataset_sink_mode=dataset_sink_mode,
1294
- sink_size=sink_size,
1295
- initial_epoch=initial_epoch)
1364
+ self._train(epoch,
1365
+ train_dataset,
1366
+ callbacks=callbacks,
1367
+ dataset_sink_mode=dataset_sink_mode,
1368
+ sink_size=sink_size,
1369
+ initial_epoch=initial_epoch)
1296
1370
 
1297
1371
  # When it's distributed training and using MindRT,
1298
1372
  # the node id should be reset to start from 0.
@@ -1396,7 +1470,7 @@ class Model:
1396
1470
 
1397
1471
  Tutorial Examples:
1398
1472
  - `Advanced Encapsulation: Model - Train and Save Model
1399
- <https://www.mindspore.cn/tutorials/en/master/advanced/model.html#training-and-saving-model>`_
1473
+ <https://www.mindspore.cn/docs/en/master/model_train/train_process/model.html#training-and-saving-model>`_
1400
1474
  """
1401
1475
  device_target = context.get_context("device_target")
1402
1476
  if _is_ps_mode() and not _cache_enable() and (device_target in ["Ascend", "CPU"]) and dataset_sink_mode:
@@ -1493,7 +1567,9 @@ class Model:
1493
1567
  if hasattr(self._train_network, '_is_check_and_refresh') and not self._train_network._is_check_and_refresh:
1494
1568
  self._train_network.check_names_and_refresh_name()
1495
1569
  self._train_network._is_check_and_refresh = True
1570
+ logger.info("Begin to init dataset in model.build() procedure.")
1496
1571
  self._init(train_dataset, valid_dataset, sink_size, epoch)
1572
+ logger.info("The model.build() which contains dataset warmup and network compile is success.")
1497
1573
 
1498
1574
  def _eval_in_fit(self, valid_dataset, callbacks=None, dataset_sink_mode=True, cb_params=None):
1499
1575
  """
@@ -1663,7 +1739,7 @@ class Model:
1663
1739
 
1664
1740
  Tutorial Examples:
1665
1741
  - `Advanced Encapsulation: Model - Train and Save Model
1666
- <https://www.mindspore.cn/tutorials/en/master/advanced/model.html#training-and-saving-model>`_
1742
+ <https://www.mindspore.cn/docs/en/master/model_train/train_process/model.html#training-and-saving-model>`_
1667
1743
  """
1668
1744
  valid_dataset = self._prepare_obf_dataset(valid_dataset)
1669
1745
  dataset_sink_mode = Validator.check_bool(dataset_sink_mode)