kobai-sdk 0.3.0rc2__tar.gz → 0.3.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kobai-sdk might be problematic. Click here for more details.

Files changed (22) hide show
  1. {kobai_sdk-0.3.0rc2/kobai_sdk.egg-info → kobai_sdk-0.3.2}/PKG-INFO +50 -62
  2. {kobai_sdk-0.3.0rc2 → kobai_sdk-0.3.2}/README.md +49 -60
  3. {kobai_sdk-0.3.0rc2 → kobai_sdk-0.3.2}/kobai/ai_query.py +25 -22
  4. {kobai_sdk-0.3.0rc2 → kobai_sdk-0.3.2}/kobai/ai_rag.py +7 -16
  5. {kobai_sdk-0.3.0rc2 → kobai_sdk-0.3.2}/kobai/tenant_client.py +13 -146
  6. {kobai_sdk-0.3.0rc2 → kobai_sdk-0.3.2/kobai_sdk.egg-info}/PKG-INFO +50 -62
  7. {kobai_sdk-0.3.0rc2 → kobai_sdk-0.3.2}/kobai_sdk.egg-info/SOURCES.txt +0 -2
  8. {kobai_sdk-0.3.0rc2 → kobai_sdk-0.3.2}/kobai_sdk.egg-info/requires.txt +0 -1
  9. {kobai_sdk-0.3.0rc2 → kobai_sdk-0.3.2}/pyproject.toml +2 -3
  10. kobai_sdk-0.3.0rc2/kobai/mobi.py +0 -682
  11. kobai_sdk-0.3.0rc2/kobai/mobi_config.py +0 -16
  12. {kobai_sdk-0.3.0rc2 → kobai_sdk-0.3.2}/LICENSE +0 -0
  13. {kobai_sdk-0.3.0rc2 → kobai_sdk-0.3.2}/MANIFEST.in +0 -0
  14. {kobai_sdk-0.3.0rc2 → kobai_sdk-0.3.2}/kobai/__init__.py +0 -0
  15. {kobai_sdk-0.3.0rc2 → kobai_sdk-0.3.2}/kobai/databricks_client.py +0 -0
  16. {kobai_sdk-0.3.0rc2 → kobai_sdk-0.3.2}/kobai/demo_tenant_client.py +0 -0
  17. {kobai_sdk-0.3.0rc2 → kobai_sdk-0.3.2}/kobai/ms_authenticate.py +0 -0
  18. {kobai_sdk-0.3.0rc2 → kobai_sdk-0.3.2}/kobai/spark_client.py +0 -0
  19. {kobai_sdk-0.3.0rc2 → kobai_sdk-0.3.2}/kobai/tenant_api.py +0 -0
  20. {kobai_sdk-0.3.0rc2 → kobai_sdk-0.3.2}/kobai_sdk.egg-info/dependency_links.txt +0 -0
  21. {kobai_sdk-0.3.0rc2 → kobai_sdk-0.3.2}/kobai_sdk.egg-info/top_level.txt +0 -0
  22. {kobai_sdk-0.3.0rc2 → kobai_sdk-0.3.2}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: kobai-sdk
3
- Version: 0.3.0rc2
3
+ Version: 0.3.2
4
4
  Summary: A package that enables interaction with a Kobai tenant.
5
5
  Author-email: Ryan Oattes <ryan@kobai.io>
6
6
  License: Apache License
@@ -223,7 +223,6 @@ Requires-Dist: langchain-core
223
223
  Requires-Dist: langchain-community
224
224
  Requires-Dist: langchain_openai
225
225
  Requires-Dist: databricks_langchain
226
- Requires-Dist: sentence-transformers
227
226
  Provides-Extra: dev
228
227
  Requires-Dist: black; extra == "dev"
229
228
  Requires-Dist: bumpver; extra == "dev"
@@ -249,38 +248,51 @@ from kobai import tenant_client, spark_client, databricks_client
249
248
 
250
249
  schema = 'main.demo'
251
250
  uri = 'https://demo.kobai.io'
252
- tenant_id = '1'
253
251
  tenant_name = 'My Demo Tenant'
