orchestrator-core 3.1.1__py3-none-any.whl → 3.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orchestrator/__init__.py +2 -2
- orchestrator/api/api_v1/api.py +1 -1
- orchestrator/api/api_v1/endpoints/processes.py +29 -9
- orchestrator/api/api_v1/endpoints/settings.py +1 -1
- orchestrator/api/api_v1/endpoints/subscriptions.py +1 -1
- orchestrator/app.py +1 -1
- orchestrator/cli/database.py +1 -1
- orchestrator/cli/generator/generator/migration.py +2 -5
- orchestrator/cli/migrate_tasks.py +13 -0
- orchestrator/config/assignee.py +1 -1
- orchestrator/db/__init__.py +2 -0
- orchestrator/db/models.py +6 -4
- orchestrator/devtools/populator.py +1 -1
- orchestrator/domain/__init__.py +2 -3
- orchestrator/domain/base.py +74 -5
- orchestrator/domain/lifecycle.py +1 -1
- orchestrator/graphql/schema.py +1 -1
- orchestrator/graphql/types.py +1 -1
- orchestrator/graphql/utils/get_subscription_product_blocks.py +13 -0
- orchestrator/migrations/env.py +15 -2
- orchestrator/migrations/helpers.py +6 -6
- orchestrator/migrations/versions/schema/2020-10-19_c112305b07d3_initial_schema_migration.py +1 -1
- orchestrator/migrations/versions/schema/2023-05-25_b1970225392d_add_subscription_metadata_workflow.py +1 -1
- orchestrator/migrations/versions/schema/2025-02-12_bac6be6f2b4f_added_input_state_table.py +1 -1
- orchestrator/schemas/engine_settings.py +1 -1
- orchestrator/schemas/subscription.py +1 -1
- orchestrator/security.py +1 -1
- orchestrator/services/celery.py +1 -1
- orchestrator/services/processes.py +99 -18
- orchestrator/services/products.py +1 -1
- orchestrator/services/subscriptions.py +1 -1
- orchestrator/services/tasks.py +1 -1
- orchestrator/settings.py +2 -23
- orchestrator/targets.py +1 -1
- orchestrator/types.py +1 -1
- orchestrator/utils/errors.py +1 -1
- orchestrator/utils/state.py +74 -54
- orchestrator/websocket/websocket_manager.py +1 -1
- orchestrator/workflow.py +20 -4
- orchestrator/workflows/modify_note.py +1 -1
- orchestrator/workflows/steps.py +1 -1
- orchestrator/workflows/tasks/cleanup_tasks_log.py +1 -1
- orchestrator/workflows/tasks/resume_workflows.py +1 -1
- orchestrator/workflows/tasks/validate_product_type.py +1 -1
- orchestrator/workflows/tasks/validate_products.py +1 -1
- orchestrator/workflows/utils.py +40 -5
- {orchestrator_core-3.1.1.dist-info → orchestrator_core-3.1.2.dist-info}/METADATA +10 -10
- {orchestrator_core-3.1.1.dist-info → orchestrator_core-3.1.2.dist-info}/RECORD +50 -50
- {orchestrator_core-3.1.1.dist-info → orchestrator_core-3.1.2.dist-info}/WHEEL +1 -1
- {orchestrator_core-3.1.1.dist-info → orchestrator_core-3.1.2.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2019-
|
|
1
|
+
# Copyright 2019-2025 SURF, GÉANT, ESnet.
|
|
2
2
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
3
3
|
# you may not use this file except in compliance with the License.
|
|
4
4
|
# You may obtain a copy of the License at
|
|
@@ -24,15 +24,10 @@ from sqlalchemy.exc import SQLAlchemyError
|
|
|
24
24
|
from sqlalchemy.orm import joinedload
|
|
25
25
|
|
|
26
26
|
from nwastdlib.ex import show_ex
|
|
27
|
+
from oauth2_lib.fastapi import OIDCUserModel
|
|
27
28
|
from orchestrator.api.error_handling import raise_status
|
|
28
29
|
from orchestrator.config.assignee import Assignee
|
|
29
|
-
from orchestrator.db import
|
|
30
|
-
EngineSettingsTable,
|
|
31
|
-
ProcessStepTable,
|
|
32
|
-
ProcessSubscriptionTable,
|
|
33
|
-
ProcessTable,
|
|
34
|
-
db,
|
|
35
|
-
)
|
|
30
|
+
from orchestrator.db import EngineSettingsTable, ProcessStepTable, ProcessSubscriptionTable, ProcessTable, db
|
|
36
31
|
from orchestrator.distlock import distlock_manager
|
|
37
32
|
from orchestrator.schemas.engine_settings import WorkerStatus
|
|
38
33
|
from orchestrator.services.input_state import store_input_state
|
|
@@ -43,9 +38,10 @@ from orchestrator.targets import Target
|
|
|
43
38
|
from orchestrator.types import BroadcastFunc
|
|
44
39
|
from orchestrator.utils.datetime import nowtz
|
|
45
40
|
from orchestrator.utils.errors import error_state_to_dict
|
|
46
|
-
from orchestrator.websocket import broadcast_invalidate_status_counts
|
|
41
|
+
from orchestrator.websocket import broadcast_invalidate_status_counts, broadcast_process_update_to_websocket
|
|
47
42
|
from orchestrator.workflow import (
|
|
48
43
|
CALLBACK_TOKEN_KEY,
|
|
44
|
+
DEFAULT_CALLBACK_PROGRESS_KEY,
|
|
49
45
|
Failed,
|
|
50
46
|
ProcessStat,
|
|
51
47
|
ProcessStatus,
|
|
@@ -413,10 +409,15 @@ def _run_process_async(process_id: UUID, f: Callable) -> UUID:
|
|
|
413
409
|
return process_id
|
|
414
410
|
|
|
415
411
|
|
|
412
|
+
def error_message_unauthorized(workflow_key: str) -> str:
|
|
413
|
+
return f"User is not authorized to execute '{workflow_key}' workflow"
|
|
414
|
+
|
|
415
|
+
|
|
416
416
|
def create_process(
|
|
417
417
|
workflow_key: str,
|
|
418
418
|
user_inputs: list[State] | None = None,
|
|
419
419
|
user: str = SYSTEM_USER,
|
|
420
|
+
user_model: OIDCUserModel | None = None,
|
|
420
421
|
) -> ProcessStat:
|
|
421
422
|
# ATTENTION!! When modifying this function make sure you make similar changes to `run_workflow` in the test code
|
|
422
423
|
|
|
@@ -429,6 +430,9 @@ def create_process(
|
|
|
429
430
|
if not workflow:
|
|
430
431
|
raise_status(HTTPStatus.NOT_FOUND, "Workflow does not exist")
|
|
431
432
|
|
|
433
|
+
if not workflow.authorize_callback(user_model):
|
|
434
|
+
raise_status(HTTPStatus.FORBIDDEN, error_message_unauthorized(workflow_key))
|
|
435
|
+
|
|
432
436
|
initial_state = {
|
|
433
437
|
"process_id": process_id,
|
|
434
438
|
"reporter": user,
|
|
@@ -448,6 +452,7 @@ def create_process(
|
|
|
448
452
|
state=Success(state | initial_state),
|
|
449
453
|
log=workflow.steps,
|
|
450
454
|
current_user=user,
|
|
455
|
+
user_model=user_model,
|
|
451
456
|
)
|
|
452
457
|
|
|
453
458
|
_db_create_process(pstat)
|
|
@@ -459,9 +464,12 @@ def thread_start_process(
|
|
|
459
464
|
workflow_key: str,
|
|
460
465
|
user_inputs: list[State] | None = None,
|
|
461
466
|
user: str = SYSTEM_USER,
|
|
467
|
+
user_model: OIDCUserModel | None = None,
|
|
462
468
|
broadcast_func: BroadcastFunc | None = None,
|
|
463
469
|
) -> UUID:
|
|
464
470
|
pstat = create_process(workflow_key, user_inputs=user_inputs, user=user)
|
|
471
|
+
if not pstat.workflow.authorize_callback(user_model):
|
|
472
|
+
raise_status(HTTPStatus.FORBIDDEN, error_message_unauthorized(workflow_key))
|
|
465
473
|
|
|
466
474
|
_safe_logstep_with_func = partial(safe_logstep, broadcast_func=broadcast_func)
|
|
467
475
|
return _run_process_async(pstat.process_id, lambda: runwf(pstat, _safe_logstep_with_func))
|
|
@@ -471,6 +479,7 @@ def start_process(
|
|
|
471
479
|
workflow_key: str,
|
|
472
480
|
user_inputs: list[State] | None = None,
|
|
473
481
|
user: str = SYSTEM_USER,
|
|
482
|
+
user_model: OIDCUserModel | None = None,
|
|
474
483
|
broadcast_func: BroadcastFunc | None = None,
|
|
475
484
|
) -> UUID:
|
|
476
485
|
"""Start a process for workflow.
|
|
@@ -479,6 +488,7 @@ def start_process(
|
|
|
479
488
|
workflow_key: name of workflow
|
|
480
489
|
user_inputs: List of form inputs from frontend
|
|
481
490
|
user: User who starts this process
|
|
491
|
+
user_model: Full OIDCUserModel with claims, etc
|
|
482
492
|
broadcast_func: Optional function to broadcast process data
|
|
483
493
|
|
|
484
494
|
Returns:
|
|
@@ -486,7 +496,9 @@ def start_process(
|
|
|
486
496
|
|
|
487
497
|
"""
|
|
488
498
|
start_func = get_execution_context()["start"]
|
|
489
|
-
return start_func(
|
|
499
|
+
return start_func(
|
|
500
|
+
workflow_key, user_inputs=user_inputs, user=user, user_model=user_model, broadcast_func=broadcast_func
|
|
501
|
+
)
|
|
490
502
|
|
|
491
503
|
|
|
492
504
|
def thread_resume_process(
|
|
@@ -494,6 +506,7 @@ def thread_resume_process(
|
|
|
494
506
|
*,
|
|
495
507
|
user_inputs: list[State] | None = None,
|
|
496
508
|
user: str | None = None,
|
|
509
|
+
user_model: OIDCUserModel | None = None,
|
|
497
510
|
broadcast_func: BroadcastFunc | None = None,
|
|
498
511
|
) -> UUID:
|
|
499
512
|
# ATTENTION!! When modifying this function make sure you make similar changes to `resume_workflow` in the test code
|
|
@@ -502,6 +515,8 @@ def thread_resume_process(
|
|
|
502
515
|
user_inputs = [{}]
|
|
503
516
|
|
|
504
517
|
pstat = load_process(process)
|
|
518
|
+
if not pstat.workflow.authorize_callback(user_model):
|
|
519
|
+
raise_status(HTTPStatus.FORBIDDEN, error_message_unauthorized(str(process.workflow_name)))
|
|
505
520
|
|
|
506
521
|
if pstat.workflow == removed_workflow:
|
|
507
522
|
raise ValueError("This workflow cannot be resumed")
|
|
@@ -541,6 +556,7 @@ def resume_process(
|
|
|
541
556
|
*,
|
|
542
557
|
user_inputs: list[State] | None = None,
|
|
543
558
|
user: str | None = None,
|
|
559
|
+
user_model: OIDCUserModel | None = None,
|
|
544
560
|
broadcast_func: BroadcastFunc | None = None,
|
|
545
561
|
) -> UUID:
|
|
546
562
|
"""Resume a failed or suspended process.
|
|
@@ -549,6 +565,7 @@ def resume_process(
|
|
|
549
565
|
process: Process from database
|
|
550
566
|
user_inputs: Optional user input from forms
|
|
551
567
|
user: user who resumed this process
|
|
568
|
+
user_model: OIDCUserModel of user who resumed this process
|
|
552
569
|
broadcast_func: Optional function to broadcast process data
|
|
553
570
|
|
|
554
571
|
Returns:
|
|
@@ -556,6 +573,9 @@ def resume_process(
|
|
|
556
573
|
|
|
557
574
|
"""
|
|
558
575
|
pstat = load_process(process)
|
|
576
|
+
if not pstat.workflow.authorize_callback(user_model):
|
|
577
|
+
raise_status(HTTPStatus.FORBIDDEN, error_message_unauthorized(str(process.workflow_name)))
|
|
578
|
+
|
|
559
579
|
try:
|
|
560
580
|
post_form(pstat.log[0].form, pstat.state.unwrap(), user_inputs=user_inputs or [])
|
|
561
581
|
except FormValidationError:
|
|
@@ -566,6 +586,39 @@ def resume_process(
|
|
|
566
586
|
return resume_func(process, user_inputs=user_inputs, user=user, broadcast_func=broadcast_func)
|
|
567
587
|
|
|
568
588
|
|
|
589
|
+
def ensure_correct_callback_token(pstat: ProcessStat, *, token: str) -> None:
|
|
590
|
+
"""Ensure that a callback token matches the expected value in state.
|
|
591
|
+
|
|
592
|
+
Args:
|
|
593
|
+
pstat: ProcessStat of process.
|
|
594
|
+
token: The token which was generated for the process.
|
|
595
|
+
|
|
596
|
+
Raises:
|
|
597
|
+
AssertionError: if the supplied token does not match the generated process token.
|
|
598
|
+
|
|
599
|
+
"""
|
|
600
|
+
state = pstat.state.unwrap()
|
|
601
|
+
|
|
602
|
+
# Check if the token matches
|
|
603
|
+
token_from_state = state.get(CALLBACK_TOKEN_KEY)
|
|
604
|
+
if token != token_from_state:
|
|
605
|
+
raise AssertionError("Invalid token")
|
|
606
|
+
|
|
607
|
+
|
|
608
|
+
def replace_current_step_state(process: ProcessTable, *, new_state: State) -> None:
|
|
609
|
+
"""Replace the state of the current step in a process.
|
|
610
|
+
|
|
611
|
+
Args:
|
|
612
|
+
process: Process from database
|
|
613
|
+
new_state: The new state
|
|
614
|
+
|
|
615
|
+
"""
|
|
616
|
+
current_step = process.steps[-1]
|
|
617
|
+
current_step.state = new_state
|
|
618
|
+
db.session.add(current_step)
|
|
619
|
+
db.session.commit()
|
|
620
|
+
|
|
621
|
+
|
|
569
622
|
def continue_awaiting_process(
|
|
570
623
|
process: ProcessTable,
|
|
571
624
|
*,
|
|
@@ -589,10 +642,7 @@ def continue_awaiting_process(
|
|
|
589
642
|
pstat = load_process(process)
|
|
590
643
|
state = pstat.state.unwrap()
|
|
591
644
|
|
|
592
|
-
|
|
593
|
-
token_from_state = state.get(CALLBACK_TOKEN_KEY)
|
|
594
|
-
if token != token_from_state:
|
|
595
|
-
raise AssertionError("Invalid token")
|
|
645
|
+
ensure_correct_callback_token(pstat, token=token)
|
|
596
646
|
|
|
597
647
|
# We need to pass the callback data to the worker executor. Currently, this is not supported.
|
|
598
648
|
# Therefore, we update the step state in the db and kick-off resume_workflow
|
|
@@ -600,16 +650,47 @@ def continue_awaiting_process(
|
|
|
600
650
|
result_key = state.get("__callback_result_key", "callback_result")
|
|
601
651
|
state = {**state, result_key: input_data}
|
|
602
652
|
|
|
603
|
-
|
|
604
|
-
current_step.state = state
|
|
605
|
-
db.session.add(current_step)
|
|
606
|
-
db.session.commit()
|
|
653
|
+
replace_current_step_state(process, new_state=state)
|
|
607
654
|
|
|
608
655
|
# Continue the workflow
|
|
609
656
|
resume_func = get_execution_context()["resume"]
|
|
610
657
|
return resume_func(process, broadcast_func=broadcast_func)
|
|
611
658
|
|
|
612
659
|
|
|
660
|
+
def update_awaiting_process_progress(
|
|
661
|
+
process: ProcessTable,
|
|
662
|
+
*,
|
|
663
|
+
token: str,
|
|
664
|
+
data: str | State,
|
|
665
|
+
) -> UUID:
|
|
666
|
+
"""Update progress for a process awaiting data from a callback.
|
|
667
|
+
|
|
668
|
+
Args:
|
|
669
|
+
process: Process from database
|
|
670
|
+
token: The token which was generated for the process. This must match.
|
|
671
|
+
data: Progress data posted to the callback
|
|
672
|
+
|
|
673
|
+
Returns:
|
|
674
|
+
process id
|
|
675
|
+
|
|
676
|
+
Raises:
|
|
677
|
+
AssertionError: if the supplied token does not match the generated process token.
|
|
678
|
+
|
|
679
|
+
"""
|
|
680
|
+
pstat = load_process(process)
|
|
681
|
+
|
|
682
|
+
ensure_correct_callback_token(pstat, token=token)
|
|
683
|
+
|
|
684
|
+
state = pstat.state.unwrap()
|
|
685
|
+
progress_key = state.get(DEFAULT_CALLBACK_PROGRESS_KEY, "callback_progress")
|
|
686
|
+
state = {**state, progress_key: data} | {"__remove_keys": [progress_key]}
|
|
687
|
+
|
|
688
|
+
replace_current_step_state(process, new_state=state)
|
|
689
|
+
broadcast_process_update_to_websocket(process.process_id)
|
|
690
|
+
|
|
691
|
+
return process.process_id
|
|
692
|
+
|
|
693
|
+
|
|
613
694
|
async def _async_resume_processes(
|
|
614
695
|
processes: Sequence[ProcessTable],
|
|
615
696
|
user_name: str,
|
orchestrator/services/tasks.py
CHANGED
orchestrator/settings.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2019-2020 SURF.
|
|
1
|
+
# Copyright 2019-2020 SURF, GÉANT.
|
|
2
2
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
3
3
|
# you may not use this file except in compliance with the License.
|
|
4
4
|
# You may obtain a copy of the License at
|
|
@@ -13,7 +13,6 @@
|
|
|
13
13
|
|
|
14
14
|
import secrets
|
|
15
15
|
import string
|
|
16
|
-
import warnings
|
|
17
16
|
from pathlib import Path
|
|
18
17
|
from typing import Literal
|
|
19
18
|
|
|
@@ -24,10 +23,6 @@ from oauth2_lib.settings import oauth2lib_settings
|
|
|
24
23
|
from pydantic_forms.types import strEnum
|
|
25
24
|
|
|
26
25
|
|
|
27
|
-
class OrchestratorDeprecationWarning(DeprecationWarning):
|
|
28
|
-
pass
|
|
29
|
-
|
|
30
|
-
|
|
31
26
|
class ExecutorType(strEnum):
|
|
32
27
|
WORKER = "celery"
|
|
33
28
|
THREADPOOL = "threadpool"
|
|
@@ -54,7 +49,7 @@ class AppSettings(BaseSettings):
|
|
|
54
49
|
EXECUTOR: str = ExecutorType.THREADPOOL
|
|
55
50
|
WORKFLOWS_SWAGGER_HOST: str = "localhost"
|
|
56
51
|
WORKFLOWS_GUI_URI: str = "http://localhost:3000"
|
|
57
|
-
DATABASE_URI: PostgresDsn = "postgresql
|
|
52
|
+
DATABASE_URI: PostgresDsn = "postgresql://nwa:nwa@localhost/orchestrator-core" # type: ignore
|
|
58
53
|
MAX_WORKERS: int = 5
|
|
59
54
|
MAIL_SERVER: str = "localhost"
|
|
60
55
|
MAIL_PORT: int = 25
|
|
@@ -93,22 +88,6 @@ class AppSettings(BaseSettings):
|
|
|
93
88
|
VALIDATE_OUT_OF_SYNC_SUBSCRIPTIONS: bool = False
|
|
94
89
|
FILTER_BY_MODE: Literal["partial", "exact"] = "exact"
|
|
95
90
|
|
|
96
|
-
def __init__(self) -> None:
|
|
97
|
-
super(AppSettings, self).__init__()
|
|
98
|
-
self.DATABASE_URI = PostgresDsn(convert_database_uri(str(self.DATABASE_URI)))
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
def convert_database_uri(db_uri: str) -> str:
|
|
102
|
-
if db_uri.startswith(("postgresql://", "postgresql+psycopg2://")):
|
|
103
|
-
db_uri = "postgresql+psycopg" + db_uri[db_uri.find("://") :]
|
|
104
|
-
warnings.filterwarnings("always", category=OrchestratorDeprecationWarning)
|
|
105
|
-
warnings.warn(
|
|
106
|
-
"DATABASE_URI converted to postgresql+psycopg:// format, please update your enviroment variable",
|
|
107
|
-
OrchestratorDeprecationWarning,
|
|
108
|
-
stacklevel=2,
|
|
109
|
-
)
|
|
110
|
-
return db_uri
|
|
111
|
-
|
|
112
91
|
|
|
113
92
|
app_settings = AppSettings()
|
|
114
93
|
|
orchestrator/targets.py
CHANGED
orchestrator/types.py
CHANGED
orchestrator/utils/errors.py
CHANGED
orchestrator/utils/state.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2019-2020 SURF, ESnet
|
|
1
|
+
# Copyright 2019-2020 SURF, ESnet, GÉANT.
|
|
2
2
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
3
3
|
# you may not use this file except in compliance with the License.
|
|
4
4
|
# You may obtain a copy of the License at
|
|
@@ -150,65 +150,85 @@ def _build_arguments(func: StepFunc | InputStepFunc, state: State) -> list: # n
|
|
|
150
150
|
|
|
151
151
|
Raises:
|
|
152
152
|
KeyError: if requested argument is not in the state, or cannot be reconstructed as an initial domain model.
|
|
153
|
+
ValueError: if requested argument cannot be converted to the expected type.
|
|
153
154
|
|
|
154
155
|
"""
|
|
156
|
+
|
|
155
157
|
sig = inspect.signature(func)
|
|
158
|
+
if not sig.parameters:
|
|
159
|
+
return []
|
|
160
|
+
|
|
161
|
+
def _convert_to_uuid(v: Any) -> UUID:
|
|
162
|
+
"""Converts the value to a UUID instance if it is not already one."""
|
|
163
|
+
return v if isinstance(v, UUID) else UUID(v)
|
|
164
|
+
|
|
156
165
|
arguments: list[Any] = []
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
166
|
+
for name, param in sig.parameters.items():
|
|
167
|
+
# Ignore dynamic arguments. Mostly need to deal with `const`
|
|
168
|
+
if param.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD):
|
|
169
|
+
logger.warning("*args and **kwargs are not supported as step params")
|
|
170
|
+
continue
|
|
171
|
+
|
|
172
|
+
# If we find an argument named "state" we use the whole state as argument to
|
|
173
|
+
# This is mainly to be backward compatible with code that needs the whole state...
|
|
174
|
+
# TODO: Remove this construction
|
|
175
|
+
if name == "state":
|
|
176
|
+
arguments.append(state)
|
|
177
|
+
continue
|
|
178
|
+
|
|
179
|
+
# Workaround for the fact that you can't call issubclass on typing types
|
|
180
|
+
try:
|
|
181
|
+
is_subscription_model_type = issubclass(param.annotation, SubscriptionModel)
|
|
182
|
+
except Exception:
|
|
183
|
+
is_subscription_model_type = False
|
|
184
|
+
|
|
185
|
+
if is_subscription_model_type:
|
|
186
|
+
subscription_id = _get_sub_id(state.get(name))
|
|
187
|
+
if subscription_id:
|
|
188
|
+
sub_mod = param.annotation.from_subscription(subscription_id)
|
|
189
|
+
arguments.append(sub_mod)
|
|
190
|
+
else:
|
|
191
|
+
logger.error("Could not find key in state.", key=name, state=state)
|
|
192
|
+
raise KeyError(f"Could not find key '{name}' in state.")
|
|
193
|
+
elif is_list_type(param.annotation, SubscriptionModel):
|
|
194
|
+
subscription_ids = [_get_sub_id(item) for item in state.get(name, [])]
|
|
195
|
+
# Actual type is first argument from list type
|
|
196
|
+
if (actual_type := get_args(param.annotation)[0]) == Any:
|
|
197
|
+
raise ValueError(
|
|
198
|
+
f"Step function argument '{param.name}' cannot be serialized from database with type 'Any'"
|
|
199
|
+
)
|
|
200
|
+
subscriptions = [actual_type.from_subscription(subscription_id) for subscription_id in subscription_ids]
|
|
201
|
+
arguments.append(subscriptions)
|
|
202
|
+
elif is_optional_type(param.annotation, SubscriptionModel):
|
|
203
|
+
subscription_id = _get_sub_id(state.get(name))
|
|
204
|
+
if subscription_id:
|
|
205
|
+
# Actual type is first argument from optional type
|
|
206
|
+
sub_mod = get_args(param.annotation)[0].from_subscription(subscription_id)
|
|
207
|
+
arguments.append(sub_mod)
|
|
208
|
+
else:
|
|
209
|
+
arguments.append(None)
|
|
210
|
+
elif param.default is not inspect.Parameter.empty:
|
|
211
|
+
arguments.append(state.get(name, param.default))
|
|
212
|
+
else:
|
|
172
213
|
try:
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
sub_mod = param.annotation.from_subscription(subscription_id)
|
|
181
|
-
arguments.append(sub_mod)
|
|
182
|
-
else:
|
|
183
|
-
logger.error("Could not find key in state.", key=name, state=state)
|
|
184
|
-
raise KeyError(f"Could not find key '{name}' in state.")
|
|
185
|
-
elif is_list_type(param.annotation, SubscriptionModel):
|
|
186
|
-
subscription_ids = map(_get_sub_id, state.get(name, []))
|
|
187
|
-
# Actual type is first argument from list type
|
|
188
|
-
if (actual_type := get_args(param.annotation)[0]) == Any:
|
|
189
|
-
raise ValueError(
|
|
190
|
-
f"Step function argument '{param.name}' cannot be serialized from database with type 'Any'"
|
|
191
|
-
)
|
|
192
|
-
subscriptions = [actual_type.from_subscription(subscription_id) for subscription_id in subscription_ids]
|
|
193
|
-
arguments.append(subscriptions)
|
|
194
|
-
elif is_optional_type(param.annotation, SubscriptionModel):
|
|
195
|
-
subscription_id = _get_sub_id(state.get(name))
|
|
196
|
-
if subscription_id:
|
|
197
|
-
# Actual type is first argument from optional type
|
|
198
|
-
sub_mod = get_args(param.annotation)[0].from_subscription(subscription_id)
|
|
199
|
-
arguments.append(sub_mod)
|
|
214
|
+
value = state[name]
|
|
215
|
+
if param.annotation == UUID:
|
|
216
|
+
arguments.append(_convert_to_uuid(value))
|
|
217
|
+
elif is_list_type(param.annotation, UUID):
|
|
218
|
+
arguments.append([_convert_to_uuid(item) for item in value])
|
|
219
|
+
elif is_optional_type(param.annotation, UUID):
|
|
220
|
+
arguments.append(None if value is None else _convert_to_uuid(value))
|
|
200
221
|
else:
|
|
201
|
-
arguments.append(
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
) from key_error
|
|
222
|
+
arguments.append(value)
|
|
223
|
+
except KeyError as key_error:
|
|
224
|
+
logger.error("Could not find key in state.", key=name, state=state)
|
|
225
|
+
raise KeyError(
|
|
226
|
+
f"Could not find key '{name}' in state. for function {func.__module__}.{func.__qualname__}"
|
|
227
|
+
) from key_error
|
|
228
|
+
except ValueError as value_error:
|
|
229
|
+
logger.error("Could not convert value to expected type.", key=name, state=state, value=state[name])
|
|
230
|
+
raise ValueError(f"Could not convert value '{state[name]}' to {param.annotation}") from value_error
|
|
231
|
+
|
|
212
232
|
return arguments
|
|
213
233
|
|
|
214
234
|
|
orchestrator/workflow.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2019-
|
|
1
|
+
# Copyright 2019-2025 SURF, GÉANT, ESnet.
|
|
2
2
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
3
3
|
# you may not use this file except in compliance with the License.
|
|
4
4
|
# You may obtain a copy of the License at
|
|
@@ -39,6 +39,7 @@ from structlog.contextvars import bound_contextvars
|
|
|
39
39
|
from structlog.stdlib import BoundLogger
|
|
40
40
|
|
|
41
41
|
from nwastdlib import const, identity
|
|
42
|
+
from oauth2_lib.fastapi import OIDCUserModel
|
|
42
43
|
from orchestrator.config.assignee import Assignee
|
|
43
44
|
from orchestrator.db import db, transactional
|
|
44
45
|
from orchestrator.services.settings import get_engine_settings
|
|
@@ -69,6 +70,7 @@ step_log_fn_var: contextvars.ContextVar[StepLogFuncInternal] = contextvars.Conte
|
|
|
69
70
|
|
|
70
71
|
DEFAULT_CALLBACK_ROUTE_KEY = "callback_route"
|
|
71
72
|
CALLBACK_TOKEN_KEY = "__callback_token" # noqa: S105
|
|
73
|
+
DEFAULT_CALLBACK_PROGRESS_KEY = "callback_progress" # noqa: S105
|
|
72
74
|
|
|
73
75
|
|
|
74
76
|
@runtime_checkable
|
|
@@ -88,6 +90,7 @@ class Workflow(Protocol):
|
|
|
88
90
|
__qualname__: str
|
|
89
91
|
name: str
|
|
90
92
|
description: str
|
|
93
|
+
authorize_callback: Callable[[OIDCUserModel | None], bool]
|
|
91
94
|
initial_input_form: InputFormGenerator | None = None
|
|
92
95
|
target: Target
|
|
93
96
|
steps: StepList
|
|
@@ -177,12 +180,18 @@ def _handle_simple_input_form_generator(f: StateInputStepFunc) -> StateInputForm
|
|
|
177
180
|
return form_generator
|
|
178
181
|
|
|
179
182
|
|
|
183
|
+
def allow(_: OIDCUserModel | None) -> bool:
|
|
184
|
+
"""Default function to return True in absence of user-defined authorize function."""
|
|
185
|
+
return True
|
|
186
|
+
|
|
187
|
+
|
|
180
188
|
def make_workflow(
|
|
181
189
|
f: Callable,
|
|
182
190
|
description: str,
|
|
183
191
|
initial_input_form: InputStepFunc | None,
|
|
184
192
|
target: Target,
|
|
185
193
|
steps: StepList,
|
|
194
|
+
authorize_callback: Callable[[OIDCUserModel | None], bool] | None = None,
|
|
186
195
|
) -> Workflow:
|
|
187
196
|
@functools.wraps(f)
|
|
188
197
|
def wrapping_function() -> NoReturn:
|
|
@@ -192,6 +201,7 @@ def make_workflow(
|
|
|
192
201
|
|
|
193
202
|
wrapping_function.name = f.__name__ # default, will be changed by LazyWorkflowInstance
|
|
194
203
|
wrapping_function.description = description
|
|
204
|
+
wrapping_function.authorize_callback = allow if authorize_callback is None else authorize_callback
|
|
195
205
|
|
|
196
206
|
if initial_input_form is None:
|
|
197
207
|
# We always need a form to prevent starting a workflow when no input is needed.
|
|
@@ -458,7 +468,10 @@ def focussteps(key: str) -> Callable[[Step | StepList], StepList]:
|
|
|
458
468
|
|
|
459
469
|
|
|
460
470
|
def workflow(
|
|
461
|
-
description: str,
|
|
471
|
+
description: str,
|
|
472
|
+
initial_input_form: InputStepFunc | None = None,
|
|
473
|
+
target: Target = Target.SYSTEM,
|
|
474
|
+
authorize_callback: Callable[[OIDCUserModel | None], bool] | None = None,
|
|
462
475
|
) -> Callable[[Callable[[], StepList]], Workflow]:
|
|
463
476
|
"""Transform an initial_input_form and a step list into a workflow.
|
|
464
477
|
|
|
@@ -478,7 +491,9 @@ def workflow(
|
|
|
478
491
|
initial_input_form_in_form_inject_args = form_inject_args(initial_input_form)
|
|
479
492
|
|
|
480
493
|
def _workflow(f: Callable[[], StepList]) -> Workflow:
|
|
481
|
-
return make_workflow(
|
|
494
|
+
return make_workflow(
|
|
495
|
+
f, description, initial_input_form_in_form_inject_args, target, f(), authorize_callback=authorize_callback
|
|
496
|
+
)
|
|
482
497
|
|
|
483
498
|
return _workflow
|
|
484
499
|
|
|
@@ -490,13 +505,14 @@ class ProcessStat:
|
|
|
490
505
|
state: Process
|
|
491
506
|
log: StepList
|
|
492
507
|
current_user: str
|
|
508
|
+
user_model: OIDCUserModel | None = None
|
|
493
509
|
|
|
494
510
|
def update(self, **vs: Any) -> ProcessStat:
|
|
495
511
|
"""Update ProcessStat.
|
|
496
512
|
|
|
497
513
|
>>> pstat = ProcessStat('', None, {}, [], "")
|
|
498
514
|
>>> pstat.update(state={"a": "b"})
|
|
499
|
-
ProcessStat(process_id='', workflow=None, state={'a': 'b'}, log=[], current_user='')
|
|
515
|
+
ProcessStat(process_id='', workflow=None, state={'a': 'b'}, log=[], current_user='', user_model=None)
|
|
500
516
|
"""
|
|
501
517
|
return ProcessStat(**{**asdict(self), **vs})
|
|
502
518
|
|
orchestrator/workflows/steps.py
CHANGED