iatoolkit 0.91.1__py3-none-any.whl → 1.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. iatoolkit/__init__.py +6 -4
  2. iatoolkit/base_company.py +0 -16
  3. iatoolkit/cli_commands.py +3 -14
  4. iatoolkit/common/exceptions.py +1 -0
  5. iatoolkit/common/interfaces/__init__.py +0 -0
  6. iatoolkit/common/interfaces/asset_storage.py +34 -0
  7. iatoolkit/common/interfaces/database_provider.py +43 -0
  8. iatoolkit/common/model_registry.py +159 -0
  9. iatoolkit/common/routes.py +47 -5
  10. iatoolkit/common/util.py +32 -13
  11. iatoolkit/company_registry.py +5 -0
  12. iatoolkit/core.py +51 -20
  13. iatoolkit/infra/connectors/file_connector_factory.py +1 -0
  14. iatoolkit/infra/connectors/s3_connector.py +4 -2
  15. iatoolkit/infra/llm_providers/__init__.py +0 -0
  16. iatoolkit/infra/llm_providers/deepseek_adapter.py +278 -0
  17. iatoolkit/infra/{gemini_adapter.py → llm_providers/gemini_adapter.py} +11 -17
  18. iatoolkit/infra/{openai_adapter.py → llm_providers/openai_adapter.py} +41 -7
  19. iatoolkit/infra/llm_proxy.py +235 -134
  20. iatoolkit/infra/llm_response.py +5 -0
  21. iatoolkit/locales/en.yaml +158 -2
  22. iatoolkit/locales/es.yaml +158 -0
  23. iatoolkit/repositories/database_manager.py +52 -47
  24. iatoolkit/repositories/document_repo.py +7 -0
  25. iatoolkit/repositories/filesystem_asset_repository.py +36 -0
  26. iatoolkit/repositories/llm_query_repo.py +2 -0
  27. iatoolkit/repositories/models.py +72 -79
  28. iatoolkit/repositories/profile_repo.py +59 -3
  29. iatoolkit/repositories/vs_repo.py +22 -24
  30. iatoolkit/services/company_context_service.py +126 -53
  31. iatoolkit/services/configuration_service.py +299 -73
  32. iatoolkit/services/dispatcher_service.py +21 -3
  33. iatoolkit/services/file_processor_service.py +0 -5
  34. iatoolkit/services/history_manager_service.py +43 -24
  35. iatoolkit/services/knowledge_base_service.py +425 -0
  36. iatoolkit/{infra/llm_client.py → services/llm_client_service.py} +38 -29
  37. iatoolkit/services/load_documents_service.py +26 -48
  38. iatoolkit/services/profile_service.py +32 -4
  39. iatoolkit/services/prompt_service.py +32 -30
  40. iatoolkit/services/query_service.py +51 -26
  41. iatoolkit/services/sql_service.py +122 -74
  42. iatoolkit/services/tool_service.py +26 -11
  43. iatoolkit/services/user_session_context_service.py +115 -63
  44. iatoolkit/static/js/chat_main.js +44 -4
  45. iatoolkit/static/js/chat_model_selector.js +227 -0
  46. iatoolkit/static/js/chat_onboarding_button.js +1 -1
  47. iatoolkit/static/js/chat_reload_button.js +4 -1
  48. iatoolkit/static/styles/chat_iatoolkit.css +58 -2
  49. iatoolkit/static/styles/llm_output.css +34 -1
  50. iatoolkit/system_prompts/query_main.prompt +26 -2
  51. iatoolkit/templates/base.html +13 -0
  52. iatoolkit/templates/chat.html +45 -2
  53. iatoolkit/templates/onboarding_shell.html +0 -1
  54. iatoolkit/views/base_login_view.py +7 -2
  55. iatoolkit/views/chat_view.py +76 -0
  56. iatoolkit/views/configuration_api_view.py +163 -0
  57. iatoolkit/views/load_document_api_view.py +14 -10
  58. iatoolkit/views/login_view.py +8 -3
  59. iatoolkit/views/rag_api_view.py +216 -0
  60. iatoolkit/views/users_api_view.py +33 -0
  61. {iatoolkit-0.91.1.dist-info → iatoolkit-1.7.0.dist-info}/METADATA +4 -4
  62. {iatoolkit-0.91.1.dist-info → iatoolkit-1.7.0.dist-info}/RECORD +66 -58
  63. iatoolkit/repositories/tasks_repo.py +0 -52
  64. iatoolkit/services/search_service.py +0 -55
  65. iatoolkit/services/tasks_service.py +0 -188
  66. iatoolkit/views/tasks_api_view.py +0 -72
  67. iatoolkit/views/tasks_review_api_view.py +0 -55
  68. {iatoolkit-0.91.1.dist-info → iatoolkit-1.7.0.dist-info}/WHEEL +0 -0
  69. {iatoolkit-0.91.1.dist-info → iatoolkit-1.7.0.dist-info}/licenses/LICENSE +0 -0
  70. {iatoolkit-0.91.1.dist-info → iatoolkit-1.7.0.dist-info}/licenses/LICENSE_COMMUNITY.md +0 -0
  71. {iatoolkit-0.91.1.dist-info → iatoolkit-1.7.0.dist-info}/top_level.txt +0 -0
