graphiti-core 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

Files changed (37) hide show
  1. graphiti_core/__init__.py +3 -0
  2. graphiti_core/edges.py +232 -0
  3. graphiti_core/graphiti.py +618 -0
  4. graphiti_core/helpers.py +7 -0
  5. graphiti_core/llm_client/__init__.py +5 -0
  6. graphiti_core/llm_client/anthropic_client.py +63 -0
  7. graphiti_core/llm_client/client.py +96 -0
  8. graphiti_core/llm_client/config.py +58 -0
  9. graphiti_core/llm_client/groq_client.py +64 -0
  10. graphiti_core/llm_client/openai_client.py +65 -0
  11. graphiti_core/llm_client/utils.py +22 -0
  12. graphiti_core/nodes.py +250 -0
  13. graphiti_core/prompts/__init__.py +4 -0
  14. graphiti_core/prompts/dedupe_edges.py +154 -0
  15. graphiti_core/prompts/dedupe_nodes.py +151 -0
  16. graphiti_core/prompts/extract_edge_dates.py +60 -0
  17. graphiti_core/prompts/extract_edges.py +138 -0
  18. graphiti_core/prompts/extract_nodes.py +145 -0
  19. graphiti_core/prompts/invalidate_edges.py +74 -0
  20. graphiti_core/prompts/lib.py +122 -0
  21. graphiti_core/prompts/models.py +31 -0
  22. graphiti_core/search/__init__.py +0 -0
  23. graphiti_core/search/search.py +142 -0
  24. graphiti_core/search/search_utils.py +454 -0
  25. graphiti_core/utils/__init__.py +15 -0
  26. graphiti_core/utils/bulk_utils.py +227 -0
  27. graphiti_core/utils/maintenance/__init__.py +16 -0
  28. graphiti_core/utils/maintenance/edge_operations.py +170 -0
  29. graphiti_core/utils/maintenance/graph_data_operations.py +133 -0
  30. graphiti_core/utils/maintenance/node_operations.py +199 -0
  31. graphiti_core/utils/maintenance/temporal_operations.py +184 -0
  32. graphiti_core/utils/maintenance/utils.py +0 -0
  33. graphiti_core/utils/utils.py +39 -0
  34. graphiti_core-0.1.0.dist-info/LICENSE +201 -0
  35. graphiti_core-0.1.0.dist-info/METADATA +199 -0
  36. graphiti_core-0.1.0.dist-info/RECORD +37 -0
  37. graphiti_core-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,618 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import asyncio
