orchestrator-core 4.4.2__py3-none-any.whl → 4.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. orchestrator/__init__.py +17 -2
  2. orchestrator/agentic_app.py +103 -0
  3. orchestrator/api/api_v1/api.py +14 -2
  4. orchestrator/api/api_v1/endpoints/search.py +296 -0
  5. orchestrator/app.py +32 -0
  6. orchestrator/cli/main.py +22 -1
  7. orchestrator/cli/search/__init__.py +32 -0
  8. orchestrator/cli/search/index_llm.py +73 -0
  9. orchestrator/cli/search/resize_embedding.py +135 -0
  10. orchestrator/cli/search/search_explore.py +208 -0
  11. orchestrator/cli/search/speedtest.py +151 -0
  12. orchestrator/db/models.py +37 -1
  13. orchestrator/devtools/populator.py +16 -0
  14. orchestrator/domain/base.py +2 -7
  15. orchestrator/domain/lifecycle.py +24 -7
  16. orchestrator/llm_settings.py +57 -0
  17. orchestrator/log_config.py +1 -0
  18. orchestrator/migrations/helpers.py +7 -1
  19. orchestrator/schemas/search.py +130 -0
  20. orchestrator/schemas/workflow.py +1 -0
  21. orchestrator/search/__init__.py +12 -0
  22. orchestrator/search/agent/__init__.py +21 -0
  23. orchestrator/search/agent/agent.py +62 -0
  24. orchestrator/search/agent/prompts.py +100 -0
  25. orchestrator/search/agent/state.py +21 -0
  26. orchestrator/search/agent/tools.py +258 -0
  27. orchestrator/search/core/__init__.py +12 -0
  28. orchestrator/search/core/embedding.py +73 -0
  29. orchestrator/search/core/exceptions.py +36 -0
  30. orchestrator/search/core/types.py +296 -0
  31. orchestrator/search/core/validators.py +40 -0
  32. orchestrator/search/docs/index.md +37 -0
  33. orchestrator/search/docs/running_local_text_embedding_inference.md +46 -0
  34. orchestrator/search/filters/__init__.py +40 -0
  35. orchestrator/search/filters/base.py +295 -0
  36. orchestrator/search/filters/date_filters.py +88 -0
  37. orchestrator/search/filters/definitions.py +107 -0
  38. orchestrator/search/filters/ltree_filters.py +56 -0
  39. orchestrator/search/filters/numeric_filter.py +73 -0
  40. orchestrator/search/indexing/__init__.py +16 -0
  41. orchestrator/search/indexing/indexer.py +334 -0
  42. orchestrator/search/indexing/registry.py +101 -0
  43. orchestrator/search/indexing/tasks.py +69 -0
  44. orchestrator/search/indexing/traverse.py +334 -0
  45. orchestrator/search/llm_migration.py +108 -0
  46. orchestrator/search/retrieval/__init__.py +16 -0
  47. orchestrator/search/retrieval/builder.py +123 -0
  48. orchestrator/search/retrieval/engine.py +154 -0
  49. orchestrator/search/retrieval/exceptions.py +90 -0
  50. orchestrator/search/retrieval/pagination.py +96 -0
  51. orchestrator/search/retrieval/retrievers/__init__.py +26 -0
  52. orchestrator/search/retrieval/retrievers/base.py +123 -0
  53. orchestrator/search/retrieval/retrievers/fuzzy.py +94 -0
  54. orchestrator/search/retrieval/retrievers/hybrid.py +277 -0
  55. orchestrator/search/retrieval/retrievers/semantic.py +94 -0
  56. orchestrator/search/retrieval/retrievers/structured.py +39 -0
  57. orchestrator/search/retrieval/utils.py +120 -0
  58. orchestrator/search/retrieval/validation.py +152 -0
  59. orchestrator/search/schemas/__init__.py +12 -0
  60. orchestrator/search/schemas/parameters.py +129 -0
  61. orchestrator/search/schemas/results.py +77 -0
  62. orchestrator/services/processes.py +1 -1
  63. orchestrator/services/settings_env_variables.py +2 -2
  64. orchestrator/settings.py +8 -1
  65. orchestrator/utils/state.py +6 -1
  66. orchestrator/workflows/steps.py +15 -1
  67. orchestrator/workflows/tasks/validate_products.py +1 -1
  68. {orchestrator_core-4.4.2.dist-info → orchestrator_core-4.5.0.dist-info}/METADATA +15 -8
  69. {orchestrator_core-4.4.2.dist-info → orchestrator_core-4.5.0.dist-info}/RECORD +71 -21
  70. {orchestrator_core-4.4.2.dist-info → orchestrator_core-4.5.0.dist-info}/WHEEL +0 -0
  71. {orchestrator_core-4.4.2.dist-info → orchestrator_core-4.5.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,334 @@
1
+ # Copyright 2019-2025 SURF, GÉANT.
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+
14
+ import re
15
+ from abc import ABC, abstractmethod
16
+ from collections.abc import Iterable
17
+ from enum import Enum
18
+ from typing import Any, cast, get_args
19
+ from uuid import uuid4
20
+
21
+ import structlog
22
+
23
+ from orchestrator.db import ProcessTable, ProductTable, SubscriptionTable, WorkflowTable
24
+ from orchestrator.domain import (
25
+ SUBSCRIPTION_MODEL_REGISTRY,
26
+ SubscriptionModel,
27
+ )
28
+ from orchestrator.domain.base import ProductBlockModel, ProductModel
29
+ from orchestrator.domain.lifecycle import (
30
+ lookup_specialized_type,
31
+ )
32
+ from orchestrator.schemas.process import ProcessSchema
33
+ from orchestrator.schemas.workflow import WorkflowSchema
34
+ from orchestrator.search.core.exceptions import ModelLoadError, ProductNotInRegistryError
35
+ from orchestrator.search.core.types import LTREE_SEPARATOR, ExtractedField, FieldType
36
+ from orchestrator.types import SubscriptionLifecycle
37
+
38
+ logger = structlog.get_logger(__name__)
39
+
40
+ DatabaseEntity = SubscriptionTable | ProductTable | ProcessTable | WorkflowTable
41
+
42
+
43
+ class BaseTraverser(ABC):
44
+ """Base class for traversing database models and extracting searchable fields."""
45
+
46
+ _MAX_DEPTH = 40
47
+
48
+ @classmethod
49
+ def get_fields(cls, entity: DatabaseEntity, pk_name: str, root_name: str) -> list[ExtractedField]:
50
+ """Main entry point for extracting fields from an entity. Default implementation delegates to _load_model."""
51
+ try:
52
+ model = cls._load_model(entity)
53
+ if model is None:
54
+ return []
55
+ return sorted(cls.traverse(model, root_name), key=lambda f: f.path)
56
+
57
+ except (ProductNotInRegistryError, ModelLoadError) as e:
58
+ entity_id = getattr(entity, pk_name, "unknown")
59
+ logger.error(f"Failed to extract fields from {entity.__class__.__name__}", id=str(entity_id), error=str(e))
60
+ return []
61
+
62
+ @classmethod
63
+ def traverse(cls, instance: Any, path: str = "") -> Iterable[ExtractedField]:
64
+ """Walks the fields of a Pydantic model, dispatching each to a field handler."""
65
+ model_class = type(instance)
66
+
67
+ # Handle both standard and computed fields from the Pydantic model
68
+ all_fields = model_class.model_fields.copy()
69
+ all_fields.update(getattr(model_class, "__pydantic_computed_fields__", {}))
70
+
71
+ for name, field in all_fields.items():
72
+ try:
73
+ value = getattr(instance, name, None)
74
+ except Exception as e:
75
+ logger.error(f"Failed to access field '{name}' on {model_class.__name__}", error=str(e))
76
+ continue
77
+ new_path = f"{path}{LTREE_SEPARATOR}{name}" if path else name
78
+ annotation = field.annotation if hasattr(field, "annotation") else field.return_type
79
+ yield from cls._yield_fields_for_value(value, new_path, annotation)
80
+
81
+ @classmethod
82
+ def _yield_fields_for_value(cls, value: Any, path: str, annotation: Any) -> Iterable[ExtractedField]:
83
+ """Yields fields for a given value based on its type (model, list, or scalar)."""
84
+ if value is None:
85
+ return
86
+
87
+ # If the value is a list, pass it to the list traverser
88
+ if isinstance(value, list):
89
+ if element_annotation := get_args(annotation):
90
+ yield from cls._traverse_list(value, path, element_annotation[0])
91
+ return
92
+
93
+ # If the value is another Pydantic model, recurse into it
94
+ if hasattr(type(value), "model_fields"):
95
+ yield from cls.traverse(value, path)
96
+ return
97
+
98
+ ftype = FieldType.from_type_hint(annotation)
99
+
100
+ if isinstance(value, Enum):
101
+ yield ExtractedField(path, str(value.value), ftype)
102
+ else:
103
+ yield ExtractedField(path, str(value), ftype)
104
+
105
+ @classmethod
106
+ def _traverse_list(cls, items: list[Any], path: str, element_annotation: Any) -> Iterable[ExtractedField]:
107
+ """Recursively traverses items in a list."""
108
+ for i, item in enumerate(items):
109
+ item_path = f"{path}.{i}"
110
+ yield from cls._yield_fields_for_value(item, item_path, element_annotation)
111
+
112
+ @classmethod
113
+ def _load_model_with_schema(cls, entity: Any, schema_class: type[Any], pk_name: str) -> Any:
114
+ """Generic helper for loading models using Pydantic schema validation."""
115
+ try:
116
+ return schema_class.model_validate(entity)
117
+ except Exception as e:
118
+ entity_id = getattr(entity, pk_name, "unknown")
119
+ raise ModelLoadError(f"Failed to load {schema_class.__name__} for {pk_name} '{entity_id}'") from e
120
+
121
+ @classmethod
122
+ @abstractmethod
123
+ def _load_model(cls, entity: Any) -> Any: ...
124
+
125
+
126
+ class SubscriptionTraverser(BaseTraverser):
127
+ """Traverser for subscription entities using full Pydantic model extraction."""
128
+
129
+ @classmethod
130
+ def _load_model(cls, sub: SubscriptionTable) -> SubscriptionModel | None:
131
+ base_model_cls = SUBSCRIPTION_MODEL_REGISTRY.get(sub.product.name)
132
+ if not base_model_cls:
133
+ raise ProductNotInRegistryError(f"Product '{sub.product.name}' not in registry.")
134
+
135
+ specialized_model_cls = cast(type[SubscriptionModel], lookup_specialized_type(base_model_cls, sub.status))
136
+
137
+ try:
138
+ return specialized_model_cls.from_subscription(sub.subscription_id)
139
+ except Exception as e:
140
+ raise ModelLoadError(f"Failed to load model for subscription_id '{sub.subscription_id}'") from e
141
+
142
+
143
+ class ProductTraverser(BaseTraverser):
144
+ """Traverser for product entities using a template SubscriptionModel instance."""
145
+
146
+ @classmethod
147
+ def _sanitize_for_ltree(cls, name: str) -> str:
148
+ """Sanitizes a string to be a valid ltree path label."""
149
+ # Convert to lowercase
150
+ sanitized = name.lower()
151
+
152
+ # Replace all non-alphanumeric (and non-underscore) characters with an underscore
153
+ sanitized = re.sub(r"[^a-z0-9_]", "_", sanitized)
154
+
155
+ # Collapse multiple underscores into a single one
156
+ sanitized = re.sub(r"__+", "_", sanitized)
157
+
158
+ # Remove leading or trailing underscores
159
+ sanitized = sanitized.strip("_")
160
+
161
+ # Handle cases where the name was only invalid characters
162
+ if not sanitized:
163
+ return "unnamed_product"
164
+
165
+ return sanitized
166
+
167
+ @classmethod
168
+ def get_fields(cls, entity: ProductTable, pk_name: str, root_name: str) -> list[ExtractedField]: # type: ignore[override]
169
+ """Extracts fields by creating a template SubscriptionModel instance for the product.
170
+
171
+ Extracts product metadata and block schema structure.
172
+ """
173
+ try:
174
+ model = cls._load_model(entity)
175
+
176
+ if not model:
177
+ return []
178
+
179
+ fields: list[ExtractedField] = []
180
+
181
+ product_fields = cls.traverse(model.product, root_name)
182
+ fields.extend(product_fields)
183
+
184
+ product_name = cls._sanitize_for_ltree(model.product.name)
185
+
186
+ product_block_root = f"{root_name}.{product_name}.product_block"
187
+
188
+ # Extract product block schema structure
189
+ model_class = type(model)
190
+ product_block_fields = getattr(model_class, "_product_block_fields_", {})
191
+
192
+ for field_name in product_block_fields:
193
+ block_value = getattr(model, field_name, None)
194
+ if block_value is not None:
195
+ block_path = f"{product_block_root}.{field_name}"
196
+ schema_fields = cls._extract_block_schema(block_value, block_path)
197
+ fields.extend(schema_fields)
198
+
199
+ return sorted(fields, key=lambda f: f.path)
200
+
201
+ except (ProductNotInRegistryError, ModelLoadError) as e:
202
+ entity_id = getattr(entity, pk_name, "unknown")
203
+ logger.error(f"Failed to extract fields from {entity.__class__.__name__}", id=str(entity_id), error=str(e))
204
+ return []
205
+
206
+ @classmethod
207
+ def _extract_block_schema(cls, block_instance: ProductBlockModel, block_path: str) -> list[ExtractedField]:
208
+ """Extract schema information from a block instance, returning field names as RESOURCE_TYPE."""
209
+ fields = []
210
+
211
+ # Add the block itself as a BLOCK type
212
+ block_name = block_path.split(LTREE_SEPARATOR)[-1]
213
+ fields.append(ExtractedField(path=block_path, value=block_name, value_type=FieldType.BLOCK))
214
+
215
+ # Extract all field names from the block as RESOURCE_TYPE
216
+ if hasattr(type(block_instance), "model_fields"):
217
+ all_fields = type(block_instance).model_fields
218
+ computed_fields = getattr(block_instance, "__pydantic_computed_fields__", None)
219
+ if computed_fields:
220
+ all_fields.update(computed_fields)
221
+
222
+ for field_name in all_fields:
223
+ field_value = getattr(block_instance, field_name, None)
224
+ field_path = f"{block_path}.{field_name}"
225
+
226
+ # If it's a nested block, recurse
227
+ if field_value is not None and isinstance(field_value, ProductBlockModel):
228
+ nested_fields = cls._extract_block_schema(field_value, field_path)
229
+ fields.extend(nested_fields)
230
+ elif field_value is not None and isinstance(field_value, list):
231
+ # Handle list of blocks
232
+ if field_value and isinstance(field_value[0], ProductBlockModel):
233
+ # For lists, we still add the list field as a resource type
234
+ fields.append(
235
+ ExtractedField(path=field_path, value=field_name, value_type=FieldType.RESOURCE_TYPE)
236
+ )
237
+ # And potentially traverse the first item for schema
238
+ first_item_path = f"{field_path}{LTREE_SEPARATOR}0"
239
+ nested_fields = cls._extract_block_schema(field_value[0], first_item_path)
240
+ fields.extend(nested_fields)
241
+ else:
242
+ fields.append(
243
+ ExtractedField(path=field_path, value=field_name, value_type=FieldType.RESOURCE_TYPE)
244
+ )
245
+ else:
246
+ # Regular fields are resource types
247
+ fields.append(ExtractedField(path=field_path, value=field_name, value_type=FieldType.RESOURCE_TYPE))
248
+
249
+ return fields
250
+
251
+ @classmethod
252
+ def _load_model(cls, product: ProductTable) -> SubscriptionModel | None:
253
+ """Creates a template instance of a SubscriptionModel for a given product.
254
+
255
+ This allows us to traverse the product's defined block structure, even
256
+ without a real subscription instance in the database.
257
+ """
258
+ # Find the SubscriptionModel class associated with this product's name.
259
+ domain_model_cls = SUBSCRIPTION_MODEL_REGISTRY.get(product.name)
260
+ if not domain_model_cls:
261
+ raise ProductNotInRegistryError(f"Product '{product.name}' not in registry.")
262
+
263
+ # Get the initial lifecycle version of that class, as it represents the base structure.
264
+ try:
265
+ subscription_model_cls = cast(
266
+ type[SubscriptionModel], lookup_specialized_type(domain_model_cls, SubscriptionLifecycle.INITIAL)
267
+ )
268
+ except Exception:
269
+ subscription_model_cls = domain_model_cls
270
+
271
+ try:
272
+ product_model = ProductModel(
273
+ product_id=product.product_id,
274
+ name=product.name,
275
+ description=product.description,
276
+ product_type=product.product_type,
277
+ tag=product.tag,
278
+ status=product.status,
279
+ )
280
+
281
+ # Generate a fake subscription ID for the template
282
+ subscription_id = uuid4()
283
+
284
+ # Get fixed inputs for the product
285
+ fixed_inputs = {fi.name: fi.value for fi in product.fixed_inputs}
286
+
287
+ # Initialize product blocks
288
+ instances = subscription_model_cls._init_instances(subscription_id)
289
+
290
+ return subscription_model_cls(
291
+ product=product_model,
292
+ customer_id="traverser_template",
293
+ subscription_id=subscription_id,
294
+ description="Template for schema traversal",
295
+ status=SubscriptionLifecycle.INITIAL,
296
+ insync=False,
297
+ start_date=None,
298
+ end_date=None,
299
+ note=None,
300
+ version=1,
301
+ **fixed_inputs,
302
+ **instances,
303
+ )
304
+ except Exception:
305
+ logger.exception("Failed to instantiate template model for product", product_name=product.name)
306
+ return None
307
+
308
+
309
+ class ProcessTraverser(BaseTraverser):
310
+ """Traverser for process entities using ProcessSchema model.
311
+
312
+ Note: Currently extracts only top-level process fields. Could be extended to include:
313
+ - Related subscriptions (entity.subscriptions)
314
+ - Related workflow information beyond workflow_name
315
+ """
316
+
317
+ @classmethod
318
+ def _load_model(cls, process: ProcessTable) -> ProcessSchema:
319
+ """Load process model using ProcessSchema."""
320
+ return cls._load_model_with_schema(process, ProcessSchema, "process_id")
321
+
322
+
323
+ class WorkflowTraverser(BaseTraverser):
324
+ """Traverser for workflow entities using WorkflowSchema model.
325
+
326
+ Note: Currently extracts only top-level workflow fields. Could be extended to include:
327
+ - Related products (entity.products) - each with their own block structures
328
+ - Related processes (entity.processes) - each with their own process data
329
+ """
330
+
331
+ @classmethod
332
+ def _load_model(cls, workflow: WorkflowTable) -> WorkflowSchema:
333
+ """Load workflow model using WorkflowSchema."""
334
+ return cls._load_model_with_schema(workflow, WorkflowSchema, "workflow_id")
@@ -0,0 +1,108 @@
1
+ # Copyright 2019-2025 SURF
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+
14
+ """Simple search migration function that runs when SEARCH_ENABLED = True."""
15
+
16
+ from sqlalchemy import text
17
+ from sqlalchemy.engine import Connection
18
+ from structlog import get_logger
19
+
20
+ from orchestrator.llm_settings import llm_settings
21
+ from orchestrator.search.core.types import FieldType
22
+
23
+ logger = get_logger(__name__)
24
+
25
+ TABLE = "ai_search_index"
26
+ TARGET_DIM = 1536
27
+
28
+
29
+ def run_migration(connection: Connection) -> None:
30
+ """Run LLM migration with ON CONFLICT DO NOTHING pattern."""
31
+ logger.info("Running LLM migration")
32
+
33
+ try:
34
+ # Test to see if the extenstion exists and then skip the migration; Needed for certain situations where db user
35
+ # has insufficient priviledges to run the `CREATE EXTENSION ...` command.
36
+ res = connection.execute(text("SELECT * FROM pg_extension where extname = 'vector';"))
37
+ if llm_settings.LLM_FORCE_EXTENTION_MIGRATION or res.rowcount == 0:
38
+ # Create PostgreSQL extensions
39
+ logger.info("Attempting to run the extention creation;")
40
+ connection.execute(text("CREATE EXTENSION IF NOT EXISTS ltree;"))
41
+ connection.execute(text("CREATE EXTENSION IF NOT EXISTS unaccent;"))
42
+ connection.execute(text("CREATE EXTENSION IF NOT EXISTS pg_trgm;"))
43
+ connection.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))
44
+
45
+ # Create field_type enum
46
+ field_type_values = "', '".join([ft.value for ft in FieldType])
47
+ connection.execute(
48
+ text(
49
+ f"""
50
+ DO $$
51
+ BEGIN
52
+ IF NOT EXISTS (SELECT 1 FROM pg_type WHERE typname = 'field_type') THEN
53
+ CREATE TYPE field_type AS ENUM ('{field_type_values}');
54
+ END IF;
55
+ END $$;
56
+ """
57
+ )
58
+ )
59
+
60
+ # Create table with ON CONFLICT DO NOTHING pattern
61
+ connection.execute(
62
+ text(
63
+ f"""
64
+ CREATE TABLE IF NOT EXISTS {TABLE} (
65
+ entity_type TEXT NOT NULL,
66
+ entity_id UUID NOT NULL,
67
+ path LTREE NOT NULL,
68
+ value TEXT NOT NULL,
69
+ embedding VECTOR({TARGET_DIM}),
70
+ content_hash VARCHAR(64) NOT NULL,
71
+ value_type field_type NOT NULL DEFAULT '{FieldType.STRING.value}',
72
+ CONSTRAINT pk_ai_search_index PRIMARY KEY (entity_id, path)
73
+ );
74
+ """
75
+ )
76
+ )
77
+
78
+ # Drop default
79
+ connection.execute(text(f"ALTER TABLE {TABLE} ALTER COLUMN value_type DROP DEFAULT;"))
80
+
81
+ # Create indexes with IF NOT EXISTS
82
+ connection.execute(text(f"CREATE INDEX IF NOT EXISTS ix_ai_search_index_entity_id ON {TABLE} (entity_id);"))
83
+ connection.execute(
84
+ text(f"CREATE INDEX IF NOT EXISTS idx_ai_search_index_content_hash ON {TABLE} (content_hash);")
85
+ )
86
+ connection.execute(
87
+ text(f"CREATE INDEX IF NOT EXISTS ix_flat_path_gist ON {TABLE} USING GIST (path gist_ltree_ops);")
88
+ )
89
+ connection.execute(text(f"CREATE INDEX IF NOT EXISTS ix_flat_path_btree ON {TABLE} (path);"))
90
+ connection.execute(
91
+ text(f"CREATE INDEX IF NOT EXISTS ix_flat_value_trgm ON {TABLE} USING GIN (value gin_trgm_ops);")
92
+ )
93
+ connection.execute(
94
+ text(
95
+ f"CREATE INDEX IF NOT EXISTS ix_flat_embed_hnsw ON {TABLE} USING HNSW (embedding vector_l2_ops) WITH (m = 16, ef_construction = 64);"
96
+ )
97
+ )
98
+
99
+ connection.commit()
100
+ logger.info("LLM migration completed successfully")
101
+
102
+ except Exception as e:
103
+ logger.error("LLM migration failed", error=str(e))
104
+ raise Exception(
105
+ f"LLM migration failed. This likely means the pgvector extension "
106
+ f"is not installed. Please install pgvector and ensure your PostgreSQL "
107
+ f"version supports it. Error: {e}"
108
+ ) from e
@@ -0,0 +1,16 @@
1
+ # Copyright 2019-2025 SURF, GÉANT.
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+
14
+ from .engine import execute_search
15
+
16
+ __all__ = ["execute_search"]
@@ -0,0 +1,123 @@
1
+ # Copyright 2019-2025 SURF, GÉANT.
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+
14
+ from collections import defaultdict
15
+ from typing import Sequence
16
+
17
+ from sqlalchemy import Select, String, cast, func, select
18
+ from sqlalchemy.engine import Row
19
+
20
+ from orchestrator.db.models import AiSearchIndex
21
+ from orchestrator.search.core.types import EntityType, FieldType, FilterOp, UIType
22
+ from orchestrator.search.filters import LtreeFilter
23
+ from orchestrator.search.schemas.parameters import BaseSearchParameters
24
+ from orchestrator.search.schemas.results import ComponentInfo, LeafInfo
25
+
26
+
27
+ def create_path_autocomplete_lquery(prefix: str) -> str:
28
+ """Create the lquery pattern for a multi-level path autocomplete search."""
29
+ return f"{prefix}*.*"
30
+
31
+
32
+ def build_candidate_query(params: BaseSearchParameters) -> Select:
33
+ """Build the base query for retrieving candidate entities.
34
+
35
+ Constructs a `SELECT` statement that retrieves distinct `entity_id` values
36
+ from the index table for the given entity type, applying any structured
37
+ filters from the provided search parameters.
38
+
39
+ Args:
40
+ params (BaseSearchParameters): The search parameters containing the entity type and optional filters.
41
+
42
+ Returns:
43
+ Select: The SQLAlchemy `Select` object representing the query.
44
+ """
45
+
46
+ stmt = select(AiSearchIndex.entity_id).where(AiSearchIndex.entity_type == params.entity_type.value).distinct()
47
+
48
+ if params.filters is not None:
49
+ entity_id_col = AiSearchIndex.entity_id
50
+ stmt = stmt.where(
51
+ params.filters.to_expression(
52
+ entity_id_col,
53
+ entity_type_value=params.entity_type.value,
54
+ )
55
+ )
56
+
57
+ return stmt
58
+
59
+
60
+ def build_paths_query(entity_type: EntityType, prefix: str | None = None, q: str | None = None) -> Select:
61
+ """Build the query for retrieving paths and their value types for leaves/components processing."""
62
+ stmt = select(AiSearchIndex.path, AiSearchIndex.value_type).where(AiSearchIndex.entity_type == entity_type.value)
63
+
64
+ if prefix:
65
+ lquery_pattern = create_path_autocomplete_lquery(prefix)
66
+ ltree_filter = LtreeFilter(op=FilterOp.MATCHES_LQUERY, value=lquery_pattern)
67
+ stmt = stmt.where(ltree_filter.to_expression(AiSearchIndex.path, path=""))
68
+
69
+ stmt = stmt.group_by(AiSearchIndex.path, AiSearchIndex.value_type)
70
+
71
+ if q:
72
+ score = func.similarity(cast(AiSearchIndex.path, String), q)
73
+ stmt = stmt.order_by(score.desc(), AiSearchIndex.path)
74
+ else:
75
+ stmt = stmt.order_by(AiSearchIndex.path)
76
+
77
+ return stmt
78
+
79
+
80
+ def process_path_rows(rows: Sequence[Row]) -> tuple[list[LeafInfo], list[ComponentInfo]]:
81
+ """Process query results to extract leaves and components information.
82
+
83
+ Parameters
84
+ ----------
85
+ rows : Sequence[Row]
86
+ Database rows containing path and value_type information
87
+
88
+ Returns:
89
+ -------
90
+ tuple[list[LeafInfo], list[ComponentInfo]]
91
+ Processed leaves and components
92
+ """
93
+ leaves_dict: dict[str, set[UIType]] = defaultdict(set)
94
+ leaves_paths_dict: dict[str, set[str]] = defaultdict(set)
95
+ components_set: set[str] = set()
96
+
97
+ for row in rows:
98
+ path, value_type = row
99
+
100
+ path_str = str(path)
101
+ path_segments = path_str.split(".")
102
+
103
+ # Remove numeric segments
104
+ clean_segments = [seg for seg in path_segments if not seg.isdigit()]
105
+
106
+ if clean_segments:
107
+ # Last segment is a leaf
108
+ leaf_name = clean_segments[-1]
109
+ ui_type = UIType.from_field_type(FieldType(value_type))
110
+ leaves_dict[leaf_name].add(ui_type)
111
+ leaves_paths_dict[leaf_name].add(path_str)
112
+
113
+ # All segments except the first/last are components
114
+ for component in clean_segments[1:-1]:
115
+ components_set.add(component)
116
+
117
+ leaves = [
118
+ LeafInfo(name=leaf, ui_types=list(types), paths=sorted(leaves_paths_dict[leaf]))
119
+ for leaf, types in leaves_dict.items()
120
+ ]
121
+ components = [ComponentInfo(name=component, ui_types=[UIType.COMPONENT]) for component in sorted(components_set)]
122
+
123
+ return leaves, components