graphiti-core 0.12.4__tar.gz → 0.13.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

Files changed (75) hide show
  1. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/PKG-INFO +51 -2
  2. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/README.md +49 -0
  3. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/graphiti_core/errors.py +8 -0
  4. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/graphiti_core/graphiti.py +60 -12
  5. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/graphiti_core/helpers.py +31 -2
  6. graphiti_core-0.13.1/graphiti_core/llm_client/azure_openai_client.py +77 -0
  7. graphiti_core-0.12.4/graphiti_core/llm_client/openai_client.py → graphiti_core-0.13.1/graphiti_core/llm_client/openai_base_client.py +88 -62
  8. graphiti_core-0.13.1/graphiti_core/llm_client/openai_client.py +95 -0
  9. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/graphiti_core/models/edges/edge_db_queries.py +1 -1
  10. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/graphiti_core/prompts/dedupe_nodes.py +8 -8
  11. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/graphiti_core/prompts/extract_edges.py +1 -2
  12. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/graphiti_core/search/search_utils.py +11 -11
  13. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/graphiti_core/utils/maintenance/node_operations.py +1 -1
  14. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/pyproject.toml +2 -2
  15. graphiti_core-0.12.4/graphiti_core/llm_client/azure_openai_client.py +0 -73
  16. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/LICENSE +0 -0
  17. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/graphiti_core/__init__.py +0 -0
  18. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/graphiti_core/cross_encoder/__init__.py +0 -0
  19. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/graphiti_core/cross_encoder/bge_reranker_client.py +0 -0
  20. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/graphiti_core/cross_encoder/client.py +0 -0
  21. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/graphiti_core/cross_encoder/openai_reranker_client.py +0 -0
  22. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/graphiti_core/driver/__init__.py +0 -0
  23. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/graphiti_core/driver/driver.py +0 -0
  24. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/graphiti_core/driver/falkordb_driver.py +0 -0
  25. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/graphiti_core/driver/neo4j_driver.py +0 -0
  26. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/graphiti_core/edges.py +0 -0
  27. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/graphiti_core/embedder/__init__.py +0 -0
  28. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/graphiti_core/embedder/azure_openai.py +0 -0
  29. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/graphiti_core/embedder/client.py +0 -0
  30. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/graphiti_core/embedder/gemini.py +0 -0
  31. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/graphiti_core/embedder/openai.py +0 -0
  32. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/graphiti_core/embedder/voyage.py +0 -0
  33. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/graphiti_core/graph_queries.py +0 -0
  34. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/graphiti_core/graphiti_types.py +0 -0
  35. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/graphiti_core/llm_client/__init__.py +0 -0
  36. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/graphiti_core/llm_client/anthropic_client.py +0 -0
  37. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/graphiti_core/llm_client/client.py +0 -0
  38. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/graphiti_core/llm_client/config.py +0 -0
  39. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/graphiti_core/llm_client/errors.py +0 -0
  40. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/graphiti_core/llm_client/gemini_client.py +0 -0
  41. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/graphiti_core/llm_client/groq_client.py +0 -0
  42. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/graphiti_core/llm_client/openai_generic_client.py +0 -0
  43. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/graphiti_core/llm_client/utils.py +0 -0
  44. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/graphiti_core/models/__init__.py +0 -0
  45. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/graphiti_core/models/edges/__init__.py +0 -0
  46. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/graphiti_core/models/nodes/__init__.py +0 -0
  47. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/graphiti_core/models/nodes/node_db_queries.py +0 -0
  48. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/graphiti_core/nodes.py +0 -0
  49. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/graphiti_core/prompts/__init__.py +0 -0
  50. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/graphiti_core/prompts/dedupe_edges.py +0 -0
  51. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/graphiti_core/prompts/eval.py +0 -0
  52. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/graphiti_core/prompts/extract_edge_dates.py +0 -0
  53. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/graphiti_core/prompts/extract_nodes.py +0 -0
  54. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/graphiti_core/prompts/invalidate_edges.py +0 -0
  55. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/graphiti_core/prompts/lib.py +0 -0
  56. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/graphiti_core/prompts/models.py +0 -0
  57. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/graphiti_core/prompts/prompt_helpers.py +0 -0
  58. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/graphiti_core/prompts/summarize_nodes.py +0 -0
  59. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/graphiti_core/py.typed +0 -0
  60. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/graphiti_core/search/__init__.py +0 -0
  61. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/graphiti_core/search/search.py +0 -0
  62. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/graphiti_core/search/search_config.py +0 -0
  63. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/graphiti_core/search/search_config_recipes.py +0 -0
  64. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/graphiti_core/search/search_filters.py +0 -0
  65. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/graphiti_core/search/search_helpers.py +0 -0
  66. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/graphiti_core/utils/__init__.py +0 -0
  67. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/graphiti_core/utils/bulk_utils.py +0 -0
  68. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/graphiti_core/utils/datetime_utils.py +0 -0
  69. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/graphiti_core/utils/maintenance/__init__.py +0 -0
  70. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/graphiti_core/utils/maintenance/community_operations.py +0 -0
  71. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/graphiti_core/utils/maintenance/edge_operations.py +0 -0
  72. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/graphiti_core/utils/maintenance/graph_data_operations.py +0 -0
  73. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/graphiti_core/utils/maintenance/temporal_operations.py +0 -0
  74. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/graphiti_core/utils/maintenance/utils.py +0 -0
  75. {graphiti_core-0.12.4 → graphiti_core-0.13.1}/graphiti_core/utils/ontology_utils/entity_types_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: graphiti-core
3
- Version: 0.12.4
3
+ Version: 0.13.1
4
4
  Summary: A temporal graph building library