@@ -3,7 +3,7 @@
3
3
  #
4
4
  # IAToolkit is open source software.
5
5
 
6
- from iatoolkit.infra.llm_client import llmClient
6
+ from iatoolkit.services.llm_client_service import llmClient
7
7
  from iatoolkit.services.profile_service import ProfileService
8
8
  from iatoolkit.repositories.profile_repo import ProfileRepo
9
9
  from iatoolkit.services.tool_service import ToolService
@@ -11,11 +11,11 @@ from iatoolkit.services.document_service import DocumentService
11
11
  from iatoolkit.services.company_context_service import CompanyContextService
12
12
  from iatoolkit.services.i18n_service import I18nService
13
13
  from iatoolkit.services.configuration_service import ConfigurationService
14
- from iatoolkit.repositories.models import Task
15
14
  from iatoolkit.services.dispatcher_service import Dispatcher
16
15
  from iatoolkit.services.prompt_service import PromptService
17
16
  from iatoolkit.services.user_session_context_service import UserSessionContextService
18
17
  from iatoolkit.services.history_manager_service import HistoryManagerService
18
+ from iatoolkit.common.model_registry import ModelRegistry
19
19
  from iatoolkit.common.util import Utility
20
20
  from injector import inject
21
21
  import base64
@@ -33,6 +33,7 @@ class HistoryHandle:
33
33
  company_short_name: str
34
34
  user_identifier: str
35
35
  type: str
36
+ model: str | None = None
36
37
  request_params: dict = None
37
38
 
38
39
 
@@ -52,6 +53,7 @@ class QueryService:
52
53
  configuration_service: ConfigurationService,
53
54
  history_manager: HistoryManagerService,
54
55
  util: Utility,
56
+ model_registry: ModelRegistry
55
57
  ):
56
58
  self.profile_service = profile_service
57
59
  self.company_context_service = company_context_service
@@ -66,6 +68,7 @@ class QueryService:
66
68
  self.configuration_service = configuration_service
67
69
  self.llm_client = llm_client
68
70
  self.history_manager = history_manager
71
+ self.model_registry = model_registry
69
72
 
70
73
 
71
74
  def _resolve_model(self, company_short_name: str, model: Optional[str]) -> str:
@@ -78,13 +81,16 @@ class QueryService:
78
81
  return effective_model
79
82
 
80
83
  def _get_history_type(self, model: str) -> str:
81
- return HistoryManagerService.TYPE_SERVER_SIDE if self.util.is_openai_model(
82
- model) else HistoryManagerService.TYPE_CLIENT_SIDE
84
+ history_type_str = self.model_registry.get_history_type(model)
85
+ if history_type_str == "server_side":
86
+ return HistoryManagerService.TYPE_SERVER_SIDE
87
+ else:
88
+ return HistoryManagerService.TYPE_CLIENT_SIDE
83
89
 
84
90
 
85
91
  def _build_user_facing_prompt(self, company, user_identifier: str,
86
92
  client_data: dict, files: list,
87
- prompt_name: Optional[str], question: str) -> str:
93
+ prompt_name: Optional[str], question: str):
88
94
  # get the user profile data from the session context
89
95
  user_profile = self.profile_service.get_profile_by_identifier(company.short_name, user_identifier)
90
96
 
