orchestrator-core 3.1.2rc3__py3-none-any.whl → 3.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. orchestrator/__init__.py +2 -2
  2. orchestrator/api/api_v1/api.py +1 -1
  3. orchestrator/api/api_v1/endpoints/processes.py +6 -9
  4. orchestrator/api/api_v1/endpoints/settings.py +1 -1
  5. orchestrator/api/api_v1/endpoints/subscriptions.py +1 -1
  6. orchestrator/app.py +1 -1
  7. orchestrator/cli/database.py +1 -1
  8. orchestrator/cli/generator/generator/migration.py +2 -5
  9. orchestrator/cli/generator/generator/workflow.py +13 -1
  10. orchestrator/cli/generator/templates/modify_product.j2 +9 -0
  11. orchestrator/cli/migrate_tasks.py +13 -0
  12. orchestrator/config/assignee.py +1 -1
  13. orchestrator/db/__init__.py +2 -0
  14. orchestrator/db/loaders.py +51 -3
  15. orchestrator/db/models.py +14 -1
  16. orchestrator/db/queries/__init__.py +0 -0
  17. orchestrator/db/queries/subscription.py +85 -0
  18. orchestrator/db/queries/subscription_instance.py +28 -0
  19. orchestrator/devtools/populator.py +1 -1
  20. orchestrator/domain/__init__.py +2 -3
  21. orchestrator/domain/base.py +236 -49
  22. orchestrator/domain/context_cache.py +62 -0
  23. orchestrator/domain/helpers.py +41 -1
  24. orchestrator/domain/lifecycle.py +1 -1
  25. orchestrator/domain/subscription_instance_transform.py +114 -0
  26. orchestrator/graphql/resolvers/process.py +3 -3
  27. orchestrator/graphql/resolvers/product.py +2 -2
  28. orchestrator/graphql/resolvers/product_block.py +2 -2
  29. orchestrator/graphql/resolvers/resource_type.py +2 -2
  30. orchestrator/graphql/resolvers/workflow.py +2 -2
  31. orchestrator/graphql/schema.py +1 -1
  32. orchestrator/graphql/types.py +1 -1
  33. orchestrator/graphql/utils/get_query_loaders.py +6 -48
  34. orchestrator/graphql/utils/get_subscription_product_blocks.py +21 -1
  35. orchestrator/migrations/env.py +1 -1
  36. orchestrator/migrations/helpers.py +6 -6
  37. orchestrator/migrations/versions/schema/2025-03-06_42b3d076a85b_subscription_instance_as_json_function.py +33 -0
  38. orchestrator/migrations/versions/schema/2025-03-06_42b3d076a85b_subscription_instance_as_json_function.sql +40 -0
  39. orchestrator/migrations/versions/schema/2025-04-09_fc5c993a4b4a_add_cascade_constraint_on_processes_.py +44 -0
  40. orchestrator/schemas/engine_settings.py +1 -1
  41. orchestrator/schemas/subscription.py +1 -1
  42. orchestrator/security.py +1 -1
  43. orchestrator/services/celery.py +1 -1
  44. orchestrator/services/processes.py +28 -9
  45. orchestrator/services/products.py +1 -1
  46. orchestrator/services/subscriptions.py +37 -7
  47. orchestrator/services/tasks.py +1 -1
  48. orchestrator/settings.py +5 -23
  49. orchestrator/targets.py +1 -1
  50. orchestrator/types.py +1 -1
  51. orchestrator/utils/errors.py +1 -1
  52. orchestrator/utils/functional.py +9 -0
  53. orchestrator/utils/redis.py +6 -0
  54. orchestrator/utils/state.py +1 -1
  55. orchestrator/websocket/websocket_manager.py +1 -1
  56. orchestrator/workflow.py +29 -6
  57. orchestrator/workflows/modify_note.py +1 -1
  58. orchestrator/workflows/steps.py +1 -1
  59. orchestrator/workflows/tasks/cleanup_tasks_log.py +1 -1
  60. orchestrator/workflows/tasks/resume_workflows.py +1 -1
  61. orchestrator/workflows/tasks/validate_product_type.py +1 -1
  62. orchestrator/workflows/tasks/validate_products.py +1 -1
  63. orchestrator/workflows/utils.py +40 -5
  64. {orchestrator_core-3.1.2rc3.dist-info → orchestrator_core-3.2.0.dist-info}/METADATA +10 -9
  65. {orchestrator_core-3.1.2rc3.dist-info → orchestrator_core-3.2.0.dist-info}/RECORD +68 -60
  66. {orchestrator_core-3.1.2rc3.dist-info → orchestrator_core-3.2.0.dist-info}/WHEEL +1 -1
  67. /orchestrator/migrations/versions/schema/{2025-10-19_4fjdn13f83ga_add_validate_product_type_task.py → 2025-01-19_4fjdn13f83ga_add_validate_product_type_task.py} +0 -0
  68. {orchestrator_core-3.1.2rc3.dist-info → orchestrator_core-3.2.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2019-2020 SURF, ESnet.
1
+ # Copyright 2019-2025 SURF, ESnet, GÉANT.
2
2
  # Licensed under the Apache License, Version 2.0 (the "License");
3
3
  # you may not use this file except in compliance with the License.
4
4
  # You may obtain a copy of the License at
@@ -12,14 +12,16 @@
12
12
  # limitations under the License.
13
13
  import itertools
14
14
  from collections import defaultdict
15
- from collections.abc import Callable, Iterable
16
15
  from datetime import datetime
17
- from inspect import get_annotations
16
+ from inspect import get_annotations, isclass
18
17
  from itertools import groupby, zip_longest
19
18
  from operator import attrgetter
20
19
  from typing import (
21
20
  Any,
21
+ Callable,
22
22
  ClassVar,
23
+ Iterable,
24
+ Mapping,
23
25
  Optional,
24
26
  TypeVar,
25
27
  Union,
@@ -34,7 +36,7 @@ from more_itertools import bucket, first, flatten, one, only
34
36
  from pydantic import BaseModel, ConfigDict, Field, ValidationError
35
37
  from pydantic.fields import PrivateAttr
36
38
  from sqlalchemy import select
37
- from sqlalchemy.orm import selectinload
39
+ from sqlalchemy.orm import joinedload, selectinload
38
40
 
39
41
  from orchestrator.db import (
40
42
  ProductBlockTable,
@@ -45,13 +47,19 @@ from orchestrator.db import (
45
47
  SubscriptionTable,
46
48
  db,
47
49
  )
48
- from orchestrator.domain.helpers import _to_product_block_field_type_iterable
50
+ from orchestrator.db.queries.subscription_instance import get_subscription_instance_dict
51
+ from orchestrator.domain.helpers import (
52
+ _to_product_block_field_type_iterable,
53
+ get_root_blocks_to_instance_ids,
54
+ no_private_attrs,
55
+ )
49
56
  from orchestrator.domain.lifecycle import (
50
57
  ProductLifecycle,
51
58
  lookup_specialized_type,
52
59
  register_specialized_type,
53
60
  validate_lifecycle_status,
54
61
  )
62
+ from orchestrator.domain.subscription_instance_transform import field_transformation_rules, transform_instance_fields
55
63
  from orchestrator.services.products import get_product_by_id
56
64
  from orchestrator.types import (
57
65
  SAFE_USED_BY_TRANSITIONS_FOR_STATUS,
@@ -96,6 +104,12 @@ class DomainModel(BaseModel):
96
104
  def __init_subclass__(cls, *args: Any, lifecycle: list[SubscriptionLifecycle] | None = None, **kwargs: Any) -> None:
97
105
  pass
98
106
 
107
+ def __eq__(self, other: Any) -> bool:
108
+ # PrivateAttr fields are excluded from both objects during the equality check.
109
+ # Added for #652 primarily because ProductBlockModel._db_model is now lazy loaded.
110
+ with no_private_attrs(self), no_private_attrs(other):
111
+ return super().__eq__(other)
112
+
99
113
  @classmethod
100
114
  def __pydantic_init_subclass__(
101
115
  cls,
@@ -308,12 +322,7 @@ class DomainModel(BaseModel):
308
322
  for product_block_field_name, product_block_field_type in cls._product_block_fields_.items():
309
323
  filter_func = match_domain_model_attr_if_possible(product_block_field_name)
310
324
 
311
- product_block_model: Any = product_block_field_type
312
- if is_list_type(product_block_field_type):
313
- _origin, args = get_origin_and_args(product_block_field_type)
314
- product_block_model = one(args)
315
-
316
- possible_product_block_types = get_possible_product_block_types(product_block_model)
325
+ possible_product_block_types = flatten_product_block_types(product_block_field_type)
317
326
  field_type_names = list(possible_product_block_types.keys())
318
327
  filtered_instances = flatten([grouped_instances.get(name, []) for name in field_type_names])
319
328
  instance_list = list(filter(filter_func, filtered_instances))
@@ -456,6 +465,15 @@ class DomainModel(BaseModel):
456
465
  raise ValueError(f"Cannot link the same subscription instance multiple times: {details}")
457
466
 
458
467
 
468
+ def flatten_product_block_types(product_block_field_type: Any) -> dict[str, type["ProductBlockModel"]]:
469
+ """Extract product block types and return mapping of product block names to product block classes."""
470
+ product_block_model: Any = product_block_field_type
471
+ if is_list_type(product_block_field_type):
472
+ _origin, args = get_origin_and_args(product_block_field_type)
473
+ product_block_model = one(args)
474
+ return get_possible_product_block_types(product_block_model)
475
+
476
+
459
477
  def get_depends_on_product_block_type_list(
460
478
  product_block_types: dict[str, type["ProductBlockModel"] | tuple[type["ProductBlockModel"]]],
461
479
  ) -> list[type["ProductBlockModel"]]:
@@ -520,7 +538,7 @@ class ProductBlockModel(DomainModel):
520
538
  product_block_id: ClassVar[UUID]
521
539
  description: ClassVar[str]
522
540
  tag: ClassVar[str]
523
- _db_model: SubscriptionInstanceTable = PrivateAttr()
541
+ _db_model: SubscriptionInstanceTable | None = PrivateAttr(default=None)
524
542
 
525
543
  # Product block name. This needs to be an instance var because its part of the API (we expose it to the frontend)
526
544
  # Is actually optional since abstract classes don't have it.
@@ -596,7 +614,9 @@ class ProductBlockModel(DomainModel):
596
614
  product_blocks_in_model = cls._get_depends_on_product_block_types()
597
615
  product_blocks_types_in_model = get_depends_on_product_block_type_list(product_blocks_in_model)
598
616
 
599
- product_blocks_in_model = set(flatten(map(attrgetter("__names__"), product_blocks_types_in_model))) # type: ignore
617
+ product_blocks_in_model = set(
618
+ flatten(map(attrgetter("__names__"), product_blocks_types_in_model))
619
+ ) # type: ignore
600
620
 
601
621
  missing_product_blocks_in_db = product_blocks_in_model - product_blocks_in_db # type: ignore
602
622
  missing_product_blocks_in_model = product_blocks_in_db - product_blocks_in_model # type: ignore
@@ -677,7 +697,7 @@ class ProductBlockModel(DomainModel):
677
697
  **sub_instances,
678
698
  **kwargs,
679
699
  )
680
- model._db_model = db_model
700
+ model.db_model = db_model
681
701
  return model
682
702
 
683
703
  @classmethod
@@ -733,7 +753,7 @@ class ProductBlockModel(DomainModel):
733
753
 
734
754
  cls._fix_pb_data()
735
755
  model = cls(**data)
736
- model._db_model = other._db_model
756
+ model.db_model = other.db_model
737
757
  return model
738
758
 
739
759
  @classmethod
@@ -791,7 +811,7 @@ class ProductBlockModel(DomainModel):
791
811
  **instance_values, # type: ignore
792
812
  **sub_instances,
793
813
  )
794
- model._db_model = subscription_instance
814
+ model.db_model = subscription_instance
795
815
 
796
816
  return model
797
817
  except ValidationError:
@@ -909,14 +929,15 @@ class ProductBlockModel(DomainModel):
909
929
 
910
930
  # If this is a "foreign" instance we just stop saving and return it so only its relation is saved
911
931
  # We should not touch these themselves
912
- if self.subscription and subscription_instance.subscription_id != subscription_id:
932
+ if self.owner_subscription_id != subscription_id:
913
933
  return [], subscription_instance
914
934
 
915
- self._db_model = subscription_instance
916
- else:
917
- subscription_instance = self._db_model
935
+ self.db_model = subscription_instance
936
+ elif subscription_instance := self.db_model:
918
937
  # We only need to add to the session if the subscription_instance does not exist.
919
938
  db.session.add(subscription_instance)
939
+ else:
940
+ raise ValueError("Cannot save ProductBlockModel without a db_model")
920
941
 
921
942
  subscription_instance.subscription_id = subscription_id
922
943
 
@@ -943,22 +964,32 @@ class ProductBlockModel(DomainModel):
943
964
  return sub_instances + [subscription_instance], subscription_instance
944
965
 
945
966
  @property
946
- def subscription(self) -> SubscriptionTable:
947
- return self.db_model.subscription
967
+ def subscription(self) -> SubscriptionTable | None:
968
+ return self.db_model.subscription if self.db_model else None
948
969
 
949
970
  @property
950
- def db_model(self) -> SubscriptionInstanceTable:
971
+ def db_model(self) -> SubscriptionInstanceTable | None:
972
+ if not self._db_model:
973
+ self._db_model = db.session.execute(
974
+ select(SubscriptionInstanceTable).where(
975
+ SubscriptionInstanceTable.subscription_instance_id == self.subscription_instance_id
976
+ )
977
+ ).scalar_one_or_none()
951
978
  return self._db_model
952
979
 
980
+ @db_model.setter
981
+ def db_model(self, value: SubscriptionInstanceTable) -> None:
982
+ self._db_model = value
983
+
953
984
  @property
954
- def in_use_by(self) -> list[SubscriptionInstanceTable]:
985
+ def in_use_by(self) -> list[SubscriptionInstanceTable]: # TODO check where used, might need eagerloading
955
986
  """This provides a list of product blocks that depend on this product block."""
956
- return self._db_model.in_use_by
987
+ return self.db_model.in_use_by if self.db_model else []
957
988
 
958
989
  @property
959
- def depends_on(self) -> list[SubscriptionInstanceTable]:
990
+ def depends_on(self) -> list[SubscriptionInstanceTable]: # TODO check where used, might need eagerloading
960
991
  """This provides a list of product blocks that this product block depends on."""
961
- return self._db_model.depends_on
992
+ return self.db_model.depends_on if self.db_model else []
962
993
 
963
994
 
964
995
  class ProductModel(BaseModel):
@@ -1006,9 +1037,11 @@ class SubscriptionModel(DomainModel):
1006
1037
  >>> SubscriptionInactive.from_subscription(subscription_id) # doctest:+SKIP
1007
1038
  """
1008
1039
 
1040
+ __model_dump_cache__: ClassVar[dict[UUID, "SubscriptionModel"] | None] = None
1041
+
1009
1042
  product: ProductModel
1010
1043
  customer_id: str
1011
- _db_model: SubscriptionTable = PrivateAttr()
1044
+ _db_model: SubscriptionTable | None = PrivateAttr(default=None)
1012
1045
  subscription_id: UUID = Field(default_factory=uuid4) # pragma: no mutate
1013
1046
  description: str = "Initial subscription" # pragma: no mutate
1014
1047
  status: SubscriptionLifecycle = SubscriptionLifecycle.INITIAL # pragma: no mutate
@@ -1051,7 +1084,9 @@ class SubscriptionModel(DomainModel):
1051
1084
  product_blocks_in_model = cls._get_depends_on_product_block_types()
1052
1085
  product_blocks_types_in_model = get_depends_on_product_block_type_list(product_blocks_in_model)
1053
1086
 
1054
- product_blocks_in_model = set(flatten(map(attrgetter("__names__"), product_blocks_types_in_model))) # type: ignore
1087
+ product_blocks_in_model = set(
1088
+ flatten(map(attrgetter("__names__"), product_blocks_types_in_model))
1089
+ ) # type: ignore
1055
1090
 
1056
1091
  missing_product_blocks_in_db = product_blocks_in_model - product_blocks_in_db # type: ignore
1057
1092
  missing_product_blocks_in_model = product_blocks_in_db - product_blocks_in_model # type: ignore
@@ -1097,6 +1132,63 @@ class SubscriptionModel(DomainModel):
1097
1132
 
1098
1133
  return missing_data
1099
1134
 
1135
+ @classmethod
1136
+ def _load_root_instances(
1137
+ cls,
1138
+ subscription_id: UUID | UUIDstr,
1139
+ ) -> dict[str, Optional[dict] | list[dict]]:
1140
+ """Load root subscription instance(s) for this subscription model.
1141
+
1142
+ When a new subscription model is loaded from an existing subscription, this function loads the entire root
1143
+ subscription instance(s) from database using an optimized postgres function. The result of that function
1144
+ is used to instantiate the root product block(s).
1145
+
1146
+ The "old" method DomainModel._load_instances() would recursively load subscription instances from the
1147
+ database and individually instantiate nested blocks, more or less "manually" reconstructing the subscription.
1148
+
1149
+ The "new" method SubscriptionModel._load_root_instances() takes a different approach; since it has all
1150
+ data for the root subscription instance, it can rely on Pydantic to instantiate the root block and all
1151
+ nested blocks in one go. This is also why it does not have the params `status` and `match_domain_attr` because
1152
+ this information is already encoded in the domain model of a product.
1153
+ """
1154
+ root_block_instance_ids = get_root_blocks_to_instance_ids(subscription_id)
1155
+
1156
+ root_block_types = {
1157
+ field_name: list(flatten_product_block_types(product_block_type).keys())
1158
+ for field_name, product_block_type in cls._product_block_fields_.items()
1159
+ }
1160
+
1161
+ def get_instances_by_block_names(block_names: list[str]) -> Iterable[dict]:
1162
+ for block_name in block_names:
1163
+ for instance_id in root_block_instance_ids.get(block_name, []):
1164
+ yield get_subscription_instance_dict(instance_id)
1165
+
1166
+ # Map root product block fields to subscription instance(s) dicts
1167
+ instances = {
1168
+ field_name: list(get_instances_by_block_names(block_names))
1169
+ for field_name, block_names in root_block_types.items()
1170
+ }
1171
+
1172
+ # Transform values according to domain models (list[dict] -> dict, add None as default for optionals)
1173
+ rules = {
1174
+ klass.name: field_transformation_rules(klass) for klass in ProductBlockModel.registry.values() if klass.name
1175
+ }
1176
+ for instance_list in instances.values():
1177
+ for instance in instance_list:
1178
+ transform_instance_fields(rules, instance)
1179
+
1180
+ # Support the (theoretical?) usecase of a list of root product blocks
1181
+ def unpack_instance_list(field_name: str, instance_list: list[dict]) -> list[dict] | dict | None:
1182
+ field_type = cls._product_block_fields_[field_name]
1183
+ if is_list_type(field_type):
1184
+ return instance_list
1185
+ return only(instance_list)
1186
+
1187
+ return {
1188
+ field_name: unpack_instance_list(field_name, instance_list)
1189
+ for field_name, instance_list in instances.items()
1190
+ }
1191
+
1100
1192
  @classmethod
1101
1193
  def from_product_id(
1102
1194
  cls: type[S],
@@ -1162,7 +1254,7 @@ class SubscriptionModel(DomainModel):
1162
1254
  **fixed_inputs,
1163
1255
  **instances,
1164
1256
  )
1165
- model._db_model = subscription
1257
+ model.db_model = subscription
1166
1258
  return model
1167
1259
 
1168
1260
  @classmethod
@@ -1195,17 +1287,26 @@ class SubscriptionModel(DomainModel):
1195
1287
  data["end_date"] = nowtz()
1196
1288
 
1197
1289
  model = cls(**data)
1198
- model._db_model = other._db_model
1290
+ model.db_model = other._db_model
1199
1291
 
1200
1292
  return model
1201
1293
 
1202
1294
  # Some common functions shared by from_other_product and from_subscription
1203
1295
  @classmethod
1204
- def _get_subscription(cls: type[S], subscription_id: UUID | UUIDstr) -> Any:
1205
- return db.session.get(
1206
- SubscriptionTable,
1207
- subscription_id,
1208
- options=[
1296
+ def _get_subscription(cls: type[S], subscription_id: UUID | UUIDstr) -> SubscriptionTable | None:
1297
+ from orchestrator.settings import app_settings
1298
+
1299
+ if not isinstance(subscription_id, UUID | UUIDstr):
1300
+ raise TypeError(f"subscription_id is of type {type(subscription_id)} instead of UUID | UUIDstr")
1301
+
1302
+ if app_settings.ENABLE_SUBSCRIPTION_MODEL_OPTIMIZATIONS:
1303
+ # TODO #900 remove toggle and make this path the default
1304
+ loaders = [
1305
+ joinedload(SubscriptionTable.product).selectinload(ProductTable.fixed_inputs),
1306
+ ]
1307
+
1308
+ else:
1309
+ loaders = [
1209
1310
  selectinload(SubscriptionTable.instances)
1210
1311
  .joinedload(SubscriptionInstanceTable.product_block)
1211
1312
  .selectinload(ProductBlockTable.resource_types),
@@ -1213,8 +1314,9 @@ class SubscriptionModel(DomainModel):
1213
1314
  SubscriptionInstanceTable.in_use_by_block_relations
1214
1315
  ),
1215
1316
  selectinload(SubscriptionTable.instances).selectinload(SubscriptionInstanceTable.values),
1216
- ],
1217
- )
1317
+ ]
1318
+
1319
+ return db.session.get(SubscriptionTable, subscription_id, options=loaders)
1218
1320
 
1219
1321
  @classmethod
1220
1322
  def _to_product_model(cls: type[S], product: ProductTable) -> ProductModel:
@@ -1240,7 +1342,9 @@ class SubscriptionModel(DomainModel):
1240
1342
  if not db_product:
1241
1343
  raise KeyError("Could not find a product for the given product_id")
1242
1344
 
1243
- subscription = cls._get_subscription(old_instantiation.subscription_id)
1345
+ old_subscription_id = old_instantiation.subscription_id
1346
+ if not (subscription := cls._get_subscription(old_subscription_id)):
1347
+ raise ValueError(f"Subscription with id: {old_subscription_id}, does not exist")
1244
1348
  product = cls._to_product_model(db_product)
1245
1349
 
1246
1350
  status = SubscriptionLifecycle(subscription.status)
@@ -1260,6 +1364,7 @@ class SubscriptionModel(DomainModel):
1260
1364
  name, product_block = new_root
1261
1365
  instances = {name: product_block}
1262
1366
  else:
1367
+ # TODO test using cls._load_root_instances() here as well
1263
1368
  instances = cls._load_instances(subscription.instances, status, match_domain_attr=False) # type:ignore
1264
1369
 
1265
1370
  try:
@@ -1277,7 +1382,7 @@ class SubscriptionModel(DomainModel):
1277
1382
  **fixed_inputs,
1278
1383
  **instances,
1279
1384
  )
1280
- model._db_model = subscription
1385
+ model.db_model = subscription
1281
1386
  return model
1282
1387
  except ValidationError:
1283
1388
  logger.exception(
@@ -1288,8 +1393,13 @@ class SubscriptionModel(DomainModel):
1288
1393
  @classmethod
1289
1394
  def from_subscription(cls: type[S], subscription_id: UUID | UUIDstr) -> S:
1290
1395
  """Use a subscription_id to return required fields of an existing subscription."""
1291
- subscription = cls._get_subscription(subscription_id)
1292
- if subscription is None:
1396
+ from orchestrator.domain.context_cache import get_from_cache, store_in_cache
1397
+ from orchestrator.settings import app_settings
1398
+
1399
+ if cached_model := get_from_cache(subscription_id):
1400
+ return cast(S, cached_model)
1401
+
1402
+ if not (subscription := cls._get_subscription(subscription_id)):
1293
1403
  raise ValueError(f"Subscription with id: {subscription_id}, does not exist")
1294
1404
  product = cls._to_product_model(subscription.product)
1295
1405
 
@@ -1311,7 +1421,12 @@ class SubscriptionModel(DomainModel):
1311
1421
 
1312
1422
  fixed_inputs = {fi.name: fi.value for fi in subscription.product.fixed_inputs}
1313
1423
 
1314
- instances = cls._load_instances(subscription.instances, status, match_domain_attr=False)
1424
+ instances: dict[str, Any]
1425
+ if app_settings.ENABLE_SUBSCRIPTION_MODEL_OPTIMIZATIONS:
1426
+ # TODO #900 remove toggle and make this path the default
1427
+ instances = cls._load_root_instances(subscription_id)
1428
+ else:
1429
+ instances = cls._load_instances(subscription.instances, status, match_domain_attr=False)
1315
1430
 
1316
1431
  try:
1317
1432
  model = cls(
@@ -1328,7 +1443,10 @@ class SubscriptionModel(DomainModel):
1328
1443
  **fixed_inputs,
1329
1444
  **instances,
1330
1445
  )
1331
- model._db_model = subscription
1446
+ model.db_model = subscription
1447
+
1448
+ store_in_cache(model)
1449
+
1332
1450
  return model
1333
1451
  except ValidationError:
1334
1452
  logger.exception(
@@ -1344,7 +1462,7 @@ class SubscriptionModel(DomainModel):
1344
1462
  f"Lifecycle status {self.status.value} requires specialized type {specialized_type!r}, was: {type(self)!r}"
1345
1463
  )
1346
1464
 
1347
- sub = db.session.get(
1465
+ existing_sub = db.session.get(
1348
1466
  SubscriptionTable,
1349
1467
  self.subscription_id,
1350
1468
  options=[
@@ -1354,13 +1472,13 @@ class SubscriptionModel(DomainModel):
1354
1472
  selectinload(SubscriptionTable.instances).selectinload(SubscriptionInstanceTable.values),
1355
1473
  ],
1356
1474
  )
1357
- if not sub:
1358
- sub = self._db_model
1475
+ if not (sub := (existing_sub or self.db_model)):
1476
+ raise ValueError("Cannot save SubscriptionModel without a db_model")
1359
1477
 
1360
1478
  # Make sure we refresh the object and not use an already mapped object
1361
1479
  db.session.refresh(sub)
1362
1480
 
1363
- self._db_model = sub
1481
+ self.db_model = sub
1364
1482
  sub.product_id = self.product.product_id
1365
1483
  sub.customer_id = self.customer_id
1366
1484
  sub.description = self.description
@@ -1398,9 +1516,78 @@ class SubscriptionModel(DomainModel):
1398
1516
  db.session.flush()
1399
1517
 
1400
1518
  @property
1401
- def db_model(self) -> SubscriptionTable:
1519
+ def db_model(self) -> SubscriptionTable | None:
1520
+ if not self._db_model:
1521
+ self._db_model = self._get_subscription(self.subscription_id)
1402
1522
  return self._db_model
1403
1523
 
1524
+ @db_model.setter
1525
+ def db_model(self, value: SubscriptionTable) -> None:
1526
+ self._db_model = value
1527
+
1528
+
1529
+ def validate_base_model(
1530
+ name: str, cls: type[Any], base_model: type[BaseModel] = DomainModel, errors: list[str] | None = None
1531
+ ) -> None:
1532
+ """Validates that the given class is not Pydantic BaseModel or its direct subclass."""
1533
+ # Instantiate errors list if not provided and avoid mutating default
1534
+ if errors is None:
1535
+ errors = []
1536
+ # Return early when the node is not a class as there is nothing to be done
1537
+ if not isclass(cls):
1538
+ return
1539
+ # Validate each field in the ProductBlockModel's field dictionaries
1540
+ if issubclass(cls, ProductBlockModel) or issubclass(cls, SubscriptionModel):
1541
+ for name, clz in cls._product_block_fields_.items():
1542
+ validate_base_model(name, clz, ProductBlockModel, errors)
1543
+ for name, clz in cls._non_product_block_fields_.items():
1544
+ validate_base_model(name, clz, SubscriptionModel, errors)
1545
+ # Generate error if node is Pydantic BaseModel or direct subclass
1546
+ if issubclass(cls, BaseModel):
1547
+ err_msg: str = (
1548
+ f"If this field was intended to be a {base_model.__name__}, define {name}:{cls.__name__} with "
1549
+ f"{base_model.__name__} as its superclass instead. e.g., class {cls.__name__}({base_model.__name__}):"
1550
+ )
1551
+ if cls is BaseModel:
1552
+ errors.append(f"Field {name}: {cls.__name__} can not be {BaseModel.__name__}. " + err_msg)
1553
+ if len(cls.__mro__) > 1 and cls.__mro__[1] is BaseModel:
1554
+ errors.append(
1555
+ f"Field {name}: {cls.__name__} can not be a direct subclass of {BaseModel.__name__}. " + err_msg
1556
+ )
1557
+ # Format all errors as one per line and raise a TypeError when they exist
1558
+ if errors:
1559
+ raise TypeError("\n".join(errors))
1560
+
1561
+
1562
+ class SubscriptionModelRegistry(dict[str, type[SubscriptionModel]]):
1563
+ """A registry for all subscription models."""
1564
+
1565
+ def __setitem__(self, __key: str, __value: type[SubscriptionModel]) -> None:
1566
+ """Set value for key in while validating against Pydantic BaseModel."""
1567
+ validate_base_model(__key, __value)
1568
+ super().__setitem__(__key, __value)
1569
+
1570
+ def update(
1571
+ self,
1572
+ m: Any = None,
1573
+ /,
1574
+ **kwargs: type[SubscriptionModel],
1575
+ ) -> None:
1576
+ """Update dictionary with mapping and/or kwargs using `__setitem__`."""
1577
+ if m:
1578
+ if isinstance(m, Mapping):
1579
+ for key, value in m.items():
1580
+ self[key] = value
1581
+ elif isinstance(m, Iterable):
1582
+ for index, item in enumerate(m):
1583
+ try:
1584
+ key, value = item
1585
+ except ValueError:
1586
+ raise TypeError(f"dictionary update sequence element #{index} is not an iterable of length 2")
1587
+ self[key] = value
1588
+ for key, value in kwargs.items():
1589
+ self[key] = value
1590
+
1404
1591
 
1405
1592
  def _validate_lifecycle_change_for_product_block(
1406
1593
  used_by: SubscriptionInstanceTable,
@@ -0,0 +1,62 @@
1
+ # Copyright 2019-2025 SURF.
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ import contextlib
14
+ from contextvars import ContextVar
15
+ from typing import Iterator
16
+ from uuid import UUID
17
+
18
+ from orchestrator.domain import SubscriptionModel
19
+ from pydantic_forms.types import UUIDstr
20
+
21
+ __subscription_model_cache: ContextVar[dict[UUID, SubscriptionModel] | None] = ContextVar(
22
+ "subscription_model_cache", default=None
23
+ )
24
+
25
+
26
+ @contextlib.contextmanager
27
+ def cache_subscription_models() -> Iterator:
28
+ """Caches SubscriptionModels for the duration of the context.
29
+
30
+ Inside this context, calling SubscriptionModel.from_subscription() twice with the same
31
+ subscription id will return the same instance.
32
+
33
+ The primary usecase is to improve performance of `@computed_field` properties on product blocks
34
+ which load other subscriptions.
35
+
36
+ Example usage:
37
+ subscription = SubscriptionModel.from_subscription("...")
38
+ with cache_subscription_models():
39
+ subscription_dict = subscription.model_dump()
40
+ """
41
+ before = __subscription_model_cache.set({})
42
+ try:
43
+ yield
44
+ finally:
45
+ __subscription_model_cache.reset(before)
46
+
47
+
48
+ def get_from_cache(subscription_id: UUID | UUIDstr) -> SubscriptionModel | None:
49
+ """Retrieve SubscriptionModel from cache, if present."""
50
+ if (cache := __subscription_model_cache.get()) is None:
51
+ return None
52
+
53
+ id_ = subscription_id if isinstance(subscription_id, UUID) else UUID(subscription_id)
54
+ return cache.get(id_, None)
55
+
56
+
57
+ def store_in_cache(model: SubscriptionModel) -> None:
58
+ """Store SubscriptionModel in cache, if required."""
59
+ if (cache := __subscription_model_cache.get()) is None:
60
+ return
61
+
62
+ cache[model.subscription_id] = model
@@ -1,6 +1,15 @@
1
- from collections.abc import Iterable
1
+ import contextlib
2
+ from collections.abc import Iterable, Iterator
3
+ from typing import Any
4
+ from uuid import UUID
2
5
 
6
+ from pydantic import BaseModel
7
+ from sqlalchemy import select
8
+
9
+ from orchestrator.db import ProductBlockTable, SubscriptionInstanceTable, db
3
10
  from orchestrator.types import filter_nonetype, get_origin_and_args, is_union_type
11
+ from orchestrator.utils.functional import group_by_key
12
+ from pydantic_forms.types import UUIDstr
4
13
 
5
14
 
6
15
  def _to_product_block_field_type_iterable(product_block_field_type: type | tuple[type]) -> Iterable[type]:
@@ -21,3 +30,34 @@ def _to_product_block_field_type_iterable(product_block_field_type: type | tuple
21
30
  return product_block_field_type
22
31
 
23
32
  return [product_block_field_type]
33
+
34
+
35
+ @contextlib.contextmanager
36
+ def no_private_attrs(model: Any) -> Iterator:
37
+ """PrivateAttrs from the given pydantic BaseModel are removed for the duration of this context."""
38
+ if not isinstance(model, BaseModel):
39
+ yield
40
+ return
41
+ private_attrs_reference = model.__pydantic_private__
42
+ try:
43
+ model.__pydantic_private__ = {}
44
+ yield
45
+ finally:
46
+ model.__pydantic_private__ = private_attrs_reference
47
+
48
+
49
+ def get_root_blocks_to_instance_ids(subscription_id: UUID | UUIDstr) -> dict[str, list[UUID]]:
50
+ """Returns mapping of root product block names to list of subscription instance ids.
51
+
52
+ While recommended practice is to have only 1 root product block, it is possible to have multiple blocks or even a
53
+ list of root blocks. This function supports that.
54
+ """
55
+ block_name_to_instance_id_rows = db.session.execute(
56
+ select(ProductBlockTable.name, SubscriptionInstanceTable.subscription_instance_id)
57
+ .select_from(SubscriptionInstanceTable)
58
+ .join(ProductBlockTable)
59
+ .where(SubscriptionInstanceTable.subscription_id == subscription_id)
60
+ .order_by(ProductBlockTable.name)
61
+ ).all()
62
+
63
+ return group_by_key(block_name_to_instance_id_rows) # type: ignore[arg-type]
@@ -1,4 +1,4 @@
1
- # Copyright 2019-2020 SURF.
1
+ # Copyright 2019-2020 SURF, GÉANT.
2
2
  # Licensed under the Apache License, Version 2.0 (the "License");
3
3
  # you may not use this file except in compliance with the License.
4
4
  # You may obtain a copy of the License at