5
5
  License: Apache-2.0
6
6
  Author: Paul Paliychuk
@@ -23,7 +23,7 @@ Requires-Dist: google-genai (>=1.8.0) ; extra == "google-genai"
23
23
  Requires-Dist: groq (>=0.2.0) ; extra == "groq"
24
24
  Requires-Dist: neo4j (>=5.26.0)
25
25
  Requires-Dist: numpy (>=1.0.0)
26
- Requires-Dist: openai (>=1.53.0)
26
+ Requires-Dist: openai (>=1.91.0)
27
27
  Requires-Dist: pydantic (>=2.11.5)
28
28
  Requires-Dist: python-dotenv (>=1.0.1)
29
29
  Requires-Dist: tenacity (>=9.0.0)
@@ -331,6 +331,55 @@ graphiti = Graphiti(
331
331
  # Now you can use Graphiti with Google Gemini
332
332
  ```
333
333
 
334
+ ## Using Graphiti with Ollama (Local LLM)
335
+
336
+ Graphiti supports Ollama for running local LLMs and embedding models via Ollama's OpenAI-compatible API. This is ideal for privacy-focused applications or when you want to avoid API costs.
337
+
338
+
339
+ Install the models:
340
+ ollama pull deepseek-r1:7b # LLM
341
+ ollama pull nomic-embed-text # embeddings
342
+
343
+ ```python
344
+ from graphiti_core import Graphiti
345
+ from graphiti_core.llm_client.config import LLMConfig
346
+ from graphiti_core.llm_client.openai_client import OpenAIClient
347
+ from graphiti_core.embedder.openai import OpenAIEmbedder, OpenAIEmbedderConfig
348
+ from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
349
+
350
+ # Configure Ollama LLM client
351
+ llm_config = LLMConfig(
352
+ api_key="abc", # Ollama doesn't require a real API key
353
+ model="deepseek-r1:7b",
354
+ small_model="deepseek-r1:7b",
355
+ base_url="http://localhost:11434/v1", # Ollama provides this port
356
+ )
357
+
358
+ llm_client = OpenAIClient(config=llm_config)
359
+
360
+ # Initialize Graphiti with Ollama clients
361
+ graphiti = Graphiti(
362
+ "bolt://localhost:7687",
363
+ "neo4j",
364
+ "password",
365
+ llm_client=llm_client,
366
+ embedder=OpenAIEmbedder(
367
+ config=OpenAIEmbedderConfig(
368
+ api_key="abc",
369
+ embedding_model="nomic-embed-text",
370
+ embedding_dim=768,
371
+ base_url="http://localhost:11434/v1",
372
+ )
373
+ ),
374
+ cross_encoder=OpenAIRerankerClient(client=llm_client, config=llm_config),
375
+ )
376
+
377
+ # Now you can use Graphiti with local Ollama models
378
+ ```
379
+
380
+ Ensure Ollama is running (`ollama serve`) and that you have pulled the models you want to use.
381
+
382
+
334
383
  ## Documentation
335
384
 
336
385
  - [Guides and API documentation](https://help.getzep.com/graphiti).
@@ -298,6 +298,55 @@ graphiti = Graphiti(
298
298
  # Now you can use Graphiti with Google Gemini
299
299
  ```
300
300
 
301
+ ## Using Graphiti with Ollama (Local LLM)
302
+
303
+ Graphiti supports Ollama for running local LLMs and embedding models via Ollama's OpenAI-compatible API. This is ideal for privacy-focused applications or when you want to avoid API costs.
304
+
305
+
306
+ Install the models:
307
+ ollama pull deepseek-r1:7b # LLM
308
+ ollama pull nomic-embed-text # embeddings
309
+
310
+ ```python
311
+ from graphiti_core import Graphiti
312
+ from graphiti_core.llm_client.config import LLMConfig
313
+ from graphiti_core.llm_client.openai_client import OpenAIClient
314
+ from graphiti_core.embedder.openai import OpenAIEmbedder, OpenAIEmbedderConfig
315
+ from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
316
+
317
+ # Configure Ollama LLM client
318
+ llm_config = LLMConfig(
319
+ api_key="abc", # Ollama doesn't require a real API key
320
+ model="deepseek-r1:7b",
321
+ small_model="deepseek-r1:7b",
322
+ base_url="http://localhost:11434/v1", # Ollama provides this port
323
+ )
324
+
325
+ llm_client = OpenAIClient(config=llm_config)
326
+
327
+ # Initialize Graphiti with Ollama clients
328
+ graphiti = Graphiti(
329
+ "bolt://localhost:7687",
330
+ "neo4j",
331
+ "password",
332
+ llm_client=llm_client,
333
+ embedder=OpenAIEmbedder(
334
+ config=OpenAIEmbedderConfig(
335
+ api_key="abc",
336
+ embedding_model="nomic-embed-text",
337
+ embedding_dim=768,
338
+ base_url="http://localhost:11434/v1",
339
+ )
340
+ ),
341
+ cross_encoder=OpenAIRerankerClient(client=llm_client, config=llm_config),
342
+ )
343
+
344
+ # Now you can use Graphiti with local Ollama models
345
+ ```
346
+
347
+ Ensure Ollama is running (`ollama serve`) and that you have pulled the models you want to use.
348
+
349
+
301
350
  ## Documentation
302
351
 
303
352
  - [Guides and API documentation](https://help.getzep.com/graphiti).
@@ -73,3 +73,11 @@ class EntityTypeValidationError(GraphitiError):
73
73
  def __init__(self, entity_type: str, entity_type_attribute: str):
74
74
  self.message = f'{entity_type_attribute} cannot be used as an attribute for {entity_type} as it is a protected attribute name.'
75
75
  super().__init__(self.message)
76
+
77
+
78
+ class GroupIdValidationError(GraphitiError):
79
+ """Raised when a group_id contains invalid characters."""
80
+
81
+ def __init__(self, group_id: str):
82
+ self.message = f'group_id "{group_id}" must contain only alphanumeric characters, dashes, or underscores'
83
+ super().__init__(self.message)
@@ -29,7 +29,7 @@ from graphiti_core.driver.neo4j_driver import Neo4jDriver
29
29
  from graphiti_core.edges import EntityEdge, EpisodicEdge
30
30
  from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
31
31
  from graphiti_core.graphiti_types import GraphitiClients
32
- from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
32
+ from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather, validate_group_id
33
33
  from graphiti_core.llm_client import LLMClient, OpenAIClient
34
34
  from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
35
35
  from graphiti_core.search.search import SearchConfig, search
@@ -103,6 +103,7 @@ class Graphiti:
103
103
  cross_encoder: CrossEncoderClient | None = None,
104
104
  store_raw_episode_content: bool = True,
105
105
  graph_driver: GraphDriver | None = None,
106
+ max_coroutines: int | None = None,
106
107
  ):
107
108
  """
108
109
  Initialize a Graphiti instance.
@@ -121,6 +122,20 @@ class Graphiti:
121
122
  llm_client : LLMClient | None, optional
122
123
  An instance of LLMClient for natural language processing tasks.
123
124
  If not provided, a default OpenAIClient will be initialized.
125
+ embedder : EmbedderClient | None, optional
126
+ An instance of EmbedderClient for embedding tasks.
127
+ If not provided, a default OpenAIEmbedder will be initialized.
128
+ cross_encoder : CrossEncoderClient | None, optional
129
+ An instance of CrossEncoderClient for reranking tasks.
130
+ If not provided, a default OpenAIRerankerClient will be initialized.
131
+ store_raw_episode_content : bool, optional
132
+ Whether to store the raw content of episodes. Defaults to True.
133
+ graph_driver : GraphDriver | None, optional
134
+ An instance of GraphDriver for database operations.
135
+ If not provided, a default Neo4jDriver will be initialized.
136
+ max_coroutines : int | None, optional
137
+ The maximum number of concurrent operations allowed. Overrides SEMAPHORE_LIMIT set in the environment.
138
+ If not set, the Graphiti default is used.
124
139
 
125
140
  Returns
126
141
  -------
@@ -145,6 +160,7 @@ class Graphiti:
145
160
 
146
161
  self.database = DEFAULT_DATABASE
147
162
  self.store_raw_episode_content = store_raw_episode_content
163
+ self.max_coroutines = max_coroutines
148
164
  if llm_client:
149
165
  self.llm_client = llm_client
150
166
  else:
@@ -335,6 +351,7 @@ class Graphiti:
335
351
  now = utc_now()
336
352
 
337
353
  validate_entity_types(entity_types)
354
+ validate_group_id(group_id)
338
355
 
339
356
  previous_episodes = (
340
357
  await self.retrieve_episodes(
@@ -393,6 +410,7 @@ class Graphiti:
393
410
  group_id,
394
411
  edge_types,
395
412
  ),
413
+ max_coroutines=self.max_coroutines,
396
414
  )
397
415
 
398
416
  edges = resolve_edge_pointers(extracted_edges, uuid_map)
@@ -409,6 +427,7 @@ class Graphiti:
409
427
  extract_attributes_from_nodes(
410
428
  self.clients, nodes, episode, previous_episodes, entity_types
411
429
  ),
430
+ max_coroutines=self.max_coroutines,
412
431
  )
413
432
 
414
433
  duplicate_of_edges = build_duplicate_of_edges(episode, now, node_duplicates)
@@ -432,7 +451,8 @@ class Graphiti:
432
451
  *[
433
452
  update_community(self.driver, self.llm_client, self.embedder, node)
434
453
  for node in nodes
435
- ]
454
+ ],
455
+ max_coroutines=self.max_coroutines,
436
456
  )
437
457
  end = time()
438
458
  logger.info(f'Completed add_episode in {(end - start) * 1000} ms')
@@ -484,6 +504,8 @@ class Graphiti:
484
504
  start = time()
485
505
  now = utc_now()
486
506
 
507
+ validate_group_id(group_id)
508
+
487
509
  episodes = [
488
510
  EpisodicNode(
489
511
  name=episode.name,
@@ -499,7 +521,10 @@ class Graphiti:
499
521
  ]
500
522
 
501
523
  # Save all the episodes
502
- await semaphore_gather(*[episode.save(self.driver) for episode in episodes])
524
+ await semaphore_gather(
525
+ *[episode.save(self.driver) for episode in episodes],
526
+ max_coroutines=self.max_coroutines,
527
+ )
503
528
 
504
529
  # Get previous episode context for each episode
505
530
  episode_pairs = await retrieve_previous_episodes_bulk(self.driver, episodes)
@@ -515,16 +540,21 @@ class Graphiti:
515
540
  await semaphore_gather(
516
541
  *[node.generate_name_embedding(self.embedder) for node in extracted_nodes],
517
542
  *[edge.generate_embedding(self.embedder) for edge in extracted_edges],
543
+ max_coroutines=self.max_coroutines,
518
544
  )
519
545
 
520
546
  # Dedupe extracted nodes, compress extracted edges
521
547
  (nodes, uuid_map), extracted_edges_timestamped = await semaphore_gather(
522
548
  dedupe_nodes_bulk(self.driver, self.llm_client, extracted_nodes),
523
549
  extract_edge_dates_bulk(self.llm_client, extracted_edges, episode_pairs),
550
+ max_coroutines=self.max_coroutines,
524
551
  )
525
552
 
526
553
  # save nodes to KG
527
- await semaphore_gather(*[node.save(self.driver) for node in nodes])
554
+ await semaphore_gather(
555
+ *[node.save(self.driver) for node in nodes],
556
+ max_coroutines=self.max_coroutines,
557
+ )
528
558
 
529
559
  # re-map edge pointers so that they don't point to discard dupe nodes
530
560
  extracted_edges_with_resolved_pointers: list[EntityEdge] = resolve_edge_pointers(
@@ -536,7 +566,8 @@ class Graphiti:
536
566
 
537
567
  # save episodic edges to KG
538
568
  await semaphore_gather(
539
- *[edge.save(self.driver) for edge in episodic_edges_with_resolved_pointers]
569
+ *[edge.save(self.driver) for edge in episodic_edges_with_resolved_pointers],
570
+ max_coroutines=self.max_coroutines,
540
571
  )
541
572
 
542
573
  # Dedupe extracted edges
@@ -548,7 +579,10 @@ class Graphiti:
548
579
  # invalidate edges
549
580
 
550
581
  # save edges to KG
551
- await semaphore_gather(*[edge.save(self.driver) for edge in edges])
582
+ await semaphore_gather(
583
+ *[edge.save(self.driver) for edge in edges],
584
+ max_coroutines=self.max_coroutines,
585
+ )
552
586
 
553
587
  end = time()
554
588
  logger.info(f'Completed add_episode_bulk in {(end - start) * 1000} ms')
@@ -572,11 +606,18 @@ class Graphiti:
572
606
  )
573
607
 
574
608
  await semaphore_gather(
575
- *[node.generate_name_embedding(self.embedder) for node in community_nodes]
609
+ *[node.generate_name_embedding(self.embedder) for node in community_nodes],
610
+ max_coroutines=self.max_coroutines,
576
611
  )
577
612
 
578
- await semaphore_gather(*[node.save(self.driver) for node in community_nodes])
579
- await semaphore_gather(*[edge.save(self.driver) for edge in community_edges])
613
+ await semaphore_gather(
614
+ *[node.save(self.driver) for node in community_nodes],
615
+ max_coroutines=self.max_coroutines,
616
+ )
617
+ await semaphore_gather(
618
+ *[edge.save(self.driver) for edge in community_edges],
619
+ max_coroutines=self.max_coroutines,
620
+ )
580
621
 
581
622
  return community_nodes
582
623
 
@@ -683,7 +724,8 @@ class Graphiti:
683
724
  episodes = await EpisodicNode.get_by_uuids(self.driver, episode_uuids)
684
725
 
685
726
  edges_list = await semaphore_gather(
686
- *[EntityEdge.get_by_uuids(self.driver, episode.entity_edges) for episode in episodes]
727
+ *[EntityEdge.get_by_uuids(self.driver, episode.entity_edges) for episode in episodes],
728
+ max_coroutines=self.max_coroutines,
687
729
  )
688
730
 
689
731
  edges: list[EntityEdge] = [edge for lst in edges_list for edge in lst]
@@ -759,6 +801,12 @@ class Graphiti:
759
801
  if record['episode_count'] == 1:
760
802
  nodes_to_delete.append(node)
761
803
 
762
- await semaphore_gather(*[node.delete(self.driver) for node in nodes_to_delete])
763
- await semaphore_gather(*[edge.delete(self.driver) for edge in edges_to_delete])
804
+ await semaphore_gather(
805
+ *[node.delete(self.driver) for node in nodes_to_delete],
806
+ max_coroutines=self.max_coroutines,
807
+ )
808
+ await semaphore_gather(
809
+ *[edge.delete(self.driver) for edge in edges_to_delete],
810
+ max_coroutines=self.max_coroutines,
811
+ )
764
812
  await episode.delete(self.driver)
@@ -16,6 +16,7 @@ limitations under the License.
16
16
 
17
17
  import asyncio
18
18
  import os
19
+ import re
19
20
  from collections.abc import Coroutine
20
21
  from datetime import datetime
21
22
 
@@ -25,6 +26,8 @@ from neo4j import time as neo4j_time
25
26
  from numpy._typing import NDArray
26
27
  from typing_extensions import LiteralString
27
28
 
29
+ from graphiti_core.errors import GroupIdValidationError
30
+
28
31
  load_dotenv()
29
32
 
30
33
  DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', 'neo4j')
@@ -94,12 +97,38 @@ def normalize_l2(embedding: list[float]) -> NDArray:
94
97
  # Use this instead of asyncio.gather() to bound coroutines
95
98
  async def semaphore_gather(
96
99
  *coroutines: Coroutine,
97
- max_coroutines: int = SEMAPHORE_LIMIT,
100
+ max_coroutines: int | None = None,
98
101
  ):
99
- semaphore = asyncio.Semaphore(max_coroutines)
102
+ semaphore = asyncio.Semaphore(max_coroutines or SEMAPHORE_LIMIT)
100
103
 
101
104
  async def _wrap_coroutine(coroutine):
102
105
  async with semaphore:
103
106
  return await coroutine
104
107
 
105
108
  return await asyncio.gather(*(_wrap_coroutine(coroutine) for coroutine in coroutines))
109
+
110
+
111
+ def validate_group_id(group_id: str) -> bool:
112
+ """
113
+ Validate that a group_id contains only ASCII alphanumeric characters, dashes, and underscores.
114
+
115
+ Args:
116
+ group_id: The group_id to validate
117
+
118
+ Returns:
119
+ True if valid, False otherwise
120
+
121
+ Raises:
122
+ GroupIdValidationError: If group_id contains invalid characters
123
+ """
124
+
125
+ # Allow empty string (default case)
126
+ if not group_id:
127
+ return True
128
+
129
+ # Check if string contains only ASCII alphanumeric characters, dashes, or underscores
130
+ # Pattern matches: letters (a-z, A-Z), digits (0-9), hyphens (-), and underscores (_)
131
+ if not re.match(r'^[a-zA-Z0-9_-]+$', group_id):
132
+ raise GroupIdValidationError(group_id)
133
+
134
+ return True
@@ -0,0 +1,77 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import logging
18
+ from typing import ClassVar
19
+
20
+ from openai import AsyncAzureOpenAI
21
+ from openai.types.chat import ChatCompletionMessageParam
22
+ from pydantic import BaseModel
23
+
24
+ from .config import DEFAULT_MAX_TOKENS, LLMConfig
25
+ from .openai_base_client import BaseOpenAIClient
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ class AzureOpenAILLMClient(BaseOpenAIClient):
31
+ """Wrapper class for AsyncAzureOpenAI that implements the LLMClient interface."""
32
+
33
+ # Class-level constants
34
+ MAX_RETRIES: ClassVar[int] = 2
35
+
36
+ def __init__(
37
+ self,
38
+ azure_client: AsyncAzureOpenAI,
39
+ config: LLMConfig | None = None,
40
+ max_tokens: int = DEFAULT_MAX_TOKENS,
41
+ ):
42
+ super().__init__(config, cache=False, max_tokens=max_tokens)
43
+ self.client = azure_client
44
+
45
+ async def _create_structured_completion(
46
+ self,
47
+ model: str,
48
+ messages: list[ChatCompletionMessageParam],
49
+ temperature: float | None,
50
+ max_tokens: int,
51
+ response_model: type[BaseModel],
52
+ ):
53
+ """Create a structured completion using Azure OpenAI's beta parse API."""
54
+ return await self.client.beta.chat.completions.parse(
55
+ model=model,
56
+ messages=messages,
57
+ temperature=temperature,
58
+ max_tokens=max_tokens,
59
+ response_format=response_model, # type: ignore
60
+ )
61
+
62
+ async def _create_completion(
63
+ self,
64
+ model: str,
65
+ messages: list[ChatCompletionMessageParam],
66
+ temperature: float | None,
67
+ max_tokens: int,
68
+ response_model: type[BaseModel] | None = None,
69
+ ):
70
+ """Create a regular completion with JSON format using Azure OpenAI."""
71
+ return await self.client.chat.completions.create(
72
+ model=model,
73
+ messages=messages,
74
+ temperature=temperature,
75
+ max_tokens=max_tokens,
76
+ response_format={'type': 'json_object'},
77
+ )
@@ -14,12 +14,13 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
+ import json
17
18
  import logging
18
19
  import typing
19
- from typing import ClassVar
20
+ from abc import abstractmethod
21
+ from typing import Any, ClassVar
20
22
 
21
23
  import openai
22
- from openai import AsyncOpenAI
23
24
  from openai.types.chat import ChatCompletionMessageParam
24
25
  from pydantic import BaseModel
25
26
 
@@ -34,25 +35,12 @@ DEFAULT_MODEL = 'gpt-4.1-mini'
34
35
  DEFAULT_SMALL_MODEL = 'gpt-4.1-nano'
35
36
 
36
37
 
37
- class OpenAIClient(LLMClient):
38
+ class BaseOpenAIClient(LLMClient):
38
39
  """
39
- OpenAIClient is a client class for interacting with OpenAI's language models.
40
+ Base client class for OpenAI-compatible APIs (OpenAI and Azure OpenAI).
40
41
 
41
- This class extends the LLMClient and provides methods to initialize the client,
42
- get an embedder, and generate responses from the language model.
43
-
44
- Attributes:
45
- client (AsyncOpenAI): The OpenAI client used to interact with the API.
46
- model (str): The model name to use for generating responses.
47
- temperature (float): The temperature to use for generating responses.
48
- max_tokens (int): The maximum number of tokens to generate in a response.
49
-
50
- Methods:
51
- __init__(config: LLMConfig | None = None, cache: bool = False, client: typing.Any = None):
52
- Initializes the OpenAIClient with the provided configuration, cache setting, and client.
53
-
54
- _generate_response(messages: list[Message]) -> dict[str, typing.Any]:
55
- Generates a response from the language model based on the provided messages.
42
+ This class contains shared logic for both OpenAI and Azure OpenAI clients,
43
+ reducing code duplication while allowing for implementation-specific differences.
56
44
  """
57
45
 
58
46
  # Class-level constants
@@ -62,41 +50,45 @@ class OpenAIClient(LLMClient):
62
50
  self,
63
51
  config: LLMConfig | None = None,
64
52
  cache: bool = False,
65
- client: typing.Any = None,
66
53
  max_tokens: int = DEFAULT_MAX_TOKENS,
67
54
  ):
