wizit-context-ingestor 0.2.5b3__py3-none-any.whl → 0.3.0b2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. wizit_context_ingestor/__init__.py +2 -2
  2. wizit_context_ingestor/application/context_chunk_service.py +149 -35
  3. wizit_context_ingestor/application/transcription_service.py +132 -52
  4. wizit_context_ingestor/data/kdb.py +10 -0
  5. wizit_context_ingestor/data/prompts.py +150 -3
  6. wizit_context_ingestor/data/storage.py +10 -0
  7. wizit_context_ingestor/infra/persistence/local_storage.py +19 -9
  8. wizit_context_ingestor/infra/persistence/s3_storage.py +29 -23
  9. wizit_context_ingestor/infra/rag/chroma_embeddings.py +30 -31
  10. wizit_context_ingestor/infra/rag/pg_embeddings.py +57 -54
  11. wizit_context_ingestor/infra/rag/redis_embeddings.py +34 -25
  12. wizit_context_ingestor/infra/rag/semantic_chunks.py +9 -1
  13. wizit_context_ingestor/infra/vertex_model.py +56 -28
  14. wizit_context_ingestor/main.py +192 -106
  15. wizit_context_ingestor/utils/file_utils.py +13 -0
  16. wizit_context_ingestor/workflows/context_nodes.py +73 -0
  17. wizit_context_ingestor/workflows/context_state.py +10 -0
  18. wizit_context_ingestor/workflows/context_tools.py +58 -0
  19. wizit_context_ingestor/workflows/context_workflow.py +42 -0
  20. wizit_context_ingestor/workflows/transcription_nodes.py +136 -0
  21. wizit_context_ingestor/workflows/transcription_schemas.py +25 -0
  22. wizit_context_ingestor/workflows/transcription_state.py +17 -0
  23. wizit_context_ingestor/workflows/transcription_tools.py +54 -0
  24. wizit_context_ingestor/workflows/transcription_workflow.py +42 -0
  25. {wizit_context_ingestor-0.2.5b3.dist-info → wizit_context_ingestor-0.3.0b2.dist-info}/METADATA +9 -1
  26. wizit_context_ingestor-0.3.0b2.dist-info/RECORD +44 -0
  27. {wizit_context_ingestor-0.2.5b3.dist-info → wizit_context_ingestor-0.3.0b2.dist-info}/WHEEL +1 -1
  28. wizit_context_ingestor-0.2.5b3.dist-info/RECORD +0 -32
@@ -15,14 +15,23 @@ class VertexModels(AiApplicationService):
15
15
  A wrapper class for Google Cloud Vertex AI models that handles credentials and
16
16
  provides methods to load embeddings and chat models.
17
17
  """
18
- __slots__ = ('project_id', 'location', 'json_service_account', 'scopes', 'llm_model_id')
18
+
19
+ __slots__ = (
20
+ "project_id",
21
+ "location",
22
+ "json_service_account",
23
+ "scopes",
24
+ "llm_model_id",
25
+ )
26
+
19
27
  def __init__(
20
- self,
21
- project_id: str,
22
- location: str,
23
- json_service_account: Dict[str, Any],
24
- scopes: Optional[List[str]] = None,
25
- llm_model_id: str = "claude-3-5-haiku@20241022"):
28
+ self,
29
+ project_id: str,
30
+ location: str,
31
+ json_service_account: Dict[str, Any],
32
+ scopes: Optional[List[str]] = None,
33
+ llm_model_id: str = "claude-sonnet-4@20250514",
34
+ ):
26
35
  """
27
36
  Initialize the VertexModels class with Google Cloud credentials.
28
37
 
@@ -36,25 +45,24 @@ class VertexModels(AiApplicationService):
36
45
  print(location)
37
46
  self.scopes = scopes or ["https://www.googleapis.com/auth/cloud-platform"]
38
47
  self.credentials = service_account.Credentials.from_service_account_info(
39
- json_service_account,
40
- scopes=self.scopes
48
+ json_service_account, scopes=self.scopes
41
49
  )
42
50
  self.llm_model_id = llm_model_id
43
51
  self.project_id = project_id
44
52
  self.location = location
45
53
  vertexai_init(
46
- project=project_id,
47
- location=location,
48
- credentials=self.credentials
54
+ project=project_id, location=location, credentials=self.credentials
55
+ )
56
+ logger.info(
57
+ f"VertexModels initialized with project {project_id} in {location}"
49
58
  )
