orchestrator-core 4.7.0rc1__py3-none-any.whl → 4.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. orchestrator/__init__.py +1 -1
  2. orchestrator/app.py +34 -1
  3. orchestrator/cli/scheduler.py +53 -10
  4. orchestrator/graphql/schemas/process.py +2 -2
  5. orchestrator/llm_settings.py +0 -1
  6. orchestrator/migrations/versions/schema/2020-10-19_a76b9185b334_add_generic_workflows_to_core.py +1 -0
  7. orchestrator/migrations/versions/schema/2021-04-06_3c8b9185c221_add_validate_products_task.py +1 -0
  8. orchestrator/migrations/versions/schema/2025-11-18_961eddbd4c13_create_linker_table_workflow_apscheduler.py +1 -1
  9. orchestrator/migrations/versions/schema/2025-12-10_9736496e3eba_set_is_task_true_on_certain_tasks.py +40 -0
  10. orchestrator/schedules/__init__.py +3 -1
  11. orchestrator/schedules/scheduling.py +5 -1
  12. orchestrator/schedules/service.py +32 -3
  13. orchestrator/schemas/search_requests.py +6 -1
  14. orchestrator/search/agent/prompts.py +10 -6
  15. orchestrator/search/agent/tools.py +55 -15
  16. orchestrator/search/aggregations/base.py +6 -2
  17. orchestrator/search/core/types.py +13 -4
  18. orchestrator/search/query/builder.py +75 -3
  19. orchestrator/search/query/engine.py +65 -3
  20. orchestrator/search/query/mixins.py +62 -2
  21. orchestrator/search/query/queries.py +15 -1
  22. orchestrator/search/query/validation.py +43 -0
  23. orchestrator/settings.py +48 -0
  24. orchestrator/workflows/modify_note.py +10 -1
  25. orchestrator/workflows/removed_workflow.py +8 -1
  26. orchestrator/workflows/tasks/cleanup_tasks_log.py +9 -2
  27. orchestrator/workflows/tasks/resume_workflows.py +4 -0
  28. orchestrator/workflows/tasks/validate_product_type.py +7 -1
  29. orchestrator/workflows/tasks/validate_products.py +9 -1
  30. orchestrator/workflows/tasks/validate_subscriptions.py +11 -4
  31. {orchestrator_core-4.7.0rc1.dist-info → orchestrator_core-4.7.1.dist-info}/METADATA +8 -8
  32. {orchestrator_core-4.7.0rc1.dist-info → orchestrator_core-4.7.1.dist-info}/RECORD +34 -33
  33. {orchestrator_core-4.7.0rc1.dist-info → orchestrator_core-4.7.1.dist-info}/WHEEL +0 -0
  34. {orchestrator_core-4.7.0rc1.dist-info → orchestrator_core-4.7.1.dist-info}/licenses/LICENSE +0 -0
@@ -25,6 +25,7 @@ from orchestrator.db.models import AiSearchIndex
25
25
  from orchestrator.search.aggregations import AggregationType, BaseAggregation, CountAggregation
26
26
  from orchestrator.search.core.types import EntityType, FieldType, FilterOp, UIType
27
27
  from orchestrator.search.filters import LtreeFilter
28
+ from orchestrator.search.query.mixins import OrderDirection
28
29
  from orchestrator.search.query.queries import AggregateQuery, CountQuery, Query
29
30
 
30
31
 
@@ -181,7 +182,8 @@ def _build_pivot_cte(base_query: Select, pivot_fields: list[str]) -> CTE:
181
182
 
182
183
 
183
184
  def _build_grouping_columns(
184
- query: CountQuery | AggregateQuery, pivot_cte: CTE
185
+ query: CountQuery | AggregateQuery,
186
+ pivot_cte: CTE,
185
187
  ) -> tuple[list[Any], list[Any], list[str]]:
186
188
  """Build GROUP BY columns and their SELECT columns.
187
189
 
@@ -244,6 +246,76 @@ def _build_aggregation_columns(query: CountQuery | AggregateQuery, pivot_cte: CT
244
246
  return [count_agg.to_expression(pivot_cte.c.entity_id)]
245
247
 
246
248
 
249
+ def _apply_cumulative_aggregations(
250
+ stmt: Select,
251
+ query: CountQuery | AggregateQuery,
252
+ group_column_names: list[str],
253
+ aggregation_columns: list[Label],
254
+ ) -> Select:
255
+ """Add cumulative aggregation columns."""
256
+
257
+ # At this point, cumulative validation has already happened at query build time
258
+ # in GroupingMixin.validate_grouping_constraints, so we know:
259
+ # temporal_group_by exists and has exactly 1 element when cumulative=True
260
+ if not query.cumulative or not aggregation_columns or not query.temporal_group_by:
261
+ return stmt
262
+
263
+ temporal_alias = query.temporal_group_by[0].alias
264
+
265
+ base_subquery = stmt.subquery()
266
+ partition_cols = [base_subquery.c[name] for name in group_column_names if name != temporal_alias]
267
+ order_col = base_subquery.c[temporal_alias]
268
+
269
+ base_columns = [base_subquery.c[col] for col in base_subquery.c.keys()]
270
+
271
+ cumulative_columns = []
272
+ for agg_col in aggregation_columns:
273
+ cumulative_alias = f"{agg_col.key}_cumulative"
274
+ over_kwargs: dict[str, Any] = {"order_by": order_col}
275
+ if partition_cols:
276
+ over_kwargs["partition_by"] = partition_cols
277
+ cumulative_expr = func.sum(base_subquery.c[agg_col.key]).over(**over_kwargs).label(cumulative_alias)
278
+ cumulative_columns.append(cumulative_expr)
279
+
280
+ return select(*(base_columns + cumulative_columns)).select_from(base_subquery)
281
+
282
+
283
+ def _apply_ordering(
284
+ stmt: Select,
285
+ query: CountQuery | AggregateQuery,
286
+ group_column_names: list[str],
287
+ ) -> Select:
288
+ """Apply ordering instructions to the SELECT statement."""
289
+ columns_by_key = {col.key: col for col in stmt.selected_columns}
290
+
291
+ if query.order_by:
292
+ order_expressions = []
293
+ for instruction in query.order_by:
294
+ # 1) exact match
295
+ col = columns_by_key.get(instruction.field)
296
+ if col is None:
297
+ # 2) temporal alias,
298
+ for tg in query.temporal_group_by or []:
299
+ if instruction.field == tg.field or instruction.field == tg.alias:
300
+ col = columns_by_key.get(tg.alias)
301
+ if col is not None:
302
+ break
303
+ if col is None:
304
+ # 3) normalized field path
305
+ col = columns_by_key.get(BaseAggregation.field_to_alias(instruction.field))
306
+ if col is None:
307
+ raise ValueError(f"Cannot order by '{instruction.field}'; column not found.")
308
+ order_expressions.append(col.desc() if instruction.direction == OrderDirection.DESC else col.asc())
309
+ return stmt.order_by(*order_expressions)
310
+
311
+ if query.temporal_group_by:
312
+ # Default ordering by all grouping columns (ascending)
313
+ order_expressions = [columns_by_key[col_name].asc() for col_name in group_column_names]
314
+ return stmt.order_by(*order_expressions)
315
+
316
+ return stmt
317
+
318
+
247
319
  def build_simple_count_query(base_query: Select) -> Select:
248
320
  """Build a simple count query without grouping.
249
321
 
