orbit-torch 0.0.4a1__py3-none-any.whl → 0.1.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. orbit/__init__.py +3 -1
  2. orbit/callback.py +4 -3
  3. orbit/dataset/__init__.py +1 -0
  4. orbit/dataset/cogn.py +138 -0
  5. orbit/dataset/data/cogn_en.jsonl +45 -0
  6. orbit/dataset/data/cogn_zh.jsonl +113 -0
  7. orbit/engine.py +210 -146
  8. orbit/kit/__init__.py +2 -0
  9. orbit/kit/interface.py +154 -0
  10. orbit/kit/wrapper.py +157 -0
  11. orbit/model/__init__.py +5 -0
  12. orbit/model/base.py +125 -0
  13. orbit/model/block/__init__.py +34 -0
  14. orbit/model/block/attention.py +265 -0
  15. orbit/model/block/bio.py +537 -0
  16. orbit/model/block/codebook.py +122 -0
  17. orbit/model/block/conv.py +505 -0
  18. orbit/model/block/embedding.py +252 -0
  19. orbit/model/block/film.py +176 -0
  20. orbit/model/block/fusion.py +335 -0
  21. orbit/model/block/gate.py +334 -0
  22. orbit/model/block/lora.py +776 -0
  23. orbit/model/block/mlp.py +68 -0
  24. orbit/model/block/moe.py +94 -0
  25. orbit/model/block/tcn.py +99 -0
  26. orbit/model/config.py +62 -0
  27. orbit/model/kit/__init__.py +6 -0
  28. orbit/model/kit/discriminator.py +46 -0
  29. orbit/model/kit/losses.py +193 -0
  30. orbit/model/motif/__init__.py +0 -0
  31. orbit/model/motif/vision/__init__.py +0 -0
  32. orbit/model/motif/vision/v1.py +645 -0
  33. orbit/model/registry.py +53 -0
  34. orbit/optim/__init__.py +2 -2
  35. orbit/optim/sam.py +10 -3
  36. orbit/plugin/__init__.py +12 -8
  37. orbit/plugin/board.py +1 -2
  38. orbit/plugin/checkpoint.py +137 -62
  39. orbit/plugin/classification.py +2 -2
  40. orbit/plugin/display_model.py +1 -2
  41. orbit/plugin/early_stopping.py +1 -2
  42. orbit/plugin/ema.py +1 -2
  43. orbit/plugin/gradient_accumulation.py +1 -2
  44. orbit/plugin/lora.py +346 -0
  45. orbit/plugin/memory_estimator.py +1 -2
  46. orbit/plugin/warmup.py +1 -2
  47. orbit/utils/__init__.py +24 -1
  48. orbit/utils/cuda.py +10 -0
  49. orbit/utils/freeze.py +61 -17
  50. orbit/utils/image.py +164 -0
  51. orbit/utils/initialization.py +184 -94
  52. orbit/utils/layer_io.py +66 -7
  53. orbit/utils/lora.py +480 -0
  54. orbit/utils/moe.py +55 -0
  55. orbit/utils/seed.py +3 -19
  56. orbit/utils/sft.py +93 -0
  57. orbit_torch-0.1.0b1.dist-info/METADATA +208 -0
  58. orbit_torch-0.1.0b1.dist-info/RECORD +65 -0
  59. orbit_torch-0.0.4a1.dist-info/METADATA +0 -25
  60. orbit_torch-0.0.4a1.dist-info/RECORD +0 -29
  61. {orbit_torch-0.0.4a1.dist-info → orbit_torch-0.1.0b1.dist-info}/WHEEL +0 -0
  62. {orbit_torch-0.0.4a1.dist-info → orbit_torch-0.1.0b1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,252 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+
5
+ from orbit.model import BaseBlock, register_model
6
+
7
+
8
+ @register_model()
9
+ class RotaryPositionalEmbedding(BaseBlock):
10
+ '''
11
+ 旋转位置编码 (Rotary Positional Embedding, RoPE)。
12
+ '''
13
+
14
+ def __init__(self, model_dim: int, max_len: int = 128000, base: int = 10000):
15
+ '''
16
+ 初始化 RoPE 模块。
17
+
18
+ Args:
19
+ model_dim (int): 模型的维度 (或 head_dim)。必须是偶数。
20
+ max_len (int, optional): 预计算位置编码的最大序列长度。默认为 128000。
21
+ base (int, optional): 计算频率的基数。默认为 10000。
22
+ '''
23
+ super(RotaryPositionalEmbedding, self).__init__()
24
+
25
+ self.model_dim = model_dim
26
+ self.max_len = max_len
27
+ self.base = base
28
+
29
+ inv_freq = 1.0 / (base ** (torch.arange(0, model_dim, 2).float() / model_dim))
30
+
31
+ t = torch.arange(max_len, dtype=torch.float)
32
+
33
+ freqs = torch.outer(t, inv_freq)
34
+
35
+ emb = torch.cat((freqs, freqs), dim=-1)
36
+
37
+ self.register_buffer('cos_cached', emb.cos())
38
+ self.register_buffer('sin_cached', emb.sin())
39
+
40
+ def _rotate_half(self, x: torch.Tensor) -> torch.Tensor:
41
+ '''
42
+ 将向量分为两半并旋转: [-x2, x1]。
43
+ 无论输入是 3D 还是 4D,Split 都是作用在最后一维 (model_dim)。
44
+
45
+ Args:
46
+ x (torch.Tensor): 输入张量。
47
+
48
+ Returns:
49
+ torch.Tensor: 旋转后的张量。
50
+ '''
51
+ x1, x2 = x.chunk(2, dim=-1)
52
+ return torch.cat((-x2, x1), dim=-1)
53
+
54
+ def forward(self, x: torch.Tensor, start_pos: int = 0) -> torch.Tensor:
55
+ '''
56
+ 应用旋转位置编码。
57
+
58
+ 自动适配两种输入:
59
+ 1. [Batch, Seq_Len, Dim]
60
+ 2. [Batch, Head, Seq_Len, Head_Dim]
61
+
62
+ Args:
63
+ x (torch.Tensor): 输入张量。
64
+ start_pos (int, optional): 起始位置索引,用于 KV Cache 推理。默认为 0。
65
+
66
+ Returns:
67
+ torch.Tensor: 添加了位置信息的张量。
68
+ '''
69
+ ndim = x.ndim
70
+ seq_len = x.shape[-2]
71
+
72
+ cos = self.cos_cached[start_pos : start_pos + seq_len, :]
73
+ sin = self.sin_cached[start_pos : start_pos + seq_len, :]
74
+
75
+ shape = [1] * (ndim - 2) + [seq_len, -1]
76
+ cos = cos.view(*shape)
77
+ sin = sin.view(*shape)
78
+
79
+ return (x * cos) + (self._rotate_half(x) * sin)
80
+
81
+
82
+ @register_model()
83
+ class SinusoidalPositionalEmbedding(BaseBlock):
84
+
85
+ def __init__(self, model_dim: int, max_len: int = 128000):
86
+ '''
87
+ 初始化绝对位置编码模块。
88
+
89
+ Args:
90
+ model_dim (int): 模型的维度。
91
+ max_len (int, optional): 最大序列长度。默认为 128000。
92
+ '''
93
+ super(SinusoidalPositionalEmbedding, self).__init__()
94
+
95
+ self.model_dim = model_dim
96
+ self.max_len = max_len
97
+
98
+ pe = torch.zeros(max_len, model_dim)
99
+
100
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
101
+
102
+ div_term = torch.exp(torch.arange(0, model_dim, 2).float() * (-math.log(10000.0) / model_dim))
103
+
104
+ pe[:, 0::2] = torch.sin(position * div_term)
105
+ pe[:, 1::2] = torch.cos(position * div_term)
106
+
107
+ pe = pe.unsqueeze(0)
108
+ self.register_buffer('pe', pe)
109
+
110
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
111
+ '''
112
+ 前向传播。
113
+
114
+ Args:
115
+ x (torch.Tensor): 输入张量。Shape: [Batch_Size, Seq_Len, model_dim]。
116
+
117
+ Returns:
118
+ torch.Tensor: 加上位置编码后的张量。
119
+ '''
120
+ x = x + self.pe[:, :x.size(1), :]
121
+ return x
122
+
123
+ @register_model()
124
+ class MRoPEInterleavedEmbedding(BaseBlock):
125
+ '''
126
+ 交错分配多模态旋转位置编码 (MRoPE‑Interleave)。
127
+ 支持三维位置(时间 t、高度 h、宽度 w),频率通道采用轮转交错分配 (thw…thw…thw)。
128
+ '''
129
+ def __init__(self, model_dim: int, max_len: int = 128000, base: int = 10000, num_axes: int = 3):
130
+ '''
131
+ 初始化 MRoPEInterleaved 模块。
132
+
133
+ Args:
134
+ model_dim (int): 模型的维度。必须是偶数且能被 num_axes 整除。
135
+ max_len (int, optional): 预计算位置编码的最大序列长度。默认为 128000。
136
+ base (int, optional): 计算频率的基数。默认为 10000。
137
+ num_axes (int, optional): 位置轴的数量(例如 3 表示时间、高度、宽度)。默认为 3。
138
+ '''
139
+ super().__init__()
140
+ assert model_dim % 2 == 0, 'model_dim must be even'
141
+ assert model_dim % num_axes == 0, f'model_dim {model_dim} not divisible by num_axes {num_axes}'
142
+
143
+ self.model_dim = model_dim
144
+ self.max_len = max_len
145
+ self.base = base
146
+ self.num_axes = num_axes
147
+
148
+ inv_freq = 1.0 / (base ** (torch.arange(0, model_dim, 2).float() / model_dim))
149
+
150
+ t_range = torch.arange(max_len, dtype=torch.float)
151
+ freqs = torch.outer(t_range, inv_freq) # [max_len, dim/2]
152
+
153
+ emb = torch.cat((freqs, freqs), dim=-1) # [max_len, dim]
154
+
155
+ self.register_buffer('cos_cached', emb.cos())
156
+ self.register_buffer('sin_cached', emb.sin())
157
+
158
+ self.register_buffer(
159
+ 'axis_mask',
160
+ torch.arange(model_dim) % num_axes,
161
+ persistent=False
162
+ )
163
+
164
+ k = model_dim // num_axes
165
+ idx = []
166
+ for p in range(model_dim):
167
+ j = p % num_axes
168
+ i = p // num_axes
169
+ pos_in_old = j * k + i
170
+ idx.append(pos_in_old)
171
+
172
+ self.register_buffer('interleave_idx', torch.tensor(idx, dtype=torch.long), persistent=False)
173
+
174
+ def _rotate_half(self, x: torch.Tensor) -> torch.Tensor:
175
+ '''
176
+ 将向量分为两半并旋转: [-x2, x1]。
177
+
178
+ Args:
179
+ x (torch.Tensor): 输入张量。
180
+
181
+ Returns:
182
+ torch.Tensor: 旋转后的张量。
183
+ '''
184
+ x1, x2 = x.chunk(2, dim=-1)
185
+ return torch.cat((-x2, x1), dim=-1)
186
+
187
+ def forward(self, x: torch.Tensor, positions: torch.Tensor = None, start_pos: int = 0) -> torch.Tensor:
188
+ '''
189
+ 应用多模态旋转位置编码。
190
+
191
+ Args:
192
+ x (torch.Tensor): 输入张量。Shape: [Batch, Seq_Len, Dim] 或 [Batch, Head, Seq_Len, Head_Dim]。
193
+ positions (torch.Tensor, optional): 位置索引张量。Shape: [Batch, Seq_Len] 或 [Batch, Seq_Len, num_axes]。
194
+ 如果是 2D 张量,将自动扩展为 [Batch, Seq_Len, num_axes]。
195
+ 如果为 None 且 num_axes=1,将自动创建线性位置索引。
196
+ start_pos (int, optional): 起始位置索引。默认为 0。
197
+
198
+ Returns:
199
+ torch.Tensor: 添加了位置信息的张量。
200
+
201
+ Raises:
202
+ ValueError: 如果 positions 为 None 且 num_axes > 1。
203
+ '''
204
+ ndim = x.ndim
205
+ seq_len = x.shape[-2]
206
+ batch_size = x.shape[0]
207
+
208
+ if positions is None:
209
+ if self.num_axes == 1:
210
+ positions = torch.arange(0, seq_len, device=x.device, dtype=torch.long)
211
+ else:
212
+ raise ValueError("positions must be provided when num_axes > 1 (e.g. for vision/multimodal inputs)")
213
+
214
+ if positions.ndim == 1:
215
+ positions = positions.unsqueeze(0).unsqueeze(-1).expand(batch_size, -1, self.num_axes)
216
+
217
+ if positions.ndim == 2:
218
+ positions = positions.unsqueeze(-1).expand(-1, -1, self.num_axes)
219
+
220
+ if positions.ndim == 3 and positions.shape[-1] == 1:
221
+ positions = positions.expand(-1, -1, self.num_axes)
222
+
223
+ batch_size = positions.shape[0]
224
+
225
+ cos_list, sin_list = [], []
226
+
227
+ for ax in range(self.num_axes):
228
+ pos_ax = positions[..., ax]
229
+ pos_ax = torch.clamp(pos_ax + start_pos, 0, self.max_len - 1).long()
230
+
231
+ cos_full = self.cos_cached[pos_ax]
232
+ sin_full = self.sin_cached[pos_ax]
233
+
234
+ mask = (self.axis_mask == ax)
235
+ cos_ax = cos_full[..., mask]
236
+ sin_ax = sin_full[..., mask]
237
+
238
+ cos_list.append(cos_ax)
239
+ sin_list.append(sin_ax)
240
+
241
+ cos_all = torch.cat(cos_list, dim=-1)
242
+ sin_all = torch.cat(sin_list, dim=-1)
243
+
244
+ cos_all = cos_all[..., self.interleave_idx]
245
+ sin_all = sin_all[..., self.interleave_idx]
246
+
247
+ if ndim == 4:
248
+ shape = [batch_size, 1, seq_len, -1]
249
+ cos_all = cos_all.view(*shape)
250
+ sin_all = sin_all.view(*shape)
251
+
252
+ return (x * cos_all) + (self._rotate_half(x) * sin_all)
@@ -0,0 +1,176 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from typing import Optional
5
+ from dataclasses import dataclass
6
+
7
+ from orbit.model import BaseBlock, register_model
8
+
9
+ @dataclass
10
+ class FiLMOutput:
11
+ ''' FiLM 模块的输出容器。
12
+
13
+ Attributes:
14
+ output (torch.Tensor): 经过 gamma 和 beta 调制后的特征。
15
+ gate (Optional[torch.Tensor]): 用于残差连接的门控值。
16
+ '''
17
+ output: torch.Tensor
18
+ gate: Optional[torch.Tensor] = None
19
+
20
+ @property
21
+ def gated_output(self):
22
+ if self.gate is None: return self.output
23
+ return self.output * self.gate
24
+
25
+
26
+ @register_model()
27
+ class FiLM(BaseBlock):
28
+ ''' Feature-wise Linear Modulation (FiLM) 模块。
29
+
30
+ 对输入特征进行仿射变换:FiLM(x) = (1 + gamma(z)) * x + beta(z)
31
+ 其中 gamma 和 beta 是从条件输入 z 生成的。
32
+ 初始状态下,gamma 为 0,beta 为 0,即恒等映射。
33
+
34
+ Args:
35
+ in_features (int): 输入特征维度。
36
+ cond_features (int): 条件特征维度。
37
+ use_beta (bool, optional): 是否使用平移项 (beta)。默认为 True。
38
+ use_gamma (bool, optional): 是否使用缩放项 (gamma)。默认为 True。
39
+ use_gate (bool, optional): 是否使用门控项 (gate)。默认为 True。
40
+ use_context_gate (bool, optional): 是否使用上下文门控 (context gate)。
41
+ 如果为 True,将使用输入特征和条件特征的拼接来生成门控值,并覆盖 use_gate 的设置。默认为 False。
42
+ channel_first (bool, optional): 特征维度是否在第 1 维 (如 CNN [B, C, H, W])。
43
+ 如果为 False,则假设特征在最后一维 (如 Transformer [B, L, C])。默认为 False。
44
+ '''
45
+ def __init__(
46
+ self,
47
+ in_features: int,
48
+ cond_features: int,
49
+ use_beta: bool = True,
50
+ use_gamma: bool = True,
51
+ use_gate: bool = True,
52
+ use_context_gate: bool = False,
53
+ channel_first: bool = False
54
+ ):
55
+ super(FiLM, self).__init__()
56
+
57
+ if use_context_gate: use_gate = False
58
+
59
+ self.in_features = in_features
60
+ self.cond_features = cond_features
61
+ self.use_beta = use_beta
62
+ self.use_gamma = use_gamma
63
+ self.use_gate = use_gate
64
+ self.use_context_gate = use_context_gate
65
+ self.channel_first = channel_first
66
+
67
+ self.out_dim = 0
68
+ if use_gamma: self.out_dim += in_features
69
+ if use_beta: self.out_dim += in_features
70
+ if use_gate: self.out_dim += in_features
71
+
72
+ self.gate_proj = nn.Linear(in_features + cond_features, in_features) if use_context_gate else nn.Identity()
73
+
74
+ if self.out_dim > 0:
75
+ self.proj = nn.Linear(cond_features, self.out_dim)
76
+ else: self.proj = None
77
+
78
+ self._init_weights(self)
79
+
80
+ def _init_weights(self, model: nn.Module):
81
+ ''' 初始化权重。
82
+
83
+ 将投影层的权重和偏置初始化为 0,以确保初始状态为恒等映射。
84
+ 如果使用了上下文门控,其投影层使用 Xavier Uniform 初始化。
85
+
86
+ Args:
87
+ model (nn.Module): 需要初始化的模型。
88
+ '''
89
+ if model is self and self.proj is not None:
90
+ nn.init.constant_(self.proj.weight, 0)
91
+ nn.init.constant_(self.proj.bias, 0)
92
+ if isinstance(self.gate_proj, nn.Identity): return
93
+ nn.init.xavier_uniform_(self.gate_proj.weight, gain=0.1)
94
+ nn.init.zeros_(self.gate_proj.bias)
95
+
96
+ def _reshape(self, param: torch.Tensor, ref_ndim: int) -> torch.Tensor:
97
+ ''' 调整参数形状以匹配输入特征的维度,以便进行广播。
98
+
99
+ Args:
100
+ param (torch.Tensor): 需要重塑的参数张量。
101
+ ref_ndim (int): 参考张量(通常是输入特征 x)的维度数。
102
+
103
+ Returns:
104
+ torch.Tensor: 重塑后的参数张量。
105
+ '''
106
+ if self.channel_first:
107
+ param = param.movedim(-1, 1)
108
+ for _ in range(ref_ndim - param.ndim):
109
+ param = param.unsqueeze(-1)
110
+ else:
111
+ for _ in range(ref_ndim - param.ndim):
112
+ param = param.unsqueeze(-2)
113
+ return param
114
+
115
+ def forward(self, x: torch.Tensor, cond: torch.Tensor) -> FiLMOutput:
116
+ ''' 前向传播。
117
+
118
+ Args:
119
+ x (torch.Tensor): 输入特征。形状为 [B, C, ...] (如果 channel_first=True)
120
+ 或 [B, ..., C] (如果 channel_first=False)。
121
+ cond (torch.Tensor): 条件输入。形状为 [B, ..., cond_features]。
122
+
123
+ Returns:
124
+ FiLMOutput: 调制后的特征。
125
+ '''
126
+ if self.proj is None: return FiLMOutput(output=x)
127
+
128
+ params = self.proj(cond)
129
+
130
+ count = sum([self.use_gamma, self.use_beta, self.use_gate])
131
+ if count > 1:
132
+ params_list = params.chunk(count, dim=-1)
133
+ else:
134
+ params_list = [params]
135
+
136
+ idx = 0
137
+ gamma, beta, gate = None, None, None
138
+ if self.use_gamma:
139
+ gamma = params_list[idx]
140
+ idx += 1
141
+ if self.use_beta:
142
+ beta = params_list[idx]
143
+ idx += 1
144
+ if self.use_gate:
145
+ gate = params_list[idx]
146
+ idx += 1
147
+
148
+ out = x
149
+ if gamma is not None:
150
+ out = out * (1 + self._reshape(gamma, x.ndim))
151
+ if beta is not None:
152
+ out = out + self._reshape(beta, x.ndim)
153
+
154
+ final_gate = None
155
+ if self.use_context_gate:
156
+ if cond.ndim < x.ndim:
157
+ shape = list(x.shape)
158
+ feat_dim = 1 if self.channel_first else -1
159
+ shape[feat_dim] = -1
160
+ cond_expanded = self._reshape(cond, x.ndim).expand(shape)
161
+ else:
162
+ cond_expanded = cond
163
+
164
+ feat_dim = 1 if self.channel_first else -1
165
+ context_input = torch.cat([x, cond_expanded], dim=feat_dim)
166
+
167
+ if self.channel_first:
168
+ context_input = context_input.movedim(1, -1)
169
+ final_gate = self.gate_proj(context_input).movedim(-1, 1)
170
+ else:
171
+ final_gate = self.gate_proj(context_input)
172
+
173
+ elif gate is not None:
174
+ final_gate = self._reshape(gate, x.ndim)
175
+
176
+ return FiLMOutput(output=out, gate=final_gate)