kobai-sdk 0.2.8rc1__py3-none-any.whl → 0.2.8rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kobai-sdk might be problematic. Click here for more details.

kobai/ai_rag.py CHANGED
@@ -1,6 +1,32 @@
1
1
  from kobai import tenant_client
2
2
 
3
- def generate_sentences(tc: tenant_client.TenantClient, replica_schema=None):
3
+ from pyspark.sql.types import StructType, StructField, StringType, ArrayType, FloatType, IntegerType
4
+ from pyspark.sql import functions as F
5
+ from sentence_transformers import SentenceTransformer
6
+ from delta import DeltaTable
7
+ from typing import Union
8
+ from langchain_core.language_models.chat_models import BaseChatModel
9
+ from langchain_core.embeddings import Embeddings
10
+ from langchain_community.document_loaders import PySparkDataFrameLoader
11
+ from langchain import hub
12
+ from langchain_core.output_parsers import StrOutputParser
13
+
14
+
15
+
16
+
17
+
18
+ def generate_sentences(tc: tenant_client.TenantClient, replica_schema=None, concept_white_list=None, use_questions=False):
19
+
20
+ """
21
+ Extract Semantic Data from Graph to Delta Table
22
+
23
+ Parameters:
24
+ tc (TenantClient): The Kobai tenant_client instance instantiated via the SDK.
25
+ replica_schema (str) OPTIONAL: An alternate schema (catalog.database) to create the Delta table. Useful when the base Kobai schema is not on a Unity Catalog.
26
+ concept_white_list ([str]) OPTIONAL: A list of Domain and Concept names for extraction.
27
+ use_questions (bool) OPTIONAL: Extract facts from published Kobai questions.
28
+ """
29
+
4
30
  if tc.spark_client is None:
5
31
  return None
6
32
 
@@ -9,7 +35,7 @@ def generate_sentences(tc: tenant_client.TenantClient, replica_schema=None):
9
35
  print("Getting Tenant Config")
10
36
  tenant_json = tc.get_tenant_config()
11
37
 
12
- concepts = __get_concept_metadata(tenant_json, tc.schema, tc.model_id)
38
+ concepts = __get_concept_metadata(tenant_json, tc.schema, tc.model_id, concept_white_list)
13
39
 
14
40
  print("Dropping and Recreating the RAG Table")
15
41
  ss.sql(__create_rag_table_sql(tc.schema, tc.model_id))
@@ -23,11 +49,193 @@ def generate_sentences(tc: tenant_client.TenantClient, replica_schema=None):
23
49
  for sql_statement in sql_statements:
24
50
  ss.sql(sql_statement)
25
51
 
52
+ if use_questions:
53
+ __generate_sentences_from_questions(tc)
54
+
26
55
  if replica_schema is not None:
27
56
  print("Replicating Schema")
28
57
  ss.sql(__create_rag_table_sql(replica_schema, tc.model_id))
29
58
  ss.sql(__replicate_to_catalog_sql(tc.schema, replica_schema, tc.model_id))
30
59
 
