graphiti-core 0.11.6rc9__py3-none-any.whl → 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

Files changed (33) hide show
  1. graphiti_core/cross_encoder/openai_reranker_client.py +1 -1
  2. graphiti_core/driver/__init__.py +17 -0
  3. graphiti_core/driver/driver.py +66 -0
  4. graphiti_core/driver/falkordb_driver.py +132 -0
  5. graphiti_core/driver/neo4j_driver.py +61 -0
  6. graphiti_core/edges.py +66 -40
  7. graphiti_core/embedder/azure_openai.py +64 -0
  8. graphiti_core/embedder/gemini.py +14 -3
  9. graphiti_core/graph_queries.py +149 -0
  10. graphiti_core/graphiti.py +41 -14
  11. graphiti_core/graphiti_types.py +2 -2
  12. graphiti_core/helpers.py +9 -4
  13. graphiti_core/llm_client/__init__.py +16 -0
  14. graphiti_core/llm_client/azure_openai_client.py +73 -0
  15. graphiti_core/llm_client/gemini_client.py +4 -1
  16. graphiti_core/models/edges/edge_db_queries.py +2 -4
  17. graphiti_core/nodes.py +31 -31
  18. graphiti_core/prompts/dedupe_edges.py +52 -1
  19. graphiti_core/prompts/dedupe_nodes.py +79 -4
  20. graphiti_core/prompts/extract_edges.py +50 -5
  21. graphiti_core/prompts/invalidate_edges.py +1 -1
  22. graphiti_core/search/search.py +6 -10
  23. graphiti_core/search/search_filters.py +23 -9
  24. graphiti_core/search/search_utils.py +250 -189
  25. graphiti_core/utils/bulk_utils.py +38 -11
  26. graphiti_core/utils/maintenance/community_operations.py +6 -7
  27. graphiti_core/utils/maintenance/edge_operations.py +149 -19
  28. graphiti_core/utils/maintenance/graph_data_operations.py +13 -42
  29. graphiti_core/utils/maintenance/node_operations.py +52 -71
  30. {graphiti_core-0.11.6rc9.dist-info → graphiti_core-0.12.0.dist-info}/METADATA +14 -5
  31. {graphiti_core-0.11.6rc9.dist-info → graphiti_core-0.12.0.dist-info}/RECORD +33 -26
  32. {graphiti_core-0.11.6rc9.dist-info → graphiti_core-0.12.0.dist-info}/LICENSE +0 -0
  33. {graphiti_core-0.11.6rc9.dist-info → graphiti_core-0.12.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,149 @@
1
+ """
2
+ Database query utilities for different graph database backends.
3
+
4
+ This module provides database-agnostic query generation for Neo4j and FalkorDB,
5
+ supporting index creation, fulltext search, and bulk operations.
6
+ """
7
+
8
+ from typing import Any
9
+
10
+ from typing_extensions import LiteralString
11
+
12
+ from graphiti_core.models.edges.edge_db_queries import (
13
+ ENTITY_EDGE_SAVE_BULK,
14
+ )
15
+ from graphiti_core.models.nodes.node_db_queries import (
16
+ ENTITY_NODE_SAVE_BULK,
17
+ )
18
+
19
+ # Mapping from Neo4j fulltext index names to FalkorDB node labels
20
+ NEO4J_TO_FALKORDB_MAPPING = {
21
+ 'node_name_and_summary': 'Entity',
22
+ 'community_name': 'Community',
23
+ 'episode_content': 'Episodic',
24
+ 'edge_name_and_fact': 'RELATES_TO',
25
+ }
26
+
27
+
28
+ def get_range_indices(db_type: str = 'neo4j') -> list[LiteralString]:
29
+ if db_type == 'falkordb':
30
+ return [
31
+ # Entity node
32
+ 'CREATE INDEX FOR (n:Entity) ON (n.uuid, n.group_id, n.name, n.created_at)',
33
+ # Episodic node
34
+ 'CREATE INDEX FOR (n:Episodic) ON (n.uuid, n.group_id, n.created_at, n.valid_at)',
35
+ # Community node
36
+ 'CREATE INDEX FOR (n:Community) ON (n.uuid)',
37
+ # RELATES_TO edge
38
+ 'CREATE INDEX FOR ()-[e:RELATES_TO]-() ON (e.uuid, e.group_id, e.name, e.created_at, e.expired_at, e.valid_at, e.invalid_at)',
39
+ # MENTIONS edge
40
+ 'CREATE INDEX FOR ()-[e:MENTIONS]-() ON (e.uuid, e.group_id)',
41
+ # HAS_MEMBER edge
42
+ 'CREATE INDEX FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)',
43
+ ]
44
+ else:
45
+ return [
46
+ 'CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)',
47
+ 'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)',
48
+ 'CREATE INDEX community_uuid IF NOT EXISTS FOR (n:Community) ON (n.uuid)',
49
+ 'CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.uuid)',
50
+ 'CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.uuid)',
51
+ 'CREATE INDEX has_member_uuid IF NOT EXISTS FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)',
52
+ 'CREATE INDEX entity_group_id IF NOT EXISTS FOR (n:Entity) ON (n.group_id)',
53
+ 'CREATE INDEX episode_group_id IF NOT EXISTS FOR (n:Episodic) ON (n.group_id)',
54
+ 'CREATE INDEX relation_group_id IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.group_id)',
55
+ 'CREATE INDEX mention_group_id IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.group_id)',
56
+ 'CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)',
57
+ 'CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)',
58
+ 'CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)',
59
+ 'CREATE INDEX valid_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.valid_at)',
60
+ 'CREATE INDEX name_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.name)',
61
+ 'CREATE INDEX created_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.created_at)',
62
+ 'CREATE INDEX expired_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.expired_at)',
63
+ 'CREATE INDEX valid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.valid_at)',
64
+ 'CREATE INDEX invalid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.invalid_at)',
65
+ ]
66
+
67
+
68
+ def get_fulltext_indices(db_type: str = 'neo4j') -> list[LiteralString]:
69
+ if db_type == 'falkordb':
70
+ return [
71
+ """CREATE FULLTEXT INDEX FOR (e:Episodic) ON (e.content, e.source, e.source_description, e.group_id)""",
72
+ """CREATE FULLTEXT INDEX FOR (n:Entity) ON (n.name, n.summary, n.group_id)""",
73
+ """CREATE FULLTEXT INDEX FOR (n:Community) ON (n.name, n.group_id)""",
74
+ """CREATE FULLTEXT INDEX FOR ()-[e:RELATES_TO]-() ON (e.name, e.fact, e.group_id)""",
75
+ ]
76
+ else:
77
+ return [
78
+ """CREATE FULLTEXT INDEX episode_content IF NOT EXISTS
79
+ FOR (e:Episodic) ON EACH [e.content, e.source, e.source_description, e.group_id]""",
80
+ """CREATE FULLTEXT INDEX node_name_and_summary IF NOT EXISTS
81
+ FOR (n:Entity) ON EACH [n.name, n.summary, n.group_id]""",
82
+ """CREATE FULLTEXT INDEX community_name IF NOT EXISTS
83
+ FOR (n:Community) ON EACH [n.name, n.group_id]""",
84
+ """CREATE FULLTEXT INDEX edge_name_and_fact IF NOT EXISTS
85
+ FOR ()-[e:RELATES_TO]-() ON EACH [e.name, e.fact, e.group_id]""",
86
+ ]
87
+
88
+
89
+ def get_nodes_query(db_type: str = 'neo4j', name: str = '', query: str | None = None) -> str:
90
+ if db_type == 'falkordb':
91
+ label = NEO4J_TO_FALKORDB_MAPPING[name]
92
+ return f"CALL db.idx.fulltext.queryNodes('{label}', {query})"
93
+ else:
94
+ return f'CALL db.index.fulltext.queryNodes("{name}", {query}, {{limit: $limit}})'
95
+
96
+
97
+ def get_vector_cosine_func_query(vec1, vec2, db_type: str = 'neo4j') -> str:
98
+ if db_type == 'falkordb':
99
+ # FalkorDB uses a different syntax for regular cosine similarity and Neo4j uses normalized cosine similarity
100
+ return f'(2 - vec.cosineDistance({vec1}, vecf32({vec2})))/2'
101
+ else:
102
+ return f'vector.similarity.cosine({vec1}, {vec2})'
103
+
104
+
105
+ def get_relationships_query(name: str, db_type: str = 'neo4j') -> str:
106
+ if db_type == 'falkordb':
107
+ label = NEO4J_TO_FALKORDB_MAPPING[name]
108
+ return f"CALL db.idx.fulltext.queryRelationships('{label}', $query)"
109
+ else:
110
+ return f'CALL db.index.fulltext.queryRelationships("{name}", $query, {{limit: $limit}})'
111
+
112
+
113
+ def get_entity_node_save_bulk_query(nodes, db_type: str = 'neo4j') -> str | Any:
114
+ if db_type == 'falkordb':
115
+ queries = []
116
+ for node in nodes:
117
+ for label in node['labels']:
118
+ queries.append(
119
+ (
120
+ f"""
121
+ UNWIND $nodes AS node
122
+ MERGE (n:Entity {{uuid: node.uuid}})
123
+ SET n:{label}
124
+ SET n = node
125
+ WITH n, node
126
+ SET n.name_embedding = vecf32(node.name_embedding)
127
+ RETURN n.uuid AS uuid
128
+ """,
129
+ {'nodes': [node]},
130
+ )
131
+ )
132
+ return queries
133
+ else:
134
+ return ENTITY_NODE_SAVE_BULK
135
+
136
+
137
+ def get_entity_edge_save_bulk_query(db_type: str = 'neo4j') -> str:
138
+ if db_type == 'falkordb':
139
+ return """
140
+ UNWIND $entity_edges AS edge
141
+ MATCH (source:Entity {uuid: edge.source_node_uuid})
142
+ MATCH (target:Entity {uuid: edge.target_node_uuid})
143
+ MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target)
144
+ SET r = {uuid: edge.uuid, name: edge.name, group_id: edge.group_id, fact: edge.fact, episodes: edge.episodes,
145
+ created_at: edge.created_at, expired_at: edge.expired_at, valid_at: edge.valid_at, invalid_at: edge.invalid_at, fact_embedding: vecf32(edge.fact_embedding)}
146
+ WITH r, edge
147
+ RETURN edge.uuid AS uuid"""
148
+ else:
149
+ return ENTITY_EDGE_SAVE_BULK
graphiti_core/graphiti.py CHANGED
@@ -19,12 +19,13 @@ from datetime import datetime
19
19
  from time import time
20
20
 
21
21
  from dotenv import load_dotenv
22
- from neo4j import AsyncGraphDatabase
23
22
  from pydantic import BaseModel
24
23
  from typing_extensions import LiteralString
25
24
 
26
25
  from graphiti_core.cross_encoder.client import CrossEncoderClient
27
26
  from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
27
+ from graphiti_core.driver.driver import GraphDriver
28
+ from graphiti_core.driver.neo4j_driver import Neo4jDriver
28
29
  from graphiti_core.edges import EntityEdge, EpisodicEdge
29
30
  from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
30
31
  from graphiti_core.graphiti_types import GraphitiClients
@@ -41,6 +42,7 @@ from graphiti_core.search.search_config_recipes import (
41
42
  from graphiti_core.search.search_filters import SearchFilters
42
43
  from graphiti_core.search.search_utils import (
43
44
  RELEVANT_SCHEMA_LIMIT,
45
+ get_edge_invalidation_candidates,
44
46
  get_mentioned_nodes,
45
47
  get_relevant_edges,
46
48
  )
@@ -62,9 +64,8 @@ from graphiti_core.utils.maintenance.community_operations import (
62
64
  )
63
65
  from graphiti_core.utils.maintenance.edge_operations import (
64
66
  build_episodic_edges,
65
- dedupe_extracted_edge,
66
67
  extract_edges,
67
- resolve_edge_contradictions,
68
+ resolve_extracted_edge,
68
69
  resolve_extracted_edges,
69
70
  )
70
71
  from graphiti_core.utils.maintenance.graph_data_operations import (
@@ -77,7 +78,6 @@ from graphiti_core.utils.maintenance.node_operations import (
77
78
  extract_nodes,
78
79
  resolve_extracted_nodes,
79
80
  )
80
- from graphiti_core.utils.maintenance.temporal_operations import get_edge_contradictions
81
81
  from graphiti_core.utils.ontology_utils.entity_types_utils import validate_entity_types
82
82
 
83
83
  logger = logging.getLogger(__name__)
@@ -95,12 +95,13 @@ class Graphiti:
95
95
  def __init__(
96
96
  self,
97
97
  uri: str,
98
- user: str,
99
- password: str,
98
+ user: str | None = None,
99
+ password: str | None = None,
100
100
  llm_client: LLMClient | None = None,
101
101
  embedder: EmbedderClient | None = None,
102
102
  cross_encoder: CrossEncoderClient | None = None,
103
103
  store_raw_episode_content: bool = True,
104
+ graph_driver: GraphDriver | None = None,
104
105
  ):
105
106
  """
106
107
  Initialize a Graphiti instance.
@@ -138,7 +139,9 @@ class Graphiti:
138
139
  Make sure to set the OPENAI_API_KEY environment variable before initializing
139
140
  Graphiti if you're using the default OpenAIClient.
140
141
  """
141
- self.driver = AsyncGraphDatabase.driver(uri, auth=(user, password))
142
+
143
+ self.driver = graph_driver if graph_driver else Neo4jDriver(uri, user, password)
144
+
142
145
  self.database = DEFAULT_DATABASE
143
146
  self.store_raw_episode_content = store_raw_episode_content
144
147
  if llm_client:
@@ -274,6 +277,8 @@ class Graphiti:
274
277
  update_communities: bool = False,
275
278
  entity_types: dict[str, BaseModel] | None = None,
276
279
  previous_episode_uuids: list[str] | None = None,
280
+ edge_types: dict[str, BaseModel] | None = None,
281
+ edge_type_map: dict[tuple[str, str], list[str]] | None = None,
277
282
  ) -> AddEpisodeResults:
278
283
  """
279
284
  Process an episode and update the graph.
@@ -356,6 +361,13 @@ class Graphiti:
356
361
  )
357
362
  )
358
363
 
364
+ # Create default edge type map
365
+ edge_type_map_default = (
366
+ {('Entity', 'Entity'): list(edge_types.keys())}
367
+ if edge_types is not None
368
+ else {('Entity', 'Entity'): []}
369
+ )
370
+
359
371
  # Extract entities as nodes
360
372
 
361
373
  extracted_nodes = await extract_nodes(
@@ -371,7 +383,9 @@ class Graphiti:
371
383
  previous_episodes,
372
384
  entity_types,
373
385
  ),
374
- extract_edges(self.clients, episode, extracted_nodes, previous_episodes, group_id),
386
+ extract_edges(
387
+ self.clients, episode, extracted_nodes, previous_episodes, group_id, edge_types
388
+ ),
375
389
  )
376
390
 
377
391
  edges = resolve_edge_pointers(extracted_edges, uuid_map)
@@ -381,6 +395,9 @@ class Graphiti:
381
395
  self.clients,
382
396
  edges,
383
397
  episode,
398
+ nodes,
399
+ edge_types or {},
400
+ edge_type_map or edge_type_map_default,
384
401
  ),
385
402
  extract_attributes_from_nodes(
386
403
  self.clients, nodes, episode, previous_episodes, entity_types
@@ -681,17 +698,27 @@ class Graphiti:
681
698
 
682
699
  updated_edge = resolve_edge_pointers([edge], uuid_map)[0]
683
700
 
684
- related_edges = await get_relevant_edges(self.driver, [updated_edge], SearchFilters(), 0.8)
701
+ related_edges = (await get_relevant_edges(self.driver, [updated_edge], SearchFilters()))[0]
702
+ existing_edges = (
703
+ await get_edge_invalidation_candidates(self.driver, [updated_edge], SearchFilters())
704
+ )[0]
685
705
 
686
- resolved_edge = await dedupe_extracted_edge(
706
+ resolved_edge, invalidated_edges = await resolve_extracted_edge(
687
707
  self.llm_client,
688
708
  updated_edge,
689
- related_edges[0],
709
+ related_edges,
710
+ existing_edges,
711
+ EpisodicNode(
712
+ name='',
713
+ source=EpisodeType.text,
714
+ source_description='',
715
+ content='',
716
+ valid_at=edge.valid_at or utc_now(),
717
+ entity_edges=[],
718
+ group_id=edge.group_id,
719
+ ),
690
720
  )
691
721
 
692
- contradicting_edges = await get_edge_contradictions(self.llm_client, edge, related_edges[0])
693
- invalidated_edges = resolve_edge_contradictions(resolved_edge, contradicting_edges)
694
-
695
722
  await add_nodes_and_edges_bulk(
696
723
  self.driver, [], [], resolved_nodes, [resolved_edge] + invalidated_edges, self.embedder
697
724
  )
@@ -14,16 +14,16 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
- from neo4j import AsyncDriver
18
17
  from pydantic import BaseModel, ConfigDict
19
18
 
20
19
  from graphiti_core.cross_encoder import CrossEncoderClient
20
+ from graphiti_core.driver.driver import GraphDriver
21
21
  from graphiti_core.embedder import EmbedderClient
22
22
  from graphiti_core.llm_client import LLMClient
23
23
 
24
24
 
25
25
  class GraphitiClients(BaseModel):
26
- driver: AsyncDriver
26
+ driver: GraphDriver
27
27
  llm_client: LLMClient
28
28
  embedder: EmbedderClient
29
29
  cross_encoder: CrossEncoderClient
graphiti_core/helpers.py CHANGED
@@ -18,7 +18,6 @@ import asyncio
18
18
  import os
19
19
  from collections.abc import Coroutine
20
20
  from datetime import datetime
21
- from typing import Any
22
21
 
23
22
  import numpy as np
24
23
  from dotenv import load_dotenv
@@ -28,7 +27,7 @@ from typing_extensions import LiteralString
28
27
 
29
28
  load_dotenv()
30
29
 
31
- DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', None)
30
+ DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', 'neo4j')
32
31
  USE_PARALLEL_RUNTIME = bool(os.getenv('USE_PARALLEL_RUNTIME', False))
33
32
  SEMAPHORE_LIMIT = int(os.getenv('SEMAPHORE_LIMIT', 20))
34
33
  MAX_REFLEXION_ITERATIONS = int(os.getenv('MAX_REFLEXION_ITERATIONS', 0))
@@ -39,8 +38,14 @@ RUNTIME_QUERY: LiteralString = (
39
38
  )
40
39
 
41
40
 
42
- def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None:
43
- return neo_date.to_native() if neo_date else None
41
+ def parse_db_date(neo_date: neo4j_time.DateTime | str | None) -> datetime | None:
42
+ return (
43
+ neo_date.to_native()
44
+ if isinstance(neo_date, neo4j_time.DateTime)
45
+ else datetime.fromisoformat(neo_date)
46
+ if neo_date
47
+ else None
48
+ )
44
49
 
45
50
 
46
51
  def lucene_sanitize(query: str) -> str:
@@ -1,3 +1,19 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
1
17
  from .client import LLMClient
2
18
  from .config import LLMConfig
3
19
  from .errors import RateLimitError
@@ -0,0 +1,73 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import json
18
+ import logging
19
+ from typing import Any
20
+
21
+ from openai import AsyncAzureOpenAI
22
+ from openai.types.chat import ChatCompletionMessageParam
23
+ from pydantic import BaseModel
24
+
25
+ from ..prompts.models import Message
26
+ from .client import LLMClient
27
+ from .config import LLMConfig, ModelSize
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ class AzureOpenAILLMClient(LLMClient):
33
+ """Wrapper class for AsyncAzureOpenAI that implements the LLMClient interface."""
34
+
35
+ def __init__(self, azure_client: AsyncAzureOpenAI, config: LLMConfig | None = None):
36
+ super().__init__(config, cache=False)
37
+ self.azure_client = azure_client
38
+
39
+ async def _generate_response(
40
+ self,
41
+ messages: list[Message],
42
+ response_model: type[BaseModel] | None = None,
43
+ max_tokens: int = 1024,
44
+ model_size: ModelSize = ModelSize.medium,
45
+ ) -> dict[str, Any]:
46
+ """Generate response using Azure OpenAI client."""
47
+ # Convert messages to OpenAI format
48
+ openai_messages: list[ChatCompletionMessageParam] = []
49
+ for message in messages:
50
+ message.content = self._clean_input(message.content)
51
+ if message.role == 'user':
52
+ openai_messages.append({'role': 'user', 'content': message.content})
53
+ elif message.role == 'system':
54
+ openai_messages.append({'role': 'system', 'content': message.content})
55
+
56
+ # Ensure model is a string
57
+ model_name = self.model if self.model else 'gpt-4o-mini'
58
+
59
+ try:
60
+ response = await self.azure_client.chat.completions.create(
61
+ model=model_name,
62
+ messages=openai_messages,
63
+ temperature=float(self.temperature) if self.temperature is not None else 0.7,
64
+ max_tokens=max_tokens,
65
+ response_format={'type': 'json_object'},
66
+ )
67
+ result = response.choices[0].message.content or '{}'
68
+
69
+ # Parse JSON response
70
+ return json.loads(result)
71
+ except Exception as e:
72
+ logger.error(f'Error in Azure OpenAI LLM response: {e}')
73
+ raise
@@ -139,13 +139,16 @@ class GeminiClient(LLMClient):
139
139
  # Generate content using the simple string approach
140
140
  response = await self.client.aio.models.generate_content(
141
141
  model=self.model or DEFAULT_MODEL,
142
- contents=gemini_messages,
142
+ contents=gemini_messages, # type: ignore[arg-type] # mypy fails on broad union type
143
143
  config=generation_config,
144
144
  )
145
145
 
146
146
  # If this was a structured output request, parse the response into the Pydantic model
147
147
  if response_model is not None:
148
148
  try:
149
+ if not response.text:
150
+ raise ValueError('No response text')
151
+
149
152
  validated_model = response_model.model_validate(json.loads(response.text))
150
153
 
151
154
  # Return as a dictionary for API consistency
@@ -34,8 +34,7 @@ ENTITY_EDGE_SAVE = """
34
34
  MATCH (source:Entity {uuid: $source_uuid})
35
35
  MATCH (target:Entity {uuid: $target_uuid})
36
36
  MERGE (source)-[r:RELATES_TO {uuid: $uuid}]->(target)
37
- SET r = {uuid: $uuid, name: $name, group_id: $group_id, fact: $fact, episodes: $episodes,
38
- created_at: $created_at, expired_at: $expired_at, valid_at: $valid_at, invalid_at: $invalid_at}
37
+ SET r = $edge_data
39
38
  WITH r CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", $fact_embedding)
40
39
  RETURN r.uuid AS uuid"""
41
40
 
@@ -44,8 +43,7 @@ ENTITY_EDGE_SAVE_BULK = """
44
43
  MATCH (source:Entity {uuid: edge.source_node_uuid})
45
44
  MATCH (target:Entity {uuid: edge.target_node_uuid})
46
45
  MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target)
47
- SET r = {uuid: edge.uuid, name: edge.name, group_id: edge.group_id, fact: edge.fact, episodes: edge.episodes,
48
- created_at: edge.created_at, expired_at: edge.expired_at, valid_at: edge.valid_at, invalid_at: edge.invalid_at}
46
+ SET r = edge
49
47
  WITH r, edge CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", edge.fact_embedding)
50
48
  RETURN edge.uuid AS uuid
51
49
  """