ollamadiffuser 1.1.6__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,384 @@
1
+ """
2
+ Dynamic Model Registry - Similar to Ollama's approach
3
+ Fetches model information from external sources with local fallbacks
4
+ """
5
+
6
+ import json
7
+ import logging
8
+ import requests
9
+ import yaml
10
+ from pathlib import Path
11
+ from typing import Dict, Any, Optional, List
12
+ from datetime import datetime, timedelta
13
+ import hashlib
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ class ModelRegistry:
18
+ """
19
+ Dynamic model registry that can fetch from external sources
20
+ Similar to how Ollama manages their model library
21
+ """
22
+
23
+ def __init__(self, cache_dir: Optional[Path] = None):
24
+ self.cache_dir = cache_dir or Path.home() / ".ollamadiffuser" / "registry"
25
+ self.cache_dir.mkdir(parents=True, exist_ok=True)
26
+
27
+ # Load configuration
28
+ self._load_config()
29
+
30
+ # Cache settings from config
31
+ self.cache_duration = timedelta(hours=self.config.get('cache_duration_hours', 24))
32
+ self.registry_cache_file = self.cache_dir / "models.json"
33
+ self.last_update_file = self.cache_dir / "last_update.txt"
34
+
35
+ # Registry sources from config
36
+ self.registry_sources = self.config.get('sources', [])
37
+
38
+ # Local models (built-in fallback)
39
+ self._builtin_models = self._load_builtin_models()
40
+
41
+ # Cached models
42
+ self._cached_models = {}
43
+ self._load_cache()
44
+
45
+ def _load_config(self):
46
+ """Load registry configuration from YAML file"""
47
+ try:
48
+ # Try to find config file
49
+ config_paths = [
50
+ Path(__file__).parent.parent.parent / "config" / "registry.yaml",
51
+ Path.home() / ".ollamadiffuser" / "registry.yaml",
52
+ Path("/etc/ollamadiffuser/registry.yaml")
53
+ ]
54
+
55
+ config = {}
56
+ for config_path in config_paths:
57
+ if config_path.exists():
58
+ with open(config_path, 'r') as f:
59
+ config = yaml.safe_load(f)
60
+ logger.debug(f"Loaded config from {config_path}")
61
+ break
62
+
63
+ # Use registry section if it exists
64
+ self.config = config.get('registry', {})
65
+
66
+ # Set defaults if no config found
67
+ if not self.config:
68
+ logger.warning("No registry config found, using defaults")
69
+ self.config = {
70
+ 'cache_duration_hours': 24,
71
+ 'sources': [
72
+ {
73
+ "name": "builtin",
74
+ "url": None,
75
+ "timeout": 10,
76
+ "enabled": True,
77
+ "description": "Built-in models only"
78
+ }
79
+ ]
80
+ }
81
+
82
+ except Exception as e:
83
+ logger.warning(f"Failed to load registry config: {e}")
84
+ self.config = {'cache_duration_hours': 24, 'sources': []}
85
+
86
+ def _load_builtin_models(self) -> Dict[str, Any]:
87
+ """Load built-in model definitions as fallback"""
88
+ return {
89
+ # FLUX.1 models
90
+ "flux.1-dev": {
91
+ "name": "flux.1-dev",
92
+ "repo_id": "black-forest-labs/FLUX.1-dev",
93
+ "model_type": "flux",
94
+ "description": "High-quality text-to-image model from Black Forest Labs",
95
+ "license": {"type": "Non-commercial", "commercial_use": False},
96
+ "size_gb": 23.8,
97
+ "hardware_requirements": {
98
+ "min_vram_gb": 12,
99
+ "recommended_vram_gb": 24,
100
+ "min_ram_gb": 16,
101
+ "recommended_ram_gb": 32
102
+ },
103
+ "parameters": {
104
+ "num_inference_steps": 50,
105
+ "guidance_scale": 3.5,
106
+ "max_sequence_length": 512
107
+ },
108
+ "tags": ["flux", "high-quality", "non-commercial"],
109
+ "downloads": 250000,
110
+ "updated": "2024-12-01"
111
+ },
112
+
113
+ "flux.1-schnell": {
114
+ "name": "flux.1-schnell",
115
+ "repo_id": "black-forest-labs/FLUX.1-schnell",
116
+ "model_type": "flux",
117
+ "description": "Fast text-to-image model optimized for speed",
118
+ "license": {"type": "Apache 2.0", "commercial_use": True},
119
+ "size_gb": 23.8,
120
+ "hardware_requirements": {
121
+ "min_vram_gb": 12,
122
+ "recommended_vram_gb": 24,
123
+ "min_ram_gb": 16,
124
+ "recommended_ram_gb": 32
125
+ },
126
+ "parameters": {
127
+ "num_inference_steps": 4,
128
+ "guidance_scale": 0.0,
129
+ "max_sequence_length": 512
130
+ },
131
+ "tags": ["flux", "fast", "commercial", "apache"],
132
+ "downloads": 180000,
133
+ "updated": "2024-12-01"
134
+ },
135
+
136
+ # GGUF variants - generate dynamically
137
+ **self._generate_gguf_variants()
138
+ }
139
+
140
+ def _generate_gguf_variants(self) -> Dict[str, Any]:
141
+ """Generate GGUF model variants dynamically"""
142
+ base_gguf = {
143
+ "repo_id": "city96/FLUX.1-dev-gguf",
144
+ "model_type": "flux_gguf",
145
+ "description": "Quantized FLUX.1-dev model for efficient inference",
146
+ "license": {"type": "Non-commercial", "commercial_use": False},
147
+ "tags": ["flux", "gguf", "quantized", "efficient"],
148
+ "updated": "2024-12-01"
149
+ }
150
+
151
+ variants = {
152
+ "q2_k": {"size_gb": 4.03, "vram_gb": 4, "description": "Ultra-light quantization"},
153
+ "q3_k_s": {"size_gb": 5.23, "vram_gb": 5, "description": "Light quantization"},
154
+ "q4_k_s": {"size_gb": 6.81, "vram_gb": 6, "description": "Recommended quantization", "recommended": True},
155
+ "q4_0": {"size_gb": 6.79, "vram_gb": 6, "description": "Alternative Q4 quantization"},
156
+ "q4_1": {"size_gb": 7.53, "vram_gb": 7, "description": "Higher quality Q4"},
157
+ "q5_k_s": {"size_gb": 8.29, "vram_gb": 8, "description": "High quality quantization"},
158
+ "q5_0": {"size_gb": 8.27, "vram_gb": 8, "description": "Alternative Q5 quantization"},
159
+ "q5_1": {"size_gb": 9.01, "vram_gb": 9, "description": "Highest Q5 quality"},
160
+ "q6_k": {"size_gb": 9.86, "vram_gb": 10, "description": "Very high quality"},
161
+ "q8_0": {"size_gb": 12.7, "vram_gb": 12, "description": "Near-original quality"},
162
+ "f16": {"size_gb": 23.8, "vram_gb": 24, "description": "Full precision"}
163
+ }
164
+
165
+ gguf_models = {}
166
+ for variant, info in variants.items():
167
+ model_name = f"flux.1-dev-gguf:{variant}"
168
+ gguf_models[model_name] = {
169
+ **base_gguf,
170
+ "name": model_name,
171
+ "variant": variant,
172
+ "file_name": f"flux1-dev-{variant.upper()}.gguf",
173
+ "quantization": variant.upper(),
174
+ "size_gb": info["size_gb"],
175
+ "description": f"{base_gguf['description']} - {info['description']}",
176
+ "hardware_requirements": {
177
+ "min_vram_gb": info["vram_gb"],
178
+ "recommended_vram_gb": info["vram_gb"] + 2,
179
+ "min_ram_gb": 8,
180
+ "recommended_ram_gb": 16
181
+ },
182
+ "parameters": {
183
+ "num_inference_steps": 16,
184
+ "guidance_scale": 2.0,
185
+ "max_sequence_length": 512
186
+ },
187
+ "downloads": 50000 - (info["vram_gb"] * 1000), # Simulate popularity
188
+ "recommended": info.get("recommended", False)
189
+ }
190
+
191
+ return gguf_models
192
+
193
+ def _load_cache(self):
194
+ """Load cached model registry"""
195
+ try:
196
+ if self.registry_cache_file.exists():
197
+ with open(self.registry_cache_file, 'r') as f:
198
+ self._cached_models = json.load(f)
199
+ logger.debug(f"Loaded {len(self._cached_models)} models from cache")
200
+ except Exception as e:
201
+ logger.warning(f"Failed to load model cache: {e}")
202
+ self._cached_models = {}
203
+
204
+ def _save_cache(self, models: Dict[str, Any]):
205
+ """Save model registry to cache"""
206
+ try:
207
+ with open(self.registry_cache_file, 'w') as f:
208
+ json.dump(models, f, indent=2)
209
+
210
+ with open(self.last_update_file, 'w') as f:
211
+ f.write(datetime.now().isoformat())
212
+
213
+ logger.debug(f"Saved {len(models)} models to cache")
214
+ except Exception as e:
215
+ logger.warning(f"Failed to save model cache: {e}")
216
+
217
+ def _is_cache_expired(self) -> bool:
218
+ """Check if cache is expired"""
219
+ try:
220
+ if not self.last_update_file.exists():
221
+ return True
222
+
223
+ with open(self.last_update_file, 'r') as f:
224
+ last_update = datetime.fromisoformat(f.read().strip())
225
+
226
+ return datetime.now() - last_update > self.cache_duration
227
+ except:
228
+ return True
229
+
230
+ def _fetch_from_source(self, source: Dict[str, Any]) -> Optional[Dict[str, Any]]:
231
+ """Fetch models from a specific source"""
232
+ try:
233
+ logger.debug(f"Fetching models from {source['name']}: {source['url']}")
234
+
235
+ response = requests.get(
236
+ source['url'],
237
+ timeout=source['timeout'],
238
+ headers={'User-Agent': 'OllamaDiffuser/1.0'}
239
+ )
240
+ response.raise_for_status()
241
+
242
+ data = response.json()
243
+
244
+ # Normalize the data format
245
+ if 'models' in data:
246
+ models = data['models']
247
+ elif isinstance(data, dict):
248
+ models = data
249
+ else:
250
+ logger.warning(f"Unexpected data format from {source['name']}")
251
+ return None
252
+
253
+ logger.info(f"Fetched {len(models)} models from {source['name']}")
254
+ return models
255
+
256
+ except requests.exceptions.Timeout:
257
+ logger.warning(f"Timeout fetching from {source['name']}")
258
+ except requests.exceptions.RequestException as e:
259
+ logger.warning(f"Failed to fetch from {source['name']}: {e}")
260
+ except json.JSONDecodeError as e:
261
+ logger.warning(f"Invalid JSON from {source['name']}: {e}")
262
+ except Exception as e:
263
+ logger.warning(f"Unexpected error fetching from {source['name']}: {e}")
264
+
265
+ return None
266
+
267
+ def refresh(self, force: bool = False) -> bool:
268
+ """Refresh model registry from external sources"""
269
+ if not force and not self._is_cache_expired():
270
+ logger.debug("Cache is still fresh, skipping refresh")
271
+ return True
272
+
273
+ logger.info("Refreshing model registry...")
274
+
275
+ # Try each source in priority order
276
+ for source in self.registry_sources:
277
+ if not source.get('enabled', True):
278
+ continue
279
+
280
+ models = self._fetch_from_source(source)
281
+ if models:
282
+ # Merge with built-in models
283
+ combined_models = {**self._builtin_models, **models}
284
+
285
+ # Update cache
286
+ self._cached_models = combined_models
287
+ self._save_cache(combined_models)
288
+
289
+ logger.info(f"Successfully updated registry with {len(combined_models)} models")
290
+ return True
291
+
292
+ logger.warning("Failed to fetch from any source, using cached/builtin models")
293
+ return False
294
+
295
+ def get_models(self, refresh: bool = False) -> Dict[str, Any]:
296
+ """Get all available models"""
297
+ if refresh or not self._cached_models:
298
+ self.refresh()
299
+
300
+ # Return cached models if available, otherwise built-in
301
+ return self._cached_models if self._cached_models else self._builtin_models
302
+
303
+ def get_model(self, model_name: str, refresh: bool = False) -> Optional[Dict[str, Any]]:
304
+ """Get specific model information"""
305
+ models = self.get_models(refresh=refresh)
306
+ return models.get(model_name)
307
+
308
+ def search_models(self, query: str = "", tags: List[str] = None, model_type: str = None) -> Dict[str, Any]:
309
+ """Search models by query, tags, or type"""
310
+ models = self.get_models()
311
+ results = {}
312
+
313
+ query_lower = query.lower()
314
+ tags = tags or []
315
+
316
+ for name, model in models.items():
317
+ # Check query match
318
+ query_match = (
319
+ not query or
320
+ query_lower in name.lower() or
321
+ query_lower in model.get('description', '').lower()
322
+ )
323
+
324
+ # Check type match
325
+ type_match = not model_type or model.get('model_type') == model_type
326
+
327
+ # Check tags match
328
+ model_tags = model.get('tags', [])
329
+ tags_match = not tags or any(tag in model_tags for tag in tags)
330
+
331
+ if query_match and type_match and tags_match:
332
+ results[name] = model
333
+
334
+ return results
335
+
336
+ def get_popular_models(self, limit: int = 10) -> Dict[str, Any]:
337
+ """Get most popular models"""
338
+ models = self.get_models()
339
+
340
+ # Sort by downloads
341
+ sorted_models = sorted(
342
+ models.items(),
343
+ key=lambda x: x[1].get('downloads', 0),
344
+ reverse=True
345
+ )
346
+
347
+ return dict(sorted_models[:limit])
348
+
349
+ def get_recommended_models(self) -> Dict[str, Any]:
350
+ """Get recommended models"""
351
+ models = self.get_models()
352
+
353
+ return {
354
+ name: model for name, model in models.items()
355
+ if model.get('recommended', False)
356
+ }
357
+
358
+ def add_local_model(self, model_name: str, model_config: Dict[str, Any]):
359
+ """Add a local model configuration"""
360
+ # Add to current models
361
+ current_models = self.get_models()
362
+ current_models[model_name] = model_config
363
+
364
+ # Save to cache
365
+ self._save_cache(current_models)
366
+ self._cached_models = current_models
367
+
368
+ logger.info(f"Added local model: {model_name}")
369
+
370
+ def remove_local_model(self, model_name: str) -> bool:
371
+ """Remove a local model configuration"""
372
+ current_models = self.get_models()
373
+
374
+ if model_name in current_models:
375
+ del current_models[model_name]
376
+ self._save_cache(current_models)
377
+ self._cached_models = current_models
378
+ logger.info(f"Removed local model: {model_name}")
379
+ return True
380
+
381
+ return False
382
+
383
+ # Global registry instance
384
+ model_registry = ModelRegistry()
@@ -12,6 +12,7 @@ from huggingface_hub import snapshot_download, hf_hub_download, HfApi
12
12
  from tqdm import tqdm
