iatoolkit 0.7.4__py3-none-any.whl → 0.7.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of iatoolkit might be problematic. Click here for more details.

Files changed (57) hide show
  1. common/__init__.py +0 -0
  2. common/auth.py +200 -0
  3. common/exceptions.py +46 -0
  4. common/routes.py +86 -0
  5. common/session_manager.py +25 -0
  6. common/util.py +358 -0
  7. iatoolkit/iatoolkit.py +3 -3
  8. {iatoolkit-0.7.4.dist-info → iatoolkit-0.7.6.dist-info}/METADATA +1 -1
  9. iatoolkit-0.7.6.dist-info/RECORD +80 -0
  10. iatoolkit-0.7.6.dist-info/top_level.txt +6 -0
  11. infra/__init__.py +5 -0
  12. infra/call_service.py +140 -0
  13. infra/connectors/__init__.py +5 -0
  14. infra/connectors/file_connector.py +17 -0
  15. infra/connectors/file_connector_factory.py +57 -0
  16. infra/connectors/google_cloud_storage_connector.py +53 -0
  17. infra/connectors/google_drive_connector.py +68 -0
  18. infra/connectors/local_file_connector.py +46 -0
  19. infra/connectors/s3_connector.py +33 -0
  20. infra/gemini_adapter.py +356 -0
  21. infra/google_chat_app.py +57 -0
  22. infra/llm_client.py +430 -0
  23. infra/llm_proxy.py +139 -0
  24. infra/llm_response.py +40 -0
  25. infra/mail_app.py +145 -0
  26. infra/openai_adapter.py +90 -0
  27. infra/redis_session_manager.py +76 -0
  28. repositories/__init__.py +5 -0
  29. repositories/database_manager.py +95 -0
  30. repositories/document_repo.py +33 -0
  31. repositories/llm_query_repo.py +91 -0
  32. repositories/models.py +309 -0
  33. repositories/profile_repo.py +118 -0
  34. repositories/tasks_repo.py +52 -0
  35. repositories/vs_repo.py +139 -0
  36. views/__init__.py +5 -0
  37. views/change_password_view.py +91 -0
  38. views/chat_token_request_view.py +98 -0
  39. views/chat_view.py +51 -0
  40. views/download_file_view.py +58 -0
  41. views/external_chat_login_view.py +88 -0
  42. views/external_login_view.py +40 -0
  43. views/file_store_view.py +58 -0
  44. views/forgot_password_view.py +64 -0
  45. views/history_view.py +57 -0
  46. views/home_view.py +34 -0
  47. views/llmquery_view.py +65 -0
  48. views/login_view.py +60 -0
  49. views/prompt_view.py +37 -0
  50. views/signup_view.py +87 -0
  51. views/tasks_review_view.py +83 -0
  52. views/tasks_view.py +98 -0
  53. views/user_feedback_view.py +74 -0
  54. views/verify_user_view.py +55 -0
  55. iatoolkit-0.7.4.dist-info/RECORD +0 -30
  56. iatoolkit-0.7.4.dist-info/top_level.txt +0 -2
  57. {iatoolkit-0.7.4.dist-info → iatoolkit-0.7.6.dist-info}/WHEEL +0 -0
