vanna 0.1.0__tar.gz → 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. {vanna-0.1.0 → vanna-0.2.0}/PKG-INFO +3 -1
  2. {vanna-0.1.0 → vanna-0.2.0}/pyproject.toml +3 -3
  3. {vanna-0.1.0 → vanna-0.2.0}/src/vanna/base/base.py +137 -24
  4. {vanna-0.1.0 → vanna-0.2.0}/src/vanna/chromadb/chromadb_vector.py +22 -6
  5. vanna-0.1.0/src/vanna/flask.py → vanna-0.2.0/src/vanna/flask/__init__.py +59 -52
  6. vanna-0.2.0/src/vanna/flask/assets.py +36 -0
  7. {vanna-0.1.0 → vanna-0.2.0}/src/vanna/openai/openai_chat.py +25 -19
  8. {vanna-0.1.0 → vanna-0.2.0}/src/vanna/openai/openai_embeddings.py +0 -2
  9. {vanna-0.1.0 → vanna-0.2.0}/src/vanna/remote.py +44 -70
  10. {vanna-0.1.0 → vanna-0.2.0}/README.md +0 -0
  11. {vanna-0.1.0 → vanna-0.2.0}/src/vanna/ZhipuAI/ZhipuAI_Chat.py +0 -0
  12. {vanna-0.1.0 → vanna-0.2.0}/src/vanna/ZhipuAI/ZhipuAI_embeddings.py +0 -0
  13. {vanna-0.1.0 → vanna-0.2.0}/src/vanna/ZhipuAI/__init__.py +0 -0
  14. {vanna-0.1.0 → vanna-0.2.0}/src/vanna/__init__.py +0 -0
  15. {vanna-0.1.0 → vanna-0.2.0}/src/vanna/base/__init__.py +0 -0
  16. {vanna-0.1.0 → vanna-0.2.0}/src/vanna/chromadb/__init__.py +0 -0
  17. {vanna-0.1.0 → vanna-0.2.0}/src/vanna/exceptions/__init__.py +0 -0
  18. {vanna-0.1.0 → vanna-0.2.0}/src/vanna/local.py +0 -0
  19. {vanna-0.1.0 → vanna-0.2.0}/src/vanna/marqo/__init__.py +0 -0
  20. {vanna-0.1.0 → vanna-0.2.0}/src/vanna/marqo/marqo.py +0 -0
  21. {vanna-0.1.0 → vanna-0.2.0}/src/vanna/mistral/__init__.py +0 -0
  22. {vanna-0.1.0 → vanna-0.2.0}/src/vanna/mistral/mistral.py +0 -0
  23. {vanna-0.1.0 → vanna-0.2.0}/src/vanna/ollama/__init__.py +0 -0
  24. {vanna-0.1.0 → vanna-0.2.0}/src/vanna/ollama/ollama.py +0 -0
  25. {vanna-0.1.0 → vanna-0.2.0}/src/vanna/openai/__init__.py +0 -0
  26. {vanna-0.1.0 → vanna-0.2.0}/src/vanna/types/__init__.py +0 -0
  27. {vanna-0.1.0 → vanna-0.2.0}/src/vanna/utils.py +0 -0
  28. {vanna-0.1.0 → vanna-0.2.0}/src/vanna/vannadb/__init__.py +0 -0
  29. {vanna-0.1.0 → vanna-0.2.0}/src/vanna/vannadb/vannadb_vector.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: vanna
3
- Version: 0.1.0
3
+ Version: 0.2.0
4
4
  Summary: Generate SQL queries from natural language
5
5
  Author-email: Zain Hoda <zain@vanna.ai>
6
6
  Requires-Python: >=3.9
@@ -15,6 +15,7 @@ Requires-Dist: pandas
15
15
  Requires-Dist: sqlparse
16
16
  Requires-Dist: kaleido
17
17
  Requires-Dist: flask
