infrahub-server 1.2.2__py3-none-any.whl → 1.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. infrahub/computed_attribute/tasks.py +8 -8
  2. infrahub/config.py +3 -0
  3. infrahub/core/graph/__init__.py +1 -1
  4. infrahub/core/migrations/graph/__init__.py +4 -1
  5. infrahub/core/migrations/graph/m024_missing_hierarchy_backfill.py +69 -0
  6. infrahub/core/models.py +6 -0
  7. infrahub/core/node/__init__.py +4 -4
  8. infrahub/core/node/constraints/grouped_uniqueness.py +24 -9
  9. infrahub/core/query/ipam.py +1 -1
  10. infrahub/core/schema/schema_branch.py +14 -5
  11. infrahub/git/integrator.py +9 -7
  12. infrahub/menu/repository.py +6 -6
  13. infrahub_sdk/client.py +6 -6
  14. infrahub_sdk/ctl/cli_commands.py +32 -37
  15. infrahub_sdk/ctl/render.py +39 -0
  16. infrahub_sdk/exceptions.py +6 -2
  17. infrahub_sdk/generator.py +1 -1
  18. infrahub_sdk/node.py +38 -11
  19. infrahub_sdk/protocols_base.py +8 -1
  20. infrahub_sdk/pytest_plugin/items/jinja2_transform.py +22 -26
  21. infrahub_sdk/store.py +351 -75
  22. infrahub_sdk/template/__init__.py +209 -0
  23. infrahub_sdk/template/exceptions.py +38 -0
  24. infrahub_sdk/template/filters.py +151 -0
  25. infrahub_sdk/template/models.py +10 -0
  26. infrahub_sdk/utils.py +7 -0
  27. {infrahub_server-1.2.2.dist-info → infrahub_server-1.2.3.dist-info}/METADATA +2 -1
  28. {infrahub_server-1.2.2.dist-info → infrahub_server-1.2.3.dist-info}/RECORD +34 -31
  29. infrahub_testcontainers/container.py +2 -0
  30. infrahub_testcontainers/docker-compose.test.yml +1 -0
  31. infrahub_testcontainers/haproxy.cfg +3 -3
  32. infrahub/support/__init__.py +0 -0
  33. infrahub/support/macro.py +0 -69
  34. {infrahub_server-1.2.2.dist-info → infrahub_server-1.2.3.dist-info}/LICENSE.txt +0 -0
  35. {infrahub_server-1.2.2.dist-info → infrahub_server-1.2.3.dist-info}/WHEEL +0 -0
  36. {infrahub_server-1.2.2.dist-info → infrahub_server-1.2.3.dist-info}/entry_points.txt +0 -0
@@ -6,6 +6,7 @@ from infrahub_sdk.protocols import (
6
6
  CoreNode, # noqa: TC002
7
7
  CoreTransformPython,
8
8
  )
9
+ from infrahub_sdk.template import Jinja2Template
9
10
  from prefect import flow
10
11
  from prefect.client.orchestration import get_client
11
12
  from prefect.logging import get_run_logger
@@ -16,7 +17,6 @@ from infrahub.core.registry import registry
16
17
  from infrahub.events import BranchDeletedEvent
17
18
  from infrahub.git.repository import get_initialized_repo
18
19
  from infrahub.services import InfrahubServices # noqa: TC001 needed for prefect flow
19
- from infrahub.support.macro import MacroDefinition
20
20
  from infrahub.trigger.models import TriggerType
21
21
  from infrahub.trigger.setup import setup_triggers
22
22
  from infrahub.workflows.catalogue import (
@@ -173,15 +173,15 @@ async def update_computed_attribute_value_jinja2(
173
173
 
174
174
  await add_tags(branches=[branch_name], nodes=[obj.id], db_change=True)
175
175
 
176
- macro_definition = MacroDefinition(macro=template_value)
177
- my_filter = {}
178
- for variable in macro_definition.variables:
176
+ jinja_template = Jinja2Template(template=template_value)
177
+ variables = {}
178
+ for variable in jinja_template.get_variables():
179
179
  components = variable.split("__")
180
180
  if len(components) == 2:
181
181
  property_name = components[0]
182
182
  property_value = components[1]
183
183
  attribute_property = getattr(obj, property_name)
184
- my_filter[variable] = getattr(attribute_property, property_value)
184
+ variables[variable] = getattr(attribute_property, property_value)
185
185
  elif len(components) == 3:
186
186
  relationship_name = components[0]
187
187
  property_name = components[1]
@@ -189,11 +189,11 @@ async def update_computed_attribute_value_jinja2(
189
189
  relationship = getattr(obj, relationship_name)
190
190
  try:
191
191
  attribute_property = getattr(relationship.peer, property_name)
192
- my_filter[variable] = getattr(attribute_property, property_value)
192
+ variables[variable] = getattr(attribute_property, property_value)
193
193
  except ValueError:
194
- my_filter[variable] = ""
194
+ variables[variable] = ""
195
195
 
196
- value = macro_definition.render(variables=my_filter)
196
+ value = await jinja_template.render(variables=variables)
197
197
  existing_value = getattr(obj, attribute_name).value
198
198
  if value == existing_value:
199
199
  log.debug(f"Ignoring to update {obj} with existing value on {attribute_name}={value}")
infrahub/config.py CHANGED
@@ -612,6 +612,9 @@ class SecuritySettings(BaseSettings):
612
612
  oauth2_provider_settings: SecurityOAuth2ProviderSettings = Field(default_factory=SecurityOAuth2ProviderSettings)
613
613
  oidc_providers: list[OIDCProvider] = Field(default_factory=list, description="The selected OIDC providers")
614
614
  oidc_provider_settings: SecurityOIDCProviderSettings = Field(default_factory=SecurityOIDCProviderSettings)
615
+ restrict_untrusted_jinja2_filters: bool = Field(
616
+ default=True, description="Indicates if untrusted Jinja2 filters should be disallowd for computed attributes"
617
+ )
615
618
  _oauth2_settings: dict[str, SecurityOAuth2Settings] = PrivateAttr(default_factory=dict)
616
619
  _oidc_settings: dict[str, SecurityOIDCSettings] = PrivateAttr(default_factory=dict)
617
620
  sso_user_default_group: str | None = Field(
@@ -1 +1 @@
1
- GRAPH_VERSION = 22
1
+ GRAPH_VERSION = 24
@@ -24,6 +24,8 @@ from .m019_restore_rels_to_time import Migration019
24
24
  from .m020_duplicate_edges import Migration020
25
25
  from .m021_missing_hierarchy_merge import Migration021
26
26
  from .m022_add_generate_template_attr import Migration022
27
+ from .m023_deduplicate_cardinality_one_relationships import Migration023
28
+ from .m024_missing_hierarchy_backfill import Migration024
27
29
 
28
30
  if TYPE_CHECKING:
29
31
  from infrahub.core.root import Root
@@ -53,7 +55,8 @@ MIGRATIONS: list[type[GraphMigration | InternalSchemaMigration | ArbitraryMigrat
53
55
  Migration020,
54
56
  Migration021,
55
57
  Migration022,
56
- # Migration023, Enable this migration once it has been tested on bigger databases
58
+ Migration023,
59
+ Migration024,
57
60
  ]
58
61
 
59
62
 
@@ -0,0 +1,69 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, Any, Sequence
4
+
5
+ from infrahub.core import registry
6
+ from infrahub.core.initialization import initialization
7
+ from infrahub.core.migrations.shared import GraphMigration, MigrationResult
8
+ from infrahub.lock import initialize_lock
9
+ from infrahub.log import get_logger
10
+
11
+ from ...query import Query, QueryType
12
+
13
+ if TYPE_CHECKING:
14
+ from infrahub.database import InfrahubDatabase
15
+
16
+ log = get_logger()
17
+
18
+
19
+ class BackfillMissingHierarchyQuery(Query):
20
+ name = "backfill_missing_hierarchy"
21
+ type = QueryType.WRITE
22
+ insert_return = False
23
+
24
+ async def query_init(self, db: InfrahubDatabase, **kwargs: dict[str, Any]) -> None: # noqa: ARG002
25
+ # load schemas from database into registry
26
+ initialize_lock()
27
+ await initialization(db=db)
28
+ kind_hierarchy_map: dict[str, str] = {}
29
+ schema_branch = await registry.schema.load_schema_from_db(db=db)
30
+ for node_schema_kind in schema_branch.node_names:
31
+ node_schema = schema_branch.get_node(name=node_schema_kind, duplicate=False)
32
+ if node_schema.hierarchy:
33
+ kind_hierarchy_map[node_schema.kind] = node_schema.hierarchy
34
+
35
+ self.params = {"hierarchy_map": kind_hierarchy_map}
36
+ query = """
37
+ MATCH (r:Root)
38
+ WITH r.default_branch AS default_branch
39
+ MATCH (rel:Relationship {name: "parent__child"})-[e:IS_RELATED]-(n:Node)
40
+ WHERE e.hierarchy IS NULL
41
+ WITH DISTINCT rel, n, default_branch
42
+ CALL {
43
+ WITH rel, n, default_branch
44
+ MATCH (rel)-[e:IS_RELATED {branch: default_branch}]-(n)
45
+ RETURN e
46
+ ORDER BY e.from DESC
47
+ LIMIT 1
48
+ }
49
+ WITH rel, n, e
50
+ WHERE e.status = "active" AND e.hierarchy IS NULL
51
+ SET e.hierarchy = $hierarchy_map[n.kind]
52
+ """
53
+ self.add_to_query(query)
54
+
55
+
56
+ class Migration024(GraphMigration):
57
+ """
58
+ A bug in diff merge logic caused the hierarchy information on IS_RELATED edges to be lost when merged into
59
+ main. This migration backfills the missing hierarchy data and accounts for the case when the branch that
60
+ created the data has been deleted.
61
+ """
62
+
63
+ name: str = "024_backfill_hierarchy"
64
+ minimum_version: int = 23
65
+ queries: Sequence[type[Query]] = [BackfillMissingHierarchyQuery]
66
+
67
+ async def validate_migration(self, db: InfrahubDatabase) -> MigrationResult: # noqa: ARG002
68
+ result = MigrationResult()
69
+ return result
infrahub/core/models.py CHANGED
@@ -19,6 +19,8 @@ if TYPE_CHECKING:
19
19
  from infrahub.core.schema import MainSchemaTypes
20
20
  from infrahub.core.schema.schema_branch import SchemaBranch
21
21
 
22
+ GENERIC_ATTRIBUTES_TO_IGNORE = ["namespace", "name", "branch"]
23
+
22
24
 
23
25
  class NodeKind(BaseModel):
24
26
  namespace: str
@@ -270,6 +272,10 @@ class SchemaUpdateValidationResult(BaseModel):
270
272
  field_info = schema.model_fields[node_field_name]
271
273
  field_update = str(field_info.json_schema_extra.get("update")) # type: ignore[union-attr]
272
274
 
275
+ # No need to execute a migration for generic nodes attributes because they are not stored in the database
276
+ if schema.is_generic_schema and node_field_name in GENERIC_ATTRIBUTES_TO_IGNORE:
277
+ return
278
+
273
279
  schema_path = SchemaPath( # type: ignore[call-arg]
274
280
  schema_kind=schema.kind,
275
281
  path_type=SchemaPathType.NODE,
@@ -3,6 +3,7 @@ from __future__ import annotations
3
3
  from enum import Enum
4
4
  from typing import TYPE_CHECKING, Any, Sequence, TypeVar, overload
5
5
 
6
+ from infrahub_sdk.template import Jinja2Template
6
7
  from infrahub_sdk.utils import is_valid_uuid
7
8
  from infrahub_sdk.uuidt import UUIDT
8
9
 
@@ -24,7 +25,6 @@ from infrahub.core.query.node import NodeCheckIDQuery, NodeCreateAllQuery, NodeD
24
25
  from infrahub.core.schema import AttributeSchema, NodeSchema, ProfileSchema, RelationshipSchema, TemplateSchema
25
26
  from infrahub.core.timestamp import Timestamp
26
27
  from infrahub.exceptions import InitializationError, NodeNotFoundError, PoolExhaustedError, ValidationError
27
- from infrahub.support.macro import MacroDefinition
28
28
  from infrahub.types import ATTRIBUTE_TYPES
29
29
 
30
30
  from ...graphql.constants import KIND_GRAPHQL_FIELD_NAME
@@ -458,9 +458,9 @@ class Node(BaseNode, metaclass=BaseNodeMeta):
458
458
  ValidationError({macro: f"{macro} is missing computational_logic for macro ({attr_schema.kind})"})
459
459
  )
460
460
  continue
461
- macro_definition = MacroDefinition(macro=attr_schema.computed_attribute.jinja2_template)
462
461
 
463
- for variable in macro_definition.variables:
462
+ jinja_template = Jinja2Template(template=attr_schema.computed_attribute.jinja2_template)
463
+ for variable in jinja_template.get_variables():
464
464
  attribute_path = schema_branch.validate_schema_path(
465
465
  node_schema=self._schema, path=variable, allowed_path_types=allowed_path_types
466
466
  )
@@ -487,7 +487,7 @@ class Node(BaseNode, metaclass=BaseNodeMeta):
487
487
  )
488
488
  variables[variable] = attribute
489
489
 
490
- content = macro_definition.render(variables=variables)
490
+ content = await jinja_template.render(variables=variables)
491
491
 
492
492
  generator_method_name = "_generate_attribute_default"
493
493
  if hasattr(self, f"generate_{attr_schema.name}"):
@@ -225,16 +225,31 @@ class NodeGroupedUniquenessConstraint(NodeConstraintInterface):
225
225
  )
226
226
  violations.extend(schema_violations)
227
227
 
228
- is_hfid_violated = any(violation.typ == UniquenessConstraintType.HFID for violation in violations)
228
+ hfid_violations = [violation for violation in violations if violation.typ == UniquenessConstraintType.HFID]
229
+ hfid_violation = hfid_violations[0] if len(hfid_violations) > 0 else None
229
230
 
230
- for violation in violations:
231
- if violation.typ == UniquenessConstraintType.STANDARD or (
232
- violation.typ == UniquenessConstraintType.SUBSET_OF_HFID and not is_hfid_violated
233
- ):
234
- error_msg = f"Violates uniqueness constraint '{'-'.join(violation.fields)}'"
235
- raise ValidationError(error_msg)
231
+ # If there are both a hfid violation and another one, in case of an upsert, we still want to update the node in case other violations are:
232
+ # - either on subset fields of hfid, which would be necessarily violated too
233
+ # - or on uniqueness constraints with a matching node id being the id of the hfid violation
236
234
 
237
235
  for violation in violations:
238
236
  if violation.typ == UniquenessConstraintType.HFID:
239
- error_msg = f"Violates uniqueness constraint '{'-'.join(violation.fields)}'"
240
- raise HFIDViolatedError(error_msg, matching_nodes_ids=violation.nodes_ids)
237
+ continue
238
+
239
+ if hfid_violation:
240
+ if violation.typ == UniquenessConstraintType.SUBSET_OF_HFID:
241
+ continue
242
+
243
+ if (
244
+ violation.typ == UniquenessConstraintType.STANDARD
245
+ and len(violation.nodes_ids) == 1
246
+ and next(iter(violation.nodes_ids)) == next(iter(hfid_violation.nodes_ids))
247
+ ):
248
+ continue
249
+
250
+ error_msg = f"Violates uniqueness constraint '{'-'.join(violation.fields)}'"
251
+ raise ValidationError(error_msg)
252
+
253
+ if hfid_violation:
254
+ error_msg = f"Violates uniqueness constraint '{'-'.join(hfid_violation.fields)}'"
255
+ raise HFIDViolatedError(error_msg, matching_nodes_ids=hfid_violation.nodes_ids)
@@ -362,7 +362,7 @@ class IPPrefixReconcileQuery(Query):
362
362
  # possible prefix: highest possible prefix length for a match
363
363
  possible_prefix_map: dict[str, int] = {}
364
364
  start_prefixlen = prefixlen if is_address else prefixlen - 1
365
- for max_prefix_len in range(start_prefixlen, 0, -1):
365
+ for max_prefix_len in range(start_prefixlen, -1, -1):
366
366
  tmp_prefix = prefix_bin_host[:max_prefix_len]
367
367
  possible_prefix = tmp_prefix.ljust(self.ip_value.max_prefixlen, "0")
368
368
  if possible_prefix not in possible_prefix_map:
@@ -6,6 +6,8 @@ from collections import defaultdict
6
6
  from itertools import chain, combinations
7
7
  from typing import Any
8
8
 
9
+ from infrahub_sdk.template import Jinja2Template
10
+ from infrahub_sdk.template.exceptions import JinjaTemplateError, JinjaTemplateOperationViolationError
9
11
  from infrahub_sdk.topological_sort import DependencyCycleExistsError, topological_sort
10
12
  from infrahub_sdk.utils import compare_lists, deep_merge_dict, duplicates, intersection
11
13
  from typing_extensions import Self
@@ -51,7 +53,6 @@ from infrahub.core.schema.definitions.core import core_profile_schema_definition
51
53
  from infrahub.core.validators import CONSTRAINT_VALIDATOR_MAP
52
54
  from infrahub.exceptions import SchemaNotFoundError, ValidationError
53
55
  from infrahub.log import get_logger
54
- from infrahub.support.macro import MacroDefinition
55
56
  from infrahub.types import ATTRIBUTE_TYPES
56
57
  from infrahub.utils import format_label
57
58
  from infrahub.visuals import select_color
@@ -1037,14 +1038,22 @@ class SchemaBranch:
1037
1038
  | SchemaElementPathType.REL_ONE_MANDATORY_ATTR_WITH_PROP
1038
1039
  | SchemaElementPathType.REL_ONE_ATTR_WITH_PROP
1039
1040
  )
1041
+
1042
+ jinja_template = Jinja2Template(template=attribute.computed_attribute.jinja2_template)
1040
1043
  try:
1041
- macro = MacroDefinition(macro=attribute.computed_attribute.jinja2_template)
1042
- except ValueError as exc:
1044
+ variables = jinja_template.get_variables()
1045
+ jinja_template.validate(restricted=config.SETTINGS.security.restrict_untrusted_jinja2_filters)
1046
+ except JinjaTemplateOperationViolationError as exc:
1047
+ raise ValueError(
1048
+ f"{node.kind}: Attribute {attribute.name!r} is assigned by a jinja2 template, but has an invalid template: {exc.message}"
1049
+ ) from exc
1050
+
1051
+ except JinjaTemplateError as exc:
1043
1052
  raise ValueError(
1044
- f"{node.kind}: Attribute {attribute.name!r} is assigned by a jinja2 template, but has an invalid template"
1053
+ f"{node.kind}: Attribute {attribute.name!r} is assigned by a jinja2 template, but has an invalid template: : {exc.message}"
1045
1054
  ) from exc
1046
1055
 
1047
- for variable in macro.variables:
1056
+ for variable in variables:
1048
1057
  try:
1049
1058
  schema_path = self.validate_schema_path(
1050
1059
  node_schema=node, path=variable, allowed_path_types=allowed_path_types
@@ -3,9 +3,9 @@ from __future__ import annotations
3
3
  import hashlib
4
4
  import importlib
5
5
  import sys
6
+ from pathlib import Path
6
7
  from typing import TYPE_CHECKING, Any
7
8
 
8
- import jinja2
9
9
  import ujson
10
10
  import yaml
11
11
  from infrahub_sdk import InfrahubClient # noqa: TC002
@@ -28,6 +28,8 @@ from infrahub_sdk.schema.repository import (
28
28
  InfrahubPythonTransformConfig,
29
29
  InfrahubRepositoryConfig,
30
30
  )
31
+ from infrahub_sdk.template import Jinja2Template
32
+ from infrahub_sdk.template.exceptions import JinjaTemplateError
31
33
  from infrahub_sdk.utils import compare_lists
32
34
  from infrahub_sdk.yaml import SchemaFile
33
35
  from prefect import flow, task
@@ -1057,14 +1059,14 @@ class InfrahubRepositoryIntegrator(InfrahubRepositoryBase):
1057
1059
 
1058
1060
  self.validate_location(commit=commit, worktree_directory=commit_worktree.directory, file_path=location)
1059
1061
 
1062
+ jinja2_template = Jinja2Template(template=Path(location), template_directory=Path(commit_worktree.directory))
1060
1063
  try:
1061
- templateLoader = jinja2.FileSystemLoader(searchpath=commit_worktree.directory)
1062
- templateEnv = jinja2.Environment(loader=templateLoader, trim_blocks=True, lstrip_blocks=True)
1063
- template = templateEnv.get_template(location)
1064
- return template.render(**data)
1065
- except Exception as exc:
1064
+ return await jinja2_template.render(variables=data)
1065
+ except JinjaTemplateError as exc:
1066
1066
  log.error(str(exc), exc_info=True)
1067
- raise TransformError(repository_name=self.name, commit=commit, location=location, message=str(exc)) from exc
1067
+ raise TransformError(
1068
+ repository_name=self.name, commit=commit, location=location, message=exc.message
1069
+ ) from exc
1068
1070
 
1069
1071
  @task(name="python-check-execute", task_run_name="Execute Python Check", cache_policy=NONE) # type: ignore[arg-type]
1070
1072
  async def execute_python_check(
@@ -20,9 +20,9 @@ class MenuRepository:
20
20
  async def add_children(menu_item: MenuItemDict, menu_node: CoreMenuItem) -> MenuItemDict:
21
21
  children = await menu_node.children.get_peers(db=self.db, peer_type=CoreMenuItem)
22
22
  for child_id, child_node in children.items():
23
- child_menu_item = menu_by_ids[child_id]
24
- child = await add_children(child_menu_item, child_node)
25
- menu_item.children[str(child.identifier)] = child
23
+ if child_menu_item := menu_by_ids.get(child_id):
24
+ child = await add_children(child_menu_item, child_node)
25
+ menu_item.children[str(child.identifier)] = child
26
26
  return menu_item
27
27
 
28
28
  for menu_node in nodes.values():
@@ -33,9 +33,9 @@ class MenuRepository:
33
33
 
34
34
  children = await menu_node.children.get_peers(db=self.db, peer_type=CoreMenuItem)
35
35
  for child_id, child_node in children.items():
36
- child_menu_item = menu_by_ids[child_id]
37
- child = await add_children(child_menu_item, child_node)
38
- menu_item.children[str(child.identifier)] = child
36
+ if child_menu_item := menu_by_ids.get(child_id):
37
+ child = await add_children(child_menu_item, child_node)
38
+ menu_item.children[str(child.identifier)] = child
39
39
 
40
40
  menu.data[str(menu_item.identifier)] = menu_item
41
41
 
infrahub_sdk/client.py CHANGED
@@ -281,7 +281,7 @@ class InfrahubClient(BaseClient):
281
281
  self.schema = InfrahubSchema(self)
282
282
  self.branch = InfrahubBranchManager(self)
283
283
  self.object_store = ObjectStore(self)
284
- self.store = NodeStore()
284
+ self.store = NodeStore(default_branch=self.default_branch)
285
285
  self.task = InfrahubTaskManager(self)
286
286
  self.concurrent_execution_limit = asyncio.Semaphore(self.max_concurrent_execution)
287
287
  self._request_method: AsyncRequester = self.config.requester or self._default_request_method
@@ -840,11 +840,11 @@ class InfrahubClient(BaseClient):
840
840
  if populate_store:
841
841
  for node in nodes:
842
842
  if node.id:
843
- self.store.set(key=node.id, node=node)
843
+ self.store.set(node=node)
844
844
  related_nodes = list(set(related_nodes))
845
845
  for node in related_nodes:
846
846
  if node.id:
847
- self.store.set(key=node.id, node=node)
847
+ self.store.set(node=node)
848
848
  return nodes
849
849
 
850
850
  def clone(self) -> InfrahubClient:
@@ -1529,7 +1529,7 @@ class InfrahubClientSync(BaseClient):
1529
1529
  self.schema = InfrahubSchemaSync(self)
1530
1530
  self.branch = InfrahubBranchManagerSync(self)
1531
1531
  self.object_store = ObjectStoreSync(self)
1532
- self.store = NodeStoreSync()
1532
+ self.store = NodeStoreSync(default_branch=self.default_branch)
1533
1533
  self.task = InfrahubTaskManagerSync(self)
1534
1534
  self._request_method: SyncRequester = self.config.sync_requester or self._default_request_method
1535
1535
  self.group_context = InfrahubGroupContextSync(self)
@@ -1997,11 +1997,11 @@ class InfrahubClientSync(BaseClient):
1997
1997
  if populate_store:
1998
1998
  for node in nodes:
1999
1999
  if node.id:
2000
- self.store.set(key=node.id, node=node)
2000
+ self.store.set(node=node)
2001
2001
  related_nodes = list(set(related_nodes))
2002
2002
  for node in related_nodes:
2003
2003
  if node.id:
2004
- self.store.set(key=node.id, node=node)
2004
+ self.store.set(node=node)
2005
2005
  return nodes
2006
2006
 
2007
2007
  @overload
@@ -9,7 +9,6 @@ import sys
9
9
  from pathlib import Path
10
10
  from typing import TYPE_CHECKING, Any, Callable, Optional
11
11
 
12
- import jinja2
13
12
  import typer
14
13
  import ujson
15
14
  from rich.console import Console
@@ -18,7 +17,6 @@ from rich.logging import RichHandler
18
17
  from rich.panel import Panel
19
18
  from rich.pretty import Pretty
20
19
  from rich.table import Table
21
- from rich.traceback import Traceback
22
20
 
23
21
  from .. import __version__ as sdk_version
24
22
  from ..async_typer import AsyncTyper
@@ -31,7 +29,7 @@ from ..ctl.exceptions import QueryNotFoundError
31
29
  from ..ctl.generator import run as run_generator
32
30
  from ..ctl.menu import app as menu_app
33
31
  from ..ctl.object import app as object_app
34
- from ..ctl.render import list_jinja2_transforms
32
+ from ..ctl.render import list_jinja2_transforms, print_template_errors
35
33
  from ..ctl.repository import app as repository_app
36
34
  from ..ctl.repository import get_repository_config
37
35
  from ..ctl.schema import app as schema_app
@@ -44,8 +42,9 @@ from ..ctl.utils import (
44
42
  )
45
43
  from ..ctl.validate import app as validate_app
46
44
  from ..exceptions import GraphQLError, ModuleImportError
47
- from ..jinja2 import identify_faulty_jinja_code
48
45
  from ..schema import MainSchemaTypesAll, SchemaRoot
46
+ from ..template import Jinja2Template
47
+ from ..template.exceptions import JinjaTemplateError
49
48
  from ..utils import get_branch, write_to_file
50
49
  from ..yaml import SchemaFile
51
50
  from .exporter import dump
@@ -168,43 +167,28 @@ async def run(
168
167
  raise typer.Abort(f"Unable to Load the method {method} in the Python script at {script}")
169
168
 
170
169
  client = initialize_client(
171
- branch=branch, timeout=timeout, max_concurrent_execution=concurrent, identifier=module_name
170
+ branch=branch,
171
+ timeout=timeout,
172
+ max_concurrent_execution=concurrent,
173
+ identifier=module_name,
172
174
  )
173
175
  func = getattr(module, method)
174
176
  await func(client=client, log=log, branch=branch, **variables_dict)
175
177
 
176
178
 
177
- def render_jinja2_template(template_path: Path, variables: dict[str, str], data: dict[str, Any]) -> str:
178
- if not template_path.is_file():
179
- console.print(f"[red]Unable to locate the template at {template_path}")
180
- raise typer.Exit(1)
181
-
182
- templateLoader = jinja2.FileSystemLoader(searchpath=".")
183
- templateEnv = jinja2.Environment(loader=templateLoader, trim_blocks=True, lstrip_blocks=True)
184
- template = templateEnv.get_template(str(template_path))
185
-
179
+ async def render_jinja2_template(template_path: Path, variables: dict[str, Any], data: dict[str, Any]) -> str:
180
+ variables["data"] = data
181
+ jinja_template = Jinja2Template(template=Path(template_path), template_directory=Path())
186
182
  try:
187
- rendered_tpl = template.render(**variables, data=data) # type: ignore[arg-type]
188
- except jinja2.TemplateSyntaxError as exc:
189
- console.print("[red]Syntax Error detected on the template")
190
- console.print(f"[yellow] {exc}")
191
- raise typer.Exit(1) from exc
192
-
193
- except jinja2.UndefinedError as exc:
194
- console.print("[red]An error occurred while rendering the jinja template")
195
- traceback = Traceback(show_locals=False)
196
- errors = identify_faulty_jinja_code(traceback=traceback)
197
- for frame, syntax in errors:
198
- console.print(f"[yellow]{frame.filename} on line {frame.lineno}\n")
199
- console.print(syntax)
200
- console.print("")
201
- console.print(traceback.trace.stacks[0].exc_value)
183
+ rendered_tpl = await jinja_template.render(variables=variables)
184
+ except JinjaTemplateError as exc:
185
+ print_template_errors(error=exc, console=console)
202
186
  raise typer.Exit(1) from exc
203
187
 
204
188
  return rendered_tpl
205
189
 
206
190
 
207
- def _run_transform(
191
+ async def _run_transform(
208
192
  query_name: str,
209
193
  variables: dict[str, Any],
210
194
  transform_func: Callable,
@@ -227,7 +211,11 @@ def _run_transform(
227
211
 
228
212
  try:
229
213
  response = execute_graphql_query(
230
- query=query_name, variables_dict=variables, branch=branch, debug=debug, repository_config=repository_config
214
+ query=query_name,
215
+ variables_dict=variables,
216
+ branch=branch,
217
+ debug=debug,
218
+ repository_config=repository_config,
231
219
  )
232
220
 
233
221
  # TODO: response is a dict and can't be printed to the console in this way.
@@ -249,7 +237,7 @@ def _run_transform(
249
237
  raise typer.Abort()
250
238
 
251
239
  if asyncio.iscoroutinefunction(transform_func):
252
- output = asyncio.run(transform_func(response))
240
+ output = await transform_func(response)
253
241
  else:
254
242
  output = transform_func(response)
255
243
  return output
@@ -257,7 +245,7 @@ def _run_transform(
257
245
 
258
246
  @app.command(name="render")
259
247
  @catch_exception(console=console)
260
- def render(
248
+ async def render(
261
249
  transform_name: str = typer.Argument(default="", help="Name of the Python transformation", show_default=False),
262
250
  variables: Optional[list[str]] = typer.Argument(
263
251
  None, help="Variables to pass along with the query. Format key=value key=value."
@@ -289,7 +277,7 @@ def render(
289
277
  transform_func = functools.partial(render_jinja2_template, transform_config.template_path, variables_dict)
290
278
 
291
279
  # Query GQL and run the transform
292
- result = _run_transform(
280
+ result = await _run_transform(
293
281
  query_name=transform_config.query,
294
282
  variables=variables_dict,
295
283
  transform_func=transform_func,
@@ -410,7 +398,10 @@ def version() -> None:
410
398
 
411
399
  @app.command(name="info")
412
400
  @catch_exception(console=console)
413
- def info(detail: bool = typer.Option(False, help="Display detailed information."), _: str = CONFIG_PARAM) -> None: # noqa: PLR0915
401
+ def info( # noqa: PLR0915
402
+ detail: bool = typer.Option(False, help="Display detailed information."),
403
+ _: str = CONFIG_PARAM,
404
+ ) -> None:
414
405
  """Display the status of the Python SDK."""
415
406
 
416
407
  info: dict[str, Any] = {
@@ -476,10 +467,14 @@ def info(detail: bool = typer.Option(False, help="Display detailed information."
476
467
  infrahub_info = Table(show_header=False, box=None)
477
468
  if info["user_info"]:
478
469
  infrahub_info.add_row("User:", info["user_info"]["AccountProfile"]["display_label"])
479
- infrahub_info.add_row("Description:", info["user_info"]["AccountProfile"]["description"]["value"])
470
+ infrahub_info.add_row(
471
+ "Description:",
472
+ info["user_info"]["AccountProfile"]["description"]["value"],
473
+ )
480
474
  infrahub_info.add_row("Status:", info["user_info"]["AccountProfile"]["status"]["label"])
481
475
  infrahub_info.add_row(
482
- "Number of Groups:", str(info["user_info"]["AccountProfile"]["member_of_groups"]["count"])
476
+ "Number of Groups:",
477
+ str(info["user_info"]["AccountProfile"]["member_of_groups"]["count"]),
483
478
  )
484
479
 
485
480
  if groups := info["groups"]:
@@ -1,6 +1,12 @@
1
1
  from rich.console import Console
2
2
 
3
3
  from ..schema.repository import InfrahubRepositoryConfig
4
+ from ..template.exceptions import (
5
+ JinjaTemplateError,
6
+ JinjaTemplateNotFoundError,
7
+ JinjaTemplateSyntaxError,
8
+ JinjaTemplateUndefinedError,
9
+ )
4
10
 
5
11
 
6
12
  def list_jinja2_transforms(config: InfrahubRepositoryConfig) -> None:
@@ -9,3 +15,36 @@ def list_jinja2_transforms(config: InfrahubRepositoryConfig) -> None:
9
15
 
10
16
  for transform in config.jinja2_transforms:
11
17
  console.print(f"{transform.name} ({transform.template_path})")
18
+
19
+
20
+ def print_template_errors(error: JinjaTemplateError, console: Console) -> None:
21
+ if isinstance(error, JinjaTemplateNotFoundError):
22
+ console.print("[red]An error occurred while rendering the jinja template")
23
+ console.print("")
24
+ if error.base_template:
25
+ console.print(f"Base template: [yellow]{error.base_template}")
26
+ console.print(f"Missing template: [yellow]{error.filename}")
27
+ return
28
+
29
+ if isinstance(error, JinjaTemplateUndefinedError):
30
+ console.print("[red]An error occurred while rendering the jinja template")
31
+ for current_error in error.errors:
32
+ console.print(f"[yellow]{current_error.frame.filename} on line {current_error.frame.lineno}\n")
33
+ console.print(current_error.syntax)
34
+ console.print("")
35
+ console.print(error.message)
36
+ return
37
+
38
+ if isinstance(error, JinjaTemplateSyntaxError):
39
+ console.print("[red]A syntax error was encountered within the template")
40
+ console.print("")
41
+ if error.filename:
42
+ console.print(f"Filename: [yellow]{error.filename}")
43
+ console.print(f"Line number: [yellow]{error.lineno}")
44
+ console.print()
45
+ console.print(error.message)
46
+ return
47
+
48
+ console.print("[red]An error occurred while rendering the jinja template")
49
+ console.print("")
50
+ console.print(f"[yellow]{error.message}")
@@ -69,12 +69,12 @@ class ModuleImportError(Error):
69
69
  class NodeNotFoundError(Error):
70
70
  def __init__(
71
71
  self,
72
- node_type: str,
73
72
  identifier: Mapping[str, list[str]],
74
73
  message: str = "Unable to find the node in the database.",
75
74
  branch_name: str | None = None,
75
+ node_type: str | None = None,
76
76
  ):
77
- self.node_type = node_type
77
+ self.node_type = node_type or "unknown"
78
78
  self.identifier = identifier
79
79
  self.branch_name = branch_name
80
80
 
@@ -88,6 +88,10 @@ class NodeNotFoundError(Error):
88
88
  """
89
89
 
90
90
 
91
+ class NodeInvalidError(NodeNotFoundError):
92
+ pass
93
+
94
+
91
95
  class ResourceNotDefinedError(Error):
92
96
  """Raised when trying to access a resource that hasn't been defined."""
93
97