infrahub-server 1.2.6__py3-none-any.whl → 1.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. infrahub/cli/db.py +2 -0
  2. infrahub/cli/patch.py +153 -0
  3. infrahub/computed_attribute/models.py +81 -1
  4. infrahub/computed_attribute/tasks.py +34 -53
  5. infrahub/core/node/__init__.py +4 -1
  6. infrahub/core/query/ipam.py +7 -5
  7. infrahub/patch/__init__.py +0 -0
  8. infrahub/patch/constants.py +13 -0
  9. infrahub/patch/edge_adder.py +64 -0
  10. infrahub/patch/edge_deleter.py +33 -0
  11. infrahub/patch/edge_updater.py +28 -0
  12. infrahub/patch/models.py +98 -0
  13. infrahub/patch/plan_reader.py +107 -0
  14. infrahub/patch/plan_writer.py +92 -0
  15. infrahub/patch/queries/__init__.py +0 -0
  16. infrahub/patch/queries/base.py +17 -0
  17. infrahub/patch/runner.py +254 -0
  18. infrahub/patch/vertex_adder.py +61 -0
  19. infrahub/patch/vertex_deleter.py +33 -0
  20. infrahub/patch/vertex_updater.py +28 -0
  21. infrahub_sdk/checks.py +1 -1
  22. infrahub_sdk/ctl/cli_commands.py +2 -2
  23. infrahub_sdk/ctl/menu.py +56 -13
  24. infrahub_sdk/ctl/object.py +55 -5
  25. infrahub_sdk/ctl/utils.py +22 -1
  26. infrahub_sdk/exceptions.py +19 -1
  27. infrahub_sdk/node.py +42 -26
  28. infrahub_sdk/protocols_generator/__init__.py +0 -0
  29. infrahub_sdk/protocols_generator/constants.py +28 -0
  30. infrahub_sdk/{code_generator.py → protocols_generator/generator.py} +47 -34
  31. infrahub_sdk/protocols_generator/template.j2 +114 -0
  32. infrahub_sdk/schema/__init__.py +110 -74
  33. infrahub_sdk/schema/main.py +36 -2
  34. infrahub_sdk/schema/repository.py +2 -0
  35. infrahub_sdk/spec/menu.py +3 -3
  36. infrahub_sdk/spec/object.py +522 -41
  37. infrahub_sdk/testing/docker.py +4 -5
  38. infrahub_sdk/testing/schemas/animal.py +7 -0
  39. infrahub_sdk/yaml.py +63 -7
  40. {infrahub_server-1.2.6.dist-info → infrahub_server-1.2.7.dist-info}/METADATA +1 -1
  41. {infrahub_server-1.2.6.dist-info → infrahub_server-1.2.7.dist-info}/RECORD +44 -27
  42. infrahub_sdk/ctl/constants.py +0 -115
  43. {infrahub_server-1.2.6.dist-info → infrahub_server-1.2.7.dist-info}/LICENSE.txt +0 -0
  44. {infrahub_server-1.2.6.dist-info → infrahub_server-1.2.7.dist-info}/WHEEL +0 -0
  45. {infrahub_server-1.2.6.dist-info → infrahub_server-1.2.7.dist-info}/entry_points.txt +0 -0
infrahub/cli/db.py CHANGED
@@ -54,12 +54,14 @@ from infrahub.services.adapters.message_bus.local import BusSimulator
54
54
  from infrahub.services.adapters.workflow.local import WorkflowLocalExecution
55
55
 
56
56
  from .constants import ERROR_BADGE, FAILED_BADGE, SUCCESS_BADGE
57
+ from .patch import patch_app
57
58
 
58
59
  if TYPE_CHECKING:
59
60
  from infrahub.cli.context import CliContext
60
61
  from infrahub.database import InfrahubDatabase
61
62
 
62
63
  app = AsyncTyper()
64
+ app.add_typer(patch_app, name="patch")
63
65
 
64
66
  PERMISSIONS_AVAILABLE = ["read", "write", "admin"]
65
67
 