254
-
255
- k = tenant_client.TenantClient(tenant_name, tenant_id, uri, schema)
252
+ k = tenant_client.TenantClient(tenant_name, uri, schema)
256
253
  ```
257
254
 
258
255
  2. Authenticate with the Kobai instance:
256
+ Authentication can be performed using different methods, such as device code flow, on-behalf-of flow, or browser-based tokens.
257
+
258
+ #### Authentication via device code
259
+ Step 1: Obtain the access token from IDM (Identity and Access Management)
259
260
 
260
261
  ```python
261
- client_id = 'your_Entra_app_id_here'
262
+ from kobai import ms_authenticate
263
+
262
264
  tenant_id = 'your_Entra_directory_id_here'
265
+ client_id = 'your_Entra_app_id_here'
263
266
 
264
- k.authenticate(client_id, tenant_id)
267
+ access_token = ms_authenticate.device_code(tenant_id, client_id)
265
268
  ```
266
269
 
267
- 3. Initialize a Spark client using your current `SparkSession`, and generate semantically-rich SQL views describing this Kobai tenant:
270
+ Step 2: Use the token to retrieve the list of Kobai tenants (unless the tenant ID is already known).
268
271
 
269
272
  ```python
270
- k.spark_init_session(spark)
271
- k.spark_generate_genie_views()
273
+ tenants = k.get_tenants(id_token=access_token)
274
+ print(tenants)
272
275
  ```
273
276
 
274
- 4. Initialize a Databricks API client using your Notebook context, and create a Genie Data Rooms environment for this Kobai tenant.
277
+ Step 3: Authenticate with Kobai for the specific tenant using the IDM access token.
275
278
 
276
279
  ```python
277
- notebook_context = dbutils.notebook.entry_point.getDbutils().notebook().getContext()
278
- sql_warehouse = '8834d98a8agffa76'
280
+ kobai_tenant_id = "5c1ba715-3961-4835-8a10-6f6f963b53ff"
281
+ k.use_access_token(access_token = access_token, tenant_id=kobai_tenant_id)
282
+ ```
283
+
284
+ At this point, authentication to the Kobai tenant is successfully completed.
285
+
286
+ #### Authentication via browser token
279
287
 
280
- k.databricks_init_notebook(notebook_context, sql_warehouse)
281
- k.databricks_build_genie()
288
+ ```python
289
+ k.use_browser_token(access_token="KOBAI_ACESS_TOKEN_FROM_BROWSER")
282
290
  ```
283
291
 
292
+ #### Authentication via on-behalf-of flow
293
+ The sample code demonstrating authentication via the on-behalf-of flow will be provided, if requested.
294
+
295
+
284
296
  ## AI Functionality
285
297
  The Kobai SDK enables users to ask follow-up questions based on the results of previous queries. This functionality currently supports models hosted on Databricks and Azure OpenAI.
286
298
 
@@ -305,68 +317,41 @@ kobai_query_name = "Set ownership"
305
317
  question_json = k.run_question_remote(k.get_question_id(kobai_query_name)) # By questionName
306
318
  ```
307
319
 
308
- 3. Ask a Follow-Up Question: Based on the initial results, you can ask a follow-up question using either Azure OpenAI, Databricks or a user-provided chat model.
309
-
310
- #### Using Azure OpenAI
311
-
312
- ###### Authentication Methods:
313
-
314
- 1. ApiKey
315
-
316
- ```python
317
- from kobai import ai_query, llm_config
318
- import json
319
-
320
- followup_question = "Which owner owns the most sets?"
321
-
322
- llm_config = llm_config.LLMConfig(endpoint="https://kobaipoc.openai.azure.com/", api_key="YOUR_API_KEY", deployment="gpt-4o-mini", llm_provider="azure_openai")
323
-
324
- output = ai_query.followup_question(followup_question, json.dumps(question_json), kobai_query_name, llm_config=llm_config)
325
- print(output)
326
- ```
327
-
328
- 2. Azure Active Directory Authentication
320
+ 3. Ask a Follow-Up Question: Based on the initial results, you can ask a follow-up question using the user-provided chat and embedding model.
329
321
 
