MindsDB 25.7.2.0__py3-none-any.whl → 25.7.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of MindsDB might be problematic. Click here for more details.

@@ -7,7 +7,7 @@ from mindsdb_sql_parser import ast
7
7
  from mindsdb.integrations.utilities.handlers.query_utilities import (
8
8
  SELECTQueryParser,
9
9
  SELECTQueryExecutor,
10
- INSERTQueryParser
10
+ INSERTQueryParser,
11
11
  )
12
12
 
13
13
  import pandas as pd
@@ -66,9 +66,13 @@ class YoutubeCommentsTable(APITable):
66
66
  select_statement_executor = SELECTQueryExecutor(
67
67
  comments_df,
68
68
  selected_columns,
69
- [where_condition for where_condition in where_conditions if where_condition[1] not in ['video_id', 'channel_id']],
69
+ [
70
+ where_condition
71
+ for where_condition in where_conditions
72
+ if where_condition[1] not in ["video_id", "channel_id"]
73
+ ],
70
74
  order_by_conditions,
71
- result_limit if query.limit else None
75
+ result_limit if query.limit else None,
72
76
  )
73
77
 
74
78
  comments_df = select_statement_executor.execute_query()
@@ -98,50 +102,30 @@ class YoutubeCommentsTable(APITable):
98
102
  values_to_insert = insert_query_parser.parse_query()
99
103
 
100
104
  for value in values_to_insert:
101
- if not value.get('comment_id'):
102
- if not value.get('comment'):
105
+ if not value.get("comment_id"):
106
+ if not value.get("comment"):
103
107
  raise ValueError("comment is mandatory for inserting a top-level comment.")
104
108
  else:
105
- self.insert_comment(video_id=value['video_id'], text=value['comment'])
109
+ self.insert_comment(video_id=value["video_id"], text=value["comment"])
106
110
 
107
111
  else:
108
- if not value.get('reply'):
112
+ if not value.get("reply"):
109
113
  raise ValueError("reply is mandatory for inserting a reply.")
110
114
  else:
111
- self.insert_comment(comment_id=value['comment_id'], text=value['reply'])
115
+ self.insert_comment(comment_id=value["comment_id"], text=value["reply"])
112
116
 
113
117
  def insert_comment(self, text, video_id: str = None, comment_id: str = None):
114
118
  # if comment_id is provided, define the request body for a reply and insert it
115
119
  if comment_id:
116
- request_body = {
117
- 'snippet': {
118
- 'parentId': comment_id,
119
- 'textOriginal': text
120
- }
121
- }
120
+ request_body = {"snippet": {"parentId": comment_id, "textOriginal": text}}
122
121
 
123
- self.handler.connect().comments().insert(
124
- part='snippet',
125
- body=request_body
126
- ).execute()
122
+ self.handler.connect().comments().insert(part="snippet", body=request_body).execute()
127
123
 
128
124
  # else if video_id is provided, define the request body for a top-level comment and insert it
129
125
  elif video_id:
130
- request_body = {
131
- 'snippet': {
132
- 'topLevelComment': {
133
- 'snippet': {
134
- 'videoId': video_id,
135
- 'textOriginal': text
136
- }
137
- }
138
- }
139
- }
126
+ request_body = {"snippet": {"topLevelComment": {"snippet": {"videoId": video_id, "textOriginal": text}}}}
140
127
 
141
- self.handler.connect().commentThreads().insert(
142
- part='snippet',
143
- body=request_body
144
- ).execute()
128
+ self.handler.connect().commentThreads().insert(part="snippet", body=request_body).execute()
145
129
 
146
130
  def get_columns(self) -> List[str]:
147
131
  """Gets all columns to be returned in pandas DataFrame responses
@@ -150,7 +134,19 @@ class YoutubeCommentsTable(APITable):
150
134
  List[str]
151
135
  List of columns
152
136
  """
153
- return ['comment_id', 'channel_id', 'video_id', 'user_id', 'display_name', 'comment', "published_at", "updated_at", 'reply_user_id', 'reply_author', 'reply']
137
+ return [
138
+ "comment_id",
139
+ "channel_id",
140
+ "video_id",
141
+ "user_id",
142
+ "display_name",
143
+ "comment",
144
+ "published_at",
145
+ "updated_at",
146
+ "reply_user_id",
147
+ "reply_author",
148
+ "reply",
149
+ ]
154
150
 
155
151
  def get_comments(self, video_id: str, channel_id: str):
156
152
  """Pulls all the records from the given youtube api end point and returns it select()
@@ -166,7 +162,12 @@ class YoutubeCommentsTable(APITable):
166
162
  resource = (
167
163
  self.handler.connect()
168
164
  .commentThreads()
169
- .list(part="snippet, replies", videoId=video_id, allThreadsRelatedToChannelId=channel_id, textFormat="plainText")
165
+ .list(
166
+ part="snippet, replies",
167
+ videoId=video_id,
168
+ allThreadsRelatedToChannelId=channel_id,
169
+ textFormat="plainText",
170
+ )
170
171
  )
171
172
 
172
173
  data = []
@@ -175,7 +176,7 @@ class YoutubeCommentsTable(APITable):
175
176
 
176
177
  for comment in comments["items"]:
177
178
  replies = []
178
- if 'replies' in comment:
179
+ if "replies" in comment:
179
180
  for reply in comment["replies"]["comments"]:
180
181
  replies.append(
181
182
  {
@@ -222,18 +223,51 @@ class YoutubeCommentsTable(APITable):
222
223
  else:
223
224
  break
224
225
 
225
- youtube_comments_df = pd.json_normalize(data, 'replies', ['comment_id', 'channel_id', 'video_id', 'user_id', 'display_name', 'comment', "published_at", "updated_at"], record_prefix='replies.')
226
- youtube_comments_df = youtube_comments_df.rename(columns={'replies.user_id': 'reply_user_id', 'replies.reply_author': 'reply_author', 'replies.reply': 'reply'})
226
+ youtube_comments_df = pd.json_normalize(
227
+ data,
228
+ "replies",
229
+ [
230
+ "comment_id",
231
+ "channel_id",
232
+ "video_id",
233
+ "user_id",
234
+ "display_name",
235
+ "comment",
236
+ "published_at",
237
+ "updated_at",
238
+ ],
239
+ record_prefix="replies.",
240
+ )
241
+ youtube_comments_df = youtube_comments_df.rename(
242
+ columns={
243
+ "replies.user_id": "reply_user_id",
244
+ "replies.reply_author": "reply_author",
245
+ "replies.reply": "reply",
246
+ }
247
+ )
227
248
 
228
249
  # check if DataFrame is empty
229
250
  if youtube_comments_df.empty:
230
251
  return youtube_comments_df
231
252
  else:
232
- return youtube_comments_df[['comment_id', 'channel_id', 'video_id', 'user_id', 'display_name', 'comment', "published_at", "updated_at", 'reply_user_id', 'reply_author', 'reply']]
253
+ return youtube_comments_df[
254
+ [
255
+ "comment_id",
256
+ "channel_id",
257
+ "video_id",
258
+ "user_id",
259
+ "display_name",
260
+ "comment",
261
+ "published_at",
262
+ "updated_at",
263
+ "reply_user_id",
264
+ "reply_author",
265
+ "reply",
266
+ ]
267
+ ]
233
268
 
234
269
 
235
270
  class YoutubeChannelsTable(APITable):
236
-
237
271
  """Youtube Channel Info by channel id Table implementation"""
238
272
 
239
273
  def select(self, query: ast.Select) -> pd.DataFrame:
@@ -263,9 +297,9 @@ class YoutubeChannelsTable(APITable):
263
297
  select_statement_executor = SELECTQueryExecutor(
264
298
  channel_df,
265
299
  selected_columns,
266
- [where_condition for where_condition in where_conditions if where_condition[1] == 'channel_id'],
300
+ [where_condition for where_condition in where_conditions if where_condition[1] == "channel_id"],
267
301
  order_by_conditions,
268
- result_limit if query.limit else None
302
+ result_limit if query.limit else None,
269
303
  )
270
304
 
271
305
  channel_df = select_statement_executor.execute_query()
@@ -304,7 +338,6 @@ class YoutubeChannelsTable(APITable):
304
338
 
305
339
 
306
340
  class YoutubeVideosTable(APITable):
307
-
308
341
  """Youtube Video info by video id Table implementation"""
309
342
 
310
343
  def select(self, query: ast.Select) -> pd.DataFrame:
@@ -317,7 +350,7 @@ class YoutubeVideosTable(APITable):
317
350
  result_limit,
318
351
  ) = select_statement_parser.parse_query()
319
352
 
320
- video_id, channel_id = None, None
353
+ video_id, channel_id, search_query = None, None, None
321
354
  for op, arg1, arg2 in where_conditions:
322
355
  if arg1 == "video_id":
323
356
  if op == "=":
@@ -331,38 +364,126 @@ class YoutubeVideosTable(APITable):
331
364
  else:
332
365
  raise NotImplementedError("Only '=' operator is supported for channel_id column.")
333
366
 
334
- if not video_id and not channel_id:
335
- raise ValueError("Either video_id or channel_id has to be present in where clause.")
367
+ elif arg1 == "query":
368
+ if op == "=":
369
+ search_query = arg2
370
+ else:
371
+ raise NotImplementedError("Only '=' operator is supported for query column.")
372
+
373
+ if not video_id and not channel_id and not search_query:
374
+ raise ValueError("At least one of video_id, channel_id, or query must be present in the WHERE clause.")
336
375
 
337
376
  if video_id:
338
377
  video_df = self.get_videos_by_video_ids([video_id])
378
+ elif channel_id and search_query:
379
+ video_df = self.get_videos_by_search_query_in_channel(search_query, channel_id, result_limit)
380
+ elif channel_id:
381
+ video_df = self.get_videos_by_channel_id(channel_id, result_limit)
339
382
  else:
340
- video_df = self.get_videos_by_channel_id(channel_id)
383
+ video_df = self.get_videos_by_search_query(search_query, result_limit)
341
384
 
342
385
  select_statement_executor = SELECTQueryExecutor(
343
386
  video_df,
344
387
  selected_columns,
345
- [where_condition for where_condition in where_conditions if where_condition[1] not in ['video_id', 'channel_id']],
388
+ [
389
+ where_condition
390
+ for where_condition in where_conditions
391
+ if where_condition[1] not in ["video_id", "channel_id", "query"]
392
+ ],
346
393
  order_by_conditions,
347
- result_limit if query.limit else None
394
+ result_limit if query.limit else None,
348
395
  )
349
396
 
350
397
  video_df = select_statement_executor.execute_query()
351
398
 
352
399
  return video_df
353
400
 
354
- def get_videos_by_channel_id(self, channel_id):
401
+ def get_videos_by_search_query(self, search_query, limit=10):
355
402
  video_ids = []
356
403
  resource = (
357
404
  self.handler.connect()
358
405
  .search()
359
- .list(part="snippet", channelId=channel_id, type="video")
406
+ .list(part="snippet", q=search_query, type="video", maxResults=min(50, limit))
360
407
  )
361
- while resource:
408
+ total_fetched = 0
409
+
410
+ while resource and total_fetched < limit:
411
+ response = resource.execute()
412
+ for item in response["items"]:
413
+ video_ids.append(item["id"]["videoId"])
414
+ total_fetched += 1
415
+ if total_fetched >= limit:
416
+ break
417
+
418
+ if "nextPageToken" in response and total_fetched < limit:
419
+ resource = (
420
+ self.handler.connect()
421
+ .search()
422
+ .list(
423
+ part="snippet",
424
+ q=search_query,
425
+ type="video",
426
+ maxResults=min(50, limit - total_fetched),
427
+ pageToken=response["nextPageToken"],
428
+ )
429
+ )
430
+ else:
431
+ break
432
+
433
+ return self.get_videos_by_video_ids(video_ids)
434
+
435
+ def get_videos_by_search_query_in_channel(self, search_query, channel_id, limit=10):
436
+ """Search for videos within a specific channel"""
437
+ video_ids = []
438
+ resource = (
439
+ self.handler.connect()
440
+ .search()
441
+ .list(part="snippet", q=search_query, channelId=channel_id, type="video", maxResults=min(50, limit))
442
+ )
443
+ total_fetched = 0
444
+
445
+ while resource and total_fetched < limit:
446
+ response = resource.execute()
447
+ for item in response["items"]:
448
+ video_ids.append(item["id"]["videoId"])
449
+ total_fetched += 1
450
+ if total_fetched >= limit:
451
+ break
452
+
453
+ if "nextPageToken" in response and total_fetched < limit:
454
+ resource = (
455
+ self.handler.connect()
456
+ .search()
457
+ .list(
458
+ part="snippet",
459
+ q=search_query,
460
+ channelId=channel_id,
461
+ type="video",
462
+ maxResults=min(50, limit - total_fetched),
463
+ pageToken=response["nextPageToken"],
464
+ )
465
+ )
466
+ else:
467
+ break
468
+
469
+ return self.get_videos_by_video_ids(video_ids)
470
+
471
+ def get_videos_by_channel_id(self, channel_id, limit=10):
472
+ video_ids = []
473
+ resource = (
474
+ self.handler.connect()
475
+ .search()
476
+ .list(part="snippet", channelId=channel_id, type="video", maxResults=min(50, limit))
477
+ )
478
+ total_fetched = 0
479
+ while resource and total_fetched < limit:
362
480
  response = resource.execute()
363
481
  for item in response["items"]:
364
482
  video_ids.append(item["id"]["videoId"])
365
- if "nextPageToken" in response:
483
+ total_fetched += 1
484
+ if total_fetched >= limit:
485
+ break
486
+ if "nextPageToken" in response and total_fetched < limit:
366
487
  resource = (
367
488
  self.handler.connect()
368
489
  .search()
@@ -370,6 +491,7 @@ class YoutubeVideosTable(APITable):
370
491
  part="snippet",
371
492
  channelId=channel_id,
372
493
  type="video",
494
+ maxResults=min(50, limit - total_fetched),
373
495
  pageToken=response["nextPageToken"],
374
496
  )
375
497
  )
@@ -388,7 +510,13 @@ class YoutubeVideosTable(APITable):
388
510
  # loop over 50 video ids at a time
389
511
  # an invalid request error is caused otherwise
390
512
  for i in range(0, len(video_ids), 50):
391
- resource = self.handler.connect().videos().list(part="statistics,snippet,contentDetails", id=",".join(video_ids[i:i + 50])).execute()
513
+ resource = (
514
+ self.handler.connect()
515
+ .videos()
516
+ .list(part="statistics,snippet,contentDetails", id=",".join(video_ids[i : i + 50]))
517
+ .execute()
518
+ )
519
+
392
520
  for item in resource["items"]:
393
521
  data.append(
394
522
  {
@@ -415,7 +543,7 @@ class YoutubeVideosTable(APITable):
415
543
  return json_formatted_transcript
416
544
 
417
545
  except Exception as e:
418
- logger.error(f"Encountered an error while fetching transcripts for video ${video_id}: ${e}"),
546
+ (logger.error(f"Encountered an error while fetching transcripts for video ${video_id}: ${e}"),)
419
547
  return "Transcript not available for this video"
420
548
 
421
549
  def parse_duration(self, video_id, duration):
@@ -428,7 +556,7 @@ class YoutubeVideosTable(APITable):
428
556
 
429
557
  return duration_str.strip(":")
430
558
  except Exception as e:
431
- logger.error(f"Encountered an error while parsing duration for video ${video_id}: ${e}"),
559
+ (logger.error(f"Encountered an error while parsing duration for video ${video_id}: ${e}"),)
432
560
  return "Duration not available for this video"
433
561
 
434
562
  def get_columns(self) -> List[str]:
@@ -180,7 +180,7 @@ class AgentsController:
180
180
  agent (db.Agents): The created agent
181
181
 
182
182
  Raises:
183
- ValueError: Agent with given name already exists, or skill/model with given name does not exist.
183
+ EntityExistsError: Agent with given name already exists, or skill/model with given name does not exist.
184
184
  """
