dao-ai 0.1.2__py3-none-any.whl → 0.1.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. dao_ai/apps/__init__.py +24 -0
  2. dao_ai/apps/handlers.py +105 -0
  3. dao_ai/apps/model_serving.py +29 -0
  4. dao_ai/apps/resources.py +1122 -0
  5. dao_ai/apps/server.py +39 -0
  6. dao_ai/cli.py +546 -37
  7. dao_ai/config.py +1179 -139
  8. dao_ai/evaluation.py +543 -0
  9. dao_ai/genie/__init__.py +55 -7
  10. dao_ai/genie/cache/__init__.py +34 -7
  11. dao_ai/genie/cache/base.py +143 -2
  12. dao_ai/genie/cache/context_aware/__init__.py +31 -0
  13. dao_ai/genie/cache/context_aware/base.py +1151 -0
  14. dao_ai/genie/cache/context_aware/in_memory.py +609 -0
  15. dao_ai/genie/cache/context_aware/persistent.py +802 -0
  16. dao_ai/genie/cache/context_aware/postgres.py +1166 -0
  17. dao_ai/genie/cache/core.py +1 -1
  18. dao_ai/genie/cache/lru.py +257 -75
  19. dao_ai/genie/cache/optimization.py +890 -0
  20. dao_ai/genie/core.py +235 -11
  21. dao_ai/memory/postgres.py +175 -39
  22. dao_ai/middleware/__init__.py +38 -0
  23. dao_ai/middleware/assertions.py +3 -3
  24. dao_ai/middleware/context_editing.py +230 -0
  25. dao_ai/middleware/core.py +4 -4
  26. dao_ai/middleware/guardrails.py +3 -3
  27. dao_ai/middleware/human_in_the_loop.py +3 -2
  28. dao_ai/middleware/message_validation.py +4 -4
  29. dao_ai/middleware/model_call_limit.py +77 -0
  30. dao_ai/middleware/model_retry.py +121 -0
  31. dao_ai/middleware/pii.py +157 -0
  32. dao_ai/middleware/summarization.py +1 -1
  33. dao_ai/middleware/tool_call_limit.py +210 -0
  34. dao_ai/middleware/tool_retry.py +174 -0
  35. dao_ai/middleware/tool_selector.py +129 -0
  36. dao_ai/models.py +327 -370
  37. dao_ai/nodes.py +9 -16
  38. dao_ai/orchestration/core.py +33 -9
  39. dao_ai/orchestration/supervisor.py +29 -13
  40. dao_ai/orchestration/swarm.py +6 -1
  41. dao_ai/{prompts.py → prompts/__init__.py} +12 -61
  42. dao_ai/prompts/instructed_retriever_decomposition.yaml +58 -0
  43. dao_ai/prompts/instruction_reranker.yaml +14 -0
  44. dao_ai/prompts/router.yaml +37 -0
  45. dao_ai/prompts/verifier.yaml +46 -0
  46. dao_ai/providers/base.py +28 -2
  47. dao_ai/providers/databricks.py +363 -33
  48. dao_ai/state.py +1 -0
  49. dao_ai/tools/__init__.py +5 -3
  50. dao_ai/tools/genie.py +103 -26
  51. dao_ai/tools/instructed_retriever.py +366 -0
  52. dao_ai/tools/instruction_reranker.py +202 -0
  53. dao_ai/tools/mcp.py +539 -97
  54. dao_ai/tools/router.py +89 -0
  55. dao_ai/tools/slack.py +13 -2
  56. dao_ai/tools/sql.py +7 -3
  57. dao_ai/tools/unity_catalog.py +32 -10
  58. dao_ai/tools/vector_search.py +493 -160
  59. dao_ai/tools/verifier.py +159 -0
  60. dao_ai/utils.py +182 -2
  61. dao_ai/vector_search.py +46 -1
  62. {dao_ai-0.1.2.dist-info → dao_ai-0.1.20.dist-info}/METADATA +45 -9
  63. dao_ai-0.1.20.dist-info/RECORD +89 -0
  64. dao_ai/agent_as_code.py +0 -22
  65. dao_ai/genie/cache/semantic.py +0 -970
  66. dao_ai-0.1.2.dist-info/RECORD +0 -64
  67. {dao_ai-0.1.2.dist-info → dao_ai-0.1.20.dist-info}/WHEEL +0 -0
  68. {dao_ai-0.1.2.dist-info → dao_ai-0.1.20.dist-info}/entry_points.txt +0 -0
  69. {dao_ai-0.1.2.dist-info → dao_ai-0.1.20.dist-info}/licenses/LICENSE +0 -0
dao_ai/evaluation.py ADDED
@@ -0,0 +1,543 @@
1
+ """
2
+ DAO AI Evaluation Module
3
+
4
+ Provides reusable scorers and helper functions for MLflow GenAI evaluation.
5
+ Implements MLflow 3.8+ best practices for trace linking and scorer patterns.
6
+ """
7
+
8
+ from typing import Any, Callable, Optional, TypedDict, Union
9
+
10
+ import mlflow
11
+ from loguru import logger
12
+ from mlflow.entities import Feedback, SpanStatus, SpanStatusCode, Trace
13
+ from mlflow.entities.span import Span
14
+ from mlflow.genai.scorers import Guidelines, Safety, scorer
15
+
16
+ # -----------------------------------------------------------------------------
17
+ # Type Definitions
18
+ # -----------------------------------------------------------------------------
19
+
20
+
21
+ class ResponseOutput(TypedDict, total=False):
22
+ """Expected output format with response content."""
23
+
24
+ response: str
25
+
26
+
27
+ class NestedOutput(TypedDict, total=False):
28
+ """Nested output format from some MLflow configurations."""
29
+
30
+ outputs: ResponseOutput
31
+
32
+
33
+ # Union type for all possible output formats from MLflow
34
+ ScorerOutputs = Union[str, ResponseOutput, NestedOutput, dict[str, Any]]
35
+
36
+
37
+ # -----------------------------------------------------------------------------
38
+ # Helper for extracting response content
39
+ # -----------------------------------------------------------------------------
40
+
41
+
42
+ def _extract_response_content(outputs: ScorerOutputs) -> tuple[str, Optional[str]]:
43
+ """
44
+ Extract response content from various output formats.
45
+
46
+ Args:
47
+ outputs: Model outputs in any supported format
48
+
49
+ Returns:
50
+ Tuple of (content_string, error_message). If error_message is not None,
51
+ content_string will be empty and the error should be returned.
52
+ """
53
+ if isinstance(outputs, str):
54
+ return outputs, None
55
+
56
+ if isinstance(outputs, dict):
57
+ # Check for nested format first: {"outputs": {"response": "..."}}
58
+ nested_outputs = outputs.get("outputs")
59
+ if isinstance(nested_outputs, dict):
60
+ content = nested_outputs.get("response", "")
61
+ return str(content) if content else "", None
62
+
63
+ # Flat format: {"response": "..."}
64
+ content = outputs.get("response", "")
65
+ return str(content) if content else "", None
66
+
67
+ return "", f"Unexpected output type: {type(outputs).__name__}"
68
+
69
+
70
+ # -----------------------------------------------------------------------------
71
+ # Custom Scorers
72
+ # -----------------------------------------------------------------------------
73
+
74
+
75
+ @scorer
76
+ def response_completeness(outputs: ScorerOutputs) -> Feedback:
77
+ """
78
+ Evaluate if the response appears complete and meaningful.
79
+
80
+ This scorer checks:
81
+ - Response is not too short (< 10 characters)
82
+ - Response doesn't end with incomplete markers
83
+
84
+ Args:
85
+ outputs: Model outputs - can be a string, dict with "response" key,
86
+ or nested dict with "outputs.response".
87
+
88
+ Returns:
89
+ Feedback: Pass/Fail feedback with rationale
90
+ """
91
+ content, error = _extract_response_content(outputs)
92
+
93
+ if error:
94
+ return Feedback(value=False, rationale=error)
95
+
96
+ if not content:
97
+ return Feedback(value=False, rationale="No response content found in outputs")
98
+
99
+ if len(content.strip()) < 10:
100
+ return Feedback(value=False, rationale="Response too short to be meaningful")
101
+
102
+ incomplete_markers = ("...", "etc", "and so on", "to be continued")
103
+ if content.lower().rstrip().endswith(incomplete_markers):
104
+ return Feedback(value=False, rationale="Response appears incomplete")
105
+
106
+ return Feedback(value=True, rationale="Response appears complete")
107
+
108
+
109
+ @scorer
110
+ def tool_call_efficiency(trace: Optional[Trace]) -> Feedback:
111
+ """
112
+ Evaluate how effectively the agent uses tools.
113
+
114
+ This trace-based scorer checks:
115
+ - Presence of tool calls
116
+ - Redundant tool calls (same tool called multiple times)
117
+ - Failed tool calls
118
+
119
+ Args:
120
+ trace: MLflow Trace object containing span information.
121
+ May be None if tracing is not enabled.
122
+
123
+ Returns:
124
+ Feedback: Pass/Fail feedback with rationale, or error if trace unavailable
125
+ """
126
+ if trace is None:
127
+ # MLflow 3.8+ requires error instead of value=None
128
+ return Feedback(error=Exception("No trace available for tool call analysis"))
129
+
130
+ try:
131
+ # Retrieve all tool call spans from the trace
132
+ tool_calls: list[Span] = trace.search_spans(span_type="TOOL")
133
+ except Exception as e:
134
+ logger.warning(f"Error searching trace spans: {e}")
135
+ return Feedback(error=Exception(f"Error accessing trace spans: {str(e)}"))
136
+
137
+ if not tool_calls:
138
+ # No tools used is valid but not evaluable - return True with note
139
+ return Feedback(
140
+ value=True,
141
+ rationale="No tool usage to evaluate - agent responded without tools",
142
+ )
143
+
144
+ # Check for redundant calls (same tool name called multiple times)
145
+ tool_names: list[str] = [span.name for span in tool_calls]
146
+ if len(tool_names) != len(set(tool_names)):
147
+ # Count duplicates for better feedback
148
+ duplicates = [name for name in set(tool_names) if tool_names.count(name) > 1]
149
+ return Feedback(
150
+ value=False, rationale=f"Redundant tool calls detected: {duplicates}"
151
+ )
152
+
153
+ # Check for failed tool calls using typed SpanStatus
154
+ failed_calls: list[str] = []
155
+ for span in tool_calls:
156
+ span_status: SpanStatus = span.status
157
+ if span_status.status_code != SpanStatusCode.OK:
158
+ failed_calls.append(span.name)
159
+
160
+ if failed_calls:
161
+ return Feedback(
162
+ value=False,
163
+ rationale=f"{len(failed_calls)} tool calls failed: {failed_calls}",
164
+ )
165
+
166
+ return Feedback(
167
+ value=True,
168
+ rationale=f"Efficient tool usage: {len(tool_calls)} successful calls",
169
+ )
170
+
171
+
172
+ # -----------------------------------------------------------------------------
173
+ # Response Clarity Scorer
174
+ # -----------------------------------------------------------------------------
175
+
176
+ # Instructions for the response clarity judge
177
+ RESPONSE_CLARITY_INSTRUCTIONS = """Evaluate the clarity and readability of the response in {{ outputs }}.
178
+
179
+ Consider:
180
+ - Is the response easy to understand?
181
+ - Is the information well-organized?
182
+ - Does it avoid unnecessary jargon or explain technical terms when used?
183
+ - Is the sentence structure clear and coherent?
184
+ - Does the response directly address what was asked in {{ inputs }}?
185
+
186
+ Return "clear" if the response is clear and readable, "unclear" if it is confusing or poorly structured."""
187
+
188
+
189
+ def create_response_clarity_scorer(
190
+ judge_model: str,
191
+ name: str = "response_clarity",
192
+ ) -> Any:
193
+ """
194
+ Create a response clarity scorer using MLflow's make_judge.
195
+
196
+ This scorer evaluates whether a response is clear, well-organized,
197
+ and easy to understand. It uses an LLM judge for nuanced assessment
198
+ of qualities like:
199
+ - Sentence structure and coherence
200
+ - Information organization
201
+ - Appropriate use of technical language
202
+ - Overall readability
203
+
204
+ Args:
205
+ judge_model: The model endpoint to use for evaluation.
206
+ Example: "databricks:/databricks-claude-3-7-sonnet"
207
+ name: Name for this scorer instance.
208
+
209
+ Returns:
210
+ A judge scorer created by make_judge
211
+
212
+ Example:
213
+ ```python
214
+ from dao_ai.evaluation import create_response_clarity_scorer
215
+
216
+ # Create the scorer with a judge model
217
+ clarity_scorer = create_response_clarity_scorer(
218
+ judge_model="databricks:/databricks-claude-3-7-sonnet",
219
+ )
220
+
221
+ # Use in evaluation
222
+ mlflow.genai.evaluate(
223
+ data=eval_data,
224
+ predict_fn=predict_fn,
225
+ scorers=[clarity_scorer],
226
+ )
227
+ ```
228
+ """
229
+ from mlflow.genai.judges import make_judge
230
+
231
+ return make_judge(
232
+ name=name,
233
+ instructions=RESPONSE_CLARITY_INSTRUCTIONS,
234
+ # No feedback_value_type - avoids response_schema parameter
235
+ # which Databricks endpoints don't support
236
+ model=judge_model,
237
+ )
238
+
239
+
240
+ # -----------------------------------------------------------------------------
241
+ # Agent Routing Scorer
242
+ # -----------------------------------------------------------------------------
243
+
244
+ # Instructions for the agent routing judge
245
+ # Note: When using {{ trace }}, it must be the ONLY template variable per MLflow make_judge rules
246
+ AGENT_ROUTING_INSTRUCTIONS = """Evaluate whether the agent routing was appropriate for the user's request.
247
+
248
+ Analyze the {{ trace }} to determine:
249
+ 1. What was the user's original query or request
250
+ 2. Which agents, chains, or components were invoked to handle it
251
+ 3. Whether the routing sequence was logical and appropriate
252
+
253
+ Consider:
254
+ - **Relevance**: Based on the names of the components invoked, do they seem appropriate for the question type?
255
+ - **Logical Flow**: Does the sequence of invocations make sense for answering the query?
256
+ - **Completeness**: Were the right types of components invoked to fully address the query?
257
+
258
+ Note: You may not know all possible agents in the system. Focus on whether the components that WERE invoked seem reasonable given the user's query found in the trace.
259
+
260
+ Return "appropriate" if routing was appropriate, "inappropriate" if it was clearly wrong."""
261
+
262
+
263
+ def create_agent_routing_scorer(
264
+ judge_model: str,
265
+ name: str = "agent_routing",
266
+ ) -> Any:
267
+ """
268
+ Create an agent routing scorer using MLflow's make_judge.
269
+
270
+ This scorer analyzes the execution trace to evaluate whether the routing
271
+ decisions were appropriate for the user's query. It uses MLflow's built-in
272
+ trace-based judge functionality.
273
+
274
+ The scorer is general-purpose and does not require knowledge of specific
275
+ agent names - it relies on the LLM to interpret whether the components
276
+ invoked (based on their names and context) were suitable for the query.
277
+
278
+ Args:
279
+ judge_model: The model endpoint to use for evaluation.
280
+ Example: "databricks:/databricks-claude-3-7-sonnet"
281
+ name: Name for this scorer instance.
282
+
283
+ Returns:
284
+ A judge scorer created by make_judge
285
+
286
+ Example:
287
+ ```python
288
+ from dao_ai.evaluation import create_agent_routing_scorer
289
+
290
+ # Create the scorer with a judge model
291
+ agent_routing = create_agent_routing_scorer(
292
+ judge_model="databricks:/databricks-claude-3-7-sonnet",
293
+ )
294
+
295
+ # Use in evaluation
296
+ mlflow.genai.evaluate(
297
+ data=eval_data,
298
+ predict_fn=predict_fn,
299
+ scorers=[agent_routing],
300
+ )
301
+ ```
302
+ """
303
+ from mlflow.genai.judges import make_judge
304
+
305
+ return make_judge(
306
+ name=name,
307
+ instructions=AGENT_ROUTING_INSTRUCTIONS,
308
+ # No feedback_value_type - avoids response_schema parameter
309
+ # which Databricks endpoints don't support
310
+ model=judge_model,
311
+ )
312
+
313
+
314
+ # -----------------------------------------------------------------------------
315
+ # Helper Functions
316
+ # -----------------------------------------------------------------------------
317
+
318
+
319
+ def create_traced_predict_fn(
320
+ predict_callable: Callable[[dict[str, Any]], dict[str, Any]],
321
+ span_name: str = "predict",
322
+ ) -> Callable[[dict[str, Any]], dict[str, Any]]:
323
+ """
324
+ Wrap a predict function with MLflow tracing.
325
+
326
+ This ensures traces are created for each prediction, allowing
327
+ trace-based scorers to access span information.
328
+
329
+ Args:
330
+ predict_callable: The original prediction function
331
+ span_name: Name for the trace span
332
+
333
+ Returns:
334
+ Wrapped function with MLflow tracing enabled
335
+ """
336
+
337
+ @mlflow.trace(name=span_name, span_type="CHAIN")
338
+ def traced_predict(inputs: dict[str, Any]) -> dict[str, Any]:
339
+ result = predict_callable(inputs)
340
+ # Normalize output format to flat structure
341
+ if "outputs" in result and isinstance(result["outputs"], dict):
342
+ # Extract from nested format
343
+ return result["outputs"]
344
+ return result
345
+
346
+ return traced_predict
347
+
348
+
349
+ def create_guidelines_scorers(
350
+ guidelines_config: list[Any],
351
+ judge_model: str,
352
+ ) -> list[Guidelines]:
353
+ """
354
+ Create Guidelines scorers from configuration with proper judge model.
355
+
356
+ Args:
357
+ guidelines_config: List of guideline configurations with name and guidelines
358
+ judge_model: The model endpoint to use for evaluation (e.g., "databricks:/model-name")
359
+
360
+ Returns:
361
+ List of configured Guidelines scorers
362
+ """
363
+ scorers = []
364
+ for guideline in guidelines_config:
365
+ scorer_instance = Guidelines(
366
+ name=guideline.name,
367
+ guidelines=guideline.guidelines,
368
+ model=judge_model,
369
+ )
370
+ scorers.append(scorer_instance)
371
+ logger.debug(
372
+ f"Created Guidelines scorer: {guideline.name} with model {judge_model}"
373
+ )
374
+
375
+ return scorers
376
+
377
+
378
+ def get_default_scorers(
379
+ include_trace_scorers: bool = True,
380
+ include_agent_routing: bool = False,
381
+ judge_model: Optional[str] = None,
382
+ ) -> list[Any]:
383
+ """
384
+ Get the default set of scorers for evaluation.
385
+
386
+ Args:
387
+ include_trace_scorers: Whether to include trace-based scorers like tool_call_efficiency
388
+ include_agent_routing: Whether to include the agent routing scorer
389
+ judge_model: The model endpoint to use for LLM-based scorers (Safety, clarity, routing).
390
+ Example: "databricks-gpt-5-2"
391
+
392
+ Returns:
393
+ List of scorer instances
394
+ """
395
+ # Safety requires a judge model for LLM-based evaluation
396
+ if judge_model:
397
+ safety_scorer = Safety(model=judge_model)
398
+ else:
399
+ safety_scorer = Safety()
400
+ logger.warning(
401
+ "No judge_model provided for Safety scorer. "
402
+ "This may cause errors if no default model is configured."
403
+ )
404
+
405
+ scorers: list[Any] = [
406
+ safety_scorer,
407
+ response_completeness,
408
+ ]
409
+
410
+ # TODO: Re-enable when Databricks endpoints support make_judge
411
+ # if judge_model:
412
+ # scorers.append(create_response_clarity_scorer(judge_model=judge_model))
413
+
414
+ if include_trace_scorers:
415
+ scorers.append(tool_call_efficiency)
416
+
417
+ # TODO: Re-enable when Databricks endpoints support make_judge
418
+ # if include_agent_routing and judge_model:
419
+ # scorers.append(create_agent_routing_scorer(judge_model=judge_model))
420
+
421
+ return scorers
422
+
423
+
424
+ def setup_evaluation_tracking(
425
+ experiment_id: Optional[str] = None,
426
+ experiment_name: Optional[str] = None,
427
+ ) -> None:
428
+ """
429
+ Set up MLflow tracking for evaluation.
430
+
431
+ Configures:
432
+ - Registry URI to databricks-uc
433
+ - Experiment context
434
+ - Autologging for LangChain
435
+
436
+ Args:
437
+ experiment_id: Optional experiment ID to use
438
+ experiment_name: Optional experiment name to use (creates if doesn't exist)
439
+ """
440
+ mlflow.set_registry_uri("databricks-uc")
441
+
442
+ if experiment_id:
443
+ mlflow.set_experiment(experiment_id=experiment_id)
444
+ elif experiment_name:
445
+ mlflow.set_experiment(experiment_name=experiment_name)
446
+
447
+ # Enable autologging with trace support
448
+ mlflow.langchain.autolog(log_traces=True)
449
+ logger.debug("MLflow evaluation tracking configured")
450
+
451
+
452
+ def run_evaluation(
453
+ data: Any,
454
+ predict_fn: Callable,
455
+ model_id: Optional[str] = None,
456
+ scorers: Optional[list[Any]] = None,
457
+ judge_model: Optional[str] = None,
458
+ guidelines: Optional[list[Any]] = None,
459
+ ) -> Any:
460
+ """
461
+ Run MLflow GenAI evaluation with proper configuration.
462
+
463
+ This is a convenience wrapper around mlflow.genai.evaluate() that:
464
+ - Wraps predict_fn with tracing
465
+ - Configures default scorers
466
+ - Sets up Guidelines with judge model
467
+
468
+ Args:
469
+ data: Evaluation dataset (DataFrame or list of dicts)
470
+ predict_fn: Function to generate predictions
471
+ model_id: Optional model ID for linking
472
+ scorers: Optional list of scorers (uses defaults if not provided)
473
+ judge_model: Model endpoint for LLM-based scorers
474
+ guidelines: Optional list of guideline configurations
475
+
476
+ Returns:
477
+ EvaluationResult from mlflow.genai.evaluate()
478
+ """
479
+ # Wrap predict function with tracing
480
+ traced_fn = create_traced_predict_fn(predict_fn)
481
+
482
+ # Build scorer list
483
+ if scorers is None:
484
+ scorers = get_default_scorers(include_trace_scorers=True)
485
+
486
+ # Add Guidelines scorers if provided
487
+ if guidelines and judge_model:
488
+ guideline_scorers = create_guidelines_scorers(guidelines, judge_model)
489
+ scorers.extend(guideline_scorers)
490
+ elif guidelines:
491
+ logger.warning(
492
+ "Guidelines provided but no judge_model specified - Guidelines scorers will not be created"
493
+ )
494
+
495
+ # Run evaluation
496
+ eval_kwargs: dict[str, Any] = {
497
+ "data": data,
498
+ "predict_fn": traced_fn,
499
+ "scorers": scorers,
500
+ }
501
+
502
+ if model_id:
503
+ eval_kwargs["model_id"] = model_id
504
+
505
+ return mlflow.genai.evaluate(**eval_kwargs)
506
+
507
+
508
+ def prepare_eval_results_for_display(eval_results: Any) -> Any:
509
+ """
510
+ Prepare evaluation results DataFrame for display in Databricks.
511
+
512
+ The 'assessments' column and other complex object columns can't be
513
+ directly converted to Arrow format for display. This function converts
514
+ them to string representation.
515
+
516
+ Args:
517
+ eval_results: EvaluationResult from mlflow.genai.evaluate()
518
+
519
+ Returns:
520
+ DataFrame copy with complex columns converted to strings
521
+ """
522
+ try:
523
+ import pandas as pd
524
+ except ImportError:
525
+ logger.warning("pandas not available, returning original results")
526
+ return eval_results.tables.get("eval_results")
527
+
528
+ results_df: pd.DataFrame = eval_results.tables["eval_results"].copy()
529
+
530
+ # Convert complex columns to string for display compatibility
531
+ if "assessments" in results_df.columns:
532
+ results_df["assessments"] = results_df["assessments"].astype(str)
533
+
534
+ # Convert any other object columns that might cause Arrow conversion issues
535
+ for col in results_df.columns:
536
+ if results_df[col].dtype == "object":
537
+ try:
538
+ # Try to keep as-is first, only convert if it fails
539
+ results_df[col].to_list()
540
+ except Exception:
541
+ results_df[col] = results_df[col].astype(str)
542
+
543
+ return results_df
dao_ai/genie/__init__.py CHANGED
@@ -5,34 +5,82 @@ This package provides core Genie functionality that can be used across
5
5
  different contexts (tools, direct integration, etc.).
6
6
 
7
7
  Main exports:
8
- - GenieService: Core service implementation wrapping Databricks Genie SDK
8
+ - Genie: Extended Genie class that captures message_id in responses
9
+ - GenieResponse: Extended response class with message_id field
10
+ - GenieService: Service implementation wrapping Genie
9
11
  - GenieServiceBase: Abstract base class for service implementations
12
+ - GenieFeedbackRating: Enum for feedback ratings (POSITIVE, NEGATIVE, NONE)
13
+
14
+ Original databricks_ai_bridge classes (aliased):
15
+ - DatabricksGenie: Original Genie from databricks_ai_bridge
16
+ - DatabricksGenieResponse: Original GenieResponse from databricks_ai_bridge
10
17
 
11
18
  Cache implementations are available in the cache subpackage:
12
19
  - dao_ai.genie.cache.lru: LRU (Least Recently Used) cache
13
- - dao_ai.genie.cache.semantic: Semantic similarity cache using pg_vector
20
+ - dao_ai.genie.cache.context_aware.postgres: PostgreSQL context-aware cache
21
+ - dao_ai.genie.cache.context_aware.in_memory: In-memory context-aware cache
14
22
 
15
23
  Example usage:
16
- from dao_ai.genie import GenieService
17
- from dao_ai.genie.cache import LRUCacheService, SemanticCacheService
24
+ from dao_ai.genie import Genie, GenieService, GenieFeedbackRating
25
+
26
+ # Create Genie with message_id support
27
+ genie = Genie(space_id="my-space")
28
+ response = genie.ask_question("What are total sales?")
29
+ print(response.message_id) # Now available!
30
+
31
+ # Use with GenieService
32
+ service = GenieService(genie)
33
+ result = service.ask_question("What are total sales?")
34
+
35
+ # Send feedback using captured message_id
36
+ service.send_feedback(
37
+ conversation_id=result.response.conversation_id,
38
+ rating=GenieFeedbackRating.POSITIVE,
39
+ message_id=result.message_id, # Available from CacheResult
40
+ was_cache_hit=result.cache_hit,
41
+ )
18
42
  """
19
43
 
44
+ from databricks.sdk.service.dashboards import GenieFeedbackRating
45
+
20
46
  from dao_ai.genie.cache import (
21
47
  CacheResult,
48
+ ContextAwareGenieService,
22
49
  GenieServiceBase,
50
+ InMemoryContextAwareGenieService,
23
51
  LRUCacheService,
24
- SemanticCacheService,
52
+ PostgresContextAwareGenieService,
25
53
  SQLCacheEntry,
26
54
  )
27
- from dao_ai.genie.core import GenieService
55
+ from dao_ai.genie.cache.base import get_latest_message_id, get_message_content
56
+ from dao_ai.genie.core import (
57
+ DatabricksGenie,
58
+ DatabricksGenieResponse,
59
+ Genie,
60
+ GenieResponse,
61
+ GenieService,
62
+ )
28
63
 
29
64
  __all__ = [
65
+ # Extended Genie classes (primary - use these)
66
+ "Genie",
67
+ "GenieResponse",
68
+ # Original databricks_ai_bridge classes (aliased)
69
+ "DatabricksGenie",
70
+ "DatabricksGenieResponse",
30
71
  # Service classes
31
72
  "GenieService",
32
73
  "GenieServiceBase",
74
+ # Feedback
75
+ "GenieFeedbackRating",
76
+ # Helper functions
77
+ "get_latest_message_id",
78
+ "get_message_content",
33
79
  # Cache types (from cache subpackage)
34
80
  "CacheResult",
81
+ "ContextAwareGenieService",
82
+ "InMemoryContextAwareGenieService",
35
83
  "LRUCacheService",
36
- "SemanticCacheService",
84
+ "PostgresContextAwareGenieService",
37
85
  "SQLCacheEntry",
38
86
  ]
@@ -6,15 +6,16 @@ chained together using the decorator pattern.
6
6
 
7
7
  Available cache implementations:
8
8
  - LRUCacheService: In-memory LRU cache with O(1) exact match lookup
9
- - SemanticCacheService: PostgreSQL pg_vector-based semantic similarity cache
9
+ - PostgresContextAwareGenieService: PostgreSQL pg_vector-based context-aware cache
10
+ - InMemoryContextAwareGenieService: In-memory context-aware cache
10
11
 
11
12
  Example usage:
12
- from dao_ai.genie.cache import LRUCacheService, SemanticCacheService
13
+ from dao_ai.genie.cache import LRUCacheService, PostgresContextAwareGenieService
13
14
 
14
- # Chain caches: LRU (checked first) -> Semantic (checked second) -> Genie
15
- genie_service = SemanticCacheService(
15
+ # Chain caches: LRU (checked first) -> Context-aware (checked second) -> Genie
16
+ genie_service = PostgresContextAwareGenieService(
16
17
  impl=GenieService(genie),
17
- parameters=semantic_params,
18
+ parameters=context_aware_params,
18
19
  )
19
20
  genie_service = LRUCacheService(
20
21
  impl=genie_service,
@@ -27,9 +28,23 @@ from dao_ai.genie.cache.base import (
27
28
  GenieServiceBase,
28
29
  SQLCacheEntry,
29
30
  )
31
+ from dao_ai.genie.cache.context_aware import (
32
+ ContextAwareGenieService,
33
+ InMemoryContextAwareGenieService,
34
+ PersistentContextAwareGenieCacheService,
35
+ PostgresContextAwareGenieService,
36
+ )
30
37
  from dao_ai.genie.cache.core import execute_sql_via_warehouse
31
38
  from dao_ai.genie.cache.lru import LRUCacheService
32
- from dao_ai.genie.cache.semantic import SemanticCacheService
39
+ from dao_ai.genie.cache.optimization import (
40
+ SemanticCacheEvalDataset,
41
+ SemanticCacheEvalEntry,
42
+ ThresholdOptimizationResult,
43
+ clear_judge_cache,
44
+ generate_eval_dataset_from_cache,
45
+ optimize_semantic_cache_thresholds,
46
+ semantic_match_judge,
47
+ )
33
48
 
34
49
  __all__ = [
35
50
  # Base types
@@ -37,7 +52,19 @@ __all__ = [
37
52
  "GenieServiceBase",
38
53
  "SQLCacheEntry",
39
54
  "execute_sql_via_warehouse",
55
+ # Context-aware base classes
56
+ "ContextAwareGenieService",
57
+ "PersistentContextAwareGenieCacheService",
40
58
  # Cache implementations
59
+ "InMemoryContextAwareGenieService",
41
60
  "LRUCacheService",
42
- "SemanticCacheService",
61
+ "PostgresContextAwareGenieService",
62
+ # Optimization
63
+ "SemanticCacheEvalDataset",
64
+ "SemanticCacheEvalEntry",
65
+ "ThresholdOptimizationResult",
66
+ "clear_judge_cache",
67
+ "generate_eval_dataset_from_cache",
68
+ "optimize_semantic_cache_thresholds",
69
+ "semantic_match_judge",
43
70
  ]