graphiti-core 0.18.9__py3-none-any.whl → 0.19.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

Files changed (29) hide show
  1. graphiti_core/driver/driver.py +4 -0
  2. graphiti_core/driver/falkordb_driver.py +3 -14
  3. graphiti_core/driver/kuzu_driver.py +175 -0
  4. graphiti_core/driver/neptune_driver.py +301 -0
  5. graphiti_core/edges.py +155 -62
  6. graphiti_core/graph_queries.py +31 -2
  7. graphiti_core/graphiti.py +6 -1
  8. graphiti_core/helpers.py +8 -8
  9. graphiti_core/llm_client/config.py +1 -1
  10. graphiti_core/llm_client/openai_base_client.py +12 -2
  11. graphiti_core/llm_client/openai_client.py +10 -2
  12. graphiti_core/migrations/__init__.py +0 -0
  13. graphiti_core/migrations/neo4j_node_group_labels.py +114 -0
  14. graphiti_core/models/edges/edge_db_queries.py +205 -76
  15. graphiti_core/models/nodes/node_db_queries.py +253 -74
  16. graphiti_core/nodes.py +271 -98
  17. graphiti_core/search/search.py +42 -12
  18. graphiti_core/search/search_config.py +4 -0
  19. graphiti_core/search/search_filters.py +35 -22
  20. graphiti_core/search/search_utils.py +1329 -392
  21. graphiti_core/utils/bulk_utils.py +50 -15
  22. graphiti_core/utils/datetime_utils.py +13 -0
  23. graphiti_core/utils/maintenance/community_operations.py +39 -32
  24. graphiti_core/utils/maintenance/edge_operations.py +47 -13
  25. graphiti_core/utils/maintenance/graph_data_operations.py +100 -15
  26. {graphiti_core-0.18.9.dist-info → graphiti_core-0.19.0.dist-info}/METADATA +87 -13
  27. {graphiti_core-0.18.9.dist-info → graphiti_core-0.19.0.dist-info}/RECORD +29 -25
  28. {graphiti_core-0.18.9.dist-info → graphiti_core-0.19.0.dist-info}/WHEEL +0 -0
  29. {graphiti_core-0.18.9.dist-info → graphiti_core-0.19.0.dist-info}/licenses/LICENSE +0 -0
graphiti_core/edges.py CHANGED
@@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
+ import json
17
18
  import logging
18
19
  from abc import ABC, abstractmethod
19
20
  from datetime import datetime
@@ -24,16 +25,16 @@ from uuid import uuid4
24
25
  from pydantic import BaseModel, Field
25
26
  from typing_extensions import LiteralString
26
27
 
27
- from graphiti_core.driver.driver import GraphDriver
28
+ from graphiti_core.driver.driver import GraphDriver, GraphProvider
28
29
  from graphiti_core.embedder import EmbedderClient
29
30
  from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
30
31
  from graphiti_core.helpers import parse_db_date
31
32
  from graphiti_core.models.edges.edge_db_queries import (
32
33
  COMMUNITY_EDGE_RETURN,
33
- ENTITY_EDGE_RETURN,
34
34
  EPISODIC_EDGE_RETURN,
35
35
  EPISODIC_EDGE_SAVE,
36
36
  get_community_edge_save_query,
37
+ get_entity_edge_return_query,
37
38
  get_entity_edge_save_query,
38
39
  )
39
40
  from graphiti_core.nodes import Node
@@ -52,33 +53,63 @@ class Edge(BaseModel, ABC):
52
53
  async def save(self, driver: GraphDriver): ...
53
54
 
54
55
  async def delete(self, driver: GraphDriver):
55
- result = await driver.execute_query(
56
- """
57
- MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER {uuid: $uuid}]->(m)
58
- DELETE e
59
- """,
60
- uuid=self.uuid,
61
- )
56
+ if driver.provider == GraphProvider.KUZU:
57
+ await driver.execute_query(
58
+ """
59
+ MATCH (n)-[e:MENTIONS|HAS_MEMBER {uuid: $uuid}]->(m)
60
+ DELETE e
61
+ """,
62
+ uuid=self.uuid,
63
+ )
64
+ await driver.execute_query(
65
+ """
66
+ MATCH (e:RelatesToNode_ {uuid: $uuid})
67
+ DETACH DELETE e
68
+ """,
69
+ uuid=self.uuid,
70
+ )
71
+ else:
72
+ await driver.execute_query(
73
+ """
74
+ MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER {uuid: $uuid}]->(m)
75
+ DELETE e
76
+ """,
77
+ uuid=self.uuid,
78
+ )
62
79
 