50
- logger.info(f"VertexModels initialized with project {project_id} in {location}")
51
59
  except Exception as e:
52
60
  logger.error(f"Failed to initialize VertexModels: {str(e)}")
53
61
  raise
54
62
 
55
63
  def load_embeddings_model(
56
- self,
57
- embeddings_model_id: str = "text-multilingual-embedding-002") -> VertexAIEmbeddings: # noqa: E125
64
+ self, embeddings_model_id: str = "text-multilingual-embedding-002"
65
+ ) -> VertexAIEmbeddings: # noqa: E125
58
66
  """
59
67
  Load and return a Vertex AI embeddings model.
60
68
  default embeddings length is 768 https://cloud.google.com/vertex-ai/generative-ai/docs/embeddings/get-text-embeddings
@@ -73,14 +81,18 @@ class VertexModels(AiApplicationService):
73
81
  logger.debug(f"Loaded embedding model: {embeddings_model_id}")
74
82
  return embeddings
75
83
  except Exception as e:
76
- logger.error(f"Failed to load embeddings model {embeddings_model_id}: {str(e)}")
84
+ logger.error(
85
+ f"Failed to load embeddings model {embeddings_model_id}: {str(e)}"
86
+ )
77
87
  raise
78
88
 
79
- def load_chat_model(self,
89
+ def load_chat_model(
90
+ self,
80
91
  temperature: float = 0.15,
81
92
  max_tokens: int = 8192,
82
93
  stop: Optional[List[str]] = None,
83
- **chat_model_params) -> Union[ChatVertexAI, ChatAnthropicVertex]:
94
+ **chat_model_params,
95
+ ) -> Union[ChatVertexAI, ChatAnthropicVertex]:
84
96
  """
85
97
  Load a Vertex AI chat model for text generation.
86
98
 
@@ -98,21 +110,35 @@ class VertexModels(AiApplicationService):
98
110
  """
99
111
  try:
100
112
  if "gemini" in self.llm_model_id:
101
- return self.load_chat_model_gemini(self.llm_model_id, temperature, max_tokens, stop, **chat_model_params)
113
+ return self.load_chat_model_gemini(
114
+ self.llm_model_id,
115
+ temperature,
116
+ max_tokens,
117
+ stop,
118
+ **chat_model_params,
119
+ )
102
120
  elif "claude" in self.llm_model_id:
103
- return self.load_chat_model_anthropic(self.llm_model_id, temperature, max_tokens, stop, **chat_model_params)
121
+ return self.load_chat_model_anthropic(
122
+ self.llm_model_id,
123
+ temperature,
124
+ max_tokens,
125
+ stop,
126
+ **chat_model_params,
127
+ )
104
128
  else:
105
129
  raise ValueError(f"Unsupported chat model: {self.llm_model_id}")
106
130
  except Exception as e:
107
131
  logger.error(f"Failed to retrieve chat model {self.llm_model_id}: {str(e)}")
108
132
  raise
109
133
 
110
- def load_chat_model_gemini(self,
134
+ def load_chat_model_gemini(
135
+ self,
111
136
  chat_model_id: str = "publishers/google/models/gemini-2.5-flash",
112
137
  temperature: float = 0.15,
113
- max_tokens: int = 8192,
138
+ max_tokens: int = 64000,
114
139
  stop: Optional[List[str]] = None,
115
- **chat_model_params) -> ChatVertexAI:
140
+ **chat_model_params,
141
+ ) -> ChatVertexAI:
116
142
  """
117
143
  Load a Vertex AI chat model for text generation.
118
144
 
@@ -137,7 +163,7 @@ class VertexModels(AiApplicationService):
137
163
  max_tokens=max_tokens,
138
164
  max_retries=1,
139
165
  stop=stop,
140
- **chat_model_params
166
+ **chat_model_params,
141
167
  )
142
168
  logger.debug(f"Retrieved chat model: {chat_model_id}")
143
169
  return self.llm_model
@@ -145,12 +171,14 @@ class VertexModels(AiApplicationService):
145
171
  logger.error(f"Failed to retrieve chat model {chat_model_id}: {str(e)}")
146
172
  raise
147
173
 
