infrahub-server 1.2.7__py3-none-any.whl → 1.2.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. infrahub/api/transformation.py +1 -0
  2. infrahub/artifacts/models.py +4 -0
  3. infrahub/cli/db.py +15 -6
  4. infrahub/computed_attribute/tasks.py +34 -12
  5. infrahub/config.py +2 -1
  6. infrahub/constants/__init__.py +0 -0
  7. infrahub/core/branch/tasks.py +0 -2
  8. infrahub/core/constants/__init__.py +1 -0
  9. infrahub/core/diff/calculator.py +4 -3
  10. infrahub/core/diff/combiner.py +1 -2
  11. infrahub/core/diff/coordinator.py +44 -28
  12. infrahub/core/diff/data_check_synchronizer.py +3 -2
  13. infrahub/core/diff/enricher/hierarchy.py +38 -27
  14. infrahub/core/diff/ipam_diff_parser.py +5 -4
  15. infrahub/core/diff/merger/merger.py +20 -18
  16. infrahub/core/diff/model/field_specifiers_map.py +64 -0
  17. infrahub/core/diff/model/path.py +55 -58
  18. infrahub/core/diff/parent_node_adder.py +14 -16
  19. infrahub/core/diff/query/drop_nodes.py +42 -0
  20. infrahub/core/diff/query/field_specifiers.py +8 -7
  21. infrahub/core/diff/query/filters.py +15 -1
  22. infrahub/core/diff/query/save.py +3 -0
  23. infrahub/core/diff/query_parser.py +49 -52
  24. infrahub/core/diff/repository/deserializer.py +36 -23
  25. infrahub/core/diff/repository/repository.py +31 -12
  26. infrahub/core/graph/__init__.py +1 -1
  27. infrahub/core/graph/index.py +3 -1
  28. infrahub/core/initialization.py +23 -7
  29. infrahub/core/manager.py +16 -5
  30. infrahub/core/migrations/graph/__init__.py +2 -0
  31. infrahub/core/migrations/graph/m014_remove_index_attr_value.py +9 -8
  32. infrahub/core/migrations/graph/m027_delete_isolated_nodes.py +50 -0
  33. infrahub/core/protocols.py +1 -0
  34. infrahub/core/query/branch.py +27 -17
  35. infrahub/core/query/diff.py +65 -38
  36. infrahub/core/query/node.py +111 -33
  37. infrahub/core/query/relationship.py +17 -3
  38. infrahub/core/query/subquery.py +2 -2
  39. infrahub/core/schema/definitions/core/builtin.py +2 -4
  40. infrahub/core/schema/definitions/core/transform.py +1 -0
  41. infrahub/core/schema/schema_branch.py +3 -0
  42. infrahub/core/validators/aggregated_checker.py +2 -2
  43. infrahub/core/validators/uniqueness/query.py +30 -9
  44. infrahub/database/__init__.py +1 -16
  45. infrahub/database/index.py +1 -1
  46. infrahub/database/memgraph.py +1 -12
  47. infrahub/database/neo4j.py +1 -13
  48. infrahub/git/integrator.py +27 -3
  49. infrahub/git/models.py +4 -0
  50. infrahub/git/tasks.py +3 -0
  51. infrahub/git_credential/helper.py +2 -2
  52. infrahub/graphql/mutations/computed_attribute.py +5 -1
  53. infrahub/graphql/queries/diff/tree.py +2 -1
  54. infrahub/message_bus/operations/requests/proposed_change.py +6 -0
  55. infrahub/message_bus/types.py +3 -0
  56. infrahub/patch/queries/consolidate_duplicated_nodes.py +109 -0
  57. infrahub/patch/queries/delete_duplicated_edges.py +138 -0
  58. infrahub/proposed_change/tasks.py +1 -0
  59. infrahub/server.py +1 -3
  60. infrahub/transformations/models.py +3 -0
  61. infrahub/transformations/tasks.py +1 -0
  62. infrahub/trigger/models.py +11 -1
  63. infrahub/trigger/setup.py +38 -13
  64. infrahub/trigger/tasks.py +1 -4
  65. infrahub/webhook/models.py +3 -0
  66. infrahub/workflows/initialization.py +1 -3
  67. infrahub_sdk/client.py +4 -4
  68. infrahub_sdk/config.py +17 -0
  69. infrahub_sdk/ctl/cli_commands.py +7 -1
  70. infrahub_sdk/ctl/generator.py +2 -2
  71. infrahub_sdk/generator.py +12 -66
  72. infrahub_sdk/operation.py +80 -0
  73. infrahub_sdk/protocols.py +12 -0
  74. infrahub_sdk/recorder.py +3 -0
  75. infrahub_sdk/schema/repository.py +4 -0
  76. infrahub_sdk/transforms.py +15 -27
  77. {infrahub_server-1.2.7.dist-info → infrahub_server-1.2.9.dist-info}/METADATA +2 -2
  78. {infrahub_server-1.2.7.dist-info → infrahub_server-1.2.9.dist-info}/RECORD +84 -78
  79. infrahub_testcontainers/container.py +1 -0
  80. infrahub_testcontainers/docker-compose.test.yml +5 -1
  81. infrahub/database/manager.py +0 -15
  82. /infrahub/{database/constants.py → constants/database.py} +0 -0
  83. {infrahub_server-1.2.7.dist-info → infrahub_server-1.2.9.dist-info}/LICENSE.txt +0 -0
  84. {infrahub_server-1.2.7.dist-info → infrahub_server-1.2.9.dist-info}/WHEEL +0 -0
  85. {infrahub_server-1.2.7.dist-info → infrahub_server-1.2.9.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,138 @@
1
+ from ..models import EdgeToDelete, EdgeToUpdate, PatchPlan
2
+ from .base import PatchQuery
3
+
4
+
5
+ class DeleteDuplicatedEdgesPatchQuery(PatchQuery):
6
+ """
7
+ Find duplicated or overlapping edges of the same status, type, and branch to update and delete
8
+ - one edge will be kept for each pair of nodes and a given status, type, and branch. it will be
9
+ updated to have the earliest "from" and "to" times in this group
10
+ - all the other duplicate/overlapping edges will be deleted
11
+ """
12
+
13
+ @property
14
+ def name(self) -> str:
15
+ return "delete-duplicated-edges"
16
+
17
+ async def plan(self) -> PatchPlan:
18
+ query = """
19
+ // ------------
20
+ // Find node pairs that have duplicate edges
21
+ // ------------
22
+ MATCH (node_with_dup_edges:Node)-[edge]->(peer)
23
+ WITH node_with_dup_edges, type(edge) AS edge_type, edge.status AS edge_status, edge.branch AS edge_branch, peer, count(*) AS num_dup_edges
24
+ WHERE num_dup_edges > 1
25
+ WITH DISTINCT node_with_dup_edges, edge_type, edge_branch, peer
26
+ CALL {
27
+ // ------------
28
+ // Get the earliest active and deleted edges for this branch
29
+ // ------------
30
+ WITH node_with_dup_edges, edge_type, edge_branch, peer
31
+ MATCH (node_with_dup_edges)-[active_edge {branch: edge_branch, status: "active"}]->(peer)
32
+ WHERE type(active_edge) = edge_type
33
+ WITH node_with_dup_edges, edge_type, edge_branch, peer, active_edge
34
+ ORDER BY active_edge.from ASC
35
+ WITH node_with_dup_edges, edge_type, edge_branch, peer, head(collect(active_edge.from)) AS active_from
36
+ OPTIONAL MATCH (node_with_dup_edges)-[deleted_edge {branch: edge_branch, status: "deleted"}]->(peer)
37
+ WITH node_with_dup_edges, edge_type, edge_branch, peer, active_from, deleted_edge
38
+ ORDER BY deleted_edge.from ASC
39
+ WITH node_with_dup_edges, edge_type, edge_branch, peer, active_from, head(collect(deleted_edge.from)) AS deleted_from
40
+ // ------------
41
+ // Plan one active edge update with correct from and to times
42
+ // ------------
43
+ CALL {
44
+ WITH node_with_dup_edges, edge_type, edge_branch, peer, active_from, deleted_from
45
+ MATCH (node_with_dup_edges)-[active_e {branch: edge_branch, status: "active"}]->(peer)
46
+ WHERE type(active_e) = edge_type
47
+ WITH node_with_dup_edges, edge_type, edge_branch, peer, active_from, deleted_from, active_e
48
+ ORDER BY %(id_func_name)s(active_e)
49
+ LIMIT 1
50
+ WITH active_e, properties(active_e) AS before_props, {from: active_from, to: deleted_from} AS prop_updates
51
+ RETURN [
52
+ {
53
+ db_id: %(id_func_name)s(active_e), before_props: before_props, prop_updates: prop_updates
54
+ }
55
+ ] AS active_edges_to_update
56
+ }
57
+ // ------------
58
+ // Plan deletes for all the other active edges of this type on this branch
59
+ // ------------
60
+ CALL {
61
+ WITH node_with_dup_edges, edge_type, edge_branch, peer
62
+ MATCH (node_with_dup_edges)-[active_e {branch: edge_branch, status: "active"}]->(peer)
63
+ WHERE type(active_e) = edge_type
64
+ WITH node_with_dup_edges, peer, active_e
65
+ ORDER BY %(id_func_name)s(active_e)
66
+ SKIP 1
67
+ RETURN collect(
68
+ {
69
+ db_id: %(id_func_name)s(active_e),
70
+ from_id: %(id_func_name)s(node_with_dup_edges),
71
+ to_id: %(id_func_name)s(peer),
72
+ edge_type: type(active_e),
73
+ before_props: properties(active_e)
74
+ }
75
+ ) AS active_edges_to_delete
76
+ }
77
+ // ------------
78
+ // Plan one deleted edge update with correct from time
79
+ // ------------
80
+ CALL {
81
+ WITH node_with_dup_edges, edge_type, edge_branch, peer, deleted_from
82
+ MATCH (node_with_dup_edges)-[deleted_e {branch: edge_branch, status: "deleted"}]->(peer)
83
+ WHERE type(deleted_e) = edge_type
84
+ WITH node_with_dup_edges, edge_type, edge_branch, peer, deleted_from, deleted_e
85
+ ORDER BY %(id_func_name)s(deleted_e)
86
+ LIMIT 1
87
+ WITH deleted_e, properties(deleted_e) AS before_props, {from: deleted_from} AS prop_updates
88
+ RETURN [
89
+ {
90
+ db_id: %(id_func_name)s(deleted_e), before_props: before_props, prop_updates: prop_updates
91
+ }
92
+ ] AS deleted_edges_to_update
93
+ }
94
+ // ------------
95
+ // Plan deletes for all the other deleted edges of this type on this branch
96
+ // ------------
97
+ CALL {
98
+ WITH node_with_dup_edges, edge_type, edge_branch, peer
99
+ MATCH (node_with_dup_edges)-[deleted_e {branch: edge_branch, status: "deleted"}]->(peer)
100
+ WHERE type(deleted_e) = edge_type
101
+ WITH node_with_dup_edges, peer, deleted_e
102
+ ORDER BY %(id_func_name)s(deleted_e)
103
+ SKIP 1
104
+ RETURN collect(
105
+ {
106
+ db_id: %(id_func_name)s(deleted_e),
107
+ from_id: %(id_func_name)s(node_with_dup_edges),
108
+ to_id: %(id_func_name)s(peer),
109
+ edge_type: type(deleted_e),
110
+ before_props: properties(deleted_e)
111
+ }
112
+ ) AS deleted_edges_to_delete
113
+ }
114
+ RETURN
115
+ active_edges_to_update + deleted_edges_to_update AS edges_to_update,
116
+ active_edges_to_delete + deleted_edges_to_delete AS edges_to_delete
117
+ }
118
+ RETURN edges_to_update, edges_to_delete
119
+ """ % {"id_func_name": self.db.get_id_function_name()}
120
+ results = await self.db.execute_query(query=query)
121
+ edges_to_delete: list[EdgeToDelete] = []
122
+ edges_to_update: list[EdgeToUpdate] = []
123
+ for result in results:
124
+ for serial_edge_to_delete in result.get("edges_to_delete"):
125
+ edge_to_delete = EdgeToDelete(**serial_edge_to_delete)
126
+ edges_to_delete.append(edge_to_delete)
127
+ for serial_edge_to_update in result.get("edges_to_update"):
128
+ prop_updates = serial_edge_to_update["prop_updates"]
129
+ if prop_updates:
130
+ serial_edge_to_update["after_props"] = serial_edge_to_update["before_props"] | prop_updates
131
+ del serial_edge_to_update["prop_updates"]
132
+ edge_to_update = EdgeToUpdate(**serial_edge_to_update)
133
+ edges_to_update.append(edge_to_update)
134
+ return PatchPlan(
135
+ name=self.name,
136
+ edges_to_delete=edges_to_delete,
137
+ edges_to_update=edges_to_update,
138
+ )
@@ -607,6 +607,7 @@ async def validate_artifacts_generation(model: RequestArtifactDefinitionCheck, s
607
607
  content_type=model.artifact_definition.content_type,
608
608
  transform_type=model.artifact_definition.transform_kind,
609
609
  transform_location=model.artifact_definition.transform_location,
610
+ convert_query_response=model.artifact_definition.convert_query_response,
610
611
  repository_id=repository.repository_id,
611
612
  repository_name=repository.repository_name,
612
613
  repository_kind=repository.kind,
infrahub/server.py CHANGED
@@ -23,7 +23,6 @@ from infrahub import __version__, config
23
23
  from infrahub.api import router as api
24
24
  from infrahub.api.exception_handlers import generic_api_exception_handler
25
25
  from infrahub.components import ComponentType
26
- from infrahub.core.graph.index import node_indexes, rel_indexes
27
26
  from infrahub.core.initialization import initialization
28
27
  from infrahub.database import InfrahubDatabase, InfrahubDatabaseMode, get_db
29
28
  from infrahub.dependencies.registry import build_component_registry
@@ -58,7 +57,6 @@ async def app_initialization(application: FastAPI, enable_scheduler: bool = True
58
57
 
59
58
  # Initialize database Driver and load local registry
60
59
  database = application.state.db = InfrahubDatabase(mode=InfrahubDatabaseMode.DRIVER, driver=await get_db())
61
- database.manager.index.init(nodes=node_indexes, rels=rel_indexes)
62
60
 
63
61
  build_component_registry()
64
62
 
@@ -83,7 +81,7 @@ async def app_initialization(application: FastAPI, enable_scheduler: bool = True
83
81
  initialize_lock(service=service)
84
82
  # We must initialize DB after initialize lock and initialize lock depends on cache initialization
85
83
  async with application.state.db.start_session() as db:
86
- await initialization(db=db)
84
+ await initialization(db=db, add_database_indexes=True)
87
85
 
88
86
  application.state.service = service
89
87
  application.state.response_delay = config.SETTINGS.miscellaneous.response_delay
@@ -11,6 +11,9 @@ class TransformPythonData(BaseModel):
11
11
  branch: str = Field(..., description="The branch to target")
12
12
  transform_location: str = Field(..., description="Location of the transform within the repository")
13
13
  commit: str = Field(..., description="The commit id to use when generating the artifact")
14
+ convert_query_response: bool = Field(
15
+ ..., description="Define if the GraphQL query respose should be converted into InfrahubNode objects"
16
+ )
14
17
  timeout: int = Field(..., description="The timeout value to use when generating the artifact")
15
18
 
16
19
 
@@ -30,6 +30,7 @@ async def transform_python(message: TransformPythonData, service: InfrahubServic
30
30
  location=message.transform_location,
31
31
  data=message.data,
32
32
  client=service.client,
33
+ convert_query_response=message.convert_query_response,
33
34
  ) # type: ignore[misc]
34
35
 
35
36
  return transformed_data
@@ -5,8 +5,11 @@ from enum import Enum
5
5
  from typing import TYPE_CHECKING, Any
6
6
 
7
7
  from prefect.events.actions import RunDeployment
8
+ from prefect.events.schemas.automations import (
9
+ Automation, # noqa: TC002
10
+ Posture,
11
+ )
8
12
  from prefect.events.schemas.automations import EventTrigger as PrefectEventTrigger
9
- from prefect.events.schemas.automations import Posture
10
13
  from prefect.events.schemas.events import ResourceSpecification
11
14
  from pydantic import BaseModel, Field
12
15
 
@@ -19,6 +22,13 @@ if TYPE_CHECKING:
19
22
  from uuid import UUID
20
23
 
21
24
 
25
+ class TriggerSetupReport(BaseModel):
26
+ created: list[TriggerDefinition] = Field(default_factory=list)
27
+ updated: list[TriggerDefinition] = Field(default_factory=list)
28
+ deleted: list[Automation] = Field(default_factory=list)
29
+ unchanged: list[TriggerDefinition] = Field(default_factory=list)
30
+
31
+
22
32
  class TriggerType(str, Enum):
23
33
  BUILTIN = "builtin"
24
34
  WEBHOOK = "webhook"
infrahub/trigger/setup.py CHANGED
@@ -5,23 +5,39 @@ from prefect.automations import AutomationCore
5
5
  from prefect.cache_policies import NONE
6
6
  from prefect.client.orchestration import PrefectClient
7
7
  from prefect.client.schemas.filters import DeploymentFilter, DeploymentFilterName
8
+ from prefect.events.schemas.automations import Automation
8
9
 
9
10
  from infrahub.trigger.models import TriggerDefinition
10
11
 
11
- from .models import TriggerType
12
+ from .models import TriggerSetupReport, TriggerType
12
13
 
13
14
  if TYPE_CHECKING:
14
15
  from uuid import UUID
15
16
 
16
17
 
18
+ def compare_automations(target: AutomationCore, existing: Automation) -> bool:
19
+ """Compare an AutomationCore with an existing Automation object to identify if they are identical or not
20
+
21
+ Return True if the target is identical to the existing automation
22
+ """
23
+
24
+ target_dump = target.model_dump(exclude_defaults=True, exclude_none=True)
25
+ existing_dump = existing.model_dump(exclude_defaults=True, exclude_none=True, exclude={"id"})
26
+
27
+ return target_dump == existing_dump
28
+
29
+
17
30
  @task(name="trigger-setup", task_run_name="Setup triggers", cache_policy=NONE) # type: ignore[arg-type]
18
31
  async def setup_triggers(
19
32
  client: PrefectClient,
20
33
  triggers: list[TriggerDefinition],
21
34
  trigger_type: TriggerType | None = None,
22
- ) -> None:
35
+ force_update: bool = False,
36
+ ) -> TriggerSetupReport:
23
37
  log = get_run_logger()
24
38
 
39
+ report = TriggerSetupReport()
40
+
25
41
  if trigger_type:
26
42
  log.info(f"Setting up triggers of type {trigger_type.value}")
27
43
  else:
@@ -38,23 +54,24 @@ async def setup_triggers(
38
54
  )
39
55
  }
40
56
  deployments_mapping: dict[str, UUID] = {name: item.id for name, item in deployments.items()}
41
- existing_automations = {item.name: item for item in await client.read_automations()}
42
57
 
43
58
  # If a trigger type is provided, narrow down the list of existing triggers to know which one to delete
59
+ existing_automations: dict[str, Automation] = {}
44
60
  if trigger_type:
45
- trigger_automations = [
46
- item.name for item in await client.read_automations() if item.name.startswith(trigger_type.value)
47
- ]
61
+ existing_automations = {
62
+ item.name: item for item in await client.read_automations() if item.name.startswith(trigger_type.value)
63
+ }
48
64
  else:
49
- trigger_automations = [item.name for item in await client.read_automations()]
65
+ existing_automations = {item.name: item for item in await client.read_automations()}
50
66
 
51
67
  trigger_names = [trigger.generate_name() for trigger in triggers]
68
+ automation_names = list(existing_automations.keys())
52
69
 
53
- log.debug(f"{len(trigger_automations)} existing triggers ({trigger_automations})")
54
- log.debug(f"{len(trigger_names)} triggers to configure ({trigger_names})")
70
+ log.debug(f"{len(automation_names)} existing triggers ({automation_names})")
71
+ log.debug(f"{len(trigger_names)} triggers to configure ({trigger_names})")
55
72
 
56
- to_delete = set(trigger_automations) - set(trigger_names)
57
- log.debug(f"{len(trigger_names)} triggers to delete ({to_delete})")
73
+ to_delete = set(automation_names) - set(trigger_names)
74
+ log.debug(f"{len(to_delete)} triggers to delete ({to_delete})")
58
75
 
59
76
  # -------------------------------------------------------------
60
77
  # Create or Update all triggers
@@ -71,11 +88,16 @@ async def setup_triggers(
71
88
  existing_automation = existing_automations.get(trigger.generate_name(), None)
72
89
 
73
90
  if existing_automation:
74
- await client.update_automation(automation_id=existing_automation.id, automation=automation)
75
- log.info(f"{trigger.generate_name()} Updated")
91
+ if force_update or not compare_automations(target=automation, existing=existing_automation):
92
+ await client.update_automation(automation_id=existing_automation.id, automation=automation)
93
+ log.info(f"{trigger.generate_name()} Updated")
94
+ report.updated.append(trigger)
95
+ else:
96
+ report.unchanged.append(trigger)
76
97
  else:
77
98
  await client.create_automation(automation=automation)
78
99
  log.info(f"{trigger.generate_name()} Created")
100
+ report.created.append(trigger)
79
101
 
80
102
  # -------------------------------------------------------------
81
103
  # Delete Triggers that shouldn't be there
@@ -86,5 +108,8 @@ async def setup_triggers(
86
108
  if not existing_automation:
87
109
  continue
88
110
 
111
+ report.deleted.append(existing_automation)
89
112
  await client.delete_automation(automation_id=existing_automation.id)
90
113
  log.info(f"{item_to_delete} Deleted")
114
+
115
+ return report
infrahub/trigger/tasks.py CHANGED
@@ -31,7 +31,4 @@ async def trigger_configure_all(service: InfrahubServices) -> None:
31
31
  )
32
32
 
33
33
  async with get_client(sync_client=False) as prefect_client:
34
- await setup_triggers(
35
- client=prefect_client,
36
- triggers=triggers,
37
- )
34
+ await setup_triggers(client=prefect_client, triggers=triggers, force_update=True)
@@ -204,6 +204,7 @@ class TransformWebhook(Webhook):
204
204
  transform_class: str = Field(...)
205
205
  transform_file: str = Field(...)
206
206
  transform_timeout: int = Field(...)
207
+ convert_query_response: bool = Field(...)
207
208
 
208
209
  async def _prepare_payload(self, data: dict[str, Any], context: EventContext, service: InfrahubServices) -> None:
209
210
  repo: InfrahubReadOnlyRepository | InfrahubRepository
@@ -229,6 +230,7 @@ class TransformWebhook(Webhook):
229
230
  branch_name=branch,
230
231
  commit=commit,
231
232
  location=f"{self.transform_file}::{self.transform_class}",
233
+ convert_query_response=self.convert_query_response,
232
234
  data={"data": data, **context.model_dump()},
233
235
  client=service.client,
234
236
  ) # type: ignore[misc]
@@ -247,4 +249,5 @@ class TransformWebhook(Webhook):
247
249
  transform_class=transform.class_name.value,
248
250
  transform_file=transform.file_path.value,
249
251
  transform_timeout=transform.timeout.value,
252
+ convert_query_response=transform.convert_query_response.value or False,
250
253
  )
@@ -71,7 +71,5 @@ async def setup_task_manager() -> None:
71
71
  await setup_worker_pools(client=client)
72
72
  await setup_deployments(client=client)
73
73
  await setup_triggers(
74
- client=client,
75
- triggers=builtin_triggers,
76
- trigger_type=TriggerType.BUILTIN,
74
+ client=client, triggers=builtin_triggers, trigger_type=TriggerType.BUILTIN, force_update=True
77
75
  )
infrahub_sdk/client.py CHANGED
@@ -847,9 +847,9 @@ class InfrahubClient(BaseClient):
847
847
  self.store.set(node=node)
848
848
  return nodes
849
849
 
850
- def clone(self) -> InfrahubClient:
850
+ def clone(self, branch: str | None = None) -> InfrahubClient:
851
851
  """Return a cloned version of the client using the same configuration"""
852
- return InfrahubClient(config=self.config)
852
+ return InfrahubClient(config=self.config.clone(branch=branch))
853
853
 
854
854
  async def execute_graphql(
855
855
  self,
@@ -1591,9 +1591,9 @@ class InfrahubClientSync(BaseClient):
1591
1591
  node = InfrahubNodeSync(client=self, schema=schema, branch=branch, data={"id": id})
1592
1592
  node.delete()
1593
1593
 
1594
- def clone(self) -> InfrahubClientSync:
1594
+ def clone(self, branch: str | None = None) -> InfrahubClientSync:
1595
1595
  """Return a cloned version of the client using the same configuration"""
1596
- return InfrahubClientSync(config=self.config)
1596
+ return InfrahubClientSync(config=self.config.clone(branch=branch))
1597
1597
 
1598
1598
  def execute_graphql(
1599
1599
  self,
infrahub_sdk/config.py CHANGED
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from copy import deepcopy
3
4
  from typing import Any
4
5
 
5
6
  from pydantic import Field, field_validator, model_validator
@@ -158,3 +159,19 @@ class Config(ConfigBase):
158
159
  elif values.get("recorder") == RecorderType.JSON and "custom_recorder" not in values:
159
160
  values["custom_recorder"] = JSONRecorder()
160
161
  return values
162
+
163
+ def clone(self, branch: str | None = None) -> Config:
164
+ config: dict[str, Any] = {
165
+ "default_branch": branch or self.default_branch,
166
+ "recorder": self.recorder,
167
+ "custom_recorder": self.custom_recorder,
168
+ "requester": self.requester,
169
+ "sync_requester": self.sync_requester,
170
+ "log": self.log,
171
+ }
172
+ covered_keys = list(config.keys())
173
+ for field in Config.model_fields.keys():
174
+ if field not in covered_keys:
175
+ config[field] = deepcopy(getattr(self, field))
176
+
177
+ return Config(**config)
@@ -41,6 +41,7 @@ from ..ctl.utils import (
41
41
  )
42
42
  from ..ctl.validate import app as validate_app
43
43
  from ..exceptions import GraphQLError, ModuleImportError
44
+ from ..node import InfrahubNode
44
45
  from ..protocols_generator.generator import CodeGenerator
45
46
  from ..schema import MainSchemaTypesAll, SchemaRoot
46
47
  from ..template import Jinja2Template
@@ -330,7 +331,12 @@ def transform(
330
331
  console.print(f"[red]{exc.message}")
331
332
  raise typer.Exit(1) from exc
332
333
 
333
- transform = transform_class(client=client, branch=branch)
334
+ transform = transform_class(
335
+ client=client,
336
+ branch=branch,
337
+ infrahub_node=InfrahubNode,
338
+ convert_query_response=transform_config.convert_query_response,
339
+ )
334
340
  # Get data
335
341
  query_str = repository_config.get_query(name=transform.query).load_query()
336
342
  data = asyncio.run(
@@ -62,7 +62,7 @@ async def run(
62
62
  generator = generator_class(
63
63
  query=generator_config.query,
64
64
  client=client,
65
- branch=branch,
65
+ branch=branch or "",
66
66
  params=variables_dict,
67
67
  convert_query_response=generator_config.convert_query_response,
68
68
  infrahub_node=InfrahubNode,
@@ -91,7 +91,7 @@ async def run(
91
91
  generator = generator_class(
92
92
  query=generator_config.query,
93
93
  client=client,
94
- branch=branch,
94
+ branch=branch or "",
95
95
  params=params,
96
96
  convert_query_response=generator_config.convert_query_response,
97
97
  infrahub_node=InfrahubNode,
infrahub_sdk/generator.py CHANGED
@@ -1,22 +1,19 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import logging
4
- import os
5
4
  from abc import abstractmethod
6
5
  from typing import TYPE_CHECKING
7
6
 
8
- from infrahub_sdk.repository import GitRepoManager
9
-
10
7
  from .exceptions import UninitializedError
8
+ from .operation import InfrahubOperation
11
9
 
12
10
  if TYPE_CHECKING:
13
11
  from .client import InfrahubClient
14
12
  from .context import RequestContext
15
13
  from .node import InfrahubNode
16
- from .store import NodeStore
17
14
 
18
15
 
19
- class InfrahubGenerator:
16
+ class InfrahubGenerator(InfrahubOperation):
20
17
  """Infrahub Generator class"""
21
18
 
22
19
  def __init__(
@@ -24,7 +21,7 @@ class InfrahubGenerator:
24
21
  query: str,
25
22
  client: InfrahubClient,
26
23
  infrahub_node: type[InfrahubNode],
27
- branch: str | None = None,
24
+ branch: str = "",
28
25
  root_directory: str = "",
29
26
  generator_instance: str = "",
30
27
  params: dict | None = None,
@@ -33,37 +30,21 @@ class InfrahubGenerator:
33
30
  request_context: RequestContext | None = None,
34
31
  ) -> None:
35
32
  self.query = query
36
- self.branch = branch
37
- self.git: GitRepoManager | None = None
33
+
34
+ super().__init__(
35
+ client=client,
36
+ infrahub_node=infrahub_node,
37
+ convert_query_response=convert_query_response,
38
+ branch=branch,
39
+ root_directory=root_directory,
40
+ )
41
+
38
42
  self.params = params or {}
39
- self.root_directory = root_directory or os.getcwd()
40
43
  self.generator_instance = generator_instance
41
- self._init_client = client.clone()
42
- self._init_client.config.default_branch = self._init_client.default_branch = self.branch_name
43
- self._init_client.store._default_branch = self.branch_name
44
44
  self._client: InfrahubClient | None = None
45
- self._nodes: list[InfrahubNode] = []
46
- self._related_nodes: list[InfrahubNode] = []
47
- self.infrahub_node = infrahub_node
48
- self.convert_query_response = convert_query_response
49
45
  self.logger = logger if logger else logging.getLogger("infrahub.tasks")
50
46
  self.request_context = request_context
51
47
 
52
- @property
53
- def store(self) -> NodeStore:
54
- """The store will be populated with nodes based on the query during the collection of data if activated"""
55
- return self._init_client.store
56
-
57
- @property
58
- def nodes(self) -> list[InfrahubNode]:
59
- """Returns nodes collected and parsed during the data collection process if this feature is enables"""
60
- return self._nodes
61
-
62
- @property
63
- def related_nodes(self) -> list[InfrahubNode]:
64
- """Returns nodes collected and parsed during the data collection process if this feature is enables"""
65
- return self._related_nodes
66
-
67
48
  @property
68
49
  def subscribers(self) -> list[str] | None:
69
50
  if self.generator_instance:
@@ -80,20 +61,6 @@ class InfrahubGenerator:
80
61
  def client(self, value: InfrahubClient) -> None:
81
62
  self._client = value
82
63
 
83
- @property
84
- def branch_name(self) -> str:
85
- """Return the name of the current git branch."""
86
-
87
- if self.branch:
88
- return self.branch
89
-
90
- if not self.git:
91
- self.git = GitRepoManager(self.root_directory)
92
-
93
- self.branch = str(self.git.active_branch)
94
-
95
- return self.branch
96
-
97
64
  async def collect_data(self) -> dict:
98
65
  """Query the result of the GraphQL Query defined in self.query and return the result"""
99
66
 
@@ -119,27 +86,6 @@ class InfrahubGenerator:
119
86
  ) as self.client:
120
87
  await self.generate(data=unpacked)
121
88
 
122
- async def process_nodes(self, data: dict) -> None:
123
- if not self.convert_query_response:
124
- return
125
-
126
- await self._init_client.schema.all(branch=self.branch_name)
127
-
128
- for kind in data:
129
- if kind in self._init_client.schema.cache[self.branch_name].nodes.keys():
130
- for result in data[kind].get("edges", []):
131
- node = await self.infrahub_node.from_graphql(
132
- client=self._init_client, branch=self.branch_name, data=result
133
- )
134
- self._nodes.append(node)
135
- await node._process_relationships(
136
- node_data=result, branch=self.branch_name, related_nodes=self._related_nodes
137
- )
138
-
139
- for node in self._nodes + self._related_nodes:
140
- if node.id:
141
- self._init_client.store.set(node=node)
142
-
143
89
  @abstractmethod
144
90
  async def generate(self, data: dict) -> None:
145
91
  """Code to run the generator
@@ -0,0 +1,80 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from typing import TYPE_CHECKING
5
+
6
+ from .repository import GitRepoManager
7
+
8
+ if TYPE_CHECKING:
9
+ from . import InfrahubClient
10
+ from .node import InfrahubNode
11
+ from .store import NodeStore
12
+
13
+
14
+ class InfrahubOperation:
15
+ def __init__(
16
+ self,
17
+ client: InfrahubClient,
18
+ infrahub_node: type[InfrahubNode],
19
+ convert_query_response: bool,
20
+ branch: str,
21
+ root_directory: str,
22
+ ):
23
+ self.branch = branch
24
+ self.convert_query_response = convert_query_response
25
+ self.root_directory = root_directory or os.getcwd()
26
+ self.infrahub_node = infrahub_node
27
+ self._nodes: list[InfrahubNode] = []
28
+ self._related_nodes: list[InfrahubNode] = []
29
+ self._init_client = client.clone(branch=self.branch_name)
30
+ self.git: GitRepoManager | None = None
31
+
32
+ @property
33
+ def branch_name(self) -> str:
34
+ """Return the name of the current git branch."""
35
+
36
+ if self.branch:
37
+ return self.branch
38
+
39
+ if not hasattr(self, "git") or not self.git:
40
+ self.git = GitRepoManager(self.root_directory)
41
+
42
+ self.branch = str(self.git.active_branch)
43
+
44
+ return self.branch
45
+
46
+ @property
47
+ def store(self) -> NodeStore:
48
+ """The store will be populated with nodes based on the query during the collection of data if activated"""
49
+ return self._init_client.store
50
+
51
+ @property
52
+ def nodes(self) -> list[InfrahubNode]:
53
+ """Returns nodes collected and parsed during the data collection process if this feature is enabled"""
54
+ return self._nodes
55
+
56
+ @property
57
+ def related_nodes(self) -> list[InfrahubNode]:
58
+ """Returns nodes collected and parsed during the data collection process if this feature is enabled"""
59
+ return self._related_nodes
60
+
61
+ async def process_nodes(self, data: dict) -> None:
62
+ if not self.convert_query_response:
63
+ return
64
+
65
+ await self._init_client.schema.all(branch=self.branch_name)
66
+
67
+ for kind in data:
68
+ if kind in self._init_client.schema.cache[self.branch_name].nodes.keys():
69
+ for result in data[kind].get("edges", []):
70
+ node = await self.infrahub_node.from_graphql(
71
+ client=self._init_client, branch=self.branch_name, data=result
72
+ )
73
+ self._nodes.append(node)
74
+ await node._process_relationships(
75
+ node_data=result, branch=self.branch_name, related_nodes=self._related_nodes
76
+ )
77
+
78
+ for node in self._nodes + self._related_nodes:
79
+ if node.id:
80
+ self._init_client.store.set(node=node)