@@ -124,9 +130,12 @@ class QueryService:
124
130
 
125
131
  return user_turn_prompt, effective_question
126
132
 
127
- def _ensure_valid_history(self, company, user_identifier: str,
128
- effective_model: str, user_turn_prompt: str,
129
- ignore_history: bool) -> tuple[Optional[HistoryHandle], Optional[dict]]:
133
+ def _ensure_valid_history(self, company,
134
+ user_identifier: str,
135
+ effective_model: str,
136
+ user_turn_prompt: str,
137
+ ignore_history: bool
138
+ ) -> tuple[Optional[HistoryHandle], Optional[dict]]:
130
139
  """
131
140
  Manages the history strategy and rebuilds context if necessary.
132
141
  Returns: (HistoryHandle, error_response)
@@ -137,7 +146,8 @@ class QueryService:
137
146
  handle = HistoryHandle(
138
147
  company_short_name=company.short_name,
139
148
  user_identifier=user_identifier,
140
- type=history_type
149
+ type=history_type,
150
+ model=effective_model
141
151
  )
142
152
 
143
153
  # pass the handle to populate request_params
@@ -197,22 +207,35 @@ class QueryService:
197
207
  def init_context(self, company_short_name: str,
198
208
  user_identifier: str,
199
209
  model: str = None) -> dict:
210
+ """
211
+ Forces a context rebuild for a given user and (optionally) model.
212
+
213
+ - Clears LLM-related context for the resolved model.
214
+ - Regenerates the static company/user context.
215
+ - Sends the context to the LLM for that model.
216
+ """
200
217
 
201
- # 1. Execute the forced rebuild sequence using the unified identifier.
202
- self.session_context.clear_all_context(company_short_name, user_identifier)
203
- logging.info(f"Context for {company_short_name}/{user_identifier} has been cleared.")
218
+ # 1. Resolve the effective model for this user/company
219
+ effective_model = self._resolve_model(company_short_name, model)
204
220
 
205
- # 2. LLM context is clean, now we can load it again
221
+ # 2. Clear only the LLM-related context for this model
222
+ self.session_context.clear_all_context(company_short_name, user_identifier,model=effective_model)
223
+ logging.info(
224
+ f"Context for {company_short_name}/{user_identifier} "
225
+ f"(model={effective_model}) has been cleared."
226
+ )
227
+
228
+ # 3. Static LLM context is now clean, we can prepare it again (model-agnostic)
206
229
  self.prepare_context(
207
230
  company_short_name=company_short_name,
208
231
  user_identifier=user_identifier
209
232
  )
210
233
 
211
- # 3. communicate the new context to the LLM
234
+ # 4. Communicate the new context to the specific LLM model
212
235
  response = self.set_context_for_llm(
213
236
  company_short_name=company_short_name,
214
237
  user_identifier=user_identifier,
215
- model=model
238
+ model=effective_model
216
239
  )
217
240
 
218
241
  return response
@@ -257,8 +280,10 @@ class QueryService:
257
280
  company_short_name: str,
258
281
  user_identifier: str,
259
282
  model: str = ''):
260
-
261
- # This service takes a pre-built context and send to the LLM
283
+ """
284
+ Takes a pre-built static context and sends it to the LLM for the given model.
285
+ Also initializes the model-specific history through HistoryManagerService.
286
+ """
262
287
  company = self.profile_repo.get_company_by_short_name(company_short_name)
263
288
  if not company:
264
289
  logging.error(f"Company not found: {company_short_name} in set_context_for_llm")
@@ -267,8 +292,8 @@ class QueryService:
267
292
  # --- Model Resolution ---
268
293
  effective_model = self._resolve_model(company_short_name, model)
269
294
 
270
- # blocking logic to avoid multiple requests for the same user/company at the same time
271
- lock_key = f"lock:context:{company_short_name}/{user_identifier}"
295
+ # Lock per (company, user, model) to avoid concurrent rebuilds for the same model
296
+ lock_key = f"lock:context:{company_short_name}/{user_identifier}/{effective_model}"
272
297
  if not self.session_context.acquire_lock(lock_key, expire_seconds=60):
273
298
  logging.warning(
274
299
  f"try to rebuild context for user {user_identifier} while is still in process, ignored.")
@@ -310,13 +335,13 @@ class QueryService:
310
335
  def llm_query(self,
311
336
  company_short_name: str,
312
337
  user_identifier: str,
313
- task: Optional[Task] = None,
338
+ model: Optional[str] = None,
314
339
  prompt_name: str = None,
315
340
  question: str = '',
316
341
  client_data: dict = {},
317
342
  ignore_history: bool = False,
318
- files: list = [],
319
- model: Optional[str] = None) -> dict:
343
+ files: list = []
344
+ ) -> dict:
320
345
  try:
321
346
  company = self.profile_repo.get_company_by_short_name(short_name=company_short_name)
322
347
  if not company:
@@ -378,10 +403,10 @@ class QueryService:
378
403
  if not response.get('valid_response'):
379
404
  response['error'] = True
380
405
 
381
- # save history using the manager passing the handle
382
- self.history_manager.update_history(
383
- history_handle, user_turn_prompt, response
384
- )
406
+ # save history using the manager passing the handle
407
+ self.history_manager.update_history(
408
+ history_handle, user_turn_prompt, response
409
+ )
385
410
 
386
411
  return response
387
412
  except Exception as e:
@@ -3,13 +3,13 @@
3
3
  #
4
4
  # IAToolkit is open source software.
5
5
 
6
+ from iatoolkit.common.interfaces.database_provider import DatabaseProvider
6
7
  from iatoolkit.repositories.database_manager import DatabaseManager
7
- from iatoolkit.common.util import Utility
8
8
  from iatoolkit.services.i18n_service import I18nService
9
9
  from iatoolkit.common.exceptions import IAToolkitException
10
- from sqlalchemy import text
11
- from sqlalchemy.exc import SQLAlchemyError
10
+ from iatoolkit.common.util import Utility
12
11
  from injector import inject, singleton
12
+ from typing import Callable
13
13
  import json
14
14
  import logging
15
15
 
@@ -28,91 +28,124 @@ class SqlService:
28
28
  self.util = util
29
29
  self.i18n_service = i18n_service
30
30
 
31
- # Cache for database connections
32
- self._db_connections: dict[str, DatabaseManager] = {}
31
+ # Cache for database providers. Key is tuple: (company_short_name, db_name)
32
+ # Value is the abstract interface DatabaseProvider
33
+ self._db_connections: dict[tuple[str, str], DatabaseProvider] = {}
34
+
35
+ # cache for database schemas. Key is tuple: (company_short_name, db_name)
36
+ self._db_schemas: dict[tuple[str, str], str] = {}
37
+
38
+ # Registry of factory functions.
39
+ # Format: {'connection_type': function(config_dict) -> DatabaseProvider}
40
+ self._provider_factories: dict[str, Callable[[dict], DatabaseProvider]] = {}
41
+
42
+ # Register the default 'direct' strategy (SQLAlchemy)
43
+ self.register_provider_factory('direct', self._create_direct_connection)
33
44
 
34
- def register_database(self, db_uri: str, db_name: str, schema: str | None = None):
45
+ def register_provider_factory(self, connection_type: str, factory: Callable[[dict], DatabaseProvider]):
35
46
  """