18
+ Requires-Dist: sqlalchemy
18
19
  Requires-Dist: psycopg2-binary ; extra == "all"
19
20
  Requires-Dist: db-dtypes ; extra == "all"
20
21
  Requires-Dist: google-cloud-bigquery ; extra == "all"
@@ -22,6 +23,7 @@ Requires-Dist: snowflake-connector-python ; extra == "all"
22
23
  Requires-Dist: duckdb ; extra == "all"
23
24
  Requires-Dist: openai ; extra == "all"
24
25
  Requires-Dist: mistralai ; extra == "all"
26
+ Requires-Dist: chromadb ; extra == "all"
25
27
  Requires-Dist: google-cloud-bigquery ; extra == "bigquery"
26
28
  Requires-Dist: chromadb ; extra == "chromadb"
27
29
  Requires-Dist: duckdb ; extra == "duckdb"
@@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi"
4
4
 
5
5
  [project]
6
6
  name = "vanna"
7
- version = "0.1.0"
7
+ version = "0.2.0"
8
8
  authors = [
9
9
  { name="Zain Hoda", email="zain@vanna.ai" },
10
10
  ]
@@ -18,7 +18,7 @@ classifiers = [
18
18
  "Operating System :: OS Independent",
19
19
  ]
20
20
  dependencies = [
21
- "requests", "tabulate", "plotly", "pandas", "sqlparse", "kaleido", "flask"
21
+ "requests", "tabulate", "plotly", "pandas", "sqlparse", "kaleido", "flask", "sqlalchemy"
22
22
  ]
23
23
 
24
24
  [project.urls]
@@ -30,7 +30,7 @@ postgres = ["psycopg2-binary", "db-dtypes"]
30
30
  bigquery = ["google-cloud-bigquery"]
31
31
  snowflake = ["snowflake-connector-python"]
32
32
  duckdb = ["duckdb"]
33
- all = ["psycopg2-binary", "db-dtypes", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "mistralai"]
33
+ all = ["psycopg2-binary", "db-dtypes", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "mistralai", "chromadb"]
34
34
  test = ["tox"]
35
35
  chromadb = ["chromadb"]
36
36
  openai = ["openai"]
@@ -72,6 +72,7 @@ class VannaBase(ABC):
72
72
  def __init__(self, config=None):
73
73
  self.config = config
74
74
  self.run_sql_is_set = False
75
+ self.static_documentation = ""
75
76
 
76
77
  def log(self, message: str):
77
78
  print(message)
@@ -140,18 +141,35 @@ class VannaBase(ABC):
140
141
  else:
141
142
  return False
142
143
 
143
- def generate_followup_questions(self, question: str, **kwargs) -> str:
144
- question_sql_list = self.get_similar_question_sql(question, **kwargs)
145
- ddl_list = self.get_related_ddl(question, **kwargs)
146
- doc_list = self.get_related_documentation(question, **kwargs)
147
- prompt = self.get_followup_questions_prompt(
148
- question=question,
149
- question_sql_list=question_sql_list,
150
- ddl_list=ddl_list,
151
- doc_list=doc_list,
152
- **kwargs,
153
- )
154
- llm_response = self.submit_prompt(prompt, **kwargs)
144
+ def generate_followup_questions(
145
+ self, question: str, sql: str, df: pd.DataFrame, **kwargs
146
+ ) -> list:
147
+ """
148
+ **Example:**
149
+ ```python
150
+ vn.generate_followup_questions("What are the top 10 customers by sales?", df)
151
+ ```
152
+
153
+ Generate a list of followup questions that you can ask Vanna.AI.
154
+
155
+ Args:
156
+ question (str): The question that was asked.
157
+ df (pd.DataFrame): The results of the SQL query.
158
+
159
+ Returns:
160
+ list: A list of followup questions that you can ask Vanna.AI.
161
+ """
162
+
163
+ message_log = [
164
+ self.system_message(
165
+ f"You are a helpful data assistant. The user asked the question: '{question}'\n\nThe SQL query for this question was: {sql}\n\nThe following is a pandas DataFrame with the results of the query: \n{df.to_markdown()}\n\n"
166
+ ),
167
+ self.user_message(
168
+ "Generate a list of followup questions that the user might ask about this data. Respond with a list of questions, one per line. Do not answer with any explanations -- just the questions. Remember that there should be an unambiguous SQL query that can be generated from the question. Prefer questions that are answerable outside of the context of this conversation. Prefer questions that are slight modifications of the SQL query that was generated that allow digging deeper into the data. Each question will be turned into a button that the user can click to generate a new SQL query so don't use 'example' type questions. Each question must have a one-to-one correspondence with an instantiated SQL query."
169
+ ),
170
+ ]
171
+
172
+ llm_response = self.submit_prompt(message_log, **kwargs)
155
173
 
