orchestrator-core 3.1.0rc1__py3-none-any.whl → 3.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. orchestrator/__init__.py +2 -2
  2. orchestrator/api/api_v1/api.py +1 -1
  3. orchestrator/api/api_v1/endpoints/processes.py +29 -9
  4. orchestrator/api/api_v1/endpoints/settings.py +1 -1
  5. orchestrator/api/api_v1/endpoints/subscriptions.py +1 -1
  6. orchestrator/app.py +1 -1
  7. orchestrator/cli/database.py +1 -1
  8. orchestrator/cli/generator/generator/migration.py +2 -5
  9. orchestrator/cli/migrate_tasks.py +13 -0
  10. orchestrator/config/assignee.py +1 -1
  11. orchestrator/db/__init__.py +2 -0
  12. orchestrator/db/models.py +6 -4
  13. orchestrator/devtools/populator.py +1 -1
  14. orchestrator/domain/__init__.py +2 -3
  15. orchestrator/domain/base.py +74 -5
  16. orchestrator/domain/lifecycle.py +1 -1
  17. orchestrator/graphql/schema.py +1 -1
  18. orchestrator/graphql/types.py +1 -1
  19. orchestrator/graphql/utils/get_subscription_product_blocks.py +13 -0
  20. orchestrator/migrations/env.py +15 -2
  21. orchestrator/migrations/helpers.py +6 -6
  22. orchestrator/migrations/versions/schema/2020-10-19_c112305b07d3_initial_schema_migration.py +1 -1
  23. orchestrator/migrations/versions/schema/2023-05-25_b1970225392d_add_subscription_metadata_workflow.py +1 -1
  24. orchestrator/migrations/versions/schema/2025-02-12_bac6be6f2b4f_added_input_state_table.py +1 -1
  25. orchestrator/schemas/engine_settings.py +1 -1
  26. orchestrator/schemas/subscription.py +1 -1
  27. orchestrator/security.py +1 -1
  28. orchestrator/services/celery.py +1 -1
  29. orchestrator/services/processes.py +99 -18
  30. orchestrator/services/products.py +1 -1
  31. orchestrator/services/subscriptions.py +1 -1
  32. orchestrator/services/tasks.py +1 -1
  33. orchestrator/settings.py +2 -23
  34. orchestrator/targets.py +1 -1
  35. orchestrator/types.py +1 -1
  36. orchestrator/utils/errors.py +1 -1
  37. orchestrator/utils/state.py +74 -54
  38. orchestrator/websocket/websocket_manager.py +1 -1
  39. orchestrator/workflow.py +20 -4
  40. orchestrator/workflows/modify_note.py +1 -1
  41. orchestrator/workflows/steps.py +1 -1
  42. orchestrator/workflows/tasks/cleanup_tasks_log.py +1 -1
  43. orchestrator/workflows/tasks/resume_workflows.py +1 -1
  44. orchestrator/workflows/tasks/validate_product_type.py +1 -1
  45. orchestrator/workflows/tasks/validate_products.py +1 -1
  46. orchestrator/workflows/utils.py +40 -5
  47. {orchestrator_core-3.1.0rc1.dist-info → orchestrator_core-3.1.2.dist-info}/METADATA +13 -13
  48. {orchestrator_core-3.1.0rc1.dist-info → orchestrator_core-3.1.2.dist-info}/RECORD +50 -50
  49. {orchestrator_core-3.1.0rc1.dist-info → orchestrator_core-3.1.2.dist-info}/WHEEL +1 -1
  50. {orchestrator_core-3.1.0rc1.dist-info → orchestrator_core-3.1.2.dist-info}/licenses/LICENSE +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2019-2020 SURF.
1
+ # Copyright 2019-2025 SURF, GÉANT, ESnet.
2
2
  # Licensed under the Apache License, Version 2.0 (the "License");
3
3
  # you may not use this file except in compliance with the License.
4
4
  # You may obtain a copy of the License at
@@ -24,15 +24,10 @@ from sqlalchemy.exc import SQLAlchemyError
24
24
  from sqlalchemy.orm import joinedload
25
25
 
26
26
  from nwastdlib.ex import show_ex
27
+ from oauth2_lib.fastapi import OIDCUserModel
27
28
  from orchestrator.api.error_handling import raise_status
