mindspore 2.1.0__cp39-cp39-win_amd64.whl → 2.2.10__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +4 -1
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -1
- mindspore/_checkparam.py +23 -29
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +4 -11
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +12 -15
- mindspore/_extends/parse/namespace.py +7 -33
- mindspore/_extends/parse/parser.py +61 -71
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +74 -104
- mindspore/_extends/parse/trope.py +1 -1
- mindspore/_extends/remote/kernel_build_server.py +25 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/amp.py +47 -11
- mindspore/atlprov.dll +0 -0
- mindspore/boost/boost.py +1 -8
- mindspore/boost/boost_cell_wrapper.py +3 -2
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/group_loss_scale_manager.py +8 -7
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +5 -3
- mindspore/common/_jit_fallback_utils.py +6 -0
- mindspore/common/_register_for_adapter.py +2 -0
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +13 -0
- mindspore/common/_utils.py +13 -0
- mindspore/common/api.py +174 -259
- mindspore/common/auto_dynamic_shape.py +494 -0
- mindspore/common/dtype.py +18 -11
- mindspore/common/dump.py +6 -4
- mindspore/common/initializer.py +14 -14
- mindspore/common/jit_config.py +33 -15
- mindspore/common/lazy_inline.py +126 -7
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/parameter.py +51 -41
- mindspore/common/seed.py +4 -4
- mindspore/common/sparse_tensor.py +13 -14
- mindspore/common/tensor.py +243 -165
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +83 -4
- mindspore/communication/management.py +152 -84
- mindspore/config/op_info.config +14 -3
- mindspore/context.py +152 -61
- mindspore/dataset/__init__.py +5 -5
- mindspore/dataset/audio/__init__.py +2 -2
- mindspore/dataset/audio/transforms.py +52 -52
- mindspore/dataset/callback/ds_callback.py +16 -2
- mindspore/dataset/core/config.py +68 -51
- mindspore/dataset/engine/cache_client.py +28 -5
- mindspore/dataset/engine/datasets.py +250 -112
- mindspore/dataset/engine/datasets_audio.py +43 -211
- mindspore/dataset/engine/datasets_standard_format.py +16 -35
- mindspore/dataset/engine/datasets_text.py +43 -67
- mindspore/dataset/engine/datasets_user_defined.py +86 -100
- mindspore/dataset/engine/datasets_vision.py +219 -1029
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/samplers.py +1 -1
- mindspore/dataset/engine/validators.py +19 -5
- mindspore/dataset/text/__init__.py +3 -3
- mindspore/dataset/text/transforms.py +101 -127
- mindspore/dataset/text/utils.py +205 -138
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +95 -40
- mindspore/dataset/utils/browse_dataset.py +8 -2
- mindspore/dataset/utils/line_reader.py +17 -19
- mindspore/dataset/vision/__init__.py +3 -3
- mindspore/dataset/vision/c_transforms.py +6 -3
- mindspore/dataset/vision/transforms.py +409 -287
- mindspore/dataset/vision/utils.py +13 -14
- mindspore/dataset/vision/validators.py +11 -1
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/experimental/map_parameter.py +14 -0
- mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
- mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
- mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
- mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +17 -14
- mindspore/include/api/status.h +8 -3
- mindspore/include/api/types.h +37 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/dataset/constants.h +6 -5
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +13 -13
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/jpeg62.dll +0 -0
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/nn/__init__.py +0 -2
- mindspore/nn/cell.py +313 -74
- mindspore/nn/dynamic_lr.py +21 -21
- mindspore/nn/layer/activation.py +22 -30
- mindspore/nn/layer/basic.py +15 -13
- mindspore/nn/layer/channel_shuffle.py +1 -1
- mindspore/nn/layer/container.py +271 -9
- mindspore/nn/layer/conv.py +323 -204
- mindspore/nn/layer/dense.py +8 -5
- mindspore/nn/layer/embedding.py +33 -27
- mindspore/nn/layer/flash_attention.py +141 -88
- mindspore/nn/layer/image.py +8 -6
- mindspore/nn/layer/math.py +16 -25
- mindspore/nn/layer/normalization.py +107 -66
- mindspore/nn/layer/padding.py +1 -1
- mindspore/nn/layer/pooling.py +131 -109
- mindspore/nn/layer/rnn_cells.py +27 -22
- mindspore/nn/layer/rnns.py +13 -16
- mindspore/nn/layer/thor_layer.py +1 -1
- mindspore/nn/layer/transformer.py +221 -154
- mindspore/nn/learning_rate_schedule.py +9 -1
- mindspore/nn/loss/loss.py +235 -174
- mindspore/nn/optim/ada_grad.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -0
- mindspore/nn/optim/adafactor.py +2 -1
- mindspore/nn/optim/adam.py +7 -4
- mindspore/nn/optim/adamax.py +3 -2
- mindspore/nn/optim/adasum.py +2 -2
- mindspore/nn/optim/asgd.py +2 -3
- mindspore/nn/optim/ftrl.py +6 -5
- mindspore/nn/optim/lamb.py +7 -4
- mindspore/nn/optim/lars.py +1 -1
- mindspore/nn/optim/lazyadam.py +5 -3
- mindspore/nn/optim/momentum.py +2 -1
- mindspore/nn/optim/optimizer.py +53 -4
- mindspore/nn/optim/proximal_ada_grad.py +3 -4
- mindspore/nn/optim/rmsprop.py +4 -3
- mindspore/nn/optim/rprop.py +23 -12
- mindspore/nn/optim/sgd.py +26 -11
- mindspore/nn/optim/thor.py +9 -7
- mindspore/nn/probability/bijector/bijector.py +5 -5
- mindspore/nn/probability/bijector/power_transform.py +27 -27
- mindspore/nn/probability/bijector/softplus.py +3 -3
- mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
- mindspore/nn/probability/distribution/bernoulli.py +5 -5
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +7 -7
- mindspore/nn/probability/distribution/cauchy.py +0 -1
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +4 -4
- mindspore/nn/probability/distribution/gumbel.py +4 -4
- mindspore/nn/probability/distribution/log_normal.py +2 -2
- mindspore/nn/probability/distribution/logistic.py +2 -2
- mindspore/nn/probability/distribution/poisson.py +4 -4
- mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
- mindspore/nn/probability/distribution/uniform.py +6 -6
- mindspore/nn/wrap/cell_wrapper.py +84 -34
- mindspore/nn/wrap/grad_reducer.py +8 -5
- mindspore/nn/wrap/loss_scale.py +105 -42
- mindspore/numpy/array_creations.py +1 -2
- mindspore/numpy/array_ops.py +3 -2
- mindspore/numpy/utils_const.py +5 -5
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/_grad_experimental/__init__.py +0 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
- mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
- mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
- mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
- mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
- mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
- mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +165 -109
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +144 -86
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +172 -187
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +51 -57
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +6 -17
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
- mindspore/ops/_op_impl/aicpu/add.py +3 -3
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/aicpu/eps.py +32 -0
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/tbe/__init__.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +45 -13
- mindspore/ops/_utils/utils.py +6 -1
- mindspore/ops/_vmap/vmap_array_ops.py +3 -3
- mindspore/ops/_vmap/vmap_base.py +3 -3
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/base.py +37 -10
- mindspore/ops/composite/math_ops.py +5 -4
- mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
- mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
- mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +4 -1
- mindspore/ops/function/array_func.py +174 -193
- mindspore/ops/function/clip_func.py +81 -13
- mindspore/ops/function/debug_func.py +1 -1
- mindspore/ops/function/grad/grad_func.py +18 -9
- mindspore/ops/function/image_func.py +10 -4
- mindspore/ops/function/linalg_func.py +5 -5
- mindspore/ops/function/math_func.py +575 -386
- mindspore/ops/function/nn_func.py +568 -260
- mindspore/ops/function/random_func.py +88 -57
- mindspore/ops/function/sparse_func.py +1 -1
- mindspore/ops/function/sparse_unary_func.py +14 -12
- mindspore/ops/function/vmap_func.py +6 -5
- mindspore/ops/functional.py +15 -10
- mindspore/ops/op_info_register.py +244 -25
- mindspore/ops/operations/__init__.py +28 -19
- mindspore/ops/operations/_grad_ops.py +72 -7
- mindspore/ops/operations/_inner_ops.py +350 -17
- mindspore/ops/operations/_quant_ops.py +4 -8
- mindspore/ops/operations/_sequence_ops.py +42 -0
- mindspore/ops/operations/array_ops.py +68 -282
- mindspore/ops/operations/comm_ops.py +107 -59
- mindspore/ops/operations/custom_ops.py +94 -70
- mindspore/ops/operations/debug_ops.py +8 -4
- mindspore/ops/operations/image_ops.py +18 -12
- mindspore/ops/operations/inner_ops.py +26 -3
- mindspore/ops/operations/math_ops.py +189 -141
- mindspore/ops/operations/nn_ops.py +794 -489
- mindspore/ops/operations/other_ops.py +0 -22
- mindspore/ops/operations/random_ops.py +53 -111
- mindspore/ops/operations/sparse_ops.py +3 -1
- mindspore/ops/primitive.py +24 -18
- mindspore/parallel/_auto_parallel_context.py +68 -8
- mindspore/parallel/_cost_model_context.py +2 -2
- mindspore/parallel/_offload_context.py +17 -3
- mindspore/parallel/_parallel_serialization.py +12 -5
- mindspore/parallel/_ps_context.py +12 -0
- mindspore/parallel/_tensor.py +18 -13
- mindspore/parallel/_transformer/layers.py +5 -3
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +2 -2
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +23 -3
- mindspore/parallel/_utils.py +11 -7
- mindspore/parallel/algo_parameter_config.py +85 -5
- mindspore/parallel/checkpoint_transform.py +19 -12
- mindspore/parallel/shard.py +21 -14
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +4 -2
- mindspore/profiler/envprofiling.py +1 -1
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +2 -2
- mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
- mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
- mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
- mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
- mindspore/profiler/parser/ascend_op_generator.py +6 -6
- mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
- mindspore/profiler/parser/base_timeline_generator.py +10 -8
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +38 -22
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +2 -2
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +21 -2
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +179 -89
- mindspore/rewrite/api/node.py +102 -19
- mindspore/rewrite/api/node_type.py +5 -1
- mindspore/rewrite/api/pattern_engine.py +1 -1
- mindspore/rewrite/api/scoped_value.py +9 -17
- mindspore/rewrite/api/symbol_tree.py +131 -47
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +33 -24
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +273 -234
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +216 -221
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +174 -113
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +36 -15
- mindspore/rewrite/parsers/function_def_parser.py +24 -16
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +196 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/utils.py +1 -1
- mindspore/rewrite/symbol_tree.py +523 -578
- mindspore/rewrite/symbol_tree_builder.py +9 -193
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +6 -4
- mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
- mindspore/safeguard/rewrite_obfuscation.py +541 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/_utils.py +7 -3
- mindspore/train/amp.py +323 -123
- mindspore/train/anf_ir_pb2.py +14 -2
- mindspore/train/callback/_backup_and_restore.py +2 -12
- mindspore/train/callback/_callback.py +29 -4
- mindspore/train/callback/_checkpoint.py +23 -8
- mindspore/train/callback/_early_stop.py +2 -2
- mindspore/train/callback/_landscape.py +4 -4
- mindspore/train/callback/_loss_monitor.py +2 -2
- mindspore/train/callback/_on_request_exit.py +2 -2
- mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
- mindspore/train/callback/_summary_collector.py +15 -8
- mindspore/train/callback/_time_monitor.py +58 -5
- mindspore/train/data_sink.py +5 -11
- mindspore/train/dataset_helper.py +84 -57
- mindspore/train/loss_scale_manager.py +2 -2
- mindspore/train/metrics/__init__.py +3 -3
- mindspore/train/metrics/cosine_similarity.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +3 -2
- mindspore/train/metrics/mean_surface_distance.py +3 -2
- mindspore/train/metrics/metric.py +39 -19
- mindspore/train/metrics/roc.py +2 -2
- mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
- mindspore/train/mind_ir_pb2.py +85 -36
- mindspore/train/model.py +187 -47
- mindspore/train/serialization.py +487 -161
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/_writer_pool.py +3 -2
- mindspore/train/summary/summary_record.py +37 -17
- mindspore/train/train_thor/convert_utils.py +3 -3
- mindspore/train/train_thor/dataset_helper.py +1 -1
- mindspore/turbojpeg.dll +0 -0
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/METADATA +5 -3
- {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/RECORD +433 -479
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/dataset/datapreprocess/__init__.py +0 -20
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/include/api/net.h +0 -142
- mindspore/nn/lr_scheduler.py +0 -262
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
- mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/WHEEL +0 -0
- {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/entry_points.txt +0 -0
- {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/top_level.txt +0 -0
|
@@ -19,7 +19,6 @@ from mindspore import log as logger
|
|
|
19
19
|
from mindspore.ops import signature as sig
|
|
20
20
|
from mindspore import _checkparam as validator
|
|
21
21
|
from mindspore.common import dtype as mstype
|
|
22
|
-
from mindspore.common._decorator import deprecated
|
|
23
22
|
from mindspore.ops.primitive import Primitive, PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register
|
|
24
23
|
from mindspore.ops.operations._pyfunc_registry import add_pyfunc
|
|
25
24
|
from mindspore._c_expression import typing
|
|
@@ -738,27 +737,6 @@ class Pull(PrimitiveWithInfer):
|
|
|
738
737
|
return mstype.float32
|
|
739
738
|
|
|
740
739
|
|
|
741
|
-
class identity(Primitive):
|
|
742
|
-
"""
|
|
743
|
-
The :class:`mindspore.ops.identity` interface is deprecated, please use the :class:`mindspore.nn.Identity` instead.
|
|
744
|
-
|
|
745
|
-
Supported Platforms:
|
|
746
|
-
Deprecated
|
|
747
|
-
"""
|
|
748
|
-
|
|
749
|
-
# Side effect will propagated from the first argument to return value.
|
|
750
|
-
side_effect_propagate = 1
|
|
751
|
-
|
|
752
|
-
@prim_attr_register
|
|
753
|
-
def __init__(self):
|
|
754
|
-
"""Initialize identity."""
|
|
755
|
-
self.add_prim_attr('side_effect_propagate', 1)
|
|
756
|
-
|
|
757
|
-
@deprecated('2.0', 'nn.Identity', False)
|
|
758
|
-
def __call__(self, x):
|
|
759
|
-
return x
|
|
760
|
-
|
|
761
|
-
|
|
762
740
|
class PyInterpret(Primitive):
|
|
763
741
|
r"""
|
|
764
742
|
Interpret Python expression.
|
|
@@ -83,15 +83,10 @@ class TruncatedNormal(Primitive):
|
|
|
83
83
|
Note:
|
|
84
84
|
- The value of `shape` must be greater than zero. The output length can not exceed 1000000.
|
|
85
85
|
- Random seed: a set of regular random numbers can be obtained through some complex mathematical algorithms,
|
|
86
|
-
and the random seed
|
|
86
|
+
and the random seed determines the initial value of this random number. If the random seed is the same in two
|
|
87
87
|
separate calls, the random number generated will not change.
|
|
88
|
-
-
|
|
89
|
-
|
|
90
|
-
with 0 to generate random number.
|
|
91
|
-
- Global random seed is not set, operator-level random seed is set: 0
|
|
92
|
-
splices with the operator-level random seed to generate random number.
|
|
93
|
-
- Both Global random and operator-level random seed are set: the global random seed will splice with the
|
|
94
|
-
operator-level random seed to generate random number.
|
|
88
|
+
- Using the Philox algorithm to scramble seed and seed2 to obtain random seed so that the user doesn't need
|
|
89
|
+
to worry about which seed is more important.
|
|
95
90
|
|
|
96
91
|
Args:
|
|
97
92
|
seed (int, optional): The operator-level random seed, used to generate random numbers,
|
|
@@ -152,15 +147,10 @@ class StandardNormal(Primitive):
|
|
|
152
147
|
|
|
153
148
|
Note:
|
|
154
149
|
- Random seed: a set of regular random numbers can be obtained through some complex mathematical algorithms,
|
|
155
|
-
and the random seed
|
|
150
|
+
and the random seed determines the initial value of this random number. If the random seed is the same in two
|
|
156
151
|
separate calls, the random number generated will not change.
|
|
157
|
-
-
|
|
158
|
-
|
|
159
|
-
with 0 to generate random number.
|
|
160
|
-
- Global random seed is not set, operator-level random seed is set: 0
|
|
161
|
-
splices with the operator-level random seed to generate random number.
|
|
162
|
-
- Both Global random and operator-level random seed are set: the global random seed will splice with the
|
|
163
|
-
operator-level random seed to generate random number.
|
|
152
|
+
- Using the Philox algorithm to scramble seed and seed2 to obtain random seed so that the user doesn't need
|
|
153
|
+
to worry about which seed is more important.
|
|
164
154
|
|
|
165
155
|
Args:
|
|
166
156
|
seed (int, optional): The operator-level random seed, used to generate random numbers,
|
|
@@ -208,15 +198,10 @@ class StandardLaplace(Primitive):
|
|
|
208
198
|
|
|
209
199
|
Note:
|
|
210
200
|
- Random seed: a set of regular random numbers can be obtained through some complex mathematical algorithms,
|
|
211
|
-
and the random seed
|
|
201
|
+
and the random seed determines the initial value of this random number. If the random seed is the same in two
|
|
212
202
|
separate calls, the random number generated will not change.
|
|
213
|
-
-
|
|
214
|
-
|
|
215
|
-
with 0 to generate random number.
|
|
216
|
-
- Global random seed is not set, operator-level random seed is set: 0
|
|
217
|
-
splices with the operator-level random seed to generate random number.
|
|
218
|
-
- Both Global random and operator-level random seed are set: the global random seed will splice with the
|
|
219
|
-
operator-level random seed to generate random number.
|
|
203
|
+
- Using the Philox algorithm to scramble seed and seed2 to obtain random seed so that the user doesn't need
|
|
204
|
+
to worry about which seed is more important.
|
|
220
205
|
|
|
221
206
|
Args:
|
|
222
207
|
seed (int, optional): The operator-level random seed, used to generate random numbers,
|
|
@@ -266,15 +251,10 @@ class RandomGamma(Primitive):
|
|
|
266
251
|
|
|
267
252
|
Note:
|
|
268
253
|
- Random seed: a set of regular random numbers can be obtained through some complex mathematical algorithms,
|
|
269
|
-
and the random seed
|
|
254
|
+
and the random seed determines the initial value of this random number. If the random seed is the same in two
|
|
270
255
|
separate calls, the random number generated will not change.
|
|
271
|
-
-
|
|
272
|
-
|
|
273
|
-
with 0 to generate random number.
|
|
274
|
-
- Global random seed is not set, operator-level random seed is set: 0
|
|
275
|
-
splices with the operator-level random seed to generate random number.
|
|
276
|
-
- Both Global random and operator-level random seed are set: the global random seed will splice with the
|
|
277
|
-
operator-level random seed to generate random number.
|
|
256
|
+
- Using the Philox algorithm to scramble seed and seed2 to obtain random seed so that the user doesn't need
|
|
257
|
+
to worry about which seed is more important.
|
|
278
258
|
|
|
279
259
|
Args:
|
|
280
260
|
seed (int, optional): The operator-level random seed, used to generate random numbers,
|
|
@@ -380,15 +360,10 @@ class Gamma(PrimitiveWithInfer):
|
|
|
380
360
|
|
|
381
361
|
Note:
|
|
382
362
|
- Random seed: a set of regular random numbers can be obtained through some complex mathematical algorithms,
|
|
383
|
-
and the random seed
|
|
363
|
+
and the random seed determines the initial value of this random number. If the random seed is the same in two
|
|
384
364
|
separate calls, the random number generated will not change.
|
|
385
|
-
-
|
|
386
|
-
|
|
387
|
-
with 0 to generate random number.
|
|
388
|
-
- Global random seed is not set, operator-level random seed is set: 0
|
|
389
|
-
splices with the operator-level random seed to generate random number.
|
|
390
|
-
- Both Global random and operator-level random seed are set: the global random seed will splice with the
|
|
391
|
-
operator-level random seed to generate random number.
|
|
365
|
+
- Using the Philox algorithm to scramble seed and seed2 to obtain random seed so that the user doesn't need
|
|
366
|
+
to worry about which seed is more important.
|
|
392
367
|
|
|
393
368
|
Args:
|
|
394
369
|
seed (int, optional): The operator-level random seed, used to generate random numbers,
|
|
@@ -468,15 +443,10 @@ class ParameterizedTruncatedNormal(Primitive):
|
|
|
468
443
|
Note:
|
|
469
444
|
- The value in tensor `min` must be strictly less than `max` at any position after broadcasting.
|
|
470
445
|
- Random seed: a set of regular random numbers can be obtained through some complex mathematical algorithms,
|
|
471
|
-
and the random seed
|
|
446
|
+
and the random seed determines the initial value of this random number. If the random seed is the same in two
|
|
472
447
|
separate calls, the random number generated will not change.
|
|
473
|
-
-
|
|
474
|
-
|
|
475
|
-
with 0 to generate random number.
|
|
476
|
-
- Global random seed is not set, operator-level random seed is set: 0
|
|
477
|
-
splices with the operator-level random seed to generate random number.
|
|
478
|
-
- Both Global random and operator-level random seed are set: the global random seed will splice with the
|
|
479
|
-
operator-level random seed to generate random number.
|
|
448
|
+
- Using the Philox algorithm to scramble seed and seed2 to obtain random seed so that the user doesn't need
|
|
449
|
+
to worry about which seed is more important.
|
|
480
450
|
|
|
481
451
|
Args:
|
|
482
452
|
seed (int, optional): The operator-level random seed, used to generate random numbers,
|
|
@@ -551,15 +521,10 @@ class Poisson(PrimitiveWithInfer):
|
|
|
551
521
|
|
|
552
522
|
Note:
|
|
553
523
|
- Random seed: a set of regular random numbers can be obtained through some complex mathematical algorithms,
|
|
554
|
-
and the random seed
|
|
524
|
+
and the random seed determines the initial value of this random number. If the random seed is the same in two
|
|
555
525
|
separate calls, the random number generated will not change.
|
|
556
|
-
-
|
|
557
|
-
|
|
558
|
-
with 0 to generate random number.
|
|
559
|
-
- Global random seed is not set, operator-level random seed is set: 0
|
|
560
|
-
splices with the operator-level random seed to generate random number.
|
|
561
|
-
- Both Global random and operator-level random seed are set: the global random seed will splice with the
|
|
562
|
-
operator-level random seed to generate random number.
|
|
526
|
+
- Using the Philox algorithm to scramble seed and seed2 to obtain random seed so that the user doesn't need
|
|
527
|
+
to worry about which seed is more important.
|
|
563
528
|
|
|
564
529
|
Args:
|
|
565
530
|
seed (int, optional): The operator-level random seed, used to generate random numbers,
|
|
@@ -630,15 +595,10 @@ class RandomPoisson(Primitive):
|
|
|
630
595
|
|
|
631
596
|
Note:
|
|
632
597
|
- Random seed: a set of regular random numbers can be obtained through some complex mathematical algorithms,
|
|
633
|
-
and the random seed
|
|
598
|
+
and the random seed determines the initial value of this random number. If the random seed is the same in two
|
|
634
599
|
separate calls, the random number generated will not change.
|
|
635
|
-
-
|
|
636
|
-
|
|
637
|
-
with 0 to generate random number.
|
|
638
|
-
- Global random seed is not set, operator-level random seed is set: 0
|
|
639
|
-
splices with the operator-level random seed to generate random number.
|
|
640
|
-
- Both Global random and operator-level random seed are set: the global random seed will splice with the
|
|
641
|
-
operator-level random seed to generate random number.
|
|
600
|
+
- Using the Philox algorithm to scramble seed and seed2 to obtain random seed so that the user doesn't need
|
|
601
|
+
to worry about which seed is more important.
|
|
642
602
|
|
|
643
603
|
Args:
|
|
644
604
|
seed (int, optional): The operator-level random seed, used to generate random numbers,
|
|
@@ -705,15 +665,10 @@ class UniformInt(Primitive):
|
|
|
705
665
|
Note:
|
|
706
666
|
- The number in tensor minval must be strictly less than maxval at any position after broadcasting.
|
|
707
667
|
- Random seed: a set of regular random numbers can be obtained through some complex mathematical algorithms,
|
|
708
|
-
and the random seed
|
|
668
|
+
and the random seed determines the initial value of this random number. If the random seed is the same in two
|
|
709
669
|
separate calls, the random number generated will not change.
|
|
710
|
-
-
|
|
711
|
-
|
|
712
|
-
with 0 to generate random number.
|
|
713
|
-
- Global random seed is not set, operator-level random seed is set: 0
|
|
714
|
-
splices with the operator-level random seed to generate random number.
|
|
715
|
-
- Both Global random and operator-level random seed are set: the global random seed will splice with the
|
|
716
|
-
operator-level random seed to generate random number.
|
|
670
|
+
- Using the Philox algorithm to scramble seed and seed2 to obtain random seed so that the user doesn't need
|
|
671
|
+
to worry about which seed is more important.
|
|
717
672
|
|
|
718
673
|
Args:
|
|
719
674
|
seed (int, optional): The operator-level random seed, used to generate random numbers,
|
|
@@ -769,15 +724,16 @@ class UniformReal(Primitive):
|
|
|
769
724
|
|
|
770
725
|
Note:
|
|
771
726
|
- Random seed: a set of regular random numbers can be obtained through some complex mathematical algorithms,
|
|
772
|
-
and the random seed
|
|
727
|
+
and the random seed determines the initial value of this random number. If the random seed is the same in two
|
|
773
728
|
separate calls, the random number generated will not change.
|
|
774
|
-
-
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
729
|
+
- Using the Philox algorithm to scramble seed and seed2 to obtain random seed so that the user doesn't need
|
|
730
|
+
to worry about which seed is more important.
|
|
731
|
+
- Currently, on the Ascend platform, `shape` as a Tensor is not supported.
|
|
732
|
+
This is supported on CPU/GPU platforms. When the input is a Tensor,
|
|
733
|
+
the supported data types are as follows:
|
|
734
|
+
|
|
735
|
+
- GPU: int32, int64.
|
|
736
|
+
- CPU: int16, int32, int64.
|
|
781
737
|
|
|
782
738
|
Args:
|
|
783
739
|
seed (int, optional): The operator-level random seed, used to generate random numbers,
|
|
@@ -787,7 +743,6 @@ class UniformReal(Primitive):
|
|
|
787
743
|
|
|
788
744
|
Inputs:
|
|
789
745
|
- **shape** (Union[tuple, Tensor]) - The shape of tensor to be generated. Only constant value is allowed.
|
|
790
|
-
Supported dtypes: int16, int32, int64.
|
|
791
746
|
|
|
792
747
|
Outputs:
|
|
793
748
|
Tensor. The shape that the input 'shape' denotes. The dtype is float32.
|
|
@@ -809,6 +764,7 @@ class UniformReal(Primitive):
|
|
|
809
764
|
>>> print(result)
|
|
810
765
|
(2, 2)
|
|
811
766
|
"""
|
|
767
|
+
|
|
812
768
|
@prim_attr_register
|
|
813
769
|
def __init__(self, seed=0, seed2=0):
|
|
814
770
|
"""Initialize UniformReal"""
|
|
@@ -826,15 +782,10 @@ class RandomChoiceWithMask(Primitive):
|
|
|
826
782
|
|
|
827
783
|
Note:
|
|
828
784
|
- Random seed: a set of regular random numbers can be obtained through some complex mathematical algorithms,
|
|
829
|
-
and the random seed
|
|
785
|
+
and the random seed determines the initial value of this random number. If the random seed is the same in two
|
|
830
786
|
separate calls, the random number generated will not change.
|
|
831
|
-
-
|
|
832
|
-
|
|
833
|
-
with 0 to generate random number.
|
|
834
|
-
- Global random seed is not set, operator-level random seed is set: 0
|
|
835
|
-
splices with the operator-level random seed to generate random number.
|
|
836
|
-
- Both Global random and operator-level random seed are set: the global random seed will splice with the
|
|
837
|
-
operator-level random seed to generate random number.
|
|
787
|
+
- Using the Philox algorithm to scramble seed and seed2 to obtain random seed so that the user doesn't need
|
|
788
|
+
to worry about which seed is more important.
|
|
838
789
|
|
|
839
790
|
Args:
|
|
840
791
|
count (int, optional): Number of items expected to get and the number must be greater than 0. Default: ``256`` .
|
|
@@ -850,8 +801,8 @@ class RandomChoiceWithMask(Primitive):
|
|
|
850
801
|
Outputs:
|
|
851
802
|
Two tensors, the first one is the index tensor and the other one is the mask tensor.
|
|
852
803
|
|
|
853
|
-
- **index** (Tensor) - The output shape is 2-D
|
|
854
|
-
- **mask** (Tensor) - The output shape is 1-D
|
|
804
|
+
- **index** (Tensor) - The output shape is 2-D, its shape is :math:`(count, rank of input_x)`.
|
|
805
|
+
- **mask** (Tensor) - The output shape is 1-D, its shape is :math:`(count)`.
|
|
855
806
|
|
|
856
807
|
Supported Platforms:
|
|
857
808
|
``Ascend`` ``GPU`` ``CPU``
|
|
@@ -945,15 +896,10 @@ class Multinomial(Primitive):
|
|
|
945
896
|
- The rows of input do not need to sum to one (in which case we use the values as weights),
|
|
946
897
|
but must be non-negative, finite and have a non-zero sum.
|
|
947
898
|
- Random seed: a set of regular random numbers can be obtained through some complex mathematical algorithms,
|
|
948
|
-
and the random seed
|
|
899
|
+
and the random seed determines the initial value of this random number. If the random seed is the same in two
|
|
949
900
|
separate calls, the random number generated will not change.
|
|
950
|
-
-
|
|
951
|
-
|
|
952
|
-
with 0 to generate random number.
|
|
953
|
-
- Global random seed is not set, operator-level random seed is set: 0
|
|
954
|
-
splices with the operator-level random seed to generate random number.
|
|
955
|
-
- Both Global random and operator-level random seed are set: the global random seed will splice with the
|
|
956
|
-
operator-level random seed to generate random number.
|
|
901
|
+
- Using the Philox algorithm to scramble seed and seed2 to obtain random seed so that the user doesn't need
|
|
902
|
+
to worry about which seed is more important.
|
|
957
903
|
|
|
958
904
|
Args:
|
|
959
905
|
seed (int, optional): The operator-level random seed, used to generate random numbers,
|
|
@@ -1024,8 +970,8 @@ class MultinomialWithReplacement(Primitive):
|
|
|
1024
970
|
Inputs:
|
|
1025
971
|
- **x** (Tensor) - the input tensor containing the cumsum of probabilities, must be 1 or 2
|
|
1026
972
|
dimensions.
|
|
1027
|
-
- **seed** (Tensor) - If `seed`
|
|
1028
|
-
|
|
973
|
+
- **seed** (Tensor) - If `seed` and 'offset' are both set to 0, the random number generator
|
|
974
|
+
is seeded by a random seed. Otherwise, it is seeded by the given seed and offset.
|
|
1029
975
|
Supported dtype: int64.
|
|
1030
976
|
- **offset** (Tensor) - Offset used to avoid seed collision. Supported dtype: int64.
|
|
1031
977
|
|
|
@@ -1072,7 +1018,9 @@ class UniformCandidateSampler(Primitive):
|
|
|
1072
1018
|
range_max (int): The number of possible classes, must be non-negative.
|
|
1073
1019
|
seed (int, optional): Used for random number generation, must be non-negative. If seed has a value of 0,
|
|
1074
1020
|
the seed will be replaced with a randomly generated value. Default: ``0`` .
|
|
1075
|
-
remove_accidental_hits (bool, optional): Whether accidental hit is removed.
|
|
1021
|
+
remove_accidental_hits (bool, optional): Whether accidental hit is removed.
|
|
1022
|
+
Accidental hit is when one of the true classes matches one of the sample classes.
|
|
1023
|
+
Set ``True`` to remove which accidentally sampling the true class as sample class. Default: ``False`` .
|
|
1076
1024
|
|
|
1077
1025
|
Inputs:
|
|
1078
1026
|
- **true_classes** (Tensor) - A Tensor. The target classes with a Tensor shape of
|
|
@@ -1128,7 +1076,6 @@ class UniformCandidateSampler(Primitive):
|
|
|
1128
1076
|
self.add_prim_attr("side_effect_hidden", True)
|
|
1129
1077
|
|
|
1130
1078
|
|
|
1131
|
-
|
|
1132
1079
|
class LogUniformCandidateSampler(Primitive):
|
|
1133
1080
|
r"""
|
|
1134
1081
|
Generates random labels with a log-uniform distribution for sampled_candidates.
|
|
@@ -1206,15 +1153,10 @@ class RandomShuffle(Primitive):
|
|
|
1206
1153
|
|
|
1207
1154
|
Note:
|
|
1208
1155
|
- Random seed: a set of regular random numbers can be obtained through some complex mathematical algorithms,
|
|
1209
|
-
and the random seed
|
|
1156
|
+
and the random seed determines the initial value of this random number. If the random seed is the same in two
|
|
1210
1157
|
separate calls, the random number generated will not change.
|
|
1211
|
-
-
|
|
1212
|
-
|
|
1213
|
-
with 0 to generate random number.
|
|
1214
|
-
- Global random seed is not set, operator-level random seed is set: 0
|
|
1215
|
-
splices with the operator-level random seed to generate random number.
|
|
1216
|
-
- Both Global random and operator-level random seed are set: the global random seed will splice with the
|
|
1217
|
-
operator-level random seed to generate random number.
|
|
1158
|
+
- Using the Philox algorithm to scramble seed and seed2 to obtain random seed so that the user doesn't need
|
|
1159
|
+
to worry about which seed is more important.
|
|
1218
1160
|
|
|
1219
1161
|
Args:
|
|
1220
1162
|
seed (int, optional): The operator-level random seed, used to generate random numbers,
|
|
@@ -1615,7 +1615,7 @@ class SparseMatrixSoftmax(Primitive):
|
|
|
1615
1615
|
if not isinstance(dtype, (type(mstype.float32), type(mstype.single), type(mstype.float64),
|
|
1616
1616
|
type(mstype.double))):
|
|
1617
1617
|
raise TypeError(
|
|
1618
|
-
"Only float32 and float64 type data are supported, but got {}"
|
|
1618
|
+
f"Only float32 and float64 type data are supported, but got {dtype}")
|
|
1619
1619
|
self.add_prim_attr("dtype", dtype)
|
|
1620
1620
|
self.init_prim_io_names(inputs=['x_dense_shape', 'x_batch_pointers', 'x_row_pointers',
|
|
1621
1621
|
'x_col_indices', 'x_values'],
|
|
@@ -2602,6 +2602,8 @@ class RaggedTensorToTensor(Primitive):
|
|
|
2602
2602
|
raise ValueError(
|
|
2603
2603
|
f"For {self.name}, the each element of row_partition_types must be 'ROW_SPLITS' "
|
|
2604
2604
|
f"when row_splits tensor.")
|
|
2605
|
+
self.num_row_partition_tensors = len(row_partition_types)
|
|
2606
|
+
self.add_prim_attr("num_row_partition_tensors", self.num_row_partition_tensors)
|
|
2605
2607
|
|
|
2606
2608
|
|
|
2607
2609
|
class SparseCross(Primitive):
|
mindspore/ops/primitive.py
CHANGED
|
@@ -25,7 +25,7 @@ from mindspore.parallel._ps_context import _is_ps_mode, _is_role_sched
|
|
|
25
25
|
from mindspore.common.parameter import Parameter
|
|
26
26
|
from mindspore.common.api import _pynative_executor
|
|
27
27
|
from mindspore.common._stub_tensor import _convert_stub
|
|
28
|
-
from mindspore._c_expression import Primitive_, prim_type
|
|
28
|
+
from mindspore._c_expression import Primitive_, prim_type, typing
|
|
29
29
|
from mindspore import _checkparam as Validator
|
|
30
30
|
from mindspore.ops import signature as sig
|
|
31
31
|
|
|
@@ -746,6 +746,8 @@ def _check_contains_variable(item_dtype, item_value):
|
|
|
746
746
|
if _check_contains_variable(item_dtype[i], element):
|
|
747
747
|
return True
|
|
748
748
|
elif isinstance(item_value, dict):
|
|
749
|
+
if isinstance(item_dtype, typing.Keyword):
|
|
750
|
+
return item_value is None
|
|
749
751
|
for i in range(len(item_value)):
|
|
750
752
|
if _check_contains_variable(item_dtype[i], list(item_value.keys())[i]):
|
|
751
753
|
return True
|
|
@@ -756,9 +758,7 @@ def _check_contains_variable(item_dtype, item_value):
|
|
|
756
758
|
|
|
757
759
|
|
|
758
760
|
def constexpr(fn=None, get_instance=True, name=None, reuse_result=True, check=True):
|
|
759
|
-
"""
|
|
760
|
-
Creates a PrimitiveWithInfer operator that can infer the value at compile time. We can use it to define a function
|
|
761
|
-
to compute constant value using the constants in the constructor.
|
|
761
|
+
"""Used to calculate constant in graph copmpiling process and improve compile performance in GRAPH_MODE.
|
|
762
762
|
|
|
763
763
|
Args:
|
|
764
764
|
fn (function): A `fn` use as the infer_value of the output operator. Default: ``None`` .
|
|
@@ -772,22 +772,27 @@ def constexpr(fn=None, get_instance=True, name=None, reuse_result=True, check=Tr
|
|
|
772
772
|
and the warning message will raised if the parameter is not const value. Default: ``True`` .
|
|
773
773
|
|
|
774
774
|
Examples:
|
|
775
|
-
|
|
776
|
-
>>>
|
|
777
|
-
>>> #
|
|
778
|
-
>>>
|
|
779
|
-
|
|
780
|
-
...
|
|
775
|
+
|
|
776
|
+
>>> import mindspore as ms
|
|
777
|
+
>>> # define a constant calculate function with for loop inside and use use constexpr to accelerate the compile
|
|
778
|
+
>>> # process.
|
|
779
|
+
>>> @ms.constexpr
|
|
780
|
+
... def for_loop_calculate(range_num):
|
|
781
|
+
... out = 0
|
|
782
|
+
... for i in range(range_num):
|
|
783
|
+
... if i %2 == 0 and i % 7 != 0:
|
|
784
|
+
... out = out + i
|
|
785
|
+
... return out // range_num
|
|
781
786
|
...
|
|
782
|
-
>>>
|
|
783
|
-
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
...
|
|
787
|
-
... return len(x)
|
|
787
|
+
>>> # construct a net and run with GRAPH_MODE.
|
|
788
|
+
>>> @ms.jit
|
|
789
|
+
... def my_func(x):
|
|
790
|
+
... new_shape = for_loop_calculate(100000)
|
|
791
|
+
... return ms.ops.broadcast_to(x, (new_shape, ))
|
|
788
792
|
...
|
|
789
|
-
>>>
|
|
790
|
-
|
|
793
|
+
>>> out = my_func(ms.Tensor([1]))
|
|
794
|
+
>>> print(out.shape)
|
|
795
|
+
>>> (21428, )
|
|
791
796
|
"""
|
|
792
797
|
|
|
793
798
|
def deco(fn):
|
|
@@ -844,6 +849,7 @@ def _primexpr(fn=None, get_instance=True, name=None, reuse_result=True):
|
|
|
844
849
|
reuse_result (bool): If ``True`` , the operator will be executed once and reuse the result next time,
|
|
845
850
|
otherwise the operator will always be executed. Default: ``True`` .
|
|
846
851
|
"""
|
|
852
|
+
|
|
847
853
|
def deco(fn):
|
|
848
854
|
"""Decorator for CompileOp."""
|
|
849
855
|
|
|
@@ -62,6 +62,7 @@ class _ParallelOptimizerConfig:
|
|
|
62
62
|
"""
|
|
63
63
|
GRADIENT_ACCUMULATION_SHARD = "gradient_accumulation_shard"
|
|
64
64
|
PARALLEL_OPTIMIZER_THRESHOLD = "parallel_optimizer_threshold"
|
|
65
|
+
OPTIMIZER_WEIGHT_SHARD_SIZE = "optimizer_weight_shard_size"
|
|
65
66
|
|
|
66
67
|
|
|
67
68
|
class _AutoParallelContext:
|
|
@@ -176,7 +177,6 @@ class _AutoParallelContext:
|
|
|
176
177
|
if comm_type == _ParallelFusionConfig.REDUCESCATTER:
|
|
177
178
|
self._context_handle.set_reducescatter_fusion_threshold_mb(fusion_threshold)
|
|
178
179
|
|
|
179
|
-
|
|
180
180
|
def fusion_threshold_mb(self):
|
|
181
181
|
"""Get all reduce threshold."""
|
|
182
182
|
self.check_context_handle()
|
|
@@ -229,6 +229,22 @@ class _AutoParallelContext:
|
|
|
229
229
|
self.check_context_handle()
|
|
230
230
|
return self._context_handle.get_pipeline_stage_split_num()
|
|
231
231
|
|
|
232
|
+
def set_pipeline_segments(self, segments):
|
|
233
|
+
"""Set the segments of the pipeline"""
|
|
234
|
+
if isinstance(segments, bool) or not isinstance(segments, int):
|
|
235
|
+
raise TypeError("For 'set_auto_parallel_context', the argument 'pipeline_segments' "
|
|
236
|
+
"must be int, but got the type : {}.".format(type(segments)))
|
|
237
|
+
if segments < 1:
|
|
238
|
+
raise ValueError("For 'set_auto_parallel_context', the argument 'pipeline_segments' "
|
|
239
|
+
"should be greater or equal 1, but got the value of segments : {}.".format(segments))
|
|
240
|
+
self.check_context_handle()
|
|
241
|
+
self._context_handle.set_pipeline_segment_split_num(segments)
|
|
242
|
+
|
|
243
|
+
def get_pipeline_segments(self):
|
|
244
|
+
"""Get the stages of the pipeline"""
|
|
245
|
+
self.check_context_handle()
|
|
246
|
+
return self._context_handle.get_pipeline_segment_split_num()
|
|
247
|
+
|
|
232
248
|
def set_gradients_mean(self, gradients_mean):
|
|
233
249
|
"""
|
|
234
250
|
Set gradients_mean flag.
|
|
@@ -491,6 +507,9 @@ class _AutoParallelContext:
|
|
|
491
507
|
Args:
|
|
492
508
|
grad_accumulation_step (int): The grad accumulation step.
|
|
493
509
|
"""
|
|
510
|
+
if grad_accumulation_step > 1:
|
|
511
|
+
raise ValueError("The interface is deprecated. To use gradient accumulation, "
|
|
512
|
+
"please use GradAccumulationCell in mindspore.nn.wrap.cell_wrapper.")
|
|
494
513
|
self.check_context_handle()
|
|
495
514
|
Validator.check_positive_int(grad_accumulation_step)
|
|
496
515
|
self._context_handle.set_grad_accumulation_step(grad_accumulation_step)
|
|
@@ -758,6 +777,11 @@ class _AutoParallelContext:
|
|
|
758
777
|
.format(type(enable_parallel_optimizer)))
|
|
759
778
|
self._context_handle.set_enable_parallel_optimizer(enable_parallel_optimizer)
|
|
760
779
|
|
|
780
|
+
def get_enable_fold_pipeline(self):
|
|
781
|
+
"""Get parallel optimizer flag."""
|
|
782
|
+
self.check_context_handle()
|
|
783
|
+
return self._context_handle.get_enable_fold_pipeline()
|
|
784
|
+
|
|
761
785
|
def get_enable_parallel_optimizer(self):
|
|
762
786
|
"""Get parallel optimizer flag."""
|
|
763
787
|
self.check_context_handle()
|
|
@@ -767,8 +791,6 @@ class _AutoParallelContext:
|
|
|
767
791
|
r"""
|
|
768
792
|
Set the configure for parallel optimizer. The configure provides more detailed behavior control about parallel
|
|
769
793
|
training when parallel optimizer is enabled.
|
|
770
|
-
Currently it supports the key `gradient_accumulation_shard`. The configure will be effective
|
|
771
|
-
when we use context.set_auto_parallel_context(enable_parallel_optimizer=True).
|
|
772
794
|
|
|
773
795
|
Args:
|
|
774
796
|
parallel_optimizer_config(dict): A dict contains the keys and values for setting the parallel optimizer
|
|
@@ -786,14 +808,21 @@ class _AutoParallelContext:
|
|
|
786
808
|
enabled, parameters with size smaller than this threshold will not be
|
|
787
809
|
sharded across the devices. Parameter size = shape[0] \* ... \*
|
|
788
810
|
shape[n] \* size(dtype). Non-negative. Unit: KB. Default: 64.
|
|
811
|
+
- optimizer_weight_shard_size(int): Set the optimizer weight shard group size if you want to specific the
|
|
812
|
+
maximum group size across devices when the parallel optimizer is
|
|
813
|
+
enabled. The numerical range can be (0, device_num]. Default value
|
|
814
|
+
is -1, which means the optimizer weight shard group size will
|
|
815
|
+
the data parallel group of each parameter. Default -1.
|
|
816
|
+
|
|
789
817
|
"""
|
|
790
818
|
self.check_context_handle()
|
|
791
819
|
grad_shard_name = _ParallelOptimizerConfig.GRADIENT_ACCUMULATION_SHARD
|
|
792
820
|
threshold_name = _ParallelOptimizerConfig.PARALLEL_OPTIMIZER_THRESHOLD
|
|
821
|
+
optimizer_weight_shard_size_name = _ParallelOptimizerConfig.OPTIMIZER_WEIGHT_SHARD_SIZE
|
|
793
822
|
|
|
794
823
|
for config_name in parallel_optimizer_config:
|
|
795
824
|
unknown_config = []
|
|
796
|
-
if config_name not in [grad_shard_name, threshold_name]:
|
|
825
|
+
if config_name not in [grad_shard_name, threshold_name, optimizer_weight_shard_size_name]:
|
|
797
826
|
unknown_config.append(config_name)
|
|
798
827
|
|
|
799
828
|
if unknown_config:
|
|
@@ -811,6 +840,11 @@ class _AutoParallelContext:
|
|
|
811
840
|
self._context_handle.set_parallel_optimizer_threshold(
|
|
812
841
|
parallel_optimizer_config[threshold_name])
|
|
813
842
|
|
|
843
|
+
if optimizer_weight_shard_size_name in parallel_optimizer_config:
|
|
844
|
+
value = parallel_optimizer_config[optimizer_weight_shard_size_name]
|
|
845
|
+
Validator.check_positive_int(value)
|
|
846
|
+
self.set_optimizer_weight_shard_size(value)
|
|
847
|
+
|
|
814
848
|
def get_grad_accumulation_shard(self):
|
|
815
849
|
"""Get grad accumulation shard."""
|
|
816
850
|
self.check_context_handle()
|
|
@@ -890,6 +924,13 @@ class _AutoParallelContext:
|
|
|
890
924
|
self.check_context_handle()
|
|
891
925
|
return self._context_handle.get_optimizer_weight_shard_size()
|
|
892
926
|
|
|
927
|
+
def set_ops_strategy_json_config(self, type, path, mode):
|
|
928
|
+
"""
|
|
929
|
+
Set configuration of saving ops strategy in file .json.
|
|
930
|
+
"""
|
|
931
|
+
self.check_context_handle()
|
|
932
|
+
self._context_handle.set_ops_strategy_json_config(type, path, mode)
|
|
933
|
+
|
|
893
934
|
def set_optimizer_weight_shard_aggregated_save(self, optimizer_weight_shard_aggregated_save):
|
|
894
935
|
"""
|
|
895
936
|
Set optimizer_weight_shard_aggregated_save.
|
|
@@ -1027,8 +1068,28 @@ class _AutoParallelContext:
|
|
|
1027
1068
|
self.set_enable_all_gather_fusion(openstate)
|
|
1028
1069
|
self.set_enable_reduce_scatter_fusion(openstate)
|
|
1029
1070
|
|
|
1071
|
+
def _set_ops_strategy_json_config(type="SAVE", path="", mode="all"):
|
|
1072
|
+
"""
|
|
1073
|
+
Set strategy json configuration.
|
|
1030
1074
|
|
|
1075
|
+
Args:
|
|
1076
|
+
type (str): The parameter for choosing save or load .json file.
|
|
1077
|
+
path (str): Path to save or load parallel strategy json.
|
|
1078
|
+
mode (str): The parameter for choosing save all or important operators.
|
|
1031
1079
|
|
|
1080
|
+
Raises:
|
|
1081
|
+
KeyError: When type is not 'SAVE' or 'LOAD'.
|
|
1082
|
+
KeyError: When mode is not 'all' or 'principal'.
|
|
1083
|
+
"""
|
|
1084
|
+
dir_path = os.path.dirname(path)
|
|
1085
|
+
if dir_path and not os.path.exists(dir_path):
|
|
1086
|
+
os.makedirs(dir_path)
|
|
1087
|
+
check_type = ["SAVE", "LOAD"]
|
|
1088
|
+
check_mode = ["all", "principal"]
|
|
1089
|
+
if type in check_type and mode in check_mode:
|
|
1090
|
+
auto_parallel_context().set_ops_strategy_json_config(type, path, mode)
|
|
1091
|
+
else:
|
|
1092
|
+
raise KeyError("Type must be 'SAVE' or 'LOAD' and mode must be 'all' or 'principal'")
|
|
1032
1093
|
|
|
1033
1094
|
_AUTO_PARALLEL_CONTEXT = None
|
|
1034
1095
|
|
|
@@ -1053,6 +1114,7 @@ _set_auto_parallel_context_func_map = {
|
|
|
1053
1114
|
"gradient_fp32_sync": auto_parallel_context().set_gradient_fp32_sync,
|
|
1054
1115
|
"loss_repeated_mean": auto_parallel_context().set_loss_repeated_mean,
|
|
1055
1116
|
"pipeline_stages": auto_parallel_context().set_pipeline_stages,
|
|
1117
|
+
"pipeline_segments": auto_parallel_context().set_pipeline_segments,
|
|
1056
1118
|
"parallel_mode": auto_parallel_context().set_parallel_mode,
|
|
1057
1119
|
"search_mode": auto_parallel_context().set_strategy_search_mode,
|
|
1058
1120
|
"auto_parallel_search_mode": auto_parallel_context().set_auto_parallel_search_mode,
|
|
@@ -1074,7 +1136,6 @@ _set_auto_parallel_context_func_map = {
|
|
|
1074
1136
|
"strategy_ckpt_config": auto_parallel_context().set_strategy_ckpt_config,
|
|
1075
1137
|
"comm_fusion": auto_parallel_context().set_comm_fusion}
|
|
1076
1138
|
|
|
1077
|
-
|
|
1078
1139
|
_get_auto_parallel_context_func_map = {
|
|
1079
1140
|
"device_num": auto_parallel_context().get_device_num,
|
|
1080
1141
|
"global_rank": auto_parallel_context().get_global_rank,
|
|
@@ -1111,7 +1172,6 @@ _get_auto_parallel_context_func_map = {
|
|
|
1111
1172
|
communi_parallel_mode=str, optimizer_weight_shard_size=int, sharding_propagation=bool,
|
|
1112
1173
|
optimizer_weight_shard_aggregated_save=bool, enable_alltoall=bool, comm_fusion=dict,
|
|
1113
1174
|
strategy_ckpt_config=dict)
|
|
1114
|
-
|
|
1115
1175
|
def _set_auto_parallel_context(**kwargs):
|
|
1116
1176
|
"""
|
|
1117
1177
|
Set auto parallel context.
|
|
@@ -1247,8 +1307,8 @@ def _reset_auto_parallel_context():
|
|
|
1247
1307
|
- strategy_ckpt_load_file: ""
|
|
1248
1308
|
- strategy_ckpt_save_file: ""
|
|
1249
1309
|
- enable_parallel_optimizer: False
|
|
1250
|
-
- search_mode:
|
|
1251
|
-
- auto_parallel_search_mode:
|
|
1310
|
+
- search_mode: 'recursive_programming
|
|
1311
|
+
- auto_parallel_search_mode: 'recursive_programming
|
|
1252
1312
|
- sharding_propagation: False
|
|
1253
1313
|
- pipeline_stages: 0
|
|
1254
1314
|
- gradient_accumulation_shard: True
|
|
@@ -475,7 +475,7 @@ class _CostModelContext:
|
|
|
475
475
|
"""
|
|
476
476
|
if self._context_handle is None:
|
|
477
477
|
raise ValueError("Context handle is none in context!!!")
|
|
478
|
-
return self._context_handle.
|
|
478
|
+
return self._context_handle.get_rp_matmul_mem_coef()
|
|
479
479
|
|
|
480
480
|
def set_costmodel_allreduce_fusion_computation_time_parameter(self, computation_time_parameter):
|
|
481
481
|
"""
|
|
@@ -693,7 +693,7 @@ def _set_rp_matmul_mem_coef(coef):
|
|
|
693
693
|
cost_model_context().set_rp_matmul_mem_coef(coef)
|
|
694
694
|
|
|
695
695
|
|
|
696
|
-
def _get_rp_matmul_mem_coef(
|
|
696
|
+
def _get_rp_matmul_mem_coef():
|
|
697
697
|
"""
|
|
698
698
|
Get the matmul memory coef which is used in the RP algorithm.
|
|
699
699
|
"""
|