infrahub/cli/patch.py ADDED
@@ -0,0 +1,153 @@
1
+ from __future__ import annotations
2
+
3
+ import importlib
4
+ import inspect
5
+ import logging
6
+ from pathlib import Path
7
+ from typing import TYPE_CHECKING
8
+
9
+ import typer
10
+ from infrahub_sdk.async_typer import AsyncTyper
11
+ from rich import print as rprint
12
+
13
+ from infrahub import config
14
+ from infrahub.patch.edge_adder import PatchPlanEdgeAdder
15
+ from infrahub.patch.edge_deleter import PatchPlanEdgeDeleter
16
+ from infrahub.patch.edge_updater import PatchPlanEdgeUpdater
17
+ from infrahub.patch.plan_reader import PatchPlanReader
18
+ from infrahub.patch.plan_writer import PatchPlanWriter
19
+ from infrahub.patch.queries.base import PatchQuery
20
+ from infrahub.patch.runner import (
21
+ PatchPlanEdgeDbIdTranslator,
22
+ PatchRunner,
23
+ )
24
+ from infrahub.patch.vertex_adder import PatchPlanVertexAdder
25
+ from infrahub.patch.vertex_deleter import PatchPlanVertexDeleter
26
+ from infrahub.patch.vertex_updater import PatchPlanVertexUpdater
27
+
28
+ from .constants import ERROR_BADGE, SUCCESS_BADGE
29
+
30
+ if TYPE_CHECKING:
31
+ from infrahub.cli.context import CliContext
32
+ from infrahub.database import InfrahubDatabase
33
+
34
+
35
+ patch_app = AsyncTyper(help="Commands for planning, applying, and reverting database patches")
36
+
37
+
38
+ def get_patch_runner(db: InfrahubDatabase) -> PatchRunner:
39
+ return PatchRunner(
40
+ plan_writer=PatchPlanWriter(),
41
+ plan_reader=PatchPlanReader(),
42
+ edge_db_id_translator=PatchPlanEdgeDbIdTranslator(),
43
+ vertex_adder=PatchPlanVertexAdder(db=db),
44
+ vertex_deleter=PatchPlanVertexDeleter(db=db),
45
+ vertex_updater=PatchPlanVertexUpdater(db=db),
46
+ edge_adder=PatchPlanEdgeAdder(db=db),
47
+ edge_deleter=PatchPlanEdgeDeleter(db=db),
48
+ edge_updater=PatchPlanEdgeUpdater(db=db),
49
+ )
50
+
51
+
52
+ @patch_app.command(name="plan")
53
+ async def plan_patch_cmd(
54
+ ctx: typer.Context,
55
+ patch_path: str = typer.Argument(
56
+ help="Path to the file containing the PatchQuery instance to run. Use Python-style dot paths, such as infrahub.cli.patch.queries.base"
57
+ ),
58
+ patch_plans_dir: Path = typer.Option(Path("infrahub-patches"), help="Path to patch plans directory"), # noqa: B008
59
+ apply: bool = typer.Option(False, help="Apply the patch immediately after creating it"),
60
+ config_file: str = typer.Argument("infrahub.toml", envvar="INFRAHUB_CONFIG"),
61
+ ) -> None:
62
+ """Create a plan for a given patch and save it in the patch plans directory to be applied/reverted"""
63
+ logging.getLogger("infrahub").setLevel(logging.WARNING)
64
+ logging.getLogger("neo4j").setLevel(logging.ERROR)
65
+ logging.getLogger("prefect").setLevel(logging.ERROR)
66
+
67
+ patch_module = importlib.import_module(patch_path)
68
+ patch_query_class = None
69
+ patch_query_class_count = 0
70
+ for _, cls in inspect.getmembers(patch_module, inspect.isclass):
71
+ if issubclass(cls, PatchQuery) and cls is not PatchQuery:
72
+ patch_query_class = cls
73
+ patch_query_class_count += 1
74
+
75
+ patch_query_path = f"{PatchQuery.__module__}.{PatchQuery.__name__}"
76
+ if patch_query_class is None:
77
+ rprint(f"{ERROR_BADGE} No subclass of {patch_query_path} found in {patch_path}")
78
+ raise typer.Exit(1)
79
+ if patch_query_class_count > 1:
80
+ rprint(
81
+ f"{ERROR_BADGE} Multiple subclasses of {patch_query_path} found in {patch_path}. Please only define one per file."
82
+ )
83
+ raise typer.Exit(1)
84
+
85
+ config.load_and_exit(config_file_name=config_file)
86
+
87
+ context: CliContext = ctx.obj
88
+ dbdriver = await context.init_db(retry=1)
89
+
90
+ patch_query_instance = patch_query_class(db=dbdriver)
91
+ async with dbdriver.start_session() as db:
92
+ patch_runner = get_patch_runner(db=db)
93
+ patch_plan_dir = await patch_runner.prepare_plan(patch_query_instance, directory=Path(patch_plans_dir))
94
+ rprint(f"{SUCCESS_BADGE} Patch plan created at {patch_plan_dir}")
95
+ if apply:
96
+ await patch_runner.apply(patch_plan_directory=patch_plan_dir)
97
+ rprint(f"{SUCCESS_BADGE} Patch plan successfully applied")
98
+
99
+ await dbdriver.close()
100
+
101
+
102
+ @patch_app.command(name="apply")
103
+ async def apply_patch_cmd(
104
+ ctx: typer.Context,
105
+ patch_plan_dir: Path = typer.Argument(help="Path to the directory containing a patch plan"),
106
+ config_file: str = typer.Argument("infrahub.toml", envvar="INFRAHUB_CONFIG"),
107
+ ) -> None:
108
+ """Apply a given patch plan"""
109
+ logging.getLogger("infrahub").setLevel(logging.WARNING)
110
+ logging.getLogger("neo4j").setLevel(logging.ERROR)
111
+ logging.getLogger("prefect").setLevel(logging.ERROR)
112
+
113
+ config.load_and_exit(config_file_name=config_file)
114
+
115
+ context: CliContext = ctx.obj
116
+ dbdriver = await context.init_db(retry=1)
117
+
118
+ if not patch_plan_dir.exists() or not patch_plan_dir.is_dir():
119
+ rprint(f"{ERROR_BADGE} patch_plan_dir must be an existing directory")
120
+ raise typer.Exit(1)
121
+
122
+ async with dbdriver.start_session() as db:
123
+ patch_runner = get_patch_runner(db=db)
124
+ await patch_runner.apply(patch_plan_directory=patch_plan_dir)
125
+ rprint(f"{SUCCESS_BADGE} Patch plan successfully applied")
126
+
127
+ await dbdriver.close()
128
+
129
+
130
+ @patch_app.command(name="revert")
131
+ async def revert_patch_cmd(
132
+ ctx: typer.Context,
133
+ patch_plan_dir: Path = typer.Argument(help="Path to the directory containing a patch plan"),
134
+ config_file: str = typer.Argument("infrahub.toml", envvar="INFRAHUB_CONFIG"),
135
+ ) -> None:
136
+ """Revert a given patch plan"""
137
+ logging.getLogger("infrahub").setLevel(logging.WARNING)
138
+ logging.getLogger("neo4j").setLevel(logging.ERROR)
139
+ logging.getLogger("prefect").setLevel(logging.ERROR)
140
+ config.load_and_exit(config_file_name=config_file)
141
+
142
+ context: CliContext = ctx.obj
143
+ db = await context.init_db(retry=1)
144
+
145
+ if not patch_plan_dir.exists() or not patch_plan_dir.is_dir():
146
+ rprint(f"{ERROR_BADGE} patch_plan_dir must be an existing directory")
147
+ raise typer.Exit(1)
148
+
149
+ patch_runner = get_patch_runner(db=db)
150
+ await patch_runner.revert(patch_plan_directory=patch_plan_dir)
151
+ rprint(f"{SUCCESS_BADGE} Patch plan successfully reverted")
152
+
153
+ await db.close()
@@ -2,13 +2,16 @@ from __future__ import annotations
2
2
 
