user-simulator 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. user_sim/__init__.py +0 -0
  2. user_sim/cli/__init__.py +0 -0
  3. user_sim/cli/gen_user_profile.py +34 -0
  4. user_sim/cli/init_project.py +65 -0
  5. user_sim/cli/sensei_chat.py +481 -0
  6. user_sim/cli/sensei_check.py +103 -0
  7. user_sim/cli/validation_check.py +143 -0
  8. user_sim/core/__init__.py +0 -0
  9. user_sim/core/ask_about.py +665 -0
  10. user_sim/core/data_extraction.py +260 -0
  11. user_sim/core/data_gathering.py +134 -0
  12. user_sim/core/interaction_styles.py +147 -0
  13. user_sim/core/role_structure.py +608 -0
  14. user_sim/core/user_simulator.py +302 -0
  15. user_sim/handlers/__init__.py +0 -0
  16. user_sim/handlers/asr_module.py +128 -0
  17. user_sim/handlers/html_parser_module.py +202 -0
  18. user_sim/handlers/image_recognition_module.py +139 -0
  19. user_sim/handlers/pdf_parser_module.py +123 -0
  20. user_sim/utils/__init__.py +0 -0
  21. user_sim/utils/config.py +47 -0
  22. user_sim/utils/cost_tracker.py +153 -0
  23. user_sim/utils/cost_tracker_v2.py +193 -0
  24. user_sim/utils/errors.py +15 -0
  25. user_sim/utils/exceptions.py +47 -0
  26. user_sim/utils/languages.py +78 -0
  27. user_sim/utils/register_management.py +62 -0
  28. user_sim/utils/show_logs.py +63 -0
  29. user_sim/utils/token_cost_calculator.py +338 -0
  30. user_sim/utils/url_management.py +60 -0
  31. user_sim/utils/utilities.py +568 -0
  32. user_simulator-0.1.0.dist-info/METADATA +733 -0
  33. user_simulator-0.1.0.dist-info/RECORD +37 -0
  34. user_simulator-0.1.0.dist-info/WHEEL +5 -0
  35. user_simulator-0.1.0.dist-info/entry_points.txt +6 -0
  36. user_simulator-0.1.0.dist-info/licenses/LICENSE.txt +21 -0
  37. user_simulator-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,47 @@
