orbit-torch 0.0.4a1__py3-none-any.whl → 0.1.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. orbit/__init__.py +3 -1
  2. orbit/callback.py +4 -3
  3. orbit/dataset/__init__.py +1 -0
  4. orbit/dataset/cogn.py +138 -0
  5. orbit/dataset/data/cogn_en.jsonl +45 -0
  6. orbit/dataset/data/cogn_zh.jsonl +113 -0
  7. orbit/engine.py +210 -146
  8. orbit/kit/__init__.py +2 -0
  9. orbit/kit/interface.py +154 -0
  10. orbit/kit/wrapper.py +157 -0
  11. orbit/model/__init__.py +5 -0
  12. orbit/model/base.py +125 -0
  13. orbit/model/block/__init__.py +34 -0
  14. orbit/model/block/attention.py +265 -0
  15. orbit/model/block/bio.py +537 -0
  16. orbit/model/block/codebook.py +122 -0
  17. orbit/model/block/conv.py +505 -0
  18. orbit/model/block/embedding.py +252 -0
  19. orbit/model/block/film.py +176 -0
  20. orbit/model/block/fusion.py +335 -0
  21. orbit/model/block/gate.py +334 -0
  22. orbit/model/block/lora.py +776 -0
  23. orbit/model/block/mlp.py +68 -0
  24. orbit/model/block/moe.py +94 -0
  25. orbit/model/block/tcn.py +99 -0
  26. orbit/model/config.py +62 -0
  27. orbit/model/kit/__init__.py +6 -0
  28. orbit/model/kit/discriminator.py +46 -0
  29. orbit/model/kit/losses.py +193 -0
  30. orbit/model/motif/__init__.py +0 -0
  31. orbit/model/motif/vision/__init__.py +0 -0
  32. orbit/model/motif/vision/v1.py +645 -0
  33. orbit/model/registry.py +53 -0
  34. orbit/optim/__init__.py +2 -2
  35. orbit/optim/sam.py +10 -3
  36. orbit/plugin/__init__.py +12 -8
  37. orbit/plugin/board.py +1 -2
  38. orbit/plugin/checkpoint.py +137 -62
  39. orbit/plugin/classification.py +2 -2
  40. orbit/plugin/display_model.py +1 -2
  41. orbit/plugin/early_stopping.py +1 -2
  42. orbit/plugin/ema.py +1 -2
  43. orbit/plugin/gradient_accumulation.py +1 -2
  44. orbit/plugin/lora.py +346 -0
  45. orbit/plugin/memory_estimator.py +1 -2
  46. orbit/plugin/warmup.py +1 -2
  47. orbit/utils/__init__.py +24 -1
  48. orbit/utils/cuda.py +10 -0
  49. orbit/utils/freeze.py +61 -17
  50. orbit/utils/image.py +164 -0
  51. orbit/utils/initialization.py +184 -94
  52. orbit/utils/layer_io.py +66 -7
  53. orbit/utils/lora.py +480 -0
  54. orbit/utils/moe.py +55 -0
  55. orbit/utils/seed.py +3 -19
  56. orbit/utils/sft.py +93 -0
  57. orbit_torch-0.1.0b1.dist-info/METADATA +208 -0
  58. orbit_torch-0.1.0b1.dist-info/RECORD +65 -0
  59. orbit_torch-0.0.4a1.dist-info/METADATA +0 -25
  60. orbit_torch-0.0.4a1.dist-info/RECORD +0 -29
  61. {orbit_torch-0.0.4a1.dist-info → orbit_torch-0.1.0b1.dist-info}/WHEEL +0 -0
  62. {orbit_torch-0.0.4a1.dist-info → orbit_torch-0.1.0b1.dist-info}/top_level.txt +0 -0
