sie-sdk 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sie_sdk/__init__.py ADDED
@@ -0,0 +1,93 @@
1
+ """SIE SDK - Python client for Search Inference Engine.
2
+
3
+ Example:
4
+ >>> from sie_sdk import SIEClient
5
+ >>> client = SIEClient("http://localhost:8080")
6
+ >>> result = client.encode("bge-m3", {"text": "Hello world"})
7
+ >>> result["dense"] # np.ndarray, shape [1024]
8
+
9
+ For ColBERT/late interaction models, use the scoring module:
10
+ >>> from sie_sdk.scoring import maxsim
11
+ >>> scores = maxsim(query_multivector, doc_multivectors)
12
+ """
13
+
14
+ from sie_sdk.client import (
15
+ LoraLoadingError,
16
+ ModelLoadingError,
17
+ PoolError,
18
+ ProvisioningError,
19
+ RequestError,
20
+ ServerError,
21
+ SIEAsyncClient,
22
+ SIEClient,
23
+ SIEConnectionError,
24
+ SIEError,
25
+ )
26
+ from sie_sdk.types import (
27
+ # Response types
28
+ AssignedWorkerInfo,
29
+ CapacityInfo,
30
+ ClusterStatusMessage,
31
+ ClusterSummary,
32
+ EncodeResult,
33
+ Entity,
34
+ ExtractResult,
35
+ HealthResponse,
36
+ Item,
37
+ ModelInfo,
38
+ ModelSummary,
39
+ PoolInfo,
40
+ PoolListItem,
41
+ PoolResponse,
42
+ PoolSpec,
43
+ PoolSpecResponse,
44
+ PoolStatusInfo,
45
+ ScoreResult,
46
+ SparseResult,
47
+ StatusMessage,
48
+ TimingInfo,
49
+ WorkerInfo,
50
+ WorkerStatusMessage,
51
+ )
52
+
53
+ __version__ = "0.1.0"
54
+
55
+ __all__ = [
56
+ # Response types
57
+ "AssignedWorkerInfo",
58
+ "CapacityInfo",
59
+ "ClusterStatusMessage",
60
+ "ClusterSummary",
61
+ "EncodeResult",
62
+ "Entity",
63
+ "ExtractResult",
64
+ "HealthResponse",
65
+ # Input types
66
+ "Item",
67
+ # Exceptions
68
+ "LoraLoadingError",
69
+ "ModelInfo",
70
+ "ModelLoadingError",
71
+ "ModelSummary",
72
+ "PoolError",
73
+ "PoolInfo",
74
+ "PoolListItem",
75
+ "PoolResponse",
76
+ "PoolSpec",
77
+ "PoolSpecResponse",
78
+ "PoolStatusInfo",
79
+ "ProvisioningError",
80
+ "RequestError",
81
+ # Clients
82
+ "SIEAsyncClient",
83
+ "SIEClient",
84
+ "SIEConnectionError",
85
+ "SIEError",
86
+ "ScoreResult",
87
+ "ServerError",
88
+ "SparseResult",
89
+ "StatusMessage",
90
+ "TimingInfo",
91
+ "WorkerInfo",
92
+ "WorkerStatusMessage",
93
+ ]
@@ -0,0 +1,122 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from pathlib import Path
5
+
6
+ import yaml
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ def _scan_model_adapters(models_dir: Path) -> dict[str, set[str]]:
12
+ """Scan model config YAMLs and return adapter modules per model.
13
+
14
+ Args:
15
+ models_dir: Path to the models directory containing *.yaml configs.
16
+
17
+ Returns:
18
+ Dict mapping model name to set of adapter module paths.
19
+ """
20
+ result: dict[str, set[str]] = {}
21
+ if not models_dir.exists():
22
+ return result
23
+
24
+ for model_path in sorted(models_dir.glob("*.yaml")):
25
+ try:
26
+ model_data = yaml.safe_load(model_path.read_text()) or {}
27
+ except Exception:
28
+ logger.exception("Failed to parse model config %s", model_path.name)
29
+ continue
30
+ model_name = model_data.get("sie_id", model_path.stem.replace("__", "/"))
31
+ modules: set[str] = set()
32
+ for profile in model_data.get("profiles", {}).values():
33
+ adapter_path = profile.get("adapter_path", "")
34
+ module_path = adapter_path.split(":", maxsplit=1)[0]
35
+ if module_path:
36
+ modules.add(module_path)
37
+ if modules:
38
+ result[model_name] = modules
39
+
40
+ return result
41
+
42
+
43
+ def match_bundle_models(bundle_path: Path, models_dir: Path) -> list[str]:
44
+ """Match models to a bundle by adapter module paths.
45
+
46
+ Loads the bundle YAML to get its adapter module list, then scans
47
+ model config YAMLs to find models whose adapter_path module matches.
48
+
49
+ Args:
50
+ bundle_path: Path to the bundle YAML file.
51
+ models_dir: Path to the models directory containing *.yaml configs.
52
+
53
+ Returns:
54
+ List of model names (sie_id or derived from filename) whose adapters
55
+ match the bundle's adapter list.
56
+ """
57
+ with bundle_path.open() as f:
58
+ data = yaml.safe_load(f) or {}
59
+
60
+ adapter_modules = set(data.get("adapters", []))
61
+ if not adapter_modules:
62
+ return []
63
+
64
+ model_adapters = _scan_model_adapters(models_dir)
65
+ return [name for name, modules in model_adapters.items() if modules & adapter_modules]
66
+
67
+
68
+ def find_bundle_for_models(
69
+ model_names: list[str],
70
+ bundles_dir: Path,
71
+ models_dir: Path,
72
+ ) -> str | None:
73
+ """Find the best bundle whose adapters cover the given models.
74
+
75
+ Scans all bundle YAMLs in bundles_dir and returns the one whose adapter
76
+ set covers all requested models with the fewest extra adapters (most
77
+ specific match). Ties are broken by bundle priority (lower = higher
78
+ priority).
79
+
80
+ Args:
81
+ model_names: List of model names to match.
82
+ bundles_dir: Path to the bundles directory.
83
+ models_dir: Path to the models directory containing *.yaml configs.
84
+
85
+ Returns:
86
+ Bundle name (without .yaml) of the best match, or None if no bundle
87
+ covers all requested models.
88
+ """
89
+ if not model_names or not bundles_dir.exists() or not models_dir.exists():
90
+ return None
91
+
92
+ # Collect adapter modules needed by the requested models
93
+ model_adapters = _scan_model_adapters(models_dir)
94
+ needed_adapters: set[str] = set()
95
+ for name in model_names:
96
+ needed_adapters |= model_adapters.get(name, set())
97
+
98
+ if not needed_adapters:
99
+ return None
100
+
101
+ # Score each bundle: must cover all needed adapters
102
+ best_name: str | None = None
103
+ best_extra = float("inf")
104
+ best_priority = float("inf")
105
+
106
+ for bundle_path in sorted(bundles_dir.glob("*.yaml")):
107
+ try:
108
+ data = yaml.safe_load(bundle_path.read_text()) or {}
109
+ except Exception:
110
+ logger.exception("Failed to parse bundle %s", bundle_path.name)
111
+ continue
112
+ bundle_adapters = set(data.get("adapters", []))
113
+ if not needed_adapters <= bundle_adapters:
114
+ continue # doesn't cover all needed adapters
115
+ extra = len(bundle_adapters - needed_adapters)
116
+ priority = data.get("priority", 50)
117
+ if extra < best_extra or (extra == best_extra and priority < best_priority):
118
+ best_name = bundle_path.stem
119
+ best_extra = extra
120
+ best_priority = priority
121
+
122
+ return best_name
sie_sdk/cache.py ADDED
@@ -0,0 +1,374 @@
1
+ """Model weight caching with hierarchical lookup.
2
+
3
+ Implements the caching hierarchy for model weights:
4
+ 1. Local cache (HF_HOME/hub by default)
5
+ 2. Cluster cache (S3/GCS bucket)
6
+ 3. HuggingFace Hub fallback (if enabled)
7
+
8
+ The caching is transparent to adapters - they always see files in local cache.
9
+ This module pre-populates local cache from cluster cache if needed.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import logging
15
+ import os
16
+ from dataclasses import dataclass
17
+ from pathlib import Path
18
+ from typing import TYPE_CHECKING
19
+
20
+ from sie_sdk.exceptions import GatedModelError
21
+ from sie_sdk.storage import get_storage_backend, join_path
22
+
23
+ if TYPE_CHECKING:
24
+ from sie_sdk.storage import StorageBackend
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+ HF_IGNORE_PATTERNS = [
29
+ "*.md",
30
+ "README*",
31
+ "LICENSE*",
32
+ "docs/*",
33
+ ".github/*",
34
+ "*.onnx",
35
+ "onnx/*",
36
+ # training-only
37
+ "optimizer.*",
38
+ "scheduler.*",
39
+ "trainer_state.json",
40
+ "training_args.bin",
41
+ "checkpoint-*/*",
42
+ ]
43
+ HF_LEGACY_WEIGHT_PATTERNS = ["*.bin", "*.ckpt", "*.msgpack", "*.h5", "*.ot"]
44
+
45
+
46
+ @dataclass
47
+ class CacheConfig:
48
+ """Configuration for the weight cache."""
49
+
50
+ local_cache: Path
51
+ """Local cache directory (usually HF_HOME/hub)."""
52
+
53
+ cluster_cache: str | None = None
54
+ """Cluster cache URL (s3:// or gs://), or None if not configured."""
55
+
56
+ hf_fallback: bool = True
57
+ """Whether to fallback to HuggingFace Hub for downloads."""
58
+
59
+
60
+ def get_cache_config() -> CacheConfig:
61
+ """Get cache configuration from environment variables.
62
+
63
+ Reads:
64
+ SIE_LOCAL_CACHE: Local cache directory (default: HF_HOME/hub)
65
+ SIE_CLUSTER_CACHE: Cluster cache URL (s3:// or gs://)
66
+ SIE_HF_FALLBACK: Whether to enable HF Hub fallback (default: true)
67
+
68
+ Returns:
69
+ CacheConfig with resolved paths.
70
+ """
71
+ # Local cache: explicit SIE_LOCAL_CACHE, or HF_HOME, or default
72
+ local_cache_env = os.environ.get("SIE_LOCAL_CACHE")
73
+ if local_cache_env:
74
+ local_cache = Path(local_cache_env)
75
+ else:
76
+ hf_home = os.environ.get("HF_HOME")
77
+ if hf_home:
78
+ local_cache = Path(hf_home) / "hub"
79
+ else:
80
+ local_cache = Path.home() / ".cache" / "huggingface" / "hub"
81
+
82
+ # Cluster cache: S3/GCS URL
83
+ cluster_cache = os.environ.get("SIE_CLUSTER_CACHE")
84
+
85
+ # HF fallback: default true
86
+ hf_fallback_str = os.environ.get("SIE_HF_FALLBACK", "true").lower()
87
+ hf_fallback = hf_fallback_str in ("true", "1", "yes")
88
+
89
+ return CacheConfig(
90
+ local_cache=local_cache,
91
+ cluster_cache=cluster_cache,
92
+ hf_fallback=hf_fallback,
93
+ )
94
+
95
+
96
+ def is_model_cached(model_id: str, config: CacheConfig | None = None) -> bool:
97
+ """Check if a model is already in local cache.
98
+
99
+ Uses HuggingFace Hub's cache structure to check for the model.
100
+
101
+ Args:
102
+ model_id: HuggingFace model ID (e.g., "BAAI/bge-m3").
103
+ config: Cache configuration. If None, reads from environment.
104
+
105
+ Returns:
106
+ True if model appears to be cached locally.
107
+ """
108
+ if config is None:
109
+ config = get_cache_config()
110
+
111
+ # HF Hub cache structure: models--{org}--{model}/snapshots/{revision}/
112
+ cache_dir = config.local_cache / f"models--{model_id.replace('/', '--')}"
113
+ if not cache_dir.exists():
114
+ return False
115
+
116
+ # Check for any snapshot directory with files
117
+ snapshots_dir = cache_dir / "snapshots"
118
+ if not snapshots_dir.exists():
119
+ return False
120
+
121
+ return any(snapshot.is_dir() and any(snapshot.iterdir()) for snapshot in snapshots_dir.iterdir())
122
+
123
+
124
+ def ensure_model_cached(
125
+ model_id: str,
126
+ config: CacheConfig | None = None,
127
+ revision: str | None = None,
128
+ ) -> Path:
129
+ """Ensure model weights are in local cache.
130
+
131
+ Implements the caching hierarchy:
132
+ 1. Check local cache - return path if found
133
+ 2. Check cluster cache - download to local if found
134
+ 3. Download from HuggingFace Hub (if fallback enabled)
135
+
136
+ Args:
137
+ model_id: HuggingFace model ID (e.g., "BAAI/bge-m3").
138
+ config: Cache configuration. If None, reads from environment.
139
+ revision: Specific model revision (commit hash, branch, or tag).
140
+
141
+ Returns:
142
+ Path to cached model in local cache.
143
+
144
+ Raises:
145
+ GatedModelError: If model is gated and authentication fails.
146
+ RuntimeError: If model not found in any cache tier and HF fallback disabled.
147
+ """
148
+ if config is None:
149
+ config = get_cache_config()
150
+
151
+ # Check local cache first
152
+ if is_model_cached(model_id, config):
153
+ logger.debug("Model %s found in local cache", model_id)
154
+ return _get_model_cache_path(model_id, config)
155
+
156
+ # Try cluster cache if configured
157
+ if config.cluster_cache:
158
+ if _download_from_cluster_cache(model_id, config):
159
+ logger.info("Downloaded %s from cluster cache", model_id)
160
+ return _get_model_cache_path(model_id, config)
161
+
162
+ # Download from HuggingFace Hub if fallback enabled
163
+ if config.hf_fallback:
164
+ logger.info("Downloading %s from HuggingFace Hub", model_id)
165
+ return _download_from_huggingface(model_id, config, revision)
166
+
167
+ # Not in any cache and HF fallback disabled
168
+ raise RuntimeError(
169
+ f"Model '{model_id}' not found in local or cluster cache, "
170
+ f"and HuggingFace fallback is disabled (SIE_HF_FALLBACK=false)"
171
+ )
172
+
173
+
174
+ def _get_model_cache_path(model_id: str, config: CacheConfig) -> Path:
175
+ """Get the local cache path for a model.
176
+
177
+ Args:
178
+ model_id: HuggingFace model ID.
179
+ config: Cache configuration.
180
+
181
+ Returns:
182
+ Path to the model's cache directory.
183
+ """
184
+ return config.local_cache / f"models--{model_id.replace('/', '--')}"
185
+
186
+
187
+ def _download_from_cluster_cache(model_id: str, config: CacheConfig) -> bool:
188
+ """Download model from cluster cache to local cache.
189
+
190
+ Mirrors the HuggingFace Hub cache structure from cluster to local.
191
+
192
+ Args:
193
+ model_id: HuggingFace model ID.
194
+ config: Cache configuration (must have cluster_cache set).
195
+
196
+ Returns:
197
+ True if successfully downloaded, False if not found in cluster cache.
198
+ """
199
+ if not config.cluster_cache:
200
+ return False
201
+
202
+ backend = get_storage_backend(config.cluster_cache)
203
+
204
+ # Construct cluster cache path
205
+ model_folder = f"models--{model_id.replace('/', '--')}"
206
+ cluster_model_path = join_path(config.cluster_cache, model_folder)
207
+
208
+ # Check if model exists in cluster cache
209
+ if not backend.exists(join_path(cluster_model_path, "snapshots")):
210
+ logger.debug("Model %s not found in cluster cache", model_id)
211
+ return False
212
+
213
+ # Create local cache directory
214
+ local_model_path = config.local_cache / model_folder
215
+ local_model_path.mkdir(parents=True, exist_ok=True)
216
+
217
+ # Download the entire model directory structure
218
+ _download_directory(backend, cluster_model_path, local_model_path)
219
+
220
+ return True
221
+
222
+
223
+ def _download_directory(backend: StorageBackend, src_path: str, dst_path: Path) -> None:
224
+ """Recursively download a directory from cloud storage.
225
+
226
+ Args:
227
+ backend: Storage backend instance.
228
+ src_path: Source path in cloud storage.
229
+ dst_path: Destination local path.
230
+ """
231
+ # Download files in current directory
232
+ for filename in backend.list_files(src_path):
233
+ src_file = join_path(src_path, filename)
234
+ dst_file = dst_path / filename
235
+ backend.download_file(src_file, dst_file)
236
+
237
+ # Recursively download subdirectories
238
+ for dirname in backend.list_dirs(src_path):
239
+ src_subdir = join_path(src_path, dirname)
240
+ dst_subdir = dst_path / dirname
241
+ dst_subdir.mkdir(parents=True, exist_ok=True)
242
+ _download_directory(backend, src_subdir, dst_subdir)
243
+
244
+
245
+ def _get_hf_ignore_patterns(
246
+ model_id: str,
247
+ revision: str | None,
248
+ token: str | None,
249
+ ) -> list[str]:
250
+ """Check if a HF repo has safetensors files; This avoids downloading duplicate weight files (e.g., both .safetensors and .bin)."""
251
+ from huggingface_hub import file_exists
252
+ from huggingface_hub.errors import (
253
+ GatedRepoError,
254
+ HfHubHTTPError,
255
+ )
256
+
257
+ try:
258
+ has_safetensors = any(
259
+ file_exists(model_id, filename, revision=revision, token=token)
260
+ for filename in ("model.safetensors", "model.safetensors.index.json")
261
+ )
262
+ if has_safetensors:
263
+ return HF_IGNORE_PATTERNS + HF_LEGACY_WEIGHT_PATTERNS
264
+ except (HfHubHTTPError, GatedRepoError) as e:
265
+ logger.warning(
266
+ "Failed to list HF repo files; falling back to default ignore patterns",
267
+ extra={
268
+ "model_id": model_id,
269
+ "revision": revision,
270
+ "status_code": getattr(e.response, "status_code", None),
271
+ },
272
+ )
273
+ return HF_IGNORE_PATTERNS
274
+
275
+
276
+ def _download_from_huggingface(
277
+ model_id: str,
278
+ config: CacheConfig,
279
+ revision: str | None = None,
280
+ ) -> Path:
281
+ """Download model from HuggingFace Hub to local cache.
282
+
283
+ Uses huggingface_hub's snapshot_download which automatically uses
284
+ the standard HF cache structure (~/.cache/huggingface/hub/).
285
+
286
+ Args:
287
+ model_id: HuggingFace model ID.
288
+ config: Cache configuration.
289
+ revision: Specific model revision.
290
+
291
+ Returns:
292
+ Path to the model in local cache.
293
+
294
+ Raises:
295
+ GatedModelError: If model is gated and authentication fails.
296
+ """
297
+ from huggingface_hub import snapshot_download
298
+ from huggingface_hub.errors import (
299
+ GatedRepoError,
300
+ HfHubHTTPError,
301
+ RepositoryNotFoundError,
302
+ )
303
+
304
+ # Get token from environment
305
+ token = os.environ.get("HF_TOKEN")
306
+
307
+ try:
308
+ # snapshot_download automatically uses HF cache structure
309
+ # cache_dir should point to HF_HOME/hub where model dirs are created
310
+ snapshot_download(
311
+ model_id,
312
+ revision=revision,
313
+ token=token,
314
+ cache_dir=str(config.local_cache), # Points to HF_HOME/hub
315
+ ignore_patterns=_get_hf_ignore_patterns(model_id, revision, token),
316
+ )
317
+ return _get_model_cache_path(model_id, config)
318
+
319
+ except GatedRepoError as e:
320
+ # User-friendly error for gated models
321
+ raise GatedModelError(model_id, e) from e
322
+
323
+ except RepositoryNotFoundError as e:
324
+ # Check if this might be a gated model accessed without auth
325
+ if token is None:
326
+ msg = (
327
+ f"Model '{model_id}' not found. This could mean:\n"
328
+ f" 1. The model ID is incorrect\n"
329
+ f" 2. The model is private/gated and requires authentication\n\n"
330
+ f"If this is a gated model, set HF_TOKEN environment variable:\n"
331
+ f" export HF_TOKEN=hf_xxxxxxxxxxxxxxxxxxxxx\n\n"
332
+ f"Original error: {e}"
333
+ )
334
+ raise RuntimeError(msg) from e
335
+ raise
336
+
337
+ except HfHubHTTPError as e:
338
+ # Handle 401/403 errors that might indicate auth issues
339
+ status_code = getattr(e.response, "status_code", None) if hasattr(e, "response") else None
340
+ if status_code in (401, 403):
341
+ raise GatedModelError(model_id, e) from e
342
+ raise
343
+
344
+
345
+ def populate_cluster_cache(
346
+ model_id: str,
347
+ config: CacheConfig | None = None,
348
+ ) -> bool:
349
+ """Upload model from local cache to cluster cache.
350
+
351
+ Used by sie-admin to pre-populate cluster cache.
352
+
353
+ Args:
354
+ model_id: HuggingFace model ID.
355
+ config: Cache configuration.
356
+
357
+ Returns:
358
+ True if successfully uploaded, False if model not in local cache.
359
+ """
360
+ if config is None:
361
+ config = get_cache_config()
362
+
363
+ if not config.cluster_cache:
364
+ logger.warning("No cluster cache configured")
365
+ return False
366
+
367
+ if not is_model_cached(model_id, config):
368
+ logger.warning("Model %s not in local cache, cannot populate cluster", model_id)
369
+ return False
370
+
371
+ # TODO: Implement upload to cluster cache
372
+ # This requires adding upload methods to storage backends
373
+ logger.info("populate_cluster_cache not yet implemented")
374
+ return False
@@ -0,0 +1,30 @@
1
+ """SIE SDK Client package.
2
+
3
+ Re-exports all client classes and errors for backwards compatibility.
4
+ """
5
+
6
+ from sie_sdk.client.async_ import SIEAsyncClient
7
+ from sie_sdk.client.errors import (
8
+ LoraLoadingError,
9
+ ModelLoadingError,
10
+ PoolError,
11
+ ProvisioningError,
12
+ RequestError,
13
+ ServerError,
14
+ SIEConnectionError,
15
+ SIEError,
16
+ )
17
+ from sie_sdk.client.sync import SIEClient
18
+
19
+ __all__ = [
20
+ "LoraLoadingError",
21
+ "ModelLoadingError",
22
+ "PoolError",
23
+ "ProvisioningError",
24
+ "RequestError",
25
+ "SIEAsyncClient",
26
+ "SIEClient",
27
+ "SIEConnectionError",
28
+ "SIEError",
29
+ "ServerError",
30
+ ]