330
- Ensure that the logged-in tenant has access to Azure OpenAI.
331
- In case of databricks notebook, the logged in service principal should have access to Azure OpenAI.
322
+ #### Using Databricks Embeddings and Chat Models in a Databricks Notebook
323
+ Initialize the AI components by specifying the embedding and chat models, then proceed with follow-up questions for interactive engagement.
332
324
 
333
325
  ```python
334
- from kobai import ai_query, llm_config
326
+ from databricks_langchain import DatabricksEmbeddings
327
+ from langchain_community.chat_models import ChatDatabricks
335
328
  import json
336
329
 
337
- followup_question = "Which owner owns the most sets?"
338
-
339
- llm_config = llm_config.LLMConfig(endpoint="https://kobaipoc.openai.azure.com/", deployment="gpt-4o-mini", llm_provider="azure_openai")
340
- llm_config.get_azure_ad_token()
341
-
342
- output = ai_query.followup_question(followup_question, json.dumps(question_json), kobai_query_name, llm_config=llm_config)
343
- print(output)
344
- ```
345
-
346
- #### Using Databricks (Default Configuration)
347
-
348
- ```python
349
- from kobai import ai_query, llm_config
350
- import json
330
+ # choose the embedding and chat model of your choice from the databricks serving and initialize.
331
+ embedding_model = DatabricksEmbeddings(endpoint="databricks-bge-large-en")
332
+ chat_model = ChatDatabricks(endpoint="databricks-gpt-oss-20b")
333
+ k.init_ai_components(embedding_model=embedding_model, chat_model=chat_model)
351
334
 
352
335
  followup_question = "Which owner owns the most sets?"
353
-
354
- llm_config = llm_config.LLMConfig()
355
-
356
- output = ai_query.followup_question(followup_question, json.dumps(question_json), kobai_query_name, llm_config=llm_config)
336
+ output = k.followup_question(followup_question, question_id=k.get_question_id(kobai_query_name))
357
337
  print(output)
358
338
  ```
359
339
 
360
- #### User Provided Chat Model
340
+ #### Using Azure OpenAI Embeddings and Chat Models
361
341
 
362
342
  ```python
363
- from kobai import ai_query, llm_config
364
- import json
365
343
  from langchain_openai import AzureChatOpenAI
344
+ from langchain_openai import AzureOpenAIEmbeddings
345
+ import json
366
346
 
367
347
  followup_question = "Which owner owns the most sets?"
368
348
 
369
- llm_config = llm_config.LLMConfig(debug=True)
349
+ embedding_model = AzureOpenAIEmbeddings(
350
+ model="text-embedding-3-small",
351
+ azure_endpoint="https://kobaipoc.openai.azure.com/",
352
+ api_key="YOUR_API_KEY",
353
+ openai_api_version="2023-05-15"
354
+ )
370
355
 
371
356
  chat_model = AzureChatOpenAI(
372
357
  azure_endpoint="https://kobaipoc.openai.azure.com/", azure_deployment="gpt-4o-mini",
@@ -375,7 +360,10 @@ openai_api_version="2024-02-15-preview",
375
360
  temperature=0.5,
376
361
  max_tokens=150,)
377
362
 
378
- output = ai_query.followup_question(followup_question, json.dumps(question_json), kobai_query_name, override_model=chat_model, llm_config=llm_config)
363
+ k.init_ai_components(embedding_model=embedding_model, chat_model=chat_model)
364
+
365
+ followup_question = "Which theme has the most sets?"
366
+ output = k.followup_question(followup_question, question_id=k.get_question_id(kobai_query_name))
379
367
  print(output)
380
368
  ```
381
369
 
@@ -15,38 +15,51 @@ from kobai import tenant_client, spark_client, databricks_client
15
15
 
16
16
  schema = 'main.demo'
17
17
  uri = 'https://demo.kobai.io'
18
- tenant_id = '1'
19
18
  tenant_name = 'My Demo Tenant'
20
-
21
- k = tenant_client.TenantClient(tenant_name, tenant_id, uri, schema)
19
+ k = tenant_client.TenantClient(tenant_name, uri, schema)
22
20
  ```
23
21
 
24
22
  2. Authenticate with the Kobai instance:
23
+ Authentication can be performed using different methods, such as device code flow, on-behalf-of flow, or browser-based tokens.
24
+
25
+ #### Authentication via device code
26
+ Step 1: Obtain the access token from IDM (Identity and Access Management)
25
27
 
26
28
  ```python