156
174
  numbers_removed = re.sub(r"^\d+\.\s*", "", llm_response, flags=re.MULTILINE)
157
175
  return numbers_removed.split("\n")
@@ -169,6 +187,36 @@ class VannaBase(ABC):
169
187
 
170
188
  return [q["question"] for q in question_sql]
171
189
 
190
+ def generate_summary(self, question: str, df: pd.DataFrame, **kwargs) -> str:
191
+ """
192
+ **Example:**
193
+ ```python
194
+ vn.generate_summary("What are the top 10 customers by sales?", df)
195
+ ```
196
+
197
+ Generate a summary of the results of a SQL query.
198
+
199
+ Args:
200
+ question (str): The question that was asked.
201
+ df (pd.DataFrame): The results of the SQL query.
202
+
203
+ Returns:
204
+ str: The summary of the results of the SQL query.
205
+ """
206
+
207
+ message_log = [
208
+ self.system_message(
209
+ f"You are a helpful data assistant. The user asked the question: '{question}'\n\nThe following is a pandas DataFrame with the results of the query: \n{df.to_markdown()}\n\n"
210
+ ),
211
+ self.user_message(
212
+ "Briefly summarize the data based on the question that was asked. Do not respond with any additional explanation beyond the summary."
213
+ ),
214
+ ]
215
+
216
+ summary = self.submit_prompt(message_log, **kwargs)
217
+
218
+ return summary
219
+
172
220
  # ----------------- Use Any Embeddings API ----------------- #
173
221
  @abstractmethod
174
222
  def generate_embedding(self, data: str, **kwargs) -> List[float]:
@@ -184,7 +232,7 @@ class VannaBase(ABC):
184
232
  question (str): The question to get similar questions and their corresponding SQL statements for.
185
233
 
186
234
  Returns:
187
- list: A list of similar questions and their corresponding SQL statements.
235
+ list: A list of similar questions and their corresponding SQL statements.
188
236
  """
189
237
  pass
190
238
 
@@ -224,7 +272,7 @@ class VannaBase(ABC):
224
272
  sql (str): The SQL query to add.
225
273
 
226
274
  Returns:
227
- str: The ID of the training data that was added.
275
+ str: The ID of the training data that was added.
228
276
  """
229
277
  pass
230
278
 
@@ -232,7 +280,7 @@ class VannaBase(ABC):
232
280
  def add_ddl(self, ddl: str, **kwargs) -> str:
233
281
  """
234
282
  This method is used to add a DDL statement to the training data.
235
-
283
+
236
284
  Args:
237
285
  ddl (str): The DDL statement to add.
238
286
 
@@ -265,7 +313,7 @@ class VannaBase(ABC):
265
313
  This method is used to get all the training data from the retrieval layer.
266
314
 
267
315
  Returns:
268
- pd.DataFrame: The training data.
316
+ pd.DataFrame: The training data.
269
317
  """
270
318
  pass
271
319
 
