orchestrator-core 3.1.2rc4__py3-none-any.whl → 3.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. orchestrator/__init__.py +1 -1
  2. orchestrator/api/api_v1/endpoints/processes.py +6 -9
  3. orchestrator/cli/generator/generator/workflow.py +13 -1
  4. orchestrator/cli/generator/templates/modify_product.j2 +9 -0
  5. orchestrator/db/__init__.py +2 -0
  6. orchestrator/db/loaders.py +51 -3
  7. orchestrator/db/models.py +13 -0
  8. orchestrator/db/queries/__init__.py +0 -0
  9. orchestrator/db/queries/subscription.py +85 -0
  10. orchestrator/db/queries/subscription_instance.py +28 -0
  11. orchestrator/domain/base.py +162 -44
  12. orchestrator/domain/context_cache.py +62 -0
  13. orchestrator/domain/helpers.py +41 -1
  14. orchestrator/domain/subscription_instance_transform.py +114 -0
  15. orchestrator/graphql/resolvers/process.py +3 -3
  16. orchestrator/graphql/resolvers/product.py +2 -2
  17. orchestrator/graphql/resolvers/product_block.py +2 -2
  18. orchestrator/graphql/resolvers/resource_type.py +2 -2
  19. orchestrator/graphql/resolvers/workflow.py +2 -2
  20. orchestrator/graphql/utils/get_query_loaders.py +6 -48
  21. orchestrator/graphql/utils/get_subscription_product_blocks.py +8 -1
  22. orchestrator/migrations/versions/schema/2025-03-06_42b3d076a85b_subscription_instance_as_json_function.py +33 -0
  23. orchestrator/migrations/versions/schema/2025-03-06_42b3d076a85b_subscription_instance_as_json_function.sql +40 -0
  24. orchestrator/migrations/versions/schema/2025-04-09_fc5c993a4b4a_add_cascade_constraint_on_processes_.py +44 -0
  25. orchestrator/services/processes.py +28 -9
  26. orchestrator/services/subscriptions.py +36 -6
  27. orchestrator/settings.py +3 -0
  28. orchestrator/utils/functional.py +9 -0
  29. orchestrator/utils/redis.py +6 -0
  30. orchestrator/workflow.py +29 -6
  31. orchestrator/workflows/utils.py +40 -5
  32. {orchestrator_core-3.1.2rc4.dist-info → orchestrator_core-3.2.0.dist-info}/METADATA +9 -8
  33. {orchestrator_core-3.1.2rc4.dist-info → orchestrator_core-3.2.0.dist-info}/RECORD +36 -28
  34. /orchestrator/migrations/versions/schema/{2025-10-19_4fjdn13f83ga_add_validate_product_type_task.py → 2025-01-19_4fjdn13f83ga_add_validate_product_type_task.py} +0 -0
  35. {orchestrator_core-3.1.2rc4.dist-info → orchestrator_core-3.2.0.dist-info}/WHEEL +0 -0
  36. {orchestrator_core-3.1.2rc4.dist-info → orchestrator_core-3.2.0.dist-info}/licenses/LICENSE +0 -0
@@ -17,7 +17,7 @@ from orchestrator.graphql.resolvers.helpers import rows_from_statement
17
17
  from orchestrator.graphql.schemas.product_block import ProductBlock
18
18
  from orchestrator.graphql.types import GraphqlFilter, GraphqlSort, OrchestratorInfo
19
19
  from orchestrator.graphql.utils import create_resolver_error_handler, is_querying_page_data, to_graphql_result_page
20
- from orchestrator.graphql.utils.get_query_loaders import get_query_loaders
20
+ from orchestrator.graphql.utils.get_query_loaders import get_query_loaders_for_gql_fields
21
21
  from orchestrator.utils.search_query import create_sqlalchemy_select
22
22
 
23
23
  logger = structlog.get_logger(__name__)
@@ -39,7 +39,7 @@ async def resolve_product_blocks(
39
39
  "resolve_product_blocks() called", range=[after, after + first], sort=sort_by, filter=pydantic_filter_by
40
40
  )
41
41
 
42
- query_loaders = get_query_loaders(info, ProductBlockTable)
42
+ query_loaders = get_query_loaders_for_gql_fields(ProductBlockTable, info)
43
43
  select_stmt = select(ProductBlockTable)
44
44
  select_stmt = filter_product_blocks(select_stmt, pydantic_filter_by, _error_handler)
