graphiti-core 0.11.6rc7__py3-none-any.whl → 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

Files changed (33) hide show
  1. graphiti_core/cross_encoder/openai_reranker_client.py +1 -1
  2. graphiti_core/driver/__init__.py +17 -0
  3. graphiti_core/driver/driver.py +66 -0
  4. graphiti_core/driver/falkordb_driver.py +132 -0
  5. graphiti_core/driver/neo4j_driver.py +61 -0
  6. graphiti_core/edges.py +66 -40
  7. graphiti_core/embedder/azure_openai.py +64 -0
  8. graphiti_core/embedder/gemini.py +14 -3
  9. graphiti_core/graph_queries.py +149 -0
  10. graphiti_core/graphiti.py +41 -14
  11. graphiti_core/graphiti_types.py +2 -2
  12. graphiti_core/helpers.py +17 -30
  13. graphiti_core/llm_client/__init__.py +16 -0
  14. graphiti_core/llm_client/azure_openai_client.py +73 -0
  15. graphiti_core/llm_client/gemini_client.py +4 -1
  16. graphiti_core/models/edges/edge_db_queries.py +2 -4
  17. graphiti_core/nodes.py +31 -31
  18. graphiti_core/prompts/dedupe_edges.py +52 -1
  19. graphiti_core/prompts/dedupe_nodes.py +79 -4
  20. graphiti_core/prompts/extract_edges.py +50 -5
  21. graphiti_core/prompts/invalidate_edges.py +1 -1
  22. graphiti_core/search/search.py +25 -55
  23. graphiti_core/search/search_filters.py +23 -9
  24. graphiti_core/search/search_utils.py +360 -195
  25. graphiti_core/utils/bulk_utils.py +38 -11
  26. graphiti_core/utils/maintenance/community_operations.py +6 -7
  27. graphiti_core/utils/maintenance/edge_operations.py +149 -19
  28. graphiti_core/utils/maintenance/graph_data_operations.py +13 -42
  29. graphiti_core/utils/maintenance/node_operations.py +52 -71
  30. {graphiti_core-0.11.6rc7.dist-info → graphiti_core-0.12.0.dist-info}/METADATA +14 -5
  31. {graphiti_core-0.11.6rc7.dist-info → graphiti_core-0.12.0.dist-info}/RECORD +33 -26
  32. {graphiti_core-0.11.6rc7.dist-info → graphiti_core-0.12.0.dist-info}/LICENSE +0 -0
  33. {graphiti_core-0.11.6rc7.dist-info → graphiti_core-0.12.0.dist-info}/WHEEL +0 -0
@@ -106,7 +106,7 @@ class OpenAIRerankerClient(CrossEncoderClient):
106
106
  if len(top_logprobs) == 0:
107
107
  continue
108
108
  norm_logprobs = np.exp(top_logprobs[0].logprob)
109
- if bool(top_logprobs[0].token):
109
+ if top_logprobs[0].token.strip().split(' ')[0].lower() == 'true':
110
110
  scores.append(norm_logprobs)
111
111
  else:
112
112
  scores.append(1 - norm_logprobs)