@@ -321,7 +369,10 @@ class VannaBase(ABC):
321
369
  return initial_prompt
322
370
 
323
371
  def add_documentation_to_prompt(
324
- self, initial_prompt: str, documentation_list: list[str], max_tokens: int = 14000
372
+ self,
373
+ initial_prompt: str,
374
+ documentation_list: list[str],
375
+ max_tokens: int = 14000,
325
376
  ) -> str:
326
377
  if len(documentation_list) > 0:
327
378
  initial_prompt += f"\nYou may use the following documentation as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
@@ -389,6 +440,9 @@ class VannaBase(ABC):
389
440
  initial_prompt, ddl_list, max_tokens=14000
390
441
  )
391
442
 
443
+ if self.static_documentation != "":
444
+ doc_list.append(self.static_documentation)
445
+
392
446
  initial_prompt = self.add_documentation_to_prompt(
393
447
  initial_prompt, doc_list, max_tokens=14000
394
448
  )
@@ -599,6 +653,7 @@ class VannaBase(ABC):
599
653
 
600
654
  return df
601
655
 
656
+ self.static_documentation = "This is a Snowflake database"
602
657
  self.run_sql = run_sql_snowflake
603
658
  self.run_sql_is_set = True
604
659
 
@@ -632,6 +687,7 @@ class VannaBase(ABC):
632
687
  def run_sql_sqlite(sql: str):
633
688
  return pd.read_sql_query(sql, conn)
634
689
 
690
+ self.static_documentation = "This is a SQLite database"
635
691
  self.run_sql = run_sql_sqlite
636
692
  self.run_sql_is_set = True
637
693
 
@@ -732,6 +788,11 @@ class VannaBase(ABC):
732
788
  conn.rollback()
733
789
  raise ValidationError(e)
734
790
 
791
+ except Exception as e:
792
+ conn.rollback()
793
+ raise e
794
+
795
+ self.static_documentation = "This is a Postgres database"
735
796
  self.run_sql_is_set = True
736
797
  self.run_sql = run_sql_postgres
737
798
 
@@ -821,6 +882,7 @@ class VannaBase(ABC):
821
882
  raise errors
822
883
  return None
823
884
 
885
+ self.static_documentation = "This is a BigQuery database"
824
886
  self.run_sql_is_set = True
825
887
  self.run_sql = run_sql_bigquery
826
888
 
@@ -829,7 +891,7 @@ class VannaBase(ABC):
829
891
  Connect to a DuckDB database. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
830
892
 
831
893
  Args:
832
- url (str): The URL of the database to connect to.
894
+ url (str): The URL of the database to connect to. Use :memory: to create an in-memory database. Use md: or motherduck: to use the MotherDuck database.
833
895
  init_sql (str, optional): SQL to run when connecting to the database. Defaults to None.
834
896
 
835
897
  Returns:
@@ -843,13 +905,15 @@ class VannaBase(ABC):
843
905
  " run command: \npip install vanna[duckdb]"
844
906
  )
845
907
  # URL of the database to download
846
- if url==":memory:" or url=="":
847
- path=":memory:"
908
+ if url == ":memory:" or url == "":
909
+ path = ":memory:"
848
910
  else:
849
911
  # Path to save the downloaded database
850
912
  print(os.path.exists(url))
851
913
  if os.path.exists(url):
852
- path=url
914
+ path = url
915
+ elif url.startswith("md") or url.startswith("motherduck"):
916
+ path = url
853
917
  else:
854
918
  path = os.path.basename(urlparse(url).path)
855
919
  # Download the database if it doesn't exist
@@ -867,9 +931,57 @@ class VannaBase(ABC):
867
931
  def run_sql_duckdb(sql: str):
868
932
  return conn.query(sql).to_df()
869
933
 
934
+ self.static_documentation = "This is a DuckDB database"
870
935
  self.run_sql = run_sql_duckdb
871
936
  self.run_sql_is_set = True
872
937
 
938
+ def connect_to_mssql(self, odbc_conn_str: str):
939
+ """
940
+ Connect to a Microsoft SQL Server database. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
941
+
942
+ Args:
943
+ odbc_conn_str (str): The ODBC connection string.
944
+
945
+ Returns:
946
+ None
947
+ """
948
+ try:
949
+ import pyodbc
950
+ except ImportError:
951
+ raise DependencyError(
952
+ "You need to install required dependencies to execute this method,"
953
+ " run command: pip install pyodbc"
954
+ )
955
+
956
+ try:
957
+ import sqlalchemy as sa
958
+ from sqlalchemy.engine import URL
959
+ except ImportError:
960
+ raise DependencyError(
961
+ "You need to install required dependencies to execute this method,"
962
+ " run command: pip install sqlalchemy"
963
+ )
964
+
965
+ connection_url = URL.create(
966
+ "mssql+pyodbc", query={"odbc_connect": odbc_conn_str}
967
+ )
968
+
969
+ from sqlalchemy import create_engine
970
+
971
+ engine = create_engine(connection_url)
972
+
973
+ def run_sql_mssql(sql: str):
974
+ # Execute the SQL statement and return the result as a pandas DataFrame
975
+ with engine.begin() as conn:
976
+ df = pd.read_sql_query(sa.text(sql), conn)
977
+ return df
978
+
979
+ raise Exception("Couldn't run sql")
980
+
981
+ self.static_documentation = "This is a Microsoft SQL Server database"
982
+ self.run_sql = run_sql_mssql
983
+ self.run_sql_is_set = True
984
+
873
985
  def run_sql(self, sql: str, **kwargs) -> pd.DataFrame:
874
986
  """
875
987
  Example:
@@ -894,7 +1006,7 @@ class VannaBase(ABC):
894
1006
  question: Union[str, None] = None,
895
1007
  print_results: bool = True,
896
1008
  auto_train: bool = True,
897
- visualize: bool = True, # if False, will not generate plotly code
1009
+ visualize: bool = True, # if False, will not generate plotly code
898
1010
  ) -> Union[
899
1011
  Tuple[
900
1012
  Union[str, None],
@@ -975,7 +1087,9 @@ class VannaBase(ABC):
975
1087
  display = __import__(
976
1088
  "IPython.display", fromlist=["display"]
977
1089
  ).display
978
- Image = __import__("IPython.display", fromlist=["Image"]).Image
1090
+ Image = __import__(
1091
+ "IPython.display", fromlist=["Image"]
1092
+ ).Image
979
1093
  img_bytes = fig.to_image(format="png", scale=2)
980
1094
  display(Image(img_bytes))
981
1095
  except Exception as e:
@@ -1328,4 +1442,3 @@ class VannaBase(ABC):
1328
1442
  fig.update_layout(template="plotly_dark")
1329
1443
 
1330
1444
  return fig
1331
-
@@ -1,7 +1,6 @@
1
1
  import json
2
- from typing import List
3
2
  import uuid
4
- from abc import abstractmethod
3
+ from typing import List
5
4
 
6
5
  import chromadb
7
6
  import pandas as pd
@@ -20,13 +19,28 @@ class ChromaDB_VectorStore(VannaBase):
20
19
  if config is not None:
21
20
  path = config.get("path", ".")
22
21
  self.embedding_function = config.get("embedding_function", default_ef)
22
+ curr_client = config.get("client", "persistent")
23
+ self.n_results = config.get("n_results", 10)
23
24
  else:
24
25
  path = "."
25
26
  self.embedding_function = default_ef
27
+ curr_client = "persistent" # defaults to persistent storage
28
+ self.n_results = 10 # defaults to 10 documents
29
+
30
+ if curr_client == "persistent":
31
+ self.chroma_client = chromadb.PersistentClient(
32
+ path=path, settings=Settings(anonymized_telemetry=False)
33
+ )
34
+ elif curr_client == "in-memory":
35
+ self.chroma_client = chromadb.EphemeralClient(
36
+ settings=Settings(anonymized_telemetry=False)
37
+ )
38
+ elif isinstance(curr_client, chromadb.api.client.Client):
39
+ # allow providing client directly
40
+ self.chroma_client = curr_client
41
+ else:
42
+ raise ValueError(f"Unsupported client was set in config: {curr_client}")
26
43
 
27
- self.chroma_client = chromadb.PersistentClient(
28
- path=path, settings=Settings(anonymized_telemetry=False)
29
- )
30
44
  self.documentation_collection = self.chroma_client.get_or_create_collection(
31
45
  name="documentation", embedding_function=self.embedding_function
32
46
  )
@@ -196,7 +210,8 @@ class ChromaDB_VectorStore(VannaBase):
196
210
  query_results (pd.DataFrame): The dataframe to use.
197
211
 
198
212
  Returns:
199
- List[str] or None: The extracted documents, or an empty list or single document if an error occurred.
213
+ List[str] or None: The extracted documents, or an empty list or
214
+ single document if an error occurred.
200
215
  """
201
216
  if query_results is None:
202
217
  return []
@@ -216,6 +231,7 @@ class ChromaDB_VectorStore(VannaBase):
216
231
  return ChromaDB_VectorStore._extract_documents(
217
232
  self.sql_collection.query(
218
233
  query_texts=[question],
234
+ n_results=self.n_results,
219
235
  )
220
236
  )
221
237
 
@@ -7,6 +7,8 @@ import flask
7
7
  import requests
8
8
  from flask import Flask, Response, jsonify, request
9
9
 
10
+ from .assets import css_content, html_content, js_content
11
+
10
12
 
11
13
  class Cache(ABC):
12
14
  @abstractmethod
@@ -92,10 +94,11 @@ class VannaFlaskApp:
92
94
 
93
95
  return decorator
94
96
 
95
- def __init__(self, vn, cache: Cache = MemoryCache()):
97
+ def __init__(self, vn, cache: Cache = MemoryCache(), allow_llm_to_see_data=False):
96
98
  self.flask_app = Flask(__name__)
97
99
  self.vn = vn
98
100
  self.cache = cache
101
+ self.allow_llm_to_see_data = allow_llm_to_see_data
99
102
 
100
103
  log = logging.getLogger("werkzeug")
101
104
  log.setLevel(logging.ERROR)
@@ -296,23 +299,55 @@ class VannaFlaskApp:
296
299
  return jsonify({"type": "error", "error": str(e)})
297
300
 
298
301
  @self.flask_app.route("/api/v0/generate_followup_questions", methods=["GET"])
299
- @self.requires_cache(["df", "question"])
300
- def generate_followup_questions(id: str, df, question):
301
- followup_questions = []
302
- # followup_questions = vn.generate_followup_questions(question=question, df=df)
303
- # if followup_questions is not None and len(followup_questions) > 5:
304
- # followup_questions = followup_questions[:5]
302
+ @self.requires_cache(["df", "question", "sql"])
303
+ def generate_followup_questions(id: str, df, question, sql):
304
+ if self.allow_llm_to_see_data:
305
+ followup_questions = vn.generate_followup_questions(
306
+ question=question, sql=sql, df=df
307
+ )
308
+ if followup_questions is not None and len(followup_questions) > 5:
309
+ followup_questions = followup_questions[:5]
305
310
 
306
- cache.set(id=id, field="followup_questions", value=followup_questions)
311
+ cache.set(id=id, field="followup_questions", value=followup_questions)
307
312
 