148
- def load_chat_model_anthropic(self,
174
+ def load_chat_model_anthropic(
175
+ self,
149
176
  chat_model_id: str = "claude-3-5-haiku@20241022",
150
177
  temperature: float = 0.7,
151
- max_tokens: int = 8000,
178
+ max_tokens: int = 64000,
152
179
  stop: Optional[List[str]] = None,
153
- **chat_model_params) -> ChatAnthropicVertex:
180
+ **chat_model_params,
181
+ ) -> ChatAnthropicVertex:
154
182
  """
155
183
  Load a Vertex AI chat model for text generation.
156
184
  """
@@ -163,7 +191,7 @@ class VertexModels(AiApplicationService):
163
191
  max_tokens=max_tokens,
164
192
  max_retries=1,
165
193
  stop=stop,
166
- **chat_model_params
194
+ **chat_model_params,
167
195
  )
168
196
  logger.debug(f"Retrieved chat model: {chat_model_id}")
169
197
  return self.llm_model
@@ -1,4 +1,5 @@
1
1
  import json
2
+ from typing import Dict, Any, Literal
2
3
  from .infra.vertex_model import VertexModels
3
4
  from .application.transcription_service import TranscriptionService
4
5
  from .application.context_chunk_service import ContextChunksInDocumentService
@@ -6,17 +7,79 @@ from .infra.persistence.s3_storage import S3StorageService
6
7
  from .infra.persistence.local_storage import LocalStorageService
7
8
  from .infra.rag.semantic_chunks import SemanticChunks
8
9
  from .infra.rag.redis_embeddings import RedisEmbeddingsManager
10
+ from .infra.rag.chroma_embeddings import ChromaEmbeddingsManager
9
11
  from .infra.secrets.aws_secrets_manager import AwsSecretsManager
12
+ from .data.storage import storage_services, StorageServices
13
+ from .data.kdb import kdb_services, KdbServices
14
+ from .utils.file_utils import has_invalid_file_name_format
15
+ from langsmith import Client, tracing_context
10
16
 
11
- class DeelabTranscribeManager:
12
17
 
13
- def __init__(self,
18
+ class KdbManager:
19
+ def __init__(
20
+ self, embeddings_model, kdb_service: kdb_services, kdb_params: Dict[Any, Any]
21
+ ):
22
+ self.kdb_service = kdb_service
23
+ self.kdb_params = kdb_params
24
+ self.embeddings_model = embeddings_model
25
+
26
+ def retrieve_kdb_service(self):
27
+ if self.kdb_service == KdbServices.REDIS.value:
28
+ return RedisEmbeddingsManager(
29
+ self.embeddings_model,
30
+ **self.kdb_params,
31
+ )
32
+ elif self.kdb_service == KdbServices.CHROMA.value:
33
+ return ChromaEmbeddingsManager(
34
+ self.embeddings_model,
35
+ **self.kdb_params,
36
+ )
37
+ else:
38
+ raise ValueError(f"Unsupported kdb provider: {self.kdb_service}")
39
+
40
+
41
+ class PersistenceManager:
42
+ def __init__(
43
+ self,
44
+ storage_service: storage_services,
45
+ source_storage_route,
46
+ target_storage_route,
47
+ ):
48
+ self.storage_service = storage_service
49
+ self.source_storage_route = source_storage_route
50
+ self.target_storage_route = target_storage_route
51
+
52
+ def retrieve_storage_service(self):
53
+ if self.storage_service == StorageServices.S3.value:
54
+ return S3StorageService(
55
+ origin_bucket_name=self.source_storage_route,
56
+ target_bucket_name=self.target_storage_route,
57
+ )
58
+ elif self.storage_service == StorageServices.LOCAL.value:
59
+ return LocalStorageService(
60
+ source_storage_route=self.source_storage_route,
61
+ target_storage_route=self.target_storage_route,
62
+ )
63
+ else:
64
+ raise ValueError(f"Unsupported storage service: {self.storage_service}")
65
+
66
+
67
+ class TranscriptionManager:
68
+ def __init__(
69
+ self,
14
70
  gcp_project_id: str,
15
71
  gcp_project_location: str,
16
72
  gcp_secret_name: str,
73
+ langsmith_api_key: str,
74
+ langsmith_project_name: str,
75
+ storage_service: storage_services,
76
+ source_storage_route: str,
77
+ target_storage_route: str,
17
78
  llm_model_id: str = "claude-sonnet-4@20250514",
18
- target_language: str = 'es',
19
- transcription_additional_instructions: str = ''
79
+ target_language: str = "es",
80
+ transcription_additional_instructions: str = "",
81
+ transcription_accuracy_threshold: int = 90,
82
+ max_transcription_retries: int = 2,
20
83
  ):
21
84
  self.gcp_project_id = gcp_project_id
22
85
  self.gcp_project_location = gcp_project_location
@@ -24,9 +87,19 @@ class DeelabTranscribeManager:
24
87
  self.gcp_secret_name = gcp_secret_name
25
88
  self.llm_model_id = llm_model_id
26
89
  self.target_language = target_language
27
- self.transcription_additional_instructions = transcription_additional_instructions
90
+ self.storage_service = storage_service
91
+ self.source_storage_route = source_storage_route
92
+ self.target_storage_route = target_storage_route
93
+ self.transcription_additional_instructions = (
94
+ transcription_additional_instructions
95
+ )
96
+ self.transcription_accuracy_threshold = transcription_accuracy_threshold
97
+ self.max_transcription_retries = max_transcription_retries
28
98
  self.gcp_sa_dict = self._get_gcp_sa_dict(gcp_secret_name)
29
99
  self.vertex_model = self._get_vertex_model()
100
+ self.langsmith_api_key = langsmith_api_key
101
+ self.langsmith_project_name = langsmith_project_name
102
+ self.langsmith_client = Client(api_key=self.langsmith_api_key)
30
103
 
31
104
  def _get_gcp_sa_dict(self, gcp_secret_name: str):
32
105
  vertex_gcp_sa = self.aws_secrets_manager.get_secret(gcp_secret_name)
@@ -38,51 +111,92 @@ class DeelabTranscribeManager:
38
111
  self.gcp_project_id,
39
112
  self.gcp_project_location,
40
113
  self.gcp_sa_dict,
41
- llm_model_id=self.llm_model_id
114
+ llm_model_id=self.llm_model_id,
42
115
  )
