mindspore 2.3.0__cp39-cp39-win_amd64.whl → 2.4.1__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +3 -1
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +50 -9
- mindspore/_extends/parse/compile_config.py +41 -0
- mindspore/_extends/parse/parser.py +9 -7
- mindspore/_extends/parse/standard_method.py +52 -14
- mindspore/_extends/pijit/pijit_func_white_list.py +350 -24
- mindspore/amp.py +24 -10
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/common/__init__.py +6 -4
- mindspore/common/_pijit_context.py +190 -0
- mindspore/common/_register_for_tensor.py +2 -1
- mindspore/common/_tensor_overload.py +139 -0
- mindspore/common/api.py +102 -87
- mindspore/common/dump.py +5 -6
- mindspore/common/generator.py +1 -7
- mindspore/common/hook_handle.py +14 -26
- mindspore/common/initializer.py +51 -15
- mindspore/common/mindir_util.py +2 -2
- mindspore/common/parameter.py +62 -15
- mindspore/common/recompute.py +39 -9
- mindspore/common/sparse_tensor.py +7 -3
- mindspore/common/tensor.py +183 -37
- mindspore/communication/__init__.py +1 -1
- mindspore/communication/_comm_helper.py +38 -3
- mindspore/communication/comm_func.py +315 -60
- mindspore/communication/management.py +14 -14
- mindspore/context.py +132 -22
- mindspore/dataset/__init__.py +1 -1
- mindspore/dataset/audio/__init__.py +1 -1
- mindspore/dataset/core/config.py +7 -0
- mindspore/dataset/core/validator_helpers.py +7 -0
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +72 -44
- mindspore/dataset/engine/datasets_audio.py +7 -7
- mindspore/dataset/engine/datasets_standard_format.py +53 -3
- mindspore/dataset/engine/datasets_text.py +20 -20
- mindspore/dataset/engine/datasets_user_defined.py +174 -104
- mindspore/dataset/engine/datasets_vision.py +33 -33
- mindspore/dataset/engine/iterators.py +29 -0
- mindspore/dataset/engine/obs/util.py +7 -0
- mindspore/dataset/engine/queue.py +114 -60
- mindspore/dataset/engine/serializer_deserializer.py +2 -2
- mindspore/dataset/engine/validators.py +34 -14
- mindspore/dataset/text/__init__.py +1 -4
- mindspore/dataset/transforms/__init__.py +0 -3
- mindspore/dataset/utils/line_reader.py +2 -0
- mindspore/dataset/vision/__init__.py +1 -4
- mindspore/dataset/vision/utils.py +1 -1
- mindspore/dataset/vision/validators.py +2 -1
- mindspore/dnnl.dll +0 -0
- mindspore/{nn/extend → experimental/es}/__init__.py +4 -11
- mindspore/experimental/es/embedding_service.py +883 -0
- mindspore/{nn/layer → experimental/es}/embedding_service_layer.py +218 -30
- mindspore/experimental/llm_boost/__init__.py +21 -0
- mindspore/{nn/extend/layer → experimental/llm_boost/atb}/__init__.py +4 -8
- mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
- mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
- mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
- mindspore/experimental/llm_boost/register.py +129 -0
- mindspore/experimental/llm_boost/utils.py +31 -0
- mindspore/experimental/optim/adamw.py +85 -0
- mindspore/experimental/optim/optimizer.py +3 -0
- mindspore/hal/__init__.py +3 -3
- mindspore/hal/contiguous_tensors_handle.py +175 -0
- mindspore/hal/stream.py +18 -0
- mindspore/include/api/model_group.h +13 -1
- mindspore/include/api/types.h +10 -10
- mindspore/include/dataset/config.h +2 -2
- mindspore/include/dataset/constants.h +2 -2
- mindspore/include/dataset/execute.h +2 -2
- mindspore/include/dataset/vision.h +4 -0
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +1 -1
- mindspore/mindrecord/filewriter.py +68 -51
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mint/__init__.py +983 -46
- mindspore/mint/distributed/__init__.py +31 -0
- mindspore/mint/distributed/distributed.py +254 -0
- mindspore/mint/nn/__init__.py +268 -23
- mindspore/mint/nn/functional.py +125 -19
- mindspore/mint/nn/layer/__init__.py +39 -0
- mindspore/mint/nn/layer/activation.py +133 -0
- mindspore/mint/nn/layer/normalization.py +477 -0
- mindspore/mint/nn/layer/pooling.py +110 -0
- mindspore/mint/optim/adamw.py +26 -13
- mindspore/mint/special/__init__.py +63 -0
- mindspore/multiprocessing/__init__.py +2 -1
- mindspore/nn/__init__.py +0 -1
- mindspore/nn/cell.py +276 -96
- mindspore/nn/layer/activation.py +211 -44
- mindspore/nn/layer/basic.py +137 -10
- mindspore/nn/layer/embedding.py +137 -2
- mindspore/nn/layer/normalization.py +101 -5
- mindspore/nn/layer/padding.py +34 -48
- mindspore/nn/layer/pooling.py +161 -7
- mindspore/nn/layer/transformer.py +3 -3
- mindspore/nn/loss/__init__.py +2 -2
- mindspore/nn/loss/loss.py +84 -6
- mindspore/nn/optim/__init__.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -1
- mindspore/nn/optim/adam.py +1 -1
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/tft_wrapper.py +124 -0
- mindspore/nn/wrap/cell_wrapper.py +12 -23
- mindspore/nn/wrap/grad_reducer.py +5 -5
- mindspore/nn/wrap/loss_scale.py +17 -3
- mindspore/numpy/__init__.py +1 -1
- mindspore/numpy/array_creations.py +65 -68
- mindspore/numpy/array_ops.py +64 -60
- mindspore/numpy/fft.py +610 -75
- mindspore/numpy/logic_ops.py +11 -10
- mindspore/numpy/math_ops.py +85 -84
- mindspore/numpy/utils_const.py +4 -4
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +6 -4
- mindspore/ops/_grad_experimental/grad_array_ops.py +0 -11
- mindspore/ops/_grad_experimental/grad_comm_ops.py +67 -4
- mindspore/ops/_grad_experimental/grad_math_ops.py +0 -22
- mindspore/ops/_vmap/vmap_array_ops.py +2 -4
- mindspore/ops/_vmap/vmap_math_ops.py +17 -1
- mindspore/ops/_vmap/vmap_nn_ops.py +43 -2
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +91 -7
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +2 -0
- mindspore/ops/auto_generate/gen_extend_func.py +767 -13
- mindspore/ops/auto_generate/gen_ops_def.py +2452 -364
- mindspore/ops/auto_generate/gen_ops_prim.py +5442 -1756
- mindspore/ops/auto_generate/pyboost_inner_prim.py +176 -56
- mindspore/ops/composite/base.py +85 -48
- mindspore/ops/composite/multitype_ops/_compile_utils.py +1 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -2
- mindspore/ops/function/__init__.py +22 -0
- mindspore/ops/function/array_func.py +492 -153
- mindspore/ops/function/debug_func.py +113 -1
- mindspore/ops/function/fft_func.py +15 -2
- mindspore/ops/function/grad/grad_func.py +3 -2
- mindspore/ops/function/math_func.py +564 -207
- mindspore/ops/function/nn_func.py +817 -383
- mindspore/ops/function/other_func.py +3 -2
- mindspore/ops/function/random_func.py +402 -12
- mindspore/ops/function/reshard_func.py +13 -11
- mindspore/ops/function/sparse_unary_func.py +1 -1
- mindspore/ops/function/vmap_func.py +3 -2
- mindspore/ops/functional.py +24 -14
- mindspore/ops/op_info_register.py +3 -3
- mindspore/ops/operations/__init__.py +7 -2
- mindspore/ops/operations/_grad_ops.py +2 -76
- mindspore/ops/operations/_infer_ops.py +1 -1
- mindspore/ops/operations/_inner_ops.py +71 -94
- mindspore/ops/operations/array_ops.py +14 -146
- mindspore/ops/operations/comm_ops.py +63 -53
- mindspore/ops/operations/custom_ops.py +83 -19
- mindspore/ops/operations/debug_ops.py +42 -10
- mindspore/ops/operations/manually_defined/_inner.py +12 -0
- mindspore/ops/operations/manually_defined/ops_def.py +273 -20
- mindspore/ops/operations/math_ops.py +12 -223
- mindspore/ops/operations/nn_ops.py +20 -114
- mindspore/ops/operations/other_ops.py +7 -4
- mindspore/ops/operations/random_ops.py +46 -1
- mindspore/ops/primitive.py +18 -6
- mindspore/ops_generate/arg_dtype_cast.py +2 -0
- mindspore/ops_generate/gen_aclnn_implement.py +11 -11
- mindspore/ops_generate/gen_constants.py +36 -0
- mindspore/ops_generate/gen_ops.py +67 -52
- mindspore/ops_generate/gen_ops_inner_prim.py +1 -1
- mindspore/ops_generate/gen_pyboost_func.py +131 -47
- mindspore/ops_generate/op_proto.py +10 -3
- mindspore/ops_generate/pyboost_utils.py +14 -1
- mindspore/ops_generate/template.py +43 -21
- mindspore/parallel/__init__.py +3 -1
- mindspore/parallel/_auto_parallel_context.py +31 -9
- mindspore/parallel/_cell_wrapper.py +85 -0
- mindspore/parallel/_parallel_serialization.py +47 -19
- mindspore/parallel/_tensor.py +127 -13
- mindspore/parallel/_utils.py +53 -22
- mindspore/parallel/algo_parameter_config.py +5 -5
- mindspore/parallel/checkpoint_transform.py +46 -39
- mindspore/parallel/cluster/process_entity/__init__.py +1 -1
- mindspore/parallel/cluster/process_entity/_api.py +31 -23
- mindspore/parallel/cluster/process_entity/_utils.py +2 -27
- mindspore/parallel/parameter_broadcast.py +3 -4
- mindspore/parallel/shard.py +162 -31
- mindspore/parallel/transform_safetensors.py +1146 -0
- mindspore/profiler/__init__.py +2 -1
- mindspore/profiler/common/constant.py +29 -0
- mindspore/profiler/common/registry.py +47 -0
- mindspore/profiler/common/util.py +28 -0
- mindspore/profiler/dynamic_profiler.py +694 -0
- mindspore/profiler/envprofiling.py +17 -19
- mindspore/profiler/parser/ascend_analysis/constant.py +18 -0
- mindspore/profiler/parser/ascend_analysis/file_manager.py +25 -4
- mindspore/profiler/parser/ascend_analysis/function_event.py +43 -19
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +31 -26
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +56 -10
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +55 -8
- mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +27 -20
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +9 -2
- mindspore/profiler/parser/ascend_msprof_exporter.py +5 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +27 -25
- mindspore/profiler/parser/base_timeline_generator.py +19 -25
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
- mindspore/profiler/parser/framework_parser.py +1 -391
- mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
- mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
- mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
- mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
- mindspore/profiler/parser/memory_usage_parser.py +0 -154
- mindspore/profiler/parser/profiler_info.py +78 -6
- mindspore/profiler/profiler.py +153 -0
- mindspore/profiler/profiling.py +285 -413
- mindspore/rewrite/__init__.py +1 -2
- mindspore/rewrite/common/namespace.py +4 -4
- mindspore/rewrite/symbol_tree/symbol_tree.py +3 -3
- mindspore/run_check/_check_version.py +39 -104
- mindspore/safeguard/rewrite_obfuscation.py +591 -247
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +4 -3
- mindspore/train/_utils.py +105 -19
- mindspore/train/amp.py +171 -53
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +4 -4
- mindspore/train/callback/_checkpoint.py +97 -31
- mindspore/train/callback/_cluster_monitor.py +1 -1
- mindspore/train/callback/_flops_collector.py +1 -0
- mindspore/train/callback/_loss_monitor.py +3 -3
- mindspore/train/callback/_on_request_exit.py +145 -31
- mindspore/train/callback/_summary_collector.py +5 -5
- mindspore/train/callback/_tft_register.py +375 -0
- mindspore/train/dataset_helper.py +15 -3
- mindspore/train/metrics/metric.py +3 -3
- mindspore/train/metrics/roc.py +4 -4
- mindspore/train/mind_ir_pb2.py +44 -39
- mindspore/train/model.py +154 -58
- mindspore/train/serialization.py +342 -128
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +21 -0
- mindspore/utils/utils.py +60 -0
- mindspore/version.py +1 -1
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/METADATA +13 -7
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/RECORD +260 -254
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/WHEEL +1 -1
- mindspore/include/c_api/ms/abstract.h +0 -67
- mindspore/include/c_api/ms/attribute.h +0 -197
- mindspore/include/c_api/ms/base/handle_types.h +0 -43
- mindspore/include/c_api/ms/base/macros.h +0 -32
- mindspore/include/c_api/ms/base/status.h +0 -33
- mindspore/include/c_api/ms/base/types.h +0 -283
- mindspore/include/c_api/ms/context.h +0 -102
- mindspore/include/c_api/ms/graph.h +0 -160
- mindspore/include/c_api/ms/node.h +0 -606
- mindspore/include/c_api/ms/tensor.h +0 -161
- mindspore/include/c_api/ms/value.h +0 -84
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/nn/extend/basic.py +0 -140
- mindspore/nn/extend/embedding.py +0 -143
- mindspore/nn/extend/layer/normalization.py +0 -109
- mindspore/nn/extend/pooling.py +0 -117
- mindspore/nn/layer/embedding_service.py +0 -531
- mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
- mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
- mindspore/ops/extend/__init__.py +0 -53
- mindspore/ops/extend/array_func.py +0 -218
- mindspore/ops/extend/math_func.py +0 -76
- mindspore/ops/extend/nn_func.py +0 -308
- mindspore/ops/silent_check.py +0 -162
- mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
- mindspore/profiler/parser/msadvisor_parser.py +0 -240
- mindspore/train/callback/_mindio_ttp.py +0 -443
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/top_level.txt +0 -0
mindspore/.commit_id
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__commit_id__ = ''[sha1]:
|
|
1
|
+
__commit_id__ = ''[sha1]:01847825,[branch]:(HEAD->r2.4.1,origin/r2.4.1)''
|
mindspore/__init__.py
CHANGED
|
@@ -23,6 +23,7 @@ from mindspore.mindrecord import *
|
|
|
23
23
|
from mindspore.ops import _op_impl, grad, value_and_grad, vjp, jvp, jacfwd, jacrev, vmap, get_grad, constexpr, reshard
|
|
24
24
|
from mindspore.train import *
|
|
25
25
|
from mindspore.log import *
|
|
26
|
+
from mindspore.utils import *
|
|
26
27
|
from mindspore.context import GRAPH_MODE, PYNATIVE_MODE, set_context, get_context, set_auto_parallel_context, \
|
|
27
28
|
get_auto_parallel_context, reset_auto_parallel_context, ParallelMode, set_ps_context, \
|
|
28
29
|
get_ps_context, reset_ps_context, set_offload_context, get_offload_context, STRICT, COMPATIBLE, LAX
|
|
@@ -30,7 +31,8 @@ from mindspore.version import __version__
|
|
|
30
31
|
from mindspore.profiler import Profiler, EnvProfiler
|
|
31
32
|
from mindspore.parallel import set_algo_parameters, get_algo_parameters, reset_algo_parameters, \
|
|
32
33
|
rank_list_for_transform, transform_checkpoint_by_rank, transform_checkpoints, merge_pipeline_strategys, shard, \
|
|
33
|
-
Layout, sync_pipeline_shared_parameters, parameter_broadcast, load_segmented_checkpoints
|
|
34
|
+
Layout, sync_pipeline_shared_parameters, parameter_broadcast, load_segmented_checkpoints, \
|
|
35
|
+
safetensors_to_ckpt, ckpt_to_safetensors, unified_safetensors
|
|
34
36
|
from mindspore.rewrite import SymbolTree, ScopedValue, Node, NodeType
|
|
35
37
|
from mindspore.safeguard import obfuscate_ckpt, load_obf_params_into_net
|
|
36
38
|
from mindspore._check_jit_forbidden_api import get_obj_module_and_name_info, is_jit_forbidden_module, \
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
mindspore/_checkparam.py
CHANGED
|
@@ -29,7 +29,6 @@ from mindspore import log as logger
|
|
|
29
29
|
from mindspore.common import dtype as mstype
|
|
30
30
|
from mindspore._c_expression import Tensor as Tensor_
|
|
31
31
|
|
|
32
|
-
|
|
33
32
|
EQ = 1 # ==
|
|
34
33
|
NE = 2 # !=
|
|
35
34
|
LT = 3 # <
|
|
@@ -148,7 +147,7 @@ def _check_3d_int_or_tuple(arg_name, arg_value, prim_name, allow_five=False, ret
|
|
|
148
147
|
ret = (1, 1, arg_value, arg_value, arg_value) if ret_five else (arg_value, arg_value, arg_value)
|
|
149
148
|
elif len(arg_value) == 3:
|
|
150
149
|
ret = (1, 1, arg_value[0], arg_value[1], arg_value[2]) if ret_five else arg_value
|
|
151
|
-
else:
|
|
150
|
+
else: # case: len(arg_value) == 5
|
|
152
151
|
ret = arg_value if ret_five else (arg_value[2], arg_value[3], arg_value[4])
|
|
153
152
|
|
|
154
153
|
return ret
|
|
@@ -240,6 +239,7 @@ def check_is_number(arg_value, arg_type, arg_name=None, prim_name=None):
|
|
|
240
239
|
else:
|
|
241
240
|
raise TypeError(f"{prim_name} type of {arg_name} must be '{arg_type.__name__}', " \
|
|
242
241
|
f"but got '{type(arg_value).__name__}'.")
|
|
242
|
+
|
|
243
243
|
_check_param()
|
|
244
244
|
return arg_value
|
|
245
245
|
|
|
@@ -265,6 +265,7 @@ def check_number_range(arg_value, lower_limit, upper_limit, rel, value_type, arg
|
|
|
265
265
|
rel_str = _format_str_two_value(lower_limit, upper_limit, rel)
|
|
266
266
|
raise ValueError(f"{prim_name} {arg_name} must be in range of {rel_str}, " \
|
|
267
267
|
f"but got {arg_value} with type '{type(arg_value).__name__}'.")
|
|
268
|
+
|
|
268
269
|
_check_param()
|
|
269
270
|
return arg_value
|
|
270
271
|
|
|
@@ -274,6 +275,7 @@ def check(arg_name, arg_value, value_name, value, rel=EQ, prim_name=None, excp_c
|
|
|
274
275
|
Method for judging relation between two int values or list/tuple made up of ints.
|
|
275
276
|
This method is not suitable for judging relation between floats, since it does not consider float error.
|
|
276
277
|
"""
|
|
278
|
+
|
|
277
279
|
def _check():
|
|
278
280
|
if not _check_binary_rel(arg_value, value, rel):
|
|
279
281
|
rel_str = _format_str_one_value(f'{value_name}: {value}', rel)
|
|
@@ -475,20 +477,24 @@ def check_non_negative_float(arg_value, arg_name=None, prim_name=None):
|
|
|
475
477
|
|
|
476
478
|
def check_number(arg_name, arg_value, value, rel, prim_name):
|
|
477
479
|
"""Number value judgment."""
|
|
480
|
+
|
|
478
481
|
def _check():
|
|
479
482
|
if not _check_binary_rel(arg_value, value, rel):
|
|
480
483
|
rel_str = _format_str_one_value(value, rel)
|
|
481
484
|
raise ValueError(f'For \'{prim_name}\', the argument \'{arg_name}\' ' \
|
|
482
485
|
f'must {rel_str}, but got {arg_value}.')
|
|
486
|
+
|
|
483
487
|
_check()
|
|
484
488
|
return arg_value
|
|
485
489
|
|
|
486
490
|
|
|
487
491
|
def check_isinstance(arg_name, arg_value, classes):
|
|
488
492
|
"""Check arg isinstance of classes"""
|
|
493
|
+
|
|
489
494
|
def _check():
|
|
490
495
|
if not isinstance(arg_value, classes):
|
|
491
496
|
raise ValueError(f'The parameter \'{arg_name}\' must be isinstance of {classes}, but got {arg_value}.')
|
|
497
|
+
|
|
492
498
|
_check()
|
|
493
499
|
return arg_value
|
|
494
500
|
|
|
@@ -507,6 +513,7 @@ def check_bool(arg_value, arg_name=None, prim_name=None):
|
|
|
507
513
|
def _check():
|
|
508
514
|
if not isinstance(arg_value, bool):
|
|
509
515
|
raise TypeError(f"{prim_name} {arg_name} must be a bool, but got {type(arg_value).__name__}.")
|
|
516
|
+
|
|
510
517
|
_check()
|
|
511
518
|
return arg_value
|
|
512
519
|
|
|
@@ -547,6 +554,7 @@ def check_string(arg_value, valid_values, arg_name=None, prim_name=None):
|
|
|
547
554
|
if not (isinstance(arg_value, str) and arg_value in valid_values):
|
|
548
555
|
raise ValueError(f"{msg_prefix} '{arg_name}' must be str and must be in '{valid_values}'," \
|
|
549
556
|
f" but got '{arg_value}'.")
|
|
557
|
+
|
|
550
558
|
_check()
|
|
551
559
|
return arg_value
|
|
552
560
|
|
|
@@ -626,10 +634,12 @@ def check_subclass(arg_name, type_, template_types, prim_name, addition_error_in
|
|
|
626
634
|
|
|
627
635
|
def check_valid_input(arg_name, arg_value, prim_name):
|
|
628
636
|
"""Checks valid value."""
|
|
637
|
+
|
|
629
638
|
def _check():
|
|
630
639
|
if arg_value is None:
|
|
631
640
|
raise ValueError(f"For \'{prim_name}\', the argument '{arg_name}'" \
|
|
632
641
|
f"can not be None, but got {arg_value}.")
|
|
642
|
+
|
|
633
643
|
_check()
|
|
634
644
|
return arg_value
|
|
635
645
|
|
|
@@ -786,6 +796,7 @@ def check_astype_dtype(dtype):
|
|
|
786
796
|
|
|
787
797
|
def check_transpose_axis(axes, ndim):
|
|
788
798
|
"""Check the axis argument for tensor.transpose"""
|
|
799
|
+
|
|
789
800
|
def _check_dim():
|
|
790
801
|
# if multiple arguments provided, it must be `ndim` number of ints
|
|
791
802
|
if len(axes) != ndim:
|
|
@@ -793,7 +804,7 @@ def check_transpose_axis(axes, ndim):
|
|
|
793
804
|
f"but got {len(axes)} in the number of axes.")
|
|
794
805
|
|
|
795
806
|
if not axes or (len(axes) == 1 and axes[0] is None):
|
|
796
|
-
return tuple(range(ndim-1, -1, -1))
|
|
807
|
+
return tuple(range(ndim - 1, -1, -1))
|
|
797
808
|
|
|
798
809
|
if len(axes) == 1:
|
|
799
810
|
perm = axes[0]
|
|
@@ -912,6 +923,7 @@ def prepare_shape_for_squeeze(shape, axes):
|
|
|
912
923
|
|
|
913
924
|
def check_axis_in_range(axis, ndim):
|
|
914
925
|
"""Checks axes are with the bounds of ndim"""
|
|
926
|
+
|
|
915
927
|
def _check():
|
|
916
928
|
if not isinstance(axis, int):
|
|
917
929
|
raise TypeError(f'The axes must be integers, but got {type(axis)}')
|
|
@@ -928,6 +940,7 @@ def check_axis_valid(axes, ndim):
|
|
|
928
940
|
Checks axes are valid given ndim, and returns axes that can be passed
|
|
929
941
|
to the built-in operator (non-negative, int or tuple)
|
|
930
942
|
"""
|
|
943
|
+
|
|
931
944
|
def _check_range(axes):
|
|
932
945
|
for axis in axes:
|
|
933
946
|
check_axis_in_range(axis, ndim)
|
|
@@ -977,16 +990,17 @@ def infer_out_shape(*shapes):
|
|
|
977
990
|
"""
|
|
978
991
|
Returns shape of output after broadcasting. Raises ValueError if shapes cannot be broadcast.
|
|
979
992
|
"""
|
|
993
|
+
|
|
980
994
|
def _check(items, max_size, shapes):
|
|
981
995
|
for item in items:
|
|
982
996
|
if item not in (1, max_size):
|
|
983
997
|
raise ValueError(f'For Tensor, the dimension on each axis must be 1 or the max value on the axis' \
|
|
984
998
|
f'to support broadcasting, but got shapes {shapes,}')
|
|
999
|
+
|
|
985
1000
|
shape_out = ()
|
|
986
1001
|
max_len = max([len(it) for it in shapes])
|
|
987
1002
|
for i in range(max_len):
|
|
988
|
-
items = [it[i-(max_len-len(it))] if i - (max_len - len(it))
|
|
989
|
-
>= 0 else 1 for it in shapes]
|
|
1003
|
+
items = [it[i - (max_len - len(it))] if i - (max_len - len(it)) >= 0 else 1 for it in shapes]
|
|
990
1004
|
max_size = 0 if 0 in items else max(items)
|
|
991
1005
|
_check(items, max_size, shapes)
|
|
992
1006
|
shape_out = shape_out + (max_size,)
|
|
@@ -1015,6 +1029,7 @@ def check_axis_type(axis, type_int=True, type_tuple=True, type_list=True):
|
|
|
1015
1029
|
|
|
1016
1030
|
def check_and_canonicalize_axes(axes, ndim):
|
|
1017
1031
|
"""Check whether the types and values of input axes are valid."""
|
|
1032
|
+
|
|
1018
1033
|
def _check(axes, ax, ndim):
|
|
1019
1034
|
if not isinstance(ax, int):
|
|
1020
1035
|
raise TypeError(f"Each axis should be integer, but got {type(ax)} in {axes}.")
|
|
@@ -1091,8 +1106,8 @@ def check_csr_tensor_shape(indptr_shp, indices_shp, values_shp, csr_shp):
|
|
|
1091
1106
|
f"{len(csr_shp)}")
|
|
1092
1107
|
if values_shp[1:] != csr_shp[2:]:
|
|
1093
1108
|
raise ValueError(f"CSRTensor's shape[2: ] must be equal to value's shape[1: ]," \
|
|
1094
|
-
f"but CSRTensor's shape[2: ] got: {csr_shp[2:
|
|
1095
|
-
f"got: {values_shp[1:
|
|
1109
|
+
f"but CSRTensor's shape[2: ] got: {csr_shp[2:]} and value's shape[1: ]" \
|
|
1110
|
+
f"got: {values_shp[1:]}")
|
|
1096
1111
|
|
|
1097
1112
|
|
|
1098
1113
|
def check_csr_tensor_dtype(indptr_dtype, indices_dtype):
|
|
@@ -1370,9 +1385,35 @@ def check_hook_fn(hook_type, hook_fn):
|
|
|
1370
1385
|
if hook_fn.__code__.co_name == "staging_specialize":
|
|
1371
1386
|
raise TypeError(f"Decorating hook function {hook_fn.__name__} with '@jit' is not supported.")
|
|
1372
1387
|
|
|
1373
|
-
|
|
1374
|
-
|
|
1388
|
+
tensor_hook_func_args_num = 1
|
|
1389
|
+
pre_hook_func_args_num = 2
|
|
1390
|
+
forward_hook_and_backward_hook_func_args_num = 3
|
|
1391
|
+
# Real args number, exclude class method self param
|
|
1392
|
+
hook_fn_args_num = len(inspect.signature(hook_fn).parameters)
|
|
1393
|
+
|
|
1394
|
+
if hook_type == "register_hook" and hook_fn_args_num != tensor_hook_func_args_num:
|
|
1395
|
+
raise TypeError(f"Tensor hook function {hook_fn.__name__} arg num should be {tensor_hook_func_args_num}, but "
|
|
1396
|
+
f"got {hook_fn_args_num}")
|
|
1397
|
+
|
|
1398
|
+
if hook_type == "register_forward_pre_hook" and hook_fn_args_num != pre_hook_func_args_num:
|
|
1399
|
+
raise TypeError(f"forward_pre_hook function {hook_fn.__name__} args num should be {pre_hook_func_args_num}, "
|
|
1400
|
+
f"but got {hook_fn_args_num}")
|
|
1401
|
+
|
|
1402
|
+
if (hook_type == "register_forward_hook" and
|
|
1403
|
+
hook_fn_args_num != forward_hook_and_backward_hook_func_args_num):
|
|
1404
|
+
raise TypeError(f"forward_hook function {hook_fn.__name__} args num should be "
|
|
1405
|
+
f"{forward_hook_and_backward_hook_func_args_num}, but got {hook_fn_args_num}")
|
|
1406
|
+
|
|
1407
|
+
if hook_type == "register_backward_pre_hook" and hook_fn_args_num != pre_hook_func_args_num:
|
|
1408
|
+
raise TypeError(f"backward_pre_hook function {hook_fn.__name__} args num should be {pre_hook_func_args_num},"
|
|
1409
|
+
f" but got {hook_fn_args_num}")
|
|
1410
|
+
|
|
1411
|
+
if (hook_type == "register_backward_hook" and
|
|
1412
|
+
hook_fn_args_num != forward_hook_and_backward_hook_func_args_num):
|
|
1413
|
+
raise TypeError(f"backward_hook function {hook_fn.__name__} args num should be "
|
|
1414
|
+
f"{forward_hook_and_backward_hook_func_args_num}, but got {hook_fn_args_num}")
|
|
1375
1415
|
|
|
1376
1416
|
return True
|
|
1377
1417
|
|
|
1418
|
+
|
|
1378
1419
|
_set_record = {}
|
|
@@ -12,6 +12,15 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ============================================================================
|
|
15
|
+
"""
|
|
16
|
+
Name: AUTO_PASSES_OPTIMIZE_PATH
|
|
17
|
+
Function: Whether to do optimize the passes configure.
|
|
18
|
+
Value Range:
|
|
19
|
+
string: The passes configure file path.
|
|
20
|
+
Default: '' .empty string. Disable to do optimize the passes.
|
|
21
|
+
"""
|
|
22
|
+
AUTO_PASSES_OPTIMIZE_PATH = ''
|
|
23
|
+
|
|
15
24
|
"""
|
|
16
25
|
Name: COMPILE_PROFILE
|
|
17
26
|
Function: Whether to do profile and print profile log.
|
|
@@ -29,6 +38,16 @@ Value Range:
|
|
|
29
38
|
"""
|
|
30
39
|
COMPILE_PROFILE_FINISH_ACTION = ''
|
|
31
40
|
|
|
41
|
+
"""
|
|
42
|
+
Name: DEBUG_MODE
|
|
43
|
+
Function: Whether to compile in debug mode.
|
|
44
|
+
Value Range:
|
|
45
|
+
"debug": Debug mode
|
|
46
|
+
"release": Release mode
|
|
47
|
+
Default: "debug"
|
|
48
|
+
"""
|
|
49
|
+
COMPILE_DEBUG_MODE = ''
|
|
50
|
+
|
|
32
51
|
"""
|
|
33
52
|
Name: FALLBACK_SUPPORT_LIST_DICT_INPLACE
|
|
34
53
|
Function: Whether to support the inplace operation of list and dict.
|
|
@@ -230,9 +249,28 @@ Value Range:
|
|
|
230
249
|
"""
|
|
231
250
|
DUMP_VALIDATE_BEFORE_RESET_ID = ''
|
|
232
251
|
|
|
252
|
+
"""
|
|
253
|
+
Name: ENABLE_RECOMPUTE_BEFORE_INLINE
|
|
254
|
+
Function: Whether to do recomputation before fprop and bprop being inlined.
|
|
255
|
+
Value Range:
|
|
256
|
+
1: Enable
|
|
257
|
+
Default: Disable.
|
|
258
|
+
"""
|
|
259
|
+
ENABLE_RECOMPUTE_BEFORE_INLINE = ''
|
|
260
|
+
|
|
261
|
+
"""
|
|
262
|
+
Name: STRICT_CHECK_PARENT_CONTEXT
|
|
263
|
+
Function: Whether to check parent context strictly.
|
|
264
|
+
Value Range:
|
|
265
|
+
1: Enable
|
|
266
|
+
Default: Disable.
|
|
267
|
+
"""
|
|
268
|
+
STRICT_CHECK_PARENT_CONTEXT = ''
|
|
269
|
+
|
|
233
270
|
__all__ = [
|
|
234
271
|
"COMPILE_PROFILE",
|
|
235
272
|
"COMPILE_PROFILE_FINISH_ACTION",
|
|
273
|
+
"COMPILE_DEBUG_MODE",
|
|
236
274
|
"FALLBACK_SUPPORT_LIST_DICT_INPLACE",
|
|
237
275
|
"FALLBACK_FORCE_ANY",
|
|
238
276
|
"IF_PARALLEL_CALL",
|
|
@@ -255,4 +293,7 @@ __all__ = [
|
|
|
255
293
|
"DUMP_IR_DDE_DETAIL",
|
|
256
294
|
"COMBINE_LIKE_GRAPHS",
|
|
257
295
|
"DUMP_VALIDATE_BEFORE_RESET_ID",
|
|
296
|
+
"ENABLE_RECOMPUTE_BEFORE_INLINE",
|
|
297
|
+
"STRICT_CHECK_PARENT_CONTEXT",
|
|
298
|
+
"AUTO_PASSES_OPTIMIZE_PATH",
|
|
258
299
|
]
|
|
@@ -127,7 +127,7 @@ _modules_from_mindspore = (
|
|
|
127
127
|
"mindspore_rl", "mindformers", "mindpet", "mindpose", "mindface", "mindsearch", "mindinsight", "mindelec",
|
|
128
128
|
"mindflow", "mindsponge", "mindearth", "sciai", "mindquantum", "mindarmour", "mindpandas", "mindvision",
|
|
129
129
|
"mindspore_gl", "mindspore_federated", "mindspore_gs", "mindspore_serving", "mindspore_xai", "mindspore_hub",
|
|
130
|
-
"ringmo_framework", "troubleshooter", "mindtorch",
|
|
130
|
+
"ringmo_framework", "troubleshooter", "mindtorch", "mindchemistry",
|
|
131
131
|
)
|
|
132
132
|
|
|
133
133
|
_global_params = {}
|
|
@@ -203,7 +203,7 @@ def get_parse_method_of_class(obj, parse_method=None):
|
|
|
203
203
|
if parse_method is not None:
|
|
204
204
|
method_name = parse_method
|
|
205
205
|
elif isinstance(obj, nn.Cell):
|
|
206
|
-
if obj.
|
|
206
|
+
if obj._backward_hook:
|
|
207
207
|
method_name = "_backward_hook_construct"
|
|
208
208
|
else:
|
|
209
209
|
method_name = "construct"
|
|
@@ -486,7 +486,7 @@ def convert_class_to_function(cls_str, cls_obj):
|
|
|
486
486
|
f"supported in 'construct' or @jit decorated function. Try to create {cls_str} "
|
|
487
487
|
f"instances external such as initialized in the method '__init__' before assigning. "
|
|
488
488
|
f"For more details, please refer to "
|
|
489
|
-
f"https://www.mindspore.cn/docs/zh-CN/master/
|
|
489
|
+
f"https://www.mindspore.cn/docs/zh-CN/master/model_train/program_form/overview.html \n")
|
|
490
490
|
return convert_class_to_function_map.get(cls_str)
|
|
491
491
|
|
|
492
492
|
|
|
@@ -931,7 +931,7 @@ class ThirdPartyLibraryChecker:
|
|
|
931
931
|
"""
|
|
932
932
|
def __init__(self):
|
|
933
933
|
self.user_workspace_dir = self.get_top_level_module_path(os.getcwd())
|
|
934
|
-
self.python_builtin_dir = os.path.
|
|
934
|
+
self.python_builtin_dir = os.path.realpath(os.path.dirname(os.__file__))
|
|
935
935
|
|
|
936
936
|
@staticmethod
|
|
937
937
|
def get_jit_modules():
|
|
@@ -963,8 +963,8 @@ class ThirdPartyLibraryChecker:
|
|
|
963
963
|
|
|
964
964
|
def get_top_level_module_path(self, module_path):
|
|
965
965
|
"""Get the path of the top level package of the current working directory."""
|
|
966
|
-
module_abspath = os.path.
|
|
967
|
-
upper_path = os.path.
|
|
966
|
+
module_abspath = os.path.realpath(module_path)
|
|
967
|
+
upper_path = os.path.realpath(os.path.dirname(module_abspath))
|
|
968
968
|
if module_abspath == upper_path:
|
|
969
969
|
return module_abspath
|
|
970
970
|
# Check whether __init__.py exists in the upper directory.
|
|
@@ -990,7 +990,7 @@ class ThirdPartyLibraryChecker:
|
|
|
990
990
|
# A modules without __file__ attribute is considered to be in user workspace.
|
|
991
991
|
if not hasattr(module, '__file__'):
|
|
992
992
|
return False
|
|
993
|
-
module_path = os.path.
|
|
993
|
+
module_path = os.path.realpath(module.__file__)
|
|
994
994
|
# Python builtin modules are treated as third-party libraries.
|
|
995
995
|
if module_path.startswith(self.python_builtin_dir):
|
|
996
996
|
logger.debug(f"Found python builtin module '{module.__name__}', which is a third-party module.")
|
|
@@ -1180,6 +1180,7 @@ class Parser:
|
|
|
1180
1180
|
return SYNTAX_SUPPORTED
|
|
1181
1181
|
|
|
1182
1182
|
def check_lambda(self, src):
|
|
1183
|
+
"""Check if the lamda expressions is correct."""
|
|
1183
1184
|
obj_type = get_obj_type(self.fn)
|
|
1184
1185
|
if (obj_type != RESOLVE_TYPE_FUNCTION or src[:4] == "def ") and is_lambda_function(self.fn):
|
|
1185
1186
|
logger.debug("fn is lambda: %r", self.fn)
|
|
@@ -1242,6 +1243,7 @@ class Parser:
|
|
|
1242
1243
|
return None, None
|
|
1243
1244
|
|
|
1244
1245
|
def get_name_from_namespace(self, value):
|
|
1246
|
+
"""Get the name of value from namespace"""
|
|
1245
1247
|
try:
|
|
1246
1248
|
value_str = value.__name__
|
|
1247
1249
|
logger.debug(
|
|
@@ -26,6 +26,7 @@ from mindspore.common.sparse_tensor import RowTensorInner
|
|
|
26
26
|
from mindspore.ops.composite.base import _append, _insert, _pop, _list_clear, _reverse, \
|
|
27
27
|
_extend, _dict_setitem, _dict_clear, _haskey, _update, _fromkeys
|
|
28
28
|
from mindspore.ops.operations._sequence_ops import TensorToTuple
|
|
29
|
+
from mindspore.ops.auto_generate import trace_v2_op, inplace_addmm_op
|
|
29
30
|
|
|
30
31
|
from ... import _checkparam as validator
|
|
31
32
|
from ..._checkparam import check_is_number, check_reshape_shp, check_axis_in_range, \
|
|
@@ -69,6 +70,38 @@ itemsize_map = {mstype.bool_: 1, mstype.int8: 1, mstype.uint8: 1,
|
|
|
69
70
|
|
|
70
71
|
nan_tensor = Tensor(float('nan'), dtype=mstype.float32)
|
|
71
72
|
|
|
73
|
+
_map = composite.HyperMap()
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def hypermap_dynamic_tuple(func, *inputs):
|
|
77
|
+
"""Make hypermap for dynamic shape tuple."""
|
|
78
|
+
iter_len = len(inputs[0])
|
|
79
|
+
i = 0
|
|
80
|
+
ret = F.make_tuple()
|
|
81
|
+
while i < iter_len:
|
|
82
|
+
cur_input = F.make_tuple()
|
|
83
|
+
for m in inputs:
|
|
84
|
+
cur_input = cur_input + F.make_tuple(m[i])
|
|
85
|
+
new_out = _map(func, *cur_input)
|
|
86
|
+
ret = ret + F.make_tuple(new_out)
|
|
87
|
+
i = i + 1
|
|
88
|
+
return ret
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def hypermap_dynamic_list(func, *inputs):
|
|
92
|
+
"""Make hypermap for dynamic shape list."""
|
|
93
|
+
iter_len = len(inputs[0])
|
|
94
|
+
i = 0
|
|
95
|
+
ret = F.make_list()
|
|
96
|
+
while i < iter_len:
|
|
97
|
+
cur_input = F.make_tuple()
|
|
98
|
+
for m in inputs:
|
|
99
|
+
cur_input = cur_input + F.make_tuple(m[i])
|
|
100
|
+
new_out = _map(func, *cur_input)
|
|
101
|
+
ret = ret + F.make_list(new_out)
|
|
102
|
+
i = i + 1
|
|
103
|
+
return ret
|
|
104
|
+
|
|
72
105
|
|
|
73
106
|
def mean(x, axis=None, keep_dims=False):
|
|
74
107
|
"""
|
|
@@ -1598,17 +1631,7 @@ def trace(x, offset=0, axis1=0, axis2=1, dtype=None):
|
|
|
1598
1631
|
>>> print(x.trace())
|
|
1599
1632
|
3.0
|
|
1600
1633
|
"""
|
|
1601
|
-
|
|
1602
|
-
return F.trace(x)
|
|
1603
|
-
d = x.diagonal(offset, axis1=axis1, axis2=axis2)
|
|
1604
|
-
shape = d.shape
|
|
1605
|
-
if dtype is None:
|
|
1606
|
-
dtype = d.dtype
|
|
1607
|
-
dtype = check_astype_dtype_const(dtype)
|
|
1608
|
-
if shape[-1] == 0:
|
|
1609
|
-
return F.fill(dtype, shape[:-1], 0)
|
|
1610
|
-
res = F.reduce_sum(d.astype(mstype.float32), -1)
|
|
1611
|
-
return res.astype(dtype)
|
|
1634
|
+
return trace_v2_op(x, offset, axis1, axis2, dtype)
|
|
1612
1635
|
|
|
1613
1636
|
|
|
1614
1637
|
def take(x, indices, axis=None, mode='clip'):
|
|
@@ -1794,7 +1817,7 @@ def searchsorted(x, v, side='left', sorter=None):
|
|
|
1794
1817
|
no suitable index, return either 0 or N (where N is the length of `a`).
|
|
1795
1818
|
sorter (Union[int, float, bool, list, tuple, Tensor]): 1-D optional array of
|
|
1796
1819
|
integer indices that sort array `a` into ascending order. They are typically
|
|
1797
|
-
the result of argsort.
|
|
1820
|
+
the result of argsort. CPU and GPU can only use default values
|
|
1798
1821
|
|
|
1799
1822
|
Returns:
|
|
1800
1823
|
Tensor, array of insertion points with the same shape as `v`.
|
|
@@ -2435,6 +2458,7 @@ def list_func(data):
|
|
|
2435
2458
|
ret = ret + F.make_list(i)
|
|
2436
2459
|
return ret
|
|
2437
2460
|
|
|
2461
|
+
|
|
2438
2462
|
def tuple_func(data):
|
|
2439
2463
|
"""Implementation of `tuple`."""
|
|
2440
2464
|
if isinstance(data, (CSRTensor, COOTensor, RowTensorInner)):
|
|
@@ -2453,7 +2477,7 @@ def tuple_func(data):
|
|
|
2453
2477
|
|
|
2454
2478
|
|
|
2455
2479
|
def ms_zip(*data):
|
|
2456
|
-
"""
|
|
2480
|
+
"""Packs elements in the corresponding positions in multiple sequences into tuples."""
|
|
2457
2481
|
x = ()
|
|
2458
2482
|
for i in data:
|
|
2459
2483
|
if isinstance(i, Tensor):
|
|
@@ -3002,7 +3026,7 @@ def tensor_scatter_mul(input_x, indices, updates):
|
|
|
3002
3026
|
`indices`, with values from `updates`. When multiple value are given for the same index,
|
|
3003
3027
|
the output result will be the division of values.
|
|
3004
3028
|
"""
|
|
3005
|
-
return F.
|
|
3029
|
+
return F.tensor_scatter_mul(input_x, indices, updates)
|
|
3006
3030
|
|
|
3007
3031
|
|
|
3008
3032
|
def tensor_sactter_div(input_x, indices, updates):
|
|
@@ -3813,6 +3837,20 @@ def addmm(x, mat1, mat2, *, beta=1, alpha=1):
|
|
|
3813
3837
|
return F.addmm(x, mat1, mat2, beta=beta, alpha=alpha)
|
|
3814
3838
|
|
|
3815
3839
|
|
|
3840
|
+
def addmm_(self, mat1, mat2, *, beta=1, alpha=1):
|
|
3841
|
+
r"""
|
|
3842
|
+
For details, please refer to :func:`mindspore.ops.addmm`.
|
|
3843
|
+
|
|
3844
|
+
.. note::
|
|
3845
|
+
The output results are directly updated in the Tensor.
|
|
3846
|
+
|
|
3847
|
+
.. warning::
|
|
3848
|
+
This is an experimental API that is subject to change or deletion.
|
|
3849
|
+
|
|
3850
|
+
"""
|
|
3851
|
+
return inplace_addmm_op(self, mat1, mat2, beta=beta, alpha=alpha)
|
|
3852
|
+
|
|
3853
|
+
|
|
3816
3854
|
def addmv(x, mat, vec, beta=1, alpha=1):
|
|
3817
3855
|
r"""
|
|
3818
3856
|
Multiplies matrix `mat` and vector `vec`. The vector `x` is added to the final result.
|