308
- return jsonify(
309
- {
310
- "type": "question_list",
311
- "id": id,
312
- "questions": followup_questions,
313
- "header": "Followup Questions can be enabled in a future version if you allow the LLM to 'see' your query results.",
314
- }
315
- )
313
+ return jsonify(
314
+ {
315
+ "type": "question_list",
316
+ "id": id,
317
+ "questions": followup_questions,
318
+ "header": "Here are some potential followup questions:",
319
+ }
320
+ )
321
+ else:
322
+ return jsonify(
323
+ {
324
+ "type": "question_list",
325
+ "id": id,
326
+ "questions": [],
327
+ "header": "Followup Questions can be enabled if you set allow_llm_to_see_data=True",
328
+ }
329
+ )
330
+
331
+ @self.flask_app.route("/api/v0/generate_summary", methods=["GET"])
332
+ @self.requires_cache(["df", "question"])
333
+ def generate_summary(id: str, df, question):
334
+ if self.allow_llm_to_see_data:
335
+ summary = vn.generate_summary(question=question, df=df)
336
+ return jsonify(
337
+ {
338
+ "type": "text",
339
+ "id": id,
340
+ "text": summary,
341
+ }
342
+ )
343
+ else:
344
+ return jsonify(
345
+ {
346
+ "type": "text",
347
+ "id": id,
348
+ "text": "Summarization can be enabled if you set allow_llm_to_see_data=True",
349
+ }
350
+ )
316
351
 
317
352
  @self.flask_app.route("/api/v0/load_question", methods=["GET"])
318
353
  @self.requires_cache(
@@ -352,25 +387,14 @@ class VannaFlaskApp:
352
387
 
353
388
  @self.flask_app.route("/assets/<path:filename>")
354
389
  def proxy_assets(filename):
355
- remote_url = f"https://vanna.ai/assets/{filename}"
356
- response = requests.get(remote_url, stream=True)
390
+ if ".css" in filename:
391
+ return Response(css_content, mimetype="text/css")
357
392
 
358
- # Check if the request to the remote URL was successful
359
- if response.status_code == 200:
360
- excluded_headers = [
361
- "content-encoding",
362
- "content-length",
363
- "transfer-encoding",
364
- "connection",
365
- ]
366
- headers = [
367
- (name, value)
368
- for (name, value) in response.raw.headers.items()
369
- if name.lower() not in excluded_headers
370
- ]
371
- return Response(response.content, response.status_code, headers)
372
- else:
373
- return "Error fetching file from remote server", response.status_code
393
+ if ".js" in filename:
394
+ return Response(js_content, mimetype="text/javascript")
395
+
396
+ # Return 404
397
+ return "File not found", 404
374
398
 
375
399
  # Proxy the /vanna.svg file to the remote server
376
400
  @self.flask_app.route("/vanna.svg")
@@ -398,24 +422,7 @@ class VannaFlaskApp:
398
422
  @self.flask_app.route("/", defaults={"path": ""})
399
423
  @self.flask_app.route("/<path:path>")
400
424
  def hello(path: str):
401
- return """
402
- <!doctype html>
403
- <html lang="en">
404
- <head>
405
- <meta charset="UTF-8" />
406
- <link rel="icon" type="image/svg+xml" href="/vanna.svg" />
407
- <meta name="viewport" content="width=device-width, initial-scale=1.0" />
408
- <link href="https://fonts.googleapis.com/css2?family=Roboto+Slab:wght@350&display=swap" rel="stylesheet">
409
- <script src="https://cdn.plot.ly/plotly-latest.min.js" type="text/javascript"></script>
410
- <title>Vanna.AI</title>
411
- <script type="module" crossorigin src="/assets/index-d29524f4.js"></script>
412
- <link rel="stylesheet" href="/assets/index-b1a5a2f1.css">
413
- </head>
414
- <body class="bg-white dark:bg-slate-900">
415
- <div id="app"></div>
416
- </body>
417
- </html>
418
- """
425
+ return html_content
419
426
 
420
427
  def run(self):
421
428
  try: