mindspore 2.1.0__cp38-cp38-win_amd64.whl → 2.2.11__cp38-cp38-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +4 -1
- mindspore/_c_dataengine.cp38-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp38-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp38-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -1
- mindspore/_checkparam.py +23 -29
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +4 -11
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +13 -15
- mindspore/_extends/parse/namespace.py +7 -33
- mindspore/_extends/parse/parser.py +67 -72
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +86 -106
- mindspore/_extends/parse/trope.py +1 -1
- mindspore/_extends/remote/kernel_build_server.py +25 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/amp.py +47 -11
- mindspore/atlprov.dll +0 -0
- mindspore/boost/boost.py +1 -8
- mindspore/boost/boost_cell_wrapper.py +3 -2
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/group_loss_scale_manager.py +8 -7
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +5 -3
- mindspore/common/_jit_fallback_utils.py +6 -0
- mindspore/common/_register_for_adapter.py +2 -0
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +13 -0
- mindspore/common/_utils.py +29 -0
- mindspore/common/api.py +174 -259
- mindspore/common/auto_dynamic_shape.py +494 -0
- mindspore/common/dtype.py +18 -11
- mindspore/common/dump.py +6 -4
- mindspore/common/initializer.py +14 -14
- mindspore/common/jit_config.py +33 -15
- mindspore/common/lazy_inline.py +126 -7
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/parameter.py +51 -41
- mindspore/common/seed.py +4 -4
- mindspore/common/sparse_tensor.py +13 -14
- mindspore/common/tensor.py +243 -165
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +83 -4
- mindspore/communication/management.py +152 -84
- mindspore/config/op_info.config +14 -3
- mindspore/context.py +152 -61
- mindspore/dataset/__init__.py +5 -5
- mindspore/dataset/audio/__init__.py +2 -2
- mindspore/dataset/audio/transforms.py +52 -52
- mindspore/dataset/callback/ds_callback.py +16 -2
- mindspore/dataset/core/config.py +68 -51
- mindspore/dataset/engine/cache_client.py +33 -7
- mindspore/dataset/engine/datasets.py +250 -112
- mindspore/dataset/engine/datasets_audio.py +43 -211
- mindspore/dataset/engine/datasets_standard_format.py +16 -35
- mindspore/dataset/engine/datasets_text.py +43 -67
- mindspore/dataset/engine/datasets_user_defined.py +86 -100
- mindspore/dataset/engine/datasets_vision.py +219 -1029
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/samplers.py +1 -1
- mindspore/dataset/engine/validators.py +19 -5
- mindspore/dataset/text/__init__.py +3 -3
- mindspore/dataset/text/transforms.py +101 -127
- mindspore/dataset/text/utils.py +205 -138
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +95 -40
- mindspore/dataset/utils/browse_dataset.py +8 -2
- mindspore/dataset/utils/line_reader.py +17 -19
- mindspore/dataset/vision/__init__.py +3 -3
- mindspore/dataset/vision/c_transforms.py +6 -3
- mindspore/dataset/vision/transforms.py +409 -287
- mindspore/dataset/vision/utils.py +13 -14
- mindspore/dataset/vision/validators.py +11 -1
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/experimental/map_parameter.py +14 -0
- mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
- mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
- mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
- mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +17 -14
- mindspore/include/api/status.h +8 -3
- mindspore/include/api/types.h +37 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/dataset/constants.h +6 -5
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +13 -13
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/jpeg62.dll +0 -0
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/nn/__init__.py +0 -2
- mindspore/nn/cell.py +313 -74
- mindspore/nn/dynamic_lr.py +21 -21
- mindspore/nn/layer/activation.py +22 -30
- mindspore/nn/layer/basic.py +15 -13
- mindspore/nn/layer/channel_shuffle.py +1 -1
- mindspore/nn/layer/container.py +271 -9
- mindspore/nn/layer/conv.py +323 -204
- mindspore/nn/layer/dense.py +8 -5
- mindspore/nn/layer/embedding.py +33 -27
- mindspore/nn/layer/flash_attention.py +61 -95
- mindspore/nn/layer/image.py +8 -6
- mindspore/nn/layer/math.py +16 -25
- mindspore/nn/layer/normalization.py +107 -66
- mindspore/nn/layer/padding.py +1 -1
- mindspore/nn/layer/pooling.py +131 -109
- mindspore/nn/layer/rnn_cells.py +27 -22
- mindspore/nn/layer/rnns.py +13 -16
- mindspore/nn/layer/thor_layer.py +1 -1
- mindspore/nn/layer/transformer.py +221 -154
- mindspore/nn/learning_rate_schedule.py +9 -1
- mindspore/nn/loss/loss.py +235 -174
- mindspore/nn/optim/ada_grad.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -0
- mindspore/nn/optim/adafactor.py +2 -1
- mindspore/nn/optim/adam.py +7 -4
- mindspore/nn/optim/adamax.py +3 -2
- mindspore/nn/optim/adasum.py +2 -2
- mindspore/nn/optim/asgd.py +2 -3
- mindspore/nn/optim/ftrl.py +6 -5
- mindspore/nn/optim/lamb.py +7 -4
- mindspore/nn/optim/lars.py +1 -1
- mindspore/nn/optim/lazyadam.py +5 -3
- mindspore/nn/optim/momentum.py +2 -1
- mindspore/nn/optim/optimizer.py +53 -4
- mindspore/nn/optim/proximal_ada_grad.py +3 -4
- mindspore/nn/optim/rmsprop.py +4 -3
- mindspore/nn/optim/rprop.py +23 -12
- mindspore/nn/optim/sgd.py +26 -11
- mindspore/nn/optim/thor.py +9 -7
- mindspore/nn/probability/bijector/bijector.py +5 -5
- mindspore/nn/probability/bijector/power_transform.py +27 -27
- mindspore/nn/probability/bijector/softplus.py +3 -3
- mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
- mindspore/nn/probability/distribution/bernoulli.py +5 -5
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +7 -7
- mindspore/nn/probability/distribution/cauchy.py +0 -1
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +4 -4
- mindspore/nn/probability/distribution/gumbel.py +4 -4
- mindspore/nn/probability/distribution/log_normal.py +2 -2
- mindspore/nn/probability/distribution/logistic.py +2 -2
- mindspore/nn/probability/distribution/poisson.py +4 -4
- mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
- mindspore/nn/probability/distribution/uniform.py +6 -6
- mindspore/nn/wrap/__init__.py +4 -2
- mindspore/nn/wrap/cell_wrapper.py +87 -34
- mindspore/nn/wrap/grad_reducer.py +8 -5
- mindspore/nn/wrap/loss_scale.py +105 -42
- mindspore/numpy/array_creations.py +1 -2
- mindspore/numpy/array_ops.py +3 -2
- mindspore/numpy/utils_const.py +5 -5
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/_grad_experimental/__init__.py +0 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
- mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
- mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
- mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
- mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
- mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
- mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
- mindspore/ops/_op_impl/aicpu/add.py +3 -3
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/{_custom_op/flash_attention/constants.py → aicpu/eps.py} +18 -27
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/tbe/__init__.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +45 -13
- mindspore/ops/_utils/utils.py +6 -1
- mindspore/ops/_vmap/vmap_array_ops.py +3 -3
- mindspore/ops/_vmap/vmap_base.py +3 -3
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/base.py +37 -10
- mindspore/ops/composite/math_ops.py +5 -4
- mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
- mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
- mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +4 -1
- mindspore/ops/function/array_func.py +174 -193
- mindspore/ops/function/clip_func.py +81 -13
- mindspore/ops/function/debug_func.py +1 -1
- mindspore/ops/function/grad/grad_func.py +18 -9
- mindspore/ops/function/image_func.py +10 -4
- mindspore/ops/function/linalg_func.py +5 -5
- mindspore/ops/function/math_func.py +575 -386
- mindspore/ops/function/nn_func.py +568 -260
- mindspore/ops/function/random_func.py +88 -57
- mindspore/ops/function/sparse_func.py +1 -1
- mindspore/ops/function/sparse_unary_func.py +14 -12
- mindspore/ops/function/vmap_func.py +6 -5
- mindspore/ops/functional.py +15 -10
- mindspore/ops/op_info_register.py +244 -25
- mindspore/ops/operations/__init__.py +31 -19
- mindspore/ops/operations/_grad_ops.py +71 -7
- mindspore/ops/operations/_inner_ops.py +350 -17
- mindspore/ops/operations/_quant_ops.py +4 -8
- mindspore/ops/operations/_sequence_ops.py +42 -0
- mindspore/ops/operations/array_ops.py +68 -282
- mindspore/ops/operations/comm_ops.py +107 -59
- mindspore/ops/operations/custom_ops.py +94 -70
- mindspore/ops/operations/debug_ops.py +8 -4
- mindspore/ops/operations/image_ops.py +18 -12
- mindspore/ops/operations/inner_ops.py +26 -3
- mindspore/ops/operations/math_ops.py +192 -144
- mindspore/ops/operations/nn_ops.py +857 -489
- mindspore/ops/operations/other_ops.py +0 -22
- mindspore/ops/operations/random_ops.py +53 -111
- mindspore/ops/operations/sparse_ops.py +3 -1
- mindspore/ops/primitive.py +24 -18
- mindspore/parallel/_auto_parallel_context.py +68 -8
- mindspore/parallel/_cost_model_context.py +2 -2
- mindspore/parallel/_offload_context.py +17 -3
- mindspore/parallel/_parallel_serialization.py +12 -5
- mindspore/parallel/_ps_context.py +12 -0
- mindspore/parallel/_tensor.py +18 -13
- mindspore/parallel/_transformer/layers.py +5 -3
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +2 -2
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +23 -3
- mindspore/parallel/_utils.py +11 -7
- mindspore/parallel/algo_parameter_config.py +85 -5
- mindspore/parallel/checkpoint_transform.py +19 -12
- mindspore/parallel/shard.py +21 -14
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +4 -2
- mindspore/profiler/envprofiling.py +1 -1
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +2 -2
- mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
- mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
- mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
- mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
- mindspore/profiler/parser/ascend_op_generator.py +6 -6
- mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
- mindspore/profiler/parser/base_timeline_generator.py +10 -8
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +38 -22
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +2 -2
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +21 -2
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +179 -89
- mindspore/rewrite/api/node.py +102 -19
- mindspore/rewrite/api/node_type.py +5 -1
- mindspore/rewrite/api/pattern_engine.py +1 -1
- mindspore/rewrite/api/scoped_value.py +9 -17
- mindspore/rewrite/api/symbol_tree.py +131 -47
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +33 -24
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +273 -234
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +216 -221
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +174 -113
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +42 -21
- mindspore/rewrite/parsers/function_def_parser.py +24 -16
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +196 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/utils.py +1 -1
- mindspore/rewrite/symbol_tree.py +523 -578
- mindspore/rewrite/symbol_tree_builder.py +9 -193
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +6 -4
- mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
- mindspore/safeguard/rewrite_obfuscation.py +541 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/_utils.py +7 -3
- mindspore/train/amp.py +323 -123
- mindspore/train/anf_ir_pb2.py +14 -2
- mindspore/train/callback/_backup_and_restore.py +2 -12
- mindspore/train/callback/_callback.py +29 -4
- mindspore/train/callback/_checkpoint.py +23 -8
- mindspore/train/callback/_early_stop.py +2 -2
- mindspore/train/callback/_landscape.py +4 -4
- mindspore/train/callback/_loss_monitor.py +2 -2
- mindspore/train/callback/_on_request_exit.py +2 -2
- mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
- mindspore/train/callback/_summary_collector.py +15 -8
- mindspore/train/callback/_time_monitor.py +58 -5
- mindspore/train/data_sink.py +5 -11
- mindspore/train/dataset_helper.py +84 -57
- mindspore/train/loss_scale_manager.py +2 -2
- mindspore/train/metrics/__init__.py +3 -3
- mindspore/train/metrics/cosine_similarity.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +3 -2
- mindspore/train/metrics/mean_surface_distance.py +3 -2
- mindspore/train/metrics/metric.py +39 -19
- mindspore/train/metrics/roc.py +2 -2
- mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
- mindspore/train/mind_ir_pb2.py +85 -36
- mindspore/train/model.py +187 -47
- mindspore/train/serialization.py +487 -161
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/_writer_pool.py +3 -2
- mindspore/train/summary/summary_record.py +37 -17
- mindspore/train/train_thor/convert_utils.py +3 -3
- mindspore/train/train_thor/dataset_helper.py +1 -1
- mindspore/turbojpeg.dll +0 -0
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/METADATA +7 -4
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/RECORD +429 -486
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/dataset/datapreprocess/__init__.py +0 -20
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/include/api/net.h +0 -142
- mindspore/nn/lr_scheduler.py +0 -262
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
- mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
- mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -350
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -409
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -578
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -199
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -446
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/WHEEL +0 -0
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/entry_points.txt +0 -0
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/top_level.txt +0 -0
|
@@ -33,6 +33,7 @@ class _OffloadConfig:
|
|
|
33
33
|
OFFLOAD_PARAM = "offload_param"
|
|
34
34
|
OFFLOAD_PATH = "offload_path"
|
|
35
35
|
OFFLOAD_CPU_SIZE = "offload_cpu_size"
|
|
36
|
+
OFFLOAD_CHECKPOINT = "offload_checkpoint"
|
|
36
37
|
OFFLOAD_DISK_SIZE = "offload_disk_size"
|
|
37
38
|
ENABLE_AIO = "enable_aio"
|
|
38
39
|
AIO_BLOCK_SIZE = "aio_block_size"
|
|
@@ -84,6 +85,16 @@ class _OffloadContext:
|
|
|
84
85
|
Validator.check_string(offload_param.lower(), ["cpu", "disk"], "offload_param", "set_offload_param")
|
|
85
86
|
self._context_handle.set_offload_param(offload_param.lower())
|
|
86
87
|
|
|
88
|
+
def set_offload_checkpoint(self, offload_checkpoint):
|
|
89
|
+
"""Set offload_checkpoint"""
|
|
90
|
+
if not isinstance(offload_checkpoint, str):
|
|
91
|
+
raise TypeError("For 'set_offload_checkpoint', "
|
|
92
|
+
"the argument 'offload_checkpoint' must be str, but got the type : {}."
|
|
93
|
+
.format(type(offload_checkpoint)))
|
|
94
|
+
Validator.check_string(offload_checkpoint.lower(), ["cpu", "disk"], "offload_checkpoint",
|
|
95
|
+
"set_offload_checkpoint")
|
|
96
|
+
self._context_handle.set_offload_checkpoint(offload_checkpoint.lower())
|
|
97
|
+
|
|
87
98
|
def set_offload_path(self, offload_path):
|
|
88
99
|
"""Set offload_path"""
|
|
89
100
|
if not isinstance(offload_path, str):
|
|
@@ -194,7 +205,8 @@ class _OffloadContext:
|
|
|
194
205
|
_OffloadConfig.HBM_RATIO, _OffloadConfig.OFFLOAD_CPU_SIZE,
|
|
195
206
|
_OffloadConfig.OFFLOAD_DISK_SIZE, _OffloadConfig.ENABLE_AIO,
|
|
196
207
|
_OffloadConfig.AIO_BLOCK_SIZE, _OffloadConfig.AIO_QUEUE_DEPTH,
|
|
197
|
-
_OffloadConfig.ENABLE_PINNED_MEM, _OffloadConfig.AUTO_OFFLOAD
|
|
208
|
+
_OffloadConfig.ENABLE_PINNED_MEM, _OffloadConfig.AUTO_OFFLOAD,
|
|
209
|
+
_OffloadConfig.OFFLOAD_CHECKPOINT]:
|
|
198
210
|
unknown_config.append(config_name)
|
|
199
211
|
|
|
200
212
|
if unknown_config:
|
|
@@ -220,7 +232,8 @@ class _OffloadContext:
|
|
|
220
232
|
_OffloadConfig.AUTO_OFFLOAD: self._context_handle.auto_offload(),
|
|
221
233
|
_OffloadConfig.HOST_MEM_BLOCk_SIZE: self._context_handle.host_mem_block_size(),
|
|
222
234
|
_OffloadConfig.CPU_RATIO: self._context_handle.cpu_ratio(),
|
|
223
|
-
_OffloadConfig.HBM_RATIO: self._context_handle.hbm_ratio()
|
|
235
|
+
_OffloadConfig.HBM_RATIO: self._context_handle.hbm_ratio(),
|
|
236
|
+
_OffloadConfig.OFFLOAD_CHECKPOINT: self._context_handle.offload_checkpoint()
|
|
224
237
|
}
|
|
225
238
|
return offload_config
|
|
226
239
|
|
|
@@ -257,5 +270,6 @@ _set_offload_context_func_map = {
|
|
|
257
270
|
_OffloadConfig.AUTO_OFFLOAD: offload_context().set_auto_offload,
|
|
258
271
|
_OffloadConfig.HOST_MEM_BLOCk_SIZE: offload_context().set_host_mem_block_size,
|
|
259
272
|
_OffloadConfig.CPU_RATIO: offload_context().set_cpu_ratio,
|
|
260
|
-
_OffloadConfig.HBM_RATIO: offload_context().set_hbm_ratio
|
|
273
|
+
_OffloadConfig.HBM_RATIO: offload_context().set_hbm_ratio,
|
|
274
|
+
_OffloadConfig.OFFLOAD_CHECKPOINT: offload_context().set_offload_checkpoint
|
|
261
275
|
}
|
|
@@ -330,12 +330,13 @@ def _rank_list_for_transform_parallel_checkpoint(rank_id, src_strategy_list, dst
|
|
|
330
330
|
device_list = list(range(0, np.prod(from_tensor_layout[0])))
|
|
331
331
|
param_rank_list = _get_needed_rank_list_by_layouts(from_tensor_layout, to_tensor_layout, device_list, rank_id)
|
|
332
332
|
param_rank_list_new = [rank % from_device_num for rank in param_rank_list]
|
|
333
|
-
|
|
334
|
-
result_list.update(
|
|
333
|
+
param_rank_set_new = set(param_rank_list_new)
|
|
334
|
+
result_list.update(param_rank_set_new)
|
|
335
335
|
return list(result_list)
|
|
336
336
|
|
|
337
337
|
|
|
338
|
-
def _transform_parallel_checkpoint(rank_id, param_total_dict, param_attr_dict, src_strategy_list,
|
|
338
|
+
def _transform_parallel_checkpoint(rank_id, param_total_dict, param_attr_dict, src_strategy_list,
|
|
339
|
+
dst_strategy_list, param_type_dict):
|
|
339
340
|
"""
|
|
340
341
|
Transform model parallel dimension for distributed checkpoint files.
|
|
341
342
|
"""
|
|
@@ -397,15 +398,21 @@ def _transform_parallel_checkpoint(rank_id, param_total_dict, param_attr_dict, s
|
|
|
397
398
|
transform_tensor = ms.Tensor(param_total_dict[param_name][rank_id % device_num])
|
|
398
399
|
requires_grad = param_attr_dict[param_name][rank_id % device_num][0]
|
|
399
400
|
layerwise_parallel = param_attr_dict[param_name][rank_id % device_num][1]
|
|
400
|
-
|
|
401
|
+
transform_para = ms.Parameter(transform_tensor, param_name, requires_grad, layerwise_parallel)
|
|
402
|
+
if param_type_dict[param_name][rank_id % device_num] == "BFloat16":
|
|
403
|
+
transform_para.set_dtype(ms.bfloat16)
|
|
404
|
+
transform_param_dict[param_name] = transform_para
|
|
401
405
|
|
|
402
406
|
# Handle those parameter like learning_rate, global_step which not in strategy_file.
|
|
403
407
|
for param_name, _ in param_total_dict.items():
|
|
404
408
|
if param_name not in transform_param_dict:
|
|
405
|
-
|
|
409
|
+
transform_para = ms.Parameter(
|
|
406
410
|
ms.Tensor(param_total_dict[param_name][rank_id % device_num]), param_name,
|
|
407
411
|
param_attr_dict[param_name][rank_id % device_num][0],
|
|
408
412
|
param_attr_dict[param_name][rank_id % device_num][1])
|
|
413
|
+
if param_type_dict[param_name][rank_id % device_num] == "BFloat16":
|
|
414
|
+
transform_para.set_dtype(ms.bfloat16)
|
|
415
|
+
transform_param_dict[param_name] = transform_para
|
|
409
416
|
|
|
410
417
|
transform_param_list = [{"name": param_name, "data": param_data}
|
|
411
418
|
for param_name, param_data in transform_param_dict.items()]
|
|
@@ -228,3 +228,15 @@ def _enable_distributed_mindrt():
|
|
|
228
228
|
This method is used to distinguish from old distributed training mode.
|
|
229
229
|
'''
|
|
230
230
|
return ps_context().enable_distributed_mindrt()
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
def _set_checkpoint_load_status(status):
|
|
234
|
+
return ps_context().set_checkpoint_load_status(status)
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
def _store_warm_up_ptr_by_tensor(param_key, tensor):
|
|
238
|
+
return ps_context().store_warm_up_ptr_by_tensor(param_key, tensor)
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
def _store_warm_up_ptr_by_tensor_list(param_key, key_tensor, value_tensor, status_tensor):
|
|
242
|
+
return ps_context().store_warm_up_ptr_by_tensor_list(param_key, key_tensor, value_tensor, status_tensor)
|
mindspore/parallel/_tensor.py
CHANGED
|
@@ -17,7 +17,7 @@ from __future__ import division
|
|
|
17
17
|
from __future__ import absolute_import
|
|
18
18
|
|
|
19
19
|
import numpy as np
|
|
20
|
-
|
|
20
|
+
from mindspore.common import dtype as mstype
|
|
21
21
|
from mindspore.common.tensor import Tensor
|
|
22
22
|
from mindspore.communication.management import get_rank, get_group_size
|
|
23
23
|
from mindspore._c_expression import TensorTransform
|
|
@@ -41,7 +41,7 @@ def _get_tensor_strategy(dev_mat, tensor_map):
|
|
|
41
41
|
if dim == -1:
|
|
42
42
|
tensor_strategy.append(1)
|
|
43
43
|
else:
|
|
44
|
-
tensor_strategy.append(dev_mat[-dim-1])
|
|
44
|
+
tensor_strategy.append(dev_mat[-dim - 1])
|
|
45
45
|
return tensor_strategy
|
|
46
46
|
|
|
47
47
|
|
|
@@ -198,7 +198,7 @@ def _get_slice_index(dev_mat, tensor_map, opt_shard_group):
|
|
|
198
198
|
return tensor_slice_index
|
|
199
199
|
|
|
200
200
|
|
|
201
|
-
def _load_tensor(tensor, dev_mat, tensor_map):
|
|
201
|
+
def _load_tensor(tensor, dev_mat, tensor_map, rank_id=-1):
|
|
202
202
|
"""
|
|
203
203
|
Get the tensor slice of the local device by the device matrix and the tensor map
|
|
204
204
|
|
|
@@ -216,16 +216,21 @@ def _load_tensor(tensor, dev_mat, tensor_map):
|
|
|
216
216
|
>>> tensor_map = [1, -1]
|
|
217
217
|
>>> tensor_slice = _load_tensor(tensor, dev_mat, tensor_map)
|
|
218
218
|
"""
|
|
219
|
-
|
|
219
|
+
if rank_id == -1:
|
|
220
|
+
rank = get_rank()
|
|
221
|
+
else:
|
|
222
|
+
rank = rank_id
|
|
220
223
|
tensor_strategy = _get_tensor_strategy(dev_mat, tensor_map)
|
|
221
224
|
tensor_slice_index = _get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, rank)
|
|
225
|
+
if tensor.dtype == mstype.bfloat16:
|
|
226
|
+
tensor = tensor.float()
|
|
222
227
|
np_tensor = tensor.asnumpy()
|
|
223
228
|
np_tensor_list = _chunk_tensor_by_strategy(np_tensor, tensor_strategy)
|
|
224
229
|
np_tensor_slice = np_tensor_list[int(tensor_slice_index)]
|
|
225
230
|
return np_tensor_slice
|
|
226
231
|
|
|
227
232
|
|
|
228
|
-
def _load_tensor_by_layout(tensor, layout):
|
|
233
|
+
def _load_tensor_by_layout(tensor, layout, rank_id):
|
|
229
234
|
"""
|
|
230
235
|
Load tensor by layout.
|
|
231
236
|
|
|
@@ -246,19 +251,19 @@ def _load_tensor_by_layout(tensor, layout):
|
|
|
246
251
|
raise ValueError("The length of layout must be larger than 5! layout is {}".format(layout))
|
|
247
252
|
dev_mat = layout[0]
|
|
248
253
|
tensor_map = layout[1]
|
|
249
|
-
if
|
|
254
|
+
if not tensor_map:
|
|
250
255
|
return tensor
|
|
251
256
|
uniform_split = layout[4]
|
|
252
257
|
group = layout[5]
|
|
253
258
|
if uniform_split == 0:
|
|
254
259
|
raise RuntimeError("The load tensor only support uniform split now")
|
|
255
|
-
tensor_slice = _load_tensor(tensor, dev_mat, tensor_map)
|
|
260
|
+
tensor_slice = _load_tensor(tensor, dev_mat, tensor_map, rank_id)
|
|
256
261
|
if group:
|
|
257
262
|
# get a totally shard tensor slice for parallel optimizer
|
|
258
263
|
rank = get_rank(group)
|
|
259
264
|
size = get_group_size(group)
|
|
260
265
|
tensor_slice = np.split(tensor_slice, size)[rank]
|
|
261
|
-
return Tensor(tensor_slice)
|
|
266
|
+
return Tensor(tensor_slice, tensor.dtype)
|
|
262
267
|
|
|
263
268
|
|
|
264
269
|
def _reshape_param_data(param_data, dev_mat, tensor_map):
|
|
@@ -315,7 +320,6 @@ def _reshape_param_data(param_data, dev_mat, tensor_map):
|
|
|
315
320
|
return Tensor(tensor_slices_new[0])
|
|
316
321
|
|
|
317
322
|
|
|
318
|
-
|
|
319
323
|
def _extract_layout_item(layout_item):
|
|
320
324
|
dev_matrix = layout_item[0]
|
|
321
325
|
tensor_map = layout_item[1]
|
|
@@ -541,6 +545,7 @@ def _check_operator(operator):
|
|
|
541
545
|
|
|
542
546
|
def _apply_operator(operator_name):
|
|
543
547
|
"""apply transform operator"""
|
|
548
|
+
|
|
544
549
|
def _apply_reshape_operator(numpy_data, reshape_op):
|
|
545
550
|
"""
|
|
546
551
|
Apply reshape operator.
|
|
@@ -597,8 +602,8 @@ def _apply_operator(operator_name):
|
|
|
597
602
|
raise ValueError("The slice operator information is wrong.")
|
|
598
603
|
shape_size = len(slice_op[1]) // 3
|
|
599
604
|
begin = slice_op[1][:shape_size]
|
|
600
|
-
end = slice_op[1][shape_size:shape_size*2]
|
|
601
|
-
stride = slice_op[1][shape_size*2:]
|
|
605
|
+
end = slice_op[1][shape_size:shape_size * 2]
|
|
606
|
+
stride = slice_op[1][shape_size * 2:]
|
|
602
607
|
slice_index = []
|
|
603
608
|
for begin_i, end_i, strides_i in zip(begin, end, stride):
|
|
604
609
|
s = slice(begin_i, end_i, strides_i)
|
|
@@ -637,8 +642,8 @@ def _reshape_param_data_with_weight(param_data, dev_mat, field_size):
|
|
|
637
642
|
for i in range(len(tensor_slices[0][0])):
|
|
638
643
|
tensor_slices_new = np.array(tensor_slices[0][:, i]).reshape(field_size, -1)
|
|
639
644
|
for j in range(1, device_count):
|
|
640
|
-
tensor_slices_new = np.concatenate((tensor_slices_new
|
|
641
|
-
|
|
645
|
+
tensor_slices_new = np.concatenate((tensor_slices_new, \
|
|
646
|
+
np.array(tensor_slices[j][:, i]).reshape(field_size, -1)), axis=1)
|
|
642
647
|
tensor_slices_col.append(tensor_slices_new)
|
|
643
648
|
new_tensor = np.array(tensor_slices_col[0]).reshape(-1, 1)
|
|
644
649
|
for i in range(1, len(tensor_slices_col)):
|
|
@@ -424,9 +424,11 @@ class _Linear(Cell):
|
|
|
424
424
|
self.out_channels = out_channels
|
|
425
425
|
if not (isinstance(activation, str) or activation is None or issubclass(activation, nn.Cell)):
|
|
426
426
|
raise TypeError(f"For Linear cell, the activation should str type or nn.Cell type, but got {activation}.")
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
427
|
+
|
|
428
|
+
if isinstance(weight_init, Tensor):
|
|
429
|
+
if weight_init.ndim != 2 or weight_init.shape[0] != out_channels \
|
|
430
|
+
or weight_init.shape[1] != in_channels:
|
|
431
|
+
raise ValueError("The shape of parameter 'weight_init' is error, please check shape of 'weight_init'.")
|
|
430
432
|
weight_shape = [out_channels, in_channels] if transpose_b else [in_channels, out_channels]
|
|
431
433
|
self.expert_num = expert_num
|
|
432
434
|
self.outer_batch = outer_batch
|
|
@@ -139,6 +139,7 @@ class _NLLLoss(Cell):
|
|
|
139
139
|
self.add = P.Add().shard(((dp, mp), ()))
|
|
140
140
|
|
|
141
141
|
def construct(self, softmax_result, one_hot_label):
|
|
142
|
+
"""The forward of _NLLLoss"""
|
|
142
143
|
log_softmax_result = self.log(self.add(softmax_result, self.eps_const))
|
|
143
144
|
loss = self.mul(log_softmax_result, one_hot_label)
|
|
144
145
|
loss_unsum = self.neg(loss)
|
|
@@ -273,7 +273,7 @@ class MoE(Cell):
|
|
|
273
273
|
if self.group_wise_a2a:
|
|
274
274
|
# If capacity can't div by mp, pad for mp shard.
|
|
275
275
|
if capacity % self.mp != 0:
|
|
276
|
-
pad_size = self.mp-(capacity % self.mp)
|
|
276
|
+
pad_size = self.mp - (capacity % self.mp)
|
|
277
277
|
if pad_size != 0:
|
|
278
278
|
capacity += pad_size
|
|
279
279
|
pad_tensor = self.stride_slice_dp(expert_input, (0, 0, 0, 0),
|
|
@@ -330,7 +330,7 @@ class MoE(Cell):
|
|
|
330
330
|
# Pad capacity for comp_comm_parallel_degree split.
|
|
331
331
|
pad_size = 0
|
|
332
332
|
if capacity % self.comp_comm_parallel_degree != 0:
|
|
333
|
-
pad_size = self.comp_comm_parallel_degree-(capacity % self.comp_comm_parallel_degree)
|
|
333
|
+
pad_size = self.comp_comm_parallel_degree - (capacity % self.comp_comm_parallel_degree)
|
|
334
334
|
capacity += pad_size
|
|
335
335
|
pad_tensor = self.stride_slice_dp(expert_input, (0, 0, 0, 0),
|
|
336
336
|
(self.expert_dim, self.dp_group, pad_size, self.hidden_size),
|
|
@@ -147,10 +147,11 @@ class _PipeLineConfig(_Config):
|
|
|
147
147
|
>>> config=_PipeLineConfig(pipeline_stage=1, micro_batch_num=1)
|
|
148
148
|
"""
|
|
149
149
|
|
|
150
|
-
def __init__(self, pipeline_stage=1, micro_batch_num=1):
|
|
150
|
+
def __init__(self, pipeline_stage=1, micro_batch_num=1, pipeline_segment=1):
|
|
151
151
|
Validator.check_positive_int(pipeline_stage, "pipeline_stage")
|
|
152
152
|
Validator.check_positive_int(micro_batch_num, "micro_batch_num")
|
|
153
153
|
self.pipeline_stage = pipeline_stage
|
|
154
|
+
self.pipeline_segment = pipeline_segment
|
|
154
155
|
self.micro_batch_num = micro_batch_num
|
|
155
156
|
|
|
156
157
|
@property
|
|
@@ -163,6 +164,16 @@ class _PipeLineConfig(_Config):
|
|
|
163
164
|
self._pipeline_stage = value
|
|
164
165
|
context.set_auto_parallel_context(pipeline_stages=value)
|
|
165
166
|
|
|
167
|
+
@property
|
|
168
|
+
def pipeline_segment(self):
|
|
169
|
+
return self._pipeline_segment
|
|
170
|
+
|
|
171
|
+
@pipeline_segment.setter
|
|
172
|
+
def pipeline_segment(self, value):
|
|
173
|
+
Validator.check_positive_int(value, "pipeline_segment")
|
|
174
|
+
self._pipeline_segment = value
|
|
175
|
+
context.set_auto_parallel_context(pipeline_segments=value)
|
|
176
|
+
|
|
166
177
|
@property
|
|
167
178
|
def micro_batch_num(self):
|
|
168
179
|
return self._micro_batch_num
|
|
@@ -226,7 +226,8 @@ class TransformerOpParallelConfig(_Config):
|
|
|
226
226
|
>>> config=TransformerOpParallelConfig(data_parallel=1, model_parallel=1, recompute=recompute_config)
|
|
227
227
|
"""
|
|
228
228
|
|
|
229
|
-
def __init__(self, data_parallel=1, model_parallel=1, expert_parallel=1, pipeline_stage=1,
|
|
229
|
+
def __init__(self, data_parallel=1, model_parallel=1, expert_parallel=1, pipeline_stage=1, pipeline_segment=1,
|
|
230
|
+
micro_batch_num=1,
|
|
230
231
|
recompute=default_transformer_recompute_config,
|
|
231
232
|
optimizer_shard=False, gradient_aggregation_group=4, vocab_emb_dp=True):
|
|
232
233
|
self.recompute = recompute
|
|
@@ -234,7 +235,8 @@ class TransformerOpParallelConfig(_Config):
|
|
|
234
235
|
self.gradient_aggregation_group = gradient_aggregation_group
|
|
235
236
|
self._embed_dp_mp_config = EmbeddingOpParallelConfig(data_parallel=data_parallel, model_parallel=model_parallel,
|
|
236
237
|
vocab_emb_dp=vocab_emb_dp)
|
|
237
|
-
self._pp_config = _PipeLineConfig(pipeline_stage=pipeline_stage, micro_batch_num=micro_batch_num
|
|
238
|
+
self._pp_config = _PipeLineConfig(pipeline_stage=pipeline_stage, micro_batch_num=micro_batch_num,
|
|
239
|
+
pipeline_segment=pipeline_segment)
|
|
238
240
|
self._moe_config = MoEParallelConfig(data_parallel=data_parallel, model_parallel=model_parallel,
|
|
239
241
|
expert_parallel=expert_parallel)
|
|
240
242
|
|
|
@@ -309,6 +311,14 @@ class TransformerOpParallelConfig(_Config):
|
|
|
309
311
|
def pipeline_stage(self, value):
|
|
310
312
|
self._pp_config.pipeline_stage = value
|
|
311
313
|
|
|
314
|
+
@property
|
|
315
|
+
def pipeline_segment(self):
|
|
316
|
+
return self._pp_config.pipeline_segment
|
|
317
|
+
|
|
318
|
+
@pipeline_segment.setter
|
|
319
|
+
def pipeline_segment(self, value):
|
|
320
|
+
self._pp_config.pipeline_segment = value
|
|
321
|
+
|
|
312
322
|
@property
|
|
313
323
|
def optimizer_shard(self):
|
|
314
324
|
return self._optimizer_shard
|
|
@@ -429,6 +439,7 @@ class FeedForward(Cell):
|
|
|
429
439
|
>>> print(output.shape)
|
|
430
440
|
(2, 20, 15)
|
|
431
441
|
"""
|
|
442
|
+
|
|
432
443
|
@_LogActionOnce(logger=logger, key='FeedForward',
|
|
433
444
|
no_warning=_get_parallel_mode() in (ParallelMode.STAND_ALONE,))
|
|
434
445
|
@_args_type_validator_check(hidden_size=Validator.check_positive_int,
|
|
@@ -622,6 +633,7 @@ class AttentionMask(Cell):
|
|
|
622
633
|
[1. 1. 1. 0]
|
|
623
634
|
[0. 0. 0. 0]]]
|
|
624
635
|
"""
|
|
636
|
+
|
|
625
637
|
@_LogActionOnce(logger=logger, key='AttentionMask',
|
|
626
638
|
no_warning=_get_parallel_mode() in (ParallelMode.STAND_ALONE,))
|
|
627
639
|
@_args_type_validator_check(seq_length=Validator.check_positive_int,
|
|
@@ -710,6 +722,7 @@ class VocabEmbedding(Cell):
|
|
|
710
722
|
>>> print(table.shape)
|
|
711
723
|
(30, 30)
|
|
712
724
|
"""
|
|
725
|
+
|
|
713
726
|
@_LogActionOnce(logger=logger, key='VocabEmbedding',
|
|
714
727
|
no_warning=_get_parallel_mode() in (ParallelMode.STAND_ALONE,))
|
|
715
728
|
@_args_type_validator_check(vocab_size=Validator.check_positive_int,
|
|
@@ -866,6 +879,7 @@ class MultiHeadAttention(Cell):
|
|
|
866
879
|
>>> print(past[1].shape)
|
|
867
880
|
(2, 3, 20, 5)
|
|
868
881
|
"""
|
|
882
|
+
|
|
869
883
|
@_LogActionOnce(logger=logger, key='MultiHeadAttention',
|
|
870
884
|
no_warning=_get_parallel_mode() in (ParallelMode.STAND_ALONE,))
|
|
871
885
|
@_args_type_validator_check(hidden_size=Validator.check_positive_int,
|
|
@@ -1203,7 +1217,8 @@ class MultiHeadAttention(Cell):
|
|
|
1203
1217
|
def _get_batch_size_from_query(self, query):
|
|
1204
1218
|
r"""Get the batch size from query tensor"""
|
|
1205
1219
|
# For the incremental prediction, the seq length for the input is 1.
|
|
1206
|
-
|
|
1220
|
+
incr_infer = self.use_past and self.is_first_iteration
|
|
1221
|
+
if len(F.shape(query)) == 2 and ((incr_infer) or (not self.use_past)):
|
|
1207
1222
|
return F.shape(query)[0] // self.src_seq_length
|
|
1208
1223
|
return F.shape(query)[0]
|
|
1209
1224
|
|
|
@@ -1459,6 +1474,7 @@ class TransformerEncoderLayer(Cell):
|
|
|
1459
1474
|
>>> print(past[1].shape)
|
|
1460
1475
|
(2, 2, 16, 4)
|
|
1461
1476
|
"""
|
|
1477
|
+
|
|
1462
1478
|
@_LogActionOnce(logger=logger, key='TransformerEncoderLayer',
|
|
1463
1479
|
no_warning=_get_parallel_mode() in (ParallelMode.STAND_ALONE,))
|
|
1464
1480
|
@_args_type_validator_check(hidden_size=Validator.check_positive_int,
|
|
@@ -1848,6 +1864,7 @@ class TransformerDecoderLayer(Cell):
|
|
|
1848
1864
|
>>> print(past[3].shape)
|
|
1849
1865
|
(2, 2, 20, 32)
|
|
1850
1866
|
"""
|
|
1867
|
+
|
|
1851
1868
|
@_LogActionOnce(logger=logger, key='TransformerDecoderLayer',
|
|
1852
1869
|
no_warning=_get_parallel_mode() in (ParallelMode.STAND_ALONE,))
|
|
1853
1870
|
@_args_type_validator_check(hidden_size=Validator.check_positive_int,
|
|
@@ -2379,6 +2396,7 @@ class TransformerEncoder(Cell):
|
|
|
2379
2396
|
>>> print(past[0][1].shape)
|
|
2380
2397
|
(2, 2, 16, 4)
|
|
2381
2398
|
"""
|
|
2399
|
+
|
|
2382
2400
|
@_LogActionOnce(logger=logger, key='TransformerEncoder',
|
|
2383
2401
|
no_warning=_get_parallel_mode() in (ParallelMode.STAND_ALONE,))
|
|
2384
2402
|
@_args_type_validator_check(batch_size=Validator.check_positive_int,
|
|
@@ -2613,6 +2631,7 @@ class TransformerDecoder(Cell):
|
|
|
2613
2631
|
>>> print(past[0][3].shape)
|
|
2614
2632
|
(2, 2, 20, 32)
|
|
2615
2633
|
"""
|
|
2634
|
+
|
|
2616
2635
|
@_LogActionOnce(logger=logger, key='TransformerDecoder',
|
|
2617
2636
|
no_warning=_get_parallel_mode() in (ParallelMode.STAND_ALONE,))
|
|
2618
2637
|
@_args_type_validator_check(batch_size=Validator.check_positive_int,
|
|
@@ -2882,6 +2901,7 @@ class Transformer(Cell):
|
|
|
2882
2901
|
>>> print(de_past[0][3].shape)
|
|
2883
2902
|
(2, 2, 20, 32)
|
|
2884
2903
|
"""
|
|
2904
|
+
|
|
2885
2905
|
@_LogActionOnce(logger=logger, key='Transformer',
|
|
2886
2906
|
no_warning=_get_parallel_mode() in (ParallelMode.STAND_ALONE,))
|
|
2887
2907
|
@_args_type_validator_check(batch_size=Validator.check_positive_int,
|
mindspore/parallel/_utils.py
CHANGED
|
@@ -100,13 +100,14 @@ def _slice_parameter(parameter, phase, layout):
|
|
|
100
100
|
parameter.sliced = True
|
|
101
101
|
return
|
|
102
102
|
if not parameter.sliced:
|
|
103
|
-
|
|
103
|
+
rank = get_rank()
|
|
104
|
+
new_tensor = _load_tensor_by_layout(parameter, layout, rank)
|
|
104
105
|
parameter.set_data(new_tensor, True)
|
|
105
106
|
|
|
106
107
|
|
|
107
|
-
def _slice_tensor(tensor, layout):
|
|
108
|
+
def _slice_tensor(tensor, layout, rank_id):
|
|
108
109
|
"""Slice python tensor obj according to the layout."""
|
|
109
|
-
new_tensor = _load_tensor_by_layout(tensor, layout)
|
|
110
|
+
new_tensor = _load_tensor_by_layout(tensor, layout, rank_id)
|
|
110
111
|
return new_tensor
|
|
111
112
|
|
|
112
113
|
|
|
@@ -136,14 +137,17 @@ def _to_full_shapes(shapes, device_num):
|
|
|
136
137
|
"dataset strategy item size {}".format(len(shape), len(dataset_strategy[index])))
|
|
137
138
|
new_shape = ()
|
|
138
139
|
for i, item in enumerate(shape):
|
|
139
|
-
|
|
140
|
+
if item > 0:
|
|
141
|
+
new_shape += (item * dataset_strategy[index][i],) # static shape
|
|
142
|
+
else:
|
|
143
|
+
new_shape += (item,) # dynamic shape
|
|
140
144
|
new_shapes.append(new_shape)
|
|
141
145
|
return new_shapes
|
|
142
146
|
for shape in shapes:
|
|
143
147
|
new_shape = ()
|
|
144
148
|
for i, item in enumerate(shape):
|
|
145
|
-
if i == 0:
|
|
146
|
-
new_shape += (item * device_num,)
|
|
149
|
+
if i == 0 and item > 0:
|
|
150
|
+
new_shape += (item * device_num,) # only for static shape
|
|
147
151
|
else:
|
|
148
152
|
new_shape += (item,)
|
|
149
153
|
new_shapes.append(new_shape)
|
|
@@ -201,7 +205,7 @@ def _to_full_tensor(elem, global_device_num, global_rank, scaling_sens=None):
|
|
|
201
205
|
slice_index += (s,)
|
|
202
206
|
new_tensor_numpy = np.zeros(new_shape, dtype_to_nptype(type_))
|
|
203
207
|
new_tensor_numpy[slice_index] = data.asnumpy()
|
|
204
|
-
new_tensor = Tensor(new_tensor_numpy)
|
|
208
|
+
new_tensor = Tensor(new_tensor_numpy, dtype=type_)
|
|
205
209
|
lst.append(new_tensor)
|
|
206
210
|
if scaling_sens:
|
|
207
211
|
lst.append(Tensor(scaling_sens, mstype.float32))
|
|
@@ -229,7 +229,7 @@ def set_algo_parameters(**kwargs):
|
|
|
229
229
|
"""
|
|
230
230
|
Set parameters in the algorithm for parallel strategy searching. See a typical use in
|
|
231
231
|
`test_auto_parallel_resnet.py
|
|
232
|
-
<https://gitee.com/mindspore/mindspore/blob/r2.
|
|
232
|
+
<https://gitee.com/mindspore/mindspore/blob/r2.2/tests/ut/python/parallel/test_auto_parallel_resnet.py>`_.
|
|
233
233
|
|
|
234
234
|
Note:
|
|
235
235
|
The attribute name is required. This interface works ONLY in AUTO_PARALLEL mode.
|
|
@@ -239,10 +239,10 @@ def set_algo_parameters(**kwargs):
|
|
|
239
239
|
Default: ``True`` . For example with 8 devices available, if set ``True`` , strategy (4, 1) will not be
|
|
240
240
|
included in ReLU's candidate strategies, because strategy (4, 1) only utilizes 4 devices.
|
|
241
241
|
elementwise_op_strategy_follow (bool): Whether the elementwise operator has the consistent strategies as its
|
|
242
|
-
subsequent operators.
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
strategy ((8, 1), (8, 1)).
|
|
242
|
+
subsequent operators. Elementwise operators refer to operators that operate on input element by element,
|
|
243
|
+
such as Add, ReLU, etc. Default: ``False`` . For the example of ReLU followed by Add, if this flag is set
|
|
244
|
+
``True`` , then the searched strategy by the algorithm guarantees that strategies of these two operators
|
|
245
|
+
are consistent, e.g., ReLU's strategy (8, 1) and Add's strategy ((8, 1), (8, 1)).
|
|
246
246
|
enable_algo_approxi (bool): Whether to enable the approximation in the algorithms. Default: ``False`` . Due to
|
|
247
247
|
large solution space in searching parallel strategy for large DNN model, the algorithm takes fairly long
|
|
248
248
|
time in this case. To mitigate it, if this flag is set ``True`` , an approximation is made to discard some
|
|
@@ -261,8 +261,87 @@ def set_algo_parameters(**kwargs):
|
|
|
261
261
|
ValueError: If context keyword is not recognized.
|
|
262
262
|
|
|
263
263
|
Examples:
|
|
264
|
+
.. note::
|
|
265
|
+
Before running the following examples, you need to configure the communication environment variables.
|
|
266
|
+
|
|
267
|
+
For the Ascend devices, users need to prepare the rank table, set rank_id and device_id.
|
|
268
|
+
Please see the `rank table startup
|
|
269
|
+
<https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/rank_table.html>`_
|
|
270
|
+
for more details.
|
|
271
|
+
|
|
272
|
+
For the GPU devices, users need to prepare the host file and mpi, please see the `mpirun startup
|
|
273
|
+
<https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/mpirun.html>`_ .
|
|
274
|
+
|
|
275
|
+
For the CPU device, users need to write a dynamic cluster startup script, please see the `Dynamic Cluster
|
|
276
|
+
Startup <https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/dynamic_cluster.html>`_ .
|
|
277
|
+
|
|
278
|
+
>>> import numpy as np
|
|
264
279
|
>>> import mindspore as ms
|
|
280
|
+
>>> import mindspore.dataset as ds
|
|
281
|
+
>>> from mindspore import nn, ops, train
|
|
282
|
+
>>> from mindspore.communication import init
|
|
283
|
+
>>> from mindspore.common.initializer import initializer
|
|
284
|
+
>>>
|
|
285
|
+
>>> ms.set_context(mode=ms.GRAPH_MODE)
|
|
286
|
+
>>> ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.AUTO_PARALLEL,
|
|
287
|
+
>>> search_mode="sharding_propagation")
|
|
288
|
+
>>> init()
|
|
289
|
+
>>> ms.set_algo_parameters(fully_use_devices=True)
|
|
265
290
|
>>> ms.set_algo_parameters(elementwise_op_strategy_follow=True)
|
|
291
|
+
>>> ms.set_algo_parameters(enable_algo_approxi=True)
|
|
292
|
+
>>> ms.set_algo_parameters(algo_approxi_epsilon=0.2)
|
|
293
|
+
>>> ms.set_algo_parameters(tensor_slice_align_enable=True)
|
|
294
|
+
>>> ms.set_algo_parameters(tensor_slice_align_size=8)
|
|
295
|
+
>>>
|
|
296
|
+
>>> # Define the network structure.
|
|
297
|
+
>>> class Dense(nn.Cell):
|
|
298
|
+
... def __init__(self, in_channels, out_channels):
|
|
299
|
+
... super().__init__()
|
|
300
|
+
... self.weight = ms.Parameter(initializer("normal", [in_channels, out_channels], ms.float32))
|
|
301
|
+
... self.bias = ms.Parameter(initializer("normal", [out_channels], ms.float32))
|
|
302
|
+
... self.matmul = ops.MatMul()
|
|
303
|
+
... self.add = ops.Add()
|
|
304
|
+
...
|
|
305
|
+
... def construct(self, x):
|
|
306
|
+
... x = self.matmul(x, self.weight)
|
|
307
|
+
... x = self.add(x, self.bias)
|
|
308
|
+
... return x
|
|
309
|
+
>>>
|
|
310
|
+
>>> class FFN(nn.Cell):
|
|
311
|
+
... def __init__(self):
|
|
312
|
+
... super().__init__()
|
|
313
|
+
... self.flatten = ops.Flatten()
|
|
314
|
+
... self.dense1 = Dense(28*28, 64)
|
|
315
|
+
... self.relu = ops.ReLU()
|
|
316
|
+
... self.dense2 = Dense(64, 10)
|
|
317
|
+
...
|
|
318
|
+
... def construct(self, x):
|
|
319
|
+
... x = self.flatten(x)
|
|
320
|
+
... x = self.dense1(x)
|
|
321
|
+
... x = self.relu(x)
|
|
322
|
+
... x = self.dense2(x)
|
|
323
|
+
... return x
|
|
324
|
+
>>> net = FFN()
|
|
325
|
+
>>> net.dense1.matmul.shard(((2, 1), (1, 2)))
|
|
326
|
+
>>>
|
|
327
|
+
>>> # Create dataset.
|
|
328
|
+
>>> step_per_epoch = 16
|
|
329
|
+
>>> def get_dataset(*inputs):
|
|
330
|
+
... def generate():
|
|
331
|
+
... for _ in range(step_per_epoch):
|
|
332
|
+
... yield inputs
|
|
333
|
+
... return generate
|
|
334
|
+
>>>
|
|
335
|
+
>>> input_data = np.random.rand(1, 28, 28).astype(np.float32)
|
|
336
|
+
>>> label_data = np.random.rand(1).astype(np.int32)
|
|
337
|
+
>>> fake_dataset = get_dataset(input_data, label_data)
|
|
338
|
+
>>> dataset = ds.GeneratorDataset(fake_dataset, ["input", "label"])
|
|
339
|
+
>>> # Train network.
|
|
340
|
+
>>> optimizer = nn.Momentum(net.trainable_params(), 1e-3, 0.1)
|
|
341
|
+
>>> loss_fn = nn.CrossEntropyLoss()
|
|
342
|
+
>>> loss_cb = train.LossMonitor()
|
|
343
|
+
>>> model = ms.Model(network=net, loss_fn=loss_fn, optimizer=optimizer)
|
|
344
|
+
>>> model.train(epoch=2, train_dataset=dataset, callbacks=[loss_cb])
|
|
266
345
|
"""
|
|
267
346
|
for key, value in kwargs.items():
|
|
268
347
|
if key not in set_algo_parameters_config_func_map:
|
|
@@ -282,6 +361,7 @@ def get_algo_parameters(attr_key):
|
|
|
282
361
|
attr_key (str): The key of the attribute. The keys include: "fully_use_devices",
|
|
283
362
|
"elementwise_op_strategy_follow", "enable_algo_approxi", "algo_approxi_epsilon",
|
|
284
363
|
"tensor_slice_align_enable","tensor_slice_align_size".
|
|
364
|
+
See :func:`mindspore.set_algo_parameters` for more details about the meaning of the attributes.
|
|
285
365
|
|
|
286
366
|
Returns:
|
|
287
367
|
Return attribute value according to the key.
|