sie-server 0.1.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (120) hide show
  1. sie_server/__init__.py +3 -0
  2. sie_server/adapters/__init__.py +9 -0
  3. sie_server/adapters/_flash_base.py +160 -0
  4. sie_server/adapters/_utils.py +146 -0
  5. sie_server/adapters/base.py +491 -0
  6. sie_server/adapters/bert_flash/__init__.py +477 -0
  7. sie_server/adapters/bert_flash_cross_encoder/__init__.py +558 -0
  8. sie_server/adapters/bge_m3/__init__.py +428 -0
  9. sie_server/adapters/bge_m3_flag/__init__.py +305 -0
  10. sie_server/adapters/bge_m3_flash/__init__.py +549 -0
  11. sie_server/adapters/clip/__init__.py +343 -0
  12. sie_server/adapters/colbert/__init__.py +1147 -0
  13. sie_server/adapters/colbert_modernbert_flash/__init__.py +594 -0
  14. sie_server/adapters/colbert_rotary_flash/__init__.py +772 -0
  15. sie_server/adapters/colpali/__init__.py +610 -0
  16. sie_server/adapters/colqwen2/__init__.py +489 -0
  17. sie_server/adapters/cross_encoder/__init__.py +252 -0
  18. sie_server/adapters/donut/__init__.py +588 -0
  19. sie_server/adapters/florence2/__init__.py +649 -0
  20. sie_server/adapters/gliclass/__init__.py +237 -0
  21. sie_server/adapters/gliner/__init__.py +305 -0
  22. sie_server/adapters/glirel/__init__.py +276 -0
  23. sie_server/adapters/grounding_dino/__init__.py +415 -0
  24. sie_server/adapters/gte_sparse_flash/__init__.py +673 -0
  25. sie_server/adapters/jina_flash_cross_encoder/__init__.py +309 -0
  26. sie_server/adapters/modernbert_flash_cross_encoder/__init__.py +476 -0
  27. sie_server/adapters/nemo_colembed/__init__.py +556 -0
  28. sie_server/adapters/nli_classification/__init__.py +239 -0
  29. sie_server/adapters/nli_classification_flash/__init__.py +300 -0
  30. sie_server/adapters/nomic_flash/__init__.py +662 -0
  31. sie_server/adapters/owlv2/__init__.py +431 -0
  32. sie_server/adapters/peft_lora_mixin.py +264 -0
  33. sie_server/adapters/pytorch_embedding/__init__.py +430 -0
  34. sie_server/adapters/qwen2_flash/__init__.py +632 -0
  35. sie_server/adapters/qwen2_flash_cross_encoder/__init__.py +569 -0
  36. sie_server/adapters/rope_flash/__init__.py +535 -0
  37. sie_server/adapters/sentence_transformer/__init__.py +385 -0
  38. sie_server/adapters/sglang/__init__.py +628 -0
  39. sie_server/adapters/siglip/__init__.py +348 -0
  40. sie_server/adapters/splade_flash/__init__.py +619 -0
  41. sie_server/adapters/xlm_roberta_flash/__init__.py +519 -0
  42. sie_server/api/__init__.py +1 -0
  43. sie_server/api/encode.py +407 -0
  44. sie_server/api/extract.py +384 -0
  45. sie_server/api/health.py +47 -0
  46. sie_server/api/helpers.py +455 -0
  47. sie_server/api/metrics.py +30 -0
  48. sie_server/api/models.py +112 -0
  49. sie_server/api/openai_compat.py +441 -0
  50. sie_server/api/openapi.py +79 -0
  51. sie_server/api/options.py +51 -0
  52. sie_server/api/root.py +17 -0
  53. sie_server/api/score.py +281 -0
  54. sie_server/api/serialization.py +65 -0
  55. sie_server/api/validation.py +60 -0
  56. sie_server/api/ws.py +333 -0
  57. sie_server/app/__init__.py +0 -0
  58. sie_server/app/app_factory.py +292 -0
  59. sie_server/app/app_state_config.py +56 -0
  60. sie_server/cli.py +252 -0
  61. sie_server/config/__init__.py +0 -0
  62. sie_server/config/engine.py +192 -0
  63. sie_server/config/model.py +295 -0
  64. sie_server/core/__init__.py +10 -0
  65. sie_server/core/adaptive_batching.py +364 -0
  66. sie_server/core/batcher.py +507 -0
  67. sie_server/core/deps.py +230 -0
  68. sie_server/core/disk_cache.py +339 -0
  69. sie_server/core/encode_pipeline.py +120 -0
  70. sie_server/core/hot_reload.py +581 -0
  71. sie_server/core/inference.py +282 -0
  72. sie_server/core/inference_output.py +171 -0
  73. sie_server/core/loader.py +424 -0
  74. sie_server/core/logging.py +110 -0
  75. sie_server/core/memory.py +435 -0
  76. sie_server/core/model_loader.py +546 -0
  77. sie_server/core/postprocessor.py +568 -0
  78. sie_server/core/postprocessor_registry.py +268 -0
  79. sie_server/core/prepared.py +306 -0
  80. sie_server/core/preprocessor/__init__.py +45 -0
  81. sie_server/core/preprocessor/base.py +133 -0
  82. sie_server/core/preprocessor/image.py +129 -0
  83. sie_server/core/preprocessor/text.py +268 -0
  84. sie_server/core/preprocessor/vision.py +946 -0
  85. sie_server/core/preprocessor_registry.py +307 -0
  86. sie_server/core/readiness.py +52 -0
  87. sie_server/core/registry.py +1239 -0
  88. sie_server/core/shutdown.py +160 -0
  89. sie_server/core/timing.py +133 -0
  90. sie_server/core/tokenizer.py +49 -0
  91. sie_server/core/watcher.py +391 -0
  92. sie_server/core/worker/__init__.py +48 -0
  93. sie_server/core/worker/handlers/__init__.py +23 -0
  94. sie_server/core/worker/handlers/base.py +125 -0
  95. sie_server/core/worker/handlers/encode.py +237 -0
  96. sie_server/core/worker/handlers/extract.py +175 -0
  97. sie_server/core/worker/handlers/score.py +115 -0
  98. sie_server/core/worker/model_worker.py +976 -0
  99. sie_server/core/worker/types.py +184 -0
  100. sie_server/main.py +30 -0
  101. sie_server/nats_pull_loop.py +1423 -0
  102. sie_server/nats_subscriber.py +231 -0
  103. sie_server/observability/__init__.py +30 -0
  104. sie_server/observability/gpu.py +202 -0
  105. sie_server/observability/metrics.py +225 -0
  106. sie_server/observability/prometheus.py +88 -0
  107. sie_server/observability/tracing.py +121 -0
  108. sie_server/static/__init__.py +1 -0
  109. sie_server/static/index.html +37 -0
  110. sie_server/types/__init__.py +47 -0
  111. sie_server/types/inputs.py +124 -0
  112. sie_server/types/openapi.py +226 -0
  113. sie_server/types/outputs.py +93 -0
  114. sie_server/types/requests.py +56 -0
  115. sie_server/types/responses.py +205 -0
  116. sie_server-0.1.10.dist-info/METADATA +50 -0
  117. sie_server-0.1.10.dist-info/RECORD +120 -0
  118. sie_server-0.1.10.dist-info/WHEEL +4 -0
  119. sie_server-0.1.10.dist-info/entry_points.txt +2 -0
  120. sie_server-0.1.10.dist-info/licenses/LICENSE +201 -0
