orchestrator-core 4.4.1__py3-none-any.whl → 4.5.1a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. orchestrator/__init__.py +26 -2
  2. orchestrator/agentic_app.py +84 -0
  3. orchestrator/api/api_v1/api.py +10 -0
  4. orchestrator/api/api_v1/endpoints/search.py +277 -0
  5. orchestrator/app.py +32 -0
  6. orchestrator/cli/index_llm.py +73 -0
  7. orchestrator/cli/main.py +22 -1
  8. orchestrator/cli/resize_embedding.py +135 -0
  9. orchestrator/cli/search_explore.py +208 -0
  10. orchestrator/cli/speedtest.py +151 -0
  11. orchestrator/db/models.py +37 -1
  12. orchestrator/llm_settings.py +51 -0
  13. orchestrator/migrations/versions/schema/2025-08-12_52b37b5b2714_search_index_model_for_llm_integration.py +95 -0
  14. orchestrator/schemas/search.py +117 -0
  15. orchestrator/search/__init__.py +12 -0
  16. orchestrator/search/agent/__init__.py +8 -0
  17. orchestrator/search/agent/agent.py +47 -0
  18. orchestrator/search/agent/prompts.py +62 -0
  19. orchestrator/search/agent/state.py +8 -0
  20. orchestrator/search/agent/tools.py +121 -0
  21. orchestrator/search/core/__init__.py +0 -0
  22. orchestrator/search/core/embedding.py +64 -0
  23. orchestrator/search/core/exceptions.py +22 -0
  24. orchestrator/search/core/types.py +281 -0
  25. orchestrator/search/core/validators.py +27 -0
  26. orchestrator/search/docs/index.md +37 -0
  27. orchestrator/search/docs/running_local_text_embedding_inference.md +45 -0
  28. orchestrator/search/filters/__init__.py +27 -0
  29. orchestrator/search/filters/base.py +272 -0
  30. orchestrator/search/filters/date_filters.py +75 -0
  31. orchestrator/search/filters/definitions.py +93 -0
  32. orchestrator/search/filters/ltree_filters.py +43 -0
  33. orchestrator/search/filters/numeric_filter.py +60 -0
  34. orchestrator/search/indexing/__init__.py +3 -0
  35. orchestrator/search/indexing/indexer.py +323 -0
  36. orchestrator/search/indexing/registry.py +88 -0
  37. orchestrator/search/indexing/tasks.py +53 -0
  38. orchestrator/search/indexing/traverse.py +322 -0
  39. orchestrator/search/retrieval/__init__.py +3 -0
  40. orchestrator/search/retrieval/builder.py +108 -0
  41. orchestrator/search/retrieval/engine.py +152 -0
  42. orchestrator/search/retrieval/pagination.py +83 -0
  43. orchestrator/search/retrieval/retriever.py +447 -0
  44. orchestrator/search/retrieval/utils.py +106 -0
  45. orchestrator/search/retrieval/validation.py +174 -0
  46. orchestrator/search/schemas/__init__.py +0 -0
  47. orchestrator/search/schemas/parameters.py +116 -0
  48. orchestrator/search/schemas/results.py +63 -0
  49. orchestrator/services/settings_env_variables.py +2 -2
  50. orchestrator/settings.py +1 -1
  51. {orchestrator_core-4.4.1.dist-info → orchestrator_core-4.5.1a1.dist-info}/METADATA +8 -3
  52. {orchestrator_core-4.4.1.dist-info → orchestrator_core-4.5.1a1.dist-info}/RECORD +54 -11
  53. {orchestrator_core-4.4.1.dist-info → orchestrator_core-4.5.1a1.dist-info}/WHEEL +0 -0
  54. {orchestrator_core-4.4.1.dist-info → orchestrator_core-4.5.1a1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,53 @@
1
+ import structlog
2
+ from sqlalchemy.orm import Query
3
+
4
+ from orchestrator.db import db
5
+ from orchestrator.search.core.types import EntityType
6
+ from orchestrator.search.indexing.indexer import Indexer
7
+ from orchestrator.search.indexing.registry import ENTITY_CONFIG_REGISTRY
8
+
9
+ logger = structlog.get_logger(__name__)
10
+
11
+
12
+ def run_indexing_for_entity(
13
+ entity_kind: EntityType,
14
+ entity_id: str | None = None,
15
+ dry_run: bool = False,
16
+ force_index: bool = False,
17
+ chunk_size: int = 1000,
18
+ ) -> None:
19
+ """Stream and index entities for the given kind.
20
+
21
+ Builds a streaming query via the entity's registry config, disables ORM eager
22
+ loads when applicable and delegates processing to `Indexer`.
23
+
24
+ Args:
25
+ entity_kind (EntityType): The entity type to index (must exist in
26
+ `ENTITY_CONFIG_REGISTRY`).
27
+ entity_id (Optional[str]): If provided, restricts indexing to a single
28
+ entity (UUID string).
29
+ dry_run (bool): When True, runs the full pipeline without performing
30
+ writes or external embedding calls.
31
+ force_index (bool): When True, re-indexes all fields regardless of
32
+ existing hashes.
33
+ chunk_size (int): Number of rows fetched per round-trip and passed to
34
+ the indexer per batch.
35
+
36
+ Returns:
37
+ None
38
+ """
39
+ config = ENTITY_CONFIG_REGISTRY[entity_kind]
40
+
41
+ q = config.get_all_query(entity_id)
42
+
43
+ if isinstance(q, Query):
44
+ q = q.enable_eagerloads(False)
45
+ stmt = q.statement
46
+ else:
47
+ stmt = q
48
+
49
+ stmt = stmt.execution_options(stream_results=True, yield_per=chunk_size)
50
+ entities = db.session.execute(stmt).scalars()
51
+
52
+ indexer = Indexer(config=config, dry_run=dry_run, force_index=force_index, chunk_size=chunk_size)
53
+ indexer.run(entities)
@@ -0,0 +1,322 @@
1
+ import re
2
+ from abc import ABC, abstractmethod
3
+ from collections.abc import Iterable
4
+ from enum import Enum
5
+ from typing import Any, cast, get_args
6
+ from uuid import uuid4
7
+
8
+ import structlog
9
+
10
+ from orchestrator.db import ProcessTable, ProductTable, SubscriptionTable, WorkflowTable
11
+ from orchestrator.domain import (
12
+ SUBSCRIPTION_MODEL_REGISTRY,
13
+ SubscriptionModel,
14
+ )
15
+ from orchestrator.domain.base import ProductBlockModel, ProductModel
16
+ from orchestrator.domain.lifecycle import (
17
+ lookup_specialized_type,
18
+ )
19
+ from orchestrator.schemas.process import ProcessSchema
20
+ from orchestrator.schemas.workflow import WorkflowSchema
21
+ from orchestrator.search.core.exceptions import ModelLoadError, ProductNotInRegistryError
22
+ from orchestrator.search.core.types import ExtractedField, FieldType
23
+ from orchestrator.types import SubscriptionLifecycle
24
+
25
+ logger = structlog.get_logger(__name__)
26
+
27
+ DatabaseEntity = SubscriptionTable | ProductTable | ProcessTable | WorkflowTable
28
+
29
+
30
+ class BaseTraverser(ABC):
31
+ """Base class for traversing database models and extracting searchable fields."""
32
+
33
+ _LTREE_SEPARATOR = "."
34
+ _MAX_DEPTH = 40
35
+
36
+ @classmethod
37
+ def get_fields(cls, entity: DatabaseEntity, pk_name: str, root_name: str) -> list[ExtractedField]:
38
+ """Main entry point for extracting fields from an entity. Default implementation delegates to _load_model."""
39
+ try:
40
+ model = cls._load_model(entity)
41
+ if model is None:
42
+ return []
43
+ return sorted(cls.traverse(model, root_name), key=lambda f: f.path)
44
+
45
+ except (ProductNotInRegistryError, ModelLoadError) as e:
46
+ entity_id = getattr(entity, pk_name, "unknown")
47
+ logger.error(f"Failed to extract fields from {entity.__class__.__name__}", id=str(entity_id), error=str(e))
48
+ return []
49
+
50
+ @classmethod
51
+ def traverse(cls, instance: Any, path: str = "") -> Iterable[ExtractedField]:
52
+ """Walks the fields of a Pydantic model, dispatching each to a field handler."""
53
+ model_class = type(instance)
54
+
55
+ # Handle both standard and computed fields from the Pydantic model
56
+ all_fields = model_class.model_fields.copy()
57
+ all_fields.update(getattr(model_class, "__pydantic_computed_fields__", {}))
58
+
59
+ for name, field in all_fields.items():
60
+ try:
61
+ value = getattr(instance, name, None)
62
+ except Exception as e:
63
+ logger.error(f"Failed to access field '{name}' on {model_class.__name__}", error=str(e))
64
+ continue
65
+ new_path = f"{path}{cls._LTREE_SEPARATOR}{name}" if path else name
66
+ annotation = field.annotation if hasattr(field, "annotation") else field.return_type
67
+ yield from cls._yield_fields_for_value(value, new_path, annotation)
68
+
69
+ @classmethod
70
+ def _yield_fields_for_value(cls, value: Any, path: str, annotation: Any) -> Iterable[ExtractedField]:
71
+ """Yields fields for a given value based on its type (model, list, or scalar)."""
72
+ if value is None:
73
+ return
74
+
75
+ # If the value is a list, pass it to the list traverser
76
+ if isinstance(value, list):
77
+ if element_annotation := get_args(annotation):
78
+ yield from cls._traverse_list(value, path, element_annotation[0])
79
+ return
80
+
81
+ # If the value is another Pydantic model, recurse into it
82
+ if hasattr(type(value), "model_fields"):
83
+ yield from cls.traverse(value, path)
84
+ return
85
+
86
+ ftype = FieldType.from_type_hint(annotation)
87
+
88
+ if isinstance(value, Enum):
89
+ yield ExtractedField(path, str(value.value), ftype)
90
+ else:
91
+ yield ExtractedField(path, str(value), ftype)
92
+
93
+ @classmethod
94
+ def _traverse_list(cls, items: list[Any], path: str, element_annotation: Any) -> Iterable[ExtractedField]:
95
+ """Recursively traverses items in a list."""
96
+ for i, item in enumerate(items):
97
+ item_path = f"{path}.{i}"
98
+ yield from cls._yield_fields_for_value(item, item_path, element_annotation)
99
+
100
+ @classmethod
101
+ def _load_model_with_schema(cls, entity: Any, schema_class: type[Any], pk_name: str) -> Any:
102
+ """Generic helper for loading models using Pydantic schema validation."""
103
+ try:
104
+ return schema_class.model_validate(entity)
105
+ except Exception as e:
106
+ entity_id = getattr(entity, pk_name, "unknown")
107
+ raise ModelLoadError(f"Failed to load {schema_class.__name__} for {pk_name} '{entity_id}'") from e
108
+
109
+ @classmethod
110
+ @abstractmethod
111
+ def _load_model(cls, entity: Any) -> Any: ...
112
+
113
+
114
+ class SubscriptionTraverser(BaseTraverser):
115
+ """Traverser for subscription entities using full Pydantic model extraction."""
116
+
117
+ @classmethod
118
+ def _load_model(cls, sub: SubscriptionTable) -> SubscriptionModel | None:
119
+ base_model_cls = SUBSCRIPTION_MODEL_REGISTRY.get(sub.product.name)
120
+ if not base_model_cls:
121
+ raise ProductNotInRegistryError(f"Product '{sub.product.name}' not in registry.")
122
+
123
+ specialized_model_cls = cast(type[SubscriptionModel], lookup_specialized_type(base_model_cls, sub.status))
124
+
125
+ try:
126
+ return specialized_model_cls.from_subscription(sub.subscription_id)
127
+ except Exception as e:
128
+ raise ModelLoadError(f"Failed to load model for subscription_id '{sub.subscription_id}'") from e
129
+
130
+
131
+ class ProductTraverser(BaseTraverser):
132
+ """Traverser for product entities using a template SubscriptionModel instance."""
133
+
134
+ @classmethod
135
+ def _sanitize_for_ltree(cls, name: str) -> str:
136
+ """Sanitizes a string to be a valid ltree path label."""
137
+ # Convert to lowercase
138
+ sanitized = name.lower()
139
+
140
+ # Replace all non-alphanumeric (and non-underscore) characters with an underscore
141
+ sanitized = re.sub(r"[^a-z0-9_]", "_", sanitized)
142
+
143
+ # Collapse multiple underscores into a single one
144
+ sanitized = re.sub(r"__+", "_", sanitized)
145
+
146
+ # Remove leading or trailing underscores
147
+ sanitized = sanitized.strip("_")
148
+
149
+ # Handle cases where the name was only invalid characters
150
+ if not sanitized:
151
+ return "unnamed_product"
152
+
153
+ return sanitized
154
+
155
+ @classmethod
156
+ def get_fields(cls, entity: ProductTable, pk_name: str, root_name: str) -> list[ExtractedField]: # type: ignore[override]
157
+ """Extracts fields by creating a template SubscriptionModel instance for the product.
158
+
159
+ Extracts product metadata and block schema structure.
160
+ """
161
+ try:
162
+ model = cls._load_model(entity)
163
+
164
+ if not model:
165
+ return []
166
+
167
+ fields: list[ExtractedField] = []
168
+
169
+ product_fields = cls.traverse(model.product, root_name)
170
+ fields.extend(product_fields)
171
+
172
+ product_name = cls._sanitize_for_ltree(model.product.name)
173
+
174
+ product_block_root = f"{root_name}.{product_name}.product_block"
175
+
176
+ # Extract product block schema structure
177
+ model_class = type(model)
178
+ product_block_fields = getattr(model_class, "_product_block_fields_", {})
179
+
180
+ for field_name in product_block_fields:
181
+ block_value = getattr(model, field_name, None)
182
+ if block_value is not None:
183
+ block_path = f"{product_block_root}.{field_name}"
184
+ schema_fields = cls._extract_block_schema(block_value, block_path)
185
+ fields.extend(schema_fields)
186
+
187
+ return sorted(fields, key=lambda f: f.path)
188
+
189
+ except (ProductNotInRegistryError, ModelLoadError) as e:
190
+ entity_id = getattr(entity, pk_name, "unknown")
191
+ logger.error(f"Failed to extract fields from {entity.__class__.__name__}", id=str(entity_id), error=str(e))
192
+ return []
193
+
194
+ @classmethod
195
+ def _extract_block_schema(cls, block_instance: ProductBlockModel, block_path: str) -> list[ExtractedField]:
196
+ """Extract schema information from a block instance, returning field names as RESOURCE_TYPE."""
197
+ fields = []
198
+
199
+ # Add the block itself as a BLOCK type
200
+ block_name = block_path.split(cls._LTREE_SEPARATOR)[-1]
201
+ fields.append(ExtractedField(path=block_path, value=block_name, value_type=FieldType.BLOCK))
202
+
203
+ # Extract all field names from the block as RESOURCE_TYPE
204
+ if hasattr(type(block_instance), "model_fields"):
205
+ all_fields = type(block_instance).model_fields
206
+ computed_fields = getattr(block_instance, "__pydantic_computed_fields__", None)
207
+ if computed_fields:
208
+ all_fields.update(computed_fields)
209
+
210
+ for field_name in all_fields:
211
+ field_value = getattr(block_instance, field_name, None)
212
+ field_path = f"{block_path}.{field_name}"
213
+
214
+ # If it's a nested block, recurse
215
+ if field_value is not None and isinstance(field_value, ProductBlockModel):
216
+ nested_fields = cls._extract_block_schema(field_value, field_path)
217
+ fields.extend(nested_fields)
218
+ elif field_value is not None and isinstance(field_value, list):
219
+ # Handle list of blocks
220
+ if field_value and isinstance(field_value[0], ProductBlockModel):
221
+ # For lists, we still add the list field as a resource type
222
+ fields.append(
223
+ ExtractedField(path=field_path, value=field_name, value_type=FieldType.RESOURCE_TYPE)
224
+ )
225
+ # And potentially traverse the first item for schema
226
+ first_item_path = f"{field_path}{cls._LTREE_SEPARATOR}0"
227
+ nested_fields = cls._extract_block_schema(field_value[0], first_item_path)
228
+ fields.extend(nested_fields)
229
+ else:
230
+ fields.append(
231
+ ExtractedField(path=field_path, value=field_name, value_type=FieldType.RESOURCE_TYPE)
232
+ )
233
+ else:
234
+ # Regular fields are resource types
235
+ fields.append(ExtractedField(path=field_path, value=field_name, value_type=FieldType.RESOURCE_TYPE))
236
+
237
+ return fields
238
+
239
+ @classmethod
240
+ def _load_model(cls, product: ProductTable) -> SubscriptionModel | None:
241
+ """Creates a template instance of a SubscriptionModel for a given product.
242
+
243
+ This allows us to traverse the product's defined block structure, even
244
+ without a real subscription instance in the database.
245
+ """
246
+ # Find the SubscriptionModel class associated with this product's name.
247
+ domain_model_cls = SUBSCRIPTION_MODEL_REGISTRY.get(product.name)
248
+ if not domain_model_cls:
249
+ raise ProductNotInRegistryError(f"Product '{product.name}' not in registry.")
250
+
251
+ # Get the initial lifecycle version of that class, as it represents the base structure.
252
+ try:
253
+ subscription_model_cls = cast(
254
+ type[SubscriptionModel], lookup_specialized_type(domain_model_cls, SubscriptionLifecycle.INITIAL)
255
+ )
256
+ except Exception:
257
+ subscription_model_cls = domain_model_cls
258
+
259
+ try:
260
+ product_model = ProductModel(
261
+ product_id=product.product_id,
262
+ name=product.name,
263
+ description=product.description,
264
+ product_type=product.product_type,
265
+ tag=product.tag,
266
+ status=product.status,
267
+ )
268
+
269
+ # Generate a fake subscription ID for the template
270
+ subscription_id = uuid4()
271
+
272
+ # Get fixed inputs for the product
273
+ fixed_inputs = {fi.name: fi.value for fi in product.fixed_inputs}
274
+
275
+ # Initialize product blocks
276
+ instances = subscription_model_cls._init_instances(subscription_id)
277
+
278
+ return subscription_model_cls(
279
+ product=product_model,
280
+ customer_id="traverser_template",
281
+ subscription_id=subscription_id,
282
+ description="Template for schema traversal",
283
+ status=SubscriptionLifecycle.INITIAL,
284
+ insync=False,
285
+ start_date=None,
286
+ end_date=None,
287
+ note=None,
288
+ version=1,
289
+ **fixed_inputs,
290
+ **instances,
291
+ )
292
+ except Exception:
293
+ logger.exception("Failed to instantiate template model for product", product_name=product.name)
294
+ return None
295
+
296
+
297
+ class ProcessTraverser(BaseTraverser):
298
+ """Traverser for process entities using ProcessSchema model.
299
+
300
+ Note: Currently extracts only top-level process fields. Could be extended to include:
301
+ - Related subscriptions (entity.subscriptions)
302
+ - Related workflow information beyond workflow_name
303
+ """
304
+
305
+ @classmethod
306
+ def _load_model(cls, process: ProcessTable) -> ProcessSchema:
307
+ """Load process model using ProcessSchema."""
308
+ return cls._load_model_with_schema(process, ProcessSchema, "process_id")
309
+
310
+
311
+ class WorkflowTraverser(BaseTraverser):
312
+ """Traverser for workflow entities using WorkflowSchema model.
313
+
314
+ Note: Currently extracts only top-level workflow fields. Could be extended to include:
315
+ - Related products (entity.products) - each with their own block structures
316
+ - Related processes (entity.processes) - each with their own process data
317
+ """
318
+
319
+ @classmethod
320
+ def _load_model(cls, workflow: WorkflowTable) -> WorkflowSchema:
321
+ """Load workflow model using WorkflowSchema."""
322
+ return cls._load_model_with_schema(workflow, WorkflowSchema, "workflow_id")
@@ -0,0 +1,3 @@
1
+ from .engine import execute_search
2
+
3
+ __all__ = ["execute_search"]
@@ -0,0 +1,108 @@
1
+ from collections import defaultdict
2
+ from typing import Sequence
3
+
4
+ from sqlalchemy import Select, String, cast, func, select
5
+ from sqlalchemy.engine import Row
6
+
7
+ from orchestrator.db.models import AiSearchIndex
8
+ from orchestrator.search.core.types import EntityType, FieldType, FilterOp, UIType
9
+ from orchestrator.search.filters import LtreeFilter
10
+ from orchestrator.search.schemas.parameters import BaseSearchParameters
11
+ from orchestrator.search.schemas.results import ComponentInfo, LeafInfo
12
+
13
+
14
+ def create_path_autocomplete_lquery(prefix: str) -> str:
15
+ """Create the lquery pattern for a multi-level path autocomplete search."""
16
+ return f"{prefix}*.*"
17
+
18
+
19
+ def build_candidate_query(params: BaseSearchParameters) -> Select:
20
+ """Build the base query for retrieving candidate entities.
21
+
22
+ Constructs a `SELECT` statement that retrieves distinct `entity_id` values
23
+ from the index table for the given entity type, applying any structured
24
+ filters from the provided search parameters.
25
+
26
+ Parameters
27
+ ----------
28
+ params : BaseSearchParameters
29
+ The search parameters containing the entity type and optional filters.
30
+
31
+ Returns:
32
+ -------
33
+ Select
34
+ The SQLAlchemy `Select` object representing the query.
35
+ """
36
+ stmt = select(AiSearchIndex.entity_id).where(AiSearchIndex.entity_type == params.entity_type.value).distinct()
37
+
38
+ if params.filters is not None:
39
+ entity_id_col = AiSearchIndex.entity_id
40
+ stmt = stmt.where(
41
+ params.filters.to_expression(
42
+ entity_id_col,
43
+ entity_type_value=params.entity_type.value,
44
+ )
45
+ )
46
+
47
+ return stmt
48
+
49
+
50
+ def build_paths_query(entity_type: EntityType, prefix: str | None = None, q: str | None = None) -> Select:
51
+ """Build the query for retrieving paths and their value types for leaves/components processing."""
52
+ stmt = select(AiSearchIndex.path, AiSearchIndex.value_type).where(AiSearchIndex.entity_type == entity_type.value)
53
+
54
+ if prefix:
55
+ lquery_pattern = create_path_autocomplete_lquery(prefix)
56
+ ltree_filter = LtreeFilter(op=FilterOp.MATCHES_LQUERY, value=lquery_pattern)
57
+ stmt = stmt.where(ltree_filter.to_expression(AiSearchIndex.path, path=""))
58
+
59
+ stmt = stmt.group_by(AiSearchIndex.path, AiSearchIndex.value_type)
60
+
61
+ if q:
62
+ score = func.similarity(cast(AiSearchIndex.path, String), q)
63
+ stmt = stmt.order_by(score.desc(), AiSearchIndex.path)
64
+ else:
65
+ stmt = stmt.order_by(AiSearchIndex.path)
66
+
67
+ return stmt
68
+
69
+
70
+ def process_path_rows(rows: Sequence[Row]) -> tuple[list[LeafInfo], list[ComponentInfo]]:
71
+ """Process query results to extract leaves and components information.
72
+
73
+ Parameters
74
+ ----------
75
+ rows : Sequence[Row]
76
+ Database rows containing path and value_type information
77
+
78
+ Returns:
79
+ -------
80
+ tuple[list[LeafInfo], list[ComponentInfo]]
81
+ Processed leaves and components
82
+ """
83
+ leaves_dict: dict[str, set[UIType]] = defaultdict(set)
84
+ components_set: set[str] = set()
85
+
86
+ for row in rows:
87
+ path, value_type = row
88
+
89
+ path_str = str(path)
90
+ path_segments = path_str.split(".")
91
+
92
+ # Remove numeric segments
93
+ clean_segments = [seg for seg in path_segments if not seg.isdigit()]
94
+
95
+ if clean_segments:
96
+ # Last segment is a leaf
97
+ leaf_name = clean_segments[-1]
98
+ ui_type = UIType.from_field_type(FieldType(value_type))
99
+ leaves_dict[leaf_name].add(ui_type)
100
+
101
+ # All segments except the first/last are components
102
+ for component in clean_segments[1:-1]:
103
+ components_set.add(component)
104
+
105
+ leaves = [LeafInfo(name=leaf, ui_types=list(types)) for leaf, types in leaves_dict.items()]
106
+ components = [ComponentInfo(name=component, ui_types=[UIType.COMPONENT]) for component in sorted(components_set)]
107
+
108
+ return leaves, components
@@ -0,0 +1,152 @@
1
+ from collections.abc import Sequence
2
+
3
+ import structlog
4
+ from sqlalchemy.engine.row import RowMapping
5
+ from sqlalchemy.orm import Session
6
+
7
+ from orchestrator.search.core.types import FilterOp, SearchMetadata
8
+ from orchestrator.search.filters import FilterTree, LtreeFilter
9
+ from orchestrator.search.schemas.parameters import BaseSearchParameters
10
+ from orchestrator.search.schemas.results import MatchingField, SearchResponse, SearchResult
11
+
12
+ from .builder import build_candidate_query
13
+ from .pagination import PaginationParams
14
+ from .retriever import Retriever
15
+ from .utils import generate_highlight_indices
16
+
17
+ logger = structlog.get_logger(__name__)
18
+
19
+
20
+ def _format_response(
21
+ db_rows: Sequence[RowMapping], search_params: BaseSearchParameters, metadata: SearchMetadata
22
+ ) -> SearchResponse:
23
+ """Format database query results into a `SearchResponse`.
24
+
25
+ Converts raw SQLAlchemy `RowMapping` objects into `SearchResult` instances,
26
+ including highlight metadata if present in the database results.
27
+
28
+ Parameters
29
+ ----------
30
+ db_rows : Sequence[RowMapping]
31
+ The rows returned from the executed SQLAlchemy query.
32
+
33
+ Returns:
34
+ -------
35
+ SearchResponse
36
+ A list of `SearchResult` objects containing entity IDs, scores, and
37
+ optional highlight information.
38
+ """
39
+
40
+ if not db_rows:
41
+ return SearchResponse(results=[], metadata=metadata)
42
+
43
+ user_query = search_params.query
44
+
45
+ results = []
46
+ for row in db_rows:
47
+ matching_field = None
48
+
49
+ if user_query and row.get("highlight_text") and row.get("highlight_path"):
50
+ # Text/semantic searches
51
+ text = row.highlight_text
52
+ path = row.highlight_path
53
+
54
+ if not isinstance(text, str):
55
+ text = str(text)
56
+ if not isinstance(path, str):
57
+ path = str(path)
58
+
59
+ highlight_indices = generate_highlight_indices(text, user_query) or None
60
+ matching_field = MatchingField(text=text, path=path, highlight_indices=highlight_indices)
61
+
62
+ elif not user_query and search_params.filters and metadata.search_type == "structured":
63
+ # Structured search (filter-only)
64
+ matching_field = _extract_matching_field_from_filters(search_params.filters)
65
+
66
+ results.append(
67
+ SearchResult(
68
+ entity_id=str(row.entity_id),
69
+ score=row.score,
70
+ perfect_match=row.get("perfect_match", 0),
71
+ matching_field=matching_field,
72
+ )
73
+ )
74
+ return SearchResponse(results=results, metadata=metadata)
75
+
76
+
77
+ def _extract_matching_field_from_filters(filters: FilterTree) -> MatchingField | None:
78
+ """Extract the first path filter to use as matching field for structured searches.
79
+
80
+ TODO: Should we allow a list of matched fields in the MatchingField model?
81
+ We need a different approach, probably a cross join in StructuredRetriever.
82
+ """
83
+ leaves = filters.get_all_leaves()
84
+ if len(leaves) != 1:
85
+ return None
86
+
87
+ pf = leaves[0]
88
+
89
+ if isinstance(pf.condition, LtreeFilter):
90
+ op = pf.condition.op
91
+ # Prefer the original component/pattern (validator may set path="*" and move the value)
92
+ display = str(getattr(pf.condition, "value", "") or pf.path)
93
+
94
+ # There can be no match for abscence.
95
+ if op == FilterOp.NOT_HAS_COMPONENT:
96
+ return None
97
+
98
+ return MatchingField(text=display, path=display, highlight_indices=[(0, len(display))])
99
+
100
+ # Everything thats not Ltree
101
+ val = getattr(pf.condition, "value", "")
102
+ text = "" if val is None else str(val)
103
+ return MatchingField(text=text, path=pf.path, highlight_indices=[(0, len(text))])
104
+
105
+
106
+ async def execute_search(
107
+ search_params: BaseSearchParameters,
108
+ db_session: Session,
109
+ pagination_params: PaginationParams | None = None,
110
+ ) -> SearchResponse:
111
+ """Execute a hybrid search and return ranked results.
112
+
113
+ Builds a candidate entity query based on the given search parameters,
114
+ applies the appropriate ranking strategy, and executes the final ranked
115
+ query to retrieve results.
116
+
117
+ Parameters
118
+ ----------
119
+ search_params : BaseSearchParameters
120
+ The search parameters specifying vector, fuzzy, or filter criteria.
121
+ db_session : Session
122
+ The active SQLAlchemy session for executing the query.
123
+ limit : int, optional
124
+ The maximum number of search results to return, by default 5.
125
+
126
+ Returns:
127
+ -------
128
+ SearchResponse
129
+ A list of `SearchResult` objects containing entity IDs, scores, and
130
+ optional highlight metadata.
131
+
132
+ Notes:
133
+ -----
134
+ If no vector query, filters, or fuzzy term are provided, a warning is logged
135
+ and an empty result set is returned.
136
+ """
137
+ if not search_params.vector_query and not search_params.filters and not search_params.fuzzy_term:
138
+ logger.warning("No search criteria provided (vector_query, fuzzy_term, or filters).")
139
+ return SearchResponse(results=[], metadata=SearchMetadata.empty())
140
+
141
+ candidate_query = build_candidate_query(search_params)
142
+
143
+ pagination_params = pagination_params or PaginationParams()
144
+ retriever = await Retriever.from_params(search_params, pagination_params)
145
+ logger.debug("Using retriever", retriever_type=retriever.__class__.__name__)
146
+
147
+ final_stmt = retriever.apply(candidate_query)
148
+ final_stmt = final_stmt.limit(search_params.limit)
149
+ logger.debug(final_stmt)
150
+ result = db_session.execute(final_stmt).mappings().all()
151
+
152
+ return _format_response(result, search_params, retriever.metadata)