wizit-context-ingestor 0.2.5b3__py3-none-any.whl → 0.3.0b2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wizit_context_ingestor/__init__.py +2 -2
- wizit_context_ingestor/application/context_chunk_service.py +149 -35
- wizit_context_ingestor/application/transcription_service.py +132 -52
- wizit_context_ingestor/data/kdb.py +10 -0
- wizit_context_ingestor/data/prompts.py +150 -3
- wizit_context_ingestor/data/storage.py +10 -0
- wizit_context_ingestor/infra/persistence/local_storage.py +19 -9
- wizit_context_ingestor/infra/persistence/s3_storage.py +29 -23
- wizit_context_ingestor/infra/rag/chroma_embeddings.py +30 -31
- wizit_context_ingestor/infra/rag/pg_embeddings.py +57 -54
- wizit_context_ingestor/infra/rag/redis_embeddings.py +34 -25
- wizit_context_ingestor/infra/rag/semantic_chunks.py +9 -1
- wizit_context_ingestor/infra/vertex_model.py +56 -28
- wizit_context_ingestor/main.py +192 -106
- wizit_context_ingestor/utils/file_utils.py +13 -0
- wizit_context_ingestor/workflows/context_nodes.py +73 -0
- wizit_context_ingestor/workflows/context_state.py +10 -0
- wizit_context_ingestor/workflows/context_tools.py +58 -0
- wizit_context_ingestor/workflows/context_workflow.py +42 -0
- wizit_context_ingestor/workflows/transcription_nodes.py +136 -0
- wizit_context_ingestor/workflows/transcription_schemas.py +25 -0
- wizit_context_ingestor/workflows/transcription_state.py +17 -0
- wizit_context_ingestor/workflows/transcription_tools.py +54 -0
- wizit_context_ingestor/workflows/transcription_workflow.py +42 -0
- {wizit_context_ingestor-0.2.5b3.dist-info → wizit_context_ingestor-0.3.0b2.dist-info}/METADATA +9 -1
- wizit_context_ingestor-0.3.0b2.dist-info/RECORD +44 -0
- {wizit_context_ingestor-0.2.5b3.dist-info → wizit_context_ingestor-0.3.0b2.dist-info}/WHEEL +1 -1
- wizit_context_ingestor-0.2.5b3.dist-info/RECORD +0 -32
|
@@ -15,14 +15,23 @@ class VertexModels(AiApplicationService):
|
|
|
15
15
|
A wrapper class for Google Cloud Vertex AI models that handles credentials and
|
|
16
16
|
provides methods to load embeddings and chat models.
|
|
17
17
|
"""
|
|
18
|
-
|
|
18
|
+
|
|
19
|
+
__slots__ = (
|
|
20
|
+
"project_id",
|
|
21
|
+
"location",
|
|
22
|
+
"json_service_account",
|
|
23
|
+
"scopes",
|
|
24
|
+
"llm_model_id",
|
|
25
|
+
)
|
|
26
|
+
|
|
19
27
|
def __init__(
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
28
|
+
self,
|
|
29
|
+
project_id: str,
|
|
30
|
+
location: str,
|
|
31
|
+
json_service_account: Dict[str, Any],
|
|
32
|
+
scopes: Optional[List[str]] = None,
|
|
33
|
+
llm_model_id: str = "claude-sonnet-4@20250514",
|
|
34
|
+
):
|
|
26
35
|
"""
|
|
27
36
|
Initialize the VertexModels class with Google Cloud credentials.
|
|
28
37
|
|
|
@@ -36,25 +45,24 @@ class VertexModels(AiApplicationService):
|
|
|
36
45
|
print(location)
|
|
37
46
|
self.scopes = scopes or ["https://www.googleapis.com/auth/cloud-platform"]
|
|
38
47
|
self.credentials = service_account.Credentials.from_service_account_info(
|
|
39
|
-
json_service_account,
|
|
40
|
-
scopes=self.scopes
|
|
48
|
+
json_service_account, scopes=self.scopes
|
|
41
49
|
)
|
|
42
50
|
self.llm_model_id = llm_model_id
|
|
43
51
|
self.project_id = project_id
|
|
44
52
|
self.location = location
|
|
45
53
|
vertexai_init(
|
|
46
|
-
project=project_id,
|
|
47
|
-
|
|
48
|
-
|
|
54
|
+
project=project_id, location=location, credentials=self.credentials
|
|
55
|
+
)
|
|
56
|
+
logger.info(
|
|
57
|
+
f"VertexModels initialized with project {project_id} in {location}"
|
|
49
58
|
)
|
|
50
|
-
logger.info(f"VertexModels initialized with project {project_id} in {location}")
|
|
51
59
|
except Exception as e:
|
|
52
60
|
logger.error(f"Failed to initialize VertexModels: {str(e)}")
|
|
53
61
|
raise
|
|
54
62
|
|
|
55
63
|
def load_embeddings_model(
|
|
56
|
-
self,
|
|
57
|
-
|
|
64
|
+
self, embeddings_model_id: str = "text-multilingual-embedding-002"
|
|
65
|
+
) -> VertexAIEmbeddings: # noqa: E125
|
|
58
66
|
"""
|
|
59
67
|
Load and return a Vertex AI embeddings model.
|
|
60
68
|
default embeddings length is 768 https://cloud.google.com/vertex-ai/generative-ai/docs/embeddings/get-text-embeddings
|
|
@@ -73,14 +81,18 @@ class VertexModels(AiApplicationService):
|
|
|
73
81
|
logger.debug(f"Loaded embedding model: {embeddings_model_id}")
|
|
74
82
|
return embeddings
|
|
75
83
|
except Exception as e:
|
|
76
|
-
logger.error(
|
|
84
|
+
logger.error(
|
|
85
|
+
f"Failed to load embeddings model {embeddings_model_id}: {str(e)}"
|
|
86
|
+
)
|
|
77
87
|
raise
|
|
78
88
|
|
|
79
|
-
def load_chat_model(
|
|
89
|
+
def load_chat_model(
|
|
90
|
+
self,
|
|
80
91
|
temperature: float = 0.15,
|
|
81
92
|
max_tokens: int = 8192,
|
|
82
93
|
stop: Optional[List[str]] = None,
|
|
83
|
-
**chat_model_params
|
|
94
|
+
**chat_model_params,
|
|
95
|
+
) -> Union[ChatVertexAI, ChatAnthropicVertex]:
|
|
84
96
|
"""
|
|
85
97
|
Load a Vertex AI chat model for text generation.
|
|
86
98
|
|
|
@@ -98,21 +110,35 @@ class VertexModels(AiApplicationService):
|
|
|
98
110
|
"""
|
|
99
111
|
try:
|
|
100
112
|
if "gemini" in self.llm_model_id:
|
|
101
|
-
return self.load_chat_model_gemini(
|
|
113
|
+
return self.load_chat_model_gemini(
|
|
114
|
+
self.llm_model_id,
|
|
115
|
+
temperature,
|
|
116
|
+
max_tokens,
|
|
117
|
+
stop,
|
|
118
|
+
**chat_model_params,
|
|
119
|
+
)
|
|
102
120
|
elif "claude" in self.llm_model_id:
|
|
103
|
-
return self.load_chat_model_anthropic(
|
|
121
|
+
return self.load_chat_model_anthropic(
|
|
122
|
+
self.llm_model_id,
|
|
123
|
+
temperature,
|
|
124
|
+
max_tokens,
|
|
125
|
+
stop,
|
|
126
|
+
**chat_model_params,
|
|
127
|
+
)
|
|
104
128
|
else:
|
|
105
129
|
raise ValueError(f"Unsupported chat model: {self.llm_model_id}")
|
|
106
130
|
except Exception as e:
|
|
107
131
|
logger.error(f"Failed to retrieve chat model {self.llm_model_id}: {str(e)}")
|
|
108
132
|
raise
|
|
109
133
|
|
|
110
|
-
def load_chat_model_gemini(
|
|
134
|
+
def load_chat_model_gemini(
|
|
135
|
+
self,
|
|
111
136
|
chat_model_id: str = "publishers/google/models/gemini-2.5-flash",
|
|
112
137
|
temperature: float = 0.15,
|
|
113
|
-
max_tokens: int =
|
|
138
|
+
max_tokens: int = 64000,
|
|
114
139
|
stop: Optional[List[str]] = None,
|
|
115
|
-
**chat_model_params
|
|
140
|
+
**chat_model_params,
|
|
141
|
+
) -> ChatVertexAI:
|
|
116
142
|
"""
|
|
117
143
|
Load a Vertex AI chat model for text generation.
|
|
118
144
|
|
|
@@ -137,7 +163,7 @@ class VertexModels(AiApplicationService):
|
|
|
137
163
|
max_tokens=max_tokens,
|
|
138
164
|
max_retries=1,
|
|
139
165
|
stop=stop,
|
|
140
|
-
**chat_model_params
|
|
166
|
+
**chat_model_params,
|
|
141
167
|
)
|
|
142
168
|
logger.debug(f"Retrieved chat model: {chat_model_id}")
|
|
143
169
|
return self.llm_model
|
|
@@ -145,12 +171,14 @@ class VertexModels(AiApplicationService):
|
|
|
145
171
|
logger.error(f"Failed to retrieve chat model {chat_model_id}: {str(e)}")
|
|
146
172
|
raise
|
|
147
173
|
|
|
148
|
-
def load_chat_model_anthropic(
|
|
174
|
+
def load_chat_model_anthropic(
|
|
175
|
+
self,
|
|
149
176
|
chat_model_id: str = "claude-3-5-haiku@20241022",
|
|
150
177
|
temperature: float = 0.7,
|
|
151
|
-
max_tokens: int =
|
|
178
|
+
max_tokens: int = 64000,
|
|
152
179
|
stop: Optional[List[str]] = None,
|
|
153
|
-
**chat_model_params
|
|
180
|
+
**chat_model_params,
|
|
181
|
+
) -> ChatAnthropicVertex:
|
|
154
182
|
"""
|
|
155
183
|
Load a Vertex AI chat model for text generation.
|
|
156
184
|
"""
|
|
@@ -163,7 +191,7 @@ class VertexModels(AiApplicationService):
|
|
|
163
191
|
max_tokens=max_tokens,
|
|
164
192
|
max_retries=1,
|
|
165
193
|
stop=stop,
|
|
166
|
-
**chat_model_params
|
|
194
|
+
**chat_model_params,
|
|
167
195
|
)
|
|
168
196
|
logger.debug(f"Retrieved chat model: {chat_model_id}")
|
|
169
197
|
return self.llm_model
|
wizit_context_ingestor/main.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import json
|
|
2
|
+
from typing import Dict, Any, Literal
|
|
2
3
|
from .infra.vertex_model import VertexModels
|
|
3
4
|
from .application.transcription_service import TranscriptionService
|
|
4
5
|
from .application.context_chunk_service import ContextChunksInDocumentService
|
|
@@ -6,17 +7,79 @@ from .infra.persistence.s3_storage import S3StorageService
|
|
|
6
7
|
from .infra.persistence.local_storage import LocalStorageService
|
|
7
8
|
from .infra.rag.semantic_chunks import SemanticChunks
|
|
8
9
|
from .infra.rag.redis_embeddings import RedisEmbeddingsManager
|
|
10
|
+
from .infra.rag.chroma_embeddings import ChromaEmbeddingsManager
|
|
9
11
|
from .infra.secrets.aws_secrets_manager import AwsSecretsManager
|
|
12
|
+
from .data.storage import storage_services, StorageServices
|
|
13
|
+
from .data.kdb import kdb_services, KdbServices
|
|
14
|
+
from .utils.file_utils import has_invalid_file_name_format
|
|
15
|
+
from langsmith import Client, tracing_context
|
|
10
16
|
|
|
11
|
-
class DeelabTranscribeManager:
|
|
12
17
|
|
|
13
|
-
|
|
18
|
+
class KdbManager:
|
|
19
|
+
def __init__(
|
|
20
|
+
self, embeddings_model, kdb_service: kdb_services, kdb_params: Dict[Any, Any]
|
|
21
|
+
):
|
|
22
|
+
self.kdb_service = kdb_service
|
|
23
|
+
self.kdb_params = kdb_params
|
|
24
|
+
self.embeddings_model = embeddings_model
|
|
25
|
+
|
|
26
|
+
def retrieve_kdb_service(self):
|
|
27
|
+
if self.kdb_service == KdbServices.REDIS.value:
|
|
28
|
+
return RedisEmbeddingsManager(
|
|
29
|
+
self.embeddings_model,
|
|
30
|
+
**self.kdb_params,
|
|
31
|
+
)
|
|
32
|
+
elif self.kdb_service == KdbServices.CHROMA.value:
|
|
33
|
+
return ChromaEmbeddingsManager(
|
|
34
|
+
self.embeddings_model,
|
|
35
|
+
**self.kdb_params,
|
|
36
|
+
)
|
|
37
|
+
else:
|
|
38
|
+
raise ValueError(f"Unsupported kdb provider: {self.kdb_service}")
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class PersistenceManager:
|
|
42
|
+
def __init__(
|
|
43
|
+
self,
|
|
44
|
+
storage_service: storage_services,
|
|
45
|
+
source_storage_route,
|
|
46
|
+
target_storage_route,
|
|
47
|
+
):
|
|
48
|
+
self.storage_service = storage_service
|
|
49
|
+
self.source_storage_route = source_storage_route
|
|
50
|
+
self.target_storage_route = target_storage_route
|
|
51
|
+
|
|
52
|
+
def retrieve_storage_service(self):
|
|
53
|
+
if self.storage_service == StorageServices.S3.value:
|
|
54
|
+
return S3StorageService(
|
|
55
|
+
origin_bucket_name=self.source_storage_route,
|
|
56
|
+
target_bucket_name=self.target_storage_route,
|
|
57
|
+
)
|
|
58
|
+
elif self.storage_service == StorageServices.LOCAL.value:
|
|
59
|
+
return LocalStorageService(
|
|
60
|
+
source_storage_route=self.source_storage_route,
|
|
61
|
+
target_storage_route=self.target_storage_route,
|
|
62
|
+
)
|
|
63
|
+
else:
|
|
64
|
+
raise ValueError(f"Unsupported storage service: {self.storage_service}")
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class TranscriptionManager:
|
|
68
|
+
def __init__(
|
|
69
|
+
self,
|
|
14
70
|
gcp_project_id: str,
|
|
15
71
|
gcp_project_location: str,
|
|
16
72
|
gcp_secret_name: str,
|
|
73
|
+
langsmith_api_key: str,
|
|
74
|
+
langsmith_project_name: str,
|
|
75
|
+
storage_service: storage_services,
|
|
76
|
+
source_storage_route: str,
|
|
77
|
+
target_storage_route: str,
|
|
17
78
|
llm_model_id: str = "claude-sonnet-4@20250514",
|
|
18
|
-
target_language: str =
|
|
19
|
-
transcription_additional_instructions: str =
|
|
79
|
+
target_language: str = "es",
|
|
80
|
+
transcription_additional_instructions: str = "",
|
|
81
|
+
transcription_accuracy_threshold: int = 90,
|
|
82
|
+
max_transcription_retries: int = 2,
|
|
20
83
|
):
|
|
21
84
|
self.gcp_project_id = gcp_project_id
|
|
22
85
|
self.gcp_project_location = gcp_project_location
|
|
@@ -24,9 +87,19 @@ class DeelabTranscribeManager:
|
|
|
24
87
|
self.gcp_secret_name = gcp_secret_name
|
|
25
88
|
self.llm_model_id = llm_model_id
|
|
26
89
|
self.target_language = target_language
|
|
27
|
-
self.
|
|
90
|
+
self.storage_service = storage_service
|
|
91
|
+
self.source_storage_route = source_storage_route
|
|
92
|
+
self.target_storage_route = target_storage_route
|
|
93
|
+
self.transcription_additional_instructions = (
|
|
94
|
+
transcription_additional_instructions
|
|
95
|
+
)
|
|
96
|
+
self.transcription_accuracy_threshold = transcription_accuracy_threshold
|
|
97
|
+
self.max_transcription_retries = max_transcription_retries
|
|
28
98
|
self.gcp_sa_dict = self._get_gcp_sa_dict(gcp_secret_name)
|
|
29
99
|
self.vertex_model = self._get_vertex_model()
|
|
100
|
+
self.langsmith_api_key = langsmith_api_key
|
|
101
|
+
self.langsmith_project_name = langsmith_project_name
|
|
102
|
+
self.langsmith_client = Client(api_key=self.langsmith_api_key)
|
|
30
103
|
|
|
31
104
|
def _get_gcp_sa_dict(self, gcp_secret_name: str):
|
|
32
105
|
vertex_gcp_sa = self.aws_secrets_manager.get_secret(gcp_secret_name)
|
|
@@ -38,51 +111,92 @@ class DeelabTranscribeManager:
|
|
|
38
111
|
self.gcp_project_id,
|
|
39
112
|
self.gcp_project_location,
|
|
40
113
|
self.gcp_sa_dict,
|
|
41
|
-
llm_model_id=self.llm_model_id
|
|
114
|
+
llm_model_id=self.llm_model_id,
|
|
42
115
|
)
|
|
43
116
|
return vertex_model
|
|
44
117
|
|
|
45
|
-
def
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
118
|
+
def tracing(func):
|
|
119
|
+
def gen_tracing_context(self, *args, **kwargs):
|
|
120
|
+
with tracing_context(
|
|
121
|
+
enabled=True,
|
|
122
|
+
project_name=self.langsmith_project_name,
|
|
123
|
+
client=self.langsmith_client,
|
|
124
|
+
):
|
|
125
|
+
return func(self, *args, **kwargs)
|
|
126
|
+
|
|
127
|
+
return gen_tracing_context
|
|
128
|
+
|
|
129
|
+
@tracing
|
|
130
|
+
def transcribe_document(self, file_key: str):
|
|
131
|
+
"""Transcribe a document from source storage to target storage.
|
|
132
|
+
This method serves as a generic interface for transcribing documents from
|
|
133
|
+
various storage sources to target destinations. The specific implementation
|
|
134
|
+
depends on the storage route types provided.
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
file_key (str): The unique identifier or path of the file to be transcribed.
|
|
138
|
+
Returns:
|
|
139
|
+
The result of the transcription process, typically the path or identifier
|
|
140
|
+
of the transcribed document.
|
|
141
|
+
|
|
142
|
+
Raises:
|
|
143
|
+
Exception: If an error occurs during the transcription process.
|
|
144
|
+
"""
|
|
51
145
|
try:
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
146
|
+
if has_invalid_file_name_format(file_key):
|
|
147
|
+
raise ValueError(
|
|
148
|
+
"Invalid file name format, do not provide special characters or spaces (instead use underscores or hyphens)"
|
|
149
|
+
)
|
|
150
|
+
persistence_layer = PersistenceManager(
|
|
151
|
+
self.storage_service,
|
|
152
|
+
self.source_storage_route,
|
|
153
|
+
self.target_storage_route,
|
|
55
154
|
)
|
|
155
|
+
persistence_service = persistence_layer.retrieve_storage_service()
|
|
56
156
|
|
|
57
157
|
transcribe_document_service = TranscriptionService(
|
|
58
158
|
ai_application_service=self.vertex_model,
|
|
59
|
-
persistence_service=
|
|
159
|
+
persistence_service=persistence_service,
|
|
60
160
|
target_language=self.target_language,
|
|
61
|
-
transcription_additional_instructions=self.transcription_additional_instructions
|
|
161
|
+
transcription_additional_instructions=self.transcription_additional_instructions,
|
|
162
|
+
transcription_accuracy_threshold=self.transcription_accuracy_threshold,
|
|
163
|
+
max_transcription_retries=self.max_transcription_retries,
|
|
164
|
+
)
|
|
165
|
+
parsed_pages, parsed_document = (
|
|
166
|
+
transcribe_document_service.process_document(file_key)
|
|
167
|
+
)
|
|
168
|
+
source_storage_file_tags = {}
|
|
169
|
+
if persistence_service.supports_tagging:
|
|
170
|
+
# source_storage_file_tags.tag_file(file_key, {"status": "transcribed"})
|
|
171
|
+
source_storage_file_tags = persistence_service.retrieve_file_tags(
|
|
172
|
+
file_key, self.source_storage_route
|
|
173
|
+
)
|
|
174
|
+
transcribe_document_service.save_parsed_document(
|
|
175
|
+
f"{file_key}.md", parsed_document, source_storage_file_tags
|
|
62
176
|
)
|
|
63
|
-
parsed_pages, parsed_document = transcribe_document_service.process_document(file_key)
|
|
64
|
-
origin_bucket_file_tags = s3_persistence_service.retrieve_file_tags(file_key, s3_origin_bucket_name)
|
|
65
|
-
transcribe_document_service.save_parsed_document(f"{file_key}.md", parsed_document, origin_bucket_file_tags)
|
|
66
177
|
# create md document from parsed_pages
|
|
67
178
|
print("parsed_pages", len(parsed_pages))
|
|
68
179
|
# print("parsed_document", parsed_document)
|
|
69
180
|
return f"{file_key}.md"
|
|
70
181
|
except Exception as e:
|
|
71
|
-
print(f"Error
|
|
182
|
+
print(f"Error processing document: {e}")
|
|
72
183
|
raise e
|
|
73
184
|
|
|
74
185
|
|
|
75
|
-
class
|
|
76
|
-
|
|
186
|
+
class ChunksManager:
|
|
77
187
|
def __init__(
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
188
|
+
self,
|
|
189
|
+
gcp_project_id: str,
|
|
190
|
+
gcp_project_location: str,
|
|
191
|
+
gcp_secret_name: str,
|
|
192
|
+
langsmith_api_key: str,
|
|
193
|
+
langsmith_project_name: str,
|
|
194
|
+
storage_service: storage_services,
|
|
195
|
+
kdb_service: Literal["redis", "chroma"],
|
|
196
|
+
kdb_params: Dict[Any, Any],
|
|
197
|
+
llm_model_id: str = "claude-3-5-haiku@20241022",
|
|
198
|
+
embeddings_model_id: str = "text-multilingual-embedding-002",
|
|
199
|
+
target_language: str = "es",
|
|
86
200
|
):
|
|
87
201
|
self.gcp_project_id = gcp_project_id
|
|
88
202
|
self.gcp_project_location = gcp_project_location
|
|
@@ -91,9 +205,16 @@ class DeelabRedisChunksManager:
|
|
|
91
205
|
self.llm_model_id = llm_model_id
|
|
92
206
|
self.target_language = target_language
|
|
93
207
|
self.gcp_sa_dict = self._get_gcp_sa_dict(gcp_secret_name)
|
|
94
|
-
self.
|
|
208
|
+
self.storage_service = storage_service
|
|
209
|
+
self.kdb_params = kdb_params
|
|
210
|
+
self.kdb_service = kdb_service
|
|
95
211
|
self.vertex_model = self._get_vertex_model()
|
|
96
|
-
self.embeddings_model = self.vertex_model.load_embeddings_model(
|
|
212
|
+
self.embeddings_model = self.vertex_model.load_embeddings_model(
|
|
213
|
+
embeddings_model_id
|
|
214
|
+
)
|
|
215
|
+
self.langsmith_api_key = langsmith_api_key
|
|
216
|
+
self.langsmith_project_name = langsmith_project_name
|
|
217
|
+
self.langsmith_client = Client(api_key=self.langsmith_api_key)
|
|
97
218
|
|
|
98
219
|
def _get_gcp_sa_dict(self, gcp_secret_name: str):
|
|
99
220
|
vertex_gcp_sa = self.aws_secrets_manager.get_secret(gcp_secret_name)
|
|
@@ -105,92 +226,57 @@ class DeelabRedisChunksManager:
|
|
|
105
226
|
self.gcp_project_id,
|
|
106
227
|
self.gcp_project_location,
|
|
107
228
|
self.gcp_sa_dict,
|
|
108
|
-
llm_model_id=self.llm_model_id
|
|
229
|
+
llm_model_id=self.llm_model_id,
|
|
109
230
|
)
|
|
110
231
|
return vertex_model
|
|
111
232
|
|
|
112
|
-
def
|
|
113
|
-
self,
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
self
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
"file_key": file_key
|
|
123
|
-
}
|
|
124
|
-
)
|
|
125
|
-
local_persistence_service = LocalStorageService()
|
|
126
|
-
context_chunks_in_document_service = ContextChunksInDocumentService(
|
|
127
|
-
ai_application_service=self.vertex_model,
|
|
128
|
-
persistence_service=local_persistence_service,
|
|
129
|
-
rag_chunker=rag_chunker,
|
|
130
|
-
embeddings_manager=redis_embeddings_manager,
|
|
131
|
-
target_language=self.target_language
|
|
132
|
-
)
|
|
133
|
-
context_chunks = context_chunks_in_document_service.get_context_chunks_in_document(file_key)
|
|
134
|
-
print("context_chunks", context_chunks)
|
|
135
|
-
return context_chunks
|
|
136
|
-
except Exception as e:
|
|
137
|
-
print(f"Error getting context chunks in document: {e}")
|
|
138
|
-
raise e
|
|
233
|
+
def tracing(func):
|
|
234
|
+
def gen_tacing_context(self, *args, **kwargs):
|
|
235
|
+
with tracing_context(
|
|
236
|
+
enabled=True,
|
|
237
|
+
project_name=self.langsmith_project_name,
|
|
238
|
+
client=self.langsmith_client,
|
|
239
|
+
):
|
|
240
|
+
return func(self, *args, **kwargs)
|
|
241
|
+
|
|
242
|
+
return gen_tacing_context
|
|
139
243
|
|
|
140
|
-
|
|
141
|
-
def
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
s3_origin_bucket_name: str,
|
|
145
|
-
s3_target_bucket_name: str
|
|
146
|
-
):
|
|
244
|
+
@tracing
|
|
245
|
+
def gen_context_chunks(
|
|
246
|
+
self, file_key: str, source_storage_route: str, target_storage_route: str
|
|
247
|
+
):
|
|
147
248
|
try:
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
249
|
+
if has_invalid_file_name_format(file_key):
|
|
250
|
+
raise ValueError(
|
|
251
|
+
"Invalid file name format, do not provide special characters or spaces (instead use underscores or hyphens)"
|
|
252
|
+
)
|
|
253
|
+
persistence_layer = PersistenceManager(
|
|
254
|
+
self.storage_service, source_storage_route, target_storage_route
|
|
151
255
|
)
|
|
152
|
-
|
|
153
|
-
|
|
256
|
+
persistence_service = persistence_layer.retrieve_storage_service()
|
|
257
|
+
target_bucket_file_tags = []
|
|
258
|
+
if persistence_service.supports_tagging:
|
|
259
|
+
target_bucket_file_tags = persistence_service.retrieve_file_tags(
|
|
260
|
+
file_key, target_storage_route
|
|
261
|
+
)
|
|
154
262
|
rag_chunker = SemanticChunks(self.embeddings_model)
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
redis_conn_string=self.redis_connection_string,
|
|
158
|
-
metadata_tags=target_bucket_file_tags
|
|
263
|
+
kdb_manager = KdbManager(
|
|
264
|
+
self.embeddings_model, self.kdb_service, self.kdb_params
|
|
159
265
|
)
|
|
266
|
+
kdb_service = kdb_manager.retrieve_kdb_service()
|
|
160
267
|
context_chunks_in_document_service = ContextChunksInDocumentService(
|
|
161
268
|
ai_application_service=self.vertex_model,
|
|
162
|
-
persistence_service=
|
|
269
|
+
persistence_service=persistence_service,
|
|
163
270
|
rag_chunker=rag_chunker,
|
|
164
|
-
embeddings_manager=
|
|
165
|
-
target_language=self.target_language
|
|
271
|
+
embeddings_manager=kdb_service,
|
|
272
|
+
target_language=self.target_language,
|
|
273
|
+
)
|
|
274
|
+
context_chunks = (
|
|
275
|
+
context_chunks_in_document_service.get_context_chunks_in_document(
|
|
276
|
+
file_key, target_bucket_file_tags
|
|
277
|
+
)
|
|
166
278
|
)
|
|
167
|
-
context_chunks = context_chunks_in_document_service.get_context_chunks_in_document(file_key, target_bucket_file_tags)
|
|
168
279
|
return context_chunks
|
|
169
280
|
except Exception as e:
|
|
170
281
|
print(f"Error getting context chunks in document: {e}")
|
|
171
282
|
raise e
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
def delete_document_context_chunks_from_aws_cloud(
|
|
175
|
-
self,
|
|
176
|
-
file_key: str,
|
|
177
|
-
s3_origin_bucket_name: str,
|
|
178
|
-
s3_target_bucket_name: str
|
|
179
|
-
):
|
|
180
|
-
pass
|
|
181
|
-
# rag_chunker = SemanticChunks(self.embeddings_model)
|
|
182
|
-
# pg_embeddings_manager = PgEmbeddingsManager(
|
|
183
|
-
# embeddings_model=self.embeddings_model,
|
|
184
|
-
# pg_connection=self.vector_store_connection
|
|
185
|
-
# )
|
|
186
|
-
# s3_persistence_service = S3StorageService(
|
|
187
|
-
# origin_bucket_name=s3_origin_bucket_name,
|
|
188
|
-
# target_bucket_name=s3_target_bucket_name
|
|
189
|
-
# )
|
|
190
|
-
# context_chunks_in_document_service = ContextChunksInDocumentService(
|
|
191
|
-
# ai_application_service=self.vertex_model,
|
|
192
|
-
# persistence_service=s3_persistence_service,
|
|
193
|
-
# rag_chunker=rag_chunker,
|
|
194
|
-
# embeddings_manager=pg_embeddings_manager
|
|
195
|
-
# )
|
|
196
|
-
# context_chunks_in_document_service.delete_document_context_chunks(file_key)
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
import re
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def has_invalid_file_name_format(file_name):
|
|
5
|
+
"""Check if file name has special characters or spaces instead of underscores"""
|
|
6
|
+
# Check for spaces
|
|
7
|
+
if " " in file_name:
|
|
8
|
+
return True
|
|
9
|
+
|
|
10
|
+
# Check for special characters (anything that's not alphanumeric, underscore, dash, or dot)
|
|
11
|
+
if re.search(r"[^a-zA-Z0-9_.-]", file_name):
|
|
12
|
+
return True
|
|
13
|
+
return False
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
from ..data.prompts import WORKFLOW_CONTEXT_CHUNKS_IN_DOCUMENT_SYSTEM_PROMPT
|
|
2
|
+
from langchain_core.prompts import ChatPromptTemplate
|
|
3
|
+
from langchain_core.prompts import MessagesPlaceholder
|
|
4
|
+
from langchain_core.messages import SystemMessage, ToolMessage
|
|
5
|
+
from langgraph.graph import END
|
|
6
|
+
from langgraph.pregel.main import Command
|
|
7
|
+
from .context_state import ContextState
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class ContextNodes:
|
|
11
|
+
def __init__(self, llm_model, tools, context_additional_instructions):
|
|
12
|
+
self.llm_model = llm_model
|
|
13
|
+
self.tools = tools
|
|
14
|
+
self.tools_by_name = {tool.name: tool for tool in tools}
|
|
15
|
+
self.context_additional_instructions = context_additional_instructions
|
|
16
|
+
|
|
17
|
+
def gen_context(self, state: ContextState, config):
|
|
18
|
+
try:
|
|
19
|
+
messages = state["messages"]
|
|
20
|
+
document_content = state["document_content"]
|
|
21
|
+
if not messages:
|
|
22
|
+
raise ValueError("No messages provided")
|
|
23
|
+
# parser = PydanticOutputParser(pydantic_object=Transcription)
|
|
24
|
+
# format_instructions=parser.get_format_instructions(),
|
|
25
|
+
formatted_context_system_prompt = WORKFLOW_CONTEXT_CHUNKS_IN_DOCUMENT_SYSTEM_PROMPT.format(
|
|
26
|
+
context_additional_instructions=self.context_additional_instructions,
|
|
27
|
+
document_content=document_content,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
prompt = ChatPromptTemplate.from_messages(
|
|
31
|
+
[
|
|
32
|
+
SystemMessage(content=formatted_context_system_prompt),
|
|
33
|
+
MessagesPlaceholder("messages"),
|
|
34
|
+
]
|
|
35
|
+
)
|
|
36
|
+
model_with_structured_output = self.llm_model.bind_tools(self.tools)
|
|
37
|
+
context_chain = prompt | model_with_structured_output
|
|
38
|
+
context_result = context_chain.invoke({"messages": messages})
|
|
39
|
+
return {"messages": [context_result]}
|
|
40
|
+
except Exception as e:
|
|
41
|
+
print(f"Error occurred: {e}")
|
|
42
|
+
raise e
|
|
43
|
+
|
|
44
|
+
def return_context(self, state: ContextState, config):
|
|
45
|
+
latest_message = state["messages"][-1]
|
|
46
|
+
if type(latest_message) is ToolMessage:
|
|
47
|
+
return Command(goto=END, update={"context": latest_message.content})
|
|
48
|
+
else:
|
|
49
|
+
raise ValueError("Invalid message type to return context")
|
|
50
|
+
|
|
51
|
+
def tool_node(self, state: ContextState, config):
|
|
52
|
+
messages = state["messages"]
|
|
53
|
+
tool_calls = messages[-1].tool_calls
|
|
54
|
+
should_end_workflow = False
|
|
55
|
+
observations = []
|
|
56
|
+
for tool_call in tool_calls:
|
|
57
|
+
tool_name = tool_call["name"]
|
|
58
|
+
tool = self.tools_by_name[tool_name]
|
|
59
|
+
tool_result = tool.invoke(tool_call["args"])
|
|
60
|
+
observations.append(
|
|
61
|
+
ToolMessage(
|
|
62
|
+
content=tool_result,
|
|
63
|
+
name=tool_call["name"],
|
|
64
|
+
tool_call_id=tool_call["id"],
|
|
65
|
+
)
|
|
66
|
+
)
|
|
67
|
+
if tool_call["name"] == "complete_context_gen":
|
|
68
|
+
should_end_workflow = True
|
|
69
|
+
|
|
70
|
+
if should_end_workflow:
|
|
71
|
+
return Command(goto="return_context", update={"messages": observations})
|
|
72
|
+
else:
|
|
73
|
+
return Command(goto="gen_context", update={"messages": observations})
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
from typing_extensions import Annotated, TypedDict, Sequence
|
|
2
|
+
from langchain_core.messages import BaseMessage
|
|
3
|
+
from langgraph.graph.message import add_messages
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ContextState(TypedDict):
|
|
7
|
+
messages: Annotated[Sequence[BaseMessage], add_messages]
|
|
8
|
+
document_content: str
|
|
9
|
+
context: str
|
|
10
|
+
context_relevance: float
|