43
116
  return vertex_model
44
117
 
45
- def aws_cloud_transcribe_document(
46
- self,
47
- file_key: str,
48
- s3_origin_bucket_name: str,
49
- s3_target_bucket_name: str
50
- ):
118
+ def tracing(func):
119
+ def gen_tracing_context(self, *args, **kwargs):
120
+ with tracing_context(
121
+ enabled=True,
122
+ project_name=self.langsmith_project_name,
123
+ client=self.langsmith_client,
124
+ ):
125
+ return func(self, *args, **kwargs)
126
+
127
+ return gen_tracing_context
128
+
129
+ @tracing
130
+ def transcribe_document(self, file_key: str):
131
+ """Transcribe a document from source storage to target storage.
132
+ This method serves as a generic interface for transcribing documents from
133
+ various storage sources to target destinations. The specific implementation
134
+ depends on the storage route types provided.
135
+
136
+ Args:
137
+ file_key (str): The unique identifier or path of the file to be transcribed.
138
+ Returns:
139
+ The result of the transcription process, typically the path or identifier
140
+ of the transcribed document.
141
+
142
+ Raises:
143
+ Exception: If an error occurs during the transcription process.
144
+ """
51
145
  try:
52
- s3_persistence_service = S3StorageService(
53
- origin_bucket_name=s3_origin_bucket_name,
54
- target_bucket_name=s3_target_bucket_name
146
+ if has_invalid_file_name_format(file_key):
147
+ raise ValueError(
148
+ "Invalid file name format, do not provide special characters or spaces (instead use underscores or hyphens)"
149
+ )
150
+ persistence_layer = PersistenceManager(
151
+ self.storage_service,
152
+ self.source_storage_route,
153
+ self.target_storage_route,
55
154
  )
155
+ persistence_service = persistence_layer.retrieve_storage_service()
56
156
 
57
157
  transcribe_document_service = TranscriptionService(
58
158
  ai_application_service=self.vertex_model,
59
- persistence_service=s3_persistence_service,
159
+ persistence_service=persistence_service,
60
160
  target_language=self.target_language,
61
- transcription_additional_instructions=self.transcription_additional_instructions
161
+ transcription_additional_instructions=self.transcription_additional_instructions,
162
+ transcription_accuracy_threshold=self.transcription_accuracy_threshold,
163
+ max_transcription_retries=self.max_transcription_retries,
164
+ )
165
+ parsed_pages, parsed_document = (
166
+ transcribe_document_service.process_document(file_key)
167
+ )
168
+ source_storage_file_tags = {}
169
+ if persistence_service.supports_tagging:
170
+ # source_storage_file_tags.tag_file(file_key, {"status": "transcribed"})
171
+ source_storage_file_tags = persistence_service.retrieve_file_tags(
172
+ file_key, self.source_storage_route
173
+ )
174
+ transcribe_document_service.save_parsed_document(
175
+ f"{file_key}.md", parsed_document, source_storage_file_tags
62
176
  )
63
- parsed_pages, parsed_document = transcribe_document_service.process_document(file_key)
64
- origin_bucket_file_tags = s3_persistence_service.retrieve_file_tags(file_key, s3_origin_bucket_name)
65
- transcribe_document_service.save_parsed_document(f"{file_key}.md", parsed_document, origin_bucket_file_tags)
66
177
  # create md document from parsed_pages
67
178
  print("parsed_pages", len(parsed_pages))
68
179
  # print("parsed_document", parsed_document)
69
180
  return f"{file_key}.md"
70
181
  except Exception as e:
71
- print(f"Error transcribing document: {e}")
182
+ print(f"Error processing document: {e}")
72
183
  raise e
73
184
 
74
185
 
75
- class DeelabRedisChunksManager:
76
-
186
+ class ChunksManager:
77
187
  def __init__(
78
- self,
79
- gcp_project_id: str,
80
- gcp_project_location: str,
81
- gcp_secret_name: str,
82
- redis_connection_string: str,
83
- llm_model_id: str = "claude-3-5-haiku@20241022",
84
- embeddings_model_id: str = "text-multilingual-embedding-002",
85
- target_language: str = "es"
188
+ self,
189
+ gcp_project_id: str,
190
+ gcp_project_location: str,
191
+ gcp_secret_name: str,
192
+ langsmith_api_key: str,
193
+ langsmith_project_name: str,
194
+ storage_service: storage_services,
195
+ kdb_service: Literal["redis", "chroma"],
196
+ kdb_params: Dict[Any, Any],
197
+ llm_model_id: str = "claude-3-5-haiku@20241022",
198
+ embeddings_model_id: str = "text-multilingual-embedding-002",
199
+ target_language: str = "es",
86
200
  ):
87
201
  self.gcp_project_id = gcp_project_id
88
202
  self.gcp_project_location = gcp_project_location
@@ -91,9 +205,16 @@ class DeelabRedisChunksManager:
91
205
  self.llm_model_id = llm_model_id
92
206
  self.target_language = target_language
93
207
  self.gcp_sa_dict = self._get_gcp_sa_dict(gcp_secret_name)
94
- self.redis_connection_string = redis_connection_string
208
+ self.storage_service = storage_service
209
+ self.kdb_params = kdb_params
210
+ self.kdb_service = kdb_service
95
211
  self.vertex_model = self._get_vertex_model()
96
- self.embeddings_model = self.vertex_model.load_embeddings_model(embeddings_model_id)
212
+ self.embeddings_model = self.vertex_model.load_embeddings_model(
213
+ embeddings_model_id
214
+ )
215
+ self.langsmith_api_key = langsmith_api_key
216
+ self.langsmith_project_name = langsmith_project_name
217
+ self.langsmith_client = Client(api_key=self.langsmith_api_key)
97
218
 
98
219
  def _get_gcp_sa_dict(self, gcp_secret_name: str):
99
220
  vertex_gcp_sa = self.aws_secrets_manager.get_secret(gcp_secret_name)
@@ -105,92 +226,57 @@ class DeelabRedisChunksManager:
105
226
  self.gcp_project_id,
106
227
  self.gcp_project_location,
107
228
  self.gcp_sa_dict,
108
- llm_model_id=self.llm_model_id
229
+ llm_model_id=self.llm_model_id,
109
230
  )
