langtune 0.1.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
langtune/model/hub.py ADDED
@@ -0,0 +1,109 @@
1
+ """
2
+ Hub Resolver for efficient model downloads.
3
+
4
+ Handles interactions with the Hugging Face Hub, including:
5
+ - Snapshot downloads with caching
6
+ - Authentication
7
+ - Offline mode support
8
+ - File resolution
9
+ """
10
+
11
+ import os
12
+ import logging
13
+ from typing import Optional, List, Dict, Union, Path
14
+ from pathlib import Path as PathLib
15
+ from huggingface_hub import snapshot_download, get_token, HfFolder
16
+ from huggingface_hub.utils import LocalEntryNotFoundError, EntryNotFoundError, RevisionNotFoundError, RepositoryNotFoundError
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ class HubResolver:
21
+ """
22
+ Resolves and downloads model files from the Hugging Face Hub.
23
+
24
+ Optimized for:
25
+ - Speed (concurrent downloads)
26
+ - Caching (avoid re-downloading)
27
+ - Offline usage (finding local files)
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ cache_dir: Optional[str] = None,
33
+ local_files_only: bool = False,
34
+ token: Optional[str] = None,
35
+ ):
36
+ self.cache_dir = cache_dir or os.environ.get("LANGTUNE_CACHE_DIR")
37
+ self.local_files_only = local_files_only
38
+ self.token = token or get_token()
39
+
40
+ def resolve(
41
+ self,
42
+ model_id: str,
43
+ revision: str = "main",
44
+ allow_patterns: Optional[List[str]] = None,
45
+ ignore_patterns: Optional[List[str]] = None,
46
+ ) -> PathLib:
47
+ """
48
+ Download or find model snapshot.
49
+
50
+ Args:
51
+ model_id: The model ID on HF Hub (e.g. 'meta-llama/Llama-2-7b-hf')
52
+ revision: Branch or commit hash
53
+ allow_patterns: Files to include (default: ['*.safetensors', '*.json', '*.model'])
54
+ ignore_patterns: Files to exclude (default: ['*.bin', '*.pth'])
55
+
56
+ Returns:
57
+ Path to the directory containing the model files.
58
+ """
59
+ # Default patterns to prioritize safetensors and configs
60
+ if allow_patterns is None:
61
+ allow_patterns = ["*.safetensors", "*.json", "*.model", "tokenizer*"]
62
+
63
+ if ignore_patterns is None:
64
+ # Explicitly ignore pytorch/pickle weights if we want to enforce safetensors
65
+ ignore_patterns = ["*.bin", "*.pth", "*.pt"]
66
+
67
+ logger.info(f"Resolving model {model_id} (revision={revision})...")
68
+
69
+ try:
70
+ download_path = snapshot_download(
71
+ repo_id=model_id,
72
+ revision=revision,
73
+ cache_dir=self.cache_dir,
74
+ local_files_only=self.local_files_only,
75
+ token=self.token,
76
+ allow_patterns=allow_patterns,
77
+ ignore_patterns=ignore_patterns,
78
+ resume_download=True,
79
+ max_workers=8, # Parallel downloads
80
+ tqdm_class=None, # We might want to hook a custom progress bar later
81
+ )
82
+
83
+ logger.info(f"Model resolved to: {download_path}")
84
+ return PathLib(download_path)
85
+
86
+ except (LocalEntryNotFoundError, EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError) as e:
87
+ if self.local_files_only:
88
+ logger.error(f"Model {model_id} not found locally and offline mode is enabled.")
89
+ raise RuntimeError(f"Model {model_id} not found locally. Disable offline mode to download.") from e
90
+ else:
91
+ logger.error(f"Failed to download model {model_id}: {e}")
92
+ raise RuntimeError(f"Failed to resolve model {model_id}: {e}") from e
93
+
94
+ except Exception as e:
95
+ logger.error(f"Unexpected error resolving model {model_id}: {e}")
96
+ raise
97
+
98
+ @property
99
+ def is_offline(self) -> bool:
100
+ return self.local_files_only
101
+
102
+ # Global instance capabilities
103
+ _default_resolver = None
104
+
105
+ def get_resolver(offline: bool = False, cache_dir: Optional[str] = None) -> HubResolver:
106
+ global _default_resolver
107
+ if _default_resolver is None or _default_resolver.local_files_only != offline:
108
+ _default_resolver = HubResolver(local_files_only=offline, cache_dir=cache_dir)
109
+ return _default_resolver
@@ -0,0 +1,84 @@
1
+ """
2
+ Model Loader.
3
+
4
+ Orchestrates the loading pipeline:
5
+ 1. Hub Resolution
6
+ 2. Architecture Construction
7
+ 3. Streaming Weight Loading
8
+ 4. Kernel Injection
9
+ """
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from typing import Optional, Dict, Any, Union
14
+ import logging
15
+
16
+ from .hub import HubResolver, get_resolver
17
+ from .safetensors import TensorStreamer
18
+ from .weights import WeightLoader
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+ class ModelLoader:
23
+ """
24
+ Main entry point for loading models with high performance.
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ offline: bool = False,
30
+ cache_dir: Optional[str] = None
31
+ ):
32
+ self.resolver = get_resolver(offline=offline, cache_dir=cache_dir)
33
+
34
+ def load(
35
+ self,
36
+ model_id: str,
37
+ quantization: Optional[str] = None, # "nf4", "fp4"
38
+ dtype: str = "bf16",
39
+ device: str = "cuda"
40
+ ) -> nn.Module:
41
+ """
42
+ Load a model with the optimized pipeline.
43
+ """
44
+ # 1. Resolve
45
+ model_path = self.resolver.resolve(model_id)
46
+
47
+ # 2. Config & Architecture
48
+ # For now, we rely on HF config to build the skeleton
49
+ # In the future, we will use our "architectures" registry
50
+ from transformers import AutoConfig, AutoModelForCausalLM
51
+
52
+ config = AutoConfig.from_pretrained(model_path)
53
+
54
+ logger.info(f"Building model skeleton for {model_id}...")
55
+ # Creation on 'meta' device avoids memory allocation!
56
+ with torch.device("meta"):
57
+ model = AutoModelForCausalLM.from_config(config)
58
+
59
+ # Move empty shell to CPU/device?
60
+ # Meta tensors can't be used directly for loading usually without `to_empty`
61
+ # Using `to_empty` moves to device but allocating memory
62
+ model = model.to_empty(device=device)
63
+
64
+ # 3. Stream Weights
65
+ logger.info("Streaming weights...")
66
+ streamer = TensorStreamer(model_path)
67
+
68
+ # 4. Load & Quantize
69
+ compute_dtype = torch.bfloat16 if dtype == "bf16" else torch.float16
70
+ loader = WeightLoader(streamer, quantization_mode=quantization, compute_dtype=compute_dtype, device=device)
71
+
72
+ loader.load_into_module(model)
73
+
74
+ # 5. Kernel Injection (Placeholder)
75
+ # self._inject_kernels(model)
76
+
77
+ return model
78
+
79
+ def _inject_kernels(self, model: nn.Module):
80
+ """
81
+ Replace layers with Langtrain Custom Kernels.
82
+ """
83
+ # Implementation to be added in next step
84
+ pass
@@ -0,0 +1,104 @@
1
+ """
2
+ Safetensors Streamer.
3
+
4
+ Efficiently streams tensors from disk using memory mapping.
5
+ """
6
+
7
+ import os
8
+ import logging
9
+ import torch
10
+ from typing import Dict, List, Optional, Union, Iterator, Tuple
11
+ from pathlib import Path
12
+ from safetensors import safe_open
13
+ from dataclasses import dataclass
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ @dataclass
18
+ class TensorInfo:
19
+ """Metadata about a tensor in a safetensors file."""
20
+ name: str
21
+ shape: List[int]
22
+ dtype: str
23
+ file_path: str
24
+
25
+ class TensorStreamer:
26
+ """
27
+ Manages lazy loading of tensors from a directory of safetensors files.
28
+ """
29
+
30
+ def __init__(self, model_dir: Union[str, Path]):
31
+ self.model_dir = Path(model_dir)
32
+ self.files = sorted(list(self.model_dir.glob("*.safetensors")))
33
+
34
+ if not self.files:
35
+ raise FileNotFoundError(f"No .safetensors files found in {self.model_dir}")
36
+
37
+ self.index: Dict[str, TensorInfo] = {}
38
+ self._build_index()
39
+
40
+ def _build_index(self):
41
+ """Scan files and build an index of tensor locations."""
42
+ logger.info(f"Indexing {len(self.files)} safetensors files...")
43
+
44
+ for file_path in self.files:
45
+ try:
46
+ with safe_open(file_path, framework="pt", device="cpu") as f:
47
+ for key in f.keys():
48
+ # We don't load the tensor, just inspecting keys
49
+ # Note: safe_open doesn't give shape/dtype without loading or accessing metadata
50
+ # Ideally we read the header. safe_open provides proper access.
51
+ # For pure metadata scan without loading payload, external tools or specialized logic is best.
52
+ # But standard safe_open is memory efficient (mmap).
53
+
54
+ # We store just the file mapping for now to keep it fast.
55
+ # Conflict check
56
+ if key in self.index:
57
+ logger.warning(f"Duplicate tensor {key} found in {file_path}, ignoring duplicate.")
58
+ continue
59
+
60
+ self.index[key] = TensorInfo(
61
+ name=key,
62
+ shape=[], # Lazy populate if needed, or read tensor for shape cheap
63
+ dtype="",
64
+ file_path=str(file_path)
65
+ )
66
+ except Exception as e:
67
+ logger.error(f"Failed to read {file_path}: {e}")
68
+ raise
69
+
70
+ def has_tensor(self, name: str) -> bool:
71
+ return name in self.index
72
+
73
+ def get_tensor(self, name: str, device: str = "cpu") -> torch.Tensor:
74
+ """
75
+ Load a single tensor.
76
+ """
77
+ if name not in self.index:
78
+ raise KeyError(f"Tensor {name} not found in model files.")
79
+
80
+ info = self.index[name]
81
+
82
+ with safe_open(info.file_path, framework="pt", device=device) as f:
83
+ return f.get_tensor(name)
84
+
85
+ def load_state_dict(self, device: str = "cpu") -> Dict[str, torch.Tensor]:
86
+ """
87
+ Load all tensors into a state dict (WARNING: High Memory Usage).
88
+ Use for small models or debugging only.
89
+ """
90
+ state_dict = {}
91
+ for name in self.index:
92
+ state_dict[name] = self.get_tensor(name, device)
93
+ return state_dict
94
+
95
+ def stream(self, device: str = "cpu") -> Iterator[Tuple[str, torch.Tensor]]:
96
+ """
97
+ Yields (name, tensor) pairs one by one to minimize memory usage.
98
+ """
99
+ # Optimize by opening files once
100
+ for file_path in self.files:
101
+ with safe_open(file_path, framework="pt", device=device) as f:
102
+ for key in f.keys():
103
+ yield key, f.get_tensor(key)
104
+
@@ -0,0 +1,100 @@
1
+ """
2
+ Weight Loader.
3
+
4
+ Handles reading weights from the streamer and applying:
5
+ - Quantize-on-Read (NF4/FP4)
6
+ - Precision casting (BF16/FP16)
7
+ - Kernel weight format definition
8
+ """
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from typing import Dict, Any, Optional
13
+ import logging
14
+
15
+ from .safetensors import TensorStreamer
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+ class WeightLoader:
20
+ """
21
+ Loads weights into a model structure.
22
+
23
+ Supports:
24
+ - Direct quantization (NF4) during loading to save memory
25
+ - BF16/FP16 casting
26
+ - Kernel-specific formatting
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ streamer: TensorStreamer,
32
+ quantization_mode: Optional[str] = None, # "nf4", "fp4", "int8"
33
+ compute_dtype: torch.dtype = torch.bfloat16,
34
+ device: str = "cuda"
35
+ ):
36
+ self.streamer = streamer
37
+ self.quantization_mode = quantization_mode
38
+ self.compute_dtype = compute_dtype
39
+ self.device = device
40
+
41
+ def load_into_module(self, module: nn.Module, prefix: str = ""):
42
+ """
43
+ Populate a module's weights from the streamer.
44
+ """
45
+ # Get all parameters/buffers in the module
46
+ # state_dict keys are relative to the module if we use module.named_parameters()
47
+ # But streamer has full keys (e.g. "model.layers.0.self_attn.q_proj.weight")
48
+
49
+ # We assume the model structure matches the streamer keys for now
50
+ # Ideally, we iterate over the model parameters and fetch corresponding keys
51
+
52
+ for name, param in module.named_parameters(recurse=False):
53
+ full_name = f"{prefix}.{name}" if prefix else name
54
+
55
+ if self.streamer.has_tensor(full_name):
56
+ tensor = self.streamer.get_tensor(full_name, device="cpu")
57
+
58
+ # Apply Quantize-on-Read
59
+ if self.quantization_mode and "lora" not in full_name and "norm" not in full_name:
60
+ self._quantize_and_assign(module, name, tensor)
61
+ else:
62
+ # Standard loading
63
+ with torch.no_grad():
64
+ param.data = tensor.to(self.device, dtype=self.compute_dtype)
65
+
66
+ # Recursively load children
67
+ for child_name, child in module.named_children():
68
+ child_prefix = f"{prefix}.{child_name}" if prefix else child_name
69
+ self.load_into_module(child, prefix=child_prefix)
70
+
71
+ def _quantize_and_assign(self, module: nn.Module, param_name: str, tensor: torch.Tensor):
72
+ """
73
+ Quantize tensor and replace parameter in module.
74
+ Currently falls back to bitsandbytes functional quantization if usage is detected,
75
+ or custom kernel logic.
76
+ """
77
+ # Placeholder for NF4 quantization connection
78
+ # In a real implementation, we would use bitsandbytes.functional.quantize_4bit
79
+ # and replace the nn.Parameter with a Params4bit
80
+
81
+ if self.quantization_mode == "nf4":
82
+ try:
83
+ import bitsandbytes as bnb
84
+ from bitsandbytes.nn import Params4bit
85
+
86
+ # Quantize to NF4
87
+ # This requires GPU generally
88
+ tensor_gpu = tensor.to(self.device)
89
+ param_4bit = Params4bit(tensor_gpu, requires_grad=False, compress_statistics=True, quant_type="nf4")
90
+
91
+ # Replace parameter
92
+ setattr(module, param_name, param_4bit)
93
+
94
+ except ImportError:
95
+ logger.warning("bitsandbytes not found, falling back to BF16/FP16")
96
+ with torch.no_grad():
97
+ getattr(module, param_name).data = tensor.to(self.device, dtype=self.compute_dtype)
98
+ else:
99
+ with torch.no_grad():
100
+ getattr(module, param_name).data = tensor.to(self.device, dtype=self.compute_dtype)
langtune/models.py ADDED
@@ -0,0 +1,19 @@
1
+ """
2
+ models.py: LoRA-enabled transformer models for Langtune.
3
+
4
+ This module now re-exports components from `langtune.nn` for backward compatibility.
5
+ """
6
+
7
+ from .nn.layers import LoRALinear, MultiHeadAttention
8
+ from .nn.transformer import TransformerBlock, LoRALanguageModel
9
+ from .nn.fast_transformer import FastMultiHeadAttention, FastTransformerBlock, FastLoRALanguageModel
10
+
11
+ __all__ = [
12
+ "LoRALinear",
13
+ "MultiHeadAttention",
14
+ "TransformerBlock",
15
+ "LoRALanguageModel",
16
+ "FastMultiHeadAttention",
17
+ "FastTransformerBlock",
18
+ "FastLoRALanguageModel",
19
+ ]