sie-server 0.1.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sie_server/__init__.py +3 -0
- sie_server/adapters/__init__.py +9 -0
- sie_server/adapters/_flash_base.py +160 -0
- sie_server/adapters/_utils.py +146 -0
- sie_server/adapters/base.py +491 -0
- sie_server/adapters/bert_flash/__init__.py +477 -0
- sie_server/adapters/bert_flash_cross_encoder/__init__.py +558 -0
- sie_server/adapters/bge_m3/__init__.py +428 -0
- sie_server/adapters/bge_m3_flag/__init__.py +305 -0
- sie_server/adapters/bge_m3_flash/__init__.py +549 -0
- sie_server/adapters/clip/__init__.py +343 -0
- sie_server/adapters/colbert/__init__.py +1147 -0
- sie_server/adapters/colbert_modernbert_flash/__init__.py +594 -0
- sie_server/adapters/colbert_rotary_flash/__init__.py +772 -0
- sie_server/adapters/colpali/__init__.py +610 -0
- sie_server/adapters/colqwen2/__init__.py +489 -0
- sie_server/adapters/cross_encoder/__init__.py +252 -0
- sie_server/adapters/donut/__init__.py +588 -0
- sie_server/adapters/florence2/__init__.py +649 -0
- sie_server/adapters/gliclass/__init__.py +237 -0
- sie_server/adapters/gliner/__init__.py +305 -0
- sie_server/adapters/glirel/__init__.py +276 -0
- sie_server/adapters/grounding_dino/__init__.py +415 -0
- sie_server/adapters/gte_sparse_flash/__init__.py +673 -0
- sie_server/adapters/jina_flash_cross_encoder/__init__.py +309 -0
- sie_server/adapters/modernbert_flash_cross_encoder/__init__.py +476 -0
- sie_server/adapters/nemo_colembed/__init__.py +556 -0
- sie_server/adapters/nli_classification/__init__.py +239 -0
- sie_server/adapters/nli_classification_flash/__init__.py +300 -0
- sie_server/adapters/nomic_flash/__init__.py +662 -0
- sie_server/adapters/owlv2/__init__.py +431 -0
- sie_server/adapters/peft_lora_mixin.py +264 -0
- sie_server/adapters/pytorch_embedding/__init__.py +430 -0
- sie_server/adapters/qwen2_flash/__init__.py +632 -0
- sie_server/adapters/qwen2_flash_cross_encoder/__init__.py +569 -0
- sie_server/adapters/rope_flash/__init__.py +535 -0
- sie_server/adapters/sentence_transformer/__init__.py +385 -0
- sie_server/adapters/sglang/__init__.py +628 -0
- sie_server/adapters/siglip/__init__.py +348 -0
- sie_server/adapters/splade_flash/__init__.py +619 -0
- sie_server/adapters/xlm_roberta_flash/__init__.py +519 -0
- sie_server/api/__init__.py +1 -0
- sie_server/api/encode.py +407 -0
- sie_server/api/extract.py +384 -0
- sie_server/api/health.py +47 -0
- sie_server/api/helpers.py +455 -0
- sie_server/api/metrics.py +30 -0
- sie_server/api/models.py +112 -0
- sie_server/api/openai_compat.py +441 -0
- sie_server/api/openapi.py +79 -0
- sie_server/api/options.py +51 -0
- sie_server/api/root.py +17 -0
- sie_server/api/score.py +281 -0
- sie_server/api/serialization.py +65 -0
- sie_server/api/validation.py +60 -0
- sie_server/api/ws.py +333 -0
- sie_server/app/__init__.py +0 -0
- sie_server/app/app_factory.py +292 -0
- sie_server/app/app_state_config.py +56 -0
- sie_server/cli.py +252 -0
- sie_server/config/__init__.py +0 -0
- sie_server/config/engine.py +192 -0
- sie_server/config/model.py +295 -0
- sie_server/core/__init__.py +10 -0
- sie_server/core/adaptive_batching.py +364 -0
- sie_server/core/batcher.py +507 -0
- sie_server/core/deps.py +230 -0
- sie_server/core/disk_cache.py +339 -0
- sie_server/core/encode_pipeline.py +120 -0
- sie_server/core/hot_reload.py +581 -0
- sie_server/core/inference.py +282 -0
- sie_server/core/inference_output.py +171 -0
- sie_server/core/loader.py +424 -0
- sie_server/core/logging.py +110 -0
- sie_server/core/memory.py +435 -0
- sie_server/core/model_loader.py +546 -0
- sie_server/core/postprocessor.py +568 -0
- sie_server/core/postprocessor_registry.py +268 -0
- sie_server/core/prepared.py +306 -0
- sie_server/core/preprocessor/__init__.py +45 -0
- sie_server/core/preprocessor/base.py +133 -0
- sie_server/core/preprocessor/image.py +129 -0
- sie_server/core/preprocessor/text.py +268 -0
- sie_server/core/preprocessor/vision.py +946 -0
- sie_server/core/preprocessor_registry.py +307 -0
- sie_server/core/readiness.py +52 -0
- sie_server/core/registry.py +1239 -0
- sie_server/core/shutdown.py +160 -0
- sie_server/core/timing.py +133 -0
- sie_server/core/tokenizer.py +49 -0
- sie_server/core/watcher.py +391 -0
- sie_server/core/worker/__init__.py +48 -0
- sie_server/core/worker/handlers/__init__.py +23 -0
- sie_server/core/worker/handlers/base.py +125 -0
- sie_server/core/worker/handlers/encode.py +237 -0
- sie_server/core/worker/handlers/extract.py +175 -0
- sie_server/core/worker/handlers/score.py +115 -0
- sie_server/core/worker/model_worker.py +976 -0
- sie_server/core/worker/types.py +184 -0
- sie_server/main.py +30 -0
- sie_server/nats_pull_loop.py +1423 -0
- sie_server/nats_subscriber.py +231 -0
- sie_server/observability/__init__.py +30 -0
- sie_server/observability/gpu.py +202 -0
- sie_server/observability/metrics.py +225 -0
- sie_server/observability/prometheus.py +88 -0
- sie_server/observability/tracing.py +121 -0
- sie_server/static/__init__.py +1 -0
- sie_server/static/index.html +37 -0
- sie_server/types/__init__.py +47 -0
- sie_server/types/inputs.py +124 -0
- sie_server/types/openapi.py +226 -0
- sie_server/types/outputs.py +93 -0
- sie_server/types/requests.py +56 -0
- sie_server/types/responses.py +205 -0
- sie_server-0.1.10.dist-info/METADATA +50 -0
- sie_server-0.1.10.dist-info/RECORD +120 -0
- sie_server-0.1.10.dist-info/WHEEL +4 -0
- sie_server-0.1.10.dist-info/entry_points.txt +2 -0
- sie_server-0.1.10.dist-info/licenses/LICENSE +201 -0
sie_server/__init__.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import gc
|
|
4
|
+
import importlib
|
|
5
|
+
import logging
|
|
6
|
+
from typing import TYPE_CHECKING, Any, ClassVar
|
|
7
|
+
|
|
8
|
+
from sie_server.adapters.base import ModelAdapter
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
import torch
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _import_adapter_class(adapter_path: str) -> type[ModelAdapter]:
|
|
17
|
+
"""Import an adapter class from a 'module:ClassName' string.
|
|
18
|
+
|
|
19
|
+
Uses the same resolution logic as ``loader._import_builtin_adapter`` but
|
|
20
|
+
accepts the **short form** used in the fallback table (e.g.
|
|
21
|
+
``"sentence_transformer:SentenceTransformerDenseAdapter"``). The
|
|
22
|
+
``sie_server.adapters.`` prefix is prepended automatically.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
adapter_path: ``"module:ClassName"`` or
|
|
26
|
+
``"sie_server.adapters.module:ClassName"`` string.
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
The adapter class.
|
|
30
|
+
|
|
31
|
+
Raises:
|
|
32
|
+
ImportError: If module or class cannot be found.
|
|
33
|
+
"""
|
|
34
|
+
if ":" not in adapter_path:
|
|
35
|
+
msg = f"Invalid adapter_path '{adapter_path}': expected 'module:ClassName'"
|
|
36
|
+
raise ImportError(msg)
|
|
37
|
+
|
|
38
|
+
module_path, class_name = adapter_path.rsplit(":", 1)
|
|
39
|
+
|
|
40
|
+
# Allow both short form ("sentence_transformer:Foo") and full form
|
|
41
|
+
if not module_path.startswith("sie_server."):
|
|
42
|
+
module_path = f"sie_server.adapters.{module_path}"
|
|
43
|
+
|
|
44
|
+
module = importlib.import_module(module_path)
|
|
45
|
+
if not hasattr(module, class_name):
|
|
46
|
+
msg = f"Adapter class '{class_name}' not found in module '{module_path}'"
|
|
47
|
+
raise ImportError(msg)
|
|
48
|
+
|
|
49
|
+
return getattr(module, class_name)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class FlashBaseAdapter(ModelAdapter):
|
|
53
|
+
"""Thin base class for all flash-attention adapters.
|
|
54
|
+
|
|
55
|
+
Provides:
|
|
56
|
+
- Declarative fallback: set ``fallback_adapter_path`` and optionally
|
|
57
|
+
``fallback_kwargs_overrides`` instead of overriding ``create_for_device``.
|
|
58
|
+
- Common ``unload()`` with gc + cache clearing.
|
|
59
|
+
- ``_resolve_dtype()`` for compute precision mapping.
|
|
60
|
+
- ``get_preprocessor()`` returning ``CharCountPreprocessor``.
|
|
61
|
+
|
|
62
|
+
Subclasses with custom fallback logic (e.g. SPLADEFlashAdapter) can still
|
|
63
|
+
override ``create_for_device()`` directly.
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
# -- Declarative fallback ------------------------------------------------
|
|
67
|
+
# Subclasses set these to enable automatic flash -> non-flash fallback.
|
|
68
|
+
fallback_adapter_path: ClassVar[str | None] = None
|
|
69
|
+
fallback_kwargs_overrides: ClassVar[dict[str, Any]] = {}
|
|
70
|
+
|
|
71
|
+
@classmethod
|
|
72
|
+
def create_for_device(cls, device: str, **kwargs: Any) -> ModelAdapter:
|
|
73
|
+
"""Factory method for device-aware adapter instantiation.
|
|
74
|
+
|
|
75
|
+
When ``fallback_adapter_path`` is ``None`` the adapter is always
|
|
76
|
+
returned as-is (no fallback). Otherwise the standard CUDA +
|
|
77
|
+
flash-attn check runs and falls back to the declared class on
|
|
78
|
+
incompatible hardware.
|
|
79
|
+
"""
|
|
80
|
+
if cls.fallback_adapter_path is None:
|
|
81
|
+
return cls(**kwargs)
|
|
82
|
+
|
|
83
|
+
# Import lazily to avoid circular deps (core.inference -> core.loader -> base)
|
|
84
|
+
from sie_server.core.inference import is_flash_attention_available
|
|
85
|
+
|
|
86
|
+
if device.startswith("cuda") and is_flash_attention_available():
|
|
87
|
+
return cls(**kwargs)
|
|
88
|
+
|
|
89
|
+
# Resolve fallback class from string path
|
|
90
|
+
fallback_class = _import_adapter_class(cls.fallback_adapter_path)
|
|
91
|
+
merged = {**kwargs, **cls.fallback_kwargs_overrides}
|
|
92
|
+
|
|
93
|
+
if not device.startswith("cuda"):
|
|
94
|
+
logger.info(
|
|
95
|
+
"%s requires CUDA. Using %s for device '%s'. "
|
|
96
|
+
"For optimal performance, use a Linux system with NVIDIA GPU.",
|
|
97
|
+
cls.__name__,
|
|
98
|
+
fallback_class.__name__,
|
|
99
|
+
device,
|
|
100
|
+
)
|
|
101
|
+
else:
|
|
102
|
+
logger.warning(
|
|
103
|
+
"Flash Attention unavailable (requires Ampere+ GPU and flash-attn package). "
|
|
104
|
+
"Using %s for device '%s'. "
|
|
105
|
+
"To install on Linux: pip install sie-server[flash-attn]",
|
|
106
|
+
fallback_class.__name__,
|
|
107
|
+
device,
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
return fallback_class(**merged)
|
|
111
|
+
|
|
112
|
+
# -- Common unload -------------------------------------------------------
|
|
113
|
+
def unload(self) -> None:
|
|
114
|
+
"""Unload model weights and free GPU memory."""
|
|
115
|
+
import torch as _torch
|
|
116
|
+
|
|
117
|
+
device = getattr(self, "_device", None)
|
|
118
|
+
|
|
119
|
+
# Clear standard fields
|
|
120
|
+
for attr in ("_model", "_tokenizer"):
|
|
121
|
+
if getattr(self, attr, None) is not None:
|
|
122
|
+
setattr(self, attr, None)
|
|
123
|
+
|
|
124
|
+
# Clear subclass-specific fields
|
|
125
|
+
for attr in self._extra_fields_to_clear():
|
|
126
|
+
if hasattr(self, attr):
|
|
127
|
+
setattr(self, attr, None)
|
|
128
|
+
|
|
129
|
+
self._device = None
|
|
130
|
+
|
|
131
|
+
gc.collect()
|
|
132
|
+
if device and str(device).startswith("cuda"):
|
|
133
|
+
_torch.cuda.empty_cache()
|
|
134
|
+
|
|
135
|
+
def _extra_fields_to_clear(self) -> list[str]:
|
|
136
|
+
"""Override to list additional instance attributes to clear on unload."""
|
|
137
|
+
return []
|
|
138
|
+
|
|
139
|
+
# -- Shared utilities ----------------------------------------------------
|
|
140
|
+
def _resolve_dtype(self) -> torch.dtype:
|
|
141
|
+
"""Map ``self._compute_precision`` to a ``torch.dtype``."""
|
|
142
|
+
import torch as _torch
|
|
143
|
+
|
|
144
|
+
dtype_map: dict[str, torch.dtype] = {
|
|
145
|
+
"float16": _torch.float16,
|
|
146
|
+
"bfloat16": _torch.bfloat16,
|
|
147
|
+
"float32": _torch.float32,
|
|
148
|
+
}
|
|
149
|
+
return dtype_map.get(
|
|
150
|
+
getattr(self, "_compute_precision", "float16"),
|
|
151
|
+
_torch.float16,
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
def get_preprocessor(self) -> Any:
|
|
155
|
+
"""Return ``CharCountPreprocessor`` for cost estimation."""
|
|
156
|
+
from sie_server.core.preprocessor import CharCountPreprocessor
|
|
157
|
+
|
|
158
|
+
return CharCountPreprocessor(
|
|
159
|
+
model_name=getattr(self, "_model_name_or_path", ""),
|
|
160
|
+
)
|
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING, Any
|
|
4
|
+
|
|
5
|
+
if TYPE_CHECKING:
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from sie_server.types.inputs import Item
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
# ---------------------------------------------------------------------------
|
|
12
|
+
# RoPE utilities (eliminates 7 identical copies)
|
|
13
|
+
# ---------------------------------------------------------------------------
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def rotate_half(x: torch.Tensor) -> torch.Tensor:
|
|
17
|
+
"""Rotate half the hidden dims of the input for RoPE."""
|
|
18
|
+
x1 = x[..., : x.shape[-1] // 2]
|
|
19
|
+
x2 = x[..., x.shape[-1] // 2 :]
|
|
20
|
+
import torch as _torch
|
|
21
|
+
|
|
22
|
+
return _torch.cat((-x2, x1), dim=-1)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def apply_rotary_pos_emb(
|
|
26
|
+
q: torch.Tensor,
|
|
27
|
+
k: torch.Tensor,
|
|
28
|
+
cos: torch.Tensor,
|
|
29
|
+
sin: torch.Tensor,
|
|
30
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
31
|
+
"""Apply Rotary Position Embedding to query and key tensors.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
q: Query tensor ``[total_tokens, num_heads, head_dim]``.
|
|
35
|
+
k: Key tensor ``[total_tokens, num_kv_heads, head_dim]``.
|
|
36
|
+
cos: Cosine part ``[total_tokens, head_dim]``.
|
|
37
|
+
sin: Sine part ``[total_tokens, head_dim]``.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
Rotated query and key tensors.
|
|
41
|
+
"""
|
|
42
|
+
cos = cos.unsqueeze(1).to(q.dtype)
|
|
43
|
+
sin = sin.unsqueeze(1).to(q.dtype)
|
|
44
|
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
|
45
|
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
|
46
|
+
return q_embed, k_embed
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
# ---------------------------------------------------------------------------
|
|
50
|
+
# Output type validation (eliminates 9+ copies)
|
|
51
|
+
# ---------------------------------------------------------------------------
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def validate_output_types(
|
|
55
|
+
output_types: list[str],
|
|
56
|
+
supported: set[str],
|
|
57
|
+
adapter_name: str,
|
|
58
|
+
) -> None:
|
|
59
|
+
"""Raise ``ValueError`` if any requested output type is unsupported."""
|
|
60
|
+
unsupported = set(output_types) - supported
|
|
61
|
+
if unsupported:
|
|
62
|
+
msg = f"Unsupported output types: {unsupported}. {adapter_name} only supports {supported!r}."
|
|
63
|
+
raise ValueError(msg)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
# ---------------------------------------------------------------------------
|
|
67
|
+
# Text extraction (eliminates 9+ copies)
|
|
68
|
+
# ---------------------------------------------------------------------------
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def extract_texts(
|
|
72
|
+
items: list[Item],
|
|
73
|
+
instruction: str | None,
|
|
74
|
+
*,
|
|
75
|
+
is_query: bool,
|
|
76
|
+
query_template: str | None = None,
|
|
77
|
+
doc_template: str | None = None,
|
|
78
|
+
always_apply_template: bool = False,
|
|
79
|
+
err_msg: str = "Item must have text",
|
|
80
|
+
) -> list[str]:
|
|
81
|
+
"""Extract text from items, applying query/doc templates.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
items: List of input items.
|
|
85
|
+
instruction: Optional instruction string.
|
|
86
|
+
is_query: Whether items are queries (selects template).
|
|
87
|
+
query_template: Template for queries, e.g. ``"query: {text}"``.
|
|
88
|
+
doc_template: Template for documents, e.g. ``"passage: {text}"``.
|
|
89
|
+
always_apply_template: If ``True``, always apply the selected
|
|
90
|
+
template even when there is no instruction. This matches
|
|
91
|
+
the Nomic adapter behavior where task prefixes are mandatory.
|
|
92
|
+
err_msg: Error message when ``item.text`` is ``None``.
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
List of formatted text strings.
|
|
96
|
+
"""
|
|
97
|
+
texts: list[str] = []
|
|
98
|
+
template = query_template if is_query else doc_template
|
|
99
|
+
|
|
100
|
+
for item in items:
|
|
101
|
+
if item.text is None:
|
|
102
|
+
raise ValueError(err_msg)
|
|
103
|
+
|
|
104
|
+
text = item.text
|
|
105
|
+
|
|
106
|
+
if template and (instruction or always_apply_template):
|
|
107
|
+
text = template.format(text=text, instruction=instruction or "")
|
|
108
|
+
elif instruction:
|
|
109
|
+
text = f"{instruction} {text}"
|
|
110
|
+
|
|
111
|
+
texts.append(text)
|
|
112
|
+
return texts
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def extract_text(item: Item, *, err_msg: str = "Item must have text") -> str:
|
|
116
|
+
"""Extract text from a single item (cross-encoder use)."""
|
|
117
|
+
if item.text is None:
|
|
118
|
+
raise ValueError(err_msg)
|
|
119
|
+
return item.text
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
# ---------------------------------------------------------------------------
|
|
123
|
+
# Runtime options resolution (eliminates 5+ copies)
|
|
124
|
+
# ---------------------------------------------------------------------------
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def resolve_embedding_options(
|
|
128
|
+
options: dict[str, Any] | None,
|
|
129
|
+
*,
|
|
130
|
+
default_normalize: bool,
|
|
131
|
+
default_pooling: str,
|
|
132
|
+
default_query_template: str | None,
|
|
133
|
+
default_doc_template: str | None,
|
|
134
|
+
) -> tuple[bool, str, str | None, str | None]:
|
|
135
|
+
"""Resolve runtime options with adapter defaults as fallback.
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
``(normalize, pooling, query_template, doc_template)``
|
|
139
|
+
"""
|
|
140
|
+
opts = options or {}
|
|
141
|
+
return (
|
|
142
|
+
opts.get("normalize", default_normalize),
|
|
143
|
+
opts.get("pooling", default_pooling),
|
|
144
|
+
opts.get("query_template", default_query_template),
|
|
145
|
+
opts.get("doc_template", default_doc_template),
|
|
146
|
+
)
|