orchestrator-core 3.1.2rc4__py3-none-any.whl → 3.2.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. orchestrator/__init__.py +1 -1
  2. orchestrator/api/api_v1/endpoints/processes.py +6 -9
  3. orchestrator/db/__init__.py +2 -0
  4. orchestrator/db/loaders.py +51 -3
  5. orchestrator/db/models.py +13 -0
  6. orchestrator/db/queries/__init__.py +0 -0
  7. orchestrator/db/queries/subscription.py +85 -0
  8. orchestrator/db/queries/subscription_instance.py +28 -0
  9. orchestrator/domain/base.py +162 -44
  10. orchestrator/domain/context_cache.py +62 -0
  11. orchestrator/domain/helpers.py +41 -1
  12. orchestrator/domain/subscription_instance_transform.py +114 -0
  13. orchestrator/graphql/resolvers/process.py +3 -3
  14. orchestrator/graphql/resolvers/product.py +2 -2
  15. orchestrator/graphql/resolvers/product_block.py +2 -2
  16. orchestrator/graphql/resolvers/resource_type.py +2 -2
  17. orchestrator/graphql/resolvers/workflow.py +2 -2
  18. orchestrator/graphql/utils/get_query_loaders.py +6 -48
  19. orchestrator/graphql/utils/get_subscription_product_blocks.py +8 -1
  20. orchestrator/migrations/versions/schema/2025-03-06_42b3d076a85b_subscription_instance_as_json_function.py +33 -0
  21. orchestrator/migrations/versions/schema/2025-03-06_42b3d076a85b_subscription_instance_as_json_function.sql +40 -0
  22. orchestrator/services/processes.py +28 -9
  23. orchestrator/services/subscriptions.py +36 -6
  24. orchestrator/settings.py +3 -0
  25. orchestrator/utils/functional.py +9 -0
  26. orchestrator/utils/redis.py +6 -0
  27. orchestrator/workflow.py +19 -4
  28. orchestrator/workflows/utils.py +40 -5
  29. {orchestrator_core-3.1.2rc4.dist-info → orchestrator_core-3.2.0rc1.dist-info}/METADATA +6 -6
  30. {orchestrator_core-3.1.2rc4.dist-info → orchestrator_core-3.2.0rc1.dist-info}/RECORD +33 -26
  31. /orchestrator/migrations/versions/schema/{2025-10-19_4fjdn13f83ga_add_validate_product_type_task.py → 2025-01-19_4fjdn13f83ga_add_validate_product_type_task.py} +0 -0
  32. {orchestrator_core-3.1.2rc4.dist-info → orchestrator_core-3.2.0rc1.dist-info}/WHEEL +0 -0
  33. {orchestrator_core-3.1.2rc4.dist-info → orchestrator_core-3.2.0rc1.dist-info}/licenses/LICENSE +0 -0
@@ -17,7 +17,7 @@ from orchestrator.graphql.resolvers.helpers import rows_from_statement
17
17
  from orchestrator.graphql.schemas.product_block import ProductBlock
18
18
  from orchestrator.graphql.types import GraphqlFilter, GraphqlSort, OrchestratorInfo
19
19
  from orchestrator.graphql.utils import create_resolver_error_handler, is_querying_page_data, to_graphql_result_page
20
- from orchestrator.graphql.utils.get_query_loaders import get_query_loaders
20
+ from orchestrator.graphql.utils.get_query_loaders import get_query_loaders_for_gql_fields
21
21
  from orchestrator.utils.search_query import create_sqlalchemy_select
22
22
 
23
23
  logger = structlog.get_logger(__name__)
@@ -39,7 +39,7 @@ async def resolve_product_blocks(
39
39
  "resolve_product_blocks() called", range=[after, after + first], sort=sort_by, filter=pydantic_filter_by
40
40
  )
41
41
 
42
- query_loaders = get_query_loaders(info, ProductBlockTable)
42
+ query_loaders = get_query_loaders_for_gql_fields(ProductBlockTable, info)
43
43
  select_stmt = select(ProductBlockTable)
44
44
  select_stmt = filter_product_blocks(select_stmt, pydantic_filter_by, _error_handler)
45
45
 
@@ -17,7 +17,7 @@ from orchestrator.graphql.resolvers.helpers import rows_from_statement
17
17
  from orchestrator.graphql.schemas.resource_type import ResourceType
18
18
  from orchestrator.graphql.types import GraphqlFilter, GraphqlSort, OrchestratorInfo
19
19
  from orchestrator.graphql.utils import create_resolver_error_handler, is_querying_page_data, to_graphql_result_page
