MindsDB 25.1.2.0__py3-none-any.whl → 25.1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of MindsDB might be problematic. Click here for more details.

Files changed (99) hide show
  1. {MindsDB-25.1.2.0.dist-info → MindsDB-25.1.5.0.dist-info}/METADATA +258 -255
  2. {MindsDB-25.1.2.0.dist-info → MindsDB-25.1.5.0.dist-info}/RECORD +98 -85
  3. {MindsDB-25.1.2.0.dist-info → MindsDB-25.1.5.0.dist-info}/WHEEL +1 -1
  4. mindsdb/__about__.py +1 -1
  5. mindsdb/__main__.py +5 -3
  6. mindsdb/api/executor/__init__.py +0 -1
  7. mindsdb/api/executor/command_executor.py +2 -1
  8. mindsdb/api/executor/data_types/answer.py +1 -1
  9. mindsdb/api/executor/datahub/datanodes/datanode.py +1 -1
  10. mindsdb/api/executor/datahub/datanodes/information_schema_datanode.py +1 -1
  11. mindsdb/api/executor/datahub/datanodes/integration_datanode.py +8 -3
  12. mindsdb/api/executor/datahub/datanodes/project_datanode.py +9 -26
  13. mindsdb/api/executor/sql_query/__init__.py +1 -0
  14. mindsdb/api/executor/sql_query/result_set.py +36 -21
  15. mindsdb/api/executor/sql_query/steps/apply_predictor_step.py +1 -1
  16. mindsdb/api/executor/sql_query/steps/join_step.py +4 -4
  17. mindsdb/api/executor/sql_query/steps/map_reduce_step.py +6 -39
  18. mindsdb/api/executor/utilities/sql.py +2 -10
  19. mindsdb/api/http/namespaces/agents.py +3 -1
  20. mindsdb/api/http/namespaces/knowledge_bases.py +3 -3
  21. mindsdb/api/http/namespaces/sql.py +3 -1
  22. mindsdb/api/mysql/mysql_proxy/executor/mysql_executor.py +2 -1
  23. mindsdb/api/mysql/mysql_proxy/mysql_proxy.py +7 -0
  24. mindsdb/api/postgres/postgres_proxy/executor/executor.py +2 -1
  25. mindsdb/integrations/handlers/chromadb_handler/chromadb_handler.py +2 -2
  26. mindsdb/integrations/handlers/chromadb_handler/requirements.txt +1 -1
  27. mindsdb/integrations/handlers/databricks_handler/requirements.txt +1 -1
  28. mindsdb/integrations/handlers/file_handler/file_handler.py +1 -1
  29. mindsdb/integrations/handlers/file_handler/requirements.txt +0 -4
  30. mindsdb/integrations/handlers/file_handler/tests/test_file_handler.py +17 -1
  31. mindsdb/integrations/handlers/jira_handler/jira_handler.py +15 -1
  32. mindsdb/integrations/handlers/jira_handler/jira_table.py +52 -31
  33. mindsdb/integrations/handlers/langchain_embedding_handler/fastapi_embeddings.py +82 -0
  34. mindsdb/integrations/handlers/langchain_embedding_handler/langchain_embedding_handler.py +8 -1
  35. mindsdb/integrations/handlers/langchain_handler/requirements.txt +1 -1
  36. mindsdb/integrations/handlers/ms_one_drive_handler/ms_one_drive_handler.py +1 -1
  37. mindsdb/integrations/handlers/ms_one_drive_handler/ms_one_drive_tables.py +8 -0
  38. mindsdb/integrations/handlers/pgvector_handler/pgvector_handler.py +49 -12
  39. mindsdb/integrations/handlers/pinecone_handler/pinecone_handler.py +123 -72
  40. mindsdb/integrations/handlers/pinecone_handler/requirements.txt +1 -1
  41. mindsdb/integrations/handlers/postgres_handler/postgres_handler.py +12 -6
  42. mindsdb/integrations/handlers/ray_serve_handler/ray_serve_handler.py +5 -3
  43. mindsdb/integrations/handlers/slack_handler/slack_handler.py +13 -2
  44. mindsdb/integrations/handlers/slack_handler/slack_tables.py +21 -1
  45. mindsdb/integrations/handlers/web_handler/requirements.txt +0 -1
  46. mindsdb/integrations/libs/ml_handler_process/learn_process.py +2 -2
  47. mindsdb/integrations/utilities/files/__init__.py +0 -0
  48. mindsdb/integrations/utilities/files/file_reader.py +258 -0
  49. mindsdb/integrations/utilities/handlers/api_utilities/microsoft/ms_graph_api_utilities.py +2 -1
  50. mindsdb/integrations/utilities/handlers/auth_utilities/microsoft/ms_graph_api_auth_utilities.py +8 -3
  51. mindsdb/integrations/utilities/rag/chains/map_reduce_summarizer_chain.py +5 -9
  52. mindsdb/integrations/utilities/rag/loaders/vector_store_loader/pgvector.py +76 -27
  53. mindsdb/integrations/utilities/rag/loaders/vector_store_loader/vector_store_loader.py +18 -1
  54. mindsdb/integrations/utilities/rag/pipelines/rag.py +84 -20
  55. mindsdb/integrations/utilities/rag/rag_pipeline_builder.py +16 -1
  56. mindsdb/integrations/utilities/rag/rerankers/reranker_compressor.py +166 -108
  57. mindsdb/integrations/utilities/rag/retrievers/__init__.py +3 -0
  58. mindsdb/integrations/utilities/rag/retrievers/multi_hop_retriever.py +85 -0
  59. mindsdb/integrations/utilities/rag/retrievers/retriever_factory.py +57 -0
  60. mindsdb/integrations/utilities/rag/retrievers/sql_retriever.py +117 -48
  61. mindsdb/integrations/utilities/rag/settings.py +190 -17
  62. mindsdb/integrations/utilities/sql_utils.py +1 -1
  63. mindsdb/interfaces/agents/agents_controller.py +18 -8
  64. mindsdb/interfaces/agents/constants.py +1 -0
  65. mindsdb/interfaces/agents/langchain_agent.py +124 -157
  66. mindsdb/interfaces/agents/langfuse_callback_handler.py +4 -37
  67. mindsdb/interfaces/agents/mindsdb_database_agent.py +21 -13
  68. mindsdb/interfaces/chatbot/chatbot_controller.py +7 -11
  69. mindsdb/interfaces/chatbot/chatbot_task.py +16 -5
  70. mindsdb/interfaces/chatbot/memory.py +58 -13
  71. mindsdb/interfaces/database/integrations.py +5 -1
  72. mindsdb/interfaces/database/projects.py +55 -16
  73. mindsdb/interfaces/database/views.py +12 -25
  74. mindsdb/interfaces/knowledge_base/controller.py +38 -9
  75. mindsdb/interfaces/knowledge_base/preprocessing/document_loader.py +7 -26
  76. mindsdb/interfaces/model/functions.py +15 -4
  77. mindsdb/interfaces/model/model_controller.py +4 -7
  78. mindsdb/interfaces/skills/custom/text2sql/mindsdb_sql_toolkit.py +51 -40
  79. mindsdb/interfaces/skills/retrieval_tool.py +10 -3
  80. mindsdb/interfaces/skills/skill_tool.py +97 -54
  81. mindsdb/interfaces/skills/skills_controller.py +7 -3
  82. mindsdb/interfaces/skills/sql_agent.py +127 -41
  83. mindsdb/interfaces/storage/db.py +1 -1
  84. mindsdb/migrations/versions/2025-01-15_c06c35f7e8e1_project_company.py +88 -0
  85. mindsdb/utilities/cache.py +7 -4
  86. mindsdb/utilities/context.py +11 -1
  87. mindsdb/utilities/langfuse.py +279 -0
  88. mindsdb/utilities/log.py +20 -2
  89. mindsdb/utilities/otel/__init__.py +206 -0
  90. mindsdb/utilities/otel/logger.py +25 -0
  91. mindsdb/utilities/otel/meter.py +19 -0
  92. mindsdb/utilities/otel/metric_handlers/__init__.py +25 -0
  93. mindsdb/utilities/otel/tracer.py +16 -0
  94. mindsdb/utilities/partitioning.py +52 -0
  95. mindsdb/utilities/render/sqlalchemy_render.py +7 -1
  96. mindsdb/utilities/utils.py +34 -0
  97. mindsdb/utilities/otel.py +0 -72
  98. {MindsDB-25.1.2.0.dist-info → MindsDB-25.1.5.0.dist-info}/LICENSE +0 -0
  99. {MindsDB-25.1.2.0.dist-info → MindsDB-25.1.5.0.dist-info}/top_level.txt +0 -0
