sie-sdk 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sie_sdk/__init__.py +93 -0
- sie_sdk/bundle_utils.py +122 -0
- sie_sdk/cache.py +374 -0
- sie_sdk/client/__init__.py +30 -0
- sie_sdk/client/_shared.py +311 -0
- sie_sdk/client/async_.py +1586 -0
- sie_sdk/client/errors.py +125 -0
- sie_sdk/client/sync.py +1782 -0
- sie_sdk/exceptions.py +41 -0
- sie_sdk/images.py +151 -0
- sie_sdk/scoring.py +131 -0
- sie_sdk/storage.py +528 -0
- sie_sdk/types.py +656 -0
- sie_sdk-0.1.7.dist-info/METADATA +25 -0
- sie_sdk-0.1.7.dist-info/RECORD +17 -0
- sie_sdk-0.1.7.dist-info/WHEEL +4 -0
- sie_sdk-0.1.7.dist-info/licenses/LICENSE +201 -0
sie_sdk/__init__.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
"""SIE SDK - Python client for Search Inference Engine.
|
|
2
|
+
|
|
3
|
+
Example:
|
|
4
|
+
>>> from sie_sdk import SIEClient
|
|
5
|
+
>>> client = SIEClient("http://localhost:8080")
|
|
6
|
+
>>> result = client.encode("bge-m3", {"text": "Hello world"})
|
|
7
|
+
>>> result["dense"] # np.ndarray, shape [1024]
|
|
8
|
+
|
|
9
|
+
For ColBERT/late interaction models, use the scoring module:
|
|
10
|
+
>>> from sie_sdk.scoring import maxsim
|
|
11
|
+
>>> scores = maxsim(query_multivector, doc_multivectors)
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from sie_sdk.client import (
|
|
15
|
+
LoraLoadingError,
|
|
16
|
+
ModelLoadingError,
|
|
17
|
+
PoolError,
|
|
18
|
+
ProvisioningError,
|
|
19
|
+
RequestError,
|
|
20
|
+
ServerError,
|
|
21
|
+
SIEAsyncClient,
|
|
22
|
+
SIEClient,
|
|
23
|
+
SIEConnectionError,
|
|
24
|
+
SIEError,
|
|
25
|
+
)
|
|
26
|
+
from sie_sdk.types import (
|
|
27
|
+
# Response types
|
|
28
|
+
AssignedWorkerInfo,
|
|
29
|
+
CapacityInfo,
|
|
30
|
+
ClusterStatusMessage,
|
|
31
|
+
ClusterSummary,
|
|
32
|
+
EncodeResult,
|
|
33
|
+
Entity,
|
|
34
|
+
ExtractResult,
|
|
35
|
+
HealthResponse,
|
|
36
|
+
Item,
|
|
37
|
+
ModelInfo,
|
|
38
|
+
ModelSummary,
|
|
39
|
+
PoolInfo,
|
|
40
|
+
PoolListItem,
|
|
41
|
+
PoolResponse,
|
|
42
|
+
PoolSpec,
|
|
43
|
+
PoolSpecResponse,
|
|
44
|
+
PoolStatusInfo,
|
|
45
|
+
ScoreResult,
|
|
46
|
+
SparseResult,
|
|
47
|
+
StatusMessage,
|
|
48
|
+
TimingInfo,
|
|
49
|
+
WorkerInfo,
|
|
50
|
+
WorkerStatusMessage,
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
__version__ = "0.1.0"
|
|
54
|
+
|
|
55
|
+
__all__ = [
|
|
56
|
+
# Response types
|
|
57
|
+
"AssignedWorkerInfo",
|
|
58
|
+
"CapacityInfo",
|
|
59
|
+
"ClusterStatusMessage",
|
|
60
|
+
"ClusterSummary",
|
|
61
|
+
"EncodeResult",
|
|
62
|
+
"Entity",
|
|
63
|
+
"ExtractResult",
|
|
64
|
+
"HealthResponse",
|
|
65
|
+
# Input types
|
|
66
|
+
"Item",
|
|
67
|
+
# Exceptions
|
|
68
|
+
"LoraLoadingError",
|
|
69
|
+
"ModelInfo",
|
|
70
|
+
"ModelLoadingError",
|
|
71
|
+
"ModelSummary",
|
|
72
|
+
"PoolError",
|
|
73
|
+
"PoolInfo",
|
|
74
|
+
"PoolListItem",
|
|
75
|
+
"PoolResponse",
|
|
76
|
+
"PoolSpec",
|
|
77
|
+
"PoolSpecResponse",
|
|
78
|
+
"PoolStatusInfo",
|
|
79
|
+
"ProvisioningError",
|
|
80
|
+
"RequestError",
|
|
81
|
+
# Clients
|
|
82
|
+
"SIEAsyncClient",
|
|
83
|
+
"SIEClient",
|
|
84
|
+
"SIEConnectionError",
|
|
85
|
+
"SIEError",
|
|
86
|
+
"ScoreResult",
|
|
87
|
+
"ServerError",
|
|
88
|
+
"SparseResult",
|
|
89
|
+
"StatusMessage",
|
|
90
|
+
"TimingInfo",
|
|
91
|
+
"WorkerInfo",
|
|
92
|
+
"WorkerStatusMessage",
|
|
93
|
+
]
|
sie_sdk/bundle_utils.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
import yaml
|
|
7
|
+
|
|
8
|
+
logger = logging.getLogger(__name__)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _scan_model_adapters(models_dir: Path) -> dict[str, set[str]]:
|
|
12
|
+
"""Scan model config YAMLs and return adapter modules per model.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
models_dir: Path to the models directory containing *.yaml configs.
|
|
16
|
+
|
|
17
|
+
Returns:
|
|
18
|
+
Dict mapping model name to set of adapter module paths.
|
|
19
|
+
"""
|
|
20
|
+
result: dict[str, set[str]] = {}
|
|
21
|
+
if not models_dir.exists():
|
|
22
|
+
return result
|
|
23
|
+
|
|
24
|
+
for model_path in sorted(models_dir.glob("*.yaml")):
|
|
25
|
+
try:
|
|
26
|
+
model_data = yaml.safe_load(model_path.read_text()) or {}
|
|
27
|
+
except Exception:
|
|
28
|
+
logger.exception("Failed to parse model config %s", model_path.name)
|
|
29
|
+
continue
|
|
30
|
+
model_name = model_data.get("sie_id", model_path.stem.replace("__", "/"))
|
|
31
|
+
modules: set[str] = set()
|
|
32
|
+
for profile in model_data.get("profiles", {}).values():
|
|
33
|
+
adapter_path = profile.get("adapter_path", "")
|
|
34
|
+
module_path = adapter_path.split(":", maxsplit=1)[0]
|
|
35
|
+
if module_path:
|
|
36
|
+
modules.add(module_path)
|
|
37
|
+
if modules:
|
|
38
|
+
result[model_name] = modules
|
|
39
|
+
|
|
40
|
+
return result
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def match_bundle_models(bundle_path: Path, models_dir: Path) -> list[str]:
|
|
44
|
+
"""Match models to a bundle by adapter module paths.
|
|
45
|
+
|
|
46
|
+
Loads the bundle YAML to get its adapter module list, then scans
|
|
47
|
+
model config YAMLs to find models whose adapter_path module matches.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
bundle_path: Path to the bundle YAML file.
|
|
51
|
+
models_dir: Path to the models directory containing *.yaml configs.
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
List of model names (sie_id or derived from filename) whose adapters
|
|
55
|
+
match the bundle's adapter list.
|
|
56
|
+
"""
|
|
57
|
+
with bundle_path.open() as f:
|
|
58
|
+
data = yaml.safe_load(f) or {}
|
|
59
|
+
|
|
60
|
+
adapter_modules = set(data.get("adapters", []))
|
|
61
|
+
if not adapter_modules:
|
|
62
|
+
return []
|
|
63
|
+
|
|
64
|
+
model_adapters = _scan_model_adapters(models_dir)
|
|
65
|
+
return [name for name, modules in model_adapters.items() if modules & adapter_modules]
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def find_bundle_for_models(
|
|
69
|
+
model_names: list[str],
|
|
70
|
+
bundles_dir: Path,
|
|
71
|
+
models_dir: Path,
|
|
72
|
+
) -> str | None:
|
|
73
|
+
"""Find the best bundle whose adapters cover the given models.
|
|
74
|
+
|
|
75
|
+
Scans all bundle YAMLs in bundles_dir and returns the one whose adapter
|
|
76
|
+
set covers all requested models with the fewest extra adapters (most
|
|
77
|
+
specific match). Ties are broken by bundle priority (lower = higher
|
|
78
|
+
priority).
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
model_names: List of model names to match.
|
|
82
|
+
bundles_dir: Path to the bundles directory.
|
|
83
|
+
models_dir: Path to the models directory containing *.yaml configs.
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
Bundle name (without .yaml) of the best match, or None if no bundle
|
|
87
|
+
covers all requested models.
|
|
88
|
+
"""
|
|
89
|
+
if not model_names or not bundles_dir.exists() or not models_dir.exists():
|
|
90
|
+
return None
|
|
91
|
+
|
|
92
|
+
# Collect adapter modules needed by the requested models
|
|
93
|
+
model_adapters = _scan_model_adapters(models_dir)
|
|
94
|
+
needed_adapters: set[str] = set()
|
|
95
|
+
for name in model_names:
|
|
96
|
+
needed_adapters |= model_adapters.get(name, set())
|
|
97
|
+
|
|
98
|
+
if not needed_adapters:
|
|
99
|
+
return None
|
|
100
|
+
|
|
101
|
+
# Score each bundle: must cover all needed adapters
|
|
102
|
+
best_name: str | None = None
|
|
103
|
+
best_extra = float("inf")
|
|
104
|
+
best_priority = float("inf")
|
|
105
|
+
|
|
106
|
+
for bundle_path in sorted(bundles_dir.glob("*.yaml")):
|
|
107
|
+
try:
|
|
108
|
+
data = yaml.safe_load(bundle_path.read_text()) or {}
|
|
109
|
+
except Exception:
|
|
110
|
+
logger.exception("Failed to parse bundle %s", bundle_path.name)
|
|
111
|
+
continue
|
|
112
|
+
bundle_adapters = set(data.get("adapters", []))
|
|
113
|
+
if not needed_adapters <= bundle_adapters:
|
|
114
|
+
continue # doesn't cover all needed adapters
|
|
115
|
+
extra = len(bundle_adapters - needed_adapters)
|
|
116
|
+
priority = data.get("priority", 50)
|
|
117
|
+
if extra < best_extra or (extra == best_extra and priority < best_priority):
|
|
118
|
+
best_name = bundle_path.stem
|
|
119
|
+
best_extra = extra
|
|
120
|
+
best_priority = priority
|
|
121
|
+
|
|
122
|
+
return best_name
|
sie_sdk/cache.py
ADDED
|
@@ -0,0 +1,374 @@
|
|
|
1
|
+
"""Model weight caching with hierarchical lookup.
|
|
2
|
+
|
|
3
|
+
Implements the caching hierarchy for model weights:
|
|
4
|
+
1. Local cache (HF_HOME/hub by default)
|
|
5
|
+
2. Cluster cache (S3/GCS bucket)
|
|
6
|
+
3. HuggingFace Hub fallback (if enabled)
|
|
7
|
+
|
|
8
|
+
The caching is transparent to adapters - they always see files in local cache.
|
|
9
|
+
This module pre-populates local cache from cluster cache if needed.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
import logging
|
|
15
|
+
import os
|
|
16
|
+
from dataclasses import dataclass
|
|
17
|
+
from pathlib import Path
|
|
18
|
+
from typing import TYPE_CHECKING
|
|
19
|
+
|
|
20
|
+
from sie_sdk.exceptions import GatedModelError
|
|
21
|
+
from sie_sdk.storage import get_storage_backend, join_path
|
|
22
|
+
|
|
23
|
+
if TYPE_CHECKING:
|
|
24
|
+
from sie_sdk.storage import StorageBackend
|
|
25
|
+
|
|
26
|
+
logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
HF_IGNORE_PATTERNS = [
|
|
29
|
+
"*.md",
|
|
30
|
+
"README*",
|
|
31
|
+
"LICENSE*",
|
|
32
|
+
"docs/*",
|
|
33
|
+
".github/*",
|
|
34
|
+
"*.onnx",
|
|
35
|
+
"onnx/*",
|
|
36
|
+
# training-only
|
|
37
|
+
"optimizer.*",
|
|
38
|
+
"scheduler.*",
|
|
39
|
+
"trainer_state.json",
|
|
40
|
+
"training_args.bin",
|
|
41
|
+
"checkpoint-*/*",
|
|
42
|
+
]
|
|
43
|
+
HF_LEGACY_WEIGHT_PATTERNS = ["*.bin", "*.ckpt", "*.msgpack", "*.h5", "*.ot"]
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@dataclass
|
|
47
|
+
class CacheConfig:
|
|
48
|
+
"""Configuration for the weight cache."""
|
|
49
|
+
|
|
50
|
+
local_cache: Path
|
|
51
|
+
"""Local cache directory (usually HF_HOME/hub)."""
|
|
52
|
+
|
|
53
|
+
cluster_cache: str | None = None
|
|
54
|
+
"""Cluster cache URL (s3:// or gs://), or None if not configured."""
|
|
55
|
+
|
|
56
|
+
hf_fallback: bool = True
|
|
57
|
+
"""Whether to fallback to HuggingFace Hub for downloads."""
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def get_cache_config() -> CacheConfig:
|
|
61
|
+
"""Get cache configuration from environment variables.
|
|
62
|
+
|
|
63
|
+
Reads:
|
|
64
|
+
SIE_LOCAL_CACHE: Local cache directory (default: HF_HOME/hub)
|
|
65
|
+
SIE_CLUSTER_CACHE: Cluster cache URL (s3:// or gs://)
|
|
66
|
+
SIE_HF_FALLBACK: Whether to enable HF Hub fallback (default: true)
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
CacheConfig with resolved paths.
|
|
70
|
+
"""
|
|
71
|
+
# Local cache: explicit SIE_LOCAL_CACHE, or HF_HOME, or default
|
|
72
|
+
local_cache_env = os.environ.get("SIE_LOCAL_CACHE")
|
|
73
|
+
if local_cache_env:
|
|
74
|
+
local_cache = Path(local_cache_env)
|
|
75
|
+
else:
|
|
76
|
+
hf_home = os.environ.get("HF_HOME")
|
|
77
|
+
if hf_home:
|
|
78
|
+
local_cache = Path(hf_home) / "hub"
|
|
79
|
+
else:
|
|
80
|
+
local_cache = Path.home() / ".cache" / "huggingface" / "hub"
|
|
81
|
+
|
|
82
|
+
# Cluster cache: S3/GCS URL
|
|
83
|
+
cluster_cache = os.environ.get("SIE_CLUSTER_CACHE")
|
|
84
|
+
|
|
85
|
+
# HF fallback: default true
|
|
86
|
+
hf_fallback_str = os.environ.get("SIE_HF_FALLBACK", "true").lower()
|
|
87
|
+
hf_fallback = hf_fallback_str in ("true", "1", "yes")
|
|
88
|
+
|
|
89
|
+
return CacheConfig(
|
|
90
|
+
local_cache=local_cache,
|
|
91
|
+
cluster_cache=cluster_cache,
|
|
92
|
+
hf_fallback=hf_fallback,
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def is_model_cached(model_id: str, config: CacheConfig | None = None) -> bool:
|
|
97
|
+
"""Check if a model is already in local cache.
|
|
98
|
+
|
|
99
|
+
Uses HuggingFace Hub's cache structure to check for the model.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
model_id: HuggingFace model ID (e.g., "BAAI/bge-m3").
|
|
103
|
+
config: Cache configuration. If None, reads from environment.
|
|
104
|
+
|
|
105
|
+
Returns:
|
|
106
|
+
True if model appears to be cached locally.
|
|
107
|
+
"""
|
|
108
|
+
if config is None:
|
|
109
|
+
config = get_cache_config()
|
|
110
|
+
|
|
111
|
+
# HF Hub cache structure: models--{org}--{model}/snapshots/{revision}/
|
|
112
|
+
cache_dir = config.local_cache / f"models--{model_id.replace('/', '--')}"
|
|
113
|
+
if not cache_dir.exists():
|
|
114
|
+
return False
|
|
115
|
+
|
|
116
|
+
# Check for any snapshot directory with files
|
|
117
|
+
snapshots_dir = cache_dir / "snapshots"
|
|
118
|
+
if not snapshots_dir.exists():
|
|
119
|
+
return False
|
|
120
|
+
|
|
121
|
+
return any(snapshot.is_dir() and any(snapshot.iterdir()) for snapshot in snapshots_dir.iterdir())
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def ensure_model_cached(
|
|
125
|
+
model_id: str,
|
|
126
|
+
config: CacheConfig | None = None,
|
|
127
|
+
revision: str | None = None,
|
|
128
|
+
) -> Path:
|
|
129
|
+
"""Ensure model weights are in local cache.
|
|
130
|
+
|
|
131
|
+
Implements the caching hierarchy:
|
|
132
|
+
1. Check local cache - return path if found
|
|
133
|
+
2. Check cluster cache - download to local if found
|
|
134
|
+
3. Download from HuggingFace Hub (if fallback enabled)
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
model_id: HuggingFace model ID (e.g., "BAAI/bge-m3").
|
|
138
|
+
config: Cache configuration. If None, reads from environment.
|
|
139
|
+
revision: Specific model revision (commit hash, branch, or tag).
|
|
140
|
+
|
|
141
|
+
Returns:
|
|
142
|
+
Path to cached model in local cache.
|
|
143
|
+
|
|
144
|
+
Raises:
|
|
145
|
+
GatedModelError: If model is gated and authentication fails.
|
|
146
|
+
RuntimeError: If model not found in any cache tier and HF fallback disabled.
|
|
147
|
+
"""
|
|
148
|
+
if config is None:
|
|
149
|
+
config = get_cache_config()
|
|
150
|
+
|
|
151
|
+
# Check local cache first
|
|
152
|
+
if is_model_cached(model_id, config):
|
|
153
|
+
logger.debug("Model %s found in local cache", model_id)
|
|
154
|
+
return _get_model_cache_path(model_id, config)
|
|
155
|
+
|
|
156
|
+
# Try cluster cache if configured
|
|
157
|
+
if config.cluster_cache:
|
|
158
|
+
if _download_from_cluster_cache(model_id, config):
|
|
159
|
+
logger.info("Downloaded %s from cluster cache", model_id)
|
|
160
|
+
return _get_model_cache_path(model_id, config)
|
|
161
|
+
|
|
162
|
+
# Download from HuggingFace Hub if fallback enabled
|
|
163
|
+
if config.hf_fallback:
|
|
164
|
+
logger.info("Downloading %s from HuggingFace Hub", model_id)
|
|
165
|
+
return _download_from_huggingface(model_id, config, revision)
|
|
166
|
+
|
|
167
|
+
# Not in any cache and HF fallback disabled
|
|
168
|
+
raise RuntimeError(
|
|
169
|
+
f"Model '{model_id}' not found in local or cluster cache, "
|
|
170
|
+
f"and HuggingFace fallback is disabled (SIE_HF_FALLBACK=false)"
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def _get_model_cache_path(model_id: str, config: CacheConfig) -> Path:
|
|
175
|
+
"""Get the local cache path for a model.
|
|
176
|
+
|
|
177
|
+
Args:
|
|
178
|
+
model_id: HuggingFace model ID.
|
|
179
|
+
config: Cache configuration.
|
|
180
|
+
|
|
181
|
+
Returns:
|
|
182
|
+
Path to the model's cache directory.
|
|
183
|
+
"""
|
|
184
|
+
return config.local_cache / f"models--{model_id.replace('/', '--')}"
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def _download_from_cluster_cache(model_id: str, config: CacheConfig) -> bool:
|
|
188
|
+
"""Download model from cluster cache to local cache.
|
|
189
|
+
|
|
190
|
+
Mirrors the HuggingFace Hub cache structure from cluster to local.
|
|
191
|
+
|
|
192
|
+
Args:
|
|
193
|
+
model_id: HuggingFace model ID.
|
|
194
|
+
config: Cache configuration (must have cluster_cache set).
|
|
195
|
+
|
|
196
|
+
Returns:
|
|
197
|
+
True if successfully downloaded, False if not found in cluster cache.
|
|
198
|
+
"""
|
|
199
|
+
if not config.cluster_cache:
|
|
200
|
+
return False
|
|
201
|
+
|
|
202
|
+
backend = get_storage_backend(config.cluster_cache)
|
|
203
|
+
|
|
204
|
+
# Construct cluster cache path
|
|
205
|
+
model_folder = f"models--{model_id.replace('/', '--')}"
|
|
206
|
+
cluster_model_path = join_path(config.cluster_cache, model_folder)
|
|
207
|
+
|
|
208
|
+
# Check if model exists in cluster cache
|
|
209
|
+
if not backend.exists(join_path(cluster_model_path, "snapshots")):
|
|
210
|
+
logger.debug("Model %s not found in cluster cache", model_id)
|
|
211
|
+
return False
|
|
212
|
+
|
|
213
|
+
# Create local cache directory
|
|
214
|
+
local_model_path = config.local_cache / model_folder
|
|
215
|
+
local_model_path.mkdir(parents=True, exist_ok=True)
|
|
216
|
+
|
|
217
|
+
# Download the entire model directory structure
|
|
218
|
+
_download_directory(backend, cluster_model_path, local_model_path)
|
|
219
|
+
|
|
220
|
+
return True
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
def _download_directory(backend: StorageBackend, src_path: str, dst_path: Path) -> None:
|
|
224
|
+
"""Recursively download a directory from cloud storage.
|
|
225
|
+
|
|
226
|
+
Args:
|
|
227
|
+
backend: Storage backend instance.
|
|
228
|
+
src_path: Source path in cloud storage.
|
|
229
|
+
dst_path: Destination local path.
|
|
230
|
+
"""
|
|
231
|
+
# Download files in current directory
|
|
232
|
+
for filename in backend.list_files(src_path):
|
|
233
|
+
src_file = join_path(src_path, filename)
|
|
234
|
+
dst_file = dst_path / filename
|
|
235
|
+
backend.download_file(src_file, dst_file)
|
|
236
|
+
|
|
237
|
+
# Recursively download subdirectories
|
|
238
|
+
for dirname in backend.list_dirs(src_path):
|
|
239
|
+
src_subdir = join_path(src_path, dirname)
|
|
240
|
+
dst_subdir = dst_path / dirname
|
|
241
|
+
dst_subdir.mkdir(parents=True, exist_ok=True)
|
|
242
|
+
_download_directory(backend, src_subdir, dst_subdir)
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
def _get_hf_ignore_patterns(
|
|
246
|
+
model_id: str,
|
|
247
|
+
revision: str | None,
|
|
248
|
+
token: str | None,
|
|
249
|
+
) -> list[str]:
|
|
250
|
+
"""Check if a HF repo has safetensors files; This avoids downloading duplicate weight files (e.g., both .safetensors and .bin)."""
|
|
251
|
+
from huggingface_hub import file_exists
|
|
252
|
+
from huggingface_hub.errors import (
|
|
253
|
+
GatedRepoError,
|
|
254
|
+
HfHubHTTPError,
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
try:
|
|
258
|
+
has_safetensors = any(
|
|
259
|
+
file_exists(model_id, filename, revision=revision, token=token)
|
|
260
|
+
for filename in ("model.safetensors", "model.safetensors.index.json")
|
|
261
|
+
)
|
|
262
|
+
if has_safetensors:
|
|
263
|
+
return HF_IGNORE_PATTERNS + HF_LEGACY_WEIGHT_PATTERNS
|
|
264
|
+
except (HfHubHTTPError, GatedRepoError) as e:
|
|
265
|
+
logger.warning(
|
|
266
|
+
"Failed to list HF repo files; falling back to default ignore patterns",
|
|
267
|
+
extra={
|
|
268
|
+
"model_id": model_id,
|
|
269
|
+
"revision": revision,
|
|
270
|
+
"status_code": getattr(e.response, "status_code", None),
|
|
271
|
+
},
|
|
272
|
+
)
|
|
273
|
+
return HF_IGNORE_PATTERNS
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
def _download_from_huggingface(
|
|
277
|
+
model_id: str,
|
|
278
|
+
config: CacheConfig,
|
|
279
|
+
revision: str | None = None,
|
|
280
|
+
) -> Path:
|
|
281
|
+
"""Download model from HuggingFace Hub to local cache.
|
|
282
|
+
|
|
283
|
+
Uses huggingface_hub's snapshot_download which automatically uses
|
|
284
|
+
the standard HF cache structure (~/.cache/huggingface/hub/).
|
|
285
|
+
|
|
286
|
+
Args:
|
|
287
|
+
model_id: HuggingFace model ID.
|
|
288
|
+
config: Cache configuration.
|
|
289
|
+
revision: Specific model revision.
|
|
290
|
+
|
|
291
|
+
Returns:
|
|
292
|
+
Path to the model in local cache.
|
|
293
|
+
|
|
294
|
+
Raises:
|
|
295
|
+
GatedModelError: If model is gated and authentication fails.
|
|
296
|
+
"""
|
|
297
|
+
from huggingface_hub import snapshot_download
|
|
298
|
+
from huggingface_hub.errors import (
|
|
299
|
+
GatedRepoError,
|
|
300
|
+
HfHubHTTPError,
|
|
301
|
+
RepositoryNotFoundError,
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
# Get token from environment
|
|
305
|
+
token = os.environ.get("HF_TOKEN")
|
|
306
|
+
|
|
307
|
+
try:
|
|
308
|
+
# snapshot_download automatically uses HF cache structure
|
|
309
|
+
# cache_dir should point to HF_HOME/hub where model dirs are created
|
|
310
|
+
snapshot_download(
|
|
311
|
+
model_id,
|
|
312
|
+
revision=revision,
|
|
313
|
+
token=token,
|
|
314
|
+
cache_dir=str(config.local_cache), # Points to HF_HOME/hub
|
|
315
|
+
ignore_patterns=_get_hf_ignore_patterns(model_id, revision, token),
|
|
316
|
+
)
|
|
317
|
+
return _get_model_cache_path(model_id, config)
|
|
318
|
+
|
|
319
|
+
except GatedRepoError as e:
|
|
320
|
+
# User-friendly error for gated models
|
|
321
|
+
raise GatedModelError(model_id, e) from e
|
|
322
|
+
|
|
323
|
+
except RepositoryNotFoundError as e:
|
|
324
|
+
# Check if this might be a gated model accessed without auth
|
|
325
|
+
if token is None:
|
|
326
|
+
msg = (
|
|
327
|
+
f"Model '{model_id}' not found. This could mean:\n"
|
|
328
|
+
f" 1. The model ID is incorrect\n"
|
|
329
|
+
f" 2. The model is private/gated and requires authentication\n\n"
|
|
330
|
+
f"If this is a gated model, set HF_TOKEN environment variable:\n"
|
|
331
|
+
f" export HF_TOKEN=hf_xxxxxxxxxxxxxxxxxxxxx\n\n"
|
|
332
|
+
f"Original error: {e}"
|
|
333
|
+
)
|
|
334
|
+
raise RuntimeError(msg) from e
|
|
335
|
+
raise
|
|
336
|
+
|
|
337
|
+
except HfHubHTTPError as e:
|
|
338
|
+
# Handle 401/403 errors that might indicate auth issues
|
|
339
|
+
status_code = getattr(e.response, "status_code", None) if hasattr(e, "response") else None
|
|
340
|
+
if status_code in (401, 403):
|
|
341
|
+
raise GatedModelError(model_id, e) from e
|
|
342
|
+
raise
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
def populate_cluster_cache(
|
|
346
|
+
model_id: str,
|
|
347
|
+
config: CacheConfig | None = None,
|
|
348
|
+
) -> bool:
|
|
349
|
+
"""Upload model from local cache to cluster cache.
|
|
350
|
+
|
|
351
|
+
Used by sie-admin to pre-populate cluster cache.
|
|
352
|
+
|
|
353
|
+
Args:
|
|
354
|
+
model_id: HuggingFace model ID.
|
|
355
|
+
config: Cache configuration.
|
|
356
|
+
|
|
357
|
+
Returns:
|
|
358
|
+
True if successfully uploaded, False if model not in local cache.
|
|
359
|
+
"""
|
|
360
|
+
if config is None:
|
|
361
|
+
config = get_cache_config()
|
|
362
|
+
|
|
363
|
+
if not config.cluster_cache:
|
|
364
|
+
logger.warning("No cluster cache configured")
|
|
365
|
+
return False
|
|
366
|
+
|
|
367
|
+
if not is_model_cached(model_id, config):
|
|
368
|
+
logger.warning("Model %s not in local cache, cannot populate cluster", model_id)
|
|
369
|
+
return False
|
|
370
|
+
|
|
371
|
+
# TODO: Implement upload to cluster cache
|
|
372
|
+
# This requires adding upload methods to storage backends
|
|
373
|
+
logger.info("populate_cluster_cache not yet implemented")
|
|
374
|
+
return False
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
"""SIE SDK Client package.
|
|
2
|
+
|
|
3
|
+
Re-exports all client classes and errors for backwards compatibility.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from sie_sdk.client.async_ import SIEAsyncClient
|
|
7
|
+
from sie_sdk.client.errors import (
|
|
8
|
+
LoraLoadingError,
|
|
9
|
+
ModelLoadingError,
|
|
10
|
+
PoolError,
|
|
11
|
+
ProvisioningError,
|
|
12
|
+
RequestError,
|
|
13
|
+
ServerError,
|
|
14
|
+
SIEConnectionError,
|
|
15
|
+
SIEError,
|
|
16
|
+
)
|
|
17
|
+
from sie_sdk.client.sync import SIEClient
|
|
18
|
+
|
|
19
|
+
__all__ = [
|
|
20
|
+
"LoraLoadingError",
|
|
21
|
+
"ModelLoadingError",
|
|
22
|
+
"PoolError",
|
|
23
|
+
"ProvisioningError",
|
|
24
|
+
"RequestError",
|
|
25
|
+
"SIEAsyncClient",
|
|
26
|
+
"SIEClient",
|
|
27
|
+
"SIEConnectionError",
|
|
28
|
+
"SIEError",
|
|
29
|
+
"ServerError",
|
|
30
|
+
]
|