orbit-torch 0.0.4a1__py3-none-any.whl → 0.1.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orbit/__init__.py +3 -1
- orbit/callback.py +4 -3
- orbit/dataset/__init__.py +1 -0
- orbit/dataset/cogn.py +138 -0
- orbit/dataset/data/cogn_en.jsonl +45 -0
- orbit/dataset/data/cogn_zh.jsonl +113 -0
- orbit/engine.py +210 -146
- orbit/kit/__init__.py +2 -0
- orbit/kit/interface.py +154 -0
- orbit/kit/wrapper.py +157 -0
- orbit/model/__init__.py +5 -0
- orbit/model/base.py +125 -0
- orbit/model/block/__init__.py +34 -0
- orbit/model/block/attention.py +265 -0
- orbit/model/block/bio.py +537 -0
- orbit/model/block/codebook.py +122 -0
- orbit/model/block/conv.py +505 -0
- orbit/model/block/embedding.py +252 -0
- orbit/model/block/film.py +176 -0
- orbit/model/block/fusion.py +335 -0
- orbit/model/block/gate.py +334 -0
- orbit/model/block/lora.py +776 -0
- orbit/model/block/mlp.py +68 -0
- orbit/model/block/moe.py +94 -0
- orbit/model/block/tcn.py +99 -0
- orbit/model/config.py +62 -0
- orbit/model/kit/__init__.py +6 -0
- orbit/model/kit/discriminator.py +46 -0
- orbit/model/kit/losses.py +193 -0
- orbit/model/motif/__init__.py +0 -0
- orbit/model/motif/vision/__init__.py +0 -0
- orbit/model/motif/vision/v1.py +645 -0
- orbit/model/registry.py +53 -0
- orbit/optim/__init__.py +2 -2
- orbit/optim/sam.py +10 -3
- orbit/plugin/__init__.py +12 -8
- orbit/plugin/board.py +1 -2
- orbit/plugin/checkpoint.py +137 -62
- orbit/plugin/classification.py +2 -2
- orbit/plugin/display_model.py +1 -2
- orbit/plugin/early_stopping.py +1 -2
- orbit/plugin/ema.py +1 -2
- orbit/plugin/gradient_accumulation.py +1 -2
- orbit/plugin/lora.py +346 -0
- orbit/plugin/memory_estimator.py +1 -2
- orbit/plugin/warmup.py +1 -2
- orbit/utils/__init__.py +24 -1
- orbit/utils/cuda.py +10 -0
- orbit/utils/freeze.py +61 -17
- orbit/utils/image.py +164 -0
- orbit/utils/initialization.py +184 -94
- orbit/utils/layer_io.py +66 -7
- orbit/utils/lora.py +480 -0
- orbit/utils/moe.py +55 -0
- orbit/utils/seed.py +3 -19
- orbit/utils/sft.py +93 -0
- orbit_torch-0.1.0b1.dist-info/METADATA +208 -0
- orbit_torch-0.1.0b1.dist-info/RECORD +65 -0
- orbit_torch-0.0.4a1.dist-info/METADATA +0 -25
- orbit_torch-0.0.4a1.dist-info/RECORD +0 -29
- {orbit_torch-0.0.4a1.dist-info → orbit_torch-0.1.0b1.dist-info}/WHEEL +0 -0
- {orbit_torch-0.0.4a1.dist-info → orbit_torch-0.1.0b1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,265 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
import math
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Optional, Tuple
|
|
7
|
+
|
|
8
|
+
from torch.nn.attention import SDPBackend, sdpa_kernel
|
|
9
|
+
|
|
10
|
+
from orbit.model import BaseBlock, register_model
|
|
11
|
+
from orbit.model.block.embedding import RotaryPositionalEmbedding, MRoPEInterleavedEmbedding
|
|
12
|
+
from orbit.model.block.gate import SigmoidGate
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class AttentionOutput:
|
|
17
|
+
output: torch.Tensor
|
|
18
|
+
attention_weights: Optional[torch.Tensor] = None
|
|
19
|
+
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def apply_attention(
|
|
23
|
+
query_states: torch.Tensor,
|
|
24
|
+
key_states: torch.Tensor,
|
|
25
|
+
value_states: torch.Tensor,
|
|
26
|
+
attention_mask: torch.Tensor = None,
|
|
27
|
+
output_attentions: bool = False,
|
|
28
|
+
is_causal: bool = None
|
|
29
|
+
) -> AttentionOutput:
|
|
30
|
+
''' 计算缩放点积注意力。
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
query_states (torch.Tensor): 查询状态张量。
|
|
34
|
+
key_states (torch.Tensor): 键状态张量。
|
|
35
|
+
value_states (torch.Tensor): 值状态张量。
|
|
36
|
+
attention_mask (torch.Tensor, optional): 注意力掩码。默认为 None。
|
|
37
|
+
output_attentions (bool, optional): 是否输出注意力权重。默认为 False。
|
|
38
|
+
is_causal (bool, optional): 是否应用因果掩码。默认为 None。
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
AttentionOutput: 包含注意力输出和可选权重的对象。
|
|
42
|
+
'''
|
|
43
|
+
|
|
44
|
+
if not output_attentions:
|
|
45
|
+
if attention_mask is None and is_causal is None: is_causal = True
|
|
46
|
+
|
|
47
|
+
try:
|
|
48
|
+
with sdpa_kernel([
|
|
49
|
+
SDPBackend.FLASH_ATTENTION,
|
|
50
|
+
SDPBackend.CUDNN_ATTENTION,
|
|
51
|
+
SDPBackend.EFFICIENT_ATTENTION
|
|
52
|
+
]):
|
|
53
|
+
output = F.scaled_dot_product_attention(
|
|
54
|
+
query_states,
|
|
55
|
+
key_states,
|
|
56
|
+
value_states,
|
|
57
|
+
attn_mask=attention_mask if not is_causal else None,
|
|
58
|
+
is_causal=is_causal,
|
|
59
|
+
dropout_p=0.0
|
|
60
|
+
)
|
|
61
|
+
return AttentionOutput(output=output, attention_weights=None)
|
|
62
|
+
except Exception:
|
|
63
|
+
print('Error at attn cal')
|
|
64
|
+
pass
|
|
65
|
+
|
|
66
|
+
d_k = query_states.size(-1)
|
|
67
|
+
scores = torch.matmul(query_states, key_states.transpose(-2, -1)) / math.sqrt(d_k)
|
|
68
|
+
|
|
69
|
+
if attention_mask is not None:
|
|
70
|
+
scores = scores.masked_fill(attention_mask == 0, float('-inf'))
|
|
71
|
+
|
|
72
|
+
attention_weights = torch.softmax(scores, dim=-1)
|
|
73
|
+
output = torch.matmul(attention_weights, value_states)
|
|
74
|
+
|
|
75
|
+
return AttentionOutput(output=output, attention_weights=attention_weights)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
@register_model()
|
|
79
|
+
class MultiHeadAttention(BaseBlock):
|
|
80
|
+
''' 多头注意力机制模块。
|
|
81
|
+
|
|
82
|
+
支持分组查询注意力 (GQA)、QK 归一化和门控机制。
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
hidden_size (int): 隐藏层大小。
|
|
86
|
+
num_heads (int): 注意力头数。
|
|
87
|
+
num_kv_heads (int, optional): 键/值头数,用于 GQA。如果为 None,则等于 num_heads。默认为 None。
|
|
88
|
+
use_qk_norm (bool, optional): 是否对查询和键应用 RMSNorm。默认为 True。
|
|
89
|
+
use_gate (bool, optional): 是否应用门控机制。默认为 False。
|
|
90
|
+
'''
|
|
91
|
+
|
|
92
|
+
def __init__(
|
|
93
|
+
self,
|
|
94
|
+
hidden_size,
|
|
95
|
+
num_heads,
|
|
96
|
+
num_kv_heads=None,
|
|
97
|
+
use_qk_norm=True,
|
|
98
|
+
use_gate=False
|
|
99
|
+
):
|
|
100
|
+
super(MultiHeadAttention, self).__init__()
|
|
101
|
+
|
|
102
|
+
if num_kv_heads is None: num_kv_heads = num_heads
|
|
103
|
+
|
|
104
|
+
assert hidden_size % num_heads == 0
|
|
105
|
+
assert num_heads % num_kv_heads == 0
|
|
106
|
+
|
|
107
|
+
self.hidden_size = hidden_size
|
|
108
|
+
self.num_heads = num_heads
|
|
109
|
+
self.num_kv_heads = num_kv_heads
|
|
110
|
+
self.num_kv_queries = num_heads // num_kv_heads
|
|
111
|
+
self.head_dim = hidden_size // num_heads
|
|
112
|
+
self.kv_dim = self.num_kv_heads * self.head_dim
|
|
113
|
+
self.use_qk_norm = use_qk_norm
|
|
114
|
+
self.use_gate = use_gate
|
|
115
|
+
|
|
116
|
+
if use_qk_norm:
|
|
117
|
+
self.q_norm = nn.RMSNorm(self.head_dim)
|
|
118
|
+
self.k_norm = nn.RMSNorm(self.head_dim)
|
|
119
|
+
|
|
120
|
+
if use_gate:
|
|
121
|
+
self.g_proj = SigmoidGate(hidden_size, hidden_size)
|
|
122
|
+
|
|
123
|
+
self.q_proj = nn.Linear(hidden_size, hidden_size)
|
|
124
|
+
self.k_proj = nn.Linear(hidden_size, self.kv_dim)
|
|
125
|
+
self.v_proj = nn.Linear(hidden_size, self.kv_dim)
|
|
126
|
+
self.o_proj = nn.Linear(hidden_size, hidden_size)
|
|
127
|
+
|
|
128
|
+
def forward(
|
|
129
|
+
self,
|
|
130
|
+
hidden_states: torch.Tensor,
|
|
131
|
+
kv_states: torch.Tensor = None,
|
|
132
|
+
attention_mask: torch.Tensor = None,
|
|
133
|
+
output_attentions: bool = False,
|
|
134
|
+
rotary_emb: RotaryPositionalEmbedding = None,
|
|
135
|
+
rotary_pos: int = 0,
|
|
136
|
+
past_key_value: tuple[torch.Tensor, torch.Tensor] = None,
|
|
137
|
+
use_cache: bool = False
|
|
138
|
+
) -> AttentionOutput:
|
|
139
|
+
''' 执行多头注意力的前向传播。
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
hidden_states (torch.Tensor): 输入隐藏状态。
|
|
143
|
+
kv_states (torch.Tensor, optional): 用于键/值的隐藏状态。如果为 None,则使用 hidden_states。默认为 None。
|
|
144
|
+
attention_mask (torch.Tensor, optional): 注意力掩码。默认为 None。
|
|
145
|
+
output_attentions (bool, optional): 是否输出注意力权重。默认为 False。
|
|
146
|
+
rotary_emb (RotaryPositionalEmbedding, optional): 旋转位置编码模块。默认为 None。
|
|
147
|
+
rotary_pos (int, optional): 旋转位置编码的起始位置。默认为 0。
|
|
148
|
+
past_key_value (tuple[torch.Tensor, torch.Tensor], optional): 过去的键值对缓存。默认为 None。
|
|
149
|
+
use_cache (bool, optional): 是否使用 KV 缓存。默认为 False。
|
|
150
|
+
|
|
151
|
+
Returns:
|
|
152
|
+
AttentionOutput: 包含输出、注意力权重和 KV 缓存的对象。
|
|
153
|
+
'''
|
|
154
|
+
|
|
155
|
+
if kv_states is None:
|
|
156
|
+
kv_states = hidden_states
|
|
157
|
+
|
|
158
|
+
batch_size, q_len, _ = hidden_states.shape
|
|
159
|
+
kv_len_input = kv_states.shape[1]
|
|
160
|
+
|
|
161
|
+
if self.use_gate:
|
|
162
|
+
G = self.g_proj(hidden_states)
|
|
163
|
+
|
|
164
|
+
Q = self.q_proj(hidden_states)
|
|
165
|
+
K = self.k_proj(kv_states)
|
|
166
|
+
V = self.v_proj(kv_states)
|
|
167
|
+
|
|
168
|
+
Q = Q.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
169
|
+
K = K.view(batch_size, kv_len_input, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
|
170
|
+
V = V.view(batch_size, kv_len_input, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
|
171
|
+
|
|
172
|
+
if self.use_qk_norm:
|
|
173
|
+
Q = self.q_norm(Q)
|
|
174
|
+
K = self.k_norm(K)
|
|
175
|
+
|
|
176
|
+
if rotary_emb is not None:
|
|
177
|
+
Q = rotary_emb(Q, start_pos=rotary_pos)
|
|
178
|
+
K = rotary_emb(K, start_pos=rotary_pos)
|
|
179
|
+
|
|
180
|
+
current_key_value = None
|
|
181
|
+
if use_cache:
|
|
182
|
+
if past_key_value is not None:
|
|
183
|
+
past_k, past_v = past_key_value
|
|
184
|
+
K = torch.cat((past_k, K), dim=2)
|
|
185
|
+
V = torch.cat((past_v, V), dim=2)
|
|
186
|
+
current_key_value = (K, V)
|
|
187
|
+
|
|
188
|
+
kv_seq_len_total = K.shape[2]
|
|
189
|
+
|
|
190
|
+
if self.num_kv_queries > 1:
|
|
191
|
+
# [B, H_kv, 1, L, D] -> [B, H_kv, G, L, D]
|
|
192
|
+
K = K[:, :, None, :, :].expand(batch_size, self.num_kv_heads, self.num_kv_queries, kv_seq_len_total, self.head_dim)
|
|
193
|
+
V = V[:, :, None, :, :].expand(batch_size, self.num_kv_heads, self.num_kv_queries, kv_seq_len_total, self.head_dim)
|
|
194
|
+
|
|
195
|
+
K = K.reshape(batch_size, self.num_heads, kv_seq_len_total, self.head_dim)
|
|
196
|
+
V = V.reshape(batch_size, self.num_heads, kv_seq_len_total, self.head_dim)
|
|
197
|
+
|
|
198
|
+
attn_output = apply_attention(Q, K, V, attention_mask, output_attentions)
|
|
199
|
+
output = attn_output.output
|
|
200
|
+
attention_weights = attn_output.attention_weights
|
|
201
|
+
|
|
202
|
+
output = output.transpose(1, 2).contiguous().view(batch_size, q_len, self.hidden_size)
|
|
203
|
+
|
|
204
|
+
output = self.o_proj(output)
|
|
205
|
+
|
|
206
|
+
if self.use_gate:
|
|
207
|
+
output = output * G
|
|
208
|
+
|
|
209
|
+
return AttentionOutput(
|
|
210
|
+
output=output,
|
|
211
|
+
attention_weights=attention_weights,
|
|
212
|
+
past_key_value=current_key_value
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
class SpatialMultiHeadAttention(MultiHeadAttention):
|
|
217
|
+
'''
|
|
218
|
+
扩展的 MultiHeadAttention,支持接收 2D 位置索引用于 MRoPE。
|
|
219
|
+
'''
|
|
220
|
+
def forward(
|
|
221
|
+
self,
|
|
222
|
+
hidden_states: torch.Tensor,
|
|
223
|
+
positions: torch.Tensor = None,
|
|
224
|
+
attention_mask: torch.Tensor = None,
|
|
225
|
+
rotary_emb: MRoPEInterleavedEmbedding = None,
|
|
226
|
+
output_attentions: bool = False,
|
|
227
|
+
) -> AttentionOutput:
|
|
228
|
+
|
|
229
|
+
batch_size, q_len, _ = hidden_states.shape
|
|
230
|
+
|
|
231
|
+
Q = self.q_proj(hidden_states)
|
|
232
|
+
K = self.k_proj(hidden_states)
|
|
233
|
+
V = self.v_proj(hidden_states)
|
|
234
|
+
|
|
235
|
+
Q = Q.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
236
|
+
K = K.view(batch_size, q_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
|
237
|
+
V = V.view(batch_size, q_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
|
238
|
+
|
|
239
|
+
if self.use_qk_norm:
|
|
240
|
+
Q = self.q_norm(Q)
|
|
241
|
+
K = self.k_norm(K)
|
|
242
|
+
|
|
243
|
+
if rotary_emb is not None and positions is not None:
|
|
244
|
+
Q = rotary_emb(Q, positions=positions)
|
|
245
|
+
K = rotary_emb(K, positions=positions)
|
|
246
|
+
|
|
247
|
+
if self.num_kv_queries > 1:
|
|
248
|
+
K = K.repeat_interleave(self.num_kv_queries, dim=1)
|
|
249
|
+
V = V.repeat_interleave(self.num_kv_queries, dim=1)
|
|
250
|
+
|
|
251
|
+
attn_output = apply_attention(Q, K, V, attention_mask, output_attentions)
|
|
252
|
+
|
|
253
|
+
output = attn_output.output
|
|
254
|
+
output = output.transpose(1, 2).contiguous().view(batch_size, q_len, self.hidden_size)
|
|
255
|
+
|
|
256
|
+
output = self.o_proj(output)
|
|
257
|
+
|
|
258
|
+
if self.use_gate:
|
|
259
|
+
G = self.g_proj(hidden_states)
|
|
260
|
+
output = output * G
|
|
261
|
+
|
|
262
|
+
return AttentionOutput(
|
|
263
|
+
output=output,
|
|
264
|
+
attention_weights=attn_output.attention_weights
|
|
265
|
+
)
|