infrahub-server 1.2.7__py3-none-any.whl → 1.2.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. infrahub/api/transformation.py +1 -0
  2. infrahub/artifacts/models.py +4 -0
  3. infrahub/cli/db.py +1 -1
  4. infrahub/computed_attribute/tasks.py +1 -0
  5. infrahub/config.py +2 -1
  6. infrahub/constants/__init__.py +0 -0
  7. infrahub/core/constants/__init__.py +1 -0
  8. infrahub/core/graph/index.py +3 -1
  9. infrahub/core/manager.py +16 -5
  10. infrahub/core/migrations/graph/m014_remove_index_attr_value.py +7 -8
  11. infrahub/core/protocols.py +1 -0
  12. infrahub/core/query/node.py +96 -29
  13. infrahub/core/schema/definitions/core/builtin.py +2 -4
  14. infrahub/core/schema/definitions/core/transform.py +1 -0
  15. infrahub/core/validators/aggregated_checker.py +2 -2
  16. infrahub/core/validators/uniqueness/query.py +8 -3
  17. infrahub/database/__init__.py +2 -10
  18. infrahub/database/index.py +1 -1
  19. infrahub/database/memgraph.py +2 -1
  20. infrahub/database/neo4j.py +1 -1
  21. infrahub/git/integrator.py +27 -3
  22. infrahub/git/models.py +4 -0
  23. infrahub/git/tasks.py +3 -0
  24. infrahub/git_credential/helper.py +2 -2
  25. infrahub/message_bus/operations/requests/proposed_change.py +6 -0
  26. infrahub/message_bus/types.py +3 -0
  27. infrahub/patch/queries/consolidate_duplicated_nodes.py +109 -0
  28. infrahub/patch/queries/delete_duplicated_edges.py +138 -0
  29. infrahub/proposed_change/tasks.py +1 -0
  30. infrahub/server.py +3 -1
  31. infrahub/transformations/models.py +3 -0
  32. infrahub/transformations/tasks.py +1 -0
  33. infrahub/webhook/models.py +3 -0
  34. infrahub_sdk/client.py +4 -4
  35. infrahub_sdk/config.py +17 -0
  36. infrahub_sdk/ctl/cli_commands.py +7 -1
  37. infrahub_sdk/ctl/generator.py +2 -2
  38. infrahub_sdk/generator.py +12 -66
  39. infrahub_sdk/operation.py +80 -0
  40. infrahub_sdk/protocols.py +12 -0
  41. infrahub_sdk/recorder.py +3 -0
  42. infrahub_sdk/schema/repository.py +4 -0
  43. infrahub_sdk/transforms.py +15 -27
  44. {infrahub_server-1.2.7.dist-info → infrahub_server-1.2.8.dist-info}/METADATA +2 -2
  45. {infrahub_server-1.2.7.dist-info → infrahub_server-1.2.8.dist-info}/RECORD +50 -46
  46. infrahub_testcontainers/docker-compose.test.yml +2 -0
  47. /infrahub/{database/constants.py → constants/database.py} +0 -0
  48. {infrahub_server-1.2.7.dist-info → infrahub_server-1.2.8.dist-info}/LICENSE.txt +0 -0
  49. {infrahub_server-1.2.7.dist-info → infrahub_server-1.2.8.dist-info}/WHEEL +0 -0
  50. {infrahub_server-1.2.7.dist-info → infrahub_server-1.2.8.dist-info}/entry_points.txt +0 -0
@@ -3,9 +3,9 @@ import sys
3
3
 
4
4
  import typer
5
5
  from infrahub_sdk import Config, InfrahubClientSync
6
+ from infrahub_sdk.protocols import CoreGenericRepository
6
7
 
7
8
  from infrahub import config
8
- from infrahub.core.constants import InfrahubKind
9
9
 
10
10
  logging.getLogger("httpx").setLevel(logging.ERROR)
11
11
  app = typer.Typer()