13
13
  import threading
14
14
  import requests
15
+ import fnmatch
15
16
 
16
17
  logger = logging.getLogger(__name__)
17
18
 
@@ -201,7 +202,9 @@ def robust_snapshot_download(
201
202
  max_retries: int = 3,
202
203
  initial_workers: int = 2,
203
204
  force_download: bool = False,
204
- progress_callback: Optional[Callable] = None
205
+ progress_callback: Optional[Callable] = None,
206
+ allow_patterns: Optional[list] = None,
207
+ ignore_patterns: Optional[list] = None
205
208
  ) -> str:
206
209
  """
207
210
  Download repository snapshot with robust error handling and detailed progress tracking
@@ -214,6 +217,8 @@ def robust_snapshot_download(
214
217
  initial_workers: Initial number of workers (reduced on retries)
215
218
  force_download: Force re-download
216
219
  progress_callback: Optional progress callback function
220
+ allow_patterns: List of file patterns to include (e.g., ["*.gguf", "*.safetensors"])
221
+ ignore_patterns: List of file patterns to exclude (e.g., ["*.txt", "*.md"])
217
222
 
218
223
  Returns:
219
224
  Path to downloaded repository
@@ -225,6 +230,32 @@ def robust_snapshot_download(
225
230
  progress_callback("pulling manifest")
226
231
 
227
232
  file_sizes = get_repo_file_list(repo_id)
233
+
234
+ # Filter files based on patterns if provided
235
+ if allow_patterns or ignore_patterns:
236
+ filtered_files = {}
237
+ for filename, size in file_sizes.items():
238
+ # Check allow patterns (if provided, file must match at least one)
239
+ if allow_patterns:
240
+ allowed = any(fnmatch.fnmatch(filename, pattern) for pattern in allow_patterns)
241
+ if not allowed:
242
+ continue
243
+
244
+ # Check ignore patterns (if file matches any, skip it)
245
+ if ignore_patterns:
246
+ ignored = any(fnmatch.fnmatch(filename, pattern) for pattern in ignore_patterns)
247
+ if ignored:
248
+ continue
249
+
250
+ filtered_files[filename] = size
251
+
252
+ file_sizes = filtered_files
253
+
254
+ if progress_callback and allow_patterns:
255
+ progress_callback(f"🔍 Filtering files with patterns: {allow_patterns}")
256
+ if progress_callback and ignore_patterns:
257
+ progress_callback(f"🚫 Ignoring patterns: {ignore_patterns}")
258
+
228
259
  total_size = sum(file_sizes.values())
229
260
 
230
261
  if progress_callback and file_sizes:
@@ -310,7 +341,9 @@ def robust_snapshot_download(
310
341
  resume_download=True, # Enable resume
311
342
  etag_timeout=300 + (attempt * 60), # Increase timeout on retries
312
343
  force_download=force_download,
313
- tqdm_class=OllamaStyleTqdm if progress_callback else None
344
+ tqdm_class=OllamaStyleTqdm if progress_callback else None,
345
+ allow_patterns=allow_patterns,
346
+ ignore_patterns=ignore_patterns
314
347
  )
315
348
 
316
349
  if progress_callback: