amazingvmsloth 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- amazingvmsloth-0.1.0/PKG-INFO +25 -0
- amazingvmsloth-0.1.0/amazingvmsloth/__init__.py +25 -0
- amazingvmsloth-0.1.0/amazingvmsloth/attention.py +236 -0
- amazingvmsloth-0.1.0/amazingvmsloth/bench.py +291 -0
- amazingvmsloth-0.1.0/amazingvmsloth/cli.py +601 -0
- amazingvmsloth-0.1.0/amazingvmsloth/cpu_trainer.py +359 -0
- amazingvmsloth-0.1.0/amazingvmsloth/gradient.py +127 -0
- amazingvmsloth-0.1.0/amazingvmsloth/lora.py +280 -0
- amazingvmsloth-0.1.0/amazingvmsloth/models/__init__.py +149 -0
- amazingvmsloth-0.1.0/amazingvmsloth/models/base.py +3 -0
- amazingvmsloth-0.1.0/amazingvmsloth/multi_gpu/__init__.py +193 -0
- amazingvmsloth-0.1.0/amazingvmsloth/multi_gpu/deepspeed_integration.py +84 -0
- amazingvmsloth-0.1.0/amazingvmsloth/multi_gpu/pipeline.py +68 -0
- amazingvmsloth-0.1.0/amazingvmsloth/offload.py +131 -0
- amazingvmsloth-0.1.0/amazingvmsloth/optimizer.py +231 -0
- amazingvmsloth-0.1.0/amazingvmsloth/packing.py +206 -0
- amazingvmsloth-0.1.0/amazingvmsloth/quantization.py +170 -0
- amazingvmsloth-0.1.0/amazingvmsloth/trainer.py +510 -0
- amazingvmsloth-0.1.0/amazingvmsloth/utils/__init__.py +16 -0
- amazingvmsloth-0.1.0/amazingvmsloth/utils/banner.py +128 -0
- amazingvmsloth-0.1.0/amazingvmsloth/utils/memory.py +70 -0
- amazingvmsloth-0.1.0/amazingvmsloth/utils/patching.py +69 -0
- amazingvmsloth-0.1.0/amazingvmsloth/utils/save_load.py +52 -0
- amazingvmsloth-0.1.0/amazingvmsloth/wizard.py +603 -0
- amazingvmsloth-0.1.0/amazingvmsloth.egg-info/PKG-INFO +25 -0
- amazingvmsloth-0.1.0/amazingvmsloth.egg-info/SOURCES.txt +33 -0
- amazingvmsloth-0.1.0/amazingvmsloth.egg-info/dependency_links.txt +1 -0
- amazingvmsloth-0.1.0/amazingvmsloth.egg-info/entry_points.txt +2 -0
- amazingvmsloth-0.1.0/amazingvmsloth.egg-info/requires.txt +24 -0
- amazingvmsloth-0.1.0/amazingvmsloth.egg-info/top_level.txt +1 -0
- amazingvmsloth-0.1.0/pyproject.toml +41 -0
- amazingvmsloth-0.1.0/setup.cfg +4 -0
- amazingvmsloth-0.1.0/tests/test_lora.py +111 -0
- amazingvmsloth-0.1.0/tests/test_quantization.py +70 -0
- amazingvmsloth-0.1.0/tests/test_trainer.py +91 -0
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: amazingvmsloth
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Blazing-fast LLM fine-tuning with minimal VRAM — multi-GPU, manual LoRA gradients, flash attention, 4-bit quant
|
|
5
|
+
License: MIT
|
|
6
|
+
Requires-Python: >=3.9
|
|
7
|
+
Description-Content-Type: text/markdown
|
|
8
|
+
Requires-Dist: torch>=2.1.0
|
|
9
|
+
Requires-Dist: transformers>=4.36.0
|
|
10
|
+
Requires-Dist: bitsandbytes>=0.41.0
|
|
11
|
+
Requires-Dist: peft>=0.7.0
|
|
12
|
+
Requires-Dist: triton>=2.1.0; sys_platform == "linux"
|
|
13
|
+
Requires-Dist: safetensors>=0.4.0
|
|
14
|
+
Requires-Dist: accelerate>=0.25.0
|
|
15
|
+
Requires-Dist: datasets>=2.14.0
|
|
16
|
+
Requires-Dist: psutil>=5.9.0
|
|
17
|
+
Provides-Extra: flash-attn
|
|
18
|
+
Requires-Dist: flash-attn>=2.3.0; extra == "flash-attn"
|
|
19
|
+
Provides-Extra: multi-gpu
|
|
20
|
+
Requires-Dist: deepspeed>=0.12.0; extra == "multi-gpu"
|
|
21
|
+
Provides-Extra: dev
|
|
22
|
+
Requires-Dist: pytest>=7.0; extra == "dev"
|
|
23
|
+
Requires-Dist: ruff>=0.1.0; extra == "dev"
|
|
24
|
+
Provides-Extra: all
|
|
25
|
+
Requires-Dist: amazingvmsloth[dev,flash-attn,multi-gpu]; extra == "all"
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
# Suppress transformers loading progress bars that can crash on Windows terminal
|
|
4
|
+
os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
|
|
5
|
+
|
|
6
|
+
from amazingvmsloth.lora import apply_lora, LoRAConfig
|
|
7
|
+
from amazingvmsloth.trainer import AmazingTrainer, TrainingConfig
|
|
8
|
+
from amazingvmsloth.quantization import quantize_model_4bit
|
|
9
|
+
from amazingvmsloth.attention import patch_attention
|
|
10
|
+
from amazingvmsloth.models import auto_patch_model
|
|
11
|
+
from amazingvmsloth.multi_gpu import setup_distributed
|
|
12
|
+
from amazingvmsloth.offload import apply_dispatch_offload, estimate_model_size_gb
|
|
13
|
+
from amazingvmsloth.cpu_trainer import CpuTrainer, CpuTrainingConfig
|
|
14
|
+
|
|
15
|
+
__version__ = "0.1.0"
|
|
16
|
+
__all__ = [
|
|
17
|
+
"apply_lora",
|
|
18
|
+
"LoRAConfig",
|
|
19
|
+
"AmazingTrainer",
|
|
20
|
+
"TrainingConfig",
|
|
21
|
+
"quantize_model_4bit",
|
|
22
|
+
"patch_attention",
|
|
23
|
+
"auto_patch_model",
|
|
24
|
+
"setup_distributed",
|
|
25
|
+
]
|
|
@@ -0,0 +1,236 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
from typing import Optional, Tuple
|
|
5
|
+
from transformers import PreTrainedModel
|
|
6
|
+
|
|
7
|
+
_flash_available = False
|
|
8
|
+
try:
|
|
9
|
+
from flash_attn import flash_attn_func
|
|
10
|
+
_flash_available = True
|
|
11
|
+
except ImportError:
|
|
12
|
+
pass
|
|
13
|
+
|
|
14
|
+
_xformers_available = False
|
|
15
|
+
try:
|
|
16
|
+
from xformers.ops import memory_efficient_attention
|
|
17
|
+
_xformers_available = True
|
|
18
|
+
except ImportError:
|
|
19
|
+
pass
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def is_flash_available() -> bool:
|
|
23
|
+
return _flash_available
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class EfficientAttention(nn.Module):
|
|
27
|
+
def __init__(self, original_module, config=None):
|
|
28
|
+
super().__init__()
|
|
29
|
+
self.original_module = original_module
|
|
30
|
+
self.config = config
|
|
31
|
+
self._use_flash = False
|
|
32
|
+
self.num_heads = getattr(original_module, "num_heads", 32)
|
|
33
|
+
self.num_kv_heads = getattr(original_module, "num_kv_groups", None) or getattr(
|
|
34
|
+
original_module, "num_key_value_heads", self.num_heads
|
|
35
|
+
)
|
|
36
|
+
self.head_dim = getattr(original_module, "head_dim", 128)
|
|
37
|
+
self.hidden_size = self.num_heads * self.head_dim
|
|
38
|
+
|
|
39
|
+
def _try_flash_attention(self, q, k, v, attention_mask=None):
|
|
40
|
+
if not _flash_available:
|
|
41
|
+
return None
|
|
42
|
+
if q.dtype != torch.float16 and q.dtype != torch.bfloat16:
|
|
43
|
+
return None
|
|
44
|
+
if q.size(1) > 8192 and q.device.type != "cuda":
|
|
45
|
+
return None
|
|
46
|
+
|
|
47
|
+
try:
|
|
48
|
+
seq_len = q.size(1)
|
|
49
|
+
q_4d = q.view(q.size(0), seq_len, self.num_heads, self.head_dim)
|
|
50
|
+
k_4d = k.view(k.size(0), k.size(1), self.num_kv_heads, self.head_dim)
|
|
51
|
+
v_4d = v.view(v.size(0), v.size(1), self.num_kv_heads, self.head_dim)
|
|
52
|
+
|
|
53
|
+
if self.num_kv_heads != self.num_heads:
|
|
54
|
+
n_rep = self.num_heads // self.num_kv_heads
|
|
55
|
+
k_4d = k_4d.unsqueeze(2).expand(-1, -1, n_rep, -1, -1).reshape(q.size(0), -1, self.num_heads, self.head_dim)
|
|
56
|
+
v_4d = v_4d.unsqueeze(2).expand(-1, -1, n_rep, -1, -1).reshape(q.size(0), -1, self.num_heads, self.head_dim)
|
|
57
|
+
|
|
58
|
+
output = flash_attn_func(q_4d, k_4d, v_4d, causal=True)
|
|
59
|
+
return output.view(q.size(0), seq_len, -1)
|
|
60
|
+
except Exception:
|
|
61
|
+
return None
|
|
62
|
+
|
|
63
|
+
def _sdpa_attention(self, q, k, v, attention_mask=None):
|
|
64
|
+
if hasattr(F, "scaled_dot_product_attention"):
|
|
65
|
+
is_causal = attention_mask is None
|
|
66
|
+
output = F.scaled_dot_product_attention(
|
|
67
|
+
q, k, v,
|
|
68
|
+
attn_mask=attention_mask,
|
|
69
|
+
is_causal=is_causal,
|
|
70
|
+
)
|
|
71
|
+
return output
|
|
72
|
+
return None
|
|
73
|
+
|
|
74
|
+
def forward(self, hidden_states, attention_mask=None, position_ids=None, **kwargs):
|
|
75
|
+
return self.original_module(
|
|
76
|
+
hidden_states,
|
|
77
|
+
attention_mask=attention_mask,
|
|
78
|
+
position_ids=position_ids,
|
|
79
|
+
**kwargs,
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class PatchedLlamaAttention(nn.Module):
|
|
84
|
+
def __init__(self, original_attn, config=None):
|
|
85
|
+
super().__init__()
|
|
86
|
+
self.q_proj = original_attn.q_proj
|
|
87
|
+
self.k_proj = original_attn.k_proj
|
|
88
|
+
self.v_proj = original_attn.v_proj
|
|
89
|
+
self.o_proj = original_attn.o_proj
|
|
90
|
+
self.rotary_emb = original_attn.rotary_emb
|
|
91
|
+
self.config = config
|
|
92
|
+
|
|
93
|
+
model_config = getattr(original_attn, "config", None)
|
|
94
|
+
self.hidden_size = getattr(model_config, "hidden_size", 4096)
|
|
95
|
+
self.num_heads = getattr(model_config, "num_attention_heads", 32)
|
|
96
|
+
self.head_dim = getattr(model_config, "head_dim", self.hidden_size // self.num_heads)
|
|
97
|
+
self.num_key_value_heads = getattr(model_config, "num_key_value_heads", self.num_heads)
|
|
98
|
+
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
|
99
|
+
self.max_position_embeddings = getattr(model_config, "max_position_embeddings", 4096)
|
|
100
|
+
|
|
101
|
+
def _shape(self, tensor, seq_len, bsz):
|
|
102
|
+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
103
|
+
|
|
104
|
+
def forward(
|
|
105
|
+
self,
|
|
106
|
+
hidden_states: torch.Tensor,
|
|
107
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
108
|
+
position_ids: Optional[torch.LongTensor] = None,
|
|
109
|
+
past_key_value=None,
|
|
110
|
+
output_attentions: bool = False,
|
|
111
|
+
use_cache: bool = False,
|
|
112
|
+
**kwargs,
|
|
113
|
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
114
|
+
bsz, q_len, _ = hidden_states.size()
|
|
115
|
+
|
|
116
|
+
query_states = self.q_proj(hidden_states)
|
|
117
|
+
key_states = self.k_proj(hidden_states)
|
|
118
|
+
value_states = self.v_proj(hidden_states)
|
|
119
|
+
|
|
120
|
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
121
|
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
122
|
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
123
|
+
|
|
124
|
+
cos, sin = self.rotary_emb(value_states, position_ids)
|
|
125
|
+
query_states, key_states = _apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
|
126
|
+
|
|
127
|
+
if past_key_value is not None:
|
|
128
|
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
|
129
|
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
|
130
|
+
past_key_value = (key_states, value_states) if use_cache else None
|
|
131
|
+
|
|
132
|
+
if self.num_key_value_groups > 1:
|
|
133
|
+
key_states = _repeat_kv(key_states, self.num_key_value_groups)
|
|
134
|
+
value_states = _repeat_kv(value_states, self.num_key_value_groups)
|
|
135
|
+
|
|
136
|
+
attn_output = _efficient_attn(query_states, key_states, value_states, attention_mask)
|
|
137
|
+
|
|
138
|
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
139
|
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
|
140
|
+
attn_output = self.o_proj(attn_output)
|
|
141
|
+
|
|
142
|
+
if not output_attentions:
|
|
143
|
+
attn_weights = None
|
|
144
|
+
|
|
145
|
+
return attn_output, attn_weights, past_key_value
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def _apply_rotary_pos_emb(q, k, cos, sin):
|
|
149
|
+
def _rotate_half(x):
|
|
150
|
+
x1 = x[..., : x.shape[-1] // 2]
|
|
151
|
+
x2 = x[..., x.shape[-1] // 2 :]
|
|
152
|
+
return torch.cat((-x2, x1), dim=-1)
|
|
153
|
+
|
|
154
|
+
q_embed = (q * cos) + (_rotate_half(q) * sin)
|
|
155
|
+
k_embed = (k * cos) + (_rotate_half(k) * sin)
|
|
156
|
+
return q_embed, k_embed
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def _repeat_kv(hidden_states, n_rep):
|
|
160
|
+
if n_rep == 1:
|
|
161
|
+
return hidden_states
|
|
162
|
+
bsz, num_heads, seq_len, head_dim = hidden_states.shape
|
|
163
|
+
hidden_states = hidden_states[:, :, None, :, :].expand(bsz, num_heads, n_rep, seq_len, head_dim)
|
|
164
|
+
return hidden_states.reshape(bsz, num_heads * n_rep, seq_len, head_dim)
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def _efficient_attn(q, k, v, mask=None):
|
|
168
|
+
bsz, n_heads, seq_len, head_dim = q.shape
|
|
169
|
+
|
|
170
|
+
if _flash_available and q.is_cuda and q.dtype in (torch.float16, torch.bfloat16):
|
|
171
|
+
try:
|
|
172
|
+
q_4d = q.transpose(1, 2)
|
|
173
|
+
k_4d = k.transpose(1, 2)
|
|
174
|
+
v_4d = v.transpose(1, 2)
|
|
175
|
+
out = flash_attn_func(q_4d, k_4d, v_4d, causal=True)
|
|
176
|
+
return out.transpose(1, 2)
|
|
177
|
+
except Exception:
|
|
178
|
+
pass
|
|
179
|
+
|
|
180
|
+
if _xformers_available and q.is_cuda and q.dtype in (torch.float16, torch.bfloat16):
|
|
181
|
+
try:
|
|
182
|
+
q_3d = q.transpose(1, 2).reshape(bsz * seq_len, n_heads, head_dim)
|
|
183
|
+
k_3d = k.transpose(1, 2).reshape(bsz * seq_len, n_heads, head_dim)
|
|
184
|
+
v_3d = v.transpose(1, 2).reshape(bsz * seq_len, n_heads, head_dim)
|
|
185
|
+
attn_bias = mask if mask is not None else None
|
|
186
|
+
out = memory_efficient_attention(q_3d, k_3d, v_3d, attn_bias=attn_bias)
|
|
187
|
+
return out.reshape(bsz, seq_len, n_heads, head_dim).transpose(1, 2)
|
|
188
|
+
except Exception:
|
|
189
|
+
pass
|
|
190
|
+
|
|
191
|
+
if hasattr(F, "scaled_dot_product_attention"):
|
|
192
|
+
is_causal = mask is None
|
|
193
|
+
return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, is_causal=is_causal)
|
|
194
|
+
|
|
195
|
+
attn_weights = torch.matmul(q, k.transpose(2, 3)) / (q.shape[-1] ** 0.5)
|
|
196
|
+
if mask is not None:
|
|
197
|
+
attn_weights = attn_weights + mask
|
|
198
|
+
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
|
|
199
|
+
return torch.matmul(attn_weights, v)
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def patch_attention(model: PreTrainedModel) -> PreTrainedModel:
|
|
203
|
+
if hasattr(model, "config"):
|
|
204
|
+
model.config._attn_implementation = "sdpa"
|
|
205
|
+
|
|
206
|
+
modules_dict = dict(model.named_modules())
|
|
207
|
+
patched_count = 0
|
|
208
|
+
|
|
209
|
+
for name, module in list(modules_dict.items()):
|
|
210
|
+
module_type = type(module).__name__
|
|
211
|
+
if "Attention" in module_type or "SdpaAttention" in module_type:
|
|
212
|
+
has_rotary = hasattr(module, "rotary_emb")
|
|
213
|
+
if not has_rotary:
|
|
214
|
+
continue
|
|
215
|
+
try:
|
|
216
|
+
patched = PatchedLlamaAttention(module, getattr(model, "config", None))
|
|
217
|
+
parts = name.rsplit(".", 1)
|
|
218
|
+
if len(parts) == 2:
|
|
219
|
+
parent = modules_dict.get(parts[0])
|
|
220
|
+
if parent is not None:
|
|
221
|
+
setattr(parent, parts[1], patched)
|
|
222
|
+
patched_count += 1
|
|
223
|
+
except Exception as e:
|
|
224
|
+
print(f"[amazingvmsloth] Could not patch attention {name}: {e}")
|
|
225
|
+
|
|
226
|
+
skipped = sum(1 for name, m in modules_dict.items()
|
|
227
|
+
if ("Attention" in type(m).__name__ or "SdpaAttention" in type(m).__name__)
|
|
228
|
+
and not hasattr(m, "rotary_emb"))
|
|
229
|
+
|
|
230
|
+
print(f"[amazingvmsloth] Patched {patched_count} attention layers, {skipped} using native SDPA (no rotary_emb on module)")
|
|
231
|
+
flash_status = "available" if _flash_available else "not available (install flash-attn for speedup)"
|
|
232
|
+
xformers_status = "available" if _xformers_available else "not available (pip install xformers for speedup)"
|
|
233
|
+
print(f"[amazingvmsloth] Flash attention: {flash_status}")
|
|
234
|
+
print(f"[amazingvmsloth] XFormers attention: {xformers_status}")
|
|
235
|
+
print(f"[amazingvmsloth] Attention implementation: {'xformers' if _xformers_available else 'SDPA'}")
|
|
236
|
+
return model
|
|
@@ -0,0 +1,291 @@
|
|
|
1
|
+
import time
|
|
2
|
+
import torch
|
|
3
|
+
import argparse
|
|
4
|
+
import os
|
|
5
|
+
from typing import Dict, Any
|
|
6
|
+
|
|
7
|
+
# Suppress transformers loading progress bars (can crash on Windows + cause UI issues)
|
|
8
|
+
os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
|
|
9
|
+
|
|
10
|
+
_unsloth_available = False
|
|
11
|
+
_unsloth_error = None
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _try_import_unsloth():
|
|
15
|
+
"""Lazy import unsloth only when unsloth benchmark is requested."""
|
|
16
|
+
global _unsloth_available, _unsloth_error
|
|
17
|
+
if _unsloth_available or _unsloth_error is not None:
|
|
18
|
+
return _unsloth_available
|
|
19
|
+
import warnings
|
|
20
|
+
with warnings.catch_warnings():
|
|
21
|
+
warnings.simplefilter("ignore", UserWarning)
|
|
22
|
+
try:
|
|
23
|
+
import unsloth
|
|
24
|
+
from unsloth import FastLanguageModel
|
|
25
|
+
_unsloth_available = True
|
|
26
|
+
except Exception as e:
|
|
27
|
+
_unsloth_error = str(e)
|
|
28
|
+
return _unsloth_available
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def run_amazingvmsloth(model_name: str, dataset_name: str, max_samples: int, max_seq_length: int) -> Dict[str, Any]:
|
|
32
|
+
from transformers import AutoTokenizer
|
|
33
|
+
from datasets import load_dataset
|
|
34
|
+
from amazingvmsloth import quantize_model_4bit, auto_patch_model, LoRAConfig, AmazingTrainer, TrainingConfig
|
|
35
|
+
from amazingvmsloth.utils.banner import print_banner, get_system_info
|
|
36
|
+
|
|
37
|
+
info = get_system_info()
|
|
38
|
+
print_banner(info, model_name)
|
|
39
|
+
|
|
40
|
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
41
|
+
if tokenizer.pad_token is None:
|
|
42
|
+
tokenizer.pad_token = tokenizer.eos_token
|
|
43
|
+
|
|
44
|
+
try:
|
|
45
|
+
model = quantize_model_4bit(model_name)
|
|
46
|
+
except (torch.OutOfMemoryError, RuntimeError) as e:
|
|
47
|
+
if "out of memory" in str(e).lower() or "cuda" in str(e).lower():
|
|
48
|
+
print(f"\n[bench] ERROR: GPU out of memory while loading {model_name}")
|
|
49
|
+
print(f"[bench] This model may be too large for your GPU.")
|
|
50
|
+
print(f"[bench] Try: --model Qwen/Qwen2.5-0.5B (smaller model)")
|
|
51
|
+
print(f"[bench] Or: --low-vram (auto-tune settings)")
|
|
52
|
+
raise
|
|
53
|
+
else:
|
|
54
|
+
# Try pre-quantized unsloth model
|
|
55
|
+
unsloth_model = f"unsloth/{model_name.split('/')[-1]}-bnb-4bit"
|
|
56
|
+
print(f"[bench] Loading failed, trying pre-quantized: {unsloth_model}")
|
|
57
|
+
model = quantize_model_4bit(unsloth_model)
|
|
58
|
+
lora_config = LoRAConfig(r=16, lora_alpha=16, use_rslora=True, use_manual_gradients=False)
|
|
59
|
+
model = auto_patch_model(model, apply_lora_patch=True, lora_config=lora_config)
|
|
60
|
+
|
|
61
|
+
dataset = load_dataset(dataset_name, split="train")
|
|
62
|
+
if max_samples > 0:
|
|
63
|
+
dataset = dataset.select(range(min(max_samples, len(dataset))))
|
|
64
|
+
|
|
65
|
+
def tokenize_fn(examples):
|
|
66
|
+
texts = [
|
|
67
|
+
f"### Instruction:\n{inst}\n\n### Response:\n{out}"
|
|
68
|
+
for inst, out in zip(examples["instruction"], examples["output"])
|
|
69
|
+
]
|
|
70
|
+
return tokenizer(texts, truncation=True, max_length=max_seq_length)
|
|
71
|
+
|
|
72
|
+
tokenized = dataset.map(tokenize_fn, batched=True, remove_columns=dataset.column_names)
|
|
73
|
+
tokenized = tokenized.map(lambda x: {"labels": x["input_ids"]}, batched=True)
|
|
74
|
+
# Don't set_format("torch") — datasets' torch formatter crashes on Windows
|
|
75
|
+
# because it tries to import torchvision.io.VideoReader which isn't available.
|
|
76
|
+
# Our collator handles list->tensor conversion manually.
|
|
77
|
+
|
|
78
|
+
# Pre-pack: concatenate all sequences, split into max_seq_length chunks.
|
|
79
|
+
# This gives 2-5x fewer forward passes vs padding single short sequences.
|
|
80
|
+
from amazingvmsloth.packing import prepack_dataset, BatchCollator
|
|
81
|
+
packed = prepack_dataset(tokenized, tokenizer, max_seq_length=max_seq_length)
|
|
82
|
+
print(f"[bench] Pre-packed {len(tokenized)} samples into {len(packed)} chunks ({max_seq_length} tokens each)")
|
|
83
|
+
|
|
84
|
+
config = TrainingConfig(
|
|
85
|
+
output_dir="./bench_amazingvmsloth",
|
|
86
|
+
num_train_epochs=1,
|
|
87
|
+
per_device_train_batch_size=1,
|
|
88
|
+
gradient_accumulation_steps=1,
|
|
89
|
+
learning_rate=2e-4,
|
|
90
|
+
bf16=True,
|
|
91
|
+
optim="paged_adamw_8bit",
|
|
92
|
+
max_seq_length=max_seq_length,
|
|
93
|
+
logging_steps=1,
|
|
94
|
+
save_steps=999999,
|
|
95
|
+
packing=False, # Already pre-packed above
|
|
96
|
+
compile_model=False, # Compile overhead dominates on short runs despite pre-packing
|
|
97
|
+
gradient_checkpointing=True, # Enable for VRAM optimization (match unsloth's 2.3GB)
|
|
98
|
+
chunked_loss=False, # 512 seq fits full logits in 4GB, skip chunked overhead
|
|
99
|
+
silent=True, # Skip tqdm/postfix overhead in benchmark
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
trainer = AmazingTrainer(
|
|
103
|
+
model=model,
|
|
104
|
+
tokenizer=tokenizer,
|
|
105
|
+
train_dataset=packed,
|
|
106
|
+
config=config,
|
|
107
|
+
data_collator=BatchCollator(),
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
start = time.time()
|
|
111
|
+
result = trainer.train()
|
|
112
|
+
elapsed = time.time() - start
|
|
113
|
+
|
|
114
|
+
vram_peak = torch.cuda.max_memory_allocated(0) / 1024**3 if torch.cuda.is_available() else 0
|
|
115
|
+
torch.cuda.reset_peak_memory_stats()
|
|
116
|
+
|
|
117
|
+
return {
|
|
118
|
+
"time_s": round(elapsed, 1),
|
|
119
|
+
"vram_peak_gb": round(vram_peak, 2),
|
|
120
|
+
"steps": result["global_step"],
|
|
121
|
+
"loss": result["log_history"][-1]["loss"] if result["log_history"] else -1,
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def run_unsloth(model_name: str, dataset_name: str, max_samples: int, max_seq_length: int) -> Dict[str, Any]:
|
|
126
|
+
if not _try_import_unsloth():
|
|
127
|
+
msg = "unsloth not installed" if _unsloth_error is None else f"unsloth import failed: {_unsloth_error}"
|
|
128
|
+
print(f"[bench] {msg}. Install with: pip install unsloth")
|
|
129
|
+
return {"time_s": -1, "vram_peak_gb": -1, "steps": -1, "loss": -1, "error": msg}
|
|
130
|
+
|
|
131
|
+
from unsloth import FastLanguageModel
|
|
132
|
+
from trl import SFTTrainer
|
|
133
|
+
from transformers import TrainingArguments
|
|
134
|
+
from datasets import load_dataset
|
|
135
|
+
|
|
136
|
+
try:
|
|
137
|
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
|
138
|
+
model_name=model_name,
|
|
139
|
+
max_seq_length=max_seq_length,
|
|
140
|
+
dtype=None,
|
|
141
|
+
load_in_4bit=True,
|
|
142
|
+
)
|
|
143
|
+
except (OSError, MemoryError) as e:
|
|
144
|
+
err_msg = str(e)
|
|
145
|
+
if "paging file" in err_msg.lower() or "memory" in err_msg.lower():
|
|
146
|
+
print(f"[bench] unsloth failed: System out of memory (Windows paging file too small)")
|
|
147
|
+
print(f"[bench] unsloth requires more RAM/virtual memory than your system has.")
|
|
148
|
+
print(f"[bench] Tip: Increase Windows virtual memory or close other applications.")
|
|
149
|
+
else:
|
|
150
|
+
print(f"[bench] unsloth model loading failed: {err_msg}")
|
|
151
|
+
return {"time_s": -1, "vram_peak_gb": -1, "steps": -1, "loss": -1, "error": err_msg}
|
|
152
|
+
|
|
153
|
+
try:
|
|
154
|
+
model = FastLanguageModel.get_peft_model(
|
|
155
|
+
model,
|
|
156
|
+
r=16,
|
|
157
|
+
lora_alpha=16,
|
|
158
|
+
lora_dropout=0,
|
|
159
|
+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
|
|
160
|
+
use_rslora=True,
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
dataset = load_dataset(dataset_name, split="train")
|
|
164
|
+
if max_samples > 0:
|
|
165
|
+
dataset = dataset.select(range(min(max_samples, len(dataset))))
|
|
166
|
+
|
|
167
|
+
def formatting_func(examples):
|
|
168
|
+
return [
|
|
169
|
+
f"### Instruction:\n{inst}\n\n### Response:\n{out}"
|
|
170
|
+
for inst, out in zip(examples["instruction"], examples["output"])
|
|
171
|
+
]
|
|
172
|
+
|
|
173
|
+
training_args = TrainingArguments(
|
|
174
|
+
output_dir="./bench_unsloth",
|
|
175
|
+
num_train_epochs=1,
|
|
176
|
+
per_device_train_batch_size=1,
|
|
177
|
+
gradient_accumulation_steps=8,
|
|
178
|
+
learning_rate=2e-4,
|
|
179
|
+
bf16=True,
|
|
180
|
+
logging_steps=1,
|
|
181
|
+
save_steps=999999,
|
|
182
|
+
optim="paged_adamw_8bit",
|
|
183
|
+
max_grad_norm=1.0,
|
|
184
|
+
warmup_ratio=0.1,
|
|
185
|
+
report_to="none",
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
trainer = SFTTrainer(
|
|
189
|
+
model=model,
|
|
190
|
+
tokenizer=tokenizer,
|
|
191
|
+
train_dataset=dataset,
|
|
192
|
+
formatting_func=formatting_func,
|
|
193
|
+
max_seq_length=max_seq_length,
|
|
194
|
+
args=training_args,
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
torch.cuda.reset_peak_memory_stats()
|
|
198
|
+
start = time.time()
|
|
199
|
+
trainer.train()
|
|
200
|
+
elapsed = time.time() - start
|
|
201
|
+
except Exception as e:
|
|
202
|
+
elapsed = time.time() - start if 'start' in locals() else 0
|
|
203
|
+
err_msg = str(e)
|
|
204
|
+
if "paging file" in err_msg.lower() or "memory" in err_msg.lower() or "os error 1455" in err_msg.lower():
|
|
205
|
+
print(f"[bench] unsloth failed: System out of memory")
|
|
206
|
+
print(f"[bench] unsloth requires more RAM/virtual memory than your system has.")
|
|
207
|
+
else:
|
|
208
|
+
print(f"[bench] unsloth training failed: {err_msg}")
|
|
209
|
+
return {
|
|
210
|
+
"time_s": round(elapsed, 1) if elapsed > 0 else -1,
|
|
211
|
+
"vram_peak_gb": -1,
|
|
212
|
+
"steps": getattr(trainer.state, "global_step", 0) if 'trainer' in locals() else 0,
|
|
213
|
+
"loss": -1,
|
|
214
|
+
"error": err_msg,
|
|
215
|
+
}
|
|
216
|
+
|
|
217
|
+
vram_peak = torch.cuda.max_memory_allocated(0) / 1024**3 if torch.cuda.is_available() else 0
|
|
218
|
+
torch.cuda.reset_peak_memory_stats()
|
|
219
|
+
|
|
220
|
+
# Find last logged loss
|
|
221
|
+
loss = -1
|
|
222
|
+
for entry in reversed(trainer.state.log_history):
|
|
223
|
+
if "loss" in entry:
|
|
224
|
+
loss = entry["loss"]
|
|
225
|
+
break
|
|
226
|
+
|
|
227
|
+
return {
|
|
228
|
+
"time_s": round(elapsed, 1),
|
|
229
|
+
"vram_peak_gb": round(vram_peak, 2),
|
|
230
|
+
"steps": trainer.state.global_step,
|
|
231
|
+
"loss": round(loss, 4) if isinstance(loss, float) else -1,
|
|
232
|
+
}
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
def main():
|
|
236
|
+
parser = argparse.ArgumentParser(description="Benchmark: amazingvmsloth vs unsloth")
|
|
237
|
+
parser.add_argument("--model", default="Qwen/Qwen2.5-1.5B", help="Model name")
|
|
238
|
+
parser.add_argument("--dataset", default="tatsu-lab/alpaca", help="Dataset")
|
|
239
|
+
parser.add_argument("--max-samples", type=int, default=500, help="Samples to use")
|
|
240
|
+
parser.add_argument("--max-seq-length", type=int, default=512, help="Max seq length")
|
|
241
|
+
parser.add_argument("--skip-unsloth", action="store_true", help="Skip unsloth benchmark")
|
|
242
|
+
args = parser.parse_args()
|
|
243
|
+
|
|
244
|
+
print(f"\n{'='*60}")
|
|
245
|
+
print(f" Benchmark: amazingvmsloth vs unsloth")
|
|
246
|
+
print(f"{'='*60}")
|
|
247
|
+
print(f" Model: {args.model}")
|
|
248
|
+
print(f" Dataset: {args.dataset}")
|
|
249
|
+
print(f" Samples: {args.max_samples}")
|
|
250
|
+
print(f" Seq length: {args.max_seq_length}")
|
|
251
|
+
print(f"{'='*60}\n")
|
|
252
|
+
|
|
253
|
+
print("[bench] Running amazingvmsloth...")
|
|
254
|
+
av_result = run_amazingvmsloth(args.model, args.dataset, args.max_samples, args.max_seq_length)
|
|
255
|
+
print(f"[bench] amazingvmsloth done: {av_result['time_s']}s, {av_result['vram_peak_gb']}GB peak VRAM\n")
|
|
256
|
+
|
|
257
|
+
us_result = None
|
|
258
|
+
if not args.skip_unsloth:
|
|
259
|
+
print("[bench] Running unsloth...")
|
|
260
|
+
us_result = run_unsloth(args.model, args.dataset, args.max_samples, args.max_seq_length)
|
|
261
|
+
if us_result.get("error"):
|
|
262
|
+
print(f"[bench] unsloth skipped: {us_result['error']}\n")
|
|
263
|
+
us_result = None
|
|
264
|
+
else:
|
|
265
|
+
print(f"[bench] unsloth done: {us_result['time_s']}s, {us_result['vram_peak_gb']}GB peak VRAM\n")
|
|
266
|
+
|
|
267
|
+
print(f"\n{'='*60}")
|
|
268
|
+
print(f" Results")
|
|
269
|
+
print(f"{'='*60}")
|
|
270
|
+
print(f" {'Metric':<20} {'amazingvmsloth':>15} {'unsloth':>15}")
|
|
271
|
+
print(f" {'-'*20} {'-'*15} {'-'*15}")
|
|
272
|
+
print(f" {'Time (s)':<20} {av_result['time_s']:>15} {us_result['time_s'] if us_result else 'N/A':>15}")
|
|
273
|
+
print(f" {'Peak VRAM (GB)':<20} {av_result['vram_peak_gb']:>15} {us_result['vram_peak_gb'] if us_result else 'N/A':>15}")
|
|
274
|
+
print(f" {'Final Loss':<20} {av_result['loss']:>15} {us_result['loss'] if us_result else 'N/A':>15}")
|
|
275
|
+
print(f" {'Steps':<20} {av_result['steps']:>15} {us_result['steps'] if us_result else 'N/A':>15}")
|
|
276
|
+
|
|
277
|
+
if us_result and us_result["time_s"] > 0 and av_result["time_s"] > 0:
|
|
278
|
+
speedup = us_result["time_s"] / av_result["time_s"]
|
|
279
|
+
vram_saved = us_result["vram_peak_gb"] - av_result["vram_peak_gb"]
|
|
280
|
+
print(f"\n Speedup: {speedup:.2f}x")
|
|
281
|
+
print(f" VRAM saved: {vram_saved:.2f} GB")
|
|
282
|
+
if speedup > 1:
|
|
283
|
+
print(f" amazingvmsloth is {speedup:.2f}x faster!")
|
|
284
|
+
else:
|
|
285
|
+
print(f" unsloth is {1/speedup:.2f}x faster")
|
|
286
|
+
|
|
287
|
+
print(f"{'='*60}\n")
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
if __name__ == "__main__":
|
|
291
|
+
main()
|