orchestrator-core 3.1.2rc4__py3-none-any.whl → 3.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. orchestrator/__init__.py +1 -1
  2. orchestrator/api/api_v1/endpoints/processes.py +6 -9
  3. orchestrator/cli/generator/generator/workflow.py +13 -1
  4. orchestrator/cli/generator/templates/modify_product.j2 +9 -0
  5. orchestrator/db/__init__.py +2 -0
  6. orchestrator/db/loaders.py +51 -3
  7. orchestrator/db/models.py +13 -0
  8. orchestrator/db/queries/__init__.py +0 -0
  9. orchestrator/db/queries/subscription.py +85 -0
  10. orchestrator/db/queries/subscription_instance.py +28 -0
  11. orchestrator/domain/base.py +162 -44
  12. orchestrator/domain/context_cache.py +62 -0
  13. orchestrator/domain/helpers.py +41 -1
  14. orchestrator/domain/subscription_instance_transform.py +114 -0
  15. orchestrator/graphql/resolvers/process.py +3 -3
  16. orchestrator/graphql/resolvers/product.py +2 -2
  17. orchestrator/graphql/resolvers/product_block.py +2 -2
  18. orchestrator/graphql/resolvers/resource_type.py +2 -2
  19. orchestrator/graphql/resolvers/workflow.py +2 -2
  20. orchestrator/graphql/utils/get_query_loaders.py +6 -48
  21. orchestrator/graphql/utils/get_subscription_product_blocks.py +8 -1
  22. orchestrator/migrations/versions/schema/2025-03-06_42b3d076a85b_subscription_instance_as_json_function.py +33 -0
  23. orchestrator/migrations/versions/schema/2025-03-06_42b3d076a85b_subscription_instance_as_json_function.sql +40 -0
  24. orchestrator/migrations/versions/schema/2025-04-09_fc5c993a4b4a_add_cascade_constraint_on_processes_.py +44 -0
  25. orchestrator/services/processes.py +28 -9
  26. orchestrator/services/subscriptions.py +36 -6
  27. orchestrator/settings.py +3 -0
  28. orchestrator/utils/functional.py +9 -0
  29. orchestrator/utils/redis.py +6 -0
  30. orchestrator/workflow.py +29 -6
  31. orchestrator/workflows/utils.py +40 -5
  32. {orchestrator_core-3.1.2rc4.dist-info → orchestrator_core-3.2.0.dist-info}/METADATA +9 -8
  33. {orchestrator_core-3.1.2rc4.dist-info → orchestrator_core-3.2.0.dist-info}/RECORD +36 -28
  34. /orchestrator/migrations/versions/schema/{2025-10-19_4fjdn13f83ga_add_validate_product_type_task.py → 2025-01-19_4fjdn13f83ga_add_validate_product_type_task.py} +0 -0
  35. {orchestrator_core-3.1.2rc4.dist-info → orchestrator_core-3.2.0.dist-info}/WHEEL +0 -0
  36. {orchestrator_core-3.1.2rc4.dist-info → orchestrator_core-3.2.0.dist-info}/licenses/LICENSE +0 -0
@@ -36,7 +36,7 @@ from more_itertools import bucket, first, flatten, one, only
36
36
  from pydantic import BaseModel, ConfigDict, Field, ValidationError
37
37
  from pydantic.fields import PrivateAttr
38
38
  from sqlalchemy import select
39
- from sqlalchemy.orm import selectinload
39
+ from sqlalchemy.orm import joinedload, selectinload
40
40
 
