aigc-core 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,12 @@
1
+ Metadata-Version: 2.4
2
+ Name: aigc-core
3
+ Version: 0.0.1
4
+ Summary: Add your description here
5
+ Requires-Python: >=3.12
6
+ Description-Content-Type: text/markdown
7
+ Requires-Dist: python-dotenv
8
+ Requires-Dist: datasets>=4.5.0
9
+ Requires-Dist: peft>=0.18.1
10
+ Requires-Dist: pip>=25.3
11
+ Requires-Dist: swanlab>=0.7.6
12
+ Requires-Dist: torch>=2.10.0
@@ -0,0 +1,18 @@
1
+ aigcore/__init__.py,sha256=xfZFqDDVqXTRUglmrVcqZqYxc6xsWuWY3pK1SNJqFRY,411
2
+ aigcore/_logger.py,sha256=74DTl4hbAZdoIZ_WL5k4dcXqUQ2-VM0SEaDdNQM5yv0,2367
3
+ aigcore/agent/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
+ aigcore/llm/__init__.py,sha256=YUDY2RxdCPX954l7pnB0UptTtx5BxdM7ZJC0boX0OI4,143
5
+ aigcore/llm/attention/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
+ aigcore/llm/attention/_self_attention.py,sha256=vi9H1rJPv1E68nOt6IlyiKDDm1hu7iEgNEQ78STTINw,1085
7
+ aigcore/llm/embed/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
+ aigcore/llm/embed/_rope.py,sha256=PoP7FaemsbJmxx6uuM4yfHNMySi7akbEeH3ewBexxs0,3606
9
+ aigcore/llm/lora/__init__.py,sha256=16mRx18XLQkfDnIB2_dUJRdoRlQ0WCX_fJfNtmQhJDQ,143
10
+ aigcore/llm/lora/_lora_base.py,sha256=qGBicjOkkp71kGQqYaySmiJ9To8fHn1kk60glOlidF8,2635
11
+ aigcore/llm/model/__init__.py,sha256=kl-sn2CMGuZxr_5T6gehYqdi6j5fq7mFL0B2Jt4cnEw,89
12
+ aigcore/llm/model/_minimind.py,sha256=NIILze7afzSMt7FTzUap7g2XM7On071yh_mgQO7RQJU,3848
13
+ aigcore/llm/norm/__init__.py,sha256=emv2DeHDgM5JYY1AJWzB-oirJiGyxlNW01fe8ITzDnc,85
14
+ aigcore/llm/norm/_norm.py,sha256=5qd8KgZ2n3sFL7SVyHVeEOVZ2tvpPErzMMAzhbiF6_8,3592
15
+ aigc_core-0.0.1.dist-info/METADATA,sha256=ImOBo8RmCsxhTVY4iaZRiv0DZKQYj1KwQ5oisW24tyw,324
16
+ aigc_core-0.0.1.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
17
+ aigc_core-0.0.1.dist-info/top_level.txt,sha256=Sg7mTVSn-QTQ740jamTJRvF1W7bzwS_mNiBJqxbcUE4,8
18
+ aigc_core-0.0.1.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (80.10.2)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1 @@
1
+ aigcore
aigcore/__init__.py ADDED
@@ -0,0 +1,26 @@
1
+ """
2
+ A inclued llm and agent to aigc lib
3
+
4
+ PyPI: https://pypi.org/project/aigcore/
5
+ GitHub: https://github.com/torrentbrave/aigcore
6
+ """
7
+
8
+ __author__ = "BoHaoChen"
9
+
10
+ __connect__ = "X @TorrentBrave"
11
+
12
+ __version__ = "0.0.1"
13
+
14
+ from ._logger import logger, print
15
+ from . import llm
16
+ from . import agent
17
+
18
+
19
+ class Null:
20
+ pass
21
+
22
+
23
+ NULL = Null()
24
+
25
+ __all__ = ["logger", "print", "NULL"]
26
+ __all__.extend(llm.__all__ + agent.__all__)
aigcore/_logger.py ADDED
@@ -0,0 +1,95 @@
1
+ import os
2
+ import datetime
3
+ import logging
4
+ from dotenv import load_dotenv
5
+
6
+ load_dotenv()
7
+
8
+
9
+ __all__ = ["logger", "print"]
10
+
11
+ _print = print
12
+
13
+
14
+ def print(
15
+ *args,
16
+ **kwargs,
17
+ ) -> None:
18
+ """
19
+ Prints, and then flushes instantly.
20
+
21
+ The usage is the same as the built-in `print`.
22
+
23
+ Parameters
24
+ ----------
25
+ See also the built-in `print`.
26
+
27
+ Returns
28
+ -------
29
+ None
30
+
31
+ Notes
32
+ -----
33
+ `args` and `kwargs` are passed to the built-in `print`. `flush` is
34
+ overridden to True no matter what.
35
+ """
36
+ kwargs["flush"] = True
37
+ _print(*args, **kwargs)
38
+
39
+
40
+ def logger(
41
+ *,
42
+ name: str | None = None,
43
+ dir: str | None = None,
44
+ ) -> logging.Logger:
45
+ """
46
+ Returns a pre-configured `logging.Logger` object.
47
+
48
+ INFO logs are written to both the .log file and the console.
49
+
50
+ WARNING logs are written to the console only.
51
+
52
+ Parameters
53
+ ----------
54
+ name: str | None = None
55
+ `logging.Logger.name`. If *None*, it is set to 'wqb' followed by
56
+ the current datetime. The filename of the .log file is set to
57
+ `name` followed by '.log'.
58
+ Specifying a name is required if you want to prevent multiple log files and keep everything in a single trace.
59
+ dir: str | None = "logs"
60
+ The directory where the .log file will be stored.
61
+ Defaults to 'logs'. If the directory does not exist, it will be created.
62
+
63
+
64
+ Returns
65
+ -------
66
+ logging.Logger
67
+ A pre-configured `logging.Logger` object.
68
+ """
69
+ if dir is None:
70
+ dir = os.getenv("LOGDIR", "logs")
71
+
72
+ if name is None:
73
+ name = "aigcore" + datetime.datetime.now().strftime("%Y%m%d%H%M%S")
74
+
75
+ if dir is not None:
76
+ os.makedirs(dir, exist_ok=True)
77
+ log_path = os.path.join(dir, f"{name}.log")
78
+ else:
79
+ log_path = f"{name}.log"
80
+
81
+ logger = logging.getLogger(name=name)
82
+ logger.setLevel(logging.INFO)
83
+ handler1 = logging.FileHandler(log_path, mode="w", encoding="utf-8")
84
+ handler1.setLevel(logging.INFO)
85
+ handler1.setFormatter(
86
+ logging.Formatter(fmt="# %(levelname)s %(asctime)s\n%(message)s\n")
87
+ )
88
+ logger.addHandler(handler1)
89
+ handler2 = logging.StreamHandler()
90
+ handler2.setLevel(logging.WARNING)
91
+ handler2.setFormatter(
92
+ logging.Formatter(fmt="# %(levelname)s %(asctime)s\n%(message)s\n")
93
+ )
94
+ logger.addHandler(handler2)
95
+ return logger
File without changes
@@ -0,0 +1,11 @@
1
+ from . import lora
2
+ from . import model
3
+ from . import embed
4
+ from . import norm
5
+
6
+ __all__ = [
7
+ "lora",
8
+ "model",
9
+ "embed",
10
+ "norm",
11
+ ]
File without changes
@@ -0,0 +1,35 @@
1
+ """
2
+ Visualizing the Self-Attention Mechanism - https://codingowen.github.io/blog/2025/02/27/self_attention_intuition/
3
+ Building the Self-Attention Mechanism from scratch
4
+ (1) https://codingowen.github.io/projects/self_attention/
5
+ (2) https://mohdfaraaz.medium.com/implementing-self-attention-from-scratch-in-pytorch-776ef7b8f13e
6
+ """
7
+
8
+ import math
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+
14
+ class SelfAttention(nn.Module):
15
+ def __init__(self, d, d_k, d_q, d_v):
16
+ super(SelfAttention, self).__init__
17
+ self.d = d
18
+ self.d_k = d_k
19
+ self.d_q = d_q
20
+ self.d_v = d_v
21
+
22
+ self.W_K = nn.Parameter(torch.rand(d, d_k))
23
+ self.W_Q = nn.Parameter(torch.rand(d, d_q))
24
+ self.W_V = nn.Parameter(torch.rand(d, d_v))
25
+
26
+ def forward(self, X):
27
+ K = X @ self.W_K
28
+ Q = X @ self.W_Q
29
+ V = X @ self.W_V
30
+
31
+ attention_scores = Q @ K.T / math.sqrt(self.d_k)
32
+ attention_weights = F.softmax(attention_scores, dim=-1)
33
+ context_vector = attention_weights @ V
34
+
35
+ return context_vector
File without changes
@@ -0,0 +1,93 @@
1
+ import math
2
+ import torch
3
+ import torch.nn.init as init
4
+ import torch.nn.functional as F
5
+ from torch import nn
6
+ from transformers.activations import ACT2FN
7
+ from typing import Optional, Tuple, List, Union
8
+ from transformers import PreTrainedModel, GenerationMixin, PretrainedConfig
9
+ from transformers.modeling_outputs import CausalLMOutputWithPast
10
+
11
+
12
+ def precompute_freqs_cis(
13
+ dim: int,
14
+ end: int = int(32 * 1024),
15
+ rope_base: float = 1e6,
16
+ rope_scaling: Optional[dict] = None,
17
+ ):
18
+ """
19
+ 预先计算旋转位置编码所需的cos和sin矩阵
20
+ YaRN(Yet another ROPE extensioN): 推理时动态扩展模型的上下文窗口(Extrapolation)
21
+
22
+ torch.arange(0, dim, 2) 从 0 到 dim, 每隔 2 取一个数
23
+ [0: dim // 2] 切片操作,强制选
24
+
25
+
26
+ """
27
+ freqs, attn_factor = (
28
+ 1.0 / (rope_base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)),
29
+ 1.0,
30
+ )
31
+ if rope_scaling is not None:
32
+ orig_max, factor, beta_fast, beta_slow, attn_factor = (
33
+ rope_scaling.get("original_max_position_embeddings", 2048),
34
+ rope_scaling.get("factor", 16),
35
+ rope_scaling.get("beta_fast", 32.0),
36
+ rope_scaling.get("beta_slow", 1.0),
37
+ rope_scaling.get("attention_factor", 1.0),
38
+ )
39
+ if end / orig_max > 1.0:
40
+ # YaRN: f'(i) = f(i)((1-γ) + γ/s), where γ∈[0,1] is linear ramp
41
+ inv_dim = lambda b: (dim * math.log(orig_max / (b * 2 * math.pi))) / (
42
+ 2 * math.log(rope_base)
43
+ )
44
+ low, high = (
45
+ max(math.floor(inv_dim(beta_fast)), 0),
46
+ min(math.ceil(inv_dim(beta_slow)), dim // 2 - 1),
47
+ )
48
+ ramp = torch.clamp(
49
+ (torch.arange(dim // 2, device=freqs.device).float() - low)
50
+ / max(high - low, 0.001),
51
+ 0,
52
+ 1,
53
+ )
54
+ freqs = freqs * (1 - ramp + ramp / factor)
55
+
56
+ t = torch.arange(end, device=freqs.device)
57
+ freqs = torch.outer(t, freqs).float()
58
+ freqs_cos = torch.cat([torch.cos(freqs), torch.cos(freqs)], dim=-1) * attn_factor
59
+ freqs_sin = torch.cat([torch.sin(freqs), torch.sin(freqs)], dim=-1) * attn_factor
60
+ return freqs_cos, freqs_sin
61
+
62
+
63
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
64
+ def rotate_half(x):
65
+ return torch.cat(
66
+ (-x[..., x.shape[-1] // 2 :], x[..., : x.shape[-1] // 2]), dim=-1
67
+ )
68
+
69
+ q_embed = (q * cos.unsqueeze(unsqueeze_dim)) + (
70
+ rotate_half(q) * sin.unsqueeze(unsqueeze_dim)
71
+ )
72
+ k_embed = (k * cos.unsqueeze(unsqueeze_dim)) + (
73
+ rotate_half(k) * sin.unsqueeze(unsqueeze_dim)
74
+ )
75
+ return q_embed, k_embed
76
+
77
+
78
+ """
79
+ 绝对位置编码: 构建一维向量,通过加法对嵌入向量增加位置信息
80
+ 旋转位置编码: 构建分组旋转矩阵,通过矩阵乘对特征分量嵌入位置信息
81
+
82
+ P(m,i): 位置编号为m,第i个角度对应的分组向量,d是注意力特征维度
83
+ P(m,i) = [sin(m*theta_i), cos(m*theta_i)]
84
+
85
+ theta: 定义角度
86
+ (1) 决定旋转的快慢: theta_i = 1 / (1 / b^2((i - 1)/d)))
87
+ i in [1, d/2]: rope 不是对所有维度统一旋转,而是两两一组,把它们看作一个个二维平面上的点,i 就是这些组的编号
88
+
89
+ Q1: 为什么高维度的慢旋转能捕捉长距离?
90
+ A1: 避免角度重合
91
+ (1) 相位偏移: 维度低,旋转快,如果两个词距离远,指针可能转了几十圈回到原点,模型分不清两个词是距离1,101 or 201
92
+ (2) 位置与角度对应唯一: 旋转慢,即使两个词距离1000,指针可能才转30度
93
+ """
@@ -0,0 +1,9 @@
1
+ "The lora module"
2
+
3
+ from ._lora_base import load_lora, save_lora, apply_lora
4
+
5
+ __all__ = [
6
+ "load_lora",
7
+ "save_lora",
8
+ "apply_lora",
9
+ ]
@@ -0,0 +1,80 @@
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class LoRA(nn.Module):
6
+ """
7
+ output = W_0x + BA(X) * alpha/rank
8
+
9
+ 1. matrix B (All Zeros):
10
+ ensures the initial state of the LoRA path is BAx=0
11
+ 2. matrix A (Random/Gaussian):
12
+ breaks symmetry
13
+ If both A and B were zero, the gradients for all neurons would be identical.
14
+ Randomizing A ensures that different neurons can learn different features once training begins and the weights start to update
15
+ """
16
+
17
+ def __init__(self, in_features, out_features, rank):
18
+ super().__init__()
19
+ self.A = nn.Linear(in_features, rank, bias=False)
20
+ self.B = nn.Linear(rank, out_features, bias=False)
21
+ self.A.weight.data.normal_(0, std=0.02)
22
+ self.B.weight.data.zero_()
23
+
24
+ def forward(self, x):
25
+ return self.B(self.A(x))
26
+
27
+
28
+ def apply_lora(model, rank=8):
29
+ """
30
+ explicit binding
31
+ """
32
+ for name, module in model.named_modules():
33
+ if (
34
+ isinstance(module, nn.Linear)
35
+ and module.weight.shape[0] == module.weight.shape[1]
36
+ ):
37
+ lora = LoRA(module.weight.shape[0], module.weight.shape[1], rank=rank).to(
38
+ model.device
39
+ )
40
+ setattr(module, "lora", lora)
41
+ original_forward = module.forward
42
+
43
+ def forward_with_lora(x, layer1=original_forward, layer2=lora):
44
+ return layer1(x) + layer2(x)
45
+
46
+ module.forward = forward_with_lora
47
+
48
+
49
+ def load_lora(model, path):
50
+ state_dict = torch.load(path, map_location=model.device)
51
+ state_dict = {
52
+ (k[7:] if k.startswith("module.") else k): v for k, v in state_dict.items()
53
+ }
54
+
55
+ for name, module in model.named_modules():
56
+ if hasattr(module, "lora"):
57
+ lora_state = {
58
+ k.replace(f"{name}.lora.", ""): v
59
+ for k, v in state_dict.items()
60
+ if f"{name}.lora." in k
61
+ }
62
+ module.lora.load_state_dict(lora_state)
63
+
64
+
65
+ def save_lora(model, path):
66
+ """
67
+ raw_model = getattr(model, "_orig_mod", model)
68
+ If compiled: It grabs the hidden original model (_orig_mod).
69
+ If not compiled: It just uses the model as-is
70
+ """
71
+ raw_model = getattr(model, "_orig_mod", model)
72
+ state_dict = {}
73
+ for name, module in raw_model.named_modules():
74
+ if hasattr(module, "lora"):
75
+ clean_name = name[7:] if name.startswith("module.") else name
76
+ lora_state = {
77
+ f"{clean_name}.lora.{k}": v for k, v in module.lora.state_dict().items()
78
+ }
79
+ state_dict.update(lora_state)
80
+ torch.save(state_dict, path)
@@ -0,0 +1,4 @@
1
+ "The model module."
2
+
3
+ from ._minimind import minimind
4
+ # from ._minimindv import minimindv
@@ -0,0 +1,103 @@
1
+ import math
2
+ import torch
3
+ import torch.nn.init as init
4
+ import torch.nn.functional as F
5
+ from torch import nn
6
+ from transformers.activations import ACT2FN
7
+ from typing import Optional, Tuple, List, Union
8
+ from transformers import PreTrainedModel, GenerationMixin, PretrainedConfig
9
+ from transformers.modeling_outputs import CausalLMOutputWithPast
10
+ from transformers import PretrainedConfig
11
+
12
+ from module import RMSNorm, LayerNorm
13
+
14
+
15
+ class MiniMindConfig(PretrainedConfig):
16
+ """
17
+ Base Parameters:
18
+ (1) vocab_size: 6400
19
+ (2) hidden_size: 512
20
+ (3) num_hidden_layers: 8
21
+
22
+ YaRN (RoPE 扩展):
23
+ inference_rope_scaling: True
24
+ 扩展 Context Window 允许模型处理比训练时更长的序列
25
+
26
+ MoE:
27
+ use_moe: True
28
+ """
29
+
30
+ model_type = "minimind"
31
+
32
+ def __init__(
33
+ self,
34
+ dropout: float = 0.0,
35
+ bos_token_id: int = 1,
36
+ eos_token_id: int = 2,
37
+ hidden_act: str = "silu",
38
+ hidden_size: int = 512,
39
+ intermediate_size: int = None,
40
+ max_position_embeddings: int = 32768,
41
+ num_attention_heads: int = 8,
42
+ num_hidden_layers: int = 8,
43
+ num_key_value_heads: int = 2,
44
+ vocab_size: int = 6400,
45
+ rms_norm_eps: float = 1e-05,
46
+ rope_theta: int = 1000000.0,
47
+ inference_rope_scaling: bool = False,
48
+ flash_attn: bool = True,
49
+ ####################################################
50
+ # Here are the specific configurations of MOE
51
+ # When use_moe is false, the following is invalid
52
+ ####################################################
53
+ use_moe: bool = False,
54
+ num_experts_per_tok: int = 2,
55
+ n_routed_experts: int = 4,
56
+ n_shared_experts: int = 1,
57
+ scoring_func: str = "softmax",
58
+ aux_loss_alpha: float = 0.01,
59
+ seq_aux: bool = True,
60
+ norm_topk_prob: bool = True,
61
+ **kwargs,
62
+ ):
63
+ super().__init__(**kwargs)
64
+ self.dropout = dropout
65
+ self.bos_token_id = bos_token_id
66
+ self.eos_token_id = eos_token_id
67
+ self.hidden_act = hidden_act
68
+ self.hidden_size = hidden_size
69
+ self.intermediate_size = intermediate_size
70
+ self.max_position_embeddings = max_position_embeddings
71
+ self.num_attention_heads = num_attention_heads
72
+ self.num_hidden_layers = num_hidden_layers
73
+ self.num_key_value_heads = num_key_value_heads
74
+ self.vocab_size = vocab_size
75
+ self.rms_norm_eps = rms_norm_eps
76
+ self.rope_theta = rope_theta
77
+ self.inference_rope_scaling = inference_rope_scaling
78
+ # 外推长度 = factor * original_max_position_embeddings = 32768
79
+ self.rope_scaling = (
80
+ {
81
+ "beta_fast": 32,
82
+ "beta_slow": 1,
83
+ "factor": 16,
84
+ "original_max_position_embeddings": 2048,
85
+ "attention_factor": 1.0,
86
+ "type": "yarn",
87
+ }
88
+ if self.inference_rope_scaling
89
+ else None
90
+ )
91
+ self.flash_attn = flash_attn
92
+ ####################################################
93
+ # Here are the specific configurations of MOE
94
+ # When use_moe is false, the following is invalid
95
+ ####################################################
96
+ self.use_moe = use_moe
97
+ self.num_experts_per_tok = num_experts_per_tok # 每个token选择的专家数量
98
+ self.n_routed_experts = n_routed_experts # 总的专家数量
99
+ self.n_shared_experts = n_shared_experts # 共享专家
100
+ self.scoring_func = scoring_func # 评分函数,默认为'softmax'
101
+ self.aux_loss_alpha = aux_loss_alpha # 辅助损失的alpha参数
102
+ self.seq_aux = seq_aux # 是否在序列级别上计算辅助损失
103
+ self.norm_topk_prob = norm_topk_prob # 是否标准化top-k概率
@@ -0,0 +1,6 @@
1
+ from ._norm import RMSNorm, LayerNorm
2
+
3
+ __all__ = [
4
+ "RMSNorm",
5
+ "LayerNorm",
6
+ ]
@@ -0,0 +1,56 @@
1
+ """
2
+ Q1: 为什么RMSNorm和LayerNorm都在token特征维度上操作而非跨batch?
3
+ A1: BatchNorm是在处理图像数据时常用的归一化方式
4
+ 图像数据通常有强烈的空间相关性,即相邻的像素通常会有相似的值或模式。因此,图像的像素特征在一个batch中通常有相似的分布,这使得在整个batch上做归一化是合理的。BatchNorm通过计算每个特征(比如每个通道)的均值和方差,能有效地减轻这些空间相关性带来的影响,并保证训练时每一层的输入保持一定的分布,从而加速收敛。
5
+ 而在NLP任务中,每个token通常是一个具有特定语义和上下文信息的单位,比如每个token代表一个词。每个token的特征是通过模型的embedding层或Transformer层计算得到的,并包含了该token的语义信息。不同token的语义内容不同,所以它们的特征应该独立地进行归一化处理。
6
+ 如果归一化操作发生在batch维度上,会导致不考虑每个token的独立性。用于归一化的数据来自不同的batch,包含不同的token内容和信息,如果跨batch进行标准化,会丢失token间的独立性,使得token之间存在耦合关系,比如一些padding token并没有实际意义,但是被加入了归一化计算,进而影响模型的学习效果
7
+
8
+ Q2: 为什么使用RMSNorm而不是LayerNorm?
9
+ A2: (1) 计算过程比更简单,因为它不涉及均值的计算,并且减少了一个可学习参数
10
+ LayerNorm在归一化时需要计算每个token的均值和方差,并使用它们来标准化输入。
11
+ 而RMSNorm只需要计算特征的平方和,减少了计算复杂度和内存消耗
12
+ (2) 处理大型模型时,输入的特征维度可能非常大,计算均值和方差的开销相对较大。RMSNorm去除了均值计算,因此可以节省计算资源,特别是在高维数据中,计算效率更高
13
+ (3) 在各种场景中实验发现,使用RMSNorm能够减少约7%~64%的计算时间
14
+
15
+ Q3: token 独立标准化的作用
16
+ A3: (1) 变长序列: 在推理时,句子长度是动态的。如果 normalization 依赖于其他 token(Batch 维度),那么当句子变长时,均值和方差会剧烈波动。
17
+ (2) 并行计算: 独立化让每个 token 的归一化可以并行完成,不需要等待其他 Batch 的统计结果。
18
+ """
19
+
20
+ import torch
21
+ from torch import nn
22
+
23
+
24
+ class RMSNorm(nn.Module):
25
+ """
26
+ x.shape: [batch_size, seq_length, embedding_dim]
27
+ gamma: scale parameter which can learn named weight for each token
28
+
29
+ torch.rsqrt: x.pow(2).mean(-1, keepdim=True) + self.eps 的平方根倒数
30
+ 直接调用 rsqrt 比 先 sqrt 再 1 / 在 GPU 上更高效
31
+ keepdim: eg [1, 2, 4] -> [1, 2, 1] 而不是 [1, 2]
32
+
33
+ Llama 系列模型标配的归一化层,比标准的LayerNorm少了减去均值的步骤,计算更简单,训练更稳定
34
+ """
35
+
36
+ def __init__(self, dim: int, eps: float = 1e-5):
37
+ self.eps = eps
38
+ self.weight = nn.Parameter(torch.ones(dim))
39
+
40
+ def _norm(self, x):
41
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
42
+
43
+ def forward(self, x):
44
+ return self.weight * self._norm(x.float()).type_as(x)
45
+
46
+
47
+ class LayerNorm(nn.Module):
48
+ def __init__(self, dim: int, eps: float = 1e-5):
49
+ self.eps = eps
50
+ self.weight = nn.Parameter(torch.ones(dim))
51
+ self.bias = nn.Parameter(torch.zeros(dim))
52
+
53
+ def forward(self, x):
54
+ mean = x.mean(dim=-1, keepdim=True)
55
+ var = x.var(dim=-1, keepdim=True, unbiased=False)
56
+ return self.weight * (x - mean) / (var + self.eps).sqrt() + self.bias