iatoolkit 0.91.1__py3-none-any.whl → 1.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- iatoolkit/__init__.py +6 -4
- iatoolkit/base_company.py +0 -16
- iatoolkit/cli_commands.py +3 -14
- iatoolkit/common/exceptions.py +1 -0
- iatoolkit/common/interfaces/__init__.py +0 -0
- iatoolkit/common/interfaces/asset_storage.py +34 -0
- iatoolkit/common/interfaces/database_provider.py +43 -0
- iatoolkit/common/model_registry.py +159 -0
- iatoolkit/common/routes.py +47 -5
- iatoolkit/common/util.py +32 -13
- iatoolkit/company_registry.py +5 -0
- iatoolkit/core.py +51 -20
- iatoolkit/infra/connectors/file_connector_factory.py +1 -0
- iatoolkit/infra/connectors/s3_connector.py +4 -2
- iatoolkit/infra/llm_providers/__init__.py +0 -0
- iatoolkit/infra/llm_providers/deepseek_adapter.py +278 -0
- iatoolkit/infra/{gemini_adapter.py → llm_providers/gemini_adapter.py} +11 -17
- iatoolkit/infra/{openai_adapter.py → llm_providers/openai_adapter.py} +41 -7
- iatoolkit/infra/llm_proxy.py +235 -134
- iatoolkit/infra/llm_response.py +5 -0
- iatoolkit/locales/en.yaml +158 -2
- iatoolkit/locales/es.yaml +158 -0
- iatoolkit/repositories/database_manager.py +52 -47
- iatoolkit/repositories/document_repo.py +7 -0
- iatoolkit/repositories/filesystem_asset_repository.py +36 -0
- iatoolkit/repositories/llm_query_repo.py +2 -0
- iatoolkit/repositories/models.py +72 -79
- iatoolkit/repositories/profile_repo.py +59 -3
- iatoolkit/repositories/vs_repo.py +22 -24
- iatoolkit/services/company_context_service.py +126 -53
- iatoolkit/services/configuration_service.py +299 -73
- iatoolkit/services/dispatcher_service.py +21 -3
- iatoolkit/services/file_processor_service.py +0 -5
- iatoolkit/services/history_manager_service.py +43 -24
- iatoolkit/services/knowledge_base_service.py +425 -0
- iatoolkit/{infra/llm_client.py → services/llm_client_service.py} +38 -29
- iatoolkit/services/load_documents_service.py +26 -48
- iatoolkit/services/profile_service.py +32 -4
- iatoolkit/services/prompt_service.py +32 -30
- iatoolkit/services/query_service.py +51 -26
- iatoolkit/services/sql_service.py +122 -74
- iatoolkit/services/tool_service.py +26 -11
- iatoolkit/services/user_session_context_service.py +115 -63
- iatoolkit/static/js/chat_main.js +44 -4
- iatoolkit/static/js/chat_model_selector.js +227 -0
- iatoolkit/static/js/chat_onboarding_button.js +1 -1
- iatoolkit/static/js/chat_reload_button.js +4 -1
- iatoolkit/static/styles/chat_iatoolkit.css +58 -2
- iatoolkit/static/styles/llm_output.css +34 -1
- iatoolkit/system_prompts/query_main.prompt +26 -2
- iatoolkit/templates/base.html +13 -0
- iatoolkit/templates/chat.html +45 -2
- iatoolkit/templates/onboarding_shell.html +0 -1
- iatoolkit/views/base_login_view.py +7 -2
- iatoolkit/views/chat_view.py +76 -0
- iatoolkit/views/configuration_api_view.py +163 -0
- iatoolkit/views/load_document_api_view.py +14 -10
- iatoolkit/views/login_view.py +8 -3
- iatoolkit/views/rag_api_view.py +216 -0
- iatoolkit/views/users_api_view.py +33 -0
- {iatoolkit-0.91.1.dist-info → iatoolkit-1.7.0.dist-info}/METADATA +4 -4
- {iatoolkit-0.91.1.dist-info → iatoolkit-1.7.0.dist-info}/RECORD +66 -58
- iatoolkit/repositories/tasks_repo.py +0 -52
- iatoolkit/services/search_service.py +0 -55
- iatoolkit/services/tasks_service.py +0 -188
- iatoolkit/views/tasks_api_view.py +0 -72
- iatoolkit/views/tasks_review_api_view.py +0 -55
- {iatoolkit-0.91.1.dist-info → iatoolkit-1.7.0.dist-info}/WHEEL +0 -0
- {iatoolkit-0.91.1.dist-info → iatoolkit-1.7.0.dist-info}/licenses/LICENSE +0 -0
- {iatoolkit-0.91.1.dist-info → iatoolkit-1.7.0.dist-info}/licenses/LICENSE_COMMUNITY.md +0 -0
- {iatoolkit-0.91.1.dist-info → iatoolkit-1.7.0.dist-info}/top_level.txt +0 -0
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
#
|
|
4
4
|
# IAToolkit is open source software.
|
|
5
5
|
|
|
6
|
-
from iatoolkit.
|
|
6
|
+
from iatoolkit.services.llm_client_service import llmClient
|
|
7
7
|
from iatoolkit.services.profile_service import ProfileService
|
|
8
8
|
from iatoolkit.repositories.profile_repo import ProfileRepo
|
|
9
9
|
from iatoolkit.services.tool_service import ToolService
|
|
@@ -11,11 +11,11 @@ from iatoolkit.services.document_service import DocumentService
|
|
|
11
11
|
from iatoolkit.services.company_context_service import CompanyContextService
|
|
12
12
|
from iatoolkit.services.i18n_service import I18nService
|
|
13
13
|
from iatoolkit.services.configuration_service import ConfigurationService
|
|
14
|
-
from iatoolkit.repositories.models import Task
|
|
15
14
|
from iatoolkit.services.dispatcher_service import Dispatcher
|
|
16
15
|
from iatoolkit.services.prompt_service import PromptService
|
|
17
16
|
from iatoolkit.services.user_session_context_service import UserSessionContextService
|
|
18
17
|
from iatoolkit.services.history_manager_service import HistoryManagerService
|
|
18
|
+
from iatoolkit.common.model_registry import ModelRegistry
|
|
19
19
|
from iatoolkit.common.util import Utility
|
|
20
20
|
from injector import inject
|
|
21
21
|
import base64
|
|
@@ -33,6 +33,7 @@ class HistoryHandle:
|
|
|
33
33
|
company_short_name: str
|
|
34
34
|
user_identifier: str
|
|
35
35
|
type: str
|
|
36
|
+
model: str | None = None
|
|
36
37
|
request_params: dict = None
|
|
37
38
|
|
|
38
39
|
|
|
@@ -52,6 +53,7 @@ class QueryService:
|
|
|
52
53
|
configuration_service: ConfigurationService,
|
|
53
54
|
history_manager: HistoryManagerService,
|
|
54
55
|
util: Utility,
|
|
56
|
+
model_registry: ModelRegistry
|
|
55
57
|
):
|
|
56
58
|
self.profile_service = profile_service
|
|
57
59
|
self.company_context_service = company_context_service
|
|
@@ -66,6 +68,7 @@ class QueryService:
|
|
|
66
68
|
self.configuration_service = configuration_service
|
|
67
69
|
self.llm_client = llm_client
|
|
68
70
|
self.history_manager = history_manager
|
|
71
|
+
self.model_registry = model_registry
|
|
69
72
|
|
|
70
73
|
|
|
71
74
|
def _resolve_model(self, company_short_name: str, model: Optional[str]) -> str:
|
|
@@ -78,13 +81,16 @@ class QueryService:
|
|
|
78
81
|
return effective_model
|
|
79
82
|
|
|
80
83
|
def _get_history_type(self, model: str) -> str:
|
|
81
|
-
|
|
82
|
-
|
|
84
|
+
history_type_str = self.model_registry.get_history_type(model)
|
|
85
|
+
if history_type_str == "server_side":
|
|
86
|
+
return HistoryManagerService.TYPE_SERVER_SIDE
|
|
87
|
+
else:
|
|
88
|
+
return HistoryManagerService.TYPE_CLIENT_SIDE
|
|
83
89
|
|
|
84
90
|
|
|
85
91
|
def _build_user_facing_prompt(self, company, user_identifier: str,
|
|
86
92
|
client_data: dict, files: list,
|
|
87
|
-
prompt_name: Optional[str], question: str)
|
|
93
|
+
prompt_name: Optional[str], question: str):
|
|
88
94
|
# get the user profile data from the session context
|
|
89
95
|
user_profile = self.profile_service.get_profile_by_identifier(company.short_name, user_identifier)
|
|
90
96
|
|
|
@@ -124,9 +130,12 @@ class QueryService:
|
|
|
124
130
|
|
|
125
131
|
return user_turn_prompt, effective_question
|
|
126
132
|
|
|
127
|
-
def _ensure_valid_history(self, company,
|
|
128
|
-
|
|
129
|
-
|
|
133
|
+
def _ensure_valid_history(self, company,
|
|
134
|
+
user_identifier: str,
|
|
135
|
+
effective_model: str,
|
|
136
|
+
user_turn_prompt: str,
|
|
137
|
+
ignore_history: bool
|
|
138
|
+
) -> tuple[Optional[HistoryHandle], Optional[dict]]:
|
|
130
139
|
"""
|
|
131
140
|
Manages the history strategy and rebuilds context if necessary.
|
|
132
141
|
Returns: (HistoryHandle, error_response)
|
|
@@ -137,7 +146,8 @@ class QueryService:
|
|
|
137
146
|
handle = HistoryHandle(
|
|
138
147
|
company_short_name=company.short_name,
|
|
139
148
|
user_identifier=user_identifier,
|
|
140
|
-
type=history_type
|
|
149
|
+
type=history_type,
|
|
150
|
+
model=effective_model
|
|
141
151
|
)
|
|
142
152
|
|
|
143
153
|
# pass the handle to populate request_params
|
|
@@ -197,22 +207,35 @@ class QueryService:
|
|
|
197
207
|
def init_context(self, company_short_name: str,
|
|
198
208
|
user_identifier: str,
|
|
199
209
|
model: str = None) -> dict:
|
|
210
|
+
"""
|
|
211
|
+
Forces a context rebuild for a given user and (optionally) model.
|
|
212
|
+
|
|
213
|
+
- Clears LLM-related context for the resolved model.
|
|
214
|
+
- Regenerates the static company/user context.
|
|
215
|
+
- Sends the context to the LLM for that model.
|
|
216
|
+
"""
|
|
200
217
|
|
|
201
|
-
# 1.
|
|
202
|
-
self.
|
|
203
|
-
logging.info(f"Context for {company_short_name}/{user_identifier} has been cleared.")
|
|
218
|
+
# 1. Resolve the effective model for this user/company
|
|
219
|
+
effective_model = self._resolve_model(company_short_name, model)
|
|
204
220
|
|
|
205
|
-
# 2.
|
|
221
|
+
# 2. Clear only the LLM-related context for this model
|
|
222
|
+
self.session_context.clear_all_context(company_short_name, user_identifier,model=effective_model)
|
|
223
|
+
logging.info(
|
|
224
|
+
f"Context for {company_short_name}/{user_identifier} "
|
|
225
|
+
f"(model={effective_model}) has been cleared."
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
# 3. Static LLM context is now clean, we can prepare it again (model-agnostic)
|
|
206
229
|
self.prepare_context(
|
|
207
230
|
company_short_name=company_short_name,
|
|
208
231
|
user_identifier=user_identifier
|
|
209
232
|
)
|
|
210
233
|
|
|
211
|
-
#
|
|
234
|
+
# 4. Communicate the new context to the specific LLM model
|
|
212
235
|
response = self.set_context_for_llm(
|
|
213
236
|
company_short_name=company_short_name,
|
|
214
237
|
user_identifier=user_identifier,
|
|
215
|
-
model=
|
|
238
|
+
model=effective_model
|
|
216
239
|
)
|
|
217
240
|
|
|
218
241
|
return response
|
|
@@ -257,8 +280,10 @@ class QueryService:
|
|
|
257
280
|
company_short_name: str,
|
|
258
281
|
user_identifier: str,
|
|
259
282
|
model: str = ''):
|
|
260
|
-
|
|
261
|
-
|
|
283
|
+
"""
|
|
284
|
+
Takes a pre-built static context and sends it to the LLM for the given model.
|
|
285
|
+
Also initializes the model-specific history through HistoryManagerService.
|
|
286
|
+
"""
|
|
262
287
|
company = self.profile_repo.get_company_by_short_name(company_short_name)
|
|
263
288
|
if not company:
|
|
264
289
|
logging.error(f"Company not found: {company_short_name} in set_context_for_llm")
|
|
@@ -267,8 +292,8 @@ class QueryService:
|
|
|
267
292
|
# --- Model Resolution ---
|
|
268
293
|
effective_model = self._resolve_model(company_short_name, model)
|
|
269
294
|
|
|
270
|
-
#
|
|
271
|
-
lock_key = f"lock:context:{company_short_name}/{user_identifier}"
|
|
295
|
+
# Lock per (company, user, model) to avoid concurrent rebuilds for the same model
|
|
296
|
+
lock_key = f"lock:context:{company_short_name}/{user_identifier}/{effective_model}"
|
|
272
297
|
if not self.session_context.acquire_lock(lock_key, expire_seconds=60):
|
|
273
298
|
logging.warning(
|
|
274
299
|
f"try to rebuild context for user {user_identifier} while is still in process, ignored.")
|
|
@@ -310,13 +335,13 @@ class QueryService:
|
|
|
310
335
|
def llm_query(self,
|
|
311
336
|
company_short_name: str,
|
|
312
337
|
user_identifier: str,
|
|
313
|
-
|
|
338
|
+
model: Optional[str] = None,
|
|
314
339
|
prompt_name: str = None,
|
|
315
340
|
question: str = '',
|
|
316
341
|
client_data: dict = {},
|
|
317
342
|
ignore_history: bool = False,
|
|
318
|
-
files: list = []
|
|
319
|
-
|
|
343
|
+
files: list = []
|
|
344
|
+
) -> dict:
|
|
320
345
|
try:
|
|
321
346
|
company = self.profile_repo.get_company_by_short_name(short_name=company_short_name)
|
|
322
347
|
if not company:
|
|
@@ -378,10 +403,10 @@ class QueryService:
|
|
|
378
403
|
if not response.get('valid_response'):
|
|
379
404
|
response['error'] = True
|
|
380
405
|
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
406
|
+
# save history using the manager passing the handle
|
|
407
|
+
self.history_manager.update_history(
|
|
408
|
+
history_handle, user_turn_prompt, response
|
|
409
|
+
)
|
|
385
410
|
|
|
386
411
|
return response
|
|
387
412
|
except Exception as e:
|
|
@@ -3,13 +3,13 @@
|
|
|
3
3
|
#
|
|
4
4
|
# IAToolkit is open source software.
|
|
5
5
|
|
|
6
|
+
from iatoolkit.common.interfaces.database_provider import DatabaseProvider
|
|
6
7
|
from iatoolkit.repositories.database_manager import DatabaseManager
|
|
7
|
-
from iatoolkit.common.util import Utility
|
|
8
8
|
from iatoolkit.services.i18n_service import I18nService
|
|
9
9
|
from iatoolkit.common.exceptions import IAToolkitException
|
|
10
|
-
from
|
|
11
|
-
from sqlalchemy.exc import SQLAlchemyError
|
|
10
|
+
from iatoolkit.common.util import Utility
|
|
12
11
|
from injector import inject, singleton
|
|
12
|
+
from typing import Callable
|
|
13
13
|
import json
|
|
14
14
|
import logging
|
|
15
15
|
|
|
@@ -28,91 +28,124 @@ class SqlService:
|
|
|
28
28
|
self.util = util
|
|
29
29
|
self.i18n_service = i18n_service
|
|
30
30
|
|
|
31
|
-
# Cache for database
|
|
32
|
-
|
|
31
|
+
# Cache for database providers. Key is tuple: (company_short_name, db_name)
|
|
32
|
+
# Value is the abstract interface DatabaseProvider
|
|
33
|
+
self._db_connections: dict[tuple[str, str], DatabaseProvider] = {}
|
|
34
|
+
|
|
35
|
+
# cache for database schemas. Key is tuple: (company_short_name, db_name)
|
|
36
|
+
self._db_schemas: dict[tuple[str, str], str] = {}
|
|
37
|
+
|
|
38
|
+
# Registry of factory functions.
|
|
39
|
+
# Format: {'connection_type': function(config_dict) -> DatabaseProvider}
|
|
40
|
+
self._provider_factories: dict[str, Callable[[dict], DatabaseProvider]] = {}
|
|
41
|
+
|
|
42
|
+
# Register the default 'direct' strategy (SQLAlchemy)
|
|
43
|
+
self.register_provider_factory('direct', self._create_direct_connection)
|
|
33
44
|
|
|
34
|
-
def
|
|
45
|
+
def register_provider_factory(self, connection_type: str, factory: Callable[[dict], DatabaseProvider]):
|
|
35
46
|
"""
|
|
36
|
-
|
|
37
|
-
|
|
47
|
+
Allows plugins (Enterprise) to register new connection types.
|
|
48
|
+
"""
|
|
49
|
+
self._provider_factories[connection_type] = factory
|
|
50
|
+
|
|
51
|
+
def _create_direct_connection(self, config: dict) -> DatabaseProvider:
|
|
52
|
+
"""Default factory for standard SQLAlchemy connections."""
|
|
53
|
+
uri = config.get('db_uri') or config.get('DATABASE_URI')
|
|
54
|
+
schema = config.get('schema')
|
|
55
|
+
if not uri:
|
|
56
|
+
raise IAToolkitException(IAToolkitException.ErrorType.DATABASE_ERROR,
|
|
57
|
+
"Missing db_uri for direct connection")
|
|
58
|
+
return DatabaseManager(uri, schema=schema, register_pgvector=False)
|
|
59
|
+
|
|
60
|
+
def register_database(self, company_short_name: str, db_name: str, config: dict):
|
|
38
61
|
"""
|
|
39
|
-
|
|
62
|
+
Creates and caches a DatabaseProvider instance based on the configuration.
|
|
63
|
+
"""
|
|
64
|
+
key = (company_short_name, db_name)
|
|
65
|
+
|
|
66
|
+
# Determine connection type (default to 'direct')
|
|
67
|
+
conn_type = config.get('connection_type', 'direct')
|
|
68
|
+
logging.info(f"Registering DB '{db_name}' ({conn_type}) for company '{company_short_name}'")
|
|
69
|
+
|
|
70
|
+
factory = self._provider_factories.get(conn_type)
|
|
71
|
+
if not factory:
|
|
72
|
+
logging.error(f"Unknown connection type '{conn_type}' for DB '{db_name}'. Skipping.")
|
|
40
73
|
return
|
|
41
74
|
|
|
42
|
-
|
|
75
|
+
try:
|
|
76
|
+
# Create the provider using the appropriate factory
|
|
77
|
+
provider_instance = factory(config)
|
|
78
|
+
self._db_connections[key] = provider_instance
|
|
79
|
+
|
|
80
|
+
# save the db_schema
|
|
81
|
+
self._db_schemas[key] = config.get('schema', 'public')
|
|
82
|
+
except Exception as e:
|
|
83
|
+
logging.error(f"Failed to register DB '{db_name}': {e}")
|
|
84
|
+
# We don't raise here to allow other DBs to load if one fails
|
|
43
85
|
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
86
|
+
def get_db_names(self, company_short_name: str) -> list[str]:
|
|
87
|
+
"""
|
|
88
|
+
Returns list of logical database names available ONLY for the specified company.
|
|
89
|
+
"""
|
|
90
|
+
return [db for (co, db) in self._db_connections.keys() if co == company_short_name]
|
|
47
91
|
|
|
48
|
-
def
|
|
92
|
+
def get_database_provider(self, company_short_name: str, db_name: str) -> DatabaseProvider:
|
|
49
93
|
"""
|
|
50
|
-
Retrieves a registered
|
|
94
|
+
Retrieves a registered DatabaseProvider instance using the composite key.
|
|
95
|
+
Replaces the old 'get_database_manager'.
|
|
51
96
|
"""
|
|
97
|
+
key = (company_short_name, db_name)
|
|
52
98
|
try:
|
|
53
|
-
return self._db_connections[
|
|
99
|
+
return self._db_connections[key]
|
|
54
100
|
except KeyError:
|
|
55
|
-
logging.error(
|
|
101
|
+
logging.error(
|
|
102
|
+
f"Attempted to access unregistered database: '{db_name}' for company '{company_short_name}'"
|
|
103
|
+
)
|
|
56
104
|
raise IAToolkitException(
|
|
57
105
|
IAToolkitException.ErrorType.DATABASE_ERROR,
|
|
58
|
-
f"Database '{db_name}' is not registered
|
|
106
|
+
f"Database '{db_name}' is not registered for this company."
|
|
59
107
|
)
|
|
60
108
|
|
|
61
|
-
def exec_sql(self, company_short_name: str,
|
|
62
|
-
database: str,
|
|
63
|
-
query: str,
|
|
64
|
-
format: str = 'json',
|
|
65
|
-
commit: bool = False):
|
|
109
|
+
def exec_sql(self, company_short_name: str, **kwargs):
|
|
66
110
|
"""
|
|
67
|
-
Executes a raw SQL statement against a registered database.
|
|
68
|
-
|
|
69
|
-
Args:
|
|
70
|
-
company_short_name: The company identifier (for logging/context).
|
|
71
|
-
database: The logical name of the database to query.
|
|
72
|
-
query: The SQL statement to execute.
|
|
73
|
-
format: The output format ('json' or 'dict'). Only relevant for SELECT queries.
|
|
74
|
-
commit: Whether to commit the transaction immediately after execution.
|
|
75
|
-
Use True for INSERT/UPDATE/DELETE statements.
|
|
76
|
-
|
|
77
|
-
Returns:
|
|
78
|
-
- A JSON string or list of dicts for SELECT queries.
|
|
79
|
-
- A dictionary {'rowcount': N} for non-returning statements (INSERT/UPDATE) if not using RETURNING.
|
|
111
|
+
Executes a raw SQL statement against a registered database provider.
|
|
112
|
+
Delegates the actual execution details to the provider implementation.
|
|
80
113
|
"""
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
114
|
+
database_name = kwargs.get('database_key')
|
|
115
|
+
query = kwargs.get('query')
|
|
116
|
+
format = kwargs.get('format', 'json')
|
|
117
|
+
commit = kwargs.get('commit')
|
|
85
118
|
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
# 3. Handle Commit
|
|
90
|
-
if commit:
|
|
91
|
-
session.commit()
|
|
119
|
+
if not database_name:
|
|
120
|
+
raise IAToolkitException(IAToolkitException.ErrorType.DATABASE_ERROR,
|
|
121
|
+
'missing database_name in call to exec_sql')
|
|
92
122
|
|
|
93
|
-
|
|
94
|
-
#
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
rows_context = [dict(zip(cols, row)) for row in result.fetchall()]
|
|
123
|
+
try:
|
|
124
|
+
# 1. Get the abstract provider (could be Direct or Bridge)
|
|
125
|
+
provider = self.get_database_provider(company_short_name, database_name)
|
|
126
|
+
db_schema = self._db_schemas[(company_short_name, database_name)]
|
|
98
127
|
|
|
99
|
-
|
|
100
|
-
|
|
128
|
+
# 2. Delegate execution
|
|
129
|
+
# The provider returns a clean List[Dict] or Dict result
|
|
130
|
+
result_data = provider.execute_query(query=query, commit=commit)
|
|
101
131
|
|
|
102
|
-
|
|
103
|
-
|
|
132
|
+
# 3. Handle Formatting (Service layer responsibility)
|
|
133
|
+
if format == 'dict':
|
|
134
|
+
return result_data
|
|
104
135
|
|
|
105
|
-
#
|
|
106
|
-
return
|
|
136
|
+
# Serialize the result
|
|
137
|
+
return json.dumps(result_data, default=self.util.serialize)
|
|
107
138
|
|
|
108
139
|
except IAToolkitException:
|
|
109
|
-
# Re-raise exceptions from get_database_manager to preserve the specific error
|
|
110
140
|
raise
|
|
111
141
|
except Exception as e:
|
|
112
|
-
# Attempt
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
142
|
+
# Attempt rollback if supported/needed
|
|
143
|
+
try:
|
|
144
|
+
provider = self.get_database_provider(company_short_name, database_name)
|
|
145
|
+
if provider:
|
|
146
|
+
provider.rollback()
|
|
147
|
+
except Exception:
|
|
148
|
+
pass
|
|
116
149
|
|
|
117
150
|
error_message = str(e)
|
|
118
151
|
if 'timed out' in str(e):
|
|
@@ -122,22 +155,37 @@ class SqlService:
|
|
|
122
155
|
raise IAToolkitException(IAToolkitException.ErrorType.DATABASE_ERROR,
|
|
123
156
|
error_message) from e
|
|
124
157
|
|
|
125
|
-
def commit(self,
|
|
158
|
+
def commit(self, company_short_name: str, database_name: str):
|
|
126
159
|
"""
|
|
127
|
-
Commits the current transaction for a registered database.
|
|
128
|
-
Useful when multiple exec_sql calls are part of a single transaction.
|
|
160
|
+
Commits the current transaction for a registered database provider.
|
|
129
161
|
"""
|
|
130
|
-
|
|
131
|
-
# Get the database manager from the cache
|
|
132
|
-
db_manager = self.get_database_manager(database)
|
|
162
|
+
provider = self.get_database_provider(company_short_name, database_name)
|
|
133
163
|
try:
|
|
134
|
-
|
|
135
|
-
except SQLAlchemyError as db_error:
|
|
136
|
-
db_manager.get_session().rollback()
|
|
137
|
-
logging.error(f"Error de base de datos: {str(db_error)}")
|
|
138
|
-
raise db_error
|
|
164
|
+
provider.commit()
|
|
139
165
|
except Exception as e:
|
|
140
|
-
|
|
166
|
+
# Try rollback
|
|
167
|
+
try:
|
|
168
|
+
provider.rollback()
|
|
169
|
+
except:
|
|
170
|
+
pass
|
|
171
|
+
logging.error(f"Error while committing sql: '{str(e)}'")
|
|
141
172
|
raise IAToolkitException(
|
|
142
173
|
IAToolkitException.ErrorType.DATABASE_ERROR, str(e)
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
def get_database_structure(self, company_short_name: str, db_name: str) -> dict:
|
|
177
|
+
"""
|
|
178
|
+
Introspects the specified database and returns its structure (Tables & Columns).
|
|
179
|
+
Used for the Schema Editor 2.0
|
|
180
|
+
"""
|
|
181
|
+
try:
|
|
182
|
+
provider = self.get_database_provider(company_short_name, db_name)
|
|
183
|
+
return provider.get_database_structure()
|
|
184
|
+
except IAToolkitException:
|
|
185
|
+
raise
|
|
186
|
+
except Exception as e:
|
|
187
|
+
logging.error(f"Error introspecting database '{db_name}': {e}")
|
|
188
|
+
raise IAToolkitException(
|
|
189
|
+
IAToolkitException.ErrorType.DATABASE_ERROR,
|
|
190
|
+
f"Failed to introspect database: {str(e)}"
|
|
143
191
|
)
|
|
@@ -5,6 +5,7 @@
|
|
|
5
5
|
|
|
6
6
|
from injector import inject
|
|
7
7
|
from iatoolkit.repositories.llm_query_repo import LLMQueryRepo
|
|
8
|
+
from iatoolkit.repositories.profile_repo import ProfileRepo
|
|
8
9
|
from iatoolkit.repositories.models import Company, Tool
|
|
9
10
|
from iatoolkit.common.exceptions import IAToolkitException
|
|
10
11
|
from iatoolkit.services.sql_service import SqlService
|
|
@@ -98,20 +99,20 @@ _SYSTEM_TOOLS = [
|
|
|
98
99
|
},
|
|
99
100
|
{
|
|
100
101
|
"function_name": "iat_sql_query",
|
|
101
|
-
"description": "Servicio SQL de IAToolkit: debes utilizar este servicio para todas las consultas a
|
|
102
|
+
"description": "Servicio SQL de IAToolkit: debes utilizar este servicio para todas las consultas SQL a bases de datos.",
|
|
102
103
|
"parameters": {
|
|
103
104
|
"type": "object",
|
|
104
105
|
"properties": {
|
|
105
|
-
"
|
|
106
|
+
"database_key": {
|
|
106
107
|
"type": "string",
|
|
107
|
-
"description": "nombre de la base de datos a consultar
|
|
108
|
+
"description": "IMPORTANT: nombre de la base de datos a consultar."
|
|
108
109
|
},
|
|
109
110
|
"query": {
|
|
110
111
|
"type": "string",
|
|
111
112
|
"description": "string con la consulta en sql"
|
|
112
113
|
},
|
|
113
114
|
},
|
|
114
|
-
"required": ["
|
|
115
|
+
"required": ["database_key", "query"]
|
|
115
116
|
}
|
|
116
117
|
}
|
|
117
118
|
]
|
|
@@ -121,10 +122,12 @@ class ToolService:
|
|
|
121
122
|
@inject
|
|
122
123
|
def __init__(self,
|
|
123
124
|
llm_query_repo: LLMQueryRepo,
|
|
125
|
+
profile_repo: ProfileRepo,
|
|
124
126
|
sql_service: SqlService,
|
|
125
127
|
excel_service: ExcelService,
|
|
126
128
|
mail_service: MailService):
|
|
127
129
|
self.llm_query_repo = llm_query_repo
|
|
130
|
+
self.profile_repo = profile_repo
|
|
128
131
|
self.sql_service = sql_service
|
|
129
132
|
self.excel_service = excel_service
|
|
130
133
|
self.mail_service = mail_service
|
|
@@ -158,14 +161,22 @@ class ToolService:
|
|
|
158
161
|
self.llm_query_repo.rollback()
|
|
159
162
|
raise IAToolkitException(IAToolkitException.ErrorType.DATABASE_ERROR, str(e))
|
|
160
163
|
|
|
161
|
-
def sync_company_tools(self,
|
|
164
|
+
def sync_company_tools(self, company_short_name: str, tools_config: list):
|
|
162
165
|
"""
|
|
163
166
|
Synchronizes tools from YAML config to Database (Create/Update/Delete strategy).
|
|
164
167
|
"""
|
|
168
|
+
if not tools_config:
|
|
169
|
+
return
|
|
170
|
+
|
|
171
|
+
company = self.profile_repo.get_company_by_short_name(company_short_name)
|
|
172
|
+
if not company:
|
|
173
|
+
raise IAToolkitException(IAToolkitException.ErrorType.INVALID_NAME,
|
|
174
|
+
f'Company {company_short_name} not found')
|
|
175
|
+
|
|
165
176
|
try:
|
|
166
177
|
# 1. Get existing tools map for later cleanup
|
|
167
178
|
existing_tools = {
|
|
168
|
-
f.name: f for f in self.llm_query_repo.get_company_tools(
|
|
179
|
+
f.name: f for f in self.llm_query_repo.get_company_tools(company)
|
|
169
180
|
}
|
|
170
181
|
defined_tool_names = set()
|
|
171
182
|
|
|
@@ -177,7 +188,7 @@ class ToolService:
|
|
|
177
188
|
# Construct the tool object with current config values
|
|
178
189
|
# We create a new transient object and let the repo merge it
|
|
179
190
|
tool_obj = Tool(
|
|
180
|
-
company_id=
|
|
191
|
+
company_id=company.id,
|
|
181
192
|
name=name,
|
|
182
193
|
description=tool_data['description'],
|
|
183
194
|
parameters=tool_data['params'],
|
|
@@ -206,11 +217,12 @@ class ToolService:
|
|
|
206
217
|
Returns the list of tools (System + Company) formatted for the LLM (OpenAI Schema).
|
|
207
218
|
"""
|
|
208
219
|
tools = []
|
|
209
|
-
# Obtiene tanto las de la empresa como las del sistema (la query del repo debería soportar esto con OR)
|
|
210
|
-
functions = self.llm_query_repo.get_company_tools(company)
|
|
211
220
|
|
|
212
|
-
for
|
|
213
|
-
|
|
221
|
+
# get all the tools for the company and system
|
|
222
|
+
company_tools = self.llm_query_repo.get_company_tools(company)
|
|
223
|
+
|
|
224
|
+
for function in company_tools:
|
|
225
|
+
# clone for no modify the SQLAlchemy session object
|
|
214
226
|
params = function.parameters.copy() if function.parameters else {}
|
|
215
227
|
params["additionalProperties"] = False
|
|
216
228
|
|
|
@@ -221,6 +233,9 @@ class ToolService:
|
|
|
221
233
|
"parameters": params,
|
|
222
234
|
"strict": True
|
|
223
235
|
}
|
|
236
|
+
if function.name == 'iat_sql_query':
|
|
237
|
+
params['properties']['database_key']['enum'] = self.sql_service.get_db_names(company.short_name)
|
|
238
|
+
|
|
224
239
|
tools.append(ai_tool)
|
|
225
240
|
return tools
|
|
226
241
|
|