amazingvmsloth 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. amazingvmsloth-0.1.0/PKG-INFO +25 -0
  2. amazingvmsloth-0.1.0/amazingvmsloth/__init__.py +25 -0
  3. amazingvmsloth-0.1.0/amazingvmsloth/attention.py +236 -0
  4. amazingvmsloth-0.1.0/amazingvmsloth/bench.py +291 -0
  5. amazingvmsloth-0.1.0/amazingvmsloth/cli.py +601 -0
  6. amazingvmsloth-0.1.0/amazingvmsloth/cpu_trainer.py +359 -0
  7. amazingvmsloth-0.1.0/amazingvmsloth/gradient.py +127 -0
  8. amazingvmsloth-0.1.0/amazingvmsloth/lora.py +280 -0
  9. amazingvmsloth-0.1.0/amazingvmsloth/models/__init__.py +149 -0
  10. amazingvmsloth-0.1.0/amazingvmsloth/models/base.py +3 -0
  11. amazingvmsloth-0.1.0/amazingvmsloth/multi_gpu/__init__.py +193 -0
  12. amazingvmsloth-0.1.0/amazingvmsloth/multi_gpu/deepspeed_integration.py +84 -0
  13. amazingvmsloth-0.1.0/amazingvmsloth/multi_gpu/pipeline.py +68 -0
  14. amazingvmsloth-0.1.0/amazingvmsloth/offload.py +131 -0
  15. amazingvmsloth-0.1.0/amazingvmsloth/optimizer.py +231 -0
  16. amazingvmsloth-0.1.0/amazingvmsloth/packing.py +206 -0
  17. amazingvmsloth-0.1.0/amazingvmsloth/quantization.py +170 -0
  18. amazingvmsloth-0.1.0/amazingvmsloth/trainer.py +510 -0
  19. amazingvmsloth-0.1.0/amazingvmsloth/utils/__init__.py +16 -0
  20. amazingvmsloth-0.1.0/amazingvmsloth/utils/banner.py +128 -0
  21. amazingvmsloth-0.1.0/amazingvmsloth/utils/memory.py +70 -0
  22. amazingvmsloth-0.1.0/amazingvmsloth/utils/patching.py +69 -0
  23. amazingvmsloth-0.1.0/amazingvmsloth/utils/save_load.py +52 -0
  24. amazingvmsloth-0.1.0/amazingvmsloth/wizard.py +603 -0
  25. amazingvmsloth-0.1.0/amazingvmsloth.egg-info/PKG-INFO +25 -0
  26. amazingvmsloth-0.1.0/amazingvmsloth.egg-info/SOURCES.txt +33 -0
  27. amazingvmsloth-0.1.0/amazingvmsloth.egg-info/dependency_links.txt +1 -0
  28. amazingvmsloth-0.1.0/amazingvmsloth.egg-info/entry_points.txt +2 -0
  29. amazingvmsloth-0.1.0/amazingvmsloth.egg-info/requires.txt +24 -0
  30. amazingvmsloth-0.1.0/amazingvmsloth.egg-info/top_level.txt +1 -0
  31. amazingvmsloth-0.1.0/pyproject.toml +41 -0
  32. amazingvmsloth-0.1.0/setup.cfg +4 -0
  33. amazingvmsloth-0.1.0/tests/test_lora.py +111 -0
  34. amazingvmsloth-0.1.0/tests/test_quantization.py +70 -0
  35. amazingvmsloth-0.1.0/tests/test_trainer.py +91 -0