60
+ def __generate_sentences_from_questions(tc: tenant_client.TenantClient):
61
+ ss = tc.spark_client.spark_session
62
+
63
+ print("Getting Question Data")
64
+
65
+ tenant_json = tc.get_tenant_config()
66
+
67
+ published_queries = []
68
+ for p in tenant_json["publishedAPIs"]:
69
+ published_queries.append(p["queryId"])
70
+
71
+ question_names = {}
72
+ for q in tenant_json["queries"]:
73
+ if q["id"] in published_queries:
74
+ question_names[q["id"]] = q["description"]
75
+
76
+ schemaV = StructType([
77
+ StructField("sentence",StringType(),True),
78
+ StructField("query_id", StringType(), True)
79
+ ])
80
+
81
+ sentences = []
82
+ for p in published_queries:
83
+ output = tc.run_question_remote(p)
84
+ for r in output:
85
+ sentence = f"For {question_names[p]}: "
86
+ for c in r:
87
+ sentence += f"The {c.replace('_', ' ')} is {r[c]}. "
88
+ sentences.append([sentence, p])
89
+
90
+
91
+ sentences_df = ss.createDataFrame(sentences, schemaV)
92
+ sentences_df = sentences_df.select(
93
+ F.col("sentence").alias("sentence"),
94
+ F.col("query_id").alias("concept_id"),
95
+ F.lit("q").alias("type"),
96
+ )
97
+
98
+ schema = tc.schema
99
+
100
+ view_name = f"rag_{tc.model_id}_question_sentences"
101
+ sentences_df.createOrReplaceTempView(view_name)
102
+
103
+ full_sql = f"INSERT INTO {schema}.rag_{tc.model_id} (content, concept_id, type)"
104
+ full_sql += f" SELECT sentence, concept_id, type FROM {view_name}"
105
+
106
+ ss.sql(full_sql)
107
+
108
+
109
+
110
+ def encode_to_delta_local(tc: tenant_client.TenantClient, st_model: SentenceTransformer, replica_schema=None):
111
+
112
+ """
113
+ Encode Semantic Data to Vectors in Delta Table
114
+
115
+ Parameters:
116
+ tc (TenantClient): The Kobai tenant_client instance instantiated via the SDK.
117
+ st_model (SentenceTransformer): A sentence_transformers model to use for encoding.
118
+ replica_schema (str) OPTIONAL: An alternate schema (catalog.database) to create the Delta table. Useful when the base Kobai schema is not on a Unity Catalog.
119
+ """
120
+
121
+ if tc.spark_client is None:
122
+ return None
123
+
124
+ ss = tc.spark_client.spark_session
125
+
126
+ schema = tc.schema
127
+ if replica_schema is not None:
128
+ schema = replica_schema
129
+
130
+ sentences_sql = f"SELECT content FROM {schema}.rag_{tc.model_id}"
131
+ sentences_df = ss.sql(sentences_sql)
132
+
133
+ num_records = sentences_df.count()
134
+ query_batch_size = 100000
135
+
136
+
137
+ for x in range(0, num_records, query_batch_size):
138
+ print(f"Running Batch Starting at {x}")
139
+ sentences_sql = f" SELECT id, content FROM {schema}.rag_{tc.model_id} ORDER BY id LIMIT {str(query_batch_size)} OFFSET {str(x)}"
140
+ sentences_df = ss.sql(sentences_sql)
141
+ content_list = [r["content"] for r in sentences_df.collect()]
142
+ id_list = [r["id"] for r in sentences_df.collect()]
143
+
144
+ vector_list = st_model.encode(content_list, normalize_embeddings=True, show_progress_bar=True)
145
+
146
+ schemaV = StructType([
147
+ StructField("id",IntegerType(),True),
148
+ StructField("vector", ArrayType(FloatType()), False)
149
+ ])
150
+
151
+ updated_list = [[r[0], r[1].tolist()] for r in zip(id_list, vector_list)]
152
+ updated_df = ss.createDataFrame(updated_list, schemaV)
153
+
154
+ target_table = DeltaTable.forName(ss, f"{schema}.rag_{tc.model_id}")
155
+
156
+ target_table.alias("t") \
157
+ .merge(
158
+ updated_df.alias("s"),
159
+ 't.id = s.id'
160
+ ) \
161
+ .whenMatchedUpdate(set = {"vector": "s.vector"}) \
162
+ .execute()
163
+
164
+ ss.sql(f"""
165
+ CREATE FUNCTION IF NOT EXISTS {schema}.cos_sim(a ARRAY<FLOAT>, b ARRAY<FLOAT>)
166
+ RETURNS FLOAT
167
+ LANGUAGE PYTHON
168
+ AS $$
169
+ import numpy as np
170
+ return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)))
171
+ $$
172
+ """)
173
+
174
+ def rag_delta(tc: tenant_client.TenantClient, emb_model: Union[SentenceTransformer, Embeddings], chat_model: BaseChatModel, question, k=5, replica_schema=None):
175
+
176
+ """
177
+ Run a RAG query using vectors in Delta table.
178
+
179
+ Parameters:
180
+ tc (TenantClient): The Kobai tenant_client instance instantiated via the SDK.
181
+ emb_model (UNION[SentenceTransformer, Embeddings]): A sentence_transformers or langchain embedding model to use for encoding the query.
182
+ chat_model (BaseChatModel): A langchain chat model to use in the RAG pipeline.
183
+ question (str): The user's query.
184
+ k (int) OPTIONAL: The number of RAG documents to retrieve.
185
+ replica_schema (str) OPTIONAL: An alternate schema (catalog.database) to create the Delta table. Useful when the base Kobai schema is not on a Unity Catalog.
186
+ """
187
+
188
+ schema = tc.schema
189
+ if replica_schema is not None:
190
+ schema = replica_schema
191
+
192
+ if tc.spark_client is None:
193
+ print("Instantiate Spark Client First")
194
+ return None
195
+
196
+ ss = tc.spark_client.spark_session
197
+
198
+ if isinstance(emb_model, SentenceTransformer):
199
+ vector_list = emb_model.encode(question, normalize_embeddings=True).tolist()
200
+ elif isinstance(emb_model, Embeddings):
201
+ vector_list = emb_model.embed_query(question)
202
+ else:
203
+ print("Invalid Embedding Model Type")
204
+ return None
205
+
206
+ if not isinstance(chat_model, BaseChatModel):
207
+ print("Invalid Chat Model Type")
208
+ return None
209
+
210
+ vector_list = [str(x) for x in vector_list]
211
+ vector_sql = ", ".join(vector_list)
212
+
213
+ results = ss.sql(f"""
214
+ SELECT content, reduce(zip_with(vector, cast(array({vector_sql}) as array<float>), (x,y) -> x*y), float(0.0), (acc,x) -> acc + x) score
215
+ FROM {schema}.rag_{tc.model_id}
216
+ ORDER BY score DESC
217
+ LIMIT {k}
218
+ """)
219
+
220
+ loader = PySparkDataFrameLoader(ss, results, page_content_column="content")
221
+ documents = loader.load()
222
+ docs_content = "\n\n".join(doc.page_content for doc in documents)
223
+
224
+ prompt = hub.pull("rlm/rag-prompt")
225
+
226
+ output_parser = StrOutputParser()
227
+
228
+ chain = prompt | chat_model | output_parser
229
+
230
+ response = chain.invoke(
231
+ {
232
+ "context": docs_content,
233
+ "question": question
234
+ }
235
+ )
236
+
237
+ return response
238
+
31
239
  def __create_rag_table_sql(schema, model_id):
32
240
  return f"CREATE OR REPLACE TABLE {schema}.rag_{model_id} (id BIGINT GENERATED BY DEFAULT AS IDENTITY, content STRING, type string, concept_id string, vector ARRAY<FLOAT>) TBLPROPERTIES (delta.enableChangeDataFeed = true)"
33
241
 
@@ -40,74 +248,99 @@ def __generate_sentence_sql_concept_literals(concepts, schema, model_id):
40
248
  statements = []
41
249
  for con in concepts:
42
250
  sql = f"'This is a {con['label']}. '"
43
- sql += " || 'It is identified by ' || split(cid._conceptid,'#')[1] || '. '"
251
+ sql += " || 'It is identified by ' || cid._plain_conceptid || '. '"
44
252
 
45
- sql_from = f"{con['con_table_name']} cid"
253
+ sql_from = f"(SELECT _conceptid, _plain_conceptid FROM {con['prop_table_name']} GROUP BY _conceptid, _plain_conceptid) cid"
46
254
  for prop in con["properties"]:
47
255
 
48
- sql_from += f" INNER JOIN {con['prop_table_name']} AS {prop['label']}"
256
+ sql_from += f" LEFT JOIN {con['prop_table_name']} AS {prop['label']}"
49
257
  sql_from += f" ON cid._conceptid = {prop['label']}._conceptid"
50
258
  sql_from += f" AND {prop['label']}.type = 'l'"
51
259
  sql_from += f" AND {prop['label']}.name = '{prop['name']}'"
52
260
 
53
- sql += f" || 'The {prop['label']} is ' || any_value({prop['label']}.value) IGNORE NULLS || '. '"
54
-
55
- full_sql = f"INSERT INTO {schema}.rag_{model_id} (content, concept_id, type)"
56
- full_sql += f" SELECT {sql} content, cid._conceptid concept_id, 'c' type FROM {sql_from} GROUP BY cid._conceptid"
57
-
58
- statements.append(full_sql)
59
- #test_df = spark.sql(full_sql)
261
+ sql += f" || 'The {prop['label']} is ' || ifnull(any_value({prop['label']}.value) IGNORE NULLS, 'unknown') || '. '"
262
+
263
+ full_sql = f"INSERT INTO {schema}.rag_{model_id} (content, concept_id, type)"
264
+ full_sql += f" SELECT {sql} content, cid._conceptid concept_id, 'c' type FROM {sql_from} GROUP BY cid._conceptid, cid._plain_conceptid"
265
+
266
+ statements.append(full_sql)
60
267
  return statements
61
268
 
62
269
  def __generate_sentence_sql_concept_relations(concepts, schema, model_id):
63
270
  statements = []
64
271
  for con in concepts:
65
-
66
- sql_from = f"{con['prop_table_name']} "
67
272
  for rel in con["relations"]:
273
+ sql_from = f"{con['prop_table_name']} rel"
274
+ sql_from += f" INNER JOIN (SELECT _conceptid, _plain_conceptid FROM {rel['target_table_name']} GROUP BY _conceptid, _plain_conceptid) cid"
275
+ sql_from += f" ON rel.value = cid._conceptid"
276
+ sql_from += f" AND rel.type = 'r'"
277
+ sql_from += f" AND rel.name = '{rel['name']}'"
68
278
 
69
- sql = f"'The {con['label']} identified by ' || split(_conceptid,'#')[1]"
279
+ sql = f"'The {con['label']} identified by ' || rel._plain_conceptid"
70
280
  sql += f" || ' has a relationship called {rel['label']} that connects it to one or more {rel['target_con_label']} identified by '"
71
- sql += " || concat_ws(', ', array_agg(split(value, '#')[1])) || '. '"
281
+ sql += " || concat_ws(', ', array_agg(cid._plain_conceptid)) || '. '"
72
282
 
73
283
 
74
284
  full_sql = f"INSERT INTO {schema}.rag_{model_id} (content, concept_id, type)"
75
- full_sql += f" SELECT {sql} content, _conceptid concept_id, 'e' type FROM {sql_from} GROUP BY _conceptid"
285
+ full_sql += f" SELECT {sql} content, rel._conceptid concept_id, 'e' type FROM {sql_from} GROUP BY rel._conceptid, rel._plain_conceptid"
76
286
 
77
287
  statements.append(full_sql)
78
288
  return statements
79
289
 
80
- def __get_concept_metadata(tenant_json, schema, model_id):
290
+ def __get_concept_metadata(tenant_json, schema, model_id, whitelist):
81
291
  target_concept_labels = {}
292
+ target_table_names = {}
82
293
  for d in tenant_json["domains"]:
83
294
  for c in d["concepts"]:
84
295
  target_concept_labels[c["uri"]] = d["name"] + " " + c["label"]
85
-
296
+ target_table_names[c["uri"]] = {
297
+ "prop": f"{schema}.data_{model_id}_{d['name']}_{c['label']}_np",
298
+ "con": f"{schema}.data_{model_id}_{d['name']}_{c['label']}_c"
299
+ }
300
+
86
301
  concepts = []
87
-
88
302
  for d in tenant_json["domains"]:
89
303
  for c in d["concepts"]:
90
304
  con_props = []
91
305
  for col in c["properties"]:
92
306
  con_props.append({
93
- #"col_name": d["name"] + "_" + c["label"] + "_" + col["label"],
94
307
  "label": col["label"],
95
308
  "name": f"{model_id}/{d['name']}/{c['label']}#{col['label']}"
96
309
  })
97
310
  con_rels = []
98
311
  for rel in c["relations"]:
312
+ if whitelist is not None and target_concept_labels[rel["relationTypeUri"]] not in whitelist:
313
+ continue
99
314
  con_rels.append({
100
315
  "label": rel["label"],
101
316
  "name": f"{model_id}/{d['name']}/{c['label']}#{rel['label']}",
102
- "target_con_label": target_concept_labels[rel["relationTypeUri"]]
317
+ "target_con_label": target_concept_labels[rel["relationTypeUri"]],
318
+ "target_table_name": target_table_names[rel["relationTypeUri"]]["prop"]
103
319
  })
320
+ con_parents = []
321
+ for p in c["inheritedConcepts"]:
322
+ con_parents.append(p)
104
323
  concepts.append({
324
+ "uri": c["uri"],
105
325
  "label": d["name"] + " " + c["label"],
106
- #"id_column": d["name"] + "_" + c["label"],
107
326
  "relations": con_rels,
108
327
  "properties": con_props,
109
- #"table_name": "data_" + k.model_id + "_" + d["name"] + "_" + c["label"] + "_w",
110
- "prop_table_name": f"{schema}.data_{model_id}_{d['name']}_{c['label']}_np",
111
- "con_table_name": f"{schema}.data_{model_id}_{d['name']}_{c['label']}_c",
328
+ "parents": con_parents,
329
+ "prop_table_name": target_table_names[c["uri"]]["prop"],
330
+ "con_table_name": target_table_names[c["uri"]]["con"]
112
331
  })