1
+ class InvalidGoalException(Exception):
2
+ pass
3
+
4
+ class InvalidInteractionException(Exception):
5
+ pass
6
+
7
+ class InvalidLanguageException(Exception):
8
+ pass
9
+
10
+ class OutOfLimitException(Exception):
11
+ pass
12
+
13
+ class BadDictionaryGeneration(Exception):
14
+ pass
15
+
16
+ class InvalidItemType(Exception):
17
+ pass
18
+
19
+ class EmptyListExcept(Exception):
20
+ pass
21
+
22
+ class InvalidDataType(Exception):
23
+ pass
24
+
25
+ class InvalidFormat(Exception):
26
+ pass
27
+
28
+ class MissingStepDefinition(Exception):
29
+ pass
30
+
31
+ class InvalidGenerator(Exception):
32
+ pass
33
+
34
+ class VariableNotFound(Exception):
35
+ pass
36
+
37
+ class InvalidDependence(Exception):
38
+ pass
39
+
40
+ class InvalidFile(Exception):
41
+ pass
42
+
43
+ class NoCostException(Exception):
44
+ pass
45
+
46
+ class UnmachedList(Exception):
47
+ pass
@@ -0,0 +1,78 @@
1
+ languages = [
2
+ "Afrikaans", "Albanian", "Amharic", "Arabic", "Armenian", "Azerbaijani", "Bengali", "Bosnian", "Bulgarian",
3
+ "Catalan", "Chinese (Simplified)", "Chinese (Traditional)", "Croatian", "Czech", "Danish", "Dutch",
4
+ "English", "Estonian", "Filipino", "Finnish", "French", "Galician", "Georgian", "German", "Greek",
5
+ "Gujarati", "Hausa", "Hebrew", "Hindi", "Hungarian", "Icelandic", "Indonesian", "Italian", "Japanese",
6
+ "Kannada", "Kazakh", "Korean", "Latvian", "Lithuanian", "Macedonian", "Malay", "Malayalam", "Marathi",
7
+ "Nepali", "Norwegian", "Persian", "Polish", "Portuguese", "Punjabi", "Romanian", "Russian", "Serbian",
8
+ "Slovak", "Slovenian", "Spanish", "Swahili", "Swedish", "Tamil", "Telugu", "Thai", "Turkish", "Ukrainian",
9
+ "Urdu", "Vietnamese", "Zulu"
10
+ ]
11
+
12
+ languages_weights = {
13
+ "Afrikaans": 1,
14
+ "Albanian": 1,
15
+ "Amharic": 1,
16
+ "Arabic": 1,
17
+ "Armenian": 1,
18
+ "Azerbaijani": 1,
19
+ "Bengali": 1,
20
+ "Bosnian": 1,
21
+ "Bulgarian": 1,
22
+ "Catalan": 1,
23
+ "Chinese (Simplified)": 1,
24
+ "Chinese (Traditional)": 1,
25
+ "Croatian": 1,
26
+ "Czech": 1,
27
+ "Danish": 1,
28
+ "Dutch": 1,
29
+ "English": 1,
30
+ "Estonian": 1,
31
+ "Filipino": 1,
32
+ "Finnish": 1,
33
+ "French": 1,
34
+ "Galician": 1,
35
+ "Georgian": 1,
36
+ "German": 1,
37
+ "Greek": 1,
38
+ "Gujarati": 1,
39
+ "Hausa": 1,
40
+ "Hebrew": 1,
41
+ "Hindi": 1,
42
+ "Hungarian": 1,
43
+ "Icelandic": 1,
44
+ "Indonesian": 1,
45
+ "Italian": 1,
46
+ "Japanese": 1,
47
+ "Kannada": 1,
48
+ "Kazakh": 1,
49
+ "Korean": 1,
50
+ "Latvian": 1,
51
+ "Lithuanian": 1,
52
+ "Macedonian": 1,
53
+ "Malay": 1,
54
+ "Malayalam": 1,
55
+ "Marathi": 1,
56
+ "Nepali": 1,
57
+ "Norwegian": 1,
58
+ "Persian": 1,
59
+ "Polish": 1,
60
+ "Portuguese": 1,
61
+ "Punjabi": 1,
62
+ "Romanian": 1,
63
+ "Russian": 1,
64
+ "Serbian": 1,
65
+ "Slovak": 1,
66
+ "Slovenian": 1,
67
+ "Spanish": 1,
68
+ "Swahili": 1,
69
+ "Swedish": 1,
70
+ "Tamil": 1,
71
+ "Telugu": 1,
72
+ "Thai": 1,
73
+ "Turkish": 1,
74
+ "Ukrainian": 1,
75
+ "Urdu": 1,
76
+ "Vietnamese": 1,
77
+ "Zulu": 1
78
+ }
@@ -0,0 +1,62 @@
1
+ import os
2
+ import json
3
+ import hashlib
4
+ import logging
5
+
6
+ current_script_dir = os.path.dirname(os.path.abspath(__file__))
7
+ project_root = os.path.abspath(os.path.join(current_script_dir, "../../..")) #change
8
+ temp_file_dir = os.path.join(project_root, "data/cache")
9
+
10
+ logger = logging.getLogger('Info Logger')
11
+
12
+
13
+ def save_register(register, name):
14
+ path = os.path.join(temp_file_dir, name)
15
+ with open(path, "w", encoding="utf-8") as file:
16
+ json.dump(register, file, ensure_ascii=False, indent=4)
17
+
18
+
19
+ def load_register(register_name):
20
+ register_path = os.path.join(temp_file_dir, register_name)
21
+ if not os.path.exists(temp_file_dir):
22
+ os.makedirs(temp_file_dir)
23
+ return {}
24
+ else:
25
+ if not os.path.exists(register_path):
26
+ with open(register_path, 'w', encoding="utf-8") as file:
27
+ json.dump({}, file, ensure_ascii=False, indent=4)
28
+ return {}
29
+ else:
30
+ with open(register_path, 'r', encoding="utf-8") as file:
31
+ hash_reg = json.load(file)
32
+ return hash_reg
33
+
34
+
35
+ def hash_generate(content_type=None, hasher=hashlib.md5(), **kwargs):
36
+ if content_type == "pdf":
37
+ hasher = hashlib.md5()
38
+ with open(kwargs.get("content",""), 'rb') as pdf_file:
39
+ buf = pdf_file.read()
40
+ hasher.update(buf)
41
+ return hasher.hexdigest()
42
+ else:
43
+ content = kwargs.get('content', '')
44
+ if isinstance(content, str):
45
+ hasher.update(content.encode("utf-8"))
46
+ else:
47
+ hasher.update(content)
48
+ return hasher.hexdigest()
49
+
50
+ def clear_register(register_name):
51
+ try:
52
+ path = os.path.join(temp_file_dir, register_name)
53
+ with open(path, 'w') as file:
54
+ json.dump({}, file)
55
+ except Exception as e:
56
+ logger.error("Couldn't clear cache because the cache file was not created during the execution.")
57
+
58
+
59
+ def clean_temp_files():
60
+ clear_register("image_register.json")
61
+ clear_register("pdf_register.json")
62
+ clear_register("webpage_register.json")
@@ -0,0 +1,63 @@
1
+ import logging
2
+ import sys
3
+ import colorama
4
+
5
+
6
+ # Initialize colorama
7
+ colorama.init(autoreset=True)
8
+
9
+ # Define color codes
10
+ RESET = colorama.Style.RESET_ALL
11
+ BLACK = colorama.Fore.BLACK
12
+ RED = colorama.Fore.RED
13
+ GREEN = colorama.Fore.GREEN
14
+ YELLOW = colorama.Fore.YELLOW
15
+ BLUE = colorama.Fore.BLUE
16
+ MAGENTA = colorama.Fore.MAGENTA
17
+ CYAN = colorama.Fore.CYAN
18
+ WHITE = colorama.Fore.WHITE
19
+
20
+
21
+ class ColoredFormatter(logging.Formatter):
22
+ # Mapping of log levels to colors
23
+ LEVEL_COLORS = {
24
+ logging.DEBUG: CYAN,
25
+ logging.INFO: GREEN,
26
+ logging.WARNING: YELLOW,
27
+ logging.ERROR: RED,
28
+ logging.CRITICAL: MAGENTA,
29
+ }
30
+
31
+ def format(self, record):
32
+ # Get the color for the current log level
33
+ level_color = self.LEVEL_COLORS.get(record.levelno, WHITE)
34
+
35
+ # Apply the color to the level name and message
36
+ record.levelname = f"{level_color}{record.levelname}{RESET}"
37
+ record.msg = f"{level_color}{record.msg}{RESET}"
38
+
39
+ # Format the message
40
+ return super().format(record)
41
+
42
+
43
+ def create_logger(verbose, name=None):
44
+ if name:
45
+ my_logger = logging.getLogger(name)
46
+ else:
47
+ my_logger = logging.getLogger()
48
+
49
+ if verbose:
50
+ my_logger.setLevel(logging.DEBUG)
51
+ else:
52
+ my_logger.setLevel(logging.CRITICAL)
53
+
54
+ console_handler = logging.StreamHandler(sys.stdout)
55
+ console_handler.setLevel(logging.DEBUG)
56
+
57
+ log_format = ColoredFormatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
58
+ # log_format = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
59
+ console_handler.setFormatter(log_format)
60
+
61
+ my_logger.addHandler(console_handler)
62
+
63
+ return my_logger
@@ -0,0 +1,338 @@
1
+ import re
2
+ import os
3
+ import base64
4
+ import tiktoken
5
+ import requests
6
+ import pandas as pd
7
+ import logging
8
+ from io import BytesIO
9
+ from PIL import Image
10
+ from langchain_core.output_parsers import StrOutputParser
11
+ from user_sim.utils import config
12
+ from user_sim.utils.utilities import get_encoding
13
+
14
+ logger = logging.getLogger('Info Logger')
15
+
16
+ columns = ["Conversation", "Test Name", "Module", "Model", "Total Cost",
17
+ "Timestamp", "Input Cost", "Input Message",
18
+ "Output Cost", "Output Message"]
19
+
20
+ PRICING = {
21
+ "gpt-4o": {"input": 2.5 / 10**6, "output": 10 / 10**6},
22
+ "gpt-4o-mini": {"input": 0.15 / 10**6, "output": 0.6 / 10**6},
23
+ "whisper": 0.006/60,
24
+ "tts-1": 0.0015/1000, # (characters, not tokens)
25
+ "gemini-2.0-flash": 0
26
+ }
27
+
28
+ TOKENS = {
29
+ "gpt-4o": {"input": 10**6/2.5, "output": 10**6/10},
30
+ "gpt-4o-mini": {"input": 10**6/0.15, "output": 10**6/0.6},
31
+ "whisper": 60/0.006,
32
+ "tts-1": 1000/0.0015, # (characters, not tokens)
33
+ "gemini-2.0-flash": 0
34
+
35
+ }
36
+
37
+ MAX_MODEL_TOKENS = {
38
+ "gpt-4o": 16384,
39
+ "gpt-4o-mini": 16384,
40
+ "gemini-2.0-flash": 10000000
41
+ }
42
+
43
+
44
+ DEFAULT_COSTS = {
45
+ # OpenAI models costs per 1M tokens
46
+ "gpt-4o": {"prompt": 5.00, "completion": 20.00},
47
+ "gpt-4o-mini": {"prompt": 0.60, "completion": 2.40},
48
+ "gpt-4.1": {"prompt": 2.00, "completion": 8.00},
49
+ "gpt-4.1-mini": {"prompt": 0.40, "completion": 1.60},
50
+ "gpt-4.1-nano": {"prompt": 0.10, "completion": 0.40},
51
+ # Google/Gemini models costs per 1M tokens
52
+ "gemini-2.0-flash": {"prompt": 0.10, "completion": 0.40},
53
+ "gemini-2.5-flash-preview-05-2023": {"prompt": 0.15, "completion": 0.60},
54
+ # Default fallback rates if model not recognized
55
+ "default": {"prompt": 0.10, "completion": 0.40},
56
+ }
57
+
58
+
59
+ def create_cost_dataset(serial, test_cases_folder):
60
+ folder = f"{test_cases_folder}/reports/__cost_reports__"
61
+ file = f"cost_report_{serial}.csv"
62
+ if not os.path.exists(folder):
63
+ os.makedirs(folder)
64
+ logger.info(f"Created cost report folder at: {folder}")
65
+
66
+ path = f"{folder}/{file}"
67
+
68
+ cost_df = pd.DataFrame(columns=columns)
69
+ cost_df.to_csv(path, index=False)
70
+ config.cost_ds_path = path
71
+ logger.info(f"Cost dataframe created at {path}.")
72
+
73
+
74
+ def count_tokens(text, model="gpt-4o-mini"):
75
+ try:
76
+ # First try to use the model name directly with tiktoken
77
+ encoding = tiktoken.encoding_for_model(model)
78
+ except (KeyError, ValueError):
79
+ # If tiktoken doesn't recognize the model, use cl100k_base encoding
80
+ # which is used for GPT-4 family models including gpt-4o and gpt-4o-mini
81
+ logger.warning(
82
+ f"Model '{model}' not recognized by tiktoken, using cl100k_base encoding"
83
+ )
84
+ encoding = tiktoken.get_encoding("cl100k_base")
85
+
86
+ return len(encoding.encode(text))
87
+
88
+
89
+ def calculate_text_cost(tokens, model="gpt-4o-mini", io_type="input"):
90
+ cost = tokens * PRICING[model][io_type]
91
+ return cost
92
+
93
+
94
+ def calculate_image_cost(image):
95
+ def get_dimensions(image_input):
96
+ try:
97
+ if isinstance(image_input, bytes):
98
+ image_input = image_input.decode('utf-8')
99
+ if re.match(r'^https?://', image_input) or re.match(r'^http?://', image_input): # Detects if it's a URL
100
+ response = requests.get(image_input)
101
+ response.raise_for_status() #
102
+ image = Image.open(BytesIO(response.content))
103
+ else:
104
+ decoded_image = base64.b64decode(image_input)
105
+ image = Image.open(BytesIO(decoded_image))
106
+
107
+ # Get the dimensions
108
+ w, h = image.size
109
+ return w, h
110
+ except Exception as e:
111
+ logger.error(e)
112
+ return None
113
+
114
+ dimensions = get_dimensions(image)
115
+ if dimensions is None:
116
+ logger.warning("Couldn't get image dimensions.")
117
+ return None
118
+ width, height = dimensions
119
+
120
+ # Initial configuration
121
+ price_per_million_tokens = 0.15
122
+ tokens_per_tile = 5667
123
+ base_tokens = 2833
124
+
125
+ # Calculate the number of tiles needed (512 x 512 pixels)
126
+ horizontal_tiles = (width + 511) // 512
127
+ vertical_tiles = (height + 511) // 512
128
+ total_tiles = horizontal_tiles * vertical_tiles
129
+
130
+ # Calculate the total tokens
131
+ total_tokens = base_tokens + (tokens_per_tile * total_tiles)
132
+
133
+ # Convert tokens to price
134
+ total_price = (total_tokens / 1_000_000) * price_per_million_tokens
135
+
136
+ return total_price
137
+
138
+
139
+
140
+ # VISION
141
+ def input_vision_module_cost(input_message, image, model):
142
+ input_tokens = count_tokens(input_message, model)
143
+ image_cost = calculate_image_cost(image)
144
+ if image_cost is None:
145
+ logger.warning("Image cost set to $0.")
146
+ image_cost = 0
147
+
148
+ model_pricing = PRICING[model]
149
+ input_cost = input_tokens * model_pricing["input"] + image_cost
150
+ return input_cost
151
+ def output_vision_module_cost(output_message, model):
152
+ output_tokens = count_tokens(output_message, model)
153
+ model_pricing = PRICING[model]
154
+ output_cost = output_tokens * model_pricing["output"]
155
+ return output_cost
156
+
157
+ # TTS-STT
158
+ def input_tts_module_cost(input_message, model):
159
+ model_pricing = PRICING[model]
160
+ input_cost = len(input_message) * model_pricing
161
+ return input_cost
162
+ def whisper_module_cost(audio_length, model):
163
+ audio_length = audio_length
164
+ model_pricing = PRICING[model]
165
+ input_cost = audio_length * model_pricing
166
+ return input_cost
167
+
168
+ # TEXT
169
+ def input_text_module_cost(input_message, model):
170
+ if isinstance(input_message, list):
171
+ input_message = ", ".join(input_message)
172
+ input_tokens = count_tokens(input_message, model)
173
+ model_pricing = PRICING[model]
174
+ input_cost = input_tokens * model_pricing["input"]
175
+ return input_cost
176
+ def output_text_module_cost(output_message, model):
177
+ if isinstance(output_message, list):
178
+ output_message = ", ".join(output_message)
179
+ output_tokens = count_tokens(output_message, model)
180
+ model_pricing = PRICING[model]
181
+ output_cost = output_tokens * model_pricing["output"]
182
+ return output_cost
183
+
184
+
185
+ def calculate_cost(input_message='', output_message='', model="gpt-4o", module=None, **kwargs):
186
+ # input_tokens = count_tokens(input_message, model)
187
+ # output_tokens = count_tokens(output_message, model)
188
+
189
+ if input_message is None:
190
+ input_message = ""
191
+ if output_message is None:
192
+ output_message = ""
193
+
194
+
195
+ if model not in PRICING:
196
+ raise ValueError(f"Pricing not available for model: {model}")
197
+
198
+ if model == "whisper":
199
+ input_cost = 0
200
+ output_cost = whisper_module_cost(kwargs.get("audio_length", None), model)
201
+ total_cost = output_cost
202
+
203
+ elif model == "tts-1":
204
+ input_cost = input_tts_module_cost(input_message, model)
205
+ output_cost = 0
206
+ total_cost = input_cost
207
+
208
+ elif kwargs.get("image", None):
209
+ input_cost = input_vision_module_cost(input_message, kwargs.get("image", None), model)
210
+ output_cost = output_vision_module_cost(output_message, model)
211
+ total_cost = input_cost + output_cost
212
+
213
+ else:
214
+ input_cost = input_text_module_cost(input_message, model)
215
+ output_cost = output_text_module_cost(output_message, model)
216
+ total_cost = input_cost + output_cost
217
+
218
+
219
+ def update_dataframe():
220
+ new_row = {"Conversation": config.conversation_name, "Test Name": config.test_name, "Module": module,
221
+ "Model": model, "Total Cost": total_cost, "Timestamp": pd.Timestamp.now(),
222
+ "Input Cost": input_cost, "Input Message": input_message,
223
+ "Output Cost": output_cost, "Output Message": output_message}
224
+
225
+ encoding = get_encoding(config.cost_ds_path)["encoding"]
226
+ cost_df = pd.read_csv(config.cost_ds_path, encoding=encoding)
227
+ cost_df.loc[len(cost_df)] = new_row
228
+ cost_df.to_csv(config.cost_ds_path, index=False)
229
+
230
+
231
+ config.total_cost = config.total_individual_cost = float(cost_df['Total Cost'].sum())
232
+
233
+
234
+ logger.info(f"Updated 'cost_report' dataframe with new cost from {module}.")
235
+
236
+ update_dataframe()
237
+
238
+
239
+ def get_cost_report(test_cases_folder):
240
+ export_path = test_cases_folder + f"/reports/__cost_report__"
241
+ serial = config.serial
242
+ if not os.path.exists(export_path):
243
+ os.makedirs(export_path)
244
+
245
+ export_file_name = export_path + f"/report_{serial}.csv"
246
+
247
+ encoding = get_encoding(config.cost_ds_path)["encoding"]
248
+ temp_cost_df = pd.read_csv(config.cost_ds_path, encoding=encoding)
249
+ temp_cost_df.to_csv(export_file_name, index=False)
250
+
251
+
252
+ def max_input_tokens_allowed(text='', model_used='gpt-4o-mini', **kwargs):
253
+
254
+ def get_delta_verification(sim_cost, sim_ind_cost):
255
+ delta_cost = config.limit_cost - sim_cost
256
+ delta_individual_cost = config.limit_individual_cost - sim_ind_cost
257
+ logger.info(f"${delta_cost} for global and ${delta_individual_cost} for individual input cost left.")
258
+ return True if delta_cost <= 0 or delta_individual_cost <= 0 else False
259
+
260
+ if config.token_count_enabled:
261
+ if kwargs.get("image", None):
262
+ input_cost = input_vision_module_cost(text, kwargs.get("image", 0), model_used)
263
+ simulated_cost = input_cost + config.total_cost
264
+ simulated_individual_cost = input_cost + config.total_individual_cost
265
+ return get_delta_verification(simulated_cost, simulated_individual_cost)
266
+ elif model_used == "tts-1":
267
+ input_cost = input_tts_module_cost(text, model_used)
268
+ simulated_cost = input_cost + config.total_cost
269
+ simulated_individual_cost = input_cost + config.total_individual_cost
270
+ return get_delta_verification(simulated_cost, simulated_individual_cost)
271
+ elif model_used == "whisper":
272
+ input_cost = whisper_module_cost(kwargs.get("audio_length", 0), model_used)
273
+ simulated_cost = input_cost + config.total_cost
274
+ simulated_individual_cost = input_cost + config.total_individual_cost
275
+ return get_delta_verification(simulated_cost, simulated_individual_cost)
276
+ else:
277
+ input_cost = input_text_module_cost(text, model_used)
278
+ simulated_cost = input_cost + config.total_cost
279
+ simulated_individual_cost = input_cost + config.total_individual_cost
280
+ return get_delta_verification(simulated_cost, simulated_individual_cost)
281
+ else:
282
+ return False
283
+
284
+ def max_output_tokens_allowed(model_used):
285
+ if config.token_count_enabled:
286
+ delta_cost = config.limit_cost - config.total_cost
287
+ delta_individual_cost = config.limit_individual_cost - config.total_individual_cost
288
+
289
+ delta = min([delta_cost, delta_individual_cost])
290
+ output_tokens = round(delta * TOKENS[model_used]["output"])
291
+
292
+
293
+ if MAX_MODEL_TOKENS[model_used]<output_tokens:
294
+ output_tokens = MAX_MODEL_TOKENS[model_used]
295
+
296
+ logger.info(f"{output_tokens} output tokens left.")
297
+ return output_tokens
298
+ else:
299
+ return
300
+
301
+
302
+ def invoke_llm(llm, prompt, input_params, model, module, parser=False):
303
+
304
+ # Outputs input messages as text.
305
+ if isinstance(input_params, dict):
306
+ messages = list(input_params.values())
307
+ parsed_messages = " ".join(messages)
308
+ else:
309
+ parsed_messages = input_params
310
+
311
+ # Measures max input tokens allowed by the execution
312
+ if config.token_count_enabled and max_input_tokens_allowed(parsed_messages, model):
313
+ logger.error(f"Token limit was surpassed in {module} module")
314
+ return None
315
+
316
+ # Calculates the amount of tokens left and updates the LLM max_tokens parameter
317
+ if config.token_count_enabled:
318
+ llm.max_tokens = max_output_tokens_allowed(model)
319
+
320
+ # Enables str output parser
321
+ if parser:
322
+ parser = StrOutputParser()
323
+ llm_chain = prompt | llm | parser
324
+ else:
325
+ llm_chain = prompt | llm
326
+
327
+ # Invoke LLM
328
+ try:
329
+ response = llm_chain.invoke(input_params)
330
+ if config.token_count_enabled:
331
+ calculate_cost(parsed_messages, response, model, module="user_simulator")
332
+ except Exception as e:
333
+ logger.error(e)
334
+ response = None
335
+ if response is None and module == "user_simulator":
336
+ response = "exit"
337
+
338
+ return response
@@ -0,0 +1,60 @@
1
+ import re
2
+ from typing import List, Dict
3
+ from user_sim.handlers.pdf_parser_module import pdf_processor
4
+ from user_sim.handlers.image_recognition_module import image_description
5
+ from user_sim.handlers.html_parser_module import webpage_reader
6
+
7
+
8
+ def classify_links(message: str) -> Dict[str, List[str]]:
9
+ url_pattern = re.compile(r'https?://\S+') # Capture URLs
10
+ links = url_pattern.findall(message)
11
+
12
+ classified_links = {
13
+ "images": [],
14
+ "pdfs": [],
15
+ "webpages": []
16
+ }
17
+
18
+ for link in links:
19
+ if re.search(r'\.(jpg|jpeg|png|gif|webp|bmp|tiff)$', link, re.IGNORECASE) or '<image>' in message:
20
+ clean_link = re.sub(r'</?image>', '', link)
21
+ classified_links["images"].append(clean_link)
22
+ elif re.search(r'\.pdf$', link, re.IGNORECASE) or 'application/pdf' in message:
23
+ classified_links["pdfs"].append(link)
24
+ else:
25
+ classified_links["webpages"].append(link)
26
+
27
+ return classified_links
28
+
29
+
30
+ def process_with_llm(link: str, category) -> str:
31
+
32
+ if category == "pdfs":
33
+ description = pdf_processor(link)
34
+ message_replacement = f"{link} {description}"
35
+ return message_replacement
36
+
37
+ elif category == "images":
38
+ description = image_description(link, detailed=True)
39
+ message_replacement = f"{link} {description}"
40
+ return message_replacement
41
+ else:
42
+ description = webpage_reader(link)
43
+ message_replacement = f"{link} {description}"
44
+ return message_replacement
45
+
46
+
47
+ def get_content(message: str) -> str:
48
+ classified_links = classify_links(message)
49
+ for category in classified_links:
50
+ for link in classified_links[category]:
51
+ description = process_with_llm(link, category)
52
+ message = message.replace(link, description)
53
+
54
+ return message
55
+
56
+
57
+ # def clean_temp_files():
58
+ # clear_pdf_register()
59
+ # clear_image_register()
60
+ # clear_webpage_register()