dao-ai 0.0.36__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. dao_ai/__init__.py +29 -0
  2. dao_ai/cli.py +195 -30
  3. dao_ai/config.py +770 -244
  4. dao_ai/genie/__init__.py +1 -22
  5. dao_ai/genie/cache/__init__.py +1 -2
  6. dao_ai/genie/cache/base.py +20 -70
  7. dao_ai/genie/cache/core.py +75 -0
  8. dao_ai/genie/cache/lru.py +44 -21
  9. dao_ai/genie/cache/semantic.py +390 -109
  10. dao_ai/genie/core.py +35 -0
  11. dao_ai/graph.py +27 -253
  12. dao_ai/hooks/__init__.py +9 -6
  13. dao_ai/hooks/core.py +22 -190
  14. dao_ai/memory/__init__.py +10 -0
  15. dao_ai/memory/core.py +23 -5
  16. dao_ai/memory/databricks.py +389 -0
  17. dao_ai/memory/postgres.py +2 -2
  18. dao_ai/messages.py +6 -4
  19. dao_ai/middleware/__init__.py +125 -0
  20. dao_ai/middleware/assertions.py +778 -0
  21. dao_ai/middleware/base.py +50 -0
  22. dao_ai/middleware/core.py +61 -0
  23. dao_ai/middleware/guardrails.py +415 -0
  24. dao_ai/middleware/human_in_the_loop.py +228 -0
  25. dao_ai/middleware/message_validation.py +554 -0
  26. dao_ai/middleware/summarization.py +192 -0
  27. dao_ai/models.py +1177 -108
  28. dao_ai/nodes.py +118 -161
  29. dao_ai/optimization.py +664 -0
  30. dao_ai/orchestration/__init__.py +52 -0
  31. dao_ai/orchestration/core.py +287 -0
  32. dao_ai/orchestration/supervisor.py +264 -0
  33. dao_ai/orchestration/swarm.py +226 -0
  34. dao_ai/prompts.py +126 -29
  35. dao_ai/providers/databricks.py +126 -381
  36. dao_ai/state.py +139 -21
  37. dao_ai/tools/__init__.py +8 -5
  38. dao_ai/tools/core.py +57 -4
  39. dao_ai/tools/email.py +280 -0
  40. dao_ai/tools/genie.py +47 -24
  41. dao_ai/tools/mcp.py +4 -3
  42. dao_ai/tools/memory.py +50 -0
  43. dao_ai/tools/python.py +4 -12
  44. dao_ai/tools/search.py +14 -0
  45. dao_ai/tools/slack.py +1 -1
  46. dao_ai/tools/unity_catalog.py +8 -6
  47. dao_ai/tools/vector_search.py +16 -9
  48. dao_ai/utils.py +72 -8
  49. dao_ai-0.1.0.dist-info/METADATA +1878 -0
  50. dao_ai-0.1.0.dist-info/RECORD +62 -0
  51. dao_ai/chat_models.py +0 -204
  52. dao_ai/guardrails.py +0 -112
  53. dao_ai/tools/genie/__init__.py +0 -236
  54. dao_ai/tools/human_in_the_loop.py +0 -100
  55. dao_ai-0.0.36.dist-info/METADATA +0 -951
  56. dao_ai-0.0.36.dist-info/RECORD +0 -47
  57. {dao_ai-0.0.36.dist-info → dao_ai-0.1.0.dist-info}/WHEEL +0 -0
  58. {dao_ai-0.0.36.dist-info → dao_ai-0.1.0.dist-info}/entry_points.txt +0 -0
  59. {dao_ai-0.0.36.dist-info → dao_ai-0.1.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,554 @@
1
+ """
2
+ Message validation middleware for DAO AI agents.
3
+
4
+ These middleware implementations validate incoming messages and context
5
+ before agent processing begins.
6
+
7
+ Factory functions are provided for consistent configuration via the
8
+ DAO AI middleware factory pattern.
9
+ """
10
+
11
+ import json
12
+ from typing import Any
13
+
14
+ from langchain.agents.middleware import hook_config
15
+ from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, RemoveMessage
16
+ from langgraph.runtime import Runtime
17
+ from loguru import logger
18
+
19
+ from dao_ai.messages import last_human_message
20
+ from dao_ai.middleware.base import AgentMiddleware
21
+ from dao_ai.state import AgentState, Context
22
+
23
+ __all__ = [
24
+ "MessageValidationMiddleware",
25
+ "UserIdValidationMiddleware",
26
+ "ThreadIdValidationMiddleware",
27
+ "CustomFieldValidationMiddleware",
28
+ "RequiredField",
29
+ "FilterLastHumanMessageMiddleware",
30
+ "create_user_id_validation_middleware",
31
+ "create_thread_id_validation_middleware",
32
+ "create_custom_field_validation_middleware",
33
+ "create_filter_last_human_message_middleware",
34
+ ]
35
+
36
+
37
+ class MessageValidationMiddleware(AgentMiddleware[AgentState, Context]):
38
+ """
39
+ Base middleware for message validation.
40
+
41
+ Subclasses should implement the validate method to perform
42
+ specific validation logic.
43
+ """
44
+
45
+ @hook_config(can_jump_to=["end"])
46
+ def before_agent(
47
+ self, state: AgentState, runtime: Runtime[Context]
48
+ ) -> dict[str, Any] | None:
49
+ """Validate messages before agent processing."""
50
+ try:
51
+ return self.validate(state, runtime)
52
+ except ValueError as e:
53
+ logger.error(f"Message validation failed: {e}")
54
+ return {
55
+ "is_valid": False,
56
+ "message_error": str(e),
57
+ "messages": [AIMessage(content=str(e))],
58
+ "jump_to": "end",
59
+ }
60
+
61
+ def validate(
62
+ self, state: AgentState, runtime: Runtime[Context]
63
+ ) -> dict[str, Any] | None:
64
+ """
65
+ Perform validation logic.
66
+
67
+ Override this method in subclasses to implement specific validation.
68
+ Raise ValueError to indicate validation failure.
69
+
70
+ Args:
71
+ state: The current agent state
72
+ runtime: The LangGraph runtime context
73
+
74
+ Returns:
75
+ Optional dict with state updates
76
+
77
+ Raises:
78
+ ValueError: If validation fails
79
+ """
80
+ return None
81
+
82
+
83
+ class UserIdValidationMiddleware(MessageValidationMiddleware):
84
+ """
85
+ Middleware that validates the presence and format of user_id.
86
+
87
+ Ensures that:
88
+ - user_id is provided in the context
89
+ - user_id does not contain invalid characters (like dots)
90
+ """
91
+
92
+ def validate(
93
+ self, state: AgentState, runtime: Runtime[Context]
94
+ ) -> dict[str, Any] | None:
95
+ """Validate user_id is present and properly formatted."""
96
+ logger.debug("Executing user_id validation")
97
+
98
+ context: Context = runtime.context or Context()
99
+ user_id: str | None = context.user_id
100
+
101
+ if not user_id:
102
+ logger.error("User ID is required but not provided in the configuration.")
103
+
104
+ thread_val = context.thread_id or "<your_thread_id>"
105
+ corrected_config: dict[str, Any] = {
106
+ "configurable": {
107
+ "thread_id": thread_val,
108
+ "user_id": "<your_user_id>",
109
+ **context.custom,
110
+ },
111
+ "session": {
112
+ "conversation_id": thread_val,
113
+ },
114
+ }
115
+ corrected_config_json = json.dumps(corrected_config, indent=2)
116
+
117
+ error_message = f"""
118
+ ## Authentication Required
119
+
120
+ A **user_id** is required to process your request. Please provide your user ID in the configuration.
121
+
122
+ ### Required Configuration Format
123
+
124
+ Please include the following JSON in your request configuration:
125
+
126
+ ```json
127
+ {corrected_config_json}
128
+ ```
129
+
130
+ ### Field Descriptions
131
+ - **thread_id**: Thread identifier (required in configurable)
132
+ - **conversation_id**: Alias of thread_id (in session)
133
+ - **user_id**: Your unique user identifier (required)
134
+
135
+ Please update your configuration and try again.
136
+ """.strip()
137
+
138
+ raise ValueError(error_message)
139
+
140
+ if "." in user_id:
141
+ logger.error(f"User ID '{user_id}' contains invalid character '.'")
142
+
143
+ corrected_user_id = user_id.replace(".", "_")
144
+ thread_val = context.thread_id or "<your_thread_id>"
145
+ corrected_config: dict[str, Any] = {
146
+ "configurable": {
147
+ "thread_id": thread_val,
148
+ "user_id": corrected_user_id,
149
+ **context.custom,
150
+ },
151
+ "session": {
152
+ "conversation_id": thread_val,
153
+ },
154
+ }
155
+ corrected_config_json = json.dumps(corrected_config, indent=2)
156
+
157
+ error_message = f"""
158
+ ## Invalid User ID Format
159
+
160
+ The **user_id** cannot contain a dot character ('.'). Please provide a valid user ID without dots.
161
+
162
+ ### Corrected Configuration (Copy & Paste This)
163
+ ```json
164
+ {corrected_config_json}
165
+ ```
166
+
167
+ Please update your user_id and try again.
168
+ """.strip()
169
+
170
+ raise ValueError(error_message)
171
+
172
+ return None
173
+
174
+
175
+ class ThreadIdValidationMiddleware(MessageValidationMiddleware):
176
+ """
177
+ Middleware that validates the presence of thread_id/conversation_id.
178
+
179
+ Note: thread_id and conversation_id are interchangeable in configurable.
180
+ """
181
+
182
+ def validate(
183
+ self, state: AgentState, runtime: Runtime[Context]
184
+ ) -> dict[str, Any] | None:
185
+ """Validate thread_id/conversation_id is present."""
186
+ logger.debug("Executing thread_id/conversation_id validation")
187
+
188
+ context: Context = runtime.context or Context()
189
+ thread_id: str | None = context.thread_id
190
+
191
+ if not thread_id:
192
+ logger.error("Thread ID / Conversation ID is required but not provided.")
193
+
194
+ corrected_config: dict[str, Any] = {
195
+ "configurable": {
196
+ "thread_id": "<your_thread_id>",
197
+ "user_id": context.user_id or "<your_user_id>",
198
+ **context.custom,
199
+ },
200
+ "session": {
201
+ "conversation_id": "<your_thread_id>",
202
+ },
203
+ }
204
+ corrected_config_json = json.dumps(corrected_config, indent=2)
205
+
206
+ error_message = f"""
207
+ ## Configuration Required
208
+
209
+ A **thread_id** is required to process your request (or **conversation_id** as an alias).
210
+
211
+ ### Required Configuration Format
212
+
213
+ Please include the following JSON in your request configuration:
214
+
215
+ ```json
216
+ {corrected_config_json}
217
+ ```
218
+
219
+ ### Field Descriptions
220
+ - **thread_id**: Thread identifier (required in configurable)
221
+ - **conversation_id**: Alias of thread_id (in session)
222
+ - **user_id**: Your unique user identifier (required)
223
+
224
+ Please update your configuration and try again.
225
+ """.strip()
226
+
227
+ raise ValueError(error_message)
228
+
229
+ return None
230
+
231
+
232
+ class RequiredField:
233
+ """Definition of a field for validation.
234
+
235
+ Fields are marked as required or optional via the `required` flag:
236
+ - required=True (default): Field must be provided, validated
237
+ - required=False: Field is optional, not validated
238
+
239
+ For required fields, an `example_value` can be provided to show in error
240
+ messages, making it easy for users to copy-paste the configuration.
241
+
242
+ Args:
243
+ name: The field name (e.g., "store_num", "user_id")
244
+ description: Human-readable description for error messages
245
+ required: Whether this field is required (default: True)
246
+ example_value: Example value to show in error messages for missing fields
247
+ """
248
+
249
+ def __init__(
250
+ self,
251
+ name: str,
252
+ description: str | None = None,
253
+ required: bool = True,
254
+ example_value: Any = None,
255
+ ):
256
+ self.name = name
257
+ self.description = description or f"Your {name}"
258
+ self.required = required
259
+ self.example_value = example_value
260
+
261
+ @property
262
+ def is_required(self) -> bool:
263
+ """A field is required based on the required flag."""
264
+ return self.required
265
+
266
+
267
+ class CustomFieldValidationMiddleware(MessageValidationMiddleware):
268
+ """
269
+ Middleware that validates the presence of required custom fields.
270
+
271
+ This is a generic validation middleware that can check for multiple
272
+ required fields in context.custom.
273
+
274
+ Fields are defined in the `fields` list. Each field can have:
275
+ - name: The field name (required)
276
+ - description: Human-readable description for error messages
277
+ - required: Whether field is required (default: True)
278
+ - example_value: Example value to show in error messages
279
+
280
+ Required fields (required=True) will be validated.
281
+ The example_value is used in error messages to help users copy-paste
282
+ the correct configuration format.
283
+
284
+ Args:
285
+ fields: List of fields to validate/show. Each can be a RequiredField
286
+ or a dict with 'name', 'description', 'required', and 'example_value' keys.
287
+ """
288
+
289
+ def __init__(
290
+ self,
291
+ fields: list[RequiredField | dict[str, Any]],
292
+ ):
293
+ super().__init__()
294
+
295
+ # Convert fields to RequiredField objects
296
+ self.fields: list[RequiredField] = []
297
+ for field in fields:
298
+ if isinstance(field, RequiredField):
299
+ self.fields.append(field)
300
+ elif isinstance(field, dict):
301
+ self.fields.append(RequiredField(**field))
302
+
303
+ def validate(
304
+ self, state: AgentState, runtime: Runtime[Context]
305
+ ) -> dict[str, Any] | None:
306
+ """Validate that all required fields are present.
307
+
308
+ Generates error messages with the new input structure:
309
+ configurable:
310
+ conversation_id: "abc-123"
311
+ user_id: "nate.fleming"
312
+ <field_name>: <example_value>
313
+ session: {}
314
+ """
315
+ logger.debug("Executing custom field validation")
316
+
317
+ context: Context = runtime.context or Context()
318
+
319
+ # Find all missing required fields
320
+ missing_fields: list[RequiredField] = []
321
+ for field in self.fields:
322
+ if field.is_required:
323
+ field_value: Any = context.custom.get(field.name)
324
+ if field_value is None:
325
+ missing_fields.append(field)
326
+
327
+ if not missing_fields:
328
+ return None
329
+
330
+ # Log the missing fields
331
+ missing_names = [f.name for f in missing_fields]
332
+ logger.error(f"Required fields missing: {', '.join(missing_names)}")
333
+
334
+ # Build the configurable dict preserving provided values
335
+ # and using example_value for missing required fields
336
+ # Note: only thread_id is in configurable (conversation_id goes in session)
337
+ configurable: dict[str, Any] = {}
338
+
339
+ thread_val = context.thread_id or "<your_thread_id>"
340
+ configurable["thread_id"] = thread_val
341
+
342
+ if context.user_id:
343
+ configurable["user_id"] = context.user_id
344
+ else:
345
+ configurable["user_id"] = "<your_user_id>"
346
+
347
+ # Add all values the user already provided in custom
348
+ for k, v in context.custom.items():
349
+ configurable[k] = v
350
+
351
+ # Then add our defined fields (provided values take precedence)
352
+ for field in self.fields:
353
+ if field.name in configurable:
354
+ # Field was provided by user - keep their value
355
+ continue
356
+
357
+ if field.is_required:
358
+ # Missing required field - use example_value or placeholder
359
+ configurable[field.name] = (
360
+ field.example_value
361
+ if field.example_value is not None
362
+ else f"<your_{field.name}>"
363
+ )
364
+ else:
365
+ # Optional field not provided - use example_value if available
366
+ if field.example_value is not None:
367
+ configurable[field.name] = field.example_value
368
+
369
+ # Build the corrected config with new structure
370
+ # Note: conversation_id is in session as an alias of thread_id
371
+ corrected_config: dict[str, Any] = {
372
+ "configurable": configurable,
373
+ "session": {
374
+ "conversation_id": thread_val,
375
+ },
376
+ }
377
+ corrected_config_json = json.dumps(corrected_config, indent=2)
378
+
379
+ # Build field descriptions
380
+ field_descriptions: list[str] = [
381
+ "- **thread_id**: Thread identifier (required in configurable)",
382
+ "- **conversation_id**: Alias of thread_id (in session)",
383
+ "- **user_id**: Your unique user identifier (required)",
384
+ ]
385
+ for field in self.fields:
386
+ required_text = "(required)" if field.is_required else "(optional)"
387
+ field_descriptions.append(
388
+ f"- **{field.name}**: {field.description} {required_text}"
389
+ )
390
+
391
+ field_descriptions_text = "\n".join(field_descriptions)
392
+
393
+ # Build the list of missing field names for the error message
394
+ missing_names_formatted = ", ".join(f"**{f.name}**" for f in missing_fields)
395
+
396
+ error_message = f"""
397
+ ## Configuration Required
398
+
399
+ The following required fields are missing: {missing_names_formatted}
400
+
401
+ ### Required Configuration Format
402
+
403
+ Please include the following JSON in your request configuration:
404
+
405
+ ```json
406
+ {corrected_config_json}
407
+ ```
408
+
409
+ ### Field Descriptions
410
+ {field_descriptions_text}
411
+
412
+ Please update your configuration and try again.
413
+ """.strip()
414
+
415
+ raise ValueError(error_message)
416
+
417
+
418
+ class FilterLastHumanMessageMiddleware(AgentMiddleware[AgentState, Context]):
419
+ """
420
+ Middleware that filters messages to keep only the last human message.
421
+
422
+ This is useful for scenarios where you want to process only the
423
+ latest user input without conversation history.
424
+ """
425
+
426
+ def before_model(
427
+ self, state: AgentState, runtime: Runtime[Context]
428
+ ) -> dict[str, Any] | None:
429
+ """Filter messages to keep only the last human message."""
430
+ logger.debug("Executing filter_last_human_message middleware")
431
+
432
+ messages: list[BaseMessage] = state.get("messages", [])
433
+
434
+ if not messages:
435
+ logger.debug("No messages found in state")
436
+ return None
437
+
438
+ last_message: HumanMessage | None = last_human_message(messages)
439
+
440
+ if last_message is None:
441
+ logger.debug("No human messages found in state")
442
+ return {"messages": []}
443
+
444
+ logger.debug(
445
+ f"Filtered {len(messages)} messages down to 1 (last human message)"
446
+ )
447
+
448
+ removed_messages = [
449
+ RemoveMessage(id=message.id)
450
+ for message in messages
451
+ if message.id != last_message.id
452
+ ]
453
+
454
+ return {"messages": removed_messages}
455
+
456
+
457
+ # =============================================================================
458
+ # Factory Functions
459
+ # =============================================================================
460
+
461
+
462
+ def create_user_id_validation_middleware() -> UserIdValidationMiddleware:
463
+ """
464
+ Create a UserIdValidationMiddleware instance.
465
+
466
+ Factory function for creating middleware that validates the presence
467
+ and format of user_id in the runtime context.
468
+
469
+ Returns:
470
+ UserIdValidationMiddleware instance
471
+
472
+ Example:
473
+ middleware = create_user_id_validation_middleware()
474
+ """
475
+ logger.debug("Creating user_id validation middleware")
476
+ return UserIdValidationMiddleware()
477
+
478
+
479
+ def create_thread_id_validation_middleware() -> ThreadIdValidationMiddleware:
480
+ """
481
+ Create a ThreadIdValidationMiddleware instance.
482
+
483
+ Factory function for creating middleware that validates the presence
484
+ of thread_id in the runtime context.
485
+
486
+ Returns:
487
+ ThreadIdValidationMiddleware instance
488
+
489
+ Example:
490
+ middleware = create_thread_id_validation_middleware()
491
+ """
492
+ logger.debug("Creating thread_id validation middleware")
493
+ return ThreadIdValidationMiddleware()
494
+
495
+
496
+ def create_custom_field_validation_middleware(
497
+ fields: list[dict[str, Any]],
498
+ ) -> CustomFieldValidationMiddleware:
499
+ """
500
+ Create a CustomFieldValidationMiddleware instance.
501
+
502
+ Factory function for creating middleware that validates the presence
503
+ of required custom fields in context.custom.
504
+
505
+ Each field in the list should have:
506
+ - name: The field name (required)
507
+ - description: Human-readable description for error messages (optional)
508
+ - required: Whether field is required (default: True)
509
+ - example_value: Example value to show in error messages (optional)
510
+
511
+ Required fields (required=True or not specified) will be validated.
512
+ The example_value is used in error messages to help users copy-paste.
513
+
514
+ Args:
515
+ fields: List of field definitions. Each dict should have 'name', and
516
+ optionally 'description', 'required', and 'example_value' keys.
517
+
518
+ Returns:
519
+ CustomFieldValidationMiddleware configured with the specified fields
520
+
521
+ Example:
522
+ middleware = create_custom_field_validation_middleware(
523
+ fields=[
524
+ # Required field with example value for easy copy-paste
525
+ {"name": "store_num", "description": "Your store number", "example_value": "12345"},
526
+ # Optional fields (required=False)
527
+ {"name": "thread_id", "description": "Thread ID", "required": False, "example_value": "1"},
528
+ {"name": "user_id", "description": "User ID", "required": False, "example_value": "my_user_id"},
529
+ ],
530
+ )
531
+ """
532
+ field_names = [f.get("name", "unknown") for f in fields]
533
+ logger.debug(
534
+ f"Creating custom field validation middleware for fields: {field_names}"
535
+ )
536
+ return CustomFieldValidationMiddleware(fields=fields)
537
+
538
+
539
+ def create_filter_last_human_message_middleware() -> FilterLastHumanMessageMiddleware:
540
+ """
541
+ Create a FilterLastHumanMessageMiddleware instance.
542
+
543
+ Factory function for creating middleware that filters messages to keep
544
+ only the last human message, useful for scenarios where you want to
545
+ process only the latest user input without conversation history.
546
+
547
+ Returns:
548
+ FilterLastHumanMessageMiddleware instance
549
+
550
+ Example:
551
+ middleware = create_filter_last_human_message_middleware()
552
+ """
553
+ logger.debug("Creating filter_last_human_message middleware")
554
+ return FilterLastHumanMessageMiddleware()