@@ -0,0 +1,17 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ __all__ = ['GraphDriver', 'Neo4jDriver', 'FalkorDriver']
@@ -0,0 +1,66 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import logging
18
+ from abc import ABC, abstractmethod
19
+ from collections.abc import Coroutine
20
+ from typing import Any
21
+
22
+ from graphiti_core.helpers import DEFAULT_DATABASE
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ class GraphDriverSession(ABC):
28
+ async def __aenter__(self):
29
+ return self
30
+
31
+ @abstractmethod
32
+ async def __aexit__(self, exc_type, exc, tb):
33
+ # No cleanup needed for Falkor, but method must exist
34
+ pass
35
+
36
+ @abstractmethod
37
+ async def run(self, query: str, **kwargs: Any) -> Any:
38
+ raise NotImplementedError()
39
+
40
+ @abstractmethod
41
+ async def close(self):
42
+ raise NotImplementedError()
43
+
44
+ @abstractmethod
45
+ async def execute_write(self, func, *args, **kwargs):
46
+ raise NotImplementedError()
47
+
48
+
49
+ class GraphDriver(ABC):
50
+ provider: str
51
+
52
+ @abstractmethod
53
+ def execute_query(self, cypher_query_: str, **kwargs: Any) -> Coroutine:
54
+ raise NotImplementedError()
55
+
56
+ @abstractmethod
57
+ def session(self, database: str) -> GraphDriverSession:
58
+ raise NotImplementedError()
59
+
60
+ @abstractmethod
61
+ def close(self):
62
+ raise NotImplementedError()
63
+
64
+ @abstractmethod
65
+ def delete_all_indexes(self, database_: str = DEFAULT_DATABASE) -> Coroutine:
66
+ raise NotImplementedError()
@@ -0,0 +1,132 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import logging
18
+ from collections.abc import Coroutine
19
+ from datetime import datetime
20
+ from typing import Any
21
+
22
+ from falkordb import Graph as FalkorGraph # type: ignore
23
+ from falkordb.asyncio import FalkorDB # type: ignore
24
+
25
+ from graphiti_core.driver.driver import GraphDriver, GraphDriverSession
26
+ from graphiti_core.helpers import DEFAULT_DATABASE
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ class FalkorDriverSession(GraphDriverSession):
32
+ def __init__(self, graph: FalkorGraph):
33
+ self.graph = graph
34
+
35
+ async def __aenter__(self):
36
+ return self
37
+
38
+ async def __aexit__(self, exc_type, exc, tb):
39
+ # No cleanup needed for Falkor, but method must exist
40
+ pass
41
+
42
+ async def close(self):
43
+ # No explicit close needed for FalkorDB, but method must exist
44
+ pass
45
+
46
+ async def execute_write(self, func, *args, **kwargs):
47
+ # Directly await the provided async function with `self` as the transaction/session
48
+ return await func(self, *args, **kwargs)
49
+
50
+ async def run(self, query: str | list, **kwargs: Any) -> Any:
51
+ # FalkorDB does not support argument for Label Set, so it's converted into an array of queries
52
+ if isinstance(query, list):
53
+ for cypher, params in query:
54
+ params = convert_datetimes_to_strings(params)
55
+ await self.graph.query(str(cypher), params)
56
+ else:
57
+ params = dict(kwargs)
58
+ params = convert_datetimes_to_strings(params)
59
+ await self.graph.query(str(query), params)
60
+ # Assuming `graph.query` is async (ideal); otherwise, wrap in executor
61
+ return None
62
+
63
+
64
+ class FalkorDriver(GraphDriver):
65
+ provider: str = 'falkordb'
66
+
67
+ def __init__(
68
+ self,
69
+ uri: str,
70
+ user: str,
71
+ password: str,
72
+ ):
73
+ super().__init__()
74
+ if user and password:
75
+ uri_parts = uri.split('://', 1)
76
+ uri = f'{uri_parts[0]}://{user}:{password}@{uri_parts[1]}'
77
+
78
+ self.client = FalkorDB.from_url(
79
+ url=uri,
80
+ )
81
+
82
+ def _get_graph(self, graph_name: str | None) -> FalkorGraph:
83
+ # FalkorDB requires a non-None database name for multi-tenant graphs; the default is "DEFAULT_DATABASE"
84
+ if graph_name is None:
85
+ graph_name = 'DEFAULT_DATABASE'
86
+ return self.client.select_graph(graph_name)
87
+
88
+ async def execute_query(self, cypher_query_, **kwargs: Any):
89
+ graph_name = kwargs.pop('database_', DEFAULT_DATABASE)
90
+ graph = self._get_graph(graph_name)
91
+
92
+ # Convert datetime objects to ISO strings (FalkorDB does not support datetime objects directly)
93
+ params = convert_datetimes_to_strings(dict(kwargs))
94
+
95
+ try:
96
+ result = await graph.query(cypher_query_, params)
97
+ except Exception as e:
98
+ if 'already indexed' in str(e):
99
+ # check if index already exists
100
+ logger.info(f'Index already exists: {e}')
101
+ return None
102
+ logger.error(f'Error executing FalkorDB query: {e}')
103
+ raise
104
+
105
+ # Convert the result header to a list of strings
106
+ header = [h[1].decode('utf-8') for h in result.header]
107
+ return result.result_set, header, None
108
+
109
+ def session(self, database: str | None) -> GraphDriverSession:
110
+ return FalkorDriverSession(self._get_graph(database))
111
+
112
+ async def close(self) -> None:
113
+ await self.client.connection.close()
114
+
115
+ async def delete_all_indexes(self, database_: str = DEFAULT_DATABASE) -> Coroutine:
116
+ return self.execute_query(
117
+ 'CALL db.indexes() YIELD name DROP INDEX name',
118
+ database_=database_,
119
+ )
120
+
121
+
122
+ def convert_datetimes_to_strings(obj):
123
+ if isinstance(obj, dict):
124
+ return {k: convert_datetimes_to_strings(v) for k, v in obj.items()}
125
+ elif isinstance(obj, list):
126
+ return [convert_datetimes_to_strings(item) for item in obj]
127
+ elif isinstance(obj, tuple):
128
+ return tuple(convert_datetimes_to_strings(item) for item in obj)
129
+ elif isinstance(obj, datetime):
130
+ return obj.isoformat()
131
+ else:
132
+ return obj
@@ -0,0 +1,61 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import logging
18
+ from collections.abc import Coroutine
19
+ from typing import Any
20
+
21
+ from neo4j import AsyncGraphDatabase
22
+ from typing_extensions import LiteralString
23
+
24
+ from graphiti_core.driver.driver import GraphDriver, GraphDriverSession
25
+ from graphiti_core.helpers import DEFAULT_DATABASE
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ class Neo4jDriver(GraphDriver):
31
+ provider: str = 'neo4j'
32
+
33
+ def __init__(
34
+ self,
35
+ uri: str,
36
+ user: str | None,
37
+ password: str | None,
38
+ ):
39
+ super().__init__()
40
+ self.client = AsyncGraphDatabase.driver(
41
+ uri=uri,
42
+ auth=(user or '', password or ''),
43
+ )
44
+
45
+ async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> Coroutine:
46
+ params = kwargs.pop('params', None)
47
+ result = await self.client.execute_query(cypher_query_, parameters_=params, **kwargs)
48
+
49
+ return result
50
+
51
+ def session(self, database: str) -> GraphDriverSession:
52
+ return self.client.session(database=database) # type: ignore
53
+
54
+ async def close(self) -> None:
55
+ return await self.client.close()
56
+
57
+ def delete_all_indexes(self, database_: str = DEFAULT_DATABASE) -> Coroutine:
58
+ return self.client.execute_query(
59
+ 'CALL db.indexes() YIELD name DROP INDEX name',
60
+ database_=database_,
61
+ )
graphiti_core/edges.py CHANGED
@@ -21,10 +21,10 @@ from time import time
21
21
  from typing import Any