41
41
  from orchestrator.db import (
42
42
  ProductBlockTable,
@@ -47,13 +47,19 @@ from orchestrator.db import (
47
47
  SubscriptionTable,
48
48
  db,
49
49
  )
50
- from orchestrator.domain.helpers import _to_product_block_field_type_iterable
50
+ from orchestrator.db.queries.subscription_instance import get_subscription_instance_dict
51
+ from orchestrator.domain.helpers import (
52
+ _to_product_block_field_type_iterable,
53
+ get_root_blocks_to_instance_ids,
54
+ no_private_attrs,
55
+ )
51
56
  from orchestrator.domain.lifecycle import (
52
57
  ProductLifecycle,
53
58
  lookup_specialized_type,
54
59
  register_specialized_type,
55
60
  validate_lifecycle_status,
56
61
  )
62
+ from orchestrator.domain.subscription_instance_transform import field_transformation_rules, transform_instance_fields
57
63
  from orchestrator.services.products import get_product_by_id
58
64
  from orchestrator.types import (
59
65
  SAFE_USED_BY_TRANSITIONS_FOR_STATUS,
@@ -98,6 +104,12 @@ class DomainModel(BaseModel):
98
104
  def __init_subclass__(cls, *args: Any, lifecycle: list[SubscriptionLifecycle] | None = None, **kwargs: Any) -> None:
99
105
  pass
100
106
 
107
+ def __eq__(self, other: Any) -> bool:
108
+ # PrivateAttr fields are excluded from both objects during the equality check.
109
+ # Added for #652 primarily because ProductBlockModel._db_model is now lazy loaded.
110
+ with no_private_attrs(self), no_private_attrs(other):
111
+ return super().__eq__(other)
112
+
101
113
  @classmethod
102
114
  def __pydantic_init_subclass__(
103
115
  cls,
@@ -310,12 +322,7 @@ class DomainModel(BaseModel):
310
322
  for product_block_field_name, product_block_field_type in cls._product_block_fields_.items():
311
323
  filter_func = match_domain_model_attr_if_possible(product_block_field_name)
312
324
 
313
- product_block_model: Any = product_block_field_type
314
- if is_list_type(product_block_field_type):
315
- _origin, args = get_origin_and_args(product_block_field_type)
316
- product_block_model = one(args)
317
-
318
- possible_product_block_types = get_possible_product_block_types(product_block_model)
325
+ possible_product_block_types = flatten_product_block_types(product_block_field_type)
319
326
  field_type_names = list(possible_product_block_types.keys())
320
327
  filtered_instances = flatten([grouped_instances.get(name, []) for name in field_type_names])
321
328
  instance_list = list(filter(filter_func, filtered_instances))
@@ -458,6 +465,15 @@ class DomainModel(BaseModel):
458
465
  raise ValueError(f"Cannot link the same subscription instance multiple times: {details}")
459
466
 
460
467
 
468
+ def flatten_product_block_types(product_block_field_type: Any) -> dict[str, type["ProductBlockModel"]]:
469
+ """Extract product block types and return mapping of product block names to product block classes."""
470
+ product_block_model: Any = product_block_field_type
471
+ if is_list_type(product_block_field_type):
472
+ _origin, args = get_origin_and_args(product_block_field_type)
473
+ product_block_model = one(args)
474
+ return get_possible_product_block_types(product_block_model)
475
+
476
+
461
477
  def get_depends_on_product_block_type_list(
462
478
  product_block_types: dict[str, type["ProductBlockModel"] | tuple[type["ProductBlockModel"]]],
463
479
  ) -> list[type["ProductBlockModel"]]:
@@ -522,7 +538,7 @@ class ProductBlockModel(DomainModel):
522
538
  product_block_id: ClassVar[UUID]
523
539
  description: ClassVar[str]
524
540
  tag: ClassVar[str]
525
- _db_model: SubscriptionInstanceTable = PrivateAttr()
541
+ _db_model: SubscriptionInstanceTable | None = PrivateAttr(default=None)
526
542
 
527
543
  # Product block name. This needs to be an instance var because its part of the API (we expose it to the frontend)
528
544
  # Is actually optional since abstract classes don't have it.
@@ -681,7 +697,7 @@ class ProductBlockModel(DomainModel):
681
697
  **sub_instances,
682
698
  **kwargs,
683
699
  )
684
- model._db_model = db_model
700
+ model.db_model = db_model
685
701
  return model
686
702
 
687
703
  @classmethod
@@ -737,7 +753,7 @@ class ProductBlockModel(DomainModel):
737
753
 
738
754
  cls._fix_pb_data()
739
755
  model = cls(**data)
740
- model._db_model = other._db_model
756
+ model.db_model = other.db_model
741
757
  return model
742
758
 
743
759
  @classmethod
@@ -795,7 +811,7 @@ class ProductBlockModel(DomainModel):
795
811
  **instance_values, # type: ignore
796
812
  **sub_instances,
797
813
  )
798
- model._db_model = subscription_instance
814
+ model.db_model = subscription_instance
799
815
 
800
816
  return model
801
817
  except ValidationError:
@@ -913,14 +929,15 @@ class ProductBlockModel(DomainModel):
913
929
 
914
930
  # If this is a "foreign" instance we just stop saving and return it so only its relation is saved
915
931
  # We should not touch these themselves
916
- if self.subscription and subscription_instance.subscription_id != subscription_id:
932
+ if self.owner_subscription_id != subscription_id:
917
933
  return [], subscription_instance
918
934
 
919
- self._db_model = subscription_instance
920
- else:
921
- subscription_instance = self._db_model
935
+ self.db_model = subscription_instance
936
+ elif subscription_instance := self.db_model:
922
937
  # We only need to add to the session if the subscription_instance does not exist.
923
938
  db.session.add(subscription_instance)
939
+ else:
940
+ raise ValueError("Cannot save ProductBlockModel without a db_model")
924
941
 
925
942
  subscription_instance.subscription_id = subscription_id
926
943
 
@@ -947,22 +964,32 @@ class ProductBlockModel(DomainModel):
947
964
  return sub_instances + [subscription_instance], subscription_instance
948
965
 
949
966
  @property
950
- def subscription(self) -> SubscriptionTable:
951
- return self.db_model.subscription
967
+ def subscription(self) -> SubscriptionTable | None:
968
+ return self.db_model.subscription if self.db_model else None
952
969
 
953
970
  @property
954
- def db_model(self) -> SubscriptionInstanceTable:
971
+ def db_model(self) -> SubscriptionInstanceTable | None:
972
+ if not self._db_model:
973
+ self._db_model = db.session.execute(
974
+ select(SubscriptionInstanceTable).where(
975
+ SubscriptionInstanceTable.subscription_instance_id == self.subscription_instance_id
976
+ )
977
+ ).scalar_one_or_none()
955
978
  return self._db_model
956
979
 
980
+ @db_model.setter
981
+ def db_model(self, value: SubscriptionInstanceTable) -> None:
982
+ self._db_model = value
983
+
957
984
  @property
958
- def in_use_by(self) -> list[SubscriptionInstanceTable]:
985
+ def in_use_by(self) -> list[SubscriptionInstanceTable]: # TODO check where used, might need eagerloading
959
986
  """This provides a list of product blocks that depend on this product block."""
960
- return self._db_model.in_use_by
987
+ return self.db_model.in_use_by if self.db_model else []
961
988
 
962
989
  @property
963
- def depends_on(self) -> list[SubscriptionInstanceTable]:
990
+ def depends_on(self) -> list[SubscriptionInstanceTable]: # TODO check where used, might need eagerloading
964
991
  """This provides a list of product blocks that this product block depends on."""
965
- return self._db_model.depends_on
992
+ return self.db_model.depends_on if self.db_model else []
966
993
 
967
994
 
968
995
  class ProductModel(BaseModel):
@@ -1010,9 +1037,11 @@ class SubscriptionModel(DomainModel):
1010
1037
  >>> SubscriptionInactive.from_subscription(subscription_id) # doctest:+SKIP
1011
1038
  """
1012
1039
 
1040
+ __model_dump_cache__: ClassVar[dict[UUID, "SubscriptionModel"] | None] = None
1041
+
1013
1042
  product: ProductModel
1014
1043
  customer_id: str
1015
- _db_model: SubscriptionTable = PrivateAttr()
1044
+ _db_model: SubscriptionTable | None = PrivateAttr(default=None)
1016
1045
  subscription_id: UUID = Field(default_factory=uuid4) # pragma: no mutate
1017
1046
  description: str = "Initial subscription" # pragma: no mutate
1018
1047
  status: SubscriptionLifecycle = SubscriptionLifecycle.INITIAL # pragma: no mutate
@@ -1103,6 +1132,63 @@ class SubscriptionModel(DomainModel):
1103
1132
 
1104
1133
  return missing_data
1105
1134
 
1135
+ @classmethod
1136
+ def _load_root_instances(
1137
+ cls,
1138
+ subscription_id: UUID | UUIDstr,
1139
+ ) -> dict[str, Optional[dict] | list[dict]]:
1140
+ """Load root subscription instance(s) for this subscription model.
1141
+
1142
+ When a new subscription model is loaded from an existing subscription, this function loads the entire root
1143
+ subscription instance(s) from database using an optimized postgres function. The result of that function
1144
+ is used to instantiate the root product block(s).
1145
+
1146
+ The "old" method DomainModel._load_instances() would recursively load subscription instances from the
1147
+ database and individually instantiate nested blocks, more or less "manually" reconstructing the subscription.
1148
+
1149
+ The "new" method SubscriptionModel._load_root_instances() takes a different approach; since it has all
1150
+ data for the root subscription instance, it can rely on Pydantic to instantiate the root block and all
1151
+ nested blocks in one go. This is also why it does not have the params `status` and `match_domain_attr` because
1152
+ this information is already encoded in the domain model of a product.
1153
+ """
1154
+ root_block_instance_ids = get_root_blocks_to_instance_ids(subscription_id)
1155
+
1156
+ root_block_types = {
1157
+ field_name: list(flatten_product_block_types(product_block_type).keys())
1158
+ for field_name, product_block_type in cls._product_block_fields_.items()
1159
+ }
1160
+
1161
+ def get_instances_by_block_names(block_names: list[str]) -> Iterable[dict]:
1162
+ for block_name in block_names:
1163
+ for instance_id in root_block_instance_ids.get(block_name, []):
1164
+ yield get_subscription_instance_dict(instance_id)
1165
+
1166
+ # Map root product block fields to subscription instance(s) dicts
1167
+ instances = {
1168
+ field_name: list(get_instances_by_block_names(block_names))
1169
+ for field_name, block_names in root_block_types.items()
1170
+ }
1171
+
1172
+ # Transform values according to domain models (list[dict] -> dict, add None as default for optionals)
1173
+ rules = {
1174
+ klass.name: field_transformation_rules(klass) for klass in ProductBlockModel.registry.values() if klass.name
1175
+ }
1176
+ for instance_list in instances.values():
1177
+ for instance in instance_list:
1178
+ transform_instance_fields(rules, instance)
1179
+
1180
+ # Support the (theoretical?) usecase of a list of root product blocks
1181
+ def unpack_instance_list(field_name: str, instance_list: list[dict]) -> list[dict] | dict | None:
1182
+ field_type = cls._product_block_fields_[field_name]
1183
+ if is_list_type(field_type):
1184
+ return instance_list
1185
+ return only(instance_list)
1186
+
1187
+ return {
1188
+ field_name: unpack_instance_list(field_name, instance_list)
1189
+ for field_name, instance_list in instances.items()
1190
+ }
1191
+
1106
1192
  @classmethod
1107
1193
  def from_product_id(
1108
1194
  cls: type[S],
@@ -1168,7 +1254,7 @@ class SubscriptionModel(DomainModel):
1168
1254
  **fixed_inputs,
1169
1255
  **instances,
1170
1256
  )
1171
- model._db_model = subscription
1257
+ model.db_model = subscription
1172
1258
  return model
1173
1259
 
1174
1260
  @classmethod
@@ -1201,17 +1287,26 @@ class SubscriptionModel(DomainModel):
1201
1287
  data["end_date"] = nowtz()
1202
1288
 
1203
1289
  model = cls(**data)
1204
- model._db_model = other._db_model
1290
+ model.db_model = other._db_model
1205
1291
 
1206
1292
  return model
1207
1293
 
1208
1294
  # Some common functions shared by from_other_product and from_subscription
1209
1295
  @classmethod
1210
- def _get_subscription(cls: type[S], subscription_id: UUID | UUIDstr) -> Any:
1211
- return db.session.get(
1212
- SubscriptionTable,
1213
- subscription_id,
1214
- options=[
1296
+ def _get_subscription(cls: type[S], subscription_id: UUID | UUIDstr) -> SubscriptionTable | None:
1297
+ from orchestrator.settings import app_settings
1298
+
1299
+ if not isinstance(subscription_id, UUID | UUIDstr):
1300
+ raise TypeError(f"subscription_id is of type {type(subscription_id)} instead of UUID | UUIDstr")
1301
+
1302
+ if app_settings.ENABLE_SUBSCRIPTION_MODEL_OPTIMIZATIONS:
1303
+ # TODO #900 remove toggle and make this path the default
1304
+ loaders = [
1305
+ joinedload(SubscriptionTable.product).selectinload(ProductTable.fixed_inputs),
1306
+ ]
1307
+
1308
+ else:
1309
+ loaders = [
1215
1310
  selectinload(SubscriptionTable.instances)
1216
1311
  .joinedload(SubscriptionInstanceTable.product_block)
1217
1312
  .selectinload(ProductBlockTable.resource_types),
@@ -1219,8 +1314,9 @@ class SubscriptionModel(DomainModel):
1219
1314
  SubscriptionInstanceTable.in_use_by_block_relations
1220
1315
  ),
1221
1316
  selectinload(SubscriptionTable.instances).selectinload(SubscriptionInstanceTable.values),
1222
- ],
1223
- )
1317
+ ]
1318
+
1319
+ return db.session.get(SubscriptionTable, subscription_id, options=loaders)
1224
1320
 
1225
1321
  @classmethod
1226
1322
  def _to_product_model(cls: type[S], product: ProductTable) -> ProductModel:
@@ -1246,7 +1342,9 @@ class SubscriptionModel(DomainModel):
1246
1342
  if not db_product:
1247
1343
  raise KeyError("Could not find a product for the given product_id")
1248
1344
 
1249
- subscription = cls._get_subscription(old_instantiation.subscription_id)
1345
+ old_subscription_id = old_instantiation.subscription_id
1346
+ if not (subscription := cls._get_subscription(old_subscription_id)):
1347
+ raise ValueError(f"Subscription with id: {old_subscription_id}, does not exist")
1250
1348
  product = cls._to_product_model(db_product)
1251
1349
 
1252
1350
  status = SubscriptionLifecycle(subscription.status)
@@ -1266,6 +1364,7 @@ class SubscriptionModel(DomainModel):
1266
1364
  name, product_block = new_root
1267
1365
  instances = {name: product_block}
1268
1366
  else:
1367
+ # TODO test using cls._load_root_instances() here as well
1269
1368
  instances = cls._load_instances(subscription.instances, status, match_domain_attr=False) # type:ignore
1270
1369
 
1271
1370
  try:
@@ -1283,7 +1382,7 @@ class SubscriptionModel(DomainModel):
1283
1382
  **fixed_inputs,
1284
1383
  **instances,
1285
1384
  )
1286
- model._db_model = subscription
1385
+ model.db_model = subscription
1287
1386
  return model
1288
1387
  except ValidationError:
1289
1388
  logger.exception(
@@ -1294,8 +1393,13 @@ class SubscriptionModel(DomainModel):
1294
1393
  @classmethod
1295
1394
  def from_subscription(cls: type[S], subscription_id: UUID | UUIDstr) -> S:
1296
1395
  """Use a subscription_id to return required fields of an existing subscription."""
1297
- subscription = cls._get_subscription(subscription_id)
1298
- if subscription is None:
1396
+ from orchestrator.domain.context_cache import get_from_cache, store_in_cache
1397
+ from orchestrator.settings import app_settings
1398
+
1399
+ if cached_model := get_from_cache(subscription_id):
1400
+ return cast(S, cached_model)
1401
+
1402
+ if not (subscription := cls._get_subscription(subscription_id)):
1299
1403
  raise ValueError(f"Subscription with id: {subscription_id}, does not exist")
1300
1404
  product = cls._to_product_model(subscription.product)
1301
1405
 
@@ -1317,7 +1421,12 @@ class SubscriptionModel(DomainModel):
1317
1421
 
1318
1422
  fixed_inputs = {fi.name: fi.value for fi in subscription.product.fixed_inputs}
1319
1423
 
1320
- instances = cls._load_instances(subscription.instances, status, match_domain_attr=False)
1424
+ instances: dict[str, Any]
1425
+ if app_settings.ENABLE_SUBSCRIPTION_MODEL_OPTIMIZATIONS:
1426
+ # TODO #900 remove toggle and make this path the default
1427
+ instances = cls._load_root_instances(subscription_id)
1428
+ else:
1429
+ instances = cls._load_instances(subscription.instances, status, match_domain_attr=False)
1321
1430
 
1322
1431
  try:
1323
1432
  model = cls(
@@ -1334,7 +1443,10 @@ class SubscriptionModel(DomainModel):
1334
1443
  **fixed_inputs,
1335
1444
  **instances,
1336
1445
  )
1337
- model._db_model = subscription
1446
+ model.db_model = subscription
1447
+
1448
+ store_in_cache(model)
1449
+
1338
1450
  return model
1339
1451
  except ValidationError:
1340
1452
  logger.exception(
@@ -1350,7 +1462,7 @@ class SubscriptionModel(DomainModel):
1350
1462
  f"Lifecycle status {self.status.value} requires specialized type {specialized_type!r}, was: {type(self)!r}"
1351
1463
  )
1352
1464
 
1353
- sub = db.session.get(
1465
+ existing_sub = db.session.get(
1354
1466
  SubscriptionTable,
1355
1467
  self.subscription_id,
1356
1468
  options=[
@@ -1360,13 +1472,13 @@ class SubscriptionModel(DomainModel):
1360
1472
  selectinload(SubscriptionTable.instances).selectinload(SubscriptionInstanceTable.values),
1361
1473
  ],
1362
1474
  )
1363
- if not sub:
1364
- sub = self._db_model
1475
+ if not (sub := (existing_sub or self.db_model)):
1476
+ raise ValueError("Cannot save SubscriptionModel without a db_model")
1365
1477
 
1366
1478
  # Make sure we refresh the object and not use an already mapped object
1367
1479
  db.session.refresh(sub)
1368
1480
 
1369
- self._db_model = sub
1481
+ self.db_model = sub
1370
1482
  sub.product_id = self.product.product_id
1371
1483
  sub.customer_id = self.customer_id
1372
1484
  sub.description = self.description
@@ -1404,9 +1516,15 @@ class SubscriptionModel(DomainModel):
1404
1516
  db.session.flush()
1405
1517
 
1406
1518
  @property
1407
- def db_model(self) -> SubscriptionTable:
1519
+ def db_model(self) -> SubscriptionTable | None:
1520
+ if not self._db_model:
1521
+ self._db_model = self._get_subscription(self.subscription_id)
1408
1522
  return self._db_model
1409
1523
 
1524
+ @db_model.setter
1525
+ def db_model(self, value: SubscriptionTable) -> None:
1526
+ self._db_model = value
1527
+
1410
1528
 
1411
1529
  def validate_base_model(
1412
1530
  name: str, cls: type[Any], base_model: type[BaseModel] = DomainModel, errors: list[str] | None = None
@@ -0,0 +1,62 @@
1
+ # Copyright 2019-2025 SURF.
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ import contextlib
14
+ from contextvars import ContextVar
15
+ from typing import Iterator
16
+ from uuid import UUID
17
+
18
+ from orchestrator.domain import SubscriptionModel
19
+ from pydantic_forms.types import UUIDstr
20
+
21
+ __subscription_model_cache: ContextVar[dict[UUID, SubscriptionModel] | None] = ContextVar(
22
+ "subscription_model_cache", default=None
23
+ )
24
+
25
+
26
+ @contextlib.contextmanager
27
+ def cache_subscription_models() -> Iterator:
28
+ """Caches SubscriptionModels for the duration of the context.
29
+
30
+ Inside this context, calling SubscriptionModel.from_subscription() twice with the same
31
+ subscription id will return the same instance.
32
+
33
+ The primary usecase is to improve performance of `@computed_field` properties on product blocks
34
+ which load other subscriptions.
35
+
36
+ Example usage:
37
+ subscription = SubscriptionModel.from_subscription("...")
38
+ with cache_subscription_models():
39
+ subscription_dict = subscription.model_dump()
40
+ """
41
+ before = __subscription_model_cache.set({})
42
+ try:
43
+ yield
44
+ finally:
45
+ __subscription_model_cache.reset(before)
46
+
47
+
48
+ def get_from_cache(subscription_id: UUID | UUIDstr) -> SubscriptionModel | None:
49
+ """Retrieve SubscriptionModel from cache, if present."""
50
+ if (cache := __subscription_model_cache.get()) is None:
51
+ return None
52
+
53
+ id_ = subscription_id if isinstance(subscription_id, UUID) else UUID(subscription_id)
54
+ return cache.get(id_, None)
55
+
56
+
57
+ def store_in_cache(model: SubscriptionModel) -> None:
58
+ """Store SubscriptionModel in cache, if required."""
59
+ if (cache := __subscription_model_cache.get()) is None:
60
+ return
61
+
62
+ cache[model.subscription_id] = model
@@ -1,6 +1,15 @@
1
- from collections.abc import Iterable
1
+ import contextlib
2
+ from collections.abc import Iterable, Iterator
3
+ from typing import Any
4
+ from uuid import UUID
2
5
 
6
+ from pydantic import BaseModel
7
+ from sqlalchemy import select
8
+
9
+ from orchestrator.db import ProductBlockTable, SubscriptionInstanceTable, db
3
10
  from orchestrator.types import filter_nonetype, get_origin_and_args, is_union_type
11
+ from orchestrator.utils.functional import group_by_key
12
+ from pydantic_forms.types import UUIDstr
4
13
 
5
14
 
6
15
  def _to_product_block_field_type_iterable(product_block_field_type: type | tuple[type]) -> Iterable[type]:
@@ -21,3 +30,34 @@ def _to_product_block_field_type_iterable(product_block_field_type: type | tuple
21
30
  return product_block_field_type
22
31
 
23
32
  return [product_block_field_type]
33
+
34
+
35
+ @contextlib.contextmanager
36
+ def no_private_attrs(model: Any) -> Iterator:
37
+ """PrivateAttrs from the given pydantic BaseModel are removed for the duration of this context."""
38
+ if not isinstance(model, BaseModel):
39
+ yield
40
+ return
41
+ private_attrs_reference = model.__pydantic_private__
42
+ try:
43
+ model.__pydantic_private__ = {}
44
+ yield
45
+ finally:
46
+ model.__pydantic_private__ = private_attrs_reference
47
+
48
+
49
+ def get_root_blocks_to_instance_ids(subscription_id: UUID | UUIDstr) -> dict[str, list[UUID]]:
50
+ """Returns mapping of root product block names to list of subscription instance ids.
51
+
52
+ While recommended practice is to have only 1 root product block, it is possible to have multiple blocks or even a
53
+ list of root blocks. This function supports that.
54
+ """
55
+ block_name_to_instance_id_rows = db.session.execute(
56
+ select(ProductBlockTable.name, SubscriptionInstanceTable.subscription_instance_id)
57
+ .select_from(SubscriptionInstanceTable)
58
+ .join(ProductBlockTable)
59
+ .where(SubscriptionInstanceTable.subscription_id == subscription_id)
60
+ .order_by(ProductBlockTable.name)
61
+ ).all()
62
+
63
+ return group_by_key(block_name_to_instance_id_rows) # type: ignore[arg-type]
@@ -0,0 +1,114 @@
1
+ # Copyright 2019-2025 SURF.
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+
14
+ """Functions to transform result of query SubscriptionInstanceAsJsonFunction to match the ProductBlockModel."""
15
+
16
+ from functools import partial
17
+ from typing import TYPE_CHECKING, Any, Callable, Iterable
18
+
19
+ from more_itertools import first, only
20
+
21
+ from orchestrator.types import is_list_type, is_optional_type
22
+
23
+ if TYPE_CHECKING:
24
+ from orchestrator.domain.base import ProductBlockModel
25
+
26
+
27
+ def _ensure_list(instance_or_value_list: Any) -> Any:
28
+ if instance_or_value_list is None:
29
+ return []
30
+
31
+ return instance_or_value_list
32
+
33
+
34
+ def _instance_list_to_dict(product_block_field_type: type, instance_list: Any) -> Any:
35
+ if instance_list is None:
36
+ return None
37
+
38
+ match instance_list:
39
+ case list():
40
+ if instance := only(instance_list):
41
+ return instance
42
+
43
+ if not is_optional_type(product_block_field_type):
44
+ raise ValueError("Required subscription instance is missing in database")
45
+
46
+ return None # Set the optional product block field to None
47
+ case _:
48
+ raise ValueError(f"All subscription instances should be returned as list, found {type(instance_list)}") #
49
+
50
+
51
+ def _value_list_to_value(field_type: type, value_list: Any) -> Any:
52
+ if value_list is None:
53
+ return None
54
+
55
+ match value_list:
56
+ case list():
57
+ if (value := only(value_list)) is not None:
58
+ return value
59
+
60
+ if not is_optional_type(field_type):
61
+ raise ValueError("Required subscription value is missing in database")
62
+
63
+ return None # Set the optional resource type field to None
64
+ case _:
65
+ raise ValueError(f"All instance values should be returned as list, found {type(value_list)}")
66
+
67
+
68
+ def field_transformation_rules(klass: type["ProductBlockModel"]) -> dict[str, Callable]:
69
+ """Create mapping of transformation rules for the given product block type."""
70
+
71
+ def create_rules() -> Iterable[tuple[str, Callable]]:
72
+ for field_name, product_block_field_type in klass._product_block_fields_.items():
73
+ if is_list_type(product_block_field_type):
74
+ yield field_name, _ensure_list
75
+ else:
76
+ yield field_name, partial(_instance_list_to_dict, product_block_field_type)
77
+
78
+ for field_name, field_type in klass._non_product_block_fields_.items():
79
+ if is_list_type(field_type):
80
+ yield field_name, _ensure_list
81
+ else:
82
+ yield field_name, partial(_value_list_to_value, field_type)
83
+
84
+ return dict(create_rules())
85
+
86
+
87
+ def transform_instance_fields(all_rules: dict[str, dict[str, Callable]], instance: dict) -> None:
88
+ """Apply transformation rules to the given subscription instance dict."""
89
+
90
+ from orchestrator.domain.base import ProductBlockModel
91
+
92
+ # Lookup applicable rules through product block name
93
+ field_rules = all_rules[instance["name"]]
94
+
95
+ klass = ProductBlockModel.registry[instance["name"]]
96
+
97
+ # Ensure the product block's metadata is loaded
98
+ klass._fix_pb_data()
99
+
100
+ # Transform all fields in this subscription instance
101
+ try:
102
+ for field_name, rewrite_func in field_rules.items():
103
+ field_value = instance.get(field_name)
104
+ instance[field_name] = rewrite_func(field_value)
105
+ except ValueError as e:
106
+ raise ValueError(f"Invalid subscription instance data {instance}") from e
107
+
108
+ # Recurse into nested subscription instances
109
+ for field_value in instance.values():
110
+ if isinstance(field_value, dict):
111
+ transform_instance_fields(all_rules, field_value)
112
+ if isinstance(field_value, list) and isinstance(first(field_value, None), dict):
113
+ for list_value in field_value:
114
+ transform_instance_fields(all_rules, list_value)
@@ -34,7 +34,7 @@ from orchestrator.graphql.utils import (
34
34
  is_querying_page_data,
35
35
  to_graphql_result_page,
36
36
  )
37
- from orchestrator.graphql.utils.get_query_loaders import get_query_loaders
37
+ from orchestrator.graphql.utils.get_query_loaders import get_query_loaders_for_gql_fields
38
38
  from orchestrator.schemas.process import ProcessSchema
39
39
  from orchestrator.services.processes import load_process
40
40
  from orchestrator.utils.enrich_process import enrich_process
@@ -56,7 +56,7 @@ def _enrich_process(process: ProcessTable, with_details: bool = False) -> Proces
56
56
 
57
57
 
58
58
  async def resolve_process(info: OrchestratorInfo, process_id: UUID) -> ProcessType | None:
59
- query_loaders = get_query_loaders(info, ProcessTable)
59
+ query_loaders = get_query_loaders_for_gql_fields(ProcessTable, info)
60
60
  stmt = select(ProcessTable).options(*query_loaders).where(ProcessTable.process_id == process_id)
61
61
  if process := db.session.scalar(stmt):
62
62
  is_detailed = _is_process_detailed(info)
@@ -83,7 +83,7 @@ async def resolve_processes(
83
83
  .selectinload(ProcessSubscriptionTable.subscription)
84
84
  .joinedload(SubscriptionTable.product)
85
85
  ]
86
- query_loaders = get_query_loaders(info, ProcessTable) or default_loaders
86
+ query_loaders = get_query_loaders_for_gql_fields(ProcessTable, info) or default_loaders
87
87
  select_stmt = select(ProcessTable).options(*query_loaders)
88
88
  select_stmt = filter_processes(select_stmt, pydantic_filter_by, _error_handler)
89
89
  if query is not None:
@@ -13,7 +13,7 @@ from orchestrator.graphql.resolvers.helpers import rows_from_statement
13
13
  from orchestrator.graphql.schemas.product import ProductType
14
14
  from orchestrator.graphql.types import GraphqlFilter, GraphqlSort, OrchestratorInfo
15
15
  from orchestrator.graphql.utils import create_resolver_error_handler, is_querying_page_data, to_graphql_result_page
16
- from orchestrator.graphql.utils.get_query_loaders import get_query_loaders
16
+ from orchestrator.graphql.utils.get_query_loaders import get_query_loaders_for_gql_fields
17
17
  from orchestrator.utils.search_query import create_sqlalchemy_select
18
18
 
19
19
  logger = structlog.get_logger(__name__)
@@ -33,7 +33,7 @@ async def resolve_products(
33
33
  pydantic_sort_by: list[Sort] = [item.to_pydantic() for item in sort_by] if sort_by else []
34
34
  logger.debug("resolve_products() called", range=[after, after + first], sort=sort_by, filter=pydantic_filter_by)
35
35
 
36
- query_loaders = get_query_loaders(info, ProductTable)
36
+ query_loaders = get_query_loaders_for_gql_fields(ProductTable, info)
37
37
  select_stmt = select(ProductTable).options(*query_loaders)
38
38
  select_stmt = filter_products(select_stmt, pydantic_filter_by, _error_handler)
39
39