18
+ import logging
19
+ from datetime import datetime
20
+ from time import time
21
+ from typing import Callable
22
+
23
+ from dotenv import load_dotenv
24
+ from neo4j import AsyncGraphDatabase
25
+
26
+ from graphiti_core.edges import EntityEdge, EpisodicEdge
27
+ from graphiti_core.llm_client import LLMClient, OpenAIClient
28
+ from graphiti_core.llm_client.utils import generate_embedding
29
+ from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
30
+ from graphiti_core.search.search import Reranker, SearchConfig, SearchMethod, hybrid_search
31
+ from graphiti_core.search.search_utils import (
32
+ get_relevant_edges,
33
+ get_relevant_nodes,
34
+ hybrid_node_search,
35
+ )
36
+ from graphiti_core.utils import (
37
+ build_episodic_edges,
38
+ retrieve_episodes,
39
+ )
40
+ from graphiti_core.utils.bulk_utils import (
41
+ RawEpisode,
42
+ dedupe_edges_bulk,
43
+ dedupe_nodes_bulk,
44
+ extract_nodes_and_edges_bulk,
45
+ resolve_edge_pointers,
46
+ retrieve_previous_episodes_bulk,
47
+ )
48
+ from graphiti_core.utils.maintenance.edge_operations import (
49
+ dedupe_extracted_edges,
50
+ extract_edges,
51
+ )
52
+ from graphiti_core.utils.maintenance.graph_data_operations import (
53
+ EPISODE_WINDOW_LEN,
54
+ build_indices_and_constraints,
55
+ )
56
+ from graphiti_core.utils.maintenance.node_operations import dedupe_extracted_nodes, extract_nodes
57
+ from graphiti_core.utils.maintenance.temporal_operations import (
58
+ extract_edge_dates,
59
+ invalidate_edges,
60
+ prepare_edges_for_invalidation,
61
+ )
62
+
63
+ logger = logging.getLogger(__name__)
64
+
65
+ load_dotenv()
66
+
67
+
68
+ class Graphiti:
69
+ def __init__(self, uri: str, user: str, password: str, llm_client: LLMClient | None = None):
70
+ """
71
+ Initialize a Graphiti instance.
72
+
73
+ This constructor sets up a connection to the Neo4j database and initializes
74
+ the LLM client for natural language processing tasks.
75
+
76
+ Parameters
77
+ ----------
78
+ uri : str
79
+ The URI of the Neo4j database.
80
+ user : str
81
+ The username for authenticating with the Neo4j database.
82
+ password : str
83
+ The password for authenticating with the Neo4j database.
84
+ llm_client : LLMClient | None, optional
85
+ An instance of LLMClient for natural language processing tasks.
86
+ If not provided, a default OpenAIClient will be initialized.
87
+
88
+ Returns
89
+ -------
90
+ None
91
+
92
+ Notes
93
+ -----
94
+ This method establishes a connection to the Neo4j database using the provided
95
+ credentials. It also sets up the LLM client, either using the provided client
96
+ or by creating a default OpenAIClient.
97
+
98
+ The default database name is set to 'neo4j'. If a different database name
99
+ is required, it should be specified in the URI or set separately after
100
+ initialization.
101
+
102
+ The OpenAI API key is expected to be set in the environment variables.
103
+ Make sure to set the OPENAI_API_KEY environment variable before initializing
104
+ Graphiti if you're using the default OpenAIClient.
105
+ """
106
+ self.driver = AsyncGraphDatabase.driver(uri, auth=(user, password))
107
+ self.database = 'neo4j'
108
+ if llm_client:
109
+ self.llm_client = llm_client
110
+ else:
111
+ self.llm_client = OpenAIClient()
112
+
113
+ def close(self):
114
+ """
115
+ Close the connection to the Neo4j database.
116
+
117
+ This method safely closes the driver connection to the Neo4j database.
118
+ It should be called when the Graphiti instance is no longer needed or
119
+ when the application is shutting down.
120
+
121
+ Parameters
122
+ ----------
123
+ None
124
+
125
+ Returns
126
+ -------
127
+ None
128
+
129
+ Notes
130
+ -----
131
+ It's important to close the driver connection to release system resources
132
+ and ensure that all pending transactions are completed or rolled back.
133
+ This method should be called as part of a cleanup process, potentially
134
+ in a context manager or a shutdown hook.
135
+
136
+ Example:
137
+ graphiti = Graphiti(uri, user, password)
138
+ try:
139
+ # Use graphiti...
140
+ finally:
141
+ graphiti.close()
142
+ self.driver.close()
143
+ """
144
+
145
+ async def build_indices_and_constraints(self):
146
+ """
147
+ Build indices and constraints in the Neo4j database.
148
+
149
+ This method sets up the necessary indices and constraints in the Neo4j database
150
+ to optimize query performance and ensure data integrity for the knowledge graph.
151
+
152
+ Parameters
153
+ ----------
154
+ None
155
+
156
+ Returns
157
+ -------
158
+ None
159
+
160
+ Notes
161
+ -----
162
+ This method should typically be called once during the initial setup of the
163
+ knowledge graph or when updating the database schema. It uses the
164
+ `build_indices_and_constraints` function from the
165
+ `graphiti_core.utils.maintenance.graph_data_operations` module to perform
166
+ the actual database operations.
167
+
168
+ The specific indices and constraints created depend on the implementation
169
+ of the `build_indices_and_constraints` function. Refer to that function's
170
+ documentation for details on the exact database schema modifications.
171
+
172
+ Caution: Running this method on a large existing database may take some time
173
+ and could impact database performance during execution.
174
+ """
175
+ await build_indices_and_constraints(self.driver)
176
+
177
+ async def retrieve_episodes(
178
+ self,
179
+ reference_time: datetime,
180
+ last_n: int = EPISODE_WINDOW_LEN,
181
+ ) -> list[EpisodicNode]:
182
+ """
183
+ Retrieve the last n episodic nodes from the graph.
184
+
185
+ This method fetches a specified number of the most recent episodic nodes
186
+ from the graph, relative to the given reference time.
187
+
188
+ Parameters
189
+ ----------
190
+ reference_time : datetime
191
+ The reference time to retrieve episodes before.
192
+ last_n : int, optional
193
+ The number of episodes to retrieve. Defaults to EPISODE_WINDOW_LEN.
194
+
195
+ Returns
196
+ -------
197
+ list[EpisodicNode]
198
+ A list of the most recent EpisodicNode objects.
199
+
200
+ Notes
201
+ -----
202
+ The actual retrieval is performed by the `retrieve_episodes` function
203
+ from the `graphiti_core.utils` module.
204
+ """
205
+ return await retrieve_episodes(self.driver, reference_time, last_n)
206
+
207
+ async def add_episode(
208
+ self,
209
+ name: str,
210
+ episode_body: str,
211
+ source_description: str,
212
+ reference_time: datetime,
213
+ source: EpisodeType = EpisodeType.message,
214
+ success_callback: Callable | None = None,
215
+ error_callback: Callable | None = None,
216
+ ):
217
+ """
218
+ Process an episode and update the graph.
219
+
220
+ This method extracts information from the episode, creates nodes and edges,
221
+ and updates the graph database accordingly.
222
+
223
+ Parameters
224
+ ----------
225
+ name : str
226
+ The name of the episode.
227
+ episode_body : str
228
+ The content of the episode.
229
+ source_description : str
230
+ A description of the episode's source.
231
+ reference_time : datetime
232
+ The reference time for the episode.
233
+ source : EpisodeType, optional
234
+ The type of the episode. Defaults to EpisodeType.message.
235
+ success_callback : Callable | None, optional
236
+ A callback function to be called upon successful processing.
237
+ error_callback : Callable | None, optional
238
+ A callback function to be called if an error occurs during processing.
239
+
240
+ Returns
241
+ -------
242
+ None
243
+
244
+ Notes
245
+ -----
246
+ This method performs several steps including node extraction, edge extraction,
247
+ deduplication, and database updates. It also handles embedding generation
248
+ and edge invalidation.
249
+
250
+ It is recommended to run this method as a background process, such as in a queue.
251
+ It's important that each episode is added sequentially and awaited before adding
252
+ the next one. For web applications, consider using FastAPI's background tasks
253
+ or a dedicated task queue like Celery for this purpose.
254
+
255
+ Example using FastAPI background tasks:
256
+ @app.post("/add_episode")
257
+ async def add_episode_endpoint(episode_data: EpisodeData):
258
+ background_tasks.add_task(graphiti.add_episode, **episode_data.dict())
259
+ return {"message": "Episode processing started"}
260
+ """
261
+ try:
262
+ start = time()
263
+
264
+ nodes: list[EntityNode] = []
265
+ entity_edges: list[EntityEdge] = []
266
+ episodic_edges: list[EpisodicEdge] = []
267
+ embedder = self.llm_client.get_embedder()
268
+ now = datetime.now()
269
+
270
+ previous_episodes = await self.retrieve_episodes(reference_time, last_n=3)
271
+ episode = EpisodicNode(
272
+ name=name,
273
+ labels=[],
274
+ source=source,
275
+ content=episode_body,
276
+ source_description=source_description,
277
+ created_at=now,
278
+ valid_at=reference_time,
279
+ )
280
+
281
+ extracted_nodes = await extract_nodes(self.llm_client, episode, previous_episodes)
282
+ logger.info(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
283
+
284
+ # Calculate Embeddings
285
+
286
+ await asyncio.gather(
287
+ *[node.generate_name_embedding(embedder) for node in extracted_nodes]
288
+ )
289
+ existing_nodes = await get_relevant_nodes(extracted_nodes, self.driver)
290
+ logger.info(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
291
+ touched_nodes, _, brand_new_nodes = await dedupe_extracted_nodes(
292
+ self.llm_client, extracted_nodes, existing_nodes
293
+ )
294
+ logger.info(f'Adjusted touched nodes: {[(n.name, n.uuid) for n in touched_nodes]}')
295
+ nodes.extend(touched_nodes)
296
+
297
+ extracted_edges = await extract_edges(
298
+ self.llm_client, episode, touched_nodes, previous_episodes
299
+ )
300
+
301
+ await asyncio.gather(*[edge.generate_embedding(embedder) for edge in extracted_edges])
302
+
303
+ existing_edges = await get_relevant_edges(extracted_edges, self.driver)
304
+ logger.info(f'Existing edges: {[(e.name, e.uuid) for e in existing_edges]}')
305
+ logger.info(f'Extracted edges: {[(e.name, e.uuid) for e in extracted_edges]}')
306
+
307
+ deduped_edges = await dedupe_extracted_edges(
308
+ self.llm_client,
309
+ extracted_edges,
310
+ existing_edges,
311
+ )
312
+
313
+ edge_touched_node_uuids = [n.uuid for n in brand_new_nodes]
314
+ for edge in deduped_edges:
315
+ edge_touched_node_uuids.append(edge.source_node_uuid)
316
+ edge_touched_node_uuids.append(edge.target_node_uuid)
317
+
318
+ for edge in deduped_edges:
319
+ valid_at, invalid_at, _ = await extract_edge_dates(
320
+ self.llm_client,
321
+ edge,
322
+ episode.valid_at,
323
+ episode,
324
+ previous_episodes,
325
+ )
326
+ edge.valid_at = valid_at
327
+ edge.invalid_at = invalid_at
328
+ if edge.invalid_at:
329
+ edge.expired_at = datetime.now()
330
+ for edge in existing_edges:
331
+ valid_at, invalid_at, _ = await extract_edge_dates(
332
+ self.llm_client,
333
+ edge,
334
+ episode.valid_at,
335
+ episode,
336
+ previous_episodes,
337
+ )
338
+ edge.valid_at = valid_at
339
+ edge.invalid_at = invalid_at
340
+ if edge.invalid_at:
341
+ edge.expired_at = datetime.now()
342
+ (
343
+ old_edges_with_nodes_pending_invalidation,
344
+ new_edges_with_nodes,
345
+ ) = prepare_edges_for_invalidation(
346
+ existing_edges=existing_edges, new_edges=deduped_edges, nodes=nodes
347
+ )
348
+
349
+ invalidated_edges = await invalidate_edges(
350
+ self.llm_client,
351
+ old_edges_with_nodes_pending_invalidation,
352
+ new_edges_with_nodes,
353
+ episode,
354
+ previous_episodes,
355
+ )
356
+
357
+ for edge in invalidated_edges:
358
+ for existing_edge in existing_edges:
359
+ if existing_edge.uuid == edge.uuid:
360
+ existing_edge.expired_at = edge.expired_at
361
+ for deduped_edge in deduped_edges:
362
+ if deduped_edge.uuid == edge.uuid:
363
+ deduped_edge.expired_at = edge.expired_at
364
+ edge_touched_node_uuids.append(edge.source_node_uuid)
365
+ edge_touched_node_uuids.append(edge.target_node_uuid)
366
+ logger.info(f'Invalidated edges: {[(e.name, e.uuid) for e in invalidated_edges]}')
367
+
368
+ edges_to_save = existing_edges + deduped_edges
369
+
370
+ entity_edges.extend(edges_to_save)
371
+
372
+ edge_touched_node_uuids = list(set(edge_touched_node_uuids))
373
+ involved_nodes = [node for node in nodes if node.uuid in edge_touched_node_uuids]
374
+
375
+ logger.info(f'Edge touched nodes: {[(n.name, n.uuid) for n in involved_nodes]}')
376
+
377
+ logger.info(f'Deduped edges: {[(e.name, e.uuid) for e in deduped_edges]}')
378
+
379
+ episodic_edges.extend(
380
+ build_episodic_edges(
381
+ # There may be an overlap between new_nodes and affected_nodes, so we're deduplicating them
382
+ involved_nodes,
383
+ episode,
384
+ now,
385
+ )
386
+ )
387
+ # Important to append the episode to the nodes at the end so that self referencing episodic edges are not built
388
+ logger.info(f'Built episodic edges: {episodic_edges}')
389
+
390
+ # Future optimization would be using batch operations to save nodes and edges
391
+ await episode.save(self.driver)
392
+ await asyncio.gather(*[node.save(self.driver) for node in nodes])
393
+ await asyncio.gather(*[edge.save(self.driver) for edge in episodic_edges])
394
+ await asyncio.gather(*[edge.save(self.driver) for edge in entity_edges])
395
+
396
+ end = time()
397
+ logger.info(f'Completed add_episode in {(end - start) * 1000} ms')
398
+ # for node in nodes:
399
+ # if isinstance(node, EntityNode):
400
+ # await node.update_summary(self.driver)
401
+ if success_callback:
402
+ await success_callback(episode)
403
+ except Exception as e:
404
+ if error_callback:
405
+ await error_callback(episode, e)
406
+ else:
407
+ raise e
408
+
409
+ async def add_episode_bulk(
410
+ self,
411
+ bulk_episodes: list[RawEpisode],
412
+ ):
413
+ """
414
+ Process multiple episodes in bulk and update the graph.
415
+
416
+ This method extracts information from multiple episodes, creates nodes and edges,
417
+ and updates the graph database accordingly, all in a single batch operation.
418
+
419
+ Parameters
420
+ ----------
421
+ bulk_episodes : list[RawEpisode]
422
+ A list of RawEpisode objects to be processed and added to the graph.
423
+
424
+ Returns
425
+ -------
426
+ None
427
+
428
+ Notes
429
+ -----
430
+ This method performs several steps including:
431
+ - Saving all episodes to the database
432
+ - Retrieving previous episode context for each new episode
433
+ - Extracting nodes and edges from all episodes
434
+ - Generating embeddings for nodes and edges
435
+ - Deduplicating nodes and edges
436
+ - Saving nodes, episodic edges, and entity edges to the knowledge graph
437
+
438
+ This bulk operation is designed for efficiency when processing multiple episodes
439
+ at once. However, it's important to ensure that the bulk operation doesn't
440
+ overwhelm system resources. Consider implementing rate limiting or chunking for
441
+ very large batches of episodes.
442
+
443
+ Important: This method does not perform edge invalidation or date extraction steps.
444
+ If these operations are required, use the `add_episode` method instead for each
445
+ individual episode.
446
+ """
447
+ try:
448
+ start = time()
449
+ embedder = self.llm_client.get_embedder()
450
+ now = datetime.now()
451
+
452
+ episodes = [
453
+ EpisodicNode(
454
+ name=episode.name,
455
+ labels=[],
456
+ source=episode.source,
457
+ content=episode.content,
458
+ source_description=episode.source_description,
459
+ created_at=now,
460
+ valid_at=episode.reference_time,
461
+ )
462
+ for episode in bulk_episodes
463
+ ]
464
+
465
+ # Save all the episodes
466
+ await asyncio.gather(*[episode.save(self.driver) for episode in episodes])
467
+
468
+ # Get previous episode context for each episode
469
+ episode_pairs = await retrieve_previous_episodes_bulk(self.driver, episodes)
470
+
471
+ # Extract all nodes and edges
472
+ (
473
+ extracted_nodes,
474
+ extracted_edges,
475
+ episodic_edges,
476
+ ) = await extract_nodes_and_edges_bulk(self.llm_client, episode_pairs)
477
+
478
+ # Generate embeddings
479
+ await asyncio.gather(
480
+ *[node.generate_name_embedding(embedder) for node in extracted_nodes],
481
+ *[edge.generate_embedding(embedder) for edge in extracted_edges],
482
+ )
483
+
484
+ # Dedupe extracted nodes
485
+ nodes, uuid_map = await dedupe_nodes_bulk(self.driver, self.llm_client, extracted_nodes)
486
+
487
+ # save nodes to KG
488
+ await asyncio.gather(*[node.save(self.driver) for node in nodes])
489
+
490
+ # re-map edge pointers so that they don't point to discard dupe nodes
491
+ extracted_edges_with_resolved_pointers: list[EntityEdge] = resolve_edge_pointers(
492
+ extracted_edges, uuid_map
493
+ )
494
+ episodic_edges_with_resolved_pointers: list[EpisodicEdge] = resolve_edge_pointers(
495
+ episodic_edges, uuid_map
496
+ )
497
+
498
+ # save episodic edges to KG
499
+ await asyncio.gather(
500
+ *[edge.save(self.driver) for edge in episodic_edges_with_resolved_pointers]
501
+ )
502
+
503
+ # Dedupe extracted edges
504
+ edges = await dedupe_edges_bulk(
505
+ self.driver, self.llm_client, extracted_edges_with_resolved_pointers
506
+ )
507
+ logger.info(f'extracted edge length: {len(edges)}')
508
+
509
+ # invalidate edges
510
+
511
+ # save edges to KG
512
+ await asyncio.gather(*[edge.save(self.driver) for edge in edges])
513
+
514
+ end = time()
515
+ logger.info(f'Completed add_episode_bulk in {(end - start) * 1000} ms')
516
+
517
+ except Exception as e:
518
+ raise e
519
+
520
+ async def search(self, query: str, center_node_uuid: str | None = None, num_results=10):
521
+ """
522
+ Perform a hybrid search on the knowledge graph.
523
+
524
+ This method executes a search query on the graph, combining vector and
525
+ text-based search techniques to retrieve relevant facts.
526
+
527
+ Parameters
528
+ ----------
529
+ query : str
530
+ The search query string.
531
+ center_node_uuid: str, optional
532
+ Facts will be reranked based on proximity to this node
533
+ num_results : int, optional
534
+ The maximum number of results to return. Defaults to 10.
535
+
536
+ Returns
537
+ -------
538
+ list
539
+ A list of EntityEdge objects that are relevant to the search query.
540
+
541
+ Notes
542
+ -----
543
+ This method uses a SearchConfig with num_episodes set to 0 and
544
+ num_results set to the provided num_results parameter. It then calls
545
+ the hybrid_search function to perform the actual search operation.
546
+
547
+ The search is performed using the current date and time as the reference
548
+ point for temporal relevance.
549
+ """
550
+ reranker = Reranker.rrf if center_node_uuid is None else Reranker.node_distance
551
+ search_config = SearchConfig(
552
+ num_episodes=0,
553
+ num_edges=num_results,
554
+ num_nodes=0,
555
+ search_methods=[SearchMethod.bm25, SearchMethod.cosine_similarity],
556
+ reranker=reranker,
557
+ )
558
+ edges = (
559
+ await hybrid_search(
560
+ self.driver,
561
+ self.llm_client.get_embedder(),
562
+ query,
563
+ datetime.now(),
564
+ search_config,
565
+ center_node_uuid,
566
+ )
567
+ ).edges
568
+
569
+ return edges
570
+
571
+ async def _search(
572
+ self,
573
+ query: str,
574
+ timestamp: datetime,
575
+ config: SearchConfig,
576
+ center_node_uuid: str | None = None,
577
+ ):
578
+ return await hybrid_search(
579
+ self.driver, self.llm_client.get_embedder(), query, timestamp, config, center_node_uuid
580
+ )
581
+
582
+ async def get_nodes_by_query(self, query: str, limit: int | None = None) -> list[EntityNode]:
583
+ """
584
+ Retrieve nodes from the graph database based on a text query.
585
+
586
+ This method performs a hybrid search using both text-based and
587
+ embedding-based approaches to find relevant nodes.
588
+
589
+ Parameters
590
+ ----------
591
+ query : str
592
+ The text query to search for in the graph.
593
+ limit : int | None, optional
594
+ The maximum number of results to return per search method.
595
+ If None, a default limit will be applied.
596
+
597
+ Returns
598
+ -------
599
+ list[EntityNode]
600
+ A list of EntityNode objects that match the search criteria.
601
+
602
+ Notes
603
+ -----
604
+ This method uses the following steps:
605
+ 1. Generates an embedding for the input query using the LLM client's embedder.
606
+ 2. Calls the hybrid_node_search function with both the text query and its embedding.
607
+ 3. The hybrid search combines fulltext search and vector similarity search
608
+ to find the most relevant nodes.
609
+
610
+ The method leverages the LLM client's embedding capabilities to enhance
611
+ the search with semantic similarity matching. The 'limit' parameter is applied
612
+ to each individual search method before results are combined and deduplicated.
613
+ If not specified, a default limit (defined in the search functions) will be used.
614
+ """
615
+ embedder = self.llm_client.get_embedder()
616
+ query_embedding = await generate_embedding(embedder, query)
617
+ relevant_nodes = await hybrid_node_search([query], [query_embedding], self.driver, limit)
618
+ return relevant_nodes
@@ -0,0 +1,7 @@
1
+ from datetime import datetime
2
+
3
+ from neo4j import time as neo4j_time
4
+
5
+
6
+ def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None:
7
+ return neo_date.to_native() if neo_date else None
@@ -0,0 +1,5 @@
1
+ from .client import LLMClient
2
+ from .config import LLMConfig
3
+ from .openai_client import OpenAIClient
4
+
5
+ __all__ = ['LLMClient', 'OpenAIClient', 'LLMConfig']
@@ -0,0 +1,63 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import json
18
+ import logging
19
+ import typing
20
+
21
+ from anthropic import AsyncAnthropic
22
+ from openai import AsyncOpenAI
23
+
24
+ from ..prompts.models import Message
25
+ from .client import LLMClient
26
+ from .config import LLMConfig
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+ DEFAULT_MODEL = 'claude-3-5-sonnet-20240620'
31
+
32
+
33
+ class AnthropicClient(LLMClient):
34
+ def __init__(self, config: LLMConfig | None = None, cache: bool = False):
35
+ if config is None:
36
+ config = LLMConfig()
37
+ super().__init__(config, cache)
38
+ self.client = AsyncAnthropic(api_key=config.api_key)
39
+
40
+ def get_embedder(self) -> typing.Any:
41
+ openai_client = AsyncOpenAI()
42
+ return openai_client.embeddings
43
+
44
+ async def _generate_response(self, messages: list[Message]) -> dict[str, typing.Any]:
45
+ system_message = messages[0]
46
+ user_messages = [{'role': m.role, 'content': m.content} for m in messages[1:]] + [
47
+ {'role': 'assistant', 'content': '{'}
48
+ ]
49
+
50
+ try:
51
+ result = await self.client.messages.create(
52
+ system='Only include JSON in the response. Do not include any additional text or explanation of the content.\n'
53
+ + system_message.content,
54
+ max_tokens=self.max_tokens,
55
+ temperature=self.temperature,
56
+ messages=user_messages, # type: ignore
57
+ model=self.model or DEFAULT_MODEL,
58
+ )
59
+
60
+ return json.loads('{' + result.content[0].text) # type: ignore
61
+ except Exception as e:
62
+ logger.error(f'Error in generating LLM response: {e}')
63
+ raise