dao-ai 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/agent_as_code.py +2 -5
- dao_ai/cli.py +65 -15
- dao_ai/config.py +672 -218
- dao_ai/genie/cache/core.py +6 -2
- dao_ai/genie/cache/lru.py +29 -11
- dao_ai/genie/cache/semantic.py +95 -44
- dao_ai/hooks/core.py +5 -5
- dao_ai/logging.py +56 -0
- dao_ai/memory/core.py +61 -44
- dao_ai/memory/databricks.py +54 -41
- dao_ai/memory/postgres.py +77 -36
- dao_ai/middleware/assertions.py +45 -17
- dao_ai/middleware/core.py +13 -7
- dao_ai/middleware/guardrails.py +30 -25
- dao_ai/middleware/human_in_the_loop.py +9 -5
- dao_ai/middleware/message_validation.py +61 -29
- dao_ai/middleware/summarization.py +16 -11
- dao_ai/models.py +172 -69
- dao_ai/nodes.py +148 -19
- dao_ai/optimization.py +26 -16
- dao_ai/orchestration/core.py +15 -8
- dao_ai/orchestration/supervisor.py +22 -8
- dao_ai/orchestration/swarm.py +57 -12
- dao_ai/prompts.py +17 -17
- dao_ai/providers/databricks.py +365 -155
- dao_ai/state.py +24 -6
- dao_ai/tools/__init__.py +2 -0
- dao_ai/tools/agent.py +1 -3
- dao_ai/tools/core.py +7 -7
- dao_ai/tools/email.py +29 -77
- dao_ai/tools/genie.py +18 -13
- dao_ai/tools/mcp.py +223 -156
- dao_ai/tools/python.py +5 -2
- dao_ai/tools/search.py +1 -1
- dao_ai/tools/slack.py +21 -9
- dao_ai/tools/sql.py +202 -0
- dao_ai/tools/time.py +30 -7
- dao_ai/tools/unity_catalog.py +129 -86
- dao_ai/tools/vector_search.py +318 -244
- dao_ai/utils.py +15 -10
- dao_ai-0.1.3.dist-info/METADATA +455 -0
- dao_ai-0.1.3.dist-info/RECORD +64 -0
- dao_ai-0.1.1.dist-info/METADATA +0 -1878
- dao_ai-0.1.1.dist-info/RECORD +0 -62
- {dao_ai-0.1.1.dist-info → dao_ai-0.1.3.dist-info}/WHEEL +0 -0
- {dao_ai-0.1.1.dist-info → dao_ai-0.1.3.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.1.1.dist-info → dao_ai-0.1.3.dist-info}/licenses/LICENSE +0 -0
dao_ai/genie/cache/core.py
CHANGED
|
@@ -38,7 +38,7 @@ def execute_sql_via_warehouse(
|
|
|
38
38
|
w: WorkspaceClient = warehouse.workspace_client
|
|
39
39
|
warehouse_id: str = str(warehouse.warehouse_id)
|
|
40
40
|
|
|
41
|
-
logger.
|
|
41
|
+
logger.trace("Executing cached SQL", layer=layer_name, sql_prefix=sql[:100])
|
|
42
42
|
|
|
43
43
|
statement_response: StatementResponse = w.statement_execution.execute_statement(
|
|
44
44
|
statement=sql,
|
|
@@ -57,7 +57,11 @@ def execute_sql_via_warehouse(
|
|
|
57
57
|
|
|
58
58
|
if statement_response.status.state != StatementState.SUCCEEDED:
|
|
59
59
|
error_msg: str = f"SQL execution failed: {statement_response.status}"
|
|
60
|
-
logger.error(
|
|
60
|
+
logger.error(
|
|
61
|
+
"SQL execution failed",
|
|
62
|
+
layer=layer_name,
|
|
63
|
+
status=str(statement_response.status),
|
|
64
|
+
)
|
|
61
65
|
return error_msg
|
|
62
66
|
|
|
63
67
|
# Convert to DataFrame
|
dao_ai/genie/cache/lru.py
CHANGED
|
@@ -124,7 +124,9 @@ class LRUCacheService(GenieServiceBase):
|
|
|
124
124
|
if self._cache:
|
|
125
125
|
oldest_key: str = next(iter(self._cache))
|
|
126
126
|
del self._cache[oldest_key]
|
|
127
|
-
logger.
|
|
127
|
+
logger.trace(
|
|
128
|
+
"Evicted cache entry", layer=self.name, key_prefix=oldest_key[:50]
|
|
129
|
+
)
|
|
128
130
|
|
|
129
131
|
def _get(self, key: str) -> SQLCacheEntry | None:
|
|
130
132
|
"""Get from cache, returning None if not found or expired."""
|
|
@@ -135,7 +137,7 @@ class LRUCacheService(GenieServiceBase):
|
|
|
135
137
|
|
|
136
138
|
if self._is_expired(entry):
|
|
137
139
|
del self._cache[key]
|
|
138
|
-
logger.
|
|
140
|
+
logger.trace("Expired cache entry", layer=self.name, key_prefix=key[:50])
|
|
139
141
|
return None
|
|
140
142
|
|
|
141
143
|
self._cache.move_to_end(key)
|
|
@@ -156,9 +158,12 @@ class LRUCacheService(GenieServiceBase):
|
|
|
156
158
|
created_at=datetime.now(),
|
|
157
159
|
)
|
|
158
160
|
logger.info(
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
161
|
+
"Stored cache entry",
|
|
162
|
+
layer=self.name,
|
|
163
|
+
key_prefix=key[:50],
|
|
164
|
+
sql_prefix=response.query[:50] if response.query else None,
|
|
165
|
+
cache_size=len(self._cache),
|
|
166
|
+
capacity=self.capacity,
|
|
162
167
|
)
|
|
163
168
|
|
|
164
169
|
@mlflow.trace(name="execute_cached_sql")
|
|
@@ -175,7 +180,7 @@ class LRUCacheService(GenieServiceBase):
|
|
|
175
180
|
w: WorkspaceClient = self.warehouse.workspace_client
|
|
176
181
|
warehouse_id: str = str(self.warehouse.warehouse_id)
|
|
177
182
|
|
|
178
|
-
logger.
|
|
183
|
+
logger.trace("Executing cached SQL", layer=self.name, sql_prefix=sql[:100])
|
|
179
184
|
|
|
180
185
|
statement_response: StatementResponse = w.statement_execution.execute_statement(
|
|
181
186
|
statement=sql,
|
|
@@ -194,7 +199,11 @@ class LRUCacheService(GenieServiceBase):
|
|
|
194
199
|
|
|
195
200
|
if statement_response.status.state != StatementState.SUCCEEDED:
|
|
196
201
|
error_msg: str = f"SQL execution failed: {statement_response.status}"
|
|
197
|
-
logger.error(
|
|
202
|
+
logger.error(
|
|
203
|
+
"SQL execution failed",
|
|
204
|
+
layer=self.name,
|
|
205
|
+
status=str(statement_response.status),
|
|
206
|
+
)
|
|
198
207
|
return error_msg
|
|
199
208
|
|
|
200
209
|
# Convert to DataFrame
|
|
@@ -250,8 +259,12 @@ class LRUCacheService(GenieServiceBase):
|
|
|
250
259
|
|
|
251
260
|
if cached is not None:
|
|
252
261
|
logger.info(
|
|
253
|
-
|
|
254
|
-
|
|
262
|
+
"Cache HIT",
|
|
263
|
+
layer=self.name,
|
|
264
|
+
question_prefix=question[:50],
|
|
265
|
+
conversation_id=conversation_id,
|
|
266
|
+
cache_size=self.size,
|
|
267
|
+
capacity=self.capacity,
|
|
255
268
|
)
|
|
256
269
|
|
|
257
270
|
# Re-execute the cached SQL to get fresh data
|
|
@@ -271,8 +284,13 @@ class LRUCacheService(GenieServiceBase):
|
|
|
271
284
|
|
|
272
285
|
# Cache miss - delegate to wrapped service
|
|
273
286
|
logger.info(
|
|
274
|
-
|
|
275
|
-
|
|
287
|
+
"Cache MISS",
|
|
288
|
+
layer=self.name,
|
|
289
|
+
question_prefix=question[:50],
|
|
290
|
+
conversation_id=conversation_id,
|
|
291
|
+
cache_size=self.size,
|
|
292
|
+
capacity=self.capacity,
|
|
293
|
+
delegating_to=type(self.impl).__name__,
|
|
276
294
|
)
|
|
277
295
|
|
|
278
296
|
result: CacheResult = self.impl.ask_question(question, conversation_id)
|
dao_ai/genie/cache/semantic.py
CHANGED
|
@@ -78,7 +78,9 @@ def get_conversation_history(
|
|
|
78
78
|
return []
|
|
79
79
|
except Exception as e:
|
|
80
80
|
logger.warning(
|
|
81
|
-
|
|
81
|
+
"Failed to retrieve conversation history",
|
|
82
|
+
conversation_id=conversation_id,
|
|
83
|
+
error=str(e),
|
|
82
84
|
)
|
|
83
85
|
return []
|
|
84
86
|
|
|
@@ -137,10 +139,13 @@ def build_context_string(
|
|
|
137
139
|
if estimated_tokens > max_tokens:
|
|
138
140
|
# Truncate to fit max_tokens
|
|
139
141
|
target_chars = max_tokens * 4
|
|
142
|
+
original_length = len(context_string)
|
|
140
143
|
context_string = context_string[:target_chars] + "..."
|
|
141
|
-
logger.
|
|
142
|
-
|
|
143
|
-
|
|
144
|
+
logger.trace(
|
|
145
|
+
"Truncated context string",
|
|
146
|
+
original_chars=original_length,
|
|
147
|
+
target_chars=target_chars,
|
|
148
|
+
max_tokens=max_tokens,
|
|
144
149
|
)
|
|
145
150
|
|
|
146
151
|
return context_string
|
|
@@ -251,7 +256,9 @@ class SemanticCacheService(GenieServiceBase):
|
|
|
251
256
|
sample_embedding: list[float] = self._embeddings.embed_query("test")
|
|
252
257
|
self._embedding_dims = len(sample_embedding)
|
|
253
258
|
logger.debug(
|
|
254
|
-
|
|
259
|
+
"Auto-detected embedding dimensions",
|
|
260
|
+
layer=self.name,
|
|
261
|
+
dims=self._embedding_dims,
|
|
255
262
|
)
|
|
256
263
|
else:
|
|
257
264
|
self._embedding_dims = self.parameters.embedding_dims
|
|
@@ -264,8 +271,11 @@ class SemanticCacheService(GenieServiceBase):
|
|
|
264
271
|
|
|
265
272
|
self._setup_complete = True
|
|
266
273
|
logger.debug(
|
|
267
|
-
|
|
268
|
-
|
|
274
|
+
"Semantic cache initialized",
|
|
275
|
+
layer=self.name,
|
|
276
|
+
space_id=self.space_id,
|
|
277
|
+
table_name=self.table_name,
|
|
278
|
+
dims=self._embedding_dims,
|
|
269
279
|
)
|
|
270
280
|
|
|
271
281
|
@property
|
|
@@ -369,9 +379,11 @@ class SemanticCacheService(GenieServiceBase):
|
|
|
369
379
|
current_dims = row.get("atttypmod", 0)
|
|
370
380
|
if current_dims != self.embedding_dims:
|
|
371
381
|
logger.warning(
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
382
|
+
"Embedding dimension mismatch, dropping and recreating table",
|
|
383
|
+
layer=self.name,
|
|
384
|
+
table_dims=current_dims,
|
|
385
|
+
expected_dims=self.embedding_dims,
|
|
386
|
+
table_name=self.table_name,
|
|
375
387
|
)
|
|
376
388
|
cur.execute(f"DROP TABLE {self.table_name}")
|
|
377
389
|
except Exception:
|
|
@@ -448,13 +460,17 @@ class SemanticCacheService(GenieServiceBase):
|
|
|
448
460
|
conversation_context[:target_chars] + "..."
|
|
449
461
|
)
|
|
450
462
|
|
|
451
|
-
logger.
|
|
452
|
-
|
|
453
|
-
|
|
463
|
+
logger.trace(
|
|
464
|
+
"Using conversation context",
|
|
465
|
+
layer=self.name,
|
|
466
|
+
messages_count=len(conversation_messages),
|
|
467
|
+
window_size=self.parameters.context_window_size,
|
|
454
468
|
)
|
|
455
469
|
except Exception as e:
|
|
456
470
|
logger.warning(
|
|
457
|
-
|
|
471
|
+
"Failed to build conversation context, using question only",
|
|
472
|
+
layer=self.name,
|
|
473
|
+
error=str(e),
|
|
458
474
|
)
|
|
459
475
|
conversation_context = ""
|
|
460
476
|
|
|
@@ -558,8 +574,10 @@ class SemanticCacheService(GenieServiceBase):
|
|
|
558
574
|
|
|
559
575
|
if row is None:
|
|
560
576
|
logger.info(
|
|
561
|
-
|
|
562
|
-
|
|
577
|
+
"Cache MISS (no entries)",
|
|
578
|
+
layer=self.name,
|
|
579
|
+
question_prefix=question[:50],
|
|
580
|
+
space=self.space_id,
|
|
563
581
|
)
|
|
564
582
|
return None
|
|
565
583
|
|
|
@@ -577,25 +595,33 @@ class SemanticCacheService(GenieServiceBase):
|
|
|
577
595
|
is_valid: bool = row.get("is_valid", False)
|
|
578
596
|
|
|
579
597
|
# Log best match info
|
|
580
|
-
logger.
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
f"
|
|
584
|
-
f"
|
|
598
|
+
logger.debug(
|
|
599
|
+
"Best match found",
|
|
600
|
+
layer=self.name,
|
|
601
|
+
question_sim=f"{question_similarity:.4f}",
|
|
602
|
+
context_sim=f"{context_similarity:.4f}",
|
|
603
|
+
combined_sim=f"{combined_similarity:.4f}",
|
|
604
|
+
is_valid=is_valid,
|
|
605
|
+
cached_question_prefix=cached_question[:50],
|
|
606
|
+
cached_context_prefix=cached_context[:80],
|
|
585
607
|
)
|
|
586
608
|
|
|
587
609
|
# Check BOTH similarity thresholds (dual embedding precision check)
|
|
588
610
|
if question_similarity < self.parameters.similarity_threshold:
|
|
589
611
|
logger.info(
|
|
590
|
-
|
|
591
|
-
|
|
612
|
+
"Cache MISS (question similarity too low)",
|
|
613
|
+
layer=self.name,
|
|
614
|
+
question_sim=f"{question_similarity:.4f}",
|
|
615
|
+
threshold=self.parameters.similarity_threshold,
|
|
592
616
|
)
|
|
593
617
|
return None
|
|
594
618
|
|
|
595
619
|
if context_similarity < self.parameters.context_similarity_threshold:
|
|
596
620
|
logger.info(
|
|
597
|
-
|
|
598
|
-
|
|
621
|
+
"Cache MISS (context similarity too low)",
|
|
622
|
+
layer=self.name,
|
|
623
|
+
context_sim=f"{context_similarity:.4f}",
|
|
624
|
+
threshold=self.parameters.context_similarity_threshold,
|
|
599
625
|
)
|
|
600
626
|
return None
|
|
601
627
|
|
|
@@ -605,14 +631,21 @@ class SemanticCacheService(GenieServiceBase):
|
|
|
605
631
|
delete_sql = f"DELETE FROM {self.table_name} WHERE id = %s"
|
|
606
632
|
cur.execute(delete_sql, (entry_id,))
|
|
607
633
|
logger.info(
|
|
608
|
-
|
|
609
|
-
|
|
634
|
+
"Cache MISS (expired, deleted for refresh)",
|
|
635
|
+
layer=self.name,
|
|
636
|
+
combined_sim=f"{combined_similarity:.4f}",
|
|
637
|
+
ttl_seconds=ttl_seconds,
|
|
638
|
+
cached_question_prefix=cached_question[:50],
|
|
610
639
|
)
|
|
611
640
|
return None
|
|
612
641
|
|
|
613
642
|
logger.info(
|
|
614
|
-
|
|
615
|
-
|
|
643
|
+
"Cache HIT",
|
|
644
|
+
layer=self.name,
|
|
645
|
+
question_sim=f"{question_similarity:.4f}",
|
|
646
|
+
context_sim=f"{context_similarity:.4f}",
|
|
647
|
+
combined_sim=f"{combined_similarity:.4f}",
|
|
648
|
+
cached_question_prefix=cached_question[:50],
|
|
616
649
|
)
|
|
617
650
|
|
|
618
651
|
entry = SQLCacheEntry(
|
|
@@ -664,9 +697,13 @@ class SemanticCacheService(GenieServiceBase):
|
|
|
664
697
|
),
|
|
665
698
|
)
|
|
666
699
|
logger.info(
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
|
|
700
|
+
"Stored cache entry",
|
|
701
|
+
layer=self.name,
|
|
702
|
+
question_prefix=question[:50],
|
|
703
|
+
context_prefix=conversation_context[:80],
|
|
704
|
+
sql_prefix=response.query[:50] if response.query else None,
|
|
705
|
+
space=self.space_id,
|
|
706
|
+
table=self.table_name,
|
|
670
707
|
)
|
|
671
708
|
|
|
672
709
|
@mlflow.trace(name="execute_cached_sql_semantic")
|
|
@@ -692,7 +729,7 @@ class SemanticCacheService(GenieServiceBase):
|
|
|
692
729
|
if statement_response.status.error is not None
|
|
693
730
|
else f"SQL execution failed with state: {statement_response.status.state}"
|
|
694
731
|
)
|
|
695
|
-
logger.error(
|
|
732
|
+
logger.error("SQL execution failed", layer=self.name, error=error_msg)
|
|
696
733
|
return error_msg
|
|
697
734
|
|
|
698
735
|
if statement_response.result and statement_response.result.data_array:
|
|
@@ -765,7 +802,10 @@ class SemanticCacheService(GenieServiceBase):
|
|
|
765
802
|
if cache_result is not None:
|
|
766
803
|
cached, combined_similarity = cache_result
|
|
767
804
|
logger.debug(
|
|
768
|
-
|
|
805
|
+
"Semantic cache hit",
|
|
806
|
+
layer=self.name,
|
|
807
|
+
combined_similarity=f"{combined_similarity:.3f}",
|
|
808
|
+
question_prefix=question[:50],
|
|
769
809
|
)
|
|
770
810
|
|
|
771
811
|
# Re-execute the cached SQL to get fresh data
|
|
@@ -785,15 +825,17 @@ class SemanticCacheService(GenieServiceBase):
|
|
|
785
825
|
return CacheResult(response=response, cache_hit=True, served_by=self.name)
|
|
786
826
|
|
|
787
827
|
# Cache miss - delegate to wrapped service
|
|
788
|
-
logger.
|
|
828
|
+
logger.trace("Cache miss", layer=self.name, question_prefix=question[:50])
|
|
789
829
|
|
|
790
830
|
result: CacheResult = self.impl.ask_question(question, conversation_id)
|
|
791
831
|
|
|
792
832
|
# Store in cache if we got a SQL query
|
|
793
833
|
if result.response.query:
|
|
794
834
|
logger.info(
|
|
795
|
-
|
|
796
|
-
|
|
835
|
+
"Storing new cache entry",
|
|
836
|
+
layer=self.name,
|
|
837
|
+
question_prefix=question[:50],
|
|
838
|
+
space=self.space_id,
|
|
797
839
|
)
|
|
798
840
|
self._store_entry(
|
|
799
841
|
question,
|
|
@@ -804,8 +846,9 @@ class SemanticCacheService(GenieServiceBase):
|
|
|
804
846
|
)
|
|
805
847
|
elif not result.response.query:
|
|
806
848
|
logger.warning(
|
|
807
|
-
|
|
808
|
-
|
|
849
|
+
"Not caching: response has no SQL query",
|
|
850
|
+
layer=self.name,
|
|
851
|
+
question_prefix=question[:50],
|
|
809
852
|
)
|
|
810
853
|
|
|
811
854
|
return CacheResult(response=result.response, cache_hit=False, served_by=None)
|
|
@@ -824,8 +867,10 @@ class SemanticCacheService(GenieServiceBase):
|
|
|
824
867
|
|
|
825
868
|
# If TTL is disabled, nothing can expire
|
|
826
869
|
if ttl_seconds is None or ttl_seconds < 0:
|
|
827
|
-
logger.
|
|
828
|
-
|
|
870
|
+
logger.trace(
|
|
871
|
+
"TTL disabled, no entries to expire",
|
|
872
|
+
layer=self.name,
|
|
873
|
+
space=self.space_id,
|
|
829
874
|
)
|
|
830
875
|
return 0
|
|
831
876
|
|
|
@@ -839,8 +884,11 @@ class SemanticCacheService(GenieServiceBase):
|
|
|
839
884
|
with conn.cursor() as cur:
|
|
840
885
|
cur.execute(delete_sql, (self.space_id, ttl_seconds))
|
|
841
886
|
deleted: int = cur.rowcount
|
|
842
|
-
logger.
|
|
843
|
-
|
|
887
|
+
logger.trace(
|
|
888
|
+
"Deleted expired entries",
|
|
889
|
+
layer=self.name,
|
|
890
|
+
deleted_count=deleted,
|
|
891
|
+
space=self.space_id,
|
|
844
892
|
)
|
|
845
893
|
return deleted
|
|
846
894
|
|
|
@@ -854,7 +902,10 @@ class SemanticCacheService(GenieServiceBase):
|
|
|
854
902
|
cur.execute(delete_sql, (self.space_id,))
|
|
855
903
|
deleted: int = cur.rowcount
|
|
856
904
|
logger.debug(
|
|
857
|
-
|
|
905
|
+
"Cleared cache entries",
|
|
906
|
+
layer=self.name,
|
|
907
|
+
deleted_count=deleted,
|
|
908
|
+
space=self.space_id,
|
|
858
909
|
)
|
|
859
910
|
return deleted
|
|
860
911
|
|
dao_ai/hooks/core.py
CHANGED
|
@@ -25,7 +25,7 @@ def create_hooks(
|
|
|
25
25
|
Returns:
|
|
26
26
|
Sequence of callable functions
|
|
27
27
|
"""
|
|
28
|
-
logger.
|
|
28
|
+
logger.trace("Creating hooks", function_hooks=function_hooks)
|
|
29
29
|
hooks: list[Callable[..., Any]] = []
|
|
30
30
|
if not function_hooks:
|
|
31
31
|
return []
|
|
@@ -35,21 +35,21 @@ def create_hooks(
|
|
|
35
35
|
if isinstance(function_hook, str):
|
|
36
36
|
function_hook = PythonFunctionModel(name=function_hook)
|
|
37
37
|
hooks.extend(function_hook.as_tools())
|
|
38
|
-
logger.
|
|
38
|
+
logger.trace("Created hooks", hooks_count=len(hooks))
|
|
39
39
|
return hooks
|
|
40
40
|
|
|
41
41
|
|
|
42
42
|
def null_hook(state: dict[str, Any], config: Any) -> dict[str, Any]:
|
|
43
43
|
"""A no-op hook that returns an empty dict."""
|
|
44
|
-
logger.
|
|
44
|
+
logger.trace("Executing null hook")
|
|
45
45
|
return {}
|
|
46
46
|
|
|
47
47
|
|
|
48
48
|
def null_initialization_hook(config: AppConfig) -> None:
|
|
49
49
|
"""A no-op initialization hook."""
|
|
50
|
-
logger.
|
|
50
|
+
logger.trace("Executing null initialization hook")
|
|
51
51
|
|
|
52
52
|
|
|
53
53
|
def null_shutdown_hook(config: AppConfig) -> None:
|
|
54
54
|
"""A no-op shutdown hook."""
|
|
55
|
-
logger.
|
|
55
|
+
logger.trace("Executing null shutdown hook")
|
dao_ai/logging.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
"""Logging configuration for DAO AI."""
|
|
2
|
+
|
|
3
|
+
import sys
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from loguru import logger
|
|
7
|
+
|
|
8
|
+
# Re-export logger for convenience
|
|
9
|
+
__all__ = ["logger", "configure_logging"]
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def format_extra(record: dict[str, Any]) -> str:
|
|
13
|
+
"""Format extra fields as key=value pairs."""
|
|
14
|
+
extra: dict[str, Any] = record["extra"]
|
|
15
|
+
if not extra:
|
|
16
|
+
return ""
|
|
17
|
+
|
|
18
|
+
formatted_pairs: list[str] = []
|
|
19
|
+
for key, value in extra.items():
|
|
20
|
+
# Handle different value types
|
|
21
|
+
if isinstance(value, str):
|
|
22
|
+
formatted_pairs.append(f"{key}={value}")
|
|
23
|
+
elif isinstance(value, (list, tuple)):
|
|
24
|
+
formatted_pairs.append(f"{key}={','.join(str(v) for v in value)}")
|
|
25
|
+
else:
|
|
26
|
+
formatted_pairs.append(f"{key}={value}")
|
|
27
|
+
|
|
28
|
+
return " | ".join(formatted_pairs)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def configure_logging(level: str = "INFO") -> None:
|
|
32
|
+
"""
|
|
33
|
+
Configure loguru logging with structured output.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
level: The log level (e.g., "INFO", "DEBUG", "WARNING")
|
|
37
|
+
"""
|
|
38
|
+
logger.remove()
|
|
39
|
+
logger.add(
|
|
40
|
+
sys.stderr,
|
|
41
|
+
level=level,
|
|
42
|
+
format=(
|
|
43
|
+
"<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | "
|
|
44
|
+
"<level>{level: <8}</level> | "
|
|
45
|
+
"<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - "
|
|
46
|
+
"<level>{message}</level>"
|
|
47
|
+
"{extra}"
|
|
48
|
+
),
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
# Add custom formatter for extra fields
|
|
52
|
+
logger.configure(
|
|
53
|
+
patcher=lambda record: record.update(
|
|
54
|
+
extra=" | " + format_extra(record) if record["extra"] else ""
|
|
55
|
+
)
|
|
56
|
+
)
|
dao_ai/memory/core.py
CHANGED
|
@@ -25,11 +25,13 @@ class InMemoryStoreManager(StoreManagerBase):
|
|
|
25
25
|
self.store_model = store_model
|
|
26
26
|
|
|
27
27
|
def store(self) -> BaseStore:
|
|
28
|
-
|
|
28
|
+
embedding_model: LLMModel = self.store_model.embedding_model
|
|
29
29
|
|
|
30
|
-
|
|
30
|
+
logger.debug(
|
|
31
|
+
"Creating in-memory store", embeddings_enabled=embedding_model is not None
|
|
32
|
+
)
|
|
31
33
|
|
|
32
|
-
|
|
34
|
+
index: dict[str, Any] = None
|
|
33
35
|
|
|
34
36
|
if embedding_model:
|
|
35
37
|
embeddings: Embeddings = DatabricksEmbeddings(endpoint=embedding_model.name)
|
|
@@ -39,6 +41,11 @@ class InMemoryStoreManager(StoreManagerBase):
|
|
|
39
41
|
|
|
40
42
|
dims: int = self.store_model.dims
|
|
41
43
|
index = {"dims": dims, "embed": embed_texts}
|
|
44
|
+
logger.debug(
|
|
45
|
+
"Store embeddings configured",
|
|
46
|
+
endpoint=embedding_model.name,
|
|
47
|
+
dimensions=dims,
|
|
48
|
+
)
|
|
42
49
|
|
|
43
50
|
store: BaseStore = InMemoryStore(index=index)
|
|
44
51
|
|
|
@@ -59,32 +66,38 @@ class StoreManager:
|
|
|
59
66
|
@classmethod
|
|
60
67
|
def instance(cls, store_model: StoreModel) -> StoreManagerBase:
|
|
61
68
|
store_manager: StoreManagerBase | None = None
|
|
62
|
-
match store_model.
|
|
69
|
+
match store_model.storage_type:
|
|
63
70
|
case StorageType.MEMORY:
|
|
64
71
|
store_manager = cls.store_managers.get(store_model.name)
|
|
65
72
|
if store_manager is None:
|
|
66
73
|
store_manager = InMemoryStoreManager(store_model)
|
|
67
74
|
cls.store_managers[store_model.name] = store_manager
|
|
68
75
|
case StorageType.POSTGRES:
|
|
69
|
-
|
|
76
|
+
# Route based on database configuration: instance_name -> Databricks, host -> Postgres
|
|
77
|
+
if store_model.database.is_lakebase:
|
|
78
|
+
# Databricks Lakebase connection
|
|
79
|
+
from dao_ai.memory.databricks import DatabricksStoreManager
|
|
70
80
|
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
)
|
|
74
|
-
if store_manager is None:
|
|
75
|
-
store_manager = PostgresStoreManager(store_model)
|
|
76
|
-
cls.store_managers[store_model.database.instance_name] = (
|
|
77
|
-
store_manager
|
|
81
|
+
store_manager = cls.store_managers.get(
|
|
82
|
+
store_model.database.instance_name
|
|
78
83
|
)
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
84
|
+
if store_manager is None:
|
|
85
|
+
store_manager = DatabricksStoreManager(store_model)
|
|
86
|
+
cls.store_managers[store_model.database.instance_name] = (
|
|
87
|
+
store_manager
|
|
88
|
+
)
|
|
89
|
+
else:
|
|
90
|
+
# Standard PostgreSQL connection
|
|
91
|
+
from dao_ai.memory.postgres import PostgresStoreManager
|
|
92
|
+
|
|
93
|
+
# Use database name as key for standard PostgreSQL
|
|
94
|
+
cache_key = f"{store_model.database.name}"
|
|
95
|
+
store_manager = cls.store_managers.get(cache_key)
|
|
96
|
+
if store_manager is None:
|
|
97
|
+
store_manager = PostgresStoreManager(store_model)
|
|
98
|
+
cls.store_managers[cache_key] = store_manager
|
|
86
99
|
case _:
|
|
87
|
-
raise ValueError(f"Unknown
|
|
100
|
+
raise ValueError(f"Unknown storage type: {store_model.storage_type}")
|
|
88
101
|
|
|
89
102
|
return store_manager
|
|
90
103
|
|
|
@@ -95,7 +108,7 @@ class CheckpointManager:
|
|
|
95
108
|
@classmethod
|
|
96
109
|
def instance(cls, checkpointer_model: CheckpointerModel) -> CheckpointManagerBase:
|
|
97
110
|
checkpointer_manager: CheckpointManagerBase | None = None
|
|
98
|
-
match checkpointer_model.
|
|
111
|
+
match checkpointer_model.storage_type:
|
|
99
112
|
case StorageType.MEMORY:
|
|
100
113
|
checkpointer_manager = cls.checkpoint_managers.get(
|
|
101
114
|
checkpointer_model.name
|
|
@@ -108,32 +121,36 @@ class CheckpointManager:
|
|
|
108
121
|
checkpointer_manager
|
|
109
122
|
)
|
|
110
123
|
case StorageType.POSTGRES:
|
|
111
|
-
|
|
124
|
+
# Route based on database configuration: instance_name -> Databricks, host -> Postgres
|
|
125
|
+
if checkpointer_model.database.is_lakebase:
|
|
126
|
+
# Databricks Lakebase connection
|
|
127
|
+
from dao_ai.memory.databricks import DatabricksCheckpointerManager
|
|
112
128
|
|
|
113
|
-
|
|
114
|
-
checkpointer_model.database.instance_name
|
|
115
|
-
)
|
|
116
|
-
if checkpointer_manager is None:
|
|
117
|
-
checkpointer_manager = AsyncPostgresCheckpointerManager(
|
|
118
|
-
checkpointer_model
|
|
119
|
-
)
|
|
120
|
-
cls.checkpoint_managers[
|
|
129
|
+
checkpointer_manager = cls.checkpoint_managers.get(
|
|
121
130
|
checkpointer_model.database.instance_name
|
|
122
|
-
] = checkpointer_manager
|
|
123
|
-
case StorageType.LAKEBASE:
|
|
124
|
-
from dao_ai.memory.databricks import DatabricksCheckpointerManager
|
|
125
|
-
|
|
126
|
-
checkpointer_manager = cls.checkpoint_managers.get(
|
|
127
|
-
checkpointer_model.name
|
|
128
|
-
)
|
|
129
|
-
if checkpointer_manager is None:
|
|
130
|
-
checkpointer_manager = DatabricksCheckpointerManager(
|
|
131
|
-
checkpointer_model
|
|
132
|
-
)
|
|
133
|
-
cls.checkpoint_managers[checkpointer_model.name] = (
|
|
134
|
-
checkpointer_manager
|
|
135
131
|
)
|
|
132
|
+
if checkpointer_manager is None:
|
|
133
|
+
checkpointer_manager = DatabricksCheckpointerManager(
|
|
134
|
+
checkpointer_model
|
|
135
|
+
)
|
|
136
|
+
cls.checkpoint_managers[
|
|
137
|
+
checkpointer_model.database.instance_name
|
|
138
|
+
] = checkpointer_manager
|
|
139
|
+
else:
|
|
140
|
+
# Standard PostgreSQL connection
|
|
141
|
+
from dao_ai.memory.postgres import AsyncPostgresCheckpointerManager
|
|
142
|
+
|
|
143
|
+
# Use database name as key for standard PostgreSQL
|
|
144
|
+
cache_key = f"{checkpointer_model.database.name}"
|
|
145
|
+
checkpointer_manager = cls.checkpoint_managers.get(cache_key)
|
|
146
|
+
if checkpointer_manager is None:
|
|
147
|
+
checkpointer_manager = AsyncPostgresCheckpointerManager(
|
|
148
|
+
checkpointer_model
|
|
149
|
+
)
|
|
150
|
+
cls.checkpoint_managers[cache_key] = checkpointer_manager
|
|
136
151
|
case _:
|
|
137
|
-
raise ValueError(
|
|
152
|
+
raise ValueError(
|
|
153
|
+
f"Unknown storage type: {checkpointer_model.storage_type}"
|
|
154
|
+
)
|
|
138
155
|
|
|
139
156
|
return checkpointer_manager
|