mindspore 2.3.0__cp39-cp39-win_amd64.whl → 2.4.0__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +3 -1
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +50 -9
- mindspore/_extends/parse/compile_config.py +41 -0
- mindspore/_extends/parse/parser.py +9 -7
- mindspore/_extends/parse/standard_method.py +52 -14
- mindspore/_extends/pijit/pijit_func_white_list.py +350 -24
- mindspore/amp.py +24 -10
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/common/__init__.py +6 -4
- mindspore/common/_pijit_context.py +190 -0
- mindspore/common/_register_for_tensor.py +2 -1
- mindspore/common/_tensor_overload.py +139 -0
- mindspore/common/api.py +102 -87
- mindspore/common/dump.py +5 -6
- mindspore/common/generator.py +1 -7
- mindspore/common/hook_handle.py +14 -26
- mindspore/common/mindir_util.py +2 -2
- mindspore/common/parameter.py +46 -13
- mindspore/common/recompute.py +39 -9
- mindspore/common/sparse_tensor.py +7 -3
- mindspore/common/tensor.py +209 -29
- mindspore/communication/__init__.py +1 -1
- mindspore/communication/_comm_helper.py +38 -3
- mindspore/communication/comm_func.py +310 -55
- mindspore/communication/management.py +14 -14
- mindspore/context.py +123 -22
- mindspore/dataset/__init__.py +1 -1
- mindspore/dataset/audio/__init__.py +1 -1
- mindspore/dataset/core/config.py +7 -0
- mindspore/dataset/core/validator_helpers.py +7 -0
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +72 -44
- mindspore/dataset/engine/datasets_audio.py +7 -7
- mindspore/dataset/engine/datasets_standard_format.py +53 -3
- mindspore/dataset/engine/datasets_text.py +20 -20
- mindspore/dataset/engine/datasets_user_defined.py +174 -104
- mindspore/dataset/engine/datasets_vision.py +33 -33
- mindspore/dataset/engine/iterators.py +29 -0
- mindspore/dataset/engine/obs/util.py +7 -0
- mindspore/dataset/engine/queue.py +114 -60
- mindspore/dataset/engine/serializer_deserializer.py +2 -2
- mindspore/dataset/engine/validators.py +34 -14
- mindspore/dataset/text/__init__.py +1 -4
- mindspore/dataset/transforms/__init__.py +0 -3
- mindspore/dataset/utils/line_reader.py +2 -0
- mindspore/dataset/vision/__init__.py +1 -4
- mindspore/dataset/vision/utils.py +1 -1
- mindspore/dataset/vision/validators.py +2 -1
- mindspore/dnnl.dll +0 -0
- mindspore/{nn/extend → experimental/es}/__init__.py +4 -11
- mindspore/experimental/es/embedding_service.py +883 -0
- mindspore/{nn/layer → experimental/es}/embedding_service_layer.py +218 -30
- mindspore/experimental/llm_boost/__init__.py +21 -0
- mindspore/{nn/extend/layer → experimental/llm_boost/atb}/__init__.py +4 -8
- mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
- mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
- mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
- mindspore/experimental/llm_boost/register.py +129 -0
- mindspore/experimental/llm_boost/utils.py +31 -0
- mindspore/experimental/optim/adamw.py +85 -0
- mindspore/experimental/optim/optimizer.py +3 -0
- mindspore/hal/__init__.py +3 -3
- mindspore/hal/contiguous_tensors_handle.py +175 -0
- mindspore/hal/stream.py +18 -0
- mindspore/include/api/model_group.h +13 -1
- mindspore/include/api/types.h +10 -10
- mindspore/include/dataset/config.h +2 -2
- mindspore/include/dataset/constants.h +2 -2
- mindspore/include/dataset/execute.h +2 -2
- mindspore/include/dataset/vision.h +4 -0
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +1 -1
- mindspore/mindrecord/filewriter.py +68 -51
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mint/__init__.py +495 -46
- mindspore/mint/distributed/__init__.py +31 -0
- mindspore/mint/distributed/distributed.py +254 -0
- mindspore/mint/nn/__init__.py +266 -21
- mindspore/mint/nn/functional.py +125 -19
- mindspore/mint/nn/layer/__init__.py +39 -0
- mindspore/mint/nn/layer/activation.py +133 -0
- mindspore/mint/nn/layer/normalization.py +477 -0
- mindspore/mint/nn/layer/pooling.py +110 -0
- mindspore/mint/optim/adamw.py +28 -7
- mindspore/mint/special/__init__.py +63 -0
- mindspore/multiprocessing/__init__.py +2 -1
- mindspore/nn/__init__.py +0 -1
- mindspore/nn/cell.py +275 -93
- mindspore/nn/layer/activation.py +211 -44
- mindspore/nn/layer/basic.py +113 -3
- mindspore/nn/layer/embedding.py +120 -2
- mindspore/nn/layer/normalization.py +101 -5
- mindspore/nn/layer/padding.py +34 -48
- mindspore/nn/layer/pooling.py +161 -7
- mindspore/nn/layer/transformer.py +3 -3
- mindspore/nn/loss/__init__.py +2 -2
- mindspore/nn/loss/loss.py +84 -6
- mindspore/nn/optim/__init__.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -1
- mindspore/nn/optim/adam.py +1 -1
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/tft_wrapper.py +127 -0
- mindspore/nn/wrap/cell_wrapper.py +12 -23
- mindspore/nn/wrap/grad_reducer.py +5 -5
- mindspore/nn/wrap/loss_scale.py +17 -3
- mindspore/numpy/__init__.py +1 -1
- mindspore/numpy/array_creations.py +65 -68
- mindspore/numpy/array_ops.py +64 -60
- mindspore/numpy/fft.py +610 -75
- mindspore/numpy/logic_ops.py +11 -10
- mindspore/numpy/math_ops.py +85 -84
- mindspore/numpy/utils_const.py +4 -4
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +6 -4
- mindspore/ops/_grad_experimental/grad_comm_ops.py +47 -3
- mindspore/ops/_grad_experimental/grad_math_ops.py +0 -22
- mindspore/ops/_vmap/vmap_array_ops.py +2 -4
- mindspore/ops/_vmap/vmap_math_ops.py +17 -1
- mindspore/ops/_vmap/vmap_nn_ops.py +43 -2
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +85 -7
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +2 -0
- mindspore/ops/auto_generate/gen_extend_func.py +734 -13
- mindspore/ops/auto_generate/gen_ops_def.py +2420 -381
- mindspore/ops/auto_generate/gen_ops_prim.py +5196 -1659
- mindspore/ops/auto_generate/pyboost_inner_prim.py +176 -56
- mindspore/ops/composite/base.py +85 -48
- mindspore/ops/composite/multitype_ops/_compile_utils.py +1 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -2
- mindspore/ops/function/__init__.py +22 -0
- mindspore/ops/function/array_func.py +490 -153
- mindspore/ops/function/debug_func.py +113 -1
- mindspore/ops/function/fft_func.py +15 -2
- mindspore/ops/function/grad/grad_func.py +3 -2
- mindspore/ops/function/math_func.py +558 -207
- mindspore/ops/function/nn_func.py +817 -383
- mindspore/ops/function/other_func.py +3 -2
- mindspore/ops/function/random_func.py +184 -8
- mindspore/ops/function/reshard_func.py +13 -11
- mindspore/ops/function/sparse_unary_func.py +1 -1
- mindspore/ops/function/vmap_func.py +3 -2
- mindspore/ops/functional.py +24 -14
- mindspore/ops/op_info_register.py +3 -3
- mindspore/ops/operations/__init__.py +6 -1
- mindspore/ops/operations/_grad_ops.py +2 -76
- mindspore/ops/operations/_infer_ops.py +1 -1
- mindspore/ops/operations/_inner_ops.py +71 -94
- mindspore/ops/operations/array_ops.py +12 -146
- mindspore/ops/operations/comm_ops.py +42 -53
- mindspore/ops/operations/custom_ops.py +83 -19
- mindspore/ops/operations/debug_ops.py +42 -10
- mindspore/ops/operations/manually_defined/_inner.py +12 -0
- mindspore/ops/operations/manually_defined/ops_def.py +265 -10
- mindspore/ops/operations/math_ops.py +12 -223
- mindspore/ops/operations/nn_ops.py +20 -114
- mindspore/ops/operations/other_ops.py +7 -4
- mindspore/ops/operations/random_ops.py +46 -1
- mindspore/ops/primitive.py +18 -6
- mindspore/ops_generate/arg_dtype_cast.py +2 -0
- mindspore/ops_generate/gen_aclnn_implement.py +11 -11
- mindspore/ops_generate/gen_constants.py +36 -0
- mindspore/ops_generate/gen_ops.py +67 -52
- mindspore/ops_generate/gen_ops_inner_prim.py +1 -1
- mindspore/ops_generate/gen_pyboost_func.py +131 -47
- mindspore/ops_generate/op_proto.py +10 -3
- mindspore/ops_generate/pyboost_utils.py +14 -1
- mindspore/ops_generate/template.py +43 -21
- mindspore/parallel/__init__.py +3 -1
- mindspore/parallel/_auto_parallel_context.py +28 -8
- mindspore/parallel/_cell_wrapper.py +83 -0
- mindspore/parallel/_parallel_serialization.py +47 -19
- mindspore/parallel/_tensor.py +81 -11
- mindspore/parallel/_utils.py +13 -1
- mindspore/parallel/algo_parameter_config.py +5 -5
- mindspore/parallel/checkpoint_transform.py +46 -39
- mindspore/parallel/cluster/process_entity/__init__.py +1 -1
- mindspore/parallel/cluster/process_entity/_api.py +31 -23
- mindspore/parallel/cluster/process_entity/_utils.py +2 -27
- mindspore/parallel/parameter_broadcast.py +3 -4
- mindspore/parallel/shard.py +162 -31
- mindspore/parallel/transform_safetensors.py +993 -0
- mindspore/profiler/__init__.py +2 -1
- mindspore/profiler/common/constant.py +29 -0
- mindspore/profiler/common/registry.py +47 -0
- mindspore/profiler/common/util.py +28 -0
- mindspore/profiler/dynamic_profiler.py +694 -0
- mindspore/profiler/envprofiling.py +17 -19
- mindspore/profiler/parser/ascend_analysis/constant.py +18 -0
- mindspore/profiler/parser/ascend_analysis/file_manager.py +25 -4
- mindspore/profiler/parser/ascend_analysis/function_event.py +43 -19
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +31 -26
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +56 -10
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +55 -8
- mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +27 -20
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +9 -2
- mindspore/profiler/parser/ascend_msprof_exporter.py +5 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +27 -25
- mindspore/profiler/parser/base_timeline_generator.py +19 -25
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
- mindspore/profiler/parser/framework_parser.py +1 -391
- mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
- mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
- mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
- mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
- mindspore/profiler/parser/memory_usage_parser.py +0 -154
- mindspore/profiler/parser/profiler_info.py +78 -6
- mindspore/profiler/profiler.py +153 -0
- mindspore/profiler/profiling.py +280 -412
- mindspore/rewrite/__init__.py +1 -2
- mindspore/rewrite/common/namespace.py +4 -4
- mindspore/rewrite/symbol_tree/symbol_tree.py +3 -3
- mindspore/run_check/_check_version.py +36 -103
- mindspore/safeguard/rewrite_obfuscation.py +591 -247
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +4 -3
- mindspore/train/_utils.py +28 -2
- mindspore/train/amp.py +171 -53
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +4 -4
- mindspore/train/callback/_checkpoint.py +85 -22
- mindspore/train/callback/_cluster_monitor.py +1 -1
- mindspore/train/callback/_flops_collector.py +1 -0
- mindspore/train/callback/_loss_monitor.py +3 -3
- mindspore/train/callback/_on_request_exit.py +134 -31
- mindspore/train/callback/_summary_collector.py +5 -5
- mindspore/train/callback/_tft_register.py +352 -0
- mindspore/train/dataset_helper.py +7 -3
- mindspore/train/metrics/metric.py +3 -3
- mindspore/train/metrics/roc.py +4 -4
- mindspore/train/mind_ir_pb2.py +44 -39
- mindspore/train/model.py +134 -58
- mindspore/train/serialization.py +336 -112
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +21 -0
- mindspore/utils/utils.py +60 -0
- mindspore/version.py +1 -1
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/METADATA +6 -2
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/RECORD +258 -252
- mindspore/include/c_api/ms/abstract.h +0 -67
- mindspore/include/c_api/ms/attribute.h +0 -197
- mindspore/include/c_api/ms/base/handle_types.h +0 -43
- mindspore/include/c_api/ms/base/macros.h +0 -32
- mindspore/include/c_api/ms/base/status.h +0 -33
- mindspore/include/c_api/ms/base/types.h +0 -283
- mindspore/include/c_api/ms/context.h +0 -102
- mindspore/include/c_api/ms/graph.h +0 -160
- mindspore/include/c_api/ms/node.h +0 -606
- mindspore/include/c_api/ms/tensor.h +0 -161
- mindspore/include/c_api/ms/value.h +0 -84
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/nn/extend/basic.py +0 -140
- mindspore/nn/extend/embedding.py +0 -143
- mindspore/nn/extend/layer/normalization.py +0 -109
- mindspore/nn/extend/pooling.py +0 -117
- mindspore/nn/layer/embedding_service.py +0 -531
- mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
- mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
- mindspore/ops/extend/__init__.py +0 -53
- mindspore/ops/extend/array_func.py +0 -218
- mindspore/ops/extend/math_func.py +0 -76
- mindspore/ops/extend/nn_func.py +0 -308
- mindspore/ops/silent_check.py +0 -162
- mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
- mindspore/profiler/parser/msadvisor_parser.py +0 -240
- mindspore/train/callback/_mindio_ttp.py +0 -443
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/WHEEL +0 -0
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/entry_points.txt +0 -0
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,190 @@
|
|
|
1
|
+
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""Define pijit context"""
|
|
16
|
+
|
|
17
|
+
import inspect
|
|
18
|
+
import types
|
|
19
|
+
import functools
|
|
20
|
+
import importlib.util
|
|
21
|
+
import mindspore
|
|
22
|
+
from mindspore import log as logger
|
|
23
|
+
from mindspore.common.jit_config import JitConfig
|
|
24
|
+
from mindspore._c_expression import GraphExecutor_, jit_mode_pi_enable, jit_mode_pi_disable, pi_jit_set_context
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _update_graph_executor_config(jit_config):
|
|
28
|
+
"""Update GraphExecutor jit_config"""
|
|
29
|
+
if isinstance(jit_config, JitConfig):
|
|
30
|
+
jit_config = jit_config.jit_config_dict
|
|
31
|
+
if not isinstance(jit_config, dict):
|
|
32
|
+
return
|
|
33
|
+
valid_config = {}
|
|
34
|
+
for k, v in jit_config.items():
|
|
35
|
+
valid_config[str(k)] = str(v)
|
|
36
|
+
GraphExecutor_.get_instance().set_jit_config(JitConfig(**valid_config).jit_config_dict)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class PIJitCaptureContext:
|
|
40
|
+
"""
|
|
41
|
+
Context manager for pijit graph capture
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
def __init__(self, jit_config=None, input_signature=None):
|
|
45
|
+
_update_graph_executor_config(jit_config)
|
|
46
|
+
config = {}
|
|
47
|
+
if isinstance(jit_config, JitConfig):
|
|
48
|
+
config.update(jit_config.jit_config_dict)
|
|
49
|
+
elif jit_config is not None:
|
|
50
|
+
config.update(jit_config)
|
|
51
|
+
|
|
52
|
+
self.config = config
|
|
53
|
+
self.input_signature = input_signature
|
|
54
|
+
self.ret = None
|
|
55
|
+
self.fn = None
|
|
56
|
+
self._init_arg = iter([self.config, self.input_signature])
|
|
57
|
+
|
|
58
|
+
if not SKIP_RULES:
|
|
59
|
+
return
|
|
60
|
+
pi_jit_set_context(wrapper=self._wrapper(),
|
|
61
|
+
skip_files=_get_skip_files(),
|
|
62
|
+
skip_codes=SKIP_RULES["codes"])
|
|
63
|
+
SKIP_RULES.clear()
|
|
64
|
+
|
|
65
|
+
@staticmethod
|
|
66
|
+
def _is_unsupported(fn):
|
|
67
|
+
# generator, coroutine, awaitable and a function that return them is unsupported
|
|
68
|
+
return inspect.isgeneratorfunction(fn) or inspect.iscoroutinefunction(fn) \
|
|
69
|
+
or inspect.isasyncgenfunction(fn) or inspect.isawaitable(fn)
|
|
70
|
+
|
|
71
|
+
def _wrapper(self):
|
|
72
|
+
def _fn(*args, **kwds):
|
|
73
|
+
with self:
|
|
74
|
+
self.ret = self.fn(*args, **kwds)
|
|
75
|
+
return self.ret
|
|
76
|
+
return _fn
|
|
77
|
+
|
|
78
|
+
def __call__(self, fn):
|
|
79
|
+
if isinstance(fn, type) and issubclass(fn, mindspore.nn.Cell):
|
|
80
|
+
fn.construct = self(fn.construct)
|
|
81
|
+
return fn
|
|
82
|
+
if isinstance(fn, mindspore.nn.Cell):
|
|
83
|
+
type(fn).construct = self(type(fn).construct)
|
|
84
|
+
return fn
|
|
85
|
+
if isinstance(fn, types.MethodType):
|
|
86
|
+
return types.MethodType(self(fn.__func__), fn.__self__)
|
|
87
|
+
if not isinstance(fn, types.FunctionType) or self._is_unsupported(fn):
|
|
88
|
+
logger.warning("unsupported function type" + str(fn))
|
|
89
|
+
return fn
|
|
90
|
+
|
|
91
|
+
try:
|
|
92
|
+
if inspect.getmodule(fn.__code__).__name__.startswith("mindspore"):
|
|
93
|
+
return fn
|
|
94
|
+
finally:
|
|
95
|
+
pass
|
|
96
|
+
|
|
97
|
+
_fn = self._wrapper()
|
|
98
|
+
if fn.__code__ is _fn.__code__:
|
|
99
|
+
fn = fn.__closure__[0].cell_contents.fn
|
|
100
|
+
self.fn = fn
|
|
101
|
+
return functools.wraps(fn)(_fn)
|
|
102
|
+
|
|
103
|
+
def __enter__(self):
|
|
104
|
+
pi_jit_set_context(self.fn, *self._init_arg)
|
|
105
|
+
jit_mode_pi_enable()
|
|
106
|
+
|
|
107
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
108
|
+
pi_jit_set_context(None)
|
|
109
|
+
jit_mode_pi_disable()
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def _get_skip_files():
|
|
113
|
+
"""
|
|
114
|
+
Get skip files by SKIP_RULES
|
|
115
|
+
"""
|
|
116
|
+
def _filter(path: str):
|
|
117
|
+
if path.endswith("__init__.py"):
|
|
118
|
+
return path[0:-11]
|
|
119
|
+
return path
|
|
120
|
+
|
|
121
|
+
# not import these modules, only find it
|
|
122
|
+
find = importlib.util.find_spec
|
|
123
|
+
|
|
124
|
+
files = [*SKIP_RULES["skip_dirs"]]
|
|
125
|
+
files += [_filter(find(m).origin) for m in SKIP_RULES["builtins"]]
|
|
126
|
+
for i in SKIP_RULES["third_party"]:
|
|
127
|
+
spec = find(i)
|
|
128
|
+
if spec is None:
|
|
129
|
+
continue
|
|
130
|
+
files.append(_filter(spec.origin))
|
|
131
|
+
|
|
132
|
+
return tuple(files)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
# complete the skip list...
|
|
136
|
+
SKIP_RULES = {
|
|
137
|
+
"skip_dirs": (
|
|
138
|
+
"<frozen importlib",
|
|
139
|
+
"<__array_function__ internals>",
|
|
140
|
+
"<string>",
|
|
141
|
+
),
|
|
142
|
+
"builtins": (
|
|
143
|
+
"mindspore", # not capture any function of mindspore unless it's called by user
|
|
144
|
+
"abc",
|
|
145
|
+
"ast",
|
|
146
|
+
"codecs",
|
|
147
|
+
"collections",
|
|
148
|
+
"contextlib",
|
|
149
|
+
"copy",
|
|
150
|
+
"copyreg",
|
|
151
|
+
"dataclasses",
|
|
152
|
+
"enum",
|
|
153
|
+
"functools",
|
|
154
|
+
"glob",
|
|
155
|
+
"importlib",
|
|
156
|
+
"inspect",
|
|
157
|
+
"linecache",
|
|
158
|
+
"logging",
|
|
159
|
+
"multiprocessing",
|
|
160
|
+
"operator",
|
|
161
|
+
"os",
|
|
162
|
+
"posixpath",
|
|
163
|
+
"random",
|
|
164
|
+
"re",
|
|
165
|
+
"selectors",
|
|
166
|
+
"signal",
|
|
167
|
+
"tempfile",
|
|
168
|
+
"threading",
|
|
169
|
+
"tokenize",
|
|
170
|
+
"traceback",
|
|
171
|
+
"types",
|
|
172
|
+
"typing",
|
|
173
|
+
"unittest",
|
|
174
|
+
"weakref",
|
|
175
|
+
"_collections_abc",
|
|
176
|
+
"_weakrefset",
|
|
177
|
+
# others...
|
|
178
|
+
"sre_compile",
|
|
179
|
+
"sre_parse",
|
|
180
|
+
"genericpath",
|
|
181
|
+
),
|
|
182
|
+
"third_party": (
|
|
183
|
+
"numpy",
|
|
184
|
+
"pandas",
|
|
185
|
+
"sklearn",
|
|
186
|
+
"tqdm",
|
|
187
|
+
"tree",
|
|
188
|
+
),
|
|
189
|
+
"codes": (),
|
|
190
|
+
}
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2020-
|
|
1
|
+
# Copyright 2020-2024 Huawei Technologies Co., Ltd
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -43,3 +43,4 @@ class Registry:
|
|
|
43
43
|
|
|
44
44
|
|
|
45
45
|
tensor_operator_registry = Registry()
|
|
46
|
+
tensor_operator_registry_for_mint = Registry()
|
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
|
|
16
|
+
"""Mint adaptor."""
|
|
17
|
+
|
|
18
|
+
from __future__ import absolute_import
|
|
19
|
+
import os
|
|
20
|
+
from mindspore.common._register_for_tensor import tensor_operator_registry_for_mint
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def repeat_interleave_mint(orig_fn):
|
|
24
|
+
"""
|
|
25
|
+
repeat_interleave wrapper.
|
|
26
|
+
For details, please refer to :func:`mindspore.ops.repeat_interleave_ext`.
|
|
27
|
+
"""
|
|
28
|
+
def wrapper(self, *args, **kwargs):
|
|
29
|
+
if os.environ.get('MS_TENSOR_API_ENABLE_MINT') == '1':
|
|
30
|
+
return tensor_operator_registry_for_mint.get('repeat_interleave')(self, *args, **kwargs)
|
|
31
|
+
return orig_fn(self, *args, **kwargs)
|
|
32
|
+
return wrapper
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def isnan_mint(orig_fn):
|
|
36
|
+
"""
|
|
37
|
+
isnan wrapper.
|
|
38
|
+
"""
|
|
39
|
+
def wrapper(self, *args, **kwargs):
|
|
40
|
+
if os.environ.get('MS_TENSOR_API_ENABLE_MINT') == '1':
|
|
41
|
+
return tensor_operator_registry_for_mint.get('ne')(self, self, **kwargs)
|
|
42
|
+
return orig_fn(self, *args, **kwargs)
|
|
43
|
+
return wrapper
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def add_mint(add):
|
|
47
|
+
"""
|
|
48
|
+
add wrapper
|
|
49
|
+
"""
|
|
50
|
+
def wrapper(self, other, **kwargs):
|
|
51
|
+
if os.environ.get('MS_TENSOR_API_ENABLE_MINT') == '1':
|
|
52
|
+
return tensor_operator_registry_for_mint.get('add')(self, other, **kwargs)
|
|
53
|
+
return add(self, other, **kwargs)
|
|
54
|
+
return wrapper
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def flatten_mint(flatten):
|
|
58
|
+
"""
|
|
59
|
+
flatten wrapper
|
|
60
|
+
"""
|
|
61
|
+
def wrapper(self, *args, **kwargs):
|
|
62
|
+
if os.environ.get('MS_TENSOR_API_ENABLE_MINT') == '1':
|
|
63
|
+
if args:
|
|
64
|
+
kwargs["start_dim"] = args[0]
|
|
65
|
+
if len(args) > 1:
|
|
66
|
+
kwargs["end_dim"] = args[1]
|
|
67
|
+
if "start_dim" not in kwargs:
|
|
68
|
+
kwargs["start_dim"] = 0
|
|
69
|
+
if "end_dim" not in kwargs:
|
|
70
|
+
kwargs["end_dim"] = -1
|
|
71
|
+
return tensor_operator_registry_for_mint.get('flatten')(self, **kwargs)
|
|
72
|
+
return flatten(self, *args, **kwargs)
|
|
73
|
+
return wrapper
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def item_mint(fn):
|
|
77
|
+
"""
|
|
78
|
+
item wrapper
|
|
79
|
+
"""
|
|
80
|
+
def wrapper(self, *args, **kwargs):
|
|
81
|
+
if os.environ.get('MS_TENSOR_API_ENABLE_MINT') == '1':
|
|
82
|
+
return tensor_operator_registry_for_mint.get('item')(self, *args, **kwargs)
|
|
83
|
+
return fn(self, *args, **kwargs)
|
|
84
|
+
return wrapper
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def max_mint(fn):
|
|
88
|
+
"""
|
|
89
|
+
max wrapper
|
|
90
|
+
"""
|
|
91
|
+
def wrapper(self, *args, **kwargs):
|
|
92
|
+
if os.environ.get('MS_TENSOR_API_ENABLE_MINT') == '1':
|
|
93
|
+
return tensor_operator_registry_for_mint.get('max')(self, *args, **kwargs)
|
|
94
|
+
return fn(self, *args, **kwargs)
|
|
95
|
+
return wrapper
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def mean_mint(fn):
|
|
99
|
+
"""
|
|
100
|
+
mean wrapper
|
|
101
|
+
"""
|
|
102
|
+
def wrapper(self, *args, **kwargs):
|
|
103
|
+
if os.environ.get('MS_TENSOR_API_ENABLE_MINT') == '1':
|
|
104
|
+
return tensor_operator_registry_for_mint.get('mean')(self, *args, **kwargs)
|
|
105
|
+
return fn(self, *args, **kwargs)
|
|
106
|
+
return wrapper
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def min_mint(fn):
|
|
110
|
+
"""
|
|
111
|
+
min wrapper
|
|
112
|
+
"""
|
|
113
|
+
def wrapper(self, *args, **kwargs):
|
|
114
|
+
if os.environ.get('MS_TENSOR_API_ENABLE_MINT') == '1':
|
|
115
|
+
return tensor_operator_registry_for_mint.get('min')(self, *args, **kwargs)
|
|
116
|
+
return fn(self, *args, **kwargs)
|
|
117
|
+
return wrapper
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def split_mint(split):
|
|
121
|
+
"""
|
|
122
|
+
split wrapper
|
|
123
|
+
"""
|
|
124
|
+
def wrapper(self, *args, **kwargs):
|
|
125
|
+
if os.environ.get('MS_TENSOR_API_ENABLE_MINT') == '1':
|
|
126
|
+
return tensor_operator_registry_for_mint.get('split')(self, *args, **kwargs)
|
|
127
|
+
return split(self, *args, **kwargs)
|
|
128
|
+
return wrapper
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def sub_mint(sub):
|
|
132
|
+
"""
|
|
133
|
+
sub wrapper
|
|
134
|
+
"""
|
|
135
|
+
def wrapper(self, *args, **kwargs):
|
|
136
|
+
if os.environ.get('MS_TENSOR_API_ENABLE_MINT') == '1':
|
|
137
|
+
return tensor_operator_registry_for_mint.get('sub')(self, *args, **kwargs)
|
|
138
|
+
return sub(self, *args, **kwargs)
|
|
139
|
+
return wrapper
|