repositories/models.py ADDED
@@ -0,0 +1,309 @@
1
+ # Copyright (c) 2024 Fernando Libedinsky
2
+ # Product: IAToolkit
3
+ #
4
+ # IAToolkit is open source software.
5
+
6
+ from sqlalchemy import Column, Integer, String, DateTime, Enum, Text, JSON, Boolean, ForeignKey, Table
7
+ from sqlalchemy.orm import DeclarativeBase
8
+ from sqlalchemy.orm import relationship, class_mapper, declarative_base
9
+ from datetime import datetime
10
+ from pgvector.sqlalchemy import Vector
11
+ from enum import Enum as PyEnum
12
+ import secrets
13
+ import enum
14
+
15
+
16
+ # base class for the ORM
17
+ class Base(DeclarativeBase):
18
+ pass
19
+
20
+ # relation table for many-to-many relationship between companies and users
21
+ user_company = Table('iat_user_company',
22
+ Base.metadata,
23
+ Column('user_id', Integer,
24
+ ForeignKey('iat_users.id', ondelete='CASCADE'),
25
+ primary_key=True),
26
+ Column('company_id', Integer,
27
+ ForeignKey('iat_companies.id',ondelete='CASCADE'),
28
+ primary_key=True),
29
+ Column('is_active', Boolean, default=True),
30
+ Column('role', String(50), default='user'), # Para manejar roles por empresa
31
+ Column('created_at', DateTime, default=datetime.now)
32
+ )
33
+
34
+ class ApiKey(Base):
35
+ """Represents an API key for a company to authenticate against the system."""
36
+ __tablename__ = 'iat_api_keys'
37
+
38
+ id = Column(Integer, primary_key=True)
39
+ company_id = Column(Integer, ForeignKey('iat_companies.id', ondelete='CASCADE'), nullable=False)
40
+ key = Column(String(128), unique=True, nullable=False, index=True) # La API Key en sí
41
+ is_active = Column(Boolean, default=True, nullable=False)
42
+ created_at = Column(DateTime, default=datetime.now)
43
+ last_used_at = Column(DateTime, nullable=True) # Opcional: para rastrear uso
44
+
45
+ company = relationship("Company", back_populates="api_keys")
46
+
47
+
48
+ class Company(Base):
49
+ """Represents a company or tenant in the multi-tenant system."""
50
+ __tablename__ = 'iat_companies'
51
+
52
+ id = Column(Integer, primary_key=True)
53
+ short_name = Column(String(20), nullable=False, unique=True, index=True)
54
+ name = Column(String(256), nullable=False)
55
+
56
+ # encrypted api-key
57
+ openai_api_key = Column(String, nullable=True)
58
+ gemini_api_key = Column(String, nullable=True)
59
+
60
+ logo_file = Column(String(128), nullable=True, default='')
61
+ parameters = Column(JSON, nullable=True, default={})
62
+ created_at = Column(DateTime, default=datetime.now)
63
+ allow_jwt = Column(Boolean, default=False, nullable=True)
64
+
65
+ documents = relationship("Document",
66
+ back_populates="company",
67
+ cascade="all, delete-orphan",
68
+ lazy='dynamic')
69
+ functions = relationship("Function",
70
+ back_populates="company",
71
+ cascade="all, delete-orphan")
72
+ vsdocs = relationship("VSDoc",
73
+ back_populates="company",
74
+ cascade="all, delete-orphan")
75
+ llm_queries = relationship("LLMQuery",
76
+ back_populates="company",
77
+ cascade="all, delete-orphan")
78
+ users = relationship("User",
79
+ secondary=user_company,
80
+ back_populates="companies")
81
+ api_keys = relationship("ApiKey",
82
+ back_populates="company",
83
+ cascade="all, delete-orphan")
84
+
85
+ tasks = relationship("Task", back_populates="company")
86
+ feedbacks = relationship("UserFeedback",
87
+ back_populates="company",
88
+ cascade="all, delete-orphan")
89
+ prompts = relationship("Prompt",
90
+ back_populates="company",
91
+ cascade="all, delete-orphan")
92
+
93
+ def to_dict(self):
94
+ return {column.key: getattr(self, column.key) for column in class_mapper(self.__class__).columns}
95
+
96
+ # users with rights to use this app
97
+ class User(Base):
98
+ """Represents an IAToolkit user who can be associated with multiple companies."""
99
+ __tablename__ = 'iat_users'
100
+
101
+ id = Column(Integer, primary_key=True)
102
+ email = Column(String(80), unique=True, nullable=False)
103
+ first_name = Column(String(50), nullable=False)
104
+ last_name = Column(String(50), nullable=False)
105
+ created_at = Column(DateTime, default=datetime.now)
106
+ password = Column(String, nullable=False)
107
+ verified = Column(Boolean, nullable=False, default=False)
108
+ verification_url = Column(String, nullable=True)
109
+ temp_code = Column(String, nullable=True)
110
+
111
+ companies = relationship(
112
+ "Company",
113
+ secondary=user_company,
114
+ back_populates="users",
115
+ cascade="all",
116
+ passive_deletes=True,
117
+ lazy='dynamic'
118
+ )
119
+
120
+ def to_dict(self):
121
+ return {
122
+ 'id': self.id,
123
+ 'email': self.email,
124
+ 'first_name': self.first_name,
125
+ 'last_name': self.last_name,
126
+ 'created_at': str(self.created_at),
127
+ 'verified': self.verified,
128
+ 'companies': [company.to_dict() for company in self.companies]
129
+ }
130
+
131
+ class Function(Base):
132
+ """Represents a custom or system function that the LLM can call (tool)."""
133
+ __tablename__ = 'iat_functions'
134
+
135
+ id = Column(Integer, primary_key=True)
136
+ company_id = Column(Integer,
137
+ ForeignKey('iat_companies.id',ondelete='CASCADE'),
138
+ nullable=True)
139
+ name = Column(String(255), nullable=False)
140
+ system_function = Column(Boolean, default=False)
141
+ description = Column(Text, nullable=False)
142
+ parameters = Column(JSON, nullable=False)
143
+ is_active = Column(Boolean, default=True)
144
+ created_at = Column(DateTime, default=datetime.now)
145
+
146
+ company = relationship('Company', back_populates='functions')
147
+
148
+ def to_dict(self):
149
+ return {column.key: getattr(self, column.key) for column in class_mapper(self.__class__).columns}
150
+
151
+
152
+ class Document(Base):
153
+ """Represents a file or document uploaded by a company for context."""
154
+ __tablename__ = 'iat_documents'
155
+
156
+ id = Column(Integer, primary_key=True, autoincrement=True)
157
+ company_id = Column(Integer, ForeignKey('iat_companies.id',
158
+ ondelete='CASCADE'), nullable=False)
159
+ filename = Column(String(256), nullable=False, index=True)
160
+ content = Column(Text, nullable=False)
161
+ content_b64 = Column(Text, nullable=False)
162
+ meta = Column(JSON, nullable=True)
163
+ created_at = Column(DateTime, default=datetime.now)
164
+
165
+ company = relationship("Company", back_populates="documents")
166
+
167
+ def to_dict(self):
168
+ return {column.key: getattr(self, column.key) for column in class_mapper(self.__class__).columns}
169
+
170
+
171
+ class LLMQuery(Base):
172
+ """Logs a query made to the LLM, including input, output, and metadata."""
173
+ __tablename__ = 'iat_queries'
174
+
175
+ id = Column(Integer, primary_key=True)
176
+ company_id = Column(Integer, ForeignKey('iat_companies.id',
177
+ ondelete='CASCADE'), nullable=False)
178
+ user_identifier = Column(String(128), nullable=False)
179
+ task_id = Column(Integer, default=0, nullable=True)
180
+ query = Column(Text, nullable=False)
181
+ output = Column(Text, nullable=False)
182
+ response = Column(JSON, nullable=True, default={})
183
+ valid_response = Column(Boolean, nullable=False, default=False)
184
+ function_calls = Column(JSON, nullable=True, default={})
185
+ stats = Column(JSON, default={})
186
+ answer_time = Column(Integer, default=0)
187
+ created_at = Column(DateTime, default=datetime.now)
188
+
189
+ company = relationship("Company", back_populates="llm_queries")
190
+ tasks = relationship("Task", back_populates="llm_query")
191
+
192
+ def to_dict(self):
193
+ return {column.key: getattr(self, column.key) for column in class_mapper(self.__class__).columns}
194
+
195
+
196
+ class VSDoc(Base):
197
+ """Stores a text chunk and its corresponding vector embedding for similarity search."""
198
+ __tablename__ = "iat_vsdocs"
199
+
200
+ id = Column(Integer, primary_key=True)
201
+ company_id = Column(Integer, ForeignKey('iat_companies.id',
202
+ ondelete='CASCADE'), nullable=False)
203
+ document_id = Column(Integer, ForeignKey('iat_documents.id',
204
+ ondelete='CASCADE'), nullable=False)
205
+ text = Column(Text, nullable=False)
206
+ embedding = Column(Vector(384), nullable=False) # Ajusta la dimensión si es necesario
207
+
208
+ company = relationship("Company", back_populates="vsdocs")
209
+
210
+ def to_dict(self):
211
+ return {column.key: getattr(self, column.key) for column in class_mapper(self.__class__).columns}
212
+
213
+ class TaskStatus(PyEnum):
214
+ """Enumeration for the possible statuses of a Task."""
215
+ pendiente = "pendiente" # task created and waiting to be executed.
216
+ ejecutado = "ejecutado" # the IA algorithm has been executed.
217
+ aprobada = "aprobada" # validated and approved by human.
218
+ rechazada = "rechazada" # validated and rejected by human.
219
+ fallida = "fallida" # error executing the IA algorithm.
220
+
221
+ class TaskType(Base):
222
+ """Defines a type of task that can be executed, including its prompt template."""
223
+ __tablename__ = 'iat_task_types'
224
+
225
+ id = Column(Integer, primary_key=True)
226
+ name = Column(String(100), unique=True, nullable=False)
227
+ prompt_template = Column(String(100), nullable=True) # Plantilla de prompt por defecto.
228
+ template_args = Column(JSON, nullable=True) # Argumentos/prefijos de configuración para el template.
229
+
230
+ class Task(Base):
231
+ """Represents an asynchronous task to be executed by the system, often involving an LLM."""
232
+ __tablename__ = 'iat_tasks'
233
+
234
+ id = Column(Integer, primary_key=True)
235
+ company_id = Column(Integer, ForeignKey("iat_companies.id"))
236
+
237
+ user_id = Column(Integer, nullable=True, default=0)
238
+ task_type_id = Column(Integer, ForeignKey('iat_task_types.id'), nullable=False)
239
+ status = Column(Enum(TaskStatus, name="task_status_enum"),
240
+ default=TaskStatus.pendiente, nullable=False)
241
+ client_data = Column(JSON, nullable=True, default={})
242
+ company_task_id = Column(Integer, nullable=True, default=0)
243
+ execute_at = Column(DateTime, default=datetime.now, nullable=True)
244
+ llm_query_id = Column(Integer, ForeignKey('iat_queries.id'), nullable=True)
245
+ callback_url = Column(String(512), default=None, nullable=True)
246
+ files = Column(JSON, default=[], nullable=True)
247
+
248
+ review_user = Column(String(128), nullable=True, default='')
249
+ review_date = Column(DateTime, nullable=True)
250
+ comment = Column(Text, nullable=True)
251
+ approved = Column(Boolean, nullable=False, default=False)
252
+
253
+ created_at = Column(DateTime, default=datetime.now)
254
+ updated_at = Column(DateTime, default=datetime.now)
255
+
256
+ task_type = relationship("TaskType")
257
+ llm_query = relationship("LLMQuery", back_populates="tasks", uselist=False)
258
+ company = relationship("Company", back_populates="tasks")
259
+
260
+ class UserFeedback(Base):
261
+ """Stores feedback and ratings submitted by users for specific interactions."""
262
+ __tablename__ = 'iat_feedback'
263
+
264
+ id = Column(Integer, primary_key=True)
265
+ company_id = Column(Integer, ForeignKey('iat_companies.id',
266
+ ondelete='CASCADE'), nullable=False)
267
+ local_user_id = Column(Integer, default=0, nullable=True)
268
+ external_user_id = Column(String(128), default='', nullable=True)
269
+ message = Column(Text, nullable=False)
270
+ rating = Column(Integer, nullable=False)
271
+ created_at = Column(DateTime, default=datetime.now)
272
+
273
+ company = relationship("Company", back_populates="feedbacks")
274
+
275
+
276
+ class PromptCategory(Base):
277
+ """Represents a category to group and organize prompts."""
278
+ __tablename__ = 'iat_prompt_categories'
279
+ id = Column(Integer, primary_key=True)
280
+ name = Column(String, nullable=False)
281
+ order = Column(Integer, nullable=False, default=0)
282
+ company_id = Column(Integer, ForeignKey('iat_companies.id'), nullable=False)
283
+
284
+ prompts = relationship("Prompt", back_populates="category", order_by="Prompt.order")
285
+
286
+ def __repr__(self):
287
+ return f"<PromptCategory(name='{self.name}', order={self.order})>"
288
+
289
+
290
+ class Prompt(Base):
291
+ """Represents a system or user-defined prompt template for the LLM."""
292
+ __tablename__ = 'iat_prompt'
293
+
294
+ id = Column(Integer, primary_key=True)
295
+ company_id = Column(Integer, ForeignKey('iat_companies.id',
296
+ ondelete='CASCADE'), nullable=True)
297
+ name = Column(String(64), nullable=False)
298
+ description = Column(String(256), nullable=False)
299
+ filename = Column(String(256), nullable=False)
300
+ active = Column(Boolean, default=True)
301
+ is_system_prompt = Column(Boolean, default=False)
302
+ order = Column(Integer, nullable=False, default=0) # Nuevo campo para el orden
303
+ category_id = Column(Integer, ForeignKey('iat_prompt_categories.id'), nullable=True)
304
+ custom_fields = Column(JSON, nullable=False, default=[])
305
+
306
+ created_at = Column(DateTime, default=datetime.now)
307
+
308
+ company = relationship("Company", back_populates="prompts")
309
+ category = relationship("PromptCategory", back_populates="prompts")
@@ -0,0 +1,118 @@
1
+ # Copyright (c) 2024 Fernando Libedinsky
2
+ # Product: IAToolkit
3
+ #
4
+ # IAToolkit is open source software.
5
+
6
+ from repositories.models import User, Company, ApiKey, UserFeedback
7
+ from injector import inject
8
+ from repositories.database_manager import DatabaseManager
9
+ from sqlalchemy.orm import joinedload # Para cargar la relación eficientemente
10
+
11
+
12
+ class ProfileRepo:
13
+ @inject
14
+ def __init__(self, db_manager: DatabaseManager):
15
+ self.session = db_manager.get_session()
16
+
17
+ def get_user_by_id(self, user_id: int) -> User:
18
+ user = self.session.query(User).filter_by(id=user_id).first()
19
+ return user
20
+
21
+ def get_user_by_email(self, email: str) -> User:
22
+ user = self.session.query(User).filter_by(email=email).first()
23
+ return user
24
+
25
+ def create_user(self, new_user: User):
26
+ self.session.add(new_user)
27
+ self.session.commit()
28
+ return new_user
29
+
30
+ def save_user(self,existing_user: User):
31
+ self.session.add(existing_user)
32
+ self.session.commit()
33
+ return existing_user
34
+
35
+ def update_user(self, email, **kwargs):
36
+ user = self.session.query(User).filter_by(email=email).first()
37
+ if not user:
38
+ return None
39
+
40
+ # get the fields for update
41
+ for key, value in kwargs.items():
42
+ if hasattr(user, key): # Asegura que el campo existe en User
43
+ setattr(user, key, value)
44
+
45
+ self.session.commit()
46
+ return user # return updated object
47
+
48
+ def verify_user(self, email):
49
+ return self.update_user(email, verified=True)
50
+
51
+ def set_temp_code(self, email, temp_code):
52
+ return self.update_user(email, temp_code=temp_code)
53
+
54
+ def reset_temp_code(self, email):
55
+ return self.update_user(email, temp_code=None)
56
+
57
+ def update_password(self, email, hashed_password):
58
+ return self.update_user(email, password=hashed_password)
59
+
60
+ def get_company(self, name: str) -> Company:
61
+ return self.session.query(Company).filter_by(name=name).first()
62
+
63
+ def get_company_by_id(self, company_id: int) -> Company:
64
+ return self.session.query(Company).filter_by(id=company_id).first()
65
+
66
+ def get_company_by_short_name(self, short_name: str) -> Company:
67
+ return self.session.query(Company).filter(Company.short_name == short_name).first()
68
+
69
+ def get_companies(self) -> list[Company]:
70
+ return self.session.query(Company).all()
71
+
72
+ def create_company(self, new_company: Company):
73
+ company = self.session.query(Company).filter_by(name=new_company.name).first()
74
+ if company:
75
+ company.parameters = new_company.parameters
76
+ company.logo_file = new_company.logo_file
77
+ else:
78
+ self.session.add(new_company)
79
+ company = new_company
80
+
81
+ self.session.commit()
82
+ return company
83
+
84
+ def save_feedback(self, feedback: UserFeedback):
85
+ self.session.add(feedback)
86
+ self.session.commit()
87
+ return feedback
88
+
89
+ def create_api_key(self, new_api_key: ApiKey):
90
+ self.session.add(new_api_key)
91
+ self.session.commit()
92
+ return new_api_key
93
+
94
+
95
+ def get_active_api_key_entry(self, api_key_value: str) -> ApiKey | None:
96
+ """
97
+ search for an active API Key by its value.
98
+ returns the entry if found and is active, None otherwise.
99
+ """
100
+ try:
101
+ # Usamos joinedload para cargar la compañía en la misma consulta
102
+ api_key_entry = self.session.query(ApiKey)\
103
+ .options(joinedload(ApiKey.company))\
104
+ .filter(ApiKey.key == api_key_value, ApiKey.is_active == True)\
105
+ .first()
106
+ return api_key_entry
107
+ except Exception:
108
+ self.session.rollback() # Asegura que la sesión esté limpia tras un error
109
+ return None
110
+
111
+ def get_active_api_key_by_company(self, company: Company) -> ApiKey | None:
112
+ return self.session.query(ApiKey)\
113
+ .filter(ApiKey.company == company, ApiKey.is_active == True)\
114
+ .first()
115
+
116
+
117
+
118
+
@@ -0,0 +1,52 @@
1
+ # Copyright (c) 2024 Fernando Libedinsky
2
+ # Product: IAToolkit
3
+ #
4
+ # IAToolkit is open source software.
5
+
6
+ from injector import inject
7
+ from datetime import datetime
8
+ from repositories.models import Task, TaskStatus, TaskType
9
+ from repositories.database_manager import DatabaseManager
10
+ from sqlalchemy import or_
11
+
12
+
13
+ class TaskRepo:
14
+ @inject
15
+ def __init__(self, db_manager: DatabaseManager):
16
+ self.session = db_manager.get_session()
17
+
18
+ def create_task(self, new_task: Task) -> Task:
19
+ self.session.add(new_task)
20
+ self.session.commit()
21
+ return new_task
22
+
23
+ def update_task(self, task: Task) -> Task:
24
+ self.session.commit()
25
+ return task
26
+
27
+ def get_task_by_id(self, task_id: int):
28
+ return self.session.query(Task).filter_by(id=task_id).first()
29
+
30
+ def create_or_update_task_type(self, new_task_type: TaskType):
31
+ task_type = self.session.query(TaskType).filter_by(name=new_task_type.name).first()
32
+ if task_type:
33
+ task_type.prompt_template = new_task_type.prompt_template
34
+ task_type.template_args = new_task_type.template_args
35
+ else:
36
+ self.session.add(new_task_type)
37
+ task_type = new_task_type
38
+ self.session.commit()
39
+ return task_type
40
+
41
+ def get_task_type(self, name: str):
42
+ task_type = self.session.query(TaskType).filter_by(name=name).first()
43
+ return task_type
44
+
45
+ def get_pending_tasks(self, company_id: int):
46
+ now = datetime.now()
47
+ tasks = self.session.query(Task).filter(
48
+ Task.company_id == company_id,
49
+ Task.status == TaskStatus.pendiente,
50
+ or_(Task.execute_at == None, Task.execute_at <= now)
51
+ ).all()
52
+ return tasks
@@ -0,0 +1,139 @@
1
+ # Copyright (c) 2024 Fernando Libedinsky
2
+ # Product: IAToolkit
3
+ #
4
+ # IAToolkit is open source software.
5
+
6
+ from sqlalchemy import text
7
+ from huggingface_hub import InferenceClient
8
+ from injector import inject
9
+ from common.exceptions import IAToolkitException
10
+ from repositories.database_manager import DatabaseManager
11
+ from repositories.models import Document, VSDoc
12
+ import os
13
+ import logging
14
+
15
+ class VSRepo:
16
+ @inject
17
+ def __init__(self, db_manager: DatabaseManager):
18
+ self.session = db_manager.get_session()
19
+
20
+ # Inicializar el modelo de embeddings
21
+ self.embedder = InferenceClient(
22
+ model="sentence-transformers/all-MiniLM-L6-v2",
23
+ token=os.getenv('HF_TOKEN'))
24
+
25
+
26
+ def add_document(self, vs_chunk_list: list[VSDoc]):
27
+ try:
28
+ for doc in vs_chunk_list:
29
+ # calculate the embedding for the text
30
+ doc.embedding = self.embedder.feature_extraction(doc.text)
31
+ self.session.add(doc)
32
+ self.session.commit()
33
+ except Exception as e:
34
+ logging.error(f"Error insertando documentos en PostgreSQL: {str(e)}")
35
+ self.session.rollback()
36
+ raise IAToolkitException(IAToolkitException.ErrorType.VECTOR_STORE_ERROR,
37
+ f"Error insertando documentos en PostgreSQL: {str(e)}")
38
+
39
+ def query(self,
40
+ company_id: int,
41
+ query_text: str,
42
+ n_results=5,
43
+ metadata_filter=None
44
+ ) -> list[Document]:
45
+ """
46
+ search documents similar to the query for a company
47
+
48
+ Args:
49
+ company_id:
50
+ query_text: query text
51
+ n_results: max number of results to return
52
+ metadata_filter: (ej: {"document_type": "certificate"})
53
+
54
+ Returns:
55
+ list of documents matching the query and filters
56
+ """
57
+ # Generate the embedding with the query text
58
+ query_embedding = self.embedder.feature_extraction([query_text])[0]
59
+
60
+ try:
61
+ # build the SQL query
62
+ sql_query_parts = ["""
63
+ SELECT iat_documents.id, \
64
+ iat_documents.filename, \
65
+ iat_documents.content, \
66
+ iat_documents.content_b64, \
67
+ iat_documents.meta
68
+ FROM iat_vsdocs, \
69
+ iat_documents
70
+ WHERE iat_vsdocs.company_id = :company_id
71
+ AND iat_vsdocs.document_id = iat_documents.id \
72
+ """]
73
+
74
+ # query parameters
75
+ params = {
76
+ "company_id": company_id,
77
+ "query_embedding": query_embedding,
78
+ "n_results": n_results
79
+ }
80
+
81
+ # add metadata filter, if exists
82
+ if metadata_filter and isinstance(metadata_filter, dict):
83
+ for key, value in metadata_filter.items():
84
+ # Usar el operador ->> para extraer el valor del JSON como texto.
85
+ # La clave del JSON se interpola directamente.
86
+ # El valor se pasa como parámetro para evitar inyección SQL.
87
+ param_name = f"value_{key}_filter"
88
+ sql_query_parts.append(f" AND documents.meta->>'{key}' = :{param_name}")
89
+ params[param_name] = str(value) # parametros como string
90
+
91
+ # join all the query parts
92
+ sql_query = "".join(sql_query_parts)
93
+
94
+ # add sorting and limit of results
95
+ sql_query += " ORDER BY embedding <-> :query_embedding LIMIT :n_results"
96
+
97
+ logging.debug(f"Executing SQL query: {sql_query}")
98
+ logging.debug(f"With parameters: {params}")
99
+
100
+ # execute the query
101
+ result = self.session.execute(text(sql_query), params)
102
+
103
+ rows = result.fetchall()
104
+ vs_documents = []
105
+
106
+ for row in rows:
107
+ # create the document object with the data
108
+ meta_data = row[4] if len(row) > 4 and row[4] is not None else {}
109
+ doc = Document(
110
+ id=row[0],
111
+ company_id=company_id,
112
+ filename=row[1],
113
+ content=row[2],
114
+ content_b64=row[3],
115
+ meta=meta_data
116
+ )
117
+ vs_documents.append(doc)
118
+
119
+ return self.remove_duplicates_by_id(vs_documents)
120
+
121
+ except Exception as e:
122
+ logging.error(f"Error en la consulta de documentos: {str(e)}")
123
+ logging.error(f"Failed SQL: {sql_query}")
124
+ logging.error(f"Failed params: {params}")
125
+ raise IAToolkitException(IAToolkitException.ErrorType.VECTOR_STORE_ERROR,
126
+ f"Error en la consulta: {str(e)}")
127
+ finally:
128
+ self.session.close()
129
+
130
+ def remove_duplicates_by_id(self, objects):
131
+ unique_by_id = {}
132
+ result = []
133
+
134
+ for obj in objects:
135
+ if obj.id not in unique_by_id:
136
+ unique_by_id[obj.id] = True
137
+ result.append(obj)
138
+
139
+ return result
views/__init__.py ADDED
@@ -0,0 +1,5 @@
1
+ # Copyright (c) 2024 Fernando Libedinsky
2
+ # Product: IAToolkit
3
+ #
4
+ # IAToolkit is open source software.
5
+