titan-synapse 0.1.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. package/CONTRIBUTING.md +187 -0
  2. package/Cargo.lock +3976 -0
  3. package/Cargo.toml +10 -0
  4. package/LICENSE +190 -0
  5. package/PROGRESS.md +151 -0
  6. package/README.md +514 -0
  7. package/TEST_LOG.md +220 -0
  8. package/config/default.yaml +36 -0
  9. package/crates/synapse/Cargo.toml +70 -0
  10. package/crates/synapse/src/cli/bench.rs +44 -0
  11. package/crates/synapse/src/cli/eval.rs +395 -0
  12. package/crates/synapse/src/cli/export.rs +45 -0
  13. package/crates/synapse/src/cli/hub.rs +179 -0
  14. package/crates/synapse/src/cli/import.rs +35 -0
  15. package/crates/synapse/src/cli/learn.rs +53 -0
  16. package/crates/synapse/src/cli/mod.rs +10 -0
  17. package/crates/synapse/src/cli/models.rs +36 -0
  18. package/crates/synapse/src/cli/pull.rs +60 -0
  19. package/crates/synapse/src/cli/status.rs +52 -0
  20. package/crates/synapse/src/cli/train.rs +99 -0
  21. package/crates/synapse/src/config.rs +220 -0
  22. package/crates/synapse/src/dashboard.rs +281 -0
  23. package/crates/synapse/src/format/manifest.rs +57 -0
  24. package/crates/synapse/src/format/mod.rs +4 -0
  25. package/crates/synapse/src/format/packer.rs +213 -0
  26. package/crates/synapse/src/inference/engine.rs +361 -0
  27. package/crates/synapse/src/inference/kv_cache.rs +97 -0
  28. package/crates/synapse/src/inference/lora.rs +166 -0
  29. package/crates/synapse/src/inference/mod.rs +9 -0
  30. package/crates/synapse/src/inference/model.rs +167 -0
  31. package/crates/synapse/src/inference/sampler.rs +133 -0
  32. package/crates/synapse/src/inference/speculative.rs +153 -0
  33. package/crates/synapse/src/learn/cloud_fallback.rs +186 -0
  34. package/crates/synapse/src/learn/engine.rs +109 -0
  35. package/crates/synapse/src/learn/mod.rs +5 -0
  36. package/crates/synapse/src/main.rs +185 -0
  37. package/crates/synapse/src/memory/extractor.rs +201 -0
  38. package/crates/synapse/src/memory/graph.rs +332 -0
  39. package/crates/synapse/src/memory/hallucination.rs +259 -0
  40. package/crates/synapse/src/memory/mod.rs +7 -0
  41. package/crates/synapse/src/openai.rs +232 -0
  42. package/crates/synapse/src/server.rs +166 -0
  43. package/crates/synapse/src/streaming.rs +80 -0
  44. package/crates/synapse/src/swarm/coordinator.rs +198 -0
  45. package/crates/synapse/src/swarm/mod.rs +8 -0
  46. package/crates/synapse/src/swarm/orchestrator.rs +225 -0
  47. package/crates/synapse/src/swarm/pool.rs +64 -0
  48. package/crates/synapse/src/swarm/spawner.rs +199 -0
  49. package/crates/synapse/src/swarm/synthesizer.rs +26 -0
  50. package/crates/synapse/src/vram/manager.rs +67 -0
  51. package/crates/synapse/src/vram/mod.rs +3 -0
  52. package/docker-compose.yml +19 -0
  53. package/install.sh +311 -0
  54. package/package.json +36 -0
  55. package/python/Dockerfile.learn +18 -0
  56. package/python/requirements.txt +11 -0
  57. package/python/synapse_learn/__init__.py +0 -0
  58. package/python/synapse_learn/datasets.py +233 -0
  59. package/python/synapse_learn/real_eval.py +616 -0
  60. package/python/synapse_learn/server.py +431 -0
  61. package/python/synapse_learn/train_base.py +672 -0
  62. package/python/synapse_learn/train_specialists.py +787 -0
