cortexnet 3.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cortexnet/__init__.py +197 -0
- cortexnet/adapter/__init__.py +26 -0
- cortexnet/adapter/arch_adapter.py +209 -0
- cortexnet/adapter/calibrator.py +244 -0
- cortexnet/adapter/inference_adapter.py +272 -0
- cortexnet/adapter/model_registry.py +378 -0
- cortexnet/adapter/weight_adapter.py +415 -0
- cortexnet/adversarial.py +195 -0
- cortexnet/attention.py +520 -0
- cortexnet/blocks.py +682 -0
- cortexnet/cache.py +83 -0
- cortexnet/causal_reasoning.py +232 -0
- cortexnet/compat.py +245 -0
- cortexnet/config.py +234 -0
- cortexnet/continual_learning.py +256 -0
- cortexnet/cortex_block_lite.py +221 -0
- cortexnet/distributed.py +213 -0
- cortexnet/graph_reasoning.py +207 -0
- cortexnet/hierarchical_memory.py +360 -0
- cortexnet/interpretability.py +196 -0
- cortexnet/memory.py +179 -0
- cortexnet/meta_learning.py +187 -0
- cortexnet/model.py +1360 -0
- cortexnet/multi_agent.py +241 -0
- cortexnet/multimodal.py +278 -0
- cortexnet/ops/__init__.py +28 -0
- cortexnet/ops/device_manager.py +449 -0
- cortexnet/ops/npu_ops.py +243 -0
- cortexnet/quantization.py +496 -0
- cortexnet/routing.py +335 -0
- cortexnet/self_evolution.py +174 -0
- cortexnet/ssm.py +340 -0
- cortexnet/training_utils.py +204 -0
- cortexnet/transformer_baseline.py +157 -0
- cortexnet-3.2.1.dist-info/METADATA +114 -0
- cortexnet-3.2.1.dist-info/RECORD +39 -0
- cortexnet-3.2.1.dist-info/WHEEL +5 -0
- cortexnet-3.2.1.dist-info/licenses/LICENSE +201 -0
- cortexnet-3.2.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,449 @@
|
|
|
1
|
+
"""
|
|
2
|
+
硬件设备管理器 (Device Manager)
|
|
3
|
+
|
|
4
|
+
自动检测可用硬件(NVIDIA GPU / 昇腾 NPU / 寒武纪 MLU / Apple MPS / CPU),
|
|
5
|
+
提供最优设备选择和配置策略。
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import logging
|
|
11
|
+
from dataclasses import dataclass
|
|
12
|
+
from typing import Optional, List, Dict, Any, Tuple, Sequence
|
|
13
|
+
|
|
14
|
+
import torch
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
SUPPORTED_DEVICE_TYPES = {"auto", "cpu", "cuda", "mps", "npu", "mlu"}
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _try_import_torch_npu() -> bool:
|
|
21
|
+
"""尝试加载 torch_npu,使 torch.npu 能被正确注册。"""
|
|
22
|
+
if hasattr(torch, "npu"):
|
|
23
|
+
return True
|
|
24
|
+
try:
|
|
25
|
+
import torch_npu # type: ignore # noqa: F401
|
|
26
|
+
_ = torch_npu
|
|
27
|
+
except Exception as exc: # pragma: no cover - 依赖环境差异
|
|
28
|
+
logger.debug("torch_npu import failed: %s", exc)
|
|
29
|
+
return False
|
|
30
|
+
return hasattr(torch, "npu")
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _try_import_torch_mlu() -> bool:
|
|
34
|
+
"""尝试加载 torch_mlu,使 torch.mlu 能被正确注册。"""
|
|
35
|
+
if hasattr(torch, "mlu"):
|
|
36
|
+
return True
|
|
37
|
+
try:
|
|
38
|
+
import torch_mlu # type: ignore # noqa: F401
|
|
39
|
+
_ = torch_mlu
|
|
40
|
+
except Exception as exc: # pragma: no cover - 依赖环境差异
|
|
41
|
+
logger.debug("torch_mlu import failed: %s", exc)
|
|
42
|
+
return False
|
|
43
|
+
return hasattr(torch, "mlu")
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def is_npu_available() -> bool:
|
|
47
|
+
"""NPU 是否可用(含 torch_npu 惰性注册)。"""
|
|
48
|
+
_try_import_torch_npu()
|
|
49
|
+
npu = getattr(torch, "npu", None)
|
|
50
|
+
if npu is None:
|
|
51
|
+
return False
|
|
52
|
+
try:
|
|
53
|
+
return bool(npu.is_available())
|
|
54
|
+
except Exception as exc: # pragma: no cover - 依赖环境差异
|
|
55
|
+
logger.debug("torch.npu.is_available() failed: %s", exc)
|
|
56
|
+
return False
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def is_mlu_available() -> bool:
|
|
60
|
+
"""MLU 是否可用(含 torch_mlu 惰性注册)。"""
|
|
61
|
+
_try_import_torch_mlu()
|
|
62
|
+
mlu = getattr(torch, "mlu", None)
|
|
63
|
+
if mlu is None:
|
|
64
|
+
return False
|
|
65
|
+
try:
|
|
66
|
+
return bool(mlu.is_available())
|
|
67
|
+
except Exception as exc: # pragma: no cover - 依赖环境差异
|
|
68
|
+
logger.debug("torch.mlu.is_available() failed: %s", exc)
|
|
69
|
+
return False
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _parse_device_request(requested: Optional[str]) -> Tuple[str, Optional[int]]:
|
|
73
|
+
value = "auto" if requested is None else str(requested).strip().lower()
|
|
74
|
+
if not value:
|
|
75
|
+
value = "auto"
|
|
76
|
+
if value == "gpu":
|
|
77
|
+
value = "cuda"
|
|
78
|
+
|
|
79
|
+
if ":" in value:
|
|
80
|
+
device_type, index_raw = value.split(":", 1)
|
|
81
|
+
if not index_raw:
|
|
82
|
+
raise ValueError(f"Invalid device string: {requested}")
|
|
83
|
+
try:
|
|
84
|
+
device_index = int(index_raw)
|
|
85
|
+
except ValueError as exc:
|
|
86
|
+
raise ValueError(f"Invalid device index in '{requested}'") from exc
|
|
87
|
+
if device_index < 0:
|
|
88
|
+
raise ValueError(f"Device index must be >= 0 in '{requested}'")
|
|
89
|
+
else:
|
|
90
|
+
device_type, device_index = value, None
|
|
91
|
+
|
|
92
|
+
if device_type not in SUPPORTED_DEVICE_TYPES:
|
|
93
|
+
supported = ", ".join(sorted(SUPPORTED_DEVICE_TYPES))
|
|
94
|
+
raise ValueError(f"Unsupported device '{requested}'. Supported: {supported}")
|
|
95
|
+
return device_type, device_index
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def get_device_type(device: Optional[str]) -> str:
|
|
99
|
+
"""返回设备类型(去掉索引,如 npu:0 -> npu)。"""
|
|
100
|
+
device_type, _ = _parse_device_request(device)
|
|
101
|
+
if device_type == "auto":
|
|
102
|
+
resolved = resolve_device_string("auto")
|
|
103
|
+
device_type, _ = _parse_device_request(resolved)
|
|
104
|
+
return device_type
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def _device_count(device_type: str) -> int:
|
|
108
|
+
if device_type == "cuda":
|
|
109
|
+
return int(torch.cuda.device_count()) if torch.cuda.is_available() else 0
|
|
110
|
+
if device_type == "npu":
|
|
111
|
+
if not is_npu_available():
|
|
112
|
+
return 0
|
|
113
|
+
npu = getattr(torch, "npu", None)
|
|
114
|
+
if npu is None:
|
|
115
|
+
return 0
|
|
116
|
+
try:
|
|
117
|
+
return int(npu.device_count())
|
|
118
|
+
except Exception: # pragma: no cover - 依赖环境差异
|
|
119
|
+
return 0
|
|
120
|
+
if device_type == "mlu":
|
|
121
|
+
if not is_mlu_available():
|
|
122
|
+
return 0
|
|
123
|
+
mlu = getattr(torch, "mlu", None)
|
|
124
|
+
if mlu is None:
|
|
125
|
+
return 0
|
|
126
|
+
try:
|
|
127
|
+
return int(mlu.device_count())
|
|
128
|
+
except Exception: # pragma: no cover - 依赖环境差异
|
|
129
|
+
return 0
|
|
130
|
+
if device_type == "mps":
|
|
131
|
+
return int(
|
|
132
|
+
hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
|
|
133
|
+
)
|
|
134
|
+
if device_type == "cpu":
|
|
135
|
+
return 1
|
|
136
|
+
return 0
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def _is_device_available(device_type: str) -> bool:
|
|
140
|
+
if device_type == "cuda":
|
|
141
|
+
return torch.cuda.is_available()
|
|
142
|
+
if device_type == "npu":
|
|
143
|
+
return is_npu_available()
|
|
144
|
+
if device_type == "mlu":
|
|
145
|
+
return is_mlu_available()
|
|
146
|
+
if device_type == "mps":
|
|
147
|
+
return bool(hasattr(torch.backends, "mps") and torch.backends.mps.is_available())
|
|
148
|
+
if device_type == "cpu":
|
|
149
|
+
return True
|
|
150
|
+
return False
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def _format_device(device_type: str, device_index: Optional[int]) -> str:
|
|
154
|
+
if device_type in {"cpu", "mps"} or device_index is None:
|
|
155
|
+
return device_type
|
|
156
|
+
return f"{device_type}:{device_index}"
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def resolve_device_string(
|
|
160
|
+
requested: Optional[str] = "auto",
|
|
161
|
+
*,
|
|
162
|
+
auto_priority: Sequence[str] = ("npu", "cuda", "mlu", "mps", "cpu"),
|
|
163
|
+
allow_fallback: bool = True,
|
|
164
|
+
) -> str:
|
|
165
|
+
"""统一解析设备字符串,支持 NPU 运行时惰性加载与回退。"""
|
|
166
|
+
device_type, device_index = _parse_device_request(requested)
|
|
167
|
+
|
|
168
|
+
if device_type == "auto":
|
|
169
|
+
for candidate in auto_priority:
|
|
170
|
+
if candidate in SUPPORTED_DEVICE_TYPES and candidate != "auto" and _is_device_available(candidate):
|
|
171
|
+
return _format_device(candidate, None)
|
|
172
|
+
return "cpu"
|
|
173
|
+
|
|
174
|
+
if not _is_device_available(device_type):
|
|
175
|
+
if allow_fallback:
|
|
176
|
+
fallback = resolve_device_string(
|
|
177
|
+
"auto",
|
|
178
|
+
auto_priority=auto_priority,
|
|
179
|
+
allow_fallback=False,
|
|
180
|
+
)
|
|
181
|
+
logger.warning(
|
|
182
|
+
"Requested device '%s' is unavailable, fallback to '%s'.",
|
|
183
|
+
requested,
|
|
184
|
+
fallback,
|
|
185
|
+
)
|
|
186
|
+
return fallback
|
|
187
|
+
|
|
188
|
+
hint = ""
|
|
189
|
+
if device_type == "npu":
|
|
190
|
+
hint = " (hint: install and configure torch_npu + CANN runtime)"
|
|
191
|
+
if device_type == "mlu":
|
|
192
|
+
hint = " (hint: install and configure torch_mlu runtime)"
|
|
193
|
+
raise RuntimeError(f"Requested device '{requested}' is unavailable{hint}.")
|
|
194
|
+
|
|
195
|
+
if device_index is not None:
|
|
196
|
+
if device_type in {"cpu", "mps"}:
|
|
197
|
+
logger.warning(
|
|
198
|
+
"Device '%s' does not use index; ignore index %s.",
|
|
199
|
+
device_type,
|
|
200
|
+
device_index,
|
|
201
|
+
)
|
|
202
|
+
device_index = None
|
|
203
|
+
else:
|
|
204
|
+
count = _device_count(device_type)
|
|
205
|
+
if count > 0 and device_index >= count:
|
|
206
|
+
if allow_fallback:
|
|
207
|
+
logger.warning(
|
|
208
|
+
"Requested %s index=%s out of range (count=%s), fallback to index=0.",
|
|
209
|
+
device_type,
|
|
210
|
+
device_index,
|
|
211
|
+
count,
|
|
212
|
+
)
|
|
213
|
+
device_index = 0
|
|
214
|
+
else:
|
|
215
|
+
raise RuntimeError(
|
|
216
|
+
f"Requested {device_type}:{device_index} out of range. "
|
|
217
|
+
f"Available count={count}."
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
return _format_device(device_type, device_index)
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
def resolve_dtype_for_device(
|
|
224
|
+
dtype: Optional[str | torch.dtype],
|
|
225
|
+
device: Optional[str],
|
|
226
|
+
) -> torch.dtype:
|
|
227
|
+
"""按设备统一解析 dtype(含 mps/npu 兼容回退)。"""
|
|
228
|
+
if isinstance(dtype, torch.dtype):
|
|
229
|
+
resolved_dtype = dtype
|
|
230
|
+
else:
|
|
231
|
+
dtype_name = "auto" if dtype is None else str(dtype).lower()
|
|
232
|
+
if dtype_name == "auto":
|
|
233
|
+
base = get_device_type(device)
|
|
234
|
+
return torch.float16 if base in {"cuda", "mps", "npu", "mlu"} else torch.float32
|
|
235
|
+
mapping = {
|
|
236
|
+
"float16": torch.float16,
|
|
237
|
+
"bfloat16": torch.bfloat16,
|
|
238
|
+
"float32": torch.float32,
|
|
239
|
+
}
|
|
240
|
+
if dtype_name not in mapping:
|
|
241
|
+
raise ValueError(f"Unsupported dtype: {dtype}")
|
|
242
|
+
resolved_dtype = mapping[dtype_name]
|
|
243
|
+
|
|
244
|
+
base = get_device_type(device)
|
|
245
|
+
if base == "mps" and resolved_dtype == torch.bfloat16:
|
|
246
|
+
return torch.float16
|
|
247
|
+
if base == "npu" and resolved_dtype == torch.bfloat16:
|
|
248
|
+
npu = getattr(torch, "npu", None)
|
|
249
|
+
if npu is not None and hasattr(npu, "is_bf16_supported"):
|
|
250
|
+
try:
|
|
251
|
+
if not bool(npu.is_bf16_supported()):
|
|
252
|
+
return torch.float16
|
|
253
|
+
except Exception: # pragma: no cover - 依赖环境差异
|
|
254
|
+
return torch.float16
|
|
255
|
+
return resolved_dtype
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
@dataclass
|
|
259
|
+
class DeviceInfo:
|
|
260
|
+
"""硬件设备信息。"""
|
|
261
|
+
device_type: str # "cuda", "npu", "mlu", "mps", "cpu"
|
|
262
|
+
device_name: str # 设备名称
|
|
263
|
+
device_index: int = 0 # 设备索引
|
|
264
|
+
total_memory_gb: float = 0.0 # 总显存/内存 (GB)
|
|
265
|
+
compute_capability: str = "" # 计算能力版本
|
|
266
|
+
supports_bf16: bool = False
|
|
267
|
+
supports_fp16: bool = False
|
|
268
|
+
supports_flash_attn: bool = False
|
|
269
|
+
dynamic_graph_support: str = "full" # "full", "limited", "none"
|
|
270
|
+
|
|
271
|
+
@property
|
|
272
|
+
def torch_device(self) -> torch.device:
|
|
273
|
+
if self.device_type == "cpu":
|
|
274
|
+
return torch.device("cpu")
|
|
275
|
+
return torch.device(f"{self.device_type}:{self.device_index}")
|
|
276
|
+
|
|
277
|
+
@property
|
|
278
|
+
def optimal_dtype(self) -> torch.dtype:
|
|
279
|
+
if self.supports_bf16:
|
|
280
|
+
return torch.bfloat16
|
|
281
|
+
if self.supports_fp16:
|
|
282
|
+
return torch.float16
|
|
283
|
+
return torch.float32
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
class DeviceManager:
|
|
287
|
+
"""硬件设备管理器。
|
|
288
|
+
|
|
289
|
+
自动检测可用硬件并提供最优配置建议。
|
|
290
|
+
"""
|
|
291
|
+
|
|
292
|
+
def __init__(self):
|
|
293
|
+
self.devices: List[DeviceInfo] = []
|
|
294
|
+
self._detect_devices()
|
|
295
|
+
|
|
296
|
+
def _detect_devices(self):
|
|
297
|
+
"""检测所有可用硬件设备。"""
|
|
298
|
+
# CUDA GPU
|
|
299
|
+
if torch.cuda.is_available():
|
|
300
|
+
for i in range(torch.cuda.device_count()):
|
|
301
|
+
props = torch.cuda.get_device_properties(i)
|
|
302
|
+
total_memory = getattr(props, "total_memory", getattr(props, "total_mem", 0))
|
|
303
|
+
mem_gb = total_memory / (1024 ** 3) if total_memory else 0.0
|
|
304
|
+
cc = f"{props.major}.{props.minor}"
|
|
305
|
+
supports_bf16 = props.major >= 8 # Ampere+
|
|
306
|
+
supports_flash = props.major >= 8
|
|
307
|
+
|
|
308
|
+
self.devices.append(DeviceInfo(
|
|
309
|
+
device_type="cuda",
|
|
310
|
+
device_name=props.name,
|
|
311
|
+
device_index=i,
|
|
312
|
+
total_memory_gb=round(mem_gb, 1),
|
|
313
|
+
compute_capability=cc,
|
|
314
|
+
supports_bf16=supports_bf16,
|
|
315
|
+
supports_fp16=True,
|
|
316
|
+
supports_flash_attn=supports_flash,
|
|
317
|
+
dynamic_graph_support="full",
|
|
318
|
+
))
|
|
319
|
+
|
|
320
|
+
# 昇腾 NPU (torch_npu)
|
|
321
|
+
if is_npu_available():
|
|
322
|
+
npu = getattr(torch, "npu", None)
|
|
323
|
+
count = _device_count("npu")
|
|
324
|
+
if npu is not None:
|
|
325
|
+
for i in range(count):
|
|
326
|
+
name = f"Ascend NPU {i}"
|
|
327
|
+
mem_gb = 0.0
|
|
328
|
+
try:
|
|
329
|
+
queried_name = npu.get_device_name(i)
|
|
330
|
+
if queried_name:
|
|
331
|
+
name = str(queried_name)
|
|
332
|
+
except Exception: # pragma: no cover - 依赖环境差异
|
|
333
|
+
pass
|
|
334
|
+
try:
|
|
335
|
+
props = npu.get_device_properties(i)
|
|
336
|
+
mem = getattr(props, "total_memory", 0)
|
|
337
|
+
if mem:
|
|
338
|
+
mem_gb = mem / (1024 ** 3)
|
|
339
|
+
except Exception: # pragma: no cover - 依赖环境差异
|
|
340
|
+
pass
|
|
341
|
+
|
|
342
|
+
supports_bf16 = True
|
|
343
|
+
if hasattr(npu, "is_bf16_supported"):
|
|
344
|
+
try:
|
|
345
|
+
supports_bf16 = bool(npu.is_bf16_supported())
|
|
346
|
+
except Exception: # pragma: no cover - 依赖环境差异
|
|
347
|
+
supports_bf16 = True
|
|
348
|
+
|
|
349
|
+
self.devices.append(DeviceInfo(
|
|
350
|
+
device_type="npu",
|
|
351
|
+
device_name=name,
|
|
352
|
+
device_index=i,
|
|
353
|
+
total_memory_gb=round(mem_gb, 1),
|
|
354
|
+
supports_bf16=supports_bf16,
|
|
355
|
+
supports_fp16=True,
|
|
356
|
+
supports_flash_attn=False,
|
|
357
|
+
dynamic_graph_support="limited",
|
|
358
|
+
))
|
|
359
|
+
|
|
360
|
+
# 寒武纪 MLU (torch_mlu)
|
|
361
|
+
if is_mlu_available():
|
|
362
|
+
count = _device_count("mlu")
|
|
363
|
+
for i in range(count):
|
|
364
|
+
self.devices.append(DeviceInfo(
|
|
365
|
+
device_type="mlu",
|
|
366
|
+
device_name=f"Cambricon MLU {i}",
|
|
367
|
+
device_index=i,
|
|
368
|
+
supports_bf16=False,
|
|
369
|
+
supports_fp16=True,
|
|
370
|
+
supports_flash_attn=False,
|
|
371
|
+
dynamic_graph_support="limited",
|
|
372
|
+
))
|
|
373
|
+
|
|
374
|
+
# Apple MPS
|
|
375
|
+
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
|
376
|
+
self.devices.append(DeviceInfo(
|
|
377
|
+
device_type="mps",
|
|
378
|
+
device_name="Apple Silicon GPU",
|
|
379
|
+
device_index=0,
|
|
380
|
+
supports_bf16=False,
|
|
381
|
+
supports_fp16=True,
|
|
382
|
+
supports_flash_attn=False,
|
|
383
|
+
dynamic_graph_support="full",
|
|
384
|
+
))
|
|
385
|
+
|
|
386
|
+
# CPU (always available)
|
|
387
|
+
self.devices.append(DeviceInfo(
|
|
388
|
+
device_type="cpu",
|
|
389
|
+
device_name="CPU",
|
|
390
|
+
device_index=0,
|
|
391
|
+
supports_bf16=True,
|
|
392
|
+
supports_fp16=True,
|
|
393
|
+
supports_flash_attn=False,
|
|
394
|
+
dynamic_graph_support="full",
|
|
395
|
+
))
|
|
396
|
+
|
|
397
|
+
def get_best_device(self) -> DeviceInfo:
|
|
398
|
+
"""返回最佳可用设备(按优先级:CUDA > NPU > MLU > MPS > CPU)。"""
|
|
399
|
+
priority = {"cuda": 0, "npu": 1, "mlu": 2, "mps": 3, "cpu": 4}
|
|
400
|
+
sorted_devices = sorted(
|
|
401
|
+
self.devices,
|
|
402
|
+
key=lambda d: (priority.get(d.device_type, 99), -d.total_memory_gb),
|
|
403
|
+
)
|
|
404
|
+
best = sorted_devices[0]
|
|
405
|
+
logger.info(
|
|
406
|
+
f"Best device: {best.device_name} ({best.device_type}:{best.device_index}), "
|
|
407
|
+
f"memory={best.total_memory_gb}GB, dtype={best.optimal_dtype}"
|
|
408
|
+
)
|
|
409
|
+
return best
|
|
410
|
+
|
|
411
|
+
def get_optimization_config(self, device: Optional[DeviceInfo] = None) -> Dict[str, Any]:
|
|
412
|
+
"""根据设备生成优化配置。"""
|
|
413
|
+
if device is None:
|
|
414
|
+
device = self.get_best_device()
|
|
415
|
+
|
|
416
|
+
config = {
|
|
417
|
+
"device": str(device.torch_device),
|
|
418
|
+
"dtype": str(device.optimal_dtype),
|
|
419
|
+
"use_flash_attention": device.supports_flash_attn,
|
|
420
|
+
"use_gradient_checkpointing": device.total_memory_gb < 24,
|
|
421
|
+
"dynamic_path_mode": "batch_static" if device.dynamic_graph_support == "limited" else "dynamic",
|
|
422
|
+
}
|
|
423
|
+
|
|
424
|
+
# NPU 特殊配置
|
|
425
|
+
if device.device_type in ("npu", "mlu"):
|
|
426
|
+
config.update({
|
|
427
|
+
"use_static_graph": True,
|
|
428
|
+
"max_batch_size": max(1, int(device.total_memory_gb // 8)),
|
|
429
|
+
})
|
|
430
|
+
|
|
431
|
+
return config
|
|
432
|
+
|
|
433
|
+
def list_devices(self) -> List[Dict[str, Any]]:
|
|
434
|
+
"""列出所有检测到的设备。"""
|
|
435
|
+
return [
|
|
436
|
+
{
|
|
437
|
+
"type": d.device_type,
|
|
438
|
+
"name": d.device_name,
|
|
439
|
+
"index": d.device_index,
|
|
440
|
+
"memory_gb": d.total_memory_gb,
|
|
441
|
+
"dtype": str(d.optimal_dtype),
|
|
442
|
+
}
|
|
443
|
+
for d in self.devices
|
|
444
|
+
]
|
|
445
|
+
|
|
446
|
+
|
|
447
|
+
def get_best_device_info() -> DeviceInfo:
|
|
448
|
+
"""便捷函数:获取最佳设备信息。"""
|
|
449
|
+
return DeviceManager().get_best_device()
|
cortexnet/ops/npu_ops.py
ADDED
|
@@ -0,0 +1,243 @@
|
|
|
1
|
+
"""
|
|
2
|
+
NPU 算子抽象层 (NPU Operator Abstraction)
|
|
3
|
+
|
|
4
|
+
为 CortexNet 核心模块提供 NPU 原生算子接口,
|
|
5
|
+
支持昇腾 CANN、寒武纪 CNNL,并包含 CPU/CUDA 模拟回退。
|
|
6
|
+
|
|
7
|
+
核心算子:
|
|
8
|
+
1. ssm_scan_op — SSM 扫描(分块并行)
|
|
9
|
+
2. sparse_attn_op — 稀疏注意力
|
|
10
|
+
3. moe_router_op — MoE 路由(向量化)
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
import logging
|
|
16
|
+
from typing import Optional, Tuple
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
import torch.nn.functional as F
|
|
20
|
+
|
|
21
|
+
try:
|
|
22
|
+
from .device_manager import is_npu_available, is_mlu_available
|
|
23
|
+
except ImportError:
|
|
24
|
+
from device_manager import is_npu_available, is_mlu_available
|
|
25
|
+
|
|
26
|
+
logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class NPUOperators:
|
|
30
|
+
"""NPU 算子管理器。
|
|
31
|
+
|
|
32
|
+
自动检测可用后端,加载对应的原生算子实现。
|
|
33
|
+
如无 NPU 环境则回退到 PyTorch 标准实现。
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(self, backend: str = "auto"):
|
|
37
|
+
"""
|
|
38
|
+
Args:
|
|
39
|
+
backend: "auto", "ascend", "cambricon", "cuda", "cpu"
|
|
40
|
+
"""
|
|
41
|
+
self.backend = self._detect_backend(backend)
|
|
42
|
+
self._ops_loaded = False
|
|
43
|
+
logger.info(f"NPU Operators initialized with backend: {self.backend}")
|
|
44
|
+
|
|
45
|
+
def _detect_backend(self, backend: str) -> str:
|
|
46
|
+
if backend != "auto":
|
|
47
|
+
return backend
|
|
48
|
+
|
|
49
|
+
# 自动检测
|
|
50
|
+
if is_npu_available():
|
|
51
|
+
return "ascend"
|
|
52
|
+
|
|
53
|
+
if is_mlu_available():
|
|
54
|
+
return "cambricon"
|
|
55
|
+
|
|
56
|
+
if torch.cuda.is_available():
|
|
57
|
+
return "cuda"
|
|
58
|
+
|
|
59
|
+
return "cpu"
|
|
60
|
+
|
|
61
|
+
# ═══════════════════════════════════════════════════════════════
|
|
62
|
+
# SSM 扫描算子
|
|
63
|
+
# ═══════════════════════════════════════════════════════════════
|
|
64
|
+
|
|
65
|
+
def ssm_scan(
|
|
66
|
+
self,
|
|
67
|
+
x: torch.Tensor,
|
|
68
|
+
A: torch.Tensor,
|
|
69
|
+
B: torch.Tensor,
|
|
70
|
+
C: torch.Tensor,
|
|
71
|
+
dt: torch.Tensor,
|
|
72
|
+
chunk_size: int = 64,
|
|
73
|
+
past_state: Optional[torch.Tensor] = None,
|
|
74
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
75
|
+
"""SSM 选择性扫描算子。
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
x: (B, L, D) 输入序列
|
|
79
|
+
A: 状态转移矩阵参数
|
|
80
|
+
B: 输入投影参数
|
|
81
|
+
C: 输出投影参数
|
|
82
|
+
dt: 时间步长
|
|
83
|
+
chunk_size: 分块大小
|
|
84
|
+
past_state: 上一步的隐状态
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
output: (B, L, D) 扫描输出
|
|
88
|
+
new_state: 最终隐状态
|
|
89
|
+
"""
|
|
90
|
+
if self.backend == "ascend":
|
|
91
|
+
return self._ssm_scan_ascend(x, A, B, C, dt, chunk_size, past_state)
|
|
92
|
+
elif self.backend == "cambricon":
|
|
93
|
+
return self._ssm_scan_cambricon(x, A, B, C, dt, chunk_size, past_state)
|
|
94
|
+
else:
|
|
95
|
+
return self._ssm_scan_pytorch(x, A, B, C, dt, chunk_size, past_state)
|
|
96
|
+
|
|
97
|
+
def _ssm_scan_pytorch(self, x, A, B, C, dt, chunk_size, past_state):
|
|
98
|
+
"""PyTorch 标准实现(CPU/CUDA 通用)。"""
|
|
99
|
+
batch, seq_len, d_inner = x.shape
|
|
100
|
+
N = A.shape[-1]
|
|
101
|
+
|
|
102
|
+
# 离散化
|
|
103
|
+
dA = torch.exp(A.unsqueeze(0).unsqueeze(0) * dt.unsqueeze(-1)) # (B, L, D, N)
|
|
104
|
+
dB_x = (dt.unsqueeze(-1) * B.unsqueeze(0).unsqueeze(0)) * x.unsqueeze(-1) # (B, L, D, N)
|
|
105
|
+
|
|
106
|
+
# 初始化状态
|
|
107
|
+
h = past_state if past_state is not None else torch.zeros(
|
|
108
|
+
batch, d_inner, N, device=x.device, dtype=x.dtype
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
outputs = []
|
|
112
|
+
for t in range(seq_len):
|
|
113
|
+
h = dA[:, t] * h + dB_x[:, t]
|
|
114
|
+
y_t = (h * C.unsqueeze(0).unsqueeze(0)).sum(-1) # (B, D)
|
|
115
|
+
outputs.append(y_t)
|
|
116
|
+
|
|
117
|
+
output = torch.stack(outputs, dim=1) # (B, L, D)
|
|
118
|
+
return output, h
|
|
119
|
+
|
|
120
|
+
def _ssm_scan_ascend(self, x, A, B, C, dt, chunk_size, past_state):
|
|
121
|
+
"""昇腾 NPU 优化实现(CANN Cube 单元分块扫描)。
|
|
122
|
+
|
|
123
|
+
优化策略:
|
|
124
|
+
- 基于 Cube 计算单元的分块并行扫描
|
|
125
|
+
- 块内并行计算,块间顺序传播
|
|
126
|
+
- 适配昇腾 910B 的张量并行
|
|
127
|
+
"""
|
|
128
|
+
# 分块处理(NPU 友好的静态计算图)
|
|
129
|
+
batch, seq_len, d_inner = x.shape
|
|
130
|
+
N = A.shape[-1]
|
|
131
|
+
|
|
132
|
+
h = past_state if past_state is not None else torch.zeros(
|
|
133
|
+
batch, d_inner, N, device=x.device, dtype=x.dtype
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
dA = torch.exp(A.unsqueeze(0).unsqueeze(0) * dt.unsqueeze(-1))
|
|
137
|
+
dB_x = (dt.unsqueeze(-1) * B.unsqueeze(0).unsqueeze(0)) * x.unsqueeze(-1)
|
|
138
|
+
|
|
139
|
+
outputs = []
|
|
140
|
+
for start in range(0, seq_len, chunk_size):
|
|
141
|
+
end = min(start + chunk_size, seq_len)
|
|
142
|
+
chunk_len = end - start
|
|
143
|
+
|
|
144
|
+
# 块内并行扫描(静态展开)
|
|
145
|
+
chunk_dA = dA[:, start:end]
|
|
146
|
+
chunk_dBx = dB_x[:, start:end]
|
|
147
|
+
|
|
148
|
+
chunk_out = []
|
|
149
|
+
for t in range(chunk_len):
|
|
150
|
+
h = chunk_dA[:, t] * h + chunk_dBx[:, t]
|
|
151
|
+
y_t = (h * C.unsqueeze(0).unsqueeze(0)).sum(-1)
|
|
152
|
+
chunk_out.append(y_t)
|
|
153
|
+
|
|
154
|
+
outputs.extend(chunk_out)
|
|
155
|
+
|
|
156
|
+
output = torch.stack(outputs, dim=1)
|
|
157
|
+
return output, h
|
|
158
|
+
|
|
159
|
+
def _ssm_scan_cambricon(self, x, A, B, C, dt, chunk_size, past_state):
|
|
160
|
+
"""寒武纪 MLU 实现(回退到 PyTorch 通用实现)。"""
|
|
161
|
+
return self._ssm_scan_pytorch(x, A, B, C, dt, chunk_size, past_state)
|
|
162
|
+
|
|
163
|
+
# ═══════════════════════════════════════════════════════════════
|
|
164
|
+
# 稀疏注意力算子
|
|
165
|
+
# ═══════════════════════════════════════════════════════════════
|
|
166
|
+
|
|
167
|
+
def sparse_attention(
|
|
168
|
+
self,
|
|
169
|
+
q: torch.Tensor,
|
|
170
|
+
k: torch.Tensor,
|
|
171
|
+
v: torch.Tensor,
|
|
172
|
+
top_k: int,
|
|
173
|
+
importance_scores: Optional[torch.Tensor] = None,
|
|
174
|
+
) -> torch.Tensor:
|
|
175
|
+
"""稀疏注意力算子。
|
|
176
|
+
|
|
177
|
+
Args:
|
|
178
|
+
q: (B, H, Lq, D) 查询
|
|
179
|
+
k: (B, H, Lk, D) 键
|
|
180
|
+
v: (B, H, Lv, D) 值
|
|
181
|
+
top_k: 选择的 top-k 个键
|
|
182
|
+
importance_scores: (B, Lk) 重要性分数
|
|
183
|
+
|
|
184
|
+
Returns:
|
|
185
|
+
output: (B, H, Lq, D)
|
|
186
|
+
"""
|
|
187
|
+
if self.backend == "ascend":
|
|
188
|
+
return self._sparse_attn_ascend(q, k, v, top_k, importance_scores)
|
|
189
|
+
else:
|
|
190
|
+
return self._sparse_attn_pytorch(q, k, v, top_k, importance_scores)
|
|
191
|
+
|
|
192
|
+
def _sparse_attn_pytorch(self, q, k, v, top_k, importance_scores):
|
|
193
|
+
"""PyTorch 标准稀疏注意力实现。"""
|
|
194
|
+
B, H, Lq, D = q.shape
|
|
195
|
+
Lk = k.shape[2]
|
|
196
|
+
|
|
197
|
+
if importance_scores is not None and top_k < Lk:
|
|
198
|
+
# 选择 top-k 个最重要的 token
|
|
199
|
+
_, top_indices = importance_scores.topk(top_k, dim=-1) # (B, top_k)
|
|
200
|
+
top_indices_exp = top_indices.unsqueeze(1).unsqueeze(-1).expand(B, H, top_k, D)
|
|
201
|
+
k_selected = k.gather(2, top_indices_exp)
|
|
202
|
+
v_selected = v.gather(2, top_indices_exp)
|
|
203
|
+
else:
|
|
204
|
+
k_selected = k
|
|
205
|
+
v_selected = v
|
|
206
|
+
|
|
207
|
+
# SDPA
|
|
208
|
+
output = F.scaled_dot_product_attention(q, k_selected, v_selected, is_causal=False)
|
|
209
|
+
return output
|
|
210
|
+
|
|
211
|
+
def _sparse_attn_ascend(self, q, k, v, top_k, importance_scores):
|
|
212
|
+
"""昇腾 NPU 稀疏注意力(稀疏张量存储优化)。"""
|
|
213
|
+
# 在 NPU 上使用相同逻辑但利用稀疏张量
|
|
214
|
+
return self._sparse_attn_pytorch(q, k, v, top_k, importance_scores)
|
|
215
|
+
|
|
216
|
+
# ═══════════════════════════════════════════════════════════════
|
|
217
|
+
# MoE 路由算子
|
|
218
|
+
# ═══════════════════════════════════════════════════════════════
|
|
219
|
+
|
|
220
|
+
def moe_route(
|
|
221
|
+
self,
|
|
222
|
+
router_logits: torch.Tensor,
|
|
223
|
+
num_active: int = 2,
|
|
224
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
225
|
+
"""MoE 路由算子(向量化 top-k)。
|
|
226
|
+
|
|
227
|
+
Args:
|
|
228
|
+
router_logits: (B * L, num_experts) 路由 logits
|
|
229
|
+
num_active: 每个 token 激活的专家数
|
|
230
|
+
|
|
231
|
+
Returns:
|
|
232
|
+
top_k_weights: (B * L, num_active) 归一化权重
|
|
233
|
+
top_k_indices: (B * L, num_active) 选中的专家索引
|
|
234
|
+
"""
|
|
235
|
+
# 所有后端使用相同的向量化实现
|
|
236
|
+
top_k_logits, top_k_indices = router_logits.topk(num_active, dim=-1)
|
|
237
|
+
top_k_weights = F.softmax(top_k_logits, dim=-1)
|
|
238
|
+
return top_k_weights, top_k_indices
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
def get_operators(backend: str = "auto") -> NPUOperators:
|
|
242
|
+
"""获取算子管理器实例。"""
|
|
243
|
+
return NPUOperators(backend=backend)
|