@@ -282,7 +354,7 @@ def build_aggregation_query(query: CountQuery | AggregateQuery, base_query: Sele
282
354
  if group_cols:
283
355
  stmt = stmt.group_by(*group_cols)
284
356
 
285
- if query.temporal_group_by:
286
- stmt = stmt.order_by(*group_cols)
357
+ stmt = _apply_cumulative_aggregations(stmt, query, group_col_names, agg_cols)
358
+ stmt = _apply_ordering(stmt, query, group_col_names)
287
359
 
288
360
  return stmt, group_col_names
@@ -15,7 +15,7 @@ import structlog
15
15
  from sqlalchemy.orm import Session
16
16
 
17
17
  from orchestrator.search.core.embedding import QueryEmbedder
18
- from orchestrator.search.core.types import SearchMetadata
18
+ from orchestrator.search.core.types import EntityType, RetrieverType, SearchMetadata
19
19
  from orchestrator.search.query.results import (
20
20
  AggregationResponse,
21
21
  SearchResponse,
@@ -23,7 +23,13 @@ from orchestrator.search.query.results import (
23
23
  format_search_response,
24
24
  )
25
25
  from orchestrator.search.retrieval.pagination import PageCursor
26
- from orchestrator.search.retrieval.retrievers import Retriever
26
+ from orchestrator.search.retrieval.retrievers import (
27
+ FuzzyRetriever,
28
+ ProcessHybridRetriever,
29
+ Retriever,
30
+ RrfHybridRetriever,
31
+ SemanticRetriever,
32
+ )
27
33
 
28
34
  from .builder import build_aggregation_query, build_candidate_query, build_simple_count_query
29
35
  from .export import fetch_export_data
@@ -32,6 +38,59 @@ from .queries import AggregateQuery, CountQuery, ExportQuery, SelectQuery
32
38
  logger = structlog.get_logger(__name__)
33
39
 
34
40
 
41
+ def _get_retriever_from_override(
42
+ query: SelectQuery | ExportQuery,
43
+ cursor: PageCursor | None,
44
+ query_embedding: list[float] | None,
45
+ ) -> Retriever | None:
46
+ """Get retriever instance from explicit override, or None if no override.
47
+
48
+ Args:
49
+ query: Query that may have a retriever override
50
+ cursor: Pagination cursor
51
+ query_embedding: Pre-computed embedding (may be None)
52
+
53
+ Returns:
54
+ Retriever instance matching the requested type, or None if no override specified
55
+
56
+ Raises:
57
+ ValueError: If override requirements aren't met (e.g., no query text or embedding)
58
+ """
59
+ if query.retriever is None:
60
+ return None
61
+
62
+ retriever_type = query.retriever
63
+
64
+ # Validate query_text (required for all retriever types)
65
+ if not query.query_text:
66
+ raise ValueError(f"{retriever_type.value.capitalize()} retriever requested but no query text provided.")
67
+
68
+ is_process = query.entity_type == EntityType.PROCESS
69
+
70
+ if retriever_type == RetrieverType.FUZZY:
71
+ return (
72
+ ProcessHybridRetriever(None, query.query_text, cursor)
73
+ if is_process
74
+ else FuzzyRetriever(query.query_text, cursor)
75
+ )
76
+ if retriever_type == RetrieverType.SEMANTIC:
77
+ if query_embedding is None:
78
+ raise ValueError(
79
+ "Semantic retriever requested but query embedding is not available. "
80
+ "Embedding generation may have failed."
81
+ )
82
+ return SemanticRetriever(query_embedding, cursor)
83
+ if query_embedding is None:
84
+ raise ValueError(
85
+ "Hybrid retriever requested but query embedding is not available. " "Embedding generation may have failed."
86
+ )
87
+ return (
88
+ ProcessHybridRetriever(query_embedding, query.query_text, cursor)
89
+ if is_process
90
+ else RrfHybridRetriever(query_embedding, query.query_text, cursor)
91
+ )
92
+
93
+
35
94
  async def _execute_search(
36
95
  query: SelectQuery | ExportQuery,
37
96
  db_session: Session,
@@ -60,7 +119,10 @@ async def _execute_search(
60
119
  if query.vector_query and not query_embedding:
61
120
  query_embedding = await QueryEmbedder.generate_for_text_async(query.vector_query)
62
121
 
63
- retriever = Retriever.route(query, cursor, query_embedding)
122
+ # Get retriever (from override or automatic routing)
123
+ retriever = _get_retriever_from_override(query, cursor, query_embedding) or Retriever.route(
124
+ query, cursor, query_embedding
125
+ )
64
126
  logger.debug("Using retriever", retriever_type=retriever.__class__.__name__)
65
127
 
66
128
  final_stmt = retriever.apply(candidate_query)
@@ -1,16 +1,38 @@
1
1
  import uuid
2
+ from enum import Enum
3
+ from typing import Self
2
4
 
3
- from pydantic import BaseModel, Field
5
+ from pydantic import BaseModel, Field, model_validator
4
6
 
5
7
  from orchestrator.search.aggregations import Aggregation, TemporalGrouping
8
+ from orchestrator.search.core.types import RetrieverType
6
9
 
7
10
  __all__ = [
8
11
  "SearchMixin",
9
12
  "GroupingMixin",
10
13
  "AggregationMixin",
14
+ "OrderBy",
15
+ "OrderDirection",
11
16
  ]
12
17
 
13
18
 
19
+ class OrderDirection(str, Enum):
20
+ """Sorting direction for aggregation results."""
21
+
22
+ ASC = "asc"
23
+ DESC = "desc"
24
+
25
+
26
+ class OrderBy(BaseModel):
27
+ """Ordering descriptor for aggregation responses."""
28
+
29
+ field: str = Field(description="Grouping or aggregation field/alias to order by.")
30
+ direction: OrderDirection = Field(
31
+ default=OrderDirection.ASC,
32
+ description="Sorting direction (asc or desc).",
33
+ )
34
+
35
+
14
36
  class SearchMixin(BaseModel):
15
37
  """Mixin providing text search capability.
16
38
 
@@ -18,6 +40,10 @@ class SearchMixin(BaseModel):
18
40
  """
19
41
 
20
42
  query_text: str | None = Field(default=None, description="Text query for semantic/fuzzy search")
43
+ retriever: RetrieverType | None = Field(
44
+ default=None,
45
+ description="Override retriever type (fuzzy/semantic/hybrid). If None, uses default routing logic.",
46
+ )
21
47
 
22
48
  @property
23
49
  def vector_query(self) -> str | None:
@@ -59,6 +85,37 @@ class GroupingMixin(BaseModel):
59
85
  default=None,
60
86
  description="Temporal grouping specifications (group by month, year, etc.)",
61
87
  )
88
+ cumulative: bool = Field(
89
+ default=False,
90
+ description="Enable cumulative aggregations when temporal grouping is present.",
91
+ )
92
+ order_by: list[OrderBy] | None = Field(
93
+ default=None,
94
+ description="Ordering instructions for grouped aggregation results.",
95
+ )
96
+
97
+ @model_validator(mode="after")
98
+ def validate_grouping_constraints(self) -> Self:
99
+ """Validate cross-field constraints for grouping features."""
100
+ if self.order_by and not self.group_by and not self.temporal_group_by:
101
+ raise ValueError(
102
+ "order_by requires at least one grouping field (group_by or temporal_group_by). "
103
+ "Ordering only applies to grouped aggregation results."
104
+ )
105
+
106
+ if self.cumulative:
107
+ if not self.temporal_group_by:
108
+ raise ValueError(
109
+ "cumulative requires at least one temporal grouping (temporal_group_by). "
110
+ "Cumulative aggregations compute running totals over time."
111
+ )
112
+ if len(self.temporal_group_by) > 1:
113
+ raise ValueError(
114
+ "cumulative currently supports only a single temporal grouping. "
115
+ "Multiple temporal dimensions with running totals are not yet supported."
116
+ )
117
+
118
+ return self
62
119
 
63
120
  def get_pivot_fields(self) -> list[str]:
64
121
  """Get all fields needed for EAV pivot from grouping.
@@ -82,7 +139,10 @@ class AggregationMixin(BaseModel):
82
139
  Used by AGGREGATE queries to define what statistics to compute.
83
140
  """
84
141
 
85
- aggregations: list[Aggregation] = Field(description="Aggregations to compute (SUM, AVG, MIN, MAX, COUNT)")
142
+ aggregations: list[Aggregation] = Field(
143
+ description="Aggregations to compute (SUM, AVG, MIN, MAX, COUNT)",
144
+ min_length=1,
145
+ )
86
146
 
87
147
  def get_aggregation_pivot_fields(self) -> list[str]:
88
148
  """Get fields needed for EAV pivot from aggregations.
@@ -13,7 +13,7 @@
13
13
 
14
14
  from typing import Annotated, Any, ClassVar, Literal, Self, Union
15
15
 
16
- from pydantic import BaseModel, ConfigDict, Discriminator, Field
16
+ from pydantic import BaseModel, ConfigDict, Discriminator, Field, model_validator
17
17
 
18
18
  from orchestrator.search.core.types import ActionType, EntityType
19
19
  from orchestrator.search.filters import FilterTree
@@ -112,6 +112,20 @@ class AggregateQuery(BaseQuery, GroupingMixin, AggregationMixin):
112
112
  query_type: Literal["aggregate"] = "aggregate"
113
113
  _action: ClassVar[ActionType] = ActionType.AGGREGATE
114
114
 
115
+ @model_validator(mode="after")
116
+ def validate_cumulative_aggregation_types(self) -> Self:
117
+ """Validate that cumulative is only used with COUNT and SUM aggregations."""
118
+ if self.cumulative:
119
+ from orchestrator.search.aggregations import AggregationType
120
+
121
+ for agg in self.aggregations:
122
+ if agg.type in (AggregationType.AVG, AggregationType.MIN, AggregationType.MAX):
123
+ raise ValueError(
124
+ f"Cumulative aggregations are not supported for {agg.type.value.upper()} aggregations. "
125
+ f"Cumulative only works with COUNT and SUM."
126
+ )
127
+ return self
128
+
115
129
  def get_pivot_fields(self) -> list[str]:
116
130
  """Get all fields needed for EAV pivot including aggregation fields."""
117
131
  # Get grouping fields from GroupingMixin
@@ -31,6 +31,7 @@ from orchestrator.search.query.exceptions import (
31
31
  InvalidLtreePatternError,
32
32
  PathNotFoundError,
33
33
  )
34
+ from orchestrator.search.query.mixins import OrderBy
34
35
 
35
36
 
36
37
  def is_filter_compatible_with_field_type(filter_condition: FilterCondition, field_type: FieldType) -> bool:
@@ -207,3 +208,45 @@ def validate_temporal_grouping_field(field_path: str) -> None:
207
208
  # Validate field type is datetime
208
209
  if field_type_str != FieldType.DATETIME.value:
209
210
  raise IncompatibleTemporalGroupingTypeError(field_path, field_type_str)
211
+
212
+
213
+ def validate_grouping_fields(group_by_paths: list[str]) -> None:
214
+ """Validate that all grouping field paths exist in the database.
215
+
216
+ Args:
217
+ group_by_paths: List of field paths to group by
218
+
219
+ Raises:
220
+ PathNotFoundError: If any path doesn't exist in the database
221
+ """
222
+ for path in group_by_paths:
223
+ field_type = validate_filter_path(path)
224
+ if field_type is None:
225
+ raise PathNotFoundError(path)
226
+
227
+
228
+ def validate_order_by_fields(order_by: list[OrderBy] | None) -> None:
229
+ """Validate that order_by field paths exist in the database.
230
+
231
+ Args:
232
+ order_by: List of ordering instructions, or None
233
+
234
+ Raises:
235
+ PathNotFoundError: If a field path doesn't exist in the database
236
+
237
+ Note:
238
+ Only validates fields that appear to be paths (contain dots).
239
+ Aggregation aliases (no dots, like 'count') are skipped as they
240
+ cannot be validated until query execution time.
241
+ """
242
+ if order_by is None:
243
+ return
244
+
245
+ for order_instr in order_by:
246
+ # Skip aggregation aliases (no dots, e.g., 'count', 'revenue')
247
+ if "." not in order_instr.field:
248
+ continue
249
+
250
+ field_type = validate_filter_path(order_instr.field)
251
+ if field_type is None:
252
+ raise PathNotFoundError(order_instr.field)
orchestrator/settings.py CHANGED
@@ -17,10 +17,13 @@ from pathlib import Path
17
17
  from typing import Literal
18
18
 
19
19
  from pydantic import Field, NonNegativeInt, PostgresDsn, RedisDsn
20
+ from pydantic.main import BaseModel
20
21
  from pydantic_settings import BaseSettings
21
22
 
23
+ from oauth2_lib.fastapi import OIDCUserModel
22
24
  from oauth2_lib.settings import oauth2lib_settings
23
25
  from orchestrator.services.settings_env_variables import expose_settings
26
+ from orchestrator.utils.auth import Authorizer
24
27
  from orchestrator.utils.expose_settings import SecretStr as OrchSecretStr
25
28
  from pydantic_forms.types import strEnum
26
29
 
@@ -111,3 +114,48 @@ if app_settings.EXPOSE_SETTINGS:
111
114
  expose_settings("app_settings", app_settings) # type: ignore
112
115
  if app_settings.EXPOSE_OAUTH_SETTINGS:
113
116
  expose_settings("oauth2lib_settings", oauth2lib_settings) # type: ignore
117
+
118
+
119
+ class Authorizers(BaseModel):
120
+ # Callbacks specifically for orchestrator-core callbacks.
121
+ # Separate from defaults for user-defined workflows and steps.
122
+ internal_authorize_callback: Authorizer | None = None
123
+ internal_retry_auth_callback: Authorizer | None = None
124
+
125
+ async def authorize_callback(self, user: OIDCUserModel | None) -> bool:
126
+ """This is the authorize_callback to be registered for workflows defined within orchestrator-core.
127
+
128
+ If Authorizers.internal_authorize_callback is None, this function will return True.
129
+ i.e. any user will be authorized to start internal workflows.
130
+ """
131
+ if self.internal_authorize_callback is None:
132
+ return True
133
+ return await self.internal_authorize_callback(user)
134
+
135
+ async def retry_auth_callback(self, user: OIDCUserModel | None) -> bool:
136
+ """This is the retry_auth_callback to be registered for workflows defined within orchestrator-core.
137
+
138
+ If Authorizers.internal_retry_auth_callback is None, this function will return True.
139
+ i.e. any user will be authorized to retry internal workflows on failure.
140
+ """
141
+ if self.internal_retry_auth_callback is None:
142
+ return True
143
+ return await self.internal_retry_auth_callback(user)
144
+
145
+
146
+ _authorizers = Authorizers()
147
+
148
+
149
+ def get_authorizers() -> Authorizers:
150
+ """Acquire singleton of app authorizers to assign these callbacks at app setup.
151
+
152
+ Ensures downstream users can acquire singleton without being tempted to do
153
+ from orchestrator.settings import authorizers
154
+ authorizers = my_authorizers
155
+ or
156
+ from orchestrator import settings
157
+ settings.authorizers = my_authorizers
158
+
159
+ ...each of which goes wrong in its own way.
160
+ """
161
+ return _authorizers
@@ -13,6 +13,7 @@
13
13
  from orchestrator.db import db
14
14
  from orchestrator.forms import SubmitFormPage
15
15
  from orchestrator.services import subscriptions
16
+ from orchestrator.settings import get_authorizers
16
17
  from orchestrator.targets import Target
17
18
  from orchestrator.utils.json import to_serializable
18
19
  from orchestrator.workflow import StepList, done, init, step, workflow
@@ -21,6 +22,8 @@ from orchestrator.workflows.utils import wrap_modify_initial_input_form
21
22
  from pydantic_forms.types import FormGenerator, State, UUIDstr
22
23
  from pydantic_forms.validators import LongText
23
24
 
25
+ authorizers = get_authorizers()
26
+
24
27
 
25
28
  def initial_input_form(subscription_id: UUIDstr) -> FormGenerator:
26
29
  subscription = subscriptions.get_subscription(subscription_id)
@@ -51,6 +54,12 @@ def store_subscription_note(subscription_id: UUIDstr, note: str) -> State:
51
54
  }
52
55
 
53
56
 
54
- @workflow("Modify Note", initial_input_form=wrap_modify_initial_input_form(initial_input_form), target=Target.MODIFY)
57
+ @workflow(
58
+ "Modify Note",
59
+ initial_input_form=wrap_modify_initial_input_form(initial_input_form),
60
+ target=Target.MODIFY,
61
+ authorize_callback=authorizers.authorize_callback,
62
+ retry_auth_callback=authorizers.retry_auth_callback,
63
+ )
55
64
  def modify_note() -> StepList:
56
65
  return init >> store_process_subscription() >> store_subscription_note >> done
@@ -12,11 +12,18 @@
12
12
  # limitations under the License.
13
13
 
14
14
 
15
+ from orchestrator.settings import get_authorizers
15
16
  from orchestrator.workflow import StepList, workflow
16
17
 
18
+ authorizers = get_authorizers()
19
+
17
20
 
18
21
  # This workflow has been made to create the initial import process for a SN7 subscription
19
22
  # it does not do anything but is needed for the correct showing in the GUI.
20
- @workflow("Dummy workflow to replace removed workflows")
23
+ @workflow(
24
+ "Dummy workflow to replace removed workflows",
25
+ authorize_callback=authorizers.authorize_callback,
26
+ retry_auth_callback=authorizers.retry_auth_callback,
27
+ )
21
28
  def removed_workflow() -> StepList:
22
29
  return StepList()
@@ -17,12 +17,14 @@ from datetime import timedelta
17
17
  from sqlalchemy import select
18
18
 
19
19
  from orchestrator.db import ProcessTable, db
20
- from orchestrator.settings import app_settings
20
+ from orchestrator.settings import app_settings, get_authorizers
21
21
  from orchestrator.targets import Target
22
22
  from orchestrator.utils.datetime import nowtz
23
23
  from orchestrator.workflow import ProcessStatus, StepList, done, init, step, workflow
24
24
  from pydantic_forms.types import State
25
25
 
26
+ authorizers = get_authorizers()
27
+
26
28
 
27
29
  @step("Clean up completed tasks older than TASK_LOG_RETENTION_DAYS")
28
30
  def remove_tasks() -> State:
@@ -41,6 +43,11 @@ def remove_tasks() -> State:
41
43
  return {"tasks_removed": count}
42
44
 
43
45
 
44
- @workflow("Clean up old tasks", target=Target.SYSTEM)
46
+ @workflow(
47
+ "Clean up old tasks",
48
+ target=Target.SYSTEM,
49
+ authorize_callback=authorizers.authorize_callback,
50
+ retry_auth_callback=authorizers.retry_auth_callback,
51
+ )
45
52
  def task_clean_up_tasks() -> StepList:
46
53
  return init >> remove_tasks >> done
@@ -17,10 +17,12 @@ from sqlalchemy import select
17
17
 
18
18
  from orchestrator.db import ProcessTable, db
19
19
  from orchestrator.services import processes
20
+ from orchestrator.settings import get_authorizers
20
21
  from orchestrator.targets import Target
21
22
  from orchestrator.workflow import ProcessStatus, StepList, done, init, step, workflow
22
23
  from pydantic_forms.types import State, UUIDstr
23
24
 
25
+ authorizers = get_authorizers()
24
26
  logger = structlog.get_logger(__name__)
25
27
 
26
28
 
@@ -110,6 +112,8 @@ def restart_created_workflows(created_state_process_ids: list[UUIDstr]) -> State
110
112
  @workflow(
111
113
  "Resume all workflows that are stuck on tasks with the status 'waiting', 'created' or 'resumed'",
112
114
  target=Target.SYSTEM,
115
+ authorize_callback=authorizers.authorize_callback,
116
+ retry_auth_callback=authorizers.retry_auth_callback,
113
117
  )
114
118
  def task_resume_workflows() -> StepList:
115
119
  return init >> find_waiting_workflows >> resume_found_workflows >> restart_created_workflows >> done
@@ -25,10 +25,12 @@ from orchestrator.services.workflows import (
25
25
  get_validation_product_workflows_for_subscription,
26
26
  start_validation_workflow_for_workflows,
27
27
  )
28
+ from orchestrator.settings import get_authorizers
28
29
  from orchestrator.targets import Target
29
30
  from orchestrator.workflow import StepList, done, init, step, workflow
30
31
  from pydantic_forms.types import FormGenerator, State
31
32
 
33
+ authorizers = get_authorizers()
32
34
  logger = structlog.get_logger(__name__)
33
35
 
34
36
 
@@ -86,7 +88,11 @@ def validate_product_type(product_type: str) -> State:
86
88
 
87
89
 
88
90
  @workflow(
89
- "Validate all subscriptions of Product Type", target=Target.SYSTEM, initial_input_form=initial_input_form_generator
91
+ "Validate all subscriptions of Product Type",
92
+ target=Target.SYSTEM,
93
+ initial_input_form=initial_input_form_generator,
94
+ authorize_callback=authorizers.authorize_callback,
95
+ retry_auth_callback=authorizers.retry_auth_callback,
90
96
  )
91
97
  def task_validate_product_type() -> StepList:
92
98
  return init >> validate_product_type >> done
@@ -26,12 +26,15 @@ from orchestrator.services import products
26
26
  from orchestrator.services.products import get_products
27
27
  from orchestrator.services.translations import generate_translations
28
28
  from orchestrator.services.workflows import get_workflow_by_name, get_workflows
29
+ from orchestrator.settings import get_authorizers
29
30
  from orchestrator.targets import Target
30
31
  from orchestrator.utils.errors import ProcessFailureError
31
32
  from orchestrator.utils.fixed_inputs import fixed_input_configuration as fi_configuration
32
33
  from orchestrator.workflow import StepList, done, init, step, workflow
33
34
  from pydantic_forms.types import State
34
35
 
36
+ authorizers = get_authorizers()
37
+
35
38
  # Since these errors are probably programming failures we should not throw AssertionErrors
36
39
 
37
40
 
@@ -187,7 +190,12 @@ def check_subscription_models() -> State:
187
190
  return {"check_subscription_models": True}
188
191
 
189
192
 
190
- @workflow("Validate products", target=Target.SYSTEM)
193
+ @workflow(
194
+ "Validate products",
195
+ target=Target.SYSTEM,
196
+ authorize_callback=authorizers.authorize_callback,
197
+ retry_auth_callback=authorizers.retry_auth_callback,
198
+ )
191
199
  def task_validate_products() -> StepList:
192
200
  return (
193
201
  init
@@ -24,15 +24,17 @@ from orchestrator.services.workflows import (
24
24
  get_validation_product_workflows_for_subscription,
25
25
  start_validation_workflow_for_workflows,
26
26
  )
27
- from orchestrator.settings import app_settings
27
+ from orchestrator.settings import app_settings, get_authorizers
28
28
  from orchestrator.targets import Target
29
- from orchestrator.workflow import StepList, init, step, workflow
29
+ from orchestrator.workflow import StepList, done, init, step, workflow
30
30
 
31
31
  logger = structlog.get_logger(__name__)
32
32
 
33
33
 
34
34
  task_semaphore = BoundedSemaphore(value=2)
35
35
 
36
+ authorizers = get_authorizers()
37
+
36
38
 
37
39
  @step("Validate subscriptions")
38
40
  def validate_subscriptions() -> None:
@@ -56,6 +58,11 @@ def validate_subscriptions() -> None:
56
58
  start_validation_workflow_for_workflows(subscription=subscription, workflows=validation_product_workflows)
57
59
 
58
60
 
59
- @workflow("Validate subscriptions", target=Target.SYSTEM)
61
+ @workflow(
62
+ "Validate subscriptions",
63
+ target=Target.SYSTEM,
64
+ authorize_callback=authorizers.authorize_callback,
65
+ retry_auth_callback=authorizers.retry_auth_callback,
66
+ )
60
67
  def task_validate_subscriptions() -> StepList:
61
- return init >> validate_subscriptions
68
+ return init >> validate_subscriptions >> done