graphiti-core 0.3.15__py3-none-any.whl → 0.3.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

File without changes
@@ -0,0 +1,45 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import asyncio
18
+ from typing import List, Tuple
19
+
20
+ from sentence_transformers import CrossEncoder
21
+
22
+ from graphiti_core.cross_encoder.client import CrossEncoderClient
23
+
24
+
25
+ class BGERerankerClient(CrossEncoderClient):
26
+ def __init__(self):
27
+ self.model = CrossEncoder('BAAI/bge-reranker-v2-m3')
28
+
29
+ async def rank(self, query: str, passages: List[str]) -> List[Tuple[str, float]]:
30
+ if not passages:
31
+ return []
32
+
33
+ input_pairs = [[query, passage] for passage in passages]
34
+
35
+ # Run the synchronous predict method in an executor
36
+ loop = asyncio.get_running_loop()
37
+ scores = await loop.run_in_executor(None, self.model.predict, input_pairs)
38
+
39
+ ranked_passages = sorted(
40
+ [(passage, float(score)) for passage, score in zip(passages, scores)],
41
+ key=lambda x: x[1],
42
+ reverse=True,
43
+ )
44
+
45
+ return ranked_passages
@@ -0,0 +1,41 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ from abc import ABC, abstractmethod
18
+ from typing import List, Tuple
19
+
20
+
21
+ class CrossEncoderClient(ABC):
22
+ """
23
+ CrossEncoderClient is an abstract base class that defines the interface
24
+ for cross-encoder models used for ranking passages based on their relevance to a query.
25
+ It allows for different implementations of cross-encoder models to be used interchangeably.
26
+ """
27
+
28
+ @abstractmethod
29
+ async def rank(self, query: str, passages: List[str]) -> List[Tuple[str, float]]:
30
+ """
31
+ Rank the given passages based on their relevance to the query.
32
+
33
+ Args:
34
+ query (str): The query string.
35
+ passages (List[str]): A list of passages to rank.
36
+
37
+ Returns:
38
+ List[Tuple[str, float]]: A list of tuples containing the passage and its score,
39
+ sorted in descending order of relevance.
40
+ """
41
+ pass
@@ -0,0 +1,113 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import asyncio
18
+ import logging
19
+ from typing import Any
20
+
21
+ import openai
22
+ from openai import AsyncOpenAI
23
+ from pydantic import BaseModel
24
+
25
+ from ..llm_client import LLMConfig, RateLimitError
26
+ from ..prompts import Message
27
+ from .client import CrossEncoderClient
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+ DEFAULT_MODEL = 'gpt-4o-mini'
32
+
33
+
34
+ class BooleanClassifier(BaseModel):
35
+ isTrue: bool
36
+
37
+
38
+ class OpenAIRerankerClient(CrossEncoderClient):
39
+ def __init__(self, config: LLMConfig | None = None):
40
+ """
41
+ Initialize the OpenAIClient with the provided configuration, cache setting, and client.
42
+
43
+ Args:
44
+ config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens.
45
+ cache (bool): Whether to use caching for responses. Defaults to False.
46
+ client (Any | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created.
47
+
48
+ """
49
+ if config is None:
50
+ config = LLMConfig()
51
+
52
+ self.config = config
53
+ self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
54
+
55
+ async def rank(self, query: str, passages: list[str]) -> list[tuple[str, float]]:
56
+ openai_messages_list: Any = [
57
+ [
58
+ Message(
59
+ role='system',
60
+ content='You are an expert tasked with determining whether the passage is relevant to the query',
61
+ ),
62
+ Message(
63
+ role='user',
64
+ content=f"""
65
+ Respond with "True" if PASSAGE is relevant to QUERY and "False" otherwise.
66
+ <PASSAGE>
67
+ {query}
68
+ </PASSAGE>
69
+ {passage}
70
+ <QUERY>
71
+ </QUERY>
72
+ """,
73
+ ),
74
+ ]
75
+ for passage in passages
76
+ ]
77
+ try:
78
+ responses = await asyncio.gather(
79
+ *[
80
+ self.client.chat.completions.create(
81
+ model=DEFAULT_MODEL,
82
+ messages=openai_messages,
83
+ temperature=0,
84
+ max_tokens=1,
85
+ logit_bias={'6432': 1, '7983': 1},
86
+ logprobs=True,
87
+ top_logprobs=2,
88
+ )
89
+ for openai_messages in openai_messages_list
90
+ ]
91
+ )
92
+
93
+ responses_top_logprobs = [
94
+ response.choices[0].logprobs.content[0].top_logprobs
95
+ if response.choices[0].logprobs is not None
96
+ and response.choices[0].logprobs.content is not None
97
+ else []
98
+ for response in responses
99
+ ]
100
+ scores: list[float] = []
101
+ for top_logprobs in responses_top_logprobs:
102
+ for logprob in top_logprobs:
103
+ if bool(logprob.token):
104
+ scores.append(logprob.logprob)
105
+
106
+ results = [(passage, score) for passage, score in zip(passages, scores)]
107
+ results.sort(reverse=True, key=lambda x: x[1])
108
+ return results
109
+ except openai.RateLimitError as e:
110
+ raise RateLimitError from e
111
+ except Exception as e:
112
+ logger.error(f'Error in generating LLM response: {e}')
113
+ raise
graphiti_core/edges.py CHANGED
@@ -54,7 +54,7 @@ class Edge(BaseModel, ABC):
54
54
  DELETE e
