orchestrator-core 4.5.2__py3-none-any.whl → 4.6.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. orchestrator/__init__.py +1 -1
  2. orchestrator/agentic_app.py +1 -21
  3. orchestrator/api/api_v1/api.py +5 -0
  4. orchestrator/api/api_v1/endpoints/agent.py +50 -0
  5. orchestrator/api/api_v1/endpoints/search.py +120 -201
  6. orchestrator/cli/database.py +3 -0
  7. orchestrator/cli/generate.py +11 -4
  8. orchestrator/cli/generator/generator/migration.py +7 -3
  9. orchestrator/cli/search/resize_embedding.py +28 -22
  10. orchestrator/cli/search/speedtest.py +4 -6
  11. orchestrator/db/__init__.py +6 -0
  12. orchestrator/db/models.py +75 -0
  13. orchestrator/migrations/helpers.py +46 -38
  14. orchestrator/schedules/validate_products.py +1 -1
  15. orchestrator/schemas/search.py +8 -85
  16. orchestrator/search/agent/__init__.py +2 -2
  17. orchestrator/search/agent/agent.py +25 -29
  18. orchestrator/search/agent/json_patch.py +51 -0
  19. orchestrator/search/agent/prompts.py +35 -9
  20. orchestrator/search/agent/state.py +28 -2
  21. orchestrator/search/agent/tools.py +192 -53
  22. orchestrator/search/core/exceptions.py +6 -0
  23. orchestrator/search/core/types.py +1 -0
  24. orchestrator/search/export.py +199 -0
  25. orchestrator/search/indexing/indexer.py +13 -4
  26. orchestrator/search/indexing/registry.py +14 -1
  27. orchestrator/search/llm_migration.py +55 -0
  28. orchestrator/search/retrieval/__init__.py +3 -2
  29. orchestrator/search/retrieval/builder.py +5 -1
  30. orchestrator/search/retrieval/engine.py +66 -23
  31. orchestrator/search/retrieval/pagination.py +46 -56
  32. orchestrator/search/retrieval/query_state.py +61 -0
  33. orchestrator/search/retrieval/retrievers/base.py +26 -40
  34. orchestrator/search/retrieval/retrievers/fuzzy.py +10 -9
  35. orchestrator/search/retrieval/retrievers/hybrid.py +11 -8
  36. orchestrator/search/retrieval/retrievers/semantic.py +9 -8
  37. orchestrator/search/retrieval/retrievers/structured.py +6 -6
  38. orchestrator/search/schemas/parameters.py +17 -13
  39. orchestrator/search/schemas/results.py +4 -1
  40. orchestrator/settings.py +1 -0
  41. orchestrator/utils/auth.py +3 -2
  42. orchestrator/workflows/tasks/validate_product_type.py +3 -3
  43. {orchestrator_core-4.5.2.dist-info → orchestrator_core-4.6.0rc1.dist-info}/METADATA +4 -4
  44. {orchestrator_core-4.5.2.dist-info → orchestrator_core-4.6.0rc1.dist-info}/RECORD +46 -42
  45. {orchestrator_core-4.5.2.dist-info → orchestrator_core-4.6.0rc1.dist-info}/WHEEL +0 -0
  46. {orchestrator_core-4.5.2.dist-info → orchestrator_core-4.6.0rc1.dist-info}/licenses/LICENSE +0 -0
orchestrator/__init__.py CHANGED
@@ -13,7 +13,7 @@
13
13
 
14
14
  """This is the orchestrator workflow engine."""
15
15
 
16
- __version__ = "4.5.2"
16
+ __version__ = "4.6.0rc1"
17
17
 
18
18
 
19
19
  from structlog import get_logger
@@ -27,7 +27,6 @@ from orchestrator.llm_settings import LLMSettings, llm_settings
27
27
 
28
28
  if TYPE_CHECKING:
29
29
  from pydantic_ai.models.openai import OpenAIModel
30
- from pydantic_ai.toolsets import FunctionToolset
31
30
 
32
31
  logger = get_logger(__name__)
33
32
 
@@ -38,19 +37,17 @@ class LLMOrchestratorCore(OrchestratorCore):
38
37
  *args: Any,