45
45
 
@@ -17,7 +17,7 @@ from orchestrator.graphql.resolvers.helpers import rows_from_statement
17
17
  from orchestrator.graphql.schemas.resource_type import ResourceType
18
18
  from orchestrator.graphql.types import GraphqlFilter, GraphqlSort, OrchestratorInfo
19
19
  from orchestrator.graphql.utils import create_resolver_error_handler, is_querying_page_data, to_graphql_result_page
20
- from orchestrator.graphql.utils.get_query_loaders import get_query_loaders
20
+ from orchestrator.graphql.utils.get_query_loaders import get_query_loaders_for_gql_fields
21
21
  from orchestrator.utils.search_query import create_sqlalchemy_select
22
22
 
23
23
  logger = structlog.get_logger(__name__)
@@ -38,7 +38,7 @@ async def resolve_resource_types(
38
38
  logger.debug(
39
39
  "resolve_resource_types() called", range=[after, after + first], sort=sort_by, filter=pydantic_filter_by
40
40
  )
41
- query_loaders = get_query_loaders(info, ResourceTypeTable)
41
+ query_loaders = get_query_loaders_for_gql_fields(ResourceTypeTable, info)
42
42
  select_stmt = select(ResourceTypeTable).options(*query_loaders)
43
43
  select_stmt = filter_resource_types(select_stmt, pydantic_filter_by, _error_handler)
44
44
 
@@ -13,7 +13,7 @@ from orchestrator.graphql.resolvers.helpers import rows_from_statement
13
13
  from orchestrator.graphql.schemas.workflow import Workflow
14
14
  from orchestrator.graphql.types import GraphqlFilter, GraphqlSort, OrchestratorInfo
15
15
  from orchestrator.graphql.utils import create_resolver_error_handler, is_querying_page_data, to_graphql_result_page
16
- from orchestrator.graphql.utils.get_query_loaders import get_query_loaders
16
+ from orchestrator.graphql.utils.get_query_loaders import get_query_loaders_for_gql_fields
17
17
  from orchestrator.utils.search_query import create_sqlalchemy_select
18
18
 
19
19
  logger = structlog.get_logger(__name__)
@@ -33,7 +33,7 @@ async def resolve_workflows(
33
33
  pydantic_sort_by: list[Sort] = [item.to_pydantic() for item in sort_by] if sort_by else []
34
34
  logger.debug("resolve_workflows() called", range=[after, after + first], sort=sort_by, filter=pydantic_filter_by)
35
35
 
36
- query_loaders = get_query_loaders(info, WorkflowTable)
36
+ query_loaders = get_query_loaders_for_gql_fields(WorkflowTable, info)
37
37
  select_stmt = WorkflowTable.select().options(*query_loaders)
38
38
  select_stmt = filter_workflows(select_stmt, pydantic_filter_by, _error_handler)
39
39
 
@@ -1,61 +1,19 @@
1
- from typing import Iterable
2
-
3
- import structlog
4
1
  from sqlalchemy.orm import Load
5
2
 
6
3
  from orchestrator.db.database import BaseModel as DbBaseModel
7
- from orchestrator.db.loaders import AttrLoader, join_attr_loaders, lookup_attr_loaders
4
+ from orchestrator.db.loaders import (
5
+ get_query_loaders_for_model_paths,
6
+ )
8
7
  from orchestrator.graphql.types import OrchestratorInfo
9
8
  from orchestrator.graphql.utils.get_selected_paths import get_selected_paths
10
9
 
11
- logger = structlog.get_logger(__name__)
12
-
13
-
14
- def _split_path(query_path: str) -> Iterable[str]:
15
- yield from (field for field in query_path.split(".") if field != "page")
16
10
 
17
-
18
- def get_query_loaders(info: OrchestratorInfo, root_model: type[DbBaseModel]) -> list[Load]:
11
+ def get_query_loaders_for_gql_fields(root_model: type[DbBaseModel], info: OrchestratorInfo) -> list[Load]:
19
12
  """Get sqlalchemy query loaders for the given GraphQL query.
20
13
 
21
14
  Based on the GraphQL query's selected fields, returns the required DB loaders to use
22
15
  in SQLALchemy's `.options()` for efficiently quering (nested) relationships.
23
16
  """
24
- # Strip page and sort by length to find the longest match first
25
- query_paths = [path.removeprefix("page.") for path in get_selected_paths(info)]
26
- query_paths.sort(key=lambda x: x.count("."), reverse=True)
27
-
28
- def get_loader_for_path(query_path: str) -> tuple[str, Load | None]:
29
- next_model = root_model
30
-
31
- matched_fields: list[str] = []
32
- path_loaders: list[AttrLoader] = []
33
-
34
- for field in _split_path(query_path):
35
- if not (attr_loaders := lookup_attr_loaders(next_model, field)):
36
- break
37
-
38
- matched_fields.append(field)
39
- path_loaders.extend(attr_loaders)
40
- next_model = attr_loaders[-1].next_model
41
-
42
- return ".".join(matched_fields), join_attr_loaders(path_loaders)
43
-
44
- query_loaders: dict[str, Load] = {}
45
-
46
- for path in query_paths:
47
- matched_path, loader = get_loader_for_path(path)
48
- if not matched_path or not loader or matched_path in query_loaders:
49
- continue
50
- if any(known_path.startswith(f"{matched_path}.") for known_path in query_loaders):
51
- continue
52
- query_loaders[matched_path] = loader
17
+ model_paths = [path.removeprefix("page.") for path in get_selected_paths(info)]
53
18
 
54
- loaders = list(query_loaders.values())
55
- logger.debug(
56
- "Generated query loaders",
57
- root_model=root_model,
58
- query_paths=query_paths,
59
- query_loaders=[str(i.path) for i in loaders],
60
- )
61
- return loaders
19
+ return get_query_loaders_for_model_paths(root_model, model_paths)
@@ -70,7 +70,14 @@ def get_all_product_blocks(subscription: dict[str, Any], _tags: list[str] | None
70
70
  return list(locate_product_block(subscription))
71
71
 
72
72
 
73
- pb_instance_property_keys = ("id", "parent", "owner_subscription_id", "subscription_instance_id", "in_use_by_relations")
73
+ pb_instance_property_keys = (
74
+ "id",
75
+ "parent",
76
+ "owner_subscription_id",
77
+ "subscription_instance_id",
78
+ "in_use_by_relations",
79
+ "in_use_by_ids",
80
+ )
74
81
 
75
82
 
76
83
  async def get_subscription_product_blocks(
@@ -0,0 +1,33 @@
1
+ """Add postgres function subscription_instance_as_json.
2
+
3
+ Revision ID: 42b3d076a85b
4
+ Revises: bac6be6f2b4f
5
+ Create Date: 2025-03-06 15:03:09.477225
6
+
7
+ """
8
+
9
+ from pathlib import Path
10
+
11
+ from alembic import op
12
+ from sqlalchemy import text
13
+
14
+ # revision identifiers, used by Alembic.
15
+ revision = "42b3d076a85b"
16
+ down_revision = "bac6be6f2b4f"
17
+ branch_labels = None
18
+ depends_on = None
19
+
20
+
21
+ def upgrade() -> None:
22
+ conn = op.get_bind()
23
+
24
+ revision_file_path = Path(__file__)
25
+ with open(revision_file_path.with_suffix(".sql")) as f:
26
+ conn.execute(text(f.read()))
27
+
28
+
29
+ def downgrade() -> None:
30
+ conn = op.get_bind()
31
+
32
+ conn.execute(text("DROP FUNCTION IF EXISTS subscription_instance_as_json;"))
33
+ conn.execute(text("DROP FUNCTION IF EXISTS subscription_instance_fields_as_json;"))
@@ -0,0 +1,40 @@
1
+ CREATE OR REPLACE FUNCTION subscription_instance_fields_as_json(sub_inst_id uuid)
2
+ RETURNS jsonb
3
+ LANGUAGE sql
4
+ STABLE PARALLEL SAFE AS
5
+ $func$
6
+ select jsonb_object_agg(rts.key, rts.val)
7
+ from (select attr.key,
8
+ attr.val
9
+ from subscription_instances si
10
+ join product_blocks pb ON si.product_block_id = pb.product_block_id
11
+ cross join lateral (
12
+ values ('subscription_instance_id', to_jsonb(si.subscription_instance_id)),
13
+ ('owner_subscription_id', to_jsonb(si.subscription_id)),
14
+ ('name', to_jsonb(pb.name))
15
+ ) as attr(key, val)
16
+ where si.subscription_instance_id = sub_inst_id
17
+ union all
18
+ select rt.resource_type key,
19
+ jsonb_agg(siv.value ORDER BY siv.value ASC) val
20
+ from subscription_instance_values siv
21
+ join resource_types rt on siv.resource_type_id = rt.resource_type_id
22
+ where siv.subscription_instance_id = sub_inst_id
23
+ group by rt.resource_type) as rts
24
+ $func$;
25
+
26
+
27
+ CREATE OR REPLACE FUNCTION subscription_instance_as_json(sub_inst_id uuid)
28
+ RETURNS jsonb
29
+ LANGUAGE sql
30
+ STABLE PARALLEL SAFE AS
31
+ $func$
32
+ select subscription_instance_fields_as_json(sub_inst_id) ||
33
+ coalesce(jsonb_object_agg(depends_on.block_name, depends_on.block_instances), '{}'::jsonb)
34
+ from (select sir.domain_model_attr block_name,
35
+ jsonb_agg(subscription_instance_as_json(sir.depends_on_id) ORDER BY sir.order_id ASC) as block_instances
36
+ from subscription_instance_relations sir
37
+ where sir.in_use_by_id = sub_inst_id
38
+ and sir.domain_model_attr is not null
39
+ group by block_name) as depends_on
40
+ $func$;
@@ -0,0 +1,44 @@
1
+ """add cascade constraint on processes input state.
2
+
3
+ Revision ID: fc5c993a4b4a
4
+ Revises: 42b3d076a85b
5
+ Create Date: 2025-04-09 18:27:31.922214
6
+
7
+ """
8
+
9
+ from alembic import op
10
+
11
+ # revision identifiers, used by Alembic.
12
+ revision = "fc5c993a4b4a"
13
+ down_revision = "42b3d076a85b"
14
+ branch_labels = None
15
+ depends_on = None
16
+
17
+
18
+ def upgrade() -> None:
19
+ # Drop the existing foreign key constraint
20
+ op.drop_constraint("input_states_pid_fkey", "input_states", type_="foreignkey")
21
+
22
+ # Add a new foreign key constraint with cascade delete
23
+ op.create_foreign_key(
24
+ "input_states_pid_fkey",
25
+ "input_states",
26
+ "processes",
27
+ ["pid"],
28
+ ["pid"],
29
+ ondelete="CASCADE",
30
+ )
31
+
32
+
33
+ def downgrade() -> None:
34
+ # Drop the cascade foreign key constraint
35
+ op.drop_constraint("input_states_pid_fkey", "input_states", type_="foreignkey")
36
+
37
+ # Recreate the original foreign key constraint without cascade
38
+ op.create_foreign_key(
39
+ "input_states_pid_fkey",
40
+ "input_states",
41
+ "processes",
42
+ ["pid"],
43
+ ["pid"],
44
+ )
@@ -1,4 +1,4 @@
1
- # Copyright 2019-2020 SURF, GÉANT.
1
+ # Copyright 2019-2025 SURF, GÉANT, ESnet.
2
2
  # Licensed under the Apache License, Version 2.0 (the "License");
3
3
  # you may not use this file except in compliance with the License.
4
4
  # You may obtain a copy of the License at
@@ -24,15 +24,10 @@ from sqlalchemy.exc import SQLAlchemyError
24
24
  from sqlalchemy.orm import joinedload
25
25
 
26
26
  from nwastdlib.ex import show_ex
27
+ from oauth2_lib.fastapi import OIDCUserModel
27
28
  from orchestrator.api.error_handling import raise_status
28
29
  from orchestrator.config.assignee import Assignee
29
- from orchestrator.db import (
30
- EngineSettingsTable,
31
- ProcessStepTable,
32
- ProcessSubscriptionTable,
33
- ProcessTable,
34
- db,
35
- )
30
+ from orchestrator.db import EngineSettingsTable, ProcessStepTable, ProcessSubscriptionTable, ProcessTable, db
36
31
  from orchestrator.distlock import distlock_manager
37
32
  from orchestrator.schemas.engine_settings import WorkerStatus
38
33
  from orchestrator.services.input_state import store_input_state
@@ -414,10 +409,15 @@ def _run_process_async(process_id: UUID, f: Callable) -> UUID:
414
409
  return process_id
415
410
 
416
411
 
412
+ def error_message_unauthorized(workflow_key: str) -> str:
413
+ return f"User is not authorized to execute '{workflow_key}' workflow"
414
+
415
+
417
416
  def create_process(
418
417
  workflow_key: str,
419
418
  user_inputs: list[State] | None = None,
420
419
  user: str = SYSTEM_USER,
420
+ user_model: OIDCUserModel | None = None,
421
421
  ) -> ProcessStat:
422
422
  # ATTENTION!! When modifying this function make sure you make similar changes to `run_workflow` in the test code
423
423
 
@@ -430,6 +430,9 @@ def create_process(
430
430
  if not workflow:
431
431
  raise_status(HTTPStatus.NOT_FOUND, "Workflow does not exist")
432
432
 
433
+ if not workflow.authorize_callback(user_model):
434
+ raise_status(HTTPStatus.FORBIDDEN, error_message_unauthorized(workflow_key))
435
+
433
436
  initial_state = {
434
437
  "process_id": process_id,
435
438
  "reporter": user,
@@ -449,6 +452,7 @@ def create_process(
449
452
  state=Success(state | initial_state),
450
453
  log=workflow.steps,
451
454
  current_user=user,
455
+ user_model=user_model,
452
456
  )
453
457
 
454
458
  _db_create_process(pstat)
@@ -460,9 +464,12 @@ def thread_start_process(
460
464
  workflow_key: str,
461
465
  user_inputs: list[State] | None = None,
462
466
  user: str = SYSTEM_USER,
467
+ user_model: OIDCUserModel | None = None,
463
468
  broadcast_func: BroadcastFunc | None = None,
464
469
  ) -> UUID:
465
470
  pstat = create_process(workflow_key, user_inputs=user_inputs, user=user)
471
+ if not pstat.workflow.authorize_callback(user_model):
472
+ raise_status(HTTPStatus.FORBIDDEN, error_message_unauthorized(workflow_key))
466
473
 
467
474
  _safe_logstep_with_func = partial(safe_logstep, broadcast_func=broadcast_func)
468
475
  return _run_process_async(pstat.process_id, lambda: runwf(pstat, _safe_logstep_with_func))
@@ -472,6 +479,7 @@ def start_process(
472
479
  workflow_key: str,
473
480
  user_inputs: list[State] | None = None,
474
481
  user: str = SYSTEM_USER,
482
+ user_model: OIDCUserModel | None = None,
475
483
  broadcast_func: BroadcastFunc | None = None,
476
484
  ) -> UUID:
477
485
  """Start a process for workflow.
@@ -480,6 +488,7 @@ def start_process(
480
488
  workflow_key: name of workflow
481
489
  user_inputs: List of form inputs from frontend
482
490
  user: User who starts this process
491
+ user_model: Full OIDCUserModel with claims, etc
483
492
  broadcast_func: Optional function to broadcast process data
484
493
 
485
494
  Returns:
@@ -487,7 +496,9 @@ def start_process(
487
496
 
488
497
  """
489
498
  start_func = get_execution_context()["start"]
490
- return start_func(workflow_key, user_inputs=user_inputs, user=user, broadcast_func=broadcast_func)
499
+ return start_func(
500
+ workflow_key, user_inputs=user_inputs, user=user, user_model=user_model, broadcast_func=broadcast_func
501
+ )
491
502
 
492
503
 
493
504
  def thread_resume_process(
@@ -495,6 +506,7 @@ def thread_resume_process(
495
506
  *,
496
507
  user_inputs: list[State] | None = None,
497
508
  user: str | None = None,
509
+ user_model: OIDCUserModel | None = None,
498
510
  broadcast_func: BroadcastFunc | None = None,
499
511
  ) -> UUID:
500
512
  # ATTENTION!! When modifying this function make sure you make similar changes to `resume_workflow` in the test code
@@ -503,6 +515,8 @@ def thread_resume_process(
503
515
  user_inputs = [{}]
504
516
 
505
517
  pstat = load_process(process)
518
+ if not pstat.workflow.authorize_callback(user_model):
519
+ raise_status(HTTPStatus.FORBIDDEN, error_message_unauthorized(str(process.workflow_name)))
506
520
 
507
521
  if pstat.workflow == removed_workflow:
508
522
  raise ValueError("This workflow cannot be resumed")
@@ -542,6 +556,7 @@ def resume_process(
542
556
  *,
543
557
  user_inputs: list[State] | None = None,
544
558
  user: str | None = None,
559
+ user_model: OIDCUserModel | None = None,
545
560
  broadcast_func: BroadcastFunc | None = None,
546
561
  ) -> UUID:
547
562
  """Resume a failed or suspended process.
@@ -550,6 +565,7 @@ def resume_process(
550
565
  process: Process from database
551
566
  user_inputs: Optional user input from forms
552
567
  user: user who resumed this process
568
+ user_model: OIDCUserModel of user who resumed this process
553
569
  broadcast_func: Optional function to broadcast process data
554
570
 
555
571
  Returns:
@@ -557,6 +573,9 @@ def resume_process(
557
573
 
558
574
  """
559
575
  pstat = load_process(process)
576
+ if not pstat.workflow.authorize_callback(user_model):
577
+ raise_status(HTTPStatus.FORBIDDEN, error_message_unauthorized(str(process.workflow_name)))
578
+
560
579
  try:
561
580
  post_form(pstat.log[0].form, pstat.state.unwrap(), user_inputs=user_inputs or [])
562
581
  except FormValidationError:
@@ -22,6 +22,7 @@ from uuid import UUID
22
22
 
23
23
  import more_itertools
24
24
  import structlog
25
+ from more_itertools import first
25
26
  from sqlalchemy import Text, cast, not_, select
26
27
  from sqlalchemy.exc import SQLAlchemyError
27
28
  from sqlalchemy.orm import Query, aliased, joinedload
@@ -41,7 +42,11 @@ from orchestrator.db.models import (
41
42
  SubscriptionInstanceRelationTable,
42
43
  SubscriptionMetadataTable,
43
44
  )
45
+ from orchestrator.db.queries.subscription import (
46
+ eagerload_all_subscription_instances_only_inuseby,
47
+ )
44
48
  from orchestrator.domain.base import SubscriptionModel
49
+ from orchestrator.domain.context_cache import cache_subscription_models
45
50
  from orchestrator.targets import Target
46
51
  from orchestrator.types import SubscriptionLifecycle
47
52
  from orchestrator.utils.datetime import nowtz
@@ -594,13 +599,15 @@ def convert_to_in_use_by_relation(obj: Any) -> dict[str, str]:
594
599
 
595
600
  def build_extended_domain_model(subscription_model: SubscriptionModel) -> dict:
596
601
  """Create a subscription dict from the SubscriptionModel with additional keys."""
597
- stmt = select(SubscriptionCustomerDescriptionTable).filter(
602
+ from orchestrator.settings import app_settings
603
+
604
+ stmt = select(SubscriptionCustomerDescriptionTable).where(
598
605
  SubscriptionCustomerDescriptionTable.subscription_id == subscription_model.subscription_id
599
606
  )
600
607
  customer_descriptions = list(db.session.scalars(stmt))
601
608
 
602
- subscription = subscription_model.model_dump()
603
- paths = product_block_paths(subscription)
609
+ with cache_subscription_models():
610
+ subscription = subscription_model.model_dump()
604
611
 
605
612
  def inject_in_use_by_ids(path_to_block: str) -> None:
606
613
  if not (in_use_by_subs := getattr_in(subscription_model, f"{path_to_block}.in_use_by")):
@@ -611,15 +618,38 @@ def build_extended_domain_model(subscription_model: SubscriptionModel) -> dict:
611
618
  update_in(subscription, f"{path_to_block}.in_use_by_ids", in_use_by_ids)
612
619
  update_in(subscription, f"{path_to_block}.in_use_by_relations", in_use_by_relations)
613
620
 
614
- # find all product blocks, check if they have in_use_by and inject the in_use_by_ids into the subscription dict.
615
- for path in paths:
616
- inject_in_use_by_ids(path)
621
+ if app_settings.ENABLE_SUBSCRIPTION_MODEL_OPTIMIZATIONS:
622
+ # TODO #900 remove toggle and make this path the default
623
+ # query all subscription instances and inject the in_use_by_ids/in_use_by_relations into the subscription dict.
624
+ instance_to_in_use_by = {
625
+ instance.subscription_instance_id: instance.in_use_by
626
+ for instance in eagerload_all_subscription_instances_only_inuseby(subscription_model.subscription_id)
627
+ }
628
+ inject_in_use_by_ids_v2(subscription, instance_to_in_use_by)
629
+ else:
630
+ # find all product blocks, check if they have in_use_by and inject the in_use_by_ids into the subscription dict.
631
+ for path in product_block_paths(subscription):
632
+ inject_in_use_by_ids(path)
617
633
 
618
634
  subscription["customer_descriptions"] = customer_descriptions
619
635
 
620
636
  return subscription
621
637
 
622
638
 
639
+ def inject_in_use_by_ids_v2(dikt: dict, instance_to_in_use_by: dict[UUID, Sequence[SubscriptionInstanceTable]]) -> None:
640
+ for value in dikt.values():
641
+ if isinstance(value, dict):
642
+ inject_in_use_by_ids_v2(value, instance_to_in_use_by)
643
+ elif isinstance(value, list) and isinstance(first(value, None), dict):
644
+ for item in value:
645
+ inject_in_use_by_ids_v2(item, instance_to_in_use_by)
646
+
647
+ if subscription_instance_id := dikt.get("subscription_instance_id"):
648
+ in_use_by_subs = instance_to_in_use_by[subscription_instance_id]
649
+ dikt["in_use_by_ids"] = [i.subscription_instance_id for i in in_use_by_subs]
650
+ dikt["in_use_by_relations"] = [convert_to_in_use_by_relation(instance) for instance in in_use_by_subs]
651
+
652
+
623
653
  def format_special_types(subscription: dict) -> dict:
624
654
  """Modifies the subscription dict in-place, formatting special types to string.
625
655
 
orchestrator/settings.py CHANGED
@@ -87,6 +87,9 @@ class AppSettings(BaseSettings):
87
87
  ENABLE_GRAPHQL_STATS_EXTENSION: bool = False
88
88
  VALIDATE_OUT_OF_SYNC_SUBSCRIPTIONS: bool = False
89
89
  FILTER_BY_MODE: Literal["partial", "exact"] = "exact"
90
+ ENABLE_SUBSCRIPTION_MODEL_OPTIMIZATIONS: bool = (
91
+ True # True=ignore cache + optimized DB queries; False=use cache + unoptimized DB queries. Remove in #900
92
+ )
90
93
 
91
94
 
92
95
  app_settings = AppSettings()
@@ -241,3 +241,12 @@ def to_ranges(i: Iterable[int]) -> Iterable[range]:
241
241
  for _, g in itertools.groupby(enumerate(i), lambda t: t[1] - t[0]):
242
242
  group = list(g)
243
243
  yield range(group[0][1], group[-1][1] + 1)
244
+
245
+
246
+ K = TypeVar("K")
247
+ V = TypeVar("V")
248
+
249
+
250
+ def group_by_key(items: Iterable[tuple[K, V]]) -> dict[K, list[V]]:
251
+ groups = itertools.groupby(items, key=lambda item: item[0])
252
+ return {key: [item[1] for item in group] for key, group in groups}
@@ -55,6 +55,12 @@ def to_redis(subscription: dict[str, Any]) -> str | None:
55
55
 
56
56
  def from_redis(subscription_id: UUID) -> tuple[PY_JSON_TYPES, str] | None:
57
57
  log = logger.bind(subscription_id=subscription_id)
58
+
59
+ if app_settings.ENABLE_SUBSCRIPTION_MODEL_OPTIMIZATIONS:
60
+ # TODO #900 remove toggle and remove usage of this function in get_subscription_dict
61
+ log.warning("Using SubscriptionModel optimization, not loading subscription from cache")
62
+ return None
63
+
58
64
  if caching_models_enabled():
59
65
  log.debug("Try to retrieve subscription from cache")
60
66
  obj = cache.get(f"orchestrator:domain:{subscription_id}")
orchestrator/workflow.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2019-2025 SURF, GÉANT.
1
+ # Copyright 2019-2025 SURF, GÉANT, ESnet.
2
2
  # Licensed under the Apache License, Version 2.0 (the "License");
3
3
  # you may not use this file except in compliance with the License.
4
4
  # You may obtain a copy of the License at
@@ -39,6 +39,7 @@ from structlog.contextvars import bound_contextvars
39
39
  from structlog.stdlib import BoundLogger
40
40
 
41
41
  from nwastdlib import const, identity
42
+ from oauth2_lib.fastapi import OIDCUserModel
42
43
  from orchestrator.config.assignee import Assignee
43
44
  from orchestrator.db import db, transactional
44
45
  from orchestrator.services.settings import get_engine_settings
@@ -89,6 +90,7 @@ class Workflow(Protocol):
89
90
  __qualname__: str
90
91
  name: str
91
92
  description: str
93
+ authorize_callback: Callable[[OIDCUserModel | None], bool]
92
94
  initial_input_form: InputFormGenerator | None = None
93
95
  target: Target
94
96
  steps: StepList
@@ -178,12 +180,18 @@ def _handle_simple_input_form_generator(f: StateInputStepFunc) -> StateInputForm
178
180
  return form_generator
179
181
 
180
182
 
183
+ def allow(_: OIDCUserModel | None) -> bool:
184
+ """Default function to return True in absence of user-defined authorize function."""
185
+ return True
186
+
187
+
181
188
  def make_workflow(
182
189
  f: Callable,
183
190
  description: str,
184
191
  initial_input_form: InputStepFunc | None,
185
192
  target: Target,
186
193
  steps: StepList,
194
+ authorize_callback: Callable[[OIDCUserModel | None], bool] | None = None,
187
195
  ) -> Workflow:
188
196
  @functools.wraps(f)
189
197
  def wrapping_function() -> NoReturn:
@@ -193,6 +201,7 @@ def make_workflow(
193
201
 
194
202
  wrapping_function.name = f.__name__ # default, will be changed by LazyWorkflowInstance
195
203
  wrapping_function.description = description
204
+ wrapping_function.authorize_callback = allow if authorize_callback is None else authorize_callback
196
205
 
197
206
  if initial_input_form is None:
198
207
  # We always need a form to prevent starting a workflow when no input is needed.
@@ -214,7 +223,11 @@ def step(name: str) -> Callable[[StepFunc], Step]:
214
223
  def decorator(func: StepFunc) -> Step:
215
224
  @functools.wraps(func)
216
225
  def wrapper(state: State) -> Process:
217
- with bound_contextvars(func=func.__qualname__):
226
+ with bound_contextvars(
227
+ func=func.__qualname__,
228
+ workflow_name=state.get("workflow_name"),
229
+ process_id=state.get("process_id"),
230
+ ):
218
231
  step_in_inject_args = inject_args(func)
219
232
  try:
220
233
  with transactional(db, logger):
@@ -239,7 +252,11 @@ def retrystep(name: str) -> Callable[[StepFunc], Step]:
239
252
  def decorator(func: StepFunc) -> Step:
240
253
  @functools.wraps(func)
241
254
  def wrapper(state: State) -> Process:
242
- with bound_contextvars(func=func.__qualname__):
255
+ with bound_contextvars(
256
+ func=func.__qualname__,
257
+ workflow_name=state.get("workflow_name"),
258
+ process_id=state.get("process_id"),
259
+ ):
243
260
  step_in_inject_args = inject_args(func)
244
261
  try:
245
262
  with transactional(db, logger):
@@ -459,7 +476,10 @@ def focussteps(key: str) -> Callable[[Step | StepList], StepList]:
459
476
 
460
477
 
461
478
  def workflow(
462
- description: str, initial_input_form: InputStepFunc | None = None, target: Target = Target.SYSTEM
479
+ description: str,
480
+ initial_input_form: InputStepFunc | None = None,
481
+ target: Target = Target.SYSTEM,
482
+ authorize_callback: Callable[[OIDCUserModel | None], bool] | None = None,
463
483
  ) -> Callable[[Callable[[], StepList]], Workflow]:
464
484
  """Transform an initial_input_form and a step list into a workflow.
465
485
 
@@ -479,7 +499,9 @@ def workflow(
479
499
  initial_input_form_in_form_inject_args = form_inject_args(initial_input_form)
480
500
 
481
501
  def _workflow(f: Callable[[], StepList]) -> Workflow:
482
- return make_workflow(f, description, initial_input_form_in_form_inject_args, target, f())
502
+ return make_workflow(
503
+ f, description, initial_input_form_in_form_inject_args, target, f(), authorize_callback=authorize_callback
504
+ )
483
505
 
484
506
  return _workflow
485
507
 
@@ -491,13 +513,14 @@ class ProcessStat:
491
513
  state: Process
492
514
  log: StepList
493
515
  current_user: str
516
+ user_model: OIDCUserModel | None = None
494
517
 
495
518
  def update(self, **vs: Any) -> ProcessStat:
496
519
  """Update ProcessStat.
497
520
 
498
521
  >>> pstat = ProcessStat('', None, {}, [], "")
499
522
  >>> pstat.update(state={"a": "b"})
500
- ProcessStat(process_id='', workflow=None, state={'a': 'b'}, log=[], current_user='')
523
+ ProcessStat(process_id='', workflow=None, state={'a': 'b'}, log=[], current_user='', user_model=None)
501
524
  """
502
525
  return ProcessStat(**{**asdict(self), **vs})
503
526