55
55
  """,
56
56
  uuid=self.uuid,
57
- _database=DEFAULT_DATABASE,
57
+ database_=DEFAULT_DATABASE,
58
58
  )
59
59
 
60
60
  logger.debug(f'Deleted Edge: {self.uuid}')
@@ -82,7 +82,7 @@ class EpisodicEdge(Edge):
82
82
  uuid=self.uuid,
83
83
  group_id=self.group_id,
84
84
  created_at=self.created_at,
85
- _database=DEFAULT_DATABASE,
85
+ database_=DEFAULT_DATABASE,
86
86
  )
87
87
 
88
88
  logger.debug(f'Saved edge to neo4j: {self.uuid}')
@@ -102,7 +102,7 @@ class EpisodicEdge(Edge):
102
102
  e.created_at AS created_at
103
103
  """,
104
104
  uuid=uuid,
105
- _database=DEFAULT_DATABASE,
105
+ database_=DEFAULT_DATABASE,
106
106
  )
107
107
 
108
108
  edges = [get_episodic_edge_from_record(record) for record in records]
@@ -125,7 +125,7 @@ class EpisodicEdge(Edge):
125
125
  e.created_at AS created_at
126
126
  """,
127
127
  uuids=uuids,
128
- _database=DEFAULT_DATABASE,
128
+ database_=DEFAULT_DATABASE,
129
129
  )
130
130
 
131
131
  edges = [get_episodic_edge_from_record(record) for record in records]
@@ -148,7 +148,7 @@ class EpisodicEdge(Edge):
148
148
  e.created_at AS created_at
149
149
  """,
150
150
  group_ids=group_ids,
151
- _database=DEFAULT_DATABASE,
151
+ database_=DEFAULT_DATABASE,
152
152
  )
153
153
 
154
154
  edges = [get_episodic_edge_from_record(record) for record in records]
@@ -202,7 +202,7 @@ class EntityEdge(Edge):
202
202
  expired_at=self.expired_at,
203
203
  valid_at=self.valid_at,
204
204
  invalid_at=self.invalid_at,
205
- _database=DEFAULT_DATABASE,
205
+ database_=DEFAULT_DATABASE,
206
206
  )
207
207
 
208
208
  logger.debug(f'Saved edge to neo4j: {self.uuid}')
@@ -229,7 +229,7 @@ class EntityEdge(Edge):
229
229
  e.invalid_at AS invalid_at
230
230
  """,
231
231
  uuid=uuid,
232
- _database=DEFAULT_DATABASE,
232
+ database_=DEFAULT_DATABASE,
233
233
  )
234
234
 
235
235
  edges = [get_entity_edge_from_record(record) for record in records]
@@ -259,7 +259,7 @@ class EntityEdge(Edge):
259
259
  e.invalid_at AS invalid_at
260
260
  """,
261
261
  uuids=uuids,
262
- _database=DEFAULT_DATABASE,
262
+ database_=DEFAULT_DATABASE,
263
263
  )
264
264
 
265
265
  edges = [get_entity_edge_from_record(record) for record in records]
@@ -289,7 +289,7 @@ class EntityEdge(Edge):
289
289
  e.invalid_at AS invalid_at
290
290
  """,
291
291
  group_ids=group_ids,
292
- _database=DEFAULT_DATABASE,
292
+ database_=DEFAULT_DATABASE,
293
293
  )
294
294
 
295
295
  edges = [get_entity_edge_from_record(record) for record in records]
@@ -308,7 +308,7 @@ class CommunityEdge(Edge):
308
308
  uuid=self.uuid,
309
309
  group_id=self.group_id,
310
310
  created_at=self.created_at,
311
- _database=DEFAULT_DATABASE,
311
+ database_=DEFAULT_DATABASE,
312
312
  )
313
313
 
314
314
  logger.debug(f'Saved edge to neo4j: {self.uuid}')
@@ -328,7 +328,7 @@ class CommunityEdge(Edge):
328
328
  e.created_at AS created_at
329
329
  """,
