infrahub-server 1.4.9__py3-none-any.whl → 1.4.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. infrahub/api/oauth2.py +13 -19
  2. infrahub/api/oidc.py +15 -21
  3. infrahub/artifacts/models.py +2 -1
  4. infrahub/auth.py +137 -3
  5. infrahub/cli/db.py +24 -0
  6. infrahub/cli/db_commands/clean_duplicate_schema_fields.py +212 -0
  7. infrahub/computed_attribute/tasks.py +1 -1
  8. infrahub/core/changelog/models.py +2 -2
  9. infrahub/core/diff/query/artifact.py +12 -9
  10. infrahub/core/ipam/utilization.py +1 -1
  11. infrahub/core/manager.py +6 -3
  12. infrahub/core/node/__init__.py +3 -1
  13. infrahub/core/node/constraints/attribute_uniqueness.py +3 -1
  14. infrahub/core/node/create.py +12 -3
  15. infrahub/core/registry.py +2 -2
  16. infrahub/core/relationship/constraints/count.py +1 -1
  17. infrahub/core/relationship/model.py +1 -1
  18. infrahub/core/schema/definitions/internal.py +4 -0
  19. infrahub/core/schema/manager.py +19 -1
  20. infrahub/core/schema/node_schema.py +4 -2
  21. infrahub/core/schema/schema_branch.py +8 -0
  22. infrahub/core/validators/determiner.py +12 -1
  23. infrahub/core/validators/relationship/peer.py +1 -1
  24. infrahub/core/validators/tasks.py +1 -1
  25. infrahub/generators/tasks.py +3 -7
  26. infrahub/git/integrator.py +1 -1
  27. infrahub/git/models.py +2 -1
  28. infrahub/git/repository.py +22 -5
  29. infrahub/git/tasks.py +14 -8
  30. infrahub/git/utils.py +123 -1
  31. infrahub/graphql/analyzer.py +1 -1
  32. infrahub/graphql/mutations/main.py +3 -3
  33. infrahub/graphql/mutations/schema.py +5 -5
  34. infrahub/message_bus/types.py +2 -1
  35. infrahub/middleware.py +26 -1
  36. infrahub/proposed_change/tasks.py +11 -12
  37. infrahub/server.py +12 -3
  38. infrahub/workers/dependencies.py +8 -1
  39. {infrahub_server-1.4.9.dist-info → infrahub_server-1.4.11.dist-info}/METADATA +17 -17
  40. {infrahub_server-1.4.9.dist-info → infrahub_server-1.4.11.dist-info}/RECORD +46 -45
  41. infrahub_testcontainers/container.py +1 -1
  42. infrahub_testcontainers/docker-compose-cluster.test.yml +1 -1
  43. infrahub_testcontainers/docker-compose.test.yml +1 -1
  44. {infrahub_server-1.4.9.dist-info → infrahub_server-1.4.11.dist-info}/LICENSE.txt +0 -0
  45. {infrahub_server-1.4.9.dist-info → infrahub_server-1.4.11.dist-info}/WHEEL +0 -0
  46. {infrahub_server-1.4.9.dist-info → infrahub_server-1.4.11.dist-info}/entry_points.txt +0 -0
@@ -152,4 +152,4 @@ class PrefixUtilizationGetter:
152
152
  grand_total_space += prefix_total_space
153
153
  if grand_total_space == 0:
154
154
  return 0.0
155
- return (grand_total_used / grand_total_space) * 100
155
+ return min((grand_total_used / grand_total_space) * 100, 100)
infrahub/core/manager.py CHANGED
@@ -63,12 +63,15 @@ def identify_node_class(node: NodeToProcess) -> type[Node]:
63
63
 
64
64
 
65
65
  def get_schema(
66
- db: InfrahubDatabase, branch: Branch, node_schema: type[SchemaProtocol] | MainSchemaTypes | str
66
+ db: InfrahubDatabase,
67
+ branch: Branch,
68
+ node_schema: type[SchemaProtocol] | MainSchemaTypes | str,
69
+ duplicate: bool = False,
67
70
  ) -> MainSchemaTypes:
68
71
  if isinstance(node_schema, str):
69
- return db.schema.get(name=node_schema, branch=branch.name)
72
+ return db.schema.get(name=node_schema, branch=branch.name, duplicate=duplicate)
70
73
  if hasattr(node_schema, "_is_runtime_protocol") and node_schema._is_runtime_protocol:
71
- return db.schema.get(name=node_schema.__name__, branch=branch.name)
74
+ return db.schema.get(name=node_schema.__name__, branch=branch.name, duplicate=duplicate)
72
75
  if not isinstance(node_schema, (MainSchemaTypes)):