36
- Creates and caches a DatabaseManager instance for a given database name and URI.
37
- If a database with the same name is already registered, it does nothing.
47
+ Allows plugins (Enterprise) to register new connection types.
48
+ """
49
+ self._provider_factories[connection_type] = factory
50
+
51
+ def _create_direct_connection(self, config: dict) -> DatabaseProvider:
52
+ """Default factory for standard SQLAlchemy connections."""
53
+ uri = config.get('db_uri') or config.get('DATABASE_URI')
54
+ schema = config.get('schema')
55
+ if not uri:
56
+ raise IAToolkitException(IAToolkitException.ErrorType.DATABASE_ERROR,
57
+ "Missing db_uri for direct connection")
58
+ return DatabaseManager(uri, schema=schema, register_pgvector=False)
59
+
60
+ def register_database(self, company_short_name: str, db_name: str, config: dict):
38
61
  """
39
- if db_name in self._db_connections:
62
+ Creates and caches a DatabaseProvider instance based on the configuration.
63
+ """
64
+ key = (company_short_name, db_name)
65
+
66
+ # Determine connection type (default to 'direct')
67
+ conn_type = config.get('connection_type', 'direct')
68
+ logging.info(f"Registering DB '{db_name}' ({conn_type}) for company '{company_short_name}'")
69
+
70
+ factory = self._provider_factories.get(conn_type)
71
+ if not factory:
72
+ logging.error(f"Unknown connection type '{conn_type}' for DB '{db_name}'. Skipping.")
40
73
  return
41
74
 
42
- logging.info(f"Registering and creating connection for database: '{db_name}' (schema: {schema})")
75
+ try:
76
+ # Create the provider using the appropriate factory
77
+ provider_instance = factory(config)
78
+ self._db_connections[key] = provider_instance
79
+
80
+ # save the db_schema
81
+ self._db_schemas[key] = config.get('schema', 'public')
82
+ except Exception as e:
83
+ logging.error(f"Failed to register DB '{db_name}': {e}")
84
+ # We don't raise here to allow other DBs to load if one fails
43
85
 
44
- # create the database connection and save it on the cache
45
- db_manager = DatabaseManager(db_uri, schema=schema, register_pgvector=False)
46
- self._db_connections[db_name] = db_manager
86
+ def get_db_names(self, company_short_name: str) -> list[str]:
87
+ """
88
+ Returns list of logical database names available ONLY for the specified company.
89
+ """
90
+ return [db for (co, db) in self._db_connections.keys() if co == company_short_name]
47
91
 
48
- def get_database_manager(self, db_name: str) -> DatabaseManager:
92
+ def get_database_provider(self, company_short_name: str, db_name: str) -> DatabaseProvider:
49
93
  """