@@ -1,37 +1,84 @@
1
- from typing import Iterable, List, Optional
2
1
 
3
2
  import re
4
- from mindsdb_sql_parser.ast import Select, Show, Describe, Explain
3
+ import inspect
4
+ from typing import Iterable, List, Optional
5
5
 
6
6
  import pandas as pd
7
7
  from mindsdb_sql_parser import parse_sql
8
- from mindsdb_sql_parser.ast import Identifier
9
- from mindsdb.integrations.utilities.query_traversal import query_traversal
8
+ from mindsdb_sql_parser.ast import Select, Show, Describe, Explain, Identifier
10
9
 
11
10
  from mindsdb.utilities import log
12
11
  from mindsdb.utilities.context import context as ctx
12
+ from mindsdb.integrations.utilities.query_traversal import query_traversal
13
13
 
14
14
  logger = log.getLogger(__name__)
15
15
 
16
16
 
17
- class SQLAgent:
17
+ def split_table_name(table_name: str) -> List[str]:
18
+ """Split table name from llm to parst
19
+
20
+ Args:
21
+ table_name (str): input table name
22
+
23
+ Returns:
24
+ List[str]: parts of table identifier like ['database', 'schema', 'table']
25
+
26
+ Example:
27
+ Input: 'aaa.bbb', Output: ['aaa', 'bbb']
28
+ Input: '`aaa.bbb`', Output: ['aaa', 'bbb']
29
+ Input: '`aaa.`bbb``', Output: ['aaa', 'bbb']
30
+ Input: 'aaa.bbb.ccc', Output: ['aaa', 'bbb', 'ccc']
31
+ Input: '`aaa.bbb.ccc`', Output: ['aaa', 'bbb', 'ccc']
32
+ Input: '`aaa.`bbb.ccc``', Output: ['aaa', 'bbb.ccc']
33
+ Input: 'aaa.`bbb.ccc`', Output: ['aaa', 'bbb.ccc']
34
+ Input: 'aaa.`bbb.ccc`', Output: ['aaa', 'bbb.ccc']
35
+ Input: '`` aaa.`bbb.ccc`` \n`', Output: ['aaa', 'bbb.ccc']
36
+ """
37
+ table_name = table_name.strip(' "\'\n\r')
38
+ while table_name.startswith('`') and table_name.endswith('`'):
39
+ table_name = table_name[1:-1]
40
+ table_name = table_name.strip(' "\'\n\r')
41
+
42
+ result = []
43
+ part = []
44
+ inside_quotes = False
45
+
46
+ for char in table_name:
47
+ if char == '`':
48
+ inside_quotes = not inside_quotes
49
+ continue
50
+
51
+ if char == '.' and not inside_quotes:
52
+ result.append(''.join(part))
53
+ part = []
54
+ else:
55
+ part.append(char)
56
+
57
+ if part:
58
+ result.append(''.join(part))
59
+
60
+ return [x for x in result if len(x) > 0]
61
+
18
62
 