20
- from orchestrator.graphql.utils.get_query_loaders import get_query_loaders
20
+ from orchestrator.graphql.utils.get_query_loaders import get_query_loaders_for_gql_fields
21
21
  from orchestrator.utils.search_query import create_sqlalchemy_select
22
22
 
23
23
  logger = structlog.get_logger(__name__)
@@ -38,7 +38,7 @@ async def resolve_resource_types(
38
38
  logger.debug(
39
39
  "resolve_resource_types() called", range=[after, after + first], sort=sort_by, filter=pydantic_filter_by
40
40
  )
41
- query_loaders = get_query_loaders(info, ResourceTypeTable)
41
+ query_loaders = get_query_loaders_for_gql_fields(ResourceTypeTable, info)
42
42
  select_stmt = select(ResourceTypeTable).options(*query_loaders)
43
43
  select_stmt = filter_resource_types(select_stmt, pydantic_filter_by, _error_handler)
44
44
 
@@ -13,7 +13,7 @@ from orchestrator.graphql.resolvers.helpers import rows_from_statement
13
13
  from orchestrator.graphql.schemas.workflow import Workflow
14
14
  from orchestrator.graphql.types import GraphqlFilter, GraphqlSort, OrchestratorInfo
15
15
  from orchestrator.graphql.utils import create_resolver_error_handler, is_querying_page_data, to_graphql_result_page
16
- from orchestrator.graphql.utils.get_query_loaders import get_query_loaders
16
+ from orchestrator.graphql.utils.get_query_loaders import get_query_loaders_for_gql_fields
17
17
  from orchestrator.utils.search_query import create_sqlalchemy_select
18
18
 
19
19
  logger = structlog.get_logger(__name__)
@@ -33,7 +33,7 @@ async def resolve_workflows(
33
33
  pydantic_sort_by: list[Sort] = [item.to_pydantic() for item in sort_by] if sort_by else []
34
34
  logger.debug("resolve_workflows() called", range=[after, after + first], sort=sort_by, filter=pydantic_filter_by)
35
35
 
36
- query_loaders = get_query_loaders(info, WorkflowTable)
36
+ query_loaders = get_query_loaders_for_gql_fields(WorkflowTable, info)
37
37
  select_stmt = WorkflowTable.select().options(*query_loaders)
38
38
  select_stmt = filter_workflows(select_stmt, pydantic_filter_by, _error_handler)
39
39
 
@@ -1,61 +1,19 @@
1
- from typing import Iterable
2
-
3
- import structlog
4
1
  from sqlalchemy.orm import Load
5
2
 
6
3
  from orchestrator.db.database import BaseModel as DbBaseModel
7
- from orchestrator.db.loaders import AttrLoader, join_attr_loaders, lookup_attr_loaders
4
+ from orchestrator.db.loaders import (
5
+ get_query_loaders_for_model_paths,
6
+ )
8
7
  from orchestrator.graphql.types import OrchestratorInfo
9
8
  from orchestrator.graphql.utils.get_selected_paths import get_selected_paths
10
9
 
11
- logger = structlog.get_logger(__name__)
12
-
13
-
14
- def _split_path(query_path: str) -> Iterable[str]:
15
- yield from (field for field in query_path.split(".") if field != "page")
16
10
 
17
-
18
- def get_query_loaders(info: OrchestratorInfo, root_model: type[DbBaseModel]) -> list[Load]:
11
+ def get_query_loaders_for_gql_fields(root_model: type[DbBaseModel], info: OrchestratorInfo) -> list[Load]:
19
12
  """Get sqlalchemy query loaders for the given GraphQL query.
20
13
 
21
14
  Based on the GraphQL query's selected fields, returns the required DB loaders to use
22
15
  in SQLALchemy's `.options()` for efficiently quering (nested) relationships.
23
16
  """
24
- # Strip page and sort by length to find the longest match first
25
- query_paths = [path.removeprefix("page.") for path in get_selected_paths(info)]
26
- query_paths.sort(key=lambda x: x.count("."), reverse=True)
27
-
28
- def get_loader_for_path(query_path: str) -> tuple[str, Load | None]:
29
- next_model = root_model
30
-
31
- matched_fields: list[str] = []
32
- path_loaders: list[AttrLoader] = []
33
-
34
- for field in _split_path(query_path):
35
- if not (attr_loaders := lookup_attr_loaders(next_model, field)):
36
- break
37
-
38
- matched_fields.append(field)
39
- path_loaders.extend(attr_loaders)
40
- next_model = attr_loaders[-1].next_model
41
-
42
- return ".".join(matched_fields), join_attr_loaders(path_loaders)
43
-
44
- query_loaders: dict[str, Load] = {}
45
-
46
- for path in query_paths:
47
- matched_path, loader = get_loader_for_path(path)
48
- if not matched_path or not loader or matched_path in query_loaders:
49
- continue
50
- if any(known_path.startswith(f"{matched_path}.") for known_path in query_loaders):
51
- continue
52
- query_loaders[matched_path] = loader
17
+ model_paths = [path.removeprefix("page.") for path in get_selected_paths(info)]
53
18
 
54
- loaders = list(query_loaders.values())
55
- logger.debug(
56
- "Generated query loaders",
57
- root_model=root_model,
58
- query_paths=query_paths,
59
- query_loaders=[str(i.path) for i in loaders],
60
- )
61
- return loaders
19
+ return get_query_loaders_for_model_paths(root_model, model_paths)
@@ -70,7 +70,14 @@ def get_all_product_blocks(subscription: dict[str, Any], _tags: list[str] | None
70
70
  return list(locate_product_block(subscription))
71
71
 
72
72
 
73
- pb_instance_property_keys = ("id", "parent", "owner_subscription_id", "subscription_instance_id", "in_use_by_relations")
73
+ pb_instance_property_keys = (
74
+ "id",
75
+ "parent",
76
+ "owner_subscription_id",
77
+ "subscription_instance_id",
78
+ "in_use_by_relations",
79
+ "in_use_by_ids",
80
+ )
74
81
 
75
82
 
76
83
  async def get_subscription_product_blocks(
@@ -0,0 +1,33 @@
1
+ """Add postgres function subscription_instance_as_json.
2
+
3
+ Revision ID: 42b3d076a85b
4
+ Revises: bac6be6f2b4f
5
+ Create Date: 2025-03-06 15:03:09.477225
6
+
7
+ """
8
+
9
+ from pathlib import Path
10
+
11
+ from alembic import op
12
+ from sqlalchemy import text
13
+
14
+ # revision identifiers, used by Alembic.
15
+ revision = "42b3d076a85b"
16
+ down_revision = "bac6be6f2b4f"
17
+ branch_labels = None
18
+ depends_on = None
19
+
20
+
21
+ def upgrade() -> None:
22
+ conn = op.get_bind()
23
+
24
+ revision_file_path = Path(__file__)
25
+ with open(revision_file_path.with_suffix(".sql")) as f:
26
+ conn.execute(text(f.read()))
27
+
28
+
29
+ def downgrade() -> None:
30
+ conn = op.get_bind()
31
+
32
+ conn.execute(text("DROP FUNCTION IF EXISTS subscription_instance_as_json;"))
33
+ conn.execute(text("DROP FUNCTION IF EXISTS subscription_instance_fields_as_json;"))
@@ -0,0 +1,40 @@
1
+ CREATE OR REPLACE FUNCTION subscription_instance_fields_as_json(sub_inst_id uuid)
2
+ RETURNS jsonb
3
+ LANGUAGE sql
4
+ STABLE PARALLEL SAFE AS
5
+ $func$
6
+ select jsonb_object_agg(rts.key, rts.val)
7
+ from (select attr.key,
8
+ attr.val
9
+ from subscription_instances si
10
+ join product_blocks pb ON si.product_block_id = pb.product_block_id
11
+ cross join lateral (
12
+ values ('subscription_instance_id', to_jsonb(si.subscription_instance_id)),
13
+ ('owner_subscription_id', to_jsonb(si.subscription_id)),
14
+ ('name', to_jsonb(pb.name))
15
+ ) as attr(key, val)
16
+ where si.subscription_instance_id = sub_inst_id
17
+ union all
18
+ select rt.resource_type key,
19
+ jsonb_agg(siv.value ORDER BY siv.value ASC) val
20
+ from subscription_instance_values siv
21
+ join resource_types rt on siv.resource_type_id = rt.resource_type_id
22
+ where siv.subscription_instance_id = sub_inst_id
23
+ group by rt.resource_type) as rts
24
+ $func$;
25
+
26
+
27
+ CREATE OR REPLACE FUNCTION subscription_instance_as_json(sub_inst_id uuid)
28
+ RETURNS jsonb
29
+ LANGUAGE sql
30
+ STABLE PARALLEL SAFE AS
31
+ $func$
32
+ select subscription_instance_fields_as_json(sub_inst_id) ||
33
+ coalesce(jsonb_object_agg(depends_on.block_name, depends_on.block_instances), '{}'::jsonb)
34
+ from (select sir.domain_model_attr block_name,
35
+ jsonb_agg(subscription_instance_as_json(sir.depends_on_id) ORDER BY sir.order_id ASC) as block_instances
36
+ from subscription_instance_relations sir
37
+ where sir.in_use_by_id = sub_inst_id
38
+ and sir.domain_model_attr is not null
39
+ group by block_name) as depends_on
40
+ $func$;
@@ -1,4 +1,4 @@
1
- # Copyright 2019-2020 SURF, GÉANT.
1
+ # Copyright 2019-2025 SURF, GÉANT, ESnet.
2
2
  # Licensed under the Apache License, Version 2.0 (the "License");
3
3
  # you may not use this file except in compliance with the License.
4
4
  # You may obtain a copy of the License at
@@ -24,15 +24,10 @@ from sqlalchemy.exc import SQLAlchemyError
24
24
  from sqlalchemy.orm import joinedload
25
25
 
26
26
  from nwastdlib.ex import show_ex
27
+ from oauth2_lib.fastapi import OIDCUserModel
27
28
  from orchestrator.api.error_handling import raise_status
28
29
  from orchestrator.config.assignee import Assignee
29
- from orchestrator.db import (
30
- EngineSettingsTable,
31
- ProcessStepTable,
32
- ProcessSubscriptionTable,
33
- ProcessTable,
34
- db,
35
- )
30
+ from orchestrator.db import EngineSettingsTable, ProcessStepTable, ProcessSubscriptionTable, ProcessTable, db
36
31
  from orchestrator.distlock import distlock_manager
37
32
  from orchestrator.schemas.engine_settings import WorkerStatus
38
33
  from orchestrator.services.input_state import store_input_state
@@ -414,10 +409,15 @@ def _run_process_async(process_id: UUID, f: Callable) -> UUID:
414
409
  return process_id
415
410
 
416
411
 
412
+ def error_message_unauthorized(workflow_key: str) -> str:
413
+ return f"User is not authorized to execute '{workflow_key}' workflow"
414
+
415
+
417
416
  def create_process(
418
417
  workflow_key: str,
419
418
  user_inputs: list[State] | None = None,
420
419
  user: str = SYSTEM_USER,
420
+ user_model: OIDCUserModel | None = None,
421
421
  ) -> ProcessStat:
422
422
  # ATTENTION!! When modifying this function make sure you make similar changes to `run_workflow` in the test code
423
423
 
@@ -430,6 +430,9 @@ def create_process(
430
430
  if not workflow:
431
431
  raise_status(HTTPStatus.NOT_FOUND, "Workflow does not exist")
432
432
 
433
+ if not workflow.authorize_callback(user_model):
434
+ raise_status(HTTPStatus.FORBIDDEN, error_message_unauthorized(workflow_key))
435
+
433
436
  initial_state = {
434
437
  "process_id": process_id,
435
438
  "reporter": user,
@@ -449,6 +452,7 @@ def create_process(
449
452
  state=Success(state | initial_state),
450
453
  log=workflow.steps,
451
454
  current_user=user,
455
+ user_model=user_model,
452
456
  )
453
457
 
454
458
  _db_create_process(pstat)
@@ -460,9 +464,12 @@ def thread_start_process(
460
464
  workflow_key: str,
461
465
  user_inputs: list[State] | None = None,
462
466
  user: str = SYSTEM_USER,
467
+ user_model: OIDCUserModel | None = None,
463
468
  broadcast_func: BroadcastFunc | None = None,
464
469
  ) -> UUID:
465
470
  pstat = create_process(workflow_key, user_inputs=user_inputs, user=user)
471
+ if not pstat.workflow.authorize_callback(user_model):
472
+ raise_status(HTTPStatus.FORBIDDEN, error_message_unauthorized(workflow_key))
466
473
 
467
474
  _safe_logstep_with_func = partial(safe_logstep, broadcast_func=broadcast_func)
468
475
  return _run_process_async(pstat.process_id, lambda: runwf(pstat, _safe_logstep_with_func))
@@ -472,6 +479,7 @@ def start_process(
472
479
  workflow_key: str,
473
480
  user_inputs: list[State] | None = None,
474
481
  user: str = SYSTEM_USER,
482
+ user_model: OIDCUserModel | None = None,
475
483
  broadcast_func: BroadcastFunc | None = None,
476
484
  ) -> UUID:
477
485
  """Start a process for workflow.
@@ -480,6 +488,7 @@ def start_process(
480
488
  workflow_key: name of workflow
481
489
  user_inputs: List of form inputs from frontend
482
490
  user: User who starts this process
491
+ user_model: Full OIDCUserModel with claims, etc
483
492
  broadcast_func: Optional function to broadcast process data
484
493
 
485
494
  Returns:
@@ -487,7 +496,9 @@ def start_process(
487
496
 
488
497
  """
489
498
  start_func = get_execution_context()["start"]
490
- return start_func(workflow_key, user_inputs=user_inputs, user=user, broadcast_func=broadcast_func)
499
+ return start_func(
500
+ workflow_key, user_inputs=user_inputs, user=user, user_model=user_model, broadcast_func=broadcast_func
501
+ )
491
502
 
492
503
 
493
504
  def thread_resume_process(
@@ -495,6 +506,7 @@ def thread_resume_process(
495
506
  *,
496
507
  user_inputs: list[State] | None = None,
497
508
  user: str | None = None,
509
+ user_model: OIDCUserModel | None = None,
498
510
  broadcast_func: BroadcastFunc | None = None,
499
511
  ) -> UUID:
500
512
  # ATTENTION!! When modifying this function make sure you make similar changes to `resume_workflow` in the test code
@@ -503,6 +515,8 @@ def thread_resume_process(
503
515
  user_inputs = [{}]
504
516
 
505
517
  pstat = load_process(process)
518
+ if not pstat.workflow.authorize_callback(user_model):
519
+ raise_status(HTTPStatus.FORBIDDEN, error_message_unauthorized(str(process.workflow_name)))
506
520
 
507
521
  if pstat.workflow == removed_workflow:
508
522
  raise ValueError("This workflow cannot be resumed")
@@ -542,6 +556,7 @@ def resume_process(
542
556
  *,
543
557
  user_inputs: list[State] | None = None,
544
558
  user: str | None = None,
559
+ user_model: OIDCUserModel | None = None,
545
560
  broadcast_func: BroadcastFunc | None = None,
546
561
  ) -> UUID:
547
562
  """Resume a failed or suspended process.
@@ -550,6 +565,7 @@ def resume_process(
550
565
  process: Process from database
551
566
  user_inputs: Optional user input from forms
552
567
  user: user who resumed this process
568
+ user_model: OIDCUserModel of user who resumed this process
553
569
  broadcast_func: Optional function to broadcast process data
554
570
 
555
571
  Returns:
@@ -557,6 +573,9 @@ def resume_process(
557
573
 
558
574
  """
559
575
  pstat = load_process(process)
576
+ if not pstat.workflow.authorize_callback(user_model):
577
+ raise_status(HTTPStatus.FORBIDDEN, error_message_unauthorized(str(process.workflow_name)))
578
+
560
579
  try:
561
580
  post_form(pstat.log[0].form, pstat.state.unwrap(), user_inputs=user_inputs or [])
562
581
  except FormValidationError:
@@ -22,6 +22,7 @@ from uuid import UUID
22
22
 
23
23
  import more_itertools
24
24
  import structlog
25
+ from more_itertools import first
25
26
  from sqlalchemy import Text, cast, not_, select
26
27
  from sqlalchemy.exc import SQLAlchemyError
27
28
  from sqlalchemy.orm import Query, aliased, joinedload
@@ -41,7 +42,11 @@ from orchestrator.db.models import (
41
42
  SubscriptionInstanceRelationTable,
42
43
  SubscriptionMetadataTable,
43
44
  )
45
+ from orchestrator.db.queries.subscription import (
46
+ eagerload_all_subscription_instances_only_inuseby,
47
+ )
44
48
  from orchestrator.domain.base import SubscriptionModel
49
+ from orchestrator.domain.context_cache import cache_subscription_models
45
50
  from orchestrator.targets import Target
46
51
  from orchestrator.types import SubscriptionLifecycle
47
52
  from orchestrator.utils.datetime import nowtz
@@ -594,13 +599,15 @@ def convert_to_in_use_by_relation(obj: Any) -> dict[str, str]:
594
599
 
595
600
  def build_extended_domain_model(subscription_model: SubscriptionModel) -> dict:
596
601
  """Create a subscription dict from the SubscriptionModel with additional keys."""
597
- stmt = select(SubscriptionCustomerDescriptionTable).filter(
602
+ from orchestrator.settings import app_settings
603
+
604
+ stmt = select(SubscriptionCustomerDescriptionTable).where(
598
605
  SubscriptionCustomerDescriptionTable.subscription_id == subscription_model.subscription_id
599
606
  )
600
607
  customer_descriptions = list(db.session.scalars(stmt))
601
608
 
602
- subscription = subscription_model.model_dump()
603
- paths = product_block_paths(subscription)
609
+ with cache_subscription_models():
610
+ subscription = subscription_model.model_dump()
604
611
 
605
612
  def inject_in_use_by_ids(path_to_block: str) -> None:
606
613
  if not (in_use_by_subs := getattr_in(subscription_model, f"{path_to_block}.in_use_by")):
@@ -611,15 +618,38 @@ def build_extended_domain_model(subscription_model: SubscriptionModel) -> dict:
611
618
  update_in(subscription, f"{path_to_block}.in_use_by_ids", in_use_by_ids)
612
619
  update_in(subscription, f"{path_to_block}.in_use_by_relations", in_use_by_relations)
613
620
 
614
- # find all product blocks, check if they have in_use_by and inject the in_use_by_ids into the subscription dict.
615
- for path in paths:
616
- inject_in_use_by_ids(path)
621
+ if app_settings.ENABLE_SUBSCRIPTION_MODEL_OPTIMIZATIONS:
622
+ # TODO #900 remove toggle and make this path the default
623
+ # query all subscription instances and inject the in_use_by_ids/in_use_by_relations into the subscription dict.
624
+ instance_to_in_use_by = {
625
+ instance.subscription_instance_id: instance.in_use_by
626
+ for instance in eagerload_all_subscription_instances_only_inuseby(subscription_model.subscription_id)
627
+ }
628
+ inject_in_use_by_ids_v2(subscription, instance_to_in_use_by)
629
+ else:
630
+ # find all product blocks, check if they have in_use_by and inject the in_use_by_ids into the subscription dict.
631
+ for path in product_block_paths(subscription):
632
+ inject_in_use_by_ids(path)
617
633
 
618
634
  subscription["customer_descriptions"] = customer_descriptions
619
635
 
620
636
  return subscription
621
637
 
622
638
 
639
+ def inject_in_use_by_ids_v2(dikt: dict, instance_to_in_use_by: dict[UUID, Sequence[SubscriptionInstanceTable]]) -> None:
640
+ for value in dikt.values():
641
+ if isinstance(value, dict):
642
+ inject_in_use_by_ids_v2(value, instance_to_in_use_by)
643
+ elif isinstance(value, list) and isinstance(first(value, None), dict):
644
+ for item in value:
645
+ inject_in_use_by_ids_v2(item, instance_to_in_use_by)
646
+
647
+ if subscription_instance_id := dikt.get("subscription_instance_id"):
648
+ in_use_by_subs = instance_to_in_use_by[subscription_instance_id]
649
+ dikt["in_use_by_ids"] = [i.subscription_instance_id for i in in_use_by_subs]
650
+ dikt["in_use_by_relations"] = [convert_to_in_use_by_relation(instance) for instance in in_use_by_subs]
651
+
652
+
623
653
  def format_special_types(subscription: dict) -> dict:
624
654
  """Modifies the subscription dict in-place, formatting special types to string.
625
655
 
orchestrator/settings.py CHANGED
@@ -87,6 +87,9 @@ class AppSettings(BaseSettings):
87
87
  ENABLE_GRAPHQL_STATS_EXTENSION: bool = False
88
88
  VALIDATE_OUT_OF_SYNC_SUBSCRIPTIONS: bool = False
89
89
  FILTER_BY_MODE: Literal["partial", "exact"] = "exact"
90
+ ENABLE_SUBSCRIPTION_MODEL_OPTIMIZATIONS: bool = (
91
+ True # True=ignore cache + optimized DB queries; False=use cache + unoptimized DB queries. Remove in #900
92
+ )
90
93
 
91
94
 
92
95
  app_settings = AppSettings()
@@ -241,3 +241,12 @@ def to_ranges(i: Iterable[int]) -> Iterable[range]:
241
241
  for _, g in itertools.groupby(enumerate(i), lambda t: t[1] - t[0]):
242
242
  group = list(g)
243
243
  yield range(group[0][1], group[-1][1] + 1)
244
+
245
+
246
+ K = TypeVar("K")
247
+ V = TypeVar("V")
248
+
249
+
250
+ def group_by_key(items: Iterable[tuple[K, V]]) -> dict[K, list[V]]:
251
+ groups = itertools.groupby(items, key=lambda item: item[0])
252
+ return {key: [item[1] for item in group] for key, group in groups}
@@ -55,6 +55,12 @@ def to_redis(subscription: dict[str, Any]) -> str | None:
55
55
 
56
56
  def from_redis(subscription_id: UUID) -> tuple[PY_JSON_TYPES, str] | None:
57
57
  log = logger.bind(subscription_id=subscription_id)
58
+
59
+ if app_settings.ENABLE_SUBSCRIPTION_MODEL_OPTIMIZATIONS:
60
+ # TODO #900 remove toggle and remove usage of this function in get_subscription_dict
61
+ log.warning("Using SubscriptionModel optimization, not loading subscription from cache")
62
+ return None
63
+
58
64
  if caching_models_enabled():
59
65
  log.debug("Try to retrieve subscription from cache")
60
66
  obj = cache.get(f"orchestrator:domain:{subscription_id}")
orchestrator/workflow.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2019-2025 SURF, GÉANT.
1
+ # Copyright 2019-2025 SURF, GÉANT, ESnet.
2
2
  # Licensed under the Apache License, Version 2.0 (the "License");
3
3
  # you may not use this file except in compliance with the License.
4
4
  # You may obtain a copy of the License at
@@ -39,6 +39,7 @@ from structlog.contextvars import bound_contextvars
39
39
  from structlog.stdlib import BoundLogger
40
40
 
41
41
  from nwastdlib import const, identity
42
+ from oauth2_lib.fastapi import OIDCUserModel
42
43
  from orchestrator.config.assignee import Assignee
43
44
  from orchestrator.db import db, transactional
44
45
  from orchestrator.services.settings import get_engine_settings
@@ -89,6 +90,7 @@ class Workflow(Protocol):
89
90
  __qualname__: str
90
91
  name: str
91
92
  description: str
93
+ authorize_callback: Callable[[OIDCUserModel | None], bool]
92
94
  initial_input_form: InputFormGenerator | None = None
93
95
  target: Target
94
96
  steps: StepList
@@ -178,12 +180,18 @@ def _handle_simple_input_form_generator(f: StateInputStepFunc) -> StateInputForm
178
180
  return form_generator
179
181
 
180
182
 
183
+ def allow(_: OIDCUserModel | None) -> bool:
184
+ """Default function to return True in absence of user-defined authorize function."""
185
+ return True
186
+
187
+
181
188
  def make_workflow(
182
189
  f: Callable,
183
190
  description: str,
184
191
  initial_input_form: InputStepFunc | None,
185
192
  target: Target,
186
193
  steps: StepList,
194
+ authorize_callback: Callable[[OIDCUserModel | None], bool] | None = None,
187
195
  ) -> Workflow:
188
196
  @functools.wraps(f)
189
197
  def wrapping_function() -> NoReturn:
@@ -193,6 +201,7 @@ def make_workflow(
193
201
 
194
202
  wrapping_function.name = f.__name__ # default, will be changed by LazyWorkflowInstance
195
203
  wrapping_function.description = description
204
+ wrapping_function.authorize_callback = allow if authorize_callback is None else authorize_callback
196
205
 
197
206
  if initial_input_form is None:
198
207
  # We always need a form to prevent starting a workflow when no input is needed.
@@ -459,7 +468,10 @@ def focussteps(key: str) -> Callable[[Step | StepList], StepList]:
459
468
 
460
469
 
461
470
  def workflow(
462
- description: str, initial_input_form: InputStepFunc | None = None, target: Target = Target.SYSTEM
471
+ description: str,
472
+ initial_input_form: InputStepFunc | None = None,
473
+ target: Target = Target.SYSTEM,
474
+ authorize_callback: Callable[[OIDCUserModel | None], bool] | None = None,
463
475
  ) -> Callable[[Callable[[], StepList]], Workflow]:
464
476
  """Transform an initial_input_form and a step list into a workflow.
465
477
 
@@ -479,7 +491,9 @@ def workflow(
479
491
  initial_input_form_in_form_inject_args = form_inject_args(initial_input_form)
480
492
 
481
493
  def _workflow(f: Callable[[], StepList]) -> Workflow:
482
- return make_workflow(f, description, initial_input_form_in_form_inject_args, target, f())
494
+ return make_workflow(
495
+ f, description, initial_input_form_in_form_inject_args, target, f(), authorize_callback=authorize_callback
496
+ )
483
497
 
484
498
  return _workflow
485
499
 
@@ -491,13 +505,14 @@ class ProcessStat:
491
505
  state: Process
492
506
  log: StepList
493
507
  current_user: str
508
+ user_model: OIDCUserModel | None = None
494
509
 
495
510
  def update(self, **vs: Any) -> ProcessStat:
496
511
  """Update ProcessStat.
497
512
 
498
513
  >>> pstat = ProcessStat('', None, {}, [], "")
499
514
  >>> pstat.update(state={"a": "b"})
500
- ProcessStat(process_id='', workflow=None, state={'a': 'b'}, log=[], current_user='')
515
+ ProcessStat(process_id='', workflow=None, state={'a': 'b'}, log=[], current_user='', user_model=None)
501
516
  """
502
517
  return ProcessStat(**{**asdict(self), **vs})
503
518
 
@@ -1,4 +1,4 @@
1
- # Copyright 2019-2020 SURF, GÉANT.
1
+ # Copyright 2019-2025 SURF, GÉANT, ESnet.
2
2
  # Licensed under the Apache License, Version 2.0 (the "License");
3
3
  # you may not use this file except in compliance with the License.
4
4
  # You may obtain a copy of the License at
@@ -20,6 +20,7 @@ from more_itertools import first_true
20
20
  from pydantic import field_validator, model_validator
21
21
  from sqlalchemy import select
22
22
 
23
+ from oauth2_lib.fastapi import OIDCUserModel
23
24
  from orchestrator.db import ProductTable, SubscriptionTable, db
24
25
  from orchestrator.forms.validators import ProductId
25
26
  from orchestrator.services import subscriptions
@@ -30,7 +31,7 @@ from orchestrator.utils.errors import StaleDataError
30
31
  from orchestrator.utils.redis import caching_models_enabled
31
32
  from orchestrator.utils.state import form_inject_args
32
33
  from orchestrator.utils.validate_data_version import validate_data_version
33
- from orchestrator.workflow import StepList, Workflow, conditional, done, init, make_workflow, step
34
+ from orchestrator.workflow import Step, StepList, Workflow, begin, conditional, done, init, make_workflow, step
34
35
  from orchestrator.workflows.steps import (
35
36
  cache_domain_models,
36
37
  refresh_subscription_search_index,
@@ -205,6 +206,7 @@ def create_workflow(
205
206
  initial_input_form: InputStepFunc | None = None,
206
207
  status: SubscriptionLifecycle = SubscriptionLifecycle.ACTIVE,
207
208
  additional_steps: StepList | None = None,
209
+ authorize_callback: Callable[[OIDCUserModel | None], bool] | None = None,
208
210
  ) -> Callable[[Callable[[], StepList]], Workflow]:
209
211
  """Transform an initial_input_form and a step list into a workflow with a target=Target.CREATE.
210
212
 
@@ -231,7 +233,14 @@ def create_workflow(
231
233
  >> done
232
234
  )
233
235
 
234
- return make_workflow(f, description, create_initial_input_form_generator, Target.CREATE, steplist)
236
+ return make_workflow(
237
+ f,
238
+ description,
239
+ create_initial_input_form_generator,
240
+ Target.CREATE,
241
+ steplist,
242
+ authorize_callback=authorize_callback,
243
+ )
235
244
 
236
245
  return _create_workflow
237
246
 
@@ -240,6 +249,7 @@ def modify_workflow(
240
249
  description: str,
241
250
  initial_input_form: InputStepFunc | None = None,
242
251
  additional_steps: StepList | None = None,
252
+ authorize_callback: Callable[[OIDCUserModel | None], bool] | None = None,
243
253
  ) -> Callable[[Callable[[], StepList]], Workflow]:
244
254
  """Transform an initial_input_form and a step list into a workflow.
245
255
 
@@ -269,7 +279,14 @@ def modify_workflow(
269
279
  >> done
270
280
  )
271
281
 
272
- return make_workflow(f, description, wrapped_modify_initial_input_form_generator, Target.MODIFY, steplist)
282
+ return make_workflow(
283
+ f,
284
+ description,
285
+ wrapped_modify_initial_input_form_generator,
286
+ Target.MODIFY,
287
+ steplist,
288
+ authorize_callback=authorize_callback,
289
+ )
273
290
 
274
291
  return _modify_workflow
275
292
 
@@ -278,6 +295,7 @@ def terminate_workflow(
278
295
  description: str,
279
296
  initial_input_form: InputStepFunc | None = None,
280
297
  additional_steps: StepList | None = None,
298
+ authorize_callback: Callable[[OIDCUserModel | None], bool] | None = None,
281
299
  ) -> Callable[[Callable[[], StepList]], Workflow]:
282
300
  """Transform an initial_input_form and a step list into a workflow.
283
301
 
@@ -308,7 +326,14 @@ def terminate_workflow(
308
326
  >> done
309
327
  )
310
328
 
311
- return make_workflow(f, description, wrapped_terminate_initial_input_form_generator, Target.TERMINATE, steplist)
329
+ return make_workflow(
330
+ f,
331
+ description,
332
+ wrapped_terminate_initial_input_form_generator,
333
+ Target.TERMINATE,
334
+ steplist,
335
+ authorize_callback=authorize_callback,
336
+ )
312
337
 
313
338
  return _terminate_workflow
314
339
 
@@ -344,6 +369,16 @@ def validate_workflow(description: str) -> Callable[[Callable[[], StepList]], Wo
344
369
  return _validate_workflow
345
370
 
346
371
 
372
+ def ensure_provisioning_status(modify_steps: Step | StepList) -> StepList:
373
+ """Decorator to ensure subscription modifications are executed only during Provisioning status."""
374
+ return (
375
+ begin
376
+ >> set_status(SubscriptionLifecycle.PROVISIONING)
377
+ >> modify_steps
378
+ >> set_status(SubscriptionLifecycle.ACTIVE)
379
+ )
380
+
381
+
347
382
  @step("Equalize workflow step count")
348
383
  def obsolete_step() -> None:
349
384
  """Equalize workflow step counts.