iatoolkit 0.7.4__py3-none-any.whl → 0.7.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of iatoolkit might be problematic. Click here for more details.
- common/__init__.py +0 -0
- common/auth.py +200 -0
- common/exceptions.py +46 -0
- common/routes.py +86 -0
- common/session_manager.py +25 -0
- common/util.py +358 -0
- iatoolkit/iatoolkit.py +3 -3
- {iatoolkit-0.7.4.dist-info → iatoolkit-0.7.6.dist-info}/METADATA +1 -1
- iatoolkit-0.7.6.dist-info/RECORD +80 -0
- iatoolkit-0.7.6.dist-info/top_level.txt +6 -0
- infra/__init__.py +5 -0
- infra/call_service.py +140 -0
- infra/connectors/__init__.py +5 -0
- infra/connectors/file_connector.py +17 -0
- infra/connectors/file_connector_factory.py +57 -0
- infra/connectors/google_cloud_storage_connector.py +53 -0
- infra/connectors/google_drive_connector.py +68 -0
- infra/connectors/local_file_connector.py +46 -0
- infra/connectors/s3_connector.py +33 -0
- infra/gemini_adapter.py +356 -0
- infra/google_chat_app.py +57 -0
- infra/llm_client.py +430 -0
- infra/llm_proxy.py +139 -0
- infra/llm_response.py +40 -0
- infra/mail_app.py +145 -0
- infra/openai_adapter.py +90 -0
- infra/redis_session_manager.py +76 -0
- repositories/__init__.py +5 -0
- repositories/database_manager.py +95 -0
- repositories/document_repo.py +33 -0
- repositories/llm_query_repo.py +91 -0
- repositories/models.py +309 -0
- repositories/profile_repo.py +118 -0
- repositories/tasks_repo.py +52 -0
- repositories/vs_repo.py +139 -0
- views/__init__.py +5 -0
- views/change_password_view.py +91 -0
- views/chat_token_request_view.py +98 -0
- views/chat_view.py +51 -0
- views/download_file_view.py +58 -0
- views/external_chat_login_view.py +88 -0
- views/external_login_view.py +40 -0
- views/file_store_view.py +58 -0
- views/forgot_password_view.py +64 -0
- views/history_view.py +57 -0
- views/home_view.py +34 -0
- views/llmquery_view.py +65 -0
- views/login_view.py +60 -0
- views/prompt_view.py +37 -0
- views/signup_view.py +87 -0
- views/tasks_review_view.py +83 -0
- views/tasks_view.py +98 -0
- views/user_feedback_view.py +74 -0
- views/verify_user_view.py +55 -0
- iatoolkit-0.7.4.dist-info/RECORD +0 -30
- iatoolkit-0.7.4.dist-info/top_level.txt +0 -2
- {iatoolkit-0.7.4.dist-info → iatoolkit-0.7.6.dist-info}/WHEEL +0 -0
repositories/models.py
ADDED
|
@@ -0,0 +1,309 @@
|
|
|
1
|
+
# Copyright (c) 2024 Fernando Libedinsky
|
|
2
|
+
# Product: IAToolkit
|
|
3
|
+
#
|
|
4
|
+
# IAToolkit is open source software.
|
|
5
|
+
|
|
6
|
+
from sqlalchemy import Column, Integer, String, DateTime, Enum, Text, JSON, Boolean, ForeignKey, Table
|
|
7
|
+
from sqlalchemy.orm import DeclarativeBase
|
|
8
|
+
from sqlalchemy.orm import relationship, class_mapper, declarative_base
|
|
9
|
+
from datetime import datetime
|
|
10
|
+
from pgvector.sqlalchemy import Vector
|
|
11
|
+
from enum import Enum as PyEnum
|
|
12
|
+
import secrets
|
|
13
|
+
import enum
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
# base class for the ORM
|
|
17
|
+
class Base(DeclarativeBase):
|
|
18
|
+
pass
|
|
19
|
+
|
|
20
|
+
# relation table for many-to-many relationship between companies and users
|
|
21
|
+
user_company = Table('iat_user_company',
|
|
22
|
+
Base.metadata,
|
|
23
|
+
Column('user_id', Integer,
|
|
24
|
+
ForeignKey('iat_users.id', ondelete='CASCADE'),
|
|
25
|
+
primary_key=True),
|
|
26
|
+
Column('company_id', Integer,
|
|
27
|
+
ForeignKey('iat_companies.id',ondelete='CASCADE'),
|
|
28
|
+
primary_key=True),
|
|
29
|
+
Column('is_active', Boolean, default=True),
|
|
30
|
+
Column('role', String(50), default='user'), # Para manejar roles por empresa
|
|
31
|
+
Column('created_at', DateTime, default=datetime.now)
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
class ApiKey(Base):
|
|
35
|
+
"""Represents an API key for a company to authenticate against the system."""
|
|
36
|
+
__tablename__ = 'iat_api_keys'
|
|
37
|
+
|
|
38
|
+
id = Column(Integer, primary_key=True)
|
|
39
|
+
company_id = Column(Integer, ForeignKey('iat_companies.id', ondelete='CASCADE'), nullable=False)
|
|
40
|
+
key = Column(String(128), unique=True, nullable=False, index=True) # La API Key en sí
|
|
41
|
+
is_active = Column(Boolean, default=True, nullable=False)
|
|
42
|
+
created_at = Column(DateTime, default=datetime.now)
|
|
43
|
+
last_used_at = Column(DateTime, nullable=True) # Opcional: para rastrear uso
|
|
44
|
+
|
|
45
|
+
company = relationship("Company", back_populates="api_keys")
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class Company(Base):
|
|
49
|
+
"""Represents a company or tenant in the multi-tenant system."""
|
|
50
|
+
__tablename__ = 'iat_companies'
|
|
51
|
+
|
|
52
|
+
id = Column(Integer, primary_key=True)
|
|
53
|
+
short_name = Column(String(20), nullable=False, unique=True, index=True)
|
|
54
|
+
name = Column(String(256), nullable=False)
|
|
55
|
+
|
|
56
|
+
# encrypted api-key
|
|
57
|
+
openai_api_key = Column(String, nullable=True)
|
|
58
|
+
gemini_api_key = Column(String, nullable=True)
|
|
59
|
+
|
|
60
|
+
logo_file = Column(String(128), nullable=True, default='')
|
|
61
|
+
parameters = Column(JSON, nullable=True, default={})
|
|
62
|
+
created_at = Column(DateTime, default=datetime.now)
|
|
63
|
+
allow_jwt = Column(Boolean, default=False, nullable=True)
|
|
64
|
+
|
|
65
|
+
documents = relationship("Document",
|
|
66
|
+
back_populates="company",
|
|
67
|
+
cascade="all, delete-orphan",
|
|
68
|
+
lazy='dynamic')
|
|
69
|
+
functions = relationship("Function",
|
|
70
|
+
back_populates="company",
|
|
71
|
+
cascade="all, delete-orphan")
|
|
72
|
+
vsdocs = relationship("VSDoc",
|
|
73
|
+
back_populates="company",
|
|
74
|
+
cascade="all, delete-orphan")
|
|
75
|
+
llm_queries = relationship("LLMQuery",
|
|
76
|
+
back_populates="company",
|
|
77
|
+
cascade="all, delete-orphan")
|
|
78
|
+
users = relationship("User",
|
|
79
|
+
secondary=user_company,
|
|
80
|
+
back_populates="companies")
|
|
81
|
+
api_keys = relationship("ApiKey",
|
|
82
|
+
back_populates="company",
|
|
83
|
+
cascade="all, delete-orphan")
|
|
84
|
+
|
|
85
|
+
tasks = relationship("Task", back_populates="company")
|
|
86
|
+
feedbacks = relationship("UserFeedback",
|
|
87
|
+
back_populates="company",
|
|
88
|
+
cascade="all, delete-orphan")
|
|
89
|
+
prompts = relationship("Prompt",
|
|
90
|
+
back_populates="company",
|
|
91
|
+
cascade="all, delete-orphan")
|
|
92
|
+
|
|
93
|
+
def to_dict(self):
|
|
94
|
+
return {column.key: getattr(self, column.key) for column in class_mapper(self.__class__).columns}
|
|
95
|
+
|
|
96
|
+
# users with rights to use this app
|
|
97
|
+
class User(Base):
|
|
98
|
+
"""Represents an IAToolkit user who can be associated with multiple companies."""
|
|
99
|
+
__tablename__ = 'iat_users'
|
|
100
|
+
|
|
101
|
+
id = Column(Integer, primary_key=True)
|
|
102
|
+
email = Column(String(80), unique=True, nullable=False)
|
|
103
|
+
first_name = Column(String(50), nullable=False)
|
|
104
|
+
last_name = Column(String(50), nullable=False)
|
|
105
|
+
created_at = Column(DateTime, default=datetime.now)
|
|
106
|
+
password = Column(String, nullable=False)
|
|
107
|
+
verified = Column(Boolean, nullable=False, default=False)
|
|
108
|
+
verification_url = Column(String, nullable=True)
|
|
109
|
+
temp_code = Column(String, nullable=True)
|
|
110
|
+
|
|
111
|
+
companies = relationship(
|
|
112
|
+
"Company",
|
|
113
|
+
secondary=user_company,
|
|
114
|
+
back_populates="users",
|
|
115
|
+
cascade="all",
|
|
116
|
+
passive_deletes=True,
|
|
117
|
+
lazy='dynamic'
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
def to_dict(self):
|
|
121
|
+
return {
|
|
122
|
+
'id': self.id,
|
|
123
|
+
'email': self.email,
|
|
124
|
+
'first_name': self.first_name,
|
|
125
|
+
'last_name': self.last_name,
|
|
126
|
+
'created_at': str(self.created_at),
|
|
127
|
+
'verified': self.verified,
|
|
128
|
+
'companies': [company.to_dict() for company in self.companies]
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
class Function(Base):
|
|
132
|
+
"""Represents a custom or system function that the LLM can call (tool)."""
|
|
133
|
+
__tablename__ = 'iat_functions'
|
|
134
|
+
|
|
135
|
+
id = Column(Integer, primary_key=True)
|
|
136
|
+
company_id = Column(Integer,
|
|
137
|
+
ForeignKey('iat_companies.id',ondelete='CASCADE'),
|
|
138
|
+
nullable=True)
|
|
139
|
+
name = Column(String(255), nullable=False)
|
|
140
|
+
system_function = Column(Boolean, default=False)
|
|
141
|
+
description = Column(Text, nullable=False)
|
|
142
|
+
parameters = Column(JSON, nullable=False)
|
|
143
|
+
is_active = Column(Boolean, default=True)
|
|
144
|
+
created_at = Column(DateTime, default=datetime.now)
|
|
145
|
+
|
|
146
|
+
company = relationship('Company', back_populates='functions')
|
|
147
|
+
|
|
148
|
+
def to_dict(self):
|
|
149
|
+
return {column.key: getattr(self, column.key) for column in class_mapper(self.__class__).columns}
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
class Document(Base):
|
|
153
|
+
"""Represents a file or document uploaded by a company for context."""
|
|
154
|
+
__tablename__ = 'iat_documents'
|
|
155
|
+
|
|
156
|
+
id = Column(Integer, primary_key=True, autoincrement=True)
|
|
157
|
+
company_id = Column(Integer, ForeignKey('iat_companies.id',
|
|
158
|
+
ondelete='CASCADE'), nullable=False)
|
|
159
|
+
filename = Column(String(256), nullable=False, index=True)
|
|
160
|
+
content = Column(Text, nullable=False)
|
|
161
|
+
content_b64 = Column(Text, nullable=False)
|
|
162
|
+
meta = Column(JSON, nullable=True)
|
|
163
|
+
created_at = Column(DateTime, default=datetime.now)
|
|
164
|
+
|
|
165
|
+
company = relationship("Company", back_populates="documents")
|
|
166
|
+
|
|
167
|
+
def to_dict(self):
|
|
168
|
+
return {column.key: getattr(self, column.key) for column in class_mapper(self.__class__).columns}
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
class LLMQuery(Base):
|
|
172
|
+
"""Logs a query made to the LLM, including input, output, and metadata."""
|
|
173
|
+
__tablename__ = 'iat_queries'
|
|
174
|
+
|
|
175
|
+
id = Column(Integer, primary_key=True)
|
|
176
|
+
company_id = Column(Integer, ForeignKey('iat_companies.id',
|
|
177
|
+
ondelete='CASCADE'), nullable=False)
|
|
178
|
+
user_identifier = Column(String(128), nullable=False)
|
|
179
|
+
task_id = Column(Integer, default=0, nullable=True)
|
|
180
|
+
query = Column(Text, nullable=False)
|
|
181
|
+
output = Column(Text, nullable=False)
|
|
182
|
+
response = Column(JSON, nullable=True, default={})
|
|
183
|
+
valid_response = Column(Boolean, nullable=False, default=False)
|
|
184
|
+
function_calls = Column(JSON, nullable=True, default={})
|
|
185
|
+
stats = Column(JSON, default={})
|
|
186
|
+
answer_time = Column(Integer, default=0)
|
|
187
|
+
created_at = Column(DateTime, default=datetime.now)
|
|
188
|
+
|
|
189
|
+
company = relationship("Company", back_populates="llm_queries")
|
|
190
|
+
tasks = relationship("Task", back_populates="llm_query")
|
|
191
|
+
|
|
192
|
+
def to_dict(self):
|
|
193
|
+
return {column.key: getattr(self, column.key) for column in class_mapper(self.__class__).columns}
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
class VSDoc(Base):
|
|
197
|
+
"""Stores a text chunk and its corresponding vector embedding for similarity search."""
|
|
198
|
+
__tablename__ = "iat_vsdocs"
|
|
199
|
+
|
|
200
|
+
id = Column(Integer, primary_key=True)
|
|
201
|
+
company_id = Column(Integer, ForeignKey('iat_companies.id',
|
|
202
|
+
ondelete='CASCADE'), nullable=False)
|
|
203
|
+
document_id = Column(Integer, ForeignKey('iat_documents.id',
|
|
204
|
+
ondelete='CASCADE'), nullable=False)
|
|
205
|
+
text = Column(Text, nullable=False)
|
|
206
|
+
embedding = Column(Vector(384), nullable=False) # Ajusta la dimensión si es necesario
|
|
207
|
+
|
|
208
|
+
company = relationship("Company", back_populates="vsdocs")
|
|
209
|
+
|
|
210
|
+
def to_dict(self):
|
|
211
|
+
return {column.key: getattr(self, column.key) for column in class_mapper(self.__class__).columns}
|
|
212
|
+
|
|
213
|
+
class TaskStatus(PyEnum):
|
|
214
|
+
"""Enumeration for the possible statuses of a Task."""
|
|
215
|
+
pendiente = "pendiente" # task created and waiting to be executed.
|
|
216
|
+
ejecutado = "ejecutado" # the IA algorithm has been executed.
|
|
217
|
+
aprobada = "aprobada" # validated and approved by human.
|
|
218
|
+
rechazada = "rechazada" # validated and rejected by human.
|
|
219
|
+
fallida = "fallida" # error executing the IA algorithm.
|
|
220
|
+
|
|
221
|
+
class TaskType(Base):
|
|
222
|
+
"""Defines a type of task that can be executed, including its prompt template."""
|
|
223
|
+
__tablename__ = 'iat_task_types'
|
|
224
|
+
|
|
225
|
+
id = Column(Integer, primary_key=True)
|
|
226
|
+
name = Column(String(100), unique=True, nullable=False)
|
|
227
|
+
prompt_template = Column(String(100), nullable=True) # Plantilla de prompt por defecto.
|
|
228
|
+
template_args = Column(JSON, nullable=True) # Argumentos/prefijos de configuración para el template.
|
|
229
|
+
|
|
230
|
+
class Task(Base):
|
|
231
|
+
"""Represents an asynchronous task to be executed by the system, often involving an LLM."""
|
|
232
|
+
__tablename__ = 'iat_tasks'
|
|
233
|
+
|
|
234
|
+
id = Column(Integer, primary_key=True)
|
|
235
|
+
company_id = Column(Integer, ForeignKey("iat_companies.id"))
|
|
236
|
+
|
|
237
|
+
user_id = Column(Integer, nullable=True, default=0)
|
|
238
|
+
task_type_id = Column(Integer, ForeignKey('iat_task_types.id'), nullable=False)
|
|
239
|
+
status = Column(Enum(TaskStatus, name="task_status_enum"),
|
|
240
|
+
default=TaskStatus.pendiente, nullable=False)
|
|
241
|
+
client_data = Column(JSON, nullable=True, default={})
|
|
242
|
+
company_task_id = Column(Integer, nullable=True, default=0)
|
|
243
|
+
execute_at = Column(DateTime, default=datetime.now, nullable=True)
|
|
244
|
+
llm_query_id = Column(Integer, ForeignKey('iat_queries.id'), nullable=True)
|
|
245
|
+
callback_url = Column(String(512), default=None, nullable=True)
|
|
246
|
+
files = Column(JSON, default=[], nullable=True)
|
|
247
|
+
|
|
248
|
+
review_user = Column(String(128), nullable=True, default='')
|
|
249
|
+
review_date = Column(DateTime, nullable=True)
|
|
250
|
+
comment = Column(Text, nullable=True)
|
|
251
|
+
approved = Column(Boolean, nullable=False, default=False)
|
|
252
|
+
|
|
253
|
+
created_at = Column(DateTime, default=datetime.now)
|
|
254
|
+
updated_at = Column(DateTime, default=datetime.now)
|
|
255
|
+
|
|
256
|
+
task_type = relationship("TaskType")
|
|
257
|
+
llm_query = relationship("LLMQuery", back_populates="tasks", uselist=False)
|
|
258
|
+
company = relationship("Company", back_populates="tasks")
|
|
259
|
+
|
|
260
|
+
class UserFeedback(Base):
|
|
261
|
+
"""Stores feedback and ratings submitted by users for specific interactions."""
|
|
262
|
+
__tablename__ = 'iat_feedback'
|
|
263
|
+
|
|
264
|
+
id = Column(Integer, primary_key=True)
|
|
265
|
+
company_id = Column(Integer, ForeignKey('iat_companies.id',
|
|
266
|
+
ondelete='CASCADE'), nullable=False)
|
|
267
|
+
local_user_id = Column(Integer, default=0, nullable=True)
|
|
268
|
+
external_user_id = Column(String(128), default='', nullable=True)
|
|
269
|
+
message = Column(Text, nullable=False)
|
|
270
|
+
rating = Column(Integer, nullable=False)
|
|
271
|
+
created_at = Column(DateTime, default=datetime.now)
|
|
272
|
+
|
|
273
|
+
company = relationship("Company", back_populates="feedbacks")
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
class PromptCategory(Base):
|
|
277
|
+
"""Represents a category to group and organize prompts."""
|
|
278
|
+
__tablename__ = 'iat_prompt_categories'
|
|
279
|
+
id = Column(Integer, primary_key=True)
|
|
280
|
+
name = Column(String, nullable=False)
|
|
281
|
+
order = Column(Integer, nullable=False, default=0)
|
|
282
|
+
company_id = Column(Integer, ForeignKey('iat_companies.id'), nullable=False)
|
|
283
|
+
|
|
284
|
+
prompts = relationship("Prompt", back_populates="category", order_by="Prompt.order")
|
|
285
|
+
|
|
286
|
+
def __repr__(self):
|
|
287
|
+
return f"<PromptCategory(name='{self.name}', order={self.order})>"
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
class Prompt(Base):
|
|
291
|
+
"""Represents a system or user-defined prompt template for the LLM."""
|
|
292
|
+
__tablename__ = 'iat_prompt'
|
|
293
|
+
|
|
294
|
+
id = Column(Integer, primary_key=True)
|
|
295
|
+
company_id = Column(Integer, ForeignKey('iat_companies.id',
|
|
296
|
+
ondelete='CASCADE'), nullable=True)
|
|
297
|
+
name = Column(String(64), nullable=False)
|
|
298
|
+
description = Column(String(256), nullable=False)
|
|
299
|
+
filename = Column(String(256), nullable=False)
|
|
300
|
+
active = Column(Boolean, default=True)
|
|
301
|
+
is_system_prompt = Column(Boolean, default=False)
|
|
302
|
+
order = Column(Integer, nullable=False, default=0) # Nuevo campo para el orden
|
|
303
|
+
category_id = Column(Integer, ForeignKey('iat_prompt_categories.id'), nullable=True)
|
|
304
|
+
custom_fields = Column(JSON, nullable=False, default=[])
|
|
305
|
+
|
|
306
|
+
created_at = Column(DateTime, default=datetime.now)
|
|
307
|
+
|
|
308
|
+
company = relationship("Company", back_populates="prompts")
|
|
309
|
+
category = relationship("PromptCategory", back_populates="prompts")
|
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
# Copyright (c) 2024 Fernando Libedinsky
|
|
2
|
+
# Product: IAToolkit
|
|
3
|
+
#
|
|
4
|
+
# IAToolkit is open source software.
|
|
5
|
+
|
|
6
|
+
from repositories.models import User, Company, ApiKey, UserFeedback
|
|
7
|
+
from injector import inject
|
|
8
|
+
from repositories.database_manager import DatabaseManager
|
|
9
|
+
from sqlalchemy.orm import joinedload # Para cargar la relación eficientemente
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ProfileRepo:
|
|
13
|
+
@inject
|
|
14
|
+
def __init__(self, db_manager: DatabaseManager):
|
|
15
|
+
self.session = db_manager.get_session()
|
|
16
|
+
|
|
17
|
+
def get_user_by_id(self, user_id: int) -> User:
|
|
18
|
+
user = self.session.query(User).filter_by(id=user_id).first()
|
|
19
|
+
return user
|
|
20
|
+
|
|
21
|
+
def get_user_by_email(self, email: str) -> User:
|
|
22
|
+
user = self.session.query(User).filter_by(email=email).first()
|
|
23
|
+
return user
|
|
24
|
+
|
|
25
|
+
def create_user(self, new_user: User):
|
|
26
|
+
self.session.add(new_user)
|
|
27
|
+
self.session.commit()
|
|
28
|
+
return new_user
|
|
29
|
+
|
|
30
|
+
def save_user(self,existing_user: User):
|
|
31
|
+
self.session.add(existing_user)
|
|
32
|
+
self.session.commit()
|
|
33
|
+
return existing_user
|
|
34
|
+
|
|
35
|
+
def update_user(self, email, **kwargs):
|
|
36
|
+
user = self.session.query(User).filter_by(email=email).first()
|
|
37
|
+
if not user:
|
|
38
|
+
return None
|
|
39
|
+
|
|
40
|
+
# get the fields for update
|
|
41
|
+
for key, value in kwargs.items():
|
|
42
|
+
if hasattr(user, key): # Asegura que el campo existe en User
|
|
43
|
+
setattr(user, key, value)
|
|
44
|
+
|
|
45
|
+
self.session.commit()
|
|
46
|
+
return user # return updated object
|
|
47
|
+
|
|
48
|
+
def verify_user(self, email):
|
|
49
|
+
return self.update_user(email, verified=True)
|
|
50
|
+
|
|
51
|
+
def set_temp_code(self, email, temp_code):
|
|
52
|
+
return self.update_user(email, temp_code=temp_code)
|
|
53
|
+
|
|
54
|
+
def reset_temp_code(self, email):
|
|
55
|
+
return self.update_user(email, temp_code=None)
|
|
56
|
+
|
|
57
|
+
def update_password(self, email, hashed_password):
|
|
58
|
+
return self.update_user(email, password=hashed_password)
|
|
59
|
+
|
|
60
|
+
def get_company(self, name: str) -> Company:
|
|
61
|
+
return self.session.query(Company).filter_by(name=name).first()
|
|
62
|
+
|
|
63
|
+
def get_company_by_id(self, company_id: int) -> Company:
|
|
64
|
+
return self.session.query(Company).filter_by(id=company_id).first()
|
|
65
|
+
|
|
66
|
+
def get_company_by_short_name(self, short_name: str) -> Company:
|
|
67
|
+
return self.session.query(Company).filter(Company.short_name == short_name).first()
|
|
68
|
+
|
|
69
|
+
def get_companies(self) -> list[Company]:
|
|
70
|
+
return self.session.query(Company).all()
|
|
71
|
+
|
|
72
|
+
def create_company(self, new_company: Company):
|
|
73
|
+
company = self.session.query(Company).filter_by(name=new_company.name).first()
|
|
74
|
+
if company:
|
|
75
|
+
company.parameters = new_company.parameters
|
|
76
|
+
company.logo_file = new_company.logo_file
|
|
77
|
+
else:
|
|
78
|
+
self.session.add(new_company)
|
|
79
|
+
company = new_company
|
|
80
|
+
|
|
81
|
+
self.session.commit()
|
|
82
|
+
return company
|
|
83
|
+
|
|
84
|
+
def save_feedback(self, feedback: UserFeedback):
|
|
85
|
+
self.session.add(feedback)
|
|
86
|
+
self.session.commit()
|
|
87
|
+
return feedback
|
|
88
|
+
|
|
89
|
+
def create_api_key(self, new_api_key: ApiKey):
|
|
90
|
+
self.session.add(new_api_key)
|
|
91
|
+
self.session.commit()
|
|
92
|
+
return new_api_key
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def get_active_api_key_entry(self, api_key_value: str) -> ApiKey | None:
|
|
96
|
+
"""
|
|
97
|
+
search for an active API Key by its value.
|
|
98
|
+
returns the entry if found and is active, None otherwise.
|
|
99
|
+
"""
|
|
100
|
+
try:
|
|
101
|
+
# Usamos joinedload para cargar la compañía en la misma consulta
|
|
102
|
+
api_key_entry = self.session.query(ApiKey)\
|
|
103
|
+
.options(joinedload(ApiKey.company))\
|
|
104
|
+
.filter(ApiKey.key == api_key_value, ApiKey.is_active == True)\
|
|
105
|
+
.first()
|
|
106
|
+
return api_key_entry
|
|
107
|
+
except Exception:
|
|
108
|
+
self.session.rollback() # Asegura que la sesión esté limpia tras un error
|
|
109
|
+
return None
|
|
110
|
+
|
|
111
|
+
def get_active_api_key_by_company(self, company: Company) -> ApiKey | None:
|
|
112
|
+
return self.session.query(ApiKey)\
|
|
113
|
+
.filter(ApiKey.company == company, ApiKey.is_active == True)\
|
|
114
|
+
.first()
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
# Copyright (c) 2024 Fernando Libedinsky
|
|
2
|
+
# Product: IAToolkit
|
|
3
|
+
#
|
|
4
|
+
# IAToolkit is open source software.
|
|
5
|
+
|
|
6
|
+
from injector import inject
|
|
7
|
+
from datetime import datetime
|
|
8
|
+
from repositories.models import Task, TaskStatus, TaskType
|
|
9
|
+
from repositories.database_manager import DatabaseManager
|
|
10
|
+
from sqlalchemy import or_
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class TaskRepo:
|
|
14
|
+
@inject
|
|
15
|
+
def __init__(self, db_manager: DatabaseManager):
|
|
16
|
+
self.session = db_manager.get_session()
|
|
17
|
+
|
|
18
|
+
def create_task(self, new_task: Task) -> Task:
|
|
19
|
+
self.session.add(new_task)
|
|
20
|
+
self.session.commit()
|
|
21
|
+
return new_task
|
|
22
|
+
|
|
23
|
+
def update_task(self, task: Task) -> Task:
|
|
24
|
+
self.session.commit()
|
|
25
|
+
return task
|
|
26
|
+
|
|
27
|
+
def get_task_by_id(self, task_id: int):
|
|
28
|
+
return self.session.query(Task).filter_by(id=task_id).first()
|
|
29
|
+
|
|
30
|
+
def create_or_update_task_type(self, new_task_type: TaskType):
|
|
31
|
+
task_type = self.session.query(TaskType).filter_by(name=new_task_type.name).first()
|
|
32
|
+
if task_type:
|
|
33
|
+
task_type.prompt_template = new_task_type.prompt_template
|
|
34
|
+
task_type.template_args = new_task_type.template_args
|
|
35
|
+
else:
|
|
36
|
+
self.session.add(new_task_type)
|
|
37
|
+
task_type = new_task_type
|
|
38
|
+
self.session.commit()
|
|
39
|
+
return task_type
|
|
40
|
+
|
|
41
|
+
def get_task_type(self, name: str):
|
|
42
|
+
task_type = self.session.query(TaskType).filter_by(name=name).first()
|
|
43
|
+
return task_type
|
|
44
|
+
|
|
45
|
+
def get_pending_tasks(self, company_id: int):
|
|
46
|
+
now = datetime.now()
|
|
47
|
+
tasks = self.session.query(Task).filter(
|
|
48
|
+
Task.company_id == company_id,
|
|
49
|
+
Task.status == TaskStatus.pendiente,
|
|
50
|
+
or_(Task.execute_at == None, Task.execute_at <= now)
|
|
51
|
+
).all()
|
|
52
|
+
return tasks
|
repositories/vs_repo.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
# Copyright (c) 2024 Fernando Libedinsky
|
|
2
|
+
# Product: IAToolkit
|
|
3
|
+
#
|
|
4
|
+
# IAToolkit is open source software.
|
|
5
|
+
|
|
6
|
+
from sqlalchemy import text
|
|
7
|
+
from huggingface_hub import InferenceClient
|
|
8
|
+
from injector import inject
|
|
9
|
+
from common.exceptions import IAToolkitException
|
|
10
|
+
from repositories.database_manager import DatabaseManager
|
|
11
|
+
from repositories.models import Document, VSDoc
|
|
12
|
+
import os
|
|
13
|
+
import logging
|
|
14
|
+
|
|
15
|
+
class VSRepo:
|
|
16
|
+
@inject
|
|
17
|
+
def __init__(self, db_manager: DatabaseManager):
|
|
18
|
+
self.session = db_manager.get_session()
|
|
19
|
+
|
|
20
|
+
# Inicializar el modelo de embeddings
|
|
21
|
+
self.embedder = InferenceClient(
|
|
22
|
+
model="sentence-transformers/all-MiniLM-L6-v2",
|
|
23
|
+
token=os.getenv('HF_TOKEN'))
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def add_document(self, vs_chunk_list: list[VSDoc]):
|
|
27
|
+
try:
|
|
28
|
+
for doc in vs_chunk_list:
|
|
29
|
+
# calculate the embedding for the text
|
|
30
|
+
doc.embedding = self.embedder.feature_extraction(doc.text)
|
|
31
|
+
self.session.add(doc)
|
|
32
|
+
self.session.commit()
|
|
33
|
+
except Exception as e:
|
|
34
|
+
logging.error(f"Error insertando documentos en PostgreSQL: {str(e)}")
|
|
35
|
+
self.session.rollback()
|
|
36
|
+
raise IAToolkitException(IAToolkitException.ErrorType.VECTOR_STORE_ERROR,
|
|
37
|
+
f"Error insertando documentos en PostgreSQL: {str(e)}")
|
|
38
|
+
|
|
39
|
+
def query(self,
|
|
40
|
+
company_id: int,
|
|
41
|
+
query_text: str,
|
|
42
|
+
n_results=5,
|
|
43
|
+
metadata_filter=None
|
|
44
|
+
) -> list[Document]:
|
|
45
|
+
"""
|
|
46
|
+
search documents similar to the query for a company
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
company_id:
|
|
50
|
+
query_text: query text
|
|
51
|
+
n_results: max number of results to return
|
|
52
|
+
metadata_filter: (ej: {"document_type": "certificate"})
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
list of documents matching the query and filters
|
|
56
|
+
"""
|
|
57
|
+
# Generate the embedding with the query text
|
|
58
|
+
query_embedding = self.embedder.feature_extraction([query_text])[0]
|
|
59
|
+
|
|
60
|
+
try:
|
|
61
|
+
# build the SQL query
|
|
62
|
+
sql_query_parts = ["""
|
|
63
|
+
SELECT iat_documents.id, \
|
|
64
|
+
iat_documents.filename, \
|
|
65
|
+
iat_documents.content, \
|
|
66
|
+
iat_documents.content_b64, \
|
|
67
|
+
iat_documents.meta
|
|
68
|
+
FROM iat_vsdocs, \
|
|
69
|
+
iat_documents
|
|
70
|
+
WHERE iat_vsdocs.company_id = :company_id
|
|
71
|
+
AND iat_vsdocs.document_id = iat_documents.id \
|
|
72
|
+
"""]
|
|
73
|
+
|
|
74
|
+
# query parameters
|
|
75
|
+
params = {
|
|
76
|
+
"company_id": company_id,
|
|
77
|
+
"query_embedding": query_embedding,
|
|
78
|
+
"n_results": n_results
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
# add metadata filter, if exists
|
|
82
|
+
if metadata_filter and isinstance(metadata_filter, dict):
|
|
83
|
+
for key, value in metadata_filter.items():
|
|
84
|
+
# Usar el operador ->> para extraer el valor del JSON como texto.
|
|
85
|
+
# La clave del JSON se interpola directamente.
|
|
86
|
+
# El valor se pasa como parámetro para evitar inyección SQL.
|
|
87
|
+
param_name = f"value_{key}_filter"
|
|
88
|
+
sql_query_parts.append(f" AND documents.meta->>'{key}' = :{param_name}")
|
|
89
|
+
params[param_name] = str(value) # parametros como string
|
|
90
|
+
|
|
91
|
+
# join all the query parts
|
|
92
|
+
sql_query = "".join(sql_query_parts)
|
|
93
|
+
|
|
94
|
+
# add sorting and limit of results
|
|
95
|
+
sql_query += " ORDER BY embedding <-> :query_embedding LIMIT :n_results"
|
|
96
|
+
|
|
97
|
+
logging.debug(f"Executing SQL query: {sql_query}")
|
|
98
|
+
logging.debug(f"With parameters: {params}")
|
|
99
|
+
|
|
100
|
+
# execute the query
|
|
101
|
+
result = self.session.execute(text(sql_query), params)
|
|
102
|
+
|
|
103
|
+
rows = result.fetchall()
|
|
104
|
+
vs_documents = []
|
|
105
|
+
|
|
106
|
+
for row in rows:
|
|
107
|
+
# create the document object with the data
|
|
108
|
+
meta_data = row[4] if len(row) > 4 and row[4] is not None else {}
|
|
109
|
+
doc = Document(
|
|
110
|
+
id=row[0],
|
|
111
|
+
company_id=company_id,
|
|
112
|
+
filename=row[1],
|
|
113
|
+
content=row[2],
|
|
114
|
+
content_b64=row[3],
|
|
115
|
+
meta=meta_data
|
|
116
|
+
)
|
|
117
|
+
vs_documents.append(doc)
|
|
118
|
+
|
|
119
|
+
return self.remove_duplicates_by_id(vs_documents)
|
|
120
|
+
|
|
121
|
+
except Exception as e:
|
|
122
|
+
logging.error(f"Error en la consulta de documentos: {str(e)}")
|
|
123
|
+
logging.error(f"Failed SQL: {sql_query}")
|
|
124
|
+
logging.error(f"Failed params: {params}")
|
|
125
|
+
raise IAToolkitException(IAToolkitException.ErrorType.VECTOR_STORE_ERROR,
|
|
126
|
+
f"Error en la consulta: {str(e)}")
|
|
127
|
+
finally:
|
|
128
|
+
self.session.close()
|
|
129
|
+
|
|
130
|
+
def remove_duplicates_by_id(self, objects):
|
|
131
|
+
unique_by_id = {}
|
|
132
|
+
result = []
|
|
133
|
+
|
|
134
|
+
for obj in objects:
|
|
135
|
+
if obj.id not in unique_by_id:
|
|
136
|
+
unique_by_id[obj.id] = True
|
|
137
|
+
result.append(obj)
|
|
138
|
+
|
|
139
|
+
return result
|