330
330
  uuid=uuid,
331
- _database=DEFAULT_DATABASE,
331
+ database_=DEFAULT_DATABASE,
332
332
  )
333
333
 
334
334
  edges = [get_community_edge_from_record(record) for record in records]
@@ -349,7 +349,7 @@ class CommunityEdge(Edge):
349
349
  e.created_at AS created_at
350
350
  """,
351
351
  uuids=uuids,
352
- _database=DEFAULT_DATABASE,
352
+ database_=DEFAULT_DATABASE,
353
353
  )
354
354
 
355
355
  edges = [get_community_edge_from_record(record) for record in records]
@@ -370,7 +370,7 @@ class CommunityEdge(Edge):
370
370
  e.created_at AS created_at
371
371
  """,
372
372
  group_ids=group_ids,
373
- _database=DEFAULT_DATABASE,
373
+ database_=DEFAULT_DATABASE,
374
374
  )
375
375
 
376
376
  edges = [get_community_edge_from_record(record) for record in records]
graphiti_core/graphiti.py CHANGED
@@ -23,8 +23,11 @@ from dotenv import load_dotenv
23
23
  from neo4j import AsyncGraphDatabase
24
24
  from pydantic import BaseModel
25
25
 
26
+ from graphiti_core.cross_encoder.client import CrossEncoderClient
27
+ from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
26
28
  from graphiti_core.edges import EntityEdge, EpisodicEdge
27
29
  from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
30
+ from graphiti_core.helpers import DEFAULT_DATABASE
28
31
  from graphiti_core.llm_client import LLMClient, OpenAIClient
29
32
  from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
30
33
  from graphiti_core.search.search import SearchConfig, search
@@ -92,6 +95,7 @@ class Graphiti:
92
95
  password: str,
93
96
  llm_client: LLMClient | None = None,
94
97
  embedder: EmbedderClient | None = None,
98
+ cross_encoder: CrossEncoderClient | None = None,
95
99
  store_raw_episode_content: bool = True,
96
100
  ):
97
101
  """
@@ -131,7 +135,7 @@ class Graphiti:
131
135
  Graphiti if you're using the default OpenAIClient.
132
136
  """
133
137
  self.driver = AsyncGraphDatabase.driver(uri, auth=(user, password))
134
- self.database = 'neo4j'
138
+ self.database = DEFAULT_DATABASE
135
139
  self.store_raw_episode_content = store_raw_episode_content
136
140
  if llm_client:
137
141
  self.llm_client = llm_client
@@ -141,6 +145,10 @@ class Graphiti:
141
145
  self.embedder = embedder
142
146
  else:
143
147
  self.embedder = OpenAIEmbedder()
148
+ if cross_encoder:
149
+ self.cross_encoder = cross_encoder
150
+ else:
151
+ self.cross_encoder = OpenAIRerankerClient()
144
152
 
145
153
  async def close(self):
146
154
  """
@@ -648,6 +656,7 @@ class Graphiti:
648
656
  await search(
649
657
  self.driver,
650
658
  self.embedder,
659
+ self.cross_encoder,
651
660
  query,
652
661
  group_ids,
653
662
  search_config,
@@ -663,8 +672,18 @@ class Graphiti:
663
672
  config: SearchConfig,
664
673
  group_ids: list[str] | None = None,
665
674
  center_node_uuid: str | None = None,
675
+ bfs_origin_node_uuids: list[str] | None = None,
666
676
  ) -> SearchResults:
667
- return await search(self.driver, self.embedder, query, group_ids, config, center_node_uuid)
677
+ return await search(
678
+ self.driver,
679
+ self.embedder,
680
+ self.cross_encoder,
681
+ query,
682
+ group_ids,
683
+ config,
684
+ center_node_uuid,
685
+ bfs_origin_node_uuids,
686
+ )
668
687
 
669
688
  async def get_nodes_by_query(
670
689
  self,
@@ -716,7 +735,13 @@ class Graphiti:
716
735
 
717
736
  nodes = (
718
737
  await search(
719
- self.driver, self.embedder, query, group_ids, search_config, center_node_uuid
738
+ self.driver,
739
+ self.embedder,
740
+ self.cross_encoder,
741
+ query,
742
+ group_ids,
743
+ search_config,
744
+ center_node_uuid,
720
745
  )
721
746
  ).nodes
722
747
  return nodes
graphiti_core/nodes.py CHANGED
@@ -90,7 +90,7 @@ class Node(BaseModel, ABC):
90
90
  DETACH DELETE n
91
91
  """,
92
92
  uuid=self.uuid,
93
- _database=DEFAULT_DATABASE,
93
+ database_=DEFAULT_DATABASE,
94
94
  )