50
- Retrieves a registered DatabaseManager instance from the cache.
94
+ Retrieves a registered DatabaseProvider instance using the composite key.
95
+ Replaces the old 'get_database_manager'.
51
96
  """
97
+ key = (company_short_name, db_name)
52
98
  try:
53
- return self._db_connections[db_name]
99
+ return self._db_connections[key]
54
100
  except KeyError:
55
- logging.error(f"Attempted to access unregistered database: '{db_name}'")
101
+ logging.error(
102
+ f"Attempted to access unregistered database: '{db_name}' for company '{company_short_name}'"
103
+ )
56
104
  raise IAToolkitException(
57
105
  IAToolkitException.ErrorType.DATABASE_ERROR,
58
- f"Database '{db_name}' is not registered with the SqlService."
106
+ f"Database '{db_name}' is not registered for this company."
59
107
  )
60
108
 
61
- def exec_sql(self, company_short_name: str,
62
- database: str,
63
- query: str,
64
- format: str = 'json',
65
- commit: bool = False):
109
+ def exec_sql(self, company_short_name: str, **kwargs):
66
110
  """
67
- Executes a raw SQL statement against a registered database.
68
-
69
- Args:
70
- company_short_name: The company identifier (for logging/context).
71
- database: The logical name of the database to query.
72
- query: The SQL statement to execute.
73
- format: The output format ('json' or 'dict'). Only relevant for SELECT queries.
74
- commit: Whether to commit the transaction immediately after execution.
75
- Use True for INSERT/UPDATE/DELETE statements.
76
-
77
- Returns:
78
- - A JSON string or list of dicts for SELECT queries.
79
- - A dictionary {'rowcount': N} for non-returning statements (INSERT/UPDATE) if not using RETURNING.
111
+ Executes a raw SQL statement against a registered database provider.
112
+ Delegates the actual execution details to the provider implementation.
80
113
  """
81
- try:
82
- # 1. Get the database manager from the cache
83
- db_manager = self.get_database_manager(database)
84
- session = db_manager.get_session()
114
+ database_name = kwargs.get('database_key')
115
+ query = kwargs.get('query')
116
+ format = kwargs.get('format', 'json')
117
+ commit = kwargs.get('commit')
85
118
 
86
- # 2. Execute the SQL statement
87
- result = session.execute(text(query))
88
-
89
- # 3. Handle Commit
90
- if commit:
91
- session.commit()
119
+ if not database_name:
120
+ raise IAToolkitException(IAToolkitException.ErrorType.DATABASE_ERROR,
121
+ 'missing database_name in call to exec_sql')
92
122
 
93
- # 4. Process Results
94
- # Check if the query returns rows (e.g., SELECT or INSERT ... RETURNING)
95
- if result.returns_rows:
96
- cols = result.keys()
97
- rows_context = [dict(zip(cols, row)) for row in result.fetchall()]
123
+ try:
124
+ # 1. Get the abstract provider (could be Direct or Bridge)
125
+ provider = self.get_database_provider(company_short_name, database_name)
126
+ db_schema = self._db_schemas[(company_short_name, database_name)]
98
127
 
99
- if format == 'dict':
100
- return rows_context
128
+ # 2. Delegate execution
129
+ # The provider returns a clean List[Dict] or Dict result
130
+ result_data = provider.execute_query(query=query, commit=commit)
101
131
 
102
- # serialize the result
103
- return json.dumps(rows_context, default=self.util.serialize)
132
+ # 3. Handle Formatting (Service layer responsibility)
133
+ if format == 'dict':
134
+ return result_data
104
135
 
105
- # For statements that don't return rows (standard UPDATE/DELETE)
106
- return {'rowcount': result.rowcount}
136
+ # Serialize the result
137
+ return json.dumps(result_data, default=self.util.serialize)
107
138
 
108
139
  except IAToolkitException:
109
- # Re-raise exceptions from get_database_manager to preserve the specific error
110
140
  raise
111
141
  except Exception as e:
112
- # Attempt to rollback if a session was active
113
- db_manager = self._db_connections.get(database)
114
- if db_manager:
115
- db_manager.get_session().rollback()
142
+ # Attempt rollback if supported/needed
143
+ try:
144
+ provider = self.get_database_provider(company_short_name, database_name)
145
+ if provider:
146
+ provider.rollback()
147
+ except Exception:
148
+ pass
116
149
 
117
150
  error_message = str(e)
118
151
  if 'timed out' in str(e):
@@ -122,22 +155,37 @@ class SqlService:
122
155
  raise IAToolkitException(IAToolkitException.ErrorType.DATABASE_ERROR,
123
156
  error_message) from e
124
157
 
125
- def commit(self, database: str):
158
+ def commit(self, company_short_name: str, database_name: str):
126
159
  """
