orchestrator-core 3.1.2rc3__py3-none-any.whl → 3.2.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. orchestrator/__init__.py +2 -2
  2. orchestrator/api/api_v1/api.py +1 -1
  3. orchestrator/api/api_v1/endpoints/processes.py +6 -9
  4. orchestrator/api/api_v1/endpoints/settings.py +1 -1
  5. orchestrator/api/api_v1/endpoints/subscriptions.py +1 -1
  6. orchestrator/app.py +1 -1
  7. orchestrator/cli/database.py +1 -1
  8. orchestrator/cli/generator/generator/migration.py +2 -5
  9. orchestrator/cli/migrate_tasks.py +13 -0
  10. orchestrator/config/assignee.py +1 -1
  11. orchestrator/db/__init__.py +2 -0
  12. orchestrator/db/loaders.py +51 -3
  13. orchestrator/db/models.py +14 -1
  14. orchestrator/db/queries/__init__.py +0 -0
  15. orchestrator/db/queries/subscription.py +85 -0
  16. orchestrator/db/queries/subscription_instance.py +28 -0
  17. orchestrator/devtools/populator.py +1 -1
  18. orchestrator/domain/__init__.py +2 -3
  19. orchestrator/domain/base.py +236 -49
  20. orchestrator/domain/context_cache.py +62 -0
  21. orchestrator/domain/helpers.py +41 -1
  22. orchestrator/domain/lifecycle.py +1 -1
  23. orchestrator/domain/subscription_instance_transform.py +114 -0
  24. orchestrator/graphql/resolvers/process.py +3 -3
  25. orchestrator/graphql/resolvers/product.py +2 -2
  26. orchestrator/graphql/resolvers/product_block.py +2 -2
  27. orchestrator/graphql/resolvers/resource_type.py +2 -2
  28. orchestrator/graphql/resolvers/workflow.py +2 -2
  29. orchestrator/graphql/schema.py +1 -1
  30. orchestrator/graphql/types.py +1 -1
  31. orchestrator/graphql/utils/get_query_loaders.py +6 -48
  32. orchestrator/graphql/utils/get_subscription_product_blocks.py +21 -1
  33. orchestrator/migrations/env.py +1 -1
  34. orchestrator/migrations/helpers.py +6 -6
  35. orchestrator/migrations/versions/schema/2025-03-06_42b3d076a85b_subscription_instance_as_json_function.py +33 -0
  36. orchestrator/migrations/versions/schema/2025-03-06_42b3d076a85b_subscription_instance_as_json_function.sql +40 -0
  37. orchestrator/schemas/engine_settings.py +1 -1
  38. orchestrator/schemas/subscription.py +1 -1
  39. orchestrator/security.py +1 -1
  40. orchestrator/services/celery.py +1 -1
  41. orchestrator/services/processes.py +28 -9
  42. orchestrator/services/products.py +1 -1
  43. orchestrator/services/subscriptions.py +37 -7
  44. orchestrator/services/tasks.py +1 -1
  45. orchestrator/settings.py +5 -23
  46. orchestrator/targets.py +1 -1
  47. orchestrator/types.py +1 -1
  48. orchestrator/utils/errors.py +1 -1
  49. orchestrator/utils/functional.py +9 -0
  50. orchestrator/utils/redis.py +6 -0
  51. orchestrator/utils/state.py +1 -1
  52. orchestrator/websocket/websocket_manager.py +1 -1
  53. orchestrator/workflow.py +19 -4
  54. orchestrator/workflows/modify_note.py +1 -1
  55. orchestrator/workflows/steps.py +1 -1
  56. orchestrator/workflows/tasks/cleanup_tasks_log.py +1 -1
  57. orchestrator/workflows/tasks/resume_workflows.py +1 -1
  58. orchestrator/workflows/tasks/validate_product_type.py +1 -1
  59. orchestrator/workflows/tasks/validate_products.py +1 -1
  60. orchestrator/workflows/utils.py +40 -5
  61. {orchestrator_core-3.1.2rc3.dist-info → orchestrator_core-3.2.0rc1.dist-info}/METADATA +7 -7
  62. {orchestrator_core-3.1.2rc3.dist-info → orchestrator_core-3.2.0rc1.dist-info}/RECORD +65 -58
  63. {orchestrator_core-3.1.2rc3.dist-info → orchestrator_core-3.2.0rc1.dist-info}/WHEEL +1 -1
  64. /orchestrator/migrations/versions/schema/{2025-10-19_4fjdn13f83ga_add_validate_product_type_task.py → 2025-01-19_4fjdn13f83ga_add_validate_product_type_task.py} +0 -0
  65. {orchestrator_core-3.1.2rc3.dist-info → orchestrator_core-3.2.0rc1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,114 @@
1
+ # Copyright 2019-2025 SURF.
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+
14
+ """Functions to transform result of query SubscriptionInstanceAsJsonFunction to match the ProductBlockModel."""
15
+
16
+ from functools import partial
17
+ from typing import TYPE_CHECKING, Any, Callable, Iterable
18
+
19
+ from more_itertools import first, only
20
+
21
+ from orchestrator.types import is_list_type, is_optional_type
22
+
23
+ if TYPE_CHECKING:
24
+ from orchestrator.domain.base import ProductBlockModel
25
+
26
+
27
+ def _ensure_list(instance_or_value_list: Any) -> Any:
28
+ if instance_or_value_list is None:
29
+ return []
30
+
31
+ return instance_or_value_list
32
+
33
+
34
+ def _instance_list_to_dict(product_block_field_type: type, instance_list: Any) -> Any:
35
+ if instance_list is None:
36
+ return None
37
+
38
+ match instance_list:
39
+ case list():
40
+ if instance := only(instance_list):
41
+ return instance
42
+
43
+ if not is_optional_type(product_block_field_type):
44
+ raise ValueError("Required subscription instance is missing in database")
45
+
46
+ return None # Set the optional product block field to None
47
+ case _:
48
+ raise ValueError(f"All subscription instances should be returned as list, found {type(instance_list)}") #
49
+
50
+
51
+ def _value_list_to_value(field_type: type, value_list: Any) -> Any:
52
+ if value_list is None:
53
+ return None
54
+
55
+ match value_list:
56
+ case list():
57
+ if (value := only(value_list)) is not None:
58
+ return value
59
+
60
+ if not is_optional_type(field_type):
61
+ raise ValueError("Required subscription value is missing in database")
62
+
63
+ return None # Set the optional resource type field to None
64
+ case _:
65
+ raise ValueError(f"All instance values should be returned as list, found {type(value_list)}")
66
+
67
+
68
+ def field_transformation_rules(klass: type["ProductBlockModel"]) -> dict[str, Callable]:
69
+ """Create mapping of transformation rules for the given product block type."""
70
+
71
+ def create_rules() -> Iterable[tuple[str, Callable]]:
72
+ for field_name, product_block_field_type in klass._product_block_fields_.items():
73
+ if is_list_type(product_block_field_type):
74
+ yield field_name, _ensure_list
75
+ else:
76
+ yield field_name, partial(_instance_list_to_dict, product_block_field_type)
77
+
78
+ for field_name, field_type in klass._non_product_block_fields_.items():
79
+ if is_list_type(field_type):
80
+ yield field_name, _ensure_list
81
+ else:
82
+ yield field_name, partial(_value_list_to_value, field_type)
83
+
84
+ return dict(create_rules())
85
+
86
+
87
+ def transform_instance_fields(all_rules: dict[str, dict[str, Callable]], instance: dict) -> None:
88
+ """Apply transformation rules to the given subscription instance dict."""
89
+
90
+ from orchestrator.domain.base import ProductBlockModel
91
+
92
+ # Lookup applicable rules through product block name
93
+ field_rules = all_rules[instance["name"]]
94
+
95
+ klass = ProductBlockModel.registry[instance["name"]]
96
+
97
+ # Ensure the product block's metadata is loaded
98
+ klass._fix_pb_data()
99
+
100
+ # Transform all fields in this subscription instance
101
+ try:
102
+ for field_name, rewrite_func in field_rules.items():
103
+ field_value = instance.get(field_name)
104
+ instance[field_name] = rewrite_func(field_value)
105
+ except ValueError as e:
106
+ raise ValueError(f"Invalid subscription instance data {instance}") from e
107
+
108
+ # Recurse into nested subscription instances
109
+ for field_value in instance.values():
110
+ if isinstance(field_value, dict):
111
+ transform_instance_fields(all_rules, field_value)
112
+ if isinstance(field_value, list) and isinstance(first(field_value, None), dict):
113
+ for list_value in field_value:
114
+ transform_instance_fields(all_rules, list_value)
@@ -34,7 +34,7 @@ from orchestrator.graphql.utils import (
34
34
  is_querying_page_data,
35
35
  to_graphql_result_page,
36
36
  )
37
- from orchestrator.graphql.utils.get_query_loaders import get_query_loaders
37
+ from orchestrator.graphql.utils.get_query_loaders import get_query_loaders_for_gql_fields
38
38
  from orchestrator.schemas.process import ProcessSchema
39
39
  from orchestrator.services.processes import load_process
40
40
  from orchestrator.utils.enrich_process import enrich_process
@@ -56,7 +56,7 @@ def _enrich_process(process: ProcessTable, with_details: bool = False) -> Proces
56
56
 
57
57
 
58
58
  async def resolve_process(info: OrchestratorInfo, process_id: UUID) -> ProcessType | None:
59
- query_loaders = get_query_loaders(info, ProcessTable)
59
+ query_loaders = get_query_loaders_for_gql_fields(ProcessTable, info)
60
60
  stmt = select(ProcessTable).options(*query_loaders).where(ProcessTable.process_id == process_id)
61
61
  if process := db.session.scalar(stmt):
62
62
  is_detailed = _is_process_detailed(info)
@@ -83,7 +83,7 @@ async def resolve_processes(
83
83
  .selectinload(ProcessSubscriptionTable.subscription)
84
84
  .joinedload(SubscriptionTable.product)
85
85
  ]
86
- query_loaders = get_query_loaders(info, ProcessTable) or default_loaders
86
+ query_loaders = get_query_loaders_for_gql_fields(ProcessTable, info) or default_loaders
87
87
  select_stmt = select(ProcessTable).options(*query_loaders)
88
88
  select_stmt = filter_processes(select_stmt, pydantic_filter_by, _error_handler)
89
89
  if query is not None:
@@ -13,7 +13,7 @@ from orchestrator.graphql.resolvers.helpers import rows_from_statement
13
13
  from orchestrator.graphql.schemas.product import ProductType
14
14
  from orchestrator.graphql.types import GraphqlFilter, GraphqlSort, OrchestratorInfo
15
15
  from orchestrator.graphql.utils import create_resolver_error_handler, is_querying_page_data, to_graphql_result_page
16
- from orchestrator.graphql.utils.get_query_loaders import get_query_loaders
16
+ from orchestrator.graphql.utils.get_query_loaders import get_query_loaders_for_gql_fields
17
17
  from orchestrator.utils.search_query import create_sqlalchemy_select
18
18
 
19
19
  logger = structlog.get_logger(__name__)
@@ -33,7 +33,7 @@ async def resolve_products(
33
33
  pydantic_sort_by: list[Sort] = [item.to_pydantic() for item in sort_by] if sort_by else []
34
34
  logger.debug("resolve_products() called", range=[after, after + first], sort=sort_by, filter=pydantic_filter_by)
35
35
 
36
- query_loaders = get_query_loaders(info, ProductTable)
36
+ query_loaders = get_query_loaders_for_gql_fields(ProductTable, info)
37
37
  select_stmt = select(ProductTable).options(*query_loaders)
38
38
  select_stmt = filter_products(select_stmt, pydantic_filter_by, _error_handler)
39
39
 
@@ -17,7 +17,7 @@ from orchestrator.graphql.resolvers.helpers import rows_from_statement
17
17
  from orchestrator.graphql.schemas.product_block import ProductBlock
18
18
  from orchestrator.graphql.types import GraphqlFilter, GraphqlSort, OrchestratorInfo
19
19
  from orchestrator.graphql.utils import create_resolver_error_handler, is_querying_page_data, to_graphql_result_page
20
- from orchestrator.graphql.utils.get_query_loaders import get_query_loaders
20
+ from orchestrator.graphql.utils.get_query_loaders import get_query_loaders_for_gql_fields
21
21
  from orchestrator.utils.search_query import create_sqlalchemy_select
22
22
 
23
23
  logger = structlog.get_logger(__name__)
@@ -39,7 +39,7 @@ async def resolve_product_blocks(
39
39
  "resolve_product_blocks() called", range=[after, after + first], sort=sort_by, filter=pydantic_filter_by
40
40
  )
41
41
 
42
- query_loaders = get_query_loaders(info, ProductBlockTable)
42
+ query_loaders = get_query_loaders_for_gql_fields(ProductBlockTable, info)
43
43
  select_stmt = select(ProductBlockTable)
44
44
  select_stmt = filter_product_blocks(select_stmt, pydantic_filter_by, _error_handler)
45
45
 
@@ -17,7 +17,7 @@ from orchestrator.graphql.resolvers.helpers import rows_from_statement
17
17
  from orchestrator.graphql.schemas.resource_type import ResourceType
18
18
  from orchestrator.graphql.types import GraphqlFilter, GraphqlSort, OrchestratorInfo
19
19
  from orchestrator.graphql.utils import create_resolver_error_handler, is_querying_page_data, to_graphql_result_page
20
- from orchestrator.graphql.utils.get_query_loaders import get_query_loaders
20
+ from orchestrator.graphql.utils.get_query_loaders import get_query_loaders_for_gql_fields
21
21
  from orchestrator.utils.search_query import create_sqlalchemy_select
22
22
 
23
23
  logger = structlog.get_logger(__name__)
@@ -38,7 +38,7 @@ async def resolve_resource_types(
38
38
  logger.debug(
39
39
  "resolve_resource_types() called", range=[after, after + first], sort=sort_by, filter=pydantic_filter_by
40
40
  )
41
- query_loaders = get_query_loaders(info, ResourceTypeTable)
41
+ query_loaders = get_query_loaders_for_gql_fields(ResourceTypeTable, info)
42
42
  select_stmt = select(ResourceTypeTable).options(*query_loaders)
43
43
  select_stmt = filter_resource_types(select_stmt, pydantic_filter_by, _error_handler)
44
44
 
@@ -13,7 +13,7 @@ from orchestrator.graphql.resolvers.helpers import rows_from_statement
13
13
  from orchestrator.graphql.schemas.workflow import Workflow
14
14
  from orchestrator.graphql.types import GraphqlFilter, GraphqlSort, OrchestratorInfo
15
15
  from orchestrator.graphql.utils import create_resolver_error_handler, is_querying_page_data, to_graphql_result_page
16
- from orchestrator.graphql.utils.get_query_loaders import get_query_loaders
16
+ from orchestrator.graphql.utils.get_query_loaders import get_query_loaders_for_gql_fields
17
17
  from orchestrator.utils.search_query import create_sqlalchemy_select
18
18
 
19
19
  logger = structlog.get_logger(__name__)
@@ -33,7 +33,7 @@ async def resolve_workflows(
33
33
  pydantic_sort_by: list[Sort] = [item.to_pydantic() for item in sort_by] if sort_by else []
34
34
  logger.debug("resolve_workflows() called", range=[after, after + first], sort=sort_by, filter=pydantic_filter_by)
35
35
 
36
- query_loaders = get_query_loaders(info, WorkflowTable)
36
+ query_loaders = get_query_loaders_for_gql_fields(WorkflowTable, info)
37
37
  select_stmt = WorkflowTable.select().options(*query_loaders)
38
38
  select_stmt = filter_workflows(select_stmt, pydantic_filter_by, _error_handler)
39
39
 
@@ -1,4 +1,4 @@
1
- # Copyright 2022-2023 SURF.
1
+ # Copyright 2022-2023 SURF, GÉANT.
2
2
  # Licensed under the Apache License, Version 2.0 (the "License");
3
3
  # you may not use this file except in compliance with the License.
4
4
  # You may obtain a copy of the License at
@@ -1,4 +1,4 @@
1
- # Copyright 2022-2023 SURF.
1
+ # Copyright 2022-2023 SURF, GÉANT.
2
2
  # Licensed under the Apache License, Version 2.0 (the "License");
3
3
  # you may not use this file except in compliance with the License.
4
4
  # You may obtain a copy of the License at
@@ -1,61 +1,19 @@
1
- from typing import Iterable
2
-
3
- import structlog
4
1
  from sqlalchemy.orm import Load
5
2
 
6
3
  from orchestrator.db.database import BaseModel as DbBaseModel
7
- from orchestrator.db.loaders import AttrLoader, join_attr_loaders, lookup_attr_loaders
4
+ from orchestrator.db.loaders import (
5
+ get_query_loaders_for_model_paths,
6
+ )
8
7
  from orchestrator.graphql.types import OrchestratorInfo
9
8
  from orchestrator.graphql.utils.get_selected_paths import get_selected_paths
10
9
 
11
- logger = structlog.get_logger(__name__)
12
-
13
-
14
- def _split_path(query_path: str) -> Iterable[str]:
15
- yield from (field for field in query_path.split(".") if field != "page")
16
10
 
17
-
18
- def get_query_loaders(info: OrchestratorInfo, root_model: type[DbBaseModel]) -> list[Load]:
11
+ def get_query_loaders_for_gql_fields(root_model: type[DbBaseModel], info: OrchestratorInfo) -> list[Load]:
19
12
  """Get sqlalchemy query loaders for the given GraphQL query.
20
13
 
21
14
  Based on the GraphQL query's selected fields, returns the required DB loaders to use
22
15
  in SQLALchemy's `.options()` for efficiently quering (nested) relationships.
23
16
  """
24
- # Strip page and sort by length to find the longest match first
25
- query_paths = [path.removeprefix("page.") for path in get_selected_paths(info)]
26
- query_paths.sort(key=lambda x: x.count("."), reverse=True)
27
-
28
- def get_loader_for_path(query_path: str) -> tuple[str, Load | None]:
29
- next_model = root_model
30
-
31
- matched_fields: list[str] = []
32
- path_loaders: list[AttrLoader] = []
33
-
34
- for field in _split_path(query_path):
35
- if not (attr_loaders := lookup_attr_loaders(next_model, field)):
36
- break
37
-
38
- matched_fields.append(field)
39
- path_loaders.extend(attr_loaders)
40
- next_model = attr_loaders[-1].next_model
41
-
42
- return ".".join(matched_fields), join_attr_loaders(path_loaders)
43
-
44
- query_loaders: dict[str, Load] = {}
45
-
46
- for path in query_paths:
47
- matched_path, loader = get_loader_for_path(path)
48
- if not matched_path or not loader or matched_path in query_loaders:
49
- continue
50
- if any(known_path.startswith(f"{matched_path}.") for known_path in query_loaders):
51
- continue
52
- query_loaders[matched_path] = loader
17
+ model_paths = [path.removeprefix("page.") for path in get_selected_paths(info)]
53
18
 
54
- loaders = list(query_loaders.values())
55
- logger.debug(
56
- "Generated query loaders",
57
- root_model=root_model,
58
- query_paths=query_paths,
59
- query_loaders=[str(i.path) for i in loaders],
60
- )
61
- return loaders
19
+ return get_query_loaders_for_model_paths(root_model, model_paths)
@@ -1,3 +1,16 @@
1
+ # Copyright 2022-2023 SURF, GÉANT.
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+
1
14
  from collections.abc import Generator
2
15
  from itertools import count
3
16
  from typing import TYPE_CHECKING, Annotated, Any
@@ -57,7 +70,14 @@ def get_all_product_blocks(subscription: dict[str, Any], _tags: list[str] | None
57
70
  return list(locate_product_block(subscription))
58
71
 
59
72
 
60
- pb_instance_property_keys = ("id", "parent", "owner_subscription_id", "subscription_instance_id", "in_use_by_relations")
73
+ pb_instance_property_keys = (
74
+ "id",
75
+ "parent",
76
+ "owner_subscription_id",
77
+ "subscription_instance_id",
78
+ "in_use_by_relations",
79
+ "in_use_by_ids",
80
+ )
61
81
 
62
82
 
63
83
  async def get_subscription_product_blocks(
@@ -1,4 +1,4 @@
1
- # Copyright 2019-2020 SURF.
1
+ # Copyright 2019-2020 SURF, GÉANT.
2
2
  # Licensed under the Apache License, Version 2.0 (the "License");
3
3
  # you may not use this file except in compliance with the License.
4
4
  # You may obtain a copy of the License at
@@ -1,4 +1,4 @@
1
- # Copyright 2019-2020 SURF.
1
+ # Copyright 2019-2020 SURF, GÉANT.
2
2
  # Licensed under the Apache License, Version 2.0 (the "License");
3
3
  # you may not use this file except in compliance with the License.
4
4
  # You may obtain a copy of the License at
@@ -880,10 +880,10 @@ def delete_product(conn: sa.engine.Connection, name: str) -> None:
880
880
  RETURNING product_id
881
881
  ),
882
882
  deleted_p_pb AS (
883
- DELETE FROM product_product_blocks WHERE product_id = ANY(SELECT product_id FROM deleted_p)
883
+ DELETE FROM product_product_blocks WHERE product_id IN (SELECT product_id FROM deleted_p)
884
884
  ),
885
885
  deleted_pb_rt AS (
886
- DELETE FROM products_workflows WHERE product_id = ANY(SELECT product_id FROM deleted_p)
886
+ DELETE FROM products_workflows WHERE product_id IN (SELECT product_id FROM deleted_p)
887
887
  )
888
888
  SELECT * from deleted_p;
889
889
  """
@@ -911,10 +911,10 @@ def delete_product_block(conn: sa.engine.Connection, name: str) -> None:
911
911
  RETURNING product_block_id
912
912
  ),
913
913
  deleted_p_pb AS (
914
- DELETE FROM product_product_blocks WHERE product_block_id =ANY(SELECT product_block_id FROM deleted_pb)
914
+ DELETE FROM product_product_blocks WHERE product_block_id IN (SELECT product_block_id FROM deleted_pb)
915
915
  ),
916
916
  deleted_pb_rt AS (
917
- DELETE FROM product_block_resource_types WHERE product_block_id =ANY(SELECT product_block_id FROM deleted_pb)
917
+ DELETE FROM product_block_resource_types WHERE product_block_id IN (SELECT product_block_id FROM deleted_pb)
918
918
  )
919
919
  SELECT * from deleted_pb;
920
920
  """
@@ -968,7 +968,7 @@ def delete_resource_type(conn: sa.engine.Connection, resource_type: str) -> None
968
968
  RETURNING resource_type_id
969
969
  ),
970
970
  deleted_pb_rt AS (
971
- DELETE FROM product_block_resource_types WHERE resource_type_id =ANY(SELECT resource_type_id FROM deleted_pb)
971
+ DELETE FROM product_block_resource_types WHERE resource_type_id IN (SELECT resource_type_id FROM deleted_pb)
972
972
  )
973
973
  SELECT * from deleted_pb;
974
974
  """
@@ -0,0 +1,33 @@
1
+ """Add postgres function subscription_instance_as_json.
2
+
3
+ Revision ID: 42b3d076a85b
4
+ Revises: bac6be6f2b4f
5
+ Create Date: 2025-03-06 15:03:09.477225
6
+
7
+ """
8
+
9
+ from pathlib import Path
10
+
11
+ from alembic import op
12
+ from sqlalchemy import text
13
+
14
+ # revision identifiers, used by Alembic.
15
+ revision = "42b3d076a85b"
16
+ down_revision = "bac6be6f2b4f"
17
+ branch_labels = None
18
+ depends_on = None
19
+
20
+
21
+ def upgrade() -> None:
22
+ conn = op.get_bind()
23
+
24
+ revision_file_path = Path(__file__)
25
+ with open(revision_file_path.with_suffix(".sql")) as f:
26
+ conn.execute(text(f.read()))
27
+
28
+
29
+ def downgrade() -> None:
30
+ conn = op.get_bind()
31
+
32
+ conn.execute(text("DROP FUNCTION IF EXISTS subscription_instance_as_json;"))
33
+ conn.execute(text("DROP FUNCTION IF EXISTS subscription_instance_fields_as_json;"))
@@ -0,0 +1,40 @@
1
+ CREATE OR REPLACE FUNCTION subscription_instance_fields_as_json(sub_inst_id uuid)
2
+ RETURNS jsonb
3
+ LANGUAGE sql
4
+ STABLE PARALLEL SAFE AS
5
+ $func$
6
+ select jsonb_object_agg(rts.key, rts.val)
7
+ from (select attr.key,
8
+ attr.val
9
+ from subscription_instances si
10
+ join product_blocks pb ON si.product_block_id = pb.product_block_id
11
+ cross join lateral (
12
+ values ('subscription_instance_id', to_jsonb(si.subscription_instance_id)),
13
+ ('owner_subscription_id', to_jsonb(si.subscription_id)),
14
+ ('name', to_jsonb(pb.name))
15
+ ) as attr(key, val)
16
+ where si.subscription_instance_id = sub_inst_id
17
+ union all
18
+ select rt.resource_type key,
19
+ jsonb_agg(siv.value ORDER BY siv.value ASC) val
20
+ from subscription_instance_values siv
21
+ join resource_types rt on siv.resource_type_id = rt.resource_type_id
22
+ where siv.subscription_instance_id = sub_inst_id
23
+ group by rt.resource_type) as rts
24
+ $func$;
25
+
26
+
27
+ CREATE OR REPLACE FUNCTION subscription_instance_as_json(sub_inst_id uuid)
28
+ RETURNS jsonb
29
+ LANGUAGE sql
30
+ STABLE PARALLEL SAFE AS
31
+ $func$
32
+ select subscription_instance_fields_as_json(sub_inst_id) ||
33
+ coalesce(jsonb_object_agg(depends_on.block_name, depends_on.block_instances), '{}'::jsonb)
34
+ from (select sir.domain_model_attr block_name,
35
+ jsonb_agg(subscription_instance_as_json(sir.depends_on_id) ORDER BY sir.order_id ASC) as block_instances
36
+ from subscription_instance_relations sir
37
+ where sir.in_use_by_id = sub_inst_id
38
+ and sir.domain_model_attr is not null
39
+ group by block_name) as depends_on
40
+ $func$;
@@ -1,4 +1,4 @@
1
- # Copyright 2019-2020 SURF.
1
+ # Copyright 2019-2020 SURF, GÉANT.
2
2
  # Licensed under the Apache License, Version 2.0 (the "License");
3
3
  # you may not use this file except in compliance with the License.
4
4
  # You may obtain a copy of the License at
@@ -1,4 +1,4 @@
1
- # Copyright 2019-2020 SURF.
1
+ # Copyright 2019-2025 SURF, GÉANT.
2
2
  # Licensed under the Apache License, Version 2.0 (the "License");
3
3
  # you may not use this file except in compliance with the License.
4
4
  # You may obtain a copy of the License at
orchestrator/security.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2019-2020 SURF.
1
+ # Copyright 2019-2020 SURF, GÉANT.
2
2
  # Licensed under the Apache License, Version 2.0 (the "License");
3
3
  # you may not use this file except in compliance with the License.
4
4
  # You may obtain a copy of the License at
@@ -1,4 +1,4 @@
1
- # Copyright 2019-2020 SURF.
1
+ # Copyright 2019-2020 SURF, GÉANT.
2
2
  # Licensed under the Apache License, Version 2.0 (the "License");
3
3
  # you may not use this file except in compliance with the License.
4
4
  # You may obtain a copy of the License at
@@ -1,4 +1,4 @@
1
- # Copyright 2019-2020 SURF.
1
+ # Copyright 2019-2025 SURF, GÉANT, ESnet.
2
2
  # Licensed under the Apache License, Version 2.0 (the "License");
3
3
  # you may not use this file except in compliance with the License.
4
4
  # You may obtain a copy of the License at
@@ -24,15 +24,10 @@ from sqlalchemy.exc import SQLAlchemyError
24
24
  from sqlalchemy.orm import joinedload
25
25
 
26
26
  from nwastdlib.ex import show_ex
27
+ from oauth2_lib.fastapi import OIDCUserModel
27
28
  from orchestrator.api.error_handling import raise_status
28
29
  from orchestrator.config.assignee import Assignee
29
- from orchestrator.db import (
30
- EngineSettingsTable,
31
- ProcessStepTable,
32
- ProcessSubscriptionTable,
33
- ProcessTable,
34
- db,
35
- )
30
+ from orchestrator.db import EngineSettingsTable, ProcessStepTable, ProcessSubscriptionTable, ProcessTable, db
36
31
  from orchestrator.distlock import distlock_manager
37
32
  from orchestrator.schemas.engine_settings import WorkerStatus
38
33
  from orchestrator.services.input_state import store_input_state
@@ -414,10 +409,15 @@ def _run_process_async(process_id: UUID, f: Callable) -> UUID:
414
409
  return process_id
415
410
 
416
411
 
412
+ def error_message_unauthorized(workflow_key: str) -> str:
413
+ return f"User is not authorized to execute '{workflow_key}' workflow"
414
+
415
+
417
416
  def create_process(
418
417
  workflow_key: str,
419
418
  user_inputs: list[State] | None = None,
420
419
  user: str = SYSTEM_USER,
420
+ user_model: OIDCUserModel | None = None,
421
421
  ) -> ProcessStat:
422
422
  # ATTENTION!! When modifying this function make sure you make similar changes to `run_workflow` in the test code
423
423
 
@@ -430,6 +430,9 @@ def create_process(
430
430
  if not workflow:
431
431
  raise_status(HTTPStatus.NOT_FOUND, "Workflow does not exist")
432
432
 
433
+ if not workflow.authorize_callback(user_model):
434
+ raise_status(HTTPStatus.FORBIDDEN, error_message_unauthorized(workflow_key))
435
+
433
436
  initial_state = {
434
437
  "process_id": process_id,
435
438
  "reporter": user,
@@ -449,6 +452,7 @@ def create_process(
449
452
  state=Success(state | initial_state),
450
453
  log=workflow.steps,
451
454
  current_user=user,
455
+ user_model=user_model,
452
456
  )
453
457
 
454
458
  _db_create_process(pstat)
@@ -460,9 +464,12 @@ def thread_start_process(
460
464
  workflow_key: str,
461
465
  user_inputs: list[State] | None = None,
462
466
  user: str = SYSTEM_USER,
467
+ user_model: OIDCUserModel | None = None,
463
468
  broadcast_func: BroadcastFunc | None = None,
464
469
  ) -> UUID:
465
470
  pstat = create_process(workflow_key, user_inputs=user_inputs, user=user)
471
+ if not pstat.workflow.authorize_callback(user_model):
472
+ raise_status(HTTPStatus.FORBIDDEN, error_message_unauthorized(workflow_key))
466
473
 
467
474
  _safe_logstep_with_func = partial(safe_logstep, broadcast_func=broadcast_func)
468
475
  return _run_process_async(pstat.process_id, lambda: runwf(pstat, _safe_logstep_with_func))
@@ -472,6 +479,7 @@ def start_process(
472
479
  workflow_key: str,
473
480
  user_inputs: list[State] | None = None,
474
481
  user: str = SYSTEM_USER,
482
+ user_model: OIDCUserModel | None = None,
475
483
  broadcast_func: BroadcastFunc | None = None,
476
484
  ) -> UUID:
477
485
  """Start a process for workflow.
@@ -480,6 +488,7 @@ def start_process(
480
488
  workflow_key: name of workflow
481
489
  user_inputs: List of form inputs from frontend
482
490
  user: User who starts this process
491
+ user_model: Full OIDCUserModel with claims, etc
483
492
  broadcast_func: Optional function to broadcast process data
484
493
 
485
494
  Returns:
@@ -487,7 +496,9 @@ def start_process(
487
496
 
488
497
  """
489
498
  start_func = get_execution_context()["start"]
490
- return start_func(workflow_key, user_inputs=user_inputs, user=user, broadcast_func=broadcast_func)
499
+ return start_func(
500
+ workflow_key, user_inputs=user_inputs, user=user, user_model=user_model, broadcast_func=broadcast_func
501
+ )
491
502
 
492
503
 
493
504
  def thread_resume_process(
@@ -495,6 +506,7 @@ def thread_resume_process(
495
506
  *,
496
507
  user_inputs: list[State] | None = None,
497
508
  user: str | None = None,
509
+ user_model: OIDCUserModel | None = None,
498
510
  broadcast_func: BroadcastFunc | None = None,
499
511
  ) -> UUID:
500
512
  # ATTENTION!! When modifying this function make sure you make similar changes to `resume_workflow` in the test code
@@ -503,6 +515,8 @@ def thread_resume_process(
503
515
  user_inputs = [{}]
504
516
 
505
517
  pstat = load_process(process)
518
+ if not pstat.workflow.authorize_callback(user_model):
519
+ raise_status(HTTPStatus.FORBIDDEN, error_message_unauthorized(str(process.workflow_name)))
506
520
 
507
521
  if pstat.workflow == removed_workflow:
508
522
  raise ValueError("This workflow cannot be resumed")
@@ -542,6 +556,7 @@ def resume_process(
542
556
  *,
543
557
  user_inputs: list[State] | None = None,
544
558
  user: str | None = None,
559
+ user_model: OIDCUserModel | None = None,
545
560
  broadcast_func: BroadcastFunc | None = None,
546
561
  ) -> UUID:
547
562
  """Resume a failed or suspended process.
@@ -550,6 +565,7 @@ def resume_process(
550
565
  process: Process from database
551
566
  user_inputs: Optional user input from forms
552
567
  user: user who resumed this process
568
+ user_model: OIDCUserModel of user who resumed this process
553
569
  broadcast_func: Optional function to broadcast process data
554
570
 
555
571
  Returns:
@@ -557,6 +573,9 @@ def resume_process(
557
573
 
558
574
  """
559
575
  pstat = load_process(process)
576
+ if not pstat.workflow.authorize_callback(user_model):
577
+ raise_status(HTTPStatus.FORBIDDEN, error_message_unauthorized(str(process.workflow_name)))
578
+
560
579
  try:
561
580
  post_form(pstat.log[0].form, pstat.state.unwrap(), user_inputs=user_inputs or [])
562
581
  except FormValidationError:
@@ -1,4 +1,4 @@
1
- # Copyright 2019-2020 SURF.
1
+ # Copyright 2019-2020 SURF, GÉANT.
2
2
  # Licensed under the Apache License, Version 2.0 (the "License");
3
3
  # you may not use this file except in compliance with the License.
4
4
  # You may obtain a copy of the License at