27
- client_id = 'your_Entra_app_id_here'
29
+ from kobai import ms_authenticate
30
+
28
31
  tenant_id = 'your_Entra_directory_id_here'
32
+ client_id = 'your_Entra_app_id_here'
29
33
 
30
- k.authenticate(client_id, tenant_id)
34
+ access_token = ms_authenticate.device_code(tenant_id, client_id)
31
35
  ```
32
36
 
33
- 3. Initialize a Spark client using your current `SparkSession`, and generate semantically-rich SQL views describing this Kobai tenant:
37
+ Step 2: Use the token to retrieve the list of Kobai tenants (unless the tenant ID is already known).
34
38
 
35
39
  ```python
36
- k.spark_init_session(spark)
37
- k.spark_generate_genie_views()
40
+ tenants = k.get_tenants(id_token=access_token)
41
+ print(tenants)
38
42
  ```
39
43
 
40
- 4. Initialize a Databricks API client using your Notebook context, and create a Genie Data Rooms environment for this Kobai tenant.
44
+ Step 3: Authenticate with Kobai for the specific tenant using the IDM access token.
41
45
 
42
46
  ```python
43
- notebook_context = dbutils.notebook.entry_point.getDbutils().notebook().getContext()
44
- sql_warehouse = '8834d98a8agffa76'
47
+ kobai_tenant_id = "5c1ba715-3961-4835-8a10-6f6f963b53ff"
48
+ k.use_access_token(access_token = access_token, tenant_id=kobai_tenant_id)
49
+ ```
50
+
51
+ At this point, authentication to the Kobai tenant is successfully completed.
52
+
53
+ #### Authentication via browser token
45
54
 
46
- k.databricks_init_notebook(notebook_context, sql_warehouse)
47
- k.databricks_build_genie()
55
+ ```python
56
+ k.use_browser_token(access_token="KOBAI_ACESS_TOKEN_FROM_BROWSER")
48
57
  ```
49
58
 
59
+ #### Authentication via on-behalf-of flow
60
+ The sample code demonstrating authentication via the on-behalf-of flow will be provided, if requested.
61
+
62
+
50
63
  ## AI Functionality
51
64
  The Kobai SDK enables users to ask follow-up questions based on the results of previous queries. This functionality currently supports models hosted on Databricks and Azure OpenAI.
52
65
 
@@ -71,68 +84,41 @@ kobai_query_name = "Set ownership"
71
84
  question_json = k.run_question_remote(k.get_question_id(kobai_query_name)) # By questionName
72
85
  ```
73
86
 
74
- 3. Ask a Follow-Up Question: Based on the initial results, you can ask a follow-up question using either Azure OpenAI, Databricks or a user-provided chat model.
75
-
76
- #### Using Azure OpenAI
77
-
78
- ###### Authentication Methods:
79
-
80
- 1. ApiKey
81
-
82
- ```python
83
- from kobai import ai_query, llm_config
84
- import json
85
-
86
- followup_question = "Which owner owns the most sets?"
87
-
88
- llm_config = llm_config.LLMConfig(endpoint="https://kobaipoc.openai.azure.com/", api_key="YOUR_API_KEY", deployment="gpt-4o-mini", llm_provider="azure_openai")
89
-
90
- output = ai_query.followup_question(followup_question, json.dumps(question_json), kobai_query_name, llm_config=llm_config)
91
- print(output)
92
- ```
93
-
94
- 2. Azure Active Directory Authentication
87
+ 3. Ask a Follow-Up Question: Based on the initial results, you can ask a follow-up question using the user-provided chat and embedding model.
95
88
 
96
- Ensure that the logged-in tenant has access to Azure OpenAI.
97
- In case of databricks notebook, the logged in service principal should have access to Azure OpenAI.
89
+ #### Using Databricks Embeddings and Chat Models in a Databricks Notebook
90
+ Initialize the AI components by specifying the embedding and chat models, then proceed with follow-up questions for interactive engagement.
98
91
 