39
38
  llm_settings: LLMSettings = llm_settings,
40
39
  agent_model: "OpenAIModel | str | None" = None,
41
- agent_tools: "list[FunctionToolset] | None" = None,
42
40
  **kwargs: Any,
43
41
  ) -> None:
44
42
  """Initialize the `LLMOrchestratorCore` class.
45
43
 
46
44
  This class extends `OrchestratorCore` with LLM features (search and agent).
47
- It runs the search migration and mounts the agent endpoint based on feature flags.
45
+ It runs the search migration based on feature flags.
48
46
 
49
47
  Args:
50
48
  *args: All the normal arguments passed to the `OrchestratorCore` class.
51
49
  llm_settings: A class of settings for the LLM
52
50
  agent_model: Override the agent model (defaults to llm_settings.AGENT_MODEL)
53
- agent_tools: A list of tools that can be used by the agent
54
51
  **kwargs: Additional arguments passed to the `OrchestratorCore` class.
55
52
 
56
53
  Returns:
@@ -58,7 +55,6 @@ class LLMOrchestratorCore(OrchestratorCore):
58
55
  """
59
56
  self.llm_settings = llm_settings
60
57
  self.agent_model = agent_model or llm_settings.AGENT_MODEL
61
- self.agent_tools = agent_tools
62
58
 
63
59
  super().__init__(*args, **kwargs)
64
60
 
@@ -79,22 +75,6 @@ class LLMOrchestratorCore(OrchestratorCore):
79
75
  )
80
76
  raise
81
77
 
82
- # Mount agent endpoint if agent is enabled
83
- if self.llm_settings.AGENT_ENABLED:
84
- logger.info("Initializing agent features", model=self.agent_model)
85
- try:
86
- from orchestrator.search.agent import build_agent_router
87
-
88
- agent_app = build_agent_router(self.agent_model, self.agent_tools)
89
- self.mount("/agent", agent_app)
90
- except ImportError as e:
91
- logger.error(
92
- "Unable to initialize agent features. Please install agent dependencies: "
93
- "`pip install orchestrator-core[agent]`",
94
- error=str(e),
95
- )
96
- raise
97
-
98
78
 
99
79
  main_typer_app = typer.Typer()
100
80
  main_typer_app.add_typer(cli_app, name="orchestrator", help="The orchestrator CLI commands")
@@ -95,3 +95,8 @@ if llm_settings.SEARCH_ENABLED:
95
95
  api_router.include_router(
96
96
  search.router, prefix="/search", tags=["Core", "Search"], dependencies=[Depends(authorize)]
97
97
  )
98
+
99
+ if llm_settings.AGENT_ENABLED:
100
+ from orchestrator.api.api_v1.endpoints import agent
101
+
102
+ api_router.include_router(agent.router, prefix="/agent", tags=["Core", "Agent"], dependencies=[Depends(authorize)])
@@ -0,0 +1,50 @@
1
+ # Copyright 2019-2025 SURF, GÉANT.
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+
14
+ from functools import cache
15
+ from typing import Annotated
16
+
17
+ from fastapi import APIRouter, Depends, Request
18
+ from pydantic_ai.ag_ui import StateDeps, handle_ag_ui_request
19
+ from pydantic_ai.agent import Agent
20
+ from starlette.responses import Response
21
+ from structlog import get_logger
22
+
23
+ from orchestrator.llm_settings import llm_settings
24
+ from orchestrator.search.agent import build_agent_instance
25
+ from orchestrator.search.agent.state import SearchState
26
+
27
+ router = APIRouter()
28
+ logger = get_logger(__name__)
29
+
30
+
31
+ @cache
32
+ def get_agent() -> Agent[StateDeps[SearchState], str]:
33
+ """Dependency to provide the agent instance.
34
+
35
+ The agent is built once and cached for the lifetime of the application.
36
+ """
37
+ return build_agent_instance(llm_settings.AGENT_MODEL)
38
+
39
+
40
+ @router.post("/")
41
+ async def agent_conversation(
42
+ request: Request,
43
+ agent: Annotated[Agent[StateDeps[SearchState], str], Depends(get_agent)],
44
+ ) -> Response:
45
+ """Agent conversation endpoint using pydantic-ai ag_ui protocol.
46
+
47
+ This endpoint handles the interactive agent conversation for search.
48
+ """
49
+ initial_state = SearchState()
50
+ return await handle_ag_ui_request(agent, request, deps=StateDeps(initial_state))
@@ -11,251 +11,121 @@
11
11
  # See the License for the specific language governing permissions and
