vanna 0.7.9__py3-none-any.whl → 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (302) hide show
  1. vanna/__init__.py +167 -395
  2. vanna/agents/__init__.py +7 -0
  3. vanna/capabilities/__init__.py +17 -0
  4. vanna/capabilities/agent_memory/__init__.py +21 -0
  5. vanna/capabilities/agent_memory/base.py +103 -0
  6. vanna/capabilities/agent_memory/models.py +53 -0
  7. vanna/capabilities/file_system/__init__.py +14 -0
  8. vanna/capabilities/file_system/base.py +71 -0
  9. vanna/capabilities/file_system/models.py +25 -0
  10. vanna/capabilities/sql_runner/__init__.py +13 -0
  11. vanna/capabilities/sql_runner/base.py +37 -0
  12. vanna/capabilities/sql_runner/models.py +13 -0
  13. vanna/components/__init__.py +92 -0
  14. vanna/components/base.py +11 -0
  15. vanna/components/rich/__init__.py +83 -0
  16. vanna/components/rich/containers/__init__.py +7 -0
  17. vanna/components/rich/containers/card.py +20 -0
  18. vanna/components/rich/data/__init__.py +9 -0
  19. vanna/components/rich/data/chart.py +17 -0
  20. vanna/components/rich/data/dataframe.py +93 -0
  21. vanna/components/rich/feedback/__init__.py +21 -0
  22. vanna/components/rich/feedback/badge.py +16 -0
  23. vanna/components/rich/feedback/icon_text.py +14 -0
  24. vanna/components/rich/feedback/log_viewer.py +41 -0
  25. vanna/components/rich/feedback/notification.py +19 -0
  26. vanna/components/rich/feedback/progress.py +37 -0
  27. vanna/components/rich/feedback/status_card.py +28 -0
  28. vanna/components/rich/feedback/status_indicator.py +14 -0
  29. vanna/components/rich/interactive/__init__.py +21 -0
  30. vanna/components/rich/interactive/button.py +95 -0
  31. vanna/components/rich/interactive/task_list.py +58 -0
  32. vanna/components/rich/interactive/ui_state.py +93 -0
  33. vanna/components/rich/specialized/__init__.py +7 -0
  34. vanna/components/rich/specialized/artifact.py +20 -0
  35. vanna/components/rich/text.py +16 -0
  36. vanna/components/simple/__init__.py +15 -0
  37. vanna/components/simple/image.py +15 -0
  38. vanna/components/simple/link.py +15 -0
  39. vanna/components/simple/text.py +11 -0
  40. vanna/core/__init__.py +193 -0
  41. vanna/core/_compat.py +19 -0
  42. vanna/core/agent/__init__.py +10 -0
  43. vanna/core/agent/agent.py +1407 -0
  44. vanna/core/agent/config.py +123 -0
  45. vanna/core/audit/__init__.py +28 -0
  46. vanna/core/audit/base.py +299 -0
  47. vanna/core/audit/models.py +131 -0
  48. vanna/core/component_manager.py +329 -0
  49. vanna/core/components.py +53 -0
  50. vanna/core/enhancer/__init__.py +11 -0
  51. vanna/core/enhancer/base.py +94 -0
  52. vanna/core/enhancer/default.py +118 -0
  53. vanna/core/enricher/__init__.py +10 -0
  54. vanna/core/enricher/base.py +59 -0
  55. vanna/core/errors.py +47 -0
  56. vanna/core/evaluation/__init__.py +81 -0
  57. vanna/core/evaluation/base.py +186 -0
  58. vanna/core/evaluation/dataset.py +254 -0
  59. vanna/core/evaluation/evaluators.py +376 -0
  60. vanna/core/evaluation/report.py +289 -0
  61. vanna/core/evaluation/runner.py +313 -0
  62. vanna/core/filter/__init__.py +10 -0
  63. vanna/core/filter/base.py +67 -0
  64. vanna/core/lifecycle/__init__.py +10 -0
  65. vanna/core/lifecycle/base.py +83 -0
  66. vanna/core/llm/__init__.py +16 -0
  67. vanna/core/llm/base.py +40 -0
  68. vanna/core/llm/models.py +61 -0
  69. vanna/core/middleware/__init__.py +10 -0
  70. vanna/core/middleware/base.py +69 -0
  71. vanna/core/observability/__init__.py +11 -0
  72. vanna/core/observability/base.py +88 -0
  73. vanna/core/observability/models.py +47 -0
  74. vanna/core/recovery/__init__.py +11 -0
  75. vanna/core/recovery/base.py +84 -0
  76. vanna/core/recovery/models.py +32 -0
  77. vanna/core/registry.py +278 -0
  78. vanna/core/rich_component.py +156 -0
  79. vanna/core/simple_component.py +27 -0
  80. vanna/core/storage/__init__.py +14 -0
  81. vanna/core/storage/base.py +46 -0
  82. vanna/core/storage/models.py +46 -0
  83. vanna/core/system_prompt/__init__.py +13 -0
  84. vanna/core/system_prompt/base.py +36 -0
  85. vanna/core/system_prompt/default.py +157 -0
  86. vanna/core/tool/__init__.py +18 -0
  87. vanna/core/tool/base.py +70 -0
  88. vanna/core/tool/models.py +84 -0
  89. vanna/core/user/__init__.py +17 -0
  90. vanna/core/user/base.py +29 -0
  91. vanna/core/user/models.py +25 -0
  92. vanna/core/user/request_context.py +70 -0
  93. vanna/core/user/resolver.py +42 -0
  94. vanna/core/validation.py +164 -0
  95. vanna/core/workflow/__init__.py +12 -0
  96. vanna/core/workflow/base.py +254 -0
  97. vanna/core/workflow/default.py +789 -0
  98. vanna/examples/__init__.py +1 -0
  99. vanna/examples/__main__.py +44 -0
  100. vanna/examples/anthropic_quickstart.py +80 -0
  101. vanna/examples/artifact_example.py +293 -0
  102. vanna/examples/claude_sqlite_example.py +236 -0
  103. vanna/examples/coding_agent_example.py +300 -0
  104. vanna/examples/custom_system_prompt_example.py +174 -0
  105. vanna/examples/default_workflow_handler_example.py +208 -0
  106. vanna/examples/email_auth_example.py +340 -0
  107. vanna/examples/evaluation_example.py +269 -0
  108. vanna/examples/extensibility_example.py +262 -0
  109. vanna/examples/minimal_example.py +67 -0
  110. vanna/examples/mock_auth_example.py +227 -0
  111. vanna/examples/mock_custom_tool.py +311 -0
  112. vanna/examples/mock_quickstart.py +79 -0
  113. vanna/examples/mock_quota_example.py +145 -0
  114. vanna/examples/mock_rich_components_demo.py +396 -0
  115. vanna/examples/mock_sqlite_example.py +223 -0
  116. vanna/examples/openai_quickstart.py +83 -0
  117. vanna/examples/primitive_components_demo.py +305 -0
  118. vanna/examples/quota_lifecycle_example.py +139 -0
  119. vanna/examples/visualization_example.py +251 -0
  120. vanna/integrations/__init__.py +17 -0
  121. vanna/integrations/anthropic/__init__.py +9 -0
  122. vanna/integrations/anthropic/llm.py +270 -0
  123. vanna/integrations/azureopenai/__init__.py +9 -0
  124. vanna/integrations/azureopenai/llm.py +329 -0
  125. vanna/integrations/azuresearch/__init__.py +7 -0
  126. vanna/integrations/azuresearch/agent_memory.py +413 -0
  127. vanna/integrations/bigquery/__init__.py +5 -0
  128. vanna/integrations/bigquery/sql_runner.py +81 -0
  129. vanna/integrations/chromadb/__init__.py +104 -0
  130. vanna/integrations/chromadb/agent_memory.py +416 -0
  131. vanna/integrations/clickhouse/__init__.py +5 -0
  132. vanna/integrations/clickhouse/sql_runner.py +82 -0
  133. vanna/integrations/duckdb/__init__.py +5 -0
  134. vanna/integrations/duckdb/sql_runner.py +65 -0
  135. vanna/integrations/faiss/__init__.py +7 -0
  136. vanna/integrations/faiss/agent_memory.py +431 -0
  137. vanna/integrations/google/__init__.py +9 -0
  138. vanna/integrations/google/gemini.py +370 -0
  139. vanna/integrations/hive/__init__.py +5 -0
  140. vanna/integrations/hive/sql_runner.py +87 -0
  141. vanna/integrations/local/__init__.py +17 -0
  142. vanna/integrations/local/agent_memory/__init__.py +7 -0
  143. vanna/integrations/local/agent_memory/in_memory.py +285 -0
  144. vanna/integrations/local/audit.py +59 -0
  145. vanna/integrations/local/file_system.py +242 -0
  146. vanna/integrations/local/file_system_conversation_store.py +255 -0
  147. vanna/integrations/local/storage.py +62 -0
  148. vanna/integrations/marqo/__init__.py +7 -0
  149. vanna/integrations/marqo/agent_memory.py +354 -0
  150. vanna/integrations/milvus/__init__.py +7 -0
  151. vanna/integrations/milvus/agent_memory.py +458 -0
  152. vanna/integrations/mock/__init__.py +9 -0
  153. vanna/integrations/mock/llm.py +65 -0
  154. vanna/integrations/mssql/__init__.py +5 -0
  155. vanna/integrations/mssql/sql_runner.py +66 -0
  156. vanna/integrations/mysql/__init__.py +5 -0
  157. vanna/integrations/mysql/sql_runner.py +92 -0
  158. vanna/integrations/ollama/__init__.py +7 -0
  159. vanna/integrations/ollama/llm.py +252 -0
  160. vanna/integrations/openai/__init__.py +10 -0
  161. vanna/integrations/openai/llm.py +267 -0
  162. vanna/integrations/openai/responses.py +163 -0
  163. vanna/integrations/opensearch/__init__.py +7 -0
  164. vanna/integrations/opensearch/agent_memory.py +411 -0
  165. vanna/integrations/oracle/__init__.py +5 -0
  166. vanna/integrations/oracle/sql_runner.py +75 -0
  167. vanna/integrations/pinecone/__init__.py +7 -0
  168. vanna/integrations/pinecone/agent_memory.py +329 -0
  169. vanna/integrations/plotly/__init__.py +5 -0
  170. vanna/integrations/plotly/chart_generator.py +313 -0
  171. vanna/integrations/postgres/__init__.py +9 -0
  172. vanna/integrations/postgres/sql_runner.py +112 -0
  173. vanna/integrations/premium/agent_memory/__init__.py +7 -0
  174. vanna/integrations/premium/agent_memory/premium.py +186 -0
  175. vanna/integrations/presto/__init__.py +5 -0
  176. vanna/integrations/presto/sql_runner.py +107 -0
  177. vanna/integrations/qdrant/__init__.py +7 -0
  178. vanna/integrations/qdrant/agent_memory.py +461 -0
  179. vanna/integrations/snowflake/__init__.py +5 -0
  180. vanna/integrations/snowflake/sql_runner.py +147 -0
  181. vanna/integrations/sqlite/__init__.py +9 -0
  182. vanna/integrations/sqlite/sql_runner.py +65 -0
  183. vanna/integrations/weaviate/__init__.py +7 -0
  184. vanna/integrations/weaviate/agent_memory.py +428 -0
  185. vanna/{ZhipuAI → legacy/ZhipuAI}/ZhipuAI_embeddings.py +11 -11
  186. vanna/legacy/__init__.py +403 -0
  187. vanna/legacy/adapter.py +463 -0
  188. vanna/{advanced → legacy/advanced}/__init__.py +3 -1
  189. vanna/{anthropic → legacy/anthropic}/anthropic_chat.py +9 -7
  190. vanna/{azuresearch → legacy/azuresearch}/azuresearch_vector.py +79 -41
  191. vanna/{base → legacy/base}/base.py +224 -217
  192. vanna/legacy/bedrock/__init__.py +1 -0
  193. vanna/{bedrock → legacy/bedrock}/bedrock_converse.py +13 -12
  194. vanna/{chromadb → legacy/chromadb}/chromadb_vector.py +3 -1
  195. vanna/legacy/cohere/__init__.py +2 -0
  196. vanna/{cohere → legacy/cohere}/cohere_chat.py +19 -14
  197. vanna/{cohere → legacy/cohere}/cohere_embeddings.py +25 -19
  198. vanna/{deepseek → legacy/deepseek}/deepseek_chat.py +5 -6
  199. vanna/legacy/faiss/__init__.py +1 -0
  200. vanna/{faiss → legacy/faiss}/faiss.py +113 -59
  201. vanna/{flask → legacy/flask}/__init__.py +84 -43
  202. vanna/{flask → legacy/flask}/assets.py +5 -5
  203. vanna/{flask → legacy/flask}/auth.py +5 -4
  204. vanna/{google → legacy/google}/bigquery_vector.py +75 -42
  205. vanna/{google → legacy/google}/gemini_chat.py +7 -3
  206. vanna/{hf → legacy/hf}/hf.py +0 -1
  207. vanna/{milvus → legacy/milvus}/milvus_vector.py +58 -35
  208. vanna/{mock → legacy/mock}/llm.py +0 -1
  209. vanna/legacy/mock/vectordb.py +67 -0
  210. vanna/legacy/ollama/ollama.py +110 -0
  211. vanna/{openai → legacy/openai}/openai_chat.py +2 -6
  212. vanna/legacy/opensearch/opensearch_vector.py +369 -0
  213. vanna/legacy/opensearch/opensearch_vector_semantic.py +200 -0
  214. vanna/legacy/oracle/oracle_vector.py +584 -0
  215. vanna/{pgvector → legacy/pgvector}/pgvector.py +42 -13
  216. vanna/{qdrant → legacy/qdrant}/qdrant.py +2 -6
  217. vanna/legacy/qianfan/Qianfan_Chat.py +170 -0
  218. vanna/legacy/qianfan/Qianfan_embeddings.py +36 -0
  219. vanna/legacy/qianwen/QianwenAI_chat.py +132 -0
  220. vanna/{remote.py → legacy/remote.py} +28 -26
  221. vanna/{utils.py → legacy/utils.py} +6 -11
  222. vanna/{vannadb → legacy/vannadb}/vannadb_vector.py +115 -46
  223. vanna/{vllm → legacy/vllm}/vllm.py +5 -6
  224. vanna/{weaviate → legacy/weaviate}/weaviate_vector.py +59 -40
  225. vanna/{xinference → legacy/xinference}/xinference.py +6 -6
  226. vanna/py.typed +0 -0
  227. vanna/servers/__init__.py +16 -0
  228. vanna/servers/__main__.py +8 -0
  229. vanna/servers/base/__init__.py +18 -0
  230. vanna/servers/base/chat_handler.py +65 -0
  231. vanna/servers/base/models.py +111 -0
  232. vanna/servers/base/rich_chat_handler.py +141 -0
  233. vanna/servers/base/templates.py +331 -0
  234. vanna/servers/cli/__init__.py +7 -0
  235. vanna/servers/cli/server_runner.py +204 -0
  236. vanna/servers/fastapi/__init__.py +7 -0
  237. vanna/servers/fastapi/app.py +163 -0
  238. vanna/servers/fastapi/routes.py +183 -0
  239. vanna/servers/flask/__init__.py +7 -0
  240. vanna/servers/flask/app.py +132 -0
  241. vanna/servers/flask/routes.py +137 -0
  242. vanna/tools/__init__.py +41 -0
  243. vanna/tools/agent_memory.py +322 -0
  244. vanna/tools/file_system.py +879 -0
  245. vanna/tools/python.py +222 -0
  246. vanna/tools/run_sql.py +165 -0
  247. vanna/tools/visualize_data.py +195 -0
  248. vanna/utils/__init__.py +0 -0
  249. vanna/web_components/__init__.py +44 -0
  250. vanna-2.0.0.dist-info/METADATA +485 -0
  251. vanna-2.0.0.dist-info/RECORD +289 -0
  252. vanna-2.0.0.dist-info/entry_points.txt +3 -0
  253. vanna/bedrock/__init__.py +0 -1
  254. vanna/cohere/__init__.py +0 -2
  255. vanna/faiss/__init__.py +0 -1
  256. vanna/mock/vectordb.py +0 -55
  257. vanna/ollama/ollama.py +0 -103
  258. vanna/opensearch/opensearch_vector.py +0 -392
  259. vanna/opensearch/opensearch_vector_semantic.py +0 -175
  260. vanna/oracle/oracle_vector.py +0 -585
  261. vanna/qianfan/Qianfan_Chat.py +0 -165
  262. vanna/qianfan/Qianfan_embeddings.py +0 -36
  263. vanna/qianwen/QianwenAI_chat.py +0 -133
  264. vanna-0.7.9.dist-info/METADATA +0 -408
  265. vanna-0.7.9.dist-info/RECORD +0 -79
  266. /vanna/{ZhipuAI → legacy/ZhipuAI}/ZhipuAI_Chat.py +0 -0
  267. /vanna/{ZhipuAI → legacy/ZhipuAI}/__init__.py +0 -0
  268. /vanna/{anthropic → legacy/anthropic}/__init__.py +0 -0
  269. /vanna/{azuresearch → legacy/azuresearch}/__init__.py +0 -0
  270. /vanna/{base → legacy/base}/__init__.py +0 -0
  271. /vanna/{chromadb → legacy/chromadb}/__init__.py +0 -0
  272. /vanna/{deepseek → legacy/deepseek}/__init__.py +0 -0
  273. /vanna/{exceptions → legacy/exceptions}/__init__.py +0 -0
  274. /vanna/{google → legacy/google}/__init__.py +0 -0
  275. /vanna/{hf → legacy/hf}/__init__.py +0 -0
  276. /vanna/{local.py → legacy/local.py} +0 -0
  277. /vanna/{marqo → legacy/marqo}/__init__.py +0 -0
  278. /vanna/{marqo → legacy/marqo}/marqo.py +0 -0
  279. /vanna/{milvus → legacy/milvus}/__init__.py +0 -0
  280. /vanna/{mistral → legacy/mistral}/__init__.py +0 -0
  281. /vanna/{mistral → legacy/mistral}/mistral.py +0 -0
  282. /vanna/{mock → legacy/mock}/__init__.py +0 -0
  283. /vanna/{mock → legacy/mock}/embedding.py +0 -0
  284. /vanna/{ollama → legacy/ollama}/__init__.py +0 -0
  285. /vanna/{openai → legacy/openai}/__init__.py +0 -0
  286. /vanna/{openai → legacy/openai}/openai_embeddings.py +0 -0
  287. /vanna/{opensearch → legacy/opensearch}/__init__.py +0 -0
  288. /vanna/{oracle → legacy/oracle}/__init__.py +0 -0
  289. /vanna/{pgvector → legacy/pgvector}/__init__.py +0 -0
  290. /vanna/{pinecone → legacy/pinecone}/__init__.py +0 -0
  291. /vanna/{pinecone → legacy/pinecone}/pinecone_vector.py +0 -0
  292. /vanna/{qdrant → legacy/qdrant}/__init__.py +0 -0
  293. /vanna/{qianfan → legacy/qianfan}/__init__.py +0 -0
  294. /vanna/{qianwen → legacy/qianwen}/QianwenAI_embeddings.py +0 -0
  295. /vanna/{qianwen → legacy/qianwen}/__init__.py +0 -0
  296. /vanna/{types → legacy/types}/__init__.py +0 -0
  297. /vanna/{vannadb → legacy/vannadb}/__init__.py +0 -0
  298. /vanna/{vllm → legacy/vllm}/__init__.py +0 -0
  299. /vanna/{weaviate → legacy/weaviate}/__init__.py +0 -0
  300. /vanna/{xinference → legacy/xinference}/__init__.py +0 -0
  301. {vanna-0.7.9.dist-info → vanna-2.0.0.dist-info}/WHEEL +0 -0
  302. {vanna-0.7.9.dist-info → vanna-2.0.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,428 @@
1
+ """
2
+ Weaviate vector database implementation of AgentMemory.
3
+
4
+ This implementation uses Weaviate for semantic search and storage of tool usage patterns.
5
+ """
6
+
7
+ import json
8
+ import uuid
9
+ from datetime import datetime
10
+ from typing import Any, Dict, List, Optional
11
+ import asyncio
12
+ from concurrent.futures import ThreadPoolExecutor
13
+
14
+ try:
15
+ import weaviate
16
+ from weaviate.classes.config import (
17
+ Configure,
18
+ Property,
19
+ DataType as WeaviateDataType,
20
+ )
21
+
22
+ WEAVIATE_AVAILABLE = True
23
+ except ImportError:
24
+ WEAVIATE_AVAILABLE = False
25
+
26
+ from vanna.capabilities.agent_memory import (
27
+ AgentMemory,
28
+ TextMemory,
29
+ TextMemorySearchResult,
30
+ ToolMemory,
31
+ ToolMemorySearchResult,
32
+ )
33
+ from vanna.core.tool import ToolContext
34
+
35
+
36
+ class WeaviateAgentMemory(AgentMemory):
37
+ """Weaviate-based implementation of AgentMemory."""
38
+
39
+ def __init__(
40
+ self,
41
+ collection_name: str = "ToolMemory",
42
+ url: str = "http://localhost:8080",
43
+ api_key: Optional[str] = None,
44
+ dimension: int = 384,
45
+ ):
46
+ if not WEAVIATE_AVAILABLE:
47
+ raise ImportError(
48
+ "Weaviate is required for WeaviateAgentMemory. Install with: pip install weaviate-client"
49
+ )
50
+
51
+ self.collection_name = collection_name
52
+ self.url = url
53
+ self.api_key = api_key
54
+ self.dimension = dimension
55
+ self._client = None
56
+ self._executor = ThreadPoolExecutor(max_workers=2)
57
+
58
+ def _get_client(self):
59
+ """Get or create Weaviate client."""
60
+ if self._client is None:
61
+ if self.api_key:
62
+ self._client = weaviate.connect_to_weaviate_cloud(
63
+ cluster_url=self.url,
64
+ auth_credentials=weaviate.auth.AuthApiKey(self.api_key),
65
+ )
66
+ else:
67
+ self._client = weaviate.connect_to_local(
68
+ host=self.url.replace("http://", "").replace("https://", "")
69
+ )
70
+
71
+ # Create collection if it doesn't exist
72
+ if not self._client.collections.exists(self.collection_name):
73
+ self._client.collections.create(
74
+ name=self.collection_name,
75
+ vectorizer_config=Configure.Vectorizer.none(),
76
+ properties=[
77
+ Property(name="question", data_type=WeaviateDataType.TEXT),
78
+ Property(name="tool_name", data_type=WeaviateDataType.TEXT),
79
+ Property(name="args_json", data_type=WeaviateDataType.TEXT),
80
+ Property(name="timestamp", data_type=WeaviateDataType.TEXT),
81
+ Property(name="success", data_type=WeaviateDataType.BOOL),
82
+ Property(name="metadata_json", data_type=WeaviateDataType.TEXT),
83
+ ],
84
+ )
85
+
86
+ return self._client
87
+
88
+ def _create_embedding(self, text: str) -> List[float]:
89
+ """Create a simple embedding from text (placeholder)."""
90
+ import hashlib
91
+
92
+ hash_val = int(hashlib.md5(text.encode()).hexdigest(), 16)
93
+ return [(hash_val >> i) % 100 / 100.0 for i in range(self.dimension)]
94
+
95
+ async def save_tool_usage(
96
+ self,
97
+ question: str,
98
+ tool_name: str,
99
+ args: Dict[str, Any],
100
+ context: ToolContext,
101
+ success: bool = True,
102
+ metadata: Optional[Dict[str, Any]] = None,
103
+ ) -> None:
104
+ """Save a tool usage pattern."""
105
+
106
+ def _save():
107
+ client = self._get_client()
108
+ collection = client.collections.get(self.collection_name)
109
+
110
+ memory_id = str(uuid.uuid4())
111
+ timestamp = datetime.now().isoformat()
112
+ embedding = self._create_embedding(question)
113
+
114
+ properties = {
115
+ "question": question,
116
+ "tool_name": tool_name,
117
+ "args_json": json.dumps(args),
118
+ "timestamp": timestamp,
119
+ "success": success,
120
+ "metadata_json": json.dumps(metadata or {}),
121
+ }
122
+
123
+ collection.data.insert(
124
+ properties=properties, vector=embedding, uuid=memory_id
125
+ )
126
+
127
+ await asyncio.get_event_loop().run_in_executor(self._executor, _save)
128
+
129
+ async def search_similar_usage(
130
+ self,
131
+ question: str,
132
+ context: ToolContext,
133
+ *,
134
+ limit: int = 10,
135
+ similarity_threshold: float = 0.7,
136
+ tool_name_filter: Optional[str] = None,
137
+ ) -> List[ToolMemorySearchResult]:
138
+ """Search for similar tool usage patterns."""
139
+
140
+ def _search():
141
+ client = self._get_client()
142
+ collection = client.collections.get(self.collection_name)
143
+
144
+ embedding = self._create_embedding(question)
145
+
146
+ # Build filter
147
+ filters = weaviate.classes.query.Filter.by_property("success").equal(True)
148
+ if tool_name_filter:
149
+ filters = filters & weaviate.classes.query.Filter.by_property(
150
+ "tool_name"
151
+ ).equal(tool_name_filter)
152
+
153
+ response = collection.query.near_vector(
154
+ near_vector=embedding,
155
+ limit=limit,
156
+ filters=filters,
157
+ return_metadata=weaviate.classes.query.MetadataQuery(distance=True),
158
+ )
159
+
160
+ search_results = []
161
+ for i, obj in enumerate(response.objects):
162
+ # Weaviate returns distance, convert to similarity
163
+ distance = obj.metadata.distance if obj.metadata else 1.0
164
+ similarity_score = 1 - distance
165
+
166
+ if similarity_score >= similarity_threshold:
167
+ properties = obj.properties
168
+ args = json.loads(properties.get("args_json", "{}"))
169
+ metadata_dict = json.loads(properties.get("metadata_json", "{}"))
170
+
171
+ memory = ToolMemory(
172
+ memory_id=str(obj.uuid),
173
+ question=properties.get("question"),
174
+ tool_name=properties.get("tool_name"),
175
+ args=args,
176
+ timestamp=properties.get("timestamp"),
177
+ success=properties.get("success", True),
178
+ metadata=metadata_dict,
179
+ )
180
+
181
+ search_results.append(
182
+ ToolMemorySearchResult(
183
+ memory=memory, similarity_score=similarity_score, rank=i + 1
184
+ )
185
+ )
186
+
187
+ return search_results
188
+
189
+ return await asyncio.get_event_loop().run_in_executor(self._executor, _search)
190
+
191
+ async def get_recent_memories(
192
+ self, context: ToolContext, limit: int = 10
193
+ ) -> List[ToolMemory]:
194
+ """Get recently added memories."""
195
+
196
+ def _get_recent():
197
+ client = self._get_client()
198
+ collection = client.collections.get(self.collection_name)
199
+
200
+ # Query and sort by timestamp
201
+ response = collection.query.fetch_objects(limit=1000)
202
+
203
+ # Convert to list and sort
204
+ objects_list = list(response.objects)
205
+ sorted_objects = sorted(
206
+ objects_list,
207
+ key=lambda o: o.properties.get("timestamp", ""),
208
+ reverse=True,
209
+ )
210
+
211
+ memories = []
212
+ for obj in sorted_objects[:limit]:
213
+ properties = obj.properties
214
+ args = json.loads(properties.get("args_json", "{}"))
215
+ metadata_dict = json.loads(properties.get("metadata_json", "{}"))
216
+
217
+ memory = ToolMemory(
218
+ memory_id=str(obj.uuid),
219
+ question=properties.get("question"),
220
+ tool_name=properties.get("tool_name"),
221
+ args=args,
222
+ timestamp=properties.get("timestamp"),
223
+ success=properties.get("success", True),
224
+ metadata=metadata_dict,
225
+ )
226
+ memories.append(memory)
227
+
228
+ return memories
229
+
230
+ return await asyncio.get_event_loop().run_in_executor(
231
+ self._executor, _get_recent
232
+ )
233
+
234
+ async def delete_by_id(self, context: ToolContext, memory_id: str) -> bool:
235
+ """Delete a memory by its ID."""
236
+
237
+ def _delete():
238
+ client = self._get_client()
239
+ collection = client.collections.get(self.collection_name)
240
+
241
+ try:
242
+ collection.data.delete_by_id(uuid=memory_id)
243
+ return True
244
+ except Exception:
245
+ return False
246
+
247
+ return await asyncio.get_event_loop().run_in_executor(self._executor, _delete)
248
+
249
+ async def save_text_memory(self, content: str, context: ToolContext) -> TextMemory:
250
+ """Save a text memory."""
251
+
252
+ def _save():
253
+ client = self._get_client()
254
+ collection = client.collections.get(self.collection_name)
255
+
256
+ memory_id = str(uuid.uuid4())
257
+ timestamp = datetime.now().isoformat()
258
+ embedding = self._create_embedding(content)
259
+
260
+ properties = {
261
+ "question": content, # Using question field for content
262
+ "tool_name": "", # Empty for text memories
263
+ "args_json": "",
264
+ "timestamp": timestamp,
265
+ "success": True,
266
+ "metadata_json": json.dumps({"is_text_memory": True}),
267
+ }
268
+
269
+ collection.data.insert(
270
+ properties=properties, vector=embedding, uuid=memory_id
271
+ )
272
+
273
+ return TextMemory(memory_id=memory_id, content=content, timestamp=timestamp)
274
+
275
+ return await asyncio.get_event_loop().run_in_executor(self._executor, _save)
276
+
277
+ async def search_text_memories(
278
+ self,
279
+ query: str,
280
+ context: ToolContext,
281
+ *,
282
+ limit: int = 10,
283
+ similarity_threshold: float = 0.7,
284
+ ) -> List[TextMemorySearchResult]:
285
+ """Search for similar text memories."""
286
+
287
+ def _search():
288
+ client = self._get_client()
289
+ collection = client.collections.get(self.collection_name)
290
+
291
+ embedding = self._create_embedding(query)
292
+
293
+ # Build filter for text memories (empty tool_name)
294
+ filters = weaviate.classes.query.Filter.by_property("tool_name").equal("")
295
+
296
+ response = collection.query.near_vector(
297
+ near_vector=embedding,
298
+ limit=limit,
299
+ filters=filters,
300
+ return_metadata=weaviate.classes.query.MetadataQuery(distance=True),
301
+ )
302
+
303
+ search_results = []
304
+ for i, obj in enumerate(response.objects):
305
+ distance = obj.metadata.distance if obj.metadata else 1.0
306
+ similarity_score = 1 - distance
307
+
308
+ if similarity_score >= similarity_threshold:
309
+ properties = obj.properties
310
+ content = properties.get("question", "")
311
+
312
+ memory = TextMemory(
313
+ memory_id=str(obj.uuid),
314
+ content=content,
315
+ timestamp=properties.get("timestamp"),
316
+ )
317
+
318
+ search_results.append(
319
+ TextMemorySearchResult(
320
+ memory=memory, similarity_score=similarity_score, rank=i + 1
321
+ )
322
+ )
323
+
324
+ return search_results
325
+
326
+ return await asyncio.get_event_loop().run_in_executor(self._executor, _search)
327
+
328
+ async def get_recent_text_memories(
329
+ self, context: ToolContext, limit: int = 10
330
+ ) -> List[TextMemory]:
331
+ """Get recently added text memories."""
332
+
333
+ def _get_recent():
334
+ client = self._get_client()
335
+ collection = client.collections.get(self.collection_name)
336
+
337
+ # Query text memories (empty tool_name) and sort by timestamp
338
+ response = collection.query.fetch_objects(
339
+ filters=weaviate.classes.query.Filter.by_property("tool_name").equal(
340
+ ""
341
+ ),
342
+ limit=1000,
343
+ )
344
+
345
+ # Convert to list and sort
346
+ objects_list = list(response.objects)
347
+ sorted_objects = sorted(
348
+ objects_list,
349
+ key=lambda o: o.properties.get("timestamp", ""),
350
+ reverse=True,
351
+ )
352
+
353
+ memories = []
354
+ for obj in sorted_objects[:limit]:
355
+ properties = obj.properties
356
+ content = properties.get("question", "")
357
+
358
+ memory = TextMemory(
359
+ memory_id=str(obj.uuid),
360
+ content=content,
361
+ timestamp=properties.get("timestamp"),
362
+ )
363
+ memories.append(memory)
364
+
365
+ return memories
366
+
367
+ return await asyncio.get_event_loop().run_in_executor(
368
+ self._executor, _get_recent
369
+ )
370
+
371
+ async def delete_text_memory(self, context: ToolContext, memory_id: str) -> bool:
372
+ """Delete a text memory by its ID."""
373
+
374
+ def _delete():
375
+ client = self._get_client()
376
+ collection = client.collections.get(self.collection_name)
377
+
378
+ try:
379
+ collection.data.delete_by_id(uuid=memory_id)
380
+ return True
381
+ except Exception:
382
+ return False
383
+
384
+ return await asyncio.get_event_loop().run_in_executor(self._executor, _delete)
385
+
386
+ async def clear_memories(
387
+ self,
388
+ context: ToolContext,
389
+ tool_name: Optional[str] = None,
390
+ before_date: Optional[str] = None,
391
+ ) -> int:
392
+ """Clear stored memories."""
393
+
394
+ def _clear():
395
+ client = self._get_client()
396
+ collection = client.collections.get(self.collection_name)
397
+
398
+ # Build filter
399
+ if tool_name and before_date:
400
+ filters = weaviate.classes.query.Filter.by_property("tool_name").equal(
401
+ tool_name
402
+ ) & weaviate.classes.query.Filter.by_property("timestamp").less_than(
403
+ before_date
404
+ )
405
+ elif tool_name:
406
+ filters = weaviate.classes.query.Filter.by_property("tool_name").equal(
407
+ tool_name
408
+ )
409
+ elif before_date:
410
+ filters = weaviate.classes.query.Filter.by_property(
411
+ "timestamp"
412
+ ).less_than(before_date)
413
+ else:
414
+ filters = None
415
+
416
+ if filters:
417
+ collection.data.delete_many(where=filters)
418
+ else:
419
+ # Delete all
420
+ collection.data.delete_many(
421
+ where=weaviate.classes.query.Filter.by_property(
422
+ "success"
423
+ ).contains_any([True, False])
424
+ )
425
+
426
+ return 0
427
+
428
+ return await asyncio.get_event_loop().run_in_executor(self._executor, _clear)
@@ -3,6 +3,7 @@ from zhipuai import ZhipuAI
3
3
  from chromadb import Documents, EmbeddingFunction, Embeddings
4
4
  from ..base import VannaBase
5
5
 
6
+
6
7
  class ZhipuAI_Embeddings(VannaBase):
7
8
  """
8
9
  [future functionality] This function is used to generate embeddings from ZhipuAI.
@@ -10,6 +11,7 @@ class ZhipuAI_Embeddings(VannaBase):
10
11
  Args:
11
12
  VannaBase (_type_): _description_
12
13
  """
14
+
13
15
  def __init__(self, config=None):
14
16
  VannaBase.__init__(self, config=config)
15
17
  if "api_key" not in config:
@@ -18,39 +20,38 @@ class ZhipuAI_Embeddings(VannaBase):
18
20
  self.client = ZhipuAI(api_key=self.api_key)
19
21
 
20
22
  def generate_embedding(self, data: str, **kwargs) -> List[float]:
21
-
22
23
  embedding = self.client.embeddings.create(
23
24
  model="embedding-2",
24
25
  input=data,
25
26
  )
26
27
 
27
28
  return embedding.data[0].embedding
28
-
29
29
 
30
30
 
31
31
  class ZhipuAIEmbeddingFunction(EmbeddingFunction[Documents]):
32
32
  """
33
33
  A embeddingFunction that uses ZhipuAI to generate embeddings which can use in chromadb.
34
- usage:
34
+ usage:
35
35
  class MyVanna(ChromaDB_VectorStore, ZhipuAI_Chat):
36
36
  def __init__(self, config=None):
37
37
  ChromaDB_VectorStore.__init__(self, config=config)
38
38
  ZhipuAI_Chat.__init__(self, config=config)
39
-
39
+
40
40
  config={'api_key': 'xxx'}
41
41
  zhipu_embedding_function = ZhipuAIEmbeddingFunction(config=config)
42
42
  config = {"api_key": "xxx", "model": "glm-4","path":"xy","embedding_function":zhipu_embedding_function}
43
-
43
+
44
44
  vn = MyVanna(config)
45
-
45
+
46
46
  """
47
+
47
48
  def __init__(self, config=None):
48
49
  if config is None or "api_key" not in config:
49
50
  raise ValueError("Missing 'api_key' in config")
50
-
51
+
51
52
  self.api_key = config["api_key"]
52
53
  self.model_name = config.get("model_name", "embedding-2")
53
-
54
+
54
55
  try:
55
56
  self.client = ZhipuAI(api_key=self.api_key)
56
57
  except Exception as e:
@@ -66,8 +67,7 @@ class ZhipuAIEmbeddingFunction(EmbeddingFunction[Documents]):
66
67
  for document in input:
67
68
  try:
68
69
  response = self.client.embeddings.create(
69
- model=self.model_name,
70
- input=document
70
+ model=self.model_name, input=document
71
71
  )
72
72
  # print(response)
73
73
  embedding = response.data[0].embedding
@@ -76,4 +76,4 @@ class ZhipuAIEmbeddingFunction(EmbeddingFunction[Documents]):
76
76
  except Exception as e:
77
77
  raise ValueError(f"Error generating embedding for document: {e}")
78
78
 
79
- return all_embeddings
79
+ return all_embeddings