orbit/kit/interface.py ADDED
@@ -0,0 +1,154 @@
1
+ from threading import Thread
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
3
+ from rich.console import Console
4
+ from rich.markdown import Markdown
5
+ from rich.panel import Panel
6
+ from rich.live import Live
7
+ from rich.prompt import Prompt
8
+
9
+ from .wrapper import AutoRegressiveWrapper
10
+
11
+ class ChatInterface:
12
+ '''
13
+ 一个用于命令行实时交互的聊天接口类
14
+
15
+ Attributes:
16
+ model: 语言模型实例
17
+ tokenizer: 分词器实例
18
+ device: 模型所在的设备
19
+ '''
20
+
21
+ def __init__(self, model=None, tokenizer=None, model_id=None, device='auto', dtype='auto', model_role='assistant'):
22
+ '''
23
+ 初始化聊天接口
24
+
25
+ Args:
26
+ model: 预加载的模型实例。如果为 None,则需要提供 model_id
27
+ tokenizer: 预加载的分词器实例。如果为 None,则需要提供 model_id
28
+ model_id: 模型的 HuggingFace ID 或本地路径
29
+ device: 设备设置,默认为 'auto'
30
+ dtype: 模型的权重精度,默认为 'auto'
31
+ model_role: 模型回复的角色名称,默认为 'assistant'
32
+ '''
33
+ self.console = Console()
34
+ self.model_role = model_role
35
+ if model is not None and tokenizer is not None:
36
+ self.model = model
37
+ self.tokenizer = tokenizer
38
+ elif model_id is not None:
39
+ self._load_model(model_id, device, dtype)
40
+ else:
41
+ raise ValueError('必须提供 (model 和 tokenizer) 或 model_id')
42
+
43
+ self.device = self.model.device
44
+
45
+ if not hasattr(self.model, 'generate'):
46
+ self.model = AutoRegressiveWrapper(self.model)
47
+
48
+ def _load_model(self, model_id, device, dtype):
49
+ '''
50
+ 从指定的 model_id 加载模型和分词器
51
+
52
+ Args:
53
+ model_id: 模型的 HuggingFace ID 或本地路径
54
+ device: 设备设置
55
+ dtype: 模型精度
56
+ '''
57
+ with self.console.status(f'[bold green]正在加载模型: {model_id} ...[/bold green]'):
58
+ self.tokenizer = AutoTokenizer.from_pretrained(model_id)
59
+ self.model = AutoModelForCausalLM.from_pretrained(
60
+ model_id,
61
+ dtype=dtype,
62
+ device_map=device
63
+ )
64
+ self.console.print(f'[bold green]模型 {model_id} 加载完成![/bold green]')
65
+
66
+ def stream_chat(
67
+ self,
68
+ messages: list,
69
+ max_new_tokens: int = 512,
70
+ temperature: float = 0.7,
71
+ top_k: int = 50,
72
+ top_p: float = 0.9,
73
+ repetition_penalty: float = 1.0,
74
+ do_sample: bool = True
75
+ ):
76
+ '''
77
+ 流式生成对话响应
78
+
79
+ Args:
80
+ messages (list): 符合 ChatML 格式的消息列表
81
+ max_new_tokens (int): 最大新生成的 token 数量
82
+ temperature (float): 生成温度
83
+ top_k (int): Top-k 采样值
84
+ top_p (float): Top-p 采样值
85
+ repetition_penalty (float): 重复惩罚系数
86
+ do_sample (bool): 是否使用采样
87
+
88
+ Yields:
89
+ str: 生成的新文本片段
90
+ '''
91
+ text = self.tokenizer.apply_chat_template(
92
+ messages,
93
+ tokenize=False,
94
+ add_generation_prompt=True
95
+ )
96
+ model_inputs = self.tokenizer([text], return_tensors='pt').to(self.device)
97
+
98
+ streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True)
99
+
100
+ generation_kwargs = dict(
101
+ model_inputs,
102
+ streamer=streamer,
103
+ max_new_tokens=max_new_tokens,
104
+ temperature=temperature,
105
+ top_k=top_k,
106
+ top_p=top_p,
107
+ repetition_penalty=repetition_penalty,
108
+ do_sample=do_sample,
109
+ eos_token_id=self.tokenizer.eos_token_id
110
+ )
111
+
112
+ thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
113
+ thread.start()
114
+
115
+ for new_text in streamer:
116
+ yield new_text
117
+
118
+ def interact(self):
119
+ '''
120
+ 启动命令行实时交互会话
121
+ '''
122
+ self.console.print(Panel(
123
+ '[bold]聊天接口已就绪[/bold]\n输入 [red]"exit"[/red] 或 [red]"quit"[/red] 退出',
124
+ title='[bold blue]Orbit Chat[/bold blue]',
125
+ border_style='blue',
126
+ expand=False
127
+ ))
128
+ history = []
129
+ while True:
130
+ try:
131
+ self.console.print()
132
+ user_input = Prompt.ask('[bold green]User[/bold green]', console=self.console)
133
+ if user_input.lower() in ['exit', 'quit']:
134
+ break
135
+ if not user_input.strip():
136
+ continue
137
+
138
+ history.append({'role': 'user', 'content': user_input})
139
+
140
+ self.console.print('[bold purple]Model:[/bold purple]\n')
141
+ full_response = ''
142
+
143
+ with Live(Markdown(""), console=self.console, refresh_per_second=12) as live:
144
+ for chunk in self.stream_chat(history):
145
+ full_response += chunk
146
+ live.update(Markdown(full_response))
147
+
148
+ history.append({'role': self.model_role, 'content': full_response})
149
+
150
+ except KeyboardInterrupt:
151
+ self.console.print('\n[bold red][会话已中断][/bold red]')
152
+ break
153
+
154
+ self.console.print('[bold blue][生成结束][/bold blue]')
orbit/kit/wrapper.py ADDED
@@ -0,0 +1,157 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import inspect
4
+ from typing import Optional, Any, List, Union
5
+
6
+
7
+ class AutoRegressiveWrapper:
8
+ '''
9
+ 一个将普通 torch.nn.Module 包装为兼容 transformers generate 接口的包装类
10
+
11
+ Attributes:
12
+ model (nn.Module): 原始模型实例
13
+ device (torch.device): 模型所在的设备
14
+ accepts_attention_mask (bool): 模型是否接受 attention_mask 参数
15
+ accepts_mask (bool): 模型是否接受 mask 参数
16
+ '''
17
+
18
+ def __init__(self, model: nn.Module):
19
+ '''
20
+ 初始化包装器
21
+
22
+ Args:
23
+ model (nn.Module): 只有 forward 方法的自定义模型
24
+ '''
25
+ self.model = model
26
+ try:
27
+ self.device = next(model.parameters()).device
28
+ except StopIteration:
29
+ self.device = torch.device('cpu')
30
+
31
+ sig = inspect.signature(model.forward)
32
+ params = sig.parameters
33
+ self.accepts_attention_mask = 'attention_mask' in params
34
+ self.accepts_mask = 'mask' in params
35
+
36
+ self.model.eval()
37
+
38
+ @torch.no_grad()
39
+ def generate(
40
+ self,
41
+ input_ids: torch.Tensor,
42
+ max_new_tokens: int = 512,
43
+ temperature: float = 1.0,
44
+ top_k: int = 50,
45
+ top_p: float = 0.9,
46
+ repetition_penalty: float = 1.0,
47
+ do_sample: bool = True,
48
+ eos_token_id: Optional[Union[int, List[int]]] = None,
49
+ streamer: Optional[Any] = None,
50
+ **kwargs
51
+ ) -> torch.Tensor:
52
+ '''
53
+ 自回归生成循环,支持多种采样策略并兼容 TextIteratorStreamer
54
+
55
+ Args:
56
+ input_ids (torch.Tensor): 输入的 token ID 序列 [batch, seq_len]
57
+ max_new_tokens (int): 最大新生成的 token 数量
58
+ temperature (float): 采样温度
59
+ top_k (int): Top-k 采样的 k 值
60
+ top_p (float): Top-p (Nucleus) 采样的 p 值
61
+ repetition_penalty (float): 重复惩罚系数
62
+ do_sample (bool): 是否使用采样
63
+ eos_token_id (Optional[Union[int, List[int]]]): 终止 token ID
64
+ streamer (Optional[Any]): transformers 库的 streamer 实例
65
+ **kwargs: 忽略其他 transformers 相关的参数
66
+
67
+ Returns:
68
+ torch.Tensor: 包含生成内容的完整序列
69
+ '''
70
+ curr_input_ids = input_ids.to(self.device)
71
+ batch_size = curr_input_ids.shape[0]
72
+
73
+ if isinstance(eos_token_id, int):
74
+ eos_token_id = [eos_token_id]
75
+
76
+ unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=self.device)
77
+
78
+ for _ in range(max_new_tokens):
79
+ model_inputs = {'input_ids': curr_input_ids}
80
+
81
+ if self.accepts_attention_mask:
82
+ model_inputs['attention_mask'] = torch.ones_like(curr_input_ids)
83
+ elif self.accepts_mask:
84
+ model_inputs['mask'] = torch.ones_like(curr_input_ids)
85
+
86
+ outputs = self.model(**model_inputs)
87
+
88
+ if isinstance(outputs, (tuple, list)):
89
+ logits = outputs[0]
90
+ else:
91
+ logits = outputs
92
+
93
+ next_token_logits = logits[:, -1, :]
94
+
95
+ if repetition_penalty != 1.0:
96
+ for i in range(batch_size):
97
+ for previous_token in set(curr_input_ids[i].tolist()):
98
+ if next_token_logits[i, previous_token] < 0:
99
+ next_token_logits[i, previous_token] *= repetition_penalty
100
+ else:
101
+ next_token_logits[i, previous_token] /= repetition_penalty
102
+
103
+ if do_sample:
104
+ if temperature != 1.0:
105
+ next_token_logits = next_token_logits / temperature
106
+
107
+ if top_k > 0:
108
+ indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
109
+ next_token_logits[indices_to_remove] = float('-inf')
110
+
111
+ if top_p < 1.0:
112
+ sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
113
+ cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
114
+
115
+ sorted_indices_to_remove = cumulative_probs > top_p
116
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
117
+ sorted_indices_to_remove[..., 0] = 0
118
+
119
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
120
+ next_token_logits[indices_to_remove] = float('-inf')
121
+
122
+ probs = torch.softmax(next_token_logits, dim=-1)
123
+ next_token = torch.multinomial(probs, num_samples=1)
124
+ else:
125
+ next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
126
+
127
+ curr_input_ids = torch.cat([curr_input_ids, next_token], dim=-1)
128
+
129
+ if streamer is not None:
130
+ if unfinished_sequences[0] == 1:
131
+ streamer.put(next_token.cpu())
132
+
133
+ if eos_token_id is not None:
134
+ for token_id in eos_token_id:
135
+ unfinished_sequences = unfinished_sequences.mul(
136
+ next_token.tile(1, 1).ne(token_id).all(dim=-1).long()
137
+ )
138
+
139
+ if unfinished_sequences.max() == 0:
140
+ break
141
+
142
+ if streamer is not None:
143
+ streamer.end()
144
+
145
+ return curr_input_ids
146
+
147
+ def __getattr__(self, name: str) -> Any:
148
+ '''
149
+ 将未定义的属性访问转发给原始模型
150
+
151
+ Args:
152
+ name (str): 属性名称
153
+
154
+ Returns:
155
+ Any: 原始模型的属性
156
+ '''
157
+ return getattr(self.model, name)
@@ -0,0 +1,5 @@
1
+ from .base import BaseBlock
2
+ from .registry import (
3
+ register_model, build_model, list_models, get_model_class
4
+ )
5
+ from .config import ModelConfig
orbit/model/base.py ADDED
@@ -0,0 +1,125 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import Union, List, Optional, Iterable
4
+
5
+ from orbit.utils import (
6
+ auto_initialize,
7
+ freeze_layers,
8
+ unfreeze_layers,
9
+ count_params,
10
+ save_model,
11
+ load_model,
12
+ )
13
+
14
+
15
+ class BaseBlock(nn.Module):
16
+ ''' 基础模型块,提供通用的模型功能。
17
+
18
+ 继承自 nn.Module,包含参数统计、梯度检查点、冻结/解冻层、保存/加载模型等功能。
19
+ '''
20
+
21
+ def __init__(self):
22
+ ''' 初始化 BaseBlock。 '''
23
+ super(BaseBlock, self).__init__()
24
+
25
+ self.gradient_checkpointing: bool = False
26
+
27
+ @property
28
+ def device(self):
29
+ ''' 获取模型所在的设备。
30
+
31
+ Returns:
32
+ torch.device: 模型参数所在的设备。如果没有参数,则返回 cpu。
33
+ '''
34
+ try:
35
+ return next(self.parameters()).device
36
+ except StopIteration:
37
+ return torch.device('cpu')
38
+
39
+ def _init_weights(self, model: Union[nn.Module, 'BaseBlock', nn.Parameter, torch.Tensor]):
40
+ ''' 初始化模型权重。
41
+
42
+ Args:
43
+ model (Union[nn.Module, 'BaseBlock', nn.Parameter, torch.Tensor]): 需要初始化的模型、层或张量。
44
+ '''
45
+ auto_initialize(model=model, verbose=False)
46
+
47
+ def set_checkpoint(self, value: bool):
48
+ ''' 设置是否启用梯度检查点。
49
+
50
+ Args:
51
+ value (bool): 是否启用梯度检查点。
52
+ '''
53
+ self.gradient_checkpointing = value
54
+ for model in self.modules():
55
+ if isinstance(model, BaseBlock) and model is not self:
56
+ model.gradient_checkpointing = value
57
+
58
+ def count_params(self, trainable_only=False):
59
+ ''' 统计模型参数数量。
60
+
61
+ Args:
62
+ trainable_only (bool, optional): 是否只统计可训练参数。默认为 False。
63
+
64
+ Returns:
65
+ int: 参数数量。
66
+ '''
67
+ if trainable_only:
68
+ return count_params(self).count
69
+
70
+ return count_params(self, mode='all').count
71
+
72
+ def checkpoint(self, function, *args, **kwargs):
73
+ ''' 应用梯度检查点。
74
+
75
+ 如果启用了梯度检查点且处于训练模式,则使用 torch.utils.checkpoint.checkpoint。
76
+ 否则直接调用函数。
77
+
78
+ Args:
79
+ function (Callable): 要执行的函数。
80
+ *args: 传递给函数的位置参数。
81
+ **kwargs: 传递给函数的关键字参数。
82
+
83
+ Returns:
84
+ Any: 函数的返回值。
85
+ '''
86
+ if self.gradient_checkpointing and self.training:
87
+ return torch.utils.checkpoint.checkpoint(function, *args, use_reentrant=False, **kwargs)
88
+ else:
89
+ return function(*args, **kwargs)
90
+
91
+ def freeze(self, targets: Optional[Union[str, List[str]]] = None):
92
+ ''' 冻结指定层的参数。
93
+
94
+ Args:
95
+ targets (Optional[Union[str, List[str]]], optional): 要冻结的层名称或名称列表。
96
+ 如果为 None,则冻结所有层。默认为 None。
97
+ '''
98
+ freeze_layers(self, targets)
99
+
100
+ def unfreeze(self, targets: Optional[Union[str, List[str]]] = None):
101
+ ''' 解冻指定层的参数。
102
+
103
+ Args:
104
+ targets (Optional[Union[str, List[str]]], optional): 要解冻的层名称或名称列表。
105
+ 如果为 None,则解冻所有层。默认为 None。
106
+ '''
107
+ unfreeze_layers(self, targets)
108
+
109
+ def save_pretrained(self, file_path: str):
110
+ ''' 保存模型权重到文件。
111
+
112
+ Args:
113
+ file_path (str): 保存路径。
114
+ '''
115
+ save_model(self, file_path)
116
+
117
+ def load_pretrained(self, file_path: str, strict: bool = True, map_location: Union[str, torch.device] = 'cpu'):
118
+ ''' 从文件加载模型权重。
119
+
120
+ Args:
121
+ file_path (str): 权重文件路径。
122
+ strict (bool, optional): 是否严格匹配键值。默认为 True。
123
+ map_location (Union[str, torch.device], optional): 映射位置。默认为 'cpu'。
124
+ '''
125
+ load_model(self, file_path, strict, map_location)
@@ -0,0 +1,34 @@
1
+ from .embedding import (
2
+ RotaryPositionalEmbedding,
3
+ SinusoidalPositionalEmbedding,
4
+ MRoPEInterleavedEmbedding
5
+ )
6
+ from .attention import (
7
+ MultiHeadAttention, apply_attention, AttentionOutput,
8
+ SpatialMultiHeadAttention
9
+ )
10
+ from .codebook import (
11
+ LFQ, QuantizerOutput
12
+ )
13
+ from .fusion import (
14
+ LowRankFusion, GatedMultimodalUnit, DiffusionMapsFusion, CompactMultimodalPooling
15
+ )
16
+ from .mlp import MLP
17
+ from .moe import MoE
18
+ from .tcn import TCN
19
+ from .bio import (
20
+ HebianLayer, PredictiveCodingLayer, PredictiveCodingOutput,
21
+ PredictiveCodingBlock
22
+ )
23
+ from .film import FiLM, FiLMOutput
24
+ from .gate import (
25
+ SigmoidGate, TanhGate, SoftmaxGate, GLUGate,
26
+ TopKGate, TopKGateOutput
27
+ )
28
+ from .conv import (
29
+ CausalConv1d, calculate_causal_layer, ConvBlock,
30
+ DepthwiseSeparableConv, ResBasicBlock
31
+ )
32
+ from .lora import (
33
+ LinearLoRA, Conv2dLoRA, Conv1dLoRA, EmbeddingLoRA
34
+ )