99
92
  ```python
100
- from kobai import ai_query, llm_config
93
+ from databricks_langchain import DatabricksEmbeddings
94
+ from langchain_community.chat_models import ChatDatabricks
101
95
  import json
102
96
 
103
- followup_question = "Which owner owns the most sets?"
104
-
105
- llm_config = llm_config.LLMConfig(endpoint="https://kobaipoc.openai.azure.com/", deployment="gpt-4o-mini", llm_provider="azure_openai")
106
- llm_config.get_azure_ad_token()
107
-
108
- output = ai_query.followup_question(followup_question, json.dumps(question_json), kobai_query_name, llm_config=llm_config)
109
- print(output)
110
- ```
111
-
112
- #### Using Databricks (Default Configuration)
113
-
114
- ```python
115
- from kobai import ai_query, llm_config
116
- import json
97
+ # choose the embedding and chat model of your choice from the databricks serving and initialize.
98
+ embedding_model = DatabricksEmbeddings(endpoint="databricks-bge-large-en")
99
+ chat_model = ChatDatabricks(endpoint="databricks-gpt-oss-20b")
100
+ k.init_ai_components(embedding_model=embedding_model, chat_model=chat_model)
117
101
 
118
102
  followup_question = "Which owner owns the most sets?"
119
-
120
- llm_config = llm_config.LLMConfig()
121
-
122
- output = ai_query.followup_question(followup_question, json.dumps(question_json), kobai_query_name, llm_config=llm_config)
103
+ output = k.followup_question(followup_question, question_id=k.get_question_id(kobai_query_name))
123
104
  print(output)
124
105
  ```
125
106
 
126
- #### User Provided Chat Model
107
+ #### Using Azure OpenAI Embeddings and Chat Models
127
108
 
128
109
  ```python
129
- from kobai import ai_query, llm_config
130
- import json
131
110
  from langchain_openai import AzureChatOpenAI
111
+ from langchain_openai import AzureOpenAIEmbeddings
112
+ import json
132
113
 
133
114
  followup_question = "Which owner owns the most sets?"
134
115
 
135
- llm_config = llm_config.LLMConfig(debug=True)
116
+ embedding_model = AzureOpenAIEmbeddings(
117
+ model="text-embedding-3-small",
118
+ azure_endpoint="https://kobaipoc.openai.azure.com/",
119
+ api_key="YOUR_API_KEY",
120
+ openai_api_version="2023-05-15"
121
+ )
136
122
 
137
123
  chat_model = AzureChatOpenAI(
138
124
  azure_endpoint="https://kobaipoc.openai.azure.com/", azure_deployment="gpt-4o-mini",
@@ -141,7 +127,10 @@ openai_api_version="2024-02-15-preview",
141
127
  temperature=0.5,
142
128
  max_tokens=150,)
143
129
 
144
- output = ai_query.followup_question(followup_question, json.dumps(question_json), kobai_query_name, override_model=chat_model, llm_config=llm_config)
130
+ k.init_ai_components(embedding_model=embedding_model, chat_model=chat_model)
131
+
132
+ followup_question = "Which theme has the most sets?"
133
+ output = k.followup_question(followup_question, question_id=k.get_question_id(kobai_query_name))
145
134
  print(output)
146
135
  ```
147
136
 
@@ -1,8 +1,6 @@
1
1
  from langchain_core.prompts import ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate, AIMessagePromptTemplate
2
2
  from langchain_core.output_parsers import StrOutputParser
3
3
 
4
- from sentence_transformers import SentenceTransformer, util
5
-
6
4
  from langchain_core.language_models.chat_models import BaseChatModel
7
5
  from langchain_core.embeddings import Embeddings
8
6
  from langchain_core.documents import Document
@@ -10,8 +8,9 @@ from langchain_core.retrievers import BaseRetriever
10
8
  from langchain_core.callbacks import CallbackManagerForRetrieverRun
11
9
  from langchain_core.runnables import RunnablePassthrough, RunnableLambda
12
10
  from langchain_core.vectorstores import InMemoryVectorStore
11
+ import numpy as np
13
12
 
14
- from typing import Union, List
13
+ from typing import List
15
14
 