113
- return concepts
332
+
333
+ for ci, c in enumerate(concepts):
334
+ if len(c["parents"]) > 0:
335
+ for p in c["parents"]:
336
+ for a in concepts:
337
+ if a["uri"] == p:
338
+ concepts[ci]["properties"].extend(a["properties"])
339
+
340
+ out_concepts = []
341
+ for c in concepts:
342
+ if whitelist is not None and c["label"] not in whitelist:
343
+ continue
344
+ out_concepts.append(c)
345
+
346
+ return out_concepts
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: kobai-sdk
3
- Version: 0.2.8rc1
3
+ Version: 0.2.8rc3
4
4
  Summary: A package that enables interaction with a Kobai tenant.
5
5
  Author-email: Ryan Oattes <ryan@kobai.io>
6
6
  License: Apache License
@@ -211,7 +211,7 @@ Classifier: Intended Audience :: Developers
211
211
  Classifier: License :: OSI Approved :: Apache Software License
212
212
  Classifier: Programming Language :: Python
213
213
  Classifier: Programming Language :: Python :: 3
214
- Requires-Python: >=3.9
214
+ Requires-Python: >=3.11
215
215
  Description-Content-Type: text/markdown
216
216
  License-File: LICENSE
217
217
  Requires-Dist: pyspark
@@ -222,6 +222,7 @@ Requires-Dist: azure-storage-blob
222
222
  Requires-Dist: langchain-core
223
223
  Requires-Dist: langchain-community
224
224
  Requires-Dist: langchain_openai
225
+ Requires-Dist: sentence_transformers
225
226
  Provides-Extra: dev
226
227
  Requires-Dist: black; extra == "dev"
227
228
  Requires-Dist: bumpver; extra == "dev"
@@ -1,14 +1,14 @@
1
1
  kobai/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  kobai/ai_query.py,sha256=fMTcfj-6Ma3FRB08VYEDj8PwOEOtFGsJHyQrha5yvPg,4512
3
- kobai/ai_rag.py,sha256=y_N7qVu8HfUHHZPIyQSO7L995RBeNtDhva7U5HBHSfY,5063
3
+ kobai/ai_rag.py,sha256=TtUbUcSN9mIsauGyS_nw8j58T9jEd4OFiAwNvzo-rr8,13593
4
4
  kobai/databricks_client.py,sha256=fyqqMly2Qm0r1AHWsQjkYeNsDdH0G1JSgTkF9KJ55qA,2118
5
5
  kobai/demo_tenant_client.py,sha256=wlNc-bdI2wotRXo8ppUOalv4hYdBlek_WzJNARZV-AE,9293
6
6
  kobai/llm_config.py,sha256=ZFx81cUAOHYZgRoTkTY-utQYaWYlmR8773ZJpj74C1A,1900
7
7
  kobai/spark_client.py,sha256=opM_F-4Ut5Hq5zZjWMuLvUps9sDULvyPNZHXGL8dW1k,776
8
8
  kobai/tenant_api.py,sha256=9U6UbxpaAb-kpbuADXx3kbkNKaOzYy0I-GGwbpiCCOk,4212
9
9
  kobai/tenant_client.py,sha256=AyJ5R2oukEv3q1dcItpojvTUVp5-gwUKvyGjofjBKyc,41821
10
- kobai_sdk-0.2.8rc1.dist-info/LICENSE,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
11
- kobai_sdk-0.2.8rc1.dist-info/METADATA,sha256=nZTb2svQk01wT32zBZDPKgeYnSAx22YER5YLHEIjoAQ,19167
12
- kobai_sdk-0.2.8rc1.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
13
- kobai_sdk-0.2.8rc1.dist-info/top_level.txt,sha256=ns1El3BrTTHKvoAgU1XtiSaVIudYeCXbEEUVY8HFDZ4,6
14
- kobai_sdk-0.2.8rc1.dist-info/RECORD,,
10
+ kobai_sdk-0.2.8rc3.dist-info/LICENSE,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
11
+ kobai_sdk-0.2.8rc3.dist-info/METADATA,sha256=f75oEdxRWLrr0bVmH1OvIlvc0KS9TrpNTh65eTlKX6k,19205
12
+ kobai_sdk-0.2.8rc3.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
13
+ kobai_sdk-0.2.8rc3.dist-info/top_level.txt,sha256=ns1El3BrTTHKvoAgU1XtiSaVIudYeCXbEEUVY8HFDZ4,6
14
+ kobai_sdk-0.2.8rc3.dist-info/RECORD,,