sie_server/__init__.py ADDED
@@ -0,0 +1,3 @@
1
+ """SIE Server - Search Inference Engine."""
2
+
3
+ __version__ = "0.1.0"
@@ -0,0 +1,9 @@
1
+ """SIE Server model adapters."""
2
+
3
+ from sie_server.adapters.base import ModelAdapter, ModelCapabilities, ModelDims
4
+
5
+ __all__ = [
6
+ "ModelAdapter",
7
+ "ModelCapabilities",
8
+ "ModelDims",
9
+ ]
@@ -0,0 +1,160 @@
1
+ from __future__ import annotations
2
+
3
+ import gc
4
+ import importlib
5
+ import logging
6
+ from typing import TYPE_CHECKING, Any, ClassVar
7
+
8
+ from sie_server.adapters.base import ModelAdapter
9
+
10
+ if TYPE_CHECKING:
11
+ import torch
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ def _import_adapter_class(adapter_path: str) -> type[ModelAdapter]:
17
+ """Import an adapter class from a 'module:ClassName' string.
18
+
19
+ Uses the same resolution logic as ``loader._import_builtin_adapter`` but
20
+ accepts the **short form** used in the fallback table (e.g.
21
+ ``"sentence_transformer:SentenceTransformerDenseAdapter"``). The
22
+ ``sie_server.adapters.`` prefix is prepended automatically.
23
+
24
+ Args:
25
+ adapter_path: ``"module:ClassName"`` or
26
+ ``"sie_server.adapters.module:ClassName"`` string.
27
+
28
+ Returns:
29
+ The adapter class.
30
+
31
+ Raises:
32
+ ImportError: If module or class cannot be found.
33
+ """
34
+ if ":" not in adapter_path:
35
+ msg = f"Invalid adapter_path '{adapter_path}': expected 'module:ClassName'"
36
+ raise ImportError(msg)
37
+
38
+ module_path, class_name = adapter_path.rsplit(":", 1)
39
+
40
+ # Allow both short form ("sentence_transformer:Foo") and full form
41
+ if not module_path.startswith("sie_server."):
42
+ module_path = f"sie_server.adapters.{module_path}"
43
+
44
+ module = importlib.import_module(module_path)
45
+ if not hasattr(module, class_name):
46
+ msg = f"Adapter class '{class_name}' not found in module '{module_path}'"
47
+ raise ImportError(msg)
48
+
49
+ return getattr(module, class_name)
50
+
51
+
52
+ class FlashBaseAdapter(ModelAdapter):
53
+ """Thin base class for all flash-attention adapters.
54
+
55
+ Provides:
56
+ - Declarative fallback: set ``fallback_adapter_path`` and optionally
57
+ ``fallback_kwargs_overrides`` instead of overriding ``create_for_device``.
58
+ - Common ``unload()`` with gc + cache clearing.
59
+ - ``_resolve_dtype()`` for compute precision mapping.
60
+ - ``get_preprocessor()`` returning ``CharCountPreprocessor``.
61
+
62
+ Subclasses with custom fallback logic (e.g. SPLADEFlashAdapter) can still
63
+ override ``create_for_device()`` directly.
64
+ """
65
+
66
+ # -- Declarative fallback ------------------------------------------------
67
+ # Subclasses set these to enable automatic flash -> non-flash fallback.
68
+ fallback_adapter_path: ClassVar[str | None] = None
69
+ fallback_kwargs_overrides: ClassVar[dict[str, Any]] = {}
70
+
71
+ @classmethod
72
+ def create_for_device(cls, device: str, **kwargs: Any) -> ModelAdapter:
73
+ """Factory method for device-aware adapter instantiation.
74
+
75
+ When ``fallback_adapter_path`` is ``None`` the adapter is always
76
+ returned as-is (no fallback). Otherwise the standard CUDA +
77
+ flash-attn check runs and falls back to the declared class on
78
+ incompatible hardware.
79
+ """
80
+ if cls.fallback_adapter_path is None:
81
+ return cls(**kwargs)
82
+
83
+ # Import lazily to avoid circular deps (core.inference -> core.loader -> base)
84
+ from sie_server.core.inference import is_flash_attention_available
85
+
86
+ if device.startswith("cuda") and is_flash_attention_available():
87
+ return cls(**kwargs)
88
+
89
+ # Resolve fallback class from string path
90
+ fallback_class = _import_adapter_class(cls.fallback_adapter_path)
91
+ merged = {**kwargs, **cls.fallback_kwargs_overrides}
92
+
93
+ if not device.startswith("cuda"):
94
+ logger.info(
95
+ "%s requires CUDA. Using %s for device '%s'. "
96
+ "For optimal performance, use a Linux system with NVIDIA GPU.",
97
+ cls.__name__,
98
+ fallback_class.__name__,
99
+ device,
100
+ )
101
+ else:
102
+ logger.warning(
103
+ "Flash Attention unavailable (requires Ampere+ GPU and flash-attn package). "
104
+ "Using %s for device '%s'. "
105
+ "To install on Linux: pip install sie-server[flash-attn]",
106
+ fallback_class.__name__,
107
+ device,
108
+ )
109
+
110
+ return fallback_class(**merged)
111
+
112
+ # -- Common unload -------------------------------------------------------
113
+ def unload(self) -> None:
114
+ """Unload model weights and free GPU memory."""
115
+ import torch as _torch
116
+
117
+ device = getattr(self, "_device", None)
118
+
119
+ # Clear standard fields
120
+ for attr in ("_model", "_tokenizer"):
121
+ if getattr(self, attr, None) is not None:
122
+ setattr(self, attr, None)
123
+
124
+ # Clear subclass-specific fields
125
+ for attr in self._extra_fields_to_clear():
126
+ if hasattr(self, attr):
127
+ setattr(self, attr, None)
128
+
129
+ self._device = None
130
+
131
+ gc.collect()
132
+ if device and str(device).startswith("cuda"):
133
+ _torch.cuda.empty_cache()
134
+
135
+ def _extra_fields_to_clear(self) -> list[str]:
136
+ """Override to list additional instance attributes to clear on unload."""
137
+ return []
138
+
139
+ # -- Shared utilities ----------------------------------------------------
140
+ def _resolve_dtype(self) -> torch.dtype:
141
+ """Map ``self._compute_precision`` to a ``torch.dtype``."""
142
+ import torch as _torch
143
+
144
+ dtype_map: dict[str, torch.dtype] = {
145
+ "float16": _torch.float16,
146
+ "bfloat16": _torch.bfloat16,
147
+ "float32": _torch.float32,
148
+ }
149
+ return dtype_map.get(
150
+ getattr(self, "_compute_precision", "float16"),
151
+ _torch.float16,
152
+ )
153
+
154
+ def get_preprocessor(self) -> Any:
155
+ """Return ``CharCountPreprocessor`` for cost estimation."""
156
+ from sie_server.core.preprocessor import CharCountPreprocessor
157
+
158
+ return CharCountPreprocessor(
159
+ model_name=getattr(self, "_model_name_or_path", ""),
160
+ )
@@ -0,0 +1,146 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, Any
4
+
5
+ if TYPE_CHECKING:
6
+ import torch
7
+
8
+ from sie_server.types.inputs import Item
9
+
10
+
11
+ # ---------------------------------------------------------------------------
12
+ # RoPE utilities (eliminates 7 identical copies)
13
+ # ---------------------------------------------------------------------------
14
+
15
+
16
+ def rotate_half(x: torch.Tensor) -> torch.Tensor:
17
+ """Rotate half the hidden dims of the input for RoPE."""
18
+ x1 = x[..., : x.shape[-1] // 2]
19
+ x2 = x[..., x.shape[-1] // 2 :]
20
+ import torch as _torch
21
+
22
+ return _torch.cat((-x2, x1), dim=-1)
23
+
24
+
25
+ def apply_rotary_pos_emb(
26
+ q: torch.Tensor,
27
+ k: torch.Tensor,
28
+ cos: torch.Tensor,
29
+ sin: torch.Tensor,
30
+ ) -> tuple[torch.Tensor, torch.Tensor]:
31
+ """Apply Rotary Position Embedding to query and key tensors.
32
+
33
+ Args:
34
+ q: Query tensor ``[total_tokens, num_heads, head_dim]``.
35
+ k: Key tensor ``[total_tokens, num_kv_heads, head_dim]``.
36
+ cos: Cosine part ``[total_tokens, head_dim]``.
37
+ sin: Sine part ``[total_tokens, head_dim]``.
38
+
39
+ Returns:
40
+ Rotated query and key tensors.
41
+ """
42
+ cos = cos.unsqueeze(1).to(q.dtype)
43
+ sin = sin.unsqueeze(1).to(q.dtype)
44
+ q_embed = (q * cos) + (rotate_half(q) * sin)
45
+ k_embed = (k * cos) + (rotate_half(k) * sin)
46
+ return q_embed, k_embed
47
+
48
+
49
+ # ---------------------------------------------------------------------------
50
+ # Output type validation (eliminates 9+ copies)
51
+ # ---------------------------------------------------------------------------
52
+
53
+
54
+ def validate_output_types(
55
+ output_types: list[str],
56
+ supported: set[str],
57
+ adapter_name: str,
58
+ ) -> None:
59
+ """Raise ``ValueError`` if any requested output type is unsupported."""
60
+ unsupported = set(output_types) - supported
61
+ if unsupported:
62
+ msg = f"Unsupported output types: {unsupported}. {adapter_name} only supports {supported!r}."
63
+ raise ValueError(msg)
64
+
65
+
66
+ # ---------------------------------------------------------------------------
67
+ # Text extraction (eliminates 9+ copies)
68
+ # ---------------------------------------------------------------------------
69
+
70
+
71
+ def extract_texts(
72
+ items: list[Item],
73
+ instruction: str | None,
74
+ *,
75
+ is_query: bool,
76
+ query_template: str | None = None,
77
+ doc_template: str | None = None,
78
+ always_apply_template: bool = False,
79
+ err_msg: str = "Item must have text",
80
+ ) -> list[str]:
81
+ """Extract text from items, applying query/doc templates.
82
+
83
+ Args:
84
+ items: List of input items.
85
+ instruction: Optional instruction string.
86
+ is_query: Whether items are queries (selects template).
87
+ query_template: Template for queries, e.g. ``"query: {text}"``.
88
+ doc_template: Template for documents, e.g. ``"passage: {text}"``.
89
+ always_apply_template: If ``True``, always apply the selected
90
+ template even when there is no instruction. This matches
91
+ the Nomic adapter behavior where task prefixes are mandatory.
92
+ err_msg: Error message when ``item.text`` is ``None``.
93
+
94
+ Returns:
95
+ List of formatted text strings.
96
+ """
97
+ texts: list[str] = []
98
+ template = query_template if is_query else doc_template
99
+
100
+ for item in items:
101
+ if item.text is None:
102
+ raise ValueError(err_msg)
103
+
104
+ text = item.text
105
+
106
+ if template and (instruction or always_apply_template):
107
+ text = template.format(text=text, instruction=instruction or "")
108
+ elif instruction:
109
+ text = f"{instruction} {text}"
110
+
111
+ texts.append(text)
112
+ return texts
113
+
114
+
115
+ def extract_text(item: Item, *, err_msg: str = "Item must have text") -> str:
116
+ """Extract text from a single item (cross-encoder use)."""
117
+ if item.text is None:
118
+ raise ValueError(err_msg)
119
+ return item.text
120
+
121
+
122
+ # ---------------------------------------------------------------------------
123
+ # Runtime options resolution (eliminates 5+ copies)
124
+ # ---------------------------------------------------------------------------
125
+
126
+
127
+ def resolve_embedding_options(
128
+ options: dict[str, Any] | None,
129
+ *,
130
+ default_normalize: bool,
131
+ default_pooling: str,
132
+ default_query_template: str | None,
133
+ default_doc_template: str | None,
134
+ ) -> tuple[bool, str, str | None, str | None]:
135
+ """Resolve runtime options with adapter defaults as fallback.
136
+
137
+ Returns:
138
+ ``(normalize, pooling, query_template, doc_template)``
139
+ """
140
+ opts = options or {}
141
+ return (
142
+ opts.get("normalize", default_normalize),
143
+ opts.get("pooling", default_pooling),
144
+ opts.get("query_template", default_query_template),
145
+ opts.get("doc_template", default_doc_template),
146
+ )