dao-ai 0.0.28__py3-none-any.whl → 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. dao_ai/__init__.py +29 -0
  2. dao_ai/agent_as_code.py +2 -5
  3. dao_ai/cli.py +342 -58
  4. dao_ai/config.py +1610 -380
  5. dao_ai/genie/__init__.py +38 -0
  6. dao_ai/genie/cache/__init__.py +43 -0
  7. dao_ai/genie/cache/base.py +72 -0
  8. dao_ai/genie/cache/core.py +79 -0
  9. dao_ai/genie/cache/lru.py +347 -0
  10. dao_ai/genie/cache/semantic.py +970 -0
  11. dao_ai/genie/core.py +35 -0
  12. dao_ai/graph.py +27 -253
  13. dao_ai/hooks/__init__.py +9 -6
  14. dao_ai/hooks/core.py +27 -195
  15. dao_ai/logging.py +56 -0
  16. dao_ai/memory/__init__.py +10 -0
  17. dao_ai/memory/core.py +65 -30
  18. dao_ai/memory/databricks.py +402 -0
  19. dao_ai/memory/postgres.py +79 -38
  20. dao_ai/messages.py +6 -4
  21. dao_ai/middleware/__init__.py +158 -0
  22. dao_ai/middleware/assertions.py +806 -0
  23. dao_ai/middleware/base.py +50 -0
  24. dao_ai/middleware/context_editing.py +230 -0
  25. dao_ai/middleware/core.py +67 -0
  26. dao_ai/middleware/guardrails.py +420 -0
  27. dao_ai/middleware/human_in_the_loop.py +233 -0
  28. dao_ai/middleware/message_validation.py +586 -0
  29. dao_ai/middleware/model_call_limit.py +77 -0
  30. dao_ai/middleware/model_retry.py +121 -0
  31. dao_ai/middleware/pii.py +157 -0
  32. dao_ai/middleware/summarization.py +197 -0
  33. dao_ai/middleware/tool_call_limit.py +210 -0
  34. dao_ai/middleware/tool_retry.py +174 -0
  35. dao_ai/models.py +1306 -114
  36. dao_ai/nodes.py +240 -161
  37. dao_ai/optimization.py +674 -0
  38. dao_ai/orchestration/__init__.py +52 -0
  39. dao_ai/orchestration/core.py +294 -0
  40. dao_ai/orchestration/supervisor.py +279 -0
  41. dao_ai/orchestration/swarm.py +271 -0
  42. dao_ai/prompts.py +128 -31
  43. dao_ai/providers/databricks.py +584 -601
  44. dao_ai/state.py +157 -21
  45. dao_ai/tools/__init__.py +13 -5
  46. dao_ai/tools/agent.py +1 -3
  47. dao_ai/tools/core.py +64 -11
  48. dao_ai/tools/email.py +232 -0
  49. dao_ai/tools/genie.py +144 -294
  50. dao_ai/tools/mcp.py +223 -155
  51. dao_ai/tools/memory.py +50 -0
  52. dao_ai/tools/python.py +9 -14
  53. dao_ai/tools/search.py +14 -0
  54. dao_ai/tools/slack.py +22 -10
  55. dao_ai/tools/sql.py +202 -0
  56. dao_ai/tools/time.py +30 -7
  57. dao_ai/tools/unity_catalog.py +165 -88
  58. dao_ai/tools/vector_search.py +331 -221
  59. dao_ai/utils.py +166 -20
  60. dao_ai/vector_search.py +37 -0
  61. dao_ai-0.1.5.dist-info/METADATA +489 -0
  62. dao_ai-0.1.5.dist-info/RECORD +70 -0
  63. dao_ai/chat_models.py +0 -204
  64. dao_ai/guardrails.py +0 -112
  65. dao_ai/tools/human_in_the_loop.py +0 -100
  66. dao_ai-0.0.28.dist-info/METADATA +0 -1168
  67. dao_ai-0.0.28.dist-info/RECORD +0 -41
  68. {dao_ai-0.0.28.dist-info → dao_ai-0.1.5.dist-info}/WHEEL +0 -0
  69. {dao_ai-0.0.28.dist-info → dao_ai-0.1.5.dist-info}/entry_points.txt +0 -0
  70. {dao_ai-0.0.28.dist-info → dao_ai-0.1.5.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,586 @@
1
+ """
2
+ Message validation middleware for DAO AI agents.
3
+
4
+ These middleware implementations validate incoming messages and context
5
+ before agent processing begins.
6
+
7
+ Factory functions are provided for consistent configuration via the
8
+ DAO AI middleware factory pattern.
9
+ """
10
+
11
+ import json
12
+ from typing import Any
13
+
14
+ from langchain.agents.middleware import hook_config
15
+ from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, RemoveMessage
16
+ from langgraph.runtime import Runtime
17
+ from loguru import logger
18
+
19
+ from dao_ai.messages import last_human_message
20
+ from dao_ai.middleware.base import AgentMiddleware
21
+ from dao_ai.state import AgentState, Context
22
+
23
+ __all__ = [
24
+ "MessageValidationMiddleware",
25
+ "UserIdValidationMiddleware",
26
+ "ThreadIdValidationMiddleware",
27
+ "CustomFieldValidationMiddleware",
28
+ "RequiredField",
29
+ "FilterLastHumanMessageMiddleware",
30
+ "create_user_id_validation_middleware",
31
+ "create_thread_id_validation_middleware",
32
+ "create_custom_field_validation_middleware",
33
+ "create_filter_last_human_message_middleware",
34
+ ]
35
+
36
+
37
+ class MessageValidationMiddleware(AgentMiddleware[AgentState, Context]):
38
+ """
39
+ Base middleware for message validation.
40
+
41
+ Subclasses should implement the validate method to perform
42
+ specific validation logic.
43
+ """
44
+
45
+ @hook_config(can_jump_to=["end"])
46
+ def before_agent(
47
+ self, state: AgentState, runtime: Runtime[Context]
48
+ ) -> dict[str, Any] | None:
49
+ """Validate messages before agent processing."""
50
+ try:
51
+ return self.validate(state, runtime)
52
+ except ValueError as e:
53
+ logger.error("Message validation failed", error=str(e))
54
+ return {
55
+ "is_valid": False,
56
+ "message_error": str(e),
57
+ "messages": [AIMessage(content=str(e))],
58
+ "jump_to": "end",
59
+ }
60
+
61
+ def validate(
62
+ self, state: AgentState, runtime: Runtime[Context]
63
+ ) -> dict[str, Any] | None:
64
+ """
65
+ Perform validation logic.
66
+
67
+ Override this method in subclasses to implement specific validation.
68
+ Raise ValueError to indicate validation failure.
69
+
70
+ Args:
71
+ state: The current agent state
72
+ runtime: The LangGraph runtime context
73
+
74
+ Returns:
75
+ Optional dict with state updates
76
+
77
+ Raises:
78
+ ValueError: If validation fails
79
+ """
80
+ return None
81
+
82
+
83
+ class UserIdValidationMiddleware(MessageValidationMiddleware):
84
+ """
85
+ Middleware that validates the presence and format of user_id.
86
+
87
+ Ensures that:
88
+ - user_id is provided in the context
89
+ - user_id does not contain invalid characters (like dots)
90
+ """
91
+
92
+ def validate(
93
+ self, state: AgentState, runtime: Runtime[Context]
94
+ ) -> dict[str, Any] | None:
95
+ """Validate user_id is present and properly formatted."""
96
+ logger.trace("Executing user_id validation")
97
+
98
+ context: Context = runtime.context or Context()
99
+ user_id: str | None = context.user_id
100
+
101
+ if not user_id:
102
+ logger.error("User ID is required but not provided in configuration")
103
+
104
+ thread_val = context.thread_id or "<your_thread_id>"
105
+ # Get extra fields from context (excluding user_id and thread_id)
106
+ context_dict = context.model_dump()
107
+ extra_fields = {
108
+ k: v
109
+ for k, v in context_dict.items()
110
+ if k not in {"user_id", "thread_id"} and v is not None
111
+ }
112
+
113
+ corrected_config: dict[str, Any] = {
114
+ "configurable": {
115
+ "thread_id": thread_val,
116
+ "user_id": "<your_user_id>",
117
+ **extra_fields,
118
+ },
119
+ "session": {
120
+ "conversation_id": thread_val,
121
+ },
122
+ }
123
+ corrected_config_json = json.dumps(corrected_config, indent=2)
124
+
125
+ error_message = f"""
126
+ ## Authentication Required
127
+
128
+ A **user_id** is required to process your request. Please provide your user ID in the configuration.
129
+
130
+ ### Required Configuration Format
131
+
132
+ Please include the following JSON in your request configuration:
133
+
134
+ ```json
135
+ {corrected_config_json}
136
+ ```
137
+
138
+ ### Field Descriptions
139
+ - **thread_id**: Thread identifier (required in configurable)
140
+ - **conversation_id**: Alias of thread_id (in session)
141
+ - **user_id**: Your unique user identifier (required)
142
+
143
+ Please update your configuration and try again.
144
+ """.strip()
145
+
146
+ raise ValueError(error_message)
147
+
148
+ if "." in user_id:
149
+ logger.error("User ID contains invalid character '.'", user_id=user_id)
150
+
151
+ corrected_user_id = user_id.replace(".", "_")
152
+ thread_val = context.thread_id or "<your_thread_id>"
153
+ # Get extra fields from context (excluding user_id and thread_id)
154
+ context_dict = context.model_dump()
155
+ extra_fields = {
156
+ k: v
157
+ for k, v in context_dict.items()
158
+ if k not in {"user_id", "thread_id"} and v is not None
159
+ }
160
+
161
+ corrected_config: dict[str, Any] = {
162
+ "configurable": {
163
+ "thread_id": thread_val,
164
+ "user_id": corrected_user_id,
165
+ **extra_fields,
166
+ },
167
+ "session": {
168
+ "conversation_id": thread_val,
169
+ },
170
+ }
171
+ corrected_config_json = json.dumps(corrected_config, indent=2)
172
+
173
+ error_message = f"""
174
+ ## Invalid User ID Format
175
+
176
+ The **user_id** cannot contain a dot character ('.'). Please provide a valid user ID without dots.
177
+
178
+ ### Corrected Configuration (Copy & Paste This)
179
+ ```json
180
+ {corrected_config_json}
181
+ ```
182
+
183
+ Please update your user_id and try again.
184
+ """.strip()
185
+
186
+ raise ValueError(error_message)
187
+
188
+ return None
189
+
190
+
191
+ class ThreadIdValidationMiddleware(MessageValidationMiddleware):
192
+ """
193
+ Middleware that validates the presence of thread_id/conversation_id.
194
+
195
+ Note: thread_id and conversation_id are interchangeable in configurable.
196
+ """
197
+
198
+ def validate(
199
+ self, state: AgentState, runtime: Runtime[Context]
200
+ ) -> dict[str, Any] | None:
201
+ """Validate thread_id/conversation_id is present."""
202
+ logger.trace("Executing thread_id/conversation_id validation")
203
+
204
+ context: Context = runtime.context or Context()
205
+ thread_id: str | None = context.thread_id
206
+
207
+ if not thread_id:
208
+ logger.error("Thread ID / Conversation ID is required but not provided")
209
+
210
+ # Get extra fields from context (excluding user_id and thread_id)
211
+ context_dict = context.model_dump()
212
+ extra_fields = {
213
+ k: v
214
+ for k, v in context_dict.items()
215
+ if k not in {"user_id", "thread_id"} and v is not None
216
+ }
217
+
218
+ corrected_config: dict[str, Any] = {
219
+ "configurable": {
220
+ "thread_id": "<your_thread_id>",
221
+ "user_id": context.user_id or "<your_user_id>",
222
+ **extra_fields,
223
+ },
224
+ "session": {
225
+ "conversation_id": "<your_thread_id>",
226
+ },
227
+ }
228
+ corrected_config_json = json.dumps(corrected_config, indent=2)
229
+
230
+ error_message = f"""
231
+ ## Configuration Required
232
+
233
+ A **thread_id** is required to process your request (or **conversation_id** as an alias).
234
+
235
+ ### Required Configuration Format
236
+
237
+ Please include the following JSON in your request configuration:
238
+
239
+ ```json
240
+ {corrected_config_json}
241
+ ```
242
+
243
+ ### Field Descriptions
244
+ - **thread_id**: Thread identifier (required in configurable)
245
+ - **conversation_id**: Alias of thread_id (in session)
246
+ - **user_id**: Your unique user identifier (required)
247
+
248
+ Please update your configuration and try again.
249
+ """.strip()
250
+
251
+ raise ValueError(error_message)
252
+
253
+ return None
254
+
255
+
256
+ class RequiredField:
257
+ """Definition of a field for validation.
258
+
259
+ Fields are marked as required or optional via the `required` flag:
260
+ - required=True (default): Field must be provided, validated
261
+ - required=False: Field is optional, not validated
262
+
263
+ For required fields, an `example_value` can be provided to show in error
264
+ messages, making it easy for users to copy-paste the configuration.
265
+
266
+ Args:
267
+ name: The field name (e.g., "store_num", "user_id")
268
+ description: Human-readable description for error messages
269
+ required: Whether this field is required (default: True)
270
+ example_value: Example value to show in error messages for missing fields
271
+ """
272
+
273
+ def __init__(
274
+ self,
275
+ name: str,
276
+ description: str | None = None,
277
+ required: bool = True,
278
+ example_value: Any = None,
279
+ ):
280
+ self.name = name
281
+ self.description = description or f"Your {name}"
282
+ self.required = required
283
+ self.example_value = example_value
284
+
285
+ @property
286
+ def is_required(self) -> bool:
287
+ """A field is required based on the required flag."""
288
+ return self.required
289
+
290
+
291
+ class CustomFieldValidationMiddleware(MessageValidationMiddleware):
292
+ """
293
+ Middleware that validates the presence of required custom fields.
294
+
295
+ This is a generic validation middleware that can check for multiple
296
+ required fields in the context object.
297
+
298
+ Fields are defined in the `fields` list. Each field can have:
299
+ - name: The field name (required)
300
+ - description: Human-readable description for error messages
301
+ - required: Whether field is required (default: True)
302
+ - example_value: Example value to show in error messages
303
+
304
+ Required fields (required=True) will be validated.
305
+ The example_value is used in error messages to help users copy-paste
306
+ the correct configuration format.
307
+
308
+ Args:
309
+ fields: List of fields to validate/show. Each can be a RequiredField
310
+ or a dict with 'name', 'description', 'required', and 'example_value' keys.
311
+ """
312
+
313
+ def __init__(
314
+ self,
315
+ fields: list[RequiredField | dict[str, Any]],
316
+ ):
317
+ super().__init__()
318
+
319
+ # Convert fields to RequiredField objects
320
+ self.fields: list[RequiredField] = []
321
+ for field in fields:
322
+ if isinstance(field, RequiredField):
323
+ self.fields.append(field)
324
+ elif isinstance(field, dict):
325
+ self.fields.append(RequiredField(**field))
326
+
327
+ def validate(
328
+ self, state: AgentState, runtime: Runtime[Context]
329
+ ) -> dict[str, Any] | None:
330
+ """Validate that all required fields are present.
331
+
332
+ Generates error messages with the new input structure:
333
+ configurable:
334
+ conversation_id: "abc-123"
335
+ user_id: "nate.fleming"
336
+ <field_name>: <example_value>
337
+ session: {}
338
+ """
339
+ logger.trace("Executing custom field validation")
340
+
341
+ context: Context = runtime.context or Context()
342
+
343
+ # Find all missing required fields
344
+ missing_fields: list[RequiredField] = []
345
+ for field in self.fields:
346
+ if field.is_required:
347
+ field_value: Any = getattr(context, field.name, None)
348
+ if field_value is None:
349
+ missing_fields.append(field)
350
+
351
+ if not missing_fields:
352
+ return None
353
+
354
+ # Log the missing fields
355
+ missing_names = [f.name for f in missing_fields]
356
+ logger.error("Required fields missing", fields=missing_names)
357
+
358
+ # Build the configurable dict preserving provided values
359
+ # and using example_value for missing required fields
360
+ # Note: only thread_id is in configurable (conversation_id goes in session)
361
+ configurable: dict[str, Any] = {}
362
+
363
+ thread_val = context.thread_id or "<your_thread_id>"
364
+ configurable["thread_id"] = thread_val
365
+
366
+ if context.user_id:
367
+ configurable["user_id"] = context.user_id
368
+ else:
369
+ configurable["user_id"] = "<your_user_id>"
370
+
371
+ # Add all extra values the user already provided
372
+ context_dict = context.model_dump()
373
+ for k, v in context_dict.items():
374
+ if k not in {"user_id", "thread_id"} and v is not None:
375
+ configurable[k] = v
376
+
377
+ # Then add our defined fields (provided values take precedence)
378
+ for field in self.fields:
379
+ if field.name in configurable:
380
+ # Field was provided by user - keep their value
381
+ continue
382
+
383
+ if field.is_required:
384
+ # Missing required field - use example_value or placeholder
385
+ configurable[field.name] = (
386
+ field.example_value
387
+ if field.example_value is not None
388
+ else f"<your_{field.name}>"
389
+ )
390
+ else:
391
+ # Optional field not provided - use example_value if available
392
+ if field.example_value is not None:
393
+ configurable[field.name] = field.example_value
394
+
395
+ # Build the corrected config with new structure
396
+ # Note: conversation_id is in session as an alias of thread_id
397
+ corrected_config: dict[str, Any] = {
398
+ "configurable": configurable,
399
+ "session": {
400
+ "conversation_id": thread_val,
401
+ },
402
+ }
403
+ corrected_config_json = json.dumps(corrected_config, indent=2)
404
+
405
+ # Build field descriptions
406
+ field_descriptions: list[str] = [
407
+ "- **thread_id**: Thread identifier (required in configurable)",
408
+ "- **conversation_id**: Alias of thread_id (in session)",
409
+ ]
410
+
411
+ # Add user_id if not in custom fields
412
+ has_user_id_field = any(f.name == "user_id" for f in self.fields)
413
+ if not has_user_id_field:
414
+ field_descriptions.append(
415
+ "- **user_id**: Your unique user identifier (required)"
416
+ )
417
+
418
+ # Add custom field descriptions
419
+ for field in self.fields:
420
+ required_text = "(required)" if field.is_required else "(optional)"
421
+ field_descriptions.append(
422
+ f"- **{field.name}**: {field.description} {required_text}"
423
+ )
424
+
425
+ field_descriptions_text = "\n".join(field_descriptions)
426
+
427
+ # Build the list of missing field names for the error message
428
+ missing_names_formatted = ", ".join(f"**{f.name}**" for f in missing_fields)
429
+
430
+ error_message = f"""
431
+ ## Configuration Required
432
+
433
+ The following required fields are missing: {missing_names_formatted}
434
+
435
+ ### Required Configuration Format
436
+
437
+ Please include the following JSON in your request configuration:
438
+
439
+ ```json
440
+ {corrected_config_json}
441
+ ```
442
+
443
+ ### Field Descriptions
444
+ {field_descriptions_text}
445
+
446
+ Please update your configuration and try again.
447
+ """.strip()
448
+
449
+ raise ValueError(error_message)
450
+
451
+
452
+ class FilterLastHumanMessageMiddleware(AgentMiddleware[AgentState, Context]):
453
+ """
454
+ Middleware that filters messages to keep only the last human message.
455
+
456
+ This is useful for scenarios where you want to process only the
457
+ latest user input without conversation history.
458
+ """
459
+
460
+ def before_model(
461
+ self, state: AgentState, runtime: Runtime[Context]
462
+ ) -> dict[str, Any] | None:
463
+ """Filter messages to keep only the last human message."""
464
+ logger.trace("Executing filter_last_human_message middleware")
465
+
466
+ messages: list[BaseMessage] = state.get("messages", [])
467
+
468
+ if not messages:
469
+ logger.trace("No messages found in state")
470
+ return None
471
+
472
+ last_message: HumanMessage | None = last_human_message(messages)
473
+
474
+ if last_message is None:
475
+ logger.trace("No human messages found in state")
476
+ return {"messages": []}
477
+
478
+ logger.trace(
479
+ "Filtered messages to last human message", original_count=len(messages)
480
+ )
481
+
482
+ removed_messages = [
483
+ RemoveMessage(id=message.id)
484
+ for message in messages
485
+ if message.id != last_message.id
486
+ ]
487
+
488
+ return {"messages": removed_messages}
489
+
490
+
491
+ # =============================================================================
492
+ # Factory Functions
493
+ # =============================================================================
494
+
495
+
496
+ def create_user_id_validation_middleware() -> UserIdValidationMiddleware:
497
+ """
498
+ Create a UserIdValidationMiddleware instance.
499
+
500
+ Factory function for creating middleware that validates the presence
501
+ and format of user_id in the runtime context.
502
+
503
+ Returns:
504
+ List containing UserIdValidationMiddleware instance
505
+
506
+ Example:
507
+ middleware = create_user_id_validation_middleware()
508
+ """
509
+ logger.trace("Creating user_id validation middleware")
510
+ return UserIdValidationMiddleware()
511
+
512
+
513
+ def create_thread_id_validation_middleware() -> ThreadIdValidationMiddleware:
514
+ """
515
+ Create a ThreadIdValidationMiddleware instance.
516
+
517
+ Factory function for creating middleware that validates the presence
518
+ of thread_id in the runtime context.
519
+
520
+ Returns:
521
+ List containing ThreadIdValidationMiddleware instance
522
+
523
+ Example:
524
+ middleware = create_thread_id_validation_middleware()
525
+ """
526
+ logger.trace("Creating thread_id validation middleware")
527
+ return ThreadIdValidationMiddleware()
528
+
529
+
530
+ def create_custom_field_validation_middleware(
531
+ fields: list[dict[str, Any]],
532
+ ) -> CustomFieldValidationMiddleware:
533
+ """
534
+ Create a CustomFieldValidationMiddleware instance.
535
+
536
+ Factory function for creating middleware that validates the presence
537
+ of required custom fields in the context object.
538
+
539
+ Each field in the list should have:
540
+ - name: The field name (required)
541
+ - description: Human-readable description for error messages (optional)
542
+ - required: Whether field is required (default: True)
543
+ - example_value: Example value to show in error messages (optional)
544
+
545
+ Required fields (required=True or not specified) will be validated.
546
+ The example_value is used in error messages to help users copy-paste.
547
+
548
+ Args:
549
+ fields: List of field definitions. Each dict should have 'name', and
550
+ optionally 'description', 'required', and 'example_value' keys.
551
+
552
+ Returns:
553
+ List containing CustomFieldValidationMiddleware configured with the specified fields
554
+
555
+ Example:
556
+ middleware = create_custom_field_validation_middleware(
557
+ fields=[
558
+ # Required field with example value for easy copy-paste
559
+ {"name": "store_num", "description": "Your store number", "example_value": "12345"},
560
+ # Optional fields (required=False)
561
+ {"name": "thread_id", "description": "Thread ID", "required": False, "example_value": "1"},
562
+ {"name": "user_id", "description": "User ID", "required": False, "example_value": "my_user_id"},
563
+ ],
564
+ )
565
+ """
566
+ field_names = [f.get("name", "unknown") for f in fields]
567
+ logger.trace("Creating custom field validation middleware", fields=field_names)
568
+ return CustomFieldValidationMiddleware(fields=fields)
569
+
570
+
571
+ def create_filter_last_human_message_middleware() -> FilterLastHumanMessageMiddleware:
572
+ """
573
+ Create a FilterLastHumanMessageMiddleware instance.
574
+
575
+ Factory function for creating middleware that filters messages to keep
576
+ only the last human message, useful for scenarios where you want to
577
+ process only the latest user input without conversation history.
578
+
579
+ Returns:
580
+ List containing FilterLastHumanMessageMiddleware instance
581
+
582
+ Example:
583
+ middleware = create_filter_last_human_message_middleware()
584
+ """
585
+ logger.trace("Creating filter_last_human_message middleware")
586
+ return FilterLastHumanMessageMiddleware()
@@ -0,0 +1,77 @@
1
+ """
2
+ Model call limit middleware for DAO AI agents.
3
+
4
+ Limits the number of model (LLM) calls to prevent infinite loops or excessive costs.
5
+
6
+ Example:
7
+ from dao_ai.middleware import create_model_call_limit_middleware
8
+
9
+ # Limit model calls per run and thread
10
+ middleware = create_model_call_limit_middleware(
11
+ thread_limit=10,
12
+ run_limit=5,
13
+ )
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ from typing import Literal
19
+
20
+ from langchain.agents.middleware import ModelCallLimitMiddleware
21
+ from loguru import logger
22
+
23
+ __all__ = [
24
+ "ModelCallLimitMiddleware",
25
+ "create_model_call_limit_middleware",
26
+ ]
27
+
28
+
29
+ def create_model_call_limit_middleware(
30
+ thread_limit: int | None = None,
31
+ run_limit: int | None = None,
32
+ exit_behavior: Literal["error", "end"] = "end",
33
+ ) -> ModelCallLimitMiddleware:
34
+ """
35
+ Create a ModelCallLimitMiddleware to limit LLM API calls.
36
+
37
+ Prevents runaway agents from making too many API calls and helps
38
+ enforce cost controls on production deployments.
39
+
40
+ Args:
41
+ thread_limit: Max model calls per thread (conversation).
42
+ Requires checkpointer. None = no limit.
43
+ run_limit: Max model calls per run (single invocation).
44
+ None = no limit.
45
+ exit_behavior: What to do when limit hit:
46
+ - "end": Stop execution gracefully (default)
47
+ - "error": Raise ModelCallLimitExceededError immediately
48
+
49
+ Returns:
50
+ List containing ModelCallLimitMiddleware instance
51
+
52
+ Raises:
53
+ ValueError: If no limits specified
54
+
55
+ Example:
56
+ # Limit to 5 model calls per run, 10 per thread
57
+ limiter = create_model_call_limit_middleware(
58
+ run_limit=5,
59
+ thread_limit=10,
60
+ exit_behavior="end",
61
+ )
62
+ """
63
+ if thread_limit is None and run_limit is None:
64
+ raise ValueError("At least one of thread_limit or run_limit must be specified.")
65
+
66
+ logger.debug(
67
+ "Creating model call limit middleware",
68
+ thread_limit=thread_limit,
69
+ run_limit=run_limit,
70
+ exit_behavior=exit_behavior,
71
+ )
72
+
73
+ return ModelCallLimitMiddleware(
74
+ thread_limit=thread_limit,
75
+ run_limit=run_limit,
76
+ exit_behavior=exit_behavior,
77
+ )