cortexnet 3.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cortexnet/__init__.py +197 -0
- cortexnet/adapter/__init__.py +26 -0
- cortexnet/adapter/arch_adapter.py +209 -0
- cortexnet/adapter/calibrator.py +244 -0
- cortexnet/adapter/inference_adapter.py +272 -0
- cortexnet/adapter/model_registry.py +378 -0
- cortexnet/adapter/weight_adapter.py +415 -0
- cortexnet/adversarial.py +195 -0
- cortexnet/attention.py +520 -0
- cortexnet/blocks.py +682 -0
- cortexnet/cache.py +83 -0
- cortexnet/causal_reasoning.py +232 -0
- cortexnet/compat.py +245 -0
- cortexnet/config.py +234 -0
- cortexnet/continual_learning.py +256 -0
- cortexnet/cortex_block_lite.py +221 -0
- cortexnet/distributed.py +213 -0
- cortexnet/graph_reasoning.py +207 -0
- cortexnet/hierarchical_memory.py +360 -0
- cortexnet/interpretability.py +196 -0
- cortexnet/memory.py +179 -0
- cortexnet/meta_learning.py +187 -0
- cortexnet/model.py +1360 -0
- cortexnet/multi_agent.py +241 -0
- cortexnet/multimodal.py +278 -0
- cortexnet/ops/__init__.py +28 -0
- cortexnet/ops/device_manager.py +449 -0
- cortexnet/ops/npu_ops.py +243 -0
- cortexnet/quantization.py +496 -0
- cortexnet/routing.py +335 -0
- cortexnet/self_evolution.py +174 -0
- cortexnet/ssm.py +340 -0
- cortexnet/training_utils.py +204 -0
- cortexnet/transformer_baseline.py +157 -0
- cortexnet-3.2.1.dist-info/METADATA +114 -0
- cortexnet-3.2.1.dist-info/RECORD +39 -0
- cortexnet-3.2.1.dist-info/WHEEL +5 -0
- cortexnet-3.2.1.dist-info/licenses/LICENSE +201 -0
- cortexnet-3.2.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,496 @@
|
|
|
1
|
+
"""
|
|
2
|
+
CortexNet 量化封装 (Quantization Wrapper)
|
|
3
|
+
|
|
4
|
+
支持多种量化策略:
|
|
5
|
+
1. Dynamic INT8: PyTorch 原生动态量化,适用于 CPU 推理加速
|
|
6
|
+
2. FP16/BF16: 半精度推理,适用于 GPU
|
|
7
|
+
3. Weight-Only INT8: 仅量化权重,激活保持原精度
|
|
8
|
+
4. Smooth Quantization: 预处理权重使量化误差更均匀
|
|
9
|
+
|
|
10
|
+
用法:
|
|
11
|
+
# 动态 INT8 量化
|
|
12
|
+
model = quantize_dynamic(model)
|
|
13
|
+
|
|
14
|
+
# 统一接口
|
|
15
|
+
model = QuantizationWrapper(model, strategy="dynamic_int8")
|
|
16
|
+
model = QuantizationWrapper(model, strategy="weight_only_int8")
|
|
17
|
+
model = QuantizationWrapper(model, strategy="fp16")
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
from __future__ import annotations
|
|
21
|
+
|
|
22
|
+
import copy
|
|
23
|
+
import logging
|
|
24
|
+
import os
|
|
25
|
+
import warnings
|
|
26
|
+
from typing import Dict, List, Optional, Set, Tuple, Type
|
|
27
|
+
|
|
28
|
+
import torch
|
|
29
|
+
import torch.nn as nn
|
|
30
|
+
import torch.nn.functional as F
|
|
31
|
+
|
|
32
|
+
logger = logging.getLogger(__name__)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class WeightOnlyInt8Linear(nn.Module):
|
|
36
|
+
"""权重量化线性层(按输出通道对称 INT8)。"""
|
|
37
|
+
|
|
38
|
+
def __init__(self, linear: nn.Linear):
|
|
39
|
+
super().__init__()
|
|
40
|
+
weight = linear.weight.detach().to(torch.float32)
|
|
41
|
+
scale = weight.abs().amax(dim=1, keepdim=True).clamp(min=1e-8) / 127.0
|
|
42
|
+
qweight = (weight / scale).round().clamp(-127, 127).to(torch.int8)
|
|
43
|
+
|
|
44
|
+
self.in_features = linear.in_features
|
|
45
|
+
self.out_features = linear.out_features
|
|
46
|
+
self.register_buffer("qweight", qweight)
|
|
47
|
+
self.register_buffer("weight_scale", scale)
|
|
48
|
+
self.register_buffer(
|
|
49
|
+
"bias",
|
|
50
|
+
linear.bias.detach().to(torch.float32).clone()
|
|
51
|
+
if linear.bias is not None
|
|
52
|
+
else None,
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
56
|
+
weight = (self.qweight.float() * self.weight_scale).to(
|
|
57
|
+
device=x.device, dtype=x.dtype
|
|
58
|
+
)
|
|
59
|
+
bias = None if self.bias is None else self.bias.to(device=x.device, dtype=x.dtype)
|
|
60
|
+
return F.linear(x, weight, bias)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def _state_dict_nbytes(model: nn.Module) -> int:
|
|
64
|
+
total = 0
|
|
65
|
+
for _, tensor in model.state_dict().items():
|
|
66
|
+
if torch.is_tensor(tensor):
|
|
67
|
+
total += tensor.numel() * tensor.element_size()
|
|
68
|
+
return total
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def _sum_nested_tensor_bytes(obj: object, visited: Optional[Set[int]] = None) -> int:
|
|
72
|
+
"""估算对象中张量底层占用字节(支持 torchao tensor subclass)。"""
|
|
73
|
+
if visited is None:
|
|
74
|
+
visited = set()
|
|
75
|
+
oid = id(obj)
|
|
76
|
+
if oid in visited:
|
|
77
|
+
return 0
|
|
78
|
+
visited.add(oid)
|
|
79
|
+
|
|
80
|
+
if torch.is_tensor(obj):
|
|
81
|
+
# torchao tensor subclass: 优先统计其内部量化存储,而非逻辑 float 形状。
|
|
82
|
+
module_name = type(obj).__module__
|
|
83
|
+
if module_name.startswith("torchao"):
|
|
84
|
+
inner = 0
|
|
85
|
+
for attr_name in ("tensor_impl", "original_weight_tensor"):
|
|
86
|
+
if hasattr(obj, attr_name):
|
|
87
|
+
try:
|
|
88
|
+
inner += _sum_nested_tensor_bytes(getattr(obj, attr_name), visited)
|
|
89
|
+
except Exception:
|
|
90
|
+
pass
|
|
91
|
+
for v in getattr(obj, "__dict__", {}).values():
|
|
92
|
+
inner += _sum_nested_tensor_bytes(v, visited)
|
|
93
|
+
if inner > 0:
|
|
94
|
+
return inner
|
|
95
|
+
return int(obj.numel() * obj.element_size())
|
|
96
|
+
|
|
97
|
+
if isinstance(obj, dict):
|
|
98
|
+
total = 0
|
|
99
|
+
for v in obj.values():
|
|
100
|
+
total += _sum_nested_tensor_bytes(v, visited)
|
|
101
|
+
return total
|
|
102
|
+
if isinstance(obj, (list, tuple, set)):
|
|
103
|
+
total = 0
|
|
104
|
+
for v in obj:
|
|
105
|
+
total += _sum_nested_tensor_bytes(v, visited)
|
|
106
|
+
return total
|
|
107
|
+
if hasattr(obj, "__dict__"):
|
|
108
|
+
total = 0
|
|
109
|
+
for v in obj.__dict__.values():
|
|
110
|
+
total += _sum_nested_tensor_bytes(v, visited)
|
|
111
|
+
return total
|
|
112
|
+
return 0
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def _state_dict_effective_nbytes(model: nn.Module) -> int:
|
|
116
|
+
total = 0
|
|
117
|
+
for _, tensor in model.state_dict().items():
|
|
118
|
+
total += _sum_nested_tensor_bytes(tensor)
|
|
119
|
+
return total
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def _build_torchao_target_model(model: nn.Module, inplace: bool) -> Optional[nn.Module]:
|
|
123
|
+
if inplace:
|
|
124
|
+
return model
|
|
125
|
+
try:
|
|
126
|
+
return copy.deepcopy(model)
|
|
127
|
+
except TypeError as e:
|
|
128
|
+
if "RLock" in str(e):
|
|
129
|
+
logger.warning("torchao deepcopy failed due RLock; trying re-init clone fallback")
|
|
130
|
+
if os.getenv("CORTEXNET_TORCHAO_ALLOW_REINIT_CLONE", "1") != "1":
|
|
131
|
+
return None
|
|
132
|
+
cfg = getattr(model, "config", None)
|
|
133
|
+
model_cls = type(model)
|
|
134
|
+
if cfg is None:
|
|
135
|
+
return None
|
|
136
|
+
try:
|
|
137
|
+
cloned = model_cls(copy.deepcopy(cfg)).eval()
|
|
138
|
+
state = model.state_dict()
|
|
139
|
+
cloned.load_state_dict(state, strict=False)
|
|
140
|
+
return cloned
|
|
141
|
+
except Exception as clone_exc:
|
|
142
|
+
logger.warning("torchao re-init clone fallback failed: %s", clone_exc)
|
|
143
|
+
return None
|
|
144
|
+
raise
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def _select_effective_linear_modules(
|
|
148
|
+
model: nn.Module,
|
|
149
|
+
) -> Tuple[List[Tuple[str, nn.Linear]], Dict[str, int]]:
|
|
150
|
+
"""选择更可能带来收益的 Linear 子图,避免量化无收益或高风险分支。"""
|
|
151
|
+
min_params = int(os.getenv("CORTEXNET_TORCHAO_MIN_LINEAR_PARAMS", "65536"))
|
|
152
|
+
skip_tied_lm_head = os.getenv("CORTEXNET_TORCHAO_SKIP_TIED_LM_HEAD", "1") == "1"
|
|
153
|
+
include_keywords = tuple(
|
|
154
|
+
s.strip().lower()
|
|
155
|
+
for s in os.getenv(
|
|
156
|
+
"CORTEXNET_TORCHAO_INCLUDE",
|
|
157
|
+
"proj,ffn,expert,attention,mlp,fc,dense",
|
|
158
|
+
).split(",")
|
|
159
|
+
if s.strip()
|
|
160
|
+
)
|
|
161
|
+
exclude_keywords = tuple(
|
|
162
|
+
s.strip().lower()
|
|
163
|
+
for s in os.getenv(
|
|
164
|
+
"CORTEXNET_TORCHAO_EXCLUDE",
|
|
165
|
+
"router,gate,norm,score,bias,aux,calib",
|
|
166
|
+
).split(",")
|
|
167
|
+
if s.strip()
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
embedding_weight_ids = {
|
|
171
|
+
id(m.weight) for m in model.modules() if isinstance(m, nn.Embedding)
|
|
172
|
+
}
|
|
173
|
+
|
|
174
|
+
selected: List[Tuple[str, nn.Linear]] = []
|
|
175
|
+
skipped: Dict[str, int] = {}
|
|
176
|
+
|
|
177
|
+
def _skip(reason: str):
|
|
178
|
+
skipped[reason] = skipped.get(reason, 0) + 1
|
|
179
|
+
|
|
180
|
+
for name, module in model.named_modules():
|
|
181
|
+
if not isinstance(module, nn.Linear):
|
|
182
|
+
continue
|
|
183
|
+
|
|
184
|
+
lname = name.lower()
|
|
185
|
+
numel = int(module.weight.numel())
|
|
186
|
+
if numel < min_params:
|
|
187
|
+
_skip("small_linear")
|
|
188
|
+
continue
|
|
189
|
+
if any(k in lname for k in exclude_keywords):
|
|
190
|
+
_skip("exclude_keyword")
|
|
191
|
+
continue
|
|
192
|
+
if include_keywords and not any(k in lname for k in include_keywords):
|
|
193
|
+
_skip("not_included")
|
|
194
|
+
continue
|
|
195
|
+
|
|
196
|
+
# tied embedding/lm_head 通常会导致压缩收益差甚至 state_dict 变大,默认跳过。
|
|
197
|
+
if skip_tied_lm_head and id(module.weight) in embedding_weight_ids:
|
|
198
|
+
_skip("tied_embedding_weight")
|
|
199
|
+
continue
|
|
200
|
+
|
|
201
|
+
selected.append((name, module))
|
|
202
|
+
|
|
203
|
+
return selected, skipped
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def _try_torchao_dynamic_int8(model: nn.Module, inplace: bool) -> Optional[nn.Module]:
|
|
207
|
+
"""尝试使用 torchao eager quantize_ API(可选依赖,默认选择性量化)。"""
|
|
208
|
+
if os.getenv("CORTEXNET_DISABLE_TORCHAO", "0") == "1":
|
|
209
|
+
return None
|
|
210
|
+
try:
|
|
211
|
+
import torchao.quantization as tq # type: ignore
|
|
212
|
+
except Exception:
|
|
213
|
+
return None
|
|
214
|
+
|
|
215
|
+
quantize_fn = getattr(tq, "quantize_", None)
|
|
216
|
+
if quantize_fn is None:
|
|
217
|
+
return None
|
|
218
|
+
|
|
219
|
+
candidates = [
|
|
220
|
+
"int8_dynamic_activation_int8_weight",
|
|
221
|
+
"int8_dynamic_activation_int8_weight_per_channel",
|
|
222
|
+
"int8_dynamic_per_token_weight",
|
|
223
|
+
"Int8DynamicActivationInt8WeightConfig",
|
|
224
|
+
"Int8DynamicActivationIntxWeightConfig",
|
|
225
|
+
]
|
|
226
|
+
for name in candidates:
|
|
227
|
+
cfg_ctor = getattr(tq, name, None)
|
|
228
|
+
if not callable(cfg_ctor):
|
|
229
|
+
continue
|
|
230
|
+
try:
|
|
231
|
+
target = _build_torchao_target_model(model, inplace=inplace)
|
|
232
|
+
if target is None:
|
|
233
|
+
return None
|
|
234
|
+
|
|
235
|
+
if os.getenv("CORTEXNET_TORCHAO_FULL_GRAPH", "0") == "1":
|
|
236
|
+
quantize_fn(target, cfg_ctor())
|
|
237
|
+
setattr(target, "_cortex_quant_torchao_mode", "full_graph")
|
|
238
|
+
logger.info(f"Dynamic quantization backend: torchao ({name}, full_graph)")
|
|
239
|
+
else:
|
|
240
|
+
selected, skipped = _select_effective_linear_modules(target)
|
|
241
|
+
if not selected:
|
|
242
|
+
logger.warning("torchao selective quantization: no effective Linear modules selected")
|
|
243
|
+
return None
|
|
244
|
+
|
|
245
|
+
ok = 0
|
|
246
|
+
fail = 0
|
|
247
|
+
for module_name, linear in selected:
|
|
248
|
+
try:
|
|
249
|
+
quantize_fn(linear, cfg_ctor())
|
|
250
|
+
ok += 1
|
|
251
|
+
except Exception as exc:
|
|
252
|
+
fail += 1
|
|
253
|
+
logger.debug("torchao selective quantize failed for %s: %s", module_name, exc)
|
|
254
|
+
|
|
255
|
+
if ok <= 0:
|
|
256
|
+
return None
|
|
257
|
+
|
|
258
|
+
setattr(target, "_cortex_quant_torchao_mode", "selective_linear")
|
|
259
|
+
setattr(target, "_cortex_quant_torchao_selected", len(selected))
|
|
260
|
+
setattr(target, "_cortex_quant_torchao_quantized", ok)
|
|
261
|
+
setattr(target, "_cortex_quant_torchao_failed", fail)
|
|
262
|
+
logger.info(
|
|
263
|
+
"Dynamic quantization backend: torchao (%s, selective) "
|
|
264
|
+
"selected=%s quantized=%s failed=%s skipped=%s",
|
|
265
|
+
name,
|
|
266
|
+
len(selected),
|
|
267
|
+
ok,
|
|
268
|
+
fail,
|
|
269
|
+
skipped,
|
|
270
|
+
)
|
|
271
|
+
return target
|
|
272
|
+
except Exception as e:
|
|
273
|
+
logger.debug(f"torchao backend candidate '{name}' failed: {e}")
|
|
274
|
+
continue
|
|
275
|
+
return None
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
def quantize_dynamic(
|
|
279
|
+
model: nn.Module,
|
|
280
|
+
dtype: torch.dtype = torch.qint8,
|
|
281
|
+
inplace: bool = False,
|
|
282
|
+
target_modules: Optional[Set[Type[nn.Module]]] = None,
|
|
283
|
+
) -> nn.Module:
|
|
284
|
+
"""对模型进行动态量化(Linear 层 INT8)。
|
|
285
|
+
|
|
286
|
+
适用于推理加速,对 CortexNet 的 Linear 层进行动态量化。
|
|
287
|
+
注意:SSM、自定义 attention 等可能不兼容,建议只对纯 Linear 子模块使用。
|
|
288
|
+
|
|
289
|
+
Args:
|
|
290
|
+
model: 待量化模型
|
|
291
|
+
dtype: 量化类型,默认 torch.qint8
|
|
292
|
+
inplace: 是否原地修改
|
|
293
|
+
target_modules: 要量化的模块类型集合,默认 {nn.Linear}
|
|
294
|
+
Returns:
|
|
295
|
+
量化后的模型
|
|
296
|
+
"""
|
|
297
|
+
if target_modules is None:
|
|
298
|
+
target_modules = {nn.Linear}
|
|
299
|
+
|
|
300
|
+
# 优先 torchao(官方迁移路径),仅在标准 Linear + qint8 场景尝试。
|
|
301
|
+
if dtype == torch.qint8 and target_modules == {nn.Linear} and not inplace:
|
|
302
|
+
baseline_bytes = _state_dict_effective_nbytes(model)
|
|
303
|
+
baseline_logical_bytes = _state_dict_nbytes(model)
|
|
304
|
+
torchao_model = _try_torchao_dynamic_int8(model, inplace=inplace)
|
|
305
|
+
if torchao_model is not None:
|
|
306
|
+
torchao_bytes = _state_dict_effective_nbytes(torchao_model)
|
|
307
|
+
torchao_logical_bytes = _state_dict_nbytes(torchao_model)
|
|
308
|
+
keep_ratio = float(os.getenv("CORTEXNET_TORCHAO_KEEP_RATIO", "0.98"))
|
|
309
|
+
if torchao_bytes <= int(baseline_bytes * keep_ratio):
|
|
310
|
+
setattr(torchao_model, "_cortex_quant_backend", "torchao_dynamic_int8")
|
|
311
|
+
setattr(torchao_model, "_cortex_quant_baseline_bytes", baseline_bytes)
|
|
312
|
+
setattr(torchao_model, "_cortex_quant_result_bytes", torchao_bytes)
|
|
313
|
+
setattr(torchao_model, "_cortex_quant_baseline_logical_bytes", baseline_logical_bytes)
|
|
314
|
+
setattr(torchao_model, "_cortex_quant_result_logical_bytes", torchao_logical_bytes)
|
|
315
|
+
return torchao_model
|
|
316
|
+
logger.warning(
|
|
317
|
+
"torchao quantization rejected due weak effective compression: "
|
|
318
|
+
"baseline=%.3fMB, torchao=%.3fMB (logical baseline=%.3fMB, logical torchao=%.3fMB)",
|
|
319
|
+
baseline_bytes / (1024 ** 2),
|
|
320
|
+
torchao_bytes / (1024 ** 2),
|
|
321
|
+
baseline_logical_bytes / (1024 ** 2),
|
|
322
|
+
torchao_logical_bytes / (1024 ** 2),
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
# 回退到 torch.ao(兼容旧环境),并屏蔽已知弃用告警。
|
|
326
|
+
with warnings.catch_warnings():
|
|
327
|
+
warnings.filterwarnings(
|
|
328
|
+
"ignore",
|
|
329
|
+
message="torch.ao.quantization is deprecated.*",
|
|
330
|
+
category=DeprecationWarning,
|
|
331
|
+
)
|
|
332
|
+
# macOS/CPU 常见场景下默认 engine=none,需显式设置为 qnnpack。
|
|
333
|
+
if hasattr(torch.backends, "quantized"):
|
|
334
|
+
current_engine = getattr(torch.backends.quantized, "engine", "none")
|
|
335
|
+
supported = list(getattr(torch.backends.quantized, "supported_engines", []))
|
|
336
|
+
if current_engine == "none" and "qnnpack" in supported:
|
|
337
|
+
torch.backends.quantized.engine = "qnnpack"
|
|
338
|
+
try:
|
|
339
|
+
quantized = torch.ao.quantization.quantize_dynamic(
|
|
340
|
+
model,
|
|
341
|
+
target_modules,
|
|
342
|
+
dtype=dtype,
|
|
343
|
+
inplace=inplace,
|
|
344
|
+
)
|
|
345
|
+
except TypeError as e:
|
|
346
|
+
if (not inplace) and "RLock" in str(e):
|
|
347
|
+
# CortexNet 含线程锁对象,deepcopy 路径不可用;回退到原地量化。
|
|
348
|
+
logger.warning(
|
|
349
|
+
"quantize_dynamic deepcopy failed due RLock; retry with inplace=True"
|
|
350
|
+
)
|
|
351
|
+
quantized = torch.ao.quantization.quantize_dynamic(
|
|
352
|
+
model,
|
|
353
|
+
target_modules,
|
|
354
|
+
dtype=dtype,
|
|
355
|
+
inplace=True,
|
|
356
|
+
)
|
|
357
|
+
else:
|
|
358
|
+
raise
|
|
359
|
+
setattr(quantized, "_cortex_quant_backend", "torch_ao_dynamic_int8")
|
|
360
|
+
return quantized
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
def to_fp16(model: nn.Module) -> nn.Module:
|
|
364
|
+
"""将模型转换为 FP16(半精度)。"""
|
|
365
|
+
return model.half()
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
def to_bf16(model: nn.Module) -> nn.Module:
|
|
369
|
+
"""将模型转换为 BF16。"""
|
|
370
|
+
return model.to(dtype=torch.bfloat16)
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
def weight_only_int8(model: nn.Module, inplace: bool = False) -> nn.Module:
|
|
374
|
+
"""仅量化权重为 INT8,保持激活为原精度。
|
|
375
|
+
|
|
376
|
+
这种方式减少模型显存占用,同时保持更好的精度。
|
|
377
|
+
"""
|
|
378
|
+
if inplace:
|
|
379
|
+
target = model
|
|
380
|
+
else:
|
|
381
|
+
try:
|
|
382
|
+
target = copy.deepcopy(model)
|
|
383
|
+
except TypeError as e:
|
|
384
|
+
if "RLock" in str(e):
|
|
385
|
+
logger.warning(
|
|
386
|
+
"weight_only_int8 deepcopy failed due RLock; retry with inplace=True"
|
|
387
|
+
)
|
|
388
|
+
target = model
|
|
389
|
+
else:
|
|
390
|
+
raise
|
|
391
|
+
replaced = 0
|
|
392
|
+
|
|
393
|
+
def _replace(module: nn.Module):
|
|
394
|
+
nonlocal replaced
|
|
395
|
+
for name, child in list(module.named_children()):
|
|
396
|
+
if isinstance(child, nn.Linear):
|
|
397
|
+
setattr(module, name, WeightOnlyInt8Linear(child))
|
|
398
|
+
replaced += 1
|
|
399
|
+
else:
|
|
400
|
+
_replace(child)
|
|
401
|
+
|
|
402
|
+
_replace(target)
|
|
403
|
+
logger.info(f"Weight-only INT8 replaced {replaced} Linear modules")
|
|
404
|
+
setattr(target, "_cortex_quant_backend", "weight_only_int8")
|
|
405
|
+
return target
|
|
406
|
+
|
|
407
|
+
|
|
408
|
+
def smooth_quantization_preprocess(model: nn.Module, alpha: float = 0.5) -> nn.Module:
|
|
409
|
+
"""平滑量化预处理:均衡权重和激活的量化难度。
|
|
410
|
+
|
|
411
|
+
基于 SmoothQuant 思想,将激活的量化难度转移到权重上,
|
|
412
|
+
使两者都更容易量化。
|
|
413
|
+
|
|
414
|
+
Args:
|
|
415
|
+
model: 模型
|
|
416
|
+
alpha: 平滑系数 (0=全部转移到权重, 1=不转移)
|
|
417
|
+
Returns:
|
|
418
|
+
预处理后的模型
|
|
419
|
+
"""
|
|
420
|
+
for name, module in model.named_modules():
|
|
421
|
+
if isinstance(module, nn.Linear) and hasattr(module, 'weight'):
|
|
422
|
+
weight = module.weight.data
|
|
423
|
+
# 计算每个输出通道的权重范围
|
|
424
|
+
w_max = weight.abs().max(dim=0, keepdim=True).values.clamp(min=1e-8)
|
|
425
|
+
# 平滑因子
|
|
426
|
+
smooth_factor = w_max.pow(alpha)
|
|
427
|
+
# 缩放权重
|
|
428
|
+
module.weight.data = weight / smooth_factor
|
|
429
|
+
# 记录逆缩放因子(推理时需对输入应用)
|
|
430
|
+
module.register_buffer('smooth_scale', smooth_factor.squeeze(0))
|
|
431
|
+
return model
|
|
432
|
+
|
|
433
|
+
|
|
434
|
+
class QuantizationWrapper(nn.Module):
|
|
435
|
+
"""量化包装器:统一多种量化策略接口。
|
|
436
|
+
|
|
437
|
+
Args:
|
|
438
|
+
model: 待量化模型
|
|
439
|
+
strategy: 量化策略
|
|
440
|
+
- "dynamic_int8": 动态 INT8 量化
|
|
441
|
+
- "weight_only_int8": 仅权重 INT8
|
|
442
|
+
- "fp16": FP16 半精度
|
|
443
|
+
- "bf16": BF16 半精度
|
|
444
|
+
- "smooth_int8": 平滑量化 + 动态 INT8
|
|
445
|
+
- None/"none": 不量化
|
|
446
|
+
smooth_alpha: 平滑量化的 alpha 系数
|
|
447
|
+
"""
|
|
448
|
+
|
|
449
|
+
STRATEGIES = {"dynamic_int8", "weight_only_int8", "fp16", "bf16", "smooth_int8", "none", None}
|
|
450
|
+
|
|
451
|
+
def __init__(
|
|
452
|
+
self,
|
|
453
|
+
model: nn.Module,
|
|
454
|
+
strategy: Optional[str] = None,
|
|
455
|
+
smooth_alpha: float = 0.5,
|
|
456
|
+
# 向后兼容旧接口
|
|
457
|
+
use_fp16: Optional[bool] = None,
|
|
458
|
+
use_dynamic_int8: bool = False,
|
|
459
|
+
):
|
|
460
|
+
super().__init__()
|
|
461
|
+
|
|
462
|
+
# 向后兼容旧接口
|
|
463
|
+
if use_dynamic_int8:
|
|
464
|
+
strategy = "dynamic_int8"
|
|
465
|
+
elif use_fp16 is True:
|
|
466
|
+
strategy = "fp16"
|
|
467
|
+
elif use_fp16 is None and strategy is None and torch.cuda.is_available():
|
|
468
|
+
strategy = "fp16"
|
|
469
|
+
|
|
470
|
+
if strategy not in self.STRATEGIES:
|
|
471
|
+
raise ValueError(
|
|
472
|
+
f"Unknown quantization strategy: '{strategy}'. "
|
|
473
|
+
f"Supported: {self.STRATEGIES - {None}}"
|
|
474
|
+
)
|
|
475
|
+
|
|
476
|
+
self.strategy = strategy or "none"
|
|
477
|
+
|
|
478
|
+
if strategy == "dynamic_int8":
|
|
479
|
+
self.model = quantize_dynamic(model, inplace=False)
|
|
480
|
+
elif strategy == "weight_only_int8":
|
|
481
|
+
self.model = weight_only_int8(model, inplace=False)
|
|
482
|
+
elif strategy == "fp16":
|
|
483
|
+
self.model = model.half()
|
|
484
|
+
elif strategy == "bf16":
|
|
485
|
+
self.model = to_bf16(model)
|
|
486
|
+
elif strategy == "smooth_int8":
|
|
487
|
+
model = smooth_quantization_preprocess(model, alpha=smooth_alpha)
|
|
488
|
+
self.model = quantize_dynamic(model, inplace=False)
|
|
489
|
+
else:
|
|
490
|
+
self.model = model
|
|
491
|
+
|
|
492
|
+
self.backend = getattr(self.model, "_cortex_quant_backend", self.strategy)
|
|
493
|
+
logger.info(f"QuantizationWrapper initialized with strategy='{self.strategy}'")
|
|
494
|
+
|
|
495
|
+
def forward(self, *args, **kwargs):
|
|
496
|
+
return self.model(*args, **kwargs)
|