vanna 0.7.9__py3-none-any.whl → 2.0.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vanna/__init__.py +167 -395
- vanna/agents/__init__.py +7 -0
- vanna/capabilities/__init__.py +17 -0
- vanna/capabilities/agent_memory/__init__.py +21 -0
- vanna/capabilities/agent_memory/base.py +103 -0
- vanna/capabilities/agent_memory/models.py +53 -0
- vanna/capabilities/file_system/__init__.py +14 -0
- vanna/capabilities/file_system/base.py +71 -0
- vanna/capabilities/file_system/models.py +25 -0
- vanna/capabilities/sql_runner/__init__.py +13 -0
- vanna/capabilities/sql_runner/base.py +37 -0
- vanna/capabilities/sql_runner/models.py +13 -0
- vanna/components/__init__.py +92 -0
- vanna/components/base.py +11 -0
- vanna/components/rich/__init__.py +83 -0
- vanna/components/rich/containers/__init__.py +7 -0
- vanna/components/rich/containers/card.py +20 -0
- vanna/components/rich/data/__init__.py +9 -0
- vanna/components/rich/data/chart.py +17 -0
- vanna/components/rich/data/dataframe.py +93 -0
- vanna/components/rich/feedback/__init__.py +21 -0
- vanna/components/rich/feedback/badge.py +16 -0
- vanna/components/rich/feedback/icon_text.py +14 -0
- vanna/components/rich/feedback/log_viewer.py +41 -0
- vanna/components/rich/feedback/notification.py +19 -0
- vanna/components/rich/feedback/progress.py +37 -0
- vanna/components/rich/feedback/status_card.py +28 -0
- vanna/components/rich/feedback/status_indicator.py +14 -0
- vanna/components/rich/interactive/__init__.py +21 -0
- vanna/components/rich/interactive/button.py +95 -0
- vanna/components/rich/interactive/task_list.py +58 -0
- vanna/components/rich/interactive/ui_state.py +93 -0
- vanna/components/rich/specialized/__init__.py +7 -0
- vanna/components/rich/specialized/artifact.py +20 -0
- vanna/components/rich/text.py +16 -0
- vanna/components/simple/__init__.py +15 -0
- vanna/components/simple/image.py +15 -0
- vanna/components/simple/link.py +15 -0
- vanna/components/simple/text.py +11 -0
- vanna/core/__init__.py +193 -0
- vanna/core/_compat.py +19 -0
- vanna/core/agent/__init__.py +10 -0
- vanna/core/agent/agent.py +1407 -0
- vanna/core/agent/config.py +123 -0
- vanna/core/audit/__init__.py +28 -0
- vanna/core/audit/base.py +299 -0
- vanna/core/audit/models.py +131 -0
- vanna/core/component_manager.py +329 -0
- vanna/core/components.py +53 -0
- vanna/core/enhancer/__init__.py +11 -0
- vanna/core/enhancer/base.py +94 -0
- vanna/core/enhancer/default.py +118 -0
- vanna/core/enricher/__init__.py +10 -0
- vanna/core/enricher/base.py +59 -0
- vanna/core/errors.py +47 -0
- vanna/core/evaluation/__init__.py +81 -0
- vanna/core/evaluation/base.py +186 -0
- vanna/core/evaluation/dataset.py +254 -0
- vanna/core/evaluation/evaluators.py +376 -0
- vanna/core/evaluation/report.py +289 -0
- vanna/core/evaluation/runner.py +313 -0
- vanna/core/filter/__init__.py +10 -0
- vanna/core/filter/base.py +67 -0
- vanna/core/lifecycle/__init__.py +10 -0
- vanna/core/lifecycle/base.py +83 -0
- vanna/core/llm/__init__.py +16 -0
- vanna/core/llm/base.py +40 -0
- vanna/core/llm/models.py +61 -0
- vanna/core/middleware/__init__.py +10 -0
- vanna/core/middleware/base.py +69 -0
- vanna/core/observability/__init__.py +11 -0
- vanna/core/observability/base.py +88 -0
- vanna/core/observability/models.py +47 -0
- vanna/core/recovery/__init__.py +11 -0
- vanna/core/recovery/base.py +84 -0
- vanna/core/recovery/models.py +32 -0
- vanna/core/registry.py +278 -0
- vanna/core/rich_component.py +156 -0
- vanna/core/simple_component.py +27 -0
- vanna/core/storage/__init__.py +14 -0
- vanna/core/storage/base.py +46 -0
- vanna/core/storage/models.py +46 -0
- vanna/core/system_prompt/__init__.py +13 -0
- vanna/core/system_prompt/base.py +36 -0
- vanna/core/system_prompt/default.py +157 -0
- vanna/core/tool/__init__.py +18 -0
- vanna/core/tool/base.py +70 -0
- vanna/core/tool/models.py +84 -0
- vanna/core/user/__init__.py +17 -0
- vanna/core/user/base.py +29 -0
- vanna/core/user/models.py +25 -0
- vanna/core/user/request_context.py +70 -0
- vanna/core/user/resolver.py +42 -0
- vanna/core/validation.py +164 -0
- vanna/core/workflow/__init__.py +12 -0
- vanna/core/workflow/base.py +254 -0
- vanna/core/workflow/default.py +789 -0
- vanna/examples/__init__.py +1 -0
- vanna/examples/__main__.py +44 -0
- vanna/examples/anthropic_quickstart.py +80 -0
- vanna/examples/artifact_example.py +293 -0
- vanna/examples/claude_sqlite_example.py +236 -0
- vanna/examples/coding_agent_example.py +300 -0
- vanna/examples/custom_system_prompt_example.py +174 -0
- vanna/examples/default_workflow_handler_example.py +208 -0
- vanna/examples/email_auth_example.py +340 -0
- vanna/examples/evaluation_example.py +269 -0
- vanna/examples/extensibility_example.py +262 -0
- vanna/examples/minimal_example.py +67 -0
- vanna/examples/mock_auth_example.py +227 -0
- vanna/examples/mock_custom_tool.py +311 -0
- vanna/examples/mock_quickstart.py +79 -0
- vanna/examples/mock_quota_example.py +145 -0
- vanna/examples/mock_rich_components_demo.py +396 -0
- vanna/examples/mock_sqlite_example.py +223 -0
- vanna/examples/openai_quickstart.py +83 -0
- vanna/examples/primitive_components_demo.py +305 -0
- vanna/examples/quota_lifecycle_example.py +139 -0
- vanna/examples/visualization_example.py +251 -0
- vanna/integrations/__init__.py +17 -0
- vanna/integrations/anthropic/__init__.py +9 -0
- vanna/integrations/anthropic/llm.py +270 -0
- vanna/integrations/azureopenai/__init__.py +9 -0
- vanna/integrations/azureopenai/llm.py +329 -0
- vanna/integrations/azuresearch/__init__.py +7 -0
- vanna/integrations/azuresearch/agent_memory.py +413 -0
- vanna/integrations/bigquery/__init__.py +5 -0
- vanna/integrations/bigquery/sql_runner.py +81 -0
- vanna/integrations/chromadb/__init__.py +104 -0
- vanna/integrations/chromadb/agent_memory.py +416 -0
- vanna/integrations/clickhouse/__init__.py +5 -0
- vanna/integrations/clickhouse/sql_runner.py +82 -0
- vanna/integrations/duckdb/__init__.py +5 -0
- vanna/integrations/duckdb/sql_runner.py +65 -0
- vanna/integrations/faiss/__init__.py +7 -0
- vanna/integrations/faiss/agent_memory.py +431 -0
- vanna/integrations/google/__init__.py +9 -0
- vanna/integrations/google/gemini.py +370 -0
- vanna/integrations/hive/__init__.py +5 -0
- vanna/integrations/hive/sql_runner.py +87 -0
- vanna/integrations/local/__init__.py +17 -0
- vanna/integrations/local/agent_memory/__init__.py +7 -0
- vanna/integrations/local/agent_memory/in_memory.py +285 -0
- vanna/integrations/local/audit.py +59 -0
- vanna/integrations/local/file_system.py +242 -0
- vanna/integrations/local/file_system_conversation_store.py +255 -0
- vanna/integrations/local/storage.py +62 -0
- vanna/integrations/marqo/__init__.py +7 -0
- vanna/integrations/marqo/agent_memory.py +354 -0
- vanna/integrations/milvus/__init__.py +7 -0
- vanna/integrations/milvus/agent_memory.py +458 -0
- vanna/integrations/mock/__init__.py +9 -0
- vanna/integrations/mock/llm.py +65 -0
- vanna/integrations/mssql/__init__.py +5 -0
- vanna/integrations/mssql/sql_runner.py +66 -0
- vanna/integrations/mysql/__init__.py +5 -0
- vanna/integrations/mysql/sql_runner.py +92 -0
- vanna/integrations/ollama/__init__.py +7 -0
- vanna/integrations/ollama/llm.py +252 -0
- vanna/integrations/openai/__init__.py +10 -0
- vanna/integrations/openai/llm.py +267 -0
- vanna/integrations/openai/responses.py +163 -0
- vanna/integrations/opensearch/__init__.py +7 -0
- vanna/integrations/opensearch/agent_memory.py +411 -0
- vanna/integrations/oracle/__init__.py +5 -0
- vanna/integrations/oracle/sql_runner.py +75 -0
- vanna/integrations/pinecone/__init__.py +7 -0
- vanna/integrations/pinecone/agent_memory.py +329 -0
- vanna/integrations/plotly/__init__.py +5 -0
- vanna/integrations/plotly/chart_generator.py +313 -0
- vanna/integrations/postgres/__init__.py +9 -0
- vanna/integrations/postgres/sql_runner.py +112 -0
- vanna/integrations/premium/agent_memory/__init__.py +7 -0
- vanna/integrations/premium/agent_memory/premium.py +186 -0
- vanna/integrations/presto/__init__.py +5 -0
- vanna/integrations/presto/sql_runner.py +107 -0
- vanna/integrations/qdrant/__init__.py +7 -0
- vanna/integrations/qdrant/agent_memory.py +439 -0
- vanna/integrations/snowflake/__init__.py +5 -0
- vanna/integrations/snowflake/sql_runner.py +147 -0
- vanna/integrations/sqlite/__init__.py +9 -0
- vanna/integrations/sqlite/sql_runner.py +65 -0
- vanna/integrations/weaviate/__init__.py +7 -0
- vanna/integrations/weaviate/agent_memory.py +428 -0
- vanna/{ZhipuAI → legacy/ZhipuAI}/ZhipuAI_embeddings.py +11 -11
- vanna/legacy/__init__.py +403 -0
- vanna/legacy/adapter.py +463 -0
- vanna/{advanced → legacy/advanced}/__init__.py +3 -1
- vanna/{anthropic → legacy/anthropic}/anthropic_chat.py +9 -7
- vanna/{azuresearch → legacy/azuresearch}/azuresearch_vector.py +79 -41
- vanna/{base → legacy/base}/base.py +224 -217
- vanna/legacy/bedrock/__init__.py +1 -0
- vanna/{bedrock → legacy/bedrock}/bedrock_converse.py +13 -12
- vanna/{chromadb → legacy/chromadb}/chromadb_vector.py +3 -1
- vanna/legacy/cohere/__init__.py +2 -0
- vanna/{cohere → legacy/cohere}/cohere_chat.py +19 -14
- vanna/{cohere → legacy/cohere}/cohere_embeddings.py +25 -19
- vanna/{deepseek → legacy/deepseek}/deepseek_chat.py +5 -6
- vanna/legacy/faiss/__init__.py +1 -0
- vanna/{faiss → legacy/faiss}/faiss.py +113 -59
- vanna/{flask → legacy/flask}/__init__.py +84 -43
- vanna/{flask → legacy/flask}/assets.py +5 -5
- vanna/{flask → legacy/flask}/auth.py +5 -4
- vanna/{google → legacy/google}/bigquery_vector.py +75 -42
- vanna/{google → legacy/google}/gemini_chat.py +7 -3
- vanna/{hf → legacy/hf}/hf.py +0 -1
- vanna/{milvus → legacy/milvus}/milvus_vector.py +58 -35
- vanna/{mock → legacy/mock}/llm.py +0 -1
- vanna/legacy/mock/vectordb.py +67 -0
- vanna/legacy/ollama/ollama.py +110 -0
- vanna/{openai → legacy/openai}/openai_chat.py +2 -6
- vanna/legacy/opensearch/opensearch_vector.py +369 -0
- vanna/legacy/opensearch/opensearch_vector_semantic.py +200 -0
- vanna/legacy/oracle/oracle_vector.py +584 -0
- vanna/{pgvector → legacy/pgvector}/pgvector.py +42 -13
- vanna/{qdrant → legacy/qdrant}/qdrant.py +2 -6
- vanna/legacy/qianfan/Qianfan_Chat.py +170 -0
- vanna/legacy/qianfan/Qianfan_embeddings.py +36 -0
- vanna/legacy/qianwen/QianwenAI_chat.py +132 -0
- vanna/{remote.py → legacy/remote.py} +28 -26
- vanna/{utils.py → legacy/utils.py} +6 -11
- vanna/{vannadb → legacy/vannadb}/vannadb_vector.py +115 -46
- vanna/{vllm → legacy/vllm}/vllm.py +5 -6
- vanna/{weaviate → legacy/weaviate}/weaviate_vector.py +59 -40
- vanna/{xinference → legacy/xinference}/xinference.py +6 -6
- vanna/py.typed +0 -0
- vanna/servers/__init__.py +16 -0
- vanna/servers/__main__.py +8 -0
- vanna/servers/base/__init__.py +18 -0
- vanna/servers/base/chat_handler.py +65 -0
- vanna/servers/base/models.py +111 -0
- vanna/servers/base/rich_chat_handler.py +141 -0
- vanna/servers/base/templates.py +331 -0
- vanna/servers/cli/__init__.py +7 -0
- vanna/servers/cli/server_runner.py +204 -0
- vanna/servers/fastapi/__init__.py +7 -0
- vanna/servers/fastapi/app.py +163 -0
- vanna/servers/fastapi/routes.py +183 -0
- vanna/servers/flask/__init__.py +7 -0
- vanna/servers/flask/app.py +132 -0
- vanna/servers/flask/routes.py +137 -0
- vanna/tools/__init__.py +41 -0
- vanna/tools/agent_memory.py +322 -0
- vanna/tools/file_system.py +879 -0
- vanna/tools/python.py +222 -0
- vanna/tools/run_sql.py +165 -0
- vanna/tools/visualize_data.py +195 -0
- vanna/utils/__init__.py +0 -0
- vanna/web_components/__init__.py +44 -0
- vanna-2.0.0rc1.dist-info/METADATA +868 -0
- vanna-2.0.0rc1.dist-info/RECORD +289 -0
- vanna-2.0.0rc1.dist-info/entry_points.txt +3 -0
- vanna/bedrock/__init__.py +0 -1
- vanna/cohere/__init__.py +0 -2
- vanna/faiss/__init__.py +0 -1
- vanna/mock/vectordb.py +0 -55
- vanna/ollama/ollama.py +0 -103
- vanna/opensearch/opensearch_vector.py +0 -392
- vanna/opensearch/opensearch_vector_semantic.py +0 -175
- vanna/oracle/oracle_vector.py +0 -585
- vanna/qianfan/Qianfan_Chat.py +0 -165
- vanna/qianfan/Qianfan_embeddings.py +0 -36
- vanna/qianwen/QianwenAI_chat.py +0 -133
- vanna-0.7.9.dist-info/METADATA +0 -408
- vanna-0.7.9.dist-info/RECORD +0 -79
- /vanna/{ZhipuAI → legacy/ZhipuAI}/ZhipuAI_Chat.py +0 -0
- /vanna/{ZhipuAI → legacy/ZhipuAI}/__init__.py +0 -0
- /vanna/{anthropic → legacy/anthropic}/__init__.py +0 -0
- /vanna/{azuresearch → legacy/azuresearch}/__init__.py +0 -0
- /vanna/{base → legacy/base}/__init__.py +0 -0
- /vanna/{chromadb → legacy/chromadb}/__init__.py +0 -0
- /vanna/{deepseek → legacy/deepseek}/__init__.py +0 -0
- /vanna/{exceptions → legacy/exceptions}/__init__.py +0 -0
- /vanna/{google → legacy/google}/__init__.py +0 -0
- /vanna/{hf → legacy/hf}/__init__.py +0 -0
- /vanna/{local.py → legacy/local.py} +0 -0
- /vanna/{marqo → legacy/marqo}/__init__.py +0 -0
- /vanna/{marqo → legacy/marqo}/marqo.py +0 -0
- /vanna/{milvus → legacy/milvus}/__init__.py +0 -0
- /vanna/{mistral → legacy/mistral}/__init__.py +0 -0
- /vanna/{mistral → legacy/mistral}/mistral.py +0 -0
- /vanna/{mock → legacy/mock}/__init__.py +0 -0
- /vanna/{mock → legacy/mock}/embedding.py +0 -0
- /vanna/{ollama → legacy/ollama}/__init__.py +0 -0
- /vanna/{openai → legacy/openai}/__init__.py +0 -0
- /vanna/{openai → legacy/openai}/openai_embeddings.py +0 -0
- /vanna/{opensearch → legacy/opensearch}/__init__.py +0 -0
- /vanna/{oracle → legacy/oracle}/__init__.py +0 -0
- /vanna/{pgvector → legacy/pgvector}/__init__.py +0 -0
- /vanna/{pinecone → legacy/pinecone}/__init__.py +0 -0
- /vanna/{pinecone → legacy/pinecone}/pinecone_vector.py +0 -0
- /vanna/{qdrant → legacy/qdrant}/__init__.py +0 -0
- /vanna/{qianfan → legacy/qianfan}/__init__.py +0 -0
- /vanna/{qianwen → legacy/qianwen}/QianwenAI_embeddings.py +0 -0
- /vanna/{qianwen → legacy/qianwen}/__init__.py +0 -0
- /vanna/{types → legacy/types}/__init__.py +0 -0
- /vanna/{vannadb → legacy/vannadb}/__init__.py +0 -0
- /vanna/{vllm → legacy/vllm}/__init__.py +0 -0
- /vanna/{weaviate → legacy/weaviate}/__init__.py +0 -0
- /vanna/{xinference → legacy/xinference}/__init__.py +0 -0
- {vanna-0.7.9.dist-info → vanna-2.0.0rc1.dist-info}/WHEEL +0 -0
- {vanna-0.7.9.dist-info → vanna-2.0.0rc1.dist-info}/licenses/LICENSE +0 -0
|
@@ -2,10 +2,7 @@ import datetime
|
|
|
2
2
|
import os
|
|
3
3
|
import uuid
|
|
4
4
|
from typing import List, Optional
|
|
5
|
-
from vertexai.language_models import
|
|
6
|
-
TextEmbeddingInput,
|
|
7
|
-
TextEmbeddingModel
|
|
8
|
-
)
|
|
5
|
+
from vertexai.language_models import TextEmbeddingInput, TextEmbeddingModel
|
|
9
6
|
|
|
10
7
|
import pandas as pd
|
|
11
8
|
from google.cloud import bigquery
|
|
@@ -18,7 +15,9 @@ class BigQuery_VectorStore(VannaBase):
|
|
|
18
15
|
self.config = config
|
|
19
16
|
|
|
20
17
|
self.n_results_sql = config.get("n_results_sql", config.get("n_results", 10))
|
|
21
|
-
self.n_results_documentation = config.get(
|
|
18
|
+
self.n_results_documentation = config.get(
|
|
19
|
+
"n_results_documentation", config.get("n_results", 10)
|
|
20
|
+
)
|
|
22
21
|
self.n_results_ddl = config.get("n_results_ddl", config.get("n_results", 10))
|
|
23
22
|
|
|
24
23
|
if "api_key" in config or os.getenv("GOOGLE_API_KEY"):
|
|
@@ -47,7 +46,7 @@ class BigQuery_VectorStore(VannaBase):
|
|
|
47
46
|
|
|
48
47
|
self.conn = bigquery.Client(project=self.project_id)
|
|
49
48
|
|
|
50
|
-
dataset_name = self.config.get(
|
|
49
|
+
dataset_name = self.config.get("bigquery_dataset_name", "vanna_managed")
|
|
51
50
|
self.dataset_id = f"{self.project_id}.{dataset_name}"
|
|
52
51
|
dataset = bigquery.Dataset(self.dataset_id)
|
|
53
52
|
|
|
@@ -101,21 +100,35 @@ class BigQuery_VectorStore(VannaBase):
|
|
|
101
100
|
# except Exception as e:
|
|
102
101
|
# print(f"Failed to create vector index: {e}")
|
|
103
102
|
|
|
104
|
-
def store_training_data(
|
|
103
|
+
def store_training_data(
|
|
104
|
+
self,
|
|
105
|
+
training_data_type: str,
|
|
106
|
+
question: str,
|
|
107
|
+
content: str,
|
|
108
|
+
embedding: List[float],
|
|
109
|
+
**kwargs,
|
|
110
|
+
) -> str:
|
|
105
111
|
id = str(uuid.uuid4())
|
|
106
112
|
created_at = datetime.datetime.now()
|
|
107
|
-
self.conn.insert_rows_json(
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
113
|
+
self.conn.insert_rows_json(
|
|
114
|
+
self.table_id,
|
|
115
|
+
[
|
|
116
|
+
{
|
|
117
|
+
"id": id,
|
|
118
|
+
"training_data_type": training_data_type,
|
|
119
|
+
"question": question,
|
|
120
|
+
"content": content,
|
|
121
|
+
"embedding": embedding,
|
|
122
|
+
"created_at": created_at.isoformat(),
|
|
123
|
+
}
|
|
124
|
+
],
|
|
125
|
+
)
|
|
115
126
|
|
|
116
127
|
return id
|
|
117
128
|
|
|
118
|
-
def fetch_similar_training_data(
|
|
129
|
+
def fetch_similar_training_data(
|
|
130
|
+
self, training_data_type: str, question: str, n_results, **kwargs
|
|
131
|
+
) -> pd.DataFrame:
|
|
119
132
|
question_embedding = self.generate_question_embedding(question)
|
|
120
133
|
|
|
121
134
|
query = f"""
|
|
@@ -145,29 +158,28 @@ class BigQuery_VectorStore(VannaBase):
|
|
|
145
158
|
embeddings = None
|
|
146
159
|
|
|
147
160
|
if self.type == "VERTEX_AI":
|
|
148
|
-
|
|
149
|
-
|
|
161
|
+
input = [TextEmbeddingInput(data, task)]
|
|
162
|
+
model = TextEmbeddingModel.from_pretrained("text-embedding-004")
|
|
150
163
|
|
|
151
|
-
|
|
164
|
+
result = model.get_embeddings(input)
|
|
152
165
|
|
|
153
|
-
|
|
154
|
-
|
|
166
|
+
if len(result) > 0:
|
|
167
|
+
embeddings = result[0].values
|
|
155
168
|
else:
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
task_type=task)
|
|
169
|
+
# Use Gemini Consumer API
|
|
170
|
+
result = self.genai.embed_content(
|
|
171
|
+
model="models/text-embedding-004", content=data, task_type=task
|
|
172
|
+
)
|
|
161
173
|
|
|
162
|
-
|
|
163
|
-
|
|
174
|
+
if "embedding" in result:
|
|
175
|
+
embeddings = result["embedding"]
|
|
164
176
|
|
|
165
177
|
return embeddings
|
|
166
178
|
|
|
167
179
|
def generate_question_embedding(self, data: str, **kwargs) -> List[float]:
|
|
168
180
|
result = self.get_embeddings(data, "RETRIEVAL_QUERY")
|
|
169
181
|
|
|
170
|
-
if result
|
|
182
|
+
if result is not None:
|
|
171
183
|
return result
|
|
172
184
|
else:
|
|
173
185
|
raise ValueError("No embeddings returned")
|
|
@@ -175,7 +187,7 @@ class BigQuery_VectorStore(VannaBase):
|
|
|
175
187
|
def generate_storage_embedding(self, data: str, **kwargs) -> List[float]:
|
|
176
188
|
result = self.get_embeddings(data, "RETRIEVAL_DOCUMENT")
|
|
177
189
|
|
|
178
|
-
if result
|
|
190
|
+
if result is not None:
|
|
179
191
|
return result
|
|
180
192
|
else:
|
|
181
193
|
raise ValueError("No embeddings returned")
|
|
@@ -195,45 +207,66 @@ class BigQuery_VectorStore(VannaBase):
|
|
|
195
207
|
return self.generate_storage_embedding(data, **kwargs)
|
|
196
208
|
|
|
197
209
|
def get_similar_question_sql(self, question: str, **kwargs) -> list:
|
|
198
|
-
df = self.fetch_similar_training_data(
|
|
210
|
+
df = self.fetch_similar_training_data(
|
|
211
|
+
training_data_type="sql", question=question, n_results=self.n_results_sql
|
|
212
|
+
)
|
|
199
213
|
|
|
200
214
|
# Return a list of dictionaries with only question, sql fields. The content field needs to be renamed to sql
|
|
201
|
-
return df.rename(columns={"content": "sql"})[["question", "sql"]].to_dict(
|
|
215
|
+
return df.rename(columns={"content": "sql"})[["question", "sql"]].to_dict(
|
|
216
|
+
orient="records"
|
|
217
|
+
)
|
|
202
218
|
|
|
203
219
|
def get_related_ddl(self, question: str, **kwargs) -> list:
|
|
204
|
-
df = self.fetch_similar_training_data(
|
|
220
|
+
df = self.fetch_similar_training_data(
|
|
221
|
+
training_data_type="ddl", question=question, n_results=self.n_results_ddl
|
|
222
|
+
)
|
|
205
223
|
|
|
206
224
|
# Return a list of strings of the content
|
|
207
225
|
return df["content"].tolist()
|
|
208
226
|
|
|
209
227
|
def get_related_documentation(self, question: str, **kwargs) -> list:
|
|
210
|
-
df = self.fetch_similar_training_data(
|
|
228
|
+
df = self.fetch_similar_training_data(
|
|
229
|
+
training_data_type="documentation",
|
|
230
|
+
question=question,
|
|
231
|
+
n_results=self.n_results_documentation,
|
|
232
|
+
)
|
|
211
233
|
|
|
212
234
|
# Return a list of strings of the content
|
|
213
235
|
return df["content"].tolist()
|
|
214
236
|
|
|
215
237
|
def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
|
|
216
|
-
doc = {
|
|
217
|
-
"question": question,
|
|
218
|
-
"sql": sql
|
|
219
|
-
}
|
|
238
|
+
doc = {"question": question, "sql": sql}
|
|
220
239
|
|
|
221
240
|
embedding = self.generate_embedding(str(doc))
|
|
222
241
|
|
|
223
|
-
return self.store_training_data(
|
|
242
|
+
return self.store_training_data(
|
|
243
|
+
training_data_type="sql",
|
|
244
|
+
question=question,
|
|
245
|
+
content=sql,
|
|
246
|
+
embedding=embedding,
|
|
247
|
+
)
|
|
224
248
|
|
|
225
249
|
def add_ddl(self, ddl: str, **kwargs) -> str:
|
|
226
250
|
embedding = self.generate_embedding(ddl)
|
|
227
251
|
|
|
228
|
-
return self.store_training_data(
|
|
252
|
+
return self.store_training_data(
|
|
253
|
+
training_data_type="ddl", question="", content=ddl, embedding=embedding
|
|
254
|
+
)
|
|
229
255
|
|
|
230
256
|
def add_documentation(self, documentation: str, **kwargs) -> str:
|
|
231
257
|
embedding = self.generate_embedding(documentation)
|
|
232
258
|
|
|
233
|
-
return self.store_training_data(
|
|
259
|
+
return self.store_training_data(
|
|
260
|
+
training_data_type="documentation",
|
|
261
|
+
question="",
|
|
262
|
+
content=documentation,
|
|
263
|
+
embedding=embedding,
|
|
264
|
+
)
|
|
234
265
|
|
|
235
266
|
def get_training_data(self, **kwargs) -> pd.DataFrame:
|
|
236
|
-
query =
|
|
267
|
+
query = (
|
|
268
|
+
f"SELECT id, training_data_type, question, content FROM `{self.table_id}`"
|
|
269
|
+
)
|
|
237
270
|
|
|
238
271
|
return self.conn.query(query).result().to_dataframe()
|
|
239
272
|
|
|
@@ -35,14 +35,18 @@ class GoogleGeminiChat(VannaBase):
|
|
|
35
35
|
import vertexai
|
|
36
36
|
from vertexai.generative_models import GenerativeModel
|
|
37
37
|
|
|
38
|
-
json_file_path = config.get(
|
|
38
|
+
json_file_path = config.get(
|
|
39
|
+
"google_credentials"
|
|
40
|
+
) # Assuming the JSON file path is provided in the config
|
|
39
41
|
|
|
40
42
|
if not json_file_path or not os.path.exists(json_file_path):
|
|
41
|
-
raise FileNotFoundError(
|
|
43
|
+
raise FileNotFoundError(
|
|
44
|
+
f"JSON credentials file not found at: {json_file_path}"
|
|
45
|
+
)
|
|
42
46
|
|
|
43
47
|
try:
|
|
44
48
|
# Validate and set the JSON file path for GOOGLE_APPLICATION_CREDENTIALS
|
|
45
|
-
os.environ[
|
|
49
|
+
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = json_file_path
|
|
46
50
|
|
|
47
51
|
# Initialize VertexAI with the credentials
|
|
48
52
|
credentials, _ = google.auth.default()
|
vanna/{hf → legacy/hf}/hf.py
RENAMED
|
@@ -33,6 +33,7 @@ class Milvus_VectorStore(VannaBase):
|
|
|
33
33
|
For more models, please refer to:
|
|
34
34
|
https://milvus.io/docs/embeddings.md
|
|
35
35
|
"""
|
|
36
|
+
|
|
36
37
|
def __init__(self, config=None):
|
|
37
38
|
VannaBase.__init__(self, config=config)
|
|
38
39
|
|
|
@@ -45,7 +46,9 @@ class Milvus_VectorStore(VannaBase):
|
|
|
45
46
|
self.embedding_function = config.get("embedding_function")
|
|
46
47
|
else:
|
|
47
48
|
self.embedding_function = model.DefaultEmbeddingFunction()
|
|
48
|
-
self._embedding_dim = self.embedding_function.encode_documents(["foo"])[
|
|
49
|
+
self._embedding_dim = self.embedding_function.encode_documents(["foo"])[
|
|
50
|
+
0
|
|
51
|
+
].shape[0]
|
|
49
52
|
self._create_collections()
|
|
50
53
|
self.n_results = config.get("n_results", 10)
|
|
51
54
|
|
|
@@ -54,21 +57,32 @@ class Milvus_VectorStore(VannaBase):
|
|
|
54
57
|
self._create_ddl_collection("vannaddl")
|
|
55
58
|
self._create_doc_collection("vannadoc")
|
|
56
59
|
|
|
57
|
-
|
|
58
60
|
def generate_embedding(self, data: str, **kwargs) -> List[float]:
|
|
59
61
|
return self.embedding_function.encode_documents(data).tolist()
|
|
60
62
|
|
|
61
|
-
|
|
62
63
|
def _create_sql_collection(self, name: str):
|
|
63
64
|
if not self.milvus_client.has_collection(collection_name=name):
|
|
64
65
|
vannasql_schema = MilvusClient.create_schema(
|
|
65
66
|
auto_id=False,
|
|
66
67
|
enable_dynamic_field=False,
|
|
67
68
|
)
|
|
68
|
-
vannasql_schema.add_field(
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
69
|
+
vannasql_schema.add_field(
|
|
70
|
+
field_name="id",
|
|
71
|
+
datatype=DataType.VARCHAR,
|
|
72
|
+
max_length=65535,
|
|
73
|
+
is_primary=True,
|
|
74
|
+
)
|
|
75
|
+
vannasql_schema.add_field(
|
|
76
|
+
field_name="text", datatype=DataType.VARCHAR, max_length=65535
|
|
77
|
+
)
|
|
78
|
+
vannasql_schema.add_field(
|
|
79
|
+
field_name="sql", datatype=DataType.VARCHAR, max_length=65535
|
|
80
|
+
)
|
|
81
|
+
vannasql_schema.add_field(
|
|
82
|
+
field_name="vector",
|
|
83
|
+
datatype=DataType.FLOAT_VECTOR,
|
|
84
|
+
dim=self._embedding_dim,
|
|
85
|
+
)
|
|
72
86
|
|
|
73
87
|
vannasql_index_params = self.milvus_client.prepare_index_params()
|
|
74
88
|
vannasql_index_params.add_index(
|
|
@@ -81,7 +95,7 @@ class Milvus_VectorStore(VannaBase):
|
|
|
81
95
|
collection_name=name,
|
|
82
96
|
schema=vannasql_schema,
|
|
83
97
|
index_params=vannasql_index_params,
|
|
84
|
-
consistency_level="Strong"
|
|
98
|
+
consistency_level="Strong",
|
|
85
99
|
)
|
|
86
100
|
|
|
87
101
|
def _create_ddl_collection(self, name: str):
|
|
@@ -90,9 +104,20 @@ class Milvus_VectorStore(VannaBase):
|
|
|
90
104
|
auto_id=False,
|
|
91
105
|
enable_dynamic_field=False,
|
|
92
106
|
)
|
|
93
|
-
vannaddl_schema.add_field(
|
|
94
|
-
|
|
95
|
-
|
|
107
|
+
vannaddl_schema.add_field(
|
|
108
|
+
field_name="id",
|
|
109
|
+
datatype=DataType.VARCHAR,
|
|
110
|
+
max_length=65535,
|
|
111
|
+
is_primary=True,
|
|
112
|
+
)
|
|
113
|
+
vannaddl_schema.add_field(
|
|
114
|
+
field_name="ddl", datatype=DataType.VARCHAR, max_length=65535
|
|
115
|
+
)
|
|
116
|
+
vannaddl_schema.add_field(
|
|
117
|
+
field_name="vector",
|
|
118
|
+
datatype=DataType.FLOAT_VECTOR,
|
|
119
|
+
dim=self._embedding_dim,
|
|
120
|
+
)
|
|
96
121
|
|
|
97
122
|
vannaddl_index_params = self.milvus_client.prepare_index_params()
|
|
98
123
|
vannaddl_index_params.add_index(
|
|
@@ -105,7 +130,7 @@ class Milvus_VectorStore(VannaBase):
|
|
|
105
130
|
collection_name=name,
|
|
106
131
|
schema=vannaddl_schema,
|
|
107
132
|
index_params=vannaddl_index_params,
|
|
108
|
-
consistency_level="Strong"
|
|
133
|
+
consistency_level="Strong",
|
|
109
134
|
)
|
|
110
135
|
|
|
111
136
|
def _create_doc_collection(self, name: str):
|
|
@@ -114,9 +139,20 @@ class Milvus_VectorStore(VannaBase):
|
|
|
114
139
|
auto_id=False,
|
|
115
140
|
enable_dynamic_field=False,
|
|
116
141
|
)
|
|
117
|
-
vannadoc_schema.add_field(
|
|
118
|
-
|
|
119
|
-
|
|
142
|
+
vannadoc_schema.add_field(
|
|
143
|
+
field_name="id",
|
|
144
|
+
datatype=DataType.VARCHAR,
|
|
145
|
+
max_length=65535,
|
|
146
|
+
is_primary=True,
|
|
147
|
+
)
|
|
148
|
+
vannadoc_schema.add_field(
|
|
149
|
+
field_name="doc", datatype=DataType.VARCHAR, max_length=65535
|
|
150
|
+
)
|
|
151
|
+
vannadoc_schema.add_field(
|
|
152
|
+
field_name="vector",
|
|
153
|
+
datatype=DataType.FLOAT_VECTOR,
|
|
154
|
+
dim=self._embedding_dim,
|
|
155
|
+
)
|
|
120
156
|
|
|
121
157
|
vannadoc_index_params = self.milvus_client.prepare_index_params()
|
|
122
158
|
vannadoc_index_params.add_index(
|
|
@@ -129,7 +165,7 @@ class Milvus_VectorStore(VannaBase):
|
|
|
129
165
|
collection_name=name,
|
|
130
166
|
schema=vannadoc_schema,
|
|
131
167
|
index_params=vannadoc_index_params,
|
|
132
|
-
consistency_level="Strong"
|
|
168
|
+
consistency_level="Strong",
|
|
133
169
|
)
|
|
134
170
|
|
|
135
171
|
def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
|
|
@@ -139,12 +175,7 @@ class Milvus_VectorStore(VannaBase):
|
|
|
139
175
|
embedding = self.embedding_function.encode_documents([question])[0]
|
|
140
176
|
self.milvus_client.insert(
|
|
141
177
|
collection_name="vannasql",
|
|
142
|
-
data={
|
|
143
|
-
"id": _id,
|
|
144
|
-
"text": question,
|
|
145
|
-
"sql": sql,
|
|
146
|
-
"vector": embedding
|
|
147
|
-
}
|
|
178
|
+
data={"id": _id, "text": question, "sql": sql, "vector": embedding},
|
|
148
179
|
)
|
|
149
180
|
return _id
|
|
150
181
|
|
|
@@ -155,11 +186,7 @@ class Milvus_VectorStore(VannaBase):
|
|
|
155
186
|
embedding = self.embedding_function.encode_documents([ddl])[0]
|
|
156
187
|
self.milvus_client.insert(
|
|
157
188
|
collection_name="vannaddl",
|
|
158
|
-
data={
|
|
159
|
-
"id": _id,
|
|
160
|
-
"ddl": ddl,
|
|
161
|
-
"vector": embedding
|
|
162
|
-
}
|
|
189
|
+
data={"id": _id, "ddl": ddl, "vector": embedding},
|
|
163
190
|
)
|
|
164
191
|
return _id
|
|
165
192
|
|
|
@@ -170,11 +197,7 @@ class Milvus_VectorStore(VannaBase):
|
|
|
170
197
|
embedding = self.embedding_function.encode_documents([documentation])[0]
|
|
171
198
|
self.milvus_client.insert(
|
|
172
199
|
collection_name="vannadoc",
|
|
173
|
-
data={
|
|
174
|
-
"id": _id,
|
|
175
|
-
"doc": documentation,
|
|
176
|
-
"vector": embedding
|
|
177
|
-
}
|
|
200
|
+
data={"id": _id, "doc": documentation, "vector": embedding},
|
|
178
201
|
)
|
|
179
202
|
return _id
|
|
180
203
|
|
|
@@ -237,7 +260,7 @@ class Milvus_VectorStore(VannaBase):
|
|
|
237
260
|
data=embeddings,
|
|
238
261
|
limit=self.n_results,
|
|
239
262
|
output_fields=["text", "sql"],
|
|
240
|
-
search_params=search_params
|
|
263
|
+
search_params=search_params,
|
|
241
264
|
)
|
|
242
265
|
res = res[0]
|
|
243
266
|
|
|
@@ -261,7 +284,7 @@ class Milvus_VectorStore(VannaBase):
|
|
|
261
284
|
data=embeddings,
|
|
262
285
|
limit=self.n_results,
|
|
263
286
|
output_fields=["ddl"],
|
|
264
|
-
search_params=search_params
|
|
287
|
+
search_params=search_params,
|
|
265
288
|
)
|
|
266
289
|
res = res[0]
|
|
267
290
|
|
|
@@ -282,7 +305,7 @@ class Milvus_VectorStore(VannaBase):
|
|
|
282
305
|
data=embeddings,
|
|
283
306
|
limit=self.n_results,
|
|
284
307
|
output_fields=["doc"],
|
|
285
|
-
search_params=search_params
|
|
308
|
+
search_params=search_params,
|
|
286
309
|
)
|
|
287
310
|
res = res[0]
|
|
288
311
|
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
|
|
3
|
+
from ..base import VannaBase
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class MockVectorDB(VannaBase):
|
|
7
|
+
def __init__(self, config=None):
|
|
8
|
+
pass
|
|
9
|
+
|
|
10
|
+
def _get_id(self, value: str, **kwargs) -> str:
|
|
11
|
+
# Hash the value and return the ID
|
|
12
|
+
return str(hash(value))
|
|
13
|
+
|
|
14
|
+
def add_ddl(self, ddl: str, **kwargs) -> str:
|
|
15
|
+
return self._get_id(ddl)
|
|
16
|
+
|
|
17
|
+
def add_documentation(self, doc: str, **kwargs) -> str:
|
|
18
|
+
return self._get_id(doc)
|
|
19
|
+
|
|
20
|
+
def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
|
|
21
|
+
return self._get_id(question)
|
|
22
|
+
|
|
23
|
+
def get_related_ddl(self, question: str, **kwargs) -> list:
|
|
24
|
+
return []
|
|
25
|
+
|
|
26
|
+
def get_related_documentation(self, question: str, **kwargs) -> list:
|
|
27
|
+
return []
|
|
28
|
+
|
|
29
|
+
def get_similar_question_sql(self, question: str, **kwargs) -> list:
|
|
30
|
+
return []
|
|
31
|
+
|
|
32
|
+
def get_training_data(self, **kwargs) -> pd.DataFrame:
|
|
33
|
+
return pd.DataFrame(
|
|
34
|
+
{
|
|
35
|
+
"id": {
|
|
36
|
+
0: "19546-ddl",
|
|
37
|
+
1: "91597-sql",
|
|
38
|
+
2: "133976-sql",
|
|
39
|
+
3: "59851-doc",
|
|
40
|
+
4: "73046-sql",
|
|
41
|
+
},
|
|
42
|
+
"training_data_type": {
|
|
43
|
+
0: "ddl",
|
|
44
|
+
1: "sql",
|
|
45
|
+
2: "sql",
|
|
46
|
+
3: "documentation",
|
|
47
|
+
4: "sql",
|
|
48
|
+
},
|
|
49
|
+
"question": {
|
|
50
|
+
0: None,
|
|
51
|
+
1: "What are the top selling genres?",
|
|
52
|
+
2: "What are the low 7 artists by sales?",
|
|
53
|
+
3: None,
|
|
54
|
+
4: "What is the total sales for each customer?",
|
|
55
|
+
},
|
|
56
|
+
"content": {
|
|
57
|
+
0: "CREATE TABLE [Invoice]\n(\n [InvoiceId] INTEGER NOT NULL,\n [CustomerId] INTEGER NOT NULL,\n [InvoiceDate] DATETIME NOT NULL,\n [BillingAddress] NVARCHAR(70),\n [BillingCity] NVARCHAR(40),\n [BillingState] NVARCHAR(40),\n [BillingCountry] NVARCHAR(40),\n [BillingPostalCode] NVARCHAR(10),\n [Total] NUMERIC(10,2) NOT NULL,\n CONSTRAINT [PK_Invoice] PRIMARY KEY ([InvoiceId]),\n FOREIGN KEY ([CustomerId]) REFERENCES [Customer] ([CustomerId]) \n\t\tON DELETE NO ACTION ON UPDATE NO ACTION\n)",
|
|
58
|
+
1: "SELECT g.Name AS Genre, SUM(il.Quantity) AS TotalSales\nFROM Genre g\nJOIN Track t ON g.GenreId = t.GenreId\nJOIN InvoiceLine il ON t.TrackId = il.TrackId\nGROUP BY g.GenreId, g.Name\nORDER BY TotalSales DESC;",
|
|
59
|
+
2: "SELECT a.ArtistId, a.Name, SUM(il.Quantity) AS TotalSales\nFROM Artist a\nINNER JOIN Album al ON a.ArtistId = al.ArtistId\nINNER JOIN Track t ON al.AlbumId = t.AlbumId\nINNER JOIN InvoiceLine il ON t.TrackId = il.TrackId\nGROUP BY a.ArtistId, a.Name\nORDER BY TotalSales ASC\nLIMIT 7;",
|
|
60
|
+
3: "This is a SQLite database. For dates rememeber to use SQLite syntax.",
|
|
61
|
+
4: "SELECT c.CustomerId, c.FirstName, c.LastName, SUM(i.Total) AS TotalSales\nFROM Customer c\nJOIN Invoice i ON c.CustomerId = i.CustomerId\nGROUP BY c.CustomerId, c.FirstName, c.LastName;",
|
|
62
|
+
},
|
|
63
|
+
}
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
def remove_training_data(id: str, **kwargs) -> bool:
|
|
67
|
+
return True
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import re
|
|
3
|
+
|
|
4
|
+
from httpx import Timeout
|
|
5
|
+
|
|
6
|
+
from ..base import VannaBase
|
|
7
|
+
from ..exceptions import DependencyError
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class Ollama(VannaBase):
|
|
11
|
+
def __init__(self, config=None):
|
|
12
|
+
try:
|
|
13
|
+
ollama = __import__("ollama")
|
|
14
|
+
except ImportError:
|
|
15
|
+
raise DependencyError(
|
|
16
|
+
"You need to install required dependencies to execute this method, run command:"
|
|
17
|
+
" \npip install ollama"
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
if not config:
|
|
21
|
+
raise ValueError("config must contain at least Ollama model")
|
|
22
|
+
if "model" not in config.keys():
|
|
23
|
+
raise ValueError("config must contain at least Ollama model")
|
|
24
|
+
self.host = config.get("ollama_host", "http://localhost:11434")
|
|
25
|
+
self.model = config["model"]
|
|
26
|
+
if ":" not in self.model:
|
|
27
|
+
self.model += ":latest"
|
|
28
|
+
|
|
29
|
+
self.ollama_timeout = config.get("ollama_timeout", 240.0)
|
|
30
|
+
|
|
31
|
+
self.ollama_client = ollama.Client(
|
|
32
|
+
self.host, timeout=Timeout(self.ollama_timeout)
|
|
33
|
+
)
|
|
34
|
+
self.keep_alive = config.get("keep_alive", None)
|
|
35
|
+
self.ollama_options = config.get("options", {})
|
|
36
|
+
self.num_ctx = self.ollama_options.get("num_ctx", 2048)
|
|
37
|
+
self.__pull_model_if_ne(self.ollama_client, self.model)
|
|
38
|
+
|
|
39
|
+
@staticmethod
|
|
40
|
+
def __pull_model_if_ne(ollama_client, model):
|
|
41
|
+
model_response = ollama_client.list()
|
|
42
|
+
model_lists = [
|
|
43
|
+
model_element["model"] for model_element in model_response.get("models", [])
|
|
44
|
+
]
|
|
45
|
+
if model not in model_lists:
|
|
46
|
+
ollama_client.pull(model)
|
|
47
|
+
|
|
48
|
+
def system_message(self, message: str) -> any:
|
|
49
|
+
return {"role": "system", "content": message}
|
|
50
|
+
|
|
51
|
+
def user_message(self, message: str) -> any:
|
|
52
|
+
return {"role": "user", "content": message}
|
|
53
|
+
|
|
54
|
+
def assistant_message(self, message: str) -> any:
|
|
55
|
+
return {"role": "assistant", "content": message}
|
|
56
|
+
|
|
57
|
+
def extract_sql(self, llm_response):
|
|
58
|
+
"""
|
|
59
|
+
Extracts the first SQL statement after the word 'select', ignoring case,
|
|
60
|
+
matches until the first semicolon, three backticks, or the end of the string,
|
|
61
|
+
and removes three backticks if they exist in the extracted string.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
- llm_response (str): The string to search within for an SQL statement.
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
- str: The first SQL statement found, with three backticks removed, or an empty string if no match is found.
|
|
68
|
+
"""
|
|
69
|
+
# Remove ollama-generated extra characters
|
|
70
|
+
llm_response = llm_response.replace("\\_", "_")
|
|
71
|
+
llm_response = llm_response.replace("\\", "")
|
|
72
|
+
|
|
73
|
+
# Regular expression to find ```sql' and capture until '```'
|
|
74
|
+
sql = re.search(r"```sql\n((.|\n)*?)(?=;|\[|```)", llm_response, re.DOTALL)
|
|
75
|
+
# Regular expression to find 'select, with (ignoring case) and capture until ';', [ (this happens in case of mistral) or end of string
|
|
76
|
+
select_with = re.search(
|
|
77
|
+
r"(select|(with.*?as \())(.*?)(?=;|\[|```)",
|
|
78
|
+
llm_response,
|
|
79
|
+
re.IGNORECASE | re.DOTALL,
|
|
80
|
+
)
|
|
81
|
+
if sql:
|
|
82
|
+
self.log(f"Output from LLM: {llm_response} \nExtracted SQL: {sql.group(1)}")
|
|
83
|
+
return sql.group(1).replace("```", "")
|
|
84
|
+
elif select_with:
|
|
85
|
+
self.log(
|
|
86
|
+
f"Output from LLM: {llm_response} \nExtracted SQL: {select_with.group(0)}"
|
|
87
|
+
)
|
|
88
|
+
return select_with.group(0)
|
|
89
|
+
else:
|
|
90
|
+
return llm_response
|
|
91
|
+
|
|
92
|
+
def submit_prompt(self, prompt, **kwargs) -> str:
|
|
93
|
+
self.log(
|
|
94
|
+
f"Ollama parameters:\n"
|
|
95
|
+
f"model={self.model},\n"
|
|
96
|
+
f"options={self.ollama_options},\n"
|
|
97
|
+
f"keep_alive={self.keep_alive}"
|
|
98
|
+
)
|
|
99
|
+
self.log(f"Prompt Content:\n{json.dumps(prompt, ensure_ascii=False)}")
|
|
100
|
+
response_dict = self.ollama_client.chat(
|
|
101
|
+
model=self.model,
|
|
102
|
+
messages=prompt,
|
|
103
|
+
stream=False,
|
|
104
|
+
options=self.ollama_options,
|
|
105
|
+
keep_alive=self.keep_alive,
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
self.log(f"Ollama Response:\n{str(response_dict)}")
|
|
109
|
+
|
|
110
|
+
return response_dict["message"]["content"]
|
|
@@ -65,9 +65,7 @@ class OpenAI_Chat(VannaBase):
|
|
|
65
65
|
|
|
66
66
|
if kwargs.get("model", None) is not None:
|
|
67
67
|
model = kwargs.get("model", None)
|
|
68
|
-
print(
|
|
69
|
-
f"Using model {model} for {num_tokens} tokens (approx)"
|
|
70
|
-
)
|
|
68
|
+
print(f"Using model {model} for {num_tokens} tokens (approx)")
|
|
71
69
|
response = self.client.chat.completions.create(
|
|
72
70
|
model=model,
|
|
73
71
|
messages=prompt,
|
|
@@ -76,9 +74,7 @@ class OpenAI_Chat(VannaBase):
|
|
|
76
74
|
)
|
|
77
75
|
elif kwargs.get("engine", None) is not None:
|
|
78
76
|
engine = kwargs.get("engine", None)
|
|
79
|
-
print(
|
|
80
|
-
f"Using model {engine} for {num_tokens} tokens (approx)"
|
|
81
|
-
)
|
|
77
|
+
print(f"Using model {engine} for {num_tokens} tokens (approx)")
|
|
82
78
|
response = self.client.chat.completions.create(
|
|
83
79
|
engine=engine,
|
|
84
80
|
messages=prompt,
|