kobai-sdk 0.2.8rc3__tar.gz → 0.2.8rc5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kobai-sdk might be problematic. Click here for more details.
- {kobai_sdk-0.2.8rc3 → kobai_sdk-0.2.8rc5}/PKG-INFO +2 -2
- kobai_sdk-0.2.8rc5/kobai/ai_query.py +255 -0
- {kobai_sdk-0.2.8rc3 → kobai_sdk-0.2.8rc5}/kobai/ai_rag.py +101 -60
- {kobai_sdk-0.2.8rc3 → kobai_sdk-0.2.8rc5}/kobai/tenant_client.py +100 -177
- {kobai_sdk-0.2.8rc3 → kobai_sdk-0.2.8rc5}/kobai_sdk.egg-info/PKG-INFO +2 -2
- {kobai_sdk-0.2.8rc3 → kobai_sdk-0.2.8rc5}/kobai_sdk.egg-info/SOURCES.txt +0 -1
- {kobai_sdk-0.2.8rc3 → kobai_sdk-0.2.8rc5}/kobai_sdk.egg-info/requires.txt +1 -1
- {kobai_sdk-0.2.8rc3 → kobai_sdk-0.2.8rc5}/pyproject.toml +3 -3
- kobai_sdk-0.2.8rc3/kobai/ai_query.py +0 -114
- kobai_sdk-0.2.8rc3/kobai/llm_config.py +0 -40
- {kobai_sdk-0.2.8rc3 → kobai_sdk-0.2.8rc5}/LICENSE +0 -0
- {kobai_sdk-0.2.8rc3 → kobai_sdk-0.2.8rc5}/MANIFEST.in +0 -0
- {kobai_sdk-0.2.8rc3 → kobai_sdk-0.2.8rc5}/README.md +0 -0
- {kobai_sdk-0.2.8rc3 → kobai_sdk-0.2.8rc5}/kobai/__init__.py +0 -0
- {kobai_sdk-0.2.8rc3 → kobai_sdk-0.2.8rc5}/kobai/databricks_client.py +0 -0
- {kobai_sdk-0.2.8rc3 → kobai_sdk-0.2.8rc5}/kobai/demo_tenant_client.py +0 -0
- {kobai_sdk-0.2.8rc3 → kobai_sdk-0.2.8rc5}/kobai/spark_client.py +0 -0
- {kobai_sdk-0.2.8rc3 → kobai_sdk-0.2.8rc5}/kobai/tenant_api.py +0 -0
- {kobai_sdk-0.2.8rc3 → kobai_sdk-0.2.8rc5}/kobai_sdk.egg-info/dependency_links.txt +0 -0
- {kobai_sdk-0.2.8rc3 → kobai_sdk-0.2.8rc5}/kobai_sdk.egg-info/top_level.txt +0 -0
- {kobai_sdk-0.2.8rc3 → kobai_sdk-0.2.8rc5}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.2
|
|
2
2
|
Name: kobai-sdk
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.8rc5
|
|
4
4
|
Summary: A package that enables interaction with a Kobai tenant.
|
|
5
5
|
Author-email: Ryan Oattes <ryan@kobai.io>
|
|
6
6
|
License: Apache License
|
|
@@ -222,7 +222,7 @@ Requires-Dist: azure-storage-blob
|
|
|
222
222
|
Requires-Dist: langchain-core
|
|
223
223
|
Requires-Dist: langchain-community
|
|
224
224
|
Requires-Dist: langchain_openai
|
|
225
|
-
Requires-Dist:
|
|
225
|
+
Requires-Dist: databricks_langchain
|
|
226
226
|
Provides-Extra: dev
|
|
227
227
|
Requires-Dist: black; extra == "dev"
|
|
228
228
|
Requires-Dist: bumpver; extra == "dev"
|
|
@@ -0,0 +1,255 @@
|
|
|
1
|
+
from langchain_core.prompts import ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate, AIMessagePromptTemplate
|
|
2
|
+
from langchain_core.output_parsers import StrOutputParser
|
|
3
|
+
|
|
4
|
+
from sentence_transformers import SentenceTransformer, util
|
|
5
|
+
|
|
6
|
+
from langchain_core.language_models.chat_models import BaseChatModel
|
|
7
|
+
from langchain_core.embeddings import Embeddings
|
|
8
|
+
from langchain_core.documents import Document
|
|
9
|
+
from langchain_core.retrievers import BaseRetriever
|
|
10
|
+
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
|
11
|
+
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
|
|
12
|
+
from langchain_core.vectorstores import InMemoryVectorStore
|
|
13
|
+
|
|
14
|
+
from typing import Union, List
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
MESSAGE_SYSTEM_TEMPLATE = """
|
|
18
|
+
You are a data analyst tasked with answering questions based on a provided data set. Please answer the questions based on the provided context below. Make sure not to make any changes to the context, if possible, when preparing answers to provide accurate responses. If the answer cannot be found in context, just politely say that you do not know, do not try to make up an answer.
|
|
19
|
+
When you receive a question from the user, answer only that one question in a concise manner. Do not elaborate with other questions.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
MESSAGE_AI_TEMPLATE = """
|
|
23
|
+
The table information is as follows:
|
|
24
|
+
{table_data}
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
MESSAGE_USER_CONTEXT_TEMPLATE = """
|
|
28
|
+
The context being provided is from a table named: {table_name}
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
MESSAGE_USER_QUESTION_TEMPLATE = """
|
|
32
|
+
{question}
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
SIMPLE_PROMPT_TEMPLATE = f"""
|
|
36
|
+
{MESSAGE_SYSTEM_TEMPLATE}
|
|
37
|
+
|
|
38
|
+
{MESSAGE_USER_CONTEXT_TEMPLATE}
|
|
39
|
+
|
|
40
|
+
{MESSAGE_AI_TEMPLATE}
|
|
41
|
+
|
|
42
|
+
Question: {MESSAGE_USER_QUESTION_TEMPLATE}
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
class QuestionRetriever(BaseRetriever):
|
|
46
|
+
#https://python.langchain.com/docs/how_to/custom_retriever/
|
|
47
|
+
#https://github.com/langchain-ai/langchain/issues/12304
|
|
48
|
+
|
|
49
|
+
documents: List[Document]
|
|
50
|
+
k: int = 5000
|
|
51
|
+
|
|
52
|
+
#def __init__(self, documents: List[Document], k: int = 5000):
|
|
53
|
+
# self.documents = documents
|
|
54
|
+
# self.k = k
|
|
55
|
+
|
|
56
|
+
def _get_relevant_documents(
|
|
57
|
+
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
|
58
|
+
) -> List[Document]:
|
|
59
|
+
"""Sync implementations for retriever."""
|
|
60
|
+
matching_documents = []
|
|
61
|
+
for document in self.documents:
|
|
62
|
+
if len(matching_documents) > self.k:
|
|
63
|
+
return matching_documents
|
|
64
|
+
|
|
65
|
+
#if query.lower() in document.page_content.lower():
|
|
66
|
+
# matching_documents.append(document)
|
|
67
|
+
matching_documents.append(document)
|
|
68
|
+
return matching_documents
|
|
69
|
+
|
|
70
|
+
def format_docs(docs):
|
|
71
|
+
return "\n\n".join([d.page_content for d in docs])
|
|
72
|
+
|
|
73
|
+
def input_only(inpt):
|
|
74
|
+
return inpt["question"]
|
|
75
|
+
|
|
76
|
+
def followup_question(user_question, question_results, question_name, question_def, embedding_model: Union[SentenceTransformer, Embeddings], chat_model: BaseChatModel, use_inmem_vectors=False):
|
|
77
|
+
|
|
78
|
+
row_texts = process_question_results(question_def, question_results)
|
|
79
|
+
question_documents = [Document(page_content=r, metadata={"source": "kobai"}) for r in row_texts]
|
|
80
|
+
|
|
81
|
+
if use_inmem_vectors:
|
|
82
|
+
question_retriever = InMemoryVectorStore.from_documents(question_documents, embedding=embedding_model).as_retriever(
|
|
83
|
+
search_kwargs={"k": 5}
|
|
84
|
+
)
|
|
85
|
+
else:
|
|
86
|
+
question_retriever = QuestionRetriever(documents=question_documents)
|
|
87
|
+
|
|
88
|
+
output_parser = StrOutputParser()
|
|
89
|
+
|
|
90
|
+
prompt = ChatPromptTemplate.from_messages(
|
|
91
|
+
[
|
|
92
|
+
SystemMessagePromptTemplate.from_template(
|
|
93
|
+
MESSAGE_SYSTEM_TEMPLATE),
|
|
94
|
+
HumanMessagePromptTemplate.from_template(
|
|
95
|
+
MESSAGE_USER_CONTEXT_TEMPLATE),
|
|
96
|
+
AIMessagePromptTemplate.from_template(MESSAGE_AI_TEMPLATE),
|
|
97
|
+
HumanMessagePromptTemplate.from_template(
|
|
98
|
+
MESSAGE_USER_QUESTION_TEMPLATE)
|
|
99
|
+
]
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
chain = (
|
|
103
|
+
{"table_name": RunnablePassthrough(), "table_data": RunnableLambda(input_only) | question_retriever | format_docs, "question": RunnablePassthrough()}
|
|
104
|
+
| prompt
|
|
105
|
+
| chat_model
|
|
106
|
+
| output_parser
|
|
107
|
+
)
|
|
108
|
+
response = chain.invoke(
|
|
109
|
+
{
|
|
110
|
+
"table_name": question_name,
|
|
111
|
+
"question": user_question
|
|
112
|
+
}
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
return response
|
|
116
|
+
|
|
117
|
+
def init_question_search_index(tenant_questions, emb_model):
|
|
118
|
+
|
|
119
|
+
q_ids = [q["id"] for q in tenant_questions]
|
|
120
|
+
q_descs = [q["description"] for q in tenant_questions]
|
|
121
|
+
|
|
122
|
+
if isinstance(emb_model, SentenceTransformer):
|
|
123
|
+
q_vectors = emb_model.encode(q_descs)
|
|
124
|
+
else:
|
|
125
|
+
q_vectors = emb_model.embed_documents(q_descs)
|
|
126
|
+
|
|
127
|
+
return {"ids": q_ids, "descs": q_descs, "vectors": q_vectors}
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def question_search(search_text: str, search_index, emb_model, k: int):
|
|
131
|
+
if isinstance(emb_model, SentenceTransformer):
|
|
132
|
+
search_vec = emb_model.encode(search_text)
|
|
133
|
+
else:
|
|
134
|
+
search_vec = emb_model.embed_query(search_text)
|
|
135
|
+
#search_vec = emb_model.encode(search_text)
|
|
136
|
+
|
|
137
|
+
matches = __top_vector_matches(search_vec, search_index["vectors"], top=k)
|
|
138
|
+
|
|
139
|
+
for mi, m in enumerate(matches):
|
|
140
|
+
matches[mi]["id"] = search_index["ids"][m["index"]]
|
|
141
|
+
matches[mi]["description"] = search_index["descs"][m["index"]]
|
|
142
|
+
return matches
|
|
143
|
+
|
|
144
|
+
def __top_vector_matches(test_vec, options_list_vec, top=1):
|
|
145
|
+
scores_t = util.cos_sim(test_vec, options_list_vec)[0]
|
|
146
|
+
scores_l = scores_t.tolist()
|
|
147
|
+
scores_d = [{"index": i, "value": v} for i, v in enumerate(scores_l)]
|
|
148
|
+
sorted_d = sorted(scores_d, key=lambda i: i["value"], reverse=True)
|
|
149
|
+
top_d = sorted_d[0:top]
|
|
150
|
+
return top_d
|
|
151
|
+
|
|
152
|
+
def process_question_results(question_def, question_results):
|
|
153
|
+
|
|
154
|
+
"""
|
|
155
|
+
Returns a template to format each row in Kobai JSON question output into a format readable by LLMs.
|
|
156
|
+
|
|
157
|
+
Parameters:
|
|
158
|
+
question_def (any): Kobai standard JSON definition of question.
|
|
159
|
+
question_results (any): JSON representation of Kobai base question results.
|
|
160
|
+
"""
|
|
161
|
+
|
|
162
|
+
concept_props = {}
|
|
163
|
+
concept_rels = {}
|
|
164
|
+
|
|
165
|
+
for ci in question_def["definition"]:
|
|
166
|
+
con_name = question_def["definition"][ci]["label"].replace("_", " ")
|
|
167
|
+
con_label = question_def["definition"][ci]["label"]
|
|
168
|
+
concept_props[ci] = {"name": con_name, "props": []}
|
|
169
|
+
for p in question_def["definition"][ci]["properties"]:
|
|
170
|
+
if p["hidden"] == False:
|
|
171
|
+
if len(p["aggregates"]) > 0:
|
|
172
|
+
for a in p["aggregates"]:
|
|
173
|
+
prop_column = con_label + "_" + p["label"] + "_" + a["type"]
|
|
174
|
+
prop_name = p["label"].replace("_", " ")
|
|
175
|
+
concept_props[ci]["props"].append({"column": prop_column, "name": prop_name, "agg": a["type"]})
|
|
176
|
+
else:
|
|
177
|
+
prop_column = con_label + "_" + p["label"]
|
|
178
|
+
prop_name = p["label"].replace("_", " ")
|
|
179
|
+
concept_props[ci]["props"].append({"column": prop_column, "name": prop_name, "agg": None})
|
|
180
|
+
for r in question_def["definition"][ci]["relations"]:
|
|
181
|
+
prop_name = question_def["definition"][ci]["relations"][r]["label"].replace("_", " ")
|
|
182
|
+
for ri in question_def["definition"][ci]["relations"][r]["relationInstances"]:
|
|
183
|
+
if ci not in concept_rels:
|
|
184
|
+
concept_rels[ci] = {"count": 0, "edges": []}
|
|
185
|
+
concept_rels[ci]["edges"].append({"src": ci, "dst": ri["relationTypeUri"], "name": prop_name})
|
|
186
|
+
concept_rels[ci]["count"] += 1
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
row_texts = {}
|
|
190
|
+
|
|
191
|
+
for ci, c in concept_props.items():
|
|
192
|
+
p_texts = []
|
|
193
|
+
for p in c["props"]:
|
|
194
|
+
if p["agg"] is None:
|
|
195
|
+
p_text = p["name"] + " " + "{" + p["column"] + "}"
|
|
196
|
+
else:
|
|
197
|
+
p_text = p["agg"] + " of " + p["name"] + " " + "{" + p["column"] + "}"
|
|
198
|
+
p_texts.append(p_text)
|
|
199
|
+
c_text = __get_article(c["name"]) + " " + c["name"]
|
|
200
|
+
if len(c["props"]) > 0:
|
|
201
|
+
c_text += " with " + __smart_comma_formatting(p_texts)
|
|
202
|
+
row_texts[ci] = c_text
|
|
203
|
+
|
|
204
|
+
max_src = ""
|
|
205
|
+
max_src_count = -1
|
|
206
|
+
|
|
207
|
+
for r in concept_rels:
|
|
208
|
+
if concept_rels[r]["count"] > max_src_count:
|
|
209
|
+
max_src_count = concept_rels[r]["count"]
|
|
210
|
+
max_src = r
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
concept_order = [max_src]
|
|
214
|
+
for t in concept_rels[max_src]["edges"]:
|
|
215
|
+
concept_order.append(t["dst"])
|
|
216
|
+
|
|
217
|
+
for c in concept_props:
|
|
218
|
+
if c not in concept_order:
|
|
219
|
+
concept_order.append(c)
|
|
220
|
+
|
|
221
|
+
row_template = concept_order[0] + " is connected to " + " and connected to ".join(concept_order[1:])
|
|
222
|
+
|
|
223
|
+
for c in row_texts:
|
|
224
|
+
row_template = row_template.replace(c, row_texts[c])
|
|
225
|
+
|
|
226
|
+
row_template = row_template[0].upper() + row_template[1:] + "."
|
|
227
|
+
|
|
228
|
+
row_texts = []
|
|
229
|
+
for row in question_results:
|
|
230
|
+
row_text = row_template
|
|
231
|
+
for col in row:
|
|
232
|
+
row_text = row_text.replace("{" + col + "}", str(row[col]))
|
|
233
|
+
row_texts.append(row_text)
|
|
234
|
+
#data = "\n".join(row_texts)
|
|
235
|
+
return row_texts
|
|
236
|
+
#return data
|
|
237
|
+
|
|
238
|
+
def __smart_comma_formatting(items):
|
|
239
|
+
if items == None:
|
|
240
|
+
return ""
|
|
241
|
+
match len(items):
|
|
242
|
+
case 0:
|
|
243
|
+
return ""
|
|
244
|
+
case 1:
|
|
245
|
+
return items[0]
|
|
246
|
+
case 2:
|
|
247
|
+
return items[0] + " and " + items[1]
|
|
248
|
+
case _:
|
|
249
|
+
return ", ".join(items[0: -1]) + " and " + items[-1]
|
|
250
|
+
|
|
251
|
+
def __get_article(label):
|
|
252
|
+
if label[0:1].lower() in ["a", "e", "i", "o", "u"]:
|
|
253
|
+
return "an"
|
|
254
|
+
else:
|
|
255
|
+
return "a"
|
|
@@ -1,4 +1,5 @@
|
|
|
1
|
-
from kobai import
|
|
1
|
+
from kobai import tenant_api
|
|
2
|
+
from pyspark.sql import SparkSession
|
|
2
3
|
|
|
3
4
|
from pyspark.sql.types import StructType, StructField, StringType, ArrayType, FloatType, IntegerType
|
|
4
5
|
from pyspark.sql import functions as F
|
|
@@ -11,12 +12,43 @@ from langchain_community.document_loaders import PySparkDataFrameLoader
|
|
|
11
12
|
from langchain import hub
|
|
12
13
|
from langchain_core.output_parsers import StrOutputParser
|
|
13
14
|
|
|
15
|
+
import urllib
|
|
16
|
+
import urllib.parse
|
|
14
17
|
|
|
18
|
+
class AIContext:
|
|
15
19
|
|
|
20
|
+
schema: str
|
|
21
|
+
spark_session: SparkSession
|
|
22
|
+
model_id: str
|
|
23
|
+
tenant_json: str
|
|
24
|
+
api_client: tenant_api.TenantAPI
|
|
16
25
|
|
|
26
|
+
def ai_run_question_remote(tc: AIContext, question_id, dynamic_filters: dict = None):
|
|
17
27
|
|
|
18
|
-
|
|
28
|
+
"""
|
|
29
|
+
Returns JSON formatted result of Kobai question.
|
|
19
30
|
|
|
31
|
+
Parameters:
|
|
32
|
+
question_id (int): Numeric identifier of Kobai question.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
uri = '/data-svcs/api/query/' + str(question_id) + '/execute?' #'/data-svcs/api/query/4518/solution/9/execute/tabular?'
|
|
36
|
+
|
|
37
|
+
queryParams = {'jsontype': 'tableau'}
|
|
38
|
+
|
|
39
|
+
if bool(dynamic_filters):
|
|
40
|
+
queryParams.update(dynamic_filters)
|
|
41
|
+
|
|
42
|
+
uri += urllib.parse.urlencode(queryParams)
|
|
43
|
+
|
|
44
|
+
json={
|
|
45
|
+
'simulations': {'concepts': {}, 'data': None}
|
|
46
|
+
}
|
|
47
|
+
response = tc.api_client._TenantAPI__run_post(uri, json)
|
|
48
|
+
|
|
49
|
+
return response.json()
|
|
50
|
+
|
|
51
|
+
def generate_sentences(tc: AIContext, replica_schema=None, concept_white_list=None, use_questions=False):
|
|
20
52
|
"""
|
|
21
53
|
Extract Semantic Data from Graph to Delta Table
|
|
22
54
|
|
|
@@ -26,24 +58,27 @@ def generate_sentences(tc: tenant_client.TenantClient, replica_schema=None, conc
|
|
|
26
58
|
concept_white_list ([str]) OPTIONAL: A list of Domain and Concept names for extraction.
|
|
27
59
|
use_questions (bool) OPTIONAL: Extract facts from published Kobai questions.
|
|
28
60
|
"""
|
|
29
|
-
|
|
30
|
-
if tc.spark_client is None:
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
ss = tc.
|
|
61
|
+
|
|
62
|
+
#if tc.spark_client is None:
|
|
63
|
+
# return None
|
|
64
|
+
|
|
65
|
+
ss = tc.spark_session
|
|
34
66
|
|
|
35
67
|
print("Getting Tenant Config")
|
|
36
|
-
tenant_json = tc.
|
|
68
|
+
tenant_json = tc.tenant_json
|
|
37
69
|
|
|
38
|
-
concepts = __get_concept_metadata(
|
|
70
|
+
concepts = __get_concept_metadata(
|
|
71
|
+
tenant_json, tc.schema, tc.model_id, concept_white_list)
|
|
39
72
|
|
|
40
73
|
print("Dropping and Recreating the RAG Table")
|
|
41
74
|
ss.sql(__create_rag_table_sql(tc.schema, tc.model_id))
|
|
42
75
|
|
|
43
76
|
print("Generating Extraction SQL")
|
|
44
77
|
sql_statements = []
|
|
45
|
-
sql_statements.extend(__generate_sentence_sql_concept_literals(
|
|
46
|
-
|
|
78
|
+
sql_statements.extend(__generate_sentence_sql_concept_literals(
|
|
79
|
+
concepts, tc.schema, tc.model_id))
|
|
80
|
+
sql_statements.extend(__generate_sentence_sql_concept_relations(
|
|
81
|
+
concepts, tc.schema, tc.model_id))
|
|
47
82
|
|
|
48
83
|
print("Running the Extraction")
|
|
49
84
|
for sql_statement in sql_statements:
|
|
@@ -55,14 +90,16 @@ def generate_sentences(tc: tenant_client.TenantClient, replica_schema=None, conc
|
|
|
55
90
|
if replica_schema is not None:
|
|
56
91
|
print("Replicating Schema")
|
|
57
92
|
ss.sql(__create_rag_table_sql(replica_schema, tc.model_id))
|
|
58
|
-
ss.sql(__replicate_to_catalog_sql(
|
|
93
|
+
ss.sql(__replicate_to_catalog_sql(
|
|
94
|
+
tc.schema, replica_schema, tc.model_id))
|
|
59
95
|
|
|
60
|
-
|
|
61
|
-
|
|
96
|
+
|
|
97
|
+
def __generate_sentences_from_questions(tc: AIContext):
|
|
98
|
+
ss = tc.spark_session
|
|
62
99
|
|
|
63
100
|
print("Getting Question Data")
|
|
64
101
|
|
|
65
|
-
tenant_json = tc.
|
|
102
|
+
tenant_json = tc.tenant_json
|
|
66
103
|
|
|
67
104
|
published_queries = []
|
|
68
105
|
for p in tenant_json["publishedAPIs"]:
|
|
@@ -73,22 +110,21 @@ def __generate_sentences_from_questions(tc: tenant_client.TenantClient):
|
|
|
73
110
|
if q["id"] in published_queries:
|
|
74
111
|
question_names[q["id"]] = q["description"]
|
|
75
112
|
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
113
|
+
schema_v = StructType([
|
|
114
|
+
StructField("sentence", StringType(), True),
|
|
115
|
+
StructField("query_id", StringType(), True)
|
|
116
|
+
])
|
|
80
117
|
|
|
81
118
|
sentences = []
|
|
82
119
|
for p in published_queries:
|
|
83
|
-
output = tc
|
|
120
|
+
output = ai_run_question_remote(tc, p)
|
|
84
121
|
for r in output:
|
|
85
122
|
sentence = f"For {question_names[p]}: "
|
|
86
123
|
for c in r:
|
|
87
124
|
sentence += f"The {c.replace('_', ' ')} is {r[c]}. "
|
|
88
125
|
sentences.append([sentence, p])
|
|
89
126
|
|
|
90
|
-
|
|
91
|
-
sentences_df = ss.createDataFrame(sentences, schemaV)
|
|
127
|
+
sentences_df = ss.createDataFrame(sentences, schema_v)
|
|
92
128
|
sentences_df = sentences_df.select(
|
|
93
129
|
F.col("sentence").alias("sentence"),
|
|
94
130
|
F.col("query_id").alias("concept_id"),
|
|
@@ -96,19 +132,17 @@ def __generate_sentences_from_questions(tc: tenant_client.TenantClient):
|
|
|
96
132
|
)
|
|
97
133
|
|
|
98
134
|
schema = tc.schema
|
|
99
|
-
|
|
135
|
+
|
|
100
136
|
view_name = f"rag_{tc.model_id}_question_sentences"
|
|
101
137
|
sentences_df.createOrReplaceTempView(view_name)
|
|
102
138
|
|
|
103
139
|
full_sql = f"INSERT INTO {schema}.rag_{tc.model_id} (content, concept_id, type)"
|
|
104
140
|
full_sql += f" SELECT sentence, concept_id, type FROM {view_name}"
|
|
105
|
-
|
|
106
|
-
ss.sql(full_sql)
|
|
107
141
|
|
|
142
|
+
ss.sql(full_sql)
|
|
108
143
|
|
|
109
144
|
|
|
110
|
-
def encode_to_delta_local(tc:
|
|
111
|
-
|
|
145
|
+
def encode_to_delta_local(tc: AIContext, st_model: SentenceTransformer, replica_schema=None):
|
|
112
146
|
"""
|
|
113
147
|
Encode Semantic Data to Vectors in Delta Table
|
|
114
148
|
|
|
@@ -118,10 +152,10 @@ def encode_to_delta_local(tc: tenant_client.TenantClient, st_model: SentenceTran
|
|
|
118
152
|
replica_schema (str) OPTIONAL: An alternate schema (catalog.database) to create the Delta table. Useful when the base Kobai schema is not on a Unity Catalog.
|
|
119
153
|
"""
|
|
120
154
|
|
|
121
|
-
if tc.spark_client is None:
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
ss = tc.
|
|
155
|
+
#if tc.spark_client is None:
|
|
156
|
+
# return None
|
|
157
|
+
|
|
158
|
+
ss = tc.spark_session
|
|
125
159
|
|
|
126
160
|
schema = tc.schema
|
|
127
161
|
if replica_schema is not None:
|
|
@@ -133,7 +167,6 @@ def encode_to_delta_local(tc: tenant_client.TenantClient, st_model: SentenceTran
|
|
|
133
167
|
num_records = sentences_df.count()
|
|
134
168
|
query_batch_size = 100000
|
|
135
169
|
|
|
136
|
-
|
|
137
170
|
for x in range(0, num_records, query_batch_size):
|
|
138
171
|
print(f"Running Batch Starting at {x}")
|
|
139
172
|
sentences_sql = f" SELECT id, content FROM {schema}.rag_{tc.model_id} ORDER BY id LIMIT {str(query_batch_size)} OFFSET {str(x)}"
|
|
@@ -141,25 +174,27 @@ def encode_to_delta_local(tc: tenant_client.TenantClient, st_model: SentenceTran
|
|
|
141
174
|
content_list = [r["content"] for r in sentences_df.collect()]
|
|
142
175
|
id_list = [r["id"] for r in sentences_df.collect()]
|
|
143
176
|
|
|
144
|
-
vector_list = st_model.encode(
|
|
177
|
+
vector_list = st_model.encode(
|
|
178
|
+
content_list, normalize_embeddings=True, show_progress_bar=True)
|
|
145
179
|
|
|
146
|
-
|
|
147
|
-
StructField("id",IntegerType(),True),
|
|
180
|
+
schema_v = StructType([
|
|
181
|
+
StructField("id", IntegerType(), True),
|
|
148
182
|
StructField("vector", ArrayType(FloatType()), False)
|
|
149
183
|
])
|
|
150
184
|
|
|
151
|
-
updated_list = [[r[0], r[1].tolist()]
|
|
152
|
-
|
|
185
|
+
updated_list = [[r[0], r[1].tolist()]
|
|
186
|
+
for r in zip(id_list, vector_list)]
|
|
187
|
+
updated_df = ss.createDataFrame(updated_list, schema_v)
|
|
153
188
|
|
|
154
189
|
target_table = DeltaTable.forName(ss, f"{schema}.rag_{tc.model_id}")
|
|
155
190
|
|
|
156
191
|
target_table.alias("t") \
|
|
157
|
-
|
|
192
|
+
.merge(
|
|
158
193
|
updated_df.alias("s"),
|
|
159
194
|
't.id = s.id'
|
|
160
195
|
) \
|
|
161
|
-
|
|
162
|
-
|
|
196
|
+
.whenMatchedUpdate(set={"vector": "s.vector"}) \
|
|
197
|
+
.execute()
|
|
163
198
|
|
|
164
199
|
ss.sql(f"""
|
|
165
200
|
CREATE FUNCTION IF NOT EXISTS {schema}.cos_sim(a ARRAY<FLOAT>, b ARRAY<FLOAT>)
|
|
@@ -171,8 +206,8 @@ def encode_to_delta_local(tc: tenant_client.TenantClient, st_model: SentenceTran
|
|
|
171
206
|
$$
|
|
172
207
|
""")
|
|
173
208
|
|
|
174
|
-
def rag_delta(tc: tenant_client.TenantClient, emb_model: Union[SentenceTransformer, Embeddings], chat_model: BaseChatModel, question, k=5, replica_schema=None):
|
|
175
209
|
|
|
210
|
+
def rag_delta(tc: AIContext, emb_model: Union[SentenceTransformer, Embeddings], chat_model: BaseChatModel, question, k=5, replica_schema=None):
|
|
176
211
|
"""
|
|
177
212
|
Run a RAG query using vectors in Delta table.
|
|
178
213
|
|
|
@@ -184,25 +219,26 @@ def rag_delta(tc: tenant_client.TenantClient, emb_model: Union[SentenceTransform
|
|
|
184
219
|
k (int) OPTIONAL: The number of RAG documents to retrieve.
|
|
185
220
|
replica_schema (str) OPTIONAL: An alternate schema (catalog.database) to create the Delta table. Useful when the base Kobai schema is not on a Unity Catalog.
|
|
186
221
|
"""
|
|
187
|
-
|
|
222
|
+
|
|
188
223
|
schema = tc.schema
|
|
189
224
|
if replica_schema is not None:
|
|
190
225
|
schema = replica_schema
|
|
191
226
|
|
|
192
|
-
if tc.spark_client is None:
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
ss = tc.
|
|
227
|
+
#if tc.spark_client is None:
|
|
228
|
+
# print("Instantiate Spark Client First")
|
|
229
|
+
# return None
|
|
230
|
+
|
|
231
|
+
ss = tc.spark_session
|
|
197
232
|
|
|
198
233
|
if isinstance(emb_model, SentenceTransformer):
|
|
199
|
-
vector_list = emb_model.encode(
|
|
234
|
+
vector_list = emb_model.encode(
|
|
235
|
+
question, normalize_embeddings=True).tolist()
|
|
200
236
|
elif isinstance(emb_model, Embeddings):
|
|
201
237
|
vector_list = emb_model.embed_query(question)
|
|
202
238
|
else:
|
|
203
239
|
print("Invalid Embedding Model Type")
|
|
204
240
|
return None
|
|
205
|
-
|
|
241
|
+
|
|
206
242
|
if not isinstance(chat_model, BaseChatModel):
|
|
207
243
|
print("Invalid Chat Model Type")
|
|
208
244
|
return None
|
|
@@ -216,7 +252,7 @@ def rag_delta(tc: tenant_client.TenantClient, emb_model: Union[SentenceTransform
|
|
|
216
252
|
ORDER BY score DESC
|
|
217
253
|
LIMIT {k}
|
|
218
254
|
""")
|
|
219
|
-
|
|
255
|
+
|
|
220
256
|
loader = PySparkDataFrameLoader(ss, results, page_content_column="content")
|
|
221
257
|
documents = loader.load()
|
|
222
258
|
docs_content = "\n\n".join(doc.page_content for doc in documents)
|
|
@@ -236,20 +272,24 @@ def rag_delta(tc: tenant_client.TenantClient, emb_model: Union[SentenceTransform
|
|
|
236
272
|
|
|
237
273
|
return response
|
|
238
274
|
|
|
275
|
+
|
|
239
276
|
def __create_rag_table_sql(schema, model_id):
|
|
240
277
|
return f"CREATE OR REPLACE TABLE {schema}.rag_{model_id} (id BIGINT GENERATED BY DEFAULT AS IDENTITY, content STRING, type string, concept_id string, vector ARRAY<FLOAT>) TBLPROPERTIES (delta.enableChangeDataFeed = true)"
|
|
241
|
-
|
|
278
|
+
|
|
279
|
+
|
|
242
280
|
def __replicate_to_catalog_sql(base_schema, target_schema, model_id):
|
|
243
281
|
move_sql = f"INSERT INTO {target_schema}.rag_{model_id} (content, concept_id, type)"
|
|
244
282
|
move_sql += f" SELECT content, concept_id, type FROM {base_schema}.rag_{model_id}"
|
|
245
283
|
return move_sql
|
|
246
284
|
|
|
285
|
+
|
|
286
|
+
|
|
247
287
|
def __generate_sentence_sql_concept_literals(concepts, schema, model_id):
|
|
248
288
|
statements = []
|
|
249
289
|
for con in concepts:
|
|
250
290
|
sql = f"'This is a {con['label']}. '"
|
|
251
291
|
sql += " || 'It is identified by ' || cid._plain_conceptid || '. '"
|
|
252
|
-
|
|
292
|
+
|
|
253
293
|
sql_from = f"(SELECT _conceptid, _plain_conceptid FROM {con['prop_table_name']} GROUP BY _conceptid, _plain_conceptid) cid"
|
|
254
294
|
for prop in con["properties"]:
|
|
255
295
|
|
|
@@ -257,15 +297,16 @@ def __generate_sentence_sql_concept_literals(concepts, schema, model_id):
|
|
|
257
297
|
sql_from += f" ON cid._conceptid = {prop['label']}._conceptid"
|
|
258
298
|
sql_from += f" AND {prop['label']}.type = 'l'"
|
|
259
299
|
sql_from += f" AND {prop['label']}.name = '{prop['name']}'"
|
|
260
|
-
|
|
300
|
+
|
|
261
301
|
sql += f" || 'The {prop['label']} is ' || ifnull(any_value({prop['label']}.value) IGNORE NULLS, 'unknown') || '. '"
|
|
262
|
-
|
|
302
|
+
|
|
263
303
|
full_sql = f"INSERT INTO {schema}.rag_{model_id} (content, concept_id, type)"
|
|
264
304
|
full_sql += f" SELECT {sql} content, cid._conceptid concept_id, 'c' type FROM {sql_from} GROUP BY cid._conceptid, cid._plain_conceptid"
|
|
265
|
-
|
|
305
|
+
|
|
266
306
|
statements.append(full_sql)
|
|
267
307
|
return statements
|
|
268
308
|
|
|
309
|
+
|
|
269
310
|
def __generate_sentence_sql_concept_relations(concepts, schema, model_id):
|
|
270
311
|
statements = []
|
|
271
312
|
for con in concepts:
|
|
@@ -280,13 +321,13 @@ def __generate_sentence_sql_concept_relations(concepts, schema, model_id):
|
|
|
280
321
|
sql += f" || ' has a relationship called {rel['label']} that connects it to one or more {rel['target_con_label']} identified by '"
|
|
281
322
|
sql += " || concat_ws(', ', array_agg(cid._plain_conceptid)) || '. '"
|
|
282
323
|
|
|
283
|
-
|
|
284
324
|
full_sql = f"INSERT INTO {schema}.rag_{model_id} (content, concept_id, type)"
|
|
285
325
|
full_sql += f" SELECT {sql} content, rel._conceptid concept_id, 'e' type FROM {sql_from} GROUP BY rel._conceptid, rel._plain_conceptid"
|
|
286
326
|
|
|
287
327
|
statements.append(full_sql)
|
|
288
328
|
return statements
|
|
289
329
|
|
|
330
|
+
|
|
290
331
|
def __get_concept_metadata(tenant_json, schema, model_id, whitelist):
|
|
291
332
|
target_concept_labels = {}
|
|
292
333
|
target_table_names = {}
|
|
@@ -297,7 +338,7 @@ def __get_concept_metadata(tenant_json, schema, model_id, whitelist):
|
|
|
297
338
|
"prop": f"{schema}.data_{model_id}_{d['name']}_{c['label']}_np",
|
|
298
339
|
"con": f"{schema}.data_{model_id}_{d['name']}_{c['label']}_c"
|
|
299
340
|
}
|
|
300
|
-
|
|
341
|
+
|
|
301
342
|
concepts = []
|
|
302
343
|
for d in tenant_json["domains"]:
|
|
303
344
|
for c in d["concepts"]:
|
|
@@ -306,7 +347,7 @@ def __get_concept_metadata(tenant_json, schema, model_id, whitelist):
|
|
|
306
347
|
con_props.append({
|
|
307
348
|
"label": col["label"],
|
|
308
349
|
"name": f"{model_id}/{d['name']}/{c['label']}#{col['label']}"
|
|
309
|
-
|
|
350
|
+
})
|
|
310
351
|
con_rels = []
|
|
311
352
|
for rel in c["relations"]:
|
|
312
353
|
if whitelist is not None and target_concept_labels[rel["relationTypeUri"]] not in whitelist:
|
|
@@ -328,7 +369,7 @@ def __get_concept_metadata(tenant_json, schema, model_id, whitelist):
|
|
|
328
369
|
"parents": con_parents,
|
|
329
370
|
"prop_table_name": target_table_names[c["uri"]]["prop"],
|
|
330
371
|
"con_table_name": target_table_names[c["uri"]]["con"]
|
|
331
|
-
|
|
372
|
+
})
|
|
332
373
|
|
|
333
374
|
for ci, c in enumerate(concepts):
|
|
334
375
|
if len(c["parents"]) > 0:
|
|
@@ -343,4 +384,4 @@ def __get_concept_metadata(tenant_json, schema, model_id, whitelist):
|
|
|
343
384
|
continue
|
|
344
385
|
out_concepts.append(c)
|
|
345
386
|
|
|
346
|
-
return out_concepts
|
|
387
|
+
return out_concepts
|
|
@@ -7,7 +7,14 @@ import requests
|
|
|
7
7
|
from azure.identity import DeviceCodeCredential
|
|
8
8
|
from pyspark.sql import SparkSession
|
|
9
9
|
|
|
10
|
-
from . import
|
|
10
|
+
from langchain_community.chat_models import ChatDatabricks
|
|
11
|
+
from databricks_langchain import DatabricksEmbeddings
|
|
12
|
+
from sentence_transformers import SentenceTransformer
|
|
13
|
+
from langchain_core.language_models.chat_models import BaseChatModel
|
|
14
|
+
from langchain_core.embeddings import Embeddings
|
|
15
|
+
from typing import Union
|
|
16
|
+
|
|
17
|
+
from . import spark_client, databricks_client, ai_query, tenant_api, ai_rag
|
|
11
18
|
|
|
12
19
|
class TenantClient:
|
|
13
20
|
|
|
@@ -39,6 +46,10 @@ class TenantClient:
|
|
|
39
46
|
self.model_id = ""
|
|
40
47
|
self.proxies = None
|
|
41
48
|
self.ssl_verify = True
|
|
49
|
+
self.question_search_index = None
|
|
50
|
+
self.embedding_model = None
|
|
51
|
+
self.chat_model = None
|
|
52
|
+
|
|
42
53
|
|
|
43
54
|
def update_proxy(self, proxies: any):
|
|
44
55
|
self.proxies = proxies
|
|
@@ -99,10 +110,11 @@ class TenantClient:
|
|
|
99
110
|
|
|
100
111
|
self.__api_init_session()
|
|
101
112
|
self.__set_tenant_solutionid()
|
|
113
|
+
self.init_ai_components()
|
|
102
114
|
|
|
103
115
|
print("Authentication Successful.")
|
|
104
116
|
|
|
105
|
-
def
|
|
117
|
+
def authenticate_browser_token(self, access_token):
|
|
106
118
|
|
|
107
119
|
"""
|
|
108
120
|
Authenticate the TenantClient with the Kobai instance. Returns nothing, but stores bearer token in client.
|
|
@@ -116,6 +128,8 @@ class TenantClient:
|
|
|
116
128
|
|
|
117
129
|
self.__api_init_session()
|
|
118
130
|
self.__set_tenant_solutionid()
|
|
131
|
+
self.init_ai_components()
|
|
132
|
+
|
|
119
133
|
|
|
120
134
|
print("Authentication Successful.")
|
|
121
135
|
|
|
@@ -410,212 +424,114 @@ class TenantClient:
|
|
|
410
424
|
return return_questions
|
|
411
425
|
|
|
412
426
|
########################################
|
|
413
|
-
#
|
|
427
|
+
# RAG Functions
|
|
414
428
|
########################################
|
|
415
429
|
|
|
416
|
-
def
|
|
417
|
-
|
|
430
|
+
def get_ai_context(self):
|
|
431
|
+
context = ai_rag.AIContext()
|
|
432
|
+
context.model_id = self.model_id
|
|
433
|
+
context.schema = self.schema
|
|
434
|
+
context.tenant_json = self.get_tenant_config()
|
|
435
|
+
context.spark_session = self.spark_client.spark_session
|
|
436
|
+
context.api_client = self.api_client
|
|
437
|
+
return context
|
|
438
|
+
|
|
439
|
+
def rag_generate_sentences(self, replica_schema=None, concept_white_list=None, use_questions=False):
|
|
418
440
|
"""
|
|
419
|
-
|
|
441
|
+
Extract Semantic Data from Graph to Delta Table
|
|
420
442
|
|
|
421
443
|
Parameters:
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
override_model (LangChain BaseLanguageModel) OPTIONAL: Langchain LLM or ChatModel runnable.
|
|
426
|
-
use_simple_prompt (bool) OPTIONAL: Uses ChatPrompt when True, Prompt when False.
|
|
427
|
-
debug (bool) OPTIONAL: Set Langchain debug for troubleshooting.
|
|
444
|
+
replica_schema (str) OPTIONAL: An alternate schema (catalog.database) to create the Delta table. Useful when the base Kobai schema is not on a Unity Catalog.
|
|
445
|
+
concept_white_list ([str]) OPTIONAL: A list of Domain and Concept names for extraction.
|
|
446
|
+
use_questions (bool) OPTIONAL: Extract facts from published Kobai questions.
|
|
428
447
|
"""
|
|
429
|
-
|
|
430
|
-
question_def = self.get_question(question_id)
|
|
431
|
-
question_name = question_def["description"]
|
|
432
|
-
|
|
433
|
-
row_texts = []
|
|
434
|
-
row_template = self.process_question_results(question_def)
|
|
435
|
-
for row in question_results:
|
|
436
|
-
row_text = row_template
|
|
437
|
-
for col in row:
|
|
438
|
-
row_text = row_text.replace("{" + col + "}", str(row[col]))
|
|
439
|
-
row_texts.append(row_text)
|
|
440
|
-
data = "\n".join(row_texts)
|
|
448
|
+
ai_rag.generate_sentences(self.get_ai_context(), replica_schema=replica_schema, concept_white_list=concept_white_list, use_questions=use_questions)
|
|
441
449
|
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
None,
|
|
446
|
-
override_model=override_model,
|
|
447
|
-
)
|
|
450
|
+
def rag_encode_to_delta_local(self, st_model: SentenceTransformer, replica_schema=None):
|
|
451
|
+
"""
|
|
452
|
+
Encode Semantic Data to Vectors in Delta Table
|
|
448
453
|
|
|
449
|
-
|
|
454
|
+
Parameters:
|
|
455
|
+
st_model (SentenceTransformer): A sentence_transformers model to use for encoding.
|
|
456
|
+
replica_schema (str) OPTIONAL: An alternate schema (catalog.database) to create the Delta table. Useful when the base Kobai schema is not on a Unity Catalog.
|
|
457
|
+
"""
|
|
458
|
+
ai_rag.encode_to_delta_local(self.get_ai_context(), st_model=st_model, replica_schema=replica_schema)
|
|
450
459
|
|
|
460
|
+
def rag_delta(self, emb_model: Union[SentenceTransformer, Embeddings], chat_model: BaseChatModel, question, k=5, replica_schema=None):
|
|
451
461
|
"""
|
|
452
|
-
|
|
462
|
+
Run a RAG query using vectors in Delta table.
|
|
453
463
|
|
|
454
464
|
Parameters:
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
con_name = question_def["definition"][ci]["label"].replace("_", " ")
|
|
463
|
-
con_label = question_def["definition"][ci]["label"]
|
|
464
|
-
concept_props[ci] = {"name": con_name, "props": []}
|
|
465
|
-
for p in question_def["definition"][ci]["properties"]:
|
|
466
|
-
if p["hidden"] == False:
|
|
467
|
-
if len(p["aggregates"]) > 0:
|
|
468
|
-
for a in p["aggregates"]:
|
|
469
|
-
prop_column = con_label + "_" + p["label"] + "_" + a["type"]
|
|
470
|
-
prop_name = p["label"].replace("_", " ")
|
|
471
|
-
concept_props[ci]["props"].append({"column": prop_column, "name": prop_name, "agg": a["type"]})
|
|
472
|
-
else:
|
|
473
|
-
prop_column = con_label + "_" + p["label"]
|
|
474
|
-
prop_name = p["label"].replace("_", " ")
|
|
475
|
-
concept_props[ci]["props"].append({"column": prop_column, "name": prop_name, "agg": None})
|
|
476
|
-
for r in question_def["definition"][ci]["relations"]:
|
|
477
|
-
prop_name = question_def["definition"][ci]["relations"][r]["label"].replace("_", " ")
|
|
478
|
-
for ri in question_def["definition"][ci]["relations"][r]["relationInstances"]:
|
|
479
|
-
if ci not in concept_rels:
|
|
480
|
-
concept_rels[ci] = {"count": 0, "edges": []}
|
|
481
|
-
concept_rels[ci]["edges"].append({"src": ci, "dst": ri["relationTypeUri"], "name": prop_name})
|
|
482
|
-
concept_rels[ci]["count"] += 1
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
row_texts = {}
|
|
486
|
-
|
|
487
|
-
for ci, c in concept_props.items():
|
|
488
|
-
p_texts = []
|
|
489
|
-
for p in c["props"]:
|
|
490
|
-
if p["agg"] is None:
|
|
491
|
-
p_text = p["name"] + " " + "{" + p["column"] + "}"
|
|
492
|
-
else:
|
|
493
|
-
p_text = p["agg"] + " of " + p["name"] + " " + "{" + p["column"] + "}"
|
|
494
|
-
p_texts.append(p_text)
|
|
495
|
-
c_text = self.__get_article(c["name"]) + " " + c["name"]
|
|
496
|
-
if len(c["props"]) > 0:
|
|
497
|
-
c_text += " with " + self.__smart_comma_formatting(p_texts)
|
|
498
|
-
row_texts[ci] = c_text
|
|
465
|
+
emb_model (UNION[SentenceTransformer, Embeddings]): A sentence_transformers or langchain embedding model to use for encoding the query.
|
|
466
|
+
chat_model (BaseChatModel): A langchain chat model to use in the RAG pipeline.
|
|
467
|
+
question (str): The user's query.
|
|
468
|
+
k (int) OPTIONAL: The number of RAG documents to retrieve.
|
|
469
|
+
replica_schema (str) OPTIONAL: An alternate schema (catalog.database) to create the Delta table. Useful when the base Kobai schema is not on a Unity Catalog.
|
|
470
|
+
"""
|
|
471
|
+
ai_rag.rag_delta(self.get_ai_context(), emb_model=emb_model, chat_model=chat_model, question=question, k=k, replica_schema=replica_schema)
|
|
499
472
|
|
|
500
|
-
|
|
501
|
-
|
|
473
|
+
########################################
|
|
474
|
+
# AI Functions
|
|
475
|
+
########################################
|
|
502
476
|
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
max_src = r
|
|
477
|
+
def followup_question(self, user_question, question_id=None, use_inmem_vectors=False):
|
|
478
|
+
"""
|
|
479
|
+
Use LLM to answer question in the context of a Kobai Studio question.
|
|
507
480
|
|
|
481
|
+
Parameters:
|
|
482
|
+
user_question (str): A natural language question to apply.
|
|
483
|
+
question_id (int) OPTIONAL: A Kobai question to use as a data source. Otherwise, an appropriate question will be automatically found.
|
|
484
|
+
use_inmem_vectors (bool) OPTIONAL: For large query sets, this secondary processing can reduce the data required in the context window.
|
|
485
|
+
"""
|
|
508
486
|
|
|
509
|
-
|
|
510
|
-
for t in concept_rels[max_src]["edges"]:
|
|
511
|
-
concept_order.append(t["dst"])
|
|
487
|
+
if question_id is None:
|
|
512
488
|
|
|
513
|
-
|
|
514
|
-
if c not in concept_order:
|
|
515
|
-
concept_order.append(c)
|
|
489
|
+
suggestions = self.question_search(user_question, k=1)
|
|
516
490
|
|
|
517
|
-
|
|
491
|
+
question_id = suggestions[0]["id"]
|
|
518
492
|
|
|
519
|
-
|
|
520
|
-
row_text = row_text.replace(c, row_texts[c])
|
|
493
|
+
question_results = self.run_question_remote(question_id)
|
|
521
494
|
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
def process_question_results2(self, question_def):
|
|
495
|
+
question_def = self.get_question(question_id)
|
|
496
|
+
question_name = question_def["description"]
|
|
526
497
|
|
|
498
|
+
return ai_query.followup_question(user_question, question_results, question_name, question_def, self.embedding_model, self.chat_model, use_inmem_vectors=use_inmem_vectors)
|
|
499
|
+
|
|
500
|
+
def init_ai_components(self, embedding_model: Union[SentenceTransformer, Embeddings] = None, chat_model: BaseChatModel = None):
|
|
527
501
|
"""
|
|
528
|
-
|
|
502
|
+
Set Chat and Embedding models for AI functions to use. If no arguments provided, Databricks hosted services are used.
|
|
529
503
|
|
|
530
504
|
Parameters:
|
|
531
|
-
|
|
505
|
+
embedding_model (Union[SentenceTransformer, Embeddings]) OPTIONAL: A sentence_transformer or Langchain Embedding model.
|
|
506
|
+
chat_model (BaseChatModel) OPTIONAL: A Langchain BaseChatModel chat model.
|
|
532
507
|
"""
|
|
533
508
|
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
for a in p["aggregates"]:
|
|
545
|
-
prop_column = con_label + "_" + p["label"] + "_" + a["type"]
|
|
546
|
-
prop_name = p["label"].replace("_", " ")
|
|
547
|
-
concept_props[ci]["props"].append({"column": prop_column, "name": prop_name, "agg": a["type"]})
|
|
548
|
-
else:
|
|
549
|
-
prop_column = con_label + "_" + p["label"]
|
|
550
|
-
prop_name = p["label"].replace("_", " ")
|
|
551
|
-
concept_props[ci]["props"].append({"column": prop_column, "name": prop_name, "agg": None})
|
|
552
|
-
for r in question_def["definition"][ci]["relations"]:
|
|
553
|
-
prop_name = question_def["definition"][ci]["relations"][r]["label"].replace("_", " ")
|
|
554
|
-
for ri in question_def["definition"][ci]["relations"][r]["relationInstances"]:
|
|
555
|
-
if ci not in concept_rels:
|
|
556
|
-
concept_rels[ci] = {"count": 0, "edges": []}
|
|
557
|
-
concept_rels[ci]["edges"].append({"src": ci, "dst": ri["relationTypeUri"], "name": prop_name})
|
|
558
|
-
concept_rels[ci]["count"] += 1
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
row_texts = {}
|
|
562
|
-
|
|
563
|
-
for ci, c in concept_props.items():
|
|
564
|
-
p_texts = []
|
|
565
|
-
for p in c["props"]:
|
|
566
|
-
if p["agg"] is None:
|
|
567
|
-
p_text = p["name"] + " " + "{" + p["column"] + "}"
|
|
568
|
-
else:
|
|
569
|
-
p_text = p["agg"] + " of " + p["name"] + " " + "{" + p["column"] + "}"
|
|
570
|
-
p_texts.append(p_text)
|
|
571
|
-
c_text = self.__get_article(c["name"]) + " " + c["name"]
|
|
572
|
-
if len(c["props"]) > 0:
|
|
573
|
-
c_text += " with " + self.__smart_comma_formatting(p_texts)
|
|
574
|
-
row_texts[ci] = c_text
|
|
509
|
+
if embedding_model is not None:
|
|
510
|
+
self.embedding_model = embedding_model
|
|
511
|
+
else:
|
|
512
|
+
#self.embedding_model = SentenceTransformer("baai/bge-large-en-v1.5")
|
|
513
|
+
self.embedding_model = DatabricksEmbeddings(endpoint="databricks-bge-large-en")
|
|
514
|
+
|
|
515
|
+
if chat_model is not None:
|
|
516
|
+
self.chat_model = chat_model
|
|
517
|
+
else:
|
|
518
|
+
self.chat_model = ChatDatabricks(endpoint="databricks-dbrx-instruct")
|
|
575
519
|
|
|
576
|
-
|
|
577
|
-
max_src_count = -1
|
|
520
|
+
self.question_search_index = ai_query.init_question_search_index(self.list_questions(), self.embedding_model)
|
|
578
521
|
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
max_src = r
|
|
522
|
+
def question_search(self, search_text, k: int = 1):
|
|
523
|
+
"""
|
|
524
|
+
Retrieve metadata about Kobai Questions based on user search text.
|
|
583
525
|
|
|
526
|
+
Parameters:
|
|
527
|
+
search_text (str): Text to compare against question names.
|
|
528
|
+
k (int) OPTIONAL: Number of top-k matches to return.
|
|
529
|
+
"""
|
|
584
530
|
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
concept_order.append(t["dst"])
|
|
531
|
+
question_list = ai_query.question_search(search_text, self.question_search_index, self.embedding_model, k)
|
|
532
|
+
return question_list
|
|
588
533
|
|
|
589
|
-
for c in concept_props:
|
|
590
|
-
if c not in concept_order:
|
|
591
|
-
concept_order.append(c)
|
|
592
534
|
|
|
593
|
-
row_text = concept_order[0] + " is connected to " + " and connected to ".join(concept_order[1:])
|
|
594
|
-
|
|
595
|
-
for c in row_texts:
|
|
596
|
-
row_text = row_text.replace(c, row_texts[c])
|
|
597
|
-
|
|
598
|
-
row_text = row_text[0].upper() + row_text[1:] + "."
|
|
599
|
-
return row_text
|
|
600
|
-
|
|
601
|
-
def __smart_comma_formatting(self, items):
|
|
602
|
-
if items == None:
|
|
603
|
-
return ""
|
|
604
|
-
match len(items):
|
|
605
|
-
case 0:
|
|
606
|
-
return ""
|
|
607
|
-
case 1:
|
|
608
|
-
return items[0]
|
|
609
|
-
case 2:
|
|
610
|
-
return items[0] + " and " + items[1]
|
|
611
|
-
case _:
|
|
612
|
-
return ", ".join(items[0: -1]) + " and " + items[-1]
|
|
613
|
-
|
|
614
|
-
def __get_article(self, label):
|
|
615
|
-
if label[0:1].lower() in ["a", "e", "i", "o", "u"]:
|
|
616
|
-
return "an"
|
|
617
|
-
else:
|
|
618
|
-
return "a"
|
|
619
535
|
|
|
620
536
|
########################################
|
|
621
537
|
# Tenant Questions
|
|
@@ -943,7 +859,14 @@ class TenantClient:
|
|
|
943
859
|
response = self.api_client._TenantAPI__run_get('/data-svcs/model/domain/questions/count')
|
|
944
860
|
for q in response.json()["drafts"]:
|
|
945
861
|
question_list.append({"id": q["id"], "description": q["description"]})
|
|
946
|
-
|
|
862
|
+
|
|
863
|
+
visited_ids = []
|
|
864
|
+
unique_question_list = []
|
|
865
|
+
for q in question_list:
|
|
866
|
+
if q["id"] not in visited_ids:
|
|
867
|
+
visited_ids.append(q["id"])
|
|
868
|
+
unique_question_list.append(q)
|
|
869
|
+
return unique_question_list
|
|
947
870
|
|
|
948
871
|
def get_question_id(self, label, domain_label=None):
|
|
949
872
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.2
|
|
2
2
|
Name: kobai-sdk
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.8rc5
|
|
4
4
|
Summary: A package that enables interaction with a Kobai tenant.
|
|
5
5
|
Author-email: Ryan Oattes <ryan@kobai.io>
|
|
6
6
|
License: Apache License
|
|
@@ -222,7 +222,7 @@ Requires-Dist: azure-storage-blob
|
|
|
222
222
|
Requires-Dist: langchain-core
|
|
223
223
|
Requires-Dist: langchain-community
|
|
224
224
|
Requires-Dist: langchain_openai
|
|
225
|
-
Requires-Dist:
|
|
225
|
+
Requires-Dist: databricks_langchain
|
|
226
226
|
Provides-Extra: dev
|
|
227
227
|
Requires-Dist: black; extra == "dev"
|
|
228
228
|
Requires-Dist: bumpver; extra == "dev"
|
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "kobai-sdk"
|
|
7
|
-
version = "0.2.
|
|
7
|
+
version = "0.2.8rc5"
|
|
8
8
|
description = "A package that enables interaction with a Kobai tenant."
|
|
9
9
|
readme = "README.md"
|
|
10
10
|
authors = [{ name = "Ryan Oattes", email = "ryan@kobai.io" }]
|
|
@@ -26,8 +26,8 @@ dependencies = [
|
|
|
26
26
|
"langchain-core",
|
|
27
27
|
"langchain-community",
|
|
28
28
|
"langchain_openai",
|
|
29
|
-
"
|
|
30
|
-
]
|
|
29
|
+
"databricks_langchain"
|
|
30
|
+
]
|
|
31
31
|
requires-python = ">=3.11"
|
|
32
32
|
|
|
33
33
|
[project.optional-dependencies]
|
|
@@ -1,114 +0,0 @@
|
|
|
1
|
-
from kobai import llm_config
|
|
2
|
-
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate, AIMessagePromptTemplate
|
|
3
|
-
from langchain_core.output_parsers import StrOutputParser
|
|
4
|
-
from langchain_community.chat_models import ChatDatabricks
|
|
5
|
-
from langchain.globals import set_debug
|
|
6
|
-
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
|
|
7
|
-
from langchain_openai import AzureChatOpenAI
|
|
8
|
-
|
|
9
|
-
MESSAGE_SYSTEM_TEMPLATE = """
|
|
10
|
-
You are a data analyst tasked with answering questions based on a provided data set. Please answer the questions based on the provided context below. Make sure not to make any changes to the context, if possible, when preparing answers to provide accurate responses. If the answer cannot be found in context, just politely say that you do not know, do not try to make up an answer.
|
|
11
|
-
When you receive a question from the user, answer only that one question in a concise manner. Do not elaborate with other questions.
|
|
12
|
-
"""
|
|
13
|
-
|
|
14
|
-
MESSAGE_AI_TEMPLATE = """
|
|
15
|
-
The table information is as follows:
|
|
16
|
-
{table_data}
|
|
17
|
-
"""
|
|
18
|
-
|
|
19
|
-
MESSAGE_USER_CONTEXT_TEMPLATE = """
|
|
20
|
-
The context being provided is from a table named: {table_name}
|
|
21
|
-
"""
|
|
22
|
-
|
|
23
|
-
MESSAGE_USER_QUESTION_TEMPLATE = """
|
|
24
|
-
{question}
|
|
25
|
-
"""
|
|
26
|
-
|
|
27
|
-
SIMPLE_PROMPT_TEMPLATE = f"""
|
|
28
|
-
{MESSAGE_SYSTEM_TEMPLATE}
|
|
29
|
-
|
|
30
|
-
{MESSAGE_USER_CONTEXT_TEMPLATE}
|
|
31
|
-
|
|
32
|
-
{MESSAGE_AI_TEMPLATE}
|
|
33
|
-
|
|
34
|
-
Question: {MESSAGE_USER_QUESTION_TEMPLATE}
|
|
35
|
-
"""
|
|
36
|
-
|
|
37
|
-
def followup_question(question, data, question_name, llm_config:llm_config, override_model=None):
|
|
38
|
-
|
|
39
|
-
"""
|
|
40
|
-
Use LLM to answer question in the context of provided data.
|
|
41
|
-
|
|
42
|
-
Parameters:
|
|
43
|
-
question (str): A natural language question to apply.
|
|
44
|
-
data (str): Simple dictionary-like structured data.
|
|
45
|
-
question_name (str): Dataset name for context.
|
|
46
|
-
llm_config (LLMConfig): User set LLM configurations and some default ones.
|
|
47
|
-
override_model (LangChain BaseLanguageModel) OPTIONAL: Langchain LLM or ChatModel runnable.
|
|
48
|
-
"""
|
|
49
|
-
|
|
50
|
-
set_debug(llm_config.debug)
|
|
51
|
-
|
|
52
|
-
# If override model is provided, then use the override model as chat model.
|
|
53
|
-
if override_model is not None:
|
|
54
|
-
chat_model=override_model
|
|
55
|
-
elif llm_config.llm_provider == "databricks":
|
|
56
|
-
chat_model = ChatDatabricks(
|
|
57
|
-
endpoint = llm_config.endpoint,
|
|
58
|
-
temperature = llm_config.temperature,
|
|
59
|
-
max_tokens = llm_config.max_tokens,
|
|
60
|
-
)
|
|
61
|
-
elif llm_config.llm_provider == "azure_openai":
|
|
62
|
-
if(llm_config.api_key is None):
|
|
63
|
-
# Authenticate through AZ Login or through service principal
|
|
64
|
-
# Instantiate the AzureChatOpenAI model
|
|
65
|
-
chat_model = AzureChatOpenAI(
|
|
66
|
-
azure_endpoint=llm_config.endpoint,
|
|
67
|
-
azure_deployment=llm_config.deployment,
|
|
68
|
-
azure_ad_token=llm_config.aad_token,
|
|
69
|
-
openai_api_version=llm_config.api_version,
|
|
70
|
-
temperature = llm_config.temperature,
|
|
71
|
-
max_tokens = llm_config.max_tokens,
|
|
72
|
-
)
|
|
73
|
-
else:
|
|
74
|
-
# Authenticate through API Key
|
|
75
|
-
chat_model = AzureChatOpenAI(
|
|
76
|
-
api_key = llm_config.api_key,
|
|
77
|
-
azure_endpoint=llm_config.endpoint,
|
|
78
|
-
azure_deployment=llm_config.deployment,
|
|
79
|
-
openai_api_version=llm_config.api_version,
|
|
80
|
-
temperature = llm_config.temperature,
|
|
81
|
-
max_tokens = llm_config.max_tokens,
|
|
82
|
-
)
|
|
83
|
-
else:
|
|
84
|
-
chat_model = ChatDatabricks(
|
|
85
|
-
endpoint = llm_config.endpoint,
|
|
86
|
-
temperature = llm_config.temperature,
|
|
87
|
-
max_tokens = llm_config.max_tokens,
|
|
88
|
-
)
|
|
89
|
-
|
|
90
|
-
if llm_config.use_simple_prompt:
|
|
91
|
-
prompt = PromptTemplate.from_template(SIMPLE_PROMPT_TEMPLATE)
|
|
92
|
-
else:
|
|
93
|
-
prompt = ChatPromptTemplate.from_messages(
|
|
94
|
-
[
|
|
95
|
-
SystemMessagePromptTemplate.from_template(MESSAGE_SYSTEM_TEMPLATE),
|
|
96
|
-
HumanMessagePromptTemplate.from_template(MESSAGE_USER_CONTEXT_TEMPLATE),
|
|
97
|
-
AIMessagePromptTemplate.from_template(MESSAGE_AI_TEMPLATE),
|
|
98
|
-
HumanMessagePromptTemplate.from_template(MESSAGE_USER_QUESTION_TEMPLATE)
|
|
99
|
-
]
|
|
100
|
-
)
|
|
101
|
-
|
|
102
|
-
output_parser = StrOutputParser()
|
|
103
|
-
|
|
104
|
-
chain = prompt | chat_model | output_parser
|
|
105
|
-
|
|
106
|
-
response = chain.invoke(
|
|
107
|
-
{
|
|
108
|
-
"table_name": question_name,
|
|
109
|
-
"table_data": str(data),
|
|
110
|
-
"question": question
|
|
111
|
-
}
|
|
112
|
-
)
|
|
113
|
-
|
|
114
|
-
return response
|
|
@@ -1,40 +0,0 @@
|
|
|
1
|
-
import os
|
|
2
|
-
from azure.identity import DefaultAzureCredential
|
|
3
|
-
|
|
4
|
-
class LLMConfig:
|
|
5
|
-
|
|
6
|
-
def __init__(self, deployment: str = None, api_key: str = None, max_tokens: int = 150, temperature: float = 0.1, endpoint: str = "databricks-dbrx-instruct", use_simple_prompt: bool = False, debug: bool = False,
|
|
7
|
-
llm_provider: str = "databricks", api_version: str = "2024-02-15-preview"):
|
|
8
|
-
|
|
9
|
-
"""
|
|
10
|
-
Initialize the LLMConfig
|
|
11
|
-
Parameters:
|
|
12
|
-
deployment (str): LLM against which the query is run.
|
|
13
|
-
api_key (str): The api_key used for authenticating with the LLM.
|
|
14
|
-
max_tokens (int): Maximum number of tokens that the model can generate in a single response.
|
|
15
|
-
temperature (float): Parameter that controls the randomness and creativity of the text generated by the LLM.
|
|
16
|
-
endpoint (str): The endpoint of the LLM to connect to.
|
|
17
|
-
debug (bool) OPTIONAL: Set Langchain debug for troubleshooting.
|
|
18
|
-
use_simple_prompt (bool) OPTIONAL: Simple Prompt template for a language model.
|
|
19
|
-
llm_provider (str): Provider of the LLM.
|
|
20
|
-
api_version (str): version of the LLM API that the application will use for making requests.
|
|
21
|
-
"""
|
|
22
|
-
|
|
23
|
-
self.endpoint = endpoint
|
|
24
|
-
self.deployment = deployment
|
|
25
|
-
self.api_key = api_key
|
|
26
|
-
self.api_version = api_version
|
|
27
|
-
self.use_simple_prompt = use_simple_prompt
|
|
28
|
-
self.debug = debug
|
|
29
|
-
self.llm_provider = llm_provider
|
|
30
|
-
self.max_tokens = max_tokens
|
|
31
|
-
self.temperature = temperature
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
def get_azure_ad_token(self):
|
|
35
|
-
# Get the Azure Credential
|
|
36
|
-
credential = DefaultAzureCredential()
|
|
37
|
-
# Set the API type to `azure_ad`
|
|
38
|
-
os.environ["OPENAI_API_TYPE"] = "azure_ad"
|
|
39
|
-
# Set the API_KEY to the token from the Azure credential
|
|
40
|
-
self.aad_token = credential.get_token("https://cognitiveservices.azure.com/.default").token
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|