graphiti-core 0.12.0rc5__py3-none-any.whl → 0.12.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

@@ -106,7 +106,7 @@ class OpenAIRerankerClient(CrossEncoderClient):
106
106
  if len(top_logprobs) == 0:
107
107
  continue
108
108
  norm_logprobs = np.exp(top_logprobs[0].logprob)
109
- if bool(top_logprobs[0].token):
109
+ if top_logprobs[0].token.strip().split(' ')[0].lower() == 'true':
110
110
  scores.append(norm_logprobs)
111
111
  else:
112
112
  scores.append(1 - norm_logprobs)
@@ -0,0 +1,17 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ __all__ = ['GraphDriver', 'Neo4jDriver', 'FalkorDriver']
@@ -0,0 +1,66 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import logging
18
+ from abc import ABC, abstractmethod
19
+ from collections.abc import Coroutine
20
+ from typing import Any
21
+
22
+ from graphiti_core.helpers import DEFAULT_DATABASE
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ class GraphDriverSession(ABC):
28
+ async def __aenter__(self):
29
+ return self
30
+
31
+ @abstractmethod
32
+ async def __aexit__(self, exc_type, exc, tb):
33
+ # No cleanup needed for Falkor, but method must exist
34
+ pass
35
+
36
+ @abstractmethod
37
+ async def run(self, query: str, **kwargs: Any) -> Any:
38
+ raise NotImplementedError()
39
+
40
+ @abstractmethod
41
+ async def close(self):
42
+ raise NotImplementedError()
43
+
44
+ @abstractmethod
45
+ async def execute_write(self, func, *args, **kwargs):
46
+ raise NotImplementedError()
47
+
48
+
49
+ class GraphDriver(ABC):
50
+ provider: str
51
+
52
+ @abstractmethod
53
+ def execute_query(self, cypher_query_: str, **kwargs: Any) -> Coroutine:
54
+ raise NotImplementedError()
55
+
56
+ @abstractmethod
57
+ def session(self, database: str) -> GraphDriverSession:
58
+ raise NotImplementedError()
59
+
60
+ @abstractmethod
61
+ def close(self):
62
+ raise NotImplementedError()
63
+
64
+ @abstractmethod
65
+ def delete_all_indexes(self, database_: str = DEFAULT_DATABASE) -> Coroutine:
66
+ raise NotImplementedError()
@@ -0,0 +1,131 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import logging
18
+ from collections.abc import Coroutine
19
+ from datetime import datetime
20
+ from typing import Any
21
+
22
+ from falkordb import Graph as FalkorGraph # type: ignore
23
+ from falkordb.asyncio import FalkorDB # type: ignore
24
+
25
+ from graphiti_core.driver.driver import GraphDriver, GraphDriverSession
26
+ from graphiti_core.helpers import DEFAULT_DATABASE
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ class FalkorDriverSession(GraphDriverSession):
32
+ def __init__(self, graph: FalkorGraph):
33
+ self.graph = graph
34
+
35
+ async def __aenter__(self):
36
+ return self
37
+
38
+ async def __aexit__(self, exc_type, exc, tb):
39
+ # No cleanup needed for Falkor, but method must exist
40
+ pass
41
+
42
+ async def close(self):
43
+ # No explicit close needed for FalkorDB, but method must exist
44
+ pass
45
+
46
+ async def execute_write(self, func, *args, **kwargs):
47
+ # Directly await the provided async function with `self` as the transaction/session
48
+ return await func(self, *args, **kwargs)
49
+
50
+ async def run(self, query: str | list, **kwargs: Any) -> Any:
51
+ # FalkorDB does not support argument for Label Set, so it's converted into an array of queries
52
+ if isinstance(query, list):
53
+ for cypher, params in query:
54
+ params = convert_datetimes_to_strings(params)
55
+ await self.graph.query(str(cypher), params)
56
+ else:
57
+ params = dict(kwargs)
58
+ params = convert_datetimes_to_strings(params)
59
+ await self.graph.query(str(query), params)
60
+ # Assuming `graph.query` is async (ideal); otherwise, wrap in executor
61
+ return None
62
+
63
+
64
+ class FalkorDriver(GraphDriver):
65
+ provider: str = 'falkordb'
66
+
67
+ def __init__(
68
+ self,
69
+ uri: str,
70
+ user: str,
71
+ password: str,
72
+ ):
73
+ super().__init__()
74
+ uri_parts = uri.split('://', 1)
75
+ uri = f'{uri_parts[0]}://{user}:{password}@{uri_parts[1]}'
76
+
77
+ self.client = FalkorDB(
78
+ host='your-db.falkor.cloud', port=6380, password='your_password', ssl=True
79
+ )
80
+
81
+ def _get_graph(self, graph_name: str | None) -> FalkorGraph:
82
+ # FalkorDB requires a non-None database name for multi-tenant graphs; the default is "DEFAULT_DATABASE"
83
+ if graph_name is None:
84
+ graph_name = 'DEFAULT_DATABASE'
85
+ return self.client.select_graph(graph_name)
86
+
87
+ async def execute_query(self, cypher_query_, **kwargs: Any):
88
+ graph_name = kwargs.pop('database_', DEFAULT_DATABASE)
89
+ graph = self._get_graph(graph_name)
90
+
91
+ # Convert datetime objects to ISO strings (FalkorDB does not support datetime objects directly)
92
+ params = convert_datetimes_to_strings(dict(kwargs))
93
+
94
+ try:
95
+ result = await graph.query(cypher_query_, params)
96
+ except Exception as e:
97
+ if 'already indexed' in str(e):
98
+ # check if index already exists
99
+ logger.info(f'Index already exists: {e}')
100
+ return None
101
+ logger.error(f'Error executing FalkorDB query: {e}')
102
+ raise
103
+
104
+ # Convert the result header to a list of strings
105
+ header = [h[1].decode('utf-8') for h in result.header]
106
+ return result.result_set, header, None
107
+
108
+ def session(self, database: str | None) -> GraphDriverSession:
109
+ return FalkorDriverSession(self._get_graph(database))
110
+
111
+ async def close(self) -> None:
112
+ await self.client.connection.close()
113
+
114
+ async def delete_all_indexes(self, database_: str = DEFAULT_DATABASE) -> Coroutine:
115
+ return self.execute_query(
116
+ 'CALL db.indexes() YIELD name DROP INDEX name',
117
+ database_=database_,
118
+ )
119
+
120
+
121
+ def convert_datetimes_to_strings(obj):
122
+ if isinstance(obj, dict):
123
+ return {k: convert_datetimes_to_strings(v) for k, v in obj.items()}
124
+ elif isinstance(obj, list):
125
+ return [convert_datetimes_to_strings(item) for item in obj]
126
+ elif isinstance(obj, tuple):
127
+ return tuple(convert_datetimes_to_strings(item) for item in obj)
128
+ elif isinstance(obj, datetime):
129
+ return obj.isoformat()
130
+ else:
131
+ return obj
@@ -0,0 +1,61 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import logging
18
+ from collections.abc import Coroutine
19
+ from typing import Any
20
+
21
+ from neo4j import AsyncGraphDatabase
22
+ from typing_extensions import LiteralString
23
+
24
+ from graphiti_core.driver.driver import GraphDriver, GraphDriverSession
25
+ from graphiti_core.helpers import DEFAULT_DATABASE
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ class Neo4jDriver(GraphDriver):
31
+ provider: str = 'neo4j'
32
+
33
+ def __init__(
34
+ self,
35
+ uri: str,
36
+ user: str | None,
37
+ password: str | None,
38
+ ):
39
+ super().__init__()
40
+ self.client = AsyncGraphDatabase.driver(
41
+ uri=uri,
42
+ auth=(user or '', password or ''),
43
+ )
44
+
45
+ async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> Coroutine:
46
+ params = kwargs.pop('params', None)
47
+ result = await self.client.execute_query(cypher_query_, parameters_=params, **kwargs)
48
+
49
+ return result
50
+
51
+ def session(self, database: str) -> GraphDriverSession:
52
+ return self.client.session(database=database) # type: ignore
53
+
54
+ async def close(self) -> None:
55
+ return await self.client.close()
56
+
57
+ def delete_all_indexes(self, database_: str = DEFAULT_DATABASE) -> Coroutine:
58
+ return self.client.execute_query(
59
+ 'CALL db.indexes() YIELD name DROP INDEX name',
60
+ database_=database_,
61
+ )
graphiti_core/edges.py CHANGED
@@ -21,10 +21,10 @@ from time import time
21
21
  from typing import Any