22
22
  from uuid import uuid4
23
23
 
24
- from neo4j import AsyncDriver
25
24
  from pydantic import BaseModel, Field
26
25
  from typing_extensions import LiteralString
27
26
 
27
+ from graphiti_core.driver.driver import GraphDriver
28
28
  from graphiti_core.embedder import EmbedderClient
29
29
  from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
30
30
  from graphiti_core.helpers import DEFAULT_DATABASE, parse_db_date
@@ -49,7 +49,9 @@ ENTITY_EDGE_RETURN: LiteralString = """
49
49
  e.episodes AS episodes,
50
50
  e.expired_at AS expired_at,
51
51
  e.valid_at AS valid_at,
52
- e.invalid_at AS invalid_at"""
52
+ e.invalid_at AS invalid_at,
53
+ properties(e) AS attributes
54
+ """
53
55
 
54
56
 
55
57
  class Edge(BaseModel, ABC):
@@ -60,9 +62,9 @@ class Edge(BaseModel, ABC):
60
62
  created_at: datetime
61
63
 
62
64
  @abstractmethod
63
- async def save(self, driver: AsyncDriver): ...
65
+ async def save(self, driver: GraphDriver): ...
64
66
 
65
- async def delete(self, driver: AsyncDriver):
67
+ async def delete(self, driver: GraphDriver):
66
68
  result = await driver.execute_query(
67
69
  """
68
70
  MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER {uuid: $uuid}]->(m)
@@ -85,11 +87,11 @@ class Edge(BaseModel, ABC):
85
87
  return False
86
88
 
87
89
  @classmethod
88
- async def get_by_uuid(cls, driver: AsyncDriver, uuid: str): ...
90
+ async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ...
89
91
 
90
92
 
91
93
  class EpisodicEdge(Edge):
92
- async def save(self, driver: AsyncDriver):
94
+ async def save(self, driver: GraphDriver):
93
95
  result = await driver.execute_query(
94
96
  EPISODIC_EDGE_SAVE,
95
97
  episode_uuid=self.source_node_uuid,
@@ -100,12 +102,12 @@ class EpisodicEdge(Edge):
100
102
  database_=DEFAULT_DATABASE,
101
103
  )