185
185
  if project_name is None:
186
186
  project_name = default_project
@@ -189,7 +189,7 @@ class AgentsController:
189
189
  agent = self.get_agent(name, project_name)
190
190
 
191
191
  if agent is not None:
192
- raise ValueError(f"Agent with name already exists: {name}")
192
+ raise EntityExistsError("Agent already exists", name)
193
193
 
194
194
  # No need to copy params since we're not preserving the original reference
195
195
  params = params or {}
@@ -1,7 +1,7 @@
1
1
  from typing import List, Union
2
-
3
2
  import pandas as pd
4
-
3
+ import json
4
+ import datetime
5
5
  from mindsdb.integrations.libs.response import RESPONSE_TYPE
6
6
  from mindsdb.interfaces.data_catalog.base_data_catalog import BaseDataCatalog
7
7
  from mindsdb.interfaces.storage import db
@@ -204,6 +204,8 @@ class DataCatalogLoader(BaseDataCatalog):
204
204
  # Convert the distinct_values_count to an integer if it is not NaN, otherwise set it to None.
205
205
  val = row.get("distinct_values_count")
206
206
  distinct_values_count = int(val) if pd.notna(val) else None
207
+ min_val = row.get("minimum_value")
208
+ max_val = row.get("maximum_value")
207
209
 
208
210
  # Convert the most_common_frequencies to a list of strings.
209
211
  most_common_frequencies = [str(val) for val in row.get("most_common_frequencies") or []]
@@ -214,8 +216,8 @@ class DataCatalogLoader(BaseDataCatalog):
214
216
  most_common_frequencies=most_common_frequencies,
215
217
  null_percentage=row.get("null_percentage"),
216
218
  distinct_values_count=distinct_values_count,
217
- minimum_value=row.get("minimum_value"),
218
- maximum_value=row.get("maximum_value"),
219
+ minimum_value=self.to_str(min_val),
220
+ maximum_value=self.to_str(max_val),
219
221
  )
220
222
  column_statistics.append(record)
221
223
 
@@ -373,3 +375,15 @@ class DataCatalogLoader(BaseDataCatalog):
373
375
  db.session.delete(table)
374
376
  db.session.commit()
375
377
  self.logger.info(f"Metadata for {self.database_name} removed successfully.")
378
+
379
+ def to_str(self, val) -> str:
380
+ """
381
+ Convert a value to a string.
382
+ """
383
+ if val is None:
384
+ return None
385
+ if isinstance(val, (datetime.datetime, datetime.date)):
386
+ return val.isoformat()
387
+ if isinstance(val, (list, dict, set, tuple)):
388
+ return json.dumps(val, default=str)
389
+ return str(val)
@@ -1186,6 +1186,13 @@ class KnowledgeBaseController:
1186
1186
  if "provider" not in params:
1187
1187
  raise ValueError("'provider' parameter is required for embedding model")
1188
1188
 
1189
+ # check available providers
1190
+ avail_providers = ("openai", "azure_openai", "bedrock", "gemini", "google")
1191
+ if params["provider"] not in avail_providers:
1192
+ raise ValueError(
1193
+ f"Wrong embedding provider: {params['provider']}. Available providers: {', '.join(avail_providers)}"
1194
+ )
1195
+
1189
1196
  if params["provider"] not in ("openai", "azure_openai"):
1190
1197
  # try use litellm
1191
1198
  try:
@@ -1,5 +1,6 @@
1
1
  import json
2
2
  import math
3
+ import re
3
4
  import time
4
5
  from typing import List
5
6
 
@@ -16,15 +17,15 @@ logger = log.getLogger(__name__)
16
17
 
17
18
 
18
19
  GENERATE_QA_SYSTEM_PROMPT = """
19
- Your task is to generate question and answer pairs for a search engine.
20
+ Your task is to generate question and answer pairs for a search engine.
20
21
  The search engine will take your query and return a list of documents.
21
22
  You will be given a text and you need to generate a question that can be answered using the information in the text.
22
23
  Your questions will be used to evaluate the search engine.
23
- Question should always have enough clues to identify the specific text that this question is generated from.
24
+ Question should always have enough clues to identify the specific text that this question is generated from.
24
25
  Never ask questions like "What license number is associated with Amend 6" because Amend 6 could be found in many documents and the question is not specific enough.
25
- Example output 1: {\"query\": \"What processor does the HP 2023 14\" FHD IPS Laptop use?\", \"reference_answer\": \"Ryzen 3 5300U\"}
26
+ Example output 1: {\"query\": \"What processor does the HP 2023 14\" FHD IPS Laptop use?\", \"reference_answer\": \"Ryzen 3 5300U\"}
26
27
  Example output 2: {\"query\": \"What is the name of the river in Paris?\", \"reference_answer\": \"Seine\"}
27
- Don't generate questions like "What is being amended in the application?" because these questions cannot be answered using the text and without knowing which document it refers to.
28
+ Don't generate questions like "What is being amended in the application?" because these questions cannot be answered using the text and without knowing which document it refers to.
28
29
  The question should be answerable without the text, but the answer should be present in the text.
29
30
  Return ONLY a json response. No other text.
30
31
  """
@@ -43,6 +44,39 @@ def calc_entropy(values: List[float]) -> float:
43
44
  return -sum([pk * math.log(pk) for pk in values])
44
45
 
45
46
 
47
+ def sanitize_json_response(response: str) -> str:
48
+ """Remove markdown code block formatting from JSON response and extract valid JSON."""
49
+ if not response or not response.strip():
50
+ raise ValueError("Empty response provided.")
51
+
52
+ # Remove leading/trailing whitespace
53
+ response = response.strip()
54
+
55
+ # Remove markdown code block markers if present
56
+ response = re.sub(r"^```(?:json|JSON)?\s*", "", response, flags=re.MULTILINE)
57
+ response = re.sub(r"\s*```$", "", response, flags=re.MULTILINE)
58
+ response = response.strip()
59
+
60
+ # Find the first opening brace
61
+ start_idx = response.find("{")
62
+ if start_idx == -1:
63
+ raise ValueError("No JSON object found in the response.")
64
+
65
+ # Try to parse JSON starting from first { with increasing end positions
66
+ # This handles nested objects and strings with braces correctly
67
+ for end_idx in range(len(response), start_idx, -1): # Start from end and work backwards
68
+ candidate = response[start_idx:end_idx]
69
+ try:
70
+ parsed = json.loads(candidate)
71
+ # Ensure it's a dictionary (object) not just any valid JSON
72
+ if isinstance(parsed, dict):
73
+ return candidate
74
+ except json.JSONDecodeError:
75
+ continue
76
+
77
+ raise ValueError("No valid JSON object found in the response.")
78
+
79
+
46
80
  class EvaluateBase:
47
81
  DEFAULT_QUESTION_COUNT = 20
48
82
  DEFAULT_SAMPLE_SIZE = 10000
@@ -178,6 +212,7 @@ class EvaluateBase:
178
212
  test_data = self.read_from_table(test_table)
179
213
 
180
214
  scores = self.evaluate(test_data)
215
+ scores["id"] = math.floor(time.time()) # unique ID for the evaluation run
181
216
  scores["name"] = self.name
182
217
  scores["created_at"] = dt.datetime.now()
183
218
 
@@ -237,9 +272,13 @@ class EvaluateRerank(EvaluateBase):
237
272
  {"role": "system", "content": GENERATE_QA_SYSTEM_PROMPT},
238
273
  {"role": "user", "content": f"\n\nText:\n{text}\n\n"},
239
274
  ]
240
- answer = self.llm_client.completion(messages)
275
+ answer = self.llm_client.completion(messages, json_output=True)
276
+
277
+ # Sanitize the response by removing markdown code block formatting like ```json
278
+ sanitized_answer = sanitize_json_response(answer)
279
+
241
280
  try:
242
- output = json.loads(answer)
281
+ output = json.loads(sanitized_answer)
243
282
  except json.JSONDecodeError:
244
283
  raise ValueError(f"Could not parse response from LLM: {answer}")
245
284
 
@@ -448,9 +487,13 @@ class EvaluateDocID(EvaluateBase):
448
487
  {"role": "system", "content": GENERATE_QA_SYSTEM_PROMPT},
449
488
  {"role": "user", "content": f"\n\nText:\n{text}\n\n"},
450
489
  ]
451
- answer = self.llm_client.completion(messages)
490
+ answer = self.llm_client.completion(messages, json_output=True)
491
+
492
+ # Sanitize the response by removing markdown code block formatting like ```json
493
+ sanitized_answer = sanitize_json_response(answer)
494
+
452
495
  try:
453
- output = json.loads(answer)
496
+ output = json.loads(sanitized_answer)
454
497
  except json.JSONDecodeError:
455
498
  raise ValueError(f"Could not parse response from LLM: {answer}")
456
499
 
@@ -54,12 +54,12 @@ class LLMClient:
54
54
 