12
12
  # limitations under the License.
13
13
 
14
- from typing import Any, Literal, overload
15
-
14
+ import structlog
16
15
  from fastapi import APIRouter, HTTPException, Query, status
17
- from sqlalchemy import case, select
18
- from sqlalchemy.orm import selectinload
19
-
20
- from orchestrator.db import (
21
- ProcessTable,
22
- ProductTable,
23
- WorkflowTable,
24
- db,
25
- )
26
- from orchestrator.domain.base import SubscriptionModel
27
- from orchestrator.domain.context_cache import cache_subscription_models
16
+
17
+ from orchestrator.db import db
28
18
  from orchestrator.schemas.search import (
19
+ ExportResponse,
29
20
  PageInfoSchema,
30
21
  PathsResponse,
31
- ProcessSearchResult,
32
- ProcessSearchSchema,
33
- ProductSearchResult,
34
- ProductSearchSchema,
35
22
  SearchResultsSchema,
36
- SubscriptionSearchResult,
37
- WorkflowSearchResult,
38
- WorkflowSearchSchema,
39
23
  )
40
- from orchestrator.search.core.exceptions import InvalidCursorError
24
+ from orchestrator.search.core.exceptions import InvalidCursorError, QueryStateNotFoundError
41
25
  from orchestrator.search.core.types import EntityType, UIType
42
26
  from orchestrator.search.filters.definitions import generate_definitions
43
- from orchestrator.search.indexing.registry import ENTITY_CONFIG_REGISTRY
44
- from orchestrator.search.retrieval import execute_search
27
+ from orchestrator.search.retrieval import SearchQueryState, execute_search, execute_search_for_export
45
28
  from orchestrator.search.retrieval.builder import build_paths_query, create_path_autocomplete_lquery, process_path_rows
46
- from orchestrator.search.retrieval.pagination import (
47
- create_next_page_cursor,
48
- process_pagination_cursor,
49
- )
29
+ from orchestrator.search.retrieval.pagination import PageCursor, encode_next_page_cursor
50
30
  from orchestrator.search.retrieval.validation import is_lquery_syntactically_valid
51
31
  from orchestrator.search.schemas.parameters import (
52
- BaseSearchParameters,
53
32
  ProcessSearchParameters,
54
33
  ProductSearchParameters,
34
+ SearchParameters,
55
35
  SubscriptionSearchParameters,
56
36
  WorkflowSearchParameters,
57
37
  )
58
38
  from orchestrator.search.schemas.results import SearchResult, TypeDefinition
59
- from orchestrator.services.subscriptions import format_special_types
60
39
 
61
40
  router = APIRouter()
41
+ logger = structlog.get_logger(__name__)
62
42
 
63
43
 
64
- def _create_search_result_item(
65
- entity: WorkflowTable | ProductTable | ProcessTable, entity_type: EntityType, search_info: SearchResult
66
- ) -> WorkflowSearchResult | ProductSearchResult | ProcessSearchResult | None:
67
- match entity_type:
68
- case EntityType.WORKFLOW:
69
- workflow_data = WorkflowSearchSchema.model_validate(entity)
70
- return WorkflowSearchResult(
71
- workflow=workflow_data,
72
- score=search_info.score,
73
- perfect_match=search_info.perfect_match,
74
- matching_field=search_info.matching_field,
75
- )
76
- case EntityType.PRODUCT:
77
- product_data = ProductSearchSchema.model_validate(entity)
78
- return ProductSearchResult(
79
- product=product_data,
80
- score=search_info.score,
81
- perfect_match=search_info.perfect_match,
82
- matching_field=search_info.matching_field,
83
- )
84
- case EntityType.PROCESS:
85
- process_data = ProcessSearchSchema.model_validate(entity)
86
- return ProcessSearchResult(
87
- process=process_data,
88
- score=search_info.score,
89
- perfect_match=search_info.perfect_match,
90
- matching_field=search_info.matching_field,
91
- )
92
- case _:
93
- return None
94
-
95
-
96
- @overload
97
44
  async def _perform_search_and_fetch(
98
- search_params: BaseSearchParameters,
99
- entity_type: Literal[EntityType.WORKFLOW],
100
- eager_loads: list[Any],
45
+ search_params: SearchParameters | None = None,
101
46
  cursor: str | None = None,
102
- ) -> SearchResultsSchema[WorkflowSearchResult]: ...
103
-
47
+ query_id: str | None = None,
48
+ ) -> SearchResultsSchema[SearchResult]:
49
+ """Execute search with optional pagination.
50
+
51
+ Args:
52
+ search_params: Search parameters for new search
53
+ cursor: Pagination cursor (loads saved query state)
54
+ query_id: Saved query ID to retrieve and execute
55
+
56
+ Returns:
57
+ Search results with entity_id, score, and matching_field.
58
+ """
59
+ try:
60
+ page_cursor: PageCursor | None = None
61
+
62
+ if cursor:
63
+ page_cursor = PageCursor.decode(cursor)
64
+ query_state = SearchQueryState.load_from_id(page_cursor.query_id)
65
+ elif query_id:
66
+ query_state = SearchQueryState.load_from_id(query_id)
67
+ elif search_params:
68
+ query_state = SearchQueryState(parameters=search_params, query_embedding=None)
69
+ else:
70
+ raise HTTPException(
71
+ status_code=status.HTTP_400_BAD_REQUEST,
72
+ detail="Either search_params, cursor, or query_id must be provided",
73
+ )
104
74
 
105
- @overload
106
- async def _perform_search_and_fetch(
107
- search_params: BaseSearchParameters,
108
- entity_type: Literal[EntityType.PRODUCT],
109
- eager_loads: list[Any],
110
- cursor: str | None = None,
111
- ) -> SearchResultsSchema[ProductSearchResult]: ...
75
+ search_response = await execute_search(
76
+ query_state.parameters, db.session, page_cursor, query_state.query_embedding
77
+ )
78
+ if not search_response.results:
79
+ return SearchResultsSchema(search_metadata=search_response.metadata)
112
80
 
81
+ next_page_cursor = encode_next_page_cursor(search_response, page_cursor, query_state.parameters)
82
+ has_next_page = next_page_cursor is not None
83
+ page_info = PageInfoSchema(has_next_page=has_next_page, next_page_cursor=next_page_cursor)
113
84
 
114
- @overload
115
- async def _perform_search_and_fetch(
116
- search_params: BaseSearchParameters,
117
- entity_type: Literal[EntityType.PROCESS],
118
- eager_loads: list[Any],
119
- cursor: str | None = None,
120
- ) -> SearchResultsSchema[ProcessSearchResult]: ...
85
+ return SearchResultsSchema(
86
+ data=search_response.results, page_info=page_info, search_metadata=search_response.metadata
87
+ )
88
+ except (InvalidCursorError, ValueError) as e:
89
+ raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(e))
90
+ except QueryStateNotFoundError as e:
91
+ raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
92
+ except Exception as e:
93
+ raise HTTPException(
94
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
95
+ detail=f"Search failed: {str(e)}",
96
+ )
121
97
 
122
98
 
123
- async def _perform_search_and_fetch(
124
- search_params: BaseSearchParameters,
125
- entity_type: EntityType,
126
- eager_loads: list[Any],
127
- cursor: str | None = None,
128
- ) -> SearchResultsSchema[Any]:
129
- try:
130
- pagination_params = await process_pagination_cursor(cursor, search_params)
131
- except InvalidCursorError:
132
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid pagination cursor")
133
-
134
- search_response = await execute_search(
135
- search_params=search_params,
136
- db_session=db.session,
137
- pagination_params=pagination_params,
138
- )
139
- if not search_response.results:
140
- return SearchResultsSchema(search_metadata=search_response.metadata)
141
-
142
- next_page_cursor = create_next_page_cursor(search_response.results, pagination_params, search_params.limit)
143
- has_next_page = next_page_cursor is not None
144
- page_info = PageInfoSchema(has_next_page=has_next_page, next_page_cursor=next_page_cursor)
145
-
146
- config = ENTITY_CONFIG_REGISTRY[entity_type]
147
- entity_ids = [res.entity_id for res in search_response.results]
148
- pk_column = getattr(config.table, config.pk_name)
149
- ordering_case = case({entity_id: i for i, entity_id in enumerate(entity_ids)}, value=pk_column)
150
-
151
- stmt = select(config.table).options(*eager_loads).filter(pk_column.in_(entity_ids)).order_by(ordering_case)
152
- entities = db.session.scalars(stmt).all()
153
-
154
- search_info_map = {res.entity_id: res for res in search_response.results}
155
- data = []
156
- for entity in entities:
157
- entity_id = getattr(entity, config.pk_name)
158
- search_info = search_info_map.get(str(entity_id))
159
- if not search_info:
160
- continue
161
-
162
- search_result_item = _create_search_result_item(entity, entity_type, search_info)
163
- if search_result_item:
164
- data.append(search_result_item)
165
-
166
- return SearchResultsSchema(data=data, page_info=page_info, search_metadata=search_response.metadata)
167
-
168
-
169
- @router.post(
170
- "/subscriptions",
171
- response_model=SearchResultsSchema[SubscriptionSearchResult],
172
- )
99
+ @router.post("/subscriptions", response_model=SearchResultsSchema[SearchResult])
173
100
  async def search_subscriptions(
174
101
  search_params: SubscriptionSearchParameters,
175
102
  cursor: str | None = None,
176
- ) -> SearchResultsSchema[SubscriptionSearchResult]:
177
- try:
178
- pagination_params = await process_pagination_cursor(cursor, search_params)
179
- except InvalidCursorError:
180
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid pagination cursor")
181
-
182
- search_response = await execute_search(
183
- search_params=search_params,
184
- db_session=db.session,
185
- pagination_params=pagination_params,
186
- )
187
-
188
- if not search_response.results:
189
- return SearchResultsSchema(search_metadata=search_response.metadata)
190
-
191
- next_page_cursor = create_next_page_cursor(search_response.results, pagination_params, search_params.limit)
192
- has_next_page = next_page_cursor is not None
193
- page_info = PageInfoSchema(has_next_page=has_next_page, next_page_cursor=next_page_cursor)
194
-
195
- search_info_map = {res.entity_id: res for res in search_response.results}
196
-
197
- with cache_subscription_models():
198
- subscriptions_data = {
199
- sub_id: SubscriptionModel.from_subscription(sub_id).model_dump(exclude_unset=False)
200
- for sub_id in search_info_map
201
- }
202
-
203
- results_data = [
204
- SubscriptionSearchResult(
205
- subscription=format_special_types(subscriptions_data[sub_id]),
206
- score=search_info.score,
207
- perfect_match=search_info.perfect_match,
208
- matching_field=search_info.matching_field,
209
- )
210
- for sub_id, search_info in search_info_map.items()
211
- ]
212
-
213
- return SearchResultsSchema(data=results_data, page_info=page_info, search_metadata=search_response.metadata)
103
+ ) -> SearchResultsSchema[SearchResult]:
104
+ return await _perform_search_and_fetch(search_params, cursor)
214
105
 
215
106
 
216
- @router.post("/workflows", response_model=SearchResultsSchema[WorkflowSearchResult])
107
+ @router.post("/workflows", response_model=SearchResultsSchema[SearchResult])
217
108
  async def search_workflows(
218
109
  search_params: WorkflowSearchParameters,
219
110
  cursor: str | None = None,
220
- ) -> SearchResultsSchema[WorkflowSearchResult]:
221
- return await _perform_search_and_fetch(
222
- search_params=search_params,
223
- entity_type=EntityType.WORKFLOW,
224
- eager_loads=[selectinload(WorkflowTable.products)],
225
- cursor=cursor,
226
- )
111
+ ) -> SearchResultsSchema[SearchResult]:
112
+ return await _perform_search_and_fetch(search_params, cursor)
227
113
 
228
114
 
229
- @router.post("/products", response_model=SearchResultsSchema[ProductSearchResult])
115
+ @router.post("/products", response_model=SearchResultsSchema[SearchResult])
230
116
  async def search_products(
231
117
  search_params: ProductSearchParameters,
232
118
  cursor: str | None = None,
233
- ) -> SearchResultsSchema[ProductSearchResult]:
234
- return await _perform_search_and_fetch(
235
- search_params=search_params,
236
- entity_type=EntityType.PRODUCT,
237
- eager_loads=[
238
- selectinload(ProductTable.workflows),
239
- selectinload(ProductTable.fixed_inputs),
240
- selectinload(ProductTable.product_blocks),
241
- ],
242
- cursor=cursor,
243
- )
244
-
245
-
246
- @router.post("/processes", response_model=SearchResultsSchema[ProcessSearchResult])
119
+ ) -> SearchResultsSchema[SearchResult]:
120
+ return await _perform_search_and_fetch(search_params, cursor)
121
+
122
+
123
+ @router.post("/processes", response_model=SearchResultsSchema[SearchResult])
247
124
  async def search_processes(
248
125
  search_params: ProcessSearchParameters,
249
126
  cursor: str | None = None,
250
- ) -> SearchResultsSchema[ProcessSearchResult]:
251
- return await _perform_search_and_fetch(
252
- search_params=search_params,
253
- entity_type=EntityType.PROCESS,
254
- eager_loads=[
255
- selectinload(ProcessTable.workflow),
256
- ],
257
- cursor=cursor,
258
- )
127
+ ) -> SearchResultsSchema[SearchResult]:
128
+ return await _perform_search_and_fetch(search_params, cursor)
259
129
 
260
130
 