102
104
 
103
- logger.debug(f'Saved edge to neo4j: {self.uuid}')
105
+ logger.debug(f'Saved edge to Graph: {self.uuid}')
104
106
 
105
107
  return result
106
108
 
107
109
  @classmethod
108
- async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
110
+ async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
109
111
  records, _, _ = await driver.execute_query(
110
112
  """
111
113
  MATCH (n:Episodic)-[e:MENTIONS {uuid: $uuid}]->(m:Entity)
@@ -128,7 +130,7 @@ class EpisodicEdge(Edge):
128
130
  return edges[0]
129
131
 
130
132
  @classmethod
131
- async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
133
+ async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
132
134
  records, _, _ = await driver.execute_query(
133
135
  """
134
136
  MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
@@ -154,7 +156,7 @@ class EpisodicEdge(Edge):
154
156
  @classmethod
155
157
  async def get_by_group_ids(
156
158
  cls,
157
- driver: AsyncDriver,
159
+ driver: GraphDriver,
158
160
  group_ids: list[str],
159
161
  limit: int | None = None,
160
162
  uuid_cursor: str | None = None,
@@ -209,6 +211,9 @@ class EntityEdge(Edge):
209
211
  invalid_at: datetime | None = Field(
210
212
  default=None, description='datetime of when the fact stopped being true'
211
213
  )
214
+ attributes: dict[str, Any] = Field(
215
+ default={}, description='Additional attributes of the edge. Dependent on edge name'
216
+ )
212
217
 
213
218
  async def generate_embedding(self, embedder: EmbedderClient):
214
219
  start = time()
@@ -221,7 +226,7 @@ class EntityEdge(Edge):
221
226
 
222
227
  return self.fact_embedding
223
228
 
224
- async def load_fact_embedding(self, driver: AsyncDriver):
229
+ async def load_fact_embedding(self, driver: GraphDriver):
225
230
  query: LiteralString = """
226
231
  MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
227
232
  RETURN e.fact_embedding AS fact_embedding
@@ -235,30 +240,36 @@ class EntityEdge(Edge):
235
240
 
236
241
  self.fact_embedding = records[0]['fact_embedding']
237
242
 
238
- async def save(self, driver: AsyncDriver):
243
+ async def save(self, driver: GraphDriver):
244
+ edge_data: dict[str, Any] = {
245
+ 'source_uuid': self.source_node_uuid,
246
+ 'target_uuid': self.target_node_uuid,
247
+ 'uuid': self.uuid,
248
+ 'name': self.name,
249
+ 'group_id': self.group_id,
250
+ 'fact': self.fact,
251
+ 'fact_embedding': self.fact_embedding,
252
+ 'episodes': self.episodes,
253
+ 'created_at': self.created_at,
254
+ 'expired_at': self.expired_at,
255
+ 'valid_at': self.valid_at,
256
+ 'invalid_at': self.invalid_at,
257
+ }
258
+
259
+ edge_data.update(self.attributes or {})
260
+
239
261
  result = await driver.execute_query(
240
262
  ENTITY_EDGE_SAVE,
241
- source_uuid=self.source_node_uuid,
242
- target_uuid=self.target_node_uuid,
243
- uuid=self.uuid,
244
- name=self.name,
245
- group_id=self.group_id,
246
- fact=self.fact,
247
- fact_embedding=self.fact_embedding,
248
- episodes=self.episodes,
249
- created_at=self.created_at,
250
- expired_at=self.expired_at,
251
- valid_at=self.valid_at,
252
- invalid_at=self.invalid_at,
263
+ edge_data=edge_data,
253
264
  database_=DEFAULT_DATABASE,
254
265
  )
255
266
 
256
- logger.debug(f'Saved edge to neo4j: {self.uuid}')
267
+ logger.debug(f'Saved edge to Graph: {self.uuid}')
257
268
 
258
269
  return result
259
270
 
260
271
  @classmethod
261
- async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
272
+ async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
262
273
  records, _, _ = await driver.execute_query(
263
274
  """
264
275
  MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
@@ -276,7 +287,7 @@ class EntityEdge(Edge):
276
287
  return edges[0]
277
288
 
278
289
  @classmethod
279
- async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
290
+ async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
280
291
  if len(uuids) == 0:
281
292
  return []
282
293
 
@@ -298,7 +309,7 @@ class EntityEdge(Edge):
298
309
  @classmethod
299
310
  async def get_by_group_ids(
300
311
  cls,
301
- driver: AsyncDriver,
312
+ driver: GraphDriver,
302
313
  group_ids: list[str],
303
314
  limit: int | None = None,
304
315
  uuid_cursor: str | None = None,
@@ -331,11 +342,11 @@ class EntityEdge(Edge):
331
342
  return edges
332
343
 
333
344
  @classmethod
334
- async def get_by_node_uuid(cls, driver: AsyncDriver, node_uuid: str):
345
+ async def get_by_node_uuid(cls, driver: GraphDriver, node_uuid: str):
335
346
  query: LiteralString = (
336
347
  """
337
- MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
338
- """
348
+ MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
349
+ """
339
350
  + ENTITY_EDGE_RETURN
340
351
  )
341
352
  records, _, _ = await driver.execute_query(
@@ -348,7 +359,7 @@ class EntityEdge(Edge):
348
359
 
349
360
 
350
361
  class CommunityEdge(Edge):
351
- async def save(self, driver: AsyncDriver):
362
+ async def save(self, driver: GraphDriver):
352
363
  result = await driver.execute_query(
353
364
  COMMUNITY_EDGE_SAVE,
354
365
  community_uuid=self.source_node_uuid,
@@ -359,12 +370,12 @@ class CommunityEdge(Edge):
359
370
  database_=DEFAULT_DATABASE,
360
371
  )
361
372
 
362
- logger.debug(f'Saved edge to neo4j: {self.uuid}')
373
+ logger.debug(f'Saved edge to Graph: {self.uuid}')
363
374
 
364
375
  return result
365
376
 
366
377
  @classmethod
367
- async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
378
+ async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
368
379
  records, _, _ = await driver.execute_query(
369
380
  """
370
381
  MATCH (n:Community)-[e:HAS_MEMBER {uuid: $uuid}]->(m:Entity | Community)
@@ -385,7 +396,7 @@ class CommunityEdge(Edge):
385
396
  return edges[0]
386
397
 
387
398
  @classmethod
388
- async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
399
+ async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
389
400
  records, _, _ = await driver.execute_query(
390
401
  """
391
402
  MATCH (n:Community)-[e:HAS_MEMBER]->(m:Entity | Community)
@@ -409,7 +420,7 @@ class CommunityEdge(Edge):
409
420
  @classmethod
410
421
  async def get_by_group_ids(
411
422
  cls,
412
- driver: AsyncDriver,
423
+ driver: GraphDriver,
413
424
  group_ids: list[str],
414
425
  limit: int | None = None,
415
426
  uuid_cursor: str | None = None,
@@ -452,12 +463,12 @@ def get_episodic_edge_from_record(record: Any) -> EpisodicEdge:
452
463
  group_id=record['group_id'],
453
464
  source_node_uuid=record['source_node_uuid'],
454
465
  target_node_uuid=record['target_node_uuid'],
455
- created_at=record['created_at'].to_native(),
466
+ created_at=parse_db_date(record['created_at']), # type: ignore
456
467
  )
457
468
 
458
469
 
459
470
  def get_entity_edge_from_record(record: Any) -> EntityEdge:
460
- return EntityEdge(
471
+ edge = EntityEdge(
461
472
  uuid=record['uuid'],
462
473
  source_node_uuid=record['source_node_uuid'],
463
474
  target_node_uuid=record['target_node_uuid'],
@@ -465,12 +476,27 @@ def get_entity_edge_from_record(record: Any) -> EntityEdge:
465
476
  name=record['name'],
466
477
  group_id=record['group_id'],
467
478
  episodes=record['episodes'],
468
- created_at=record['created_at'].to_native(),
479
+ created_at=parse_db_date(record['created_at']), # type: ignore
469
480
  expired_at=parse_db_date(record['expired_at']),
470
481
  valid_at=parse_db_date(record['valid_at']),
471
482
  invalid_at=parse_db_date(record['invalid_at']),
483
+ attributes=record['attributes'],
472
484
  )
473
485
 
486
+ edge.attributes.pop('uuid', None)
487
+ edge.attributes.pop('source_node_uuid', None)
488
+ edge.attributes.pop('target_node_uuid', None)
489
+ edge.attributes.pop('fact', None)
490
+ edge.attributes.pop('name', None)
491
+ edge.attributes.pop('group_id', None)
492
+ edge.attributes.pop('episodes', None)
493
+ edge.attributes.pop('created_at', None)
494
+ edge.attributes.pop('expired_at', None)
495
+ edge.attributes.pop('valid_at', None)
496
+ edge.attributes.pop('invalid_at', None)
497
+
498
+ return edge
499
+
474
500
 
475
501
  def get_community_edge_from_record(record: Any):
476
502
  return CommunityEdge(
@@ -478,7 +504,7 @@ def get_community_edge_from_record(record: Any):
478
504
  group_id=record['group_id'],
479
505
  source_node_uuid=record['source_node_uuid'],
480
506
  target_node_uuid=record['target_node_uuid'],
481
- created_at=record['created_at'].to_native(),
507
+ created_at=parse_db_date(record['created_at']), # type: ignore
482
508
  )
483
509
 
484
510
 
@@ -0,0 +1,64 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import logging
18
+ from typing import Any
19
+
20
+ from openai import AsyncAzureOpenAI
21
+
22
+ from .client import EmbedderClient
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ class AzureOpenAIEmbedderClient(EmbedderClient):
28
+ """Wrapper class for AsyncAzureOpenAI that implements the EmbedderClient interface."""
29
+
30
+ def __init__(self, azure_client: AsyncAzureOpenAI, model: str = 'text-embedding-3-small'):
31
+ self.azure_client = azure_client
32
+ self.model = model
33
+
34
+ async def create(self, input_data: str | list[str] | Any) -> list[float]:
35
+ """Create embeddings using Azure OpenAI client."""
36
+ try:
37
+ # Handle different input types
38
+ if isinstance(input_data, str):
39
+ text_input = [input_data]
40
+ elif isinstance(input_data, list) and all(isinstance(item, str) for item in input_data):
41
+ text_input = input_data
42
+ else:
43
+ # Convert to string list for other types
44
+ text_input = [str(input_data)]
45
+
46
+ response = await self.azure_client.embeddings.create(model=self.model, input=text_input)
47
+
48
+ # Return the first embedding as a list of floats
49
+ return response.data[0].embedding
50
+ except Exception as e:
51
+ logger.error(f'Error in Azure OpenAI embedding: {e}')
52
+ raise
53
+
54
+ async def create_batch(self, input_data_list: list[str]) -> list[list[float]]:
55
+ """Create batch embeddings using Azure OpenAI client."""
56
+ try:
57
+ response = await self.azure_client.embeddings.create(
58
+ model=self.model, input=input_data_list
59
+ )
60
+
61
+ return [embedding.embedding for embedding in response.data]
62
+ except Exception as e:
63
+ logger.error(f'Error in Azure OpenAI batch embedding: {e}')
64
+ raise
@@ -61,18 +61,29 @@ class GeminiEmbedder(EmbedderClient):
61
61
  # Generate embeddings
62
62
  result = await self.client.aio.models.embed_content(
63
63
  model=self.config.embedding_model or DEFAULT_EMBEDDING_MODEL,
64
- contents=[input_data],
64
+ contents=[input_data], # type: ignore[arg-type] # mypy fails on broad union type
65
65
  config=types.EmbedContentConfig(output_dimensionality=self.config.embedding_dim),
66
66
  )
67
67
 
68
+ if not result.embeddings or len(result.embeddings) == 0 or not result.embeddings[0].values:
69
+ raise ValueError('No embeddings returned from Gemini API in create()')
70
+
68
71
  return result.embeddings[0].values
69
72
 
70
73
  async def create_batch(self, input_data_list: list[str]) -> list[list[float]]:
71
74
  # Generate embeddings
72
75
  result = await self.client.aio.models.embed_content(
73
76
  model=self.config.embedding_model or DEFAULT_EMBEDDING_MODEL,
74
- contents=input_data_list,
77
+ contents=input_data_list, # type: ignore[arg-type] # mypy fails on broad union type
75
78
  config=types.EmbedContentConfig(output_dimensionality=self.config.embedding_dim),
76
79
  )
77
80
 
78
- return [embedding.values for embedding in result.embeddings]
81
+ if not result.embeddings or len(result.embeddings) == 0:
82
+ raise Exception('No embeddings returned')
83
+
84
+ embeddings = []
85
+ for embedding in result.embeddings:
86
+ if not embedding.values:
87
+ raise ValueError('Empty embedding values returned')
88
+ embeddings.append(embedding.values)
89
+ return embeddings