dao-ai 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. dao_ai/agent_as_code.py +2 -5
  2. dao_ai/cli.py +65 -15
  3. dao_ai/config.py +672 -218
  4. dao_ai/genie/cache/core.py +6 -2
  5. dao_ai/genie/cache/lru.py +29 -11
  6. dao_ai/genie/cache/semantic.py +95 -44
  7. dao_ai/hooks/core.py +5 -5
  8. dao_ai/logging.py +56 -0
  9. dao_ai/memory/core.py +61 -44
  10. dao_ai/memory/databricks.py +54 -41
  11. dao_ai/memory/postgres.py +77 -36
  12. dao_ai/middleware/assertions.py +45 -17
  13. dao_ai/middleware/core.py +13 -7
  14. dao_ai/middleware/guardrails.py +30 -25
  15. dao_ai/middleware/human_in_the_loop.py +9 -5
  16. dao_ai/middleware/message_validation.py +61 -29
  17. dao_ai/middleware/summarization.py +16 -11
  18. dao_ai/models.py +172 -69
  19. dao_ai/nodes.py +148 -19
  20. dao_ai/optimization.py +26 -16
  21. dao_ai/orchestration/core.py +15 -8
  22. dao_ai/orchestration/supervisor.py +22 -8
  23. dao_ai/orchestration/swarm.py +57 -12
  24. dao_ai/prompts.py +17 -17
  25. dao_ai/providers/databricks.py +365 -155
  26. dao_ai/state.py +24 -6
  27. dao_ai/tools/__init__.py +2 -0
  28. dao_ai/tools/agent.py +1 -3
  29. dao_ai/tools/core.py +7 -7
  30. dao_ai/tools/email.py +29 -77
  31. dao_ai/tools/genie.py +18 -13
  32. dao_ai/tools/mcp.py +223 -156
  33. dao_ai/tools/python.py +5 -2
  34. dao_ai/tools/search.py +1 -1
  35. dao_ai/tools/slack.py +21 -9
  36. dao_ai/tools/sql.py +202 -0
  37. dao_ai/tools/time.py +30 -7
  38. dao_ai/tools/unity_catalog.py +129 -86
  39. dao_ai/tools/vector_search.py +318 -244
  40. dao_ai/utils.py +15 -10
  41. dao_ai-0.1.3.dist-info/METADATA +455 -0
  42. dao_ai-0.1.3.dist-info/RECORD +64 -0
  43. dao_ai-0.1.1.dist-info/METADATA +0 -1878
  44. dao_ai-0.1.1.dist-info/RECORD +0 -62
  45. {dao_ai-0.1.1.dist-info → dao_ai-0.1.3.dist-info}/WHEEL +0 -0
  46. {dao_ai-0.1.1.dist-info → dao_ai-0.1.3.dist-info}/entry_points.txt +0 -0
  47. {dao_ai-0.1.1.dist-info → dao_ai-0.1.3.dist-info}/licenses/LICENSE +0 -0