110
231
  return vertex_model
111
232
 
112
- def context_chunks_in_document(
113
- self,
114
- file_key: str
115
- ):
116
- try:
117
- rag_chunker = SemanticChunks(self.embeddings_model)
118
- redis_embeddings_manager = RedisEmbeddingsManager(
119
- self.embeddings_model,
120
- self.redis_connection_string,
121
- {
122
- "file_key": file_key
123
- }
124
- )
125
- local_persistence_service = LocalStorageService()
126
- context_chunks_in_document_service = ContextChunksInDocumentService(
127
- ai_application_service=self.vertex_model,
128
- persistence_service=local_persistence_service,
129
- rag_chunker=rag_chunker,
130
- embeddings_manager=redis_embeddings_manager,
131
- target_language=self.target_language
132
- )
133
- context_chunks = context_chunks_in_document_service.get_context_chunks_in_document(file_key)
134
- print("context_chunks", context_chunks)
135
- return context_chunks
136
- except Exception as e:
137
- print(f"Error getting context chunks in document: {e}")
138
- raise e
233
+ def tracing(func):
234
+ def gen_tacing_context(self, *args, **kwargs):
235
+ with tracing_context(
236
+ enabled=True,
237
+ project_name=self.langsmith_project_name,
238
+ client=self.langsmith_client,
239
+ ):
240
+ return func(self, *args, **kwargs)
241
+
242
+ return gen_tacing_context
139
243
 