127
- Commits the current transaction for a registered database.
128
- Useful when multiple exec_sql calls are part of a single transaction.
160
+ Commits the current transaction for a registered database provider.
129
161
  """
130
-
131
- # Get the database manager from the cache
132
- db_manager = self.get_database_manager(database)
162
+ provider = self.get_database_provider(company_short_name, database_name)
133
163
  try:
134
- db_manager.get_session().commit()
135
- except SQLAlchemyError as db_error:
136
- db_manager.get_session().rollback()
137
- logging.error(f"Error de base de datos: {str(db_error)}")
138
- raise db_error
164
+ provider.commit()
139
165
  except Exception as e:
140
- logging.error(f"error while commiting sql: '{str(e)}'")
166
+ # Try rollback
167
+ try:
168
+ provider.rollback()
169
+ except:
170
+ pass
171
+ logging.error(f"Error while committing sql: '{str(e)}'")
141
172
  raise IAToolkitException(
142
173
  IAToolkitException.ErrorType.DATABASE_ERROR, str(e)
174
+ )
175
+
176
+ def get_database_structure(self, company_short_name: str, db_name: str) -> dict:
177
+ """
178
+ Introspects the specified database and returns its structure (Tables & Columns).
179
+ Used for the Schema Editor 2.0
180
+ """
181
+ try:
182
+ provider = self.get_database_provider(company_short_name, db_name)
183
+ return provider.get_database_structure()
184
+ except IAToolkitException:
185
+ raise
186
+ except Exception as e:
187
+ logging.error(f"Error introspecting database '{db_name}': {e}")
188
+ raise IAToolkitException(
189
+ IAToolkitException.ErrorType.DATABASE_ERROR,
190
+ f"Failed to introspect database: {str(e)}"
143
191
  )
@@ -5,6 +5,7 @@
5
5
 
6
6
  from injector import inject
7
7
  from iatoolkit.repositories.llm_query_repo import LLMQueryRepo
8
+ from iatoolkit.repositories.profile_repo import ProfileRepo
8
9
  from iatoolkit.repositories.models import Company, Tool
9
10
  from iatoolkit.common.exceptions import IAToolkitException
10
11
  from iatoolkit.services.sql_service import SqlService
@@ -98,20 +99,20 @@ _SYSTEM_TOOLS = [
98
99
  },
99
100
  {
100
101
  "function_name": "iat_sql_query",
101
- "description": "Servicio SQL de IAToolkit: debes utilizar este servicio para todas las consultas a base de datos.",
102
+ "description": "Servicio SQL de IAToolkit: debes utilizar este servicio para todas las consultas SQL a bases de datos.",
102
103
  "parameters": {
103
104
  "type": "object",
104
105
  "properties": {
105
- "database": {
106
+ "database_key": {
106
107
  "type": "string",
107
- "description": "nombre de la base de datos a consultar: `database_name`"
108
+ "description": "IMPORTANT: nombre de la base de datos a consultar."
108
109
  },
109
110
  "query": {
110
111
  "type": "string",
111
112
  "description": "string con la consulta en sql"
112
113
  },
113
114
  },
114
- "required": ["database", "query"]
115
+ "required": ["database_key", "query"]
115
116
  }
116
117
  }
117
118
  ]
@@ -121,10 +122,12 @@ class ToolService:
121
122
  @inject
122
123
  def __init__(self,
123
124
  llm_query_repo: LLMQueryRepo,
125
+ profile_repo: ProfileRepo,
124
126
  sql_service: SqlService,
125
127
  excel_service: ExcelService,
126
128
  mail_service: MailService):
127
129
  self.llm_query_repo = llm_query_repo
130
+ self.profile_repo = profile_repo
128
131
  self.sql_service = sql_service
129
132
  self.excel_service = excel_service
130
133
  self.mail_service = mail_service
@@ -158,14 +161,22 @@ class ToolService:
158
161
  self.llm_query_repo.rollback()
159
162
  raise IAToolkitException(IAToolkitException.ErrorType.DATABASE_ERROR, str(e))
160
163
 
161
- def sync_company_tools(self, company_instance, tools_config: list):
164
+ def sync_company_tools(self, company_short_name: str, tools_config: list):
162
165
  """
