infrahub-server 1.2.1__py3-none-any.whl → 1.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. infrahub/computed_attribute/tasks.py +71 -67
  2. infrahub/config.py +3 -0
  3. infrahub/core/graph/__init__.py +1 -1
  4. infrahub/core/migrations/graph/__init__.py +4 -1
  5. infrahub/core/migrations/graph/m024_missing_hierarchy_backfill.py +69 -0
  6. infrahub/core/models.py +6 -0
  7. infrahub/core/node/__init__.py +4 -4
  8. infrahub/core/node/constraints/grouped_uniqueness.py +24 -9
  9. infrahub/core/query/ipam.py +1 -1
  10. infrahub/core/query/node.py +16 -5
  11. infrahub/core/schema/schema_branch.py +14 -5
  12. infrahub/exceptions.py +30 -2
  13. infrahub/git/base.py +80 -29
  14. infrahub/git/integrator.py +9 -31
  15. infrahub/menu/repository.py +6 -6
  16. infrahub/trigger/tasks.py +19 -18
  17. infrahub/workflows/utils.py +5 -5
  18. infrahub_sdk/client.py +6 -6
  19. infrahub_sdk/ctl/cli_commands.py +32 -37
  20. infrahub_sdk/ctl/render.py +39 -0
  21. infrahub_sdk/exceptions.py +6 -2
  22. infrahub_sdk/generator.py +1 -1
  23. infrahub_sdk/node.py +41 -12
  24. infrahub_sdk/protocols_base.py +8 -1
  25. infrahub_sdk/pytest_plugin/items/jinja2_transform.py +22 -26
  26. infrahub_sdk/store.py +351 -75
  27. infrahub_sdk/template/__init__.py +209 -0
  28. infrahub_sdk/template/exceptions.py +38 -0
  29. infrahub_sdk/template/filters.py +151 -0
  30. infrahub_sdk/template/models.py +10 -0
  31. infrahub_sdk/utils.py +7 -0
  32. {infrahub_server-1.2.1.dist-info → infrahub_server-1.2.3.dist-info}/METADATA +2 -1
  33. {infrahub_server-1.2.1.dist-info → infrahub_server-1.2.3.dist-info}/RECORD +39 -36
  34. infrahub_testcontainers/container.py +2 -0
  35. infrahub_testcontainers/docker-compose.test.yml +1 -0
  36. infrahub_testcontainers/haproxy.cfg +3 -3
  37. infrahub/support/__init__.py +0 -0
  38. infrahub/support/macro.py +0 -69
  39. {infrahub_server-1.2.1.dist-info → infrahub_server-1.2.3.dist-info}/LICENSE.txt +0 -0
  40. {infrahub_server-1.2.1.dist-info → infrahub_server-1.2.3.dist-info}/WHEEL +0 -0
  41. {infrahub_server-1.2.1.dist-info → infrahub_server-1.2.3.dist-info}/entry_points.txt +0 -0
infrahub/git/base.py CHANGED
@@ -17,13 +17,16 @@ from pydantic import BaseModel, ConfigDict, Field
17
17
  from pydantic import ValidationError as PydanticValidationError
18
18
 
19
19
  from infrahub.core.branch import Branch
20
- from infrahub.core.constants import InfrahubKind
20
+ from infrahub.core.constants import InfrahubKind, RepositoryOperationalStatus, RepositorySyncStatus
21
21
  from infrahub.core.registry import registry
22
22
  from infrahub.exceptions import (
23
23
  CommitNotFoundError,
24
24
  FileOutOfRepositoryError,
25
+ RepositoryConnectionError,
26
+ RepositoryCredentialsError,
25
27
  RepositoryError,
26
28
  RepositoryFileNotFoundError,
29
+ RepositoryInvalidBranchError,
27
30
  RepositoryInvalidFileSystemError,
28
31
  )
29
32
  from infrahub.git.constants import BRANCHES_DIRECTORY_NAME, COMMITS_DIRECTORY_NAME, TEMPORARY_DIRECTORY_NAME
@@ -200,6 +203,54 @@ class InfrahubRepositoryBase(BaseModel, ABC):
200
203
  """Return the path to the directory where the temp worktrees of all the commits pending validation are stored."""
201
204
  return self.directory_root / TEMPORARY_DIRECTORY_NAME
202
205
 
