kobai-sdk 0.2.8rc3__tar.gz → 0.2.8rc5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kobai-sdk might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: kobai-sdk
3
- Version: 0.2.8rc3
3
+ Version: 0.2.8rc5
4
4
  Summary: A package that enables interaction with a Kobai tenant.
5
5
  Author-email: Ryan Oattes <ryan@kobai.io>
6
6
  License: Apache License
@@ -222,7 +222,7 @@ Requires-Dist: azure-storage-blob
222
222
  Requires-Dist: langchain-core
223
223
  Requires-Dist: langchain-community
224
224
  Requires-Dist: langchain_openai
225
- Requires-Dist: sentence_transformers
225
+ Requires-Dist: databricks_langchain
226
226
  Provides-Extra: dev
227
227
  Requires-Dist: black; extra == "dev"
228
228
  Requires-Dist: bumpver; extra == "dev"
@@ -0,0 +1,255 @@
1
+ from langchain_core.prompts import ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate, AIMessagePromptTemplate
2
+ from langchain_core.output_parsers import StrOutputParser
3
+
4
+ from sentence_transformers import SentenceTransformer, util
5
+
6
+ from langchain_core.language_models.chat_models import BaseChatModel
7
+ from langchain_core.embeddings import Embeddings
8
+ from langchain_core.documents import Document
9
+ from langchain_core.retrievers import BaseRetriever
10
+ from langchain_core.callbacks import CallbackManagerForRetrieverRun
11
+ from langchain_core.runnables import RunnablePassthrough, RunnableLambda
12
+ from langchain_core.vectorstores import InMemoryVectorStore
13
+
14
+ from typing import Union, List
15
+
16
+
17
+ MESSAGE_SYSTEM_TEMPLATE = """
18
+ You are a data analyst tasked with answering questions based on a provided data set. Please answer the questions based on the provided context below. Make sure not to make any changes to the context, if possible, when preparing answers to provide accurate responses. If the answer cannot be found in context, just politely say that you do not know, do not try to make up an answer.
19
+ When you receive a question from the user, answer only that one question in a concise manner. Do not elaborate with other questions.
20
+ """
21
+
22
+ MESSAGE_AI_TEMPLATE = """
23
+ The table information is as follows:
24
+ {table_data}
25
+ """
26
+
27
+ MESSAGE_USER_CONTEXT_TEMPLATE = """
28
+ The context being provided is from a table named: {table_name}
29
+ """
30
+
31
+ MESSAGE_USER_QUESTION_TEMPLATE = """
32
+ {question}
33
+ """
34
+
35
+ SIMPLE_PROMPT_TEMPLATE = f"""
36
+ {MESSAGE_SYSTEM_TEMPLATE}
37
+
38
+ {MESSAGE_USER_CONTEXT_TEMPLATE}
39
+
40
+ {MESSAGE_AI_TEMPLATE}
41
+
42
+ Question: {MESSAGE_USER_QUESTION_TEMPLATE}
43
+ """
44
+
45
+ class QuestionRetriever(BaseRetriever):
46
+ #https://python.langchain.com/docs/how_to/custom_retriever/
47
+ #https://github.com/langchain-ai/langchain/issues/12304
48
+
49
+ documents: List[Document]
50
+ k: int = 5000
51
+
52
+ #def __init__(self, documents: List[Document], k: int = 5000):
53
+ # self.documents = documents
54
+ # self.k = k
55
+
56
+ def _get_relevant_documents(
57
+ self, query: str, *, run_manager: CallbackManagerForRetrieverRun
58
+ ) -> List[Document]:
59
+ """Sync implementations for retriever."""
60
+ matching_documents = []
61
+ for document in self.documents:
62
+ if len(matching_documents) > self.k:
63
+ return matching_documents
64
+
65
+ #if query.lower() in document.page_content.lower():
66
+ # matching_documents.append(document)
67
+ matching_documents.append(document)
68
+ return matching_documents
69
+
70
+ def format_docs(docs):
71
+ return "\n\n".join([d.page_content for d in docs])
72
+
73
+ def input_only(inpt):
74
+ return inpt["question"]
75
+
76
+ def followup_question(user_question, question_results, question_name, question_def, embedding_model: Union[SentenceTransformer, Embeddings], chat_model: BaseChatModel, use_inmem_vectors=False):
77
+
78
+ row_texts = process_question_results(question_def, question_results)
79
+ question_documents = [Document(page_content=r, metadata={"source": "kobai"}) for r in row_texts]
80
+
81
+ if use_inmem_vectors:
82
+ question_retriever = InMemoryVectorStore.from_documents(question_documents, embedding=embedding_model).as_retriever(
83
+ search_kwargs={"k": 5}
84
+ )
85
+ else:
86
+ question_retriever = QuestionRetriever(documents=question_documents)
87
+
88
+ output_parser = StrOutputParser()
89
+
90
+ prompt = ChatPromptTemplate.from_messages(
91
+ [
92
+ SystemMessagePromptTemplate.from_template(
93
+ MESSAGE_SYSTEM_TEMPLATE),
94
+ HumanMessagePromptTemplate.from_template(
95
+ MESSAGE_USER_CONTEXT_TEMPLATE),
96
+ AIMessagePromptTemplate.from_template(MESSAGE_AI_TEMPLATE),
97
+ HumanMessagePromptTemplate.from_template(
98
+ MESSAGE_USER_QUESTION_TEMPLATE)
99
+ ]
100
+ )
101
+
102
+ chain = (
103
+ {"table_name": RunnablePassthrough(), "table_data": RunnableLambda(input_only) | question_retriever | format_docs, "question": RunnablePassthrough()}
104
+ | prompt
105
+ | chat_model
106
+ | output_parser
107
+ )
108
+ response = chain.invoke(
109
+ {
110
+ "table_name": question_name,
111
+ "question": user_question
112
+ }
113
+ )
114
+
115
+ return response
116
+
117
+ def init_question_search_index(tenant_questions, emb_model):
118
+
119
+ q_ids = [q["id"] for q in tenant_questions]
120
+ q_descs = [q["description"] for q in tenant_questions]
121
+
122
+ if isinstance(emb_model, SentenceTransformer):
123
+ q_vectors = emb_model.encode(q_descs)
124
+ else:
125
+ q_vectors = emb_model.embed_documents(q_descs)
126
+
127
+ return {"ids": q_ids, "descs": q_descs, "vectors": q_vectors}
128
+
129
+
130
+ def question_search(search_text: str, search_index, emb_model, k: int):
131
+ if isinstance(emb_model, SentenceTransformer):
132
+ search_vec = emb_model.encode(search_text)
133
+ else:
134
+ search_vec = emb_model.embed_query(search_text)
135
+ #search_vec = emb_model.encode(search_text)
136
+
137
+ matches = __top_vector_matches(search_vec, search_index["vectors"], top=k)
138
+
139
+ for mi, m in enumerate(matches):
140
+ matches[mi]["id"] = search_index["ids"][m["index"]]
141
+ matches[mi]["description"] = search_index["descs"][m["index"]]
142
+ return matches
143
+
144
+ def __top_vector_matches(test_vec, options_list_vec, top=1):
145
+ scores_t = util.cos_sim(test_vec, options_list_vec)[0]
146
+ scores_l = scores_t.tolist()
147
+ scores_d = [{"index": i, "value": v} for i, v in enumerate(scores_l)]
148
+ sorted_d = sorted(scores_d, key=lambda i: i["value"], reverse=True)
149
+ top_d = sorted_d[0:top]
150
+ return top_d
151
+
152
+ def process_question_results(question_def, question_results):
153
+
154
+ """
155
+ Returns a template to format each row in Kobai JSON question output into a format readable by LLMs.
156
+
157
+ Parameters:
158
+ question_def (any): Kobai standard JSON definition of question.
159
+ question_results (any): JSON representation of Kobai base question results.
160
+ """
161
+
162
+ concept_props = {}
163
+ concept_rels = {}
164
+
165
+ for ci in question_def["definition"]:
166
+ con_name = question_def["definition"][ci]["label"].replace("_", " ")
167
+ con_label = question_def["definition"][ci]["label"]
168
+ concept_props[ci] = {"name": con_name, "props": []}
169
+ for p in question_def["definition"][ci]["properties"]:
170
+ if p["hidden"] == False:
171
+ if len(p["aggregates"]) > 0:
172
+ for a in p["aggregates"]:
173
+ prop_column = con_label + "_" + p["label"] + "_" + a["type"]
174
+ prop_name = p["label"].replace("_", " ")
175
+ concept_props[ci]["props"].append({"column": prop_column, "name": prop_name, "agg": a["type"]})
176
+ else:
177
+ prop_column = con_label + "_" + p["label"]
178
+ prop_name = p["label"].replace("_", " ")
179
+ concept_props[ci]["props"].append({"column": prop_column, "name": prop_name, "agg": None})
180
+ for r in question_def["definition"][ci]["relations"]:
181
+ prop_name = question_def["definition"][ci]["relations"][r]["label"].replace("_", " ")
182
+ for ri in question_def["definition"][ci]["relations"][r]["relationInstances"]:
183
+ if ci not in concept_rels:
184
+ concept_rels[ci] = {"count": 0, "edges": []}
185
+ concept_rels[ci]["edges"].append({"src": ci, "dst": ri["relationTypeUri"], "name": prop_name})
186
+ concept_rels[ci]["count"] += 1
187
+
188
+
189
+ row_texts = {}
190
+
191
+ for ci, c in concept_props.items():
192
+ p_texts = []
193
+ for p in c["props"]:
194
+ if p["agg"] is None:
195
+ p_text = p["name"] + " " + "{" + p["column"] + "}"
196
+ else:
197
+ p_text = p["agg"] + " of " + p["name"] + " " + "{" + p["column"] + "}"
198
+ p_texts.append(p_text)
199
+ c_text = __get_article(c["name"]) + " " + c["name"]
200
+ if len(c["props"]) > 0:
201
+ c_text += " with " + __smart_comma_formatting(p_texts)
202
+ row_texts[ci] = c_text
203
+
204
+ max_src = ""
205
+ max_src_count = -1
206
+
207
+ for r in concept_rels:
208
+ if concept_rels[r]["count"] > max_src_count:
209
+ max_src_count = concept_rels[r]["count"]
210
+ max_src = r
211
+
212
+
213
+ concept_order = [max_src]
214
+ for t in concept_rels[max_src]["edges"]:
215
+ concept_order.append(t["dst"])
216
+
217
+ for c in concept_props:
218
+ if c not in concept_order:
219
+ concept_order.append(c)
220
+
221
+ row_template = concept_order[0] + " is connected to " + " and connected to ".join(concept_order[1:])
222
+
223
+ for c in row_texts:
224
+ row_template = row_template.replace(c, row_texts[c])
225
+
226
+ row_template = row_template[0].upper() + row_template[1:] + "."
227
+
228
+ row_texts = []
229
+ for row in question_results:
230
+ row_text = row_template
231
+ for col in row:
232
+ row_text = row_text.replace("{" + col + "}", str(row[col]))
233
+ row_texts.append(row_text)
234
+ #data = "\n".join(row_texts)
235
+ return row_texts
236
+ #return data
237
+
238
+ def __smart_comma_formatting(items):
239
+ if items == None:
240
+ return ""
241
+ match len(items):
242
+ case 0:
243
+ return ""
244
+ case 1:
245
+ return items[0]
246
+ case 2:
247
+ return items[0] + " and " + items[1]
248
+ case _:
249
+ return ", ".join(items[0: -1]) + " and " + items[-1]
250
+
251
+ def __get_article(label):
252
+ if label[0:1].lower() in ["a", "e", "i", "o", "u"]:
253
+ return "an"
254
+ else:
255
+ return "a"
@@ -1,4 +1,5 @@
1
- from kobai import tenant_client
1
+ from kobai import tenant_api
2
+ from pyspark.sql import SparkSession
2
3
 
