graphiti-core 0.11.6rc9__py3-none-any.whl → 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- graphiti_core/cross_encoder/openai_reranker_client.py +1 -1
- graphiti_core/driver/__init__.py +17 -0
- graphiti_core/driver/driver.py +66 -0
- graphiti_core/driver/falkordb_driver.py +132 -0
- graphiti_core/driver/neo4j_driver.py +61 -0
- graphiti_core/edges.py +66 -40
- graphiti_core/embedder/azure_openai.py +64 -0
- graphiti_core/embedder/gemini.py +14 -3
- graphiti_core/graph_queries.py +149 -0
- graphiti_core/graphiti.py +41 -14
- graphiti_core/graphiti_types.py +2 -2
- graphiti_core/helpers.py +9 -4
- graphiti_core/llm_client/__init__.py +16 -0
- graphiti_core/llm_client/azure_openai_client.py +73 -0
- graphiti_core/llm_client/gemini_client.py +4 -1
- graphiti_core/models/edges/edge_db_queries.py +2 -4
- graphiti_core/nodes.py +31 -31
- graphiti_core/prompts/dedupe_edges.py +52 -1
- graphiti_core/prompts/dedupe_nodes.py +79 -4
- graphiti_core/prompts/extract_edges.py +50 -5
- graphiti_core/prompts/invalidate_edges.py +1 -1
- graphiti_core/search/search.py +6 -10
- graphiti_core/search/search_filters.py +23 -9
- graphiti_core/search/search_utils.py +250 -189
- graphiti_core/utils/bulk_utils.py +38 -11
- graphiti_core/utils/maintenance/community_operations.py +6 -7
- graphiti_core/utils/maintenance/edge_operations.py +149 -19
- graphiti_core/utils/maintenance/graph_data_operations.py +13 -42
- graphiti_core/utils/maintenance/node_operations.py +52 -71
- {graphiti_core-0.11.6rc9.dist-info → graphiti_core-0.12.0.dist-info}/METADATA +14 -5
- {graphiti_core-0.11.6rc9.dist-info → graphiti_core-0.12.0.dist-info}/RECORD +33 -26
- {graphiti_core-0.11.6rc9.dist-info → graphiti_core-0.12.0.dist-info}/LICENSE +0 -0
- {graphiti_core-0.11.6rc9.dist-info → graphiti_core-0.12.0.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Database query utilities for different graph database backends.
|
|
3
|
+
|
|
4
|
+
This module provides database-agnostic query generation for Neo4j and FalkorDB,
|
|
5
|
+
supporting index creation, fulltext search, and bulk operations.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
from typing_extensions import LiteralString
|
|
11
|
+
|
|
12
|
+
from graphiti_core.models.edges.edge_db_queries import (
|
|
13
|
+
ENTITY_EDGE_SAVE_BULK,
|
|
14
|
+
)
|
|
15
|
+
from graphiti_core.models.nodes.node_db_queries import (
|
|
16
|
+
ENTITY_NODE_SAVE_BULK,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
# Mapping from Neo4j fulltext index names to FalkorDB node labels
|
|
20
|
+
NEO4J_TO_FALKORDB_MAPPING = {
|
|
21
|
+
'node_name_and_summary': 'Entity',
|
|
22
|
+
'community_name': 'Community',
|
|
23
|
+
'episode_content': 'Episodic',
|
|
24
|
+
'edge_name_and_fact': 'RELATES_TO',
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def get_range_indices(db_type: str = 'neo4j') -> list[LiteralString]:
|
|
29
|
+
if db_type == 'falkordb':
|
|
30
|
+
return [
|
|
31
|
+
# Entity node
|
|
32
|
+
'CREATE INDEX FOR (n:Entity) ON (n.uuid, n.group_id, n.name, n.created_at)',
|
|
33
|
+
# Episodic node
|
|
34
|
+
'CREATE INDEX FOR (n:Episodic) ON (n.uuid, n.group_id, n.created_at, n.valid_at)',
|
|
35
|
+
# Community node
|
|
36
|
+
'CREATE INDEX FOR (n:Community) ON (n.uuid)',
|
|
37
|
+
# RELATES_TO edge
|
|
38
|
+
'CREATE INDEX FOR ()-[e:RELATES_TO]-() ON (e.uuid, e.group_id, e.name, e.created_at, e.expired_at, e.valid_at, e.invalid_at)',
|
|
39
|
+
# MENTIONS edge
|
|
40
|
+
'CREATE INDEX FOR ()-[e:MENTIONS]-() ON (e.uuid, e.group_id)',
|
|
41
|
+
# HAS_MEMBER edge
|
|
42
|
+
'CREATE INDEX FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)',
|
|
43
|
+
]
|
|
44
|
+
else:
|
|
45
|
+
return [
|
|
46
|
+
'CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)',
|
|
47
|
+
'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)',
|
|
48
|
+
'CREATE INDEX community_uuid IF NOT EXISTS FOR (n:Community) ON (n.uuid)',
|
|
49
|
+
'CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.uuid)',
|
|
50
|
+
'CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.uuid)',
|
|
51
|
+
'CREATE INDEX has_member_uuid IF NOT EXISTS FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)',
|
|
52
|
+
'CREATE INDEX entity_group_id IF NOT EXISTS FOR (n:Entity) ON (n.group_id)',
|
|
53
|
+
'CREATE INDEX episode_group_id IF NOT EXISTS FOR (n:Episodic) ON (n.group_id)',
|
|
54
|
+
'CREATE INDEX relation_group_id IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.group_id)',
|
|
55
|
+
'CREATE INDEX mention_group_id IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.group_id)',
|
|
56
|
+
'CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)',
|
|
57
|
+
'CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)',
|
|
58
|
+
'CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)',
|
|
59
|
+
'CREATE INDEX valid_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.valid_at)',
|
|
60
|
+
'CREATE INDEX name_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.name)',
|
|
61
|
+
'CREATE INDEX created_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.created_at)',
|
|
62
|
+
'CREATE INDEX expired_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.expired_at)',
|
|
63
|
+
'CREATE INDEX valid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.valid_at)',
|
|
64
|
+
'CREATE INDEX invalid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.invalid_at)',
|
|
65
|
+
]
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def get_fulltext_indices(db_type: str = 'neo4j') -> list[LiteralString]:
|
|
69
|
+
if db_type == 'falkordb':
|
|
70
|
+
return [
|
|
71
|
+
"""CREATE FULLTEXT INDEX FOR (e:Episodic) ON (e.content, e.source, e.source_description, e.group_id)""",
|
|
72
|
+
"""CREATE FULLTEXT INDEX FOR (n:Entity) ON (n.name, n.summary, n.group_id)""",
|
|
73
|
+
"""CREATE FULLTEXT INDEX FOR (n:Community) ON (n.name, n.group_id)""",
|
|
74
|
+
"""CREATE FULLTEXT INDEX FOR ()-[e:RELATES_TO]-() ON (e.name, e.fact, e.group_id)""",
|
|
75
|
+
]
|
|
76
|
+
else:
|
|
77
|
+
return [
|
|
78
|
+
"""CREATE FULLTEXT INDEX episode_content IF NOT EXISTS
|
|
79
|
+
FOR (e:Episodic) ON EACH [e.content, e.source, e.source_description, e.group_id]""",
|
|
80
|
+
"""CREATE FULLTEXT INDEX node_name_and_summary IF NOT EXISTS
|
|
81
|
+
FOR (n:Entity) ON EACH [n.name, n.summary, n.group_id]""",
|
|
82
|
+
"""CREATE FULLTEXT INDEX community_name IF NOT EXISTS
|
|
83
|
+
FOR (n:Community) ON EACH [n.name, n.group_id]""",
|
|
84
|
+
"""CREATE FULLTEXT INDEX edge_name_and_fact IF NOT EXISTS
|
|
85
|
+
FOR ()-[e:RELATES_TO]-() ON EACH [e.name, e.fact, e.group_id]""",
|
|
86
|
+
]
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def get_nodes_query(db_type: str = 'neo4j', name: str = '', query: str | None = None) -> str:
|
|
90
|
+
if db_type == 'falkordb':
|
|
91
|
+
label = NEO4J_TO_FALKORDB_MAPPING[name]
|
|
92
|
+
return f"CALL db.idx.fulltext.queryNodes('{label}', {query})"
|
|
93
|
+
else:
|
|
94
|
+
return f'CALL db.index.fulltext.queryNodes("{name}", {query}, {{limit: $limit}})'
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def get_vector_cosine_func_query(vec1, vec2, db_type: str = 'neo4j') -> str:
|
|
98
|
+
if db_type == 'falkordb':
|
|
99
|
+
# FalkorDB uses a different syntax for regular cosine similarity and Neo4j uses normalized cosine similarity
|
|
100
|
+
return f'(2 - vec.cosineDistance({vec1}, vecf32({vec2})))/2'
|
|
101
|
+
else:
|
|
102
|
+
return f'vector.similarity.cosine({vec1}, {vec2})'
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def get_relationships_query(name: str, db_type: str = 'neo4j') -> str:
|
|
106
|
+
if db_type == 'falkordb':
|
|
107
|
+
label = NEO4J_TO_FALKORDB_MAPPING[name]
|
|
108
|
+
return f"CALL db.idx.fulltext.queryRelationships('{label}', $query)"
|
|
109
|
+
else:
|
|
110
|
+
return f'CALL db.index.fulltext.queryRelationships("{name}", $query, {{limit: $limit}})'
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def get_entity_node_save_bulk_query(nodes, db_type: str = 'neo4j') -> str | Any:
|
|
114
|
+
if db_type == 'falkordb':
|
|
115
|
+
queries = []
|
|
116
|
+
for node in nodes:
|
|
117
|
+
for label in node['labels']:
|
|
118
|
+
queries.append(
|
|
119
|
+
(
|
|
120
|
+
f"""
|
|
121
|
+
UNWIND $nodes AS node
|
|
122
|
+
MERGE (n:Entity {{uuid: node.uuid}})
|
|
123
|
+
SET n:{label}
|
|
124
|
+
SET n = node
|
|
125
|
+
WITH n, node
|
|
126
|
+
SET n.name_embedding = vecf32(node.name_embedding)
|
|
127
|
+
RETURN n.uuid AS uuid
|
|
128
|
+
""",
|
|
129
|
+
{'nodes': [node]},
|
|
130
|
+
)
|
|
131
|
+
)
|
|
132
|
+
return queries
|
|
133
|
+
else:
|
|
134
|
+
return ENTITY_NODE_SAVE_BULK
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def get_entity_edge_save_bulk_query(db_type: str = 'neo4j') -> str:
|
|
138
|
+
if db_type == 'falkordb':
|
|
139
|
+
return """
|
|
140
|
+
UNWIND $entity_edges AS edge
|
|
141
|
+
MATCH (source:Entity {uuid: edge.source_node_uuid})
|
|
142
|
+
MATCH (target:Entity {uuid: edge.target_node_uuid})
|
|
143
|
+
MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target)
|
|
144
|
+
SET r = {uuid: edge.uuid, name: edge.name, group_id: edge.group_id, fact: edge.fact, episodes: edge.episodes,
|
|
145
|
+
created_at: edge.created_at, expired_at: edge.expired_at, valid_at: edge.valid_at, invalid_at: edge.invalid_at, fact_embedding: vecf32(edge.fact_embedding)}
|
|
146
|
+
WITH r, edge
|
|
147
|
+
RETURN edge.uuid AS uuid"""
|
|
148
|
+
else:
|
|
149
|
+
return ENTITY_EDGE_SAVE_BULK
|
graphiti_core/graphiti.py
CHANGED
|
@@ -19,12 +19,13 @@ from datetime import datetime
|
|
|
19
19
|
from time import time
|
|
20
20
|
|
|
21
21
|
from dotenv import load_dotenv
|
|
22
|
-
from neo4j import AsyncGraphDatabase
|
|
23
22
|
from pydantic import BaseModel
|
|
24
23
|
from typing_extensions import LiteralString
|
|
25
24
|
|
|
26
25
|
from graphiti_core.cross_encoder.client import CrossEncoderClient
|
|
27
26
|
from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
|
|
27
|
+
from graphiti_core.driver.driver import GraphDriver
|
|
28
|
+
from graphiti_core.driver.neo4j_driver import Neo4jDriver
|
|
28
29
|
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
|
29
30
|
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
|
|
30
31
|
from graphiti_core.graphiti_types import GraphitiClients
|
|
@@ -41,6 +42,7 @@ from graphiti_core.search.search_config_recipes import (
|
|
|
41
42
|
from graphiti_core.search.search_filters import SearchFilters
|
|
42
43
|
from graphiti_core.search.search_utils import (
|
|
43
44
|
RELEVANT_SCHEMA_LIMIT,
|
|
45
|
+
get_edge_invalidation_candidates,
|
|
44
46
|
get_mentioned_nodes,
|
|
45
47
|
get_relevant_edges,
|
|
46
48
|
)
|
|
@@ -62,9 +64,8 @@ from graphiti_core.utils.maintenance.community_operations import (
|
|
|
62
64
|
)
|
|
63
65
|
from graphiti_core.utils.maintenance.edge_operations import (
|
|
64
66
|
build_episodic_edges,
|
|
65
|
-
dedupe_extracted_edge,
|
|
66
67
|
extract_edges,
|
|
67
|
-
|
|
68
|
+
resolve_extracted_edge,
|
|
68
69
|
resolve_extracted_edges,
|
|
69
70
|
)
|
|
70
71
|
from graphiti_core.utils.maintenance.graph_data_operations import (
|
|
@@ -77,7 +78,6 @@ from graphiti_core.utils.maintenance.node_operations import (
|
|
|
77
78
|
extract_nodes,
|
|
78
79
|
resolve_extracted_nodes,
|
|
79
80
|
)
|
|
80
|
-
from graphiti_core.utils.maintenance.temporal_operations import get_edge_contradictions
|
|
81
81
|
from graphiti_core.utils.ontology_utils.entity_types_utils import validate_entity_types
|
|
82
82
|
|
|
83
83
|
logger = logging.getLogger(__name__)
|
|
@@ -95,12 +95,13 @@ class Graphiti:
|
|
|
95
95
|
def __init__(
|
|
96
96
|
self,
|
|
97
97
|
uri: str,
|
|
98
|
-
user: str,
|
|
99
|
-
password: str,
|
|
98
|
+
user: str | None = None,
|
|
99
|
+
password: str | None = None,
|
|
100
100
|
llm_client: LLMClient | None = None,
|
|
101
101
|
embedder: EmbedderClient | None = None,
|
|
102
102
|
cross_encoder: CrossEncoderClient | None = None,
|
|
103
103
|
store_raw_episode_content: bool = True,
|
|
104
|
+
graph_driver: GraphDriver | None = None,
|
|
104
105
|
):
|
|
105
106
|
"""
|
|
106
107
|
Initialize a Graphiti instance.
|
|
@@ -138,7 +139,9 @@ class Graphiti:
|
|
|
138
139
|
Make sure to set the OPENAI_API_KEY environment variable before initializing
|
|
139
140
|
Graphiti if you're using the default OpenAIClient.
|
|
140
141
|
"""
|
|
141
|
-
|
|
142
|
+
|
|
143
|
+
self.driver = graph_driver if graph_driver else Neo4jDriver(uri, user, password)
|
|
144
|
+
|
|
142
145
|
self.database = DEFAULT_DATABASE
|
|
143
146
|
self.store_raw_episode_content = store_raw_episode_content
|
|
144
147
|
if llm_client:
|
|
@@ -274,6 +277,8 @@ class Graphiti:
|
|
|
274
277
|
update_communities: bool = False,
|
|
275
278
|
entity_types: dict[str, BaseModel] | None = None,
|
|
276
279
|
previous_episode_uuids: list[str] | None = None,
|
|
280
|
+
edge_types: dict[str, BaseModel] | None = None,
|
|
281
|
+
edge_type_map: dict[tuple[str, str], list[str]] | None = None,
|
|
277
282
|
) -> AddEpisodeResults:
|
|
278
283
|
"""
|
|
279
284
|
Process an episode and update the graph.
|
|
@@ -356,6 +361,13 @@ class Graphiti:
|
|
|
356
361
|
)
|
|
357
362
|
)
|
|
358
363
|
|
|
364
|
+
# Create default edge type map
|
|
365
|
+
edge_type_map_default = (
|
|
366
|
+
{('Entity', 'Entity'): list(edge_types.keys())}
|
|
367
|
+
if edge_types is not None
|
|
368
|
+
else {('Entity', 'Entity'): []}
|
|
369
|
+
)
|
|
370
|
+
|
|
359
371
|
# Extract entities as nodes
|
|
360
372
|
|
|
361
373
|
extracted_nodes = await extract_nodes(
|
|
@@ -371,7 +383,9 @@ class Graphiti:
|
|
|
371
383
|
previous_episodes,
|
|
372
384
|
entity_types,
|
|
373
385
|
),
|
|
374
|
-
extract_edges(
|
|
386
|
+
extract_edges(
|
|
387
|
+
self.clients, episode, extracted_nodes, previous_episodes, group_id, edge_types
|
|
388
|
+
),
|
|
375
389
|
)
|
|
376
390
|
|
|
377
391
|
edges = resolve_edge_pointers(extracted_edges, uuid_map)
|
|
@@ -381,6 +395,9 @@ class Graphiti:
|
|
|
381
395
|
self.clients,
|
|
382
396
|
edges,
|
|
383
397
|
episode,
|
|
398
|
+
nodes,
|
|
399
|
+
edge_types or {},
|
|
400
|
+
edge_type_map or edge_type_map_default,
|
|
384
401
|
),
|
|
385
402
|
extract_attributes_from_nodes(
|
|
386
403
|
self.clients, nodes, episode, previous_episodes, entity_types
|
|
@@ -681,17 +698,27 @@ class Graphiti:
|
|
|
681
698
|
|
|
682
699
|
updated_edge = resolve_edge_pointers([edge], uuid_map)[0]
|
|
683
700
|
|
|
684
|
-
related_edges = await get_relevant_edges(self.driver, [updated_edge], SearchFilters()
|
|
701
|
+
related_edges = (await get_relevant_edges(self.driver, [updated_edge], SearchFilters()))[0]
|
|
702
|
+
existing_edges = (
|
|
703
|
+
await get_edge_invalidation_candidates(self.driver, [updated_edge], SearchFilters())
|
|
704
|
+
)[0]
|
|
685
705
|
|
|
686
|
-
resolved_edge = await
|
|
706
|
+
resolved_edge, invalidated_edges = await resolve_extracted_edge(
|
|
687
707
|
self.llm_client,
|
|
688
708
|
updated_edge,
|
|
689
|
-
related_edges
|
|
709
|
+
related_edges,
|
|
710
|
+
existing_edges,
|
|
711
|
+
EpisodicNode(
|
|
712
|
+
name='',
|
|
713
|
+
source=EpisodeType.text,
|
|
714
|
+
source_description='',
|
|
715
|
+
content='',
|
|
716
|
+
valid_at=edge.valid_at or utc_now(),
|
|
717
|
+
entity_edges=[],
|
|
718
|
+
group_id=edge.group_id,
|
|
719
|
+
),
|
|
690
720
|
)
|
|
691
721
|
|
|
692
|
-
contradicting_edges = await get_edge_contradictions(self.llm_client, edge, related_edges[0])
|
|
693
|
-
invalidated_edges = resolve_edge_contradictions(resolved_edge, contradicting_edges)
|
|
694
|
-
|
|
695
722
|
await add_nodes_and_edges_bulk(
|
|
696
723
|
self.driver, [], [], resolved_nodes, [resolved_edge] + invalidated_edges, self.embedder
|
|
697
724
|
)
|
graphiti_core/graphiti_types.py
CHANGED
|
@@ -14,16 +14,16 @@ See the License for the specific language governing permissions and
|
|
|
14
14
|
limitations under the License.
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
|
-
from neo4j import AsyncDriver
|
|
18
17
|
from pydantic import BaseModel, ConfigDict
|
|
19
18
|
|
|
20
19
|
from graphiti_core.cross_encoder import CrossEncoderClient
|
|
20
|
+
from graphiti_core.driver.driver import GraphDriver
|
|
21
21
|
from graphiti_core.embedder import EmbedderClient
|
|
22
22
|
from graphiti_core.llm_client import LLMClient
|
|
23
23
|
|
|
24
24
|
|
|
25
25
|
class GraphitiClients(BaseModel):
|
|
26
|
-
driver:
|
|
26
|
+
driver: GraphDriver
|
|
27
27
|
llm_client: LLMClient
|
|
28
28
|
embedder: EmbedderClient
|
|
29
29
|
cross_encoder: CrossEncoderClient
|
graphiti_core/helpers.py
CHANGED
|
@@ -18,7 +18,6 @@ import asyncio
|
|
|
18
18
|
import os
|
|
19
19
|
from collections.abc import Coroutine
|
|
20
20
|
from datetime import datetime
|
|
21
|
-
from typing import Any
|
|
22
21
|
|
|
23
22
|
import numpy as np
|
|
24
23
|
from dotenv import load_dotenv
|
|
@@ -28,7 +27,7 @@ from typing_extensions import LiteralString
|
|
|
28
27
|
|
|
29
28
|
load_dotenv()
|
|
30
29
|
|
|
31
|
-
DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE',
|
|
30
|
+
DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', 'neo4j')
|
|
32
31
|
USE_PARALLEL_RUNTIME = bool(os.getenv('USE_PARALLEL_RUNTIME', False))
|
|
33
32
|
SEMAPHORE_LIMIT = int(os.getenv('SEMAPHORE_LIMIT', 20))
|
|
34
33
|
MAX_REFLEXION_ITERATIONS = int(os.getenv('MAX_REFLEXION_ITERATIONS', 0))
|
|
@@ -39,8 +38,14 @@ RUNTIME_QUERY: LiteralString = (
|
|
|
39
38
|
)
|
|
40
39
|
|
|
41
40
|
|
|
42
|
-
def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None:
|
|
43
|
-
return
|
|
41
|
+
def parse_db_date(neo_date: neo4j_time.DateTime | str | None) -> datetime | None:
|
|
42
|
+
return (
|
|
43
|
+
neo_date.to_native()
|
|
44
|
+
if isinstance(neo_date, neo4j_time.DateTime)
|
|
45
|
+
else datetime.fromisoformat(neo_date)
|
|
46
|
+
if neo_date
|
|
47
|
+
else None
|
|
48
|
+
)
|
|
44
49
|
|
|
45
50
|
|
|
46
51
|
def lucene_sanitize(query: str) -> str:
|
|
@@ -1,3 +1,19 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
1
17
|
from .client import LLMClient
|
|
2
18
|
from .config import LLMConfig
|
|
3
19
|
from .errors import RateLimitError
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import json
|
|
18
|
+
import logging
|
|
19
|
+
from typing import Any
|
|
20
|
+
|
|
21
|
+
from openai import AsyncAzureOpenAI
|
|
22
|
+
from openai.types.chat import ChatCompletionMessageParam
|
|
23
|
+
from pydantic import BaseModel
|
|
24
|
+
|
|
25
|
+
from ..prompts.models import Message
|
|
26
|
+
from .client import LLMClient
|
|
27
|
+
from .config import LLMConfig, ModelSize
|
|
28
|
+
|
|
29
|
+
logger = logging.getLogger(__name__)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class AzureOpenAILLMClient(LLMClient):
|
|
33
|
+
"""Wrapper class for AsyncAzureOpenAI that implements the LLMClient interface."""
|
|
34
|
+
|
|
35
|
+
def __init__(self, azure_client: AsyncAzureOpenAI, config: LLMConfig | None = None):
|
|
36
|
+
super().__init__(config, cache=False)
|
|
37
|
+
self.azure_client = azure_client
|
|
38
|
+
|
|
39
|
+
async def _generate_response(
|
|
40
|
+
self,
|
|
41
|
+
messages: list[Message],
|
|
42
|
+
response_model: type[BaseModel] | None = None,
|
|
43
|
+
max_tokens: int = 1024,
|
|
44
|
+
model_size: ModelSize = ModelSize.medium,
|
|
45
|
+
) -> dict[str, Any]:
|
|
46
|
+
"""Generate response using Azure OpenAI client."""
|
|
47
|
+
# Convert messages to OpenAI format
|
|
48
|
+
openai_messages: list[ChatCompletionMessageParam] = []
|
|
49
|
+
for message in messages:
|
|
50
|
+
message.content = self._clean_input(message.content)
|
|
51
|
+
if message.role == 'user':
|
|
52
|
+
openai_messages.append({'role': 'user', 'content': message.content})
|
|
53
|
+
elif message.role == 'system':
|
|
54
|
+
openai_messages.append({'role': 'system', 'content': message.content})
|
|
55
|
+
|
|
56
|
+
# Ensure model is a string
|
|
57
|
+
model_name = self.model if self.model else 'gpt-4o-mini'
|
|
58
|
+
|
|
59
|
+
try:
|
|
60
|
+
response = await self.azure_client.chat.completions.create(
|
|
61
|
+
model=model_name,
|
|
62
|
+
messages=openai_messages,
|
|
63
|
+
temperature=float(self.temperature) if self.temperature is not None else 0.7,
|
|
64
|
+
max_tokens=max_tokens,
|
|
65
|
+
response_format={'type': 'json_object'},
|
|
66
|
+
)
|
|
67
|
+
result = response.choices[0].message.content or '{}'
|
|
68
|
+
|
|
69
|
+
# Parse JSON response
|
|
70
|
+
return json.loads(result)
|
|
71
|
+
except Exception as e:
|
|
72
|
+
logger.error(f'Error in Azure OpenAI LLM response: {e}')
|
|
73
|
+
raise
|
|
@@ -139,13 +139,16 @@ class GeminiClient(LLMClient):
|
|
|
139
139
|
# Generate content using the simple string approach
|
|
140
140
|
response = await self.client.aio.models.generate_content(
|
|
141
141
|
model=self.model or DEFAULT_MODEL,
|
|
142
|
-
contents=gemini_messages,
|
|
142
|
+
contents=gemini_messages, # type: ignore[arg-type] # mypy fails on broad union type
|
|
143
143
|
config=generation_config,
|
|
144
144
|
)
|
|
145
145
|
|
|
146
146
|
# If this was a structured output request, parse the response into the Pydantic model
|
|
147
147
|
if response_model is not None:
|
|
148
148
|
try:
|
|
149
|
+
if not response.text:
|
|
150
|
+
raise ValueError('No response text')
|
|
151
|
+
|
|
149
152
|
validated_model = response_model.model_validate(json.loads(response.text))
|
|
150
153
|
|
|
151
154
|
# Return as a dictionary for API consistency
|
|
@@ -34,8 +34,7 @@ ENTITY_EDGE_SAVE = """
|
|
|
34
34
|
MATCH (source:Entity {uuid: $source_uuid})
|
|
35
35
|
MATCH (target:Entity {uuid: $target_uuid})
|
|
36
36
|
MERGE (source)-[r:RELATES_TO {uuid: $uuid}]->(target)
|
|
37
|
-
SET r =
|
|
38
|
-
created_at: $created_at, expired_at: $expired_at, valid_at: $valid_at, invalid_at: $invalid_at}
|
|
37
|
+
SET r = $edge_data
|
|
39
38
|
WITH r CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", $fact_embedding)
|
|
40
39
|
RETURN r.uuid AS uuid"""
|
|
41
40
|
|
|
@@ -44,8 +43,7 @@ ENTITY_EDGE_SAVE_BULK = """
|
|
|
44
43
|
MATCH (source:Entity {uuid: edge.source_node_uuid})
|
|
45
44
|
MATCH (target:Entity {uuid: edge.target_node_uuid})
|
|
46
45
|
MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target)
|
|
47
|
-
SET r =
|
|
48
|
-
created_at: edge.created_at, expired_at: edge.expired_at, valid_at: edge.valid_at, invalid_at: edge.invalid_at}
|
|
46
|
+
SET r = edge
|
|
49
47
|
WITH r, edge CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", edge.fact_embedding)
|
|
50
48
|
RETURN edge.uuid AS uuid
|
|
51
49
|
"""
|