68
- """
69
- Initialize the OpenAIClient with the provided configuration, cache setting, and client.
70
-
71
- Args:
72
- config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens.
73
- cache (bool): Whether to use caching for responses. Defaults to False.
74
- client (Any | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created.
75
-
76
- """
77
- # removed caching to simplify the `generate_response` override
78
55
  if cache:
79
- raise NotImplementedError('Caching is not implemented for OpenAI')
56
+ raise NotImplementedError('Caching is not implemented for OpenAI-based clients')
80
57
 
81
58
  if config is None:
82
59
  config = LLMConfig()
83
60
 
84
61
  super().__init__(config, cache)
85
-
86
- if client is None:
87
- self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
88
- else:
89
- self.client = client
90
-
91
62
  self.max_tokens = max_tokens
92
63
 
93
- async def _generate_response(
64
+ @abstractmethod
65
+ async def _create_completion(
94
66
  self,
95
- messages: list[Message],
67
+ model: str,
68
+ messages: list[ChatCompletionMessageParam],
69
+ temperature: float | None,
70
+ max_tokens: int,
96
71
  response_model: type[BaseModel] | None = None,
97
- max_tokens: int = DEFAULT_MAX_TOKENS,
98
- model_size: ModelSize = ModelSize.medium,
99
- ) -> dict[str, typing.Any]:
72
+ ) -> Any:
73
+ """Create a completion using the specific client implementation."""
74
+ pass
75
+
76
+ @abstractmethod
77
+ async def _create_structured_completion(
78
+ self,
79
+ model: str,
80
+ messages: list[ChatCompletionMessageParam],
81
+ temperature: float | None,
82
+ max_tokens: int,
83
+ response_model: type[BaseModel],
84
+ ) -> Any:
85
+ """Create a structured completion using the specific client implementation."""
86
+ pass
87
+
88
+ def _convert_messages_to_openai_format(
89
+ self, messages: list[Message]
90
+ ) -> list[ChatCompletionMessageParam]:
91
+ """Convert internal Message format to OpenAI ChatCompletionMessageParam format."""
100
92
  openai_messages: list[ChatCompletionMessageParam] = []
101
93
  for m in messages:
102
94
  m.content = self._clean_input(m.content)
@@ -104,28 +96,61 @@ class OpenAIClient(LLMClient):
104
96
  openai_messages.append({'role': 'user', 'content': m.content})
105
97
  elif m.role == 'system':
106
98
  openai_messages.append({'role': 'system', 'content': m.content})
99
+ return openai_messages
100
+
101
+ def _get_model_for_size(self, model_size: ModelSize) -> str:
102
+ """Get the appropriate model name based on the requested size."""
103
+ if model_size == ModelSize.small:
104
+ return self.small_model or DEFAULT_SMALL_MODEL
105
+ else:
106
+ return self.model or DEFAULT_MODEL
107
+
108
+ def _handle_structured_response(self, response: Any) -> dict[str, Any]:
109
+ """Handle structured response parsing and validation."""
110
+ response_object = response.choices[0].message
111
+
112
+ if response_object.parsed:
113
+ return response_object.parsed.model_dump()
114
+ elif response_object.refusal:
115
+ raise RefusalError(response_object.refusal)
116
+ else:
117
+ raise Exception(f'Invalid response from LLM: {response_object.model_dump()}')
118
+
119
+ def _handle_json_response(self, response: Any) -> dict[str, Any]:
120
+ """Handle JSON response parsing."""
121
+ result = response.choices[0].message.content or '{}'
122
+ return json.loads(result)
123
+
124
+ async def _generate_response(
125
+ self,
126
+ messages: list[Message],
127
+ response_model: type[BaseModel] | None = None,
128
+ max_tokens: int = DEFAULT_MAX_TOKENS,
129
+ model_size: ModelSize = ModelSize.medium,
130
+ ) -> dict[str, Any]:
131
+ """Generate a response using the appropriate client implementation."""
132
+ openai_messages = self._convert_messages_to_openai_format(messages)
133
+ model = self._get_model_for_size(model_size)
134
+
107
135
  try:
108
- if model_size == ModelSize.small:
109
- model = self.small_model or DEFAULT_SMALL_MODEL
110
- else:
111
- model = self.model or DEFAULT_MODEL
112
-
113
- response = await self.client.beta.chat.completions.parse(
114
- model=model,
115
- messages=openai_messages,
116
- temperature=self.temperature,
117
- max_tokens=max_tokens or self.max_tokens,
118
- response_format=response_model, # type: ignore
119
- )
120
-
121
- response_object = response.choices[0].message
122
-
123
- if response_object.parsed:
124
- return response_object.parsed.model_dump()
125
- elif response_object.refusal:
126
- raise RefusalError(response_object.refusal)
136
+ if response_model:
137
+ response = await self._create_structured_completion(
138
+ model=model,
139
+ messages=openai_messages,
140
+ temperature=self.temperature,
141
+ max_tokens=max_tokens or self.max_tokens,
142
+ response_model=response_model,
143
+ )
144
+ return self._handle_structured_response(response)
127
145
  else:
128
- raise Exception(f'Invalid response from LLM: {response_object.model_dump()}')
146
+ response = await self._create_completion(
147
+ model=model,
148
+ messages=openai_messages,
149
+ temperature=self.temperature,
150
+ max_tokens=max_tokens or self.max_tokens,
151
+ )
152
+ return self._handle_json_response(response)
153
+
129
154
  except openai.LengthFinishReasonError as e:
130
155
  raise Exception(f'Output length exceeded max tokens {self.max_tokens}: {e}') from e
131
156
  except openai.RateLimitError as e:
@@ -141,6 +166,7 @@ class OpenAIClient(LLMClient):
141
166
  max_tokens: int | None = None,
142
167
  model_size: ModelSize = ModelSize.medium,
143
168
  ) -> dict[str, typing.Any]:
169
+ """Generate a response with retry logic and error handling."""
144
170
  if max_tokens is None:
145
171
  max_tokens = self.max_tokens
146
172
 
@@ -0,0 +1,95 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import typing
18
+
19
+ from openai import AsyncOpenAI
20
+ from openai.types.chat import ChatCompletionMessageParam
21
+ from pydantic import BaseModel
22
+
23
+ from .config import DEFAULT_MAX_TOKENS, LLMConfig
24
+ from .openai_base_client import BaseOpenAIClient
25
+
26
+
27
+ class OpenAIClient(BaseOpenAIClient):
28
+ """
29
+ OpenAIClient is a client class for interacting with OpenAI's language models.
30
+
31
+ This class extends the BaseOpenAIClient and provides OpenAI-specific implementation
32
+ for creating completions.
33
+
34
+ Attributes:
35
+ client (AsyncOpenAI): The OpenAI client used to interact with the API.
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ config: LLMConfig | None = None,
41
+ cache: bool = False,
42
+ client: typing.Any = None,
43
+ max_tokens: int = DEFAULT_MAX_TOKENS,
44
+ ):
45
+ """
46
+ Initialize the OpenAIClient with the provided configuration, cache setting, and client.
47
+
48
+ Args:
49
+ config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens.
50
+ cache (bool): Whether to use caching for responses. Defaults to False.
51
+ client (Any | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created.
52
+ """
53
+ super().__init__(config, cache, max_tokens)
54
+
55
+ if config is None:
56
+ config = LLMConfig()
57
+
58
+ if client is None:
59
+ self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
60
+ else:
61
+ self.client = client
62
+
63
+ async def _create_structured_completion(
64
+ self,
65
+ model: str,
66
+ messages: list[ChatCompletionMessageParam],
67
+ temperature: float | None,
68
+ max_tokens: int,
69
+ response_model: type[BaseModel],
70
+ ):
71
+ """Create a structured completion using OpenAI's beta parse API."""
72
+ return await self.client.beta.chat.completions.parse(
73
+ model=model,
74
+ messages=messages,
75
+ temperature=temperature,
76
+ max_tokens=max_tokens,
77
+ response_format=response_model, # type: ignore
78
+ )
79
+
80
+ async def _create_completion(
81
+ self,
82
+ model: str,
83
+ messages: list[ChatCompletionMessageParam],
84
+ temperature: float | None,
85
+ max_tokens: int,
86
+ response_model: type[BaseModel] | None = None,
87
+ ):
88
+ """Create a regular completion with JSON format."""
89
+ return await self.client.chat.completions.create(
90
+ model=model,
91
+ messages=messages,
92
+ temperature=temperature,
93
+ max_tokens=max_tokens,
94
+ response_format={'type': 'json_object'},
95
+ )
@@ -35,7 +35,7 @@ ENTITY_EDGE_SAVE = """
35
35
  MATCH (target:Entity {uuid: $target_uuid})
36
36
  MERGE (source)-[r:RELATES_TO {uuid: $uuid}]->(target)
37
37
  SET r = $edge_data
38
- WITH r CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", $fact_embedding)
38
+ WITH r CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", $edge_data.fact_embedding)
39
39
  RETURN r.uuid AS uuid"""
40
40
 
41
41
  ENTITY_EDGE_SAVE_BULK = """
@@ -30,7 +30,7 @@ class NodeDuplicate(BaseModel):
30
30
  )
31
31
  name: str = Field(
32
32
  ...,
33
- description='Name of the entity. Should be the most complete and descriptive name possible. Do not include any JSON formatting in the Entity name.',
33
+ description='Name of the entity. Should be the most complete and descriptive name of the entity. Do not include any JSON formatting in the Entity name such as {}.',
34
34
  )
35
35
  additional_duplicates: list[int] = Field(
36
36
  ...,
@@ -84,19 +84,19 @@ def node(context: dict[str, Any]) -> list[Message]:
84
84
  is a duplicate entity of one of the EXISTING ENTITIES.
85
85
 
86
86
  Entities should only be considered duplicates if they refer to the *same real-world object or concept*.
87
+ Semantic Equivalence: if a descriptive label in existing_entities clearly refers to a named entity in context, treat them as duplicates.
87
88
 
88
89
  Do NOT mark entities as duplicates if:
89
90
  - They are related but distinct.
90
91
  - They have similar names or purposes but refer to separate instances or concepts.
91
92
 
92
- Task:
93
- If the NEW ENTITY represents a duplicate entity of any entity in EXISTING ENTITIES, set duplicate_entity_id to the
94
- id of the EXISTING ENTITY that is the duplicate.
95
-
96
- If the NEW ENTITY is not a duplicate of any of the EXISTING ENTITIES,
97
- duplicate_entity_id should be set to -1.
93
+ TASK:
94
+ 1. Compare `new_entity` against each item in `existing_entities`.
95
+ 2. If it refers to the same real‐world object or concept, collect its index.
96
+ 3. Let `duplicate_idx` = the *first* collected index, or –1 if none.
97
+ 4. Let `additional_duplicates` = the list of *any other* collected indices (empty list if none).
98
98
 
99
- Also return the name that best describes the NEW ENTITY (whether it is the name of the NEW ENTITY, a node it
99
+ Also return the full name of the NEW ENTITY (whether it is the name of the NEW ENTITY, a node it
100
100
  is a duplicate of, or a combination of the two).
101
101
  """,
102
102
  ),
