orchestrator-core 4.4.0rc2__py3-none-any.whl → 5.0.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. orchestrator/__init__.py +1 -1
  2. orchestrator/api/api_v1/api.py +7 -0
  3. orchestrator/api/api_v1/endpoints/agent.py +62 -0
  4. orchestrator/api/api_v1/endpoints/processes.py +6 -12
  5. orchestrator/api/api_v1/endpoints/search.py +197 -0
  6. orchestrator/api/api_v1/endpoints/subscriptions.py +0 -1
  7. orchestrator/app.py +4 -0
  8. orchestrator/cli/index_llm.py +73 -0
  9. orchestrator/cli/main.py +8 -1
  10. orchestrator/cli/resize_embedding.py +136 -0
  11. orchestrator/cli/scheduler.py +29 -40
  12. orchestrator/cli/search_explore.py +203 -0
  13. orchestrator/db/models.py +37 -1
  14. orchestrator/graphql/schema.py +0 -5
  15. orchestrator/graphql/schemas/process.py +2 -2
  16. orchestrator/graphql/utils/create_resolver_error_handler.py +1 -1
  17. orchestrator/migrations/versions/schema/2025-08-12_52b37b5b2714_search_index_model_for_llm_integration.py +95 -0
  18. orchestrator/schedules/__init__.py +2 -1
  19. orchestrator/schedules/resume_workflows.py +2 -2
  20. orchestrator/schedules/scheduling.py +24 -64
  21. orchestrator/schedules/task_vacuum.py +2 -2
  22. orchestrator/schedules/validate_products.py +2 -8
  23. orchestrator/schedules/validate_subscriptions.py +2 -2
  24. orchestrator/schemas/search.py +101 -0
  25. orchestrator/search/__init__.py +0 -0
  26. orchestrator/search/agent/__init__.py +1 -0
  27. orchestrator/search/agent/prompts.py +62 -0
  28. orchestrator/search/agent/state.py +8 -0
  29. orchestrator/search/agent/tools.py +122 -0
  30. orchestrator/search/core/__init__.py +0 -0
  31. orchestrator/search/core/embedding.py +64 -0
  32. orchestrator/search/core/exceptions.py +16 -0
  33. orchestrator/search/core/types.py +162 -0
  34. orchestrator/search/core/validators.py +27 -0
  35. orchestrator/search/docs/index.md +37 -0
  36. orchestrator/search/docs/running_local_text_embedding_inference.md +45 -0
  37. orchestrator/search/filters/__init__.py +27 -0
  38. orchestrator/search/filters/base.py +236 -0
  39. orchestrator/search/filters/date_filters.py +75 -0
  40. orchestrator/search/filters/definitions.py +76 -0
  41. orchestrator/search/filters/ltree_filters.py +31 -0
  42. orchestrator/search/filters/numeric_filter.py +60 -0
  43. orchestrator/search/indexing/__init__.py +3 -0
  44. orchestrator/search/indexing/indexer.py +316 -0
  45. orchestrator/search/indexing/registry.py +88 -0
  46. orchestrator/search/indexing/tasks.py +53 -0
  47. orchestrator/search/indexing/traverse.py +209 -0
  48. orchestrator/search/retrieval/__init__.py +3 -0
  49. orchestrator/search/retrieval/builder.py +64 -0
  50. orchestrator/search/retrieval/engine.py +96 -0
  51. orchestrator/search/retrieval/ranker.py +202 -0
  52. orchestrator/search/retrieval/utils.py +88 -0
  53. orchestrator/search/retrieval/validation.py +174 -0
  54. orchestrator/search/schemas/__init__.py +0 -0
  55. orchestrator/search/schemas/parameters.py +114 -0
  56. orchestrator/search/schemas/results.py +47 -0
  57. orchestrator/services/processes.py +11 -16
  58. orchestrator/services/subscriptions.py +0 -4
  59. orchestrator/settings.py +29 -1
  60. orchestrator/targets.py +0 -1
  61. orchestrator/workflow.py +1 -8
  62. orchestrator/workflows/utils.py +1 -48
  63. {orchestrator_core-4.4.0rc2.dist-info → orchestrator_core-5.0.0a1.dist-info}/METADATA +6 -3
  64. {orchestrator_core-4.4.0rc2.dist-info → orchestrator_core-5.0.0a1.dist-info}/RECORD +66 -30
  65. orchestrator/graphql/resolvers/scheduled_tasks.py +0 -36
  66. orchestrator/graphql/schemas/scheduled_task.py +0 -8
  67. orchestrator/schedules/scheduler.py +0 -163
  68. {orchestrator_core-4.4.0rc2.dist-info → orchestrator_core-5.0.0a1.dist-info}/WHEEL +0 -0
  69. {orchestrator_core-4.4.0rc2.dist-info → orchestrator_core-5.0.0a1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,316 @@
1
+ import hashlib
2
+ from collections.abc import Generator, Iterable, Iterator
3
+ from contextlib import contextmanager, nullcontext
4
+ from functools import lru_cache
5
+ from typing import Any
6
+
7
+ import structlog
8
+ from litellm.utils import encode, get_max_tokens
9
+ from sqlalchemy import delete, tuple_
10
+ from sqlalchemy.dialects.postgresql import insert
11
+ from sqlalchemy.dialects.postgresql.dml import Insert
12
+ from sqlalchemy.orm import Session
13
+ from sqlalchemy_utils.types.ltree import Ltree
14
+
15
+ from orchestrator.db import db
16
+ from orchestrator.db.database import BaseModel
17
+ from orchestrator.db.models import AiSearchIndex
18
+ from orchestrator.search.core.embedding import EmbeddingIndexer
19
+ from orchestrator.search.core.types import ExtractedField, IndexableRecord
20
+ from orchestrator.search.indexing.registry import EntityConfig
21
+ from orchestrator.settings import app_settings
22
+
23
+ logger = structlog.get_logger(__name__)
24
+
25
+
26
+ @contextmanager
27
+ def _maybe_begin(session: Session | None) -> Iterator[None]:
28
+ if session is None:
29
+ yield
30
+ else:
31
+ with session.begin():
32
+ yield
33
+
34
+
35
+ class Indexer:
36
+ """Index entities into `AiSearchIndex` using streaming reads and batched writes.
37
+
38
+ Entities are read from a streaming iterator and accumulated into chunks of
39
+ size `chunk_size`. For each chunk, the indexer extracts fields, diffs via
40
+ content hashes, deletes stale paths, and prepares upserts using a two-list
41
+ buffer:
42
+ - Embeddable list (STRING fields): maintains a running token count against a
43
+ token budget (model context window minus a safety margin) and flushes when
44
+ adding the next item would exceed the budget.
45
+ - Non-embeddable list: accumulated in parallel and does not contribute to the
46
+ flush condition.
47
+ Each flush (or end-of-chunk) emits a single combined UPSERT batch from both
48
+ lists (wrapped in a per-chunk transaction in non-dry-runs).
49
+
50
+ Args:
51
+ config (EntityConfig): Registry config describing the entity kind,
52
+ ORM table, and traverser.
53
+ dry_run (bool): If True, skip DELETE/UPSERT statements and external
54
+ embedding calls.
55
+ force_index (bool): If True, ignore existing hashes and reindex all
56
+ fields for each entity.
57
+
58
+ Notes:
59
+ - Non-dry-run runs open a write session and wrap each processed chunk in
60
+ a transaction (`Session.begin()`).
61
+ - Read queries use the passed session when available, otherwise the
62
+ generic `db.session`.
63
+
64
+ Workflow:
65
+ 1) Stream entities (yield_per=chunk_size) and accumulate into a chunk.
66
+ 2) Begin transaction for the chunk.
67
+ 3) determine_changes() → fields_to_upsert, paths_to_delete.
68
+ 4) Delete stale paths.
69
+ 5) Build UPSERT batches with a two-list buffer:
70
+ - Embeddable list (STRING): track running token count; flush when next item
71
+ would exceed the token budget (model max context - safety margin).
72
+ - Non-embeddable list: accumulate in parallel; does not affect flushing.
73
+ 6) Execute UPSERT for each batch (skip in dry_run).
74
+ 7) Commit transaction (auto on context exit).
75
+ 8) Repeat until the stream is exhausted.
76
+ """
77
+
78
+ def __init__(self, config: EntityConfig, dry_run: bool, force_index: bool) -> None:
79
+ self.config = config
80
+ self.dry_run = dry_run
81
+ self.force_index = force_index
82
+ self.embedding_model = app_settings.EMBEDDING_MODEL
83
+ self.logger = logger.bind(entity_kind=config.entity_kind.value)
84
+
85
+ def run(self, entities: Iterable[BaseModel], chunk_size: int) -> int:
86
+ """Orchestrates the entire indexing process."""
87
+ chunk: list[BaseModel] = []
88
+ total_records_processed = 0
89
+ total_identical_records = 0
90
+
91
+ write_scope = db.database_scope() if not self.dry_run else nullcontext()
92
+
93
+ def flush() -> None:
94
+ nonlocal total_records_processed, total_identical_records
95
+ with _maybe_begin(session):
96
+ processed_in_chunk, identical_in_chunk = self._process_chunk(chunk, session)
97
+ total_records_processed += processed_in_chunk
98
+ total_identical_records += identical_in_chunk
99
+ chunk.clear()
100
+
101
+ with write_scope as database:
102
+ session: Session | None = getattr(database, "session", None)
103
+ for entity in entities:
104
+ chunk.append(entity)
105
+ if len(chunk) >= chunk_size:
106
+ flush()
107
+
108
+ if chunk:
109
+ flush()
110
+
111
+ final_log_message = (
112
+ f"processed {total_records_processed} records and skipped {total_identical_records} identical records."
113
+ )
114
+ self.logger.info(
115
+ f"Dry run, would have indexed {final_log_message}"
116
+ if self.dry_run
117
+ else f"Indexing done, {final_log_message}"
118
+ )
119
+ return total_records_processed
120
+
121
+ def _process_chunk(self, entity_chunk: list[BaseModel], session: Session | None = None) -> tuple[int, int]:
122
+ """Process a chunk of entities."""
123
+ if not entity_chunk:
124
+ return 0, 0
125
+
126
+ fields_to_upsert, paths_to_delete, identical_count = self._determine_changes(entity_chunk, session)
127
+
128
+ if paths_to_delete and session is not None:
129
+ self.logger.debug(f"Deleting {len(paths_to_delete)} stale records in chunk.")
130
+ delete_stmt = delete(AiSearchIndex).where(
131
+ tuple_(AiSearchIndex.entity_id, AiSearchIndex.path).in_(paths_to_delete)
132
+ )
133
+ session.execute(delete_stmt)
134
+
135
+ if fields_to_upsert:
136
+ upsert_stmt = self._get_upsert_statement()
137
+ batch_generator = self._generate_upsert_batches(fields_to_upsert)
138
+
139
+ for batch in batch_generator:
140
+ if self.dry_run:
141
+ self.logger.debug(f"Dry Run: Would upsert {len(batch)} records.")
142
+ elif batch and session:
143
+ session.execute(upsert_stmt, batch)
144
+
145
+ return len(fields_to_upsert), identical_count
146
+
147
+ def _determine_changes(
148
+ self, entities: list[BaseModel], session: Session | None = None
149
+ ) -> tuple[list[tuple[str, ExtractedField]], list[tuple[str, Ltree]], int]:
150
+ """Identifies all changes across all entities using pre-fetched data."""
151
+ entity_ids = [str(getattr(e, self.config.pk_name)) for e in entities]
152
+ read_session = session or db.session
153
+ existing_hashes = {} if self.force_index else self._get_all_existing_hashes(entity_ids, read_session)
154
+
155
+ fields_to_upsert: list[tuple[str, ExtractedField]] = []
156
+ paths_to_delete: list[tuple[str, Ltree]] = []
157
+ identical_records_count = 0
158
+
159
+ for entity in entities:
160
+ entity_id = str(getattr(entity, self.config.pk_name))
161
+ current_fields = self.config.traverser.get_fields(
162
+ entity, pk_name=self.config.pk_name, root_name=self.config.root_name
163
+ )
164
+
165
+ entity_hashes = existing_hashes.get(entity_id, {})
166
+ current_paths = set()
167
+
168
+ for field in current_fields:
169
+ current_paths.add(field.path)
170
+ current_hash = self._compute_content_hash(field.path, field.value)
171
+ if field.path not in entity_hashes or entity_hashes[field.path] != current_hash:
172
+ fields_to_upsert.append((entity_id, field))
173
+ else:
174
+ identical_records_count += 1
175
+
176
+ stale_paths = set(entity_hashes.keys()) - current_paths
177
+ paths_to_delete.extend([(entity_id, Ltree(p)) for p in stale_paths])
178
+
179
+ return fields_to_upsert, paths_to_delete, identical_records_count
180
+
181
+ def _get_all_existing_hashes(self, entity_ids: list[str], session: Session) -> dict[str, dict[str, str]]:
182
+ """Fetches all existing hashes for a list of entity IDs in a single query."""
183
+ if not entity_ids:
184
+ return {}
185
+
186
+ results = (
187
+ session.query(AiSearchIndex.entity_id, AiSearchIndex.path, AiSearchIndex.content_hash)
188
+ .filter(AiSearchIndex.entity_id.in_(entity_ids))
189
+ .all()
190
+ )
191
+
192
+ hashes_by_entity: dict[str, dict[str, str]] = {eid: {} for eid in entity_ids}
193
+ for entity_id, path, content_hash in results:
194
+ hashes_by_entity[str(entity_id)][str(path)] = content_hash
195
+ return hashes_by_entity
196
+
197
+ def _generate_upsert_batches(
198
+ self, fields_to_upsert: Iterable[tuple[str, ExtractedField]]
199
+ ) -> Generator[list[IndexableRecord], None, None]:
200
+ """Streams through fields, buffers them by token count, and yields batches."""
201
+ embeddable_buffer: list[tuple[str, ExtractedField]] = []
202
+ non_embeddable_records: list[IndexableRecord] = []
203
+ current_tokens = 0
204
+
205
+ max_ctx = self._get_max_tokens()
206
+ safe_margin = int(max_ctx * app_settings.EMBEDDING_SAFE_MARGIN_PERCENT)
207
+ token_budget = max(1, max_ctx - safe_margin)
208
+
209
+ max_batch_size = None
210
+ if app_settings.OPENAI_BASE_URL: # We are using a local model
211
+ max_batch_size = app_settings.EMBEDDING_MAX_BATCH_SIZE
212
+
213
+ for entity_id, field in fields_to_upsert:
214
+ if field.value_type.is_embeddable():
215
+ text = self._prepare_text_for_embedding(field)
216
+ try:
217
+ item_tokens = len(encode(model=self.embedding_model, text=text))
218
+ except Exception as e:
219
+ self.logger.warning("Tokenization failed; skipping.", path=field.path, err=str(e))
220
+ continue
221
+
222
+ if item_tokens > max_ctx:
223
+ self.logger.warning(
224
+ "Field exceeds context; skipping.", path=field.path, tokens=item_tokens, max_ctx=max_ctx
225
+ )
226
+ continue
227
+
228
+ should_flush = embeddable_buffer and (
229
+ current_tokens + item_tokens > token_budget
230
+ or (max_batch_size and len(embeddable_buffer) >= max_batch_size)
231
+ )
232
+
233
+ if should_flush:
234
+ yield self._flush_buffer(embeddable_buffer, non_embeddable_records)
235
+ embeddable_buffer.clear()
236
+ non_embeddable_records.clear()
237
+ current_tokens = 0
238
+
239
+ embeddable_buffer.append((entity_id, field))
240
+ current_tokens += item_tokens
241
+ else:
242
+ record = self._make_indexable_record(field, entity_id, embedding=None)
243
+ non_embeddable_records.append(record)
244
+
245
+ if embeddable_buffer or non_embeddable_records:
246
+ yield self._flush_buffer(embeddable_buffer, non_embeddable_records)
247
+
248
+ def _flush_buffer(self, embeddable_buffer: list, non_embeddable_records: list) -> list[IndexableRecord]:
249
+ """Processes and combines buffers into a single batch."""
250
+ if not embeddable_buffer:
251
+ return non_embeddable_records
252
+
253
+ texts_to_embed = [self._prepare_text_for_embedding(f) for _, f in embeddable_buffer]
254
+ embeddings = EmbeddingIndexer.get_embeddings_from_api_batch(texts_to_embed, self.dry_run)
255
+
256
+ if len(embeddable_buffer) != len(embeddings):
257
+ raise ValueError(f"Embedding mismatch: sent {len(embeddable_buffer)}, received {len(embeddings)}")
258
+
259
+ with_embeddings = [
260
+ self._make_indexable_record(field, entity_id, embedding)
261
+ for (entity_id, field), embedding in zip(embeddable_buffer, embeddings)
262
+ ]
263
+ return non_embeddable_records + with_embeddings
264
+
265
+ def _get_max_tokens(self) -> int:
266
+ """Gets max tokens, using a fallback from settings if necessary."""
267
+ try:
268
+ max_ctx = get_max_tokens(self.embedding_model)
269
+ if isinstance(max_ctx, int):
270
+ return max_ctx
271
+ except Exception:
272
+ # Allow local(unknown) models to fall back.
273
+ self.logger.warning("Could not auto-detect max tokens.", model=self.embedding_model)
274
+
275
+ max_ctx = app_settings.EMBEDDING_FALLBACK_MAX_TOKENS
276
+ if not isinstance(max_ctx, int):
277
+ raise RuntimeError("Model not recognized and EMBEDDING_FALLBACK_MAX_TOKENS not set.")
278
+ self.logger.warning("Using configured fallback token limit.", fallback=max_ctx)
279
+ return max_ctx
280
+
281
+ @staticmethod
282
+ def _prepare_text_for_embedding(field: ExtractedField) -> str:
283
+ return f"{field.path}: {str(field.value)}"
284
+
285
+ @staticmethod
286
+ def _compute_content_hash(path: str, value: Any) -> str:
287
+ v = "" if value is None else str(value)
288
+ content = f"{path}:{v}"
289
+ return hashlib.sha256(content.encode("utf-8")).hexdigest()
290
+
291
+ def _make_indexable_record(
292
+ self, field: ExtractedField, entity_id: str, embedding: list[float] | None
293
+ ) -> IndexableRecord:
294
+ return IndexableRecord(
295
+ entity_id=entity_id,
296
+ entity_type=self.config.entity_kind.value,
297
+ path=Ltree(field.path),
298
+ value=field.value,
299
+ value_type=field.value_type,
300
+ content_hash=self._compute_content_hash(field.path, field.value),
301
+ embedding=embedding if embedding else None,
302
+ )
303
+
304
+ @staticmethod
305
+ @lru_cache(maxsize=1)
306
+ def _get_upsert_statement() -> Insert:
307
+ stmt = insert(AiSearchIndex)
308
+ return stmt.on_conflict_do_update(
309
+ index_elements=[AiSearchIndex.entity_id, AiSearchIndex.path],
310
+ set_={
311
+ AiSearchIndex.value: stmt.excluded.value,
312
+ AiSearchIndex.value_type: stmt.excluded.value_type,
313
+ AiSearchIndex.content_hash: stmt.excluded.content_hash,
314
+ AiSearchIndex.embedding: stmt.excluded.embedding,
315
+ },
316
+ )
@@ -0,0 +1,88 @@
1
+ from dataclasses import dataclass
2
+ from typing import Generic, TypeVar
3
+ from uuid import UUID
4
+
5
+ from sqlalchemy.orm import Query
6
+ from sqlalchemy.sql import Select
7
+
8
+ from orchestrator.db import (
9
+ ProcessTable,
10
+ ProductTable,
11
+ SubscriptionTable,
12
+ WorkflowTable,
13
+ )
14
+ from orchestrator.db.database import BaseModel
15
+ from orchestrator.search.core.types import EntityType
16
+
17
+ from .traverse import (
18
+ BaseTraverser,
19
+ ProcessTraverser,
20
+ ProductTraverser,
21
+ SubscriptionTraverser,
22
+ WorkflowTraverser,
23
+ )
24
+
25
+ ModelT = TypeVar("ModelT", bound=BaseModel)
26
+
27
+
28
+ @dataclass(frozen=True)
29
+ class EntityConfig(Generic[ModelT]):
30
+ """A container for all configuration related to a specific entity type."""
31
+
32
+ entity_kind: EntityType
33
+ table: type[ModelT]
34
+
35
+ traverser: "type[BaseTraverser]"
36
+ pk_name: str
37
+ root_name: str
38
+
39
+ def get_all_query(self, entity_id: str | None = None) -> Query | Select:
40
+ query = self.table.query
41
+ if entity_id:
42
+ pk_column = getattr(self.table, self.pk_name)
43
+ query = query.filter(pk_column == UUID(entity_id))
44
+ return query
45
+
46
+
47
+ @dataclass(frozen=True)
48
+ class WorkflowConfig(EntityConfig[WorkflowTable]):
49
+ """Workflows have a custom select() function that filters out deleted workflows."""
50
+
51
+ def get_all_query(self, entity_id: str | None = None) -> Select:
52
+ query = self.table.select()
53
+ if entity_id:
54
+ pk_column = getattr(self.table, self.pk_name)
55
+ query = query.where(pk_column == UUID(entity_id))
56
+ return query
57
+
58
+
59
+ ENTITY_CONFIG_REGISTRY: dict[EntityType, EntityConfig] = {
60
+ EntityType.SUBSCRIPTION: EntityConfig(
61
+ entity_kind=EntityType.SUBSCRIPTION,
62
+ table=SubscriptionTable,
63
+ traverser=SubscriptionTraverser,
64
+ pk_name="subscription_id",
65
+ root_name="subscription",
66
+ ),
67
+ EntityType.PRODUCT: EntityConfig(
68
+ entity_kind=EntityType.PRODUCT,
69
+ table=ProductTable,
70
+ traverser=ProductTraverser,
71
+ pk_name="product_id",
72
+ root_name="product",
73
+ ),
74
+ EntityType.PROCESS: EntityConfig(
75
+ entity_kind=EntityType.PROCESS,
76
+ table=ProcessTable,
77
+ traverser=ProcessTraverser,
78
+ pk_name="process_id",
79
+ root_name="process",
80
+ ),
81
+ EntityType.WORKFLOW: WorkflowConfig(
82
+ entity_kind=EntityType.WORKFLOW,
83
+ table=WorkflowTable,
84
+ traverser=WorkflowTraverser,
85
+ pk_name="workflow_id",
86
+ root_name="workflow",
87
+ ),
88
+ }
@@ -0,0 +1,53 @@
1
+ import structlog
2
+ from sqlalchemy.orm import Query
3
+
4
+ from orchestrator.db import db
5
+ from orchestrator.search.core.types import EntityType
6
+ from orchestrator.search.indexing.indexer import Indexer
7
+ from orchestrator.search.indexing.registry import ENTITY_CONFIG_REGISTRY
8
+
9
+ logger = structlog.get_logger(__name__)
10
+
11
+
12
+ def run_indexing_for_entity(
13
+ entity_kind: EntityType,
14
+ entity_id: str | None = None,
15
+ dry_run: bool = False,
16
+ force_index: bool = False,
17
+ chunk_size: int = 1000,
18
+ ) -> None:
19
+ """Stream and index entities for the given kind.
20
+
21
+ Builds a streaming query via the entity's registry config, disables ORM eager
22
+ loads when applicable and delegates processing to `Indexer`.
23
+
24
+ Args:
25
+ entity_kind (EntityType): The entity type to index (must exist in
26
+ `ENTITY_CONFIG_REGISTRY`).
27
+ entity_id (Optional[str]): If provided, restricts indexing to a single
28
+ entity (UUID string).
29
+ dry_run (bool): When True, runs the full pipeline without performing
30
+ writes or external embedding calls.
31
+ force_index (bool): When True, re-indexes all fields regardless of
32
+ existing hashes.
33
+ chunk_size (int): Number of rows fetched per round-trip and passed to
34
+ the indexer per batch.
35
+
36
+ Returns:
37
+ None
38
+ """
39
+ config = ENTITY_CONFIG_REGISTRY[entity_kind]
40
+
41
+ q = config.get_all_query(entity_id)
42
+
43
+ if isinstance(q, Query):
44
+ q = q.enable_eagerloads(False)
45
+ stmt = q.statement
46
+ else:
47
+ stmt = q
48
+
49
+ stmt = stmt.execution_options(stream_results=True, yield_per=chunk_size)
50
+ entities = db.session.execute(stmt).scalars()
51
+
52
+ indexer = Indexer(config=config, dry_run=dry_run, force_index=force_index)
53
+ indexer.run(entities, chunk_size=chunk_size)
@@ -0,0 +1,209 @@
1
+ from abc import ABC, abstractmethod
2
+ from collections.abc import Iterable
3
+ from typing import Any, cast
4
+
5
+ import structlog
6
+ from sqlalchemy.inspection import inspect
7
+
8
+ from orchestrator.db import ProcessTable, ProductTable, SubscriptionTable, WorkflowTable
9
+ from orchestrator.domain import (
10
+ SUBSCRIPTION_MODEL_REGISTRY,
11
+ SubscriptionModel,
12
+ )
13
+ from orchestrator.domain.base import ProductBlockModel
14
+ from orchestrator.domain.lifecycle import (
15
+ lookup_specialized_type,
16
+ )
17
+ from orchestrator.search.core.exceptions import ModelLoadError, ProductNotInRegistryError
18
+ from orchestrator.search.core.types import ExtractedField, FieldType, TypedValue
19
+ from orchestrator.types import SubscriptionLifecycle
20
+
21
+ logger = structlog.get_logger(__name__)
22
+
23
+
24
+ class BaseTraverser(ABC):
25
+ """An abstract base class for traversing database models."""
26
+
27
+ _LTREE_SEPARATOR = "."
28
+ _MAX_DEPTH = 40
29
+
30
+ @staticmethod
31
+ def _traverse(data: Any, path: str = "", depth: int = 0, max_depth: int = _MAX_DEPTH) -> Iterable[ExtractedField]:
32
+ """Recursive walk through dicts / lists; returns `(path, value)`."""
33
+ if depth >= max_depth:
34
+ logger.error("Max recursive depth reached while traversing: path=%s", path)
35
+ return
36
+ if isinstance(data, dict):
37
+ for key, value in data.items():
38
+ new_path = f"{path}{BaseTraverser._LTREE_SEPARATOR}{key}" if path else key
39
+ yield from BaseTraverser._traverse(value, new_path, depth + 1, max_depth)
40
+
41
+ elif isinstance(data, list):
42
+ if len(data) == 1:
43
+ yield from BaseTraverser._traverse(data[0], path, depth + 1, max_depth)
44
+ else:
45
+ for i, item in enumerate(data):
46
+ new_path = f"{path}{BaseTraverser._LTREE_SEPARATOR}{i}"
47
+ yield from BaseTraverser._traverse(item, new_path, depth + 1, max_depth)
48
+
49
+ elif data is not None:
50
+ yield ExtractedField.from_raw(path, data)
51
+
52
+ @staticmethod
53
+ def _dump_sqlalchemy_fields(entity: Any, exclude: set[str] | None = None) -> dict:
54
+ """Serialize SQLAlchemy column attributes of an entity into a dictionary, with optional exclusions."""
55
+ exclude = exclude or set()
56
+ mapper = inspect(entity.__class__)
57
+ if not mapper:
58
+ return {}
59
+
60
+ return {
61
+ attr.key: getattr(entity, attr.key)
62
+ for attr in mapper.column_attrs
63
+ if hasattr(entity, attr.key) and attr.key not in exclude
64
+ }
65
+
66
+ @classmethod
67
+ @abstractmethod
68
+ def _dump(cls, entity: Any) -> dict:
69
+ """Abstract method to convert a model instance to a dictionary."""
70
+ ...
71
+
72
+ @classmethod
73
+ def get_fields(cls, entity: Any, pk_name: str, root_name: str) -> list[ExtractedField]:
74
+ """Serializes a model instance and returns a list of (path, value) tuples."""
75
+ try:
76
+ data_dict = cls._dump(entity)
77
+ except Exception as e:
78
+ entity_id = getattr(entity, pk_name, "unknown")
79
+ logger.error(f"Failed to serialize {entity.__class__.__name__}", id=str(entity_id), error=str(e))
80
+ return []
81
+
82
+ fields = cls._traverse(data_dict, path=root_name)
83
+ return sorted(fields, key=lambda field: (field.path.count(cls._LTREE_SEPARATOR), field.path))
84
+
85
+
86
+ class SubscriptionTraverser(BaseTraverser):
87
+
88
+ @classmethod
89
+ def _load_model(cls, sub: SubscriptionTable) -> SubscriptionModel | None:
90
+
91
+ base_model_cls = SUBSCRIPTION_MODEL_REGISTRY.get(sub.product.name)
92
+ if not base_model_cls:
93
+ raise ProductNotInRegistryError(f"Product '{sub.product.name}' not in registry.")
94
+ specialized_model_cls = cast(type[SubscriptionModel], lookup_specialized_type(base_model_cls, sub.status))
95
+
96
+ try:
97
+ return specialized_model_cls.from_subscription(sub.subscription_id)
98
+ except Exception as e:
99
+ raise ModelLoadError(f"Failed to load model for subscription_id '{sub.subscription_id}'") from e
100
+
101
+ @classmethod
102
+ def _dump(cls, sub: SubscriptionTable) -> dict:
103
+ """Loads a Pydantic model, dumps it to a dict, and then transforms the keys."""
104
+ model = cls._load_model(sub)
105
+ if not model:
106
+ return {}
107
+
108
+ return model.model_dump(exclude_unset=False)
109
+
110
+
111
+ class ProductTraverser(BaseTraverser):
112
+ """Product traverser dumps core product fields and a nested structure of product blocks."""
113
+
114
+ @classmethod
115
+ def _dump(cls, prod: ProductTable) -> dict[str, Any]:
116
+
117
+ def dump_block_model(model: type[ProductBlockModel], seen: set[str]) -> dict[str, Any]:
118
+ result = {}
119
+ for attr, field in model.model_fields.items():
120
+ field_type = field.annotation
121
+ if isinstance(field_type, type) and issubclass(field_type, ProductBlockModel):
122
+ if attr not in seen:
123
+ seen.add(attr)
124
+ # Use TypedValue to indicate this is a block
125
+ result[attr] = TypedValue(attr, FieldType.BLOCK)
126
+ else:
127
+ # Use TypedValue to indicate this is a resource type
128
+ result[attr] = TypedValue(attr, FieldType.RESOURCE_TYPE)
129
+ return result
130
+
131
+ base = cls._dump_sqlalchemy_fields(prod)
132
+
133
+ # Get domain model for this product
134
+ domain_model_cls = SUBSCRIPTION_MODEL_REGISTRY.get(prod.name)
135
+ if not domain_model_cls:
136
+ return base # No model = skip block info
137
+
138
+ try:
139
+ lifecycle_model = cast(
140
+ type[SubscriptionModel], lookup_specialized_type(domain_model_cls, SubscriptionLifecycle.INITIAL)
141
+ )
142
+ except Exception:
143
+ lifecycle_model = domain_model_cls
144
+
145
+ seen: set[str] = set()
146
+ nested_blocks = {}
147
+
148
+ for attr, field in lifecycle_model.model_fields.items():
149
+ field_type = field.annotation
150
+ if isinstance(field_type, type) and issubclass(field_type, ProductBlockModel):
151
+ nested_blocks[attr] = dump_block_model(field_type, seen)
152
+
153
+ if nested_blocks:
154
+ base["product_blocks"] = nested_blocks
155
+
156
+ return base
157
+
158
+
159
+ class ProcessTraverser(BaseTraverser):
160
+ # We are explicitly excluding 'traceback' and 'steps'
161
+ # to avoid overloading the index with too much data.
162
+ _process_fields_to_exclude: set[str] = {
163
+ "traceback",
164
+ }
165
+
166
+ @classmethod
167
+ def _dump(cls, proc: ProcessTable) -> dict:
168
+ """Serializes a ProcessTable instance into a dictionary, including key relationships."""
169
+
170
+ base = cls._dump_sqlalchemy_fields(proc, exclude=cls._process_fields_to_exclude)
171
+
172
+ if proc.workflow:
173
+ base["workflow_name"] = proc.workflow.name
174
+
175
+ if proc.subscriptions:
176
+ base["subscriptions"] = [
177
+ cls._dump_sqlalchemy_fields(sub) for sub in sorted(proc.subscriptions, key=lambda s: s.subscription_id)
178
+ ]
179
+
180
+ return base
181
+
182
+
183
+ class WorkflowTraverser(BaseTraverser):
184
+ """Traverser for WorkflowTable entities."""
185
+
186
+ @classmethod
187
+ def _dump(cls, workflow: WorkflowTable) -> dict:
188
+ """Serializes a WorkflowTable instance into a dictionary including all fields."""
189
+
190
+ base = cls._dump_sqlalchemy_fields(workflow)
191
+
192
+ if workflow.products:
193
+ for product in sorted(workflow.products, key=lambda p: p.name):
194
+ if product.tag:
195
+ product_key = product.tag.lower()
196
+
197
+ full_product_data = ProductTraverser._dump(product)
198
+
199
+ # Ignore nested dictionaries in the product data.
200
+ # We only want the top-level fields because thats what the search index expects.
201
+ product_reference = {
202
+ key: value for key, value in full_product_data.items() if not isinstance(value, dict)
203
+ }
204
+
205
+ base[product_key] = product_reference
206
+ else:
207
+ logger.warning("Workflow has an associated product without a tag", product_name=product.name)
208
+
209
+ return base
@@ -0,0 +1,3 @@
1
+ from .engine import execute_search
2
+
3
+ __all__ = ["execute_search"]