3
4
  from pyspark.sql.types import StructType, StructField, StringType, ArrayType, FloatType, IntegerType
4
5
  from pyspark.sql import functions as F
@@ -11,12 +12,43 @@ from langchain_community.document_loaders import PySparkDataFrameLoader
11
12
  from langchain import hub
12
13
  from langchain_core.output_parsers import StrOutputParser
13
14
 
15
+ import urllib
16
+ import urllib.parse
14
17
 
18
+ class AIContext:
15
19
 
20
+ schema: str
21
+ spark_session: SparkSession
22
+ model_id: str
23
+ tenant_json: str
24
+ api_client: tenant_api.TenantAPI
16
25
 
26
+ def ai_run_question_remote(tc: AIContext, question_id, dynamic_filters: dict = None):
17
27
 
18
- def generate_sentences(tc: tenant_client.TenantClient, replica_schema=None, concept_white_list=None, use_questions=False):
28
+ """
29
+ Returns JSON formatted result of Kobai question.
19
30
 
31
+ Parameters:
32
+ question_id (int): Numeric identifier of Kobai question.
33
+ """
34
+
35
+ uri = '/data-svcs/api/query/' + str(question_id) + '/execute?' #'/data-svcs/api/query/4518/solution/9/execute/tabular?'
36
+
37
+ queryParams = {'jsontype': 'tableau'}
38
+
39
+ if bool(dynamic_filters):
40
+ queryParams.update(dynamic_filters)
41
+
42
+ uri += urllib.parse.urlencode(queryParams)
43
+
44
+ json={
45
+ 'simulations': {'concepts': {}, 'data': None}
46
+ }
47
+ response = tc.api_client._TenantAPI__run_post(uri, json)
48
+
49
+ return response.json()
50
+
51
+ def generate_sentences(tc: AIContext, replica_schema=None, concept_white_list=None, use_questions=False):
20
52
  """
21
53
  Extract Semantic Data from Graph to Delta Table
22
54
 
@@ -26,24 +58,27 @@ def generate_sentences(tc: tenant_client.TenantClient, replica_schema=None, conc
26
58
  concept_white_list ([str]) OPTIONAL: A list of Domain and Concept names for extraction.
27
59
  use_questions (bool) OPTIONAL: Extract facts from published Kobai questions.
28
60
  """