261
131
  @router.get(
@@ -294,3 +164,52 @@ async def list_paths(
294
164
  async def get_definitions() -> dict[UIType, TypeDefinition]:
295
165
  """Provide a static definition of operators and schemas for each UI type."""
296
166
  return generate_definitions()
167
+
168
+
169
+ @router.get(
170
+ "/queries/{query_id}",
171
+ response_model=SearchResultsSchema[SearchResult],
172
+ summary="Retrieve saved search results by query_id",
173
+ )
174
+ async def get_by_query_id(
175
+ query_id: str,
176
+ cursor: str | None = None,
177
+ ) -> SearchResultsSchema[SearchResult]:
178
+ """Retrieve and execute a saved search by query_id."""
179
+ return await _perform_search_and_fetch(query_id=query_id, cursor=cursor)
180
+
181
+
182
+ @router.get(
183
+ "/queries/{query_id}/export",
184
+ summary="Export query results by query_id",
185
+ response_model=ExportResponse,
186
+ )
187
+ async def export_by_query_id(query_id: str) -> ExportResponse:
188
+ """Export search results using query_id.
189
+
190
+ The query is retrieved from the database, re-executed, and results are returned
191
+ as flattened records suitable for CSV download.
192
+
193
+ Args:
194
+ query_id: Query UUID
195
+
196
+ Returns:
197
+ ExportResponse containing 'page' with an array of flattened entity records.
198
+
199
+ Raises:
200
+ HTTPException: 404 if query not found, 400 if invalid data
201
+ """
202
+ try:
203
+ query_state = SearchQueryState.load_from_id(query_id)
204
+ export_records = await execute_search_for_export(query_state, db.session)
205
+ return ExportResponse(page=export_records)
206
+ except ValueError as e:
207
+ raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(e))
208
+ except QueryStateNotFoundError as e:
209
+ raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
210
+ except Exception as e:
211
+ logger.error(e)
212
+ raise HTTPException(
213
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
214
+ detail=f"Error executing export: {str(e)}",
215
+ )
@@ -23,6 +23,8 @@ from structlog import get_logger
23
23
 
24
24
  import orchestrator.workflows
25
25
  from orchestrator.cli.domain_gen_helpers.types import ModelUpdates
26
+ from orchestrator.cli.generate import create_writer, get_template_environment
27
+ from orchestrator.cli.generator.generator.migration import create_data_head_if_not_exists
26
28
  from orchestrator.cli.helpers.print_helpers import COLOR, str_fmt
27
29
  from orchestrator.cli.migrate_domain_models import create_domain_models_migration_sql
28
30
  from orchestrator.cli.migrate_tasks import create_tasks_migration_wizard
@@ -223,6 +225,7 @@ def revision(
223
225
  --head TEXT Determine the head you need to add your migration to.
224
226
  ```
225
227
  """
228
+ create_data_head_if_not_exists({"writer": create_writer(), "environment": get_template_environment()})
226
229
  command.revision(alembic_cfg(), message, version_path=version_path, autogenerate=autogenerate, head=head)
227
230
 
228
231
 
@@ -14,6 +14,7 @@
14
14
  # ruff: noqa: S603
15
15
  import subprocess
16
16
  from pathlib import Path
17
+ from typing import Callable
17
18
 
18
19
  import structlog
19
20
  import typer
@@ -78,7 +79,7 @@ def ruff(content: str) -> str:
78
79
  return content
79
80
 
80
81
 
81
- def create_context(config_file: Path, dryrun: bool, force: bool, python_version: str, tdd: bool | None = False) -> dict:
82
+ def create_writer(dryrun: bool = False, force: bool = False) -> Callable[..., None]:
82
83
  def writer(path: Path, content: str, append: bool = False) -> None:
83
84
  content = ruff(content) if path.suffix == ".py" else content
84
85
  if dryrun:
@@ -88,9 +89,15 @@ def create_context(config_file: Path, dryrun: bool, force: bool, python_version:
88
89
  else:
89
90
  write_file(path, content, append=append, force=force)
90
91
 
92
+ return writer
93
+
94
+
95
+ def get_template_environment() -> Environment:
91
96
  search_path = (settings.CUSTOM_TEMPLATES, Path(__file__).parent / "generator" / "templates")
92
- environment = Environment(loader=FileSystemLoader(search_path), autoescape=True, keep_trailing_newline=True)
97
+ return Environment(loader=FileSystemLoader(search_path), autoescape=True, keep_trailing_newline=True)
93
98
 
99
+
100
+ def create_context(config_file: Path, dryrun: bool, force: bool, python_version: str, tdd: bool | None = False) -> dict:
94
101
  config = read_config(config_file)
95
102
  config["variable"] = get_variable(config)
96
103
  for pb in config["product_blocks"]:
@@ -98,10 +105,10 @@ def create_context(config_file: Path, dryrun: bool, force: bool, python_version:
98
105
 
99
106
  return {
100
107
  "config": config,
101
- "environment": environment,
108
+ "environment": get_template_environment(),
102
109
  "python_version": python_version,
103
110
  "tdd": tdd,
104
- "writer": writer,
111
+ "writer": create_writer(dryrun=dryrun, force=force),
105
112
  }
106
113
 
107
114
 
@@ -79,6 +79,12 @@ def create_data_head(context: dict, depends_on: str) -> None:
79
79
  writer(path, content)
80
80
 
81
81
 
82
+ def create_data_head_if_not_exists(context: dict) -> None:
83
+ heads = get_heads()
84
+ if "data" not in heads:
85
+ create_data_head(context=context, depends_on=heads["schema"])
86
+
87
+
82
88
  def extract_revision_info(content: list[str]) -> dict:
83
89
  def process() -> Generator:
84
90
  for line in content:
@@ -136,9 +142,7 @@ def generate_product_migration(context: dict) -> None:
136
142
  environment = context["environment"]
137
143
  writer = context["writer"]
138
144
 
139
- heads = get_heads()
140
- if "data" not in heads:
141
- create_data_head(context=context, depends_on=heads["schema"])
145
+ create_data_head_if_not_exists(context=context)
142
146
 
143
147
  if not (migration_file := create_migration_file(message=f"add {config['name']}", head="data")):
144
148
  logger.error("Could not create migration file")