orchestrator-core 4.1.0rc1__py3-none-any.whl → 4.2.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
orchestrator/__init__.py CHANGED
@@ -13,7 +13,7 @@
13
13
 
14
14
  """This is the orchestrator workflow engine."""
15
15
 
16
- __version__ = "4.1.0rc1"
16
+ __version__ = "4.2.0rc1"
17
17
 
18
18
  from orchestrator.app import OrchestratorCore
19
19
  from orchestrator.settings import app_settings
@@ -25,7 +25,7 @@ from fastapi.param_functions import Body, Depends, Header
25
25
  from fastapi.routing import APIRouter
26
26
  from fastapi.websockets import WebSocket
27
27
  from fastapi_etag.dependency import CacheHit
28
- from more_itertools import chunked
28
+ from more_itertools import chunked, last
29
29
  from sentry_sdk.tracing import trace
30
30
  from sqlalchemy import CompoundSelect, Select, select
31
31
  from sqlalchemy.orm import defer, joinedload
@@ -56,6 +56,7 @@ from orchestrator.services.processes import (
56
56
  )
57
57
  from orchestrator.services.settings import get_engine_settings
58
58
  from orchestrator.settings import app_settings
59
+ from orchestrator.utils.auth import Authorizer
59
60
  from orchestrator.utils.enrich_process import enrich_process
60
61
  from orchestrator.websocket import (
61
62
  WS_CHANNELS,
@@ -63,7 +64,7 @@ from orchestrator.websocket import (
63
64
  broadcast_process_update_to_websocket,
64
65
  websocket_manager,
65
66
  )
66
- from orchestrator.workflow import ProcessStatus
67
+ from orchestrator.workflow import ProcessStat, ProcessStatus, StepList, Workflow
67
68
  from pydantic_forms.types import JSON, State
68
69
 
69
70
  router = APIRouter()
@@ -86,6 +87,48 @@ def check_global_lock() -> None:
86
87
  )
87
88
 
88
89
 
