dao-ai 0.1.20__py3-none-any.whl → 0.1.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/config.py +114 -33
- dao_ai/genie/cache/__init__.py +11 -9
- dao_ai/genie/cache/context_aware/__init__.py +21 -0
- dao_ai/genie/cache/context_aware/base.py +54 -1
- dao_ai/genie/cache/context_aware/in_memory.py +112 -0
- dao_ai/genie/cache/{optimization.py → context_aware/optimization.py} +83 -43
- dao_ai/genie/cache/context_aware/postgres.py +177 -0
- dao_ai/middleware/__init__.py +8 -1
- dao_ai/middleware/tool_call_observability.py +227 -0
- dao_ai/utils.py +7 -3
- {dao_ai-0.1.20.dist-info → dao_ai-0.1.21.dist-info}/METADATA +1 -1
- {dao_ai-0.1.20.dist-info → dao_ai-0.1.21.dist-info}/RECORD +15 -14
- {dao_ai-0.1.20.dist-info → dao_ai-0.1.21.dist-info}/WHEEL +0 -0
- {dao_ai-0.1.20.dist-info → dao_ai-0.1.21.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.1.20.dist-info → dao_ai-0.1.21.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
"""
|
|
2
|
-
|
|
2
|
+
Context-aware semantic cache threshold optimization using Optuna Bayesian optimization.
|
|
3
3
|
|
|
4
|
-
This module provides optimization for Genie
|
|
4
|
+
This module provides optimization for context-aware Genie cache thresholds using
|
|
5
5
|
Optuna's Tree-structured Parzen Estimator (TPE) algorithm with LLM-as-Judge
|
|
6
6
|
evaluation for semantic match validation.
|
|
7
7
|
|
|
@@ -11,10 +11,23 @@ The optimizer tunes these thresholds:
|
|
|
11
11
|
- question_weight: Weight for question similarity in combined score
|
|
12
12
|
|
|
13
13
|
Usage:
|
|
14
|
-
from dao_ai.genie.cache.optimization import
|
|
14
|
+
from dao_ai.genie.cache.context_aware.optimization import (
|
|
15
|
+
optimize_context_aware_cache_thresholds,
|
|
16
|
+
generate_eval_dataset_from_cache,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
# Get entries from your cache
|
|
20
|
+
entries = cache_service.get_entries(include_embeddings=True, limit=100)
|
|
21
|
+
|
|
22
|
+
# Generate evaluation dataset
|
|
23
|
+
eval_dataset = generate_eval_dataset_from_cache(
|
|
24
|
+
cache_entries=entries,
|
|
25
|
+
dataset_name="my_cache_eval",
|
|
26
|
+
)
|
|
15
27
|
|
|
16
|
-
|
|
17
|
-
|
|
28
|
+
# Optimize thresholds
|
|
29
|
+
result = optimize_context_aware_cache_thresholds(
|
|
30
|
+
dataset=eval_dataset,
|
|
18
31
|
judge_model="databricks-meta-llama-3-3-70b-instruct",
|
|
19
32
|
n_trials=50,
|
|
20
33
|
metric="f1",
|
|
@@ -29,37 +42,31 @@ import hashlib
|
|
|
29
42
|
import math
|
|
30
43
|
from dataclasses import dataclass, field
|
|
31
44
|
from datetime import datetime, timezone
|
|
32
|
-
from typing import Any, Callable, Literal, Sequence
|
|
45
|
+
from typing import TYPE_CHECKING, Any, Callable, Iterator, Literal, Sequence
|
|
33
46
|
|
|
34
47
|
import mlflow
|
|
35
|
-
import optuna
|
|
36
48
|
from loguru import logger
|
|
37
|
-
from optuna.samplers import TPESampler
|
|
38
|
-
|
|
39
|
-
# Optional MLflow integration - requires optuna-integration[mlflow]
|
|
40
|
-
try:
|
|
41
|
-
from optuna.integration import MLflowCallback
|
|
42
|
-
|
|
43
|
-
MLFLOW_CALLBACK_AVAILABLE = True
|
|
44
|
-
except ModuleNotFoundError:
|
|
45
|
-
MLFLOW_CALLBACK_AVAILABLE = False
|
|
46
|
-
MLflowCallback = None # type: ignore
|
|
47
49
|
|
|
48
50
|
from dao_ai.config import GenieContextAwareCacheParametersModel, LLMModel
|
|
49
51
|
from dao_ai.utils import dao_ai_version
|
|
50
52
|
|
|
53
|
+
# Type-only import for optuna.Trial to support type hints without runtime dependency
|
|
54
|
+
if TYPE_CHECKING:
|
|
55
|
+
import optuna
|
|
56
|
+
|
|
51
57
|
__all__ = [
|
|
52
|
-
"
|
|
53
|
-
"
|
|
58
|
+
"ContextAwareCacheEvalEntry",
|
|
59
|
+
"ContextAwareCacheEvalDataset",
|
|
54
60
|
"ThresholdOptimizationResult",
|
|
55
|
-
"
|
|
61
|
+
"optimize_context_aware_cache_thresholds",
|
|
56
62
|
"generate_eval_dataset_from_cache",
|
|
57
63
|
"semantic_match_judge",
|
|
64
|
+
"clear_judge_cache",
|
|
58
65
|
]
|
|
59
66
|
|
|
60
67
|
|
|
61
68
|
@dataclass
|
|
62
|
-
class
|
|
69
|
+
class ContextAwareCacheEvalEntry:
|
|
63
70
|
"""Single evaluation entry for threshold optimization.
|
|
64
71
|
|
|
65
72
|
Represents a pair of question/context combinations to evaluate
|
|
@@ -90,7 +97,7 @@ class SemanticCacheEvalEntry:
|
|
|
90
97
|
|
|
91
98
|
|
|
92
99
|
@dataclass
|
|
93
|
-
class
|
|
100
|
+
class ContextAwareCacheEvalDataset:
|
|
94
101
|
"""Dataset for semantic cache threshold optimization.
|
|
95
102
|
|
|
96
103
|
Attributes:
|
|
@@ -100,13 +107,13 @@ class SemanticCacheEvalDataset:
|
|
|
100
107
|
"""
|
|
101
108
|
|
|
102
109
|
name: str
|
|
103
|
-
entries: list[
|
|
110
|
+
entries: list[ContextAwareCacheEvalEntry]
|
|
104
111
|
description: str = ""
|
|
105
112
|
|
|
106
113
|
def __len__(self) -> int:
|
|
107
114
|
return len(self.entries)
|
|
108
115
|
|
|
109
|
-
def __iter__(self):
|
|
116
|
+
def __iter__(self) -> Iterator[ContextAwareCacheEvalEntry]:
|
|
110
117
|
return iter(self.entries)
|
|
111
118
|
|
|
112
119
|
|
|
@@ -272,7 +279,7 @@ def _compute_l2_similarity(embedding1: list[float], embedding2: list[float]) ->
|
|
|
272
279
|
|
|
273
280
|
|
|
274
281
|
def _evaluate_thresholds(
|
|
275
|
-
dataset:
|
|
282
|
+
dataset: ContextAwareCacheEvalDataset,
|
|
276
283
|
similarity_threshold: float,
|
|
277
284
|
context_similarity_threshold: float,
|
|
278
285
|
question_weight: float,
|
|
@@ -370,14 +377,14 @@ def _evaluate_thresholds(
|
|
|
370
377
|
|
|
371
378
|
|
|
372
379
|
def _create_objective(
|
|
373
|
-
dataset:
|
|
380
|
+
dataset: ContextAwareCacheEvalDataset,
|
|
374
381
|
judge_model: LLMModel | str | None,
|
|
375
382
|
metric: Literal["f1", "precision", "recall", "fbeta"],
|
|
376
383
|
beta: float = 1.0,
|
|
377
|
-
) -> Callable[[optuna.Trial], float]:
|
|
384
|
+
) -> Callable[["optuna.Trial"], float]:
|
|
378
385
|
"""Create the Optuna objective function."""
|
|
379
386
|
|
|
380
|
-
def objective(trial: optuna.Trial) -> float:
|
|
387
|
+
def objective(trial: "optuna.Trial") -> float:
|
|
381
388
|
# Sample parameters
|
|
382
389
|
similarity_threshold = trial.suggest_float(
|
|
383
390
|
"similarity_threshold", 0.5, 0.99, log=False
|
|
@@ -423,8 +430,8 @@ def _create_objective(
|
|
|
423
430
|
return objective
|
|
424
431
|
|
|
425
432
|
|
|
426
|
-
def
|
|
427
|
-
dataset:
|
|
433
|
+
def optimize_context_aware_cache_thresholds(
|
|
434
|
+
dataset: ContextAwareCacheEvalDataset,
|
|
428
435
|
original_thresholds: dict[str, float]
|
|
429
436
|
| GenieContextAwareCacheParametersModel
|
|
430
437
|
| None = None,
|
|
@@ -461,12 +468,12 @@ def optimize_semantic_cache_thresholds(
|
|
|
461
468
|
ThresholdOptimizationResult with optimized thresholds and metrics
|
|
462
469
|
|
|
463
470
|
Example:
|
|
464
|
-
from dao_ai.genie.cache.optimization import (
|
|
465
|
-
|
|
466
|
-
|
|
471
|
+
from dao_ai.genie.cache.context_aware.optimization import (
|
|
472
|
+
optimize_context_aware_cache_thresholds,
|
|
473
|
+
ContextAwareCacheEvalDataset,
|
|
467
474
|
)
|
|
468
475
|
|
|
469
|
-
result =
|
|
476
|
+
result = optimize_context_aware_cache_thresholds(
|
|
470
477
|
dataset=my_dataset,
|
|
471
478
|
judge_model="databricks-meta-llama-3-3-70b-instruct",
|
|
472
479
|
n_trials=50,
|
|
@@ -476,6 +483,26 @@ def optimize_semantic_cache_thresholds(
|
|
|
476
483
|
if result.improved:
|
|
477
484
|
print(f"New thresholds: {result.optimized_thresholds}")
|
|
478
485
|
"""
|
|
486
|
+
# Lazy import optuna - only loaded when optimization is actually called
|
|
487
|
+
# This allows the cache module to be imported without optuna installed
|
|
488
|
+
try:
|
|
489
|
+
import optuna
|
|
490
|
+
from optuna.samplers import TPESampler
|
|
491
|
+
except ImportError as e:
|
|
492
|
+
raise ImportError(
|
|
493
|
+
"optuna is required for cache threshold optimization. "
|
|
494
|
+
"Install it with: pip install optuna"
|
|
495
|
+
) from e
|
|
496
|
+
|
|
497
|
+
# Optional MLflow integration - requires optuna-integration[mlflow]
|
|
498
|
+
try:
|
|
499
|
+
from optuna.integration import MLflowCallback
|
|
500
|
+
|
|
501
|
+
mlflow_callback_available = True
|
|
502
|
+
except ModuleNotFoundError:
|
|
503
|
+
mlflow_callback_available = False
|
|
504
|
+
MLflowCallback = None # type: ignore
|
|
505
|
+
|
|
479
506
|
logger.info(
|
|
480
507
|
"Starting semantic cache threshold optimization",
|
|
481
508
|
dataset_name=dataset.name,
|
|
@@ -539,7 +566,7 @@ def optimize_semantic_cache_thresholds(
|
|
|
539
566
|
# Create study name if not provided
|
|
540
567
|
if study_name is None:
|
|
541
568
|
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
|
|
542
|
-
study_name = f"
|
|
569
|
+
study_name = f"context_aware_cache_threshold_optimization_{timestamp}"
|
|
543
570
|
|
|
544
571
|
# Create Optuna study
|
|
545
572
|
sampler = TPESampler(seed=seed)
|
|
@@ -562,7 +589,7 @@ def optimize_semantic_cache_thresholds(
|
|
|
562
589
|
|
|
563
590
|
# Set up MLflow callback if available
|
|
564
591
|
callbacks = []
|
|
565
|
-
if
|
|
592
|
+
if mlflow_callback_available and MLflowCallback is not None:
|
|
566
593
|
try:
|
|
567
594
|
mlflow_callback = MLflowCallback(
|
|
568
595
|
tracking_uri=mlflow.get_tracking_uri(),
|
|
@@ -743,7 +770,7 @@ def generate_eval_dataset_from_cache(
|
|
|
743
770
|
num_negative_pairs: int = 50,
|
|
744
771
|
paraphrase_model: LLMModel | str | None = None,
|
|
745
772
|
dataset_name: str = "generated_eval_dataset",
|
|
746
|
-
) ->
|
|
773
|
+
) -> ContextAwareCacheEvalDataset:
|
|
747
774
|
"""
|
|
748
775
|
Generate an evaluation dataset from existing cache entries.
|
|
749
776
|
|
|
@@ -752,7 +779,8 @@ def generate_eval_dataset_from_cache(
|
|
|
752
779
|
|
|
753
780
|
Args:
|
|
754
781
|
cache_entries: List of cache entries with 'question', 'conversation_context',
|
|
755
|
-
'question_embedding', and 'context_embedding' keys
|
|
782
|
+
'question_embedding', and 'context_embedding' keys. Use cache.get_entries()
|
|
783
|
+
with include_embeddings=True to retrieve these.
|
|
756
784
|
embedding_model: Model for generating embeddings for paraphrased questions
|
|
757
785
|
num_positive_pairs: Number of positive (matching) pairs to generate
|
|
758
786
|
num_negative_pairs: Number of negative (non-matching) pairs to generate
|
|
@@ -760,7 +788,19 @@ def generate_eval_dataset_from_cache(
|
|
|
760
788
|
dataset_name: Name for the generated dataset
|
|
761
789
|
|
|
762
790
|
Returns:
|
|
763
|
-
|
|
791
|
+
ContextAwareCacheEvalDataset with generated entries
|
|
792
|
+
|
|
793
|
+
Example:
|
|
794
|
+
# Get entries from cache with embeddings
|
|
795
|
+
entries = cache_service.get_entries(include_embeddings=True, limit=100)
|
|
796
|
+
|
|
797
|
+
# Generate evaluation dataset
|
|
798
|
+
eval_dataset = generate_eval_dataset_from_cache(
|
|
799
|
+
cache_entries=entries,
|
|
800
|
+
num_positive_pairs=50,
|
|
801
|
+
num_negative_pairs=50,
|
|
802
|
+
dataset_name="my_cache_eval",
|
|
803
|
+
)
|
|
764
804
|
"""
|
|
765
805
|
import random
|
|
766
806
|
|
|
@@ -787,7 +827,7 @@ def generate_eval_dataset_from_cache(
|
|
|
787
827
|
)
|
|
788
828
|
chat = para_model.as_chat_model()
|
|
789
829
|
|
|
790
|
-
entries: list[
|
|
830
|
+
entries: list[ContextAwareCacheEvalEntry] = []
|
|
791
831
|
|
|
792
832
|
# Generate positive pairs (paraphrases)
|
|
793
833
|
logger.info(
|
|
@@ -824,7 +864,7 @@ Rephrased question:"""
|
|
|
824
864
|
)
|
|
825
865
|
|
|
826
866
|
entries.append(
|
|
827
|
-
|
|
867
|
+
ContextAwareCacheEvalEntry(
|
|
828
868
|
question=paraphrased_question,
|
|
829
869
|
question_embedding=para_q_emb,
|
|
830
870
|
context=original_context,
|
|
@@ -855,7 +895,7 @@ Rephrased question:"""
|
|
|
855
895
|
|
|
856
896
|
# Use entry1 as the "question" and entry2 as the "cached" entry
|
|
857
897
|
entries.append(
|
|
858
|
-
|
|
898
|
+
ContextAwareCacheEvalEntry(
|
|
859
899
|
question=entry1.get("question", ""),
|
|
860
900
|
question_embedding=entry1.get("question_embedding", []),
|
|
861
901
|
context=entry1.get("conversation_context", ""),
|
|
@@ -876,7 +916,7 @@ Rephrased question:"""
|
|
|
876
916
|
negative_pairs=sum(1 for e in entries if e.expected_match is False),
|
|
877
917
|
)
|
|
878
918
|
|
|
879
|
-
return
|
|
919
|
+
return ContextAwareCacheEvalDataset(
|
|
880
920
|
name=dataset_name,
|
|
881
921
|
entries=entries,
|
|
882
922
|
description=f"Generated from {len(cache_entries)} cache entries",
|
|
@@ -849,6 +849,144 @@ class PostgresContextAwareGenieService(PersistentContextAwareGenieCacheService):
|
|
|
849
849
|
}
|
|
850
850
|
return {}
|
|
851
851
|
|
|
852
|
+
def get_entries(
|
|
853
|
+
self,
|
|
854
|
+
limit: int | None = None,
|
|
855
|
+
offset: int | None = None,
|
|
856
|
+
include_embeddings: bool = False,
|
|
857
|
+
conversation_id: str | None = None,
|
|
858
|
+
created_after: datetime | None = None,
|
|
859
|
+
created_before: datetime | None = None,
|
|
860
|
+
question_contains: str | None = None,
|
|
861
|
+
) -> list[dict[str, Any]]:
|
|
862
|
+
"""
|
|
863
|
+
Get cache entries with optional filtering.
|
|
864
|
+
|
|
865
|
+
This method retrieves cache entries for inspection, debugging, or
|
|
866
|
+
generating evaluation datasets for threshold optimization.
|
|
867
|
+
|
|
868
|
+
Args:
|
|
869
|
+
limit: Maximum number of entries to return (None = no limit)
|
|
870
|
+
offset: Number of entries to skip for pagination (None = 0)
|
|
871
|
+
include_embeddings: Whether to include embedding vectors in results.
|
|
872
|
+
Embeddings are large, so set False for general inspection.
|
|
873
|
+
conversation_id: Filter by conversation ID (None = all conversations)
|
|
874
|
+
created_after: Only entries created after this time (None = no filter)
|
|
875
|
+
created_before: Only entries created before this time (None = no filter)
|
|
876
|
+
question_contains: Case-insensitive text search on question field
|
|
877
|
+
|
|
878
|
+
Returns:
|
|
879
|
+
List of cache entry dicts. See base class for full key documentation.
|
|
880
|
+
|
|
881
|
+
Example:
|
|
882
|
+
# Get entries with embeddings for evaluation dataset generation
|
|
883
|
+
entries = cache.get_entries(include_embeddings=True, limit=100)
|
|
884
|
+
eval_dataset = generate_eval_dataset_from_cache(entries)
|
|
885
|
+
"""
|
|
886
|
+
self._setup()
|
|
887
|
+
|
|
888
|
+
# Build column list
|
|
889
|
+
base_columns = [
|
|
890
|
+
"id",
|
|
891
|
+
"question",
|
|
892
|
+
"conversation_context",
|
|
893
|
+
"sql_query",
|
|
894
|
+
"description",
|
|
895
|
+
"conversation_id",
|
|
896
|
+
"created_at",
|
|
897
|
+
]
|
|
898
|
+
|
|
899
|
+
if include_embeddings:
|
|
900
|
+
columns = base_columns + ["question_embedding", "context_embedding"]
|
|
901
|
+
else:
|
|
902
|
+
columns = base_columns
|
|
903
|
+
|
|
904
|
+
columns_str = ", ".join(columns)
|
|
905
|
+
|
|
906
|
+
# Build WHERE clause with parameters
|
|
907
|
+
where_clauses = ["genie_space_id = %s"]
|
|
908
|
+
params: list[Any] = [self.space_id]
|
|
909
|
+
|
|
910
|
+
if conversation_id is not None:
|
|
911
|
+
where_clauses.append("conversation_id = %s")
|
|
912
|
+
params.append(conversation_id)
|
|
913
|
+
|
|
914
|
+
if created_after is not None:
|
|
915
|
+
where_clauses.append("created_at > %s")
|
|
916
|
+
params.append(created_after)
|
|
917
|
+
|
|
918
|
+
if created_before is not None:
|
|
919
|
+
where_clauses.append("created_at < %s")
|
|
920
|
+
params.append(created_before)
|
|
921
|
+
|
|
922
|
+
if question_contains is not None:
|
|
923
|
+
where_clauses.append("question ILIKE %s")
|
|
924
|
+
params.append(f"%{question_contains}%")
|
|
925
|
+
|
|
926
|
+
where_str = " AND ".join(where_clauses)
|
|
927
|
+
|
|
928
|
+
# Build full query
|
|
929
|
+
query = f"""
|
|
930
|
+
SELECT {columns_str}
|
|
931
|
+
FROM {self.table_name}
|
|
932
|
+
WHERE {where_str}
|
|
933
|
+
ORDER BY created_at DESC
|
|
934
|
+
"""
|
|
935
|
+
|
|
936
|
+
if limit is not None:
|
|
937
|
+
query += f" LIMIT {int(limit)}"
|
|
938
|
+
|
|
939
|
+
if offset is not None:
|
|
940
|
+
query += f" OFFSET {int(offset)}"
|
|
941
|
+
|
|
942
|
+
# Execute query
|
|
943
|
+
with self._pool.connection() as conn:
|
|
944
|
+
with conn.cursor() as cur:
|
|
945
|
+
cur.execute(query, params)
|
|
946
|
+
rows = cur.fetchall()
|
|
947
|
+
|
|
948
|
+
entries: list[dict[str, Any]] = []
|
|
949
|
+
for row in rows:
|
|
950
|
+
entry: dict[str, Any] = {
|
|
951
|
+
"id": row.get("id"),
|
|
952
|
+
"question": row.get("question"),
|
|
953
|
+
"conversation_context": row.get("conversation_context"),
|
|
954
|
+
"sql_query": row.get("sql_query"),
|
|
955
|
+
"description": row.get("description"),
|
|
956
|
+
"conversation_id": row.get("conversation_id"),
|
|
957
|
+
"created_at": row.get("created_at"),
|
|
958
|
+
}
|
|
959
|
+
|
|
960
|
+
if include_embeddings:
|
|
961
|
+
# Convert pgvector to list
|
|
962
|
+
q_emb = row.get("question_embedding")
|
|
963
|
+
c_emb = row.get("context_embedding")
|
|
964
|
+
entry["question_embedding"] = (
|
|
965
|
+
list(q_emb) if q_emb is not None else []
|
|
966
|
+
)
|
|
967
|
+
entry["context_embedding"] = (
|
|
968
|
+
list(c_emb) if c_emb is not None else []
|
|
969
|
+
)
|
|
970
|
+
|
|
971
|
+
entries.append(entry)
|
|
972
|
+
|
|
973
|
+
logger.debug(
|
|
974
|
+
"Retrieved cache entries",
|
|
975
|
+
layer=self.name,
|
|
976
|
+
count=len(entries),
|
|
977
|
+
include_embeddings=include_embeddings,
|
|
978
|
+
filters={
|
|
979
|
+
"conversation_id": conversation_id,
|
|
980
|
+
"created_after": str(created_after) if created_after else None,
|
|
981
|
+
"created_before": (
|
|
982
|
+
str(created_before) if created_before else None
|
|
983
|
+
),
|
|
984
|
+
"question_contains": question_contains,
|
|
985
|
+
},
|
|
986
|
+
)
|
|
987
|
+
|
|
988
|
+
return entries
|
|
989
|
+
|
|
852
990
|
def from_space(
|
|
853
991
|
self,
|
|
854
992
|
space_id: str | None = None,
|
|
@@ -857,6 +995,7 @@ class PostgresContextAwareGenieService(PersistentContextAwareGenieCacheService):
|
|
|
857
995
|
from_datetime: datetime | None = None,
|
|
858
996
|
to_datetime: datetime | None = None,
|
|
859
997
|
max_messages: int | None = None,
|
|
998
|
+
max_conversations: int | None = None,
|
|
860
999
|
) -> Self:
|
|
861
1000
|
"""Populate cache from existing Genie space conversations.
|
|
862
1001
|
|
|
@@ -872,6 +1011,7 @@ class PostgresContextAwareGenieService(PersistentContextAwareGenieCacheService):
|
|
|
872
1011
|
from_datetime: Only include messages after this time
|
|
873
1012
|
to_datetime: Only include messages before this time
|
|
874
1013
|
max_messages: Limit to last N messages (most recent first)
|
|
1014
|
+
max_conversations: Limit to N conversations (stops pagination after reaching limit)
|
|
875
1015
|
|
|
876
1016
|
Returns:
|
|
877
1017
|
self for method chaining
|
|
@@ -916,8 +1056,21 @@ class PostgresContextAwareGenieService(PersistentContextAwareGenieCacheService):
|
|
|
916
1056
|
break
|
|
917
1057
|
|
|
918
1058
|
if response.conversations is None:
|
|
1059
|
+
logger.debug(
|
|
1060
|
+
"No conversations in response",
|
|
1061
|
+
layer=self.name,
|
|
1062
|
+
space_id=target_space_id,
|
|
1063
|
+
)
|
|
919
1064
|
break
|
|
920
1065
|
|
|
1066
|
+
logger.debug(
|
|
1067
|
+
"Fetched conversations page",
|
|
1068
|
+
layer=self.name,
|
|
1069
|
+
conversations_in_page=len(response.conversations),
|
|
1070
|
+
total_conversations_so_far=stats["conversations_processed"],
|
|
1071
|
+
has_next_page=response.next_page_token is not None,
|
|
1072
|
+
)
|
|
1073
|
+
|
|
921
1074
|
for conversation in response.conversations:
|
|
922
1075
|
if conversation.conversation_id is None:
|
|
923
1076
|
continue
|
|
@@ -945,11 +1098,35 @@ class PostgresContextAwareGenieService(PersistentContextAwareGenieCacheService):
|
|
|
945
1098
|
if max_messages and len(all_messages) >= max_messages:
|
|
946
1099
|
break
|
|
947
1100
|
|
|
1101
|
+
if (
|
|
1102
|
+
max_conversations
|
|
1103
|
+
and stats["conversations_processed"] >= max_conversations
|
|
1104
|
+
):
|
|
1105
|
+
break
|
|
1106
|
+
|
|
948
1107
|
if max_messages and len(all_messages) >= max_messages:
|
|
949
1108
|
break
|
|
950
1109
|
|
|
1110
|
+
if (
|
|
1111
|
+
max_conversations
|
|
1112
|
+
and stats["conversations_processed"] >= max_conversations
|
|
1113
|
+
):
|
|
1114
|
+
logger.debug(
|
|
1115
|
+
"Reached max_conversations limit",
|
|
1116
|
+
layer=self.name,
|
|
1117
|
+
max_conversations=max_conversations,
|
|
1118
|
+
total_conversations=stats["conversations_processed"],
|
|
1119
|
+
)
|
|
1120
|
+
break
|
|
1121
|
+
|
|
951
1122
|
page_token = response.next_page_token
|
|
952
1123
|
if page_token is None:
|
|
1124
|
+
logger.debug(
|
|
1125
|
+
"No more pages to fetch",
|
|
1126
|
+
layer=self.name,
|
|
1127
|
+
total_conversations=stats["conversations_processed"],
|
|
1128
|
+
total_messages=len(all_messages),
|
|
1129
|
+
)
|
|
953
1130
|
break
|
|
954
1131
|
|
|
955
1132
|
# Sort and limit
|
dao_ai/middleware/__init__.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
# DAO AI Middleware Module
|
|
2
|
-
#
|
|
2
|
+
# Middleware implementations compatible with LangChain v1's create_agent
|
|
3
3
|
|
|
4
4
|
# Re-export LangChain built-in middleware
|
|
5
5
|
from langchain.agents.middleware import (
|
|
@@ -82,6 +82,10 @@ from dao_ai.middleware.summarization import (
|
|
|
82
82
|
create_summarization_middleware,
|
|
83
83
|
)
|
|
84
84
|
from dao_ai.middleware.tool_call_limit import create_tool_call_limit_middleware
|
|
85
|
+
from dao_ai.middleware.tool_call_observability import (
|
|
86
|
+
ToolCallObservabilityMiddleware,
|
|
87
|
+
create_tool_call_observability_middleware,
|
|
88
|
+
)
|
|
85
89
|
from dao_ai.middleware.tool_retry import create_tool_retry_middleware
|
|
86
90
|
from dao_ai.middleware.tool_selector import create_llm_tool_selector_middleware
|
|
87
91
|
|
|
@@ -160,4 +164,7 @@ __all__ = [
|
|
|
160
164
|
"create_clear_tool_uses_edit",
|
|
161
165
|
# PII middleware factory functions
|
|
162
166
|
"create_pii_middleware",
|
|
167
|
+
# Tool call observability middleware
|
|
168
|
+
"ToolCallObservabilityMiddleware",
|
|
169
|
+
"create_tool_call_observability_middleware",
|
|
163
170
|
]
|