22
22
  from uuid import uuid4
23
23
 
24
- from neo4j import AsyncDriver
25
24
  from pydantic import BaseModel, Field
26
25
  from typing_extensions import LiteralString
27
26
 
27
+ from graphiti_core.driver.driver import GraphDriver
28
28
  from graphiti_core.embedder import EmbedderClient
29
29
  from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
30
30
  from graphiti_core.helpers import DEFAULT_DATABASE, parse_db_date
@@ -62,9 +62,9 @@ class Edge(BaseModel, ABC):
62
62
  created_at: datetime
63
63
 
64
64
  @abstractmethod
65
- async def save(self, driver: AsyncDriver): ...
65
+ async def save(self, driver: GraphDriver): ...
66
66
 
67
- async def delete(self, driver: AsyncDriver):
67
+ async def delete(self, driver: GraphDriver):
68
68
  result = await driver.execute_query(
69
69
  """
70
70
  MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER {uuid: $uuid}]->(m)
@@ -87,11 +87,11 @@ class Edge(BaseModel, ABC):
87
87
  return False
88
88
 
89
89
  @classmethod
90
- async def get_by_uuid(cls, driver: AsyncDriver, uuid: str): ...
90
+ async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ...
91
91
 
92
92
 
93
93
  class EpisodicEdge(Edge):
94
- async def save(self, driver: AsyncDriver):
94
+ async def save(self, driver: GraphDriver):
95
95
  result = await driver.execute_query(
96
96
  EPISODIC_EDGE_SAVE,
97
97
  episode_uuid=self.source_node_uuid,
@@ -102,12 +102,12 @@ class EpisodicEdge(Edge):
102
102
  database_=DEFAULT_DATABASE,
103
103
  )
104
104
 
105
- logger.debug(f'Saved edge to neo4j: {self.uuid}')
105
+ logger.debug(f'Saved edge to Graph: {self.uuid}')
106
106
 
107
107
  return result
108
108
 
109
109
  @classmethod
110
- async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
110
+ async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
111
111
  records, _, _ = await driver.execute_query(
112
112
  """
113
113
  MATCH (n:Episodic)-[e:MENTIONS {uuid: $uuid}]->(m:Entity)
@@ -130,7 +130,7 @@ class EpisodicEdge(Edge):
130
130
  return edges[0]
131
131
 
132
132
  @classmethod
133
- async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
133
+ async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
134
134
  records, _, _ = await driver.execute_query(
135
135
  """
136
136
  MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
@@ -156,7 +156,7 @@ class EpisodicEdge(Edge):
156
156
  @classmethod
157
157
  async def get_by_group_ids(
158
158
  cls,
159
- driver: AsyncDriver,
159
+ driver: GraphDriver,
160
160
  group_ids: list[str],
161
161
  limit: int | None = None,
162
162
  uuid_cursor: str | None = None,
@@ -226,7 +226,7 @@ class EntityEdge(Edge):
226
226
 
227
227
  return self.fact_embedding
228
228
 
229
- async def load_fact_embedding(self, driver: AsyncDriver):
229
+ async def load_fact_embedding(self, driver: GraphDriver):
230
230
  query: LiteralString = """
231
231
  MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
232
232
  RETURN e.fact_embedding AS fact_embedding
@@ -240,7 +240,7 @@ class EntityEdge(Edge):
240
240
 
241
241
  self.fact_embedding = records[0]['fact_embedding']
242
242
 
243
- async def save(self, driver: AsyncDriver):
243
+ async def save(self, driver: GraphDriver):
244
244
  edge_data: dict[str, Any] = {
245
245
  'source_uuid': self.source_node_uuid,
246
246
  'target_uuid': self.target_node_uuid,
@@ -264,12 +264,12 @@ class EntityEdge(Edge):
264
264
  database_=DEFAULT_DATABASE,
265
265
  )
266
266
 
267
- logger.debug(f'Saved edge to neo4j: {self.uuid}')
267
+ logger.debug(f'Saved edge to Graph: {self.uuid}')
268
268
 
269
269
  return result
270
270
 
271
271
  @classmethod
272
- async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
272
+ async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
273
273
  records, _, _ = await driver.execute_query(
274
274
  """
275
275
  MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
@@ -287,7 +287,7 @@ class EntityEdge(Edge):
287
287
  return edges[0]
288
288
 
289
289
  @classmethod
290
- async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
290
+ async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
291
291
  if len(uuids) == 0:
292
292
  return []
293
293
 
@@ -309,7 +309,7 @@ class EntityEdge(Edge):
309
309
  @classmethod
310
310
  async def get_by_group_ids(
311
311
  cls,
312
- driver: AsyncDriver,
312
+ driver: GraphDriver,
313
313
  group_ids: list[str],
314
314
  limit: int | None = None,
315
315
  uuid_cursor: str | None = None,
@@ -342,11 +342,11 @@ class EntityEdge(Edge):
342
342
  return edges
343
343
 
344
344
  @classmethod
345
- async def get_by_node_uuid(cls, driver: AsyncDriver, node_uuid: str):
345
+ async def get_by_node_uuid(cls, driver: GraphDriver, node_uuid: str):
346
346
  query: LiteralString = (
347
347
  """
348
- MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
349
- """
348
+ MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
349
+ """
350
350
  + ENTITY_EDGE_RETURN
351
351
  )
352
352
  records, _, _ = await driver.execute_query(
@@ -359,7 +359,7 @@ class EntityEdge(Edge):
359
359
 
360
360
 
361
361
  class CommunityEdge(Edge):
362
- async def save(self, driver: AsyncDriver):
362
+ async def save(self, driver: GraphDriver):
363
363
  result = await driver.execute_query(
364
364
  COMMUNITY_EDGE_SAVE,
365
365
  community_uuid=self.source_node_uuid,
@@ -370,12 +370,12 @@ class CommunityEdge(Edge):
370
370
  database_=DEFAULT_DATABASE,
371
371
  )
372
372
 
373
- logger.debug(f'Saved edge to neo4j: {self.uuid}')
373
+ logger.debug(f'Saved edge to Graph: {self.uuid}')
374
374
 
375
375
  return result
376
376
 
377
377
  @classmethod
378
- async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
378
+ async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
379
379
  records, _, _ = await driver.execute_query(
380
380
  """
381
381
  MATCH (n:Community)-[e:HAS_MEMBER {uuid: $uuid}]->(m:Entity | Community)
@@ -396,7 +396,7 @@ class CommunityEdge(Edge):
396
396
  return edges[0]
397
397
 
398
398
  @classmethod
399
- async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
399
+ async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
400
400
  records, _, _ = await driver.execute_query(
401
401
  """
402
402
  MATCH (n:Community)-[e:HAS_MEMBER]->(m:Entity | Community)
@@ -420,7 +420,7 @@ class CommunityEdge(Edge):
420
420
  @classmethod
421
421
  async def get_by_group_ids(
422
422
  cls,
423
- driver: AsyncDriver,
423
+ driver: GraphDriver,
424
424
  group_ids: list[str],
425
425
  limit: int | None = None,
426
426
  uuid_cursor: str | None = None,
@@ -463,7 +463,7 @@ def get_episodic_edge_from_record(record: Any) -> EpisodicEdge:
463
463
  group_id=record['group_id'],
464
464
  source_node_uuid=record['source_node_uuid'],
465
465
  target_node_uuid=record['target_node_uuid'],
466
- created_at=record['created_at'].to_native(),
466
+ created_at=parse_db_date(record['created_at']), # type: ignore
467
467
  )
468
468
 
469
469
 
@@ -476,7 +476,7 @@ def get_entity_edge_from_record(record: Any) -> EntityEdge:
476
476
  name=record['name'],
477
477
  group_id=record['group_id'],
478
478
  episodes=record['episodes'],
479
- created_at=record['created_at'].to_native(),
479
+ created_at=parse_db_date(record['created_at']), # type: ignore
480
480
  expired_at=parse_db_date(record['expired_at']),
481
481
  valid_at=parse_db_date(record['valid_at']),
482
482
  invalid_at=parse_db_date(record['invalid_at']),
@@ -504,7 +504,7 @@ def get_community_edge_from_record(record: Any):
504
504
  group_id=record['group_id'],
505
505
  source_node_uuid=record['source_node_uuid'],
506
506
  target_node_uuid=record['target_node_uuid'],
507
- created_at=record['created_at'].to_native(),
507
+ created_at=parse_db_date(record['created_at']), # type: ignore
508
508
  )
509
509
 
510
510
 
@@ -0,0 +1,64 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import logging
18
+ from typing import Any
19
+
20
+ from openai import AsyncAzureOpenAI
21
+
22
+ from .client import EmbedderClient
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ class AzureOpenAIEmbedderClient(EmbedderClient):
28
+ """Wrapper class for AsyncAzureOpenAI that implements the EmbedderClient interface."""
29
+
30
+ def __init__(self, azure_client: AsyncAzureOpenAI, model: str = 'text-embedding-3-small'):
31
+ self.azure_client = azure_client
32
+ self.model = model
33
+
34
+ async def create(self, input_data: str | list[str] | Any) -> list[float]:
35
+ """Create embeddings using Azure OpenAI client."""
36
+ try:
37
+ # Handle different input types
38
+ if isinstance(input_data, str):
39
+ text_input = [input_data]
40
+ elif isinstance(input_data, list) and all(isinstance(item, str) for item in input_data):
41
+ text_input = input_data
42
+ else:
43
+ # Convert to string list for other types
44
+ text_input = [str(input_data)]
45
+
46
+ response = await self.azure_client.embeddings.create(model=self.model, input=text_input)
47
+
48
+ # Return the first embedding as a list of floats
49
+ return response.data[0].embedding
50
+ except Exception as e:
51
+ logger.error(f'Error in Azure OpenAI embedding: {e}')
52
+ raise
53
+
54
+ async def create_batch(self, input_data_list: list[str]) -> list[list[float]]:
55
+ """Create batch embeddings using Azure OpenAI client."""
56
+ try:
57
+ response = await self.azure_client.embeddings.create(
58
+ model=self.model, input=input_data_list
59
+ )
60
+
61
+ return [embedding.embedding for embedding in response.data]
62
+ except Exception as e:
63
+ logger.error(f'Error in Azure OpenAI batch embedding: {e}')
64
+ raise