@@ -0,0 +1,25 @@
1
+ Metadata-Version: 2.4
2
+ Name: amazingvmsloth
3
+ Version: 0.1.0
4
+ Summary: Blazing-fast LLM fine-tuning with minimal VRAM — multi-GPU, manual LoRA gradients, flash attention, 4-bit quant
5
+ License: MIT
6
+ Requires-Python: >=3.9
7
+ Description-Content-Type: text/markdown
8
+ Requires-Dist: torch>=2.1.0
9
+ Requires-Dist: transformers>=4.36.0
10
+ Requires-Dist: bitsandbytes>=0.41.0
11
+ Requires-Dist: peft>=0.7.0
12
+ Requires-Dist: triton>=2.1.0; sys_platform == "linux"
13
+ Requires-Dist: safetensors>=0.4.0
14
+ Requires-Dist: accelerate>=0.25.0
15
+ Requires-Dist: datasets>=2.14.0
16
+ Requires-Dist: psutil>=5.9.0
17
+ Provides-Extra: flash-attn
18
+ Requires-Dist: flash-attn>=2.3.0; extra == "flash-attn"
19
+ Provides-Extra: multi-gpu
20
+ Requires-Dist: deepspeed>=0.12.0; extra == "multi-gpu"
21
+ Provides-Extra: dev
22
+ Requires-Dist: pytest>=7.0; extra == "dev"
23
+ Requires-Dist: ruff>=0.1.0; extra == "dev"
24
+ Provides-Extra: all
25
+ Requires-Dist: amazingvmsloth[dev,flash-attn,multi-gpu]; extra == "all"
@@ -0,0 +1,25 @@
1
+ import os
2
+
3
+ # Suppress transformers loading progress bars that can crash on Windows terminal
4
+ os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
5
+
6
+ from amazingvmsloth.lora import apply_lora, LoRAConfig
7
+ from amazingvmsloth.trainer import AmazingTrainer, TrainingConfig
8
+ from amazingvmsloth.quantization import quantize_model_4bit
9
+ from amazingvmsloth.attention import patch_attention
10
+ from amazingvmsloth.models import auto_patch_model
11
+ from amazingvmsloth.multi_gpu import setup_distributed
12
+ from amazingvmsloth.offload import apply_dispatch_offload, estimate_model_size_gb
13
+ from amazingvmsloth.cpu_trainer import CpuTrainer, CpuTrainingConfig
14
+
15
+ __version__ = "0.1.0"
16
+ __all__ = [
17
+ "apply_lora",
18
+ "LoRAConfig",
19
+ "AmazingTrainer",
20
+ "TrainingConfig",
21
+ "quantize_model_4bit",
22
+ "patch_attention",
23
+ "auto_patch_model",
24
+ "setup_distributed",
25
+ ]
@@ -0,0 +1,236 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from typing import Optional, Tuple
5
+ from transformers import PreTrainedModel
6
+
7
+ _flash_available = False
8
+ try:
9
+ from flash_attn import flash_attn_func
10
+ _flash_available = True
11
+ except ImportError:
12
+ pass
13
+
14
+ _xformers_available = False
15
+ try:
16
+ from xformers.ops import memory_efficient_attention
17
+ _xformers_available = True
18
+ except ImportError:
19
+ pass
20
+
21
+
22
+ def is_flash_available() -> bool:
23
+ return _flash_available
24
+
25
+
26
+ class EfficientAttention(nn.Module):
27
+ def __init__(self, original_module, config=None):
28
+ super().__init__()
29
+ self.original_module = original_module
30
+ self.config = config
31
+ self._use_flash = False
32
+ self.num_heads = getattr(original_module, "num_heads", 32)
33
+ self.num_kv_heads = getattr(original_module, "num_kv_groups", None) or getattr(
34
+ original_module, "num_key_value_heads", self.num_heads
35
+ )
36
+ self.head_dim = getattr(original_module, "head_dim", 128)
37
+ self.hidden_size = self.num_heads * self.head_dim
38
+
39
+ def _try_flash_attention(self, q, k, v, attention_mask=None):
40
+ if not _flash_available:
41
+ return None
42
+ if q.dtype != torch.float16 and q.dtype != torch.bfloat16:
43
+ return None
44
+ if q.size(1) > 8192 and q.device.type != "cuda":
45
+ return None
46
+
47
+ try:
48
+ seq_len = q.size(1)
49
+ q_4d = q.view(q.size(0), seq_len, self.num_heads, self.head_dim)
50
+ k_4d = k.view(k.size(0), k.size(1), self.num_kv_heads, self.head_dim)
51
+ v_4d = v.view(v.size(0), v.size(1), self.num_kv_heads, self.head_dim)
52
+
53
+ if self.num_kv_heads != self.num_heads:
54
+ n_rep = self.num_heads // self.num_kv_heads
55
+ k_4d = k_4d.unsqueeze(2).expand(-1, -1, n_rep, -1, -1).reshape(q.size(0), -1, self.num_heads, self.head_dim)
56
+ v_4d = v_4d.unsqueeze(2).expand(-1, -1, n_rep, -1, -1).reshape(q.size(0), -1, self.num_heads, self.head_dim)
57
+
58
+ output = flash_attn_func(q_4d, k_4d, v_4d, causal=True)
59
+ return output.view(q.size(0), seq_len, -1)
60
+ except Exception:
61
+ return None
62
+
63
+ def _sdpa_attention(self, q, k, v, attention_mask=None):
64
+ if hasattr(F, "scaled_dot_product_attention"):
65
+ is_causal = attention_mask is None
66
+ output = F.scaled_dot_product_attention(
67
+ q, k, v,
68
+ attn_mask=attention_mask,
69
+ is_causal=is_causal,
70
+ )
71
+ return output
72
+ return None
73
+
74
+ def forward(self, hidden_states, attention_mask=None, position_ids=None, **kwargs):
75
+ return self.original_module(
76
+ hidden_states,
77
+ attention_mask=attention_mask,
78
+ position_ids=position_ids,
79
+ **kwargs,
80
+ )
81
+
82
+
83
+ class PatchedLlamaAttention(nn.Module):
84
+ def __init__(self, original_attn, config=None):
85
+ super().__init__()
86
+ self.q_proj = original_attn.q_proj
87
+ self.k_proj = original_attn.k_proj
88
+ self.v_proj = original_attn.v_proj
89
+ self.o_proj = original_attn.o_proj
90
+ self.rotary_emb = original_attn.rotary_emb
91
+ self.config = config
92
+
93
+ model_config = getattr(original_attn, "config", None)
94
+ self.hidden_size = getattr(model_config, "hidden_size", 4096)
95
+ self.num_heads = getattr(model_config, "num_attention_heads", 32)
96
+ self.head_dim = getattr(model_config, "head_dim", self.hidden_size // self.num_heads)
97
+ self.num_key_value_heads = getattr(model_config, "num_key_value_heads", self.num_heads)
98
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
99
+ self.max_position_embeddings = getattr(model_config, "max_position_embeddings", 4096)
100
+
101
+ def _shape(self, tensor, seq_len, bsz):
102
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
103
+
104
+ def forward(
105
+ self,
106
+ hidden_states: torch.Tensor,
107
+ attention_mask: Optional[torch.Tensor] = None,
108
+ position_ids: Optional[torch.LongTensor] = None,
109
+ past_key_value=None,
110
+ output_attentions: bool = False,
111
+ use_cache: bool = False,
112
+ **kwargs,
113
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
114
+ bsz, q_len, _ = hidden_states.size()
115
+
116
+ query_states = self.q_proj(hidden_states)
117
+ key_states = self.k_proj(hidden_states)
118
+ value_states = self.v_proj(hidden_states)
119
+
120
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
121
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
122
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
123
+
124
+ cos, sin = self.rotary_emb(value_states, position_ids)
125
+ query_states, key_states = _apply_rotary_pos_emb(query_states, key_states, cos, sin)
126
+
127
+ if past_key_value is not None:
128
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
129
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
130
+ past_key_value = (key_states, value_states) if use_cache else None
131
+
132
+ if self.num_key_value_groups > 1:
133
+ key_states = _repeat_kv(key_states, self.num_key_value_groups)
134
+ value_states = _repeat_kv(value_states, self.num_key_value_groups)
135
+
136
+ attn_output = _efficient_attn(query_states, key_states, value_states, attention_mask)
137
+
138
+ attn_output = attn_output.transpose(1, 2).contiguous()
139
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
140
+ attn_output = self.o_proj(attn_output)
141
+
142
+ if not output_attentions:
143
+ attn_weights = None
144
+
145
+ return attn_output, attn_weights, past_key_value
146
+
147
+
148
+ def _apply_rotary_pos_emb(q, k, cos, sin):
149
+ def _rotate_half(x):
150
+ x1 = x[..., : x.shape[-1] // 2]
151
+ x2 = x[..., x.shape[-1] // 2 :]
152
+ return torch.cat((-x2, x1), dim=-1)
153
+
154
+ q_embed = (q * cos) + (_rotate_half(q) * sin)
155
+ k_embed = (k * cos) + (_rotate_half(k) * sin)
156
+ return q_embed, k_embed
157
+
158
+
159
+ def _repeat_kv(hidden_states, n_rep):
160
+ if n_rep == 1:
161
+ return hidden_states
162
+ bsz, num_heads, seq_len, head_dim = hidden_states.shape
163
+ hidden_states = hidden_states[:, :, None, :, :].expand(bsz, num_heads, n_rep, seq_len, head_dim)
164
+ return hidden_states.reshape(bsz, num_heads * n_rep, seq_len, head_dim)
165
+
166
+
167
+ def _efficient_attn(q, k, v, mask=None):
168
+ bsz, n_heads, seq_len, head_dim = q.shape
169
+
170
+ if _flash_available and q.is_cuda and q.dtype in (torch.float16, torch.bfloat16):
171
+ try:
172
+ q_4d = q.transpose(1, 2)
173
+ k_4d = k.transpose(1, 2)
174
+ v_4d = v.transpose(1, 2)
175
+ out = flash_attn_func(q_4d, k_4d, v_4d, causal=True)
176
+ return out.transpose(1, 2)
177
+ except Exception:
178
+ pass
179
+
180
+ if _xformers_available and q.is_cuda and q.dtype in (torch.float16, torch.bfloat16):
181
+ try:
182
+ q_3d = q.transpose(1, 2).reshape(bsz * seq_len, n_heads, head_dim)
183
+ k_3d = k.transpose(1, 2).reshape(bsz * seq_len, n_heads, head_dim)
184
+ v_3d = v.transpose(1, 2).reshape(bsz * seq_len, n_heads, head_dim)
185
+ attn_bias = mask if mask is not None else None
186
+ out = memory_efficient_attention(q_3d, k_3d, v_3d, attn_bias=attn_bias)
187
+ return out.reshape(bsz, seq_len, n_heads, head_dim).transpose(1, 2)
188
+ except Exception:
189
+ pass
190
+
191
+ if hasattr(F, "scaled_dot_product_attention"):
192
+ is_causal = mask is None
193
+ return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, is_causal=is_causal)
194
+
195
+ attn_weights = torch.matmul(q, k.transpose(2, 3)) / (q.shape[-1] ** 0.5)
196
+ if mask is not None:
197
+ attn_weights = attn_weights + mask
198
+ attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
199
+ return torch.matmul(attn_weights, v)
200
+
201
+
202
+ def patch_attention(model: PreTrainedModel) -> PreTrainedModel:
203
+ if hasattr(model, "config"):
204
+ model.config._attn_implementation = "sdpa"
205
+
206
+ modules_dict = dict(model.named_modules())
207
+ patched_count = 0
208
+
209
+ for name, module in list(modules_dict.items()):
210
+ module_type = type(module).__name__
211
+ if "Attention" in module_type or "SdpaAttention" in module_type:
212
+ has_rotary = hasattr(module, "rotary_emb")
213
+ if not has_rotary:
214
+ continue
215
+ try:
216
+ patched = PatchedLlamaAttention(module, getattr(model, "config", None))
217
+ parts = name.rsplit(".", 1)
218
+ if len(parts) == 2:
219
+ parent = modules_dict.get(parts[0])
220
+ if parent is not None:
221
+ setattr(parent, parts[1], patched)
222
+ patched_count += 1
223
+ except Exception as e:
224
+ print(f"[amazingvmsloth] Could not patch attention {name}: {e}")
225
+
226
+ skipped = sum(1 for name, m in modules_dict.items()
227
+ if ("Attention" in type(m).__name__ or "SdpaAttention" in type(m).__name__)
228
+ and not hasattr(m, "rotary_emb"))
229
+
230
+ print(f"[amazingvmsloth] Patched {patched_count} attention layers, {skipped} using native SDPA (no rotary_emb on module)")
231
+ flash_status = "available" if _flash_available else "not available (install flash-attn for speedup)"
232
+ xformers_status = "available" if _xformers_available else "not available (pip install xformers for speedup)"
233
+ print(f"[amazingvmsloth] Flash attention: {flash_status}")
234
+ print(f"[amazingvmsloth] XFormers attention: {xformers_status}")
235
+ print(f"[amazingvmsloth] Attention implementation: {'xformers' if _xformers_available else 'SDPA'}")
236
+ return model
@@ -0,0 +1,291 @@
1
+ import time
2
+ import torch
3
+ import argparse
4
+ import os
5
+ from typing import Dict, Any
6
+
7
+ # Suppress transformers loading progress bars (can crash on Windows + cause UI issues)
8
+ os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
9
+
10
+ _unsloth_available = False
11
+ _unsloth_error = None
12
+
13
+
14
+ def _try_import_unsloth():
15
+ """Lazy import unsloth only when unsloth benchmark is requested."""
16
+ global _unsloth_available, _unsloth_error
17
+ if _unsloth_available or _unsloth_error is not None:
18
+ return _unsloth_available
19
+ import warnings
20
+ with warnings.catch_warnings():
21
+ warnings.simplefilter("ignore", UserWarning)
22
+ try:
23
+ import unsloth
24
+ from unsloth import FastLanguageModel
25
+ _unsloth_available = True
26
+ except Exception as e:
27
+ _unsloth_error = str(e)
28
+ return _unsloth_available
29
+
30
+
31
+ def run_amazingvmsloth(model_name: str, dataset_name: str, max_samples: int, max_seq_length: int) -> Dict[str, Any]:
32
+ from transformers import AutoTokenizer
33
+ from datasets import load_dataset
34
+ from amazingvmsloth import quantize_model_4bit, auto_patch_model, LoRAConfig, AmazingTrainer, TrainingConfig
35
+ from amazingvmsloth.utils.banner import print_banner, get_system_info
36
+
37
+ info = get_system_info()
38
+ print_banner(info, model_name)
39
+
40
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
41
+ if tokenizer.pad_token is None:
42
+ tokenizer.pad_token = tokenizer.eos_token
43
+
44
+ try:
45
+ model = quantize_model_4bit(model_name)
46
+ except (torch.OutOfMemoryError, RuntimeError) as e:
47
+ if "out of memory" in str(e).lower() or "cuda" in str(e).lower():
48
+ print(f"\n[bench] ERROR: GPU out of memory while loading {model_name}")
49
+ print(f"[bench] This model may be too large for your GPU.")
50
+ print(f"[bench] Try: --model Qwen/Qwen2.5-0.5B (smaller model)")
51
+ print(f"[bench] Or: --low-vram (auto-tune settings)")
52
+ raise
53
+ else:
54
+ # Try pre-quantized unsloth model
55
+ unsloth_model = f"unsloth/{model_name.split('/')[-1]}-bnb-4bit"
56
+ print(f"[bench] Loading failed, trying pre-quantized: {unsloth_model}")
57
+ model = quantize_model_4bit(unsloth_model)
58
+ lora_config = LoRAConfig(r=16, lora_alpha=16, use_rslora=True, use_manual_gradients=False)
59
+ model = auto_patch_model(model, apply_lora_patch=True, lora_config=lora_config)
60
+
61
+ dataset = load_dataset(dataset_name, split="train")
62
+ if max_samples > 0:
63
+ dataset = dataset.select(range(min(max_samples, len(dataset))))
64
+
65
+ def tokenize_fn(examples):
66
+ texts = [
67
+ f"### Instruction:\n{inst}\n\n### Response:\n{out}"
68
+ for inst, out in zip(examples["instruction"], examples["output"])
69
+ ]
70
+ return tokenizer(texts, truncation=True, max_length=max_seq_length)
71
+
72
+ tokenized = dataset.map(tokenize_fn, batched=True, remove_columns=dataset.column_names)
73
+ tokenized = tokenized.map(lambda x: {"labels": x["input_ids"]}, batched=True)
74
+ # Don't set_format("torch") — datasets' torch formatter crashes on Windows
75
+ # because it tries to import torchvision.io.VideoReader which isn't available.
76
+ # Our collator handles list->tensor conversion manually.
77
+
78
+ # Pre-pack: concatenate all sequences, split into max_seq_length chunks.
79
+ # This gives 2-5x fewer forward passes vs padding single short sequences.
80
+ from amazingvmsloth.packing import prepack_dataset, BatchCollator
81
+ packed = prepack_dataset(tokenized, tokenizer, max_seq_length=max_seq_length)
82
+ print(f"[bench] Pre-packed {len(tokenized)} samples into {len(packed)} chunks ({max_seq_length} tokens each)")
83
+
84
+ config = TrainingConfig(
85
+ output_dir="./bench_amazingvmsloth",
86
+ num_train_epochs=1,
87
+ per_device_train_batch_size=1,
88
+ gradient_accumulation_steps=1,
89
+ learning_rate=2e-4,
90
+ bf16=True,
91
+ optim="paged_adamw_8bit",
92
+ max_seq_length=max_seq_length,
93
+ logging_steps=1,
94
+ save_steps=999999,
95
+ packing=False, # Already pre-packed above
96
+ compile_model=False, # Compile overhead dominates on short runs despite pre-packing
97
+ gradient_checkpointing=True, # Enable for VRAM optimization (match unsloth's 2.3GB)
98
+ chunked_loss=False, # 512 seq fits full logits in 4GB, skip chunked overhead
99
+ silent=True, # Skip tqdm/postfix overhead in benchmark
100
+ )
101
+
102
+ trainer = AmazingTrainer(
103
+ model=model,
104
+ tokenizer=tokenizer,
105
+ train_dataset=packed,
106
+ config=config,
107
+ data_collator=BatchCollator(),
108
+ )
109
+
110
+ start = time.time()
111
+ result = trainer.train()
112
+ elapsed = time.time() - start
113
+
114
+ vram_peak = torch.cuda.max_memory_allocated(0) / 1024**3 if torch.cuda.is_available() else 0
115
+ torch.cuda.reset_peak_memory_stats()
116
+
117
+ return {
118
+ "time_s": round(elapsed, 1),
119
+ "vram_peak_gb": round(vram_peak, 2),
120
+ "steps": result["global_step"],
121
+ "loss": result["log_history"][-1]["loss"] if result["log_history"] else -1,
122
+ }
123
+
124
+
125
+ def run_unsloth(model_name: str, dataset_name: str, max_samples: int, max_seq_length: int) -> Dict[str, Any]:
126
+ if not _try_import_unsloth():
127
+ msg = "unsloth not installed" if _unsloth_error is None else f"unsloth import failed: {_unsloth_error}"
128
+ print(f"[bench] {msg}. Install with: pip install unsloth")
129
+ return {"time_s": -1, "vram_peak_gb": -1, "steps": -1, "loss": -1, "error": msg}
130
+
131
+ from unsloth import FastLanguageModel
132
+ from trl import SFTTrainer
133
+ from transformers import TrainingArguments
134
+ from datasets import load_dataset
135
+
136
+ try:
137
+ model, tokenizer = FastLanguageModel.from_pretrained(
138
+ model_name=model_name,
139
+ max_seq_length=max_seq_length,
140
+ dtype=None,
141
+ load_in_4bit=True,
142
+ )
143
+ except (OSError, MemoryError) as e:
144
+ err_msg = str(e)
145
+ if "paging file" in err_msg.lower() or "memory" in err_msg.lower():
146
+ print(f"[bench] unsloth failed: System out of memory (Windows paging file too small)")
147
+ print(f"[bench] unsloth requires more RAM/virtual memory than your system has.")
148
+ print(f"[bench] Tip: Increase Windows virtual memory or close other applications.")
149
+ else:
150
+ print(f"[bench] unsloth model loading failed: {err_msg}")
151
+ return {"time_s": -1, "vram_peak_gb": -1, "steps": -1, "loss": -1, "error": err_msg}
152
+
153
+ try:
154
+ model = FastLanguageModel.get_peft_model(
155
+ model,
156
+ r=16,
157
+ lora_alpha=16,
158
+ lora_dropout=0,
159
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
160
+ use_rslora=True,
161
+ )
162
+
163
+ dataset = load_dataset(dataset_name, split="train")
164
+ if max_samples > 0:
165
+ dataset = dataset.select(range(min(max_samples, len(dataset))))
166
+
167
+ def formatting_func(examples):
168
+ return [
169
+ f"### Instruction:\n{inst}\n\n### Response:\n{out}"
170
+ for inst, out in zip(examples["instruction"], examples["output"])
171
+ ]
172
+
173
+ training_args = TrainingArguments(
174
+ output_dir="./bench_unsloth",
175
+ num_train_epochs=1,
176
+ per_device_train_batch_size=1,
177
+ gradient_accumulation_steps=8,
178
+ learning_rate=2e-4,
179
+ bf16=True,
180
+ logging_steps=1,
181
+ save_steps=999999,
182
+ optim="paged_adamw_8bit",
183
+ max_grad_norm=1.0,
184
+ warmup_ratio=0.1,
185
+ report_to="none",
186
+ )
187
+
188
+ trainer = SFTTrainer(
189
+ model=model,
190
+ tokenizer=tokenizer,
191
+ train_dataset=dataset,
192
+ formatting_func=formatting_func,
193
+ max_seq_length=max_seq_length,
194
+ args=training_args,
195
+ )
196
+
197
+ torch.cuda.reset_peak_memory_stats()
198
+ start = time.time()
199
+ trainer.train()
200
+ elapsed = time.time() - start
201
+ except Exception as e:
202
+ elapsed = time.time() - start if 'start' in locals() else 0
203
+ err_msg = str(e)
204
+ if "paging file" in err_msg.lower() or "memory" in err_msg.lower() or "os error 1455" in err_msg.lower():
205
+ print(f"[bench] unsloth failed: System out of memory")
206
+ print(f"[bench] unsloth requires more RAM/virtual memory than your system has.")
207
+ else:
208
+ print(f"[bench] unsloth training failed: {err_msg}")
209
+ return {
210
+ "time_s": round(elapsed, 1) if elapsed > 0 else -1,
211
+ "vram_peak_gb": -1,
212
+ "steps": getattr(trainer.state, "global_step", 0) if 'trainer' in locals() else 0,
213
+ "loss": -1,
214
+ "error": err_msg,
215
+ }
216
+
217
+ vram_peak = torch.cuda.max_memory_allocated(0) / 1024**3 if torch.cuda.is_available() else 0
218
+ torch.cuda.reset_peak_memory_stats()
219
+
220
+ # Find last logged loss
221
+ loss = -1
222
+ for entry in reversed(trainer.state.log_history):
223
+ if "loss" in entry:
224
+ loss = entry["loss"]
225
+ break
226
+
227
+ return {
228
+ "time_s": round(elapsed, 1),
229
+ "vram_peak_gb": round(vram_peak, 2),
230
+ "steps": trainer.state.global_step,
231
+ "loss": round(loss, 4) if isinstance(loss, float) else -1,
232
+ }
233
+
234
+
235
+ def main():
236
+ parser = argparse.ArgumentParser(description="Benchmark: amazingvmsloth vs unsloth")
237
+ parser.add_argument("--model", default="Qwen/Qwen2.5-1.5B", help="Model name")
238
+ parser.add_argument("--dataset", default="tatsu-lab/alpaca", help="Dataset")
239
+ parser.add_argument("--max-samples", type=int, default=500, help="Samples to use")
240
+ parser.add_argument("--max-seq-length", type=int, default=512, help="Max seq length")
241
+ parser.add_argument("--skip-unsloth", action="store_true", help="Skip unsloth benchmark")
242
+ args = parser.parse_args()
243
+
244
+ print(f"\n{'='*60}")
245
+ print(f" Benchmark: amazingvmsloth vs unsloth")
246
+ print(f"{'='*60}")
247
+ print(f" Model: {args.model}")
248
+ print(f" Dataset: {args.dataset}")
249
+ print(f" Samples: {args.max_samples}")
250
+ print(f" Seq length: {args.max_seq_length}")
251
+ print(f"{'='*60}\n")
252
+
253
+ print("[bench] Running amazingvmsloth...")
254
+ av_result = run_amazingvmsloth(args.model, args.dataset, args.max_samples, args.max_seq_length)
255
+ print(f"[bench] amazingvmsloth done: {av_result['time_s']}s, {av_result['vram_peak_gb']}GB peak VRAM\n")
256
+
257
+ us_result = None
258
+ if not args.skip_unsloth:
259
+ print("[bench] Running unsloth...")
260
+ us_result = run_unsloth(args.model, args.dataset, args.max_samples, args.max_seq_length)
261
+ if us_result.get("error"):
262
+ print(f"[bench] unsloth skipped: {us_result['error']}\n")
263
+ us_result = None
264
+ else:
265
+ print(f"[bench] unsloth done: {us_result['time_s']}s, {us_result['vram_peak_gb']}GB peak VRAM\n")
266
+
267
+ print(f"\n{'='*60}")
268
+ print(f" Results")
269
+ print(f"{'='*60}")
270
+ print(f" {'Metric':<20} {'amazingvmsloth':>15} {'unsloth':>15}")
271
+ print(f" {'-'*20} {'-'*15} {'-'*15}")
272
+ print(f" {'Time (s)':<20} {av_result['time_s']:>15} {us_result['time_s'] if us_result else 'N/A':>15}")
273
+ print(f" {'Peak VRAM (GB)':<20} {av_result['vram_peak_gb']:>15} {us_result['vram_peak_gb'] if us_result else 'N/A':>15}")
274
+ print(f" {'Final Loss':<20} {av_result['loss']:>15} {us_result['loss'] if us_result else 'N/A':>15}")
275
+ print(f" {'Steps':<20} {av_result['steps']:>15} {us_result['steps'] if us_result else 'N/A':>15}")
276
+
277
+ if us_result and us_result["time_s"] > 0 and av_result["time_s"] > 0:
278
+ speedup = us_result["time_s"] / av_result["time_s"]
279
+ vram_saved = us_result["vram_peak_gb"] - av_result["vram_peak_gb"]
280
+ print(f"\n Speedup: {speedup:.2f}x")
281
+ print(f" VRAM saved: {vram_saved:.2f} GB")
282
+ if speedup > 1:
283
+ print(f" amazingvmsloth is {speedup:.2f}x faster!")
284
+ else:
285
+ print(f" unsloth is {1/speedup:.2f}x faster")
286
+
287
+ print(f"{'='*60}\n")
288
+
289
+
290
+ if __name__ == "__main__":
291
+ main()