cortexnet 3.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,496 @@
1
+ """
2
+ CortexNet 量化封装 (Quantization Wrapper)
3
+
4
+ 支持多种量化策略:
5
+ 1. Dynamic INT8: PyTorch 原生动态量化,适用于 CPU 推理加速
6
+ 2. FP16/BF16: 半精度推理,适用于 GPU
7
+ 3. Weight-Only INT8: 仅量化权重,激活保持原精度
8
+ 4. Smooth Quantization: 预处理权重使量化误差更均匀
9
+
10
+ 用法:
11
+ # 动态 INT8 量化
12
+ model = quantize_dynamic(model)
13
+
14
+ # 统一接口
15
+ model = QuantizationWrapper(model, strategy="dynamic_int8")
16
+ model = QuantizationWrapper(model, strategy="weight_only_int8")
17
+ model = QuantizationWrapper(model, strategy="fp16")
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ import copy
23
+ import logging
24
+ import os
25
+ import warnings
26
+ from typing import Dict, List, Optional, Set, Tuple, Type
27
+
28
+ import torch
29
+ import torch.nn as nn
30
+ import torch.nn.functional as F
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ class WeightOnlyInt8Linear(nn.Module):
36
+ """权重量化线性层(按输出通道对称 INT8)。"""
37
+
38
+ def __init__(self, linear: nn.Linear):
39
+ super().__init__()
40
+ weight = linear.weight.detach().to(torch.float32)
41
+ scale = weight.abs().amax(dim=1, keepdim=True).clamp(min=1e-8) / 127.0
42
+ qweight = (weight / scale).round().clamp(-127, 127).to(torch.int8)
43
+
44
+ self.in_features = linear.in_features
45
+ self.out_features = linear.out_features
46
+ self.register_buffer("qweight", qweight)
47
+ self.register_buffer("weight_scale", scale)
48
+ self.register_buffer(
49
+ "bias",
50
+ linear.bias.detach().to(torch.float32).clone()
51
+ if linear.bias is not None
52
+ else None,
53
+ )
54
+
55
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
56
+ weight = (self.qweight.float() * self.weight_scale).to(
57
+ device=x.device, dtype=x.dtype
58
+ )
59
+ bias = None if self.bias is None else self.bias.to(device=x.device, dtype=x.dtype)
60
+ return F.linear(x, weight, bias)
61
+
62
+
63
+ def _state_dict_nbytes(model: nn.Module) -> int:
64
+ total = 0
65
+ for _, tensor in model.state_dict().items():
66
+ if torch.is_tensor(tensor):
67
+ total += tensor.numel() * tensor.element_size()
68
+ return total
69
+
70
+
71
+ def _sum_nested_tensor_bytes(obj: object, visited: Optional[Set[int]] = None) -> int:
72
+ """估算对象中张量底层占用字节(支持 torchao tensor subclass)。"""
73
+ if visited is None:
74
+ visited = set()
75
+ oid = id(obj)
76
+ if oid in visited:
77
+ return 0
78
+ visited.add(oid)
79
+
80
+ if torch.is_tensor(obj):
81
+ # torchao tensor subclass: 优先统计其内部量化存储,而非逻辑 float 形状。
82
+ module_name = type(obj).__module__
83
+ if module_name.startswith("torchao"):
84
+ inner = 0
85
+ for attr_name in ("tensor_impl", "original_weight_tensor"):
86
+ if hasattr(obj, attr_name):
87
+ try:
88
+ inner += _sum_nested_tensor_bytes(getattr(obj, attr_name), visited)
89
+ except Exception:
90
+ pass
91
+ for v in getattr(obj, "__dict__", {}).values():
92
+ inner += _sum_nested_tensor_bytes(v, visited)
93
+ if inner > 0:
94
+ return inner
95
+ return int(obj.numel() * obj.element_size())
96
+
97
+ if isinstance(obj, dict):
98
+ total = 0
99
+ for v in obj.values():
100
+ total += _sum_nested_tensor_bytes(v, visited)
101
+ return total
102
+ if isinstance(obj, (list, tuple, set)):
103
+ total = 0
104
+ for v in obj:
105
+ total += _sum_nested_tensor_bytes(v, visited)
106
+ return total
107
+ if hasattr(obj, "__dict__"):
108
+ total = 0
109
+ for v in obj.__dict__.values():
110
+ total += _sum_nested_tensor_bytes(v, visited)
111
+ return total
112
+ return 0
113
+
114
+
115
+ def _state_dict_effective_nbytes(model: nn.Module) -> int:
116
+ total = 0
117
+ for _, tensor in model.state_dict().items():
118
+ total += _sum_nested_tensor_bytes(tensor)
119
+ return total
120
+
121
+
122
+ def _build_torchao_target_model(model: nn.Module, inplace: bool) -> Optional[nn.Module]:
123
+ if inplace:
124
+ return model
125
+ try:
126
+ return copy.deepcopy(model)
127
+ except TypeError as e:
128
+ if "RLock" in str(e):
129
+ logger.warning("torchao deepcopy failed due RLock; trying re-init clone fallback")
130
+ if os.getenv("CORTEXNET_TORCHAO_ALLOW_REINIT_CLONE", "1") != "1":
131
+ return None
132
+ cfg = getattr(model, "config", None)
133
+ model_cls = type(model)
134
+ if cfg is None:
135
+ return None
136
+ try:
137
+ cloned = model_cls(copy.deepcopy(cfg)).eval()
138
+ state = model.state_dict()
139
+ cloned.load_state_dict(state, strict=False)
140
+ return cloned
141
+ except Exception as clone_exc:
142
+ logger.warning("torchao re-init clone fallback failed: %s", clone_exc)
143
+ return None
144
+ raise
145
+
146
+
147
+ def _select_effective_linear_modules(
148
+ model: nn.Module,
149
+ ) -> Tuple[List[Tuple[str, nn.Linear]], Dict[str, int]]:
150
+ """选择更可能带来收益的 Linear 子图,避免量化无收益或高风险分支。"""
151
+ min_params = int(os.getenv("CORTEXNET_TORCHAO_MIN_LINEAR_PARAMS", "65536"))
152
+ skip_tied_lm_head = os.getenv("CORTEXNET_TORCHAO_SKIP_TIED_LM_HEAD", "1") == "1"
153
+ include_keywords = tuple(
154
+ s.strip().lower()
155
+ for s in os.getenv(
156
+ "CORTEXNET_TORCHAO_INCLUDE",
157
+ "proj,ffn,expert,attention,mlp,fc,dense",
158
+ ).split(",")
159
+ if s.strip()
160
+ )
161
+ exclude_keywords = tuple(
162
+ s.strip().lower()
163
+ for s in os.getenv(
164
+ "CORTEXNET_TORCHAO_EXCLUDE",
165
+ "router,gate,norm,score,bias,aux,calib",
166
+ ).split(",")
167
+ if s.strip()
168
+ )
169
+
170
+ embedding_weight_ids = {
171
+ id(m.weight) for m in model.modules() if isinstance(m, nn.Embedding)
172
+ }
173
+
174
+ selected: List[Tuple[str, nn.Linear]] = []
175
+ skipped: Dict[str, int] = {}
176
+
177
+ def _skip(reason: str):
178
+ skipped[reason] = skipped.get(reason, 0) + 1
179
+
180
+ for name, module in model.named_modules():
181
+ if not isinstance(module, nn.Linear):
182
+ continue
183
+
184
+ lname = name.lower()
185
+ numel = int(module.weight.numel())
186
+ if numel < min_params:
187
+ _skip("small_linear")
188
+ continue
189
+ if any(k in lname for k in exclude_keywords):
190
+ _skip("exclude_keyword")
191
+ continue
192
+ if include_keywords and not any(k in lname for k in include_keywords):
193
+ _skip("not_included")
194
+ continue
195
+
196
+ # tied embedding/lm_head 通常会导致压缩收益差甚至 state_dict 变大,默认跳过。
197
+ if skip_tied_lm_head and id(module.weight) in embedding_weight_ids:
198
+ _skip("tied_embedding_weight")
199
+ continue
200
+
201
+ selected.append((name, module))
202
+
203
+ return selected, skipped
204
+
205
+
206
+ def _try_torchao_dynamic_int8(model: nn.Module, inplace: bool) -> Optional[nn.Module]:
207
+ """尝试使用 torchao eager quantize_ API(可选依赖,默认选择性量化)。"""
208
+ if os.getenv("CORTEXNET_DISABLE_TORCHAO", "0") == "1":
209
+ return None
210
+ try:
211
+ import torchao.quantization as tq # type: ignore
212
+ except Exception:
213
+ return None
214
+
215
+ quantize_fn = getattr(tq, "quantize_", None)
216
+ if quantize_fn is None:
217
+ return None
218
+
219
+ candidates = [
220
+ "int8_dynamic_activation_int8_weight",
221
+ "int8_dynamic_activation_int8_weight_per_channel",
222
+ "int8_dynamic_per_token_weight",
223
+ "Int8DynamicActivationInt8WeightConfig",
224
+ "Int8DynamicActivationIntxWeightConfig",
225
+ ]
226
+ for name in candidates:
227
+ cfg_ctor = getattr(tq, name, None)
228
+ if not callable(cfg_ctor):
229
+ continue
230
+ try:
231
+ target = _build_torchao_target_model(model, inplace=inplace)
232
+ if target is None:
233
+ return None
234
+
235
+ if os.getenv("CORTEXNET_TORCHAO_FULL_GRAPH", "0") == "1":
236
+ quantize_fn(target, cfg_ctor())
237
+ setattr(target, "_cortex_quant_torchao_mode", "full_graph")
238
+ logger.info(f"Dynamic quantization backend: torchao ({name}, full_graph)")
239
+ else:
240
+ selected, skipped = _select_effective_linear_modules(target)
241
+ if not selected:
242
+ logger.warning("torchao selective quantization: no effective Linear modules selected")
243
+ return None
244
+
245
+ ok = 0
246
+ fail = 0
247
+ for module_name, linear in selected:
248
+ try:
249
+ quantize_fn(linear, cfg_ctor())
250
+ ok += 1
251
+ except Exception as exc:
252
+ fail += 1
253
+ logger.debug("torchao selective quantize failed for %s: %s", module_name, exc)
254
+
255
+ if ok <= 0:
256
+ return None
257
+
258
+ setattr(target, "_cortex_quant_torchao_mode", "selective_linear")
259
+ setattr(target, "_cortex_quant_torchao_selected", len(selected))
260
+ setattr(target, "_cortex_quant_torchao_quantized", ok)
261
+ setattr(target, "_cortex_quant_torchao_failed", fail)
262
+ logger.info(
263
+ "Dynamic quantization backend: torchao (%s, selective) "
264
+ "selected=%s quantized=%s failed=%s skipped=%s",
265
+ name,
266
+ len(selected),
267
+ ok,
268
+ fail,
269
+ skipped,
270
+ )
271
+ return target
272
+ except Exception as e:
273
+ logger.debug(f"torchao backend candidate '{name}' failed: {e}")
274
+ continue
275
+ return None
276
+
277
+
278
+ def quantize_dynamic(
279
+ model: nn.Module,
280
+ dtype: torch.dtype = torch.qint8,
281
+ inplace: bool = False,
282
+ target_modules: Optional[Set[Type[nn.Module]]] = None,
283
+ ) -> nn.Module:
284
+ """对模型进行动态量化(Linear 层 INT8)。
285
+
286
+ 适用于推理加速,对 CortexNet 的 Linear 层进行动态量化。
287
+ 注意:SSM、自定义 attention 等可能不兼容,建议只对纯 Linear 子模块使用。
288
+
289
+ Args:
290
+ model: 待量化模型
291
+ dtype: 量化类型,默认 torch.qint8
292
+ inplace: 是否原地修改
293
+ target_modules: 要量化的模块类型集合,默认 {nn.Linear}
294
+ Returns:
295
+ 量化后的模型
296
+ """
297
+ if target_modules is None:
298
+ target_modules = {nn.Linear}
299
+
300
+ # 优先 torchao(官方迁移路径),仅在标准 Linear + qint8 场景尝试。
301
+ if dtype == torch.qint8 and target_modules == {nn.Linear} and not inplace:
302
+ baseline_bytes = _state_dict_effective_nbytes(model)
303
+ baseline_logical_bytes = _state_dict_nbytes(model)
304
+ torchao_model = _try_torchao_dynamic_int8(model, inplace=inplace)
305
+ if torchao_model is not None:
306
+ torchao_bytes = _state_dict_effective_nbytes(torchao_model)
307
+ torchao_logical_bytes = _state_dict_nbytes(torchao_model)
308
+ keep_ratio = float(os.getenv("CORTEXNET_TORCHAO_KEEP_RATIO", "0.98"))
309
+ if torchao_bytes <= int(baseline_bytes * keep_ratio):
310
+ setattr(torchao_model, "_cortex_quant_backend", "torchao_dynamic_int8")
311
+ setattr(torchao_model, "_cortex_quant_baseline_bytes", baseline_bytes)
312
+ setattr(torchao_model, "_cortex_quant_result_bytes", torchao_bytes)
313
+ setattr(torchao_model, "_cortex_quant_baseline_logical_bytes", baseline_logical_bytes)
314
+ setattr(torchao_model, "_cortex_quant_result_logical_bytes", torchao_logical_bytes)
315
+ return torchao_model
316
+ logger.warning(
317
+ "torchao quantization rejected due weak effective compression: "
318
+ "baseline=%.3fMB, torchao=%.3fMB (logical baseline=%.3fMB, logical torchao=%.3fMB)",
319
+ baseline_bytes / (1024 ** 2),
320
+ torchao_bytes / (1024 ** 2),
321
+ baseline_logical_bytes / (1024 ** 2),
322
+ torchao_logical_bytes / (1024 ** 2),
323
+ )
324
+
325
+ # 回退到 torch.ao(兼容旧环境),并屏蔽已知弃用告警。
326
+ with warnings.catch_warnings():
327
+ warnings.filterwarnings(
328
+ "ignore",
329
+ message="torch.ao.quantization is deprecated.*",
330
+ category=DeprecationWarning,
331
+ )
332
+ # macOS/CPU 常见场景下默认 engine=none,需显式设置为 qnnpack。
333
+ if hasattr(torch.backends, "quantized"):
334
+ current_engine = getattr(torch.backends.quantized, "engine", "none")
335
+ supported = list(getattr(torch.backends.quantized, "supported_engines", []))
336
+ if current_engine == "none" and "qnnpack" in supported:
337
+ torch.backends.quantized.engine = "qnnpack"
338
+ try:
339
+ quantized = torch.ao.quantization.quantize_dynamic(
340
+ model,
341
+ target_modules,
342
+ dtype=dtype,
343
+ inplace=inplace,
344
+ )
345
+ except TypeError as e:
346
+ if (not inplace) and "RLock" in str(e):
347
+ # CortexNet 含线程锁对象,deepcopy 路径不可用;回退到原地量化。
348
+ logger.warning(
349
+ "quantize_dynamic deepcopy failed due RLock; retry with inplace=True"
350
+ )
351
+ quantized = torch.ao.quantization.quantize_dynamic(
352
+ model,
353
+ target_modules,
354
+ dtype=dtype,
355
+ inplace=True,
356
+ )
357
+ else:
358
+ raise
359
+ setattr(quantized, "_cortex_quant_backend", "torch_ao_dynamic_int8")
360
+ return quantized
361
+
362
+
363
+ def to_fp16(model: nn.Module) -> nn.Module:
364
+ """将模型转换为 FP16(半精度)。"""
365
+ return model.half()
366
+
367
+
368
+ def to_bf16(model: nn.Module) -> nn.Module:
369
+ """将模型转换为 BF16。"""
370
+ return model.to(dtype=torch.bfloat16)
371
+
372
+
373
+ def weight_only_int8(model: nn.Module, inplace: bool = False) -> nn.Module:
374
+ """仅量化权重为 INT8,保持激活为原精度。
375
+
376
+ 这种方式减少模型显存占用,同时保持更好的精度。
377
+ """
378
+ if inplace:
379
+ target = model
380
+ else:
381
+ try:
382
+ target = copy.deepcopy(model)
383
+ except TypeError as e:
384
+ if "RLock" in str(e):
385
+ logger.warning(
386
+ "weight_only_int8 deepcopy failed due RLock; retry with inplace=True"
387
+ )
388
+ target = model
389
+ else:
390
+ raise
391
+ replaced = 0
392
+
393
+ def _replace(module: nn.Module):
394
+ nonlocal replaced
395
+ for name, child in list(module.named_children()):
396
+ if isinstance(child, nn.Linear):
397
+ setattr(module, name, WeightOnlyInt8Linear(child))
398
+ replaced += 1
399
+ else:
400
+ _replace(child)
401
+
402
+ _replace(target)
403
+ logger.info(f"Weight-only INT8 replaced {replaced} Linear modules")
404
+ setattr(target, "_cortex_quant_backend", "weight_only_int8")
405
+ return target
406
+
407
+
408
+ def smooth_quantization_preprocess(model: nn.Module, alpha: float = 0.5) -> nn.Module:
409
+ """平滑量化预处理:均衡权重和激活的量化难度。
410
+
411
+ 基于 SmoothQuant 思想,将激活的量化难度转移到权重上,
412
+ 使两者都更容易量化。
413
+
414
+ Args:
415
+ model: 模型
416
+ alpha: 平滑系数 (0=全部转移到权重, 1=不转移)
417
+ Returns:
418
+ 预处理后的模型
419
+ """
420
+ for name, module in model.named_modules():
421
+ if isinstance(module, nn.Linear) and hasattr(module, 'weight'):
422
+ weight = module.weight.data
423
+ # 计算每个输出通道的权重范围
424
+ w_max = weight.abs().max(dim=0, keepdim=True).values.clamp(min=1e-8)
425
+ # 平滑因子
426
+ smooth_factor = w_max.pow(alpha)
427
+ # 缩放权重
428
+ module.weight.data = weight / smooth_factor
429
+ # 记录逆缩放因子(推理时需对输入应用)
430
+ module.register_buffer('smooth_scale', smooth_factor.squeeze(0))
431
+ return model
432
+
433
+
434
+ class QuantizationWrapper(nn.Module):
435
+ """量化包装器:统一多种量化策略接口。
436
+
437
+ Args:
438
+ model: 待量化模型
439
+ strategy: 量化策略
440
+ - "dynamic_int8": 动态 INT8 量化
441
+ - "weight_only_int8": 仅权重 INT8
442
+ - "fp16": FP16 半精度
443
+ - "bf16": BF16 半精度
444
+ - "smooth_int8": 平滑量化 + 动态 INT8
445
+ - None/"none": 不量化
446
+ smooth_alpha: 平滑量化的 alpha 系数
447
+ """
448
+
449
+ STRATEGIES = {"dynamic_int8", "weight_only_int8", "fp16", "bf16", "smooth_int8", "none", None}
450
+
451
+ def __init__(
452
+ self,
453
+ model: nn.Module,
454
+ strategy: Optional[str] = None,
455
+ smooth_alpha: float = 0.5,
456
+ # 向后兼容旧接口
457
+ use_fp16: Optional[bool] = None,
458
+ use_dynamic_int8: bool = False,
459
+ ):
460
+ super().__init__()
461
+
462
+ # 向后兼容旧接口
463
+ if use_dynamic_int8:
464
+ strategy = "dynamic_int8"
465
+ elif use_fp16 is True:
466
+ strategy = "fp16"
467
+ elif use_fp16 is None and strategy is None and torch.cuda.is_available():
468
+ strategy = "fp16"
469
+
470
+ if strategy not in self.STRATEGIES:
471
+ raise ValueError(
472
+ f"Unknown quantization strategy: '{strategy}'. "
473
+ f"Supported: {self.STRATEGIES - {None}}"
474
+ )
475
+
476
+ self.strategy = strategy or "none"
477
+
478
+ if strategy == "dynamic_int8":
479
+ self.model = quantize_dynamic(model, inplace=False)
480
+ elif strategy == "weight_only_int8":
481
+ self.model = weight_only_int8(model, inplace=False)
482
+ elif strategy == "fp16":
483
+ self.model = model.half()
484
+ elif strategy == "bf16":
485
+ self.model = to_bf16(model)
486
+ elif strategy == "smooth_int8":
487
+ model = smooth_quantization_preprocess(model, alpha=smooth_alpha)
488
+ self.model = quantize_dynamic(model, inplace=False)
489
+ else:
490
+ self.model = model
491
+
492
+ self.backend = getattr(self.model, "_cortex_quant_backend", self.strategy)
493
+ logger.info(f"QuantizationWrapper initialized with strategy='{self.strategy}'")
494
+
495
+ def forward(self, *args, **kwargs):
496
+ return self.model(*args, **kwargs)