163
166
  Synchronizes tools from YAML config to Database (Create/Update/Delete strategy).
164
167
  """
168
+ if not tools_config:
169
+ return
170
+
171
+ company = self.profile_repo.get_company_by_short_name(company_short_name)
172
+ if not company:
173
+ raise IAToolkitException(IAToolkitException.ErrorType.INVALID_NAME,
174
+ f'Company {company_short_name} not found')
175
+
165
176
  try:
166
177
  # 1. Get existing tools map for later cleanup
167
178
  existing_tools = {
168
- f.name: f for f in self.llm_query_repo.get_company_tools(company_instance.company)
179
+ f.name: f for f in self.llm_query_repo.get_company_tools(company)
169
180
  }
170
181
  defined_tool_names = set()
171
182
 
@@ -177,7 +188,7 @@ class ToolService:
177
188
  # Construct the tool object with current config values
178
189
  # We create a new transient object and let the repo merge it
179
190
  tool_obj = Tool(
180
- company_id=company_instance.company.id,
191
+ company_id=company.id,
181
192
  name=name,
182
193
  description=tool_data['description'],
183
194
  parameters=tool_data['params'],
@@ -206,11 +217,12 @@ class ToolService:
206
217
  Returns the list of tools (System + Company) formatted for the LLM (OpenAI Schema).
207
218
  """
208
219
  tools = []
209
- # Obtiene tanto las de la empresa como las del sistema (la query del repo debería soportar esto con OR)
210
- functions = self.llm_query_repo.get_company_tools(company)
211
220
 
212
- for function in functions:
213
- # Clonamos para no modificar el objeto de la sesión SQLAlchemy
221
+ # get all the tools for the company and system
222
+ company_tools = self.llm_query_repo.get_company_tools(company)
223
+
224
+ for function in company_tools:
225
+ # clone for no modify the SQLAlchemy session object
214
226
  params = function.parameters.copy() if function.parameters else {}
215
227
  params["additionalProperties"] = False
216
228
 
@@ -221,6 +233,9 @@ class ToolService:
221
233
  "parameters": params,
222
234
  "strict": True
223
235
  }
236
+ if function.name == 'iat_sql_query':
237
+ params['properties']['database_key']['enum'] = self.sql_service.get_db_names(company.short_name)
238
+
224
239
  tools.append(ai_tool)
225
240
  return tools
226
241