63
+ class SQLAgent:
19
64
  def __init__(
20
65
  self,
21
66
  command_executor,
22
- database: str,
67
+ databases: List[str],
68
+ databases_struct: dict,
23
69
  include_tables: Optional[List[str]] = None,
24
70
  ignore_tables: Optional[List[str]] = None,
25
71
  sample_rows_in_table_info: int = 3,
26
72
  cache: Optional[dict] = None
27
73
  ):
28
74
  self._command_executor = command_executor
75
+ self._mindsdb_db_struct = databases_struct
29
76
 
30
77
  self._sample_rows_in_table_info = int(sample_rows_in_table_info)
31
78
 
32
79
  self._tables_to_include = include_tables
33
80
  self._tables_to_ignore = []
34
- self._databases = database.split(',')
81
+ self._databases = databases
35
82
  if not self._tables_to_include:
36
83
  # ignore_tables and include_tables should not be used together.
37
84
  # include_tables takes priority if it's set.
@@ -40,7 +87,6 @@ class SQLAgent:
40
87
 
41
88
  def _call_engine(self, query: str, database=None):
42
89
  # switch database
43
-
44
90
  ast_query = parse_sql(query.strip('`'))
45
91
  self._check_permissions(ast_query)
46
92
 
@@ -55,7 +101,6 @@ class SQLAgent:
55
101
  return ret
56
102
 
57
103
  def _check_permissions(self, ast_query):
58
-
59
104
  # check type of query
60
105
  if not isinstance(ast_query, (Select, Show, Describe, Explain)):
61
106
  raise ValueError(f"Query is not allowed: {ast_query.to_string()}")
@@ -66,14 +111,21 @@ class SQLAgent:
66
111
  if is_table and isinstance(node, Identifier):
67
112
  name1 = node.to_string()
68
113
  name2 = '.'.join(node.parts)
69
- name3 = node.parts[-1]
114
+ if len(node.parts) == 3:
115
+ name3 = '.'.join(node.parts[1:])
116
+ else:
117
+ name3 = node.parts[-1]
70
118
  if not {name1, name2, name3}.intersection(self._tables_to_include):
71
119
  raise ValueError(f"Table {name1} not found. Available tables: {', '.join(self._tables_to_include)}")
72
120
 
73
121
  query_traversal(ast_query, _check_f)