73
76
  raise ValueError(f"Invalid schema provided {node_schema}")
74
77
 
@@ -408,7 +408,9 @@ class Node(BaseNode, metaclass=BaseNodeMeta):
408
408
  for attribute_name in template._attributes:
409
409
  if attribute_name in list(fields) + [OBJECT_TEMPLATE_NAME_ATTR]:
410
410
  continue
411
- fields[attribute_name] = {"value": getattr(template, attribute_name).value, "source": template.id}
411
+ attr_value = getattr(template, attribute_name).value
412
+ if attr_value is not None:
413
+ fields[attribute_name] = {"value": attr_value, "source": template.id}
412
414
 
413
415
  for relationship_name in template._relationships:
414
416
  relationship_schema = template._schema.get_relationship(name=relationship_name)
@@ -29,7 +29,9 @@ class NodeAttributeUniquenessConstraint(NodeConstraintInterface):
29
29
  attr = getattr(node, unique_attr.name)
30
30
  if unique_attr.inherited:
31
31
  for generic_parent_schema_name in node_schema.inherit_from:
32
- generic_parent_schema = self.db.schema.get(generic_parent_schema_name, branch=self.branch)
32
+ generic_parent_schema = self.db.schema.get(
33
+ generic_parent_schema_name, branch=self.branch, duplicate=False
34
+ )
33
35
  parent_attr = generic_parent_schema.get_attribute_or_none(unique_attr.name)
34
36
  if parent_attr is None:
35
37
  continue
@@ -41,10 +41,19 @@ async def extract_peer_data(
41
41
  ) -> Mapping[str, Any]:
42
42
  obj_peer_data: dict[str, Any] = {}
43
43
 
44
- for attr in template_peer.get_schema().attribute_names:
45
- if attr not in obj_peer_schema.attribute_names:
44
+ for attr_name in template_peer.get_schema().attribute_names:
45
+ template_attr = getattr(template_peer, attr_name)
46
+ if template_attr.value is None:
46
47
  continue
47
- obj_peer_data[attr] = {"value": getattr(template_peer, attr).value, "source": template_peer.id}
48
+ if template_attr.is_default:
49
+ # if template attr is_default and the value matches the object schema, then do not set the source
50
+ try:
51
+ if obj_peer_schema.get_attribute(name=attr_name).default_value == template_attr.value:
52
+ continue
53
+ except ValueError:
54
+ pass
55
+
56
+ obj_peer_data[attr_name] = {"value": template_attr.value, "source": template_peer.id}
48
57
 
49
58
  for rel in template_peer.get_schema().relationship_names:
50
59
  rel_manager: RelationshipManager = getattr(template_peer, rel)
infrahub/core/registry.py CHANGED
@@ -113,8 +113,8 @@ class Registry:
113
113
  return True
114
114
  return False
115
115
 
116
- def get_node_schema(self, name: str, branch: Branch | str | None = None) -> NodeSchema:
117
- return self.schema.get_node_schema(name=name, branch=branch)
116
+ def get_node_schema(self, name: str, branch: Branch | str | None = None, duplicate: bool = False) -> NodeSchema:
117
+ return self.schema.get_node_schema(name=name, branch=branch, duplicate=duplicate)
118
118
 
119
119
  def get_data_type(self, name: str) -> type[InfrahubDataType]:
120
120
  if name not in self.data_type:
@@ -40,7 +40,7 @@ class RelationshipCountConstraint(RelationshipManagerConstraintInterface):
40
40
  # peer_ids_present_database_only:
41
41
  # relationship to be deleted, need to check if the schema on the other side has a min_count defined
42
42
  # TODO see how to manage Generic node
43
- peer_schema = registry.schema.get(name=relm.schema.peer, branch=branch)
43
+ peer_schema = registry.schema.get(name=relm.schema.peer, branch=branch, duplicate=False)
44
44
  peer_rels = peer_schema.get_relationships_by_identifier(id=relm.schema.get_identifier())
45
45
  if not peer_rels:
46
46
  return
@@ -440,7 +440,7 @@ class Relationship(FlagPropertyMixin, NodePropertyMixin):
440
440
  self.set_peer(value=peer)
441
441
 
442
442
  if not self.peer_id and self.peer_hfid:
443
- peer_schema = db.schema.get(name=self.schema.peer, branch=self.branch)
443
+ peer_schema = db.schema.get(name=self.schema.peer, branch=self.branch, duplicate=False)
444
444
  kind = (
445
445
  self.data["kind"]
446
446
  if isinstance(self.data, dict) and "kind" in self.data and peer_schema.is_generic_schema
@@ -180,6 +180,7 @@ class SchemaNode(BaseModel):
180
180
  attributes: list[SchemaAttribute]
181
181
  relationships: list[SchemaRelationship]
182
182
  display_labels: list[str]
183
+ uniqueness_constraints: list[list[str]] | None = None
183
184
 
184
185
  def to_dict(self) -> dict[str, Any]:
185
186
  return {
@@ -195,6 +196,7 @@ class SchemaNode(BaseModel):
195
196
  ],
196
197
  "relationships": [relationship.to_dict() for relationship in self.relationships],
197
198
  "display_labels": self.display_labels,
199
+ "uniqueness_constraints": self.uniqueness_constraints,
198
200
  }
199
201
 
200
202
  def without_duplicates(self, other: SchemaNode) -> SchemaNode:
@@ -465,6 +467,7 @@ attribute_schema = SchemaNode(
465
467
  include_in_menu=False,
466
468
  default_filter=None,
467
469
  display_labels=["name__value"],
470
+ uniqueness_constraints=[["name__value", "node"]],
468
471
  attributes=[
469
472
  SchemaAttribute(
470
473
  name="id",
@@ -669,6 +672,7 @@ relationship_schema = SchemaNode(
669
672
  include_in_menu=False,
670
673
  default_filter=None,
671
674
  display_labels=["name__value"],
675
+ uniqueness_constraints=[["name__value", "node"]],
672
676
  attributes=[
673
677
  SchemaAttribute(
674
678
  name="id",
@@ -2,6 +2,9 @@ from __future__ import annotations
2
2
 
3
3
  from typing import TYPE_CHECKING, Any
4
4
 
5
+ from cachetools import LRUCache
6
+ from infrahub_sdk.schema import BranchSchema as SDKBranchSchema
7
+
5
8
  from infrahub import lock
6
9
  from infrahub.core.manager import NodeManager
7
10
  from infrahub.core.models import (
@@ -40,6 +43,8 @@ class SchemaManager(NodeManager):
40
43
  def __init__(self) -> None:
41
44
  self._cache: dict[int, Any] = {}
42
45
  self._branches: dict[str, SchemaBranch] = {}
46
+ self._branch_hash_by_name: dict[str, str] = {}
47
+ self._sdk_branches: LRUCache[str, SDKBranchSchema] = LRUCache(maxsize=10)
43
48
 
44
49
  def _get_from_cache(self, key: int) -> Any:
45
50
  return self._cache[key]
@@ -140,12 +145,23 @@ class SchemaManager(NodeManager):
140
145
  if name in self._branches:
141
146
  return self._branches[name]
142
147
 
143
- self._branches[name] = SchemaBranch(cache=self._cache, name=name)
148
+ self.set_schema_branch(name, schema=SchemaBranch(cache=self._cache, name=name))
144
149
  return self._branches[name]
145
150
 
151
+ def get_sdk_schema_branch(self, name: str) -> SDKBranchSchema:
152
+ schema_hash = self._branch_hash_by_name[name]
153
+ branch_schema = self._sdk_branches.get(schema_hash)
154
+ if not branch_schema:
155
+ self._sdk_branches[schema_hash] = SDKBranchSchema.from_api_response(
156
+ data=self._branches[name].to_dict_api_schema_object()
157
+ )
158
+
159
+ return self._sdk_branches[schema_hash]
160
+
146
161
  def set_schema_branch(self, name: str, schema: SchemaBranch) -> None:
147
162
  schema.name = name
148
163
  self._branches[name] = schema
164
+ self._branch_hash_by_name[name] = schema.get_hash()
149
165
 
150
166
  def process_schema_branch(self, name: str) -> None:
151
167
  schema_branch = self.get_schema_branch(name=name)
@@ -764,6 +780,8 @@ class SchemaManager(NodeManager):
764
780
  for branch_name in list(self._branches.keys()):
765
781
  if branch_name not in active_branches:
766
782
  del self._branches[branch_name]
783
+ if branch_name in self._branch_hash_by_name:
784
+ del self._branch_hash_by_name[branch_name]
767
785
  removed_branches.append(branch_name)
768
786
 
769
787
  for hash_key in list(self._cache.keys()):
@@ -129,10 +129,12 @@ class NodeSchema(GeneratedNodeSchema):
129
129
  item_idx = existing_inherited_relationships[relationship.name]
130
130
  self.relationships[item_idx].update_from_generic(other=new_relationship)
131
131
 
132
- def get_hierarchy_schema(self, db: InfrahubDatabase, branch: Branch | str | None = None) -> GenericSchema:
132
+ def get_hierarchy_schema(
133
+ self, db: InfrahubDatabase, branch: Branch | str | None = None, duplicate: bool = False
134
+ ) -> GenericSchema:
133
135
  if not self.hierarchy:
134
136
  raise ValueError("The node is not part of a hierarchy")
135
- schema = db.schema.get(name=self.hierarchy, branch=branch)
137
+ schema = db.schema.get(name=self.hierarchy, branch=branch, duplicate=duplicate)
136
138
  if not isinstance(schema, GenericSchema):
137
139
  raise TypeError
138
140
  return schema
@@ -162,6 +162,14 @@ class SchemaBranch:
162
162
  "templates": {name: self.get(name, duplicate=duplicate) for name in self.templates},
163
163
  }
164
164
 
165
+ def to_dict_api_schema_object(self) -> dict[str, list[dict]]:
166
+ return {
167
+ "nodes": [self.get(name, duplicate=False).model_dump() for name in self.nodes],
168
+ "profiles": [self.get(name, duplicate=False).model_dump() for name in self.profiles],
169
+ "generics": [self.get(name, duplicate=False).model_dump() for name in self.generics],
170
+ "templates": [self.get(name, duplicate=False).model_dump() for name in self.templates],
171
+ }
172
+
165
173
  @classmethod
166
174
  def from_dict_schema_object(cls, data: dict) -> Self:
167
175
  type_mapping = {
@@ -10,6 +10,7 @@ from infrahub.core.schema.attribute_parameters import AttributeParameters
10
10
  from infrahub.core.schema.relationship_schema import RelationshipSchema
11
11
  from infrahub.core.schema.schema_branch import SchemaBranch
12
12
  from infrahub.core.validators import CONSTRAINT_VALIDATOR_MAP
13
+ from infrahub.exceptions import SchemaNotFoundError
13
14
  from infrahub.log import get_logger
14
15
 
15
16
  if TYPE_CHECKING:
@@ -81,7 +82,17 @@ class ConstraintValidatorDeterminer:
81
82
 
82
83
  async def _get_all_property_constraints(self) -> list[SchemaUpdateConstraintInfo]:
83
84
  constraints: list[SchemaUpdateConstraintInfo] = []
84
- for schema in self.schema_branch.get_all().values():
85
+ schemas = list(self.schema_branch.get_all(duplicate=False).values())
86
+ # added here to check their uniqueness constraints
87
+ try:
88
+ schemas.append(self.schema_branch.get_node(name="SchemaAttribute", duplicate=False))
89
+ except SchemaNotFoundError:
90
+ pass
91
+ try:
92
+ schemas.append(self.schema_branch.get_node(name="SchemaRelationship", duplicate=False))
93
+ except SchemaNotFoundError:
94
+ pass
95
+ for schema in schemas:
85
96
  constraints.extend(await self._get_property_constraints_for_one_schema(schema=schema))
86
97
  return constraints
87
98
 
@@ -22,7 +22,7 @@ class RelationshipPeerUpdateValidatorQuery(RelationshipSchemaValidatorQuery):
22
22
  name = "relationship_constraints_peer_validator"
23
23
 
24
24
  async def query_init(self, db: InfrahubDatabase, **kwargs: dict[str, Any]) -> None: # noqa: ARG002
25
- peer_schema = db.schema.get(name=self.relationship_schema.peer, branch=self.branch)
25
+ peer_schema = db.schema.get(name=self.relationship_schema.peer, branch=self.branch, duplicate=False)
26
26
  allowed_peer_kinds = [peer_schema.kind]
27
27
  if isinstance(peer_schema, GenericSchema):
28
28
  allowed_peer_kinds += peer_schema.used_by
@@ -36,7 +36,7 @@ async def schema_validate_migrations(message: SchemaValidateMigrationData) -> li
36
36
  log.info(f"{len(message.constraints)} constraint(s) to validate")
37
37
  # NOTE this task is a good candidate to add a progress bar
38
38
  for constraint in message.constraints:
39
- schema = message.schema_branch.get(name=constraint.path.schema_kind)
39
+ schema = message.schema_branch.get(name=constraint.path.schema_kind, duplicate=False)
40
40
  if not isinstance(schema, GenericSchema | NodeSchema):
41
41
  continue
42
42
  batch.add(
@@ -21,6 +21,7 @@ from infrahub.generators.models import (
21
21
  )
22
22
  from infrahub.git.base import extract_repo_file_information
23
23
  from infrahub.git.repository import get_initialized_repo
24
+ from infrahub.git.utils import fetch_proposed_change_generator_definition_targets
24
25
  from infrahub.workers.dependencies import get_client, get_workflow
25
26
  from infrahub.workflows.catalogue import REQUEST_GENERATOR_DEFINITION_RUN, REQUEST_GENERATOR_RUN
26
27
  from infrahub.workflows.utils import add_tags
@@ -177,14 +178,9 @@ async def request_generator_definition_run(
177
178
  branch=model.branch,
178
179
  )
179
180
 
180
- group = await client.get(
181
- kind=InfrahubKind.GENERICGROUP,
182
- prefetch_relationships=True,
183
- populate_store=True,
184
- id=model.generator_definition.group_id,
185
- branch=model.branch,
181
+ group = await fetch_proposed_change_generator_definition_targets(
182
+ client=client, branch=model.branch, definition=model.generator_definition
186
183
  )
187
- await group.members.fetch()
188
184
 
189
185
  instance_by_member = {}
190
186
  for instance in existing_instances:
@@ -1363,7 +1363,7 @@ class InfrahubRepositoryIntegrator(InfrahubRepositoryBase):
1363
1363
  message: CheckArtifactCreate | RequestArtifactGenerate,
1364
1364
  ) -> ArtifactGenerateResult:
1365
1365
  response = await self.sdk.query_gql_query(
1366
- name=message.query,
1366
+ name=message.query_id,
1367
1367
  variables=message.variables,
1368
1368
  update_group=True,
1369
1369
  subscribers=[artifact.id],
infrahub/git/models.py CHANGED
@@ -38,7 +38,8 @@ class RequestArtifactGenerate(BaseModel):
38
38
  target_kind: str = Field(..., description="The kind of the target object for this artifact")
39
39
  target_name: str = Field(..., description="Name of the artifact target")
40
40
  artifact_id: str | None = Field(default=None, description="The id of the artifact if it previously existed")
41
- query: str = Field(..., description="The name of the query to use when collecting data")
41
+ query: str = Field(..., description="The name of the query to use when collecting data") # Deprecated
42
+ query_id: str = Field(..., description="The id of the query to use when collecting data")
42
43
  timeout: int = Field(..., description="Timeout for requests used to generate this artifact")
43
44
  variables: dict = Field(..., description="Input variables when generating the artifact")
44
45
  context: InfrahubContext = Field(..., description="The context of the task")
@@ -2,6 +2,9 @@ from __future__ import annotations
2
2
 
3
3
  from typing import TYPE_CHECKING, Any
4
4
 
5
+ from cachetools import TTLCache
6
+ from cachetools.keys import hashkey
7
+ from cachetools_async import cached
5
8
  from git.exc import BadName, GitCommandError
6
9
  from infrahub_sdk.exceptions import GraphQLError
7
10
  from prefect import task
@@ -248,12 +251,13 @@ class InfrahubReadOnlyRepository(InfrahubRepositoryIntegrator):
248
251
  await self.update_commit_value(branch_name=self.infrahub_branch_name, commit=commit)
249
252
 
250
253
 
251
- @task(
252
- name="Fetch repository commit",
253
- description="Retrieve a git repository at a given commit, if it does not already exist locally",
254
- cache_policy=NONE,
254
+ @cached(
255
+ TTLCache(maxsize=100, ttl=30),
256
+ key=lambda *_, **kwargs: hashkey(
257
+ kwargs.get("repository_id"), kwargs.get("name"), kwargs.get("repository_kind"), kwargs.get("commit")
258
+ ),
255
259
  )
256
- async def get_initialized_repo(
260
+ async def _get_initialized_repo(
257
261
  client: InfrahubClient, repository_id: str, name: str, repository_kind: str, commit: str | None = None
258
262
  ) -> InfrahubReadOnlyRepository | InfrahubRepository:
259
263
  if repository_kind == InfrahubKind.REPOSITORY:
@@ -263,3 +267,16 @@ async def get_initialized_repo(
263
267
  return await InfrahubReadOnlyRepository.init(id=repository_id, name=name, commit=commit, client=client)
264
268
 
265
269
  raise NotImplementedError(f"The repository kind {repository_kind} has not been implemented")
270
+
271
+
272
+ @task(
273
+ name="Fetch repository commit",
274
+ description="Retrieve a git repository at a given commit, if it does not already exist locally",
275
+ cache_policy=NONE,
276
+ )
277
+ async def get_initialized_repo(
278
+ client: InfrahubClient, repository_id: str, name: str, repository_kind: str, commit: str | None = None
279
+ ) -> InfrahubReadOnlyRepository | InfrahubRepository:
280
+ return await _get_initialized_repo(
281
+ client=client, repository_id=repository_id, name=name, repository_kind=repository_kind, commit=commit
282
+ )
infrahub/git/tasks.py CHANGED
@@ -53,6 +53,7 @@ from .models import (
53
53
  UserCheckDefinitionData,
54
54
  )
55
55
  from .repository import InfrahubReadOnlyRepository, InfrahubRepository, get_initialized_repo
56
+ from .utils import fetch_artifact_definition_targets, fetch_check_definition_targets
56
57
 
57
58
 
58
59
  @flow(
@@ -323,9 +324,8 @@ async def generate_request_artifact_definition(
323
324
  kind=CoreArtifactDefinition, id=model.artifact_definition_id, branch=model.branch
324
325
  )
325
326
 
326
- await artifact_definition.targets.fetch()
327
- group = artifact_definition.targets.peer
328
- await group.members.fetch()
327
+ group = await fetch_artifact_definition_targets(client=client, branch=model.branch, definition=artifact_definition)
328
+
329
329
  current_members = [member.id for member in group.members.peers]
330
330
 
331
331
  artifacts_by_member = {}
@@ -356,6 +356,7 @@ async def generate_request_artifact_definition(
356
356
  transform_location = f"{transform.file_path.value}::{transform.class_name.value}"
357
357
  convert_query_response = transform.convert_query_response.value
358
358
 
359
+ batch = await client.create_batch()
359
360
  for relationship in group.members.peers:
360
361
  member = relationship.peer
361
362
  artifact_id = artifacts_by_member.get(member.id)
@@ -376,6 +377,7 @@ async def generate_request_artifact_definition(
376
377
  repository_kind=repository.get_kind(),
377
378
  branch_name=model.branch,
378
379
  query=query.name.value,
380
+ query_id=query.id,
379
381
  variables=await member.extract(params=artifact_definition.parameters.value),
380
382
  target_id=member.id,
381
383
  target_name=member.display_label,
@@ -385,10 +387,16 @@ async def generate_request_artifact_definition(
385
387
  context=context,
386
388
  )
387
389
 
388
- await get_workflow().submit_workflow(
389
- workflow=REQUEST_ARTIFACT_GENERATE, context=context, parameters={"model": request_artifact_generate_model}
390
+ batch.add(
391
+ task=get_workflow().submit_workflow,
392
+ workflow=REQUEST_ARTIFACT_GENERATE,
393
+ context=context,
394
+ parameters={"model": request_artifact_generate_model},
390
395
  )
391
396
 
397
+ async for _, _ in batch.execute():
398
+ pass
399
+
392
400
 
393
401
  @flow(name="git-repository-pull-read-only", flow_run_name="Pull latest commit on {model.repository_name}")
394
402
  async def pull_read_only(model: GitRepositoryPullReadOnly) -> None:
@@ -569,9 +577,7 @@ async def trigger_repository_user_checks_definitions(model: UserCheckDefinitionD
569
577
 
570
578
  if definition.targets.id:
571
579
  # Check against a group of targets
572
- await definition.targets.fetch()
573
- group = definition.targets.peer
574
- await group.members.fetch()
580
+ group = await fetch_check_definition_targets(client=client, branch=model.branch_name, definition=definition)
575
581
  check_models = []
576
582
  for relationship in group.members.peers:
577
583
  member = relationship.peer
infrahub/git/utils.py CHANGED
@@ -1,9 +1,16 @@
1
- from typing import TYPE_CHECKING
1
+ from collections import defaultdict
2
+ from typing import TYPE_CHECKING, Any
3
+
4
+ from infrahub_sdk import InfrahubClient
5
+ from infrahub_sdk.node import RelationshipManager
6
+ from infrahub_sdk.protocols import CoreArtifactDefinition, CoreCheckDefinition, CoreGroup
7
+ from infrahub_sdk.types import Order
2
8
 
3
9
  from infrahub.core import registry
4
10
  from infrahub.core.constants import InfrahubKind
5
11
  from infrahub.core.manager import NodeManager
6
12
  from infrahub.database import InfrahubDatabase
13
+ from infrahub.generators.models import ProposedChangeGeneratorDefinition
7
14
 
8
15
  from .models import RepositoryBranchInfo, RepositoryData
9
16
 
@@ -46,3 +53,118 @@ async def get_repositories_commit_per_branch(
46
53
  )
47
54
 
48
55
  return repositories
56
+
57
+
58
+ def _collect_parameter_first_segments(params: Any) -> set[str]:
59
+ segments: set[str] = set()
60
+
61
+ def _walk(value: Any) -> None:
62
+ if isinstance(value, str):
63
+ segment = value.split("__", 1)[0]
64
+ if segment:
65
+ segments.add(segment)
66
+ elif isinstance(value, dict):
67
+ for nested in value.values():
68
+ _walk(nested)
69
+ elif isinstance(value, (list, tuple, set)):
70
+ for nested in value:
71
+ _walk(nested)
72
+
73
+ _walk(params)
74
+ return segments
75
+
76
+
77
+ async def _prefetch_group_member_nodes(
78
+ client: InfrahubClient,
79
+ members: RelationshipManager,
80
+ branch: str,
81
+ required_fields: set[str],
82
+ ) -> None:
83
+ ids_per_kind: dict[str, set[str]] = defaultdict(set)
84
+ for peer in members.peers:
85
+ if peer.id and peer.typename:
86
+ ids_per_kind[peer.typename].add(peer.id)
87
+
88
+ if not ids_per_kind:
89
+ return
90
+
91
+ batch = await client.create_batch()
92
+
93
+ for kind, ids in ids_per_kind.items():
94
+ schema = await client.schema.get(kind=kind, branch=branch)
95
+
96
+ # FIXME: https://github.com/opsmill/infrahub-sdk-python/pull/205
97
+ valid_fields = set(schema.attribute_names) | set(schema.relationship_names)
98
+ keep_relationships = set(schema.relationship_names) & required_fields
99
+ cleaned_fields = valid_fields - required_fields
100
+
101
+ kwargs: dict[str, Any] = {
102
+ "kind": kind,
103
+ "ids": list(ids),
104
+ "branch": branch,
105
+ "exclude": list(cleaned_fields),
106
+ "populate_store": True,
107
+ "order": Order(disable=True),
108
+ }
109
+
110
+ if keep_relationships:
111
+ kwargs["include"] = list(keep_relationships)
112
+
113
+ batch.add(task=client.filters, **kwargs)
114
+
115
+ async for _ in batch.execute():
116
+ pass
117
+
118
+
119
+ async def _fetch_definition_targets(
120
+ client: InfrahubClient,
121
+ branch: str,
122
+ group_id: str,
123
+ parameters: Any,
124
+ ) -> CoreGroup:
125
+ group = await client.get(
126
+ kind=CoreGroup,
127
+ id=group_id,
128
+ branch=branch,
129
+ include=["members"],
130
+ )
131
+
132
+ parameter_fields = _collect_parameter_first_segments(parameters)
133
+ await _prefetch_group_member_nodes(
134
+ client=client,
135
+ members=group.members,
136
+ branch=branch,
137
+ required_fields=parameter_fields,
138
+ )
139
+
140
+ return group
141
+
142
+
143
+ async def fetch_artifact_definition_targets(
144
+ client: InfrahubClient,
145
+ branch: str,
146
+ definition: CoreArtifactDefinition,
147
+ ) -> CoreGroup:
148
+ return await _fetch_definition_targets(
149
+ client=client, branch=branch, group_id=definition.targets.id, parameters=definition.parameters.value
150
+ )
151
+
152
+
153
+ async def fetch_check_definition_targets(
154
+ client: InfrahubClient,
155
+ branch: str,
156
+ definition: CoreCheckDefinition,
157
+ ) -> CoreGroup:
158
+ return await _fetch_definition_targets(
159
+ client=client, branch=branch, group_id=definition.targets.id, parameters=definition.parameters.value
160
+ )
161
+
162
+
163
+ async def fetch_proposed_change_generator_definition_targets(
164
+ client: InfrahubClient,
165
+ branch: str,
166
+ definition: ProposedChangeGeneratorDefinition,
167
+ ) -> CoreGroup:
168
+ return await _fetch_definition_targets(
169
+ client=client, branch=branch, group_id=definition.group_id, parameters=definition.parameters
170
+ )
@@ -639,7 +639,7 @@ class InfrahubGraphQLQueryAnalyzer(GraphQLQueryAnalyzer):
639
639
  self, node: InlineFragmentNode, query_node: GraphQLQueryNode
640
640
  ) -> GraphQLQueryNode:
641
641
  context_type = query_node.context_type
642
- infrahub_model = self.schema_branch.get(name=node.type_condition.name.value)
642
+ infrahub_model = self.schema_branch.get(name=node.type_condition.name.value, duplicate=False)
643
643
  context_type = ContextType.DIRECT
644
644
  current_node = GraphQLQueryNode(
645
645
  parent=query_node,
@@ -479,7 +479,7 @@ def _get_kinds_to_lock_on_object_mutation(kind: str, schema_branch: SchemaBranch
479
479
  it means node schema overrided this constraint, in which case we only need to lock on the generic.
480
480
  """
481
481
 
482
- node_schema = schema_branch.get(name=kind)
482
+ node_schema = schema_branch.get(name=kind, duplicate=False)
483
483
 
484
484
  schema_uc = None
485
485
  kinds = []
@@ -494,7 +494,7 @@ def _get_kinds_to_lock_on_object_mutation(kind: str, schema_branch: SchemaBranch
494
494
 
495
495
  node_schema_kind_removed = False
496
496
  for generic_kind in generics_kinds:
497
- generic_uc = schema_branch.get(name=generic_kind).uniqueness_constraints
497
+ generic_uc = schema_branch.get(name=generic_kind, duplicate=False).uniqueness_constraints
498
498
  if generic_uc:
499
499
  kinds.append(generic_kind)
500
500
  if not node_schema_kind_removed and generic_uc == schema_uc:
@@ -513,7 +513,7 @@ def _should_kind_be_locked_on_any_branch(kind: str, schema_branch: SchemaBranch)
513
513
  if kind in KINDS_CONCURRENT_MUTATIONS_NOT_ALLOWED:
514
514
  return True
515
515
 
516
- node_schema = schema_branch.get(name=kind)
516
+ node_schema = schema_branch.get(name=kind, duplicate=False)
517
517
  if node_schema.is_generic_schema:
518
518
  return False
519
519
 
@@ -81,7 +81,7 @@ class SchemaDropdownAdd(Mutation):
81
81
  _validate_schema_permission(graphql_context=graphql_context)
82
82
  await apply_external_context(graphql_context=graphql_context, context_input=context)
83
83
 
84
- kind = graphql_context.db.schema.get(name=str(data.kind), branch=graphql_context.branch.name)
84
+ kind = graphql_context.db.schema.get(name=str(data.kind), branch=graphql_context.branch.name, duplicate=False)
85
85
  attribute = str(data.attribute)
86
86
  validate_kind_dropdown(kind=kind, attribute=attribute)
87
87
  dropdown = str(data.dropdown)
@@ -104,7 +104,7 @@ class SchemaDropdownAdd(Mutation):
104
104
  context=graphql_context.get_context(),
105
105
  )
106
106
 
107
- kind = graphql_context.db.schema.get(name=str(data.kind), branch=graphql_context.branch.name)
107
+ kind = graphql_context.db.schema.get(name=str(data.kind), branch=graphql_context.branch.name, duplicate=False)
108
108
  attrib = kind.get_attribute(attribute)
109
109
  dropdown_entry = {}
110
110
  success = False
@@ -141,7 +141,7 @@ class SchemaDropdownRemove(Mutation):
141
141
  graphql_context: GraphqlContext = info.context
142
142
 
143
143
  _validate_schema_permission(graphql_context=graphql_context)
144
- kind = graphql_context.db.schema.get(name=str(data.kind), branch=graphql_context.branch.name)
144
+ kind = graphql_context.db.schema.get(name=str(data.kind), branch=graphql_context.branch.name, duplicate=False)
145
145
  await apply_external_context(graphql_context=graphql_context, context_input=context)
146
146
 
147
147
  attribute = str(data.attribute)
@@ -197,7 +197,7 @@ class SchemaEnumAdd(Mutation):
197
197
  graphql_context: GraphqlContext = info.context
198
198
 
199
199
  _validate_schema_permission(graphql_context=graphql_context)
200
- kind = graphql_context.db.schema.get(name=str(data.kind), branch=graphql_context.branch.name)
200
+ kind = graphql_context.db.schema.get(name=str(data.kind), branch=graphql_context.branch.name, duplicate=False)
201
201
  await apply_external_context(graphql_context=graphql_context, context_input=context)
202
202
 
203
203
  attribute = str(data.attribute)
@@ -243,7 +243,7 @@ class SchemaEnumRemove(Mutation):
243
243
  graphql_context: GraphqlContext = info.context
244
244
 
245
245
  _validate_schema_permission(graphql_context=graphql_context)
246
- kind = graphql_context.db.schema.get(name=str(data.kind), branch=graphql_context.branch.name)
246
+ kind = graphql_context.db.schema.get(name=str(data.kind), branch=graphql_context.branch.name, duplicate=False)
247
247
  await apply_external_context(graphql_context=graphql_context, context_input=context)
248
248
 
249
249
  attribute = str(data.attribute)
@@ -89,7 +89,8 @@ class ProposedChangeArtifactDefinition(BaseModel):
89
89
  definition_id: str
90
90
  definition_name: str
91
91
  artifact_name: str
92
- query_name: str
92
+ query_name: str # Deprecated
93
+ query_id: str
93
94
  query_models: list[str]
94
95
  repository_id: str
95
96
  transform_kind: str