63
80
  logger.debug(f'Deleted Edge: {self.uuid}')
64
81
 
65
- return result
66
-
67
82
  @classmethod
68
83
  async def delete_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
69
- result = await driver.execute_query(
70
- """
71
- MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER]->(m)
72
- WHERE e.uuid IN $uuids
73
- DELETE e
74
- """,
75
- uuids=uuids,
76
- )
84
+ if driver.provider == GraphProvider.KUZU:
85
+ await driver.execute_query(
86
+ """
87
+ MATCH (n)-[e:MENTIONS|HAS_MEMBER]->(m)
88
+ WHERE e.uuid IN $uuids
89
+ DELETE e
90
+ """,
91
+ uuids=uuids,
92
+ )
93
+ await driver.execute_query(
94
+ """
95
+ MATCH (e:RelatesToNode_)
96
+ WHERE e.uuid IN $uuids
97
+ DETACH DELETE e
98
+ """,
99
+ uuids=uuids,
100
+ )
101
+ else:
102
+ await driver.execute_query(
103
+ """
104
+ MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER]->(m)
105
+ WHERE e.uuid IN $uuids
106
+ DELETE e
107
+ """,
108
+ uuids=uuids,
109
+ )
77
110
 
78
111
  logger.debug(f'Deleted Edges: {uuids}')
79
112
 
80
- return result
81
-
82
113
  def __hash__(self):
83
114
  return hash(self.uuid)
84
115
 
@@ -165,7 +196,7 @@ class EpisodicEdge(Edge):
165
196
  """
166
197
  + EPISODIC_EDGE_RETURN
167
198
  + """
168
- ORDER BY e.uuid DESC
199
+ ORDER BY e.uuid DESC
169
200
  """
170
201
  + limit_query,
171
202
  group_ids=group_ids,
@@ -214,11 +245,25 @@ class EntityEdge(Edge):
214
245
  return self.fact_embedding
215
246
 
216
247
  async def load_fact_embedding(self, driver: GraphDriver):
217
- records, _, _ = await driver.execute_query(
218
- """
248
+ query = """
219
249
  MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
220
250
  RETURN e.fact_embedding AS fact_embedding
221
- """,
251
+ """
252
+
253
+ if driver.provider == GraphProvider.NEPTUNE:
254
+ query = """
255
+ MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
256
+ RETURN [x IN split(e.fact_embedding, ",") | toFloat(x)] as fact_embedding
257
+ """
258
+
259
+ if driver.provider == GraphProvider.KUZU:
260
+ query = """
261
+ MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: $uuid})-[:RELATES_TO]->(m:Entity)
262
+ RETURN e.fact_embedding AS fact_embedding
263
+ """
264
+
265
+ records, _, _ = await driver.execute_query(
266
+ query,
222
267
  uuid=self.uuid,
223
268
  routing_='r',
224
269
  )
@@ -244,12 +289,22 @@ class EntityEdge(Edge):
244
289
  'invalid_at': self.invalid_at,
245
290
  }
246
291
 
247
- edge_data.update(self.attributes or {})
292
+ if driver.provider == GraphProvider.KUZU:
293
+ edge_data['attributes'] = json.dumps(self.attributes)
294
+ result = await driver.execute_query(
295
+ get_entity_edge_save_query(driver.provider),
296
+ **edge_data,
297
+ )
298
+ else:
299
+ edge_data.update(self.attributes or {})
248
300
 
249
- result = await driver.execute_query(
250
- get_entity_edge_save_query(driver.provider),
251
- edge_data=edge_data,
252
- )
301
+ if driver.provider == GraphProvider.NEPTUNE:
302
+ driver.save_to_aoss('edge_name_and_fact', [edge_data]) # pyright: ignore reportAttributeAccessIssue
303
+
304
+ result = await driver.execute_query(
305
+ get_entity_edge_save_query(driver.provider),
306
+ edge_data=edge_data,
307
+ )
253
308
 
254
309
  logger.debug(f'Saved edge to Graph: {self.uuid}')
255
310
 
@@ -257,17 +312,25 @@ class EntityEdge(Edge):
257
312
 
258
313
  @classmethod
259
314
  async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
260
- records, _, _ = await driver.execute_query(
261
- """
315
+ match_query = """
262
316
  MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
317
+ """
318
+ if driver.provider == GraphProvider.KUZU:
319
+ match_query = """
320
+ MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: $uuid})-[:RELATES_TO]->(m:Entity)
321
+ """
322
+
323
+ records, _, _ = await driver.execute_query(
324
+ match_query
325
+ + """
263
326
  RETURN
264
327
  """
265
- + ENTITY_EDGE_RETURN,
328
+ + get_entity_edge_return_query(driver.provider),
266
329
  uuid=uuid,
267
330
  routing_='r',
268
331
  )