@@ -97,8 +97,7 @@ Only extract facts that:
97
97
  - The FACT TYPES provide a list of the most important types of facts, make sure to extract facts of these types
98
98
  - The FACT TYPES are not an exhaustive list, extract all facts from the message even if they do not fit into one
99
99
  of the FACT TYPES
100
- - The FACT TYPES each contain their fact_type_signature which represents the entity types which that fact_type is defined for.
101
- A Type of Entity in the signature represents any extracted entity (it is a generic universal type for all entities).
100
+ - The FACT TYPES each contain their fact_type_signature which represents the source and target entity types.
102
101
 
103
102
  You may use information from the PREVIOUS MESSAGES only to disambiguate references or support continuity.
104
103
 
@@ -63,7 +63,7 @@ MAX_QUERY_LENGTH = 32
63
63
 
64
64
  def fulltext_query(query: str, group_ids: list[str] | None = None):
65
65
  group_ids_filter_list = (
66
- [f"group_id-'{lucene_sanitize(g)}'" for g in group_ids] if group_ids is not None else []
66
+ [f'group_id:"{lucene_sanitize(g)}"' for g in group_ids] if group_ids is not None else []
67
67
  )
68
68
  group_ids_filter = ''
69
69
  for f in group_ids_filter_list:
@@ -301,12 +301,12 @@ async def edge_bfs_search(
301
301
 
302
302
  query = (
303
303
  """
304
- UNWIND $bfs_origin_node_uuids AS origin_uuid
305
- MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
306
- UNWIND relationships(path) AS rel
307
- MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity)
308
- WHERE r.uuid = rel.uuid
309
- """
304
+ UNWIND $bfs_origin_node_uuids AS origin_uuid
305
+ MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
306
+ UNWIND relationships(path) AS rel
307
+ MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity)
308
+ WHERE r.uuid = rel.uuid
309
+ """
310
310
  + filter_query
311
311
  + """
312
312
  RETURN DISTINCT
@@ -455,10 +455,10 @@ async def node_bfs_search(
455
455
 
456
456
  query = (
457
457
  """
458
- UNWIND $bfs_origin_node_uuids AS origin_uuid
459
- MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
460
- WHERE n.group_id = origin.group_id
461
- """
458
+ UNWIND $bfs_origin_node_uuids AS origin_uuid
459
+ MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
460
+ WHERE n.group_id = origin.group_id
461
+ """
462
462
  + filter_query
463
463
  + ENTITY_NODE_RETURN
464
464
  + """
@@ -310,7 +310,7 @@ async def resolve_extracted_nodes(
310
310
  else extracted_node
311
311
  )
312
312
 
313
- resolved_node.name = resolution.get('name')
313
+ # resolved_node.name = resolution.get('name')
314
314
 
315
315
  resolved_nodes.append(resolved_node)
316
316
  uuid_map[extracted_node.uuid] = resolved_node.uuid
@@ -1,7 +1,7 @@
1
1
  [project]
2
2
  name = "graphiti-core"
3
3
  description = "A temporal graph building library"
4
- version = "0.12.4"
4
+ version = "0.13.1"
5
5
  authors = [
6
6
  { "name" = "Paul Paliychuk", "email" = "paul@getzep.com" },
7
7
  { "name" = "Preston Rasmussen", "email" = "preston@getzep.com" },
@@ -15,7 +15,7 @@ dependencies = [
15
15
  "pydantic>=2.11.5",
16
16
  "neo4j>=5.26.0",
17
17
  "diskcache>=5.6.3",
18
- "openai>=1.53.0",
18
+ "openai>=1.91.0",
19
19
  "tenacity>=9.0.0",
20
20
  "numpy>=1.0.0",
21
21
  "python-dotenv>=1.0.1",
@@ -1,73 +0,0 @@
1
- """
2
- Copyright 2024, Zep Software, Inc.
3
-
4
- Licensed under the Apache License, Version 2.0 (the "License");
5
- you may not use this file except in compliance with the License.
6
- You may obtain a copy of the License at
7
-
8
- http://www.apache.org/licenses/LICENSE-2.0
9
-
10
- Unless required by applicable law or agreed to in writing, software
11
- distributed under the License is distributed on an "AS IS" BASIS,
12
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- See the License for the specific language governing permissions and
14
- limitations under the License.
15
- """
16
-
17
- import json
18
- import logging
19
- from typing import Any
20
-
21
- from openai import AsyncAzureOpenAI
22
- from openai.types.chat import ChatCompletionMessageParam
23
- from pydantic import BaseModel
24
-
25
- from ..prompts.models import Message
26
- from .client import LLMClient
27
- from .config import LLMConfig, ModelSize
28
-
29
- logger = logging.getLogger(__name__)
30
-
31
-
32
- class AzureOpenAILLMClient(LLMClient):
33
- """Wrapper class for AsyncAzureOpenAI that implements the LLMClient interface."""
34
-
35
- def __init__(self, azure_client: AsyncAzureOpenAI, config: LLMConfig | None = None):
36
- super().__init__(config, cache=False)
37
- self.azure_client = azure_client
38
-
39
- async def _generate_response(
40
- self,
41
- messages: list[Message],
42
- response_model: type[BaseModel] | None = None,
43
- max_tokens: int = 1024,
44
- model_size: ModelSize = ModelSize.medium,
45
- ) -> dict[str, Any]:
46
- """Generate response using Azure OpenAI client."""
47
- # Convert messages to OpenAI format
48
- openai_messages: list[ChatCompletionMessageParam] = []
49
- for message in messages:
50
- message.content = self._clean_input(message.content)
51
- if message.role == 'user':
52
- openai_messages.append({'role': 'user', 'content': message.content})
53
- elif message.role == 'system':
54
- openai_messages.append({'role': 'system', 'content': message.content})
55
-
56
- # Ensure model is a string
57
- model_name = self.model if self.model else 'gpt-4o-mini'
58
-
59
- try:
60
- response = await self.azure_client.chat.completions.create(
61
- model=model_name,
62
- messages=openai_messages,
63
- temperature=float(self.temperature) if self.temperature is not None else 0.7,
64
- max_tokens=max_tokens,
65
- response_format={'type': 'json_object'},
66
- )
67
- result = response.choices[0].message.content or '{}'
68
-
69
- # Parse JSON response
70
- return json.loads(result)
71
- except Exception as e:
72
- logger.error(f'Error in Azure OpenAI LLM response: {e}')
73
- raise
File without changes