3
3
  from collections import defaultdict
4
4
  from dataclasses import dataclass, field
5
- from typing import TYPE_CHECKING
5
+ from typing import TYPE_CHECKING, Any
6
6
 
7
+ from infrahub_sdk.graphql import Query
7
8
  from prefect.events.schemas.automations import Automation # noqa: TC002
8
9
  from pydantic import BaseModel, ConfigDict, Field, computed_field
9
10
  from typing_extensions import Self
10
11
 
11
12
  from infrahub.core import registry
13
+ from infrahub.core.constants import RelationshipCardinality
14
+ from infrahub.core.schema import AttributeSchema, NodeSchema # noqa: TC001
12
15
  from infrahub.core.schema.schema_branch_computed import ( # noqa: TC001
13
16
  ComputedAttributeTarget,
14
17
  ComputedAttributeTriggerNode,
@@ -309,3 +312,80 @@ class ComputedAttrPythonQueryTriggerDefinition(TriggerBranchDefinition):
309
312
  )
310
313
 
311
314
  return definition
315
+
316
+
317
+ class ComputedAttrJinja2GraphQLResponse(BaseModel):
318
+ node_id: str
319
+ computed_attribute_value: str | None
320
+ variables: dict[str, Any] = Field(default_factory=dict)
321
+
322
+
323
+ class ComputedAttrJinja2GraphQL(BaseModel):
324
+ node_schema: NodeSchema = Field(..., description="The node kind where the computed attribute is defined")
325
+ attribute_schema: AttributeSchema = Field(..., description="The computed attribute")
326
+ variables: list[str] = Field(..., description="The list of variable names used within the computed attribute")
327
+
328
+ def render_graphql_query(self, query_filter: str, filter_id: str) -> str:
329
+ query_fields = self.query_fields
330
+ query_fields["id"] = None
331
+ query_fields[self.attribute_schema.name] = {"value": None}
332
+ query = Query(
333
+ name="ComputedAttributeFilter",
334
+ query={
335
+ self.node_schema.kind: {
336
+ "@filters": {query_filter: filter_id},
337
+ "edges": {"node": query_fields},
338
+ }
339
+ },
340
+ )
341
+
342
+ return query.render()
343
+
344
+ @property
345
+ def query_fields(self) -> dict[str, Any]:
346
+ output: dict[str, Any] = {}
347
+ for variable in self.variables:
348
+ field_name, remainder = variable.split("__", maxsplit=1)
349
+ if field_name in self.node_schema.attribute_names:
350
+ output[field_name] = {remainder: None}
351
+ elif field_name in self.node_schema.relationship_names:
352
+ related_attribute, related_value = remainder.split("__", maxsplit=1)
353
+ relationship = self.node_schema.get_relationship(name=field_name)
354
+ if relationship.cardinality == RelationshipCardinality.ONE:
355
+ if field_name not in output:
356
+ output[field_name] = {"node": {}}
357
+ output[field_name]["node"][related_attribute] = {related_value: None}
358
+ return output
359
+
360
+ def parse_response(self, response: dict[str, Any]) -> list[ComputedAttrJinja2GraphQLResponse]:
361
+ rendered_response: list[ComputedAttrJinja2GraphQLResponse] = []
362
+ if kind_payload := response.get(self.node_schema.kind):
363
+ edges = kind_payload.get("edges", [])
364
+ for node in edges:
365
+ if node_response := self.to_node_response(node_dict=node):
366
+ rendered_response.append(node_response)
367
+ return rendered_response
368
+
369
+ def to_node_response(self, node_dict: dict[str, Any]) -> ComputedAttrJinja2GraphQLResponse | None:
370
+ if node := node_dict.get("node"):
371
+ node_id = node.get("id")
372
+ else:
373
+ return None
374
+
375
+ computed_attribute = node.get(self.attribute_schema.name, {}).get("value")
376
+ response = ComputedAttrJinja2GraphQLResponse(node_id=node_id, computed_attribute_value=computed_attribute)
377
+ for variable in self.variables:
378
+ field_name, remainder = variable.split("__", maxsplit=1)
379
+ response.variables[variable] = None
380
+ if field_content := node.get(field_name):
381
+ if field_name in self.node_schema.attribute_names:
382
+ response.variables[variable] = field_content.get(remainder)
383
+ elif field_name in self.node_schema.relationship_names:
384
+ relationship = self.node_schema.get_relationship(name=field_name)
385
+ if relationship.cardinality == RelationshipCardinality.ONE:
386
+ related_attribute, related_value = remainder.split("__", maxsplit=1)
387
+ node_content = field_content.get("node") or {}
388
+ related_attribute_content = node_content.get(related_attribute) or {}
389
+ response.variables[variable] = related_attribute_content.get(related_value)
390
+
391
+ return response
@@ -2,10 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  from typing import TYPE_CHECKING
4
4
 
5
- from infrahub_sdk.protocols import (
6
- CoreNode, # noqa: TC002
7
- CoreTransformPython,
8
- )
5
+ from infrahub_sdk.protocols import CoreTransformPython
9
6
  from infrahub_sdk.template import Jinja2Template
10
7
  from prefect import flow
11
8
  from prefect.client.orchestration import get_client
@@ -28,9 +25,7 @@ from infrahub.workflows.catalogue import (
28
25
  from infrahub.workflows.utils import add_tags, wait_for_schema_to_converge
29
26
 
30
27
  from .gather import gather_trigger_computed_attribute_jinja2, gather_trigger_computed_attribute_python
31
- from .models import (
32
- PythonTransformTarget,
33
- )
28
+ from .models import ComputedAttrJinja2GraphQL, ComputedAttrJinja2GraphQLResponse, PythonTransformTarget
34
29
 
35
30
  if TYPE_CHECKING:
36
31
  from infrahub.core.schema.computed_attribute import ComputedAttribute
@@ -167,49 +162,33 @@ async def trigger_update_python_computed_attributes(
167
162
  flow_run_name="Update value for computed attribute {attribute_name}",
168
163
  )
169
164
  async def update_computed_attribute_value_jinja2(
170
- branch_name: str, obj: CoreNode, attribute_name: str, template_value: str, service: InfrahubServices
165
+ branch_name: str,
166
+ obj: ComputedAttrJinja2GraphQLResponse,
167
+ node_kind: str,
168
+ attribute_name: str,
169
+ template: Jinja2Template,
170
+ service: InfrahubServices,
171
171
  ) -> None:
172
172
  log = get_run_logger()
173
173
 
174
- await add_tags(branches=[branch_name], nodes=[obj.id], db_change=True)
175
-
176
- jinja_template = Jinja2Template(template=template_value)
177
- variables = {}
178
- for variable in jinja_template.get_variables():
179
- components = variable.split("__")
180
- if len(components) == 2:
181
- property_name = components[0]
182
- property_value = components[1]
183
- attribute_property = getattr(obj, property_name)
184
- variables[variable] = getattr(attribute_property, property_value)
185
- elif len(components) == 3:
186
- relationship_name = components[0]
187
- property_name = components[1]
188
- property_value = components[2]
189
- relationship = getattr(obj, relationship_name)
190
- try:
191
- attribute_property = getattr(relationship.peer, property_name)
192
- variables[variable] = getattr(attribute_property, property_value)
193
- except ValueError:
194
- variables[variable] = ""
195
-
196
- value = await jinja_template.render(variables=variables)
197
- existing_value = getattr(obj, attribute_name).value
198
- if value == existing_value:
174
+ await add_tags(branches=[branch_name], nodes=[obj.node_id], db_change=True)
175
+
176
+ value = await template.render(variables=obj.variables)
177
+ if value == obj.computed_attribute_value:
199
178
  log.debug(f"Ignoring to update {obj} with existing value on {attribute_name}={value}")
200
179
  return
201
180
 
202
181
  await service.client.execute_graphql(
203
182
  query=UPDATE_ATTRIBUTE,
204
183
  variables={
205
- "id": obj.id,
206
- "kind": obj.get_kind(),
184
+ "id": obj.node_id,
185
+ "kind": node_kind,
207
186
  "attribute": attribute_name,
208
187
  "value": value,
209
188
  },
210
189
  branch_name=branch_name,
211
190
  )
212
- log.info(f"Updating computed attribute {obj.get_kind()}.{attribute_name}='{value}' ({obj.id})")
191
+ log.info(f"Updating computed attribute {node_kind}.{attribute_name}='{value}' ({obj.node_id})")
213
192
 
214
193
 
215
194
  @flow(
@@ -235,41 +214,43 @@ async def process_jinja2(
235
214
  branch_name if branch_name in registry.get_altered_schema_branches() else registry.default_branch
236
215
  )
237
216
  schema_branch = registry.schema.get_schema_branch(name=target_branch_schema)
238
- await service.client.schema.all(branch=branch_name, refresh=True, schema_hash=schema_branch.get_hash())
239
-
217
+ node_schema = schema_branch.get_node(name=computed_attribute_kind, duplicate=False)
240
218
  computed_macros = [
241
219
  attrib
242
220
  for attrib in schema_branch.computed_attributes.get_impacted_jinja2_targets(kind=node_kind, updates=updates)
243
221
  if attrib.kind == computed_attribute_kind and attrib.attribute.name == computed_attribute_name
244
222
  ]
245
223
  for computed_macro in computed_macros:
246
- found: list[CoreNode] = []
224
+ found: list[ComputedAttrJinja2GraphQLResponse] = []
225
+ template_string = "n/a"
226
+ if computed_macro.attribute.computed_attribute and computed_macro.attribute.computed_attribute.jinja2_template:
227
+ template_string = computed_macro.attribute.computed_attribute.jinja2_template
228
+
229
+ jinja_template = Jinja2Template(template=template_string)
230
+ variables = jinja_template.get_variables()
231
+
232
+ attribute_graphql = ComputedAttrJinja2GraphQL(
233
+ node_schema=node_schema, attribute_schema=computed_macro.attribute, variables=variables
234
+ )
235
+
247
236
  for id_filter in computed_macro.node_filters:
248
- filters = {id_filter: object_id}
249
- nodes: list[CoreNode] = await service.client.filters(
250
- kind=computed_macro.kind,
251
- branch=branch_name,
252
- prefetch_relationships=True,
253
- populate_store=True,
254
- **filters,
255
- )
256
- found.extend(nodes)
237
+ query = attribute_graphql.render_graphql_query(query_filter=id_filter, filter_id=object_id)
238
+ response = await service.client.execute_graphql(query=query, branch_name=branch_name)
239
+ output = attribute_graphql.parse_response(response=response)
240
+ found.extend(output)
257
241
 
258
242
  if not found:
259
243
  log.debug("No nodes found that requires updates")
260
244
 
261
- template_string = "n/a"
262
- if computed_macro.attribute.computed_attribute and computed_macro.attribute.computed_attribute.jinja2_template:
263
- template_string = computed_macro.attribute.computed_attribute.jinja2_template
264
-
265
245
  batch = await service.client.create_batch()
266
246
  for node in found:
267
247
  batch.add(
268
248
  task=update_computed_attribute_value_jinja2,
269
249
  branch_name=branch_name,
270
250
  obj=node,
251
+ node_kind=node_schema.kind,
271
252
  attribute_name=computed_macro.attribute.name,
272
- template_value=template_string,
253
+ template=jinja_template,
273
254
  service=service,
274
255
  )
275
256
 
@@ -29,6 +29,7 @@ from infrahub.types import ATTRIBUTE_TYPES
29
29
 
30
30
  from ...graphql.constants import KIND_GRAPHQL_FIELD_NAME
31
31
  from ...graphql.models import OrderModel
32
+ from ...log import get_logger
32
33
  from ..query.relationship import RelationshipDeleteAllQuery
33
34
  from ..relationship import RelationshipManager
34
35
  from ..utils import update_relationships_to
@@ -53,6 +54,8 @@ SchemaProtocol = TypeVar("SchemaProtocol")
53
54
  # -
54
55
  # ---------------------------------------------------------------------------------------
55
56
 
57
+ log = get_logger()
58
+
56
59
 
57
60
  class Node(BaseNode, metaclass=BaseNodeMeta):
58
61
  @classmethod
@@ -348,7 +351,7 @@ class Node(BaseNode, metaclass=BaseNodeMeta):
348
351
  fields.pop("updated_at")
349
352
  for field_name in fields.keys():
350
353
  if field_name not in self._schema.valid_input_names:
351
- errors.append(ValidationError({field_name: f"{field_name} is not a valid input for {self.get_kind()}"}))
354
+ log.error(f"{field_name} is not a valid input for {self.get_kind()}")
352
355
 
353
356
  # Backfill fields with the ones from the template if there's one
354
357
  await self.handle_object_template(fields=fields, db=db, errors=errors)
@@ -367,9 +367,11 @@ class IPPrefixReconcileQuery(Query):
367
367
  possible_prefix = tmp_prefix.ljust(self.ip_value.max_prefixlen, "0")
368
368
  if possible_prefix not in possible_prefix_map:
369
369
  possible_prefix_map[possible_prefix] = max_prefix_len
370
- self.params["possible_prefix_and_length_list"] = [
371
- [possible_prefix, max_length] for possible_prefix, max_length in possible_prefix_map.items()
372
- ]
370
+ self.params["possible_prefix_and_length_list"] = []
371
+ self.params["possible_prefix_list"] = []
372
+ for possible_prefix, max_length in possible_prefix_map.items():
373
+ self.params["possible_prefix_and_length_list"].append([possible_prefix, max_length])
374
+ self.params["possible_prefix_list"].append(possible_prefix)
373
375
 
374
376
  namespace_query = """
375
377
  // ------------------
@@ -386,8 +388,7 @@ class IPPrefixReconcileQuery(Query):
386
388
  // ------------------
387
389
  // Get IP Prefix node by UUID
388
390
  // ------------------
389
- MATCH (ip_node {uuid: $node_uuid})
390
- WHERE "%(ip_kind)s" IN labels(ip_node)
391
+ MATCH (ip_node:%(ip_kind)s {uuid: $node_uuid})
391
392
  """ % {
392
393
  "ip_kind": InfrahubKind.IPADDRESS
393
394
  if isinstance(self.ip_value, IPAddressType)
@@ -487,6 +488,7 @@ class IPPrefixReconcileQuery(Query):
487
488
  -[hvr:HAS_VALUE]->(av:%(ip_prefix_attribute_kind)s)
488
489
  WHERE all(r IN relationships(parent_path) WHERE (%(branch_filter)s))
489
490
  AND av.version = $ip_version
491
+ AND av.binary_address IN $possible_prefix_list
490
492
  AND any(prefix_and_length IN $possible_prefix_and_length_list WHERE av.binary_address = prefix_and_length[0] AND av.prefixlen <= prefix_and_length[1])
491
493
  WITH
492
494
  maybe_new_parent,
File without changes
@@ -0,0 +1,13 @@
1
+ from enum import Enum
2
+
3
+
4
+ class PatchPlanFilename(str, Enum):
5
+ VERTICES_TO_ADD = "vertices_to_add.json"
6
+ VERTICES_TO_UPDATE = "vertices_to_update.json"
7
+ VERTICES_TO_DELETE = "vertices_to_delete.json"
8
+ EDGES_TO_ADD = "edges_to_add.json"
9
+ EDGES_TO_UPDATE = "edges_to_update.json"
10
+ EDGES_TO_DELETE = "edges_to_delete.json"
11
+ ADDED_DB_IDS = "added_db_ids.json"
12
+ DELETED_DB_IDS = "deleted_db_ids.json"
13
+ REVERTED_DELETED_DB_IDS = "reverted_deleted_db_ids.json"
@@ -0,0 +1,64 @@
1
+ from collections import defaultdict
2
+ from dataclasses import asdict
3
+ from typing import AsyncGenerator
4
+
5
+ from infrahub.core.query import QueryType
6
+ from infrahub.database import InfrahubDatabase
7
+
8
+ from .models import EdgeToAdd
9
+
10
+
11
+ class PatchPlanEdgeAdder:
12
+ def __init__(self, db: InfrahubDatabase, batch_size_limit: int = 1000) -> None:
13
+ self.db = db
14
+ self.batch_size_limit = batch_size_limit
15
+
16
+ async def _run_add_query(self, edge_type: str, edges_to_add: list[EdgeToAdd]) -> dict[str, str]:
17
+ query = """
18
+ UNWIND $edges_to_add AS edge_to_add
19
+ MATCH (a) WHERE %(id_func_name)s(a) = edge_to_add.from_id
20
+ MATCH (b) WHERE %(id_func_name)s(b) = edge_to_add.to_id
21
+ CREATE (a)-[e:%(edge_type)s]->(b)
22
+ SET e = edge_to_add.after_props
23
+ RETURN edge_to_add.identifier AS abstract_id, %(id_func_name)s(e) AS db_id
24
+ """ % {
25
+ "edge_type": edge_type,
26
+ "id_func_name": self.db.get_id_function_name(),
27
+ }
28
+ edges_to_add_dicts = [asdict(v) for v in edges_to_add]
29
+ # use transaction to make sure we record the results before committing them
30
+ try:
31
+ txn_db = self.db.start_transaction()
32
+ async with txn_db as txn:
33
+ results = await txn.execute_query(
34
+ query=query, params={"edges_to_add": edges_to_add_dicts}, type=QueryType.WRITE
35
+ )
36
+ abstract_to_concrete_id_map: dict[str, str] = {}
37
+ for result in results:
38
+ abstract_id = result.get("abstract_id")
39
+ concrete_id = result.get("db_id")
40
+ abstract_to_concrete_id_map[abstract_id] = concrete_id
41
+ finally:
42
+ await txn_db.close()
43
+ return abstract_to_concrete_id_map
44
+
45
+ async def execute(
46
+ self,
47
+ edges_to_add: list[EdgeToAdd],
48
+ ) -> AsyncGenerator[dict[str, str], None]:
49
+ """
50
+ Create edges_to_add on the database.
51
+ Returns a generator that yields dictionaries mapping EdgeToAdd.identifier to the database-level ID of the newly created edge.
52
+ """
53
+ edges_map_queue: dict[str, list[EdgeToAdd]] = defaultdict(list)
54
+ for edge_to_add in edges_to_add:
55
+ edges_map_queue[edge_to_add.edge_type].append(edge_to_add)
56
+ if len(edges_map_queue[edge_to_add.edge_type]) > self.batch_size_limit:
57
+ yield await self._run_add_query(
58
+ edge_type=edge_to_add.edge_type,
59
+ edges_to_add=edges_map_queue[edge_to_add.edge_type],
60
+ )
61
+ edges_map_queue[edge_to_add.edge_type] = []
62
+
63
+ for edge_type, edges_group in edges_map_queue.items():
64
+ yield await self._run_add_query(edge_type=edge_type, edges_to_add=edges_group)
@@ -0,0 +1,33 @@
1
+ from typing import AsyncGenerator
2
+
3
+ from infrahub.core.query import QueryType
4
+ from infrahub.database import InfrahubDatabase
5
+
6
+ from .models import EdgeToDelete
7
+
8
+
9
+ class PatchPlanEdgeDeleter:
10
+ def __init__(self, db: InfrahubDatabase, batch_size_limit: int = 1000) -> None:
11
+ self.db = db
12
+ self.batch_size_limit = batch_size_limit
13
+
14
+ async def _run_delete_query(self, ids_to_delete: list[str]) -> set[str]:
15
+ query = """
16
+ MATCH ()-[e]-()
17
+ WHERE %(id_func_name)s(e) IN $ids_to_delete
18
+ DELETE e
19
+ RETURN %(id_func_name)s(e) AS deleted_id
20
+ """ % {"id_func_name": self.db.get_id_function_name()}
21
+ results = await self.db.execute_query(
22
+ query=query, params={"ids_to_delete": ids_to_delete}, type=QueryType.WRITE
23
+ )
24
+ deleted_ids: set[str] = set()
25
+ for result in results:
26
+ deleted_id = result.get("deleted_id")
27
+ deleted_ids.add(deleted_id)
28
+ return deleted_ids
29
+
30
+ async def execute(self, edges_to_delete: list[EdgeToDelete]) -> AsyncGenerator[set[str], None]:
31
+ for i in range(0, len(edges_to_delete), self.batch_size_limit):
32
+ ids_to_delete = [e.db_id for e in edges_to_delete[i : i + self.batch_size_limit]]
33
+ yield await self._run_delete_query(ids_to_delete=ids_to_delete)
@@ -0,0 +1,28 @@
1
+ from dataclasses import asdict
2
+
3
+ from infrahub.core.query import QueryType
4
+ from infrahub.database import InfrahubDatabase
5
+
6
+ from .models import EdgeToUpdate
7
+
8
+
9
+ class PatchPlanEdgeUpdater:
10
+ def __init__(self, db: InfrahubDatabase, batch_size_limit: int = 1000) -> None:
11
+ self.db = db
12
+ self.batch_size_limit = batch_size_limit
13
+
14
+ async def _run_update_query(self, edges_to_update: list[EdgeToUpdate]) -> None:
15
+ query = """
16
+ UNWIND $edges_to_update AS edge_to_update
17
+ MATCH ()-[e]-()
18
+ WHERE %(id_func_name)s(e) = edge_to_update.db_id
19
+ SET e = edge_to_update.after_props
20
+ """ % {"id_func_name": self.db.get_id_function_name()}
21
+ await self.db.execute_query(
22
+ query=query, params={"edges_to_update": [asdict(e) for e in edges_to_update]}, type=QueryType.WRITE
23
+ )
24
+
25
+ async def execute(self, edges_to_update: list[EdgeToUpdate]) -> None:
26
+ for i in range(0, len(edges_to_update), self.batch_size_limit):
27
+ vertices_slice = edges_to_update[i : i + self.batch_size_limit]
28
+ await self._run_update_query(edges_to_update=vertices_slice)