29
-
30
- if tc.spark_client is None:
31
- return None
32
-
33
- ss = tc.spark_client.spark_session
61
+
62
+ #if tc.spark_client is None:
63
+ # return None
64
+
65
+ ss = tc.spark_session
34
66
 
35
67
  print("Getting Tenant Config")
36
- tenant_json = tc.get_tenant_config()
68
+ tenant_json = tc.tenant_json
37
69
 
38
- concepts = __get_concept_metadata(tenant_json, tc.schema, tc.model_id, concept_white_list)
70
+ concepts = __get_concept_metadata(
71
+ tenant_json, tc.schema, tc.model_id, concept_white_list)
39
72
 
40
73
  print("Dropping and Recreating the RAG Table")
41
74
  ss.sql(__create_rag_table_sql(tc.schema, tc.model_id))
42
75
 
43
76
  print("Generating Extraction SQL")
44
77
  sql_statements = []
45
- sql_statements.extend(__generate_sentence_sql_concept_literals(concepts, tc.schema, tc.model_id))
46
- sql_statements.extend(__generate_sentence_sql_concept_relations(concepts, tc.schema, tc.model_id))
78
+ sql_statements.extend(__generate_sentence_sql_concept_literals(
79
+ concepts, tc.schema, tc.model_id))
80
+ sql_statements.extend(__generate_sentence_sql_concept_relations(
81
+ concepts, tc.schema, tc.model_id))
47
82
 
48
83
  print("Running the Extraction")
49
84
  for sql_statement in sql_statements:
@@ -55,14 +90,16 @@ def generate_sentences(tc: tenant_client.TenantClient, replica_schema=None, conc
55
90
  if replica_schema is not None:
56
91
  print("Replicating Schema")
57
92
  ss.sql(__create_rag_table_sql(replica_schema, tc.model_id))
58
- ss.sql(__replicate_to_catalog_sql(tc.schema, replica_schema, tc.model_id))
93
+ ss.sql(__replicate_to_catalog_sql(
94
+ tc.schema, replica_schema, tc.model_id))
59
95
 
60
- def __generate_sentences_from_questions(tc: tenant_client.TenantClient):
61
- ss = tc.spark_client.spark_session
96
+
97
+ def __generate_sentences_from_questions(tc: AIContext):
98
+ ss = tc.spark_session
62
99
 
63
100
  print("Getting Question Data")
64
101
 
65
- tenant_json = tc.get_tenant_config()
102
+ tenant_json = tc.tenant_json
66
103
 
67
104
  published_queries = []
68
105
  for p in tenant_json["publishedAPIs"]:
@@ -73,22 +110,21 @@ def __generate_sentences_from_questions(tc: tenant_client.TenantClient):
73
110
  if q["id"] in published_queries:
74
111
  question_names[q["id"]] = q["description"]
75
112
 
76
- schemaV = StructType([
77
- StructField("sentence",StringType(),True),
78
- StructField("query_id", StringType(), True)
79
- ])
113
+ schema_v = StructType([
114
+ StructField("sentence", StringType(), True),
115
+ StructField("query_id", StringType(), True)
116
+ ])
80
117
 
81
118
  sentences = []
82
119
  for p in published_queries:
83
- output = tc.run_question_remote(p)
120
+ output = ai_run_question_remote(tc, p)
84
121
  for r in output:
85
122
  sentence = f"For {question_names[p]}: "
86
123
  for c in r:
87
124
  sentence += f"The {c.replace('_', ' ')} is {r[c]}. "
88
125
  sentences.append([sentence, p])
89
126
 
90
-
91
- sentences_df = ss.createDataFrame(sentences, schemaV)
127
+ sentences_df = ss.createDataFrame(sentences, schema_v)
92
128
  sentences_df = sentences_df.select(
93
129
  F.col("sentence").alias("sentence"),
94
130
  F.col("query_id").alias("concept_id"),
@@ -96,19 +132,17 @@ def __generate_sentences_from_questions(tc: tenant_client.TenantClient):
96
132
  )
97
133
 
98
134
  schema = tc.schema
99
-
135
+
100
136
  view_name = f"rag_{tc.model_id}_question_sentences"
101
137
  sentences_df.createOrReplaceTempView(view_name)
102
138
 
103
139
  full_sql = f"INSERT INTO {schema}.rag_{tc.model_id} (content, concept_id, type)"
104
140
  full_sql += f" SELECT sentence, concept_id, type FROM {view_name}"
105
-
106
- ss.sql(full_sql)
107
141
 
142
+ ss.sql(full_sql)
108
143
 
109
144
 
110
- def encode_to_delta_local(tc: tenant_client.TenantClient, st_model: SentenceTransformer, replica_schema=None):
111
-
145
+ def encode_to_delta_local(tc: AIContext, st_model: SentenceTransformer, replica_schema=None):
112
146
  """
113
147
  Encode Semantic Data to Vectors in Delta Table
114
148
 
@@ -118,10 +152,10 @@ def encode_to_delta_local(tc: tenant_client.TenantClient, st_model: SentenceTran
118
152
  replica_schema (str) OPTIONAL: An alternate schema (catalog.database) to create the Delta table. Useful when the base Kobai schema is not on a Unity Catalog.
119
153
  """
120
154
 
121
- if tc.spark_client is None:
122
- return None
123
-
124
- ss = tc.spark_client.spark_session
155
+ #if tc.spark_client is None:
156
+ # return None
157
+
158
+ ss = tc.spark_session
125
159
 
126
160
  schema = tc.schema
127
161
  if replica_schema is not None:
@@ -133,7 +167,6 @@ def encode_to_delta_local(tc: tenant_client.TenantClient, st_model: SentenceTran
133
167
  num_records = sentences_df.count()
134
168
  query_batch_size = 100000
135
169
 
136
-
137
170
  for x in range(0, num_records, query_batch_size):
138
171
  print(f"Running Batch Starting at {x}")
139
172
  sentences_sql = f" SELECT id, content FROM {schema}.rag_{tc.model_id} ORDER BY id LIMIT {str(query_batch_size)} OFFSET {str(x)}"
@@ -141,25 +174,27 @@ def encode_to_delta_local(tc: tenant_client.TenantClient, st_model: SentenceTran
141
174
  content_list = [r["content"] for r in sentences_df.collect()]
142
175
  id_list = [r["id"] for r in sentences_df.collect()]
143
176
 
144
- vector_list = st_model.encode(content_list, normalize_embeddings=True, show_progress_bar=True)
177
+ vector_list = st_model.encode(
178
+ content_list, normalize_embeddings=True, show_progress_bar=True)
145
179
 
146
- schemaV = StructType([
147
- StructField("id",IntegerType(),True),
180
+ schema_v = StructType([
181
+ StructField("id", IntegerType(), True),
148
182
  StructField("vector", ArrayType(FloatType()), False)
149
183
  ])
150
184
 
151
- updated_list = [[r[0], r[1].tolist()] for r in zip(id_list, vector_list)]
152
- updated_df = ss.createDataFrame(updated_list, schemaV)
185
+ updated_list = [[r[0], r[1].tolist()]
186
+ for r in zip(id_list, vector_list)]
187
+ updated_df = ss.createDataFrame(updated_list, schema_v)
153
188
 
154
189
  target_table = DeltaTable.forName(ss, f"{schema}.rag_{tc.model_id}")
155
190
 
156
191
  target_table.alias("t") \
157
- .merge(
192
+ .merge(
158
193
  updated_df.alias("s"),
159
194
  't.id = s.id'
160
195
  ) \
161
- .whenMatchedUpdate(set = {"vector": "s.vector"}) \
162
- .execute()
196
+ .whenMatchedUpdate(set={"vector": "s.vector"}) \
197
+ .execute()
163
198
 
164
199
  ss.sql(f"""
165
200
  CREATE FUNCTION IF NOT EXISTS {schema}.cos_sim(a ARRAY<FLOAT>, b ARRAY<FLOAT>)
@@ -171,8 +206,8 @@ def encode_to_delta_local(tc: tenant_client.TenantClient, st_model: SentenceTran
171
206
  $$
172
207
  """)
173
208
 
174
- def rag_delta(tc: tenant_client.TenantClient, emb_model: Union[SentenceTransformer, Embeddings], chat_model: BaseChatModel, question, k=5, replica_schema=None):
175
209
 
210
+ def rag_delta(tc: AIContext, emb_model: Union[SentenceTransformer, Embeddings], chat_model: BaseChatModel, question, k=5, replica_schema=None):
176
211
  """
177
212
  Run a RAG query using vectors in Delta table.
178
213
 
@@ -184,25 +219,26 @@ def rag_delta(tc: tenant_client.TenantClient, emb_model: Union[SentenceTransform
184
219
  k (int) OPTIONAL: The number of RAG documents to retrieve.
185
220
  replica_schema (str) OPTIONAL: An alternate schema (catalog.database) to create the Delta table. Useful when the base Kobai schema is not on a Unity Catalog.
186
221
  """
187
-
222
+
188
223
  schema = tc.schema
189
224
  if replica_schema is not None:
190
225
  schema = replica_schema
191
226
 
192
- if tc.spark_client is None:
193
- print("Instantiate Spark Client First")
194
- return None
195
-
196
- ss = tc.spark_client.spark_session
227
+ #if tc.spark_client is None:
228
+ # print("Instantiate Spark Client First")
229
+ # return None
230
+
231
+ ss = tc.spark_session
197
232
 
198
233
  if isinstance(emb_model, SentenceTransformer):
199
- vector_list = emb_model.encode(question, normalize_embeddings=True).tolist()
234
+ vector_list = emb_model.encode(
235
+ question, normalize_embeddings=True).tolist()
200
236
  elif isinstance(emb_model, Embeddings):
201
237
  vector_list = emb_model.embed_query(question)
202
238
  else:
203
239
  print("Invalid Embedding Model Type")
204
240
  return None
205
-
241
+
206
242
  if not isinstance(chat_model, BaseChatModel):
207
243
  print("Invalid Chat Model Type")
208
244
  return None
@@ -216,7 +252,7 @@ def rag_delta(tc: tenant_client.TenantClient, emb_model: Union[SentenceTransform
216
252
  ORDER BY score DESC
217
253
  LIMIT {k}
218
254
  """)
219
-
255
+
220
256
  loader = PySparkDataFrameLoader(ss, results, page_content_column="content")
221
257
  documents = loader.load()
222
258
  docs_content = "\n\n".join(doc.page_content for doc in documents)
@@ -236,20 +272,24 @@ def rag_delta(tc: tenant_client.TenantClient, emb_model: Union[SentenceTransform
236
272
 
237
273
  return response
238
274
 
275
+
239
276
  def __create_rag_table_sql(schema, model_id):
240
277
  return f"CREATE OR REPLACE TABLE {schema}.rag_{model_id} (id BIGINT GENERATED BY DEFAULT AS IDENTITY, content STRING, type string, concept_id string, vector ARRAY<FLOAT>) TBLPROPERTIES (delta.enableChangeDataFeed = true)"
241
-
278
+
279
+
242
280
  def __replicate_to_catalog_sql(base_schema, target_schema, model_id):
243
281
  move_sql = f"INSERT INTO {target_schema}.rag_{model_id} (content, concept_id, type)"
244
282
  move_sql += f" SELECT content, concept_id, type FROM {base_schema}.rag_{model_id}"
245
283
  return move_sql
246
284
 
285
+
286
+
247
287
  def __generate_sentence_sql_concept_literals(concepts, schema, model_id):
248
288
  statements = []
249
289
  for con in concepts:
250
290
  sql = f"'This is a {con['label']}. '"
251
291
  sql += " || 'It is identified by ' || cid._plain_conceptid || '. '"
252
-
292
+
253
293
  sql_from = f"(SELECT _conceptid, _plain_conceptid FROM {con['prop_table_name']} GROUP BY _conceptid, _plain_conceptid) cid"
254
294
  for prop in con["properties"]:
255
295
 
@@ -257,15 +297,16 @@ def __generate_sentence_sql_concept_literals(concepts, schema, model_id):
257
297
  sql_from += f" ON cid._conceptid = {prop['label']}._conceptid"
258
298
  sql_from += f" AND {prop['label']}.type = 'l'"
259
299
  sql_from += f" AND {prop['label']}.name = '{prop['name']}'"
260
-
300
+
261
301
  sql += f" || 'The {prop['label']} is ' || ifnull(any_value({prop['label']}.value) IGNORE NULLS, 'unknown') || '. '"
262
-
302
+
263
303
  full_sql = f"INSERT INTO {schema}.rag_{model_id} (content, concept_id, type)"
264
304
  full_sql += f" SELECT {sql} content, cid._conceptid concept_id, 'c' type FROM {sql_from} GROUP BY cid._conceptid, cid._plain_conceptid"
265
-
305
+
266
306
  statements.append(full_sql)
267
307
  return statements
268
308
 
309
+
269
310
  def __generate_sentence_sql_concept_relations(concepts, schema, model_id):
270
311
  statements = []
271
312
  for con in concepts:
@@ -280,13 +321,13 @@ def __generate_sentence_sql_concept_relations(concepts, schema, model_id):
280
321
  sql += f" || ' has a relationship called {rel['label']} that connects it to one or more {rel['target_con_label']} identified by '"
281
322
  sql += " || concat_ws(', ', array_agg(cid._plain_conceptid)) || '. '"
282
323
 
283
-
284
324
  full_sql = f"INSERT INTO {schema}.rag_{model_id} (content, concept_id, type)"
285
325
  full_sql += f" SELECT {sql} content, rel._conceptid concept_id, 'e' type FROM {sql_from} GROUP BY rel._conceptid, rel._plain_conceptid"
286
326
 
287
327
  statements.append(full_sql)
288
328
  return statements
289
329
 
330
+
290
331
  def __get_concept_metadata(tenant_json, schema, model_id, whitelist):
291
332
  target_concept_labels = {}
292
333
  target_table_names = {}
@@ -297,7 +338,7 @@ def __get_concept_metadata(tenant_json, schema, model_id, whitelist):
297
338
  "prop": f"{schema}.data_{model_id}_{d['name']}_{c['label']}_np",
298
339
  "con": f"{schema}.data_{model_id}_{d['name']}_{c['label']}_c"
299
340
  }
300
-
341
+
301
342
  concepts = []
302
343
  for d in tenant_json["domains"]:
303
344
  for c in d["concepts"]:
@@ -306,7 +347,7 @@ def __get_concept_metadata(tenant_json, schema, model_id, whitelist):
306
347
  con_props.append({
307
348
  "label": col["label"],
308
349
  "name": f"{model_id}/{d['name']}/{c['label']}#{col['label']}"
309
- })
350
+ })
310
351
  con_rels = []
311
352
  for rel in c["relations"]:
312
353
  if whitelist is not None and target_concept_labels[rel["relationTypeUri"]] not in whitelist:
@@ -328,7 +369,7 @@ def __get_concept_metadata(tenant_json, schema, model_id, whitelist):
328
369
  "parents": con_parents,
329
370
  "prop_table_name": target_table_names[c["uri"]]["prop"],
330
371
  "con_table_name": target_table_names[c["uri"]]["con"]
331
- })
372
+ })
332
373
 
333
374
  for ci, c in enumerate(concepts):
334
375
  if len(c["parents"]) > 0:
@@ -343,4 +384,4 @@ def __get_concept_metadata(tenant_json, schema, model_id, whitelist):
343
384
  continue
344
385
  out_concepts.append(c)
345
386
 
346
- return out_concepts
387
+ return out_concepts
@@ -7,7 +7,14 @@ import requests
7
7
  from azure.identity import DeviceCodeCredential
8
8
  from pyspark.sql import SparkSession
9
9
 
10
- from . import spark_client, databricks_client, ai_query, tenant_api
10
+ from langchain_community.chat_models import ChatDatabricks
11
+ from databricks_langchain import DatabricksEmbeddings
12
+ from sentence_transformers import SentenceTransformer
13
+ from langchain_core.language_models.chat_models import BaseChatModel
14
+ from langchain_core.embeddings import Embeddings
15
+ from typing import Union
16
+
17
+ from . import spark_client, databricks_client, ai_query, tenant_api, ai_rag
11
18
 
12
19
  class TenantClient:
13
20
 
@@ -39,6 +46,10 @@ class TenantClient:
39
46
  self.model_id = ""
40
47
  self.proxies = None
41
48
  self.ssl_verify = True
49
+ self.question_search_index = None
50
+ self.embedding_model = None
51
+ self.chat_model = None
52
+
42
53
 
43
54
  def update_proxy(self, proxies: any):
44
55
  self.proxies = proxies
@@ -99,10 +110,11 @@ class TenantClient:
99
110
 
100
111
  self.__api_init_session()
101
112
  self.__set_tenant_solutionid()
113
+ self.init_ai_components()
102
114
 
103
115
  print("Authentication Successful.")
104
116
 
105
- def authenticate_brower_token(self, access_token):
117
+ def authenticate_browser_token(self, access_token):
106
118
 
107
119
  """
108
120
  Authenticate the TenantClient with the Kobai instance. Returns nothing, but stores bearer token in client.
@@ -116,6 +128,8 @@ class TenantClient:
116
128
 
117
129
  self.__api_init_session()
118
130
  self.__set_tenant_solutionid()
131
+ self.init_ai_components()
132
+
119
133
 
120
134
  print("Authentication Successful.")
121
135
 
@@ -410,212 +424,114 @@ class TenantClient:
410
424
  return return_questions
411
425
 
412
426
  ########################################
413
- # AI Functions
427
+ # RAG Functions
414
428
  ########################################
415
429
 
416
- def followup_question(self, followup_question, question_results, question_id, override_model=None, use_simple_prompt=False, debug=False):
417
-
430
+ def get_ai_context(self):
431
+ context = ai_rag.AIContext()
432
+ context.model_id = self.model_id
433
+ context.schema = self.schema
434
+ context.tenant_json = self.get_tenant_config()
435
+ context.spark_session = self.spark_client.spark_session
436
+ context.api_client = self.api_client
437
+ return context
438
+
439
+ def rag_generate_sentences(self, replica_schema=None, concept_white_list=None, use_questions=False):
418
440
  """
419
- Use LLM to further investigate the results of a Kobai base question.
441
+ Extract Semantic Data from Graph to Delta Table
420
442
 
421
443
  Parameters:
422
- followup_question (str): A natural language question to apply.
423
- question_results (any): JSON representation of Kobai base question results.
424
- question_id (int): Numeric id for base Kobai question.
425
- override_model (LangChain BaseLanguageModel) OPTIONAL: Langchain LLM or ChatModel runnable.
426
- use_simple_prompt (bool) OPTIONAL: Uses ChatPrompt when True, Prompt when False.
427
- debug (bool) OPTIONAL: Set Langchain debug for troubleshooting.
444
+ replica_schema (str) OPTIONAL: An alternate schema (catalog.database) to create the Delta table. Useful when the base Kobai schema is not on a Unity Catalog.
445
+ concept_white_list ([str]) OPTIONAL: A list of Domain and Concept names for extraction.
446
+ use_questions (bool) OPTIONAL: Extract facts from published Kobai questions.
428
447
  """
429
-
430
- question_def = self.get_question(question_id)
431
- question_name = question_def["description"]
432
-
433
- row_texts = []
434
- row_template = self.process_question_results(question_def)
435
- for row in question_results:
436
- row_text = row_template
437
- for col in row:
438
- row_text = row_text.replace("{" + col + "}", str(row[col]))
439
- row_texts.append(row_text)
440
- data = "\n".join(row_texts)
448
+ ai_rag.generate_sentences(self.get_ai_context(), replica_schema=replica_schema, concept_white_list=concept_white_list, use_questions=use_questions)
441
449
 
442
- return ai_query.followup_question(followup_question,
443
- data,
444
- question_name,
445
- None,
446
- override_model=override_model,
447
- )
450
+ def rag_encode_to_delta_local(self, st_model: SentenceTransformer, replica_schema=None):
451
+ """
452
+ Encode Semantic Data to Vectors in Delta Table
448
453
 
449
- def process_question_results(self, question_def):
454
+ Parameters:
455
+ st_model (SentenceTransformer): A sentence_transformers model to use for encoding.
456
+ replica_schema (str) OPTIONAL: An alternate schema (catalog.database) to create the Delta table. Useful when the base Kobai schema is not on a Unity Catalog.
457
+ """
458
+ ai_rag.encode_to_delta_local(self.get_ai_context(), st_model=st_model, replica_schema=replica_schema)
450
459
 
460
+ def rag_delta(self, emb_model: Union[SentenceTransformer, Embeddings], chat_model: BaseChatModel, question, k=5, replica_schema=None):
451
461
  """
452
- Returns a template to format each row in Kobai JSON question output into a format readable by LLMs.
462
+ Run a RAG query using vectors in Delta table.
453
463
 
454
464
  Parameters:
455
- question_def (any): Kobai standard JSON definition of question.
456
- """
457
-
458
- concept_props = {}
459
- concept_rels = {}
460
-
461
- for ci in question_def["definition"]:
462
- con_name = question_def["definition"][ci]["label"].replace("_", " ")
463
- con_label = question_def["definition"][ci]["label"]
464
- concept_props[ci] = {"name": con_name, "props": []}
465
- for p in question_def["definition"][ci]["properties"]:
466
- if p["hidden"] == False:
467
- if len(p["aggregates"]) > 0:
468
- for a in p["aggregates"]:
469
- prop_column = con_label + "_" + p["label"] + "_" + a["type"]
470
- prop_name = p["label"].replace("_", " ")
471
- concept_props[ci]["props"].append({"column": prop_column, "name": prop_name, "agg": a["type"]})
472
- else:
473
- prop_column = con_label + "_" + p["label"]
474
- prop_name = p["label"].replace("_", " ")
475
- concept_props[ci]["props"].append({"column": prop_column, "name": prop_name, "agg": None})
476
- for r in question_def["definition"][ci]["relations"]:
477
- prop_name = question_def["definition"][ci]["relations"][r]["label"].replace("_", " ")
478
- for ri in question_def["definition"][ci]["relations"][r]["relationInstances"]:
479
- if ci not in concept_rels:
480
- concept_rels[ci] = {"count": 0, "edges": []}
481
- concept_rels[ci]["edges"].append({"src": ci, "dst": ri["relationTypeUri"], "name": prop_name})
482
- concept_rels[ci]["count"] += 1
483
-
484
-
485
- row_texts = {}
486
-
487
- for ci, c in concept_props.items():
488
- p_texts = []
489
- for p in c["props"]:
490
- if p["agg"] is None:
491
- p_text = p["name"] + " " + "{" + p["column"] + "}"
492
- else:
493
- p_text = p["agg"] + " of " + p["name"] + " " + "{" + p["column"] + "}"
494
- p_texts.append(p_text)
495
- c_text = self.__get_article(c["name"]) + " " + c["name"]
496
- if len(c["props"]) > 0:
497
- c_text += " with " + self.__smart_comma_formatting(p_texts)
498
- row_texts[ci] = c_text
465
+ emb_model (UNION[SentenceTransformer, Embeddings]): A sentence_transformers or langchain embedding model to use for encoding the query.
466
+ chat_model (BaseChatModel): A langchain chat model to use in the RAG pipeline.
467
+ question (str): The user's query.
468
+ k (int) OPTIONAL: The number of RAG documents to retrieve.
469
+ replica_schema (str) OPTIONAL: An alternate schema (catalog.database) to create the Delta table. Useful when the base Kobai schema is not on a Unity Catalog.
470
+ """
471
+ ai_rag.rag_delta(self.get_ai_context(), emb_model=emb_model, chat_model=chat_model, question=question, k=k, replica_schema=replica_schema)
499
472
 
500
- max_src = ""
501
- max_src_count = -1
473
+ ########################################
474
+ # AI Functions
475
+ ########################################
502
476
 
503
- for r in concept_rels:
504
- if concept_rels[r]["count"] > max_src_count:
505
- max_src_count = concept_rels[r]["count"]
506
- max_src = r
477
+ def followup_question(self, user_question, question_id=None, use_inmem_vectors=False):
478
+ """
479
+ Use LLM to answer question in the context of a Kobai Studio question.
507
480
 
481
+ Parameters:
482
+ user_question (str): A natural language question to apply.
483
+ question_id (int) OPTIONAL: A Kobai question to use as a data source. Otherwise, an appropriate question will be automatically found.
484
+ use_inmem_vectors (bool) OPTIONAL: For large query sets, this secondary processing can reduce the data required in the context window.
485
+ """
508
486
 
509
- concept_order = [max_src]
510
- for t in concept_rels[max_src]["edges"]:
511
- concept_order.append(t["dst"])
487
+ if question_id is None:
512
488
 
513
- for c in concept_props:
514
- if c not in concept_order:
515
- concept_order.append(c)
489
+ suggestions = self.question_search(user_question, k=1)
516
490
 
517
- row_text = concept_order[0] + " is connected to " + " and connected to ".join(concept_order[1:])
491
+ question_id = suggestions[0]["id"]
518
492
 
519
- for c in row_texts:
520
- row_text = row_text.replace(c, row_texts[c])
493
+ question_results = self.run_question_remote(question_id)
521
494
 
522
- row_text = row_text[0].upper() + row_text[1:] + "."
523
- return row_text
524
-
525
- def process_question_results2(self, question_def):
495
+ question_def = self.get_question(question_id)
496
+ question_name = question_def["description"]
526
497
 
498
+ return ai_query.followup_question(user_question, question_results, question_name, question_def, self.embedding_model, self.chat_model, use_inmem_vectors=use_inmem_vectors)
499
+
500
+ def init_ai_components(self, embedding_model: Union[SentenceTransformer, Embeddings] = None, chat_model: BaseChatModel = None):
527
501
  """
528
- Returns a template to format each row in Kobai JSON question output into a format readable by LLMs.
502
+ Set Chat and Embedding models for AI functions to use. If no arguments provided, Databricks hosted services are used.
529
503
 
530
504
  Parameters:
531
- question_def (any): Kobai standard JSON definition of question.
505
+ embedding_model (Union[SentenceTransformer, Embeddings]) OPTIONAL: A sentence_transformer or Langchain Embedding model.
506
+ chat_model (BaseChatModel) OPTIONAL: A Langchain BaseChatModel chat model.
532
507
  """
533
508
 
534
- concept_props = {}
535
- concept_rels = {}
536
-
537
- for ci in question_def["definition"]:
538
- con_name = question_def["definition"][ci]["label"].replace("_", " ")
539
- con_label = question_def["definition"][ci]["label"]
540
- concept_props[ci] = {"name": con_name, "props": []}
541
- for p in question_def["definition"][ci]["properties"]:
542
- if p["hidden"] == False:
543
- if len(p["aggregates"]) > 0:
544
- for a in p["aggregates"]:
545
- prop_column = con_label + "_" + p["label"] + "_" + a["type"]
546
- prop_name = p["label"].replace("_", " ")
547
- concept_props[ci]["props"].append({"column": prop_column, "name": prop_name, "agg": a["type"]})
548
- else:
549
- prop_column = con_label + "_" + p["label"]
550
- prop_name = p["label"].replace("_", " ")
551
- concept_props[ci]["props"].append({"column": prop_column, "name": prop_name, "agg": None})
552
- for r in question_def["definition"][ci]["relations"]:
553
- prop_name = question_def["definition"][ci]["relations"][r]["label"].replace("_", " ")
554
- for ri in question_def["definition"][ci]["relations"][r]["relationInstances"]:
555
- if ci not in concept_rels:
556
- concept_rels[ci] = {"count": 0, "edges": []}
557
- concept_rels[ci]["edges"].append({"src": ci, "dst": ri["relationTypeUri"], "name": prop_name})
558
- concept_rels[ci]["count"] += 1
559
-
560
-
561
- row_texts = {}
562
-
563
- for ci, c in concept_props.items():
564
- p_texts = []
565
- for p in c["props"]:
566
- if p["agg"] is None:
567
- p_text = p["name"] + " " + "{" + p["column"] + "}"
568
- else:
569
- p_text = p["agg"] + " of " + p["name"] + " " + "{" + p["column"] + "}"
570
- p_texts.append(p_text)
571
- c_text = self.__get_article(c["name"]) + " " + c["name"]
572
- if len(c["props"]) > 0:
573
- c_text += " with " + self.__smart_comma_formatting(p_texts)
574
- row_texts[ci] = c_text
509
+ if embedding_model is not None:
510
+ self.embedding_model = embedding_model
511
+ else:
512
+ #self.embedding_model = SentenceTransformer("baai/bge-large-en-v1.5")
513
+ self.embedding_model = DatabricksEmbeddings(endpoint="databricks-bge-large-en")
514
+
515
+ if chat_model is not None:
516
+ self.chat_model = chat_model
517
+ else:
518
+ self.chat_model = ChatDatabricks(endpoint="databricks-dbrx-instruct")
575
519
 
576
- max_src = ""
577
- max_src_count = -1
520
+ self.question_search_index = ai_query.init_question_search_index(self.list_questions(), self.embedding_model)
578
521
 
579
- for r in concept_rels:
580
- if concept_rels[r]["count"] > max_src_count:
581
- max_src_count = concept_rels[r]["count"]
582
- max_src = r
522
+ def question_search(self, search_text, k: int = 1):
523
+ """
524
+ Retrieve metadata about Kobai Questions based on user search text.
583
525
 
526
+ Parameters:
527
+ search_text (str): Text to compare against question names.
528
+ k (int) OPTIONAL: Number of top-k matches to return.
529
+ """
584
530
 
585
- concept_order = [max_src]
586
- for t in concept_rels[max_src]["edges"]:
587
- concept_order.append(t["dst"])
531
+ question_list = ai_query.question_search(search_text, self.question_search_index, self.embedding_model, k)
532
+ return question_list
588
533
 
589
- for c in concept_props:
590
- if c not in concept_order:
591
- concept_order.append(c)
592
534
 
593
- row_text = concept_order[0] + " is connected to " + " and connected to ".join(concept_order[1:])
594
-
595
- for c in row_texts:
596
- row_text = row_text.replace(c, row_texts[c])
597
-
598
- row_text = row_text[0].upper() + row_text[1:] + "."
599
- return row_text
600
-
601
- def __smart_comma_formatting(self, items):
602
- if items == None:
603
- return ""
604
- match len(items):
605
- case 0:
606
- return ""
607
- case 1:
608
- return items[0]
609
- case 2:
610
- return items[0] + " and " + items[1]
611
- case _:
612
- return ", ".join(items[0: -1]) + " and " + items[-1]
613
-
614
- def __get_article(self, label):
615
- if label[0:1].lower() in ["a", "e", "i", "o", "u"]:
616
- return "an"
617
- else:
618
- return "a"
619
535
 
620
536
  ########################################
621
537
  # Tenant Questions
@@ -943,7 +859,14 @@ class TenantClient:
943
859
  response = self.api_client._TenantAPI__run_get('/data-svcs/model/domain/questions/count')
944
860
  for q in response.json()["drafts"]:
945
861
  question_list.append({"id": q["id"], "description": q["description"]})
946
- return question_list
862
+
863
+ visited_ids = []
864
+ unique_question_list = []
865
+ for q in question_list:
866
+ if q["id"] not in visited_ids:
867
+ visited_ids.append(q["id"])
868
+ unique_question_list.append(q)
869
+ return unique_question_list
947
870
 
948
871
  def get_question_id(self, label, domain_label=None):
949
872
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: kobai-sdk
3
- Version: 0.2.8rc3
3
+ Version: 0.2.8rc5
4
4
  Summary: A package that enables interaction with a Kobai tenant.
5
5
  Author-email: Ryan Oattes <ryan@kobai.io>
6
6
  License: Apache License
@@ -222,7 +222,7 @@ Requires-Dist: azure-storage-blob
222
222
  Requires-Dist: langchain-core
223
223
  Requires-Dist: langchain-community
224
224
  Requires-Dist: langchain_openai
225
- Requires-Dist: sentence_transformers
225
+ Requires-Dist: databricks_langchain
226
226
  Provides-Extra: dev
227
227
  Requires-Dist: black; extra == "dev"
228
228
  Requires-Dist: bumpver; extra == "dev"
@@ -7,7 +7,6 @@ kobai/ai_query.py
7
7
  kobai/ai_rag.py
8
8
  kobai/databricks_client.py
9
9
  kobai/demo_tenant_client.py
10
- kobai/llm_config.py
11
10
  kobai/spark_client.py
12
11
  kobai/tenant_api.py
13
12
  kobai/tenant_client.py
@@ -6,7 +6,7 @@ azure-storage-blob
6
6
  langchain-core
7
7
  langchain-community
8
8
  langchain_openai
9
- sentence_transformers
9
+ databricks_langchain
10
10
 
11
11
  [dev]
12
12
  black
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "kobai-sdk"
7
- version = "0.2.8rc3"
7
+ version = "0.2.8rc5"
8
8
  description = "A package that enables interaction with a Kobai tenant."
9
9
  readme = "README.md"
10
10
  authors = [{ name = "Ryan Oattes", email = "ryan@kobai.io" }]
@@ -26,8 +26,8 @@ dependencies = [
26
26
  "langchain-core",
27
27
  "langchain-community",
28
28
  "langchain_openai",
29
- "sentence_transformers"
30
- ]
29
+ "databricks_langchain"
30
+ ]
31
31
  requires-python = ">=3.11"
32
32
 
33
33
  [project.optional-dependencies]
@@ -1,114 +0,0 @@
1
- from kobai import llm_config
2
- from langchain_core.prompts import ChatPromptTemplate, PromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate, AIMessagePromptTemplate
3
- from langchain_core.output_parsers import StrOutputParser
4
- from langchain_community.chat_models import ChatDatabricks
5
- from langchain.globals import set_debug
6
- from azure.identity import DefaultAzureCredential, get_bearer_token_provider
7
- from langchain_openai import AzureChatOpenAI
8
-
9
- MESSAGE_SYSTEM_TEMPLATE = """
10
- You are a data analyst tasked with answering questions based on a provided data set. Please answer the questions based on the provided context below. Make sure not to make any changes to the context, if possible, when preparing answers to provide accurate responses. If the answer cannot be found in context, just politely say that you do not know, do not try to make up an answer.
11
- When you receive a question from the user, answer only that one question in a concise manner. Do not elaborate with other questions.
12
- """
13
-
14
- MESSAGE_AI_TEMPLATE = """
15
- The table information is as follows:
16
- {table_data}
17
- """
18
-
19
- MESSAGE_USER_CONTEXT_TEMPLATE = """
20
- The context being provided is from a table named: {table_name}
21
- """
22
-
23
- MESSAGE_USER_QUESTION_TEMPLATE = """
24
- {question}
25
- """
26
-
27
- SIMPLE_PROMPT_TEMPLATE = f"""
28
- {MESSAGE_SYSTEM_TEMPLATE}
29
-
30
- {MESSAGE_USER_CONTEXT_TEMPLATE}
31
-
32
- {MESSAGE_AI_TEMPLATE}
33
-
34
- Question: {MESSAGE_USER_QUESTION_TEMPLATE}
35
- """
36
-
37
- def followup_question(question, data, question_name, llm_config:llm_config, override_model=None):
38
-
39
- """
40
- Use LLM to answer question in the context of provided data.
41
-
42
- Parameters:
43
- question (str): A natural language question to apply.
44
- data (str): Simple dictionary-like structured data.
45
- question_name (str): Dataset name for context.
46
- llm_config (LLMConfig): User set LLM configurations and some default ones.
47
- override_model (LangChain BaseLanguageModel) OPTIONAL: Langchain LLM or ChatModel runnable.
48
- """
49
-
50
- set_debug(llm_config.debug)
51
-
52
- # If override model is provided, then use the override model as chat model.
53
- if override_model is not None:
54
- chat_model=override_model
55
- elif llm_config.llm_provider == "databricks":
56
- chat_model = ChatDatabricks(
57
- endpoint = llm_config.endpoint,
58
- temperature = llm_config.temperature,
59
- max_tokens = llm_config.max_tokens,
60
- )
61
- elif llm_config.llm_provider == "azure_openai":
62
- if(llm_config.api_key is None):
63
- # Authenticate through AZ Login or through service principal
64
- # Instantiate the AzureChatOpenAI model
65
- chat_model = AzureChatOpenAI(
66
- azure_endpoint=llm_config.endpoint,
67
- azure_deployment=llm_config.deployment,
68
- azure_ad_token=llm_config.aad_token,
69
- openai_api_version=llm_config.api_version,
70
- temperature = llm_config.temperature,
71
- max_tokens = llm_config.max_tokens,
72
- )
73
- else:
74
- # Authenticate through API Key
75
- chat_model = AzureChatOpenAI(
76
- api_key = llm_config.api_key,
77
- azure_endpoint=llm_config.endpoint,
78
- azure_deployment=llm_config.deployment,
79
- openai_api_version=llm_config.api_version,
80
- temperature = llm_config.temperature,
81
- max_tokens = llm_config.max_tokens,
82
- )
83
- else:
84
- chat_model = ChatDatabricks(
85
- endpoint = llm_config.endpoint,
86
- temperature = llm_config.temperature,
87
- max_tokens = llm_config.max_tokens,
88
- )
89
-
90
- if llm_config.use_simple_prompt:
91
- prompt = PromptTemplate.from_template(SIMPLE_PROMPT_TEMPLATE)
92
- else:
93
- prompt = ChatPromptTemplate.from_messages(
94
- [
95
- SystemMessagePromptTemplate.from_template(MESSAGE_SYSTEM_TEMPLATE),
96
- HumanMessagePromptTemplate.from_template(MESSAGE_USER_CONTEXT_TEMPLATE),
97
- AIMessagePromptTemplate.from_template(MESSAGE_AI_TEMPLATE),
98
- HumanMessagePromptTemplate.from_template(MESSAGE_USER_QUESTION_TEMPLATE)
99
- ]
100
- )
101
-
102
- output_parser = StrOutputParser()
103
-
104
- chain = prompt | chat_model | output_parser
105
-
106
- response = chain.invoke(
107
- {
108
- "table_name": question_name,
109
- "table_data": str(data),
110
- "question": question
111
- }
112
- )
113
-
114
- return response
@@ -1,40 +0,0 @@
1
- import os
2
- from azure.identity import DefaultAzureCredential
3
-
4
- class LLMConfig:
5
-
6
- def __init__(self, deployment: str = None, api_key: str = None, max_tokens: int = 150, temperature: float = 0.1, endpoint: str = "databricks-dbrx-instruct", use_simple_prompt: bool = False, debug: bool = False,
7
- llm_provider: str = "databricks", api_version: str = "2024-02-15-preview"):
8
-
9
- """
10
- Initialize the LLMConfig
11
- Parameters:
12
- deployment (str): LLM against which the query is run.
13
- api_key (str): The api_key used for authenticating with the LLM.
14
- max_tokens (int): Maximum number of tokens that the model can generate in a single response.
15
- temperature (float): Parameter that controls the randomness and creativity of the text generated by the LLM.
16
- endpoint (str): The endpoint of the LLM to connect to.
17
- debug (bool) OPTIONAL: Set Langchain debug for troubleshooting.
18
- use_simple_prompt (bool) OPTIONAL: Simple Prompt template for a language model.
19
- llm_provider (str): Provider of the LLM.
20
- api_version (str): version of the LLM API that the application will use for making requests.
21
- """
22
-
23
- self.endpoint = endpoint
24
- self.deployment = deployment
25
- self.api_key = api_key
26
- self.api_version = api_version
27
- self.use_simple_prompt = use_simple_prompt
28
- self.debug = debug
29
- self.llm_provider = llm_provider
30
- self.max_tokens = max_tokens
31
- self.temperature = temperature
32
-
33
-
34
- def get_azure_ad_token(self):
35
- # Get the Azure Credential
36
- credential = DefaultAzureCredential()
37
- # Set the API type to `azure_ad`
38
- os.environ["OPENAI_API_TYPE"] = "azure_ad"
39
- # Set the API_KEY to the token from the Azure credential
40
- self.aad_token = credential.get_token("https://cognitiveservices.azure.com/.default").token
File without changes
File without changes
File without changes
File without changes