infrahub-server 1.5.1__py3-none-any.whl → 1.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,236 @@
1
+ from __future__ import annotations
2
+
3
+ from itertools import chain
4
+ from typing import TYPE_CHECKING, Any
5
+
6
+ import ujson
7
+ from rich.progress import Progress, TaskID
8
+
9
+ from infrahub.core.branch import Branch
10
+ from infrahub.core.constants import GLOBAL_BRANCH_NAME, BranchSupportType, SchemaPathType
11
+ from infrahub.core.initialization import get_root_node
12
+ from infrahub.core.migrations.graph.m044_backfill_hfid_display_label_in_db import (
13
+ DefaultBranchNodeCount,
14
+ GetPathDetailsDefaultBranch,
15
+ GetResultMapQuery,
16
+ UpdateAttributeValuesQuery,
17
+ )
18
+ from infrahub.core.migrations.schema.node_attribute_add import NodeAttributeAddMigration
19
+ from infrahub.core.migrations.shared import ArbitraryMigration, MigrationResult, get_migration_console
20
+ from infrahub.core.path import SchemaPath
21
+ from infrahub.core.query import Query, QueryType
22
+
23
+ from .load_schema_branch import get_or_load_schema_branch
24
+
25
+ if TYPE_CHECKING:
26
+ from infrahub.core.schema import AttributeSchema, MainSchemaTypes, NodeSchema, SchemaAttributePath
27
+ from infrahub.core.schema.schema_branch import SchemaBranch
28
+ from infrahub.database import InfrahubDatabase
29
+
30
+
31
+ class DeleteBranchAwareAttrsForBranchAgnosticNodesQuery(Query):
32
+ name = "delete_branch_aware_attrs_for_branch_agnostic_nodes_query"
33
+ type = QueryType.WRITE
34
+ insert_return = False
35
+ raise_error_if_empty = False
36
+
37
+ async def query_init(self, db: InfrahubDatabase, **kwargs: dict[str, Any]) -> None: # noqa: ARG002
38
+ query = """
39
+ MATCH (n:Node {branch_support: "agnostic"})
40
+ MATCH (n)-[:HAS_ATTRIBUTE]->(attr:Attribute)
41
+ WHERE attr.name IN ["human_friendly_id", "display_label"]
42
+ WITH DISTINCT attr
43
+ CALL (attr) {
44
+ DETACH DELETE attr
45
+ } IN TRANSACTIONS
46
+ """
47
+ self.add_to_query(query)
48
+
49
+
50
+ class Migration046(ArbitraryMigration):
51
+ """
52
+ Delete any branch-aware human_friendly_id and display_label attributes added to branch-agnostic nodes
53
+ Add human_friendly_id and display_label attributes to branch-agnostic nodes
54
+ Set human_friendly_id and display_label attributes for branch-agnostic nodes on global branch
55
+
56
+ Uses and duplicates code from Migration044
57
+ """
58
+
59
+ name: str = "046_fill_agnostic_hfid_display_labels"
60
+ minimum_version: int = 45
61
+ update_batch_size: int = 1000
62
+
63
+ async def _do_one_schema_all(
64
+ self,
65
+ db: InfrahubDatabase,
66
+ branch: Branch,
67
+ schema: MainSchemaTypes,
68
+ schema_branch: SchemaBranch,
69
+ attribute_schema_map: dict[AttributeSchema, AttributeSchema],
70
+ progress: Progress | None = None,
71
+ update_task: TaskID | None = None,
72
+ ) -> None:
73
+ print(f"Processing {schema.kind}...", end="")
74
+
75
+ schema_paths_by_name: dict[str, list[SchemaAttributePath]] = {}
76
+ for source_attribute_schema in attribute_schema_map.keys():
77
+ node_schema_property = getattr(schema, source_attribute_schema.name)
78
+ if not node_schema_property:
79
+ continue
80
+ if isinstance(node_schema_property, list):
81
+ schema_paths_by_name[source_attribute_schema.name] = [
82
+ schema.parse_schema_path(path=str(path), schema=schema_branch) for path in node_schema_property
83
+ ]
84
+ else:
85
+ schema_paths_by_name[source_attribute_schema.name] = [
86
+ schema.parse_schema_path(path=str(node_schema_property), schema=schema_branch)
87
+ ]
88
+ all_schema_paths = list(chain(*schema_paths_by_name.values()))
89
+ offset = 0
90
+
91
+ # loop until we get no results from the get_details_query
92
+ while True:
93
+ get_details_query: GetResultMapQuery = await GetPathDetailsDefaultBranch.init(
94
+ db=db,
95
+ schema_kind=schema.kind,
96
+ schema_paths=all_schema_paths,
97
+ offset=offset,
98
+ limit=self.update_batch_size,
99
+ )
100
+ await get_details_query.execute(db=db)
101
+
102
+ num_updates = 0
103
+ for source_attribute_schema, destination_attribute_schema in attribute_schema_map.items():
104
+ schema_paths = schema_paths_by_name[source_attribute_schema.name]
105
+ schema_path_values_map = get_details_query.get_result_map(schema_paths)
106
+ num_updates = max(num_updates, len(schema_path_values_map))
107
+ formatted_schema_path_values_map = {}
108
+ for k, v in schema_path_values_map.items():
109
+ if not v:
110
+ continue
111
+ if destination_attribute_schema.kind == "List":
112
+ formatted_schema_path_values_map[k] = ujson.dumps(v)
113
+ else:
114
+ formatted_schema_path_values_map[k] = " ".join(item for item in v if item is not None)
115
+
116
+ if not formatted_schema_path_values_map:
117
+ continue
118
+
119
+ update_display_label_query = await UpdateAttributeValuesQuery.init(
120
+ db=db,
121
+ branch=branch,
122
+ attribute_schema=destination_attribute_schema,
123
+ values_by_id_map=formatted_schema_path_values_map,
124
+ )
125
+ await update_display_label_query.execute(db=db)
126
+
127
+ if progress is not None and update_task is not None:
128
+ progress.update(update_task, advance=num_updates)
129
+
130
+ if num_updates == 0:
131
+ break
132
+
133
+ offset += self.update_batch_size
134
+
135
+ print("done")
136
+
137
+ async def execute(self, db: InfrahubDatabase) -> MigrationResult:
138
+ try:
139
+ return await self._do_execute(db=db)
140
+ except Exception as exc:
141
+ return MigrationResult(errors=[str(exc)])
142
+
143
+ async def _do_execute(self, db: InfrahubDatabase) -> MigrationResult:
144
+ console = get_migration_console()
145
+ result = MigrationResult()
146
+
147
+ root_node = await get_root_node(db=db, initialize=False)
148
+ default_branch_name = root_node.default_branch
149
+ default_branch = await Branch.get_by_name(db=db, name=default_branch_name)
150
+ main_schema_branch = await get_or_load_schema_branch(db=db, branch=default_branch)
151
+
152
+ console.print("Deleting branch-aware attributes for branch-agnostic nodes...", end="")
153
+ delete_query = await DeleteBranchAwareAttrsForBranchAgnosticNodesQuery.init(db=db)
154
+ await delete_query.execute(db=db)
155
+ console.print("done")
156
+
157
+ branch_agnostic_schemas: list[NodeSchema] = []
158
+ migrations = []
159
+ for node_schema_kind in main_schema_branch.node_names:
160
+ schema = main_schema_branch.get_node(name=node_schema_kind, duplicate=False)
161
+ if schema.branch is not BranchSupportType.AGNOSTIC:
162
+ continue
163
+ branch_agnostic_schemas.append(schema)
164
+ migrations.extend(
165
+ [
166
+ NodeAttributeAddMigration(
167
+ new_node_schema=schema,
168
+ previous_node_schema=schema,
169
+ schema_path=SchemaPath(
170
+ schema_kind=schema.kind, path_type=SchemaPathType.ATTRIBUTE, field_name="human_friendly_id"
171
+ ),
172
+ ),
173
+ NodeAttributeAddMigration(
174
+ new_node_schema=schema,
175
+ previous_node_schema=schema,
176
+ schema_path=SchemaPath(
177
+ schema_kind=schema.kind, path_type=SchemaPathType.ATTRIBUTE, field_name="display_label"
178
+ ),
179
+ ),
180
+ ]
181
+ )
182
+
183
+ global_branch = await Branch.get_by_name(db=db, name=GLOBAL_BRANCH_NAME)
184
+ with Progress(console=console) as progress:
185
+ update_task = progress.add_task(
186
+ "Adding HFID and display label attributes to branch-agnostic nodes", total=len(migrations)
187
+ )
188
+
189
+ for migration in migrations:
190
+ try:
191
+ execution_result = await migration.execute(db=db, branch=global_branch)
192
+ result.errors.extend(execution_result.errors)
193
+ progress.update(update_task, advance=1)
194
+ except Exception as exc:
195
+ result.errors.append(str(exc))
196
+ return result
197
+
198
+ total_nodes_query = await DefaultBranchNodeCount.init(
199
+ db=db, kinds_to_include=[sch.kind for sch in branch_agnostic_schemas]
200
+ )
201
+ await total_nodes_query.execute(db=db)
202
+ total_nodes_count = total_nodes_query.get_num_nodes()
203
+
204
+ base_node_schema = main_schema_branch.get("SchemaNode", duplicate=False)
205
+ display_label_attribute_schema = base_node_schema.get_attribute("display_label")
206
+ display_labels_attribute_schema = base_node_schema.get_attribute("display_labels")
207
+ hfid_attribute_schema = base_node_schema.get_attribute("human_friendly_id")
208
+
209
+ with Progress(console=console) as progress:
210
+ update_task = progress.add_task(
211
+ f"Set display_label and human_friendly_id for {total_nodes_count} branch-agnostic nodes on global branch",
212
+ total=total_nodes_count,
213
+ )
214
+ for branch_agnostic_schema in branch_agnostic_schemas:
215
+ attribute_schema_map = {}
216
+ if branch_agnostic_schema.display_labels:
217
+ attribute_schema_map[display_labels_attribute_schema] = display_label_attribute_schema
218
+ if branch_agnostic_schema.human_friendly_id:
219
+ attribute_schema_map[hfid_attribute_schema] = hfid_attribute_schema
220
+ if not attribute_schema_map:
221
+ continue
222
+
223
+ await self._do_one_schema_all(
224
+ db=db,
225
+ branch=global_branch,
226
+ schema=branch_agnostic_schema,
227
+ schema_branch=main_schema_branch,
228
+ attribute_schema_map=attribute_schema_map,
229
+ progress=progress,
230
+ update_task=update_task,
231
+ )
232
+
233
+ return result
234
+
235
+ async def validate_migration(self, db: InfrahubDatabase) -> MigrationResult: # noqa: ARG002
236
+ return MigrationResult()
@@ -38,7 +38,7 @@ class AttributeQuery(Query):
38
38
  if at:
39
39
  self.at = Timestamp(at)
40
40
  else:
41
- self.at = self.attr.at
41
+ self.at = Timestamp()
42
42
 
43
43
  self.branch = branch or self.attr.get_branch_based_on_support_type()
44
44
 
@@ -247,10 +247,9 @@ class AttributeGetQuery(AttributeQuery):
247
247
  self.params["attr_uuid"] = self.attr.id
248
248
  self.params["node_uuid"] = self.attr.node.id
249
249
 
250
- at = self.at or self.attr.at
251
- self.params["at"] = at.to_string()
250
+ self.params["at"] = self.at.to_string()
252
251
 
253
- rels_filter, rels_params = self.branch.get_query_filter_path(at=at.to_string())
252
+ rels_filter, rels_params = self.branch.get_query_filter_path(at=self.at.to_string())
254
253
  self.params.update(rels_params)
255
254
 
256
255
  query = (
@@ -1,7 +1,5 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import cast
4
-
5
3
  from infrahub_sdk.exceptions import URLNotFoundError
6
4
  from infrahub_sdk.template import Jinja2Template
7
5
  from prefect import flow
@@ -139,11 +137,32 @@ async def display_labels_setup_jinja2(
139
137
  ) # type: ignore[misc]
140
138
 
141
139
  # Configure all DisplayLabelTriggerDefinitions in Prefect
142
- display_reports = [cast(DisplayLabelTriggerDefinition, entry) for entry in report.updated + report.created]
143
- direct_target_triggers = [display_report for display_report in display_reports if display_report.target_kind]
140
+ all_triggers = report.triggers_with_type(trigger_type=DisplayLabelTriggerDefinition)
141
+ direct_target_triggers = [
142
+ display_report
143
+ for display_report in report.modified_triggers_with_type(trigger_type=DisplayLabelTriggerDefinition)
144
+ if display_report.target_kind
145
+ ]
144
146
 
145
147
  for display_report in direct_target_triggers:
146
148
  if event_name != BranchDeletedEvent.event_name and display_report.branch == branch_name:
149
+ if branch_name != registry.default_branch:
150
+ default_branch_triggers = [
151
+ trigger
152
+ for trigger in all_triggers
153
+ if trigger.branch == registry.default_branch
154
+ and trigger.target_kind == display_report.target_kind
155
+ ]
156
+ if (
157
+ default_branch_triggers
158
+ and len(default_branch_triggers) == 1
159
+ and default_branch_triggers[0].template_hash == display_report.template_hash
160
+ ):
161
+ log.debug(
162
+ f"Skipping display label updates for {display_report.target_kind} [{branch_name}], schema is identical to default branch"
163
+ )
164
+ continue
165
+
147
166
  await get_workflow().submit_workflow(
148
167
  workflow=TRIGGER_UPDATE_DISPLAY_LABELS,
149
168
  context=context,
infrahub/hfid/tasks.py CHANGED
@@ -1,7 +1,5 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import cast
4
-
5
3
  from infrahub_sdk.exceptions import URLNotFoundError
6
4
  from prefect import flow
7
5
  from prefect.logging import get_run_logger
@@ -138,11 +136,32 @@ async def hfid_setup(context: InfrahubContext, branch_name: str | None = None, e
138
136
  ) # type: ignore[misc]
139
137
 
140
138
  # Configure all DisplayLabelTriggerDefinitions in Prefect
141
- hfid_reports = [cast(HFIDTriggerDefinition, entry) for entry in report.updated + report.created]
142
- direct_target_triggers = [hfid_report for hfid_report in hfid_reports if hfid_report.target_kind]
139
+ all_triggers = report.triggers_with_type(trigger_type=HFIDTriggerDefinition)
140
+ direct_target_triggers = [
141
+ hfid_report
142
+ for hfid_report in report.modified_triggers_with_type(trigger_type=HFIDTriggerDefinition)
143
+ if hfid_report.target_kind
144
+ ]
143
145
 
144
146
  for display_report in direct_target_triggers:
145
147
  if event_name != BranchDeletedEvent.event_name and display_report.branch == branch_name:
148
+ if branch_name != registry.default_branch:
149
+ default_branch_triggers = [
150
+ trigger
151
+ for trigger in all_triggers
152
+ if trigger.branch == registry.default_branch
153
+ and trigger.target_kind == display_report.target_kind
154
+ ]
155
+ if (
156
+ default_branch_triggers
157
+ and len(default_branch_triggers) == 1
158
+ and default_branch_triggers[0].hfid_hash == display_report.hfid_hash
159
+ ):
160
+ log.debug(
161
+ f"Skipping HFID updates for {display_report.target_kind} [{branch_name}], schema is identical to default branch"
162
+ )
163
+ continue
164
+
146
165
  await get_workflow().submit_workflow(
147
166
  workflow=TRIGGER_UPDATE_HFID,
148
167
  context=context,
@@ -1,5 +1,6 @@
1
1
  import asyncio
2
- import uuid
2
+ import hashlib
3
+ import json
3
4
  from datetime import datetime, timedelta, timezone
4
5
  from typing import Any
5
6
  from uuid import UUID
@@ -27,11 +28,14 @@ from prefect.client.schemas.sorting import (
27
28
  FlowRunSort,
28
29
  )
29
30
 
31
+ from infrahub import config
30
32
  from infrahub.core.constants import TaskConclusion
31
33
  from infrahub.core.query.node import NodeGetKindQuery
32
34
  from infrahub.database import InfrahubDatabase
33
35
  from infrahub.log import get_logger
36
+ from infrahub.message_bus.types import KVTTL
34
37
  from infrahub.utils import get_nested_dict
38
+ from infrahub.workers.dependencies import get_cache
35
39
  from infrahub.workflows.constants import TAG_NAMESPACE, WorkflowTag
36
40
 
37
41
  from .constants import CONCLUSION_STATE_MAPPING
@@ -44,6 +48,12 @@ PREFECT_MAX_LOGS_PER_CALL = 200
44
48
 
45
49
 
46
50
  class PrefectTask:
51
+ @staticmethod
52
+ def _build_flow_run_count_cache_key(body: dict[str, Any]) -> str:
53
+ serialized = json.dumps(body, sort_keys=True, separators=(",", ":"))
54
+ hashed = hashlib.sha256(serialized.encode()).hexdigest()
55
+ return f"task_manager:flow_run_count:{hashed}"
56
+
47
57
  @classmethod
48
58
  async def count_flow_runs(
49
59
  cls,
@@ -59,10 +69,24 @@ class PrefectTask:
59
69
  "flows": flow_filter.model_dump(mode="json") if flow_filter else None,
60
70
  "flow_runs": (flow_run_filter.model_dump(mode="json", exclude_unset=True) if flow_run_filter else None),
61
71
  }
72
+ cache_key = cls._build_flow_run_count_cache_key(body)
73
+
74
+ cache = await get_cache()
75
+ cached_value_raw = await cache.get(key=cache_key)
76
+ if cached_value_raw is not None:
77
+ try:
78
+ return int(cached_value_raw)
79
+ except (TypeError, ValueError):
80
+ await cache.delete(key=cache_key)
62
81
 
63
82
  response = await client._client.post("/flow_runs/count", json=body)
64
83
  response.raise_for_status()
65
- return response.json()
84
+ count_value = int(response.json())
85
+
86
+ if count_value >= config.SETTINGS.workflow.flow_run_count_cache_threshold:
87
+ await cache.set(key=cache_key, value=str(count_value), expires=KVTTL.ONE_MINUTE)
88
+
89
+ return count_value
66
90
 
67
91
  @classmethod
68
92
  async def _get_related_nodes(cls, db: InfrahubDatabase, flows: list[FlowRun]) -> RelatedNodesInfo:
@@ -204,7 +228,7 @@ class PrefectTask:
204
228
  tags=FlowRunFilterTags(all_=filter_tags),
205
229
  )
206
230
  if ids:
207
- flow_run_filter.id = FlowRunFilterId(any_=[uuid.UUID(id) for id in ids])
231
+ flow_run_filter.id = FlowRunFilterId(any_=[UUID(id) for id in ids])
208
232
 
209
233
  if statuses:
210
234
  flow_run_filter.state = FlowRunFilterState(type=FlowRunFilterStateType(any_=statuses))
@@ -8,6 +8,7 @@ from prefect.client.schemas.objects import WorkerStatus
8
8
  from infrahub.events.utils import get_all_events
9
9
  from infrahub.trigger.constants import NAME_SEPARATOR
10
10
  from infrahub.trigger.models import TriggerType
11
+ from infrahub.trigger.setup import gather_all_automations
11
12
 
12
13
  from .models import TelemetryPrefectData, TelemetryWorkPoolData
13
14
 
@@ -53,7 +54,7 @@ async def gather_prefect_events(client: PrefectClient) -> dict[str, Any]:
53
54
 
54
55
  @task(name="telemetry-gather-automations", task_run_name="Gather Automations", cache_policy=NONE)
55
56
  async def gather_prefect_automations(client: PrefectClient) -> dict[str, Any]:
56
- automations = await client.read_automations()
57
+ automations = await gather_all_automations(client=client)
57
58
 
58
59
  data: dict[str, Any] = {}
59
60
 
@@ -1,8 +1,8 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from datetime import timedelta
4
- from enum import Enum
5
- from typing import TYPE_CHECKING, Any
4
+ from enum import Enum, StrEnum
5
+ from typing import TYPE_CHECKING, Any, TypeVar
6
6
 
7
7
  from prefect.events.actions import RunDeployment
8
8
  from prefect.events.schemas.automations import Automation, Posture
@@ -18,16 +18,78 @@ from .constants import NAME_SEPARATOR
18
18
  if TYPE_CHECKING:
19
19
  from uuid import UUID
20
20
 
21
+ T = TypeVar("T", bound="TriggerDefinition")
22
+
23
+
24
+ class TriggerComparison(StrEnum):
25
+ MATCH = "match" # Expected trigger and actual trigger is identical
26
+ REFRESH = "refresh" # The branch parameters doesn't match, the hash does, refresh in Prefect but don't run triggers
27
+ UPDATE = "update" # Neither branch or other data points match, update in Prefect and run triggers
28
+
29
+ @property
30
+ def update_prefect(self) -> bool:
31
+ return self in {TriggerComparison.REFRESH, TriggerComparison.UPDATE}
32
+
21
33
 
22
34
  class TriggerSetupReport(BaseModel):
23
35
  created: list[TriggerDefinition] = Field(default_factory=list)
36
+ refreshed: list[TriggerDefinition] = Field(default_factory=list)
24
37
  updated: list[TriggerDefinition] = Field(default_factory=list)
25
38
  deleted: list[Automation] = Field(default_factory=list)
26
39
  unchanged: list[TriggerDefinition] = Field(default_factory=list)
27
40
 
28
41
  @property
29
42
  def in_use_count(self) -> int:
30
- return len(self.created + self.updated + self.unchanged)
43
+ return len(self.created + self.updated + self.unchanged + self.refreshed)
44
+
45
+ def add_with_comparison(self, trigger: TriggerDefinition, comparison: TriggerComparison) -> None:
46
+ match comparison:
47
+ case TriggerComparison.UPDATE:
48
+ self.updated.append(trigger)
49
+ case TriggerComparison.REFRESH:
50
+ self.refreshed.append(trigger)
51
+ case TriggerComparison.MATCH:
52
+ self.unchanged.append(trigger)
53
+
54
+ def _created_triggers_with_type(self, trigger_type: type[T]) -> list[T]:
55
+ return [trigger for trigger in self.created if isinstance(trigger, trigger_type)]
56
+
57
+ def _refreshed_triggers_with_type(self, trigger_type: type[T]) -> list[T]:
58
+ return [trigger for trigger in self.refreshed if isinstance(trigger, trigger_type)]
59
+
60
+ def _unchanged_triggers_with_type(self, trigger_type: type[T]) -> list[T]:
61
+ return [trigger for trigger in self.unchanged if isinstance(trigger, trigger_type)]
62
+
63
+ def _updated_triggers_with_type(self, trigger_type: type[T]) -> list[T]:
64
+ return [trigger for trigger in self.updated if isinstance(trigger, trigger_type)]
65
+
66
+ def triggers_with_type(self, trigger_type: type[T]) -> list[T]:
67
+ """Return all triggers that match the specified type.
68
+
69
+ Args:
70
+ trigger_type: A TriggerDefinition class or subclass to filter by
71
+
72
+ Returns:
73
+ List of triggers of the specified type from all categories
74
+ """
75
+ created = self._created_triggers_with_type(trigger_type=trigger_type)
76
+ updated = self._updated_triggers_with_type(trigger_type=trigger_type)
77
+ refreshed = self._refreshed_triggers_with_type(trigger_type=trigger_type)
78
+ unchanged = self._unchanged_triggers_with_type(trigger_type=trigger_type)
79
+ return created + updated + refreshed + unchanged
80
+
81
+ def modified_triggers_with_type(self, trigger_type: type[T]) -> list[T]:
82
+ """Return all created and updated triggers that match the specified type.
83
+
84
+ Args:
85
+ trigger_type: A TriggerDefinition class or subclass to filter by
86
+
87
+ Returns:
88
+ List of triggers of the specified type from both created and updated lists
89
+ """
90
+ created = self._created_triggers_with_type(trigger_type=trigger_type)
91
+ updated = self._updated_triggers_with_type(trigger_type=trigger_type)
92
+ return created + updated
31
93
 
32
94
 
33
95
  class TriggerType(str, Enum):
@@ -41,6 +103,16 @@ class TriggerType(str, Enum):
41
103
  HUMAN_FRIENDLY_ID = "human_friendly_id"
42
104
  # OBJECT = "object"
43
105
 
106
+ @property
107
+ def is_branch_specific(self) -> bool:
108
+ return self in {
109
+ TriggerType.COMPUTED_ATTR_JINJA2,
110
+ TriggerType.COMPUTED_ATTR_PYTHON,
111
+ TriggerType.COMPUTED_ATTR_PYTHON_QUERY,
112
+ TriggerType.DISPLAY_LABEL_JINJA2,
113
+ TriggerType.HUMAN_FRIENDLY_ID,
114
+ }
115
+
44
116
 
45
117
  def _match_related_dict() -> dict:
46
118
  # Make Mypy happy as match related is a dict[str, Any] | list[dict[str, Any]]
infrahub/trigger/setup.py CHANGED
@@ -12,22 +12,36 @@ from infrahub import lock
12
12
  from infrahub.database import InfrahubDatabase
13
13
  from infrahub.trigger.models import TriggerDefinition
14
14
 
15
- from .models import TriggerSetupReport, TriggerType
15
+ from .models import TriggerComparison, TriggerSetupReport, TriggerType
16
16
 
17
17
  if TYPE_CHECKING:
18
18
  from uuid import UUID
19
19
 
20
20
 
21
- def compare_automations(target: AutomationCore, existing: Automation) -> bool:
22
- """Compare an AutomationCore with an existing Automation object to identify if they are identical or not
23
-
24
- Return True if the target is identical to the existing automation
21
+ def compare_automations(
22
+ target: AutomationCore, existing: Automation, trigger_type: TriggerType | None, force_update: bool = False
23
+ ) -> TriggerComparison:
24
+ """Compare an AutomationCore with an existing Automation object to identify if they are identical,
25
+ if it's a branch specific automation and the branch filter may be different, or if they are different.
25
26
  """
26
27
 
28
+ if force_update:
29
+ return TriggerComparison.UPDATE
30
+
27
31
  target_dump = target.model_dump(exclude_defaults=True, exclude_none=True)
28
32
  existing_dump = existing.model_dump(exclude_defaults=True, exclude_none=True, exclude={"id"})
29
33
 
30
- return target_dump == existing_dump
34
+ if target_dump == existing_dump:
35
+ return TriggerComparison.MATCH
36
+
37
+ if not trigger_type or not trigger_type.is_branch_specific:
38
+ return TriggerComparison.UPDATE
39
+
40
+ if target.description == existing.description:
41
+ # If only the branch related info is different, we consider it a refresh
42
+ return TriggerComparison.REFRESH
43
+
44
+ return TriggerComparison.UPDATE
31
45
 
32
46
 
33
47
  @task(name="trigger-setup-specific", task_run_name="Setup triggers of a specific kind", cache_policy=NONE) # type: ignore[arg-type]
@@ -63,10 +77,8 @@ async def setup_triggers(
63
77
 
64
78
  report = TriggerSetupReport()
65
79
 
66
- if trigger_type:
67
- log.debug(f"Setting up triggers of type {trigger_type.value}")
68
- else:
69
- log.debug("Setting up all triggers")
80
+ trigger_log_message = f"triggers of type {trigger_type.value}" if trigger_type else "all triggers"
81
+ log.debug(f"Setting up {trigger_log_message}")
70
82
 
71
83
  # -------------------------------------------------------------
72
84
  # Retrieve existing Deployments and Automation from the server
@@ -80,16 +92,14 @@ async def setup_triggers(
80
92
  }
81
93
  deployments_mapping: dict[str, UUID] = {name: item.id for name, item in deployments.items()}
82
94
 
83
- # If a trigger type is provided, narrow down the list of existing triggers to know which one to delete
84
- existing_automations: dict[str, Automation] = {}
95
+ existing_automations = {item.name: item for item in await gather_all_automations(client=client)}
85
96
  if trigger_type:
97
+ # If a trigger type is provided, narrow down the list of existing triggers to know which one to delete
86
98
  existing_automations = {
87
- item.name: item
88
- for item in await client.read_automations()
89
- if item.name.startswith(f"{trigger_type.value}::")
99
+ automation_name: automation
100
+ for automation_name, automation in existing_automations.items()
101
+ if automation_name.startswith(f"{trigger_type.value}::")
90
102
  }
91
- else:
92
- existing_automations = {item.name: item for item in await client.read_automations()}
93
103
 
94
104
  trigger_names = [trigger.generate_name() for trigger in triggers]
95
105
  automation_names = list(existing_automations.keys())
@@ -115,12 +125,13 @@ async def setup_triggers(
115
125
  existing_automation = existing_automations.get(trigger.generate_name(), None)
116
126
 
117
127
  if existing_automation:
118
- if force_update or not compare_automations(target=automation, existing=existing_automation):
128
+ trigger_comparison = compare_automations(
129
+ target=automation, existing=existing_automation, trigger_type=trigger_type, force_update=force_update
130
+ )
131
+ if trigger_comparison.update_prefect:
119
132
  await client.update_automation(automation_id=existing_automation.id, automation=automation)
120
133
  log.info(f"{trigger.generate_name()} Updated")
121
- report.updated.append(trigger)
122
- else:
123
- report.unchanged.append(trigger)
134
+ report.add_with_comparison(trigger, trigger_comparison)
124
135
  else:
125
136
  await client.create_automation(automation=automation)
126
137
  log.info(f"{trigger.generate_name()} Created")
@@ -145,15 +156,34 @@ async def setup_triggers(
145
156
  else:
146
157
  raise
147
158
 
148
- if trigger_type:
149
- log.info(
150
- f"Processed triggers of type {trigger_type.value}: "
151
- f"{len(report.created)} created, {len(report.updated)} updated, {len(report.unchanged)} unchanged, {len(report.deleted)} deleted"
152
- )
153
- else:
154
- log.info(
155
- f"Processed all triggers: "
156
- f"{len(report.created)} created, {len(report.updated)} updated, {len(report.unchanged)} unchanged, {len(report.deleted)} deleted"
157
- )
159
+ log.info(
160
+ f"Processed {trigger_log_message}: {len(report.created)} created, {len(report.updated)} updated, "
161
+ f"{len(report.refreshed)} refreshed, {len(report.unchanged)} unchanged, {len(report.deleted)} deleted"
162
+ )
158
163
 
159
164
  return report
165
+
166
+
167
+ async def gather_all_automations(client: PrefectClient) -> list[Automation]:
168
+ """Gather all automations from the Prefect server
169
+
170
+ By default the Prefect client only retrieves a limited number of automations, this function
171
+ retrieves them all by paginating through the results. The default within Prefect is 200 items,
172
+ and client.read_automations() doesn't support pagination parameters.
173
+ """
174
+ automation_count_response = await client.request("POST", "/automations/count")
175
+ automation_count_response.raise_for_status()
176
+ automation_count: int = automation_count_response.json()
177
+ offset = 0
178
+ limit = 200
179
+ missing_automations = True
180
+ automations: list[Automation] = []
181
+ while missing_automations:
182
+ response = await client.request("POST", "/automations/filter", json={"limit": limit, "offset": offset})
183
+ response.raise_for_status()
184
+ automations.extend(Automation.model_validate_list(response.json()))
185
+ if len(automations) >= automation_count:
186
+ missing_automations = False
187
+ offset += limit
188
+
189
+ return automations