comfygit-core 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- comfygit_core/analyzers/custom_node_scanner.py +109 -0
- comfygit_core/analyzers/git_change_parser.py +156 -0
- comfygit_core/analyzers/model_scanner.py +318 -0
- comfygit_core/analyzers/node_classifier.py +58 -0
- comfygit_core/analyzers/node_git_analyzer.py +77 -0
- comfygit_core/analyzers/status_scanner.py +362 -0
- comfygit_core/analyzers/workflow_dependency_parser.py +143 -0
- comfygit_core/caching/__init__.py +16 -0
- comfygit_core/caching/api_cache.py +210 -0
- comfygit_core/caching/base.py +212 -0
- comfygit_core/caching/comfyui_cache.py +100 -0
- comfygit_core/caching/custom_node_cache.py +320 -0
- comfygit_core/caching/workflow_cache.py +797 -0
- comfygit_core/clients/__init__.py +4 -0
- comfygit_core/clients/civitai_client.py +412 -0
- comfygit_core/clients/github_client.py +349 -0
- comfygit_core/clients/registry_client.py +230 -0
- comfygit_core/configs/comfyui_builtin_nodes.py +1614 -0
- comfygit_core/configs/comfyui_models.py +62 -0
- comfygit_core/configs/model_config.py +151 -0
- comfygit_core/constants.py +82 -0
- comfygit_core/core/environment.py +1635 -0
- comfygit_core/core/workspace.py +898 -0
- comfygit_core/factories/environment_factory.py +419 -0
- comfygit_core/factories/uv_factory.py +61 -0
- comfygit_core/factories/workspace_factory.py +109 -0
- comfygit_core/infrastructure/sqlite_manager.py +156 -0
- comfygit_core/integrations/__init__.py +7 -0
- comfygit_core/integrations/uv_command.py +318 -0
- comfygit_core/logging/logging_config.py +15 -0
- comfygit_core/managers/environment_git_orchestrator.py +316 -0
- comfygit_core/managers/environment_model_manager.py +296 -0
- comfygit_core/managers/export_import_manager.py +116 -0
- comfygit_core/managers/git_manager.py +667 -0
- comfygit_core/managers/model_download_manager.py +252 -0
- comfygit_core/managers/model_symlink_manager.py +166 -0
- comfygit_core/managers/node_manager.py +1378 -0
- comfygit_core/managers/pyproject_manager.py +1321 -0
- comfygit_core/managers/user_content_symlink_manager.py +436 -0
- comfygit_core/managers/uv_project_manager.py +569 -0
- comfygit_core/managers/workflow_manager.py +1944 -0
- comfygit_core/models/civitai.py +432 -0
- comfygit_core/models/commit.py +18 -0
- comfygit_core/models/environment.py +293 -0
- comfygit_core/models/exceptions.py +378 -0
- comfygit_core/models/manifest.py +132 -0
- comfygit_core/models/node_mapping.py +201 -0
- comfygit_core/models/protocols.py +248 -0
- comfygit_core/models/registry.py +63 -0
- comfygit_core/models/shared.py +356 -0
- comfygit_core/models/sync.py +42 -0
- comfygit_core/models/system.py +204 -0
- comfygit_core/models/workflow.py +914 -0
- comfygit_core/models/workspace_config.py +71 -0
- comfygit_core/py.typed +0 -0
- comfygit_core/repositories/migrate_paths.py +49 -0
- comfygit_core/repositories/model_repository.py +958 -0
- comfygit_core/repositories/node_mappings_repository.py +246 -0
- comfygit_core/repositories/workflow_repository.py +57 -0
- comfygit_core/repositories/workspace_config_repository.py +121 -0
- comfygit_core/resolvers/global_node_resolver.py +459 -0
- comfygit_core/resolvers/model_resolver.py +250 -0
- comfygit_core/services/import_analyzer.py +218 -0
- comfygit_core/services/model_downloader.py +422 -0
- comfygit_core/services/node_lookup_service.py +251 -0
- comfygit_core/services/registry_data_manager.py +161 -0
- comfygit_core/strategies/__init__.py +4 -0
- comfygit_core/strategies/auto.py +72 -0
- comfygit_core/strategies/confirmation.py +69 -0
- comfygit_core/utils/comfyui_ops.py +125 -0
- comfygit_core/utils/common.py +164 -0
- comfygit_core/utils/conflict_parser.py +232 -0
- comfygit_core/utils/dependency_parser.py +231 -0
- comfygit_core/utils/download.py +216 -0
- comfygit_core/utils/environment_cleanup.py +111 -0
- comfygit_core/utils/filesystem.py +178 -0
- comfygit_core/utils/git.py +1184 -0
- comfygit_core/utils/input_signature.py +145 -0
- comfygit_core/utils/model_categories.py +52 -0
- comfygit_core/utils/pytorch.py +71 -0
- comfygit_core/utils/requirements.py +211 -0
- comfygit_core/utils/retry.py +242 -0
- comfygit_core/utils/symlink_utils.py +119 -0
- comfygit_core/utils/system_detector.py +258 -0
- comfygit_core/utils/uuid.py +28 -0
- comfygit_core/utils/uv_error_handler.py +158 -0
- comfygit_core/utils/version.py +73 -0
- comfygit_core/utils/workflow_hash.py +90 -0
- comfygit_core/validation/resolution_tester.py +297 -0
- comfygit_core-0.2.0.dist-info/METADATA +939 -0
- comfygit_core-0.2.0.dist-info/RECORD +93 -0
- comfygit_core-0.2.0.dist-info/WHEEL +4 -0
- comfygit_core-0.2.0.dist-info/licenses/LICENSE.txt +661 -0
|
@@ -0,0 +1,422 @@
|
|
|
1
|
+
"""Model download service for fetching models from URLs."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import tempfile
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import TYPE_CHECKING
|
|
9
|
+
from urllib.parse import urlparse
|
|
10
|
+
|
|
11
|
+
import requests
|
|
12
|
+
from blake3 import blake3
|
|
13
|
+
|
|
14
|
+
from ..configs.model_config import ModelConfig
|
|
15
|
+
from ..logging.logging_config import get_logger
|
|
16
|
+
from ..models.exceptions import DownloadErrorContext
|
|
17
|
+
from ..models.shared import ModelWithLocation
|
|
18
|
+
from ..utils.model_categories import get_model_category
|
|
19
|
+
|
|
20
|
+
if TYPE_CHECKING:
|
|
21
|
+
from ..repositories.model_repository import ModelRepository
|
|
22
|
+
from ..repositories.workspace_config_repository import WorkspaceConfigRepository
|
|
23
|
+
|
|
24
|
+
logger = get_logger(__name__)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass
|
|
28
|
+
class DownloadRequest:
|
|
29
|
+
"""Request to download a model."""
|
|
30
|
+
url: str
|
|
31
|
+
target_path: Path # Full path in global models directory
|
|
32
|
+
workflow_name: str | None = None
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@dataclass
|
|
36
|
+
class DownloadResult:
|
|
37
|
+
"""Result of a download operation."""
|
|
38
|
+
success: bool
|
|
39
|
+
model: ModelWithLocation | None = None
|
|
40
|
+
error: str | None = None
|
|
41
|
+
error_context: "DownloadErrorContext | None" = None # Structured error info
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class ModelDownloader:
|
|
45
|
+
"""Handles model downloads with hashing and indexing.
|
|
46
|
+
|
|
47
|
+
Responsibilities:
|
|
48
|
+
- Download files from URLs with progress tracking
|
|
49
|
+
- Compute hashes (short + full blake3)
|
|
50
|
+
- Register in ModelRepository
|
|
51
|
+
- Detect URL type (civitai/HF/direct)
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
def __init__(
|
|
55
|
+
self,
|
|
56
|
+
model_repository: ModelRepository,
|
|
57
|
+
workspace_config: WorkspaceConfigRepository,
|
|
58
|
+
models_dir: Path | None = None
|
|
59
|
+
):
|
|
60
|
+
"""Initialize ModelDownloader.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
model_repository: Repository for indexing models
|
|
64
|
+
workspace_config: Workspace config for API credentials and models directory
|
|
65
|
+
models_dir: Optional override for models directory (defaults to workspace config)
|
|
66
|
+
"""
|
|
67
|
+
self.repository = model_repository
|
|
68
|
+
self.workspace_config = workspace_config
|
|
69
|
+
|
|
70
|
+
# Use provided models_dir or get from workspace config
|
|
71
|
+
self.models_dir = models_dir if models_dir is not None else workspace_config.get_models_directory()
|
|
72
|
+
|
|
73
|
+
# Since workspace always has models_dir configured, this should never be None
|
|
74
|
+
# Raise clear error if it somehow is
|
|
75
|
+
if self.models_dir is None:
|
|
76
|
+
raise ValueError(
|
|
77
|
+
"No models directory available. Either provide models_dir parameter "
|
|
78
|
+
"or ensure workspace config has a models directory configured."
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
self.model_config = ModelConfig.load()
|
|
82
|
+
|
|
83
|
+
def detect_url_type(self, url: str) -> str:
|
|
84
|
+
"""Detect source type from URL.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
url: URL to analyze
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
'civitai', 'huggingface', or 'custom'
|
|
91
|
+
"""
|
|
92
|
+
url_lower = url.lower()
|
|
93
|
+
|
|
94
|
+
if "civitai.com" in url_lower:
|
|
95
|
+
return "civitai"
|
|
96
|
+
elif "huggingface.co" in url_lower or "hf.co" in url_lower:
|
|
97
|
+
return "huggingface"
|
|
98
|
+
else:
|
|
99
|
+
return "custom"
|
|
100
|
+
|
|
101
|
+
def suggest_path(
|
|
102
|
+
self,
|
|
103
|
+
url: str,
|
|
104
|
+
node_type: str | None = None,
|
|
105
|
+
filename_hint: str | None = None
|
|
106
|
+
) -> Path:
|
|
107
|
+
"""Suggest download path based on context.
|
|
108
|
+
|
|
109
|
+
For known nodes: checkpoints/model.safetensors
|
|
110
|
+
For unknown: Uses filename hint or extracts from URL
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
url: Download URL
|
|
114
|
+
node_type: Optional node type for category mapping
|
|
115
|
+
filename_hint: Optional filename hint from workflow
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
Suggested relative path (including base directory)
|
|
119
|
+
"""
|
|
120
|
+
# Extract filename from URL or use hint
|
|
121
|
+
filename = self._extract_filename(url, filename_hint)
|
|
122
|
+
|
|
123
|
+
# If node type is known, map to directory
|
|
124
|
+
if node_type and self.model_config.is_model_loader_node(node_type):
|
|
125
|
+
directories = self.model_config.get_directories_for_node(node_type)
|
|
126
|
+
base_dir = directories[0] # e.g., "checkpoints"
|
|
127
|
+
return Path(base_dir) / filename
|
|
128
|
+
|
|
129
|
+
# Fallback: try to extract category from filename hint
|
|
130
|
+
if filename_hint:
|
|
131
|
+
category = get_model_category(filename_hint)
|
|
132
|
+
return Path(category) / filename
|
|
133
|
+
|
|
134
|
+
# Last resort: use generic models directory
|
|
135
|
+
return Path("models") / filename
|
|
136
|
+
|
|
137
|
+
def _extract_filename(self, url: str, filename_hint: str | None = None) -> str:
|
|
138
|
+
"""Extract filename from URL or use hint.
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
url: Download URL
|
|
142
|
+
filename_hint: Optional filename from workflow
|
|
143
|
+
|
|
144
|
+
Returns:
|
|
145
|
+
Filename to use
|
|
146
|
+
"""
|
|
147
|
+
# Try to extract from URL path
|
|
148
|
+
parsed = urlparse(url)
|
|
149
|
+
url_filename = Path(parsed.path).name
|
|
150
|
+
|
|
151
|
+
# Use URL filename if it looks valid (has extension)
|
|
152
|
+
if url_filename and '.' in url_filename:
|
|
153
|
+
return url_filename
|
|
154
|
+
|
|
155
|
+
# Fall back to hint
|
|
156
|
+
if filename_hint:
|
|
157
|
+
# Extract just the filename from hint path
|
|
158
|
+
return Path(filename_hint).name
|
|
159
|
+
|
|
160
|
+
# Last resort: generate generic name
|
|
161
|
+
return "downloaded_model.safetensors"
|
|
162
|
+
|
|
163
|
+
def _check_provider_auth(self, provider: str) -> bool:
|
|
164
|
+
"""Check if authentication is configured for a provider.
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
provider: Provider type ('civitai', 'huggingface', 'custom')
|
|
168
|
+
|
|
169
|
+
Returns:
|
|
170
|
+
True if auth credentials are configured
|
|
171
|
+
"""
|
|
172
|
+
if provider == "civitai":
|
|
173
|
+
if not self.workspace_config:
|
|
174
|
+
return False
|
|
175
|
+
api_key = self.workspace_config.get_civitai_token()
|
|
176
|
+
return api_key is not None and api_key.strip() != ""
|
|
177
|
+
elif provider == "huggingface":
|
|
178
|
+
# Check HF_TOKEN environment variable
|
|
179
|
+
import os
|
|
180
|
+
token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
|
|
181
|
+
return token is not None and token.strip() != ""
|
|
182
|
+
else:
|
|
183
|
+
return False
|
|
184
|
+
|
|
185
|
+
def _classify_download_error(
|
|
186
|
+
self,
|
|
187
|
+
error: Exception,
|
|
188
|
+
url: str,
|
|
189
|
+
provider: str,
|
|
190
|
+
has_auth: bool
|
|
191
|
+
) -> DownloadErrorContext:
|
|
192
|
+
"""Classify download error and create structured context.
|
|
193
|
+
|
|
194
|
+
Args:
|
|
195
|
+
error: The exception that occurred
|
|
196
|
+
url: Download URL
|
|
197
|
+
provider: Provider type
|
|
198
|
+
has_auth: Whether auth was configured
|
|
199
|
+
|
|
200
|
+
Returns:
|
|
201
|
+
DownloadErrorContext with classification
|
|
202
|
+
"""
|
|
203
|
+
from urllib.error import URLError
|
|
204
|
+
from socket import timeout as SocketTimeout
|
|
205
|
+
|
|
206
|
+
http_status = None
|
|
207
|
+
error_category = "unknown"
|
|
208
|
+
raw_error = str(error)
|
|
209
|
+
|
|
210
|
+
# Classify based on exception type
|
|
211
|
+
if isinstance(error, requests.HTTPError):
|
|
212
|
+
http_status = error.response.status_code
|
|
213
|
+
|
|
214
|
+
if http_status == 401:
|
|
215
|
+
# Unauthorized - check if we have auth
|
|
216
|
+
if not has_auth:
|
|
217
|
+
error_category = "auth_missing"
|
|
218
|
+
else:
|
|
219
|
+
error_category = "auth_invalid"
|
|
220
|
+
elif http_status == 403:
|
|
221
|
+
# Forbidden - could be rate limit, permissions, or invalid token
|
|
222
|
+
if not has_auth and provider in ("civitai", "huggingface"):
|
|
223
|
+
error_category = "auth_missing"
|
|
224
|
+
else:
|
|
225
|
+
error_category = "forbidden"
|
|
226
|
+
elif http_status == 404:
|
|
227
|
+
error_category = "not_found"
|
|
228
|
+
elif http_status >= 500:
|
|
229
|
+
error_category = "server"
|
|
230
|
+
else:
|
|
231
|
+
error_category = "unknown"
|
|
232
|
+
|
|
233
|
+
elif isinstance(error, (URLError, SocketTimeout, requests.Timeout, requests.ConnectionError)):
|
|
234
|
+
error_category = "network"
|
|
235
|
+
|
|
236
|
+
return DownloadErrorContext(
|
|
237
|
+
provider=provider,
|
|
238
|
+
error_category=error_category,
|
|
239
|
+
http_status=http_status,
|
|
240
|
+
url=url,
|
|
241
|
+
has_configured_auth=has_auth,
|
|
242
|
+
raw_error=raw_error
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
def download(
|
|
246
|
+
self,
|
|
247
|
+
request: DownloadRequest,
|
|
248
|
+
progress_callback=None
|
|
249
|
+
) -> DownloadResult:
|
|
250
|
+
"""Download and index a model.
|
|
251
|
+
|
|
252
|
+
Flow:
|
|
253
|
+
1. Check if URL already downloaded
|
|
254
|
+
2. Validate URL and target path
|
|
255
|
+
3. Download to temp file with progress
|
|
256
|
+
4. Hash during download (streaming)
|
|
257
|
+
5. Move to target location
|
|
258
|
+
6. Register in repository
|
|
259
|
+
7. Add source URL
|
|
260
|
+
|
|
261
|
+
Args:
|
|
262
|
+
request: Download request with URL and target path
|
|
263
|
+
progress_callback: Optional callback(bytes_downloaded, total_bytes) for progress updates.
|
|
264
|
+
total_bytes may be None if server doesn't provide Content-Length.
|
|
265
|
+
|
|
266
|
+
Returns:
|
|
267
|
+
DownloadResult with model or error
|
|
268
|
+
"""
|
|
269
|
+
temp_path: Path | None = None
|
|
270
|
+
try:
|
|
271
|
+
# Step 1: Check if already downloaded
|
|
272
|
+
existing = self.repository.find_by_source_url(request.url)
|
|
273
|
+
if existing:
|
|
274
|
+
logger.info(f"Model already downloaded from URL: {existing.relative_path}")
|
|
275
|
+
return DownloadResult(success=True, model=existing)
|
|
276
|
+
|
|
277
|
+
# Step 2: Validate target path
|
|
278
|
+
request.target_path.parent.mkdir(parents=True, exist_ok=True)
|
|
279
|
+
|
|
280
|
+
# Step 3-4: Download with streaming hash calculation
|
|
281
|
+
logger.info(f"Downloading from {request.url}")
|
|
282
|
+
|
|
283
|
+
# Add Civitai auth header if URL is from Civitai and we have an API key
|
|
284
|
+
headers = {}
|
|
285
|
+
if "civitai.com" in request.url.lower() and self.workspace_config:
|
|
286
|
+
api_key = self.workspace_config.get_civitai_token()
|
|
287
|
+
if api_key:
|
|
288
|
+
headers['Authorization'] = f'Bearer {api_key}'
|
|
289
|
+
logger.debug("Using Civitai API key for authentication")
|
|
290
|
+
|
|
291
|
+
# Timeout: (connect_timeout, read_timeout)
|
|
292
|
+
# 30s to establish connection, None for read (allow slow downloads)
|
|
293
|
+
response = requests.get(request.url, stream=True, timeout=(30, None), headers=headers)
|
|
294
|
+
response.raise_for_status()
|
|
295
|
+
|
|
296
|
+
# Extract total size from headers (may be None)
|
|
297
|
+
total_size = None
|
|
298
|
+
if 'content-length' in response.headers:
|
|
299
|
+
try:
|
|
300
|
+
total_size = int(response.headers['content-length'])
|
|
301
|
+
except (ValueError, TypeError):
|
|
302
|
+
pass
|
|
303
|
+
|
|
304
|
+
# Use temp file for atomic move
|
|
305
|
+
with tempfile.NamedTemporaryFile(delete=False, dir=request.target_path.parent) as temp_file:
|
|
306
|
+
temp_path = Path(temp_file.name)
|
|
307
|
+
|
|
308
|
+
# Stream download with hash calculation
|
|
309
|
+
hasher = blake3()
|
|
310
|
+
file_size = 0
|
|
311
|
+
|
|
312
|
+
chunk_size = 8192
|
|
313
|
+
for chunk in response.iter_content(chunk_size=chunk_size):
|
|
314
|
+
if chunk:
|
|
315
|
+
temp_file.write(chunk)
|
|
316
|
+
hasher.update(chunk)
|
|
317
|
+
file_size += len(chunk)
|
|
318
|
+
|
|
319
|
+
if progress_callback:
|
|
320
|
+
progress_callback(file_size, total_size)
|
|
321
|
+
|
|
322
|
+
# Step 5: Calculate short hash for indexing
|
|
323
|
+
short_hash = self.repository.calculate_short_hash(temp_path)
|
|
324
|
+
blake3_hash = hasher.hexdigest()
|
|
325
|
+
|
|
326
|
+
# Step 6: Atomic move to final location (replace handles existing files)
|
|
327
|
+
temp_path.replace(request.target_path)
|
|
328
|
+
temp_path = None # Clear temp_path since file has been moved
|
|
329
|
+
|
|
330
|
+
# Step 7: Register in repository
|
|
331
|
+
relative_path = request.target_path.relative_to(self.models_dir)
|
|
332
|
+
mtime = request.target_path.stat().st_mtime
|
|
333
|
+
|
|
334
|
+
self.repository.ensure_model(
|
|
335
|
+
hash=short_hash,
|
|
336
|
+
file_size=file_size,
|
|
337
|
+
blake3_hash=blake3_hash
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
self.repository.add_location(
|
|
341
|
+
model_hash=short_hash,
|
|
342
|
+
base_directory=self.models_dir,
|
|
343
|
+
relative_path=str(relative_path),
|
|
344
|
+
filename=request.target_path.name,
|
|
345
|
+
mtime=mtime
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
# Step 8: Add source URL
|
|
349
|
+
source_type = self.detect_url_type(request.url)
|
|
350
|
+
self.repository.add_source(
|
|
351
|
+
model_hash=short_hash,
|
|
352
|
+
source_type=source_type,
|
|
353
|
+
source_url=request.url
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
# Step 9: Create result model
|
|
357
|
+
model = ModelWithLocation(
|
|
358
|
+
hash=short_hash,
|
|
359
|
+
file_size=file_size,
|
|
360
|
+
blake3_hash=blake3_hash,
|
|
361
|
+
sha256_hash=None,
|
|
362
|
+
relative_path=str(relative_path),
|
|
363
|
+
filename=request.target_path.name,
|
|
364
|
+
mtime=mtime,
|
|
365
|
+
last_seen=int(mtime),
|
|
366
|
+
metadata={}
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
logger.info(f"Successfully downloaded and indexed: {relative_path}")
|
|
370
|
+
return DownloadResult(success=True, model=model)
|
|
371
|
+
|
|
372
|
+
except requests.HTTPError as e:
|
|
373
|
+
# HTTP errors with status codes - classify them
|
|
374
|
+
provider = self.detect_url_type(request.url)
|
|
375
|
+
has_auth = self._check_provider_auth(provider)
|
|
376
|
+
error_context = self._classify_download_error(e, request.url, provider, has_auth)
|
|
377
|
+
|
|
378
|
+
# Generate user-friendly message
|
|
379
|
+
user_message = error_context.get_user_message()
|
|
380
|
+
logger.error(f"Download failed: {user_message}")
|
|
381
|
+
|
|
382
|
+
return DownloadResult(
|
|
383
|
+
success=False,
|
|
384
|
+
error=user_message,
|
|
385
|
+
error_context=error_context
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
except (requests.Timeout, requests.ConnectionError) as e:
|
|
389
|
+
# Network errors
|
|
390
|
+
provider = self.detect_url_type(request.url)
|
|
391
|
+
error_context = self._classify_download_error(e, request.url, provider, False)
|
|
392
|
+
user_message = error_context.get_user_message()
|
|
393
|
+
logger.error(f"Download failed: {user_message}")
|
|
394
|
+
|
|
395
|
+
return DownloadResult(
|
|
396
|
+
success=False,
|
|
397
|
+
error=user_message,
|
|
398
|
+
error_context=error_context
|
|
399
|
+
)
|
|
400
|
+
|
|
401
|
+
except Exception as e:
|
|
402
|
+
# Unexpected errors - still provide some context
|
|
403
|
+
provider = self.detect_url_type(request.url)
|
|
404
|
+
has_auth = self._check_provider_auth(provider)
|
|
405
|
+
error_context = self._classify_download_error(e, request.url, provider, has_auth)
|
|
406
|
+
user_message = error_context.get_user_message()
|
|
407
|
+
logger.error(f"Unexpected download error: {user_message}")
|
|
408
|
+
|
|
409
|
+
return DownloadResult(
|
|
410
|
+
success=False,
|
|
411
|
+
error=user_message,
|
|
412
|
+
error_context=error_context
|
|
413
|
+
)
|
|
414
|
+
|
|
415
|
+
finally:
|
|
416
|
+
# Always clean up temp file if it still exists (download failed or was interrupted)
|
|
417
|
+
if temp_path is not None and temp_path.exists():
|
|
418
|
+
try:
|
|
419
|
+
temp_path.unlink()
|
|
420
|
+
logger.debug(f"Cleaned up temporary file: {temp_path}")
|
|
421
|
+
except Exception as cleanup_error:
|
|
422
|
+
logger.warning(f"Failed to clean up temp file {temp_path}: {cleanup_error}")
|
|
@@ -0,0 +1,251 @@
|
|
|
1
|
+
"""NodeLookupService - Pure stateless service for finding nodes and analyzing requirements."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import TYPE_CHECKING
|
|
7
|
+
|
|
8
|
+
from comfygit_core.models.exceptions import CDNodeNotFoundError, CDRegistryError
|
|
9
|
+
from comfygit_core.models.shared import NodeInfo
|
|
10
|
+
|
|
11
|
+
from ..analyzers.custom_node_scanner import CustomNodeScanner
|
|
12
|
+
from ..caching import APICacheManager, CustomNodeCacheManager
|
|
13
|
+
from ..clients import ComfyRegistryClient, GitHubClient
|
|
14
|
+
from ..logging.logging_config import get_logger
|
|
15
|
+
from ..utils.git import is_git_url
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
from comfygit_core.repositories.node_mappings_repository import NodeMappingsRepository
|
|
19
|
+
from comfygit_core.repositories.workspace_config_repository import WorkspaceConfigRepository
|
|
20
|
+
|
|
21
|
+
logger = get_logger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class NodeLookupService:
|
|
25
|
+
"""Pure stateless service for finding nodes and analyzing their requirements.
|
|
26
|
+
|
|
27
|
+
Responsibilities:
|
|
28
|
+
- Registry lookup (cache-first, then API)
|
|
29
|
+
- GitHub API calls (validating repos, getting commit info)
|
|
30
|
+
- Requirement scanning (analyzing node directories)
|
|
31
|
+
- Cache management (API responses, downloaded node archives)
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(
|
|
35
|
+
self,
|
|
36
|
+
cache_path: Path,
|
|
37
|
+
node_mappings_repository: NodeMappingsRepository | None = None,
|
|
38
|
+
workspace_config_repository: WorkspaceConfigRepository | None = None,
|
|
39
|
+
):
|
|
40
|
+
"""Initialize the node lookup service.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
cache_path: Required path to workspace cache directory
|
|
44
|
+
node_mappings_repository: Repository for cached node mappings
|
|
45
|
+
workspace_config_repository: Repository for workspace config (cache preference)
|
|
46
|
+
"""
|
|
47
|
+
self.scanner = CustomNodeScanner()
|
|
48
|
+
self.api_cache = APICacheManager(cache_base_path=cache_path)
|
|
49
|
+
self.custom_node_cache = CustomNodeCacheManager(cache_base_path=cache_path)
|
|
50
|
+
self.registry_client = ComfyRegistryClient(cache_manager=self.api_cache)
|
|
51
|
+
self.github_client = GitHubClient(cache_manager=self.api_cache)
|
|
52
|
+
self.node_mappings_repository = node_mappings_repository
|
|
53
|
+
self.workspace_config_repository = workspace_config_repository
|
|
54
|
+
|
|
55
|
+
def find_node(self, identifier: str) -> NodeInfo | None:
|
|
56
|
+
"""Find node info from cache, registry, or git URL.
|
|
57
|
+
|
|
58
|
+
Cache-first strategy:
|
|
59
|
+
1. If prefer_registry_cache=True and package in local cache → return from cache
|
|
60
|
+
2. Otherwise → query live API
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
identifier: Registry ID, git URL, or name. Supports @version/@ref syntax:
|
|
64
|
+
- registry-id@1.0.0 (registry version)
|
|
65
|
+
- https://github.com/user/repo@v1.2.3 (git tag)
|
|
66
|
+
- https://github.com/user/repo@main (git branch)
|
|
67
|
+
- https://github.com/user/repo@abc123 (git commit)
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
NodeInfo with metadata, or None if not found
|
|
71
|
+
"""
|
|
72
|
+
# Parse version/ref from identifier (e.g., "package-id@1.2.3" or "https://...@branch")
|
|
73
|
+
requested_version = None
|
|
74
|
+
base_identifier = identifier
|
|
75
|
+
|
|
76
|
+
if '@' in identifier:
|
|
77
|
+
parts = identifier.rsplit('@', 1) # rsplit to handle URLs with @
|
|
78
|
+
base_identifier = parts[0]
|
|
79
|
+
requested_version = parts[1]
|
|
80
|
+
|
|
81
|
+
# Check if it's a git URL - these bypass cache
|
|
82
|
+
if is_git_url(base_identifier):
|
|
83
|
+
try:
|
|
84
|
+
if repo_info := self.github_client.get_repository_info(base_identifier, ref=requested_version):
|
|
85
|
+
return NodeInfo(
|
|
86
|
+
name=repo_info.name,
|
|
87
|
+
repository=repo_info.clone_url,
|
|
88
|
+
source="git",
|
|
89
|
+
version=repo_info.latest_commit # This will be the requested ref's commit
|
|
90
|
+
)
|
|
91
|
+
except Exception as e:
|
|
92
|
+
logger.warning(f"Invalid git URL: {e}")
|
|
93
|
+
return None
|
|
94
|
+
|
|
95
|
+
# Check if we should prefer cached mappings
|
|
96
|
+
prefer_cache = True
|
|
97
|
+
if self.workspace_config_repository:
|
|
98
|
+
prefer_cache = self.workspace_config_repository.get_prefer_registry_cache()
|
|
99
|
+
|
|
100
|
+
# Strategy: Cache first, then API
|
|
101
|
+
if prefer_cache and self.node_mappings_repository:
|
|
102
|
+
package = self.node_mappings_repository.get_package(base_identifier)
|
|
103
|
+
if package:
|
|
104
|
+
logger.debug(f"Found '{base_identifier}' in local cache")
|
|
105
|
+
return NodeInfo.from_global_package(package, version=requested_version)
|
|
106
|
+
else:
|
|
107
|
+
logger.debug(f"'{base_identifier}' not in local cache, trying API...")
|
|
108
|
+
|
|
109
|
+
# Fallback to registry API
|
|
110
|
+
try:
|
|
111
|
+
registry_node = self.registry_client.get_node(base_identifier)
|
|
112
|
+
if registry_node:
|
|
113
|
+
if requested_version:
|
|
114
|
+
version = requested_version
|
|
115
|
+
logger.debug(f"Using requested version: {version}")
|
|
116
|
+
else:
|
|
117
|
+
version = registry_node.latest_version.version if registry_node.latest_version else None
|
|
118
|
+
node_version = self.registry_client.install_node(registry_node.id, version)
|
|
119
|
+
if node_version:
|
|
120
|
+
registry_node.latest_version = node_version
|
|
121
|
+
return NodeInfo.from_registry_node(registry_node)
|
|
122
|
+
except CDRegistryError as e:
|
|
123
|
+
logger.warning(f"Cannot reach registry: {e}")
|
|
124
|
+
|
|
125
|
+
logger.debug(f"Node '{base_identifier}' not found")
|
|
126
|
+
return None
|
|
127
|
+
|
|
128
|
+
def get_node(self, identifier: str) -> NodeInfo:
|
|
129
|
+
"""Get a node - raises if not found.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
identifier: Registry ID, node name, or git URL
|
|
133
|
+
|
|
134
|
+
Returns:
|
|
135
|
+
NodeInfo with metadata
|
|
136
|
+
|
|
137
|
+
Raises:
|
|
138
|
+
CDNodeNotFoundError: If node not found in any source
|
|
139
|
+
"""
|
|
140
|
+
node = self.find_node(identifier)
|
|
141
|
+
if not node:
|
|
142
|
+
# Build context-aware error based on what was tried
|
|
143
|
+
if is_git_url(identifier):
|
|
144
|
+
msg = f"Node '{identifier}' not found. GitHub repository is invalid or inaccessible."
|
|
145
|
+
else:
|
|
146
|
+
# Registry lookup was attempted
|
|
147
|
+
prefer_cache = True
|
|
148
|
+
if self.workspace_config_repository:
|
|
149
|
+
prefer_cache = self.workspace_config_repository.get_prefer_registry_cache()
|
|
150
|
+
|
|
151
|
+
sources_tried = ["local cache", "registry API"] if prefer_cache else ["registry API"]
|
|
152
|
+
msg = f"Node '{identifier}' not found in {' or '.join(sources_tried)}"
|
|
153
|
+
|
|
154
|
+
raise CDNodeNotFoundError(msg)
|
|
155
|
+
return node
|
|
156
|
+
|
|
157
|
+
def scan_requirements(self, node_path: Path) -> list[str]:
|
|
158
|
+
"""Scan a node directory for Python requirements.
|
|
159
|
+
|
|
160
|
+
Args:
|
|
161
|
+
node_path: Path to node directory
|
|
162
|
+
|
|
163
|
+
Returns:
|
|
164
|
+
List of requirement strings (empty if none found)
|
|
165
|
+
"""
|
|
166
|
+
deps = self.scanner.scan_node(node_path)
|
|
167
|
+
if deps and deps.requirements:
|
|
168
|
+
logger.info(f"Found {len(deps.requirements)} requirements in {node_path.name}")
|
|
169
|
+
return deps.requirements
|
|
170
|
+
logger.info(f"No requirements found in {node_path.name}")
|
|
171
|
+
return []
|
|
172
|
+
|
|
173
|
+
def download_to_cache(self, node_info: NodeInfo) -> Path | None:
|
|
174
|
+
"""Download a node to cache and return the cached path.
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
node_info: Node information
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
Path to cached node directory, or None if download failed
|
|
181
|
+
"""
|
|
182
|
+
import tempfile
|
|
183
|
+
|
|
184
|
+
from ..utils.download import download_and_extract_archive
|
|
185
|
+
from ..utils.git import git_clone
|
|
186
|
+
|
|
187
|
+
# Check if already cached
|
|
188
|
+
if cache_path := self.custom_node_cache.get_cached_path(node_info):
|
|
189
|
+
logger.debug(f"Node '{node_info.name}' already in cache")
|
|
190
|
+
return cache_path
|
|
191
|
+
|
|
192
|
+
# Download to temp location
|
|
193
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
194
|
+
temp_path = Path(tmpdir) / node_info.name
|
|
195
|
+
|
|
196
|
+
try:
|
|
197
|
+
if node_info.source == "registry":
|
|
198
|
+
if not node_info.download_url:
|
|
199
|
+
# Fallback: Clone from repository if download URL missing
|
|
200
|
+
if node_info.repository:
|
|
201
|
+
logger.info(
|
|
202
|
+
f"No CDN package for '{node_info.name}', "
|
|
203
|
+
f"falling back to git clone from {node_info.repository}"
|
|
204
|
+
)
|
|
205
|
+
# Update source to git for this installation
|
|
206
|
+
node_info.source = "git"
|
|
207
|
+
ref = node_info.version if node_info.version else None
|
|
208
|
+
git_clone(node_info.repository, temp_path, depth=1, ref=ref, timeout=30)
|
|
209
|
+
else:
|
|
210
|
+
logger.error(
|
|
211
|
+
f"Cannot download '{node_info.name}': "
|
|
212
|
+
f"no CDN package and no repository URL"
|
|
213
|
+
)
|
|
214
|
+
return None
|
|
215
|
+
else:
|
|
216
|
+
download_and_extract_archive(node_info.download_url, temp_path)
|
|
217
|
+
elif node_info.source == "git":
|
|
218
|
+
if not node_info.repository:
|
|
219
|
+
logger.error(f"No repository URL for git node '{node_info.name}'")
|
|
220
|
+
return None
|
|
221
|
+
ref = node_info.version if node_info.version else None
|
|
222
|
+
git_clone(node_info.repository, temp_path, depth=1, ref=ref, timeout=30)
|
|
223
|
+
else:
|
|
224
|
+
logger.error(f"Unsupported source: '{node_info.source}'")
|
|
225
|
+
return None
|
|
226
|
+
|
|
227
|
+
# Cache it
|
|
228
|
+
logger.info(f"Caching node '{node_info.name}'")
|
|
229
|
+
return self.custom_node_cache.cache_node(node_info, temp_path)
|
|
230
|
+
|
|
231
|
+
except Exception as e:
|
|
232
|
+
logger.error(f"Failed to download node '{node_info.name}': {e}")
|
|
233
|
+
return None
|
|
234
|
+
|
|
235
|
+
def search_nodes(self, query: str, limit: int = 10) -> list[NodeInfo] | None:
|
|
236
|
+
"""Search for nodes in the registry.
|
|
237
|
+
|
|
238
|
+
Args:
|
|
239
|
+
query: Search term
|
|
240
|
+
limit: Maximum results
|
|
241
|
+
|
|
242
|
+
Returns:
|
|
243
|
+
List of matching NodeInfo objects or None
|
|
244
|
+
"""
|
|
245
|
+
try:
|
|
246
|
+
nodes = self.registry_client.search_nodes(query)
|
|
247
|
+
if nodes:
|
|
248
|
+
return [NodeInfo.from_registry_node(node) for node in nodes[:limit]]
|
|
249
|
+
except CDRegistryError as e:
|
|
250
|
+
logger.warning(f"Failed to search registry: {e}")
|
|
251
|
+
return None
|