16
15
 
17
16
  MESSAGE_SYSTEM_TEMPLATE = """
@@ -73,7 +72,7 @@ def format_docs(docs):
73
72
  def input_only(inpt):
74
73
  return inpt["question"]
75
74
 
76
- def followup_question(user_question, question_results, question_name, question_def, embedding_model: Union[SentenceTransformer, Embeddings], chat_model: BaseChatModel, use_inmem_vectors=False, k=50):
75
+ def followup_question(user_question, question_results, question_name, question_def, embedding_model: Embeddings, chat_model: BaseChatModel, use_inmem_vectors=False, k=50):
77
76
 
78
77
  row_texts = process_question_results(question_def, question_results)
79
78
  question_documents = [Document(page_content=r, metadata={"source": "kobai"}) for r in row_texts]
@@ -118,22 +117,13 @@ def init_question_search_index(tenant_questions, emb_model):
118
117
 
119
118
  q_ids = [q["id"] for q in tenant_questions]
120
119
  q_descs = [q["description"] for q in tenant_questions]
121
-
122
- if isinstance(emb_model, SentenceTransformer):
123
- q_vectors = emb_model.encode(q_descs)
124
- else:
125
- q_vectors = emb_model.embed_documents(q_descs)
126
-
120
+ q_vectors = emb_model.embed_documents(q_descs)
127
121
  return {"ids": q_ids, "descs": q_descs, "vectors": q_vectors}
128
122
 
129
123
 
130
124
  def question_search(search_text: str, search_index, emb_model, k: int):
131
- if isinstance(emb_model, SentenceTransformer):
132
- search_vec = emb_model.encode(search_text)
133
- else:
134
- search_vec = emb_model.embed_query(search_text)
125
+ search_vec = emb_model.embed_query(search_text)
135
126
  #search_vec = emb_model.encode(search_text)
136
-
137
127
  matches = __top_vector_matches(search_vec, search_index["vectors"], top=k)
138
128
 
139
129
  for mi, m in enumerate(matches):
@@ -142,13 +132,25 @@ def question_search(search_text: str, search_index, emb_model, k: int):
142
132
  return matches
143
133
 
144
134
  def __top_vector_matches(test_vec, options_list_vec, top=1):
145
- scores_t = util.cos_sim(test_vec, options_list_vec)[0]
146
- scores_l = scores_t.tolist()
147
- scores_d = [{"index": i, "value": v} for i, v in enumerate(scores_l)]
148
- sorted_d = sorted(scores_d, key=lambda i: i["value"], reverse=True)
149
- top_d = sorted_d[0:top]
135
+ # Normalize the test vector
136
+ test_vec_norm = test_vec / np.linalg.norm(test_vec)
137
+ # Normalize the option vectors
138
+ options_norm = options_list_vec / np.linalg.norm(options_list_vec, axis=1, keepdims=True)
139
+
140
+ # Compute cosine similarity (dot product of normalized vectors)
141
+ cosine_similarities = np.dot(options_norm, test_vec_norm)
142
+
143
+ # Get indexes and similarity scores as dict
144
+ scores_d = [{"index": i, "value": float(v)} for i, v in enumerate(cosine_similarities)]
145
+
146
+ # Sort dict by similarity score descending
147
+ sorted_d = sorted(scores_d, key=lambda x: x["value"], reverse=True)
148
+
149
+ # Return top results
150
+ top_d = sorted_d[:top]
150
151
  return top_d
151
152
 
153
+
152
154
  def process_question_results(question_def, question_results):
153
155
 
154
156
  """
@@ -211,8 +213,9 @@ def process_question_results(question_def, question_results):
211
213
 
212
214
 
213
215
  concept_order = [max_src]
214
- for t in concept_rels[max_src]["edges"]:
215
- concept_order.append(t["dst"])
216
+ if max_src != "":
217
+ for t in concept_rels[max_src]["edges"]:
218
+ concept_order.append(t["dst"])
216
219
 
217
220
  for c in concept_props:
218
221
  if c not in concept_order:
@@ -3,9 +3,7 @@ from pyspark.sql import SparkSession
3
3
 