90
+ def get_current_steps(pstat: ProcessStat) -> StepList:
91
+ """Extract past and current steps from the ProcessStat."""
92
+ remaining_steps = pstat.log
93
+ past_steps = pstat.workflow.steps[: -len(remaining_steps)]
94
+ return StepList(past_steps + [pstat.log[0]])
95
+
96
+
97
+ def get_auth_callbacks(steps: StepList, workflow: Workflow) -> tuple[Authorizer | None, Authorizer | None]:
98
+ """Iterate over workflow and prior steps to determine correct authorization callbacks for the current step.
99
+
100
+ It's safest to always iterate through the steps. We could track these callbacks statefully
101
+ as we progress through the workflow, but if we fail a step and the system restarts, the previous
102
+ callbacks will be lost if they're only available in the process state.
103
+
104
+ Priority:
105
+ - RESUME callback is explicit RESUME callback, else previous START/RESUME callback
106
+ - RETRY callback is explicit RETRY, else explicit RESUME, else previous RETRY
107
+ """
108
+ # Default to workflow start callbacks
109
+ auth_resume = workflow.authorize_callback
110
+ # auth_retry defaults to the workflow start callback if not otherwise specified.
111
+ # A workflow SHOULD have both callbacks set to not-None. This enforces the correct default regardless.
112
+ auth_retry = workflow.retry_auth_callback or auth_resume # type: ignore[unreachable, truthy-function]
113
+
114
+ # Choose the most recently established value for resume.
115
+ auth_resume = last(filter(None, (step.resume_auth_callback for step in steps)), auth_resume)
116
+ # Choose the most recently established value for retry, unless there is a more recent value for resume.
117
+ auth_retry = last(
118
+ filter(None, (step.retry_auth_callback or step.resume_auth_callback for step in steps)), auth_retry
119
+ )
120
+ return auth_resume, auth_retry
121
+
122
+
123
+ def can_be_resumed(status: ProcessStatus) -> bool:
124
+ return status in (
125
+ ProcessStatus.SUSPENDED, # Can be resumed
126
+ ProcessStatus.FAILED, # Can be retried
127
+ ProcessStatus.API_UNAVAILABLE, # subtype of FAILED
128
+ ProcessStatus.INCONSISTENT_DATA, # subtype of FAILED
129
+ )
130
+
131
+
89
132
  def resolve_user_name(
90
133
  *,
91
134
  reporter: Reporter | None,
@@ -115,6 +158,9 @@ def delete(process_id: UUID) -> None:
115
158
  if not process:
116
159
  raise_status(HTTPStatus.NOT_FOUND)
117
160
 
161
+ if not process.is_task:
162
+ raise_status(HTTPStatus.BAD_REQUEST)
163
+
118
164
  db.session.delete(db.session.get(ProcessTable, process_id))
119
165
  db.session.commit()
120
166
 
@@ -150,18 +196,25 @@ def new_process(
150
196
  dependencies=[Depends(check_global_lock, use_cache=False)],
151
197
  )
152
198
  def resume_process_endpoint(
153
- process_id: UUID, request: Request, json_data: JSON = Body(...), user: str = Depends(user_name)
199
+ process_id: UUID,
200
+ request: Request,
201
+ json_data: JSON = Body(...),
202
+ user: str = Depends(user_name),
203
+ user_model: OIDCUserModel | None = Depends(authenticate),
154
204
  ) -> None:
155
205
  process = _get_process(process_id)
156
206
 
157
- if process.last_status == ProcessStatus.COMPLETED:
158
- raise_status(HTTPStatus.CONFLICT, "Resuming a completed workflow is not possible")
159
-
160
- if process.last_status == ProcessStatus.RUNNING:
161
- raise_status(HTTPStatus.CONFLICT, "Resuming a running workflow is not possible")
207
+ if not can_be_resumed(process.last_status):
208
+ raise_status(HTTPStatus.CONFLICT, f"Resuming a {process.last_status.lower()} workflow is not possible")
162
209
 
163
- if process.last_status == ProcessStatus.RESUMED:
164
- raise_status(HTTPStatus.CONFLICT, "Resuming a resumed workflow is not possible")
210
+ pstat = load_process(process)
211
+ auth_resume, auth_retry = get_auth_callbacks(get_current_steps(pstat), pstat.workflow)
212
+ if process.last_status == ProcessStatus.SUSPENDED:
213
+ if auth_resume is not None and not auth_resume(user_model):
214
+ raise_status(HTTPStatus.FORBIDDEN, "User is not authorized to resume step")
215
+ elif process.last_status == ProcessStatus.FAILED:
216
+ if auth_retry is not None and not auth_retry(user_model):
217
+ raise_status(HTTPStatus.FORBIDDEN, "User is not authorized to retry step")
165
218
 
166
219
  broadcast_invalidate_status_counts()
167
220
  broadcast_func = api_broadcast_process_data(request)
@@ -220,7 +273,7 @@ def update_progress_on_awaiting_process_endpoint(
220
273
  @router.put(
221
274
  "/resume-all", response_model=ProcessResumeAllSchema, dependencies=[Depends(check_global_lock, use_cache=False)]
222
275
  )
223
- async def resume_all_processess_endpoint(request: Request, user: str = Depends(user_name)) -> dict[str, int]:
276
+ async def resume_all_processes_endpoint(request: Request, user: str = Depends(user_name)) -> dict[str, int]:
224
277
  """Retry all task processes in status Failed, Waiting, API Unavailable or Inconsistent Data.
225
278
 
226
279
  The retry is started in the background, returning status 200 and number of processes in message.
@@ -47,10 +47,12 @@ from orchestrator.services.subscriptions import (
47
47
  subscription_workflows,
48
48
  )
49
49
  from orchestrator.settings import app_settings
50
+ from orchestrator.targets import Target
50
51
  from orchestrator.types import SubscriptionLifecycle
51
52
  from orchestrator.utils.deprecation_logger import deprecated_endpoint
52
53
  from orchestrator.utils.get_subscription_dict import get_subscription_dict
53
54
  from orchestrator.websocket import sync_invalidate_subscription_cache
55
+ from orchestrator.workflows import get_workflow
54
56
 
55
57
  router = APIRouter()
56
58
 
@@ -100,6 +102,25 @@ def _filter_statuses(filter_statuses: str | None = None) -> list[str]:
100
102
  return statuses
101
103
 
102
104
 
105
+ def _authorized_subscription_workflows(
106
+ subscription: SubscriptionTable, current_user: OIDCUserModel | None
107
+ ) -> dict[str, list[dict[str, list[Any] | str]]]:
108
+ subscription_workflows_dict = subscription_workflows(subscription)
109
+
110
+ for workflow_target in Target.values():
111
+ for workflow_dict in subscription_workflows_dict[workflow_target.lower()]:
112
+ workflow = get_workflow(workflow_dict["name"])
113
+ if not workflow:
114
+ continue
115
+ if (
116
+ not workflow.authorize_callback(current_user) # The current user isn't allowed to run this workflow
117
+ and "reason" not in workflow_dict # and there isn't already a reason why this workflow cannot run
118
+ ):
119
+ workflow_dict["reason"] = "subscription.insufficient_workflow_permissions"
120
+
121
+ return subscription_workflows_dict
122
+
123
+
103
124
  @router.get(
104
125
  "/domain-model/{subscription_id}",
105
126
  response_model=SubscriptionDomainModelSchema | None,
@@ -169,7 +190,9 @@ def subscriptions_search(
169
190
  description="This endpoint is deprecated and will be removed in a future release. Please use the GraphQL query",
170
191
  dependencies=[Depends(deprecated_endpoint)],
171
192
  )
172
- def subscription_workflows_by_id(subscription_id: UUID) -> dict[str, list[dict[str, list[Any] | str]]]:
193
+ def subscription_workflows_by_id(
194
+ subscription_id: UUID, current_user: OIDCUserModel | None = Depends(authenticate)
195
+ ) -> dict[str, list[dict[str, list[Any] | str]]]:
173
196
  subscription = db.session.get(
174
197
  SubscriptionTable,
175
198
  subscription_id,
@@ -181,7 +204,7 @@ def subscription_workflows_by_id(subscription_id: UUID) -> dict[str, list[dict[s
181
204
  if not subscription:
182
205
  raise_status(HTTPStatus.NOT_FOUND)
183
206
 
184
- return subscription_workflows(subscription)
207
+ return _authorized_subscription_workflows(subscription, current_user)
185
208
 
186
209
 
187
210
  @router.put("/{subscription_id}/set_in_sync", response_model=None, status_code=HTTPStatus.OK)
@@ -11,7 +11,6 @@ from pydantic_forms.types import FormGenerator, State, UUIDstr
11
11
 
12
12
  from orchestrator.forms import FormPage
13
13
  from orchestrator.forms.validators import Divider, Label, CustomerId, MigrationSummary
14
- from orchestrator.targets import Target
15
14
  from orchestrator.types import SubscriptionLifecycle
16
15
  from orchestrator.workflow import StepList, begin, step
17
16
  from orchestrator.workflows.steps import store_process_subscription
@@ -119,6 +118,6 @@ def create_{{ product.variable }}() -> StepList:
119
118
  return (
120
119
  begin
121
120
  >> construct_{{ product.variable }}_model
122
- >> store_process_subscription(Target.CREATE)
121
+ >> store_process_subscription()
123
122
  # TODO add provision step(s)
124
123
  )
orchestrator/db/models.py CHANGED
@@ -117,7 +117,7 @@ class ProcessTable(BaseModel):
117
117
  is_task = mapped_column(Boolean, nullable=False, server_default=text("false"), index=True)
118
118
 
119
119
  steps = relationship(
120
- "ProcessStepTable", cascade="delete", passive_deletes=True, order_by="asc(ProcessStepTable.executed_at)"
120
+ "ProcessStepTable", cascade="delete", passive_deletes=True, order_by="asc(ProcessStepTable.completed_at)"
121
121
  )
122
122
  input_states = relationship("InputStateTable", cascade="delete", order_by="desc(InputStateTable.input_time)")
123
123
  process_subscriptions = relationship("ProcessSubscriptionTable", back_populates="process", passive_deletes=True)
@@ -141,7 +141,8 @@ class ProcessStepTable(BaseModel):
141
141
  status = mapped_column(String(50), nullable=False)
142
142
  state = mapped_column(pg.JSONB(), nullable=False)
143
143
  created_by = mapped_column(String(255), nullable=True)
144
- executed_at = mapped_column(UtcTimestamp, server_default=text("statement_timestamp()"), nullable=False)
144
+ completed_at = mapped_column(UtcTimestamp, server_default=text("statement_timestamp()"), nullable=False)
145
+ started_at = mapped_column(UtcTimestamp, server_default=text("statement_timestamp()"), nullable=False)
145
146
  commit_hash = mapped_column(String(40), nullable=True, default=GIT_COMMIT_HASH)
146
147
 
147
148
 
@@ -154,7 +155,9 @@ class ProcessSubscriptionTable(BaseModel):
154
155
  )
155
156
  subscription_id = mapped_column(UUIDType, ForeignKey("subscriptions.subscription_id"), nullable=False, index=True)
156
157
  created_at = mapped_column(UtcTimestamp, server_default=text("current_timestamp()"), nullable=False)
157
- workflow_target = mapped_column(String(255), nullable=False, server_default=Target.CREATE)
158
+
159
+ # FIXME: workflow_target is already stored in the workflow table, this column should get removed in a later release.
160
+ workflow_target = mapped_column(String(255), nullable=True)
158
161
 
159
162
  process = relationship("ProcessTable", back_populates="process_subscriptions")
160
163
  subscription = relationship("SubscriptionTable", back_populates="processes")
@@ -6,14 +6,17 @@ from strawberry.federation.schema_directives import Key
6
6
  from strawberry.scalars import JSON
7
7
 
8
8
  from oauth2_lib.strawberry import authenticated_field
9
+ from orchestrator.api.api_v1.endpoints.processes import get_auth_callbacks, get_current_steps
9
10
  from orchestrator.db import ProcessTable, ProductTable, db
10
11
  from orchestrator.graphql.pagination import EMPTY_PAGE, Connection
11
12
  from orchestrator.graphql.schemas.customer import CustomerType
12
13
  from orchestrator.graphql.schemas.helpers import get_original_model
13
14
  from orchestrator.graphql.schemas.product import ProductType
14
- from orchestrator.graphql.types import GraphqlFilter, GraphqlSort, OrchestratorInfo
15
+ from orchestrator.graphql.types import FormUserPermissionsType, GraphqlFilter, GraphqlSort, OrchestratorInfo
15
16
  from orchestrator.schemas.process import ProcessSchema, ProcessStepSchema
17
+ from orchestrator.services.processes import load_process
16
18
  from orchestrator.settings import app_settings
19
+ from orchestrator.workflows import get_workflow
17
20
 
18
21
  if TYPE_CHECKING:
19
22
  from orchestrator.graphql.schemas.subscription import SubscriptionInterface
@@ -29,7 +32,11 @@ class ProcessStepType:
29
32
  name: strawberry.auto
30
33
  status: strawberry.auto
31
34
  created_by: strawberry.auto
32
- executed: strawberry.auto
35
+ executed: strawberry.auto = strawberry.field(
36
+ deprecation_reason="Deprecated, use 'started' and 'completed' for step start and completion times"
37
+ )
38
+ started: strawberry.auto
39
+ completed: strawberry.auto
33
40
  commit_hash: strawberry.auto
34
41
  state: JSON | None
35
42
  state_delta: JSON | None
@@ -74,6 +81,18 @@ class ProcessType:
74
81
  shortcode=app_settings.DEFAULT_CUSTOMER_SHORTCODE,
75
82
  )
76
83
 
84
+ @strawberry.field(description="Returns user permissions for operations on this process") # type: ignore
85
+ def user_permissions(self, info: OrchestratorInfo) -> FormUserPermissionsType:
86
+ oidc_user = info.context.get_current_user
87
+ workflow = get_workflow(self.workflow_name)
88
+ process = load_process(db.session.get(ProcessTable, self.process_id)) # type: ignore[arg-type]
89
+ auth_resume, auth_retry = get_auth_callbacks(get_current_steps(process), workflow) # type: ignore[arg-type]
90
+
91
+ return FormUserPermissionsType(
92
+ retryAllowed=auth_retry and auth_retry(oidc_user), # type: ignore[arg-type]
93
+ resumeAllowed=auth_resume and auth_resume(oidc_user), # type: ignore[arg-type]
94
+ )
95
+
77
96
  @authenticated_field(description="Returns list of subscriptions of the process") # type: ignore
78
97
  async def subscriptions(
79
98
  self,
@@ -52,21 +52,20 @@ class ProductType:
52
52
  return await resolve_subscriptions(info, filter_by_with_related_subscriptions, sort_by, first, after)
53
53
 
54
54
  @strawberry.field(description="Returns list of all nested productblock names") # type: ignore
55
- async def all_pb_names(self) -> list[str]:
56
-
55
+ async def all_product_block_names(self) -> list[str]:
57
56
  model = get_original_model(self, ProductTable)
58
57
 
59
- def get_all_pb_names(product_blocks: list[ProductBlockTable]) -> Iterable[str]:
58
+ def get_names(product_blocks: list[ProductBlockTable], visited: set) -> Iterable[str]:
60
59
  for product_block in product_blocks:
60
+ if product_block.product_block_id in visited:
61
+ continue
62
+ visited.add(product_block.product_block_id)
61
63
  yield product_block.name
62
-
63
64
  if product_block.depends_on:
64
- yield from get_all_pb_names(product_block.depends_on)
65
-
66
- names: list[str] = list(get_all_pb_names(model.product_blocks))
67
- names.sort()
65
+ yield from get_names(product_block.depends_on, visited)
68
66
 
69
- return names
67
+ names = set(get_names(model.product_blocks, set()))
68
+ return sorted(names)
70
69
 
71
70
  @strawberry.field(description="Return product blocks") # type: ignore
72
71
  async def product_blocks(self) -> list[Annotated["ProductBlock", strawberry.lazy(".product_block")]]:
@@ -5,6 +5,7 @@ import strawberry
5
5
  from orchestrator.config.assignee import Assignee
6
6
  from orchestrator.db import WorkflowTable
7
7
  from orchestrator.graphql.schemas.helpers import get_original_model
8
+ from orchestrator.graphql.types import OrchestratorInfo
8
9
  from orchestrator.schemas import StepSchema, WorkflowSchema
9
10
  from orchestrator.workflows import get_workflow
10
11
 
@@ -30,3 +31,11 @@ class Workflow:
30
31
  @strawberry.field(description="Return all steps for this workflow") # type: ignore
31
32
  def steps(self) -> list[Step]:
32
33
  return [Step(name=step.name, assignee=step.assignee) for step in get_workflow(self.name).steps] # type: ignore
34
+
35
+ @strawberry.field(description="Return whether the currently logged-in used is allowed to start this workflow") # type: ignore
36
+ def is_allowed(self, info: OrchestratorInfo) -> bool:
37
+ oidc_user = info.context.get_current_user
38
+ workflow_table = get_original_model(self, WorkflowTable)
39
+ workflow = get_workflow(workflow_table.name)
40
+
41
+ return workflow.authorize_callback(oidc_user) # type: ignore
@@ -1,4 +1,4 @@
1
- # Copyright 2022-2023 SURF, GÉANT.
1
+ # Copyright 2022-2025 SURF, GÉANT.
2
2
  # Licensed under the Apache License, Version 2.0 (the "License");
3
3
  # you may not use this file except in compliance with the License.
4
4
  # You may obtain a copy of the License at
@@ -132,6 +132,12 @@ SCALAR_OVERRIDES: ScalarOverrideType = {
132
132
  }
133
133
 
134
134
 
135
+ @strawberry.type(description="User permissions on a specific process")
136
+ class FormUserPermissionsType:
137
+ retryAllowed: bool
138
+ resumeAllowed: bool
139
+
140
+
135
141
  @strawberry.type(description="Generic class to capture errors")
136
142
  class MutationError:
137
143
  message: str = strawberry.field(description="Error message")
@@ -0,0 +1,65 @@
1
+ """Changed timestamping fields in process_steps.
2
+
3
+ Revision ID: 93fc5834c7e5
4
+ Revises: 4b58e336d1bf
5
+ Create Date: 2025-07-01 14:20:44.755694
6
+
7
+ """
8
+
9
+ import sqlalchemy as sa
10
+ from alembic import op
11
+
12
+ from orchestrator import db
13
+
14
+ # revision identifiers, used by Alembic.
15
+ revision = "93fc5834c7e5"
16
+ down_revision = "4b58e336d1bf"
17
+ branch_labels = None
18
+ depends_on = None
19
+
20
+
21
+ def upgrade() -> None:
22
+ # ### commands auto generated by Alembic - please adjust! ###
23
+ op.add_column(
24
+ "process_steps",
25
+ sa.Column(
26
+ "started_at",
27
+ db.UtcTimestamp(timezone=True),
28
+ server_default=sa.text("statement_timestamp()"),
29
+ nullable=False,
30
+ ),
31
+ )
32
+ op.alter_column("process_steps", "executed_at", new_column_name="completed_at")
33
+ # conn = op.get_bind()
34
+ # sa.select
35
+ # ### end Alembic commands ###
36
+ # Backfill started_at field correctly using proper aliasing
37
+ op.execute(
38
+ """
39
+ WITH backfill_started_at AS (
40
+ SELECT
41
+ ps1.stepid,
42
+ COALESCE(prev.completed_at, p.started_at) AS new_started_at
43
+ FROM process_steps ps1
44
+ JOIN processes p ON ps1.pid = p.pid
45
+ LEFT JOIN LATERAL (
46
+ SELECT ps2.completed_at
47
+ FROM process_steps ps2
48
+ WHERE ps2.pid = ps1.pid AND ps2.completed_at < ps1.completed_at
49
+ ORDER BY ps2.completed_at DESC
50
+ LIMIT 1
51
+ ) prev ON true
52
+ )
53
+ UPDATE process_steps
54
+ SET started_at = b.new_started_at
55
+ FROM backfill_started_at b
56
+ WHERE process_steps.stepid = b.stepid;
57
+ """
58
+ )
59
+
60
+
61
+ def downgrade() -> None:
62
+ # ### commands auto generated by Alembic - please adjust! ###
63
+ op.drop_column("process_steps", "started_at")
64
+ op.alter_column("process_steps", "completed_at", new_column_name="executed_at")
65
+ # ### end Alembic commands ###
@@ -0,0 +1,30 @@
1
+ """Deprecating workflow target in ProcessSubscriptionTable.
2
+
3
+ Revision ID: 4b58e336d1bf
4
+ Revises: 161918133bec
5
+ Create Date: 2025-07-04 15:27:23.814954
6
+
7
+ """
8
+
9
+ import sqlalchemy as sa
10
+ from alembic import op
11
+
12
+ # revision identifiers, used by Alembic.
13
+ revision = "4b58e336d1bf"
14
+ down_revision = "161918133bec"
15
+ branch_labels = None
16
+ depends_on = None
17
+
18
+
19
+ def upgrade() -> None:
20
+ op.alter_column("processes_subscriptions", "workflow_target", existing_type=sa.VARCHAR(length=255), nullable=True)
21
+
22
+
23
+ def downgrade() -> None:
24
+ op.alter_column(
25
+ "processes_subscriptions",
26
+ "workflow_target",
27
+ existing_type=sa.VARCHAR(length=255),
28
+ nullable=False,
29
+ existing_server_default=sa.text("'CREATE'::character varying"),
30
+ )
@@ -49,7 +49,11 @@ class ProcessStepSchema(OrchestratorBaseModel):
49
49
  name: str
50
50
  status: str
51
51
  created_by: str | None = None
52
- executed: datetime | None = None
52
+ executed: datetime | None = Field(
53
+ None, deprecated="Deprecated, use 'started' and 'completed' for step start and completion times"
54
+ )
55
+ started: datetime | None = None
56
+ completed: datetime | None = None
53
57
  commit_hash: str | None = None
54
58
  state: dict[str, Any] | None = None
55
59
  state_delta: dict[str, Any] | None = None
@@ -19,6 +19,7 @@ import structlog
19
19
  from celery.result import AsyncResult
20
20
  from kombu.exceptions import ConnectionError, OperationalError
21
21
 
22
+ from oauth2_lib.fastapi import OIDCUserModel
22
23
  from orchestrator import app_settings
23
24
  from orchestrator.api.error_handling import raise_status
24
25
  from orchestrator.db import ProcessTable, db
@@ -42,7 +43,11 @@ def _block_when_testing(task_result: AsyncResult) -> None:
42
43
 
43
44
 
44
45
  def _celery_start_process(
45
- workflow_key: str, user_inputs: list[State] | None, user: str = SYSTEM_USER, **kwargs: Any
46
+ workflow_key: str,
47
+ user_inputs: list[State] | None,
48
+ user: str = SYSTEM_USER,
49
+ user_model: OIDCUserModel | None = None,
50
+ **kwargs: Any,
46
51
  ) -> UUID:
47
52
  """Client side call of Celery."""
48
53
  from orchestrator.services.tasks import NEW_TASK, NEW_WORKFLOW, get_celery_task
@@ -57,7 +62,7 @@ def _celery_start_process(
57
62
 
58
63
  task_name = NEW_TASK if wf_table.is_task else NEW_WORKFLOW
59
64
  trigger_task = get_celery_task(task_name)
60
- pstat = create_process(workflow_key, user_inputs, user)
65
+ pstat = create_process(workflow_key, user_inputs=user_inputs, user=user, user_model=user_model)
61
66
  try:
62
67
  result = trigger_task.delay(pstat.process_id, workflow_key, user)
63
68
  _block_when_testing(result)
@@ -12,6 +12,7 @@
12
12
  # limitations under the License.
13
13
  from collections.abc import Callable, Sequence
14
14
  from concurrent.futures.thread import ThreadPoolExecutor
15
+ from datetime import datetime
15
16
  from functools import partial
16
17
  from http import HTTPStatus
17
18
  from typing import Any
@@ -19,6 +20,7 @@ from uuid import UUID, uuid4
19
20
 
20
21
  import structlog
21
22
  from deepmerge.merger import Merger
23
+ from pytz import utc
22
24
  from sqlalchemy import delete, select
23
25
  from sqlalchemy.exc import SQLAlchemyError
24
26
  from sqlalchemy.orm import joinedload
@@ -206,6 +208,10 @@ def _get_current_step_to_update(
206
208
  finally:
207
209
  step_state.pop("__remove_keys", None)
208
210
 
211
+ # We don't have __last_step_started in __remove_keys because the way __remove_keys is populated appears like it would overwrite
212
+ # what's put there in the step decorator in certain cases (step groups and callback steps)
213
+ step_start_time = step_state.pop("__last_step_started_at", None)
214
+
209
215
  if process_state.isfailed() or process_state.iswaiting():
210
216
  if (
211
217
  last_db_step is not None
@@ -216,7 +222,7 @@ def _get_current_step_to_update(
216
222
  ):
217
223
  state_ex_info = {
218
224
  "retries": last_db_step.state.get("retries", 0) + 1,
219
- "executed_at": last_db_step.state.get("executed_at", []) + [str(last_db_step.executed_at)],
225
+ "completed_at": last_db_step.state.get("completed_at", []) + [str(last_db_step.completed_at)],
220
226
  }
221
227
 
222
228
  # write new state info and execution date
@@ -236,10 +242,13 @@ def _get_current_step_to_update(
236
242
  state=step_state,
237
243
  created_by=stat.current_user,
238
244
  )
245
+ # Since the Start step does not have a __last_step_started_at in it's state, we effectively assume it is instantaneous.
246
+ now = nowtz()
247
+ current_step.started_at = datetime.fromtimestamp(step_start_time or now.timestamp(), tz=utc)
239
248
 
240
249
  # Always explicitly set this instead of leaving it to the database to prevent failing tests
241
250
  # Test will fail if multiple steps have the same timestamp
242
- current_step.executed_at = nowtz()
251
+ current_step.completed_at = now
243
252
  return current_step
244
253
 
245
254
 
@@ -467,9 +476,7 @@ def thread_start_process(
467
476
  user_model: OIDCUserModel | None = None,
468
477
  broadcast_func: BroadcastFunc | None = None,
469
478
  ) -> UUID:
470
- pstat = create_process(workflow_key, user_inputs=user_inputs, user=user)
471
- if not pstat.workflow.authorize_callback(user_model):
472
- raise_status(HTTPStatus.FORBIDDEN, error_message_unauthorized(workflow_key))
479
+ pstat = create_process(workflow_key, user_inputs=user_inputs, user=user, user_model=user_model)
473
480
 
474
481
  _safe_logstep_with_func = partial(safe_logstep, broadcast_func=broadcast_func)
475
482
  return _run_process_async(pstat.process_id, lambda: runwf(pstat, _safe_logstep_with_func))
@@ -506,7 +513,6 @@ def thread_resume_process(
506
513
  *,
507
514
  user_inputs: list[State] | None = None,
508
515
  user: str | None = None,
509
- user_model: OIDCUserModel | None = None,
510
516
  broadcast_func: BroadcastFunc | None = None,
511
517
  ) -> UUID:
512
518
  # ATTENTION!! When modifying this function make sure you make similar changes to `resume_workflow` in the test code
@@ -515,8 +521,6 @@ def thread_resume_process(
515
521
  user_inputs = [{}]
516
522
 
517
523
  pstat = load_process(process)
518
- if not pstat.workflow.authorize_callback(user_model):
519
- raise_status(HTTPStatus.FORBIDDEN, error_message_unauthorized(str(process.workflow_name)))
520
524
 
521
525
  if pstat.workflow == removed_workflow:
522
526
  raise ValueError("This workflow cannot be resumed")
@@ -556,7 +560,6 @@ def resume_process(
556
560
  *,
557
561
  user_inputs: list[State] | None = None,
558
562
  user: str | None = None,
559
- user_model: OIDCUserModel | None = None,
560
563
  broadcast_func: BroadcastFunc | None = None,
561
564
  ) -> UUID:
562
565
  """Resume a failed or suspended process.
@@ -565,7 +568,6 @@ def resume_process(
565
568
  process: Process from database
566
569
  user_inputs: Optional user input from forms
567
570
  user: user who resumed this process
568
- user_model: OIDCUserModel of user who resumed this process
569
571
  broadcast_func: Optional function to broadcast process data
570
572
 
571
573
  Returns:
@@ -573,8 +575,6 @@ def resume_process(
573
575
 
574
576
  """
575
577
  pstat = load_process(process)
576
- if not pstat.workflow.authorize_callback(user_model):
577
- raise_status(HTTPStatus.FORBIDDEN, error_message_unauthorized(str(process.workflow_name)))
578
578
 
579
579
  try:
580
580
  post_form(pstat.log[0].form, pstat.state.unwrap(), user_inputs=user_inputs or [])
@@ -0,0 +1,9 @@
1
+ from collections.abc import Callable
2
+ from typing import TypeAlias
3
+
4
+ from oauth2_lib.fastapi import OIDCUserModel
5
+
6
+ # This file is broken out separately to avoid circular imports.
7
+
8
+ # Can instead use "type Authorizer = ..." in later Python versions.
9
+ Authorizer: TypeAlias = Callable[[OIDCUserModel | None], bool]
@@ -57,7 +57,9 @@ def enrich_step_details(step: ProcessStepTable, previous_step: ProcessStepTable
57
57
 
58
58
  return {
59
59
  "name": step.name,
60
- "executed": step.executed_at.timestamp(),
60
+ "executed": step.completed_at.timestamp(),
61
+ "started": step.started_at.timestamp(),
62
+ "completed": step.completed_at.timestamp(),
61
63
  "status": step.status,
62
64
  "state": step.state,
63
65
  "created_by": step.created_by,
@@ -103,7 +105,7 @@ def enrich_process(process: ProcessTable, p_stat: ProcessStat | None = None) ->
103
105
  "is_task": process.is_task,
104
106
  "workflow_id": process.workflow_id,
105
107
  "workflow_name": process.workflow.name,
106
- "workflow_target": process.process_subscriptions[0].workflow_target if process.process_subscriptions else None,
108
+ "workflow_target": process.workflow.target,
107
109
  "failed_reason": process.failed_reason,
108
110
  "created_by": process.created_by,
109
111
  "started_at": process.started_at,