55
55
  self.client = module.Handler
56
56
 
57
- def completion(self, messages: List[dict]) -> str:
57
+ def completion(self, messages: List[dict], json_output: bool = False) -> str:
58
58
  """
59
59
  Call LLM completion and get response
60
60
  """
61
61
  params = self.params
62
-
62
+ params["json_output"] = json_output
63
63
  if self.provider in ("azure_openai", "openai"):
64
64
  response = self.client.chat.completions.create(
65
65
  model=params["model_name"],
@@ -69,6 +69,6 @@ class LLMClient:
69
69
  else:
70
70
  kwargs = params.copy()
71
71
  model = kwargs.pop("model_name")
72
-
72
+ kwargs.pop("provider", None)
73
73
  response = self.client.completion(self.provider, model=model, messages=messages, args=kwargs)
74
74
  return response.choices[0].message.content
@@ -1,16 +1,17 @@
1
+ import re
2
+ import html
3
+ import asyncio
1
4
  from typing import List, Dict, Optional, Any
5
+
2
6
  import pandas as pd
3
7
  from langchain_text_splitters import RecursiveCharacterTextSplitter
4
- import asyncio
5
-
8
+ from langchain_core.documents import Document as LangchainDocument
6
9
 
7
10
  from mindsdb.integrations.utilities.rag.splitters.file_splitter import (
8
11
  FileSplitter,
9
12
  FileSplitterConfig,
10
13
  )
11
-
12
14
  from mindsdb.interfaces.agents.langchain_agent import create_chat_model
13
-
14
15
  from mindsdb.interfaces.knowledge_base.preprocessing.models import (
15
16
  PreprocessingConfig,
16
17
  ProcessedChunk,
@@ -21,7 +22,6 @@ from mindsdb.interfaces.knowledge_base.preprocessing.models import (
21
22
  )
22
23
  from mindsdb.utilities import log
23
24
 
24
- from langchain_core.documents import Document as LangchainDocument
25
25
 
26
26
  logger = log.getLogger(__name__)
27
27
 
@@ -123,11 +123,11 @@ class ContextualPreprocessor(DocumentPreprocessor):
123
123
 
124
124
  DEFAULT_CONTEXT_TEMPLATE = """
125
125
  <document>
126
- {{WHOLE_DOCUMENT}}
126
+ {WHOLE_DOCUMENT}
127
127
  </document>
128
128
  Here is the chunk we want to situate within the whole document
129
129
  <chunk>
130
- {{CHUNK_CONTENT}}
130
+ {CHUNK_CONTENT}
131
131
  </chunk>
132
132
  Please give a short succinct context to situate this chunk within the overall document for the purposes of improving search retrieval of the chunk. Answer only with the succinct context and nothing else."""
133
133
 
@@ -149,12 +149,20 @@ Please give a short succinct context to situate this chunk within the overall do
149
149
  self.summarize = self.config.summarize
150
150
 
151
151
  def _prepare_prompts(self, chunk_contents: list[str], full_documents: list[str]) -> list[str]:
152
- prompts = [
153
- self.context_template.replace("{{WHOLE_DOCUMENT}}", full_document) for full_document in full_documents
154
- ]
155
- prompts = [
156
- prompt.replace("{{CHUNK_CONTENT}}", chunk_content) for prompt, chunk_content in zip(prompts, chunk_contents)
157
- ]
152
+ def tag_replacer(match):
153
+ tag = match.group(0)
154
+ if tag.lower() not in ["<document>", "</document>", "<chunk>", "</chunk>"]:
155
+ return tag
156
+ return html.escape(tag)
157
+
158
+ tag_pattern = r"</?document>|</?chunk>"
159
+ prompts = []
160
+ for chunk_content, full_document in zip(chunk_contents, full_documents):
161
+ chunk_content = re.sub(tag_pattern, tag_replacer, chunk_content, flags=re.IGNORECASE)
162
+ full_document = re.sub(tag_pattern, tag_replacer, full_document, flags=re.IGNORECASE)
163
+ prompts.append(
164
+ self.DEFAULT_CONTEXT_TEMPLATE.format(WHOLE_DOCUMENT=full_document, CHUNK_CONTENT=chunk_content)
165
+ )
158
166
 
159
167
  return prompts
160
168