orbit-torch 0.0.4a1__py3-none-any.whl → 0.1.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orbit/__init__.py +3 -1
- orbit/callback.py +4 -3
- orbit/dataset/__init__.py +1 -0
- orbit/dataset/cogn.py +138 -0
- orbit/dataset/data/cogn_en.jsonl +45 -0
- orbit/dataset/data/cogn_zh.jsonl +113 -0
- orbit/engine.py +210 -146
- orbit/kit/__init__.py +2 -0
- orbit/kit/interface.py +154 -0
- orbit/kit/wrapper.py +157 -0
- orbit/model/__init__.py +5 -0
- orbit/model/base.py +125 -0
- orbit/model/block/__init__.py +34 -0
- orbit/model/block/attention.py +265 -0
- orbit/model/block/bio.py +537 -0
- orbit/model/block/codebook.py +122 -0
- orbit/model/block/conv.py +505 -0
- orbit/model/block/embedding.py +252 -0
- orbit/model/block/film.py +176 -0
- orbit/model/block/fusion.py +335 -0
- orbit/model/block/gate.py +334 -0
- orbit/model/block/lora.py +776 -0
- orbit/model/block/mlp.py +68 -0
- orbit/model/block/moe.py +94 -0
- orbit/model/block/tcn.py +99 -0
- orbit/model/config.py +62 -0
- orbit/model/kit/__init__.py +6 -0
- orbit/model/kit/discriminator.py +46 -0
- orbit/model/kit/losses.py +193 -0
- orbit/model/motif/__init__.py +0 -0
- orbit/model/motif/vision/__init__.py +0 -0
- orbit/model/motif/vision/v1.py +645 -0
- orbit/model/registry.py +53 -0
- orbit/optim/__init__.py +2 -2
- orbit/optim/sam.py +10 -3
- orbit/plugin/__init__.py +12 -8
- orbit/plugin/board.py +1 -2
- orbit/plugin/checkpoint.py +137 -62
- orbit/plugin/classification.py +2 -2
- orbit/plugin/display_model.py +1 -2
- orbit/plugin/early_stopping.py +1 -2
- orbit/plugin/ema.py +1 -2
- orbit/plugin/gradient_accumulation.py +1 -2
- orbit/plugin/lora.py +346 -0
- orbit/plugin/memory_estimator.py +1 -2
- orbit/plugin/warmup.py +1 -2
- orbit/utils/__init__.py +24 -1
- orbit/utils/cuda.py +10 -0
- orbit/utils/freeze.py +61 -17
- orbit/utils/image.py +164 -0
- orbit/utils/initialization.py +184 -94
- orbit/utils/layer_io.py +66 -7
- orbit/utils/lora.py +480 -0
- orbit/utils/moe.py +55 -0
- orbit/utils/seed.py +3 -19
- orbit/utils/sft.py +93 -0
- orbit_torch-0.1.0b1.dist-info/METADATA +208 -0
- orbit_torch-0.1.0b1.dist-info/RECORD +65 -0
- orbit_torch-0.0.4a1.dist-info/METADATA +0 -25
- orbit_torch-0.0.4a1.dist-info/RECORD +0 -29
- {orbit_torch-0.0.4a1.dist-info → orbit_torch-0.1.0b1.dist-info}/WHEEL +0 -0
- {orbit_torch-0.0.4a1.dist-info → orbit_torch-0.1.0b1.dist-info}/top_level.txt +0 -0
orbit/kit/interface.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
from threading import Thread
|
|
2
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
|
3
|
+
from rich.console import Console
|
|
4
|
+
from rich.markdown import Markdown
|
|
5
|
+
from rich.panel import Panel
|
|
6
|
+
from rich.live import Live
|
|
7
|
+
from rich.prompt import Prompt
|
|
8
|
+
|
|
9
|
+
from .wrapper import AutoRegressiveWrapper
|
|
10
|
+
|
|
11
|
+
class ChatInterface:
|
|
12
|
+
'''
|
|
13
|
+
一个用于命令行实时交互的聊天接口类
|
|
14
|
+
|
|
15
|
+
Attributes:
|
|
16
|
+
model: 语言模型实例
|
|
17
|
+
tokenizer: 分词器实例
|
|
18
|
+
device: 模型所在的设备
|
|
19
|
+
'''
|
|
20
|
+
|
|
21
|
+
def __init__(self, model=None, tokenizer=None, model_id=None, device='auto', dtype='auto', model_role='assistant'):
|
|
22
|
+
'''
|
|
23
|
+
初始化聊天接口
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
model: 预加载的模型实例。如果为 None,则需要提供 model_id
|
|
27
|
+
tokenizer: 预加载的分词器实例。如果为 None,则需要提供 model_id
|
|
28
|
+
model_id: 模型的 HuggingFace ID 或本地路径
|
|
29
|
+
device: 设备设置,默认为 'auto'
|
|
30
|
+
dtype: 模型的权重精度,默认为 'auto'
|
|
31
|
+
model_role: 模型回复的角色名称,默认为 'assistant'
|
|
32
|
+
'''
|
|
33
|
+
self.console = Console()
|
|
34
|
+
self.model_role = model_role
|
|
35
|
+
if model is not None and tokenizer is not None:
|
|
36
|
+
self.model = model
|
|
37
|
+
self.tokenizer = tokenizer
|
|
38
|
+
elif model_id is not None:
|
|
39
|
+
self._load_model(model_id, device, dtype)
|
|
40
|
+
else:
|
|
41
|
+
raise ValueError('必须提供 (model 和 tokenizer) 或 model_id')
|
|
42
|
+
|
|
43
|
+
self.device = self.model.device
|
|
44
|
+
|
|
45
|
+
if not hasattr(self.model, 'generate'):
|
|
46
|
+
self.model = AutoRegressiveWrapper(self.model)
|
|
47
|
+
|
|
48
|
+
def _load_model(self, model_id, device, dtype):
|
|
49
|
+
'''
|
|
50
|
+
从指定的 model_id 加载模型和分词器
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
model_id: 模型的 HuggingFace ID 或本地路径
|
|
54
|
+
device: 设备设置
|
|
55
|
+
dtype: 模型精度
|
|
56
|
+
'''
|
|
57
|
+
with self.console.status(f'[bold green]正在加载模型: {model_id} ...[/bold green]'):
|
|
58
|
+
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
|
|
59
|
+
self.model = AutoModelForCausalLM.from_pretrained(
|
|
60
|
+
model_id,
|
|
61
|
+
dtype=dtype,
|
|
62
|
+
device_map=device
|
|
63
|
+
)
|
|
64
|
+
self.console.print(f'[bold green]模型 {model_id} 加载完成![/bold green]')
|
|
65
|
+
|
|
66
|
+
def stream_chat(
|
|
67
|
+
self,
|
|
68
|
+
messages: list,
|
|
69
|
+
max_new_tokens: int = 512,
|
|
70
|
+
temperature: float = 0.7,
|
|
71
|
+
top_k: int = 50,
|
|
72
|
+
top_p: float = 0.9,
|
|
73
|
+
repetition_penalty: float = 1.0,
|
|
74
|
+
do_sample: bool = True
|
|
75
|
+
):
|
|
76
|
+
'''
|
|
77
|
+
流式生成对话响应
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
messages (list): 符合 ChatML 格式的消息列表
|
|
81
|
+
max_new_tokens (int): 最大新生成的 token 数量
|
|
82
|
+
temperature (float): 生成温度
|
|
83
|
+
top_k (int): Top-k 采样值
|
|
84
|
+
top_p (float): Top-p 采样值
|
|
85
|
+
repetition_penalty (float): 重复惩罚系数
|
|
86
|
+
do_sample (bool): 是否使用采样
|
|
87
|
+
|
|
88
|
+
Yields:
|
|
89
|
+
str: 生成的新文本片段
|
|
90
|
+
'''
|
|
91
|
+
text = self.tokenizer.apply_chat_template(
|
|
92
|
+
messages,
|
|
93
|
+
tokenize=False,
|
|
94
|
+
add_generation_prompt=True
|
|
95
|
+
)
|
|
96
|
+
model_inputs = self.tokenizer([text], return_tensors='pt').to(self.device)
|
|
97
|
+
|
|
98
|
+
streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True)
|
|
99
|
+
|
|
100
|
+
generation_kwargs = dict(
|
|
101
|
+
model_inputs,
|
|
102
|
+
streamer=streamer,
|
|
103
|
+
max_new_tokens=max_new_tokens,
|
|
104
|
+
temperature=temperature,
|
|
105
|
+
top_k=top_k,
|
|
106
|
+
top_p=top_p,
|
|
107
|
+
repetition_penalty=repetition_penalty,
|
|
108
|
+
do_sample=do_sample,
|
|
109
|
+
eos_token_id=self.tokenizer.eos_token_id
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
|
|
113
|
+
thread.start()
|
|
114
|
+
|
|
115
|
+
for new_text in streamer:
|
|
116
|
+
yield new_text
|
|
117
|
+
|
|
118
|
+
def interact(self):
|
|
119
|
+
'''
|
|
120
|
+
启动命令行实时交互会话
|
|
121
|
+
'''
|
|
122
|
+
self.console.print(Panel(
|
|
123
|
+
'[bold]聊天接口已就绪[/bold]\n输入 [red]"exit"[/red] 或 [red]"quit"[/red] 退出',
|
|
124
|
+
title='[bold blue]Orbit Chat[/bold blue]',
|
|
125
|
+
border_style='blue',
|
|
126
|
+
expand=False
|
|
127
|
+
))
|
|
128
|
+
history = []
|
|
129
|
+
while True:
|
|
130
|
+
try:
|
|
131
|
+
self.console.print()
|
|
132
|
+
user_input = Prompt.ask('[bold green]User[/bold green]', console=self.console)
|
|
133
|
+
if user_input.lower() in ['exit', 'quit']:
|
|
134
|
+
break
|
|
135
|
+
if not user_input.strip():
|
|
136
|
+
continue
|
|
137
|
+
|
|
138
|
+
history.append({'role': 'user', 'content': user_input})
|
|
139
|
+
|
|
140
|
+
self.console.print('[bold purple]Model:[/bold purple]\n')
|
|
141
|
+
full_response = ''
|
|
142
|
+
|
|
143
|
+
with Live(Markdown(""), console=self.console, refresh_per_second=12) as live:
|
|
144
|
+
for chunk in self.stream_chat(history):
|
|
145
|
+
full_response += chunk
|
|
146
|
+
live.update(Markdown(full_response))
|
|
147
|
+
|
|
148
|
+
history.append({'role': self.model_role, 'content': full_response})
|
|
149
|
+
|
|
150
|
+
except KeyboardInterrupt:
|
|
151
|
+
self.console.print('\n[bold red][会话已中断][/bold red]')
|
|
152
|
+
break
|
|
153
|
+
|
|
154
|
+
self.console.print('[bold blue][生成结束][/bold blue]')
|
orbit/kit/wrapper.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import inspect
|
|
4
|
+
from typing import Optional, Any, List, Union
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class AutoRegressiveWrapper:
|
|
8
|
+
'''
|
|
9
|
+
一个将普通 torch.nn.Module 包装为兼容 transformers generate 接口的包装类
|
|
10
|
+
|
|
11
|
+
Attributes:
|
|
12
|
+
model (nn.Module): 原始模型实例
|
|
13
|
+
device (torch.device): 模型所在的设备
|
|
14
|
+
accepts_attention_mask (bool): 模型是否接受 attention_mask 参数
|
|
15
|
+
accepts_mask (bool): 模型是否接受 mask 参数
|
|
16
|
+
'''
|
|
17
|
+
|
|
18
|
+
def __init__(self, model: nn.Module):
|
|
19
|
+
'''
|
|
20
|
+
初始化包装器
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
model (nn.Module): 只有 forward 方法的自定义模型
|
|
24
|
+
'''
|
|
25
|
+
self.model = model
|
|
26
|
+
try:
|
|
27
|
+
self.device = next(model.parameters()).device
|
|
28
|
+
except StopIteration:
|
|
29
|
+
self.device = torch.device('cpu')
|
|
30
|
+
|
|
31
|
+
sig = inspect.signature(model.forward)
|
|
32
|
+
params = sig.parameters
|
|
33
|
+
self.accepts_attention_mask = 'attention_mask' in params
|
|
34
|
+
self.accepts_mask = 'mask' in params
|
|
35
|
+
|
|
36
|
+
self.model.eval()
|
|
37
|
+
|
|
38
|
+
@torch.no_grad()
|
|
39
|
+
def generate(
|
|
40
|
+
self,
|
|
41
|
+
input_ids: torch.Tensor,
|
|
42
|
+
max_new_tokens: int = 512,
|
|
43
|
+
temperature: float = 1.0,
|
|
44
|
+
top_k: int = 50,
|
|
45
|
+
top_p: float = 0.9,
|
|
46
|
+
repetition_penalty: float = 1.0,
|
|
47
|
+
do_sample: bool = True,
|
|
48
|
+
eos_token_id: Optional[Union[int, List[int]]] = None,
|
|
49
|
+
streamer: Optional[Any] = None,
|
|
50
|
+
**kwargs
|
|
51
|
+
) -> torch.Tensor:
|
|
52
|
+
'''
|
|
53
|
+
自回归生成循环,支持多种采样策略并兼容 TextIteratorStreamer
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
input_ids (torch.Tensor): 输入的 token ID 序列 [batch, seq_len]
|
|
57
|
+
max_new_tokens (int): 最大新生成的 token 数量
|
|
58
|
+
temperature (float): 采样温度
|
|
59
|
+
top_k (int): Top-k 采样的 k 值
|
|
60
|
+
top_p (float): Top-p (Nucleus) 采样的 p 值
|
|
61
|
+
repetition_penalty (float): 重复惩罚系数
|
|
62
|
+
do_sample (bool): 是否使用采样
|
|
63
|
+
eos_token_id (Optional[Union[int, List[int]]]): 终止 token ID
|
|
64
|
+
streamer (Optional[Any]): transformers 库的 streamer 实例
|
|
65
|
+
**kwargs: 忽略其他 transformers 相关的参数
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
torch.Tensor: 包含生成内容的完整序列
|
|
69
|
+
'''
|
|
70
|
+
curr_input_ids = input_ids.to(self.device)
|
|
71
|
+
batch_size = curr_input_ids.shape[0]
|
|
72
|
+
|
|
73
|
+
if isinstance(eos_token_id, int):
|
|
74
|
+
eos_token_id = [eos_token_id]
|
|
75
|
+
|
|
76
|
+
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=self.device)
|
|
77
|
+
|
|
78
|
+
for _ in range(max_new_tokens):
|
|
79
|
+
model_inputs = {'input_ids': curr_input_ids}
|
|
80
|
+
|
|
81
|
+
if self.accepts_attention_mask:
|
|
82
|
+
model_inputs['attention_mask'] = torch.ones_like(curr_input_ids)
|
|
83
|
+
elif self.accepts_mask:
|
|
84
|
+
model_inputs['mask'] = torch.ones_like(curr_input_ids)
|
|
85
|
+
|
|
86
|
+
outputs = self.model(**model_inputs)
|
|
87
|
+
|
|
88
|
+
if isinstance(outputs, (tuple, list)):
|
|
89
|
+
logits = outputs[0]
|
|
90
|
+
else:
|
|
91
|
+
logits = outputs
|
|
92
|
+
|
|
93
|
+
next_token_logits = logits[:, -1, :]
|
|
94
|
+
|
|
95
|
+
if repetition_penalty != 1.0:
|
|
96
|
+
for i in range(batch_size):
|
|
97
|
+
for previous_token in set(curr_input_ids[i].tolist()):
|
|
98
|
+
if next_token_logits[i, previous_token] < 0:
|
|
99
|
+
next_token_logits[i, previous_token] *= repetition_penalty
|
|
100
|
+
else:
|
|
101
|
+
next_token_logits[i, previous_token] /= repetition_penalty
|
|
102
|
+
|
|
103
|
+
if do_sample:
|
|
104
|
+
if temperature != 1.0:
|
|
105
|
+
next_token_logits = next_token_logits / temperature
|
|
106
|
+
|
|
107
|
+
if top_k > 0:
|
|
108
|
+
indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
|
|
109
|
+
next_token_logits[indices_to_remove] = float('-inf')
|
|
110
|
+
|
|
111
|
+
if top_p < 1.0:
|
|
112
|
+
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
|
|
113
|
+
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
|
114
|
+
|
|
115
|
+
sorted_indices_to_remove = cumulative_probs > top_p
|
|
116
|
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
|
117
|
+
sorted_indices_to_remove[..., 0] = 0
|
|
118
|
+
|
|
119
|
+
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
|
120
|
+
next_token_logits[indices_to_remove] = float('-inf')
|
|
121
|
+
|
|
122
|
+
probs = torch.softmax(next_token_logits, dim=-1)
|
|
123
|
+
next_token = torch.multinomial(probs, num_samples=1)
|
|
124
|
+
else:
|
|
125
|
+
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
|
|
126
|
+
|
|
127
|
+
curr_input_ids = torch.cat([curr_input_ids, next_token], dim=-1)
|
|
128
|
+
|
|
129
|
+
if streamer is not None:
|
|
130
|
+
if unfinished_sequences[0] == 1:
|
|
131
|
+
streamer.put(next_token.cpu())
|
|
132
|
+
|
|
133
|
+
if eos_token_id is not None:
|
|
134
|
+
for token_id in eos_token_id:
|
|
135
|
+
unfinished_sequences = unfinished_sequences.mul(
|
|
136
|
+
next_token.tile(1, 1).ne(token_id).all(dim=-1).long()
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
if unfinished_sequences.max() == 0:
|
|
140
|
+
break
|
|
141
|
+
|
|
142
|
+
if streamer is not None:
|
|
143
|
+
streamer.end()
|
|
144
|
+
|
|
145
|
+
return curr_input_ids
|
|
146
|
+
|
|
147
|
+
def __getattr__(self, name: str) -> Any:
|
|
148
|
+
'''
|
|
149
|
+
将未定义的属性访问转发给原始模型
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
name (str): 属性名称
|
|
153
|
+
|
|
154
|
+
Returns:
|
|
155
|
+
Any: 原始模型的属性
|
|
156
|
+
'''
|
|
157
|
+
return getattr(self.model, name)
|
orbit/model/__init__.py
ADDED
orbit/model/base.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
from typing import Union, List, Optional, Iterable
|
|
4
|
+
|
|
5
|
+
from orbit.utils import (
|
|
6
|
+
auto_initialize,
|
|
7
|
+
freeze_layers,
|
|
8
|
+
unfreeze_layers,
|
|
9
|
+
count_params,
|
|
10
|
+
save_model,
|
|
11
|
+
load_model,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class BaseBlock(nn.Module):
|
|
16
|
+
''' 基础模型块,提供通用的模型功能。
|
|
17
|
+
|
|
18
|
+
继承自 nn.Module,包含参数统计、梯度检查点、冻结/解冻层、保存/加载模型等功能。
|
|
19
|
+
'''
|
|
20
|
+
|
|
21
|
+
def __init__(self):
|
|
22
|
+
''' 初始化 BaseBlock。 '''
|
|
23
|
+
super(BaseBlock, self).__init__()
|
|
24
|
+
|
|
25
|
+
self.gradient_checkpointing: bool = False
|
|
26
|
+
|
|
27
|
+
@property
|
|
28
|
+
def device(self):
|
|
29
|
+
''' 获取模型所在的设备。
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
torch.device: 模型参数所在的设备。如果没有参数,则返回 cpu。
|
|
33
|
+
'''
|
|
34
|
+
try:
|
|
35
|
+
return next(self.parameters()).device
|
|
36
|
+
except StopIteration:
|
|
37
|
+
return torch.device('cpu')
|
|
38
|
+
|
|
39
|
+
def _init_weights(self, model: Union[nn.Module, 'BaseBlock', nn.Parameter, torch.Tensor]):
|
|
40
|
+
''' 初始化模型权重。
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
model (Union[nn.Module, 'BaseBlock', nn.Parameter, torch.Tensor]): 需要初始化的模型、层或张量。
|
|
44
|
+
'''
|
|
45
|
+
auto_initialize(model=model, verbose=False)
|
|
46
|
+
|
|
47
|
+
def set_checkpoint(self, value: bool):
|
|
48
|
+
''' 设置是否启用梯度检查点。
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
value (bool): 是否启用梯度检查点。
|
|
52
|
+
'''
|
|
53
|
+
self.gradient_checkpointing = value
|
|
54
|
+
for model in self.modules():
|
|
55
|
+
if isinstance(model, BaseBlock) and model is not self:
|
|
56
|
+
model.gradient_checkpointing = value
|
|
57
|
+
|
|
58
|
+
def count_params(self, trainable_only=False):
|
|
59
|
+
''' 统计模型参数数量。
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
trainable_only (bool, optional): 是否只统计可训练参数。默认为 False。
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
int: 参数数量。
|
|
66
|
+
'''
|
|
67
|
+
if trainable_only:
|
|
68
|
+
return count_params(self).count
|
|
69
|
+
|
|
70
|
+
return count_params(self, mode='all').count
|
|
71
|
+
|
|
72
|
+
def checkpoint(self, function, *args, **kwargs):
|
|
73
|
+
''' 应用梯度检查点。
|
|
74
|
+
|
|
75
|
+
如果启用了梯度检查点且处于训练模式,则使用 torch.utils.checkpoint.checkpoint。
|
|
76
|
+
否则直接调用函数。
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
function (Callable): 要执行的函数。
|
|
80
|
+
*args: 传递给函数的位置参数。
|
|
81
|
+
**kwargs: 传递给函数的关键字参数。
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
Any: 函数的返回值。
|
|
85
|
+
'''
|
|
86
|
+
if self.gradient_checkpointing and self.training:
|
|
87
|
+
return torch.utils.checkpoint.checkpoint(function, *args, use_reentrant=False, **kwargs)
|
|
88
|
+
else:
|
|
89
|
+
return function(*args, **kwargs)
|
|
90
|
+
|
|
91
|
+
def freeze(self, targets: Optional[Union[str, List[str]]] = None):
|
|
92
|
+
''' 冻结指定层的参数。
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
targets (Optional[Union[str, List[str]]], optional): 要冻结的层名称或名称列表。
|
|
96
|
+
如果为 None,则冻结所有层。默认为 None。
|
|
97
|
+
'''
|
|
98
|
+
freeze_layers(self, targets)
|
|
99
|
+
|
|
100
|
+
def unfreeze(self, targets: Optional[Union[str, List[str]]] = None):
|
|
101
|
+
''' 解冻指定层的参数。
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
targets (Optional[Union[str, List[str]]], optional): 要解冻的层名称或名称列表。
|
|
105
|
+
如果为 None,则解冻所有层。默认为 None。
|
|
106
|
+
'''
|
|
107
|
+
unfreeze_layers(self, targets)
|
|
108
|
+
|
|
109
|
+
def save_pretrained(self, file_path: str):
|
|
110
|
+
''' 保存模型权重到文件。
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
file_path (str): 保存路径。
|
|
114
|
+
'''
|
|
115
|
+
save_model(self, file_path)
|
|
116
|
+
|
|
117
|
+
def load_pretrained(self, file_path: str, strict: bool = True, map_location: Union[str, torch.device] = 'cpu'):
|
|
118
|
+
''' 从文件加载模型权重。
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
file_path (str): 权重文件路径。
|
|
122
|
+
strict (bool, optional): 是否严格匹配键值。默认为 True。
|
|
123
|
+
map_location (Union[str, torch.device], optional): 映射位置。默认为 'cpu'。
|
|
124
|
+
'''
|
|
125
|
+
load_model(self, file_path, strict, map_location)
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
from .embedding import (
|
|
2
|
+
RotaryPositionalEmbedding,
|
|
3
|
+
SinusoidalPositionalEmbedding,
|
|
4
|
+
MRoPEInterleavedEmbedding
|
|
5
|
+
)
|
|
6
|
+
from .attention import (
|
|
7
|
+
MultiHeadAttention, apply_attention, AttentionOutput,
|
|
8
|
+
SpatialMultiHeadAttention
|
|
9
|
+
)
|
|
10
|
+
from .codebook import (
|
|
11
|
+
LFQ, QuantizerOutput
|
|
12
|
+
)
|
|
13
|
+
from .fusion import (
|
|
14
|
+
LowRankFusion, GatedMultimodalUnit, DiffusionMapsFusion, CompactMultimodalPooling
|
|
15
|
+
)
|
|
16
|
+
from .mlp import MLP
|
|
17
|
+
from .moe import MoE
|
|
18
|
+
from .tcn import TCN
|
|
19
|
+
from .bio import (
|
|
20
|
+
HebianLayer, PredictiveCodingLayer, PredictiveCodingOutput,
|
|
21
|
+
PredictiveCodingBlock
|
|
22
|
+
)
|
|
23
|
+
from .film import FiLM, FiLMOutput
|
|
24
|
+
from .gate import (
|
|
25
|
+
SigmoidGate, TanhGate, SoftmaxGate, GLUGate,
|
|
26
|
+
TopKGate, TopKGateOutput
|
|
27
|
+
)
|
|
28
|
+
from .conv import (
|
|
29
|
+
CausalConv1d, calculate_causal_layer, ConvBlock,
|
|
30
|
+
DepthwiseSeparableConv, ResBasicBlock
|
|
31
|
+
)
|
|
32
|
+
from .lora import (
|
|
33
|
+
LinearLoRA, Conv2dLoRA, Conv1dLoRA, EmbeddingLoRA
|
|
34
|
+
)
|