openllava 3.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (170) hide show
  1. openllava/__init__.py +96 -0
  2. openllava/api/__init__.py +31 -0
  3. openllava/api/callbacks.py +463 -0
  4. openllava/api/config.py +328 -0
  5. openllava/api/fast_model.py +406 -0
  6. openllava/api/loggers.py +501 -0
  7. openllava/api/strategies.py +433 -0
  8. openllava/api/trainer.py +519 -0
  9. openllava/backends/__init__.py +90 -0
  10. openllava/backends/cpu_simd.py +272 -0
  11. openllava/backends/gguf_v2.py +470 -0
  12. openllava/backends/mlx_backend.py +144 -0
  13. openllava/backends/onnx_export.py +272 -0
  14. openllava/backends/rocm_backend.py +190 -0
  15. openllava/backends/safetensors_io.py +137 -0
  16. openllava/backends/tpu_dataloader.py +364 -0
  17. openllava/backends/tpu_spmd.py +373 -0
  18. openllava/backends/tpu_xla.py +326 -0
  19. openllava/backends/xpu_backend.py +175 -0
  20. openllava/cli/__init__.py +18 -0
  21. openllava/cli/commands/__init__.py +9 -0
  22. openllava/cli/commands/benchmark.py +381 -0
  23. openllava/cli/commands/export.py +423 -0
  24. openllava/cli/commands/serve.py +313 -0
  25. openllava/cli/commands/train.py +214 -0
  26. openllava/cli/main.py +122 -0
  27. openllava/core/__init__.py +1 -0
  28. openllava/core/backend.py +485 -0
  29. openllava/core/model.py +939 -0
  30. openllava/core/patcher.py +529 -0
  31. openllava/data/__init__.py +109 -0
  32. openllava/data/collator.py +473 -0
  33. openllava/data/gpu_augmentation.py +303 -0
  34. openllava/data/pipeline.py +198 -0
  35. openllava/data/preprocessing.py +204 -0
  36. openllava/data/smart_batching.py +251 -0
  37. openllava/data/streaming.py +297 -0
  38. openllava/data/templates.py +159 -0
  39. openllava/distributed/__init__.py +223 -0
  40. openllava/distributed/auto_parallel.py +438 -0
  41. openllava/distributed/cluster_config.py +444 -0
  42. openllava/distributed/deepspeed.py +449 -0
  43. openllava/distributed/device_mesh.py +474 -0
  44. openllava/distributed/dtensor_ops.py +423 -0
  45. openllava/distributed/expert_parallel.py +492 -0
  46. openllava/distributed/fsdp.py +372 -0
  47. openllava/distributed/heterogeneous.py +357 -0
  48. openllava/distributed/init_process.py +414 -0
  49. openllava/distributed/parallel_4d.py +851 -0
  50. openllava/distributed/pipeline_parallel.py +481 -0
  51. openllava/distributed/placement.py +364 -0
  52. openllava/distributed/ring_attention.py +224 -0
  53. openllava/distributed/tensor_parallel.py +430 -0
  54. openllava/distributed/topology.py +406 -0
  55. openllava/distributed/zero_hpz.py +379 -0
  56. openllava/distributed/zero_offload.py +469 -0
  57. openllava/eval/__init__.py +1 -0
  58. openllava/eval/mmbench.py +127 -0
  59. openllava/eval/runner.py +312 -0
  60. openllava/eval/scienceqa.py +120 -0
  61. openllava/eval/textvqa.py +119 -0
  62. openllava/experts/__init__.py +50 -0
  63. openllava/experts/moe_layers.py +424 -0
  64. openllava/experts/moe_lora.py +509 -0
  65. openllava/experts/moe_trainer.py +588 -0
  66. openllava/experts/ocr.py +86 -0
  67. openllava/experts/router.py +192 -0
  68. openllava/experts/visual.py +194 -0
  69. openllava/export/__init__.py +24 -0
  70. openllava/export/gguf.py +251 -0
  71. openllava/export/hub.py +165 -0
  72. openllava/export/merge.py +117 -0
  73. openllava/export/quantize.py +152 -0
  74. openllava/inference/__init__.py +31 -0
  75. openllava/inference/continuous_batching.py +519 -0
  76. openllava/inference/engine.py +633 -0
  77. openllava/inference/paged_attention.py +617 -0
  78. openllava/inference/speculative_decoding.py +813 -0
  79. openllava/kernels/__init__.py +32 -0
  80. openllava/kernels/cuda_graphs/__init__.py +3 -0
  81. openllava/kernels/cuda_graphs/graph_trainer.py +161 -0
  82. openllava/kernels/streams.py +143 -0
  83. openllava/kernels/triton/__init__.py +91 -0
  84. openllava/kernels/triton/bitnet_gemm.py +280 -0
  85. openllava/kernels/triton/blindsight.py +348 -0
  86. openllava/kernels/triton/flash_attention.py +364 -0
  87. openllava/kernels/triton/flex_attention.py +248 -0
  88. openllava/kernels/triton/fused_attention.py +543 -0
  89. openllava/kernels/triton/fused_cross_entropy.py +286 -0
  90. openllava/kernels/triton/fused_projector.py +384 -0
  91. openllava/kernels/triton/fused_rmsnorm.py +129 -0
  92. openllava/kernels/triton/fused_rope.py +278 -0
  93. openllava/kernels/triton/fused_swiglu.py +406 -0
  94. openllava/kernels/triton/grouped_gemm.py +339 -0
  95. openllava/kernels/triton/online_softmax.py +249 -0
  96. openllava/kernels/triton/sparse_attention.py +421 -0
  97. openllava/optimizations/__init__.py +342 -0
  98. openllava/optimizations/async_io.py +253 -0
  99. openllava/optimizations/bitnet.py +342 -0
  100. openllava/optimizations/bitnet_a48.py +336 -0
  101. openllava/optimizations/chunked_prefill.py +367 -0
  102. openllava/optimizations/cpu_offload.py +153 -0
  103. openllava/optimizations/curriculum.py +319 -0
  104. openllava/optimizations/eagle_draft.py +430 -0
  105. openllava/optimizations/ema.py +179 -0
  106. openllava/optimizations/fast_nf4.py +312 -0
  107. openllava/optimizations/fp4_quant.py +589 -0
  108. openllava/optimizations/fp8_training.py +411 -0
  109. openllava/optimizations/full_finetune.py +373 -0
  110. openllava/optimizations/galore.py +384 -0
  111. openllava/optimizations/gptq_awq.py +320 -0
  112. openllava/optimizations/kv_compression.py +482 -0
  113. openllava/optimizations/kv_eviction.py +452 -0
  114. openllava/optimizations/kv_quantization.py +448 -0
  115. openllava/optimizations/medusa_heads.py +356 -0
  116. openllava/optimizations/memory_pool.py +90 -0
  117. openllava/optimizations/mixed_precision_quant.py +628 -0
  118. openllava/optimizations/mxfp8_moe.py +418 -0
  119. openllava/optimizations/ngram_draft.py +456 -0
  120. openllava/optimizations/packing.py +142 -0
  121. openllava/optimizations/padding_free.py +116 -0
  122. openllava/optimizations/qat.py +516 -0
  123. openllava/optimizations/schedulers.py +298 -0
  124. openllava/optimizations/selective_checkpoint.py +94 -0
  125. openllava/optimizations/sparse_attn_selector.py +385 -0
  126. openllava/optimizations/split_lora.py +623 -0
  127. openllava/optimizations/torch_compile.py +259 -0
  128. openllava/optimizations/torchao_integration.py +253 -0
  129. openllava/optimizations/tree_verification.py +431 -0
  130. openllava/optimizations/yadis_cross_attn.py +496 -0
  131. openllava/optimizations/yadis_moe_adaptive.py +459 -0
  132. openllava/optimizations/yadis_vq_ema.py +649 -0
  133. openllava/rl/__init__.py +61 -0
  134. openllava/rl/dpo.py +561 -0
  135. openllava/rl/grpo.py +524 -0
  136. openllava/rl/orpo.py +448 -0
  137. openllava/rl/ppo.py +674 -0
  138. openllava/rl/rewards.py +486 -0
  139. openllava/rl/vllm_integration.py +405 -0
  140. openllava/serve/__init__.py +47 -0
  141. openllava/serve/batch_manager.py +523 -0
  142. openllava/serve/metrics.py +387 -0
  143. openllava/serve/middleware.py +421 -0
  144. openllava/serve/openai_api.py +302 -0
  145. openllava/serve/server.py +499 -0
  146. openllava/training/__init__.py +54 -0
  147. openllava/training/bitnet_trainer.py +487 -0
  148. openllava/training/checkpointing.py +342 -0
  149. openllava/training/dora.py +366 -0
  150. openllava/training/lora.py +226 -0
  151. openllava/training/lora_fa.py +369 -0
  152. openllava/training/lora_ga.py +372 -0
  153. openllava/training/lora_plus.py +391 -0
  154. openllava/training/lora_registry.py +393 -0
  155. openllava/training/memory.py +309 -0
  156. openllava/training/trainer.py +614 -0
  157. openllava/utils/__init__.py +76 -0
  158. openllava/utils/auto_detect.py +388 -0
  159. openllava/utils/benchmark.py +433 -0
  160. openllava/utils/hardware_detect.py +393 -0
  161. openllava/utils/hub.py +436 -0
  162. openllava/utils/model_card.py +340 -0
  163. openllava/utils/profiler.py +478 -0
  164. openllava/utils/registry.py +431 -0
  165. openllava-3.0.0.dist-info/METADATA +1299 -0
  166. openllava-3.0.0.dist-info/RECORD +170 -0
  167. openllava-3.0.0.dist-info/WHEEL +5 -0
  168. openllava-3.0.0.dist-info/entry_points.txt +2 -0
  169. openllava-3.0.0.dist-info/licenses/LICENSE +201 -0
  170. openllava-3.0.0.dist-info/top_level.txt +1 -0