28
29
  from orchestrator.config.assignee import Assignee
29
- from orchestrator.db import (
30
- EngineSettingsTable,
31
- ProcessStepTable,
32
- ProcessSubscriptionTable,
33
- ProcessTable,
34
- db,
35
- )
30
+ from orchestrator.db import EngineSettingsTable, ProcessStepTable, ProcessSubscriptionTable, ProcessTable, db
36
31
  from orchestrator.distlock import distlock_manager
37
32
  from orchestrator.schemas.engine_settings import WorkerStatus
38
33
  from orchestrator.services.input_state import store_input_state
@@ -43,9 +38,10 @@ from orchestrator.targets import Target
43
38
  from orchestrator.types import BroadcastFunc
44
39
  from orchestrator.utils.datetime import nowtz
45
40
  from orchestrator.utils.errors import error_state_to_dict
46
- from orchestrator.websocket import broadcast_invalidate_status_counts
41
+ from orchestrator.websocket import broadcast_invalidate_status_counts, broadcast_process_update_to_websocket
47
42
  from orchestrator.workflow import (
48
43
  CALLBACK_TOKEN_KEY,
44
+ DEFAULT_CALLBACK_PROGRESS_KEY,
49
45
  Failed,
50
46
  ProcessStat,
51
47
  ProcessStatus,
@@ -413,10 +409,15 @@ def _run_process_async(process_id: UUID, f: Callable) -> UUID:
413
409
  return process_id
414
410
 
415
411
 
412
+ def error_message_unauthorized(workflow_key: str) -> str:
413
+ return f"User is not authorized to execute '{workflow_key}' workflow"
414
+
415
+
416
416
  def create_process(
417
417
  workflow_key: str,
418
418
  user_inputs: list[State] | None = None,
419
419
  user: str = SYSTEM_USER,
420
+ user_model: OIDCUserModel | None = None,
420
421
  ) -> ProcessStat:
421
422
  # ATTENTION!! When modifying this function make sure you make similar changes to `run_workflow` in the test code
422
423
 
@@ -429,6 +430,9 @@ def create_process(
429
430
  if not workflow:
430
431
  raise_status(HTTPStatus.NOT_FOUND, "Workflow does not exist")
431
432
 
433
+ if not workflow.authorize_callback(user_model):
434
+ raise_status(HTTPStatus.FORBIDDEN, error_message_unauthorized(workflow_key))
435
+
432
436
  initial_state = {
433
437
  "process_id": process_id,
434
438
  "reporter": user,
@@ -448,6 +452,7 @@ def create_process(
448
452
  state=Success(state | initial_state),
449
453
  log=workflow.steps,
450
454
  current_user=user,
455
+ user_model=user_model,
451
456
  )
452
457
 
453
458
  _db_create_process(pstat)
@@ -459,9 +464,12 @@ def thread_start_process(
459
464
  workflow_key: str,
460
465
  user_inputs: list[State] | None = None,
461
466
  user: str = SYSTEM_USER,
467
+ user_model: OIDCUserModel | None = None,
462
468
  broadcast_func: BroadcastFunc | None = None,
463
469
  ) -> UUID:
464
470
  pstat = create_process(workflow_key, user_inputs=user_inputs, user=user)
471
+ if not pstat.workflow.authorize_callback(user_model):
472
+ raise_status(HTTPStatus.FORBIDDEN, error_message_unauthorized(workflow_key))
465
473
 
466
474
  _safe_logstep_with_func = partial(safe_logstep, broadcast_func=broadcast_func)
467
475
  return _run_process_async(pstat.process_id, lambda: runwf(pstat, _safe_logstep_with_func))
@@ -471,6 +479,7 @@ def start_process(
471
479
  workflow_key: str,
472
480
  user_inputs: list[State] | None = None,
473
481
  user: str = SYSTEM_USER,
482
+ user_model: OIDCUserModel | None = None,
474
483
  broadcast_func: BroadcastFunc | None = None,
475
484
  ) -> UUID:
476
485
  """Start a process for workflow.
@@ -479,6 +488,7 @@ def start_process(
479
488
  workflow_key: name of workflow
480
489
  user_inputs: List of form inputs from frontend
481
490
  user: User who starts this process
491
+ user_model: Full OIDCUserModel with claims, etc
482
492
  broadcast_func: Optional function to broadcast process data
483
493
 
484
494
  Returns:
@@ -486,7 +496,9 @@ def start_process(
486
496
 
487
497
  """
488
498
  start_func = get_execution_context()["start"]
489
- return start_func(workflow_key, user_inputs=user_inputs, user=user, broadcast_func=broadcast_func)
499
+ return start_func(
500
+ workflow_key, user_inputs=user_inputs, user=user, user_model=user_model, broadcast_func=broadcast_func
501
+ )
490
502
 
491
503
 
492
504
  def thread_resume_process(
@@ -494,6 +506,7 @@ def thread_resume_process(
494
506
  *,
495
507
  user_inputs: list[State] | None = None,
496
508
  user: str | None = None,
509
+ user_model: OIDCUserModel | None = None,
497
510
  broadcast_func: BroadcastFunc | None = None,
498
511
  ) -> UUID:
499
512
  # ATTENTION!! When modifying this function make sure you make similar changes to `resume_workflow` in the test code
@@ -502,6 +515,8 @@ def thread_resume_process(
502
515
  user_inputs = [{}]
503
516
 
504
517
  pstat = load_process(process)
518
+ if not pstat.workflow.authorize_callback(user_model):
519
+ raise_status(HTTPStatus.FORBIDDEN, error_message_unauthorized(str(process.workflow_name)))
505
520
 
506
521
  if pstat.workflow == removed_workflow:
507
522
  raise ValueError("This workflow cannot be resumed")
@@ -541,6 +556,7 @@ def resume_process(
541
556
  *,
542
557
  user_inputs: list[State] | None = None,
543
558
  user: str | None = None,
559
+ user_model: OIDCUserModel | None = None,
544
560
  broadcast_func: BroadcastFunc | None = None,
545
561
  ) -> UUID:
546
562
  """Resume a failed or suspended process.
@@ -549,6 +565,7 @@ def resume_process(
549
565
  process: Process from database
550
566
  user_inputs: Optional user input from forms
551
567
  user: user who resumed this process
568
+ user_model: OIDCUserModel of user who resumed this process
552
569
  broadcast_func: Optional function to broadcast process data
553
570
 
554
571
  Returns:
@@ -556,6 +573,9 @@ def resume_process(
556
573
 
557
574
  """
558
575
  pstat = load_process(process)
576
+ if not pstat.workflow.authorize_callback(user_model):
577
+ raise_status(HTTPStatus.FORBIDDEN, error_message_unauthorized(str(process.workflow_name)))
578
+
559
579
  try:
560
580
  post_form(pstat.log[0].form, pstat.state.unwrap(), user_inputs=user_inputs or [])
561
581
  except FormValidationError:
@@ -566,6 +586,39 @@ def resume_process(
566
586
  return resume_func(process, user_inputs=user_inputs, user=user, broadcast_func=broadcast_func)
567
587
 
568
588
 
589
+ def ensure_correct_callback_token(pstat: ProcessStat, *, token: str) -> None:
590
+ """Ensure that a callback token matches the expected value in state.
591
+
592
+ Args:
593
+ pstat: ProcessStat of process.
594
+ token: The token which was generated for the process.
595
+
596
+ Raises:
597
+ AssertionError: if the supplied token does not match the generated process token.
598
+
599
+ """
600
+ state = pstat.state.unwrap()
601
+
602
+ # Check if the token matches
603
+ token_from_state = state.get(CALLBACK_TOKEN_KEY)
604
+ if token != token_from_state:
605
+ raise AssertionError("Invalid token")
606
+
607
+
608
+ def replace_current_step_state(process: ProcessTable, *, new_state: State) -> None:
609
+ """Replace the state of the current step in a process.
610
+
611
+ Args:
612
+ process: Process from database
613
+ new_state: The new state
614
+
615
+ """
616
+ current_step = process.steps[-1]
617
+ current_step.state = new_state
618
+ db.session.add(current_step)
619
+ db.session.commit()
620
+
621
+
569
622
  def continue_awaiting_process(
570
623
  process: ProcessTable,
571
624
  *,
@@ -589,10 +642,7 @@ def continue_awaiting_process(
589
642
  pstat = load_process(process)
590
643
  state = pstat.state.unwrap()
591
644
 
592
- # Check if the token matches
593
- token_from_state = state.get(CALLBACK_TOKEN_KEY)
594
- if token != token_from_state:
595
- raise AssertionError("Invalid token")
645
+ ensure_correct_callback_token(pstat, token=token)
596
646
 
597
647
  # We need to pass the callback data to the worker executor. Currently, this is not supported.
598
648
  # Therefore, we update the step state in the db and kick-off resume_workflow
@@ -600,16 +650,47 @@ def continue_awaiting_process(
600
650
  result_key = state.get("__callback_result_key", "callback_result")
601
651
  state = {**state, result_key: input_data}
602
652
 
603
- current_step = process.steps[-1]
604
- current_step.state = state
605
- db.session.add(current_step)
606
- db.session.commit()
653
+ replace_current_step_state(process, new_state=state)
607
654
 
608
655
  # Continue the workflow
609
656
  resume_func = get_execution_context()["resume"]
610
657
  return resume_func(process, broadcast_func=broadcast_func)
611
658
 
612
659
 
660
+ def update_awaiting_process_progress(
661
+ process: ProcessTable,
662
+ *,
663
+ token: str,
664
+ data: str | State,
665
+ ) -> UUID:
666
+ """Update progress for a process awaiting data from a callback.
667
+
668
+ Args:
669
+ process: Process from database
670
+ token: The token which was generated for the process. This must match.
671
+ data: Progress data posted to the callback
672
+
673
+ Returns:
674
+ process id
675
+
676
+ Raises:
677
+ AssertionError: if the supplied token does not match the generated process token.
678
+
679
+ """
680
+ pstat = load_process(process)
681
+
682
+ ensure_correct_callback_token(pstat, token=token)
683
+
684
+ state = pstat.state.unwrap()
685
+ progress_key = state.get(DEFAULT_CALLBACK_PROGRESS_KEY, "callback_progress")
686
+ state = {**state, progress_key: data} | {"__remove_keys": [progress_key]}
687
+
688
+ replace_current_step_state(process, new_state=state)
689
+ broadcast_process_update_to_websocket(process.process_id)
690
+
691
+ return process.process_id
692
+
693
+
613
694
  async def _async_resume_processes(
614
695
  processes: Sequence[ProcessTable],
615
696
  user_name: str,
@@ -1,4 +1,4 @@
1
- # Copyright 2019-2020 SURF.
1
+ # Copyright 2019-2020 SURF, GÉANT.
2
2
  # Licensed under the Apache License, Version 2.0 (the "License");
3
3
  # you may not use this file except in compliance with the License.
4
4
  # You may obtain a copy of the License at
@@ -1,4 +1,4 @@
1
- # Copyright 2019-2020 SURF.
1
+ # Copyright 2019-2020 SURF, GÉANT.
2
2
  # Licensed under the Apache License, Version 2.0 (the "License");
3
3
  # you may not use this file except in compliance with the License.
4
4
  # You may obtain a copy of the License at
@@ -1,4 +1,4 @@
1
- # Copyright 2019-2020 SURF.
1
+ # Copyright 2019-2020 SURF, GÉANT.
2
2
  # Licensed under the Apache License, Version 2.0 (the "License");
3
3
  # you may not use this file except in compliance with the License.
4
4
  # You may obtain a copy of the License at
orchestrator/settings.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2019-2020 SURF.
1
+ # Copyright 2019-2020 SURF, GÉANT.
2
2
  # Licensed under the Apache License, Version 2.0 (the "License");
3
3
  # you may not use this file except in compliance with the License.
4
4
  # You may obtain a copy of the License at
@@ -13,7 +13,6 @@
13
13
 
14
14
  import secrets
15
15
  import string
16
- import warnings
17
16
  from pathlib import Path
18
17
  from typing import Literal
19
18
 
@@ -24,10 +23,6 @@ from oauth2_lib.settings import oauth2lib_settings
24
23
  from pydantic_forms.types import strEnum
25
24
 
26
25
 
27
- class OrchestratorDeprecationWarning(DeprecationWarning):
28
- pass
29
-
30
-
31
26
  class ExecutorType(strEnum):
32
27
  WORKER = "celery"
33
28
  THREADPOOL = "threadpool"
@@ -54,7 +49,7 @@ class AppSettings(BaseSettings):
54
49
  EXECUTOR: str = ExecutorType.THREADPOOL
55
50
  WORKFLOWS_SWAGGER_HOST: str = "localhost"
56
51
  WORKFLOWS_GUI_URI: str = "http://localhost:3000"
57
- DATABASE_URI: PostgresDsn = "postgresql+psycopg://nwa:nwa@localhost/orchestrator-core" # type: ignore
52
+ DATABASE_URI: PostgresDsn = "postgresql://nwa:nwa@localhost/orchestrator-core" # type: ignore
58
53
  MAX_WORKERS: int = 5
59
54
  MAIL_SERVER: str = "localhost"
60
55
  MAIL_PORT: int = 25
@@ -93,22 +88,6 @@ class AppSettings(BaseSettings):
93
88
  VALIDATE_OUT_OF_SYNC_SUBSCRIPTIONS: bool = False
94
89
  FILTER_BY_MODE: Literal["partial", "exact"] = "exact"
95
90
 
96
- def __init__(self) -> None:
97
- super(AppSettings, self).__init__()
98
- self.DATABASE_URI = PostgresDsn(convert_database_uri(str(self.DATABASE_URI)))
99
-
100
-
101
- def convert_database_uri(db_uri: str) -> str:
102
- if db_uri.startswith(("postgresql://", "postgresql+psycopg2://")):
103
- db_uri = "postgresql+psycopg" + db_uri[db_uri.find("://") :]
104
- warnings.filterwarnings("always", category=OrchestratorDeprecationWarning)
105
- warnings.warn(
106
- "DATABASE_URI converted to postgresql+psycopg:// format, please update your enviroment variable",
107
- OrchestratorDeprecationWarning,
108
- stacklevel=2,
109
- )
110
- return db_uri
111
-
112
91
 
113
92
  app_settings = AppSettings()
114
93
 
orchestrator/targets.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2019-2020 SURF.
1
+ # Copyright 2019-2020 SURF, GÉANT.
2
2
  # Licensed under the Apache License, Version 2.0 (the "License");
3
3
  # you may not use this file except in compliance with the License.
4
4
  # You may obtain a copy of the License at
orchestrator/types.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2019-2020 SURF.
1
+ # Copyright 2019-2020 SURF, GÉANT.
2
2
  # Licensed under the Apache License, Version 2.0 (the "License");
3
3
  # you may not use this file except in compliance with the License.
4
4
  # You may obtain a copy of the License at
@@ -1,4 +1,4 @@
1
- # Copyright 2019-2020 SURF.
1
+ # Copyright 2019-2020 SURF, GÉANT.
2
2
  # Licensed under the Apache License, Version 2.0 (the "License");
3
3
  # you may not use this file except in compliance with the License.
4
4
  # You may obtain a copy of the License at
@@ -1,4 +1,4 @@
1
- # Copyright 2019-2020 SURF, ESnet
1
+ # Copyright 2019-2020 SURF, ESnet, GÉANT.
2
2
  # Licensed under the Apache License, Version 2.0 (the "License");
3
3
  # you may not use this file except in compliance with the License.
4
4
  # You may obtain a copy of the License at
@@ -150,65 +150,85 @@ def _build_arguments(func: StepFunc | InputStepFunc, state: State) -> list: # n
150
150
 
151
151
  Raises:
152
152
  KeyError: if requested argument is not in the state, or cannot be reconstructed as an initial domain model.
153
+ ValueError: if requested argument cannot be converted to the expected type.
153
154
 
154
155
  """
156
+
155
157
  sig = inspect.signature(func)
158
+ if not sig.parameters:
159
+ return []
160
+
161
+ def _convert_to_uuid(v: Any) -> UUID:
162
+ """Converts the value to a UUID instance if it is not already one."""
163
+ return v if isinstance(v, UUID) else UUID(v)
164
+
156
165
  arguments: list[Any] = []
157
- if sig.parameters:
158
- for name, param in sig.parameters.items():
159
- # Ignore dynamic arguments. Mostly need to deal with `const`
160
- if param.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD):
161
- logger.warning("*args and **kwargs are not supported as step params")
162
- continue
163
-
164
- # If we find an argument named "state" we use the whole state as argument to
165
- # This is mainly to be backward compatible with code that needs the whole state...
166
- # TODO: Remove this construction
167
- if name == "state":
168
- arguments.append(state)
169
- continue
170
-
171
- # Workaround for the fact that you can't call issubclass on typing types
166
+ for name, param in sig.parameters.items():
167
+ # Ignore dynamic arguments. Mostly need to deal with `const`
168
+ if param.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD):
169
+ logger.warning("*args and **kwargs are not supported as step params")
170
+ continue
171
+
172
+ # If we find an argument named "state" we use the whole state as argument to
173
+ # This is mainly to be backward compatible with code that needs the whole state...
174
+ # TODO: Remove this construction
175
+ if name == "state":
176
+ arguments.append(state)
177
+ continue
178
+
179
+ # Workaround for the fact that you can't call issubclass on typing types
180
+ try:
181
+ is_subscription_model_type = issubclass(param.annotation, SubscriptionModel)
182
+ except Exception:
183
+ is_subscription_model_type = False
184
+
185
+ if is_subscription_model_type:
186
+ subscription_id = _get_sub_id(state.get(name))
187
+ if subscription_id:
188
+ sub_mod = param.annotation.from_subscription(subscription_id)
189
+ arguments.append(sub_mod)
190
+ else:
191
+ logger.error("Could not find key in state.", key=name, state=state)
192
+ raise KeyError(f"Could not find key '{name}' in state.")
193
+ elif is_list_type(param.annotation, SubscriptionModel):
194
+ subscription_ids = [_get_sub_id(item) for item in state.get(name, [])]
195
+ # Actual type is first argument from list type
196
+ if (actual_type := get_args(param.annotation)[0]) == Any:
197
+ raise ValueError(
198
+ f"Step function argument '{param.name}' cannot be serialized from database with type 'Any'"
199
+ )
200
+ subscriptions = [actual_type.from_subscription(subscription_id) for subscription_id in subscription_ids]
201
+ arguments.append(subscriptions)
202
+ elif is_optional_type(param.annotation, SubscriptionModel):
203
+ subscription_id = _get_sub_id(state.get(name))
204
+ if subscription_id:
205
+ # Actual type is first argument from optional type
206
+ sub_mod = get_args(param.annotation)[0].from_subscription(subscription_id)
207
+ arguments.append(sub_mod)
208
+ else:
209
+ arguments.append(None)
210
+ elif param.default is not inspect.Parameter.empty:
211
+ arguments.append(state.get(name, param.default))
212
+ else:
172
213
  try:
173
- is_subscription_model_type = issubclass(param.annotation, SubscriptionModel)
174
- except Exception:
175
- is_subscription_model_type = False
176
-
177
- if is_subscription_model_type:
178
- subscription_id = _get_sub_id(state.get(name))
179
- if subscription_id:
180
- sub_mod = param.annotation.from_subscription(subscription_id)
181
- arguments.append(sub_mod)
182
- else:
183
- logger.error("Could not find key in state.", key=name, state=state)
184
- raise KeyError(f"Could not find key '{name}' in state.")
185
- elif is_list_type(param.annotation, SubscriptionModel):
186
- subscription_ids = map(_get_sub_id, state.get(name, []))
187
- # Actual type is first argument from list type
188
- if (actual_type := get_args(param.annotation)[0]) == Any:
189
- raise ValueError(
190
- f"Step function argument '{param.name}' cannot be serialized from database with type 'Any'"
191
- )
192
- subscriptions = [actual_type.from_subscription(subscription_id) for subscription_id in subscription_ids]
193
- arguments.append(subscriptions)
194
- elif is_optional_type(param.annotation, SubscriptionModel):
195
- subscription_id = _get_sub_id(state.get(name))
196
- if subscription_id:
197
- # Actual type is first argument from optional type
198
- sub_mod = get_args(param.annotation)[0].from_subscription(subscription_id)
199
- arguments.append(sub_mod)
214
+ value = state[name]
215
+ if param.annotation == UUID:
216
+ arguments.append(_convert_to_uuid(value))
217
+ elif is_list_type(param.annotation, UUID):
218
+ arguments.append([_convert_to_uuid(item) for item in value])
219
+ elif is_optional_type(param.annotation, UUID):
220
+ arguments.append(None if value is None else _convert_to_uuid(value))
200
221
  else:
201
- arguments.append(None)
202
- elif param.default is not inspect.Parameter.empty:
203
- arguments.append(state.get(name, param.default))
204
- else:
205
- try:
206
- arguments.append(state[name])
207
- except KeyError as key_error:
208
- logger.error("Could not find key in state.", key=name, state=state)
209
- raise KeyError(
210
- f"Could not find key '{name}' in state. for function {func.__module__}.{func.__qualname__}"
211
- ) from key_error
222
+ arguments.append(value)
223
+ except KeyError as key_error:
224
+ logger.error("Could not find key in state.", key=name, state=state)
225
+ raise KeyError(
226
+ f"Could not find key '{name}' in state. for function {func.__module__}.{func.__qualname__}"
227
+ ) from key_error
228
+ except ValueError as value_error:
229
+ logger.error("Could not convert value to expected type.", key=name, state=state, value=state[name])
230
+ raise ValueError(f"Could not convert value '{state[name]}' to {param.annotation}") from value_error
231
+
212
232
  return arguments
213
233
 
214
234
 
@@ -1,4 +1,4 @@
1
- # Copyright 2019-2020 SURF.
1
+ # Copyright 2019-2020 SURF, GÉANT.
2
2
  # Licensed under the Apache License, Version 2.0 (the "License");
3
3
  # you may not use this file except in compliance with the License.
4
4
  # You may obtain a copy of the License at
orchestrator/workflow.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2019-2020 SURF.
1
+ # Copyright 2019-2025 SURF, GÉANT, ESnet.
2
2
  # Licensed under the Apache License, Version 2.0 (the "License");
3
3
  # you may not use this file except in compliance with the License.
4
4
  # You may obtain a copy of the License at
@@ -39,6 +39,7 @@ from structlog.contextvars import bound_contextvars
39
39
  from structlog.stdlib import BoundLogger
40
40
 
41
41
  from nwastdlib import const, identity
42
+ from oauth2_lib.fastapi import OIDCUserModel
42
43
  from orchestrator.config.assignee import Assignee
43
44
  from orchestrator.db import db, transactional
44
45
  from orchestrator.services.settings import get_engine_settings
@@ -69,6 +70,7 @@ step_log_fn_var: contextvars.ContextVar[StepLogFuncInternal] = contextvars.Conte
69
70
 
70
71
  DEFAULT_CALLBACK_ROUTE_KEY = "callback_route"
71
72
  CALLBACK_TOKEN_KEY = "__callback_token" # noqa: S105
73
+ DEFAULT_CALLBACK_PROGRESS_KEY = "callback_progress" # noqa: S105
72
74
 
73
75
 
74
76
  @runtime_checkable
@@ -88,6 +90,7 @@ class Workflow(Protocol):
88
90
  __qualname__: str
89
91
  name: str
90
92
  description: str
93
+ authorize_callback: Callable[[OIDCUserModel | None], bool]
91
94
  initial_input_form: InputFormGenerator | None = None
92
95
  target: Target
93
96
  steps: StepList
@@ -177,12 +180,18 @@ def _handle_simple_input_form_generator(f: StateInputStepFunc) -> StateInputForm
177
180
  return form_generator
178
181
 
179
182
 
183
+ def allow(_: OIDCUserModel | None) -> bool:
184
+ """Default function to return True in absence of user-defined authorize function."""
185
+ return True
186
+
187
+
180
188
  def make_workflow(
181
189
  f: Callable,
182
190
  description: str,
183
191
  initial_input_form: InputStepFunc | None,
184
192
  target: Target,
185
193
  steps: StepList,
194
+ authorize_callback: Callable[[OIDCUserModel | None], bool] | None = None,
186
195
  ) -> Workflow:
187
196
  @functools.wraps(f)
188
197
  def wrapping_function() -> NoReturn:
@@ -192,6 +201,7 @@ def make_workflow(
192
201
 
193
202
  wrapping_function.name = f.__name__ # default, will be changed by LazyWorkflowInstance
194
203
  wrapping_function.description = description
204
+ wrapping_function.authorize_callback = allow if authorize_callback is None else authorize_callback
195
205
 
196
206
  if initial_input_form is None:
197
207
  # We always need a form to prevent starting a workflow when no input is needed.
@@ -458,7 +468,10 @@ def focussteps(key: str) -> Callable[[Step | StepList], StepList]:
458
468
 
459
469
 
460
470
  def workflow(
461
- description: str, initial_input_form: InputStepFunc | None = None, target: Target = Target.SYSTEM
471
+ description: str,
472
+ initial_input_form: InputStepFunc | None = None,
473
+ target: Target = Target.SYSTEM,
474
+ authorize_callback: Callable[[OIDCUserModel | None], bool] | None = None,
462
475
  ) -> Callable[[Callable[[], StepList]], Workflow]:
463
476
  """Transform an initial_input_form and a step list into a workflow.
464
477
 
@@ -478,7 +491,9 @@ def workflow(
478
491
  initial_input_form_in_form_inject_args = form_inject_args(initial_input_form)
479
492
 
480
493
  def _workflow(f: Callable[[], StepList]) -> Workflow:
481
- return make_workflow(f, description, initial_input_form_in_form_inject_args, target, f())
494
+ return make_workflow(
495
+ f, description, initial_input_form_in_form_inject_args, target, f(), authorize_callback=authorize_callback
496
+ )
482
497
 
483
498
  return _workflow
484
499
 
@@ -490,13 +505,14 @@ class ProcessStat:
490
505
  state: Process
491
506
  log: StepList
492
507
  current_user: str
508
+ user_model: OIDCUserModel | None = None
493
509
 
494
510
  def update(self, **vs: Any) -> ProcessStat:
495
511
  """Update ProcessStat.
496
512
 
497
513
  >>> pstat = ProcessStat('', None, {}, [], "")
498
514
  >>> pstat.update(state={"a": "b"})
499
- ProcessStat(process_id='', workflow=None, state={'a': 'b'}, log=[], current_user='')
515
+ ProcessStat(process_id='', workflow=None, state={'a': 'b'}, log=[], current_user='', user_model=None)
500
516
  """
501
517
  return ProcessStat(**{**asdict(self), **vs})
502
518
 
@@ -1,4 +1,4 @@
1
- # Copyright 2019-2020 SURF.
1
+ # Copyright 2019-2020 SURF, GÉANT.
2
2
  # Licensed under the Apache License, Version 2.0 (the "License");
3
3
  # you may not use this file except in compliance with the License.
4
4
  # You may obtain a copy of the License at
@@ -1,4 +1,4 @@
1
- # Copyright 2019-2020 SURF.
1
+ # Copyright 2019-2020 SURF, GÉANT.
2
2
  # Licensed under the Apache License, Version 2.0 (the "License");
3
3
  # you may not use this file except in compliance with the License.
4
4
  # You may obtain a copy of the License at
@@ -1,4 +1,4 @@
1
- # Copyright 2019-2020 SURF.
1
+ # Copyright 2019-2020 SURF, GÉANT.
2
2
  # Licensed under the Apache License, Version 2.0 (the "License");
3
3
  # you may not use this file except in compliance with the License.
4
4
  # You may obtain a copy of the License at
@@ -1,4 +1,4 @@
1
- # Copyright 2019-2020 SURF.
1
+ # Copyright 2019-2020 SURF, GÉANT.
2
2
  # Licensed under the Apache License, Version 2.0 (the "License");
3
3
  # you may not use this file except in compliance with the License.
4
4
  # You may obtain a copy of the License at
@@ -1,4 +1,4 @@
1
- # Copyright 2019-2024 SURF.
1
+ # Copyright 2019-2024 SURF, GÉANT.
2
2
  # Licensed under the Apache License, Version 2.0 (the "License");
3
3
  # you may not use this file except in compliance with the License.
4
4
  # You may obtain a copy of the License at
@@ -1,4 +1,4 @@
1
- # Copyright 2019-2020 SURF.
1
+ # Copyright 2019-2020 SURF, GÉANT.
2
2
  # Licensed under the Apache License, Version 2.0 (the "License");
3
3
  # you may not use this file except in compliance with the License.
4
4
  # You may obtain a copy of the License at