74
122
 
75
123
  def get_usable_table_names(self) -> Iterable[str]:
124
+ """Get a list of tables that the agent has access to.
76
125
 
126
+ Returns:
127
+ Iterable[str]: list with table names
128
+ """
77
129
  cache_key = f'{ctx.company_id}_{",".join(self._databases)}_tables'
78
130
 
79
131
  # first check cache and return if found
@@ -85,25 +137,52 @@ class SQLAgent:
85
137
  if self._tables_to_include:
86
138
  return self._tables_to_include
87
139
 
88
- ret = self._call_engine('show databases;')
89
- dbs = [lst[0] for lst in ret.data.to_lists() if lst[0] != 'information_schema']
90
- usable_tables = []
91
- for db in dbs:
92
- if db != 'mindsdb' and db in self._databases:
93
- try:
94
- ret = self._call_engine('show tables', database=db)
95
- tables = [lst[0] for lst in ret.data.to_lists() if lst[0] != 'information_schema']
96
- for table in tables:
97
- # By default, include all tables in a database unless expilcitly ignored.
98
- table_name = f'{db}.{table}'
99
- if table_name not in self._tables_to_ignore:
100
- usable_tables.append(table_name)
101
- except Exception as e:
102
- logger.warning('Unable to get tables for %s: %s', db, str(e))
140
+ result_tables = []
141
+
142
+ for db_name in self._mindsdb_db_struct:
143
+ handler = self._command_executor.session.integration_controller.get_data_handler(db_name)
144
+
145
+ schemas_names = list(self._mindsdb_db_struct[db_name].keys())
146
+ if len(schemas_names) > 1 and None in schemas_names:
147
+ raise Exception('default schema and named schemas can not be used in same filter')
148
+
149
+ if None in schemas_names:
150
+ # get tables only from default schema
151
+ response = handler.get_tables()
152
+ tables_in_default_schema = list(response.data_frame.table_name)
153
+ schema_tables_restrictions = self._mindsdb_db_struct[db_name][None] # None - is default schema
154
+ if schema_tables_restrictions is None:
155
+ for table_name in tables_in_default_schema:
156
+ result_tables.append([db_name, table_name])
157
+ else:
158
+ for table_name in schema_tables_restrictions:
159
+ if table_name in tables_in_default_schema:
160
+ result_tables.append([db_name, table_name])
161
+ else:
162
+ if 'all' in inspect.signature(handler.get_tables).parameters:
163
+ response = handler.get_tables(all=True)
164
+ else:
165
+ response = handler.get_tables()
166
+ response_schema_names = list(response.data_frame.table_schema.unique())
167
+ schemas_intersection = set(schemas_names) & set(response_schema_names)
168
+ if len(schemas_intersection) == 0:
169
+ raise Exception('There are no allowed schemas in ds')
170
+
171
+ for schema_name in schemas_intersection:
172
+ schema_sub_df = response.data_frame[response.data_frame['table_schema'] == schema_name]
173
+ if self._mindsdb_db_struct[db_name][schema_name] is None:
174
+ # all tables from schema allowed
175
+ for row in schema_sub_df:
176
+ result_tables.append([db_name, schema_name, row['table_name']])
177
+ else:
178
+ for table_name in self._mindsdb_db_struct[db_name][schema_name]:
179
+ if table_name in schema_sub_df['table_name'].values:
180
+ result_tables.append([db_name, schema_name, table_name])
181
+
182
+ result_tables = ['.'.join(x) for x in result_tables]
103
183
  if self._cache:
104
- self._cache.set(cache_key, set(usable_tables))
105
-
106
- return usable_tables
184
+ self._cache.set(cache_key, set(result_tables))
185
+ return result_tables
107
186
 
108
187
  def _resolve_table_names(self, table_names: List[str], all_tables: List[Identifier]) -> List[Identifier]:
109
188
  """
@@ -115,7 +194,10 @@ class SQLAgent:
115
194
  tables_idx = {}
116
195
  for table in all_tables:
117
196
  # by name
118
- tables_idx[(table.parts[-1],)] = table
197
+ if len(table.parts) == 3:
198
+ tables_idx[tuple(table.parts[1:])] = table
199
+ else:
200
+ tables_idx[(table.parts[-1],)] = table
119
201
  # by path
120
202
  tables_idx[tuple(table.parts)] = table
121
203
 
@@ -125,15 +207,14 @@ class SQLAgent:
125
207
  continue
126
208
 
127
209
  # Some LLMs (e.g. gpt-4o) may include backticks or quotes when invoking tools.
128
- table_name = table_name.strip(' `"\'\n\r')
129
- table = Identifier(table_name)
210
+ table_parts = split_table_name(table_name)
130
211
 
131
212
  # resolved table
132
- table2 = tables_idx.get(tuple(table.parts))
213
+ table_identifier = tables_idx.get(tuple(table_parts))
133
214
 
134
- if table2 is None:
215
+ if table_identifier is None:
135
216
  raise ValueError(f"Table {table} not found in database")
136
- tables.append(table2)
217
+ tables.append(table_identifier)
137
218
 
138
219
  return tables
139
220
 
@@ -165,26 +246,31 @@ class SQLAgent:
165
246
  def _get_single_table_info(self, table: Identifier) -> str:
166
247
  if len(table.parts) < 2:
167
248
  raise ValueError(f"Database is required for table: {table}")
168
- integration, table_name = table.parts[-2:]
249
+ if len(table.parts) == 3:
250
+ integration, schema_name, table_name = table.parts[-3:]
251
+ else:
252
+ schema_name = None
253
+ integration, table_name = table.parts[-2:]
254
+
169
255
  table_str = str(table)
170
256
 
171
257
  dn = self._command_executor.session.datahub.get(integration)
172
258
 
173
259
  fields, dtypes = [], []
174
- for column in dn.get_table_columns(table_name):
260
+ for column in dn.get_table_columns(table_name, schema_name):
175
261
  fields.append(column['name'])
176
262
  dtypes.append(column.get('type', ''))
177
263
 
178
- info = f'Table named `{table_name}`\n'
179
- info += f"\n/* Sample with first {self._sample_rows_in_table_info} rows from table {table_str}:\n"
264
+ info = f'Table named `{table_str}`:\n'
265
+ info += f"\nSample with first {self._sample_rows_in_table_info} rows from table {table_str}:\n"
180
266
  info += "\t".join([field for field in fields])
181
- info += self._get_sample_rows(table_str, fields) + "\n*/"
267
+ info += self._get_sample_rows(table_str, fields) + "\n"
182
268
  info += '\nColumn data types: ' + ",\t".join(
183
- [f'`{field}` : `{dtype}`' for field, dtype in zip(fields, dtypes)]) + '\n' # noqa
269
+ [f'\n`{field}` : `{dtype}`' for field, dtype in zip(fields, dtypes)]) + '\n' # noqa
184
270
  return info
185
271
 
186
272
  def _get_sample_rows(self, table: str, fields: List[str]) -> str:
187
- command = f"select {','.join(fields)} from {table} limit {self._sample_rows_in_table_info};"
273
+ command = f"select {', '.join(fields)} from {table} limit {self._sample_rows_in_table_info};"
188
274
  try:
189
275
  ret = self._call_engine(command)
190
276
  sample_rows = ret.data.to_lists()
@@ -212,7 +212,7 @@ class Project(Base):
212
212
  )
213
213
  deleted_at = Column(DateTime)
214
214
  name = Column(String, nullable=False)
215
- company_id = Column(Integer)
215
+ company_id = Column(Integer, default=0)
216
216
  __table_args__ = (
217
217
  UniqueConstraint("name", "company_id", name="unique_project_name_company_id"),
218
218
  )
@@ -0,0 +1,88 @@
1
+ """project-company
2
+
3
+ Revision ID: c06c35f7e8e1
4
+ Revises: f6dc924079fa
5
+ Create Date: 2025-01-15 14:14:29.295834
6
+
7
+ """
8
+ from collections import defaultdict
9
+
10
+ from alembic import op
11
+ import sqlalchemy as sa
12
+ import mindsdb.interfaces.storage.db # noqa
13
+ from mindsdb.utilities import log
14
+
15
+ # revision identifiers, used by Alembic.
16
+ revision = 'c06c35f7e8e1'
17
+ down_revision = 'f6dc924079fa'
18
+ branch_labels = None
19
+ depends_on = None
20
+
21
+
22
+ logger = log.getLogger(__name__)
23
+
24
+
25
+ def upgrade():
26
+
27
+ """
28
+ convert company_id from null to 0 to make constrain works
29
+ duplicated names are renamed
30
+ """
31
+
32
+ conn = op.get_bind()
33
+ table = sa.Table(
34
+ 'project',
35
+ sa.MetaData(),
36
+ sa.Column('id', sa.Integer()),
37
+ sa.Column('name', sa.String()),
38
+ sa.Column('company_id', sa.Integer()),
39
+ )
40
+
41
+ data = conn.execute(
42
+ table
43
+ .select()
44
+ .where(table.c.company_id == sa.null())
45
+ ).fetchall()
46
+
47
+ names = defaultdict(list)
48
+ for id, name, _ in data:
49
+ names[name].append(id)
50
+
51
+ # get duplicated
52
+ for name, ids in names.items():
53
+ if len(ids) == 1:
54
+ continue
55
+
56
+ # rename all except first
57
+ for id in ids[1:]:
58
+ new_name = f'{name}__{id}'
59
+
60
+ op.execute(
61
+ table
62
+ .update()
63
+ .where(table.c.id == id)
64
+ .values({'name': new_name})
65
+ )
66
+ logger.warning(f'Found duplicated project name: {name}, renamed to: {new_name}')
67
+
68
+ op.execute(
69
+ table
70
+ .update()
71
+ .where(table.c.company_id == sa.null())
72
+ .values({'company_id': 0})
73
+ )
74
+
75
+
76
+ def downgrade():
77
+ table = sa.Table(
78
+ 'project',
79
+ sa.MetaData(),
80
+ sa.Column('company_id', sa.Integer())
81
+ )
82
+
83
+ op.execute(
84
+ table
85
+ .update()
86
+ .where(table.c.company_id == 0)
87
+ .values({'company_id': sa.null()})
88
+ )
@@ -71,10 +71,13 @@ _CACHE_MAX_SIZE = 500
71
71
 
72
72
 
73
73
  def dataframe_checksum(df: pd.DataFrame):
74
-
75
- return str_checksum(str(
76
- df.set_axis(range(len(df.columns)), axis=1).to_records(index=False)
77
- ))
74
+ original_columns = df.columns
75
+ df.columns = list(range(len(df.columns)))
76
+ result = hashlib.sha256(
77
+ str(df.values).encode()
78
+ ).hexdigest()
79
+ df.columns = original_columns
80
+ return result
78
81
 
79
82
 
80
83
  def json_checksum(obj: t.Union[dict, list]):
@@ -24,7 +24,8 @@ class Context:
24
24
  'enabled': False,
25
25
  'pointer': None,
26
26
  'tree': None
27
- }
27
+ },
28
+ 'email_confirmed': 0,
28
29
  })
29
30
 
30
31
  def __getattr__(self, name: str) -> Any:
@@ -52,6 +53,15 @@ class Context:
52
53
  def load(self, storage: dict) -> None:
53
54
  self._storage.set(storage)
54
55
 
56
+ def metadata(self, **kwargs) -> dict:
57
+ return {
58
+ 'user_id': self.user_id or "",
59
+ 'company_id': self.company_id or "",
60
+ 'session_id': self.session_id,
61
+ 'user_class': self.user_class,
62
+ **kwargs
63
+ }
64
+
55
65
 
56
66
  _context_var = ContextVar('mindsdb.context')
57
67
  context = Context(_context_var)
@@ -0,0 +1,279 @@
1
+ import os
2
+ import typing
3
+
4
+ from mindsdb.utilities import log
5
+ from langfuse import Langfuse
6
+ from langfuse.client import StatefulSpanClient
7
+ from langfuse.callback import CallbackHandler
8
+ from langfuse.api.resources.commons.errors.not_found_error import NotFoundError as TraceNotFoundError
9
+
10
+ logger = log.getLogger(__name__)
11
+
12
+ # Define Langfuse public key.
13
+ LANGFUSE_PUBLIC_KEY = os.getenv("LANGFUSE_PUBLIC_KEY", "langfuse_public_key")
14
+
15
+ # Define Langfuse secret key.
16
+ LANGFUSE_SECRET_KEY = os.getenv("LANGFUSE_SECRET_KEY", "langfuse_secret_key")
17
+
18
+ # Define Langfuse host.
19
+ LANGFUSE_HOST = os.getenv("LANGFUSE_HOST", "http://localhost:3000")
20
+
21
+ # Define Langfuse environment.
22
+ LANGFUSE_ENVIRONMENT = os.getenv("LANGFUSE_ENVIRONMENT", "local")
23
+
24
+ # Define Langfuse release.
25
+ LANGFUSE_RELEASE = os.getenv("LANGFUSE_RELEASE", "local")
26
+
27
+ # Define Langfuse debug mode.
28
+ LANGFUSE_DEBUG = os.getenv("LANGFUSE_DEBUG", "false").lower() == "true"
29
+
30
+ # Define Langfuse timeout.
31
+ LANGFUSE_TIMEOUT = int(os.getenv("LANGFUSE_TIMEOUT", 10))
32
+
33
+ # Define Langfuse sample rate.
34
+ LANGFUSE_SAMPLE_RATE = float(os.getenv("LANGFUSE_SAMPLE_RATE", 1.0))
35
+
36
+ # Define if Langfuse is disabled.
37
+ LANGFUSE_DISABLED = os.getenv("LANGFUSE_DISABLED", "false").lower() == "true" or LANGFUSE_ENVIRONMENT == "local"
38
+ LANGFUSE_FORCE_RUN = os.getenv("LANGFUSE_FORCE_RUN", "false").lower() == "true"
39
+
40
+
41
+ class LangfuseClientWrapper:
42
+ """
43
+ Langfuse client wrapper. Defines Langfuse client configuration and initializes Langfuse client.
44
+ """
45
+
46
+ def __init__(self,
47
+ public_key: str = LANGFUSE_PUBLIC_KEY,
48
+ secret_key: str = LANGFUSE_SECRET_KEY,
49
+ host: str = LANGFUSE_HOST,
50
+ environment: str = LANGFUSE_ENVIRONMENT,
51
+ release: str = LANGFUSE_RELEASE,
52
+ debug: bool = LANGFUSE_DEBUG,
53
+ timeout: int = LANGFUSE_TIMEOUT,
54
+ sample_rate: float = LANGFUSE_SAMPLE_RATE,
55
+ disable: bool = LANGFUSE_DISABLED,
56
+ force_run: bool = LANGFUSE_FORCE_RUN) -> None:
57
+ """
58
+ Initialize Langfuse client.
59
+
60
+ Args:
61
+ public_key (str): Langfuse public key.
62
+ secret_key (str): Langfuse secret key.
63
+ host (str): Langfuse host.
64
+ release (str): Langfuse release.
65
+ timeout (int): Langfuse timeout.
66
+ sample_rate (float): Langfuse sample rate.
67
+ """
68
+
69
+ self.metadata = None
70
+ self.public_key = public_key
71
+ self.secret_key = secret_key
72
+ self.host = host
73
+ self.environment = environment
74
+ self.release = release
75
+ self.debug = debug
76
+ self.timeout = timeout
77
+ self.sample_rate = sample_rate
78
+ self.disable = disable
79
+ self.force_run = force_run
80
+
81
+ self.client = None
82
+ self.trace = None
83
+ self.metadata = None
84
+ self.tags = None
85
+
86
+ # Check if Langfuse is disabled.
87
+ if LANGFUSE_DISABLED and not LANGFUSE_FORCE_RUN:
88
+ logger.info("Langfuse is disabled.")
89
+ return
90
+
91
+ logger.info("Langfuse enabled")
92
+ logger.debug(f"LANGFUSE_PUBLIC_KEY: {LANGFUSE_PUBLIC_KEY}")
93
+ logger.debug(f"LANGFUSE_SECRET_KEY: {'*' * len(LANGFUSE_SECRET_KEY)}")
94
+ logger.debug(f"LANGFUSE_HOST: {LANGFUSE_HOST}")
95
+ logger.debug(f"LANGFUSE_ENVIRONMENT: {LANGFUSE_ENVIRONMENT}")
96
+ logger.debug(f"LANGFUSE_RELEASE: {LANGFUSE_RELEASE}")
97
+ logger.debug(f"LANGFUSE_DEBUG: {LANGFUSE_DEBUG}")
98
+ logger.debug(f"LANGFUSE_TIMEOUT: {LANGFUSE_TIMEOUT}")
99
+ logger.debug(f"LANGFUSE_SAMPLE_RATE: {LANGFUSE_SAMPLE_RATE * 100}%")
100
+
101
+ self.client = Langfuse(
102
+ public_key=public_key,
103
+ secret_key=secret_key,
104
+ host=host,
105
+ release=release,
106
+ debug=debug,
107
+ timeout=timeout,
108
+ sample_rate=sample_rate
109
+ )
110
+
111
+ def setup_trace(self,
112
+ name: str,
113
+ input: typing.Optional[typing.Any] = None,
114
+ tags: typing.Optional[typing.List] = None,
115
+ metadata: typing.Optional[typing.Dict] = None,
116
+ user_id: str = None,
117
+ session_id: str = None) -> None:
118
+ """
119
+ Setup trace. If Langfuse is disabled, nothing will be done.
120
+ Args:
121
+ name (str): Trace name.
122
+ input (dict): Trace input.
123
+ tags (dict): Trace tags.
124
+ metadata (dict): Trace metadata.
125
+ user_id (str): User ID.
126
+ session_id (str): Session ID.
127
+ """
128
+
129
+ if self.client is None:
130
+ logger.debug("Langfuse is disabled.")
131
+ return
132
+
133
+ self.set_metadata(metadata)
134
+ self.set_tags(tags)
135
+
136
+ try:
137
+ self.trace = self.client.trace(
138
+ name=name,
139
+ input=input,
140
+ metadata=self.metadata,
141
+ tags=self.tags,
142
+ user_id=user_id,
143
+ session_id=session_id
144
+ )
145
+ except Exception as e:
146
+ logger.error(f'Something went wrong while processing Langfuse trace {self.trace.id}: {str(e)}')
147
+
148
+ logger.info(f"Langfuse trace configured with ID: {self.trace.id}")
149
+
150
+ def get_trace_id(self) -> typing.Optional[str]:
151
+ """
152
+ Get trace ID. If Langfuse is disabled, returns None.
153
+ """
154
+
155
+ if self.client is None:
156
+ logger.debug("Langfuse is disabled.")
157
+ return ""
158
+
159
+ if self.trace is None:
160
+ logger.debug("Langfuse trace is not setup.")
161
+ return ""
162
+
163
+ return self.trace.id
164
+
165
+ def start_span(self,
166
+ name: str,
167
+ input: typing.Optional[typing.Any] = None) -> typing.Optional[StatefulSpanClient]:
168
+ """
169
+ Create span. If Langfuse is disabled, nothing will be done.
170
+
171
+ Args:
172
+ name (str): Span name.
173
+ input (dict): Span input.
174
+ """
175
+
176
+ if self.client is None:
177
+ logger.debug("Langfuse is disabled.")
178
+ return None
179
+
180
+ return self.trace.span(name=name, input=input)
181
+
182
+ def end_span_stream(self,
183
+ span: typing.Optional[StatefulSpanClient] = None) -> None:
184
+ """
185
+ End span. If Langfuse is disabled, nothing will happen.
186
+ Args:
187
+ span (Any): Span object.
188
+ """
189
+
190
+ if self.client is None:
191
+ logger.debug("Langfuse is disabled.")
192
+ return
193
+
194
+ span.end()
195
+ self.trace.update()
196
+
197
+ def end_span(self,
198
+ span: typing.Optional[StatefulSpanClient] = None,
199
+ output: typing.Optional[typing.Any] = None) -> None:
200
+ """
201
+ End trace. If Langfuse is disabled, nothing will be done.
202
+
203
+ Args:
204
+ span (Any): Span object.
205
+ output (Any): Span output.
206
+ """
207
+
208
+ if self.client is None:
209
+ logger.debug("Langfuse is disabled.")
210
+ return
211
+
212
+ if span is None:
213
+ logger.debug("Langfuse span is not created.")
214
+ return
215
+
216
+ span.end(output=output)
217
+ self.trace.update(output=output)
218
+
219
+ metadata = self.metadata or {}
220
+
221
+ try:
222
+ # Ensure all batched traces are sent before fetching.
223
+ self.client.flush()
224
+ metadata['tool_usage'] = self._get_tool_usage()
225
+ self.trace.update(metadata=metadata)
226
+
227
+ except Exception as e:
228
+ logger.error(f'Something went wrong while processing Langfuse trace {self.trace.id}: {str(e)}')
229
+
230
+ def get_langchain_handler(self) -> typing.Optional[CallbackHandler]:
231
+ """
232
+ Get Langchain handler. If Langfuse is disabled, returns None.
233
+ """
234
+
235
+ if self.client is None:
236
+ logger.debug("Langfuse is disabled.")
237
+ return None
238
+
239
+ return self.trace.get_langchain_handler()
240
+
241
+ def set_metadata(self, custom_metadata: dict = None) -> None:
242
+ """
243
+ Get default metadata.
244
+ """
245
+ self.metadata = custom_metadata or {}
246
+
247
+ self.metadata["environment"] = self.environment
248
+ self.metadata["release"] = self.release
249
+
250
+ def set_tags(self, custom_tags: typing.Optional[typing.List] = None) -> None:
251
+ """
252
+ Get default tags.
253
+ """
254
+ self.tags = custom_tags or []
255
+
256
+ self.tags.append(self.environment)
257
+ self.tags.append(self.release)
258
+
259
+ def _get_tool_usage(self) -> typing.Dict:
260
+ """ Retrieves tool usage information from a langfuse trace.
261
+ Note: assumes trace marks an action with string `AgentAction` """
262
+
263
+ tool_usage = {}
264
+
265
+ try:
266
+ fetched_trace = self.client.get_trace(self.trace.id)
267
+ steps = [s.name for s in fetched_trace.observations]
268
+ for step in steps:
269
+ if 'AgentAction' in step:
270
+ tool_name = step.split('-')[1]
271
+ if tool_name not in tool_usage:
272
+ tool_usage[tool_name] = 0
273
+ tool_usage[tool_name] += 1
274
+ except TraceNotFoundError:
275
+ logger.warning(f'Langfuse trace {self.trace.id} not found')
276
+ except Exception as e:
277
+ logger.error(f'Something went wrong while processing Langfuse trace {self.trace.id}: {str(e)}')
278
+
279
+ return tool_usage