orbit-torch 0.0.4a1__py3-none-any.whl → 0.1.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. orbit/__init__.py +3 -1
  2. orbit/callback.py +4 -3
  3. orbit/dataset/__init__.py +1 -0
  4. orbit/dataset/cogn.py +138 -0
  5. orbit/dataset/data/cogn_en.jsonl +45 -0
  6. orbit/dataset/data/cogn_zh.jsonl +113 -0
  7. orbit/engine.py +210 -146
  8. orbit/kit/__init__.py +2 -0
  9. orbit/kit/interface.py +154 -0
  10. orbit/kit/wrapper.py +157 -0
  11. orbit/model/__init__.py +5 -0
  12. orbit/model/base.py +125 -0
  13. orbit/model/block/__init__.py +34 -0
  14. orbit/model/block/attention.py +265 -0
  15. orbit/model/block/bio.py +537 -0
  16. orbit/model/block/codebook.py +122 -0
  17. orbit/model/block/conv.py +505 -0
  18. orbit/model/block/embedding.py +252 -0
  19. orbit/model/block/film.py +176 -0
  20. orbit/model/block/fusion.py +335 -0
  21. orbit/model/block/gate.py +334 -0
  22. orbit/model/block/lora.py +776 -0
  23. orbit/model/block/mlp.py +68 -0
  24. orbit/model/block/moe.py +94 -0
  25. orbit/model/block/tcn.py +99 -0
  26. orbit/model/config.py +62 -0
  27. orbit/model/kit/__init__.py +6 -0
  28. orbit/model/kit/discriminator.py +46 -0
  29. orbit/model/kit/losses.py +193 -0
  30. orbit/model/motif/__init__.py +0 -0
  31. orbit/model/motif/vision/__init__.py +0 -0
  32. orbit/model/motif/vision/v1.py +645 -0
  33. orbit/model/registry.py +53 -0
  34. orbit/optim/__init__.py +2 -2
  35. orbit/optim/sam.py +10 -3
  36. orbit/plugin/__init__.py +12 -8
  37. orbit/plugin/board.py +1 -2
  38. orbit/plugin/checkpoint.py +137 -62
  39. orbit/plugin/classification.py +2 -2
  40. orbit/plugin/display_model.py +1 -2
  41. orbit/plugin/early_stopping.py +1 -2
  42. orbit/plugin/ema.py +1 -2
  43. orbit/plugin/gradient_accumulation.py +1 -2
  44. orbit/plugin/lora.py +346 -0
  45. orbit/plugin/memory_estimator.py +1 -2
  46. orbit/plugin/warmup.py +1 -2
  47. orbit/utils/__init__.py +24 -1
  48. orbit/utils/cuda.py +10 -0
  49. orbit/utils/freeze.py +61 -17
  50. orbit/utils/image.py +164 -0
  51. orbit/utils/initialization.py +184 -94
  52. orbit/utils/layer_io.py +66 -7
  53. orbit/utils/lora.py +480 -0
  54. orbit/utils/moe.py +55 -0
  55. orbit/utils/seed.py +3 -19
  56. orbit/utils/sft.py +93 -0
  57. orbit_torch-0.1.0b1.dist-info/METADATA +208 -0
  58. orbit_torch-0.1.0b1.dist-info/RECORD +65 -0
  59. orbit_torch-0.0.4a1.dist-info/METADATA +0 -25
  60. orbit_torch-0.0.4a1.dist-info/RECORD +0 -29
  61. {orbit_torch-0.0.4a1.dist-info → orbit_torch-0.1.0b1.dist-info}/WHEEL +0 -0
  62. {orbit_torch-0.0.4a1.dist-info → orbit_torch-0.1.0b1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,265 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+ from dataclasses import dataclass
6
+ from typing import Optional, Tuple
7
+
8
+ from torch.nn.attention import SDPBackend, sdpa_kernel
9
+
10
+ from orbit.model import BaseBlock, register_model
11
+ from orbit.model.block.embedding import RotaryPositionalEmbedding, MRoPEInterleavedEmbedding
12
+ from orbit.model.block.gate import SigmoidGate
13
+
14
+
15
+ @dataclass
16
+ class AttentionOutput:
17
+ output: torch.Tensor
18
+ attention_weights: Optional[torch.Tensor] = None
19
+ past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
20
+
21
+
22
+ def apply_attention(
23
+ query_states: torch.Tensor,
24
+ key_states: torch.Tensor,
25
+ value_states: torch.Tensor,
26
+ attention_mask: torch.Tensor = None,
27
+ output_attentions: bool = False,
28
+ is_causal: bool = None
29
+ ) -> AttentionOutput:
30
+ ''' 计算缩放点积注意力。
31
+
32
+ Args:
33
+ query_states (torch.Tensor): 查询状态张量。
34
+ key_states (torch.Tensor): 键状态张量。
35
+ value_states (torch.Tensor): 值状态张量。
36
+ attention_mask (torch.Tensor, optional): 注意力掩码。默认为 None。
37
+ output_attentions (bool, optional): 是否输出注意力权重。默认为 False。
38
+ is_causal (bool, optional): 是否应用因果掩码。默认为 None。
39
+
40
+ Returns:
41
+ AttentionOutput: 包含注意力输出和可选权重的对象。
42
+ '''
43
+
44
+ if not output_attentions:
45
+ if attention_mask is None and is_causal is None: is_causal = True
46
+
47
+ try:
48
+ with sdpa_kernel([
49
+ SDPBackend.FLASH_ATTENTION,
50
+ SDPBackend.CUDNN_ATTENTION,
51
+ SDPBackend.EFFICIENT_ATTENTION
52
+ ]):
53
+ output = F.scaled_dot_product_attention(
54
+ query_states,
55
+ key_states,
56
+ value_states,
57
+ attn_mask=attention_mask if not is_causal else None,
58
+ is_causal=is_causal,
59
+ dropout_p=0.0
60
+ )
61
+ return AttentionOutput(output=output, attention_weights=None)
62
+ except Exception:
63
+ print('Error at attn cal')
64
+ pass
65
+
66
+ d_k = query_states.size(-1)
67
+ scores = torch.matmul(query_states, key_states.transpose(-2, -1)) / math.sqrt(d_k)
68
+
69
+ if attention_mask is not None:
70
+ scores = scores.masked_fill(attention_mask == 0, float('-inf'))
71
+
72
+ attention_weights = torch.softmax(scores, dim=-1)
73
+ output = torch.matmul(attention_weights, value_states)
74
+
75
+ return AttentionOutput(output=output, attention_weights=attention_weights)
76
+
77
+
78
+ @register_model()
79
+ class MultiHeadAttention(BaseBlock):
80
+ ''' 多头注意力机制模块。
81
+
82
+ 支持分组查询注意力 (GQA)、QK 归一化和门控机制。
83
+
84
+ Args:
85
+ hidden_size (int): 隐藏层大小。
86
+ num_heads (int): 注意力头数。
87
+ num_kv_heads (int, optional): 键/值头数,用于 GQA。如果为 None,则等于 num_heads。默认为 None。
88
+ use_qk_norm (bool, optional): 是否对查询和键应用 RMSNorm。默认为 True。
89
+ use_gate (bool, optional): 是否应用门控机制。默认为 False。
90
+ '''
91
+
92
+ def __init__(
93
+ self,
94
+ hidden_size,
95
+ num_heads,
96
+ num_kv_heads=None,
97
+ use_qk_norm=True,
98
+ use_gate=False
99
+ ):
100
+ super(MultiHeadAttention, self).__init__()
101
+
102
+ if num_kv_heads is None: num_kv_heads = num_heads
103
+
104
+ assert hidden_size % num_heads == 0
105
+ assert num_heads % num_kv_heads == 0
106
+
107
+ self.hidden_size = hidden_size
108
+ self.num_heads = num_heads
109
+ self.num_kv_heads = num_kv_heads
110
+ self.num_kv_queries = num_heads // num_kv_heads
111
+ self.head_dim = hidden_size // num_heads
112
+ self.kv_dim = self.num_kv_heads * self.head_dim
113
+ self.use_qk_norm = use_qk_norm
114
+ self.use_gate = use_gate
115
+
116
+ if use_qk_norm:
117
+ self.q_norm = nn.RMSNorm(self.head_dim)
118
+ self.k_norm = nn.RMSNorm(self.head_dim)
119
+
120
+ if use_gate:
121
+ self.g_proj = SigmoidGate(hidden_size, hidden_size)
122
+
123
+ self.q_proj = nn.Linear(hidden_size, hidden_size)
124
+ self.k_proj = nn.Linear(hidden_size, self.kv_dim)
125
+ self.v_proj = nn.Linear(hidden_size, self.kv_dim)
126
+ self.o_proj = nn.Linear(hidden_size, hidden_size)
127
+
128
+ def forward(
129
+ self,
130
+ hidden_states: torch.Tensor,
131
+ kv_states: torch.Tensor = None,
132
+ attention_mask: torch.Tensor = None,
133
+ output_attentions: bool = False,
134
+ rotary_emb: RotaryPositionalEmbedding = None,
135
+ rotary_pos: int = 0,
136
+ past_key_value: tuple[torch.Tensor, torch.Tensor] = None,
137
+ use_cache: bool = False
138
+ ) -> AttentionOutput:
139
+ ''' 执行多头注意力的前向传播。
140
+
141
+ Args:
142
+ hidden_states (torch.Tensor): 输入隐藏状态。
143
+ kv_states (torch.Tensor, optional): 用于键/值的隐藏状态。如果为 None,则使用 hidden_states。默认为 None。
144
+ attention_mask (torch.Tensor, optional): 注意力掩码。默认为 None。
145
+ output_attentions (bool, optional): 是否输出注意力权重。默认为 False。
146
+ rotary_emb (RotaryPositionalEmbedding, optional): 旋转位置编码模块。默认为 None。
147
+ rotary_pos (int, optional): 旋转位置编码的起始位置。默认为 0。
148
+ past_key_value (tuple[torch.Tensor, torch.Tensor], optional): 过去的键值对缓存。默认为 None。
149
+ use_cache (bool, optional): 是否使用 KV 缓存。默认为 False。
150
+
151
+ Returns:
152
+ AttentionOutput: 包含输出、注意力权重和 KV 缓存的对象。
153
+ '''
154
+
155
+ if kv_states is None:
156
+ kv_states = hidden_states
157
+
158
+ batch_size, q_len, _ = hidden_states.shape
159
+ kv_len_input = kv_states.shape[1]
160
+
161
+ if self.use_gate:
162
+ G = self.g_proj(hidden_states)
163
+
164
+ Q = self.q_proj(hidden_states)
165
+ K = self.k_proj(kv_states)
166
+ V = self.v_proj(kv_states)
167
+
168
+ Q = Q.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
169
+ K = K.view(batch_size, kv_len_input, self.num_kv_heads, self.head_dim).transpose(1, 2)
170
+ V = V.view(batch_size, kv_len_input, self.num_kv_heads, self.head_dim).transpose(1, 2)
171
+
172
+ if self.use_qk_norm:
173
+ Q = self.q_norm(Q)
174
+ K = self.k_norm(K)
175
+
176
+ if rotary_emb is not None:
177
+ Q = rotary_emb(Q, start_pos=rotary_pos)
178
+ K = rotary_emb(K, start_pos=rotary_pos)
179
+
180
+ current_key_value = None
181
+ if use_cache:
182
+ if past_key_value is not None:
183
+ past_k, past_v = past_key_value
184
+ K = torch.cat((past_k, K), dim=2)
185
+ V = torch.cat((past_v, V), dim=2)
186
+ current_key_value = (K, V)
187
+
188
+ kv_seq_len_total = K.shape[2]
189
+
190
+ if self.num_kv_queries > 1:
191
+ # [B, H_kv, 1, L, D] -> [B, H_kv, G, L, D]
192
+ K = K[:, :, None, :, :].expand(batch_size, self.num_kv_heads, self.num_kv_queries, kv_seq_len_total, self.head_dim)
193
+ V = V[:, :, None, :, :].expand(batch_size, self.num_kv_heads, self.num_kv_queries, kv_seq_len_total, self.head_dim)
194
+
195
+ K = K.reshape(batch_size, self.num_heads, kv_seq_len_total, self.head_dim)
196
+ V = V.reshape(batch_size, self.num_heads, kv_seq_len_total, self.head_dim)
197
+
198
+ attn_output = apply_attention(Q, K, V, attention_mask, output_attentions)
199
+ output = attn_output.output
200
+ attention_weights = attn_output.attention_weights
201
+
202
+ output = output.transpose(1, 2).contiguous().view(batch_size, q_len, self.hidden_size)
203
+
204
+ output = self.o_proj(output)
205
+
206
+ if self.use_gate:
207
+ output = output * G
208
+
209
+ return AttentionOutput(
210
+ output=output,
211
+ attention_weights=attention_weights,
212
+ past_key_value=current_key_value
213
+ )
214
+
215
+
216
+ class SpatialMultiHeadAttention(MultiHeadAttention):
217
+ '''
218
+ 扩展的 MultiHeadAttention,支持接收 2D 位置索引用于 MRoPE。
219
+ '''
220
+ def forward(
221
+ self,
222
+ hidden_states: torch.Tensor,
223
+ positions: torch.Tensor = None,
224
+ attention_mask: torch.Tensor = None,
225
+ rotary_emb: MRoPEInterleavedEmbedding = None,
226
+ output_attentions: bool = False,
227
+ ) -> AttentionOutput:
228
+
229
+ batch_size, q_len, _ = hidden_states.shape
230
+
231
+ Q = self.q_proj(hidden_states)
232
+ K = self.k_proj(hidden_states)
233
+ V = self.v_proj(hidden_states)
234
+
235
+ Q = Q.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
236
+ K = K.view(batch_size, q_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
237
+ V = V.view(batch_size, q_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
238
+
239
+ if self.use_qk_norm:
240
+ Q = self.q_norm(Q)
241
+ K = self.k_norm(K)
242
+
243
+ if rotary_emb is not None and positions is not None:
244
+ Q = rotary_emb(Q, positions=positions)
245
+ K = rotary_emb(K, positions=positions)
246
+
247
+ if self.num_kv_queries > 1:
248
+ K = K.repeat_interleave(self.num_kv_queries, dim=1)
249
+ V = V.repeat_interleave(self.num_kv_queries, dim=1)
250
+
251
+ attn_output = apply_attention(Q, K, V, attention_mask, output_attentions)
252
+
253
+ output = attn_output.output
254
+ output = output.transpose(1, 2).contiguous().view(batch_size, q_len, self.hidden_size)
255
+
256
+ output = self.o_proj(output)
257
+
258
+ if self.use_gate:
259
+ G = self.g_proj(hidden_states)
260
+ output = output * G
261
+
262
+ return AttentionOutput(
263
+ output=output,
264
+ attention_weights=attn_output.attention_weights
265
+ )