269
332
 
270
- edges = [get_entity_edge_from_record(record) for record in records]
333
+ edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
271
334
 
272
335
  if len(edges) == 0:
273
336
  raise EdgeNotFoundError(uuid)
@@ -278,18 +341,26 @@ class EntityEdge(Edge):
278
341
  if len(uuids) == 0:
279
342
  return []
280
343
 
281
- records, _, _ = await driver.execute_query(
282
- """
344
+ match_query = """
283
345
  MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
346
+ """
347
+ if driver.provider == GraphProvider.KUZU:
348
+ match_query = """
349
+ MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
350
+ """
351
+
352
+ records, _, _ = await driver.execute_query(
353
+ match_query
354
+ + """
284
355
  WHERE e.uuid IN $uuids
285
356
  RETURN
286
357
  """
287
- + ENTITY_EDGE_RETURN,
358
+ + get_entity_edge_return_query(driver.provider),
288
359
  uuids=uuids,
289
360
  routing_='r',
290
361
  )
291
362
 
292
- edges = [get_entity_edge_from_record(record) for record in records]
363
+ edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
293
364
 
294
365
  return edges
295
366
 
@@ -312,19 +383,27 @@ class EntityEdge(Edge):
312
383
  else ''
313
384
  )
314
385
 
315
- records, _, _ = await driver.execute_query(
316
- """
386
+ match_query = """
317
387
  MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
388
+ """
389
+ if driver.provider == GraphProvider.KUZU:
390
+ match_query = """
391
+ MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
392
+ """
393
+
394
+ records, _, _ = await driver.execute_query(
395
+ match_query
396
+ + """
318
397
  WHERE e.group_id IN $group_ids
319
398
  """
320
399
  + cursor_query
321
400
  + """
322
401
  RETURN
323
402
  """
324
- + ENTITY_EDGE_RETURN
403
+ + get_entity_edge_return_query(driver.provider)
325
404
  + with_embeddings_query
326
405
  + """
327
- ORDER BY e.uuid DESC
406
+ ORDER BY e.uuid DESC
328
407
  """
329
408
  + limit_query,
330
409
  group_ids=group_ids,
@@ -333,7 +412,7 @@ class EntityEdge(Edge):
333
412
  routing_='r',
334
413
  )
335
414
 
336
- edges = [get_entity_edge_from_record(record) for record in records]
415
+ edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
337
416
 
338
417
  if len(edges) == 0:
339
418
  raise GroupsEdgesNotFoundError(group_ids)
@@ -341,17 +420,25 @@ class EntityEdge(Edge):
341
420
 
342
421
  @classmethod
343
422
  async def get_by_node_uuid(cls, driver: GraphDriver, node_uuid: str):
344
- records, _, _ = await driver.execute_query(
345
- """
423
+ match_query = """
346
424
  MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
425
+ """
426
+ if driver.provider == GraphProvider.KUZU:
427
+ match_query = """
428
+ MATCH (n:Entity {uuid: $node_uuid})-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
429
+ """
430
+
431
+ records, _, _ = await driver.execute_query(
432
+ match_query
433
+ + """
347
434
  RETURN
348
435
  """
349
- + ENTITY_EDGE_RETURN,
436
+ + get_entity_edge_return_query(driver.provider),
350
437
  node_uuid=node_uuid,
351
438
  routing_='r',
352
439
  )
353
440
 
354
- edges = [get_entity_edge_from_record(record) for record in records]
441
+ edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
355
442
 
356
443
  return edges
357
444
 
@@ -451,7 +538,25 @@ def get_episodic_edge_from_record(record: Any) -> EpisodicEdge:
451
538
  )
452
539
 
453
540
 
454
- def get_entity_edge_from_record(record: Any) -> EntityEdge:
541
+ def get_entity_edge_from_record(record: Any, provider: GraphProvider) -> EntityEdge:
542
+ episodes = record['episodes']
543
+ if provider == GraphProvider.KUZU:
544
+ attributes = json.loads(record['attributes']) if record['attributes'] else {}
545
+ else:
546
+ attributes = record['attributes']
547
+ attributes.pop('uuid', None)
548
+ attributes.pop('source_node_uuid', None)
549
+ attributes.pop('target_node_uuid', None)
550
+ attributes.pop('fact', None)
551
+ attributes.pop('fact_embedding', None)
552
+ attributes.pop('name', None)
553
+ attributes.pop('group_id', None)
554
+ attributes.pop('episodes', None)
555
+ attributes.pop('created_at', None)
556
+ attributes.pop('expired_at', None)
557
+ attributes.pop('valid_at', None)
558
+ attributes.pop('invalid_at', None)
559
+
455
560
  edge = EntityEdge(
456
561
  uuid=record['uuid'],
457
562
  source_node_uuid=record['source_node_uuid'],
@@ -460,26 +565,14 @@ def get_entity_edge_from_record(record: Any) -> EntityEdge:
460
565
  fact_embedding=record.get('fact_embedding'),
461
566
  name=record['name'],
462
567
  group_id=record['group_id'],
463
- episodes=record['episodes'],
568
+ episodes=episodes,
464
569
  created_at=parse_db_date(record['created_at']), # type: ignore
465
570
  expired_at=parse_db_date(record['expired_at']),
466
571
  valid_at=parse_db_date(record['valid_at']),
467
572
  invalid_at=parse_db_date(record['invalid_at']),
468
- attributes=record['attributes'],
573
+ attributes=attributes,
469
574
  )
470
575
 
471
- edge.attributes.pop('uuid', None)
472
- edge.attributes.pop('source_node_uuid', None)
473
- edge.attributes.pop('target_node_uuid', None)
474
- edge.attributes.pop('fact', None)
475
- edge.attributes.pop('name', None)
476
- edge.attributes.pop('group_id', None)
477
- edge.attributes.pop('episodes', None)
478
- edge.attributes.pop('created_at', None)
479
- edge.attributes.pop('expired_at', None)
480
- edge.attributes.pop('valid_at', None)
481
- edge.attributes.pop('invalid_at', None)
482
-
483
576
  return edge
484
577
 
485
578
 
@@ -16,6 +16,13 @@ NEO4J_TO_FALKORDB_MAPPING = {
16
16
  'episode_content': 'Episodic',
17
17
  'edge_name_and_fact': 'RELATES_TO',
18
18
  }
19
+ # Mapping from fulltext index names to Kuzu node labels
20
+ INDEX_TO_LABEL_KUZU_MAPPING = {
21
+ 'node_name_and_summary': 'Entity',
22
+ 'community_name': 'Community',
23
+ 'episode_content': 'Episodic',
24
+ 'edge_name_and_fact': 'RelatesToNode_',
25
+ }
19
26
 
20
27
 
21
28
  def get_range_indices(provider: GraphProvider) -> list[LiteralString]:
@@ -35,6 +42,9 @@ def get_range_indices(provider: GraphProvider) -> list[LiteralString]:
35
42
  'CREATE INDEX FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)',
36
43
  ]
37
44
 
45
+ if provider == GraphProvider.KUZU:
46
+ return []
47
+
38
48
  return [
39
49
  'CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)',
40
50
  'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)',
@@ -68,6 +78,14 @@ def get_fulltext_indices(provider: GraphProvider) -> list[LiteralString]:
68
78
  """CREATE FULLTEXT INDEX FOR ()-[e:RELATES_TO]-() ON (e.name, e.fact, e.group_id)""",
69
79
  ]
70
80
 
81
+ if provider == GraphProvider.KUZU:
82
+ return [
83
+ "CALL CREATE_FTS_INDEX('Episodic', 'episode_content', ['content', 'source', 'source_description']);",
84
+ "CALL CREATE_FTS_INDEX('Entity', 'node_name_and_summary', ['name', 'summary']);",
85
+ "CALL CREATE_FTS_INDEX('Community', 'community_name', ['name']);",
86
+ "CALL CREATE_FTS_INDEX('RelatesToNode_', 'edge_name_and_fact', ['name', 'fact']);",
87
+ ]
88
+
71
89
  return [
72
90
  """CREATE FULLTEXT INDEX episode_content IF NOT EXISTS
73
91
  FOR (e:Episodic) ON EACH [e.content, e.source, e.source_description, e.group_id]""",
@@ -80,11 +98,15 @@ def get_fulltext_indices(provider: GraphProvider) -> list[LiteralString]:
80
98
  ]
81
99
 
82
100
 
83
- def get_nodes_query(provider: GraphProvider, name: str = '', query: str | None = None) -> str:
101
+ def get_nodes_query(name: str, query: str, limit: int, provider: GraphProvider) -> str:
84
102
  if provider == GraphProvider.FALKORDB:
85
103
  label = NEO4J_TO_FALKORDB_MAPPING[name]
86
104
  return f"CALL db.idx.fulltext.queryNodes('{label}', {query})"
87
105
 
106
+ if provider == GraphProvider.KUZU:
107
+ label = INDEX_TO_LABEL_KUZU_MAPPING[name]
108
+ return f"CALL QUERY_FTS_INDEX('{label}', '{name}', {query}, TOP := $limit)"
109
+
88
110
  return f'CALL db.index.fulltext.queryNodes("{name}", {query}, {{limit: $limit}})'
89
111
 
90
112
 
@@ -93,12 +115,19 @@ def get_vector_cosine_func_query(vec1, vec2, provider: GraphProvider) -> str:
93
115
  # FalkorDB uses a different syntax for regular cosine similarity and Neo4j uses normalized cosine similarity
94
116
  return f'(2 - vec.cosineDistance({vec1}, vecf32({vec2})))/2'
95
117
 
118
+ if provider == GraphProvider.KUZU:
119
+ return f'array_cosine_similarity({vec1}, {vec2})'
120
+
96
121
  return f'vector.similarity.cosine({vec1}, {vec2})'
97
122
 
98
123
 
99
- def get_relationships_query(name: str, provider: GraphProvider) -> str:
124
+ def get_relationships_query(name: str, limit: int, provider: GraphProvider) -> str:
100
125
  if provider == GraphProvider.FALKORDB:
101
126
  label = NEO4J_TO_FALKORDB_MAPPING[name]
102
127
  return f"CALL db.idx.fulltext.queryRelationships('{label}', $query)"
103
128
 
129
+ if provider == GraphProvider.KUZU:
130
+ label = INDEX_TO_LABEL_KUZU_MAPPING[name]
131
+ return f"CALL QUERY_FTS_INDEX('{label}', '{name}', cast($query AS STRING), TOP := $limit)"
132
+
104
133
  return f'CALL db.index.fulltext.queryRelationships("{name}", $query, {{limit: $limit}})'
graphiti_core/graphiti.py CHANGED
@@ -89,6 +89,7 @@ from graphiti_core.utils.maintenance.edge_operations import (
89
89
  )
90
90
  from graphiti_core.utils.maintenance.graph_data_operations import (
91
91
  EPISODE_WINDOW_LEN,
92
+ build_dynamic_indexes,
92
93
  build_indices_and_constraints,
93
94
  retrieve_episodes,
94
95
  )
@@ -450,6 +451,7 @@ class Graphiti:
450
451
 
451
452
  validate_excluded_entity_types(excluded_entity_types, entity_types)
452
453
  validate_group_id(group_id)
454
+ await build_dynamic_indexes(self.driver, group_id)
453
455
 
454
456
  previous_episodes = (
455
457
  await self.retrieve_episodes(
@@ -625,6 +627,7 @@ class Graphiti:
625
627
  # if group_id is None, use the default group id by the provider
626
628
  group_id = group_id or get_default_group_id(self.driver.provider)
627
629
  validate_group_id(group_id)
630
+ await build_dynamic_indexes(self.driver, group_id)
628
631
 
629
632
  # Create default edge type map
630
633
  edge_type_map_default = (
@@ -1006,6 +1009,8 @@ class Graphiti:
1006
1009
  if edge.fact_embedding is None:
1007
1010
  await edge.generate_embedding(self.embedder)
1008
1011
 
1012
+ await build_dynamic_indexes(self.driver, source_node.group_id)
1013
+
1009
1014
  nodes, uuid_map, _ = await resolve_extracted_nodes(
1010
1015
  self.clients,
1011
1016
  [source_node, target_node],
@@ -1068,7 +1073,7 @@ class Graphiti:
1068
1073
  if record['episode_count'] == 1:
1069
1074
  nodes_to_delete.append(node)
1070
1075
 
1076
+ await Edge.delete_by_uuids(self.driver, [edge.uuid for edge in edges_to_delete])
1071
1077
  await Node.delete_by_uuids(self.driver, [node.uuid for node in nodes_to_delete])
1072
1078
 
1073
- await Edge.delete_by_uuids(self.driver, [edge.uuid for edge in edges_to_delete])
1074
1079
  await episode.delete(self.driver)
graphiti_core/helpers.py CHANGED
@@ -43,14 +43,14 @@ RUNTIME_QUERY: LiteralString = (
43
43
  )
44
44
 
45
45
 
46
- def parse_db_date(neo_date: neo4j_time.DateTime | str | None) -> datetime | None:
47
- return (
48
- neo_date.to_native()
49
- if isinstance(neo_date, neo4j_time.DateTime)
50
- else datetime.fromisoformat(neo_date)
51
- if neo_date
52
- else None
53
- )
46
+ def parse_db_date(input_date: neo4j_time.DateTime | str | None) -> datetime | None:
47
+ if isinstance(input_date, neo4j_time.DateTime):
48
+ return input_date.to_native()
49
+
50
+ if isinstance(input_date, str):
51
+ return datetime.fromisoformat(input_date)
52
+
53
+ return input_date
54
54
 
55
55
 
56
56
  def get_default_group_id(provider: GraphProvider) -> str:
@@ -17,7 +17,7 @@ limitations under the License.
17
17
  from enum import Enum
18
18
 
19
19
  DEFAULT_MAX_TOKENS = 8192
20
- DEFAULT_TEMPERATURE = 0
20
+ DEFAULT_TEMPERATURE = 1
21
21
 
22
22
 
23
23
  class ModelSize(Enum):
@@ -31,8 +31,10 @@ from .errors import RateLimitError, RefusalError
31
31
 
32
32
  logger = logging.getLogger(__name__)
33
33
 
34
- DEFAULT_MODEL = 'gpt-4.1-mini'
35
- DEFAULT_SMALL_MODEL = 'gpt-4.1-nano'
34
+ DEFAULT_MODEL = 'gpt-5-mini'
35
+ DEFAULT_SMALL_MODEL = 'gpt-5-nano'
36
+ DEFAULT_REASONING = 'minimal'
37
+ DEFAULT_VERBOSITY = 'low'
36
38
 
37
39
 
38
40
  class BaseOpenAIClient(LLMClient):
@@ -51,6 +53,8 @@ class BaseOpenAIClient(LLMClient):
51
53
  config: LLMConfig | None = None,
52
54
  cache: bool = False,
53
55
  max_tokens: int = DEFAULT_MAX_TOKENS,
56
+ reasoning: str | None = DEFAULT_REASONING,
57
+ verbosity: str | None = DEFAULT_VERBOSITY,
54
58
  ):
55
59
  if cache:
56
60
  raise NotImplementedError('Caching is not implemented for OpenAI-based clients')
@@ -60,6 +64,8 @@ class BaseOpenAIClient(LLMClient):
60
64
 
61
65
  super().__init__(config, cache)
62
66
  self.max_tokens = max_tokens
67
+ self.reasoning = reasoning
68
+ self.verbosity = verbosity
63
69
 
64
70
  @abstractmethod
65
71
  async def _create_completion(
@@ -81,6 +87,8 @@ class BaseOpenAIClient(LLMClient):
81
87
  temperature: float | None,
82
88
  max_tokens: int,
83
89
  response_model: type[BaseModel],
90
+ reasoning: str | None,
91
+ verbosity: str | None,
84
92
  ) -> Any:
85
93
  """Create a structured completion using the specific client implementation."""
86
94
  pass
@@ -140,6 +148,8 @@ class BaseOpenAIClient(LLMClient):
140
148
  temperature=self.temperature,
141
149
  max_tokens=max_tokens or self.max_tokens,
142
150
  response_model=response_model,
151
+ reasoning=self.reasoning,
152
+ verbosity=self.verbosity,
143
153
  )
144
154
  return self._handle_structured_response(response)
145
155
  else:
@@ -21,7 +21,7 @@ from openai.types.chat import ChatCompletionMessageParam
21
21
  from pydantic import BaseModel
22
22
 
23
23
  from .config import DEFAULT_MAX_TOKENS, LLMConfig
24
- from .openai_base_client import BaseOpenAIClient
24
+ from .openai_base_client import DEFAULT_REASONING, DEFAULT_VERBOSITY, BaseOpenAIClient
25
25
 
26
26
 
27
27
  class OpenAIClient(BaseOpenAIClient):
@@ -41,6 +41,8 @@ class OpenAIClient(BaseOpenAIClient):
41
41
  cache: bool = False,
42
42
  client: typing.Any = None,
43
43
  max_tokens: int = DEFAULT_MAX_TOKENS,
44
+ reasoning: str = DEFAULT_REASONING,
45
+ verbosity: str = DEFAULT_VERBOSITY,
44
46
  ):
45
47
  """
46
48
  Initialize the OpenAIClient with the provided configuration, cache setting, and client.
@@ -50,7 +52,7 @@ class OpenAIClient(BaseOpenAIClient):
50
52
  cache (bool): Whether to use caching for responses. Defaults to False.
51
53
  client (Any | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created.
52
54
  """
53
- super().__init__(config, cache, max_tokens)
55
+ super().__init__(config, cache, max_tokens, reasoning, verbosity)
54
56
 
55
57
  if config is None:
56
58
  config = LLMConfig()
@@ -67,6 +69,8 @@ class OpenAIClient(BaseOpenAIClient):
67
69
  temperature: float | None,
68
70
  max_tokens: int,
69
71
  response_model: type[BaseModel],
72
+ reasoning: str | None = None,
73
+ verbosity: str | None = None,
70
74
  ):
71
75
  """Create a structured completion using OpenAI's beta parse API."""
72
76
  response = await self.client.responses.parse(
@@ -75,6 +79,8 @@ class OpenAIClient(BaseOpenAIClient):
75
79
  temperature=temperature,
76
80
  max_output_tokens=max_tokens,
77
81
  text_format=response_model, # type: ignore
82
+ reasoning={'effort': reasoning} if reasoning is not None else None, # type: ignore
83
+ text={'verbosity': verbosity} if verbosity is not None else None, # type: ignore
78
84
  )
79
85
 
80
86
  return response
@@ -86,6 +92,8 @@ class OpenAIClient(BaseOpenAIClient):
86
92
  temperature: float | None,
87
93
  max_tokens: int,
88
94
  response_model: type[BaseModel] | None = None,
95
+ reasoning: str | None = None,
96
+ verbosity: str | None = None,
89
97
  ):
90
98
  """Create a regular completion with JSON format."""
91
99
  return await self.client.chat.completions.create(
File without changes