95
95
 
96
96
  logger.debug(f'Deleted Node: {self.uuid}')
@@ -136,7 +136,7 @@ class EpisodicNode(Node):
136
136
  created_at=self.created_at,
137
137
  valid_at=self.valid_at,
138
138
  source=self.source.value,
139
- _database=DEFAULT_DATABASE,
139
+ database_=DEFAULT_DATABASE,
140
140
  )
141
141
 
142
142
  logger.debug(f'Saved Node to neo4j: {self.uuid}')
@@ -158,7 +158,7 @@ class EpisodicNode(Node):
158
158
  e.source AS source
159
159
  """,
160
160
  uuid=uuid,
161
- _database=DEFAULT_DATABASE,
161
+ database_=DEFAULT_DATABASE,
162
162
  )
163
163
 
164
164
  episodes = [get_episodic_node_from_record(record) for record in records]
@@ -184,7 +184,7 @@ class EpisodicNode(Node):
184
184
  e.source AS source
185
185
  """,
186
186
  uuids=uuids,
187
- _database=DEFAULT_DATABASE,
187
+ database_=DEFAULT_DATABASE,
188
188
  )
189
189
 
190
190
  episodes = [get_episodic_node_from_record(record) for record in records]
@@ -207,7 +207,7 @@ class EpisodicNode(Node):
207
207
  e.source AS source
208
208
  """,
209
209
  group_ids=group_ids,
210
- _database=DEFAULT_DATABASE,
210
+ database_=DEFAULT_DATABASE,
211
211
  )
212
212
 
213
213
  episodes = [get_episodic_node_from_record(record) for record in records]
@@ -237,7 +237,7 @@ class EntityNode(Node):
237
237
  summary=self.summary,
238
238
  name_embedding=self.name_embedding,
239
239
  created_at=self.created_at,
240
- _database=DEFAULT_DATABASE,
240
+ database_=DEFAULT_DATABASE,
241
241
  )
242
242
 
243
243
  logger.debug(f'Saved Node to neo4j: {self.uuid}')
@@ -258,7 +258,7 @@ class EntityNode(Node):
258
258
  n.summary AS summary
259
259
  """,
260
260
  uuid=uuid,
261
- _database=DEFAULT_DATABASE,
261
+ database_=DEFAULT_DATABASE,
262
262
  )
263
263
 
264
264
  nodes = [get_entity_node_from_record(record) for record in records]
@@ -282,7 +282,7 @@ class EntityNode(Node):
282
282
  n.summary AS summary
283
283
  """,
284
284
  uuids=uuids,
285
- _database=DEFAULT_DATABASE,
285
+ database_=DEFAULT_DATABASE,
286
286
  )
287
287
 
288
288
  nodes = [get_entity_node_from_record(record) for record in records]
@@ -303,7 +303,7 @@ class EntityNode(Node):
303
303
  n.summary AS summary
304
304
  """,
305
305
  group_ids=group_ids,
306
- _database=DEFAULT_DATABASE,
306
+ database_=DEFAULT_DATABASE,
307
307
  )
308
308
 
309
309
  nodes = [get_entity_node_from_record(record) for record in records]
@@ -324,7 +324,7 @@ class CommunityNode(Node):
324
324
  summary=self.summary,
325
325
  name_embedding=self.name_embedding,
326
326
  created_at=self.created_at,
327
- _database=DEFAULT_DATABASE,
327
+ database_=DEFAULT_DATABASE,
328
328
  )
329
329
 
330
330
  logger.debug(f'Saved Node to neo4j: {self.uuid}')
@@ -354,7 +354,7 @@ class CommunityNode(Node):
354
354
  n.summary AS summary
355
355
  """,
356
356
  uuid=uuid,
357
- _database=DEFAULT_DATABASE,
357
+ database_=DEFAULT_DATABASE,
358
358
  )
359
359
 
360
360
  nodes = [get_community_node_from_record(record) for record in records]
@@ -378,7 +378,7 @@ class CommunityNode(Node):
378
378
  n.summary AS summary
379
379
  """,
380
380
  uuids=uuids,
381
- _database=DEFAULT_DATABASE,
381
+ database_=DEFAULT_DATABASE,
382
382
  )
383
383
 
384
384
  communities = [get_community_node_from_record(record) for record in records]
@@ -399,7 +399,7 @@ class CommunityNode(Node):
399
399
  n.summary AS summary
400
400
  """,
401
401
  group_ids=group_ids,
402
- _database=DEFAULT_DATABASE,
402
+ database_=DEFAULT_DATABASE,
403
403
  )
404
404
 
405
405
  communities = [get_community_node_from_record(record) for record in records]