MindsDB 25.3.2.0__py3-none-any.whl → 25.3.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of MindsDB might be problematic. Click here for more details.

Files changed (45) hide show
  1. mindsdb/__about__.py +1 -1
  2. mindsdb/__main__.py +0 -1
  3. mindsdb/api/executor/datahub/datanodes/information_schema_datanode.py +2 -6
  4. mindsdb/api/executor/datahub/datanodes/mindsdb_tables.py +1 -1
  5. mindsdb/api/http/namespaces/agents.py +9 -5
  6. mindsdb/api/http/namespaces/chatbots.py +6 -5
  7. mindsdb/api/http/namespaces/databases.py +5 -6
  8. mindsdb/api/http/namespaces/skills.py +5 -4
  9. mindsdb/api/http/namespaces/views.py +6 -7
  10. mindsdb/integrations/handlers/chromadb_handler/chromadb_handler.py +23 -2
  11. mindsdb/integrations/handlers/dummy_data_handler/dummy_data_handler.py +16 -6
  12. mindsdb/integrations/handlers/file_handler/tests/test_file_handler.py +64 -83
  13. mindsdb/integrations/handlers/github_handler/generate_api.py +228 -0
  14. mindsdb/integrations/handlers/github_handler/github_handler.py +15 -8
  15. mindsdb/integrations/handlers/github_handler/requirements.txt +1 -1
  16. mindsdb/integrations/handlers/huggingface_handler/requirements.txt +5 -4
  17. mindsdb/integrations/handlers/huggingface_handler/requirements_cpu.txt +5 -5
  18. mindsdb/integrations/handlers/ms_one_drive_handler/ms_graph_api_one_drive_client.py +1 -1
  19. mindsdb/integrations/handlers/ms_teams_handler/ms_graph_api_teams_client.py +278 -0
  20. mindsdb/integrations/handlers/ms_teams_handler/ms_teams_handler.py +114 -70
  21. mindsdb/integrations/handlers/ms_teams_handler/ms_teams_tables.py +431 -0
  22. mindsdb/integrations/handlers/pgvector_handler/pgvector_handler.py +18 -4
  23. mindsdb/integrations/handlers/redshift_handler/redshift_handler.py +1 -0
  24. mindsdb/integrations/handlers/salesforce_handler/requirements.txt +1 -1
  25. mindsdb/integrations/handlers/salesforce_handler/salesforce_handler.py +20 -25
  26. mindsdb/integrations/handlers/salesforce_handler/salesforce_tables.py +2 -2
  27. mindsdb/integrations/handlers/timescaledb_handler/timescaledb_handler.py +11 -6
  28. mindsdb/integrations/libs/ml_handler_process/learn_process.py +9 -3
  29. mindsdb/integrations/libs/vectordatabase_handler.py +2 -2
  30. mindsdb/integrations/utilities/files/file_reader.py +3 -3
  31. mindsdb/integrations/utilities/handlers/api_utilities/microsoft/ms_graph_api_utilities.py +36 -2
  32. mindsdb/integrations/utilities/rag/settings.py +1 -0
  33. mindsdb/interfaces/chatbot/chatbot_controller.py +6 -4
  34. mindsdb/interfaces/jobs/jobs_controller.py +1 -4
  35. mindsdb/interfaces/knowledge_base/controller.py +9 -28
  36. mindsdb/interfaces/knowledge_base/preprocessing/document_preprocessor.py +1 -1
  37. mindsdb/interfaces/skills/skills_controller.py +8 -7
  38. mindsdb/utilities/render/sqlalchemy_render.py +11 -5
  39. {mindsdb-25.3.2.0.dist-info → mindsdb-25.3.4.0.dist-info}/METADATA +236 -233
  40. {mindsdb-25.3.2.0.dist-info → mindsdb-25.3.4.0.dist-info}/RECORD +43 -42
  41. {mindsdb-25.3.2.0.dist-info → mindsdb-25.3.4.0.dist-info}/WHEEL +1 -1
  42. mindsdb/integrations/handlers/timescaledb_handler/tests/__init__.py +0 -0
  43. mindsdb/integrations/handlers/timescaledb_handler/tests/test_timescaledb_handler.py +0 -47
  44. {mindsdb-25.3.2.0.dist-info → mindsdb-25.3.4.0.dist-info/licenses}/LICENSE +0 -0
  45. {mindsdb-25.3.2.0.dist-info → mindsdb-25.3.4.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,431 @@
1
+ from typing import List
2
+
3
+ import pandas as pd
4
+
5
+ from mindsdb.integrations.handlers.ms_teams_handler.ms_graph_api_teams_client import MSGraphAPITeamsDelegatedPermissionsClient
6
+ from mindsdb.integrations.libs.api_handler import APIResource
7
+ from mindsdb.integrations.utilities.sql_utils import (
8
+ FilterCondition,
9
+ FilterOperator,
10
+ SortColumn
11
+ )
12
+
13
+
14
+ class TeamsTable(APIResource):
15
+ """
16
+ The table abstraction for the 'teams' resource of the Microsoft Graph API.
17
+ """
18
+ def list(
19
+ self,
20
+ conditions: List[FilterCondition] = None,
21
+ limit: int = None,
22
+ sort: List[SortColumn] = None,
23
+ targets: List[str] = None,
24
+ **kwargs
25
+ ):
26
+ """
27
+ Executes a parsed SELECT SQL query on the 'teams' resource of the Microsoft Graph API.
28
+
29
+ Args:
30
+ conditions (List[FilterCondition]): The list of parsed filter conditions.
31
+ limit (int): The maximum number of records to return.
32
+ sort (List[SortColumn]): The list of parsed sort columns.
33
+ targets (List[str]): The list of target columns to return.
34
+ """
35
+ client: MSGraphAPITeamsDelegatedPermissionsClient = self.handler.connect()
36
+ teams = client.get_all_groups()
37
+
38
+ teams_df = pd.json_normalize(teams, sep="_")
39
+ teams_df = teams_df[self.get_columns()]
40
+
41
+ return teams_df
42
+
43
+ def get_columns(self) -> List[str]:
44
+ """
45
+ Retrieves the attributes (columns) of the 'teams' resource.
46
+
47
+ Returns:
48
+ List[Text]: A list of attributes (columns) of the 'teams' resource.
49
+ """
50
+ return [
51
+ "id",
52
+ "createdDateTime",
53
+ "displayName",
54
+ "description",
55
+ "internalId",
56
+ "classification",
57
+ "specialization",
58
+ "visibility",
59
+ "webUrl",
60
+ "isArchived",
61
+ "tenantId",
62
+ "isMembershipLimitedToOwners",
63
+ ]
64
+
65
+
66
+ class ChannelsTable(APIResource):
67
+ """
68
+ The table abstraction for the 'channels' resource of the Microsoft Graph API.
69
+ """
70
+ def list(
71
+ self,
72
+ conditions: List[FilterCondition] = None,
73
+ limit: int = None,
74
+ sort: List[SortColumn] = None,
75
+ targets: List[str] = None,
76
+ **kwargs
77
+ ):
78
+ """
79
+ Executes a parsed SELECT SQL query on the 'channels' resource of the Microsoft Graph API.
80
+
81
+ Args:
82
+ conditions (List[FilterCondition]): The list of parsed filter conditions.
83
+ limit (int): The maximum number of records to return.
84
+ sort (List[SortColumn]): The list of parsed sort columns.
85
+ targets (List[str]): The list of target columns to return.
86
+ """
87
+ client: MSGraphAPITeamsDelegatedPermissionsClient = self.handler.connect()
88
+ channels = []
89
+
90
+ team_id, channel_ids = None, None
91
+ for condition in conditions:
92
+ if condition.column == "teamId":
93
+ if condition.op == FilterOperator.EQUAL:
94
+ team_id = condition.value
95
+
96
+ else:
97
+ raise ValueError(
98
+ f"Unsupported operator '{condition.op}' for column 'teamId'."
99
+ )
100
+
101
+ condition.applied = True
102
+
103
+ if condition.column == "id":
104
+ if condition.op == FilterOperator.EQUAL:
105
+ channel_ids = [condition.value]
106
+
107
+ elif condition.op == FilterOperator.IN:
108
+ channel_ids = condition.value
109
+
110
+ else:
111
+ raise ValueError(
112
+ f"Unsupported operator '{condition.op}' for column 'id'."
113
+ )
114
+
115
+ condition.applied = True
116
+
117
+ if team_id:
118
+ if channel_ids:
119
+ channels = client.get_channels_in_group_by_ids(team_id, channel_ids)
120
+
121
+ else:
122
+ channels = client.get_all_channels_in_group(team_id)
123
+
124
+ elif channel_ids:
125
+ channels = client.get_channels_across_all_groups_by_ids(channel_ids)
126
+
127
+ else:
128
+ channels = client.get_all_channels_across_all_groups()
129
+
130
+ channels_df = pd.json_normalize(channels, sep="_")
131
+ channels_df = channels_df[self.get_columns()]
132
+
133
+ return channels_df
134
+
135
+ def get_columns(self) -> List[str]:
136
+ """
137
+ Retrieves the attributes (columns) of the 'chats' resource.
138
+
139
+ Returns:
140
+ List[Text]: A list of attributes (columns) of the 'chats' resource.
141
+ """
142
+ return [
143
+ "id",
144
+ "createdDateTime",
145
+ "displayName",
146
+ "description",
147
+ "isFavoriteByDefault",
148
+ "email",
149
+ "tenantId",
150
+ "webUrl",
151
+ "membershipType",
152
+ "teamId",
153
+ ]
154
+
155
+
156
+ class ChannelMessagesTable(APIResource):
157
+ """
158
+ The table abstraction for the 'channel messages' resource of the Microsoft Graph API.
159
+ """
160
+ def list(
161
+ self,
162
+ conditions: List[FilterCondition] = None,
163
+ limit: int = None,
164
+ sort: List[SortColumn] = None,
165
+ targets: List[str] = None,
166
+ **kwargs
167
+ ):
168
+ """
169
+ Executes a parsed SELECT SQL query on the 'channel messages' resource of the Microsoft Graph API.
170
+
171
+ Args:
172
+ conditions (List[FilterCondition]): The list of parsed filter conditions.
173
+ limit (int): The maximum number of records to return.
174
+ sort (List[SortColumn]): The list of parsed sort columns.
175
+ targets (List[str]): The list of target columns to return.
176
+ """
177
+ client: MSGraphAPITeamsDelegatedPermissionsClient = self.handler.connect()
178
+ messages = []
179
+
180
+ group_id, channel_id, message_ids = None, None, None
181
+ for condition in conditions:
182
+ if condition.column == "channelIdentity_teamId":
183
+ if condition.op == FilterOperator.EQUAL:
184
+ group_id = condition.value
185
+
186
+ else:
187
+ raise ValueError(
188
+ f"Unsupported operator '{condition.op}' for column 'channelIdentity_teamId'."
189
+ )
190
+
191
+ condition.applied = True
192
+
193
+ if condition.column == "channelIdentity_channelId":
194
+ if condition.op == FilterOperator.EQUAL:
195
+ channel_id = condition.value
196
+
197
+ else:
198
+ raise ValueError(
199
+ f"Unsupported operator '{condition.op}' for column 'channelIdentity_channelId'."
200
+ )
201
+
202
+ condition.applied = True
203
+
204
+ if condition.column == "id":
205
+ if condition.op == FilterOperator.EQUAL:
206
+ message_ids = [condition.value]
207
+
208
+ elif condition.op == FilterOperator.IN:
209
+ message_ids = condition.value
210
+
211
+ else:
212
+ raise ValueError(
213
+ f"Unsupported operator '{condition.op}' for column 'id'."
214
+ )
215
+
216
+ condition.applied = True
217
+
218
+ if not group_id or not channel_id:
219
+ raise ValueError("The 'channelIdentity_teamId' and 'channelIdentity_channelId' columns are required.")
220
+
221
+ if message_ids:
222
+ messages = client.get_messages_in_channel_by_ids(group_id, channel_id, message_ids)
223
+
224
+ else:
225
+ messages = client.get_all_messages_in_channel(group_id, channel_id, limit)
226
+
227
+ messages_df = pd.json_normalize(messages, sep="_")
228
+ messages_df = messages_df[self.get_columns()]
229
+
230
+ return messages_df
231
+
232
+ def get_columns(self) -> List[str]:
233
+ """
234
+ Retrieves the attributes (columns) of the 'chat messages' resource.
235
+
236
+ Returns:
237
+ List[Text]: A list of attributes (columns) of the 'chat messages' resource.
238
+ """
239
+ return [
240
+ "id",
241
+ "replyToId",
242
+ "etag",
243
+ "messageType",
244
+ "createdDateTime",
245
+ "lastModifiedDateTime",
246
+ "lastEditedDateTime",
247
+ "deletedDateTime",
248
+ "subject",
249
+ "summary",
250
+ "chatId",
251
+ "importance",
252
+ "locale",
253
+ "webUrl",
254
+ "policyViolation",
255
+ "from_application",
256
+ "from_device",
257
+ "from_user_id",
258
+ "from_user_displayName",
259
+ "from_user_userIdentityType",
260
+ "body_contentType",
261
+ "body_content",
262
+ "channelIdentity_teamId",
263
+ "channelIdentity_channelId",
264
+ ]
265
+
266
+
267
+ class ChatsTable(APIResource):
268
+ """
269
+ The table abstraction for the 'chats' resource of the Microsoft Graph API.
270
+ """
271
+ def list(
272
+ self,
273
+ conditions: List[FilterCondition] = None,
274
+ limit: int = None,
275
+ sort: List[SortColumn] = None,
276
+ targets: List[str] = None,
277
+ **kwargs
278
+ ):
279
+ """
280
+ Executes a parsed SELECT SQL query on the 'chats' resource of the Microsoft Graph API.
281
+
282
+ Args:
283
+ conditions (List[FilterCondition]): The list of parsed filter conditions.
284
+ limit (int): The maximum number of records to return.
285
+ sort (List[SortColumn]): The list of parsed sort columns.
286
+ targets (List[str]): The list of target columns to return.
287
+ """
288
+ client: MSGraphAPITeamsDelegatedPermissionsClient = self.handler.connect()
289
+ chats = []
290
+
291
+ chat_ids = None
292
+ for condition in conditions:
293
+ if condition.column == "id":
294
+ if condition.op == FilterOperator.EQUAL:
295
+ chat_ids = [condition.value]
296
+
297
+ elif condition.op == FilterOperator.IN:
298
+ chat_ids = condition.value
299
+
300
+ else:
301
+ raise ValueError(
302
+ f"Unsupported operator '{condition.op}' for column 'id'."
303
+ )
304
+
305
+ condition.applied = True
306
+
307
+ if chat_ids:
308
+ chats = client.get_chats_by_ids(chat_ids)
309
+
310
+ else:
311
+ chats = client.get_all_chats(limit)
312
+
313
+ chats_df = pd.json_normalize(chats, sep="_")
314
+ chats_df = chats_df[self.get_columns()]
315
+
316
+ return chats_df
317
+
318
+ def get_columns(self) -> List[str]:
319
+ """
320
+ Retrieves the attributes (columns) of the 'chats' resource.
321
+
322
+ Returns:
323
+ List[Text]: A list of attributes (columns) of the 'chats' resource.
324
+ """
325
+ return [
326
+ "id",
327
+ "topic",
328
+ "createdDateTime",
329
+ "lastUpdatedDateTime",
330
+ "chatType",
331
+ "webUrl",
332
+ "isHiddenForAllMembers"
333
+ ]
334
+
335
+
336
+ class ChatMessagesTable(APIResource):
337
+ """
338
+ The table abstraction for the 'chat messages' resource of the Microsoft Graph API.
339
+ """
340
+ def list(
341
+ self,
342
+ conditions: List[FilterCondition] = None,
343
+ limit: int = None,
344
+ sort: List[SortColumn] = None,
345
+ targets: List[str] = None,
346
+ **kwargs
347
+ ):
348
+ """
349
+ Executes a parsed SELECT SQL query on the 'chat messages' resource of the Microsoft Graph API.
350
+
351
+ Args:
352
+ conditions (List[FilterCondition]): The list of parsed filter conditions.
353
+ limit (int): The maximum number of records to return.
354
+ sort (List[SortColumn]): The list of parsed sort columns.
355
+ targets (List[str]): The list of target columns to return.
356
+ """
357
+ client: MSGraphAPITeamsDelegatedPermissionsClient = self.handler.connect()
358
+ messages = []
359
+
360
+ chat_id, message_ids = None, None
361
+ for condition in conditions:
362
+ if condition.column == "chatId":
363
+ if condition.op == FilterOperator.EQUAL:
364
+ chat_id = condition.value
365
+
366
+ else:
367
+ raise ValueError(
368
+ f"Unsupported operator '{condition.op}' for column 'chatId'."
369
+ )
370
+
371
+ condition.applied = True
372
+
373
+ if condition.column == "id":
374
+ if condition.op == FilterOperator.EQUAL:
375
+ message_ids = [condition.value]
376
+
377
+ elif condition.op == FilterOperator.IN:
378
+ message_ids = condition.value
379
+
380
+ else:
381
+ raise ValueError(
382
+ f"Unsupported operator '{condition.op}' for column 'id'."
383
+ )
384
+
385
+ condition.applied = True
386
+
387
+ if not chat_id:
388
+ raise ValueError("The 'chatId' column is required.")
389
+
390
+ if message_ids:
391
+ messages = client.get_messages_in_chat_by_ids(chat_id, message_ids)
392
+
393
+ else:
394
+ messages = client.get_all_messages_in_chat(chat_id, limit)
395
+
396
+ messages_df = pd.json_normalize(messages, sep="_")
397
+ messages_df = messages_df[self.get_columns()]
398
+
399
+ return messages_df
400
+
401
+ def get_columns(self) -> List[str]:
402
+ """
403
+ Retrieves the attributes (columns) of the 'chat messages' resource.
404
+
405
+ Returns:
406
+ List[Text]: A list of attributes (columns) of the 'chat messages' resource.
407
+ """
408
+ return [
409
+ "id",
410
+ "replyToId",
411
+ "etag",
412
+ "messageType",
413
+ "createdDateTime",
414
+ "lastModifiedDateTime",
415
+ "lastEditedDateTime",
416
+ "deletedDateTime",
417
+ "subject",
418
+ "summary",
419
+ "chatId",
420
+ "importance",
421
+ "locale",
422
+ "webUrl",
423
+ "policyViolation",
424
+ "from_application",
425
+ "from_device",
426
+ "from_user_id",
427
+ "from_user_displayName",
428
+ "from_user_userIdentityType",
429
+ "body_contentType",
430
+ "body_content",
431
+ ]
@@ -114,13 +114,27 @@ class PgVectorHandler(PostgresHandler, VectorStoreHandler):
114
114
  if conditions is None:
115
115
  return {}
116
116
 
117
- return {
118
- condition.column.split(".")[-1]: {
117
+ filter_conditions = {}
118
+
119
+ for condition in conditions:
120
+
121
+ parts = condition.column.split(".")
122
+ key = parts[0]
123
+ # converts 'col.el1.el2' to col->'el1'->>'el2'
124
+ if len(parts) > 1:
125
+ # intermediate elements
126
+ for el in parts[1:-1]:
127
+ key += f" -> '{el}'"
128
+
129
+ # last element
130
+ key += f" ->> '{parts[-1]}'"
131
+
132
+ filter_conditions[key] = {
119
133
  "op": condition.op.value,
120
134
  "value": condition.value,
121
135
  }
122
- for condition in conditions
123
- }
136
+
137
+ return filter_conditions
124
138
 
125
139
  @staticmethod
126
140
  def _construct_where_clause(filter_conditions=None):
@@ -57,6 +57,7 @@ class RedshiftHandler(PostgresHandler):
57
57
  connection.commit()
58
58
  except Exception as e:
59
59
  logger.error(f"Error inserting data into {table_name}, {e}!")
60
+ connection.rollback()
60
61
  response = Response(
61
62
  RESPONSE_TYPE.ERROR,
62
63
  error_code=0,
@@ -1 +1 @@
1
- salesforce_api
1
+ salesforce_api==0.1.45
@@ -39,31 +39,8 @@ class SalesforceHandler(APIHandler):
39
39
 
40
40
  self.connection = None
41
41
  self.is_connected = False
42
-
43
- # Register Salesforce tables.
44
- self.resource_names = {
45
- 'Account',
46
- 'Contact',
47
- 'Opportunity',
48
- 'Lead',
49
- 'Task',
50
- 'Event',
51
- 'User',
52
- 'Product2',
53
- 'Pricebook2',
54
- 'PricebookEntry',
55
- 'Order',
56
- 'OrderItem',
57
- 'Case',
58
- 'Campaign',
59
- 'CampaignMember',
60
- 'Contract',
61
- 'Asset'
62
- }
63
-
64
- for resource_name in self.resource_names:
65
- table_class = create_table_class(resource_name, resource_name)
66
- self._register_table(resource_name, table_class(self))
42
+ self.thread_safe = True
43
+ self.resource_names = []
67
44
 
68
45
  def connect(self) -> salesforce_api.client.Client:
69
46
  """
@@ -92,6 +69,12 @@ class SalesforceHandler(APIHandler):
92
69
  is_sandbox=self.connection_data.get('is_sandbox', False)
93
70
  )
94
71
  self.is_connected = True
72
+
73
+ # Register Salesforce tables.
74
+ for resource_name in self._get_resource_names():
75
+ table_class = create_table_class(resource_name)
76
+ self._register_table(resource_name.lower(), table_class(self))
77
+
95
78
  return self.connection
96
79
  except AuthenticationError as auth_error:
97
80
  logger.error(f"Authentication error connecting to Salesforce, {auth_error}!")
@@ -179,3 +162,15 @@ class SalesforceHandler(APIHandler):
179
162
  )
180
163
 
181
164
  return response
165
+
166
+ def _get_resource_names(self) -> None:
167
+ """
168
+ Retrieves the names of the Salesforce resources.
169
+
170
+ Returns:
171
+ None
172
+ """
173
+ if not self.resource_names:
174
+ self.resource_names = [resource['name'] for resource in self.connection.sobjects.describe()['sobjects']]
175
+
176
+ return self.resource_names
@@ -11,7 +11,7 @@ from mindsdb.utilities import log
11
11
  logger = log.getLogger(__name__)
12
12
 
13
13
 
14
- def create_table_class(table_name: Text, resource_name: Text) -> APIResource:
14
+ def create_table_class(resource_name: Text) -> APIResource:
15
15
  """
16
16
  Creates a table class for the given Salesforce resource.
17
17
  """
@@ -31,7 +31,7 @@ def create_table_class(table_name: Text, resource_name: Text) -> APIResource:
31
31
  Returns:
32
32
  pd.DataFrame: A DataFrame containing the data retrieved from the Salesforce resource.
33
33
  """
34
- query.from_table = table_name
34
+ query.from_table = resource_name
35
35
 
36
36
  # SOQL does not support * in SELECT queries. Replace * with column names.
37
37
  if isinstance(query.targets[0], Star):
@@ -10,8 +10,6 @@ class TimeScaleDBHandler(PostgresHandler):
10
10
  super().__init__(name, **kwargs)
11
11
 
12
12
 
13
-
14
-
15
13
  connection_args = OrderedDict(
16
14
  host={
17
15
  'type': ARG_TYPE.STR,
@@ -31,6 +29,12 @@ connection_args = OrderedDict(
31
29
  'type': ARG_TYPE.STR,
32
30
  'description': 'The password to authenticate the user with the TimeScaleDB server.'
33
31
  },
32
+ schema={
33
+ 'type': ARG_TYPE.STR,
34
+ 'description': 'The schema in which objects are searched first.',
35
+ 'required': False,
36
+ 'label': 'Schema'
37
+ },
34
38
  port={
35
39
  'type': ARG_TYPE.INT,
36
40
  'description': 'Specify port to connect TimeScaleDB '
@@ -39,8 +43,9 @@ connection_args = OrderedDict(
39
43
 
40
44
  connection_args_example = OrderedDict(
41
45
  host='127.0.0.1',
42
- port=36806,
43
- password='P455W0rD',
44
- user='tsdbadmin',
45
- database="tsdb"
46
+ port=5432,
47
+ password='password',
48
+ user='root',
49
+ database="timescaledb",
50
+ schema='public'
46
51
  )
@@ -111,10 +111,16 @@ def learn_process(data_integration_ref: dict, problem_definition: dict, fetch_da
111
111
  )
112
112
  handlers_cacher[predictor_record.id] = ml_handler
113
113
 
114
- if not ml_handler.generative:
114
+ if not ml_handler.generative and target is not None:
115
115
  if training_data_df is not None and target not in training_data_df.columns:
116
- raise Exception(
117
- f'Prediction target "{target}" not found in training dataframe: {list(training_data_df.columns)}')
116
+ # is the case different? convert column case in input dataframe
117
+ col_names = {c.lower(): c for c in training_data_df.columns}
118
+ target_found = col_names.get(target.lower())
119
+ if target_found:
120
+ training_data_df.rename(columns={target_found: target}, inplace=True)
121
+ else:
122
+ raise Exception(
123
+ f'Prediction target "{target}" not found in training dataframe: {list(training_data_df.columns)}')
118
124
 
119
125
  # create new model
120
126
  if base_model_id is None:
@@ -325,7 +325,7 @@ class VectorStoreHandler(BaseHandler):
325
325
  if not df_insert.empty:
326
326
  self.insert(table_name, df_insert)
327
327
 
328
- def _dispatch_delete(self, query: Delete):
328
+ def dispatch_delete(self, query: Delete):
329
329
  """
330
330
  Dispatch delete query to the appropriate method.
331
331
  """
@@ -382,7 +382,7 @@ class VectorStoreHandler(BaseHandler):
382
382
  DropTables: self._dispatch_drop_table,
383
383
  Insert: self._dispatch_insert,
384
384
  Update: self._dispatch_update,
385
- Delete: self._dispatch_delete,
385
+ Delete: self.dispatch_delete,
386
386
  Select: self.dispatch_select,
387
387
  }
388
388
  if type(query) in dispatch_router:
@@ -309,7 +309,7 @@ class FileReader(FormatDetector):
309
309
  )
310
310
  text = file_obj.read()
311
311
 
312
- metadata = {"source": name}
312
+ metadata = {"source_file": name, "file_format": "txt"}
313
313
  documents = [Document(page_content=text, metadata=metadata)]
314
314
 
315
315
  text_splitter = RecursiveCharacterTextSplitter(
@@ -325,7 +325,7 @@ class FileReader(FormatDetector):
325
325
  )
326
326
 
327
327
  @staticmethod
328
- def read_pdf(file_obj: BytesIO, **kwargs):
328
+ def read_pdf(file_obj: BytesIO, name=None, **kwargs):
329
329
 
330
330
  with fitz.open(stream=file_obj.read()) as pdf: # open pdf
331
331
  text = chr(12).join([page.get_text() for page in pdf])
@@ -337,7 +337,7 @@ class FileReader(FormatDetector):
337
337
  split_text = text_splitter.split_text(text)
338
338
 
339
339
  return pd.DataFrame(
340
- {"content": split_text, "metadata": [{}] * len(split_text)}
340
+ {"content": split_text, "metadata": [{"file_format": "pdf", "source_file": name}] * len(split_text)}
341
341
  )
342
342
 
343
343
  @staticmethod