mindspore 2.2.10__cp38-none-any.whl → 2.2.14__cp38-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +2 -1
- mindspore/_akg/akg/composite/build_module.py +95 -5
- mindspore/_akg/akg/topi/cpp/impl.py +1 -1
- mindspore/_akg/akg/tvm/_ffi/base.py +1 -1
- mindspore/_akg/akg/utils/composite_op_helper.py +7 -2
- mindspore/_akg/akg/utils/dump_ascend_meta.py +22 -3
- mindspore/_akg/akg/utils/util.py +18 -1
- mindspore/_c_dataengine.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_extends/parse/__init__.py +3 -2
- mindspore/_extends/parse/parser.py +6 -1
- mindspore/_extends/parse/standard_method.py +12 -2
- mindspore/_mindspore_offline_debug.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/common/_utils.py +16 -0
- mindspore/common/tensor.py +0 -2
- mindspore/communication/management.py +3 -0
- mindspore/context.py +34 -4
- mindspore/dataset/engine/cache_client.py +8 -5
- mindspore/dataset/engine/datasets.py +23 -0
- mindspore/dataset/engine/validators.py +1 -1
- mindspore/dataset/vision/py_transforms_util.py +2 -2
- mindspore/experimental/optim/lr_scheduler.py +5 -6
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +118 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/mindrecord/tools/cifar100_to_mr.py +49 -57
- mindspore/mindrecord/tools/cifar10_to_mr.py +46 -55
- mindspore/mindrecord/tools/csv_to_mr.py +3 -8
- mindspore/mindrecord/tools/mnist_to_mr.py +4 -9
- mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -4
- mindspore/nn/layer/activation.py +1 -1
- mindspore/nn/layer/embedding.py +2 -2
- mindspore/nn/layer/flash_attention.py +48 -135
- mindspore/nn/loss/loss.py +1 -1
- mindspore/nn/optim/ada_grad.py +2 -2
- mindspore/nn/optim/sgd.py +3 -2
- mindspore/nn/wrap/__init__.py +4 -2
- mindspore/nn/wrap/cell_wrapper.py +6 -3
- mindspore/numpy/math_ops.py +1 -1
- mindspore/ops/__init__.py +3 -0
- mindspore/ops/_grad_experimental/grad_array_ops.py +0 -31
- mindspore/ops/_grad_experimental/grad_comm_ops.py +4 -2
- mindspore/ops/_grad_experimental/grad_inner_ops.py +8 -0
- mindspore/ops/_grad_experimental/grad_math_ops.py +37 -17
- mindspore/ops/_op_impl/aicpu/__init__.py +1 -0
- mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +38 -0
- mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
- mindspore/ops/function/array_func.py +6 -5
- mindspore/ops/function/debug_func.py +1 -1
- mindspore/ops/function/linalg_func.py +21 -11
- mindspore/ops/function/math_func.py +3 -0
- mindspore/ops/function/nn_func.py +13 -11
- mindspore/ops/function/parameter_func.py +2 -0
- mindspore/ops/function/sparse_unary_func.py +2 -2
- mindspore/ops/function/vmap_func.py +1 -0
- mindspore/ops/operations/__init__.py +5 -2
- mindspore/ops/operations/_embedding_cache_ops.py +1 -1
- mindspore/ops/operations/_grad_ops.py +3 -4
- mindspore/ops/operations/_inner_ops.py +56 -1
- mindspore/ops/operations/_quant_ops.py +4 -4
- mindspore/ops/operations/_rl_inner_ops.py +1 -1
- mindspore/ops/operations/array_ops.py +15 -4
- mindspore/ops/operations/custom_ops.py +1 -1
- mindspore/ops/operations/debug_ops.py +1 -1
- mindspore/ops/operations/image_ops.py +3 -3
- mindspore/ops/operations/inner_ops.py +49 -0
- mindspore/ops/operations/math_ops.py +65 -3
- mindspore/ops/operations/nn_ops.py +95 -28
- mindspore/ops/operations/random_ops.py +2 -0
- mindspore/ops/operations/sparse_ops.py +4 -4
- mindspore/ops/silent_check.py +162 -0
- mindspore/parallel/__init__.py +3 -2
- mindspore/parallel/_auto_parallel_context.py +82 -3
- mindspore/parallel/_parallel_serialization.py +34 -2
- mindspore/parallel/_tensor.py +3 -1
- mindspore/parallel/_transformer/transformer.py +8 -8
- mindspore/parallel/checkpoint_transform.py +191 -45
- mindspore/profiler/parser/ascend_cluster_generator.py +111 -0
- mindspore/profiler/parser/ascend_communicate_generator.py +315 -0
- mindspore/profiler/parser/ascend_flops_generator.py +8 -2
- mindspore/profiler/parser/ascend_fpbp_generator.py +8 -2
- mindspore/profiler/parser/ascend_hccl_generator.py +2 -2
- mindspore/profiler/parser/ascend_msprof_exporter.py +30 -6
- mindspore/profiler/parser/ascend_msprof_generator.py +16 -5
- mindspore/profiler/parser/ascend_op_generator.py +15 -7
- mindspore/profiler/parser/ascend_timeline_generator.py +5 -2
- mindspore/profiler/parser/base_timeline_generator.py +11 -3
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +2 -1
- mindspore/profiler/parser/framework_parser.py +8 -2
- mindspore/profiler/parser/memory_usage_parser.py +8 -2
- mindspore/profiler/parser/minddata_analyzer.py +8 -2
- mindspore/profiler/parser/minddata_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_analyzer.py +4 -2
- mindspore/profiler/parser/msadvisor_parser.py +9 -3
- mindspore/profiler/profiling.py +97 -25
- mindspore/rewrite/api/node.py +1 -1
- mindspore/rewrite/api/symbol_tree.py +2 -2
- mindspore/rewrite/parsers/for_parser.py +6 -6
- mindspore/rewrite/parsers/module_parser.py +4 -4
- mindspore/scipy/ops.py +55 -5
- mindspore/scipy/optimize/__init__.py +3 -2
- mindspore/scipy/optimize/linear_sum_assignment.py +38 -33
- mindspore/train/callback/_checkpoint.py +8 -8
- mindspore/train/callback/_landscape.py +2 -3
- mindspore/train/callback/_summary_collector.py +6 -7
- mindspore/train/dataset_helper.py +6 -0
- mindspore/train/model.py +17 -5
- mindspore/train/serialization.py +6 -1
- mindspore/train/summary/_writer_pool.py +1 -1
- mindspore/train/summary/summary_record.py +5 -6
- mindspore/version.py +1 -1
- {mindspore-2.2.10.dist-info → mindspore-2.2.14.dist-info}/METADATA +3 -2
- {mindspore-2.2.10.dist-info → mindspore-2.2.14.dist-info}/RECORD +140 -148
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -406
- mindspore/ops/_op_impl/_custom_op/flash_attention/constants.py +0 -41
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -467
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -563
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -193
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -435
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
- {mindspore-2.2.10.dist-info → mindspore-2.2.14.dist-info}/WHEEL +0 -0
- {mindspore-2.2.10.dist-info → mindspore-2.2.14.dist-info}/entry_points.txt +0 -0
- {mindspore-2.2.10.dist-info → mindspore-2.2.14.dist-info}/top_level.txt +0 -0
|
@@ -768,7 +768,7 @@ class DynamicFrameWorkParser:
|
|
|
768
768
|
rank_id (int): The rank ID.
|
|
769
769
|
"""
|
|
770
770
|
|
|
771
|
-
def __init__(self, output_path, rank_id):
|
|
771
|
+
def __init__(self, output_path, rank_id, pretty=False):
|
|
772
772
|
"""Initialization of parsing framework data."""
|
|
773
773
|
self._output_path = output_path
|
|
774
774
|
self._all_op_exe_time = defaultdict(list)
|
|
@@ -779,6 +779,12 @@ class DynamicFrameWorkParser:
|
|
|
779
779
|
self._exe_time_and_shape_detail = defaultdict(dict)
|
|
780
780
|
self._dynamic_shape_info = defaultdict(list)
|
|
781
781
|
self._step = 0
|
|
782
|
+
self._pretty = pretty
|
|
783
|
+
|
|
784
|
+
@property
|
|
785
|
+
def indent(self):
|
|
786
|
+
indent = 1 if self._pretty else None
|
|
787
|
+
return indent
|
|
782
788
|
|
|
783
789
|
def write_dynamic_shape_data(self, df_op_summary):
|
|
784
790
|
"""Analyze dynamic shape data and write to dynamic shape file."""
|
|
@@ -804,7 +810,7 @@ class DynamicFrameWorkParser:
|
|
|
804
810
|
self._dynamic_shape_info['op_type'] = self._op_info.get("op_type")
|
|
805
811
|
dynamic_shape_file_path = os.path.join(self._output_path, output_dynamic_shape_file_name)
|
|
806
812
|
with os.fdopen(os.open(dynamic_shape_file_path, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600), 'w') as fp:
|
|
807
|
-
json.dump(self._dynamic_shape_info, fp)
|
|
813
|
+
json.dump(self._dynamic_shape_info, fp, indent=self.indent)
|
|
808
814
|
os.chmod(dynamic_shape_file_path, stat.S_IREAD | stat.S_IWRITE)
|
|
809
815
|
|
|
810
816
|
def _analyse_op_execute_time(self, op_summary):
|
|
@@ -40,7 +40,7 @@ GIGABYTES = 1024 * 1024 * 1024
|
|
|
40
40
|
class MemoryUsageParser:
|
|
41
41
|
"""MemoryUsageParser to parse memory raw data."""
|
|
42
42
|
|
|
43
|
-
def __init__(self, profiling_dir, device_id):
|
|
43
|
+
def __init__(self, profiling_dir, device_id, pretty=False):
|
|
44
44
|
self._profiling_dir = profiling_dir
|
|
45
45
|
self._device_id = device_id
|
|
46
46
|
self._proto_file_path = 'memory_usage_{}.pb'
|
|
@@ -57,6 +57,12 @@ class MemoryUsageParser:
|
|
|
57
57
|
}
|
|
58
58
|
self._framework = {}
|
|
59
59
|
self._points = {}
|
|
60
|
+
self._pretty = pretty
|
|
61
|
+
|
|
62
|
+
@property
|
|
63
|
+
def indent(self):
|
|
64
|
+
indent = 1 if self._pretty else None
|
|
65
|
+
return indent
|
|
60
66
|
|
|
61
67
|
@staticmethod
|
|
62
68
|
def _process_framework_info(aicore_detail_data):
|
|
@@ -164,7 +170,7 @@ class MemoryUsageParser:
|
|
|
164
170
|
|
|
165
171
|
try:
|
|
166
172
|
with os.fdopen(os.open(file_path, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600), 'w') as json_file:
|
|
167
|
-
json.dump(content, json_file)
|
|
173
|
+
json.dump(content, json_file, indent=self.indent)
|
|
168
174
|
os.chmod(file_path, stat.S_IREAD | stat.S_IWRITE)
|
|
169
175
|
except (IOError, OSError) as err:
|
|
170
176
|
logger.critical('Fail to write memory file.\n%s', err)
|
|
@@ -41,7 +41,7 @@ class MinddataProfilingAnalyzer:
|
|
|
41
41
|
ProfilerFileNotFoundException: If any of the MindData profiling input files do not exist.
|
|
42
42
|
"""
|
|
43
43
|
|
|
44
|
-
def __init__(self, source_dir, device_id, output_path='./'):
|
|
44
|
+
def __init__(self, source_dir, device_id, output_path='./', pretty=False):
|
|
45
45
|
# Validate and save input parameters
|
|
46
46
|
self._device_id = device_id
|
|
47
47
|
self._source_dir = self._validate_directory(source_dir, 'Source directory')
|
|
@@ -55,6 +55,12 @@ class MinddataProfilingAnalyzer:
|
|
|
55
55
|
|
|
56
56
|
# Save output filename
|
|
57
57
|
self._save_path = self._get_save_path(output_path)
|
|
58
|
+
self._pretty = pretty
|
|
59
|
+
|
|
60
|
+
@property
|
|
61
|
+
def indent(self):
|
|
62
|
+
indent = 1 if self._pretty else None
|
|
63
|
+
return indent
|
|
58
64
|
|
|
59
65
|
@property
|
|
60
66
|
def save_path(self):
|
|
@@ -624,7 +630,7 @@ class MinddataProfilingAnalyzer:
|
|
|
624
630
|
|
|
625
631
|
# Save summary output dictionary to JSON output file (format#1)
|
|
626
632
|
with os.fdopen(os.open(self._save_path, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600), 'w') as save_file:
|
|
627
|
-
json.dump(summary_dict, save_file)
|
|
633
|
+
json.dump(summary_dict, save_file, indent=self.indent)
|
|
628
634
|
|
|
629
635
|
os.chmod(self._save_path, stat.S_IREAD | stat.S_IWRITE)
|
|
630
636
|
|
|
@@ -110,7 +110,7 @@ class MinddataParser:
|
|
|
110
110
|
input_path=os.path.join(source_path, "data"), file_name='DATA_PREPROCESS.dev.AICPUMI')
|
|
111
111
|
if not minddata_aicpu_source_path:
|
|
112
112
|
return
|
|
113
|
-
minddata_aicpu_output_path = os.path.join(output_path, "minddata_aicpu_" + device_id + ".txt")
|
|
113
|
+
minddata_aicpu_output_path = os.path.join(output_path, "minddata_aicpu_" + str(device_id) + ".txt")
|
|
114
114
|
minddata_aicpu_data = MinddataParser.parse_minddata_aicpu_data(minddata_aicpu_source_path)
|
|
115
115
|
if minddata_aicpu_data:
|
|
116
116
|
fwrite_format(minddata_aicpu_output_path, " ".join(col_names), is_start=True)
|
|
@@ -29,10 +29,11 @@ class Msadvisor:
|
|
|
29
29
|
"""
|
|
30
30
|
The interface to call MSAdvisor(CANN) by command line.
|
|
31
31
|
"""
|
|
32
|
-
def __init__(self, job_id, rank_id, output_path):
|
|
32
|
+
def __init__(self, job_id, rank_id, output_path, pretty=False):
|
|
33
33
|
self._job_id, self._device_id = job_id.split("/")
|
|
34
34
|
self._rank_id = rank_id
|
|
35
35
|
self._output_path = output_path
|
|
36
|
+
self._pretty = pretty
|
|
36
37
|
|
|
37
38
|
def call_msadvisor(self):
|
|
38
39
|
"""
|
|
@@ -75,6 +76,7 @@ class Msadvisor:
|
|
|
75
76
|
"""
|
|
76
77
|
Execute the MSAdvisor parser, generate timeline file and call MSAdvisor by command line.
|
|
77
78
|
"""
|
|
78
|
-
reformater = MsadvisorParser(self._job_id, self._device_id,
|
|
79
|
+
reformater = MsadvisorParser(self._job_id, self._device_id,
|
|
80
|
+
self._rank_id, self._output_path, pretty=self._pretty)
|
|
79
81
|
reformater.parse()
|
|
80
82
|
self.call_msadvisor()
|
|
@@ -36,7 +36,7 @@ class MsadvisorParser:
|
|
|
36
36
|
Data format conversion for MSAdvisor AICPU model.
|
|
37
37
|
"""
|
|
38
38
|
|
|
39
|
-
def __init__(self, job_id, device_id, rank_id, output_path):
|
|
39
|
+
def __init__(self, job_id, device_id, rank_id, output_path, pretty=False):
|
|
40
40
|
self._job_id = job_id
|
|
41
41
|
self._device_id = device_id
|
|
42
42
|
self._rank_id = rank_id
|
|
@@ -45,6 +45,12 @@ class MsadvisorParser:
|
|
|
45
45
|
self._aicpu_path = ""
|
|
46
46
|
self._time_start = 0
|
|
47
47
|
self._time_end = 0
|
|
48
|
+
self._pretty = pretty
|
|
49
|
+
|
|
50
|
+
@property
|
|
51
|
+
def indent(self):
|
|
52
|
+
indent = 1 if self._pretty else None
|
|
53
|
+
return indent
|
|
48
54
|
|
|
49
55
|
@staticmethod
|
|
50
56
|
def check_clear_make_dir(dir_path):
|
|
@@ -199,7 +205,7 @@ class MsadvisorParser:
|
|
|
199
205
|
break
|
|
200
206
|
if tid > 1:
|
|
201
207
|
output_file.write(",")
|
|
202
|
-
json.dump(op, output_file)
|
|
208
|
+
json.dump(op, output_file, indent=self.indent)
|
|
203
209
|
|
|
204
210
|
def write_aicpu(self):
|
|
205
211
|
"""
|
|
@@ -221,7 +227,7 @@ class MsadvisorParser:
|
|
|
221
227
|
if op.get("ts") > self._time_end:
|
|
222
228
|
break
|
|
223
229
|
output_file.write(",")
|
|
224
|
-
json.dump(op, output_file)
|
|
230
|
+
json.dump(op, output_file, indent=self.indent)
|
|
225
231
|
output_file.write("]")
|
|
226
232
|
|
|
227
233
|
def parse(self):
|
mindspore/profiler/profiling.py
CHANGED
|
@@ -20,6 +20,8 @@ import json
|
|
|
20
20
|
import glob
|
|
21
21
|
import subprocess
|
|
22
22
|
import csv
|
|
23
|
+
import socket
|
|
24
|
+
import shutil
|
|
23
25
|
from enum import Enum
|
|
24
26
|
import numpy as np
|
|
25
27
|
|
|
@@ -53,7 +55,9 @@ from mindspore.profiler.parser.ascend_fpbp_generator import AscendFPBPGenerator
|
|
|
53
55
|
from mindspore.profiler.parser.ascend_op_generator import AscendOPGenerator
|
|
54
56
|
from mindspore.profiler.parser.ascend_steptrace_generator import AscendStepTraceGenerator
|
|
55
57
|
from mindspore.profiler.parser.ascend_flops_generator import AscendFlopsGenerator
|
|
58
|
+
from mindspore.profiler.parser.ascend_cluster_generator import AscendClusterGenerator
|
|
56
59
|
from mindspore.profiler.parser.ascend_hccl_generator import AscendHCCLGenerator, AscendHCCLGeneratorOld
|
|
60
|
+
from mindspore.profiler.parser.ascend_communicate_generator import AscendCommunicationGenerator
|
|
57
61
|
|
|
58
62
|
INIT_OP_NAME = 'Default/InitDataSetQueue'
|
|
59
63
|
|
|
@@ -460,6 +464,7 @@ class Profiler:
|
|
|
460
464
|
self._dynamic_status = False
|
|
461
465
|
self._profile_framework = "all"
|
|
462
466
|
self._msprof_enable = os.getenv("PROFILER_SAMPLECONFIG")
|
|
467
|
+
self._pretty_json = False
|
|
463
468
|
if self._msprof_enable:
|
|
464
469
|
return
|
|
465
470
|
self._start_time = int(time.time() * 1000000)
|
|
@@ -467,9 +472,7 @@ class Profiler:
|
|
|
467
472
|
if kwargs.get("env_enable"):
|
|
468
473
|
self._profiler_init(kwargs)
|
|
469
474
|
return
|
|
470
|
-
|
|
471
|
-
msg = "Do not init twice in the profiler."
|
|
472
|
-
raise RuntimeError(msg)
|
|
475
|
+
|
|
473
476
|
Profiler._has_initialized = True
|
|
474
477
|
# get device_id and device_target
|
|
475
478
|
self._get_devid_rankid_and_devtarget()
|
|
@@ -527,6 +530,9 @@ class Profiler:
|
|
|
527
530
|
dev_info = info_dict.get("DeviceInfo", [])
|
|
528
531
|
dev_id = dev_info[0].get("id", -1)
|
|
529
532
|
|
|
533
|
+
if int(rank_id) < 0:
|
|
534
|
+
rank_id = 0
|
|
535
|
+
|
|
530
536
|
return str(rank_id), str(dev_id)
|
|
531
537
|
|
|
532
538
|
def op_analyse(self, op_name, device_id=None):
|
|
@@ -594,7 +600,7 @@ class Profiler:
|
|
|
594
600
|
return message
|
|
595
601
|
return op_info
|
|
596
602
|
|
|
597
|
-
def analyse(self, offline_path=None):
|
|
603
|
+
def analyse(self, offline_path=None, pretty=False):
|
|
598
604
|
"""
|
|
599
605
|
Collect and analyze training performance data, support calls during and after training. The example shows above.
|
|
600
606
|
|
|
@@ -602,7 +608,9 @@ class Profiler:
|
|
|
602
608
|
offline_path (Union[str, None], optional): The data path which need to be analysed with offline mode.
|
|
603
609
|
Offline mode isused in abnormal exit scenario. This parameter should be set to ``None``
|
|
604
610
|
for online mode. Default: ``None``.
|
|
611
|
+
pretty (bool, optional): Whether to pretty json files. Default: ``False``.
|
|
605
612
|
"""
|
|
613
|
+
self._pretty_json = pretty
|
|
606
614
|
self._analyse(offline_path=offline_path)
|
|
607
615
|
|
|
608
616
|
def _analyse(self, offline_path=None, model_iteration_dict=None):
|
|
@@ -643,8 +651,6 @@ class Profiler:
|
|
|
643
651
|
self._dynamic_status = self._profiler_manager.dynamic_status()
|
|
644
652
|
_environment_check()
|
|
645
653
|
|
|
646
|
-
self._cpu_profiler.stop()
|
|
647
|
-
|
|
648
654
|
cpu_op_file = glob.glob(os.path.join(self._output_path, 'cpu_op_type_info_*'))
|
|
649
655
|
if self._device_target and self._device_target != DeviceTarget.CPU.value and cpu_op_file:
|
|
650
656
|
self._is_heterogeneous = True
|
|
@@ -673,7 +679,6 @@ class Profiler:
|
|
|
673
679
|
|
|
674
680
|
Raises:
|
|
675
681
|
RuntimeError: If the profiler has already started.
|
|
676
|
-
RuntimeError: If MD profiling has stopped, repeated start action is not supported.
|
|
677
682
|
RuntimeError: If the `start_profile` parameter is not set or is set to ``True``.
|
|
678
683
|
|
|
679
684
|
Examples:
|
|
@@ -707,13 +712,8 @@ class Profiler:
|
|
|
707
712
|
if not self._has_started:
|
|
708
713
|
if not self._has_started_twice:
|
|
709
714
|
self._has_started = True
|
|
710
|
-
self._has_started_twice = True
|
|
711
|
-
else:
|
|
712
|
-
raise RuntimeError("MindSpore Profiling has finished, repeated start and stop actions are not "
|
|
713
|
-
"supported.")
|
|
714
715
|
else:
|
|
715
|
-
raise RuntimeError("The profiler has already started.
|
|
716
|
-
"is set to False.")
|
|
716
|
+
raise RuntimeError("The profiler has already started. Do not turn on again in the open state.")
|
|
717
717
|
|
|
718
718
|
# No need to start anything if parse profiling data offline
|
|
719
719
|
if self._is_offline_parser():
|
|
@@ -785,7 +785,8 @@ class Profiler:
|
|
|
785
785
|
# Stop data collection after all operators are executed.
|
|
786
786
|
_pynative_executor.sync()
|
|
787
787
|
|
|
788
|
-
|
|
788
|
+
self._cpu_profiler.stop()
|
|
789
|
+
if self._data_process and self._md_profiler is not None:
|
|
789
790
|
self._md_profiler.stop()
|
|
790
791
|
self._md_profiler.save(self._output_path)
|
|
791
792
|
|
|
@@ -962,9 +963,6 @@ class Profiler:
|
|
|
962
963
|
if self._profile_communication:
|
|
963
964
|
hccl_option = {"output": self._output_path, "task_trace": "on"}
|
|
964
965
|
os.environ['PROFILING_OPTIONS'] = json.dumps(hccl_option)
|
|
965
|
-
if not self.start_profile:
|
|
966
|
-
raise RuntimeError(f"For '{self.__class__.__name__}', the parameter profile_communication can "
|
|
967
|
-
f"not be True while starting profiler in the process of training.")
|
|
968
966
|
|
|
969
967
|
self._profile_memory = kwargs.pop("profile_memory", False)
|
|
970
968
|
if not isinstance(self._profile_memory, bool):
|
|
@@ -1060,7 +1058,8 @@ class Profiler:
|
|
|
1060
1058
|
# Analyze minddata information
|
|
1061
1059
|
logger.info("Profiling: analyzing the minddata information.")
|
|
1062
1060
|
try:
|
|
1063
|
-
MinddataProfilingAnalyzer(self._output_path, store_id,
|
|
1061
|
+
MinddataProfilingAnalyzer(self._output_path, store_id,
|
|
1062
|
+
self._output_path, pretty=self._pretty_json).analyze()
|
|
1064
1063
|
except ProfilerException as err:
|
|
1065
1064
|
logger.warning(err.message)
|
|
1066
1065
|
finally:
|
|
@@ -1080,7 +1079,7 @@ class Profiler:
|
|
|
1080
1079
|
|
|
1081
1080
|
step_trace_point_info_path = validate_and_normalize_path(step_trace_point_info_path)
|
|
1082
1081
|
|
|
1083
|
-
fpbp_analyse = AscendFPBPGenerator(op_summary, steptrace)
|
|
1082
|
+
fpbp_analyse = AscendFPBPGenerator(op_summary, steptrace, pretty=self._pretty_json)
|
|
1084
1083
|
points, _ = fpbp_analyse.parse()
|
|
1085
1084
|
fpbp_analyse.write(step_trace_point_info_path)
|
|
1086
1085
|
except ProfilerException as err:
|
|
@@ -1149,7 +1148,7 @@ class Profiler:
|
|
|
1149
1148
|
logger.info("Profiling: analyzing the timeline data")
|
|
1150
1149
|
timeline_analyser = AscendTimelineGenerator(self._output_path, self._dev_id, self._rank_id, self._rank_size,
|
|
1151
1150
|
context.get_context('mode'))
|
|
1152
|
-
timeline_analyser.init_timeline(op_summary, steptrace)
|
|
1151
|
+
timeline_analyser.init_timeline(op_summary, steptrace, pretty=self._pretty_json)
|
|
1153
1152
|
timeline_analyser.write_timeline(self._timeline_size_limit_byte)
|
|
1154
1153
|
timeline_analyser.write_timeline_summary()
|
|
1155
1154
|
except (ProfilerIOException, ProfilerFileNotFoundException, RuntimeError) as err:
|
|
@@ -1166,7 +1165,7 @@ class Profiler:
|
|
|
1166
1165
|
logger.warning("The profile_memory parameter cannot be set on the dynamic shape network.")
|
|
1167
1166
|
logger.warning(
|
|
1168
1167
|
"[Profiler]Dynamic Shape network does not support collecting step trace performance data currently.")
|
|
1169
|
-
dynamic_parser = DynamicFrameWorkParser(self._output_path, self._rank_id)
|
|
1168
|
+
dynamic_parser = DynamicFrameWorkParser(self._output_path, self._rank_id, pretty=self._pretty_json)
|
|
1170
1169
|
dynamic_parser.write_dynamic_shape_data(op_summary)
|
|
1171
1170
|
|
|
1172
1171
|
def _ascend_flops_analyse(self, op_summary):
|
|
@@ -1184,7 +1183,7 @@ class Profiler:
|
|
|
1184
1183
|
flops_path = validate_and_normalize_path(flops_path)
|
|
1185
1184
|
flops_summary_path = validate_and_normalize_path(flops_summary_path)
|
|
1186
1185
|
|
|
1187
|
-
flops_analyser = AscendFlopsGenerator(op_summary)
|
|
1186
|
+
flops_analyser = AscendFlopsGenerator(op_summary, pretty=self._pretty_json)
|
|
1188
1187
|
flops_analyser.parse()
|
|
1189
1188
|
flops_analyser.write(flops_path, flops_summary_path)
|
|
1190
1189
|
|
|
@@ -1208,6 +1207,73 @@ class Profiler:
|
|
|
1208
1207
|
finally:
|
|
1209
1208
|
pass
|
|
1210
1209
|
|
|
1210
|
+
def _ascend_ms_analyze(self, source_path):
|
|
1211
|
+
"""Ascend ms generate"""
|
|
1212
|
+
time_stamp = time.strftime("%Y%m%d%H%M%S", time.localtime(time.time()))
|
|
1213
|
+
if self._rank_id:
|
|
1214
|
+
ascend_ms_path = f"rank-{self._rank_id}_{time_stamp}_ascend_ms"
|
|
1215
|
+
else:
|
|
1216
|
+
ascend_ms_path = f"{socket.gethostname()}--{os.getpid()}_{time_stamp}_ascend_ms"
|
|
1217
|
+
self._ascend_ms_path = os.path.join(self._output_path, ascend_ms_path)
|
|
1218
|
+
if not os.path.exists(self._ascend_ms_path):
|
|
1219
|
+
os.makedirs(self._ascend_ms_path, exist_ok=True)
|
|
1220
|
+
os.chmod(self._ascend_ms_path, stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR)
|
|
1221
|
+
|
|
1222
|
+
dev_id = self._rank_id if self._device_target == DeviceTarget.ASCEND.value else self._dev_id
|
|
1223
|
+
ascend_profiler_output_path = os.path.join(self._ascend_ms_path, 'ASCEND_PROFILER_OUTPUT')
|
|
1224
|
+
os.makedirs(ascend_profiler_output_path, exist_ok=True)
|
|
1225
|
+
|
|
1226
|
+
source_profiler_info_path = os.path.join(self._output_path, f"profiler_info_{dev_id}.json")
|
|
1227
|
+
target_profiler_info_path = os.path.join(self._ascend_ms_path, f"profiler_info_{dev_id}.json")
|
|
1228
|
+
shutil.copy(source_profiler_info_path, target_profiler_info_path)
|
|
1229
|
+
|
|
1230
|
+
source_timeline_path = os.path.join(self._output_path, f"ascend_timeline_display_{dev_id}.json")
|
|
1231
|
+
target_timeline_path = os.path.join(ascend_profiler_output_path, f"trace_view.json")
|
|
1232
|
+
shutil.copy(source_timeline_path, target_timeline_path)
|
|
1233
|
+
|
|
1234
|
+
self._ascend_graph_cluster_analyse(source_path, ascend_profiler_output_path)
|
|
1235
|
+
self._ascend_graph_communicate_analyse(source_path, ascend_profiler_output_path)
|
|
1236
|
+
|
|
1237
|
+
def _ascend_graph_cluster_analyse(self, source_path, ascend_profiler_output_path):
|
|
1238
|
+
"""Analyse step trace time info"""
|
|
1239
|
+
|
|
1240
|
+
try:
|
|
1241
|
+
logger.info("Profiling: analyzing the step trace time profiler info.")
|
|
1242
|
+
|
|
1243
|
+
step_trace_time_path = os.path.join(ascend_profiler_output_path, f'step_trace_time.csv')
|
|
1244
|
+
step_trace_time_path = validate_and_normalize_path(step_trace_time_path)
|
|
1245
|
+
|
|
1246
|
+
cluster_analyse = AscendClusterGenerator(os.path.join(source_path, 'timeline'))
|
|
1247
|
+
cluster_analyse.parse()
|
|
1248
|
+
cluster_analyse.write(step_trace_time_path)
|
|
1249
|
+
except (ProfilerIOException, ProfilerFileNotFoundException, ProfilerRawFileException) as err:
|
|
1250
|
+
logger.warning(err.message)
|
|
1251
|
+
finally:
|
|
1252
|
+
pass
|
|
1253
|
+
|
|
1254
|
+
def _ascend_graph_communicate_analyse(self, source_path, ascend_profiler_output_path):
|
|
1255
|
+
"""Analyse communicate info"""
|
|
1256
|
+
if not self._profile_communication:
|
|
1257
|
+
return
|
|
1258
|
+
|
|
1259
|
+
try:
|
|
1260
|
+
logger.info("Profiling: analyzing the communicate and communicate_matrix profiler info.")
|
|
1261
|
+
|
|
1262
|
+
communication_file_path = os.path.join(ascend_profiler_output_path, f'communication.json')
|
|
1263
|
+
communication_file_path = validate_and_normalize_path(communication_file_path)
|
|
1264
|
+
|
|
1265
|
+
communication_matrix_file_path = os.path.join(ascend_profiler_output_path, f"communication_matrix.json")
|
|
1266
|
+
communication_matrix_file_path = validate_and_normalize_path(communication_matrix_file_path)
|
|
1267
|
+
|
|
1268
|
+
analyze_path = os.path.join(os.path.dirname(source_path), 'analyze')
|
|
1269
|
+
communicate_analyser = AscendCommunicationGenerator(analyze_path)
|
|
1270
|
+
communicate_analyser.parse()
|
|
1271
|
+
communicate_analyser.write(communication_file_path, communication_matrix_file_path)
|
|
1272
|
+
except (ProfilerIOException, ProfilerFileNotFoundException, ProfilerRawFileException) as err:
|
|
1273
|
+
logger.warning(err.message)
|
|
1274
|
+
finally:
|
|
1275
|
+
pass
|
|
1276
|
+
|
|
1211
1277
|
def _ascend_graph_hccl_analyse(self, source_path, steptrace, flag):
|
|
1212
1278
|
"""Analyse hccl profiler info."""
|
|
1213
1279
|
if not self._profile_communication:
|
|
@@ -1237,7 +1303,7 @@ class Profiler:
|
|
|
1237
1303
|
def _ascend_graph_msadvisor_analyse(self, job_id):
|
|
1238
1304
|
"""Call MSAdvisor function."""
|
|
1239
1305
|
logger.info("MSAdvisor starts running.")
|
|
1240
|
-
msadvisor = Msadvisor(job_id, self._rank_id, self._output_path)
|
|
1306
|
+
msadvisor = Msadvisor(job_id, self._rank_id, self._output_path, pretty=self._pretty_json)
|
|
1241
1307
|
try:
|
|
1242
1308
|
msadvisor.analyse()
|
|
1243
1309
|
except FileNotFoundError as err:
|
|
@@ -1283,6 +1349,7 @@ class Profiler:
|
|
|
1283
1349
|
self._ascend_dynamic_net_analyse(op_summary)
|
|
1284
1350
|
self._ascend_flops_analyse(op_summary)
|
|
1285
1351
|
self._ascend_graph_memory_analyse(points)
|
|
1352
|
+
self._ascend_ms_analyze(source_path)
|
|
1286
1353
|
self._ascend_graph_hccl_analyse(source_path, steptrace, flag)
|
|
1287
1354
|
self._ascend_graph_msadvisor_analyse(job_id)
|
|
1288
1355
|
ProfilerInfo.set_graph_ids(graph_ids)
|
|
@@ -1368,11 +1435,15 @@ class Profiler:
|
|
|
1368
1435
|
|
|
1369
1436
|
def _cpu_analyse(self):
|
|
1370
1437
|
"""Collect and analyse cpu performance data."""
|
|
1438
|
+
if self._has_started:
|
|
1439
|
+
self.stop()
|
|
1440
|
+
else:
|
|
1441
|
+
logger.info("No need to stop profiler because profiler has been stopped or profiler has not been started.")
|
|
1371
1442
|
if not self._op_time:
|
|
1372
1443
|
return
|
|
1373
1444
|
try:
|
|
1374
1445
|
timeline_generator = CpuTimelineGenerator(self._output_path, self._rank_id, context.get_context("mode"))
|
|
1375
|
-
timeline_generator.init_timeline()
|
|
1446
|
+
timeline_generator.init_timeline(pretty=self._pretty_json)
|
|
1376
1447
|
timeline_generator.write_timeline(self._timeline_size_limit_byte)
|
|
1377
1448
|
timeline_generator.write_timeline_summary()
|
|
1378
1449
|
except (ProfilerIOException, ProfilerFileNotFoundException, RuntimeError) as err:
|
|
@@ -1462,7 +1533,7 @@ class Profiler:
|
|
|
1462
1533
|
"""Analyse memory usage data."""
|
|
1463
1534
|
integrator = Integrator(self._output_path, self._rank_id)
|
|
1464
1535
|
aicore_detail_data = integrator.get_aicore_detail_data()
|
|
1465
|
-
memory_parser = MemoryUsageParser(self._output_path, self._rank_id)
|
|
1536
|
+
memory_parser = MemoryUsageParser(self._output_path, self._rank_id, pretty=self._pretty_json)
|
|
1466
1537
|
memory_parser.init_memory_usage_info(aicore_detail_data, points)
|
|
1467
1538
|
memory_parser.write_memory_files()
|
|
1468
1539
|
|
|
@@ -1630,6 +1701,7 @@ class Profiler:
|
|
|
1630
1701
|
else:
|
|
1631
1702
|
output_path = kwargs.pop("output_path")
|
|
1632
1703
|
self._output_path = validate_and_normalize_path(output_path)
|
|
1704
|
+
|
|
1633
1705
|
self._output_path = os.path.join(self._output_path, "profiler")
|
|
1634
1706
|
if not os.path.exists(self._output_path):
|
|
1635
1707
|
os.makedirs(self._output_path, exist_ok=True)
|
mindspore/rewrite/api/node.py
CHANGED
|
@@ -283,7 +283,7 @@ class Node:
|
|
|
283
283
|
>>> dst_node = stree.get_node("relu_3")
|
|
284
284
|
>>> dst_node.set_arg_by_node(0, src_node, 0)
|
|
285
285
|
>>> print(dst_node.get_args())
|
|
286
|
-
[
|
|
286
|
+
[fc1_var]
|
|
287
287
|
"""
|
|
288
288
|
Validator.check_value_type("arg_idx", arg_idx, [int], "Node")
|
|
289
289
|
Validator.check_value_type("src_node", src_node, [Node], "Node")
|
|
@@ -168,8 +168,8 @@ class SymbolTree:
|
|
|
168
168
|
>>> net = LeNet5()
|
|
169
169
|
>>> stree = SymbolTree.create(net)
|
|
170
170
|
>>> print([node.get_name() for node in stree.nodes()])
|
|
171
|
-
['input_x', 'Expr', 'conv1', 'relu', 'max_pool2d', 'conv2', 'relu_1', 'max_pool2d_1',
|
|
172
|
-
'flatten', 'fc1', 'relu_2', 'fc2', 'relu_3', 'fc3', '
|
|
171
|
+
['input_x', 'Expr', 'conv1', 'relu', 'max_pool2d', 'conv2', 'relu_1', 'max_pool2d_1', 'attribute_assign',
|
|
172
|
+
'unaryop_not', 'if_node', 'flatten', 'fc1', 'relu_2', 'fc2', 'relu_3', 'fc3', 'return_1']
|
|
173
173
|
"""
|
|
174
174
|
Validator.check_value_type("all_nodes", all_nodes, [bool], "nodes")
|
|
175
175
|
nodes = self._symbol_tree.all_nodes() if all_nodes else self._symbol_tree.nodes()
|
|
@@ -72,7 +72,7 @@ class ForParser(Parser):
|
|
|
72
72
|
return
|
|
73
73
|
iter_code = astunparse.unparse(node.iter)
|
|
74
74
|
if not iter_code.startswith(EVAL_WHITE_LIST):
|
|
75
|
-
logger.
|
|
75
|
+
logger.info(
|
|
76
76
|
f"For MindSpore Rewrtie, illegal iteration condition for For node, it must start with{EVAL_WHITE_LIST}")
|
|
77
77
|
return
|
|
78
78
|
if "self" in iter_code:
|
|
@@ -82,7 +82,7 @@ class ForParser(Parser):
|
|
|
82
82
|
except (NameError, TypeError) as e:
|
|
83
83
|
_info = f"For MindSpore Rewrtie, when eval '{iter_code}' by using JIT Fallback feature, " \
|
|
84
84
|
f"an error occurred: {str(e)}"
|
|
85
|
-
logger.
|
|
85
|
+
logger.info(_info)
|
|
86
86
|
stree.try_append_python_node(node, node, node_manager)
|
|
87
87
|
return
|
|
88
88
|
|
|
@@ -115,13 +115,13 @@ class ForParser(Parser):
|
|
|
115
115
|
stree.on_change(Event.CodeChangeEvent)
|
|
116
116
|
return
|
|
117
117
|
if isinstance(iter_obj, range):
|
|
118
|
-
logger.
|
|
118
|
+
logger.info("For MindSpore Rewrite, range not support.")
|
|
119
119
|
elif isinstance(iter_obj, zip):
|
|
120
|
-
logger.
|
|
120
|
+
logger.info("For MindSpore Rewrite, zip not support.")
|
|
121
121
|
elif isinstance(iter_obj, enumerate):
|
|
122
|
-
logger.
|
|
122
|
+
logger.info("For MindSpore Rewrite, enumerate not support.")
|
|
123
123
|
else:
|
|
124
|
-
logger.
|
|
124
|
+
logger.info(f"For MindSpore Rewrite, not supported type: {type(iter_obj).__name__}")
|
|
125
125
|
stree.try_append_python_node(node, node, node_manager)
|
|
126
126
|
return
|
|
127
127
|
|
|
@@ -170,15 +170,15 @@ class ModuleParser(Parser):
|
|
|
170
170
|
level_count += 1
|
|
171
171
|
continue
|
|
172
172
|
except Exception as e: # pylint: disable=W0703
|
|
173
|
-
logger.
|
|
174
|
-
|
|
173
|
+
logger.info(f"For MindSpore Rewrite, in module parser, process import code: "
|
|
174
|
+
f"{import_code} failed: {e}. Ignore this import code.")
|
|
175
175
|
return None, None
|
|
176
176
|
else:
|
|
177
177
|
# try test code success
|
|
178
178
|
return import_node_test.module, file_path_tmp
|
|
179
179
|
# try codes with all level failed
|
|
180
|
-
logger.
|
|
181
|
-
|
|
180
|
+
logger.info(f"For MindSpore Rewrite, in module parser, test import code: "
|
|
181
|
+
f"{astunparse.unparse(import_node).strip()} failed. Ignore this import code.")
|
|
182
182
|
return None, None
|
|
183
183
|
|
|
184
184
|
@staticmethod
|
mindspore/scipy/ops.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2021 Huawei Technologies Co., Ltd
|
|
1
|
+
# Copyright 2021-2023 Huawei Technologies Co., Ltd
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -156,14 +156,64 @@ class LU(PrimitiveWithInfer):
|
|
|
156
156
|
|
|
157
157
|
|
|
158
158
|
class LinearSumAssignment(Primitive):
|
|
159
|
-
"""
|
|
159
|
+
r"""
|
|
160
|
+
Solve the linear sum assignment problem.
|
|
161
|
+
|
|
162
|
+
The assignment problem is represented as follows:
|
|
163
|
+
|
|
164
|
+
.. math::
|
|
165
|
+
min\sum_{i}^{} \sum_{j}^{} C_{i,j} X_{i,j}
|
|
166
|
+
|
|
167
|
+
where :math:`C` is cost matrix, :math:`X_{i,j} = 1` means column :math:`j` is assigned to row :math:`i` .
|
|
168
|
+
|
|
169
|
+
Inputs:
|
|
170
|
+
- **cost_matrix** (Tensor) - 2-D cost matrix. Tensor of shape :math:`(M, N)` .
|
|
171
|
+
- **dimension_limit** (Tensor, optional) - A scalar used to limit the actual size of the 2nd dimension of
|
|
172
|
+
``cost_matrix``. Default is ``Tensor(sys.maxsize)``, which means no limitation. The type is 0-D int64
|
|
173
|
+
Tensor.
|
|
174
|
+
- **maximize** (bool) - Calculate a maximum weight matching if true, otherwise calculate a minimum weight
|
|
175
|
+
matching.
|
|
176
|
+
|
|
177
|
+
Outputs:
|
|
178
|
+
A tuple of tensors containing 'row_idx' and 'col_idx'.
|
|
179
|
+
|
|
180
|
+
- **row_idx** (Tensor) - Row indices of the problem. If `dimension_limit` is given, -1 would be padded at the
|
|
181
|
+
end. The shape is :math:`(N, )` , where :math:`N` is the minimum value of `cost_matrix` dimension.
|
|
182
|
+
- **col_idx** (Tensor) - Column indices of the problem. If `dimension_limit` is given, -1 would be padded at
|
|
183
|
+
the end. The shape is :math:`(N, )` , where :math:`N` is the minimum value of `cost_matrix` dimension.
|
|
184
|
+
|
|
185
|
+
Raises:
|
|
186
|
+
TypeError: If the data type of `cost_matrix` is not the type in [float16, float32, float64,
|
|
187
|
+
int8, int16, int32, int64, uint8, uint16, uint32, uint64, bool]
|
|
188
|
+
TypeError: If the type of `maximize` is not bool.
|
|
189
|
+
TypeError: If the data type of `dimension_limit` is not int64.
|
|
190
|
+
ValueError: If the rank of `cost_matrix` is not 2.
|
|
191
|
+
ValueError: If the number of input args is not 3.
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
Supported Platforms:
|
|
195
|
+
``Ascend`` ``CPU``
|
|
196
|
+
|
|
197
|
+
Examples:
|
|
198
|
+
>>> import mindspore as ms
|
|
199
|
+
>>> import numpy as np
|
|
200
|
+
>>> from mindspore import Tensor
|
|
201
|
+
>>> from mindspore.scipy.ops import LinearSumAssignment
|
|
202
|
+
>>> lsap = LinearSumAssignment()
|
|
203
|
+
>>> cost_matrix = Tensor(np.array([[2, 3, 3], [3, 2, 3], [3, 3, 2]])).astype(ms.float64)
|
|
204
|
+
>>> dimension_limit = Tensor(2)
|
|
205
|
+
>>> maximize = False
|
|
206
|
+
>>> a, b = lsap(cost_matrix, dimension_limit, maximize)
|
|
207
|
+
>>> print(a)
|
|
208
|
+
[0 1 -1]
|
|
209
|
+
>>> print(b)
|
|
210
|
+
[0 1 -1]
|
|
211
|
+
"""
|
|
160
212
|
|
|
161
213
|
@prim_attr_register
|
|
162
214
|
def __init__(self):
|
|
163
|
-
super().__init__("LinearSumAssignment")
|
|
215
|
+
super().__init__(name="LinearSumAssignment")
|
|
164
216
|
self.init_prim_io_names(inputs=['cost_matrix', 'dimension_limit', 'maximize'], outputs=['row_ind', 'col_ind'])
|
|
165
|
-
self.add_prim_attr("cust_aicpu", "mindspore_aicpu_kernels")
|
|
166
|
-
|
|
167
217
|
|
|
168
218
|
# pylint: disable=C0413,W0611
|
|
169
219
|
from .ops_grad import get_bprpo_eigh, get_bprpo_trsm
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2021 Huawei Technologies Co., Ltd
|
|
1
|
+
# Copyright 2021-2023 Huawei Technologies Co., Ltd
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -15,5 +15,6 @@
|
|
|
15
15
|
"""Optimize submodule"""
|
|
16
16
|
from .minimize import minimize
|
|
17
17
|
from .line_search import line_search
|
|
18
|
+
from .linear_sum_assignment import linear_sum_assignment
|
|
18
19
|
|
|
19
|
-
__all__ = ["minimize", "line_search"]
|
|
20
|
+
__all__ = ["minimize", "line_search", "linear_sum_assignment"]
|