206
+ async def _update_operational_status(self, status: RepositoryOperationalStatus) -> None:
207
+ update_status = """
208
+ mutation UpdateRepositoryStatus(
209
+ $repo_id: String!,
210
+ $status: String!,
211
+ ) {
212
+ CoreGenericRepositoryUpdate(
213
+ data: {
214
+ id: $repo_id,
215
+ operational_status: { value: $status },
216
+ }
217
+ ) {
218
+ ok
219
+ }
220
+ }
221
+ """
222
+
223
+ await self.sdk.execute_graphql(
224
+ branch_name=self.infrahub_branch_name or registry.default_branch,
225
+ query=update_status,
226
+ variables={"repo_id": str(self.id), "status": status.value},
227
+ tracker="mutation-repository-update-operational-status",
228
+ )
229
+
230
+ async def _update_sync_status(self, branch_name: str, status: RepositorySyncStatus) -> None:
231
+ update_status = """
232
+ mutation UpdateRepositoryStatus(
233
+ $repo_id: String!,
234
+ $status: String!,
235
+ ) {
236
+ CoreGenericRepositoryUpdate(
237
+ data: {
238
+ id: $repo_id,
239
+ sync_status: { value: $status },
240
+ }
241
+ ) {
242
+ ok
243
+ }
244
+ }
245
+ """
246
+
247
+ await self.sdk.execute_graphql(
248
+ branch_name=branch_name,
249
+ query=update_status,
250
+ variables={"repo_id": str(self.id), "status": status.value},
251
+ tracker="mutation-repository-update-admin-status",
252
+ )
253
+
203
254
  def get_git_repo_main(self) -> Repo:
204
255
  """Return Git Repo object of the main repository.
205
256
 
@@ -340,7 +391,7 @@ class InfrahubRepositoryBase(BaseModel, ABC):
340
391
  repo = Repo.clone_from(self.location, self.directory_default)
341
392
  repo.git.checkout(checkout_ref or self.default_branch)
342
393
  except GitCommandError as exc:
343
- self._raise_enriched_error(error=exc)
394
+ await self._raise_enriched_error(error=exc)
344
395
 
345
396
  self.has_origin = True
346
397
 
@@ -572,7 +623,7 @@ class InfrahubRepositoryBase(BaseModel, ABC):
572
623
  try:
573
624
  br_repo.remotes.origin.pull(branch_name)
574
625
  except GitCommandError as exc:
575
- self._raise_enriched_error(error=exc, branch_name=branch_name)
626
+ await self._raise_enriched_error(error=exc, branch_name=branch_name)
576
627
  self.create_commit_worktree(str(br_repo.head.reference.commit))
577
628
  log.debug(
578
629
  f"Branch {branch_name} created in Git, tracking remote branch {remote_branch[0]}.",
@@ -668,7 +719,9 @@ class InfrahubRepositoryBase(BaseModel, ABC):
668
719
  try:
669
720
  repo.remotes.origin.fetch()
670
721
  except GitCommandError as exc:
671
- self._raise_enriched_error(error=exc)
722
+ await self._raise_enriched_error(error=exc)
723
+
724
+ await self._update_operational_status(status=RepositoryOperationalStatus.ONLINE)
672
725
 
673
726
  return True
674
727
 
@@ -765,7 +818,7 @@ class InfrahubRepositoryBase(BaseModel, ABC):
765
818
  commit_before = str(repo.head.commit)
766
819
  repo.remotes.origin.pull(branch_name)
767
820
  except GitCommandError as exc:
768
- self._raise_enriched_error(error=exc, branch_name=branch_name)
821
+ await self._raise_enriched_error(error=exc, branch_name=branch_name)
769
822
 
770
823
  commit_after = str(repo.head.commit)
771
824
 
@@ -862,37 +915,41 @@ class InfrahubRepositoryBase(BaseModel, ABC):
862
915
  except GitCommandError as exc:
863
916
  cls._raise_enriched_error_static(name=name, location=url, error=exc)
864
917
 
865
- def _raise_enriched_error(self, error: GitCommandError, branch_name: str | None = None) -> NoReturn:
866
- self._raise_enriched_error_static(
867
- error=error, name=self.name, location=self.location, branch_name=branch_name or self.default_branch
868
- )
918
+ async def _raise_enriched_error(self, error: GitCommandError, branch_name: str | None = None) -> NoReturn:
919
+ try:
920
+ self._raise_enriched_error_static(
921
+ error=error, name=self.name, location=self.location, branch_name=branch_name or self.default_branch
922
+ )
923
+ except RepositoryError as exc:
924
+ await self._update_operational_status(
925
+ status={
926
+ RepositoryConnectionError: RepositoryOperationalStatus.ERROR_CONNECTION,
927
+ RepositoryCredentialsError: RepositoryOperationalStatus.ERROR_CRED,
928
+ }.get(type(exc), RepositoryOperationalStatus.ERROR)
929
+ )
930
+ raise
869
931
 
870
932
  @staticmethod
871
933
  def _raise_enriched_error_static(
872
934
  error: GitCommandError, name: str, location: str, branch_name: str | None = None
873
935
  ) -> NoReturn:
874
936
  if "Repository not found" in error.stderr or "does not appear to be a git" in error.stderr:
875
- raise RepositoryError(
876
- identifier=name,
877
- message=f"Unable to clone the repository {name}, please check the address and the credential",
878
- ) from error
937
+ raise RepositoryConnectionError(identifier=name) from error
879
938
 
880
939
  if "error: pathspec" in error.stderr:
881
- raise RepositoryError(
882
- identifier=name,
883
- message=f"The branch {branch_name} isn't a valid branch for the repository {name} at {location}.",
884
- ) from error
940
+ raise RepositoryInvalidBranchError(identifier=name, branch_name=branch_name, location=location) from error
885
941
 
886
942
  if "SSL certificate problem" in error.stderr or "server certificate verification failed" in error.stderr:
887
- raise RepositoryError(
888
- identifier=name,
889
- message=f"SSL verification failed for {name}, please validate the certificate chain.",
943
+ raise RepositoryConnectionError(
944
+ identifier=name, message=f"SSL verification failed for {name}, please validate the certificate chain."
890
945
  ) from error
891
946
 
892
947
  if "authentication failed for" in error.stderr.lower():
893
- raise RepositoryError(
894
- identifier=name,
895
- message=f"Authentication failed for {name}, please validate the credentials.",
948
+ raise RepositoryCredentialsError(identifier=name) from error
949
+
950
+ if "fatal: could not read Username for" in error.stderr and "terminal prompts disable" in error.stderr:
951
+ raise RepositoryCredentialsError(
952
+ identifier=name, message=f"Unable to correctly lookup credentials for repository {name} ({location})."
896
953
  ) from error
897
954
 
898
955
  if any(err in error.stderr for err in ("Need to specify how to reconcile", "because you have unmerged files")):
@@ -901,12 +958,6 @@ class InfrahubRepositoryBase(BaseModel, ABC):
901
958
  message=f"Unable to pull the branch {branch_name} for repository {name}, there are conflicts that must be resolved.",
902
959
  ) from error
903
960
 
904
- if "fatal: could not read Username for" in error.stderr and "terminal prompts disable" in error.stderr:
905
- raise RepositoryError(
906
- identifier=name,
907
- message=f"Unable to correctly lookup credentials for repository {name} ({location}).",
908
- ) from error
909
-
910
961
  raise RepositoryError(identifier=name, message=error.stderr) from error
911
962
 
912
963
  def _get_mapped_remote_branch(self, branch_name: str) -> str:
@@ -3,9 +3,9 @@ from __future__ import annotations
3
3
  import hashlib
4
4
  import importlib
5
5
  import sys
6
+ from pathlib import Path
6
7
  from typing import TYPE_CHECKING, Any
7
8
 
8
- import jinja2
9
9
  import ujson
10
10
  import yaml
11
11
  from infrahub_sdk import InfrahubClient # noqa: TC002
@@ -28,6 +28,8 @@ from infrahub_sdk.schema.repository import (
28
28
  InfrahubPythonTransformConfig,
29
29
  InfrahubRepositoryConfig,
30
30
  )
31
+ from infrahub_sdk.template import Jinja2Template
32
+ from infrahub_sdk.template.exceptions import JinjaTemplateError
31
33
  from infrahub_sdk.utils import compare_lists
32
34
  from infrahub_sdk.yaml import SchemaFile
33
35
  from prefect import flow, task
@@ -212,30 +214,6 @@ class InfrahubRepositoryIntegrator(InfrahubRepositoryBase):
212
214
  )
213
215
  )
214
216
 
215
- async def _update_sync_status(self, branch_name: str, status: RepositorySyncStatus) -> None:
216
- update_status = """
217
- mutation UpdateRepositoryStatus(
218
- $repo_id: String!,
219
- $status: String!,
220
- ) {
221
- CoreGenericRepositoryUpdate(
222
- data: {
223
- id: $repo_id,
224
- sync_status: { value: $status },
225
- }
226
- ) {
227
- ok
228
- }
229
- }
230
- """
231
-
232
- await self.sdk.execute_graphql(
233
- branch_name=branch_name,
234
- query=update_status,
235
- variables={"repo_id": str(self.id), "status": status.value},
236
- tracker="mutation-repository-update-admin-status",
237
- )
238
-
239
217
  @task(name="import-jinja2-tansforms", task_run_name="Import Jinja2 transform", cache_policy=NONE) # type: ignore[arg-type]
240
218
  async def import_jinja2_transforms(
241
219
  self,
@@ -1081,14 +1059,14 @@ class InfrahubRepositoryIntegrator(InfrahubRepositoryBase):
1081
1059
 
1082
1060
  self.validate_location(commit=commit, worktree_directory=commit_worktree.directory, file_path=location)
1083
1061
 
1062
+ jinja2_template = Jinja2Template(template=Path(location), template_directory=Path(commit_worktree.directory))
1084
1063
  try:
1085
- templateLoader = jinja2.FileSystemLoader(searchpath=commit_worktree.directory)
1086
- templateEnv = jinja2.Environment(loader=templateLoader, trim_blocks=True, lstrip_blocks=True)
1087
- template = templateEnv.get_template(location)
1088
- return template.render(**data)
1089
- except Exception as exc:
1064
+ return await jinja2_template.render(variables=data)
1065
+ except JinjaTemplateError as exc:
1090
1066
  log.error(str(exc), exc_info=True)
1091
- raise TransformError(repository_name=self.name, commit=commit, location=location, message=str(exc)) from exc
1067
+ raise TransformError(
1068
+ repository_name=self.name, commit=commit, location=location, message=exc.message
1069
+ ) from exc
1092
1070
 
1093
1071
  @task(name="python-check-execute", task_run_name="Execute Python Check", cache_policy=NONE) # type: ignore[arg-type]
1094
1072
  async def execute_python_check(
@@ -20,9 +20,9 @@ class MenuRepository:
20
20
  async def add_children(menu_item: MenuItemDict, menu_node: CoreMenuItem) -> MenuItemDict:
21
21
  children = await menu_node.children.get_peers(db=self.db, peer_type=CoreMenuItem)
22
22
  for child_id, child_node in children.items():
23
- child_menu_item = menu_by_ids[child_id]
24
- child = await add_children(child_menu_item, child_node)
25
- menu_item.children[str(child.identifier)] = child
23
+ if child_menu_item := menu_by_ids.get(child_id):
24
+ child = await add_children(child_menu_item, child_node)
25
+ menu_item.children[str(child.identifier)] = child
26
26
  return menu_item
27
27
 
28
28
  for menu_node in nodes.values():
@@ -33,9 +33,9 @@ class MenuRepository:
33
33
 
34
34
  children = await menu_node.children.get_peers(db=self.db, peer_type=CoreMenuItem)
35
35
  for child_id, child_node in children.items():
36
- child_menu_item = menu_by_ids[child_id]
37
- child = await add_children(child_menu_item, child_node)
38
- menu_item.children[str(child.identifier)] = child
36
+ if child_menu_item := menu_by_ids.get(child_id):
37
+ child = await add_children(child_menu_item, child_node)
38
+ menu_item.children[str(child.identifier)] = child
39
39
 
40
40
  menu.data[str(menu_item.identifier)] = menu_item
41
41
 
infrahub/trigger/tasks.py CHANGED
@@ -14,23 +14,24 @@ from .setup import setup_triggers
14
14
 
15
15
  @flow(name="trigger-configure-all", flow_run_name="Configure all triggers")
16
16
  async def trigger_configure_all(service: InfrahubServices) -> None:
17
- webhook_trigger = await gather_trigger_webhook(db=service.database)
18
- computed_attribute_j2_triggers = await gather_trigger_computed_attribute_jinja2()
19
- (
20
- computed_attribute_python_triggers,
21
- computed_attribute_python_query_triggers,
22
- ) = await gather_trigger_computed_attribute_python(db=service.database)
17
+ async with service.database.start_session() as db:
18
+ webhook_trigger = await gather_trigger_webhook(db=db)
19
+ computed_attribute_j2_triggers = await gather_trigger_computed_attribute_jinja2()
20
+ (
21
+ computed_attribute_python_triggers,
22
+ computed_attribute_python_query_triggers,
23
+ ) = await gather_trigger_computed_attribute_python(db=db)
23
24
 
24
- triggers = (
25
- computed_attribute_j2_triggers
26
- + computed_attribute_python_triggers
27
- + computed_attribute_python_query_triggers
28
- + builtin_triggers
29
- + webhook_trigger
30
- )
31
-
32
- async with get_client(sync_client=False) as prefect_client:
33
- await setup_triggers(
34
- client=prefect_client,
35
- triggers=triggers,
25
+ triggers = (
26
+ computed_attribute_j2_triggers
27
+ + computed_attribute_python_triggers
28
+ + computed_attribute_python_query_triggers
29
+ + builtin_triggers
30
+ + webhook_trigger
36
31
  )
32
+
33
+ async with get_client(sync_client=False) as prefect_client:
34
+ await setup_triggers(
35
+ client=prefect_client,
36
+ triggers=triggers,
37
+ )
@@ -15,7 +15,8 @@ from .constants import TAG_NAMESPACE, WorkflowTag
15
15
  if TYPE_CHECKING:
16
16
  import logging
17
17
 
18
- from infrahub.services import InfrahubServices
18
+ from infrahub.database import InfrahubDatabase
19
+ from infrahub.services import InfrahubComponent
19
20
 
20
21
 
21
22
  async def add_tags(
@@ -56,7 +57,7 @@ async def add_related_node_tag(node_id: str) -> None:
56
57
 
57
58
 
58
59
  async def wait_for_schema_to_converge(
59
- branch_name: str, service: InfrahubServices, log: logging.Logger | logging.LoggerAdapter
60
+ branch_name: str, component: InfrahubComponent, db: InfrahubDatabase, log: logging.Logger | logging.LoggerAdapter
60
61
  ) -> None:
61
62
  has_converged = False
62
63
  branch_id = branch_name
@@ -67,7 +68,7 @@ async def wait_for_schema_to_converge(
67
68
  max_iterations = delay * 5 * 30
68
69
  iteration = 0
69
70
  while not has_converged:
70
- workers = await service.component.list_workers(branch=branch_id, schema_hash=True)
71
+ workers = await component.list_workers(branch=branch_id, schema_hash=True)
71
72
 
72
73
  hashes = {worker.schema_hash for worker in workers if worker.active}
73
74
  if len(hashes) == 1:
@@ -79,8 +80,7 @@ async def wait_for_schema_to_converge(
79
80
  log.warning(
80
81
  f"Schema had not converged after {delay * iteration:.2f} seconds, refreshing schema on local worker manually"
81
82
  )
82
- async with service.database.start_session() as db:
83
- await refresh_branches(db=db)
83
+ await refresh_branches(db=db)
84
84
  return
85
85
 
86
86
  iteration += 1
infrahub_sdk/client.py CHANGED
@@ -281,7 +281,7 @@ class InfrahubClient(BaseClient):
281
281
  self.schema = InfrahubSchema(self)
282
282
  self.branch = InfrahubBranchManager(self)
283
283
  self.object_store = ObjectStore(self)
284
- self.store = NodeStore()
284
+ self.store = NodeStore(default_branch=self.default_branch)
285
285
  self.task = InfrahubTaskManager(self)
286
286
  self.concurrent_execution_limit = asyncio.Semaphore(self.max_concurrent_execution)
287
287
  self._request_method: AsyncRequester = self.config.requester or self._default_request_method
@@ -840,11 +840,11 @@ class InfrahubClient(BaseClient):
840
840
  if populate_store:
841
841
  for node in nodes:
842
842
  if node.id:
843
- self.store.set(key=node.id, node=node)
843
+ self.store.set(node=node)
844
844
  related_nodes = list(set(related_nodes))
845
845
  for node in related_nodes:
846
846
  if node.id:
847
- self.store.set(key=node.id, node=node)
847
+ self.store.set(node=node)
848
848
  return nodes
849
849
 
850
850
  def clone(self) -> InfrahubClient:
@@ -1529,7 +1529,7 @@ class InfrahubClientSync(BaseClient):
1529
1529
  self.schema = InfrahubSchemaSync(self)
1530
1530
  self.branch = InfrahubBranchManagerSync(self)
1531
1531
  self.object_store = ObjectStoreSync(self)
1532
- self.store = NodeStoreSync()
1532
+ self.store = NodeStoreSync(default_branch=self.default_branch)
1533
1533
  self.task = InfrahubTaskManagerSync(self)
1534
1534
  self._request_method: SyncRequester = self.config.sync_requester or self._default_request_method
1535
1535
  self.group_context = InfrahubGroupContextSync(self)
@@ -1997,11 +1997,11 @@ class InfrahubClientSync(BaseClient):
1997
1997
  if populate_store:
1998
1998
  for node in nodes:
1999
1999
  if node.id:
2000
- self.store.set(key=node.id, node=node)
2000
+ self.store.set(node=node)
2001
2001
  related_nodes = list(set(related_nodes))
2002
2002
  for node in related_nodes:
2003
2003
  if node.id:
2004
- self.store.set(key=node.id, node=node)
2004
+ self.store.set(node=node)
2005
2005
  return nodes
2006
2006
 
2007
2007
  @overload
@@ -9,7 +9,6 @@ import sys
9
9
  from pathlib import Path
10
10
  from typing import TYPE_CHECKING, Any, Callable, Optional
11
11
 
12
- import jinja2
13
12
  import typer
14
13
  import ujson
15
14
  from rich.console import Console
@@ -18,7 +17,6 @@ from rich.logging import RichHandler
18
17
  from rich.panel import Panel
19
18
  from rich.pretty import Pretty
20
19
  from rich.table import Table
21
- from rich.traceback import Traceback
22
20
 
23
21
  from .. import __version__ as sdk_version
24
22
  from ..async_typer import AsyncTyper
@@ -31,7 +29,7 @@ from ..ctl.exceptions import QueryNotFoundError
31
29
  from ..ctl.generator import run as run_generator
32
30
  from ..ctl.menu import app as menu_app
33
31
  from ..ctl.object import app as object_app
34
- from ..ctl.render import list_jinja2_transforms
32
+ from ..ctl.render import list_jinja2_transforms, print_template_errors
35
33
  from ..ctl.repository import app as repository_app
36
34
  from ..ctl.repository import get_repository_config
37
35
  from ..ctl.schema import app as schema_app
@@ -44,8 +42,9 @@ from ..ctl.utils import (
44
42
  )
45
43
  from ..ctl.validate import app as validate_app
46
44
  from ..exceptions import GraphQLError, ModuleImportError
47
- from ..jinja2 import identify_faulty_jinja_code
48
45
  from ..schema import MainSchemaTypesAll, SchemaRoot
46
+ from ..template import Jinja2Template
47
+ from ..template.exceptions import JinjaTemplateError
49
48
  from ..utils import get_branch, write_to_file
50
49
  from ..yaml import SchemaFile
51
50
  from .exporter import dump
@@ -168,43 +167,28 @@ async def run(
168
167
  raise typer.Abort(f"Unable to Load the method {method} in the Python script at {script}")
169
168
 
170
169
  client = initialize_client(
171
- branch=branch, timeout=timeout, max_concurrent_execution=concurrent, identifier=module_name
170
+ branch=branch,
171
+ timeout=timeout,
172
+ max_concurrent_execution=concurrent,
173
+ identifier=module_name,
172
174
  )
173
175
  func = getattr(module, method)
174
176
  await func(client=client, log=log, branch=branch, **variables_dict)
175
177
 
176
178
 
177
- def render_jinja2_template(template_path: Path, variables: dict[str, str], data: dict[str, Any]) -> str:
178
- if not template_path.is_file():
179
- console.print(f"[red]Unable to locate the template at {template_path}")
180
- raise typer.Exit(1)
181
-
182
- templateLoader = jinja2.FileSystemLoader(searchpath=".")
183
- templateEnv = jinja2.Environment(loader=templateLoader, trim_blocks=True, lstrip_blocks=True)
184
- template = templateEnv.get_template(str(template_path))
185
-
179
+ async def render_jinja2_template(template_path: Path, variables: dict[str, Any], data: dict[str, Any]) -> str:
180
+ variables["data"] = data
181
+ jinja_template = Jinja2Template(template=Path(template_path), template_directory=Path())
186
182
  try:
187
- rendered_tpl = template.render(**variables, data=data) # type: ignore[arg-type]
188
- except jinja2.TemplateSyntaxError as exc:
189
- console.print("[red]Syntax Error detected on the template")
190
- console.print(f"[yellow] {exc}")
191
- raise typer.Exit(1) from exc
192
-
193
- except jinja2.UndefinedError as exc:
194
- console.print("[red]An error occurred while rendering the jinja template")
195
- traceback = Traceback(show_locals=False)
196
- errors = identify_faulty_jinja_code(traceback=traceback)
197
- for frame, syntax in errors:
198
- console.print(f"[yellow]{frame.filename} on line {frame.lineno}\n")
199
- console.print(syntax)
200
- console.print("")
201
- console.print(traceback.trace.stacks[0].exc_value)
183
+ rendered_tpl = await jinja_template.render(variables=variables)
184
+ except JinjaTemplateError as exc:
185
+ print_template_errors(error=exc, console=console)
202
186
  raise typer.Exit(1) from exc
203
187
 
204
188
  return rendered_tpl
205
189
 
206
190
 
207
- def _run_transform(
191
+ async def _run_transform(
208
192
  query_name: str,
209
193
  variables: dict[str, Any],
210
194
  transform_func: Callable,
@@ -227,7 +211,11 @@ def _run_transform(
227
211
 
228
212
  try:
229
213
  response = execute_graphql_query(
230
- query=query_name, variables_dict=variables, branch=branch, debug=debug, repository_config=repository_config
214
+ query=query_name,
215
+ variables_dict=variables,
216
+ branch=branch,
217
+ debug=debug,
218
+ repository_config=repository_config,
231
219
  )
232
220
 
233
221
  # TODO: response is a dict and can't be printed to the console in this way.
@@ -249,7 +237,7 @@ def _run_transform(
249
237
  raise typer.Abort()
250
238
 
251
239
  if asyncio.iscoroutinefunction(transform_func):
252
- output = asyncio.run(transform_func(response))
240
+ output = await transform_func(response)
253
241
  else:
254
242
  output = transform_func(response)
255
243
  return output
@@ -257,7 +245,7 @@ def _run_transform(
257
245
 
258
246
  @app.command(name="render")
259
247
  @catch_exception(console=console)
260
- def render(
248
+ async def render(
261
249
  transform_name: str = typer.Argument(default="", help="Name of the Python transformation", show_default=False),
262
250
  variables: Optional[list[str]] = typer.Argument(
263
251
  None, help="Variables to pass along with the query. Format key=value key=value."
@@ -289,7 +277,7 @@ def render(
289
277
  transform_func = functools.partial(render_jinja2_template, transform_config.template_path, variables_dict)
290
278
 
291
279
  # Query GQL and run the transform
292
- result = _run_transform(
280
+ result = await _run_transform(
293
281
  query_name=transform_config.query,
294
282
  variables=variables_dict,
295
283
  transform_func=transform_func,
@@ -410,7 +398,10 @@ def version() -> None:
410
398
 
411
399
  @app.command(name="info")
412
400
  @catch_exception(console=console)
413
- def info(detail: bool = typer.Option(False, help="Display detailed information."), _: str = CONFIG_PARAM) -> None: # noqa: PLR0915
401
+ def info( # noqa: PLR0915
402
+ detail: bool = typer.Option(False, help="Display detailed information."),
403
+ _: str = CONFIG_PARAM,
404
+ ) -> None:
414
405
  """Display the status of the Python SDK."""
415
406
 
416
407
  info: dict[str, Any] = {
@@ -476,10 +467,14 @@ def info(detail: bool = typer.Option(False, help="Display detailed information."
476
467
  infrahub_info = Table(show_header=False, box=None)
477
468
  if info["user_info"]:
478
469
  infrahub_info.add_row("User:", info["user_info"]["AccountProfile"]["display_label"])
479
- infrahub_info.add_row("Description:", info["user_info"]["AccountProfile"]["description"]["value"])
470
+ infrahub_info.add_row(
471
+ "Description:",
472
+ info["user_info"]["AccountProfile"]["description"]["value"],
473
+ )
480
474
  infrahub_info.add_row("Status:", info["user_info"]["AccountProfile"]["status"]["label"])
481
475
  infrahub_info.add_row(
482
- "Number of Groups:", str(info["user_info"]["AccountProfile"]["member_of_groups"]["count"])
476
+ "Number of Groups:",
477
+ str(info["user_info"]["AccountProfile"]["member_of_groups"]["count"]),
483
478
  )
484
479
 
485
480
  if groups := info["groups"]:
@@ -1,6 +1,12 @@
1
1
  from rich.console import Console
2
2
 
3
3
  from ..schema.repository import InfrahubRepositoryConfig
4
+ from ..template.exceptions import (
5
+ JinjaTemplateError,
6
+ JinjaTemplateNotFoundError,
7
+ JinjaTemplateSyntaxError,
8
+ JinjaTemplateUndefinedError,
9
+ )
4
10
 
5
11
 
6
12
  def list_jinja2_transforms(config: InfrahubRepositoryConfig) -> None:
@@ -9,3 +15,36 @@ def list_jinja2_transforms(config: InfrahubRepositoryConfig) -> None:
9
15
 
10
16
  for transform in config.jinja2_transforms:
11
17
  console.print(f"{transform.name} ({transform.template_path})")
18
+
19
+
20
+ def print_template_errors(error: JinjaTemplateError, console: Console) -> None:
21
+ if isinstance(error, JinjaTemplateNotFoundError):
22
+ console.print("[red]An error occurred while rendering the jinja template")
23
+ console.print("")
24
+ if error.base_template:
25
+ console.print(f"Base template: [yellow]{error.base_template}")
26
+ console.print(f"Missing template: [yellow]{error.filename}")
27
+ return
28
+
29
+ if isinstance(error, JinjaTemplateUndefinedError):
30
+ console.print("[red]An error occurred while rendering the jinja template")
31
+ for current_error in error.errors:
32
+ console.print(f"[yellow]{current_error.frame.filename} on line {current_error.frame.lineno}\n")
33
+ console.print(current_error.syntax)
34
+ console.print("")
35
+ console.print(error.message)
36
+ return
37
+
38
+ if isinstance(error, JinjaTemplateSyntaxError):
39
+ console.print("[red]A syntax error was encountered within the template")
40
+ console.print("")
41
+ if error.filename:
42
+ console.print(f"Filename: [yellow]{error.filename}")
43
+ console.print(f"Line number: [yellow]{error.lineno}")
44
+ console.print()
45
+ console.print(error.message)
46
+ return
47
+
48
+ console.print("[red]An error occurred while rendering the jinja template")
49
+ console.print("")
50
+ console.print(f"[yellow]{error.message}")
@@ -69,12 +69,12 @@ class ModuleImportError(Error):
69
69
  class NodeNotFoundError(Error):
70
70
  def __init__(
71
71
  self,
72
- node_type: str,
73
72
  identifier: Mapping[str, list[str]],
74
73
  message: str = "Unable to find the node in the database.",
75
74
  branch_name: str | None = None,
75
+ node_type: str | None = None,
76
76
  ):
77
- self.node_type = node_type
77
+ self.node_type = node_type or "unknown"
78
78
  self.identifier = identifier
79
79
  self.branch_name = branch_name
80
80
 
@@ -88,6 +88,10 @@ class NodeNotFoundError(Error):
88
88
  """
89
89
 
90
90
 
91
+ class NodeInvalidError(NodeNotFoundError):
92
+ pass
93
+
94
+
91
95
  class ResourceNotDefinedError(Error):
92
96
  """Raised when trying to access a resource that hasn't been defined."""
93
97
 
infrahub_sdk/generator.py CHANGED
@@ -137,7 +137,7 @@ class InfrahubGenerator:
137
137
 
138
138
  for node in self._nodes + self._related_nodes:
139
139
  if node.id:
140
- self._init_client.store.set(key=node.id, node=node)
140
+ self._init_client.store.set(node=node)
141
141
 
142
142
  @abstractmethod
143
143
  async def generate(self, data: dict) -> None: