orchestrator-core 3.1.2rc4__py3-none-any.whl → 3.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orchestrator/__init__.py +1 -1
- orchestrator/api/api_v1/endpoints/processes.py +6 -9
- orchestrator/cli/generator/generator/workflow.py +13 -1
- orchestrator/cli/generator/templates/modify_product.j2 +9 -0
- orchestrator/db/__init__.py +2 -0
- orchestrator/db/loaders.py +51 -3
- orchestrator/db/models.py +13 -0
- orchestrator/db/queries/__init__.py +0 -0
- orchestrator/db/queries/subscription.py +85 -0
- orchestrator/db/queries/subscription_instance.py +28 -0
- orchestrator/domain/base.py +162 -44
- orchestrator/domain/context_cache.py +62 -0
- orchestrator/domain/helpers.py +41 -1
- orchestrator/domain/subscription_instance_transform.py +114 -0
- orchestrator/graphql/resolvers/process.py +3 -3
- orchestrator/graphql/resolvers/product.py +2 -2
- orchestrator/graphql/resolvers/product_block.py +2 -2
- orchestrator/graphql/resolvers/resource_type.py +2 -2
- orchestrator/graphql/resolvers/workflow.py +2 -2
- orchestrator/graphql/utils/get_query_loaders.py +6 -48
- orchestrator/graphql/utils/get_subscription_product_blocks.py +8 -1
- orchestrator/migrations/versions/schema/2025-03-06_42b3d076a85b_subscription_instance_as_json_function.py +33 -0
- orchestrator/migrations/versions/schema/2025-03-06_42b3d076a85b_subscription_instance_as_json_function.sql +40 -0
- orchestrator/migrations/versions/schema/2025-04-09_fc5c993a4b4a_add_cascade_constraint_on_processes_.py +44 -0
- orchestrator/services/processes.py +28 -9
- orchestrator/services/subscriptions.py +36 -6
- orchestrator/settings.py +3 -0
- orchestrator/utils/functional.py +9 -0
- orchestrator/utils/redis.py +6 -0
- orchestrator/workflow.py +29 -6
- orchestrator/workflows/utils.py +40 -5
- {orchestrator_core-3.1.2rc4.dist-info → orchestrator_core-3.2.0.dist-info}/METADATA +9 -8
- {orchestrator_core-3.1.2rc4.dist-info → orchestrator_core-3.2.0.dist-info}/RECORD +36 -28
- /orchestrator/migrations/versions/schema/{2025-10-19_4fjdn13f83ga_add_validate_product_type_task.py → 2025-01-19_4fjdn13f83ga_add_validate_product_type_task.py} +0 -0
- {orchestrator_core-3.1.2rc4.dist-info → orchestrator_core-3.2.0.dist-info}/WHEEL +0 -0
- {orchestrator_core-3.1.2rc4.dist-info → orchestrator_core-3.2.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -17,7 +17,7 @@ from orchestrator.graphql.resolvers.helpers import rows_from_statement
|
|
|
17
17
|
from orchestrator.graphql.schemas.product_block import ProductBlock
|
|
18
18
|
from orchestrator.graphql.types import GraphqlFilter, GraphqlSort, OrchestratorInfo
|
|
19
19
|
from orchestrator.graphql.utils import create_resolver_error_handler, is_querying_page_data, to_graphql_result_page
|
|
20
|
-
from orchestrator.graphql.utils.get_query_loaders import
|
|
20
|
+
from orchestrator.graphql.utils.get_query_loaders import get_query_loaders_for_gql_fields
|
|
21
21
|
from orchestrator.utils.search_query import create_sqlalchemy_select
|
|
22
22
|
|
|
23
23
|
logger = structlog.get_logger(__name__)
|
|
@@ -39,7 +39,7 @@ async def resolve_product_blocks(
|
|
|
39
39
|
"resolve_product_blocks() called", range=[after, after + first], sort=sort_by, filter=pydantic_filter_by
|
|
40
40
|
)
|
|
41
41
|
|
|
42
|
-
query_loaders =
|
|
42
|
+
query_loaders = get_query_loaders_for_gql_fields(ProductBlockTable, info)
|
|
43
43
|
select_stmt = select(ProductBlockTable)
|
|
44
44
|
select_stmt = filter_product_blocks(select_stmt, pydantic_filter_by, _error_handler)
|
|
45
45
|
|
|
@@ -17,7 +17,7 @@ from orchestrator.graphql.resolvers.helpers import rows_from_statement
|
|
|
17
17
|
from orchestrator.graphql.schemas.resource_type import ResourceType
|
|
18
18
|
from orchestrator.graphql.types import GraphqlFilter, GraphqlSort, OrchestratorInfo
|
|
19
19
|
from orchestrator.graphql.utils import create_resolver_error_handler, is_querying_page_data, to_graphql_result_page
|
|
20
|
-
from orchestrator.graphql.utils.get_query_loaders import
|
|
20
|
+
from orchestrator.graphql.utils.get_query_loaders import get_query_loaders_for_gql_fields
|
|
21
21
|
from orchestrator.utils.search_query import create_sqlalchemy_select
|
|
22
22
|
|
|
23
23
|
logger = structlog.get_logger(__name__)
|
|
@@ -38,7 +38,7 @@ async def resolve_resource_types(
|
|
|
38
38
|
logger.debug(
|
|
39
39
|
"resolve_resource_types() called", range=[after, after + first], sort=sort_by, filter=pydantic_filter_by
|
|
40
40
|
)
|
|
41
|
-
query_loaders =
|
|
41
|
+
query_loaders = get_query_loaders_for_gql_fields(ResourceTypeTable, info)
|
|
42
42
|
select_stmt = select(ResourceTypeTable).options(*query_loaders)
|
|
43
43
|
select_stmt = filter_resource_types(select_stmt, pydantic_filter_by, _error_handler)
|
|
44
44
|
|
|
@@ -13,7 +13,7 @@ from orchestrator.graphql.resolvers.helpers import rows_from_statement
|
|
|
13
13
|
from orchestrator.graphql.schemas.workflow import Workflow
|
|
14
14
|
from orchestrator.graphql.types import GraphqlFilter, GraphqlSort, OrchestratorInfo
|
|
15
15
|
from orchestrator.graphql.utils import create_resolver_error_handler, is_querying_page_data, to_graphql_result_page
|
|
16
|
-
from orchestrator.graphql.utils.get_query_loaders import
|
|
16
|
+
from orchestrator.graphql.utils.get_query_loaders import get_query_loaders_for_gql_fields
|
|
17
17
|
from orchestrator.utils.search_query import create_sqlalchemy_select
|
|
18
18
|
|
|
19
19
|
logger = structlog.get_logger(__name__)
|
|
@@ -33,7 +33,7 @@ async def resolve_workflows(
|
|
|
33
33
|
pydantic_sort_by: list[Sort] = [item.to_pydantic() for item in sort_by] if sort_by else []
|
|
34
34
|
logger.debug("resolve_workflows() called", range=[after, after + first], sort=sort_by, filter=pydantic_filter_by)
|
|
35
35
|
|
|
36
|
-
query_loaders =
|
|
36
|
+
query_loaders = get_query_loaders_for_gql_fields(WorkflowTable, info)
|
|
37
37
|
select_stmt = WorkflowTable.select().options(*query_loaders)
|
|
38
38
|
select_stmt = filter_workflows(select_stmt, pydantic_filter_by, _error_handler)
|
|
39
39
|
|
|
@@ -1,61 +1,19 @@
|
|
|
1
|
-
from typing import Iterable
|
|
2
|
-
|
|
3
|
-
import structlog
|
|
4
1
|
from sqlalchemy.orm import Load
|
|
5
2
|
|
|
6
3
|
from orchestrator.db.database import BaseModel as DbBaseModel
|
|
7
|
-
from orchestrator.db.loaders import
|
|
4
|
+
from orchestrator.db.loaders import (
|
|
5
|
+
get_query_loaders_for_model_paths,
|
|
6
|
+
)
|
|
8
7
|
from orchestrator.graphql.types import OrchestratorInfo
|
|
9
8
|
from orchestrator.graphql.utils.get_selected_paths import get_selected_paths
|
|
10
9
|
|
|
11
|
-
logger = structlog.get_logger(__name__)
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
def _split_path(query_path: str) -> Iterable[str]:
|
|
15
|
-
yield from (field for field in query_path.split(".") if field != "page")
|
|
16
10
|
|
|
17
|
-
|
|
18
|
-
def get_query_loaders(info: OrchestratorInfo, root_model: type[DbBaseModel]) -> list[Load]:
|
|
11
|
+
def get_query_loaders_for_gql_fields(root_model: type[DbBaseModel], info: OrchestratorInfo) -> list[Load]:
|
|
19
12
|
"""Get sqlalchemy query loaders for the given GraphQL query.
|
|
20
13
|
|
|
21
14
|
Based on the GraphQL query's selected fields, returns the required DB loaders to use
|
|
22
15
|
in SQLALchemy's `.options()` for efficiently quering (nested) relationships.
|
|
23
16
|
"""
|
|
24
|
-
|
|
25
|
-
query_paths = [path.removeprefix("page.") for path in get_selected_paths(info)]
|
|
26
|
-
query_paths.sort(key=lambda x: x.count("."), reverse=True)
|
|
27
|
-
|
|
28
|
-
def get_loader_for_path(query_path: str) -> tuple[str, Load | None]:
|
|
29
|
-
next_model = root_model
|
|
30
|
-
|
|
31
|
-
matched_fields: list[str] = []
|
|
32
|
-
path_loaders: list[AttrLoader] = []
|
|
33
|
-
|
|
34
|
-
for field in _split_path(query_path):
|
|
35
|
-
if not (attr_loaders := lookup_attr_loaders(next_model, field)):
|
|
36
|
-
break
|
|
37
|
-
|
|
38
|
-
matched_fields.append(field)
|
|
39
|
-
path_loaders.extend(attr_loaders)
|
|
40
|
-
next_model = attr_loaders[-1].next_model
|
|
41
|
-
|
|
42
|
-
return ".".join(matched_fields), join_attr_loaders(path_loaders)
|
|
43
|
-
|
|
44
|
-
query_loaders: dict[str, Load] = {}
|
|
45
|
-
|
|
46
|
-
for path in query_paths:
|
|
47
|
-
matched_path, loader = get_loader_for_path(path)
|
|
48
|
-
if not matched_path or not loader or matched_path in query_loaders:
|
|
49
|
-
continue
|
|
50
|
-
if any(known_path.startswith(f"{matched_path}.") for known_path in query_loaders):
|
|
51
|
-
continue
|
|
52
|
-
query_loaders[matched_path] = loader
|
|
17
|
+
model_paths = [path.removeprefix("page.") for path in get_selected_paths(info)]
|
|
53
18
|
|
|
54
|
-
|
|
55
|
-
logger.debug(
|
|
56
|
-
"Generated query loaders",
|
|
57
|
-
root_model=root_model,
|
|
58
|
-
query_paths=query_paths,
|
|
59
|
-
query_loaders=[str(i.path) for i in loaders],
|
|
60
|
-
)
|
|
61
|
-
return loaders
|
|
19
|
+
return get_query_loaders_for_model_paths(root_model, model_paths)
|
|
@@ -70,7 +70,14 @@ def get_all_product_blocks(subscription: dict[str, Any], _tags: list[str] | None
|
|
|
70
70
|
return list(locate_product_block(subscription))
|
|
71
71
|
|
|
72
72
|
|
|
73
|
-
pb_instance_property_keys = (
|
|
73
|
+
pb_instance_property_keys = (
|
|
74
|
+
"id",
|
|
75
|
+
"parent",
|
|
76
|
+
"owner_subscription_id",
|
|
77
|
+
"subscription_instance_id",
|
|
78
|
+
"in_use_by_relations",
|
|
79
|
+
"in_use_by_ids",
|
|
80
|
+
)
|
|
74
81
|
|
|
75
82
|
|
|
76
83
|
async def get_subscription_product_blocks(
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
"""Add postgres function subscription_instance_as_json.
|
|
2
|
+
|
|
3
|
+
Revision ID: 42b3d076a85b
|
|
4
|
+
Revises: bac6be6f2b4f
|
|
5
|
+
Create Date: 2025-03-06 15:03:09.477225
|
|
6
|
+
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
|
|
11
|
+
from alembic import op
|
|
12
|
+
from sqlalchemy import text
|
|
13
|
+
|
|
14
|
+
# revision identifiers, used by Alembic.
|
|
15
|
+
revision = "42b3d076a85b"
|
|
16
|
+
down_revision = "bac6be6f2b4f"
|
|
17
|
+
branch_labels = None
|
|
18
|
+
depends_on = None
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def upgrade() -> None:
|
|
22
|
+
conn = op.get_bind()
|
|
23
|
+
|
|
24
|
+
revision_file_path = Path(__file__)
|
|
25
|
+
with open(revision_file_path.with_suffix(".sql")) as f:
|
|
26
|
+
conn.execute(text(f.read()))
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def downgrade() -> None:
|
|
30
|
+
conn = op.get_bind()
|
|
31
|
+
|
|
32
|
+
conn.execute(text("DROP FUNCTION IF EXISTS subscription_instance_as_json;"))
|
|
33
|
+
conn.execute(text("DROP FUNCTION IF EXISTS subscription_instance_fields_as_json;"))
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
CREATE OR REPLACE FUNCTION subscription_instance_fields_as_json(sub_inst_id uuid)
|
|
2
|
+
RETURNS jsonb
|
|
3
|
+
LANGUAGE sql
|
|
4
|
+
STABLE PARALLEL SAFE AS
|
|
5
|
+
$func$
|
|
6
|
+
select jsonb_object_agg(rts.key, rts.val)
|
|
7
|
+
from (select attr.key,
|
|
8
|
+
attr.val
|
|
9
|
+
from subscription_instances si
|
|
10
|
+
join product_blocks pb ON si.product_block_id = pb.product_block_id
|
|
11
|
+
cross join lateral (
|
|
12
|
+
values ('subscription_instance_id', to_jsonb(si.subscription_instance_id)),
|
|
13
|
+
('owner_subscription_id', to_jsonb(si.subscription_id)),
|
|
14
|
+
('name', to_jsonb(pb.name))
|
|
15
|
+
) as attr(key, val)
|
|
16
|
+
where si.subscription_instance_id = sub_inst_id
|
|
17
|
+
union all
|
|
18
|
+
select rt.resource_type key,
|
|
19
|
+
jsonb_agg(siv.value ORDER BY siv.value ASC) val
|
|
20
|
+
from subscription_instance_values siv
|
|
21
|
+
join resource_types rt on siv.resource_type_id = rt.resource_type_id
|
|
22
|
+
where siv.subscription_instance_id = sub_inst_id
|
|
23
|
+
group by rt.resource_type) as rts
|
|
24
|
+
$func$;
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
CREATE OR REPLACE FUNCTION subscription_instance_as_json(sub_inst_id uuid)
|
|
28
|
+
RETURNS jsonb
|
|
29
|
+
LANGUAGE sql
|
|
30
|
+
STABLE PARALLEL SAFE AS
|
|
31
|
+
$func$
|
|
32
|
+
select subscription_instance_fields_as_json(sub_inst_id) ||
|
|
33
|
+
coalesce(jsonb_object_agg(depends_on.block_name, depends_on.block_instances), '{}'::jsonb)
|
|
34
|
+
from (select sir.domain_model_attr block_name,
|
|
35
|
+
jsonb_agg(subscription_instance_as_json(sir.depends_on_id) ORDER BY sir.order_id ASC) as block_instances
|
|
36
|
+
from subscription_instance_relations sir
|
|
37
|
+
where sir.in_use_by_id = sub_inst_id
|
|
38
|
+
and sir.domain_model_attr is not null
|
|
39
|
+
group by block_name) as depends_on
|
|
40
|
+
$func$;
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
"""add cascade constraint on processes input state.
|
|
2
|
+
|
|
3
|
+
Revision ID: fc5c993a4b4a
|
|
4
|
+
Revises: 42b3d076a85b
|
|
5
|
+
Create Date: 2025-04-09 18:27:31.922214
|
|
6
|
+
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from alembic import op
|
|
10
|
+
|
|
11
|
+
# revision identifiers, used by Alembic.
|
|
12
|
+
revision = "fc5c993a4b4a"
|
|
13
|
+
down_revision = "42b3d076a85b"
|
|
14
|
+
branch_labels = None
|
|
15
|
+
depends_on = None
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def upgrade() -> None:
|
|
19
|
+
# Drop the existing foreign key constraint
|
|
20
|
+
op.drop_constraint("input_states_pid_fkey", "input_states", type_="foreignkey")
|
|
21
|
+
|
|
22
|
+
# Add a new foreign key constraint with cascade delete
|
|
23
|
+
op.create_foreign_key(
|
|
24
|
+
"input_states_pid_fkey",
|
|
25
|
+
"input_states",
|
|
26
|
+
"processes",
|
|
27
|
+
["pid"],
|
|
28
|
+
["pid"],
|
|
29
|
+
ondelete="CASCADE",
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def downgrade() -> None:
|
|
34
|
+
# Drop the cascade foreign key constraint
|
|
35
|
+
op.drop_constraint("input_states_pid_fkey", "input_states", type_="foreignkey")
|
|
36
|
+
|
|
37
|
+
# Recreate the original foreign key constraint without cascade
|
|
38
|
+
op.create_foreign_key(
|
|
39
|
+
"input_states_pid_fkey",
|
|
40
|
+
"input_states",
|
|
41
|
+
"processes",
|
|
42
|
+
["pid"],
|
|
43
|
+
["pid"],
|
|
44
|
+
)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2019-
|
|
1
|
+
# Copyright 2019-2025 SURF, GÉANT, ESnet.
|
|
2
2
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
3
3
|
# you may not use this file except in compliance with the License.
|
|
4
4
|
# You may obtain a copy of the License at
|
|
@@ -24,15 +24,10 @@ from sqlalchemy.exc import SQLAlchemyError
|
|
|
24
24
|
from sqlalchemy.orm import joinedload
|
|
25
25
|
|
|
26
26
|
from nwastdlib.ex import show_ex
|
|
27
|
+
from oauth2_lib.fastapi import OIDCUserModel
|
|
27
28
|
from orchestrator.api.error_handling import raise_status
|
|
28
29
|
from orchestrator.config.assignee import Assignee
|
|
29
|
-
from orchestrator.db import
|
|
30
|
-
EngineSettingsTable,
|
|
31
|
-
ProcessStepTable,
|
|
32
|
-
ProcessSubscriptionTable,
|
|
33
|
-
ProcessTable,
|
|
34
|
-
db,
|
|
35
|
-
)
|
|
30
|
+
from orchestrator.db import EngineSettingsTable, ProcessStepTable, ProcessSubscriptionTable, ProcessTable, db
|
|
36
31
|
from orchestrator.distlock import distlock_manager
|
|
37
32
|
from orchestrator.schemas.engine_settings import WorkerStatus
|
|
38
33
|
from orchestrator.services.input_state import store_input_state
|
|
@@ -414,10 +409,15 @@ def _run_process_async(process_id: UUID, f: Callable) -> UUID:
|
|
|
414
409
|
return process_id
|
|
415
410
|
|
|
416
411
|
|
|
412
|
+
def error_message_unauthorized(workflow_key: str) -> str:
|
|
413
|
+
return f"User is not authorized to execute '{workflow_key}' workflow"
|
|
414
|
+
|
|
415
|
+
|
|
417
416
|
def create_process(
|
|
418
417
|
workflow_key: str,
|
|
419
418
|
user_inputs: list[State] | None = None,
|
|
420
419
|
user: str = SYSTEM_USER,
|
|
420
|
+
user_model: OIDCUserModel | None = None,
|
|
421
421
|
) -> ProcessStat:
|
|
422
422
|
# ATTENTION!! When modifying this function make sure you make similar changes to `run_workflow` in the test code
|
|
423
423
|
|
|
@@ -430,6 +430,9 @@ def create_process(
|
|
|
430
430
|
if not workflow:
|
|
431
431
|
raise_status(HTTPStatus.NOT_FOUND, "Workflow does not exist")
|
|
432
432
|
|
|
433
|
+
if not workflow.authorize_callback(user_model):
|
|
434
|
+
raise_status(HTTPStatus.FORBIDDEN, error_message_unauthorized(workflow_key))
|
|
435
|
+
|
|
433
436
|
initial_state = {
|
|
434
437
|
"process_id": process_id,
|
|
435
438
|
"reporter": user,
|
|
@@ -449,6 +452,7 @@ def create_process(
|
|
|
449
452
|
state=Success(state | initial_state),
|
|
450
453
|
log=workflow.steps,
|
|
451
454
|
current_user=user,
|
|
455
|
+
user_model=user_model,
|
|
452
456
|
)
|
|
453
457
|
|
|
454
458
|
_db_create_process(pstat)
|
|
@@ -460,9 +464,12 @@ def thread_start_process(
|
|
|
460
464
|
workflow_key: str,
|
|
461
465
|
user_inputs: list[State] | None = None,
|
|
462
466
|
user: str = SYSTEM_USER,
|
|
467
|
+
user_model: OIDCUserModel | None = None,
|
|
463
468
|
broadcast_func: BroadcastFunc | None = None,
|
|
464
469
|
) -> UUID:
|
|
465
470
|
pstat = create_process(workflow_key, user_inputs=user_inputs, user=user)
|
|
471
|
+
if not pstat.workflow.authorize_callback(user_model):
|
|
472
|
+
raise_status(HTTPStatus.FORBIDDEN, error_message_unauthorized(workflow_key))
|
|
466
473
|
|
|
467
474
|
_safe_logstep_with_func = partial(safe_logstep, broadcast_func=broadcast_func)
|
|
468
475
|
return _run_process_async(pstat.process_id, lambda: runwf(pstat, _safe_logstep_with_func))
|
|
@@ -472,6 +479,7 @@ def start_process(
|
|
|
472
479
|
workflow_key: str,
|
|
473
480
|
user_inputs: list[State] | None = None,
|
|
474
481
|
user: str = SYSTEM_USER,
|
|
482
|
+
user_model: OIDCUserModel | None = None,
|
|
475
483
|
broadcast_func: BroadcastFunc | None = None,
|
|
476
484
|
) -> UUID:
|
|
477
485
|
"""Start a process for workflow.
|
|
@@ -480,6 +488,7 @@ def start_process(
|
|
|
480
488
|
workflow_key: name of workflow
|
|
481
489
|
user_inputs: List of form inputs from frontend
|
|
482
490
|
user: User who starts this process
|
|
491
|
+
user_model: Full OIDCUserModel with claims, etc
|
|
483
492
|
broadcast_func: Optional function to broadcast process data
|
|
484
493
|
|
|
485
494
|
Returns:
|
|
@@ -487,7 +496,9 @@ def start_process(
|
|
|
487
496
|
|
|
488
497
|
"""
|
|
489
498
|
start_func = get_execution_context()["start"]
|
|
490
|
-
return start_func(
|
|
499
|
+
return start_func(
|
|
500
|
+
workflow_key, user_inputs=user_inputs, user=user, user_model=user_model, broadcast_func=broadcast_func
|
|
501
|
+
)
|
|
491
502
|
|
|
492
503
|
|
|
493
504
|
def thread_resume_process(
|
|
@@ -495,6 +506,7 @@ def thread_resume_process(
|
|
|
495
506
|
*,
|
|
496
507
|
user_inputs: list[State] | None = None,
|
|
497
508
|
user: str | None = None,
|
|
509
|
+
user_model: OIDCUserModel | None = None,
|
|
498
510
|
broadcast_func: BroadcastFunc | None = None,
|
|
499
511
|
) -> UUID:
|
|
500
512
|
# ATTENTION!! When modifying this function make sure you make similar changes to `resume_workflow` in the test code
|
|
@@ -503,6 +515,8 @@ def thread_resume_process(
|
|
|
503
515
|
user_inputs = [{}]
|
|
504
516
|
|
|
505
517
|
pstat = load_process(process)
|
|
518
|
+
if not pstat.workflow.authorize_callback(user_model):
|
|
519
|
+
raise_status(HTTPStatus.FORBIDDEN, error_message_unauthorized(str(process.workflow_name)))
|
|
506
520
|
|
|
507
521
|
if pstat.workflow == removed_workflow:
|
|
508
522
|
raise ValueError("This workflow cannot be resumed")
|
|
@@ -542,6 +556,7 @@ def resume_process(
|
|
|
542
556
|
*,
|
|
543
557
|
user_inputs: list[State] | None = None,
|
|
544
558
|
user: str | None = None,
|
|
559
|
+
user_model: OIDCUserModel | None = None,
|
|
545
560
|
broadcast_func: BroadcastFunc | None = None,
|
|
546
561
|
) -> UUID:
|
|
547
562
|
"""Resume a failed or suspended process.
|
|
@@ -550,6 +565,7 @@ def resume_process(
|
|
|
550
565
|
process: Process from database
|
|
551
566
|
user_inputs: Optional user input from forms
|
|
552
567
|
user: user who resumed this process
|
|
568
|
+
user_model: OIDCUserModel of user who resumed this process
|
|
553
569
|
broadcast_func: Optional function to broadcast process data
|
|
554
570
|
|
|
555
571
|
Returns:
|
|
@@ -557,6 +573,9 @@ def resume_process(
|
|
|
557
573
|
|
|
558
574
|
"""
|
|
559
575
|
pstat = load_process(process)
|
|
576
|
+
if not pstat.workflow.authorize_callback(user_model):
|
|
577
|
+
raise_status(HTTPStatus.FORBIDDEN, error_message_unauthorized(str(process.workflow_name)))
|
|
578
|
+
|
|
560
579
|
try:
|
|
561
580
|
post_form(pstat.log[0].form, pstat.state.unwrap(), user_inputs=user_inputs or [])
|
|
562
581
|
except FormValidationError:
|
|
@@ -22,6 +22,7 @@ from uuid import UUID
|
|
|
22
22
|
|
|
23
23
|
import more_itertools
|
|
24
24
|
import structlog
|
|
25
|
+
from more_itertools import first
|
|
25
26
|
from sqlalchemy import Text, cast, not_, select
|
|
26
27
|
from sqlalchemy.exc import SQLAlchemyError
|
|
27
28
|
from sqlalchemy.orm import Query, aliased, joinedload
|
|
@@ -41,7 +42,11 @@ from orchestrator.db.models import (
|
|
|
41
42
|
SubscriptionInstanceRelationTable,
|
|
42
43
|
SubscriptionMetadataTable,
|
|
43
44
|
)
|
|
45
|
+
from orchestrator.db.queries.subscription import (
|
|
46
|
+
eagerload_all_subscription_instances_only_inuseby,
|
|
47
|
+
)
|
|
44
48
|
from orchestrator.domain.base import SubscriptionModel
|
|
49
|
+
from orchestrator.domain.context_cache import cache_subscription_models
|
|
45
50
|
from orchestrator.targets import Target
|
|
46
51
|
from orchestrator.types import SubscriptionLifecycle
|
|
47
52
|
from orchestrator.utils.datetime import nowtz
|
|
@@ -594,13 +599,15 @@ def convert_to_in_use_by_relation(obj: Any) -> dict[str, str]:
|
|
|
594
599
|
|
|
595
600
|
def build_extended_domain_model(subscription_model: SubscriptionModel) -> dict:
|
|
596
601
|
"""Create a subscription dict from the SubscriptionModel with additional keys."""
|
|
597
|
-
|
|
602
|
+
from orchestrator.settings import app_settings
|
|
603
|
+
|
|
604
|
+
stmt = select(SubscriptionCustomerDescriptionTable).where(
|
|
598
605
|
SubscriptionCustomerDescriptionTable.subscription_id == subscription_model.subscription_id
|
|
599
606
|
)
|
|
600
607
|
customer_descriptions = list(db.session.scalars(stmt))
|
|
601
608
|
|
|
602
|
-
|
|
603
|
-
|
|
609
|
+
with cache_subscription_models():
|
|
610
|
+
subscription = subscription_model.model_dump()
|
|
604
611
|
|
|
605
612
|
def inject_in_use_by_ids(path_to_block: str) -> None:
|
|
606
613
|
if not (in_use_by_subs := getattr_in(subscription_model, f"{path_to_block}.in_use_by")):
|
|
@@ -611,15 +618,38 @@ def build_extended_domain_model(subscription_model: SubscriptionModel) -> dict:
|
|
|
611
618
|
update_in(subscription, f"{path_to_block}.in_use_by_ids", in_use_by_ids)
|
|
612
619
|
update_in(subscription, f"{path_to_block}.in_use_by_relations", in_use_by_relations)
|
|
613
620
|
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
621
|
+
if app_settings.ENABLE_SUBSCRIPTION_MODEL_OPTIMIZATIONS:
|
|
622
|
+
# TODO #900 remove toggle and make this path the default
|
|
623
|
+
# query all subscription instances and inject the in_use_by_ids/in_use_by_relations into the subscription dict.
|
|
624
|
+
instance_to_in_use_by = {
|
|
625
|
+
instance.subscription_instance_id: instance.in_use_by
|
|
626
|
+
for instance in eagerload_all_subscription_instances_only_inuseby(subscription_model.subscription_id)
|
|
627
|
+
}
|
|
628
|
+
inject_in_use_by_ids_v2(subscription, instance_to_in_use_by)
|
|
629
|
+
else:
|
|
630
|
+
# find all product blocks, check if they have in_use_by and inject the in_use_by_ids into the subscription dict.
|
|
631
|
+
for path in product_block_paths(subscription):
|
|
632
|
+
inject_in_use_by_ids(path)
|
|
617
633
|
|
|
618
634
|
subscription["customer_descriptions"] = customer_descriptions
|
|
619
635
|
|
|
620
636
|
return subscription
|
|
621
637
|
|
|
622
638
|
|
|
639
|
+
def inject_in_use_by_ids_v2(dikt: dict, instance_to_in_use_by: dict[UUID, Sequence[SubscriptionInstanceTable]]) -> None:
|
|
640
|
+
for value in dikt.values():
|
|
641
|
+
if isinstance(value, dict):
|
|
642
|
+
inject_in_use_by_ids_v2(value, instance_to_in_use_by)
|
|
643
|
+
elif isinstance(value, list) and isinstance(first(value, None), dict):
|
|
644
|
+
for item in value:
|
|
645
|
+
inject_in_use_by_ids_v2(item, instance_to_in_use_by)
|
|
646
|
+
|
|
647
|
+
if subscription_instance_id := dikt.get("subscription_instance_id"):
|
|
648
|
+
in_use_by_subs = instance_to_in_use_by[subscription_instance_id]
|
|
649
|
+
dikt["in_use_by_ids"] = [i.subscription_instance_id for i in in_use_by_subs]
|
|
650
|
+
dikt["in_use_by_relations"] = [convert_to_in_use_by_relation(instance) for instance in in_use_by_subs]
|
|
651
|
+
|
|
652
|
+
|
|
623
653
|
def format_special_types(subscription: dict) -> dict:
|
|
624
654
|
"""Modifies the subscription dict in-place, formatting special types to string.
|
|
625
655
|
|
orchestrator/settings.py
CHANGED
|
@@ -87,6 +87,9 @@ class AppSettings(BaseSettings):
|
|
|
87
87
|
ENABLE_GRAPHQL_STATS_EXTENSION: bool = False
|
|
88
88
|
VALIDATE_OUT_OF_SYNC_SUBSCRIPTIONS: bool = False
|
|
89
89
|
FILTER_BY_MODE: Literal["partial", "exact"] = "exact"
|
|
90
|
+
ENABLE_SUBSCRIPTION_MODEL_OPTIMIZATIONS: bool = (
|
|
91
|
+
True # True=ignore cache + optimized DB queries; False=use cache + unoptimized DB queries. Remove in #900
|
|
92
|
+
)
|
|
90
93
|
|
|
91
94
|
|
|
92
95
|
app_settings = AppSettings()
|
orchestrator/utils/functional.py
CHANGED
|
@@ -241,3 +241,12 @@ def to_ranges(i: Iterable[int]) -> Iterable[range]:
|
|
|
241
241
|
for _, g in itertools.groupby(enumerate(i), lambda t: t[1] - t[0]):
|
|
242
242
|
group = list(g)
|
|
243
243
|
yield range(group[0][1], group[-1][1] + 1)
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
K = TypeVar("K")
|
|
247
|
+
V = TypeVar("V")
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
def group_by_key(items: Iterable[tuple[K, V]]) -> dict[K, list[V]]:
|
|
251
|
+
groups = itertools.groupby(items, key=lambda item: item[0])
|
|
252
|
+
return {key: [item[1] for item in group] for key, group in groups}
|
orchestrator/utils/redis.py
CHANGED
|
@@ -55,6 +55,12 @@ def to_redis(subscription: dict[str, Any]) -> str | None:
|
|
|
55
55
|
|
|
56
56
|
def from_redis(subscription_id: UUID) -> tuple[PY_JSON_TYPES, str] | None:
|
|
57
57
|
log = logger.bind(subscription_id=subscription_id)
|
|
58
|
+
|
|
59
|
+
if app_settings.ENABLE_SUBSCRIPTION_MODEL_OPTIMIZATIONS:
|
|
60
|
+
# TODO #900 remove toggle and remove usage of this function in get_subscription_dict
|
|
61
|
+
log.warning("Using SubscriptionModel optimization, not loading subscription from cache")
|
|
62
|
+
return None
|
|
63
|
+
|
|
58
64
|
if caching_models_enabled():
|
|
59
65
|
log.debug("Try to retrieve subscription from cache")
|
|
60
66
|
obj = cache.get(f"orchestrator:domain:{subscription_id}")
|
orchestrator/workflow.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2019-2025 SURF, GÉANT.
|
|
1
|
+
# Copyright 2019-2025 SURF, GÉANT, ESnet.
|
|
2
2
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
3
3
|
# you may not use this file except in compliance with the License.
|
|
4
4
|
# You may obtain a copy of the License at
|
|
@@ -39,6 +39,7 @@ from structlog.contextvars import bound_contextvars
|
|
|
39
39
|
from structlog.stdlib import BoundLogger
|
|
40
40
|
|
|
41
41
|
from nwastdlib import const, identity
|
|
42
|
+
from oauth2_lib.fastapi import OIDCUserModel
|
|
42
43
|
from orchestrator.config.assignee import Assignee
|
|
43
44
|
from orchestrator.db import db, transactional
|
|
44
45
|
from orchestrator.services.settings import get_engine_settings
|
|
@@ -89,6 +90,7 @@ class Workflow(Protocol):
|
|
|
89
90
|
__qualname__: str
|
|
90
91
|
name: str
|
|
91
92
|
description: str
|
|
93
|
+
authorize_callback: Callable[[OIDCUserModel | None], bool]
|
|
92
94
|
initial_input_form: InputFormGenerator | None = None
|
|
93
95
|
target: Target
|
|
94
96
|
steps: StepList
|
|
@@ -178,12 +180,18 @@ def _handle_simple_input_form_generator(f: StateInputStepFunc) -> StateInputForm
|
|
|
178
180
|
return form_generator
|
|
179
181
|
|
|
180
182
|
|
|
183
|
+
def allow(_: OIDCUserModel | None) -> bool:
|
|
184
|
+
"""Default function to return True in absence of user-defined authorize function."""
|
|
185
|
+
return True
|
|
186
|
+
|
|
187
|
+
|
|
181
188
|
def make_workflow(
|
|
182
189
|
f: Callable,
|
|
183
190
|
description: str,
|
|
184
191
|
initial_input_form: InputStepFunc | None,
|
|
185
192
|
target: Target,
|
|
186
193
|
steps: StepList,
|
|
194
|
+
authorize_callback: Callable[[OIDCUserModel | None], bool] | None = None,
|
|
187
195
|
) -> Workflow:
|
|
188
196
|
@functools.wraps(f)
|
|
189
197
|
def wrapping_function() -> NoReturn:
|
|
@@ -193,6 +201,7 @@ def make_workflow(
|
|
|
193
201
|
|
|
194
202
|
wrapping_function.name = f.__name__ # default, will be changed by LazyWorkflowInstance
|
|
195
203
|
wrapping_function.description = description
|
|
204
|
+
wrapping_function.authorize_callback = allow if authorize_callback is None else authorize_callback
|
|
196
205
|
|
|
197
206
|
if initial_input_form is None:
|
|
198
207
|
# We always need a form to prevent starting a workflow when no input is needed.
|
|
@@ -214,7 +223,11 @@ def step(name: str) -> Callable[[StepFunc], Step]:
|
|
|
214
223
|
def decorator(func: StepFunc) -> Step:
|
|
215
224
|
@functools.wraps(func)
|
|
216
225
|
def wrapper(state: State) -> Process:
|
|
217
|
-
with bound_contextvars(
|
|
226
|
+
with bound_contextvars(
|
|
227
|
+
func=func.__qualname__,
|
|
228
|
+
workflow_name=state.get("workflow_name"),
|
|
229
|
+
process_id=state.get("process_id"),
|
|
230
|
+
):
|
|
218
231
|
step_in_inject_args = inject_args(func)
|
|
219
232
|
try:
|
|
220
233
|
with transactional(db, logger):
|
|
@@ -239,7 +252,11 @@ def retrystep(name: str) -> Callable[[StepFunc], Step]:
|
|
|
239
252
|
def decorator(func: StepFunc) -> Step:
|
|
240
253
|
@functools.wraps(func)
|
|
241
254
|
def wrapper(state: State) -> Process:
|
|
242
|
-
with bound_contextvars(
|
|
255
|
+
with bound_contextvars(
|
|
256
|
+
func=func.__qualname__,
|
|
257
|
+
workflow_name=state.get("workflow_name"),
|
|
258
|
+
process_id=state.get("process_id"),
|
|
259
|
+
):
|
|
243
260
|
step_in_inject_args = inject_args(func)
|
|
244
261
|
try:
|
|
245
262
|
with transactional(db, logger):
|
|
@@ -459,7 +476,10 @@ def focussteps(key: str) -> Callable[[Step | StepList], StepList]:
|
|
|
459
476
|
|
|
460
477
|
|
|
461
478
|
def workflow(
|
|
462
|
-
description: str,
|
|
479
|
+
description: str,
|
|
480
|
+
initial_input_form: InputStepFunc | None = None,
|
|
481
|
+
target: Target = Target.SYSTEM,
|
|
482
|
+
authorize_callback: Callable[[OIDCUserModel | None], bool] | None = None,
|
|
463
483
|
) -> Callable[[Callable[[], StepList]], Workflow]:
|
|
464
484
|
"""Transform an initial_input_form and a step list into a workflow.
|
|
465
485
|
|
|
@@ -479,7 +499,9 @@ def workflow(
|
|
|
479
499
|
initial_input_form_in_form_inject_args = form_inject_args(initial_input_form)
|
|
480
500
|
|
|
481
501
|
def _workflow(f: Callable[[], StepList]) -> Workflow:
|
|
482
|
-
return make_workflow(
|
|
502
|
+
return make_workflow(
|
|
503
|
+
f, description, initial_input_form_in_form_inject_args, target, f(), authorize_callback=authorize_callback
|
|
504
|
+
)
|
|
483
505
|
|
|
484
506
|
return _workflow
|
|
485
507
|
|
|
@@ -491,13 +513,14 @@ class ProcessStat:
|
|
|
491
513
|
state: Process
|
|
492
514
|
log: StepList
|
|
493
515
|
current_user: str
|
|
516
|
+
user_model: OIDCUserModel | None = None
|
|
494
517
|
|
|
495
518
|
def update(self, **vs: Any) -> ProcessStat:
|
|
496
519
|
"""Update ProcessStat.
|
|
497
520
|
|
|
498
521
|
>>> pstat = ProcessStat('', None, {}, [], "")
|
|
499
522
|
>>> pstat.update(state={"a": "b"})
|
|
500
|
-
ProcessStat(process_id='', workflow=None, state={'a': 'b'}, log=[], current_user='')
|
|
523
|
+
ProcessStat(process_id='', workflow=None, state={'a': 'b'}, log=[], current_user='', user_model=None)
|
|
501
524
|
"""
|
|
502
525
|
return ProcessStat(**{**asdict(self), **vs})
|
|
503
526
|
|