4
4
  from pyspark.sql.types import StructType, StructField, StringType, ArrayType, FloatType, IntegerType
5
5
  from pyspark.sql import functions as F
6
- from sentence_transformers import SentenceTransformer
7
6
  from delta import DeltaTable
8
- from typing import Union
9
7
  from langchain_core.language_models.chat_models import BaseChatModel
10
8
  from langchain_core.embeddings import Embeddings
11
9
  from langchain_community.document_loaders import PySparkDataFrameLoader
@@ -145,13 +143,13 @@ def __generate_sentences_from_questions(tc: AIContext, debug):
145
143
  ss.sql(full_sql)
146
144
 
147
145
 
148
- def encode_to_delta_local(tc: AIContext, st_model: Union[SentenceTransformer, Embeddings], replica_schema=None, batch_size=100000):
146
+ def encode_to_delta_local(tc: AIContext, st_model: Embeddings, replica_schema=None, batch_size=100000):
149
147
  """
150
148
  Encode Semantic Data to Vectors in Delta Table
151
149
 
152
150
  Parameters:
153
151
  tc (TenantClient): The Kobai tenant_client instance instantiated via the SDK.
154
- st_model (SentenceTransformer): A sentence_transformers model to use for encoding.
152
+ st_model (Embeddings): A langchain embedding model to use for encoding.
155
153
  replica_schema (str) OPTIONAL: An alternate schema (catalog.database) to create the Delta table. Useful when the base Kobai schema is not on a Unity Catalog.
156
154
  """
157
155
 
@@ -174,12 +172,8 @@ def encode_to_delta_local(tc: AIContext, st_model: Union[SentenceTransformer, Em
174
172
  content_list = [r["content"] for r in sentences_df.collect()]
175
173
  id_list = [r["id"] for r in sentences_df.collect()]
176
174
 
177
- if isinstance(st_model, SentenceTransformer):
178
- vector_list = st_model.encode(
179
- content_list, normalize_embeddings=True, show_progress_bar=True).tolist()
180
- else:
181
- vector_list = st_model.embed_documents(content_list)
182
- for i, v in enumerate(vector_list):
175
+ vector_list = st_model.embed_documents(content_list)
176
+ for i, v in enumerate(vector_list):
183
177
  vector_list[i] = [float(x) for x in v]
184
178
  #vector_list = st_model.encode(
185
179
  # content_list, normalize_embeddings=True, show_progress_bar=True)
@@ -214,13 +208,13 @@ def encode_to_delta_local(tc: AIContext, st_model: Union[SentenceTransformer, Em
214
208
  # """)
215
209
 
216
210
 
217
- def rag_delta(tc: AIContext, emb_model: Union[SentenceTransformer, Embeddings], chat_model: BaseChatModel, question, k=5, replica_schema=None):
211
+ def rag_delta(tc: AIContext, emb_model: Embeddings, chat_model: BaseChatModel, question, k=5, replica_schema=None):
218
212
  """
219
213
  Run a RAG query using vectors in Delta table.
220
214
 
221
215
  Parameters:
222
216
  tc (TenantClient): The Kobai tenant_client instance instantiated via the SDK.
223
- emb_model (UNION[SentenceTransformer, Embeddings]): A sentence_transformers or langchain embedding model to use for encoding the query.
217
+ emb_model (Embeddings): A langchain embedding model to use for encoding the query.
224
218
  chat_model (BaseChatModel): A langchain chat model to use in the RAG pipeline.
225
219
  question (str): The user's query.
226
220
  k (int) OPTIONAL: The number of RAG documents to retrieve.
@@ -233,10 +227,7 @@ def rag_delta(tc: AIContext, emb_model: Union[SentenceTransformer, Embeddings],
233
227
 
234
228
  ss = tc.spark_session
235
229
 
236
- if isinstance(emb_model, SentenceTransformer):
237
- vector_list = emb_model.encode(
238
- question, normalize_embeddings=True).tolist()
239
- elif isinstance(emb_model, Embeddings):
230
+ if isinstance(emb_model, Embeddings):
240
231
  vector_list = emb_model.embed_query(question)
241
232
  else:
242
233
  print("Invalid Embedding Model Type")