@@ -51,7 +51,7 @@ def get(
51
51
  raise typer.Exit(1) from exc
52
52
 
53
53
  client = InfrahubClientSync(config=Config(address=config.SETTINGS.main.internal_address, insert_tracker=True))
54
- repo = client.get(kind=InfrahubKind.GENERICREPOSITORY, location__value=location)
54
+ repo = client.get(kind=CoreGenericRepository.__name__, location__value=location)
55
55
 
56
56
  if not repo:
57
57
  print("Repository not found in the database.")
@@ -319,6 +319,9 @@ query GatherArtifactDefinitions {
319
319
  file_path {
320
320
  value
321
321
  }
322
+ convert_query_response {
323
+ value
324
+ }
322
325
  }
323
326
  repository {
324
327
  node {
@@ -526,6 +529,9 @@ def _parse_artifact_definitions(definitions: list[dict]) -> list[ProposedChangeA
526
529
  elif artifact_definition.transform_kind == InfrahubKind.TRANSFORMPYTHON:
527
530
  artifact_definition.class_name = definition["node"]["transformation"]["node"]["class_name"]["value"]
528
531
  artifact_definition.file_path = definition["node"]["transformation"]["node"]["file_path"]["value"]
532
+ artifact_definition.convert_query_response = definition["node"]["transformation"]["node"][
533
+ "convert_query_response"
534
+ ]["value"]
529
535
 
530
536
  parsed.append(artifact_definition)
531
537
 
@@ -96,6 +96,9 @@ class ProposedChangeArtifactDefinition(BaseModel):
96
96
  class_name: str = Field(default="")
97
97
  content_type: str
98
98
  file_path: str = Field(default="")
99
+ convert_query_response: bool = Field(
100
+ default=False, description="Convert query response to InfrahubNode objects for Python based transforms"
101
+ )
99
102
  timeout: int
100
103
 
101
104
  @property
@@ -0,0 +1,109 @@
1
+ from ..models import EdgeToAdd, EdgeToDelete, PatchPlan, VertexToDelete
2
+ from .base import PatchQuery
3
+
4
+
5
+ class ConsolidateDuplicatedNodesPatchQuery(PatchQuery):
6
+ """
7
+ Find any groups of nodes with the same labels and properties, move all the edges to one of the duplicated nodes,
8
+ then delete the other duplicated nodes
9
+ """
10
+
11
+ @property
12
+ def name(self) -> str:
13
+ return "consolidate-duplicated-nodes"
14
+
15
+ async def plan(self) -> PatchPlan:
16
+ query = """
17
+ //------------
18
+ // Find nodes with the same labels and UUID
19
+ //------------
20
+ MATCH (n:Node)
21
+ WITH n.uuid AS node_uuid, count(*) as num_nodes_with_uuid
22
+ WHERE num_nodes_with_uuid > 1
23
+ WITH DISTINCT node_uuid
24
+ MATCH (n:Node {uuid: node_uuid})
25
+ CALL {
26
+ WITH n
27
+ WITH labels(n) AS n_labels
28
+ UNWIND n_labels AS n_label
29
+ WITH n_label
30
+ ORDER BY n_label ASC
31
+ RETURN collect(n_label) AS sorted_labels
32
+ }
33
+ WITH n.uuid AS n_uuid, sorted_labels, collect(n) AS duplicate_nodes
34
+ WHERE size(duplicate_nodes) > 1
35
+ WITH n_uuid, head(duplicate_nodes) AS node_to_keep, tail(duplicate_nodes) AS nodes_to_delete
36
+ UNWIND nodes_to_delete AS node_to_delete
37
+ //------------
38
+ // Find the edges that we need to move to the selected node_to_keep
39
+ //------------
40
+ CALL {
41
+ WITH node_to_keep, node_to_delete
42
+ MATCH (node_to_delete)-[edge_to_delete]->(peer)
43
+ RETURN {
44
+ from_id: %(id_func_name)s(node_to_keep),
45
+ to_id: %(id_func_name)s(peer),
46
+ edge_type: type(edge_to_delete),
47
+ after_props: properties(edge_to_delete)
48
+ } AS edge_to_create
49
+ UNION
50
+ WITH node_to_keep, node_to_delete
51
+ MATCH (node_to_delete)<-[edge_to_delete]-(peer)
52
+ RETURN {
53
+ from_id: %(id_func_name)s(peer),
54
+ to_id: %(id_func_name)s(node_to_keep),
55
+ edge_type: type(edge_to_delete),
56
+ after_props: properties(edge_to_delete)
57
+ } AS edge_to_create
58
+ }
59
+ WITH node_to_delete, collect(edge_to_create) AS edges_to_create
60
+ //------------
61
+ // Find the edges that we need to remove from the duplicated nodes
62
+ //------------
63
+ CALL {
64
+ WITH node_to_delete
65
+ MATCH (node_to_delete)-[e]->(peer)
66
+ RETURN {
67
+ db_id: %(id_func_name)s(e),
68
+ from_id: %(id_func_name)s(node_to_delete),
69
+ to_id: %(id_func_name)s(peer),
70
+ edge_type: type(e),
71
+ before_props: properties(e)
72
+ } AS edge_to_delete
73
+ UNION
74
+ WITH node_to_delete
75
+ MATCH (node_to_delete)<-[e]-(peer)
76
+ RETURN {
77
+ db_id: %(id_func_name)s(e),
78
+ from_id: %(id_func_name)s(peer),
79
+ to_id: %(id_func_name)s(node_to_delete),
80
+ edge_type: type(e),
81
+ before_props: properties(e)
82
+ } AS edge_to_delete
83
+ }
84
+ WITH node_to_delete, edges_to_create, collect(edge_to_delete) AS edges_to_delete
85
+ RETURN
86
+ {db_id: %(id_func_name)s(node_to_delete), labels: labels(node_to_delete), before_props: properties(node_to_delete)} AS vertex_to_delete,
87
+ edges_to_create,
88
+ edges_to_delete
89
+ """ % {"id_func_name": self.db.get_id_function_name()}
90
+ results = await self.db.execute_query(query=query)
91
+ vertices_to_delete: list[VertexToDelete] = []
92
+ edges_to_delete: list[EdgeToDelete] = []
93
+ edges_to_add: list[EdgeToAdd] = []
94
+ for result in results:
95
+ serial_vertex_to_delete = result.get("vertex_to_delete")
96
+ if serial_vertex_to_delete:
97
+ vertex_to_delete = VertexToDelete(**serial_vertex_to_delete)
98
+ vertices_to_delete.append(vertex_to_delete)
99
+ for serial_edge_to_delete in result.get("edges_to_delete"):
100
+ edge_to_delete = EdgeToDelete(**serial_edge_to_delete)
101
+ edges_to_delete.append(edge_to_delete)
102
+ for serial_edge_to_create in result.get("edges_to_create"):
103
+ edges_to_add.append(EdgeToAdd(**serial_edge_to_create))
104
+ return PatchPlan(
105
+ name=self.name,
106
+ vertices_to_delete=vertices_to_delete,
107
+ edges_to_add=edges_to_add,
108
+ edges_to_delete=edges_to_delete,
109
+ )
@@ -0,0 +1,138 @@
1
+ from ..models import EdgeToDelete, EdgeToUpdate, PatchPlan
2
+ from .base import PatchQuery
3
+
4
+
5
+ class DeleteDuplicatedEdgesPatchQuery(PatchQuery):
6
+ """
7
+ Find duplicated or overlapping edges of the same status, type, and branch to update and delete
8
+ - one edge will be kept for each pair of nodes and a given status, type, and branch. it will be
9
+ updated to have the earliest "from" and "to" times in this group
10
+ - all the other duplicate/overlapping edges will be deleted
11
+ """
12
+
13
+ @property
14
+ def name(self) -> str:
15
+ return "delete-duplicated-edges"
16
+
17
+ async def plan(self) -> PatchPlan:
18
+ query = """
19
+ // ------------
20
+ // Find node pairs that have duplicate edges
21
+ // ------------
22
+ MATCH (node_with_dup_edges:Node)-[edge]->(peer)
23
+ WITH node_with_dup_edges, type(edge) AS edge_type, edge.status AS edge_status, edge.branch AS edge_branch, peer, count(*) AS num_dup_edges
24
+ WHERE num_dup_edges > 1
25
+ WITH DISTINCT node_with_dup_edges, edge_type, edge_branch, peer
26
+ CALL {
27
+ // ------------
28
+ // Get the earliest active and deleted edges for this branch
29
+ // ------------
30
+ WITH node_with_dup_edges, edge_type, edge_branch, peer
31
+ MATCH (node_with_dup_edges)-[active_edge {branch: edge_branch, status: "active"}]->(peer)
32
+ WHERE type(active_edge) = edge_type
33
+ WITH node_with_dup_edges, edge_type, edge_branch, peer, active_edge
34
+ ORDER BY active_edge.from ASC
35
+ WITH node_with_dup_edges, edge_type, edge_branch, peer, head(collect(active_edge.from)) AS active_from
36
+ OPTIONAL MATCH (node_with_dup_edges)-[deleted_edge {branch: edge_branch, status: "deleted"}]->(peer)
37
+ WITH node_with_dup_edges, edge_type, edge_branch, peer, active_from, deleted_edge
38
+ ORDER BY deleted_edge.from ASC
39
+ WITH node_with_dup_edges, edge_type, edge_branch, peer, active_from, head(collect(deleted_edge.from)) AS deleted_from
40
+ // ------------
41
+ // Plan one active edge update with correct from and to times
42
+ // ------------
43
+ CALL {
44
+ WITH node_with_dup_edges, edge_type, edge_branch, peer, active_from, deleted_from
45
+ MATCH (node_with_dup_edges)-[active_e {branch: edge_branch, status: "active"}]->(peer)
46
+ WHERE type(active_e) = edge_type
47
+ WITH node_with_dup_edges, edge_type, edge_branch, peer, active_from, deleted_from, active_e
48
+ ORDER BY %(id_func_name)s(active_e)
49
+ LIMIT 1
50
+ WITH active_e, properties(active_e) AS before_props, {from: active_from, to: deleted_from} AS prop_updates
51
+ RETURN [
52
+ {
53
+ db_id: %(id_func_name)s(active_e), before_props: before_props, prop_updates: prop_updates
54
+ }
55
+ ] AS active_edges_to_update
56
+ }
57
+ // ------------
58
+ // Plan deletes for all the other active edges of this type on this branch
59
+ // ------------
60
+ CALL {
61
+ WITH node_with_dup_edges, edge_type, edge_branch, peer
62
+ MATCH (node_with_dup_edges)-[active_e {branch: edge_branch, status: "active"}]->(peer)
63
+ WHERE type(active_e) = edge_type
64
+ WITH node_with_dup_edges, peer, active_e
65
+ ORDER BY %(id_func_name)s(active_e)
66
+ SKIP 1
67
+ RETURN collect(
68
+ {
69
+ db_id: %(id_func_name)s(active_e),
70
+ from_id: %(id_func_name)s(node_with_dup_edges),
71
+ to_id: %(id_func_name)s(peer),
72
+ edge_type: type(active_e),
73
+ before_props: properties(active_e)
74
+ }
75
+ ) AS active_edges_to_delete
76
+ }
77
+ // ------------
78
+ // Plan one deleted edge update with correct from time
79
+ // ------------
80
+ CALL {
81
+ WITH node_with_dup_edges, edge_type, edge_branch, peer, deleted_from
82
+ MATCH (node_with_dup_edges)-[deleted_e {branch: edge_branch, status: "deleted"}]->(peer)
83
+ WHERE type(deleted_e) = edge_type
84
+ WITH node_with_dup_edges, edge_type, edge_branch, peer, deleted_from, deleted_e
85
+ ORDER BY %(id_func_name)s(deleted_e)
86
+ LIMIT 1
87
+ WITH deleted_e, properties(deleted_e) AS before_props, {from: deleted_from} AS prop_updates
88
+ RETURN [
89
+ {
90
+ db_id: %(id_func_name)s(deleted_e), before_props: before_props, prop_updates: prop_updates
91
+ }
92
+ ] AS deleted_edges_to_update
93
+ }
94
+ // ------------
95
+ // Plan deletes for all the other deleted edges of this type on this branch
96
+ // ------------
97
+ CALL {
98
+ WITH node_with_dup_edges, edge_type, edge_branch, peer
99
+ MATCH (node_with_dup_edges)-[deleted_e {branch: edge_branch, status: "deleted"}]->(peer)
100
+ WHERE type(deleted_e) = edge_type
101
+ WITH node_with_dup_edges, peer, deleted_e
102
+ ORDER BY %(id_func_name)s(deleted_e)
103
+ SKIP 1
104
+ RETURN collect(
105
+ {
106
+ db_id: %(id_func_name)s(deleted_e),
107
+ from_id: %(id_func_name)s(node_with_dup_edges),
108
+ to_id: %(id_func_name)s(peer),
109
+ edge_type: type(deleted_e),
110
+ before_props: properties(deleted_e)
111
+ }
112
+ ) AS deleted_edges_to_delete
113
+ }
114
+ RETURN
115
+ active_edges_to_update + deleted_edges_to_update AS edges_to_update,
116
+ active_edges_to_delete + deleted_edges_to_delete AS edges_to_delete
117
+ }
118
+ RETURN edges_to_update, edges_to_delete
119
+ """ % {"id_func_name": self.db.get_id_function_name()}
120
+ results = await self.db.execute_query(query=query)
121
+ edges_to_delete: list[EdgeToDelete] = []
122
+ edges_to_update: list[EdgeToUpdate] = []
123
+ for result in results:
124
+ for serial_edge_to_delete in result.get("edges_to_delete"):
125
+ edge_to_delete = EdgeToDelete(**serial_edge_to_delete)
126
+ edges_to_delete.append(edge_to_delete)
127
+ for serial_edge_to_update in result.get("edges_to_update"):
128
+ prop_updates = serial_edge_to_update["prop_updates"]
129
+ if prop_updates:
130
+ serial_edge_to_update["after_props"] = serial_edge_to_update["before_props"] | prop_updates
131
+ del serial_edge_to_update["prop_updates"]
132
+ edge_to_update = EdgeToUpdate(**serial_edge_to_update)
133
+ edges_to_update.append(edge_to_update)
134
+ return PatchPlan(
135
+ name=self.name,
136
+ edges_to_delete=edges_to_delete,
137
+ edges_to_update=edges_to_update,
138
+ )
@@ -607,6 +607,7 @@ async def validate_artifacts_generation(model: RequestArtifactDefinitionCheck, s
607
607
  content_type=model.artifact_definition.content_type,
608
608
  transform_type=model.artifact_definition.transform_kind,
609
609
  transform_location=model.artifact_definition.transform_location,
610
+ convert_query_response=model.artifact_definition.convert_query_response,
610
611
  repository_id=repository.repository_id,
611
612
  repository_name=repository.repository_name,
612
613
  repository_kind=repository.kind,
infrahub/server.py CHANGED
@@ -23,7 +23,7 @@ from infrahub import __version__, config
23
23
  from infrahub.api import router as api
24
24
  from infrahub.api.exception_handlers import generic_api_exception_handler
25
25
  from infrahub.components import ComponentType
26
- from infrahub.core.graph.index import node_indexes, rel_indexes
26
+ from infrahub.core.graph.index import attr_value_index, node_indexes, rel_indexes
27
27
  from infrahub.core.initialization import initialization
28
28
  from infrahub.database import InfrahubDatabase, InfrahubDatabaseMode, get_db
29
29
  from infrahub.dependencies.registry import build_component_registry
@@ -58,6 +58,8 @@ async def app_initialization(application: FastAPI, enable_scheduler: bool = True
58
58
 
59
59
  # Initialize database Driver and load local registry
60
60
  database = application.state.db = InfrahubDatabase(mode=InfrahubDatabaseMode.DRIVER, driver=await get_db())
61
+ if config.SETTINGS.experimental_features.value_db_index:
62
+ node_indexes.append(attr_value_index)
61
63
  database.manager.index.init(nodes=node_indexes, rels=rel_indexes)
62
64
 
63
65
  build_component_registry()
@@ -11,6 +11,9 @@ class TransformPythonData(BaseModel):
11
11
  branch: str = Field(..., description="The branch to target")
12
12
  transform_location: str = Field(..., description="Location of the transform within the repository")
13
13
  commit: str = Field(..., description="The commit id to use when generating the artifact")
14
+ convert_query_response: bool = Field(
15
+ ..., description="Define if the GraphQL query respose should be converted into InfrahubNode objects"
16
+ )
14
17
  timeout: int = Field(..., description="The timeout value to use when generating the artifact")
15
18
 
16
19
 
@@ -30,6 +30,7 @@ async def transform_python(message: TransformPythonData, service: InfrahubServic
30
30
  location=message.transform_location,
31
31
  data=message.data,
32
32
  client=service.client,
33
+ convert_query_response=message.convert_query_response,
33
34
  ) # type: ignore[misc]
34
35
 
35
36
  return transformed_data
@@ -204,6 +204,7 @@ class TransformWebhook(Webhook):
204
204
  transform_class: str = Field(...)
205
205
  transform_file: str = Field(...)
206
206
  transform_timeout: int = Field(...)
207
+ convert_query_response: bool = Field(...)
207
208
 
208
209
  async def _prepare_payload(self, data: dict[str, Any], context: EventContext, service: InfrahubServices) -> None:
209
210
  repo: InfrahubReadOnlyRepository | InfrahubRepository
@@ -229,6 +230,7 @@ class TransformWebhook(Webhook):
229
230
  branch_name=branch,
230
231
  commit=commit,
231
232
  location=f"{self.transform_file}::{self.transform_class}",
233
+ convert_query_response=self.convert_query_response,
232
234
  data={"data": data, **context.model_dump()},
233
235
  client=service.client,
234
236
  ) # type: ignore[misc]
@@ -247,4 +249,5 @@ class TransformWebhook(Webhook):
247
249
  transform_class=transform.class_name.value,
248
250
  transform_file=transform.file_path.value,
249
251
  transform_timeout=transform.timeout.value,
252
+ convert_query_response=transform.convert_query_response.value or False,
250
253
  )
infrahub_sdk/client.py CHANGED
@@ -847,9 +847,9 @@ class InfrahubClient(BaseClient):
847
847
  self.store.set(node=node)
848
848
  return nodes
849
849
 
850
- def clone(self) -> InfrahubClient:
850
+ def clone(self, branch: str | None = None) -> InfrahubClient:
851
851
  """Return a cloned version of the client using the same configuration"""
852
- return InfrahubClient(config=self.config)
852
+ return InfrahubClient(config=self.config.clone(branch=branch))
853
853
 
854
854
  async def execute_graphql(
855
855
  self,
@@ -1591,9 +1591,9 @@ class InfrahubClientSync(BaseClient):
1591
1591
  node = InfrahubNodeSync(client=self, schema=schema, branch=branch, data={"id": id})
1592
1592
  node.delete()
1593
1593
 
1594
- def clone(self) -> InfrahubClientSync:
1594
+ def clone(self, branch: str | None = None) -> InfrahubClientSync:
1595
1595
  """Return a cloned version of the client using the same configuration"""
1596
- return InfrahubClientSync(config=self.config)
1596
+ return InfrahubClientSync(config=self.config.clone(branch=branch))
1597
1597
 
1598
1598
  def execute_graphql(
1599
1599
  self,
infrahub_sdk/config.py CHANGED
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from copy import deepcopy
3
4
  from typing import Any
4
5
 
5
6
  from pydantic import Field, field_validator, model_validator
@@ -158,3 +159,19 @@ class Config(ConfigBase):
158
159
  elif values.get("recorder") == RecorderType.JSON and "custom_recorder" not in values:
159
160
  values["custom_recorder"] = JSONRecorder()
160
161
  return values
162
+
163
+ def clone(self, branch: str | None = None) -> Config:
164
+ config: dict[str, Any] = {
165
+ "default_branch": branch or self.default_branch,
166
+ "recorder": self.recorder,
167
+ "custom_recorder": self.custom_recorder,
168
+ "requester": self.requester,
169
+ "sync_requester": self.sync_requester,
170
+ "log": self.log,
171
+ }
172
+ covered_keys = list(config.keys())
173
+ for field in Config.model_fields.keys():
174
+ if field not in covered_keys:
175
+ config[field] = deepcopy(getattr(self, field))
176
+
177
+ return Config(**config)
@@ -41,6 +41,7 @@ from ..ctl.utils import (
41
41
  )
42
42
  from ..ctl.validate import app as validate_app
43
43
  from ..exceptions import GraphQLError, ModuleImportError
44
+ from ..node import InfrahubNode
44
45
  from ..protocols_generator.generator import CodeGenerator
45
46
  from ..schema import MainSchemaTypesAll, SchemaRoot
46
47
  from ..template import Jinja2Template
@@ -330,7 +331,12 @@ def transform(
330
331
  console.print(f"[red]{exc.message}")
331
332
  raise typer.Exit(1) from exc
332
333
 
333
- transform = transform_class(client=client, branch=branch)
334
+ transform = transform_class(
335
+ client=client,
336
+ branch=branch,
337
+ infrahub_node=InfrahubNode,
338
+ convert_query_response=transform_config.convert_query_response,
339
+ )
334
340
  # Get data
335
341
  query_str = repository_config.get_query(name=transform.query).load_query()
336
342
  data = asyncio.run(
@@ -62,7 +62,7 @@ async def run(
62
62
  generator = generator_class(
63
63
  query=generator_config.query,
64
64
  client=client,
65
- branch=branch,
65
+ branch=branch or "",
66
66
  params=variables_dict,
67
67
  convert_query_response=generator_config.convert_query_response,
68
68
  infrahub_node=InfrahubNode,
@@ -91,7 +91,7 @@ async def run(
91
91
  generator = generator_class(
92
92
  query=generator_config.query,
93
93
  client=client,
94
- branch=branch,
94
+ branch=branch or "",
95
95
  params=params,
96
96
  convert_query_response=generator_config.convert_query_response,
97
97
  infrahub_node=InfrahubNode,
infrahub_sdk/generator.py CHANGED
@@ -1,22 +1,19 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import logging
4
- import os
5
4
  from abc import abstractmethod
6
5
  from typing import TYPE_CHECKING
7
6
 
8
- from infrahub_sdk.repository import GitRepoManager
9
-
10
7
  from .exceptions import UninitializedError
8
+ from .operation import InfrahubOperation
11
9
 
12
10
  if TYPE_CHECKING:
13
11
  from .client import InfrahubClient
14
12
  from .context import RequestContext
15
13
  from .node import InfrahubNode
16
- from .store import NodeStore
17
14
 
18
15
 
19
- class InfrahubGenerator:
16
+ class InfrahubGenerator(InfrahubOperation):
20
17
  """Infrahub Generator class"""
21
18
 
22
19
  def __init__(
@@ -24,7 +21,7 @@ class InfrahubGenerator:
24
21
  query: str,
25
22
  client: InfrahubClient,
26
23
  infrahub_node: type[InfrahubNode],
27
- branch: str | None = None,
24
+ branch: str = "",
28
25
  root_directory: str = "",
29
26
  generator_instance: str = "",
30
27
  params: dict | None = None,
@@ -33,37 +30,21 @@ class InfrahubGenerator:
33
30
  request_context: RequestContext | None = None,
34
31
  ) -> None:
35
32
  self.query = query
36
- self.branch = branch
37
- self.git: GitRepoManager | None = None
33
+
34
+ super().__init__(
35
+ client=client,
36
+ infrahub_node=infrahub_node,
37
+ convert_query_response=convert_query_response,
38
+ branch=branch,
39
+ root_directory=root_directory,
40
+ )
41
+
38
42
  self.params = params or {}
39
- self.root_directory = root_directory or os.getcwd()
40
43
  self.generator_instance = generator_instance
41
- self._init_client = client.clone()
42
- self._init_client.config.default_branch = self._init_client.default_branch = self.branch_name
43
- self._init_client.store._default_branch = self.branch_name
44
44
  self._client: InfrahubClient | None = None
45
- self._nodes: list[InfrahubNode] = []
46
- self._related_nodes: list[InfrahubNode] = []
47
- self.infrahub_node = infrahub_node
48
- self.convert_query_response = convert_query_response
49
45
  self.logger = logger if logger else logging.getLogger("infrahub.tasks")
50
46
  self.request_context = request_context
51
47
 
52
- @property
53
- def store(self) -> NodeStore:
54
- """The store will be populated with nodes based on the query during the collection of data if activated"""
55
- return self._init_client.store
56
-
57
- @property
58
- def nodes(self) -> list[InfrahubNode]:
59
- """Returns nodes collected and parsed during the data collection process if this feature is enables"""
60
- return self._nodes
61
-
62
- @property
63
- def related_nodes(self) -> list[InfrahubNode]:
64
- """Returns nodes collected and parsed during the data collection process if this feature is enables"""
65
- return self._related_nodes
66
-
67
48
  @property
68
49
  def subscribers(self) -> list[str] | None:
69
50
  if self.generator_instance:
@@ -80,20 +61,6 @@ class InfrahubGenerator:
80
61
  def client(self, value: InfrahubClient) -> None:
81
62
  self._client = value
82
63
 
83
- @property
84
- def branch_name(self) -> str:
85
- """Return the name of the current git branch."""
86
-
87
- if self.branch:
88
- return self.branch
89
-
90
- if not self.git:
91
- self.git = GitRepoManager(self.root_directory)
92
-
93
- self.branch = str(self.git.active_branch)
94
-
95
- return self.branch
96
-
97
64
  async def collect_data(self) -> dict:
98
65
  """Query the result of the GraphQL Query defined in self.query and return the result"""
99
66
 
@@ -119,27 +86,6 @@ class InfrahubGenerator:
119
86
  ) as self.client:
120
87
  await self.generate(data=unpacked)
121
88
 
122
- async def process_nodes(self, data: dict) -> None:
123
- if not self.convert_query_response:
124
- return
125
-
126
- await self._init_client.schema.all(branch=self.branch_name)
127
-
128
- for kind in data:
129
- if kind in self._init_client.schema.cache[self.branch_name].nodes.keys():
130
- for result in data[kind].get("edges", []):
131
- node = await self.infrahub_node.from_graphql(
132
- client=self._init_client, branch=self.branch_name, data=result
133
- )
134
- self._nodes.append(node)
135
- await node._process_relationships(
136
- node_data=result, branch=self.branch_name, related_nodes=self._related_nodes
137
- )
138
-
139
- for node in self._nodes + self._related_nodes:
140
- if node.id:
141
- self._init_client.store.set(node=node)
142
-
143
89
  @abstractmethod
144
90
  async def generate(self, data: dict) -> None:
145
91
  """Code to run the generator
@@ -0,0 +1,80 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from typing import TYPE_CHECKING
5
+
6
+ from .repository import GitRepoManager
7
+
8
+ if TYPE_CHECKING:
9
+ from . import InfrahubClient
10
+ from .node import InfrahubNode
11
+ from .store import NodeStore
12
+
13
+
14
+ class InfrahubOperation:
15
+ def __init__(
16
+ self,
17
+ client: InfrahubClient,
18
+ infrahub_node: type[InfrahubNode],
19
+ convert_query_response: bool,
20
+ branch: str,
21
+ root_directory: str,
22
+ ):
23
+ self.branch = branch
24
+ self.convert_query_response = convert_query_response
25
+ self.root_directory = root_directory or os.getcwd()
26
+ self.infrahub_node = infrahub_node
27
+ self._nodes: list[InfrahubNode] = []
28
+ self._related_nodes: list[InfrahubNode] = []
29
+ self._init_client = client.clone(branch=self.branch_name)
30
+ self.git: GitRepoManager | None = None
31
+
32
+ @property
33
+ def branch_name(self) -> str:
34
+ """Return the name of the current git branch."""
35
+
36
+ if self.branch:
37
+ return self.branch
38
+
39
+ if not hasattr(self, "git") or not self.git:
40
+ self.git = GitRepoManager(self.root_directory)
41
+
42
+ self.branch = str(self.git.active_branch)
43
+
44
+ return self.branch
45
+
46
+ @property
47
+ def store(self) -> NodeStore:
48
+ """The store will be populated with nodes based on the query during the collection of data if activated"""
49
+ return self._init_client.store
50
+
51
+ @property
52
+ def nodes(self) -> list[InfrahubNode]:
53
+ """Returns nodes collected and parsed during the data collection process if this feature is enabled"""
54
+ return self._nodes
55
+
56
+ @property
57
+ def related_nodes(self) -> list[InfrahubNode]:
58
+ """Returns nodes collected and parsed during the data collection process if this feature is enabled"""
59
+ return self._related_nodes
60
+
61
+ async def process_nodes(self, data: dict) -> None:
62
+ if not self.convert_query_response:
63
+ return
64
+
65
+ await self._init_client.schema.all(branch=self.branch_name)
66
+
67
+ for kind in data:
68
+ if kind in self._init_client.schema.cache[self.branch_name].nodes.keys():
69
+ for result in data[kind].get("edges", []):
70
+ node = await self.infrahub_node.from_graphql(
71
+ client=self._init_client, branch=self.branch_name, data=result
72
+ )
73
+ self._nodes.append(node)
74
+ await node._process_relationships(
75
+ node_data=result, branch=self.branch_name, related_nodes=self._related_nodes
76
+ )
77
+
78
+ for node in self._nodes + self._related_nodes:
79
+ if node.id:
80
+ self._init_client.store.set(node=node)