@@ -38,7 +38,7 @@ def execute_sql_via_warehouse(
38
38
  w: WorkspaceClient = warehouse.workspace_client
39
39
  warehouse_id: str = str(warehouse.warehouse_id)
40
40
 
41
- logger.debug(f"[{layer_name}] Executing cached SQL: {sql[:100]}...")
41
+ logger.trace("Executing cached SQL", layer=layer_name, sql_prefix=sql[:100])
42
42
 
43
43
  statement_response: StatementResponse = w.statement_execution.execute_statement(
44
44
  statement=sql,
@@ -57,7 +57,11 @@ def execute_sql_via_warehouse(
57
57
 
58
58
  if statement_response.status.state != StatementState.SUCCEEDED:
59
59
  error_msg: str = f"SQL execution failed: {statement_response.status}"
60
- logger.error(f"[{layer_name}] {error_msg}")
60
+ logger.error(
61
+ "SQL execution failed",
62
+ layer=layer_name,
63
+ status=str(statement_response.status),
64
+ )
61
65
  return error_msg
62
66
 
63
67
  # Convert to DataFrame
dao_ai/genie/cache/lru.py CHANGED
@@ -124,7 +124,9 @@ class LRUCacheService(GenieServiceBase):
124
124
  if self._cache:
125
125
  oldest_key: str = next(iter(self._cache))
126
126
  del self._cache[oldest_key]
127
- logger.debug(f"[{self.name}] Evicted: {oldest_key[:50]}...")
127
+ logger.trace(
128
+ "Evicted cache entry", layer=self.name, key_prefix=oldest_key[:50]
129
+ )
128
130
 
129
131
  def _get(self, key: str) -> SQLCacheEntry | None:
130
132
  """Get from cache, returning None if not found or expired."""
@@ -135,7 +137,7 @@ class LRUCacheService(GenieServiceBase):
135
137
 
136
138
  if self._is_expired(entry):
137
139
  del self._cache[key]
138
- logger.debug(f"[{self.name}] Expired: {key[:50]}...")
140
+ logger.trace("Expired cache entry", layer=self.name, key_prefix=key[:50])
139
141
  return None
140
142
 
141
143
  self._cache.move_to_end(key)
@@ -156,9 +158,12 @@ class LRUCacheService(GenieServiceBase):
156
158
  created_at=datetime.now(),
157
159
  )
158
160
  logger.info(
159
- f"[{self.name}] Stored cache entry: key='{key[:50]}...' "
160
- f"sql='{response.query[:50] if response.query else 'None'}...' "
161
- f"(cache_size={len(self._cache)}/{self.capacity})"
161
+ "Stored cache entry",
162
+ layer=self.name,
163
+ key_prefix=key[:50],
164
+ sql_prefix=response.query[:50] if response.query else None,
165
+ cache_size=len(self._cache),
166
+ capacity=self.capacity,
162
167
  )
163
168
 
164
169
  @mlflow.trace(name="execute_cached_sql")
@@ -175,7 +180,7 @@ class LRUCacheService(GenieServiceBase):
175
180
  w: WorkspaceClient = self.warehouse.workspace_client
176
181
  warehouse_id: str = str(self.warehouse.warehouse_id)
177
182
 
178
- logger.debug(f"[{self.name}] Executing cached SQL: {sql[:100]}...")
183
+ logger.trace("Executing cached SQL", layer=self.name, sql_prefix=sql[:100])
179
184
 
180
185
  statement_response: StatementResponse = w.statement_execution.execute_statement(
181
186
  statement=sql,
@@ -194,7 +199,11 @@ class LRUCacheService(GenieServiceBase):
194
199
 
195
200
  if statement_response.status.state != StatementState.SUCCEEDED:
196
201
  error_msg: str = f"SQL execution failed: {statement_response.status}"
197
- logger.error(f"[{self.name}] {error_msg}")
202
+ logger.error(
203
+ "SQL execution failed",
204
+ layer=self.name,
205
+ status=str(statement_response.status),
206
+ )
198
207
  return error_msg
199
208
 
200
209
  # Convert to DataFrame
@@ -250,8 +259,12 @@ class LRUCacheService(GenieServiceBase):
250
259
 
251
260
  if cached is not None:
252
261
  logger.info(
253
- f"[{self.name}] Cache HIT: '{question[:50]}...' "
254
- f"(conversation_id={conversation_id}, cache_size={self.size}/{self.capacity})"
262
+ "Cache HIT",
263
+ layer=self.name,
264
+ question_prefix=question[:50],
265
+ conversation_id=conversation_id,
266
+ cache_size=self.size,
267
+ capacity=self.capacity,
255
268
  )
256
269
 
257
270
  # Re-execute the cached SQL to get fresh data
@@ -271,8 +284,13 @@ class LRUCacheService(GenieServiceBase):
271
284
 
272
285
  # Cache miss - delegate to wrapped service
273
286
  logger.info(
274
- f"[{self.name}] Cache MISS: '{question[:50]}...' "
275
- f"(conversation_id={conversation_id}, cache_size={self.size}/{self.capacity}, delegating to {type(self.impl).__name__})"
287
+ "Cache MISS",
288
+ layer=self.name,
289
+ question_prefix=question[:50],
290
+ conversation_id=conversation_id,
291
+ cache_size=self.size,
292
+ capacity=self.capacity,
293
+ delegating_to=type(self.impl).__name__,
276
294
  )
277
295
 
278
296
  result: CacheResult = self.impl.ask_question(question, conversation_id)
@@ -78,7 +78,9 @@ def get_conversation_history(
78
78
  return []
79
79
  except Exception as e:
80
80
  logger.warning(
81
- f"Failed to retrieve conversation history for conversation_id={conversation_id}: {e}"
81
+ "Failed to retrieve conversation history",
82
+ conversation_id=conversation_id,
83
+ error=str(e),
82
84
  )
83
85
  return []
84
86
 
@@ -137,10 +139,13 @@ def build_context_string(
137
139
  if estimated_tokens > max_tokens:
138
140
  # Truncate to fit max_tokens
139
141
  target_chars = max_tokens * 4
142
+ original_length = len(context_string)
140
143
  context_string = context_string[:target_chars] + "..."
141
- logger.debug(
142
- f"Truncated context string from {len(context_string)} to {target_chars} chars "
143
- f"(estimated {max_tokens} tokens)"
144
+ logger.trace(
145
+ "Truncated context string",
146
+ original_chars=original_length,
147
+ target_chars=target_chars,
148
+ max_tokens=max_tokens,
144
149
  )
145
150
 
146
151
  return context_string
@@ -251,7 +256,9 @@ class SemanticCacheService(GenieServiceBase):
251
256
  sample_embedding: list[float] = self._embeddings.embed_query("test")
252
257
  self._embedding_dims = len(sample_embedding)
253
258
  logger.debug(
254
- f"[{self.name}] Auto-detected embedding dimensions: {self._embedding_dims}"
259
+ "Auto-detected embedding dimensions",
260
+ layer=self.name,
261
+ dims=self._embedding_dims,
255
262
  )
256
263
  else:
257
264
  self._embedding_dims = self.parameters.embedding_dims
@@ -264,8 +271,11 @@ class SemanticCacheService(GenieServiceBase):
264
271
 
265
272
  self._setup_complete = True
266
273
  logger.debug(
267
- f"[{self.name}] Semantic cache initialized for space '{self.space_id}' "
268
- f"with table '{self.table_name}' (dims={self._embedding_dims})"
274
+ "Semantic cache initialized",
275
+ layer=self.name,
276
+ space_id=self.space_id,
277
+ table_name=self.table_name,
278
+ dims=self._embedding_dims,
269
279
  )
270
280
 
271
281
  @property
@@ -369,9 +379,11 @@ class SemanticCacheService(GenieServiceBase):
369
379
  current_dims = row.get("atttypmod", 0)
370
380
  if current_dims != self.embedding_dims:
371
381
  logger.warning(
372
- f"[{self.name}] Embedding dimension mismatch: "
373
- f"table has {current_dims}, expected {self.embedding_dims}. "
374
- f"Dropping and recreating table '{self.table_name}'."
382
+ "Embedding dimension mismatch, dropping and recreating table",
383
+ layer=self.name,
384
+ table_dims=current_dims,
385
+ expected_dims=self.embedding_dims,
386
+ table_name=self.table_name,
375
387
  )
376
388
  cur.execute(f"DROP TABLE {self.table_name}")
377
389
  except Exception:
@@ -448,13 +460,17 @@ class SemanticCacheService(GenieServiceBase):
448
460
  conversation_context[:target_chars] + "..."
449
461
  )
450
462
 
451
- logger.debug(
452
- f"[{self.name}] Using conversation context: {len(conversation_messages)} messages "
453
- f"(window_size={self.parameters.context_window_size})"
463
+ logger.trace(
464
+ "Using conversation context",
465
+ layer=self.name,
466
+ messages_count=len(conversation_messages),
467
+ window_size=self.parameters.context_window_size,
454
468
  )
455
469
  except Exception as e:
456
470
  logger.warning(
457
- f"[{self.name}] Failed to build conversation context, using question only: {e}"
471
+ "Failed to build conversation context, using question only",
472
+ layer=self.name,
473
+ error=str(e),
458
474
  )
459
475
  conversation_context = ""
460
476
 
@@ -558,8 +574,10 @@ class SemanticCacheService(GenieServiceBase):
558
574
 
559
575
  if row is None:
560
576
  logger.info(
561
- f"[{self.name}] MISS (no entries): "
562
- f"question='{question[:50]}...' space='{self.space_id}'"
577
+ "Cache MISS (no entries)",
578
+ layer=self.name,
579
+ question_prefix=question[:50],
580
+ space=self.space_id,
563
581
  )
564
582
  return None
565
583
 
@@ -577,25 +595,33 @@ class SemanticCacheService(GenieServiceBase):
577
595
  is_valid: bool = row.get("is_valid", False)
578
596
 
579
597
  # Log best match info
580
- logger.info(
581
- f"[{self.name}] Best match: "
582
- f"question_sim={question_similarity:.4f}, context_sim={context_similarity:.4f}, "
583
- f"combined_sim={combined_similarity:.4f}, is_valid={is_valid}, "
584
- f"question='{cached_question[:50]}...', context='{cached_context[:80]}...'"
598
+ logger.debug(
599
+ "Best match found",
600
+ layer=self.name,
601
+ question_sim=f"{question_similarity:.4f}",
602
+ context_sim=f"{context_similarity:.4f}",
603
+ combined_sim=f"{combined_similarity:.4f}",
604
+ is_valid=is_valid,
605
+ cached_question_prefix=cached_question[:50],
606
+ cached_context_prefix=cached_context[:80],
585
607
  )
586
608
 
587
609
  # Check BOTH similarity thresholds (dual embedding precision check)
588
610
  if question_similarity < self.parameters.similarity_threshold:
589
611
  logger.info(
590
- f"[{self.name}] MISS (question similarity too low): "
591
- f"question_sim={question_similarity:.4f} < threshold={self.parameters.similarity_threshold}"
612
+ "Cache MISS (question similarity too low)",
613
+ layer=self.name,
614
+ question_sim=f"{question_similarity:.4f}",
615
+ threshold=self.parameters.similarity_threshold,
592
616
  )
593
617
  return None
594
618
 
595
619
  if context_similarity < self.parameters.context_similarity_threshold:
596
620
  logger.info(
597
- f"[{self.name}] MISS (context similarity too low): "
598
- f"context_sim={context_similarity:.4f} < threshold={self.parameters.context_similarity_threshold}"
621
+ "Cache MISS (context similarity too low)",
622
+ layer=self.name,
623
+ context_sim=f"{context_similarity:.4f}",
624
+ threshold=self.parameters.context_similarity_threshold,
599
625
  )
600
626
  return None
601
627
 
@@ -605,14 +631,21 @@ class SemanticCacheService(GenieServiceBase):
605
631
  delete_sql = f"DELETE FROM {self.table_name} WHERE id = %s"
606
632
  cur.execute(delete_sql, (entry_id,))
607
633
  logger.info(
608
- f"[{self.name}] MISS (expired, deleted for refresh): "
609
- f"combined_sim={combined_similarity:.4f}, ttl={ttl_seconds}s, question='{cached_question[:50]}...'"
634
+ "Cache MISS (expired, deleted for refresh)",
635
+ layer=self.name,
636
+ combined_sim=f"{combined_similarity:.4f}",
637
+ ttl_seconds=ttl_seconds,
638
+ cached_question_prefix=cached_question[:50],
610
639
  )
611
640
  return None
612
641
 
613
642
  logger.info(
614
- f"[{self.name}] HIT: question_sim={question_similarity:.4f}, context_sim={context_similarity:.4f}, "
615
- f"combined_sim={combined_similarity:.4f} (cached_question='{cached_question[:50]}...')"
643
+ "Cache HIT",
644
+ layer=self.name,
645
+ question_sim=f"{question_similarity:.4f}",
646
+ context_sim=f"{context_similarity:.4f}",
647
+ combined_sim=f"{combined_similarity:.4f}",
648
+ cached_question_prefix=cached_question[:50],
616
649
  )
617
650
 
618
651
  entry = SQLCacheEntry(
@@ -664,9 +697,13 @@ class SemanticCacheService(GenieServiceBase):
664
697
  ),
665
698
  )
666
699
  logger.info(
667
- f"[{self.name}] Stored cache entry: question='{question[:50]}...' "
668
- f"context='{conversation_context[:80]}...' "
669
- f"sql='{response.query[:50]}...' (space={self.space_id}, table={self.table_name})"
700
+ "Stored cache entry",
701
+ layer=self.name,
702
+ question_prefix=question[:50],
703
+ context_prefix=conversation_context[:80],
704
+ sql_prefix=response.query[:50] if response.query else None,
705
+ space=self.space_id,
706
+ table=self.table_name,
670
707
  )
671
708
 
672
709
  @mlflow.trace(name="execute_cached_sql_semantic")
@@ -692,7 +729,7 @@ class SemanticCacheService(GenieServiceBase):
692
729
  if statement_response.status.error is not None
693
730
  else f"SQL execution failed with state: {statement_response.status.state}"
694
731
  )
695
- logger.error(f"[{self.name}] {error_msg}")
732
+ logger.error("SQL execution failed", layer=self.name, error=error_msg)
696
733
  return error_msg
697
734
 
698
735
  if statement_response.result and statement_response.result.data_array:
@@ -765,7 +802,10 @@ class SemanticCacheService(GenieServiceBase):
765
802
  if cache_result is not None:
766
803
  cached, combined_similarity = cache_result
767
804
  logger.debug(
768
- f"[{self.name}] Semantic cache hit (combined_similarity={combined_similarity:.3f}): {question[:50]}..."
805
+ "Semantic cache hit",
806
+ layer=self.name,
807
+ combined_similarity=f"{combined_similarity:.3f}",
808
+ question_prefix=question[:50],
769
809
  )
770
810
 
771
811
  # Re-execute the cached SQL to get fresh data
@@ -785,15 +825,17 @@ class SemanticCacheService(GenieServiceBase):
785
825
  return CacheResult(response=response, cache_hit=True, served_by=self.name)
786
826
 
787
827
  # Cache miss - delegate to wrapped service
788
- logger.debug(f"[{self.name}] Miss: {question[:50]}...")
828
+ logger.trace("Cache miss", layer=self.name, question_prefix=question[:50])
789
829
 
790
830
  result: CacheResult = self.impl.ask_question(question, conversation_id)
791
831
 
792
832
  # Store in cache if we got a SQL query
793
833
  if result.response.query:
794
834
  logger.info(
795
- f"[{self.name}] Storing new cache entry for question: '{question[:50]}...' "
796
- f"(space={self.space_id})"
835
+ "Storing new cache entry",
836
+ layer=self.name,
837
+ question_prefix=question[:50],
838
+ space=self.space_id,
797
839
  )
798
840
  self._store_entry(
799
841
  question,
@@ -804,8 +846,9 @@ class SemanticCacheService(GenieServiceBase):
804
846
  )
805
847
  elif not result.response.query:
806
848
  logger.warning(
807
- f"[{self.name}] Not caching: response has no SQL query "
808
- f"(question='{question[:50]}...')"
849
+ "Not caching: response has no SQL query",
850
+ layer=self.name,
851
+ question_prefix=question[:50],
809
852
  )
810
853
 
811
854
  return CacheResult(response=result.response, cache_hit=False, served_by=None)
@@ -824,8 +867,10 @@ class SemanticCacheService(GenieServiceBase):
824
867
 
825
868
  # If TTL is disabled, nothing can expire
826
869
  if ttl_seconds is None or ttl_seconds < 0:
827
- logger.debug(
828
- f"[{self.name}] TTL disabled, no entries to expire for space {self.space_id}"
870
+ logger.trace(
871
+ "TTL disabled, no entries to expire",
872
+ layer=self.name,
873
+ space=self.space_id,
829
874
  )
830
875
  return 0
831
876
 
@@ -839,8 +884,11 @@ class SemanticCacheService(GenieServiceBase):
839
884
  with conn.cursor() as cur:
840
885
  cur.execute(delete_sql, (self.space_id, ttl_seconds))
841
886
  deleted: int = cur.rowcount
842
- logger.debug(
843
- f"[{self.name}] Deleted {deleted} expired entries for space {self.space_id}"
887
+ logger.trace(
888
+ "Deleted expired entries",
889
+ layer=self.name,
890
+ deleted_count=deleted,
891
+ space=self.space_id,
844
892
  )
845
893
  return deleted
846
894
 
@@ -854,7 +902,10 @@ class SemanticCacheService(GenieServiceBase):
854
902
  cur.execute(delete_sql, (self.space_id,))
855
903
  deleted: int = cur.rowcount
856
904
  logger.debug(
857
- f"[{self.name}] Cleared {deleted} entries for space {self.space_id}"
905
+ "Cleared cache entries",
906
+ layer=self.name,
907
+ deleted_count=deleted,
908
+ space=self.space_id,
858
909
  )
859
910
  return deleted
860
911
 
dao_ai/hooks/core.py CHANGED
@@ -25,7 +25,7 @@ def create_hooks(
25
25
  Returns:
26
26
  Sequence of callable functions
27
27
  """
28
- logger.debug(f"Creating hooks from: {function_hooks}")
28
+ logger.trace("Creating hooks", function_hooks=function_hooks)
29
29
  hooks: list[Callable[..., Any]] = []
30
30
  if not function_hooks:
31
31
  return []
@@ -35,21 +35,21 @@ def create_hooks(
35
35
  if isinstance(function_hook, str):
36
36
  function_hook = PythonFunctionModel(name=function_hook)
37
37
  hooks.extend(function_hook.as_tools())
38
- logger.debug(f"Created hooks: {hooks}")
38
+ logger.trace("Created hooks", hooks_count=len(hooks))
39
39
  return hooks
40
40
 
41
41
 
42
42
  def null_hook(state: dict[str, Any], config: Any) -> dict[str, Any]:
43
43
  """A no-op hook that returns an empty dict."""
44
- logger.debug("Executing null hook")
44
+ logger.trace("Executing null hook")
45
45
  return {}
46
46
 
47
47
 
48
48
  def null_initialization_hook(config: AppConfig) -> None:
49
49
  """A no-op initialization hook."""
50
- logger.debug("Executing null initialization hook")
50
+ logger.trace("Executing null initialization hook")
51
51
 
52
52
 
53
53
  def null_shutdown_hook(config: AppConfig) -> None:
54
54
  """A no-op shutdown hook."""
55
- logger.debug("Executing null shutdown hook")
55
+ logger.trace("Executing null shutdown hook")
dao_ai/logging.py ADDED
@@ -0,0 +1,56 @@
1
+ """Logging configuration for DAO AI."""
2
+
3
+ import sys
4
+ from typing import Any
5
+
6
+ from loguru import logger
7
+
8
+ # Re-export logger for convenience
9
+ __all__ = ["logger", "configure_logging"]
10
+
11
+
12
+ def format_extra(record: dict[str, Any]) -> str:
13
+ """Format extra fields as key=value pairs."""
14
+ extra: dict[str, Any] = record["extra"]
15
+ if not extra:
16
+ return ""
17
+
18
+ formatted_pairs: list[str] = []
19
+ for key, value in extra.items():
20
+ # Handle different value types
21
+ if isinstance(value, str):
22
+ formatted_pairs.append(f"{key}={value}")
23
+ elif isinstance(value, (list, tuple)):
24
+ formatted_pairs.append(f"{key}={','.join(str(v) for v in value)}")
25
+ else:
26
+ formatted_pairs.append(f"{key}={value}")
27
+
28
+ return " | ".join(formatted_pairs)
29
+
30
+
31
+ def configure_logging(level: str = "INFO") -> None:
32
+ """
33
+ Configure loguru logging with structured output.
34
+
35
+ Args:
36
+ level: The log level (e.g., "INFO", "DEBUG", "WARNING")
37
+ """
38
+ logger.remove()
39
+ logger.add(
40
+ sys.stderr,
41
+ level=level,
42
+ format=(
43
+ "<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | "
44
+ "<level>{level: <8}</level> | "
45
+ "<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - "
46
+ "<level>{message}</level>"
47
+ "{extra}"
48
+ ),
49
+ )
50
+
51
+ # Add custom formatter for extra fields
52
+ logger.configure(
53
+ patcher=lambda record: record.update(
54
+ extra=" | " + format_extra(record) if record["extra"] else ""
55
+ )
56
+ )
dao_ai/memory/core.py CHANGED
@@ -25,11 +25,13 @@ class InMemoryStoreManager(StoreManagerBase):
25
25
  self.store_model = store_model
26
26
 
27
27
  def store(self) -> BaseStore:
28
- logger.debug("Creating InMemory store")
28
+ embedding_model: LLMModel = self.store_model.embedding_model
29
29
 
30
- index: dict[str, Any] = None
30
+ logger.debug(
31
+ "Creating in-memory store", embeddings_enabled=embedding_model is not None
32
+ )
31
33
 
32
- embedding_model: LLMModel = self.store_model.embedding_model
34
+ index: dict[str, Any] = None
33
35
 
34
36
  if embedding_model:
35
37
  embeddings: Embeddings = DatabricksEmbeddings(endpoint=embedding_model.name)
@@ -39,6 +41,11 @@ class InMemoryStoreManager(StoreManagerBase):
39
41
 
40
42
  dims: int = self.store_model.dims
41
43
  index = {"dims": dims, "embed": embed_texts}
44
+ logger.debug(
45
+ "Store embeddings configured",
46
+ endpoint=embedding_model.name,
47
+ dimensions=dims,
48
+ )
42
49
 
43
50
  store: BaseStore = InMemoryStore(index=index)
44
51
 
@@ -59,32 +66,38 @@ class StoreManager:
59
66
  @classmethod
60
67
  def instance(cls, store_model: StoreModel) -> StoreManagerBase:
61
68
  store_manager: StoreManagerBase | None = None
62
- match store_model.type:
69
+ match store_model.storage_type:
63
70
  case StorageType.MEMORY:
64
71
  store_manager = cls.store_managers.get(store_model.name)
65
72
  if store_manager is None:
66
73
  store_manager = InMemoryStoreManager(store_model)
67
74
  cls.store_managers[store_model.name] = store_manager
68
75
  case StorageType.POSTGRES:
69
- from dao_ai.memory.postgres import PostgresStoreManager
76
+ # Route based on database configuration: instance_name -> Databricks, host -> Postgres
77
+ if store_model.database.is_lakebase:
78
+ # Databricks Lakebase connection
79
+ from dao_ai.memory.databricks import DatabricksStoreManager
70
80
 
71
- store_manager = cls.store_managers.get(
72
- store_model.database.instance_name
73
- )
74
- if store_manager is None:
75
- store_manager = PostgresStoreManager(store_model)
76
- cls.store_managers[store_model.database.instance_name] = (
77
- store_manager
81
+ store_manager = cls.store_managers.get(
82
+ store_model.database.instance_name
78
83
  )
79
- case StorageType.LAKEBASE:
80
- from dao_ai.memory.databricks import DatabricksStoreManager
81
-
82
- store_manager = cls.store_managers.get(store_model.name)
83
- if store_manager is None:
84
- store_manager = DatabricksStoreManager(store_model)
85
- cls.store_managers[store_model.name] = store_manager
84
+ if store_manager is None:
85
+ store_manager = DatabricksStoreManager(store_model)
86
+ cls.store_managers[store_model.database.instance_name] = (
87
+ store_manager
88
+ )
89
+ else:
90
+ # Standard PostgreSQL connection
91
+ from dao_ai.memory.postgres import PostgresStoreManager
92
+
93
+ # Use database name as key for standard PostgreSQL
94
+ cache_key = f"{store_model.database.name}"
95
+ store_manager = cls.store_managers.get(cache_key)
96
+ if store_manager is None:
97
+ store_manager = PostgresStoreManager(store_model)
98
+ cls.store_managers[cache_key] = store_manager
86
99
  case _:
87
- raise ValueError(f"Unknown store type: {store_model.type}")
100
+ raise ValueError(f"Unknown storage type: {store_model.storage_type}")
88
101
 
89
102
  return store_manager
90
103
 
@@ -95,7 +108,7 @@ class CheckpointManager:
95
108
  @classmethod
96
109
  def instance(cls, checkpointer_model: CheckpointerModel) -> CheckpointManagerBase:
97
110
  checkpointer_manager: CheckpointManagerBase | None = None
98
- match checkpointer_model.type:
111
+ match checkpointer_model.storage_type:
99
112
  case StorageType.MEMORY:
100
113
  checkpointer_manager = cls.checkpoint_managers.get(
101
114
  checkpointer_model.name
@@ -108,32 +121,36 @@ class CheckpointManager:
108
121
  checkpointer_manager
109
122
  )
110
123
  case StorageType.POSTGRES:
111
- from dao_ai.memory.postgres import AsyncPostgresCheckpointerManager
124
+ # Route based on database configuration: instance_name -> Databricks, host -> Postgres
125
+ if checkpointer_model.database.is_lakebase:
126
+ # Databricks Lakebase connection
127
+ from dao_ai.memory.databricks import DatabricksCheckpointerManager
112
128
 
113
- checkpointer_manager = cls.checkpoint_managers.get(
114
- checkpointer_model.database.instance_name
115
- )
116
- if checkpointer_manager is None:
117
- checkpointer_manager = AsyncPostgresCheckpointerManager(
118
- checkpointer_model
119
- )
120
- cls.checkpoint_managers[
129
+ checkpointer_manager = cls.checkpoint_managers.get(
121
130
  checkpointer_model.database.instance_name
122
- ] = checkpointer_manager
123
- case StorageType.LAKEBASE:
124
- from dao_ai.memory.databricks import DatabricksCheckpointerManager
125
-
126
- checkpointer_manager = cls.checkpoint_managers.get(
127
- checkpointer_model.name
128
- )
129
- if checkpointer_manager is None:
130
- checkpointer_manager = DatabricksCheckpointerManager(
131
- checkpointer_model
132
- )
133
- cls.checkpoint_managers[checkpointer_model.name] = (
134
- checkpointer_manager
135
131
  )
132
+ if checkpointer_manager is None:
133
+ checkpointer_manager = DatabricksCheckpointerManager(
134
+ checkpointer_model
135
+ )
136
+ cls.checkpoint_managers[
137
+ checkpointer_model.database.instance_name
138
+ ] = checkpointer_manager
139
+ else:
140
+ # Standard PostgreSQL connection
141
+ from dao_ai.memory.postgres import AsyncPostgresCheckpointerManager
142
+
143
+ # Use database name as key for standard PostgreSQL
144
+ cache_key = f"{checkpointer_model.database.name}"
145
+ checkpointer_manager = cls.checkpoint_managers.get(cache_key)
146
+ if checkpointer_manager is None:
147
+ checkpointer_manager = AsyncPostgresCheckpointerManager(
148
+ checkpointer_model
149
+ )
150
+ cls.checkpoint_managers[cache_key] = checkpointer_manager
136
151
  case _:
137
- raise ValueError(f"Unknown store type: {checkpointer_model.type}")
152
+ raise ValueError(
153
+ f"Unknown storage type: {checkpointer_model.storage_type}"
154
+ )
138
155
 
139
156
  return checkpointer_manager