140
- # TODO
141
- def context_chunks_in_document_from_aws_cloud(
142
- self,
143
- file_key: str,
144
- s3_origin_bucket_name: str,
145
- s3_target_bucket_name: str
146
- ):
244
+ @tracing
245
+ def gen_context_chunks(
246
+ self, file_key: str, source_storage_route: str, target_storage_route: str
247
+ ):
147
248
  try:
148
- s3_persistence_service = S3StorageService(
149
- origin_bucket_name=s3_origin_bucket_name,
150
- target_bucket_name=s3_target_bucket_name
249
+ if has_invalid_file_name_format(file_key):
250
+ raise ValueError(
251
+ "Invalid file name format, do not provide special characters or spaces (instead use underscores or hyphens)"
252
+ )
253
+ persistence_layer = PersistenceManager(
254
+ self.storage_service, source_storage_route, target_storage_route
151
255
  )
152
- target_bucket_file_tags = s3_persistence_service.retrieve_file_tags(file_key, s3_target_bucket_name)
153
-
256
+ persistence_service = persistence_layer.retrieve_storage_service()
257
+ target_bucket_file_tags = []
258
+ if persistence_service.supports_tagging:
259
+ target_bucket_file_tags = persistence_service.retrieve_file_tags(
260
+ file_key, target_storage_route
261
+ )
154
262
  rag_chunker = SemanticChunks(self.embeddings_model)
155
- redis_embeddings_manager = RedisEmbeddingsManager(
156
- embeddings_model=self.embeddings_model,
157
- redis_conn_string=self.redis_connection_string,
158
- metadata_tags=target_bucket_file_tags
263
+ kdb_manager = KdbManager(
264
+ self.embeddings_model, self.kdb_service, self.kdb_params
159
265
  )
266
+ kdb_service = kdb_manager.retrieve_kdb_service()
160
267
  context_chunks_in_document_service = ContextChunksInDocumentService(
161
268
  ai_application_service=self.vertex_model,
162
- persistence_service=s3_persistence_service,
269
+ persistence_service=persistence_service,
163
270
  rag_chunker=rag_chunker,
164
- embeddings_manager=redis_embeddings_manager,
165
- target_language=self.target_language
271
+ embeddings_manager=kdb_service,
272
+ target_language=self.target_language,
273
+ )
274
+ context_chunks = (
275
+ context_chunks_in_document_service.get_context_chunks_in_document(
276
+ file_key, target_bucket_file_tags
277
+ )
166
278
  )
167
- context_chunks = context_chunks_in_document_service.get_context_chunks_in_document(file_key, target_bucket_file_tags)
168
279
  return context_chunks
169
280
  except Exception as e:
170
281
  print(f"Error getting context chunks in document: {e}")
171
282
  raise e
172
-
173
-
174
- def delete_document_context_chunks_from_aws_cloud(
175
- self,
176
- file_key: str,
177
- s3_origin_bucket_name: str,
178
- s3_target_bucket_name: str
179
- ):
180
- pass
181
- # rag_chunker = SemanticChunks(self.embeddings_model)
182
- # pg_embeddings_manager = PgEmbeddingsManager(
183
- # embeddings_model=self.embeddings_model,
184
- # pg_connection=self.vector_store_connection
185
- # )
186
- # s3_persistence_service = S3StorageService(
187
- # origin_bucket_name=s3_origin_bucket_name,
188
- # target_bucket_name=s3_target_bucket_name
189
- # )
190
- # context_chunks_in_document_service = ContextChunksInDocumentService(
191
- # ai_application_service=self.vertex_model,
192
- # persistence_service=s3_persistence_service,
193
- # rag_chunker=rag_chunker,
194
- # embeddings_manager=pg_embeddings_manager
195
- # )
196
- # context_chunks_in_document_service.delete_document_context_chunks(file_key)
@@ -0,0 +1,13 @@
1
+ import re
2
+
3
+
4
+ def has_invalid_file_name_format(file_name):
5
+ """Check if file name has special characters or spaces instead of underscores"""
6
+ # Check for spaces
7
+ if " " in file_name:
8
+ return True
9
+
10
+ # Check for special characters (anything that's not alphanumeric, underscore, dash, or dot)
11
+ if re.search(r"[^a-zA-Z0-9_.-]", file_name):
12
+ return True
13
+ return False
@@ -0,0 +1,73 @@
1
+ from ..data.prompts import WORKFLOW_CONTEXT_CHUNKS_IN_DOCUMENT_SYSTEM_PROMPT
2
+ from langchain_core.prompts import ChatPromptTemplate
3
+ from langchain_core.prompts import MessagesPlaceholder
4
+ from langchain_core.messages import SystemMessage, ToolMessage
5
+ from langgraph.graph import END
6
+ from langgraph.pregel.main import Command
7
+ from .context_state import ContextState
8
+
9
+
10
+ class ContextNodes:
11
+ def __init__(self, llm_model, tools, context_additional_instructions):
12
+ self.llm_model = llm_model
13
+ self.tools = tools
14
+ self.tools_by_name = {tool.name: tool for tool in tools}
15
+ self.context_additional_instructions = context_additional_instructions
16
+
17
+ def gen_context(self, state: ContextState, config):
18
+ try:
19
+ messages = state["messages"]
20
+ document_content = state["document_content"]
21
+ if not messages:
22
+ raise ValueError("No messages provided")
23
+ # parser = PydanticOutputParser(pydantic_object=Transcription)
24
+ # format_instructions=parser.get_format_instructions(),
25
+ formatted_context_system_prompt = WORKFLOW_CONTEXT_CHUNKS_IN_DOCUMENT_SYSTEM_PROMPT.format(
26
+ context_additional_instructions=self.context_additional_instructions,
27
+ document_content=document_content,
28
+ )
29
+
30
+ prompt = ChatPromptTemplate.from_messages(
31
+ [
32
+ SystemMessage(content=formatted_context_system_prompt),
33
+ MessagesPlaceholder("messages"),
34
+ ]
35
+ )
36
+ model_with_structured_output = self.llm_model.bind_tools(self.tools)
37
+ context_chain = prompt | model_with_structured_output
38
+ context_result = context_chain.invoke({"messages": messages})
39
+ return {"messages": [context_result]}
40
+ except Exception as e:
41
+ print(f"Error occurred: {e}")
42
+ raise e
43
+
44
+ def return_context(self, state: ContextState, config):
45
+ latest_message = state["messages"][-1]
46
+ if type(latest_message) is ToolMessage:
47
+ return Command(goto=END, update={"context": latest_message.content})
48
+ else:
49
+ raise ValueError("Invalid message type to return context")
50
+
51
+ def tool_node(self, state: ContextState, config):
52
+ messages = state["messages"]
53
+ tool_calls = messages[-1].tool_calls
54
+ should_end_workflow = False
55
+ observations = []
56
+ for tool_call in tool_calls:
57
+ tool_name = tool_call["name"]
58
+ tool = self.tools_by_name[tool_name]
59
+ tool_result = tool.invoke(tool_call["args"])
60
+ observations.append(
61
+ ToolMessage(
62
+ content=tool_result,
63
+ name=tool_call["name"],
64
+ tool_call_id=tool_call["id"],
65
+ )
66
+ )
67
+ if tool_call["name"] == "complete_context_gen":
68
+ should_end_workflow = True
69
+
70
+ if should_end_workflow:
71
+ return Command(goto="return_context", update={"messages": observations})
72
+ else:
73
+ return Command(goto="gen_context", update={"messages": observations})
@@ -0,0 +1,10 @@
1
+ from typing_extensions import Annotated, TypedDict, Sequence
2
+ from langchain_core.messages import BaseMessage
3
+ from langgraph.graph.message import add_messages
4
+
5
+
6
+ class ContextState(TypedDict):
7
+ messages: Annotated[Sequence[BaseMessage], add_messages]
8
+ document_content: str
9
+ context: str
10
+ context_relevance: float