openllava/__init__.py ADDED
@@ -0,0 +1,96 @@
1
+ """
2
+ OpenLLaVA — Open-Source Multimodal Vision Injection Framework.
3
+
4
+ Inject vision into any language model. Architecture-agnostic, multi-backend.
5
+
6
+ Usage:
7
+ from openllava import OpenLLaVA, Backend, experts
8
+
9
+ model = OpenLLaVA(
10
+ llm="meta-llama/Llama-3-8B",
11
+ vision_encoder="google/siglip2-so400m-patch14-384",
12
+ backend=Backend.CUDA,
13
+ )
14
+ model.lora(r=64, alpha=128)
15
+ model.train(phase1=dict(dataset="liuhaotian/LLaVA-Pretrain", samples=100_000))
16
+ model.push("my-org/my-model")
17
+ """
18
+
19
+ from __future__ import annotations
20
+
21
+ from . import (
22
+ api,
23
+ backends,
24
+ cli,
25
+ data,
26
+ distributed,
27
+ eval,
28
+ experts,
29
+ export,
30
+ inference,
31
+ kernels,
32
+ optimizations,
33
+ rl,
34
+ serve,
35
+ training,
36
+ utils,
37
+ )
38
+ from .api import FastLanguageModel, FastVisionModel, OpenLLaVATrainer, TrainingConfig
39
+ from .core.backend import Backend, BackendManager, get_backend, is_cuda_available
40
+ from .core.model import OpenLLaVA
41
+ from .core.patcher import AnyResProcessor, ModelPatcher, YakiModel, YakiProjector
42
+ from .kernels import cuda_graphs
43
+ from .kernels import triton as triton_kernels
44
+ from .optimizations import (
45
+ EMAModel,
46
+ awq_quantize,
47
+ compile_model,
48
+ enable_fp8_training,
49
+ gptq_quantize,
50
+ )
51
+ from .utils import HardwareDetector, HardwareInfo, auto_configure, profile_model
52
+
53
+ __version__ = "3.0.0"
54
+ __author__ = "OpceanAI Research Team"
55
+
56
+ __all__ = [
57
+ "OpenLLaVA",
58
+ "Backend",
59
+ "BackendManager",
60
+ "get_backend",
61
+ "is_cuda_available",
62
+ "YakiProjector",
63
+ "YakiModel",
64
+ "AnyResProcessor",
65
+ "ModelPatcher",
66
+ "experts",
67
+ "training",
68
+ "data",
69
+ "rl",
70
+ "export",
71
+ "eval",
72
+ "kernels",
73
+ "optimizations",
74
+ "inference",
75
+ "distributed",
76
+ "serve",
77
+ "backends",
78
+ "api",
79
+ "cli",
80
+ "utils",
81
+ "triton_kernels",
82
+ "cuda_graphs",
83
+ "compile_model",
84
+ "enable_fp8_training",
85
+ "EMAModel",
86
+ "gptq_quantize",
87
+ "awq_quantize",
88
+ "FastVisionModel",
89
+ "FastLanguageModel",
90
+ "OpenLLaVATrainer",
91
+ "TrainingConfig",
92
+ "auto_configure",
93
+ "HardwareDetector",
94
+ "HardwareInfo",
95
+ "profile_model",
96
+ ]
@@ -0,0 +1,31 @@
1
+ """OpenLLaVA High-Level API — Unsloth-compatible fast model loading and training.
2
+
3
+ Provides drop-in replacements for common training frameworks with
4
+ auto-configuration, multi-backend logging, and callback orchestration.
5
+
6
+ Usage:
7
+ from openllava.api import FastVisionModel, OpenLLaVATrainer, TrainingConfig
8
+
9
+ model, tokenizer = FastVisionModel.from_pretrained("openllava/yaki-8b")
10
+ config = TrainingConfig(mode="lora", output_dir="./output")
11
+ trainer = OpenLLaVATrainer(model=model, tokenizer=tokenizer, args=config)
12
+ trainer.train()
13
+ trainer.save_model("./final")
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ from .config import TrainingConfig
19
+ from .fast_model import FastLanguageModel, FastVisionModel
20
+ from .strategies import auto_configure, get_peft_model, load_dataset
21
+ from .trainer import OpenLLaVATrainer
22
+
23
+ __all__ = [
24
+ "FastVisionModel",
25
+ "FastLanguageModel",
26
+ "OpenLLaVATrainer",
27
+ "TrainingConfig",
28
+ "get_peft_model",
29
+ "load_dataset",
30
+ "auto_configure",
31
+ ]
@@ -0,0 +1,463 @@
1
+ """OpenLLaVA Callback System.
2
+
3
+ Complete training callback infrastructure with hooks for monitoring, checkpointing,
4
+ early stopping, and memory profiling. Designed for composition via CallbackList.
5
+
6
+ Usage:
7
+ from openllava.api.callbacks import EarlyStoppingCallback, CallbackList
8
+
9
+ callbacks = CallbackList([
10
+ EarlyStoppingCallback(monitor="loss", patience=3),
11
+ ModelCheckpointCallback(dirpath="./checkpoints", monitor="loss"),
12
+ ])
13
+ callbacks.on_train_begin()
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import math
19
+ import os
20
+ import time
21
+ from abc import ABC
22
+ from typing import Any
23
+
24
+ import torch
25
+
26
+ _HOOKS = [
27
+ "on_train_begin", "on_train_end",
28
+ "on_epoch_begin", "on_epoch_end",
29
+ "on_step_begin", "on_step_end",
30
+ "on_save", "on_log", "on_evaluate",
31
+ ]
32
+
33
+
34
+ class BaseCallback(ABC):
35
+ """Abstract base for all training callbacks.
36
+
37
+ Subclasses override any combination of the hook methods below.
38
+ All hooks receive **kwargs for forward compatibility.
39
+ """
40
+
41
+ def on_train_begin(self, **kwargs) -> None:
42
+ pass
43
+
44
+ def on_train_end(self, **kwargs) -> None:
45
+ pass
46
+
47
+ def on_epoch_begin(self, **kwargs) -> None:
48
+ pass
49
+
50
+ def on_epoch_end(self, **kwargs) -> None:
51
+ pass
52
+
53
+ def on_step_begin(self, **kwargs) -> None:
54
+ pass
55
+
56
+ def on_step_end(self, **kwargs) -> None:
57
+ pass
58
+
59
+ def on_save(self, **kwargs) -> None:
60
+ pass
61
+
62
+ def on_log(self, **kwargs) -> None:
63
+ pass
64
+
65
+ def on_evaluate(self, **kwargs) -> None:
66
+ pass
67
+
68
+
69
+ class CallbackList:
70
+ """Manages a list of callbacks and dispatches hooks to all of them.
71
+
72
+ Usage:
73
+ cbl = CallbackList([EarlyStoppingCallback(), ModelCheckpointCallback()])
74
+ cbl.on_train_begin(trainer=trainer)
75
+ """
76
+
77
+ def __init__(self, callbacks: list[BaseCallback] | None = None):
78
+ self._callbacks: list[BaseCallback] = list(callbacks) if callbacks else []
79
+
80
+ def append(self, callback: BaseCallback) -> None:
81
+ self._callbacks.append(callback)
82
+
83
+ def extend(self, callbacks: list[BaseCallback]) -> None:
84
+ self._callbacks.extend(callbacks)
85
+
86
+ def dispatch(self, hook: str, **kwargs) -> None:
87
+ for cb in self._callbacks:
88
+ try:
89
+ getattr(cb, hook)(**kwargs)
90
+ except Exception as e:
91
+ print(f"[CallbackList] Error in {cb.__class__.__name__}.{hook}: {e}")
92
+
93
+ def on_train_begin(self, **kwargs) -> None:
94
+ self.dispatch("on_train_begin", **kwargs)
95
+
96
+ def on_train_end(self, **kwargs) -> None:
97
+ self.dispatch("on_train_end", **kwargs)
98
+
99
+ def on_epoch_begin(self, **kwargs) -> None:
100
+ self.dispatch("on_epoch_begin", **kwargs)
101
+
102
+ def on_epoch_end(self, **kwargs) -> None:
103
+ self.dispatch("on_epoch_end", **kwargs)
104
+
105
+ def on_step_begin(self, **kwargs) -> None:
106
+ self.dispatch("on_step_begin", **kwargs)
107
+
108
+ def on_step_end(self, **kwargs) -> None:
109
+ self.dispatch("on_step_end", **kwargs)
110
+
111
+ def on_save(self, **kwargs) -> None:
112
+ self.dispatch("on_save", **kwargs)
113
+
114
+ def on_log(self, **kwargs) -> None:
115
+ self.dispatch("on_log", **kwargs)
116
+
117
+ def on_evaluate(self, **kwargs) -> None:
118
+ self.dispatch("on_evaluate", **kwargs)
119
+
120
+ def __len__(self) -> int:
121
+ return len(self._callbacks)
122
+
123
+ def __iter__(self):
124
+ return iter(self._callbacks)
125
+
126
+
127
+ class EarlyStoppingCallback(BaseCallback):
128
+ """Stop training when a monitored metric stops improving.
129
+
130
+ Args:
131
+ monitor: Metric name to monitor (default: "loss").
132
+ patience: Number of validation checks with no improvement before stopping.
133
+ min_delta: Minimum change to qualify as an improvement.
134
+ mode: "min" (lower is better) or "max" (higher is better).
135
+ """
136
+
137
+ def __init__(
138
+ self,
139
+ monitor: str = "loss",
140
+ patience: int = 3,
141
+ min_delta: float = 0.0,
142
+ mode: str = "min",
143
+ ):
144
+ if patience < 1:
145
+ raise ValueError(f"patience must be >= 1, got {patience}")
146
+ if mode not in {"min", "max"}:
147
+ raise ValueError(f"mode must be 'min' or 'max', got '{mode}'")
148
+ self.monitor = monitor
149
+ self.patience = patience
150
+ self.min_delta = min_delta
151
+ self.mode = mode
152
+ self._best: float | None = None
153
+ self._counter: int = 0
154
+ self._should_stop: bool = False
155
+
156
+ def _is_improvement(self, current: float) -> bool:
157
+ if self._best is None:
158
+ return True
159
+ if self.mode == "min":
160
+ return current < self._best - self.min_delta
161
+ return current > self._best + self.min_delta
162
+
163
+ def on_train_begin(self, **kwargs) -> None:
164
+ self._best = None
165
+ self._counter = 0
166
+ self._should_stop = False
167
+
168
+ def on_evaluate(self, **kwargs) -> None:
169
+ current = kwargs.get("metrics", {}).get(self.monitor, None)
170
+ if current is None:
171
+ return
172
+ if self._is_improvement(current):
173
+ self._best = current
174
+ self._counter = 0
175
+ else:
176
+ self._counter += 1
177
+ if self._counter >= self.patience:
178
+ self._should_stop = True
179
+ print(f"[EarlyStopping] Stopping after {self._counter} "
180
+ f"checks without improvement in {self.monitor}")
181
+
182
+ @property
183
+ def should_stop(self) -> bool:
184
+ return self._should_stop
185
+
186
+ @property
187
+ def best_value(self) -> float | None:
188
+ return self._best
189
+
190
+
191
+ class ModelCheckpointCallback(BaseCallback):
192
+ """Save model checkpoints during training.
193
+
194
+ Args:
195
+ dirpath: Directory to save checkpoints.
196
+ monitor: Metric to monitor for best model selection.
197
+ save_top_k: Number of best checkpoints to keep (-1 for all).
198
+ mode: "min" or "max" for metric comparison.
199
+ save_last: Whether to save the last checkpoint.
200
+ """
201
+
202
+ def __init__(
203
+ self,
204
+ dirpath: str = "./checkpoints",
205
+ monitor: str = "loss",
206
+ save_top_k: int = 1,
207
+ mode: str = "min",
208
+ save_last: bool = True,
209
+ ):
210
+ self.dirpath = dirpath
211
+ self.monitor = monitor
212
+ self.save_top_k = save_top_k
213
+ self.mode = mode
214
+ self.save_last = save_last
215
+ self._best_scores: list[tuple[float, str]] = []
216
+ os.makedirs(dirpath, exist_ok=True)
217
+
218
+ def on_evaluate(self, **kwargs) -> None:
219
+ score = kwargs.get("metrics", {}).get(self.monitor)
220
+ if score is None:
221
+ return
222
+ step = kwargs.get("step", 0)
223
+ model = kwargs.get("model")
224
+ if model is None:
225
+ return
226
+
227
+ path = os.path.join(self.dirpath, f"checkpoint-{step}")
228
+ self._save_model(model, path)
229
+
230
+ if self.save_top_k > 0:
231
+ self._best_scores.append((float(score), path))
232
+ self._best_scores.sort(
233
+ key=lambda x: x[0],
234
+ reverse=(self.mode == "max"),
235
+ )
236
+ while len(self._best_scores) > self.save_top_k:
237
+ _, old_path = self._best_scores.pop()
238
+ self._remove_checkpoint(old_path)
239
+
240
+ def on_train_end(self, **kwargs) -> None:
241
+ if self.save_last:
242
+ model = kwargs.get("model")
243
+ if model is not None:
244
+ path = os.path.join(self.dirpath, "last")
245
+ self._save_model(model, path)
246
+
247
+ def _save_model(self, model, path: str) -> None:
248
+ try:
249
+ os.makedirs(path, exist_ok=True)
250
+ if hasattr(model, "save_pretrained"):
251
+ model.save_pretrained(path)
252
+ else:
253
+ torch.save(model.state_dict(), os.path.join(path, "pytorch_model.bin"))
254
+ except Exception as e:
255
+ print(f"[ModelCheckpoint] Failed to save to {path}: {e}")
256
+
257
+ @staticmethod
258
+ def _remove_checkpoint(path: str) -> None:
259
+ import shutil
260
+ if os.path.exists(path):
261
+ shutil.rmtree(path, ignore_errors=True)
262
+
263
+
264
+ class GradientAccumulationCallback(BaseCallback):
265
+ """Dynamic gradient accumulation step adjustment.
266
+
267
+ Adjusts accumulation steps based on observed batch memory usage
268
+ to fit within available GPU memory.
269
+ """
270
+
271
+ def __init__(self, target_effective_batch: int = 32, min_accum: int = 1, max_accum: int = 128):
272
+ self.target = target_effective_batch
273
+ self.min_accum = min_accum
274
+ self.max_accum = max_accum
275
+ self.current_accum: int = 1
276
+ self._oom_count: int = 0
277
+
278
+ def on_train_begin(self, **kwargs) -> None:
279
+ self.current_accum = getattr(
280
+ kwargs.get("trainer", {}), "gradient_accumulation_steps", 1
281
+ )
282
+
283
+ def on_step_end(self, **kwargs) -> None:
284
+ loss = kwargs.get("loss", None)
285
+ if loss is not None and torch.isnan(loss):
286
+ self.current_accum = min(self.current_accum * 2, self.max_accum)
287
+ self._oom_count += 1
288
+ print(f"[GradientAccum] NaN loss detected, increasing accumulation to {self.current_accum}")
289
+
290
+
291
+ class LearningRateMonitorCallback(BaseCallback):
292
+ """Log learning rate at each step."""
293
+
294
+ def __init__(self):
295
+ self._lrs: list[tuple[int, float]] = []
296
+
297
+ def on_step_end(self, **kwargs) -> None:
298
+ step = kwargs.get("step", len(self._lrs))
299
+ optimizer = kwargs.get("optimizer")
300
+ if optimizer is not None and len(optimizer.param_groups) > 0:
301
+ lr = optimizer.param_groups[0]["lr"]
302
+ self._lrs.append((step, lr))
303
+
304
+ def get_history(self) -> list[tuple[int, float]]:
305
+ return self._lrs
306
+
307
+ def on_train_end(self, **kwargs) -> None:
308
+ self._lrs.clear()
309
+
310
+
311
+ class MemoryProfilerCallback(BaseCallback):
312
+ """Log GPU/CPU memory usage periodically."""
313
+
314
+ def __init__(self, log_every_n_steps: int = 100):
315
+ self.log_every = log_every_n_steps
316
+ self._step: int = 0
317
+ self._peak_gpu_mb: float = 0.0
318
+
319
+ def on_step_end(self, **kwargs) -> None:
320
+ self._step += 1
321
+ if self._step % self.log_every != 0:
322
+ return
323
+ info_parts = []
324
+ if torch.cuda.is_available():
325
+ try:
326
+ import pynvml
327
+ pynvml.nvmlInit()
328
+ for i in range(torch.cuda.device_count()):
329
+ handle = pynvml.nvmlDeviceGetHandleByIndex(i)
330
+ info = pynvml.nvmlDeviceGetMemoryInfo(handle)
331
+ gb = info.used / (1024 ** 3)
332
+ self._peak_gpu_mb = max(self._peak_gpu_mb, info.used / (1024 ** 2))
333
+ info_parts.append(f"GPU{i}: {gb:.1f}GB")
334
+ except ImportError:
335
+ alloc = torch.cuda.memory_allocated() / (1024 ** 3)
336
+ reserved = torch.cuda.memory_reserved() / (1024 ** 3)
337
+ info_parts.append(f"GPU alloc: {alloc:.1f}GB, reserved: {reserved:.1f}GB")
338
+ import psutil
339
+ try:
340
+ ram = psutil.virtual_memory()
341
+ info_parts.append(f"RAM: {ram.percent:.0f}%")
342
+ except ImportError:
343
+ pass
344
+ if info_parts:
345
+ print(f"[Memory] Step {self._step}: {' | '.join(info_parts)}")
346
+
347
+
348
+ class TimingCallback(BaseCallback):
349
+ """Track epoch and batch timing statistics."""
350
+
351
+ def __init__(self):
352
+ self._epoch_start: float = 0.0
353
+ self._step_start: float = 0.0
354
+ self.epoch_times: list[float] = []
355
+ self.step_times: list[float] = []
356
+
357
+ def on_train_begin(self, **kwargs) -> None:
358
+ self.epoch_times.clear()
359
+ self.step_times.clear()
360
+
361
+ def on_epoch_begin(self, **kwargs) -> None:
362
+ self._epoch_start = time.perf_counter()
363
+
364
+ def on_epoch_end(self, **kwargs) -> None:
365
+ elapsed = time.perf_counter() - self._epoch_start
366
+ self.epoch_times.append(elapsed)
367
+
368
+ def on_step_begin(self, **kwargs) -> None:
369
+ self._step_start = time.perf_counter()
370
+
371
+ def on_step_end(self, **kwargs) -> None:
372
+ elapsed = time.perf_counter() - self._step_start
373
+ self.step_times.append(elapsed)
374
+
375
+ def summary(self) -> dict[str, float]:
376
+ if not self.step_times:
377
+ return {}
378
+ return {
379
+ "mean_step_s": sum(self.step_times) / len(self.step_times),
380
+ "min_step_s": min(self.step_times),
381
+ "max_step_s": max(self.step_times),
382
+ "mean_epoch_s": sum(self.epoch_times) / len(self.epoch_times) if self.epoch_times else 0.0,
383
+ }
384
+
385
+
386
+ class ProgressBarCallback(BaseCallback):
387
+ """Rich progress bar display for training progress."""
388
+
389
+ def __init__(self, total_steps: int = 0, use_rich: bool = False):
390
+ self.total_steps = total_steps
391
+ self._current: int = 0
392
+ self._pbar: Any = None
393
+ self._use_rich = use_rich
394
+
395
+ def on_train_begin(self, **kwargs) -> None:
396
+ if self._use_rich:
397
+ try:
398
+ from rich.progress import BarColumn, Progress, TextColumn, TimeRemainingColumn
399
+ self._rich_progress = Progress(
400
+ TextColumn("[progress.description]{task.description}"),
401
+ BarColumn(),
402
+ TextColumn("{task.completed}/{task.total}"),
403
+ TimeRemainingColumn(),
404
+ )
405
+ self._rich_progress.start()
406
+ self._task = self._rich_progress.add_task("[cyan]Training...", total=self.total_steps)
407
+ except ImportError:
408
+ self._use_rich = False
409
+ if not self._use_rich:
410
+ from tqdm import tqdm
411
+ self._pbar = tqdm(total=self.total_steps, desc="Training", unit="step")
412
+
413
+ def on_step_end(self, **kwargs) -> None:
414
+ self._current += 1
415
+ loss = kwargs.get("loss", None)
416
+ postfix = {}
417
+ if loss is not None:
418
+ postfix["loss"] = f"{loss:.4f}" if isinstance(loss, (int, float)) else str(loss)
419
+ if self._use_rich and hasattr(self, "_rich_progress"):
420
+ self._rich_progress.update(self._task, advance=1)
421
+ elif self._pbar is not None:
422
+ if postfix:
423
+ self._pbar.set_postfix(postfix)
424
+ self._pbar.update(1)
425
+
426
+ def on_train_end(self, **kwargs) -> None:
427
+ if self._use_rich and hasattr(self, "_rich_progress"):
428
+ self._rich_progress.stop()
429
+ elif self._pbar is not None:
430
+ self._pbar.close()
431
+
432
+
433
+ class NaNMonitorCallback(BaseCallback):
434
+ """Stop training if NaN values are detected in loss or gradients."""
435
+
436
+ def __init__(self, max_nan_count: int = 5):
437
+ self.max_nan_count = max_nan_count
438
+ self._nan_count: int = 0
439
+ self._should_stop: bool = False
440
+
441
+ def on_train_begin(self, **kwargs) -> None:
442
+ self._nan_count = 0
443
+ self._should_stop = False
444
+
445
+ def on_step_end(self, **kwargs) -> None:
446
+ loss = kwargs.get("loss", None)
447
+ if loss is None:
448
+ return
449
+ loss_val = loss.item() if hasattr(loss, "item") else float(loss)
450
+ grad_norm = kwargs.get("grad_norm", None)
451
+ is_nan = math.isnan(loss_val) or math.isinf(loss_val)
452
+ if grad_norm is not None and hasattr(grad_norm, "item"):
453
+ gn = grad_norm.item() if hasattr(grad_norm, "item") else float(grad_norm)
454
+ is_nan = is_nan or math.isnan(gn) or math.isinf(gn)
455
+ if is_nan:
456
+ self._nan_count += 1
457
+ print(f"[NaNMonitor] NaN/Inf detected ({self._nan_count}/{self.max_nan_count})")
458
+ if self._nan_count >= self.max_nan_count:
459
+ self._should_stop = True
460
+
461
+ @property
462
+ def should_stop(self) -> bool:
463
+ return self._should_stop