orchestrator-core 4.7.0rc1__py3-none-any.whl → 4.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orchestrator/__init__.py +1 -1
- orchestrator/app.py +34 -1
- orchestrator/cli/scheduler.py +53 -10
- orchestrator/graphql/schemas/process.py +2 -2
- orchestrator/llm_settings.py +0 -1
- orchestrator/migrations/versions/schema/2020-10-19_a76b9185b334_add_generic_workflows_to_core.py +1 -0
- orchestrator/migrations/versions/schema/2021-04-06_3c8b9185c221_add_validate_products_task.py +1 -0
- orchestrator/migrations/versions/schema/2025-11-18_961eddbd4c13_create_linker_table_workflow_apscheduler.py +1 -1
- orchestrator/migrations/versions/schema/2025-12-10_9736496e3eba_set_is_task_true_on_certain_tasks.py +40 -0
- orchestrator/schedules/__init__.py +3 -1
- orchestrator/schedules/scheduling.py +5 -1
- orchestrator/schedules/service.py +32 -3
- orchestrator/schemas/search_requests.py +6 -1
- orchestrator/search/agent/prompts.py +10 -6
- orchestrator/search/agent/tools.py +55 -15
- orchestrator/search/aggregations/base.py +6 -2
- orchestrator/search/core/types.py +13 -4
- orchestrator/search/query/builder.py +75 -3
- orchestrator/search/query/engine.py +65 -3
- orchestrator/search/query/mixins.py +62 -2
- orchestrator/search/query/queries.py +15 -1
- orchestrator/search/query/validation.py +43 -0
- orchestrator/settings.py +48 -0
- orchestrator/workflows/modify_note.py +10 -1
- orchestrator/workflows/removed_workflow.py +8 -1
- orchestrator/workflows/tasks/cleanup_tasks_log.py +9 -2
- orchestrator/workflows/tasks/resume_workflows.py +4 -0
- orchestrator/workflows/tasks/validate_product_type.py +7 -1
- orchestrator/workflows/tasks/validate_products.py +9 -1
- orchestrator/workflows/tasks/validate_subscriptions.py +11 -4
- {orchestrator_core-4.7.0rc1.dist-info → orchestrator_core-4.7.1.dist-info}/METADATA +8 -8
- {orchestrator_core-4.7.0rc1.dist-info → orchestrator_core-4.7.1.dist-info}/RECORD +34 -33
- {orchestrator_core-4.7.0rc1.dist-info → orchestrator_core-4.7.1.dist-info}/WHEEL +0 -0
- {orchestrator_core-4.7.0rc1.dist-info → orchestrator_core-4.7.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -25,6 +25,7 @@ from orchestrator.db.models import AiSearchIndex
|
|
|
25
25
|
from orchestrator.search.aggregations import AggregationType, BaseAggregation, CountAggregation
|
|
26
26
|
from orchestrator.search.core.types import EntityType, FieldType, FilterOp, UIType
|
|
27
27
|
from orchestrator.search.filters import LtreeFilter
|
|
28
|
+
from orchestrator.search.query.mixins import OrderDirection
|
|
28
29
|
from orchestrator.search.query.queries import AggregateQuery, CountQuery, Query
|
|
29
30
|
|
|
30
31
|
|
|
@@ -181,7 +182,8 @@ def _build_pivot_cte(base_query: Select, pivot_fields: list[str]) -> CTE:
|
|
|
181
182
|
|
|
182
183
|
|
|
183
184
|
def _build_grouping_columns(
|
|
184
|
-
query: CountQuery | AggregateQuery,
|
|
185
|
+
query: CountQuery | AggregateQuery,
|
|
186
|
+
pivot_cte: CTE,
|
|
185
187
|
) -> tuple[list[Any], list[Any], list[str]]:
|
|
186
188
|
"""Build GROUP BY columns and their SELECT columns.
|
|
187
189
|
|
|
@@ -244,6 +246,76 @@ def _build_aggregation_columns(query: CountQuery | AggregateQuery, pivot_cte: CT
|
|
|
244
246
|
return [count_agg.to_expression(pivot_cte.c.entity_id)]
|
|
245
247
|
|
|
246
248
|
|
|
249
|
+
def _apply_cumulative_aggregations(
|
|
250
|
+
stmt: Select,
|
|
251
|
+
query: CountQuery | AggregateQuery,
|
|
252
|
+
group_column_names: list[str],
|
|
253
|
+
aggregation_columns: list[Label],
|
|
254
|
+
) -> Select:
|
|
255
|
+
"""Add cumulative aggregation columns."""
|
|
256
|
+
|
|
257
|
+
# At this point, cumulative validation has already happened at query build time
|
|
258
|
+
# in GroupingMixin.validate_grouping_constraints, so we know:
|
|
259
|
+
# temporal_group_by exists and has exactly 1 element when cumulative=True
|
|
260
|
+
if not query.cumulative or not aggregation_columns or not query.temporal_group_by:
|
|
261
|
+
return stmt
|
|
262
|
+
|
|
263
|
+
temporal_alias = query.temporal_group_by[0].alias
|
|
264
|
+
|
|
265
|
+
base_subquery = stmt.subquery()
|
|
266
|
+
partition_cols = [base_subquery.c[name] for name in group_column_names if name != temporal_alias]
|
|
267
|
+
order_col = base_subquery.c[temporal_alias]
|
|
268
|
+
|
|
269
|
+
base_columns = [base_subquery.c[col] for col in base_subquery.c.keys()]
|
|
270
|
+
|
|
271
|
+
cumulative_columns = []
|
|
272
|
+
for agg_col in aggregation_columns:
|
|
273
|
+
cumulative_alias = f"{agg_col.key}_cumulative"
|
|
274
|
+
over_kwargs: dict[str, Any] = {"order_by": order_col}
|
|
275
|
+
if partition_cols:
|
|
276
|
+
over_kwargs["partition_by"] = partition_cols
|
|
277
|
+
cumulative_expr = func.sum(base_subquery.c[agg_col.key]).over(**over_kwargs).label(cumulative_alias)
|
|
278
|
+
cumulative_columns.append(cumulative_expr)
|
|
279
|
+
|
|
280
|
+
return select(*(base_columns + cumulative_columns)).select_from(base_subquery)
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
def _apply_ordering(
|
|
284
|
+
stmt: Select,
|
|
285
|
+
query: CountQuery | AggregateQuery,
|
|
286
|
+
group_column_names: list[str],
|
|
287
|
+
) -> Select:
|
|
288
|
+
"""Apply ordering instructions to the SELECT statement."""
|
|
289
|
+
columns_by_key = {col.key: col for col in stmt.selected_columns}
|
|
290
|
+
|
|
291
|
+
if query.order_by:
|
|
292
|
+
order_expressions = []
|
|
293
|
+
for instruction in query.order_by:
|
|
294
|
+
# 1) exact match
|
|
295
|
+
col = columns_by_key.get(instruction.field)
|
|
296
|
+
if col is None:
|
|
297
|
+
# 2) temporal alias,
|
|
298
|
+
for tg in query.temporal_group_by or []:
|
|
299
|
+
if instruction.field == tg.field or instruction.field == tg.alias:
|
|
300
|
+
col = columns_by_key.get(tg.alias)
|
|
301
|
+
if col is not None:
|
|
302
|
+
break
|
|
303
|
+
if col is None:
|
|
304
|
+
# 3) normalized field path
|
|
305
|
+
col = columns_by_key.get(BaseAggregation.field_to_alias(instruction.field))
|
|
306
|
+
if col is None:
|
|
307
|
+
raise ValueError(f"Cannot order by '{instruction.field}'; column not found.")
|
|
308
|
+
order_expressions.append(col.desc() if instruction.direction == OrderDirection.DESC else col.asc())
|
|
309
|
+
return stmt.order_by(*order_expressions)
|
|
310
|
+
|
|
311
|
+
if query.temporal_group_by:
|
|
312
|
+
# Default ordering by all grouping columns (ascending)
|
|
313
|
+
order_expressions = [columns_by_key[col_name].asc() for col_name in group_column_names]
|
|
314
|
+
return stmt.order_by(*order_expressions)
|
|
315
|
+
|
|
316
|
+
return stmt
|
|
317
|
+
|
|
318
|
+
|
|
247
319
|
def build_simple_count_query(base_query: Select) -> Select:
|
|
248
320
|
"""Build a simple count query without grouping.
|
|
249
321
|
|
|
@@ -282,7 +354,7 @@ def build_aggregation_query(query: CountQuery | AggregateQuery, base_query: Sele
|
|
|
282
354
|
if group_cols:
|
|
283
355
|
stmt = stmt.group_by(*group_cols)
|
|
284
356
|
|
|
285
|
-
|
|
286
|
-
|
|
357
|
+
stmt = _apply_cumulative_aggregations(stmt, query, group_col_names, agg_cols)
|
|
358
|
+
stmt = _apply_ordering(stmt, query, group_col_names)
|
|
287
359
|
|
|
288
360
|
return stmt, group_col_names
|
|
@@ -15,7 +15,7 @@ import structlog
|
|
|
15
15
|
from sqlalchemy.orm import Session
|
|
16
16
|
|
|
17
17
|
from orchestrator.search.core.embedding import QueryEmbedder
|
|
18
|
-
from orchestrator.search.core.types import SearchMetadata
|
|
18
|
+
from orchestrator.search.core.types import EntityType, RetrieverType, SearchMetadata
|
|
19
19
|
from orchestrator.search.query.results import (
|
|
20
20
|
AggregationResponse,
|
|
21
21
|
SearchResponse,
|
|
@@ -23,7 +23,13 @@ from orchestrator.search.query.results import (
|
|
|
23
23
|
format_search_response,
|
|
24
24
|
)
|
|
25
25
|
from orchestrator.search.retrieval.pagination import PageCursor
|
|
26
|
-
from orchestrator.search.retrieval.retrievers import
|
|
26
|
+
from orchestrator.search.retrieval.retrievers import (
|
|
27
|
+
FuzzyRetriever,
|
|
28
|
+
ProcessHybridRetriever,
|
|
29
|
+
Retriever,
|
|
30
|
+
RrfHybridRetriever,
|
|
31
|
+
SemanticRetriever,
|
|
32
|
+
)
|
|
27
33
|
|
|
28
34
|
from .builder import build_aggregation_query, build_candidate_query, build_simple_count_query
|
|
29
35
|
from .export import fetch_export_data
|
|
@@ -32,6 +38,59 @@ from .queries import AggregateQuery, CountQuery, ExportQuery, SelectQuery
|
|
|
32
38
|
logger = structlog.get_logger(__name__)
|
|
33
39
|
|
|
34
40
|
|
|
41
|
+
def _get_retriever_from_override(
|
|
42
|
+
query: SelectQuery | ExportQuery,
|
|
43
|
+
cursor: PageCursor | None,
|
|
44
|
+
query_embedding: list[float] | None,
|
|
45
|
+
) -> Retriever | None:
|
|
46
|
+
"""Get retriever instance from explicit override, or None if no override.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
query: Query that may have a retriever override
|
|
50
|
+
cursor: Pagination cursor
|
|
51
|
+
query_embedding: Pre-computed embedding (may be None)
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
Retriever instance matching the requested type, or None if no override specified
|
|
55
|
+
|
|
56
|
+
Raises:
|
|
57
|
+
ValueError: If override requirements aren't met (e.g., no query text or embedding)
|
|
58
|
+
"""
|
|
59
|
+
if query.retriever is None:
|
|
60
|
+
return None
|
|
61
|
+
|
|
62
|
+
retriever_type = query.retriever
|
|
63
|
+
|
|
64
|
+
# Validate query_text (required for all retriever types)
|
|
65
|
+
if not query.query_text:
|
|
66
|
+
raise ValueError(f"{retriever_type.value.capitalize()} retriever requested but no query text provided.")
|
|
67
|
+
|
|
68
|
+
is_process = query.entity_type == EntityType.PROCESS
|
|
69
|
+
|
|
70
|
+
if retriever_type == RetrieverType.FUZZY:
|
|
71
|
+
return (
|
|
72
|
+
ProcessHybridRetriever(None, query.query_text, cursor)
|
|
73
|
+
if is_process
|
|
74
|
+
else FuzzyRetriever(query.query_text, cursor)
|
|
75
|
+
)
|
|
76
|
+
if retriever_type == RetrieverType.SEMANTIC:
|
|
77
|
+
if query_embedding is None:
|
|
78
|
+
raise ValueError(
|
|
79
|
+
"Semantic retriever requested but query embedding is not available. "
|
|
80
|
+
"Embedding generation may have failed."
|
|
81
|
+
)
|
|
82
|
+
return SemanticRetriever(query_embedding, cursor)
|
|
83
|
+
if query_embedding is None:
|
|
84
|
+
raise ValueError(
|
|
85
|
+
"Hybrid retriever requested but query embedding is not available. " "Embedding generation may have failed."
|
|
86
|
+
)
|
|
87
|
+
return (
|
|
88
|
+
ProcessHybridRetriever(query_embedding, query.query_text, cursor)
|
|
89
|
+
if is_process
|
|
90
|
+
else RrfHybridRetriever(query_embedding, query.query_text, cursor)
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
|
|
35
94
|
async def _execute_search(
|
|
36
95
|
query: SelectQuery | ExportQuery,
|
|
37
96
|
db_session: Session,
|
|
@@ -60,7 +119,10 @@ async def _execute_search(
|
|
|
60
119
|
if query.vector_query and not query_embedding:
|
|
61
120
|
query_embedding = await QueryEmbedder.generate_for_text_async(query.vector_query)
|
|
62
121
|
|
|
63
|
-
retriever
|
|
122
|
+
# Get retriever (from override or automatic routing)
|
|
123
|
+
retriever = _get_retriever_from_override(query, cursor, query_embedding) or Retriever.route(
|
|
124
|
+
query, cursor, query_embedding
|
|
125
|
+
)
|
|
64
126
|
logger.debug("Using retriever", retriever_type=retriever.__class__.__name__)
|
|
65
127
|
|
|
66
128
|
final_stmt = retriever.apply(candidate_query)
|
|
@@ -1,16 +1,38 @@
|
|
|
1
1
|
import uuid
|
|
2
|
+
from enum import Enum
|
|
3
|
+
from typing import Self
|
|
2
4
|
|
|
3
|
-
from pydantic import BaseModel, Field
|
|
5
|
+
from pydantic import BaseModel, Field, model_validator
|
|
4
6
|
|
|
5
7
|
from orchestrator.search.aggregations import Aggregation, TemporalGrouping
|
|
8
|
+
from orchestrator.search.core.types import RetrieverType
|
|
6
9
|
|
|
7
10
|
__all__ = [
|
|
8
11
|
"SearchMixin",
|
|
9
12
|
"GroupingMixin",
|
|
10
13
|
"AggregationMixin",
|
|
14
|
+
"OrderBy",
|
|
15
|
+
"OrderDirection",
|
|
11
16
|
]
|
|
12
17
|
|
|
13
18
|
|
|
19
|
+
class OrderDirection(str, Enum):
|
|
20
|
+
"""Sorting direction for aggregation results."""
|
|
21
|
+
|
|
22
|
+
ASC = "asc"
|
|
23
|
+
DESC = "desc"
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class OrderBy(BaseModel):
|
|
27
|
+
"""Ordering descriptor for aggregation responses."""
|
|
28
|
+
|
|
29
|
+
field: str = Field(description="Grouping or aggregation field/alias to order by.")
|
|
30
|
+
direction: OrderDirection = Field(
|
|
31
|
+
default=OrderDirection.ASC,
|
|
32
|
+
description="Sorting direction (asc or desc).",
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
|
|
14
36
|
class SearchMixin(BaseModel):
|
|
15
37
|
"""Mixin providing text search capability.
|
|
16
38
|
|
|
@@ -18,6 +40,10 @@ class SearchMixin(BaseModel):
|
|
|
18
40
|
"""
|
|
19
41
|
|
|
20
42
|
query_text: str | None = Field(default=None, description="Text query for semantic/fuzzy search")
|
|
43
|
+
retriever: RetrieverType | None = Field(
|
|
44
|
+
default=None,
|
|
45
|
+
description="Override retriever type (fuzzy/semantic/hybrid). If None, uses default routing logic.",
|
|
46
|
+
)
|
|
21
47
|
|
|
22
48
|
@property
|
|
23
49
|
def vector_query(self) -> str | None:
|
|
@@ -59,6 +85,37 @@ class GroupingMixin(BaseModel):
|
|
|
59
85
|
default=None,
|
|
60
86
|
description="Temporal grouping specifications (group by month, year, etc.)",
|
|
61
87
|
)
|
|
88
|
+
cumulative: bool = Field(
|
|
89
|
+
default=False,
|
|
90
|
+
description="Enable cumulative aggregations when temporal grouping is present.",
|
|
91
|
+
)
|
|
92
|
+
order_by: list[OrderBy] | None = Field(
|
|
93
|
+
default=None,
|
|
94
|
+
description="Ordering instructions for grouped aggregation results.",
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
@model_validator(mode="after")
|
|
98
|
+
def validate_grouping_constraints(self) -> Self:
|
|
99
|
+
"""Validate cross-field constraints for grouping features."""
|
|
100
|
+
if self.order_by and not self.group_by and not self.temporal_group_by:
|
|
101
|
+
raise ValueError(
|
|
102
|
+
"order_by requires at least one grouping field (group_by or temporal_group_by). "
|
|
103
|
+
"Ordering only applies to grouped aggregation results."
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
if self.cumulative:
|
|
107
|
+
if not self.temporal_group_by:
|
|
108
|
+
raise ValueError(
|
|
109
|
+
"cumulative requires at least one temporal grouping (temporal_group_by). "
|
|
110
|
+
"Cumulative aggregations compute running totals over time."
|
|
111
|
+
)
|
|
112
|
+
if len(self.temporal_group_by) > 1:
|
|
113
|
+
raise ValueError(
|
|
114
|
+
"cumulative currently supports only a single temporal grouping. "
|
|
115
|
+
"Multiple temporal dimensions with running totals are not yet supported."
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
return self
|
|
62
119
|
|
|
63
120
|
def get_pivot_fields(self) -> list[str]:
|
|
64
121
|
"""Get all fields needed for EAV pivot from grouping.
|
|
@@ -82,7 +139,10 @@ class AggregationMixin(BaseModel):
|
|
|
82
139
|
Used by AGGREGATE queries to define what statistics to compute.
|
|
83
140
|
"""
|
|
84
141
|
|
|
85
|
-
aggregations: list[Aggregation] = Field(
|
|
142
|
+
aggregations: list[Aggregation] = Field(
|
|
143
|
+
description="Aggregations to compute (SUM, AVG, MIN, MAX, COUNT)",
|
|
144
|
+
min_length=1,
|
|
145
|
+
)
|
|
86
146
|
|
|
87
147
|
def get_aggregation_pivot_fields(self) -> list[str]:
|
|
88
148
|
"""Get fields needed for EAV pivot from aggregations.
|
|
@@ -13,7 +13,7 @@
|
|
|
13
13
|
|
|
14
14
|
from typing import Annotated, Any, ClassVar, Literal, Self, Union
|
|
15
15
|
|
|
16
|
-
from pydantic import BaseModel, ConfigDict, Discriminator, Field
|
|
16
|
+
from pydantic import BaseModel, ConfigDict, Discriminator, Field, model_validator
|
|
17
17
|
|
|
18
18
|
from orchestrator.search.core.types import ActionType, EntityType
|
|
19
19
|
from orchestrator.search.filters import FilterTree
|
|
@@ -112,6 +112,20 @@ class AggregateQuery(BaseQuery, GroupingMixin, AggregationMixin):
|
|
|
112
112
|
query_type: Literal["aggregate"] = "aggregate"
|
|
113
113
|
_action: ClassVar[ActionType] = ActionType.AGGREGATE
|
|
114
114
|
|
|
115
|
+
@model_validator(mode="after")
|
|
116
|
+
def validate_cumulative_aggregation_types(self) -> Self:
|
|
117
|
+
"""Validate that cumulative is only used with COUNT and SUM aggregations."""
|
|
118
|
+
if self.cumulative:
|
|
119
|
+
from orchestrator.search.aggregations import AggregationType
|
|
120
|
+
|
|
121
|
+
for agg in self.aggregations:
|
|
122
|
+
if agg.type in (AggregationType.AVG, AggregationType.MIN, AggregationType.MAX):
|
|
123
|
+
raise ValueError(
|
|
124
|
+
f"Cumulative aggregations are not supported for {agg.type.value.upper()} aggregations. "
|
|
125
|
+
f"Cumulative only works with COUNT and SUM."
|
|
126
|
+
)
|
|
127
|
+
return self
|
|
128
|
+
|
|
115
129
|
def get_pivot_fields(self) -> list[str]:
|
|
116
130
|
"""Get all fields needed for EAV pivot including aggregation fields."""
|
|
117
131
|
# Get grouping fields from GroupingMixin
|
|
@@ -31,6 +31,7 @@ from orchestrator.search.query.exceptions import (
|
|
|
31
31
|
InvalidLtreePatternError,
|
|
32
32
|
PathNotFoundError,
|
|
33
33
|
)
|
|
34
|
+
from orchestrator.search.query.mixins import OrderBy
|
|
34
35
|
|
|
35
36
|
|
|
36
37
|
def is_filter_compatible_with_field_type(filter_condition: FilterCondition, field_type: FieldType) -> bool:
|
|
@@ -207,3 +208,45 @@ def validate_temporal_grouping_field(field_path: str) -> None:
|
|
|
207
208
|
# Validate field type is datetime
|
|
208
209
|
if field_type_str != FieldType.DATETIME.value:
|
|
209
210
|
raise IncompatibleTemporalGroupingTypeError(field_path, field_type_str)
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def validate_grouping_fields(group_by_paths: list[str]) -> None:
|
|
214
|
+
"""Validate that all grouping field paths exist in the database.
|
|
215
|
+
|
|
216
|
+
Args:
|
|
217
|
+
group_by_paths: List of field paths to group by
|
|
218
|
+
|
|
219
|
+
Raises:
|
|
220
|
+
PathNotFoundError: If any path doesn't exist in the database
|
|
221
|
+
"""
|
|
222
|
+
for path in group_by_paths:
|
|
223
|
+
field_type = validate_filter_path(path)
|
|
224
|
+
if field_type is None:
|
|
225
|
+
raise PathNotFoundError(path)
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
def validate_order_by_fields(order_by: list[OrderBy] | None) -> None:
|
|
229
|
+
"""Validate that order_by field paths exist in the database.
|
|
230
|
+
|
|
231
|
+
Args:
|
|
232
|
+
order_by: List of ordering instructions, or None
|
|
233
|
+
|
|
234
|
+
Raises:
|
|
235
|
+
PathNotFoundError: If a field path doesn't exist in the database
|
|
236
|
+
|
|
237
|
+
Note:
|
|
238
|
+
Only validates fields that appear to be paths (contain dots).
|
|
239
|
+
Aggregation aliases (no dots, like 'count') are skipped as they
|
|
240
|
+
cannot be validated until query execution time.
|
|
241
|
+
"""
|
|
242
|
+
if order_by is None:
|
|
243
|
+
return
|
|
244
|
+
|
|
245
|
+
for order_instr in order_by:
|
|
246
|
+
# Skip aggregation aliases (no dots, e.g., 'count', 'revenue')
|
|
247
|
+
if "." not in order_instr.field:
|
|
248
|
+
continue
|
|
249
|
+
|
|
250
|
+
field_type = validate_filter_path(order_instr.field)
|
|
251
|
+
if field_type is None:
|
|
252
|
+
raise PathNotFoundError(order_instr.field)
|
orchestrator/settings.py
CHANGED
|
@@ -17,10 +17,13 @@ from pathlib import Path
|
|
|
17
17
|
from typing import Literal
|
|
18
18
|
|
|
19
19
|
from pydantic import Field, NonNegativeInt, PostgresDsn, RedisDsn
|
|
20
|
+
from pydantic.main import BaseModel
|
|
20
21
|
from pydantic_settings import BaseSettings
|
|
21
22
|
|
|
23
|
+
from oauth2_lib.fastapi import OIDCUserModel
|
|
22
24
|
from oauth2_lib.settings import oauth2lib_settings
|
|
23
25
|
from orchestrator.services.settings_env_variables import expose_settings
|
|
26
|
+
from orchestrator.utils.auth import Authorizer
|
|
24
27
|
from orchestrator.utils.expose_settings import SecretStr as OrchSecretStr
|
|
25
28
|
from pydantic_forms.types import strEnum
|
|
26
29
|
|
|
@@ -111,3 +114,48 @@ if app_settings.EXPOSE_SETTINGS:
|
|
|
111
114
|
expose_settings("app_settings", app_settings) # type: ignore
|
|
112
115
|
if app_settings.EXPOSE_OAUTH_SETTINGS:
|
|
113
116
|
expose_settings("oauth2lib_settings", oauth2lib_settings) # type: ignore
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
class Authorizers(BaseModel):
|
|
120
|
+
# Callbacks specifically for orchestrator-core callbacks.
|
|
121
|
+
# Separate from defaults for user-defined workflows and steps.
|
|
122
|
+
internal_authorize_callback: Authorizer | None = None
|
|
123
|
+
internal_retry_auth_callback: Authorizer | None = None
|
|
124
|
+
|
|
125
|
+
async def authorize_callback(self, user: OIDCUserModel | None) -> bool:
|
|
126
|
+
"""This is the authorize_callback to be registered for workflows defined within orchestrator-core.
|
|
127
|
+
|
|
128
|
+
If Authorizers.internal_authorize_callback is None, this function will return True.
|
|
129
|
+
i.e. any user will be authorized to start internal workflows.
|
|
130
|
+
"""
|
|
131
|
+
if self.internal_authorize_callback is None:
|
|
132
|
+
return True
|
|
133
|
+
return await self.internal_authorize_callback(user)
|
|
134
|
+
|
|
135
|
+
async def retry_auth_callback(self, user: OIDCUserModel | None) -> bool:
|
|
136
|
+
"""This is the retry_auth_callback to be registered for workflows defined within orchestrator-core.
|
|
137
|
+
|
|
138
|
+
If Authorizers.internal_retry_auth_callback is None, this function will return True.
|
|
139
|
+
i.e. any user will be authorized to retry internal workflows on failure.
|
|
140
|
+
"""
|
|
141
|
+
if self.internal_retry_auth_callback is None:
|
|
142
|
+
return True
|
|
143
|
+
return await self.internal_retry_auth_callback(user)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
_authorizers = Authorizers()
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def get_authorizers() -> Authorizers:
|
|
150
|
+
"""Acquire singleton of app authorizers to assign these callbacks at app setup.
|
|
151
|
+
|
|
152
|
+
Ensures downstream users can acquire singleton without being tempted to do
|
|
153
|
+
from orchestrator.settings import authorizers
|
|
154
|
+
authorizers = my_authorizers
|
|
155
|
+
or
|
|
156
|
+
from orchestrator import settings
|
|
157
|
+
settings.authorizers = my_authorizers
|
|
158
|
+
|
|
159
|
+
...each of which goes wrong in its own way.
|
|
160
|
+
"""
|
|
161
|
+
return _authorizers
|
|
@@ -13,6 +13,7 @@
|
|
|
13
13
|
from orchestrator.db import db
|
|
14
14
|
from orchestrator.forms import SubmitFormPage
|
|
15
15
|
from orchestrator.services import subscriptions
|
|
16
|
+
from orchestrator.settings import get_authorizers
|
|
16
17
|
from orchestrator.targets import Target
|
|
17
18
|
from orchestrator.utils.json import to_serializable
|
|
18
19
|
from orchestrator.workflow import StepList, done, init, step, workflow
|
|
@@ -21,6 +22,8 @@ from orchestrator.workflows.utils import wrap_modify_initial_input_form
|
|
|
21
22
|
from pydantic_forms.types import FormGenerator, State, UUIDstr
|
|
22
23
|
from pydantic_forms.validators import LongText
|
|
23
24
|
|
|
25
|
+
authorizers = get_authorizers()
|
|
26
|
+
|
|
24
27
|
|
|
25
28
|
def initial_input_form(subscription_id: UUIDstr) -> FormGenerator:
|
|
26
29
|
subscription = subscriptions.get_subscription(subscription_id)
|
|
@@ -51,6 +54,12 @@ def store_subscription_note(subscription_id: UUIDstr, note: str) -> State:
|
|
|
51
54
|
}
|
|
52
55
|
|
|
53
56
|
|
|
54
|
-
@workflow(
|
|
57
|
+
@workflow(
|
|
58
|
+
"Modify Note",
|
|
59
|
+
initial_input_form=wrap_modify_initial_input_form(initial_input_form),
|
|
60
|
+
target=Target.MODIFY,
|
|
61
|
+
authorize_callback=authorizers.authorize_callback,
|
|
62
|
+
retry_auth_callback=authorizers.retry_auth_callback,
|
|
63
|
+
)
|
|
55
64
|
def modify_note() -> StepList:
|
|
56
65
|
return init >> store_process_subscription() >> store_subscription_note >> done
|
|
@@ -12,11 +12,18 @@
|
|
|
12
12
|
# limitations under the License.
|
|
13
13
|
|
|
14
14
|
|
|
15
|
+
from orchestrator.settings import get_authorizers
|
|
15
16
|
from orchestrator.workflow import StepList, workflow
|
|
16
17
|
|
|
18
|
+
authorizers = get_authorizers()
|
|
19
|
+
|
|
17
20
|
|
|
18
21
|
# This workflow has been made to create the initial import process for a SN7 subscription
|
|
19
22
|
# it does not do anything but is needed for the correct showing in the GUI.
|
|
20
|
-
@workflow(
|
|
23
|
+
@workflow(
|
|
24
|
+
"Dummy workflow to replace removed workflows",
|
|
25
|
+
authorize_callback=authorizers.authorize_callback,
|
|
26
|
+
retry_auth_callback=authorizers.retry_auth_callback,
|
|
27
|
+
)
|
|
21
28
|
def removed_workflow() -> StepList:
|
|
22
29
|
return StepList()
|
|
@@ -17,12 +17,14 @@ from datetime import timedelta
|
|
|
17
17
|
from sqlalchemy import select
|
|
18
18
|
|
|
19
19
|
from orchestrator.db import ProcessTable, db
|
|
20
|
-
from orchestrator.settings import app_settings
|
|
20
|
+
from orchestrator.settings import app_settings, get_authorizers
|
|
21
21
|
from orchestrator.targets import Target
|
|
22
22
|
from orchestrator.utils.datetime import nowtz
|
|
23
23
|
from orchestrator.workflow import ProcessStatus, StepList, done, init, step, workflow
|
|
24
24
|
from pydantic_forms.types import State
|
|
25
25
|
|
|
26
|
+
authorizers = get_authorizers()
|
|
27
|
+
|
|
26
28
|
|
|
27
29
|
@step("Clean up completed tasks older than TASK_LOG_RETENTION_DAYS")
|
|
28
30
|
def remove_tasks() -> State:
|
|
@@ -41,6 +43,11 @@ def remove_tasks() -> State:
|
|
|
41
43
|
return {"tasks_removed": count}
|
|
42
44
|
|
|
43
45
|
|
|
44
|
-
@workflow(
|
|
46
|
+
@workflow(
|
|
47
|
+
"Clean up old tasks",
|
|
48
|
+
target=Target.SYSTEM,
|
|
49
|
+
authorize_callback=authorizers.authorize_callback,
|
|
50
|
+
retry_auth_callback=authorizers.retry_auth_callback,
|
|
51
|
+
)
|
|
45
52
|
def task_clean_up_tasks() -> StepList:
|
|
46
53
|
return init >> remove_tasks >> done
|
|
@@ -17,10 +17,12 @@ from sqlalchemy import select
|
|
|
17
17
|
|
|
18
18
|
from orchestrator.db import ProcessTable, db
|
|
19
19
|
from orchestrator.services import processes
|
|
20
|
+
from orchestrator.settings import get_authorizers
|
|
20
21
|
from orchestrator.targets import Target
|
|
21
22
|
from orchestrator.workflow import ProcessStatus, StepList, done, init, step, workflow
|
|
22
23
|
from pydantic_forms.types import State, UUIDstr
|
|
23
24
|
|
|
25
|
+
authorizers = get_authorizers()
|
|
24
26
|
logger = structlog.get_logger(__name__)
|
|
25
27
|
|
|
26
28
|
|
|
@@ -110,6 +112,8 @@ def restart_created_workflows(created_state_process_ids: list[UUIDstr]) -> State
|
|
|
110
112
|
@workflow(
|
|
111
113
|
"Resume all workflows that are stuck on tasks with the status 'waiting', 'created' or 'resumed'",
|
|
112
114
|
target=Target.SYSTEM,
|
|
115
|
+
authorize_callback=authorizers.authorize_callback,
|
|
116
|
+
retry_auth_callback=authorizers.retry_auth_callback,
|
|
113
117
|
)
|
|
114
118
|
def task_resume_workflows() -> StepList:
|
|
115
119
|
return init >> find_waiting_workflows >> resume_found_workflows >> restart_created_workflows >> done
|
|
@@ -25,10 +25,12 @@ from orchestrator.services.workflows import (
|
|
|
25
25
|
get_validation_product_workflows_for_subscription,
|
|
26
26
|
start_validation_workflow_for_workflows,
|
|
27
27
|
)
|
|
28
|
+
from orchestrator.settings import get_authorizers
|
|
28
29
|
from orchestrator.targets import Target
|
|
29
30
|
from orchestrator.workflow import StepList, done, init, step, workflow
|
|
30
31
|
from pydantic_forms.types import FormGenerator, State
|
|
31
32
|
|
|
33
|
+
authorizers = get_authorizers()
|
|
32
34
|
logger = structlog.get_logger(__name__)
|
|
33
35
|
|
|
34
36
|
|
|
@@ -86,7 +88,11 @@ def validate_product_type(product_type: str) -> State:
|
|
|
86
88
|
|
|
87
89
|
|
|
88
90
|
@workflow(
|
|
89
|
-
"Validate all subscriptions of Product Type",
|
|
91
|
+
"Validate all subscriptions of Product Type",
|
|
92
|
+
target=Target.SYSTEM,
|
|
93
|
+
initial_input_form=initial_input_form_generator,
|
|
94
|
+
authorize_callback=authorizers.authorize_callback,
|
|
95
|
+
retry_auth_callback=authorizers.retry_auth_callback,
|
|
90
96
|
)
|
|
91
97
|
def task_validate_product_type() -> StepList:
|
|
92
98
|
return init >> validate_product_type >> done
|
|
@@ -26,12 +26,15 @@ from orchestrator.services import products
|
|
|
26
26
|
from orchestrator.services.products import get_products
|
|
27
27
|
from orchestrator.services.translations import generate_translations
|
|
28
28
|
from orchestrator.services.workflows import get_workflow_by_name, get_workflows
|
|
29
|
+
from orchestrator.settings import get_authorizers
|
|
29
30
|
from orchestrator.targets import Target
|
|
30
31
|
from orchestrator.utils.errors import ProcessFailureError
|
|
31
32
|
from orchestrator.utils.fixed_inputs import fixed_input_configuration as fi_configuration
|
|
32
33
|
from orchestrator.workflow import StepList, done, init, step, workflow
|
|
33
34
|
from pydantic_forms.types import State
|
|
34
35
|
|
|
36
|
+
authorizers = get_authorizers()
|
|
37
|
+
|
|
35
38
|
# Since these errors are probably programming failures we should not throw AssertionErrors
|
|
36
39
|
|
|
37
40
|
|
|
@@ -187,7 +190,12 @@ def check_subscription_models() -> State:
|
|
|
187
190
|
return {"check_subscription_models": True}
|
|
188
191
|
|
|
189
192
|
|
|
190
|
-
@workflow(
|
|
193
|
+
@workflow(
|
|
194
|
+
"Validate products",
|
|
195
|
+
target=Target.SYSTEM,
|
|
196
|
+
authorize_callback=authorizers.authorize_callback,
|
|
197
|
+
retry_auth_callback=authorizers.retry_auth_callback,
|
|
198
|
+
)
|
|
191
199
|
def task_validate_products() -> StepList:
|
|
192
200
|
return (
|
|
193
201
|
init
|
|
@@ -24,15 +24,17 @@ from orchestrator.services.workflows import (
|
|
|
24
24
|
get_validation_product_workflows_for_subscription,
|
|
25
25
|
start_validation_workflow_for_workflows,
|
|
26
26
|
)
|
|
27
|
-
from orchestrator.settings import app_settings
|
|
27
|
+
from orchestrator.settings import app_settings, get_authorizers
|
|
28
28
|
from orchestrator.targets import Target
|
|
29
|
-
from orchestrator.workflow import StepList, init, step, workflow
|
|
29
|
+
from orchestrator.workflow import StepList, done, init, step, workflow
|
|
30
30
|
|
|
31
31
|
logger = structlog.get_logger(__name__)
|
|
32
32
|
|
|
33
33
|
|
|
34
34
|
task_semaphore = BoundedSemaphore(value=2)
|
|
35
35
|
|
|
36
|
+
authorizers = get_authorizers()
|
|
37
|
+
|
|
36
38
|
|
|
37
39
|
@step("Validate subscriptions")
|
|
38
40
|
def validate_subscriptions() -> None:
|
|
@@ -56,6 +58,11 @@ def validate_subscriptions() -> None:
|
|
|
56
58
|
start_validation_workflow_for_workflows(subscription=subscription, workflows=validation_product_workflows)
|
|
57
59
|
|
|
58
60
|
|
|
59
|
-
@workflow(
|
|
61
|
+
@workflow(
|
|
62
|
+
"Validate subscriptions",
|
|
63
|
+
target=Target.SYSTEM,
|
|
64
|
+
authorize_callback=authorizers.authorize_callback,
|
|
65
|
+
retry_auth_callback=authorizers.retry_auth_callback,
|
|
66
|
+
)
|
|
60
67
|
def task_validate_subscriptions() -> StepList:
|
|
61
|
-
return init >> validate_subscriptions
|
|
68
|
+
return init >> validate_subscriptions >> done
|