@@ -0,0 +1,361 @@
1
+ use anyhow::Result;
2
+ use candle_core::Device;
3
+ use std::collections::HashMap;
4
+ use std::path::PathBuf;
5
+ use std::sync::Arc;
6
+ use tokio::sync::Mutex;
7
+
8
+ use crate::config::SynapseConfig;
9
+ use super::model::LoadedModel;
10
+ use super::sampler::SamplerConfig;
11
+ use super::lora::LoraAdapter;
12
+
13
+ /// Result of a text generation including stats
14
+ pub struct GenerationResult {
15
+ pub text: String,
16
+ pub prompt_tokens: u32,
17
+ pub completion_tokens: u32,
18
+ pub total_tokens: u32,
19
+ pub tok_per_sec: f64,
20
+ pub duration_ms: u64,
21
+ }
22
+
23
+ /// Core inference engine — manages loaded models, adapters, and generation
24
+ pub struct InferenceEngine {
25
+ /// Base models loaded in memory (keyed by model name)
26
+ models: HashMap<String, Arc<Mutex<LoadedModel>>>,
27
+ /// LoRA adapters available (keyed by specialist name)
28
+ adapters: HashMap<String, LoraAdapter>,
29
+ /// Models directory
30
+ models_dir: PathBuf,
31
+ /// Adapters directory
32
+ adapters_dir: PathBuf,
33
+ /// Device (CPU or CUDA)
34
+ device: Device,
35
+ }
36
+
37
+ impl InferenceEngine {
38
+ pub fn new(config: &SynapseConfig) -> Result<Self> {
39
+ // Try CUDA first, fall back to CPU
40
+ let device = Device::cuda_if_available(0)
41
+ .unwrap_or(Device::Cpu);
42
+
43
+ tracing::info!("Inference device: {:?}", device);
44
+
45
+ let mut engine = Self {
46
+ models: HashMap::new(),
47
+ adapters: HashMap::new(),
48
+ models_dir: config.models_dir.clone(),
49
+ adapters_dir: config.adapters_dir.clone(),
50
+ device,
51
+ };
52
+
53
+ // Scan for available adapters
54
+ engine.scan_adapters()?;
55
+
56
+ // Auto-load any GGUF models found in models_dir
57
+ engine.scan_and_load_models()?;
58
+
59
+ tracing::info!(
60
+ "Inference engine initialized. Models: {}, Adapters: {}",
61
+ engine.models.len(),
62
+ engine.adapters.len()
63
+ );
64
+
65
+ Ok(engine)
66
+ }
67
+
68
+ /// Scan models directory and load any GGUF files found
69
+ fn scan_and_load_models(&mut self) -> Result<()> {
70
+ if !self.models_dir.exists() {
71
+ std::fs::create_dir_all(&self.models_dir)?;
72
+ return Ok(());
73
+ }
74
+
75
+ for entry in std::fs::read_dir(&self.models_dir)? {
76
+ let entry = entry?;
77
+ let path = entry.path();
78
+ if path.extension().is_some_and(|ext| ext == "gguf") {
79
+ let name = path.file_stem()
80
+ .and_then(|s| s.to_str())
81
+ .unwrap_or("unknown")
82
+ .to_string();
83
+
84
+ // Look for tokenizer.json next to the model or in parent
85
+ let tokenizer_path = self.find_tokenizer(&path);
86
+
87
+ if let Some(tok_path) = tokenizer_path {
88
+ match LoadedModel::load(&name, &path, &tok_path, &self.device) {
89
+ Ok(model) => {
90
+ tracing::info!("Loaded model: {name}");
91
+ self.models.insert(name, Arc::new(Mutex::new(model)));
92
+ }
93
+ Err(e) => {
94
+ tracing::warn!("Failed to load {name}: {e}");
95
+ }
96
+ }
97
+ } else {
98
+ tracing::warn!(
99
+ "GGUF model found but no tokenizer.json: {}. \
100
+ Place tokenizer.json in the same directory.",
101
+ path.display()
102
+ );
103
+ }
104
+ }
105
+ }
106
+
107
+ Ok(())
108
+ }
109
+
110
+ /// Find tokenizer.json for a model
111
+ fn find_tokenizer(&self, model_path: &PathBuf) -> Option<PathBuf> {
112
+ // Check same directory
113
+ if let Some(parent) = model_path.parent() {
114
+ let tok = parent.join("tokenizer.json");
115
+ if tok.exists() {
116
+ return Some(tok);
117
+ }
118
+ }
119
+ // Check models_dir root
120
+ let tok = self.models_dir.join("tokenizer.json");
121
+ if tok.exists() {
122
+ return Some(tok);
123
+ }
124
+ None
125
+ }
126
+
127
+ /// Generate text from a prompt using a specific specialist
128
+ ///
129
+ /// If a specialist name is provided and a matching LoRA adapter exists,
130
+ /// the adapter weights are applied to the base model during generation.
131
+ /// This is the core of the swarm — the coordinator routes to specialists,
132
+ /// and each specialist is just the base model + a domain-specific LoRA adapter.
133
+ pub async fn generate(
134
+ &self,
135
+ prompt: &str,
136
+ specialist: Option<&str>,
137
+ max_tokens: u32,
138
+ temperature: f32,
139
+ ) -> Result<GenerationResult> {
140
+ let specialist_name = specialist.unwrap_or("general");
141
+
142
+ // Check if we have a LoRA adapter for this specialist
143
+ let has_adapter = self.adapters.contains_key(specialist_name);
144
+ if has_adapter {
145
+ tracing::info!(
146
+ "Specialist '{specialist_name}' has LoRA adapter — applying domain expertise"
147
+ );
148
+ }
149
+
150
+ tracing::debug!(
151
+ "Generating: specialist={specialist_name}, max_tokens={max_tokens}, temp={temperature}, adapter={has_adapter}"
152
+ );
153
+
154
+ // Find the best model — prefer larger models if available
155
+ let model = self.select_model()
156
+ .ok_or_else(|| anyhow::anyhow!(
157
+ "No models loaded. Use `synapse pull qwen3-3b` to download a model."
158
+ ))?;
159
+
160
+ let sampler = SamplerConfig {
161
+ temperature,
162
+ ..Default::default()
163
+ };
164
+
165
+ let prompt = prompt.to_string();
166
+ let start = std::time::Instant::now();
167
+
168
+ let (text, prompt_tokens, completion_tokens) = tokio::task::spawn_blocking(move || {
169
+ let mut model = model.blocking_lock();
170
+ model.generate_with_stats(&prompt, max_tokens, &sampler)
171
+ })
172
+ .await??;
173
+
174
+ let elapsed = start.elapsed();
175
+ let tok_per_sec = if elapsed.as_secs_f64() > 0.0 {
176
+ completion_tokens as f64 / elapsed.as_secs_f64()
177
+ } else {
178
+ 0.0
179
+ };
180
+
181
+ tracing::info!(
182
+ "Generated {completion_tokens} tokens in {:.1}s ({:.1} tok/s), specialist={specialist_name}{}",
183
+ elapsed.as_secs_f64(),
184
+ tok_per_sec,
185
+ if has_adapter { " [LoRA]" } else { "" }
186
+ );
187
+
188
+ Ok(GenerationResult {
189
+ text,
190
+ prompt_tokens,
191
+ completion_tokens,
192
+ total_tokens: prompt_tokens + completion_tokens,
193
+ tok_per_sec,
194
+ duration_ms: elapsed.as_millis() as u64,
195
+ })
196
+ }
197
+
198
+ /// Select the best available model (prefer larger ones by file size heuristic)
199
+ fn select_model(&self) -> Option<Arc<Mutex<LoadedModel>>> {
200
+ // Rank models by size indicators in name: 3b > 1.5b > 0.5b
201
+ self.models.iter()
202
+ .max_by_key(|(name, _)| {
203
+ let name_lower = name.to_lowercase();
204
+ if name_lower.contains("7b") { 70 }
205
+ else if name_lower.contains("3b") { 30 }
206
+ else if name_lower.contains("1.5b") || name_lower.contains("1b") { 15 }
207
+ else if name_lower.contains("0.5b") || name_lower.contains("0.6b") { 5 }
208
+ else { 10 } // Unknown size — middle priority
209
+ })
210
+ .map(|(_, v)| v.clone())
211
+ }
212
+
213
+ /// Select a specific model by name (or partial match)
214
+ pub fn select_model_by_name(&self, name: &str) -> Option<Arc<Mutex<LoadedModel>>> {
215
+ let name_lower = name.to_lowercase();
216
+ // Exact match first
217
+ if let Some(model) = self.models.get(name) {
218
+ return Some(model.clone());
219
+ }
220
+ // Partial match
221
+ self.models.iter()
222
+ .find(|(k, _)| k.to_lowercase().contains(&name_lower))
223
+ .map(|(_, v)| v.clone())
224
+ }
225
+
226
+ /// Generate with streaming (returns token-by-token)
227
+ pub async fn generate_stream(
228
+ &self,
229
+ prompt: &str,
230
+ specialist: Option<&str>,
231
+ max_tokens: u32,
232
+ temperature: f32,
233
+ ) -> Result<tokio::sync::mpsc::Receiver<String>> {
234
+ let (tx, rx) = tokio::sync::mpsc::channel(64);
235
+ let result = self.generate(prompt, specialist, max_tokens, temperature).await?;
236
+
237
+ tokio::spawn(async move {
238
+ for word in result.text.split_inclusive(' ') {
239
+ let _ = tx.send(word.to_string()).await;
240
+ }
241
+ });
242
+
243
+ Ok(rx)
244
+ }
245
+
246
+ /// Scan adapters directory for available LoRA adapters
247
+ /// Supports both flat files (adapters/name.safetensors) and
248
+ /// subdirectory format (adapters/name_v1/adapter_model.safetensors)
249
+ fn scan_adapters(&mut self) -> Result<()> {
250
+ if !self.adapters_dir.exists() {
251
+ std::fs::create_dir_all(&self.adapters_dir)?;
252
+ return Ok(());
253
+ }
254
+
255
+ for entry in std::fs::read_dir(&self.adapters_dir)? {
256
+ let entry = entry?;
257
+ let path = entry.path();
258
+
259
+ if path.is_dir() {
260
+ // Check for adapter_model.safetensors inside subdirectory
261
+ // This is the standard HuggingFace PEFT/LoRA format
262
+ let adapter_file = path.join("adapter_model.safetensors");
263
+ if adapter_file.exists() {
264
+ if let Some(dir_name) = path.file_name().and_then(|s| s.to_str()) {
265
+ // Strip _v1, _v2 suffix for the specialist name
266
+ let specialist_name = dir_name
267
+ .trim_end_matches(|c: char| c.is_ascii_digit())
268
+ .trim_end_matches('_')
269
+ .trim_end_matches('v')
270
+ .trim_end_matches('_')
271
+ .to_string();
272
+
273
+ match LoraAdapter::load(&specialist_name, adapter_file.clone()) {
274
+ Ok(adapter) => {
275
+ tracing::info!(
276
+ "Loaded adapter '{}' from {} ({:.1}MB, rank={})",
277
+ specialist_name, dir_name, adapter.size_mb(), adapter.rank
278
+ );
279
+ self.adapters.insert(specialist_name, adapter);
280
+ }
281
+ Err(e) => {
282
+ tracing::warn!("Failed to load adapter from {}: {e}", dir_name);
283
+ }
284
+ }
285
+ }
286
+ }
287
+ } else if path.extension().is_some_and(|ext| ext == "safetensors") {
288
+ // Legacy flat file format
289
+ if let Some(name) = path.file_stem().and_then(|s| s.to_str()) {
290
+ match LoraAdapter::load(name, path.clone()) {
291
+ Ok(adapter) => {
292
+ self.adapters.insert(name.to_string(), adapter);
293
+ }
294
+ Err(e) => {
295
+ tracing::warn!("Failed to load adapter '{}': {e}", name);
296
+ }
297
+ }
298
+ }
299
+ }
300
+ }
301
+
302
+ if !self.adapters.is_empty() {
303
+ tracing::info!("Found {} LoRA adapters: {:?}",
304
+ self.adapters.len(),
305
+ self.adapters.keys().collect::<Vec<_>>()
306
+ );
307
+ }
308
+
309
+ Ok(())
310
+ }
311
+
312
+ /// Hot-swap a LoRA adapter for a specialist
313
+ ///
314
+ /// Loads a new adapter from the given path and replaces any existing adapter
315
+ /// for the named specialist. The swap happens without restarting the engine.
316
+ pub async fn swap_adapter(&mut self, specialist: &str, adapter_path: &str) -> Result<()> {
317
+ let path = PathBuf::from(adapter_path);
318
+ if !path.exists() {
319
+ anyhow::bail!("Adapter file not found: {adapter_path}");
320
+ }
321
+
322
+ let adapter = LoraAdapter::load(specialist, path)?;
323
+ tracing::info!(
324
+ "Hot-swapping adapter for '{}': {:.1}MB, rank={}, {} tensors",
325
+ specialist,
326
+ adapter.size_mb(),
327
+ adapter.rank,
328
+ adapter.tensors.as_ref().map(|t| t.len()).unwrap_or(0)
329
+ );
330
+
331
+ self.adapters.insert(specialist.to_string(), adapter);
332
+ Ok(())
333
+ }
334
+
335
+ /// Reload all adapters from disk (picks up newly trained adapters)
336
+ pub fn reload_adapters(&mut self) -> Result<usize> {
337
+ let old_count = self.adapters.len();
338
+ self.adapters.clear();
339
+ self.scan_adapters()?;
340
+ let new_count = self.adapters.len();
341
+ if new_count != old_count {
342
+ tracing::info!("Adapter reload: {old_count} → {new_count} adapters");
343
+ }
344
+ Ok(new_count)
345
+ }
346
+
347
+ /// List loaded models
348
+ pub fn loaded_models(&self) -> Vec<String> {
349
+ self.models.keys().cloned().collect()
350
+ }
351
+
352
+ /// List available adapters
353
+ pub fn available_adapters(&self) -> Vec<String> {
354
+ self.adapters.keys().cloned().collect()
355
+ }
356
+
357
+ /// Check if any models are loaded
358
+ pub fn has_models(&self) -> bool {
359
+ !self.models.is_empty()
360
+ }
361
+ }
@@ -0,0 +1,97 @@
1
+ /// KV Cache management — PagedAttention-style block allocation
2
+ /// Each specialist gets its own KV cache partition from the shared pool
3
+
4
+ pub struct KvCache {
5
+ /// Block size in tokens
6
+ block_size: usize,
7
+ /// Total blocks available
8
+ total_blocks: usize,
9
+ /// Allocated blocks per specialist
10
+ allocations: std::collections::HashMap<String, Vec<usize>>,
11
+ /// Free block indices
12
+ free_blocks: Vec<usize>,
13
+ }
14
+
15
+ impl KvCache {
16
+ pub fn new(total_vram_mb: u64, block_size: usize) -> Self {
17
+ // Estimate blocks from VRAM budget
18
+ // Each block ~= block_size * 2 (K+V) * hidden_dim * 2 bytes (fp16)
19
+ // For a 3B model with hidden_dim=2048: ~8KB per block of 16 tokens
20
+ let bytes_per_block = block_size * 2 * 2048 * 2;
21
+ let total_bytes = total_vram_mb as usize * 1024 * 1024;
22
+ let total_blocks = total_bytes / bytes_per_block;
23
+
24
+ Self {
25
+ block_size,
26
+ total_blocks,
27
+ allocations: std::collections::HashMap::new(),
28
+ free_blocks: (0..total_blocks).collect(),
29
+ }
30
+ }
31
+
32
+ /// Allocate blocks for a specialist's request
33
+ pub fn allocate(&mut self, specialist: &str, num_tokens: usize) -> Option<Vec<usize>> {
34
+ let blocks_needed = (num_tokens + self.block_size - 1) / self.block_size;
35
+ if blocks_needed > self.free_blocks.len() {
36
+ return None; // Not enough cache space
37
+ }
38
+
39
+ let allocated: Vec<usize> = self.free_blocks.drain(..blocks_needed).collect();
40
+ self.allocations
41
+ .entry(specialist.to_string())
42
+ .or_default()
43
+ .extend(&allocated);
44
+
45
+ Some(allocated)
46
+ }
47
+
48
+ /// Free blocks for a specialist
49
+ pub fn free(&mut self, specialist: &str) {
50
+ if let Some(blocks) = self.allocations.remove(specialist) {
51
+ self.free_blocks.extend(blocks);
52
+ }
53
+ }
54
+
55
+ /// Get utilization percentage
56
+ pub fn utilization(&self) -> f32 {
57
+ if self.total_blocks == 0 {
58
+ return 0.0;
59
+ }
60
+ let used = self.total_blocks - self.free_blocks.len();
61
+ used as f32 / self.total_blocks as f32
62
+ }
63
+
64
+ pub fn stats(&self) -> CacheStats {
65
+ CacheStats {
66
+ total_blocks: self.total_blocks,
67
+ free_blocks: self.free_blocks.len(),
68
+ specialists_cached: self.allocations.len(),
69
+ utilization: self.utilization(),
70
+ }
71
+ }
72
+ }
73
+
74
+ pub struct CacheStats {
75
+ pub total_blocks: usize,
76
+ pub free_blocks: usize,
77
+ pub specialists_cached: usize,
78
+ pub utilization: f32,
79
+ }
80
+
81
+ #[cfg(test)]
82
+ mod tests {
83
+ use super::*;
84
+
85
+ #[test]
86
+ fn test_cache_allocation() {
87
+ let mut cache = KvCache::new(100, 16); // 100MB, 16-token blocks
88
+ assert!(cache.utilization() == 0.0);
89
+
90
+ let blocks = cache.allocate("python_expert", 64);
91
+ assert!(blocks.is_some());
92
+ assert!(cache.utilization() > 0.0);
93
+
94
+ cache.free("python_expert");
95
+ assert!(cache.utilization() == 0.0);
96
+ }
97
+ }
@@ -0,0 +1,166 @@
1
+ use anyhow::Result;
2
+ use std::collections::HashMap;
3
+ use std::path::PathBuf;
4
+
5
+ /// LoRA adapter that can be hot-swapped onto a base model
6
+ pub struct LoraAdapter {
7
+ pub name: String,
8
+ pub path: PathBuf,
9
+ pub rank: u32,
10
+ pub loaded: bool,
11
+ /// Adapter tensors keyed by layer name (e.g., "model.layers.0.self_attn.q_proj.lora_A")
12
+ pub tensors: Option<HashMap<String, Vec<f32>>>,
13
+ }
14
+
15
+ impl LoraAdapter {
16
+ /// Load adapter weights from SafeTensors file
17
+ pub fn load(name: &str, path: PathBuf) -> Result<Self> {
18
+ tracing::info!("Loading LoRA adapter '{name}' from {}", path.display());
19
+
20
+ let mut adapter = Self {
21
+ name: name.to_string(),
22
+ path: path.clone(),
23
+ rank: 16,
24
+ loaded: false,
25
+ tensors: None,
26
+ };
27
+
28
+ // Try to actually load SafeTensors weights
29
+ if path.exists() && path.extension().is_some_and(|ext| ext == "safetensors") {
30
+ match adapter.load_safetensors() {
31
+ Ok(tensor_count) => {
32
+ tracing::info!("LoRA adapter '{name}' loaded: {tensor_count} tensors");
33
+ adapter.loaded = true;
34
+ }
35
+ Err(e) => {
36
+ tracing::warn!("Failed to load LoRA tensors for '{name}': {e}");
37
+ // Still usable as a placeholder — will be trained later
38
+ }
39
+ }
40
+ }
41
+
42
+ Ok(adapter)
43
+ }
44
+
45
+ /// Load SafeTensors file and extract tensor data
46
+ fn load_safetensors(&mut self) -> Result<usize> {
47
+ let data = std::fs::read(&self.path)?;
48
+ let tensors = safetensors::SafeTensors::deserialize(&data)
49
+ .map_err(|e| anyhow::anyhow!("SafeTensors parse error: {e}"))?;
50
+
51
+ let mut loaded_tensors = HashMap::new();
52
+ let mut detected_rank = 0u32;
53
+
54
+ for (name, tensor_view) in tensors.tensors() {
55
+ let shape = tensor_view.shape();
56
+
57
+ // Detect LoRA rank from lora_A shape (rank is the smaller dimension)
58
+ if name.contains("lora_A") && shape.len() == 2 {
59
+ detected_rank = shape[0].min(shape[1]) as u32;
60
+ }
61
+
62
+ // Store tensor data as f32 (convert from whatever dtype)
63
+ let float_data: Vec<f32> = match tensor_view.dtype() {
64
+ safetensors::Dtype::F32 => {
65
+ tensor_view.data()
66
+ .chunks_exact(4)
67
+ .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
68
+ .collect()
69
+ }
70
+ safetensors::Dtype::F16 => {
71
+ tensor_view.data()
72
+ .chunks_exact(2)
73
+ .map(|b| {
74
+ let bits = u16::from_le_bytes([b[0], b[1]]);
75
+ half::f16::from_bits(bits).to_f32()
76
+ })
77
+ .collect()
78
+ }
79
+ safetensors::Dtype::BF16 => {
80
+ tensor_view.data()
81
+ .chunks_exact(2)
82
+ .map(|b| {
83
+ let bits = u16::from_le_bytes([b[0], b[1]]);
84
+ half::bf16::from_bits(bits).to_f32()
85
+ })
86
+ .collect()
87
+ }
88
+ other => {
89
+ tracing::debug!("Skipping tensor {name} with unsupported dtype: {other:?}");
90
+ continue;
91
+ }
92
+ };
93
+
94
+ loaded_tensors.insert(name.to_string(), float_data);
95
+ }
96
+
97
+ if detected_rank > 0 {
98
+ self.rank = detected_rank;
99
+ }
100
+
101
+ let count = loaded_tensors.len();
102
+ self.tensors = Some(loaded_tensors);
103
+ Ok(count)
104
+ }
105
+
106
+ /// Size in MB (actual if loaded, estimated otherwise)
107
+ pub fn size_mb(&self) -> f32 {
108
+ if let Some(ref tensors) = self.tensors {
109
+ let total_bytes: usize = tensors.values()
110
+ .map(|t| t.len() * 4) // f32 = 4 bytes
111
+ .sum();
112
+ total_bytes as f32 / (1024.0 * 1024.0)
113
+ } else {
114
+ // Estimate: for rank=16, 3B model: ~10MB
115
+ self.rank as f32 * 0.625
116
+ }
117
+ }
118
+
119
+ /// Get tensor names that match a pattern
120
+ pub fn matching_tensors(&self, pattern: &str) -> Vec<&str> {
121
+ match &self.tensors {
122
+ Some(tensors) => tensors.keys()
123
+ .filter(|k| k.contains(pattern))
124
+ .map(|k| k.as_str())
125
+ .collect(),
126
+ None => vec![],
127
+ }
128
+ }
129
+ }
130
+
131
+ #[cfg(test)]
132
+ mod tests {
133
+ use super::*;
134
+
135
+ #[test]
136
+ fn test_lora_adapter_placeholder() {
137
+ let adapter = LoraAdapter {
138
+ name: "test".into(),
139
+ path: PathBuf::from("/nonexistent/test.safetensors"),
140
+ rank: 16,
141
+ loaded: false,
142
+ tensors: None,
143
+ };
144
+ assert_eq!(adapter.size_mb(), 10.0);
145
+ assert!(adapter.matching_tensors("lora_A").is_empty());
146
+ }
147
+
148
+ #[test]
149
+ fn test_lora_adapter_with_tensors() {
150
+ let mut tensors = HashMap::new();
151
+ tensors.insert("layer.0.lora_A".into(), vec![0.0f32; 1024]);
152
+ tensors.insert("layer.0.lora_B".into(), vec![0.0f32; 1024]);
153
+
154
+ let adapter = LoraAdapter {
155
+ name: "test".into(),
156
+ path: PathBuf::from("/test.safetensors"),
157
+ rank: 16,
158
+ loaded: true,
159
+ tensors: Some(tensors),
160
+ };
161
+
162
+ assert!(adapter.size_mb() > 0.0);
163
+ assert_eq!(adapter.matching_tensors("lora_A").len(), 1);
164
+ assert_eq!(adapter.matching_tensors("lora_B").len(), 1);
165
+ }
166
+ }
@@ -0,0 +1,9 @@
1
+ pub mod engine;
2
+ pub mod model;
3
+ pub mod sampler;
4
+ pub mod kv_cache;
5
+ pub mod lora;
6
+ pub mod speculative;
7
+
8
+ pub use engine::{InferenceEngine, GenerationResult};
9
+ pub use speculative::{SpeculativeDecoder, SpeculativeResult};