user-simulator 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- user_sim/__init__.py +0 -0
- user_sim/cli/__init__.py +0 -0
- user_sim/cli/gen_user_profile.py +34 -0
- user_sim/cli/init_project.py +65 -0
- user_sim/cli/sensei_chat.py +481 -0
- user_sim/cli/sensei_check.py +103 -0
- user_sim/cli/validation_check.py +143 -0
- user_sim/core/__init__.py +0 -0
- user_sim/core/ask_about.py +665 -0
- user_sim/core/data_extraction.py +260 -0
- user_sim/core/data_gathering.py +134 -0
- user_sim/core/interaction_styles.py +147 -0
- user_sim/core/role_structure.py +608 -0
- user_sim/core/user_simulator.py +302 -0
- user_sim/handlers/__init__.py +0 -0
- user_sim/handlers/asr_module.py +128 -0
- user_sim/handlers/html_parser_module.py +202 -0
- user_sim/handlers/image_recognition_module.py +139 -0
- user_sim/handlers/pdf_parser_module.py +123 -0
- user_sim/utils/__init__.py +0 -0
- user_sim/utils/config.py +47 -0
- user_sim/utils/cost_tracker.py +153 -0
- user_sim/utils/cost_tracker_v2.py +193 -0
- user_sim/utils/errors.py +15 -0
- user_sim/utils/exceptions.py +47 -0
- user_sim/utils/languages.py +78 -0
- user_sim/utils/register_management.py +62 -0
- user_sim/utils/show_logs.py +63 -0
- user_sim/utils/token_cost_calculator.py +338 -0
- user_sim/utils/url_management.py +60 -0
- user_sim/utils/utilities.py +568 -0
- user_simulator-0.1.0.dist-info/METADATA +733 -0
- user_simulator-0.1.0.dist-info/RECORD +37 -0
- user_simulator-0.1.0.dist-info/WHEEL +5 -0
- user_simulator-0.1.0.dist-info/entry_points.txt +6 -0
- user_simulator-0.1.0.dist-info/licenses/LICENSE.txt +21 -0
- user_simulator-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,47 @@
|
|
1
|
+
class InvalidGoalException(Exception):
|
2
|
+
pass
|
3
|
+
|
4
|
+
class InvalidInteractionException(Exception):
|
5
|
+
pass
|
6
|
+
|
7
|
+
class InvalidLanguageException(Exception):
|
8
|
+
pass
|
9
|
+
|
10
|
+
class OutOfLimitException(Exception):
|
11
|
+
pass
|
12
|
+
|
13
|
+
class BadDictionaryGeneration(Exception):
|
14
|
+
pass
|
15
|
+
|
16
|
+
class InvalidItemType(Exception):
|
17
|
+
pass
|
18
|
+
|
19
|
+
class EmptyListExcept(Exception):
|
20
|
+
pass
|
21
|
+
|
22
|
+
class InvalidDataType(Exception):
|
23
|
+
pass
|
24
|
+
|
25
|
+
class InvalidFormat(Exception):
|
26
|
+
pass
|
27
|
+
|
28
|
+
class MissingStepDefinition(Exception):
|
29
|
+
pass
|
30
|
+
|
31
|
+
class InvalidGenerator(Exception):
|
32
|
+
pass
|
33
|
+
|
34
|
+
class VariableNotFound(Exception):
|
35
|
+
pass
|
36
|
+
|
37
|
+
class InvalidDependence(Exception):
|
38
|
+
pass
|
39
|
+
|
40
|
+
class InvalidFile(Exception):
|
41
|
+
pass
|
42
|
+
|
43
|
+
class NoCostException(Exception):
|
44
|
+
pass
|
45
|
+
|
46
|
+
class UnmachedList(Exception):
|
47
|
+
pass
|
@@ -0,0 +1,78 @@
|
|
1
|
+
languages = [
|
2
|
+
"Afrikaans", "Albanian", "Amharic", "Arabic", "Armenian", "Azerbaijani", "Bengali", "Bosnian", "Bulgarian",
|
3
|
+
"Catalan", "Chinese (Simplified)", "Chinese (Traditional)", "Croatian", "Czech", "Danish", "Dutch",
|
4
|
+
"English", "Estonian", "Filipino", "Finnish", "French", "Galician", "Georgian", "German", "Greek",
|
5
|
+
"Gujarati", "Hausa", "Hebrew", "Hindi", "Hungarian", "Icelandic", "Indonesian", "Italian", "Japanese",
|
6
|
+
"Kannada", "Kazakh", "Korean", "Latvian", "Lithuanian", "Macedonian", "Malay", "Malayalam", "Marathi",
|
7
|
+
"Nepali", "Norwegian", "Persian", "Polish", "Portuguese", "Punjabi", "Romanian", "Russian", "Serbian",
|
8
|
+
"Slovak", "Slovenian", "Spanish", "Swahili", "Swedish", "Tamil", "Telugu", "Thai", "Turkish", "Ukrainian",
|
9
|
+
"Urdu", "Vietnamese", "Zulu"
|
10
|
+
]
|
11
|
+
|
12
|
+
languages_weights = {
|
13
|
+
"Afrikaans": 1,
|
14
|
+
"Albanian": 1,
|
15
|
+
"Amharic": 1,
|
16
|
+
"Arabic": 1,
|
17
|
+
"Armenian": 1,
|
18
|
+
"Azerbaijani": 1,
|
19
|
+
"Bengali": 1,
|
20
|
+
"Bosnian": 1,
|
21
|
+
"Bulgarian": 1,
|
22
|
+
"Catalan": 1,
|
23
|
+
"Chinese (Simplified)": 1,
|
24
|
+
"Chinese (Traditional)": 1,
|
25
|
+
"Croatian": 1,
|
26
|
+
"Czech": 1,
|
27
|
+
"Danish": 1,
|
28
|
+
"Dutch": 1,
|
29
|
+
"English": 1,
|
30
|
+
"Estonian": 1,
|
31
|
+
"Filipino": 1,
|
32
|
+
"Finnish": 1,
|
33
|
+
"French": 1,
|
34
|
+
"Galician": 1,
|
35
|
+
"Georgian": 1,
|
36
|
+
"German": 1,
|
37
|
+
"Greek": 1,
|
38
|
+
"Gujarati": 1,
|
39
|
+
"Hausa": 1,
|
40
|
+
"Hebrew": 1,
|
41
|
+
"Hindi": 1,
|
42
|
+
"Hungarian": 1,
|
43
|
+
"Icelandic": 1,
|
44
|
+
"Indonesian": 1,
|
45
|
+
"Italian": 1,
|
46
|
+
"Japanese": 1,
|
47
|
+
"Kannada": 1,
|
48
|
+
"Kazakh": 1,
|
49
|
+
"Korean": 1,
|
50
|
+
"Latvian": 1,
|
51
|
+
"Lithuanian": 1,
|
52
|
+
"Macedonian": 1,
|
53
|
+
"Malay": 1,
|
54
|
+
"Malayalam": 1,
|
55
|
+
"Marathi": 1,
|
56
|
+
"Nepali": 1,
|
57
|
+
"Norwegian": 1,
|
58
|
+
"Persian": 1,
|
59
|
+
"Polish": 1,
|
60
|
+
"Portuguese": 1,
|
61
|
+
"Punjabi": 1,
|
62
|
+
"Romanian": 1,
|
63
|
+
"Russian": 1,
|
64
|
+
"Serbian": 1,
|
65
|
+
"Slovak": 1,
|
66
|
+
"Slovenian": 1,
|
67
|
+
"Spanish": 1,
|
68
|
+
"Swahili": 1,
|
69
|
+
"Swedish": 1,
|
70
|
+
"Tamil": 1,
|
71
|
+
"Telugu": 1,
|
72
|
+
"Thai": 1,
|
73
|
+
"Turkish": 1,
|
74
|
+
"Ukrainian": 1,
|
75
|
+
"Urdu": 1,
|
76
|
+
"Vietnamese": 1,
|
77
|
+
"Zulu": 1
|
78
|
+
}
|
@@ -0,0 +1,62 @@
|
|
1
|
+
import os
|
2
|
+
import json
|
3
|
+
import hashlib
|
4
|
+
import logging
|
5
|
+
|
6
|
+
current_script_dir = os.path.dirname(os.path.abspath(__file__))
|
7
|
+
project_root = os.path.abspath(os.path.join(current_script_dir, "../../..")) #change
|
8
|
+
temp_file_dir = os.path.join(project_root, "data/cache")
|
9
|
+
|
10
|
+
logger = logging.getLogger('Info Logger')
|
11
|
+
|
12
|
+
|
13
|
+
def save_register(register, name):
|
14
|
+
path = os.path.join(temp_file_dir, name)
|
15
|
+
with open(path, "w", encoding="utf-8") as file:
|
16
|
+
json.dump(register, file, ensure_ascii=False, indent=4)
|
17
|
+
|
18
|
+
|
19
|
+
def load_register(register_name):
|
20
|
+
register_path = os.path.join(temp_file_dir, register_name)
|
21
|
+
if not os.path.exists(temp_file_dir):
|
22
|
+
os.makedirs(temp_file_dir)
|
23
|
+
return {}
|
24
|
+
else:
|
25
|
+
if not os.path.exists(register_path):
|
26
|
+
with open(register_path, 'w', encoding="utf-8") as file:
|
27
|
+
json.dump({}, file, ensure_ascii=False, indent=4)
|
28
|
+
return {}
|
29
|
+
else:
|
30
|
+
with open(register_path, 'r', encoding="utf-8") as file:
|
31
|
+
hash_reg = json.load(file)
|
32
|
+
return hash_reg
|
33
|
+
|
34
|
+
|
35
|
+
def hash_generate(content_type=None, hasher=hashlib.md5(), **kwargs):
|
36
|
+
if content_type == "pdf":
|
37
|
+
hasher = hashlib.md5()
|
38
|
+
with open(kwargs.get("content",""), 'rb') as pdf_file:
|
39
|
+
buf = pdf_file.read()
|
40
|
+
hasher.update(buf)
|
41
|
+
return hasher.hexdigest()
|
42
|
+
else:
|
43
|
+
content = kwargs.get('content', '')
|
44
|
+
if isinstance(content, str):
|
45
|
+
hasher.update(content.encode("utf-8"))
|
46
|
+
else:
|
47
|
+
hasher.update(content)
|
48
|
+
return hasher.hexdigest()
|
49
|
+
|
50
|
+
def clear_register(register_name):
|
51
|
+
try:
|
52
|
+
path = os.path.join(temp_file_dir, register_name)
|
53
|
+
with open(path, 'w') as file:
|
54
|
+
json.dump({}, file)
|
55
|
+
except Exception as e:
|
56
|
+
logger.error("Couldn't clear cache because the cache file was not created during the execution.")
|
57
|
+
|
58
|
+
|
59
|
+
def clean_temp_files():
|
60
|
+
clear_register("image_register.json")
|
61
|
+
clear_register("pdf_register.json")
|
62
|
+
clear_register("webpage_register.json")
|
@@ -0,0 +1,63 @@
|
|
1
|
+
import logging
|
2
|
+
import sys
|
3
|
+
import colorama
|
4
|
+
|
5
|
+
|
6
|
+
# Initialize colorama
|
7
|
+
colorama.init(autoreset=True)
|
8
|
+
|
9
|
+
# Define color codes
|
10
|
+
RESET = colorama.Style.RESET_ALL
|
11
|
+
BLACK = colorama.Fore.BLACK
|
12
|
+
RED = colorama.Fore.RED
|
13
|
+
GREEN = colorama.Fore.GREEN
|
14
|
+
YELLOW = colorama.Fore.YELLOW
|
15
|
+
BLUE = colorama.Fore.BLUE
|
16
|
+
MAGENTA = colorama.Fore.MAGENTA
|
17
|
+
CYAN = colorama.Fore.CYAN
|
18
|
+
WHITE = colorama.Fore.WHITE
|
19
|
+
|
20
|
+
|
21
|
+
class ColoredFormatter(logging.Formatter):
|
22
|
+
# Mapping of log levels to colors
|
23
|
+
LEVEL_COLORS = {
|
24
|
+
logging.DEBUG: CYAN,
|
25
|
+
logging.INFO: GREEN,
|
26
|
+
logging.WARNING: YELLOW,
|
27
|
+
logging.ERROR: RED,
|
28
|
+
logging.CRITICAL: MAGENTA,
|
29
|
+
}
|
30
|
+
|
31
|
+
def format(self, record):
|
32
|
+
# Get the color for the current log level
|
33
|
+
level_color = self.LEVEL_COLORS.get(record.levelno, WHITE)
|
34
|
+
|
35
|
+
# Apply the color to the level name and message
|
36
|
+
record.levelname = f"{level_color}{record.levelname}{RESET}"
|
37
|
+
record.msg = f"{level_color}{record.msg}{RESET}"
|
38
|
+
|
39
|
+
# Format the message
|
40
|
+
return super().format(record)
|
41
|
+
|
42
|
+
|
43
|
+
def create_logger(verbose, name=None):
|
44
|
+
if name:
|
45
|
+
my_logger = logging.getLogger(name)
|
46
|
+
else:
|
47
|
+
my_logger = logging.getLogger()
|
48
|
+
|
49
|
+
if verbose:
|
50
|
+
my_logger.setLevel(logging.DEBUG)
|
51
|
+
else:
|
52
|
+
my_logger.setLevel(logging.CRITICAL)
|
53
|
+
|
54
|
+
console_handler = logging.StreamHandler(sys.stdout)
|
55
|
+
console_handler.setLevel(logging.DEBUG)
|
56
|
+
|
57
|
+
log_format = ColoredFormatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
58
|
+
# log_format = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
59
|
+
console_handler.setFormatter(log_format)
|
60
|
+
|
61
|
+
my_logger.addHandler(console_handler)
|
62
|
+
|
63
|
+
return my_logger
|
@@ -0,0 +1,338 @@
|
|
1
|
+
import re
|
2
|
+
import os
|
3
|
+
import base64
|
4
|
+
import tiktoken
|
5
|
+
import requests
|
6
|
+
import pandas as pd
|
7
|
+
import logging
|
8
|
+
from io import BytesIO
|
9
|
+
from PIL import Image
|
10
|
+
from langchain_core.output_parsers import StrOutputParser
|
11
|
+
from user_sim.utils import config
|
12
|
+
from user_sim.utils.utilities import get_encoding
|
13
|
+
|
14
|
+
logger = logging.getLogger('Info Logger')
|
15
|
+
|
16
|
+
columns = ["Conversation", "Test Name", "Module", "Model", "Total Cost",
|
17
|
+
"Timestamp", "Input Cost", "Input Message",
|
18
|
+
"Output Cost", "Output Message"]
|
19
|
+
|
20
|
+
PRICING = {
|
21
|
+
"gpt-4o": {"input": 2.5 / 10**6, "output": 10 / 10**6},
|
22
|
+
"gpt-4o-mini": {"input": 0.15 / 10**6, "output": 0.6 / 10**6},
|
23
|
+
"whisper": 0.006/60,
|
24
|
+
"tts-1": 0.0015/1000, # (characters, not tokens)
|
25
|
+
"gemini-2.0-flash": 0
|
26
|
+
}
|
27
|
+
|
28
|
+
TOKENS = {
|
29
|
+
"gpt-4o": {"input": 10**6/2.5, "output": 10**6/10},
|
30
|
+
"gpt-4o-mini": {"input": 10**6/0.15, "output": 10**6/0.6},
|
31
|
+
"whisper": 60/0.006,
|
32
|
+
"tts-1": 1000/0.0015, # (characters, not tokens)
|
33
|
+
"gemini-2.0-flash": 0
|
34
|
+
|
35
|
+
}
|
36
|
+
|
37
|
+
MAX_MODEL_TOKENS = {
|
38
|
+
"gpt-4o": 16384,
|
39
|
+
"gpt-4o-mini": 16384,
|
40
|
+
"gemini-2.0-flash": 10000000
|
41
|
+
}
|
42
|
+
|
43
|
+
|
44
|
+
DEFAULT_COSTS = {
|
45
|
+
# OpenAI models costs per 1M tokens
|
46
|
+
"gpt-4o": {"prompt": 5.00, "completion": 20.00},
|
47
|
+
"gpt-4o-mini": {"prompt": 0.60, "completion": 2.40},
|
48
|
+
"gpt-4.1": {"prompt": 2.00, "completion": 8.00},
|
49
|
+
"gpt-4.1-mini": {"prompt": 0.40, "completion": 1.60},
|
50
|
+
"gpt-4.1-nano": {"prompt": 0.10, "completion": 0.40},
|
51
|
+
# Google/Gemini models costs per 1M tokens
|
52
|
+
"gemini-2.0-flash": {"prompt": 0.10, "completion": 0.40},
|
53
|
+
"gemini-2.5-flash-preview-05-2023": {"prompt": 0.15, "completion": 0.60},
|
54
|
+
# Default fallback rates if model not recognized
|
55
|
+
"default": {"prompt": 0.10, "completion": 0.40},
|
56
|
+
}
|
57
|
+
|
58
|
+
|
59
|
+
def create_cost_dataset(serial, test_cases_folder):
|
60
|
+
folder = f"{test_cases_folder}/reports/__cost_reports__"
|
61
|
+
file = f"cost_report_{serial}.csv"
|
62
|
+
if not os.path.exists(folder):
|
63
|
+
os.makedirs(folder)
|
64
|
+
logger.info(f"Created cost report folder at: {folder}")
|
65
|
+
|
66
|
+
path = f"{folder}/{file}"
|
67
|
+
|
68
|
+
cost_df = pd.DataFrame(columns=columns)
|
69
|
+
cost_df.to_csv(path, index=False)
|
70
|
+
config.cost_ds_path = path
|
71
|
+
logger.info(f"Cost dataframe created at {path}.")
|
72
|
+
|
73
|
+
|
74
|
+
def count_tokens(text, model="gpt-4o-mini"):
|
75
|
+
try:
|
76
|
+
# First try to use the model name directly with tiktoken
|
77
|
+
encoding = tiktoken.encoding_for_model(model)
|
78
|
+
except (KeyError, ValueError):
|
79
|
+
# If tiktoken doesn't recognize the model, use cl100k_base encoding
|
80
|
+
# which is used for GPT-4 family models including gpt-4o and gpt-4o-mini
|
81
|
+
logger.warning(
|
82
|
+
f"Model '{model}' not recognized by tiktoken, using cl100k_base encoding"
|
83
|
+
)
|
84
|
+
encoding = tiktoken.get_encoding("cl100k_base")
|
85
|
+
|
86
|
+
return len(encoding.encode(text))
|
87
|
+
|
88
|
+
|
89
|
+
def calculate_text_cost(tokens, model="gpt-4o-mini", io_type="input"):
|
90
|
+
cost = tokens * PRICING[model][io_type]
|
91
|
+
return cost
|
92
|
+
|
93
|
+
|
94
|
+
def calculate_image_cost(image):
|
95
|
+
def get_dimensions(image_input):
|
96
|
+
try:
|
97
|
+
if isinstance(image_input, bytes):
|
98
|
+
image_input = image_input.decode('utf-8')
|
99
|
+
if re.match(r'^https?://', image_input) or re.match(r'^http?://', image_input): # Detects if it's a URL
|
100
|
+
response = requests.get(image_input)
|
101
|
+
response.raise_for_status() #
|
102
|
+
image = Image.open(BytesIO(response.content))
|
103
|
+
else:
|
104
|
+
decoded_image = base64.b64decode(image_input)
|
105
|
+
image = Image.open(BytesIO(decoded_image))
|
106
|
+
|
107
|
+
# Get the dimensions
|
108
|
+
w, h = image.size
|
109
|
+
return w, h
|
110
|
+
except Exception as e:
|
111
|
+
logger.error(e)
|
112
|
+
return None
|
113
|
+
|
114
|
+
dimensions = get_dimensions(image)
|
115
|
+
if dimensions is None:
|
116
|
+
logger.warning("Couldn't get image dimensions.")
|
117
|
+
return None
|
118
|
+
width, height = dimensions
|
119
|
+
|
120
|
+
# Initial configuration
|
121
|
+
price_per_million_tokens = 0.15
|
122
|
+
tokens_per_tile = 5667
|
123
|
+
base_tokens = 2833
|
124
|
+
|
125
|
+
# Calculate the number of tiles needed (512 x 512 pixels)
|
126
|
+
horizontal_tiles = (width + 511) // 512
|
127
|
+
vertical_tiles = (height + 511) // 512
|
128
|
+
total_tiles = horizontal_tiles * vertical_tiles
|
129
|
+
|
130
|
+
# Calculate the total tokens
|
131
|
+
total_tokens = base_tokens + (tokens_per_tile * total_tiles)
|
132
|
+
|
133
|
+
# Convert tokens to price
|
134
|
+
total_price = (total_tokens / 1_000_000) * price_per_million_tokens
|
135
|
+
|
136
|
+
return total_price
|
137
|
+
|
138
|
+
|
139
|
+
|
140
|
+
# VISION
|
141
|
+
def input_vision_module_cost(input_message, image, model):
|
142
|
+
input_tokens = count_tokens(input_message, model)
|
143
|
+
image_cost = calculate_image_cost(image)
|
144
|
+
if image_cost is None:
|
145
|
+
logger.warning("Image cost set to $0.")
|
146
|
+
image_cost = 0
|
147
|
+
|
148
|
+
model_pricing = PRICING[model]
|
149
|
+
input_cost = input_tokens * model_pricing["input"] + image_cost
|
150
|
+
return input_cost
|
151
|
+
def output_vision_module_cost(output_message, model):
|
152
|
+
output_tokens = count_tokens(output_message, model)
|
153
|
+
model_pricing = PRICING[model]
|
154
|
+
output_cost = output_tokens * model_pricing["output"]
|
155
|
+
return output_cost
|
156
|
+
|
157
|
+
# TTS-STT
|
158
|
+
def input_tts_module_cost(input_message, model):
|
159
|
+
model_pricing = PRICING[model]
|
160
|
+
input_cost = len(input_message) * model_pricing
|
161
|
+
return input_cost
|
162
|
+
def whisper_module_cost(audio_length, model):
|
163
|
+
audio_length = audio_length
|
164
|
+
model_pricing = PRICING[model]
|
165
|
+
input_cost = audio_length * model_pricing
|
166
|
+
return input_cost
|
167
|
+
|
168
|
+
# TEXT
|
169
|
+
def input_text_module_cost(input_message, model):
|
170
|
+
if isinstance(input_message, list):
|
171
|
+
input_message = ", ".join(input_message)
|
172
|
+
input_tokens = count_tokens(input_message, model)
|
173
|
+
model_pricing = PRICING[model]
|
174
|
+
input_cost = input_tokens * model_pricing["input"]
|
175
|
+
return input_cost
|
176
|
+
def output_text_module_cost(output_message, model):
|
177
|
+
if isinstance(output_message, list):
|
178
|
+
output_message = ", ".join(output_message)
|
179
|
+
output_tokens = count_tokens(output_message, model)
|
180
|
+
model_pricing = PRICING[model]
|
181
|
+
output_cost = output_tokens * model_pricing["output"]
|
182
|
+
return output_cost
|
183
|
+
|
184
|
+
|
185
|
+
def calculate_cost(input_message='', output_message='', model="gpt-4o", module=None, **kwargs):
|
186
|
+
# input_tokens = count_tokens(input_message, model)
|
187
|
+
# output_tokens = count_tokens(output_message, model)
|
188
|
+
|
189
|
+
if input_message is None:
|
190
|
+
input_message = ""
|
191
|
+
if output_message is None:
|
192
|
+
output_message = ""
|
193
|
+
|
194
|
+
|
195
|
+
if model not in PRICING:
|
196
|
+
raise ValueError(f"Pricing not available for model: {model}")
|
197
|
+
|
198
|
+
if model == "whisper":
|
199
|
+
input_cost = 0
|
200
|
+
output_cost = whisper_module_cost(kwargs.get("audio_length", None), model)
|
201
|
+
total_cost = output_cost
|
202
|
+
|
203
|
+
elif model == "tts-1":
|
204
|
+
input_cost = input_tts_module_cost(input_message, model)
|
205
|
+
output_cost = 0
|
206
|
+
total_cost = input_cost
|
207
|
+
|
208
|
+
elif kwargs.get("image", None):
|
209
|
+
input_cost = input_vision_module_cost(input_message, kwargs.get("image", None), model)
|
210
|
+
output_cost = output_vision_module_cost(output_message, model)
|
211
|
+
total_cost = input_cost + output_cost
|
212
|
+
|
213
|
+
else:
|
214
|
+
input_cost = input_text_module_cost(input_message, model)
|
215
|
+
output_cost = output_text_module_cost(output_message, model)
|
216
|
+
total_cost = input_cost + output_cost
|
217
|
+
|
218
|
+
|
219
|
+
def update_dataframe():
|
220
|
+
new_row = {"Conversation": config.conversation_name, "Test Name": config.test_name, "Module": module,
|
221
|
+
"Model": model, "Total Cost": total_cost, "Timestamp": pd.Timestamp.now(),
|
222
|
+
"Input Cost": input_cost, "Input Message": input_message,
|
223
|
+
"Output Cost": output_cost, "Output Message": output_message}
|
224
|
+
|
225
|
+
encoding = get_encoding(config.cost_ds_path)["encoding"]
|
226
|
+
cost_df = pd.read_csv(config.cost_ds_path, encoding=encoding)
|
227
|
+
cost_df.loc[len(cost_df)] = new_row
|
228
|
+
cost_df.to_csv(config.cost_ds_path, index=False)
|
229
|
+
|
230
|
+
|
231
|
+
config.total_cost = config.total_individual_cost = float(cost_df['Total Cost'].sum())
|
232
|
+
|
233
|
+
|
234
|
+
logger.info(f"Updated 'cost_report' dataframe with new cost from {module}.")
|
235
|
+
|
236
|
+
update_dataframe()
|
237
|
+
|
238
|
+
|
239
|
+
def get_cost_report(test_cases_folder):
|
240
|
+
export_path = test_cases_folder + f"/reports/__cost_report__"
|
241
|
+
serial = config.serial
|
242
|
+
if not os.path.exists(export_path):
|
243
|
+
os.makedirs(export_path)
|
244
|
+
|
245
|
+
export_file_name = export_path + f"/report_{serial}.csv"
|
246
|
+
|
247
|
+
encoding = get_encoding(config.cost_ds_path)["encoding"]
|
248
|
+
temp_cost_df = pd.read_csv(config.cost_ds_path, encoding=encoding)
|
249
|
+
temp_cost_df.to_csv(export_file_name, index=False)
|
250
|
+
|
251
|
+
|
252
|
+
def max_input_tokens_allowed(text='', model_used='gpt-4o-mini', **kwargs):
|
253
|
+
|
254
|
+
def get_delta_verification(sim_cost, sim_ind_cost):
|
255
|
+
delta_cost = config.limit_cost - sim_cost
|
256
|
+
delta_individual_cost = config.limit_individual_cost - sim_ind_cost
|
257
|
+
logger.info(f"${delta_cost} for global and ${delta_individual_cost} for individual input cost left.")
|
258
|
+
return True if delta_cost <= 0 or delta_individual_cost <= 0 else False
|
259
|
+
|
260
|
+
if config.token_count_enabled:
|
261
|
+
if kwargs.get("image", None):
|
262
|
+
input_cost = input_vision_module_cost(text, kwargs.get("image", 0), model_used)
|
263
|
+
simulated_cost = input_cost + config.total_cost
|
264
|
+
simulated_individual_cost = input_cost + config.total_individual_cost
|
265
|
+
return get_delta_verification(simulated_cost, simulated_individual_cost)
|
266
|
+
elif model_used == "tts-1":
|
267
|
+
input_cost = input_tts_module_cost(text, model_used)
|
268
|
+
simulated_cost = input_cost + config.total_cost
|
269
|
+
simulated_individual_cost = input_cost + config.total_individual_cost
|
270
|
+
return get_delta_verification(simulated_cost, simulated_individual_cost)
|
271
|
+
elif model_used == "whisper":
|
272
|
+
input_cost = whisper_module_cost(kwargs.get("audio_length", 0), model_used)
|
273
|
+
simulated_cost = input_cost + config.total_cost
|
274
|
+
simulated_individual_cost = input_cost + config.total_individual_cost
|
275
|
+
return get_delta_verification(simulated_cost, simulated_individual_cost)
|
276
|
+
else:
|
277
|
+
input_cost = input_text_module_cost(text, model_used)
|
278
|
+
simulated_cost = input_cost + config.total_cost
|
279
|
+
simulated_individual_cost = input_cost + config.total_individual_cost
|
280
|
+
return get_delta_verification(simulated_cost, simulated_individual_cost)
|
281
|
+
else:
|
282
|
+
return False
|
283
|
+
|
284
|
+
def max_output_tokens_allowed(model_used):
|
285
|
+
if config.token_count_enabled:
|
286
|
+
delta_cost = config.limit_cost - config.total_cost
|
287
|
+
delta_individual_cost = config.limit_individual_cost - config.total_individual_cost
|
288
|
+
|
289
|
+
delta = min([delta_cost, delta_individual_cost])
|
290
|
+
output_tokens = round(delta * TOKENS[model_used]["output"])
|
291
|
+
|
292
|
+
|
293
|
+
if MAX_MODEL_TOKENS[model_used]<output_tokens:
|
294
|
+
output_tokens = MAX_MODEL_TOKENS[model_used]
|
295
|
+
|
296
|
+
logger.info(f"{output_tokens} output tokens left.")
|
297
|
+
return output_tokens
|
298
|
+
else:
|
299
|
+
return
|
300
|
+
|
301
|
+
|
302
|
+
def invoke_llm(llm, prompt, input_params, model, module, parser=False):
|
303
|
+
|
304
|
+
# Outputs input messages as text.
|
305
|
+
if isinstance(input_params, dict):
|
306
|
+
messages = list(input_params.values())
|
307
|
+
parsed_messages = " ".join(messages)
|
308
|
+
else:
|
309
|
+
parsed_messages = input_params
|
310
|
+
|
311
|
+
# Measures max input tokens allowed by the execution
|
312
|
+
if config.token_count_enabled and max_input_tokens_allowed(parsed_messages, model):
|
313
|
+
logger.error(f"Token limit was surpassed in {module} module")
|
314
|
+
return None
|
315
|
+
|
316
|
+
# Calculates the amount of tokens left and updates the LLM max_tokens parameter
|
317
|
+
if config.token_count_enabled:
|
318
|
+
llm.max_tokens = max_output_tokens_allowed(model)
|
319
|
+
|
320
|
+
# Enables str output parser
|
321
|
+
if parser:
|
322
|
+
parser = StrOutputParser()
|
323
|
+
llm_chain = prompt | llm | parser
|
324
|
+
else:
|
325
|
+
llm_chain = prompt | llm
|
326
|
+
|
327
|
+
# Invoke LLM
|
328
|
+
try:
|
329
|
+
response = llm_chain.invoke(input_params)
|
330
|
+
if config.token_count_enabled:
|
331
|
+
calculate_cost(parsed_messages, response, model, module="user_simulator")
|
332
|
+
except Exception as e:
|
333
|
+
logger.error(e)
|
334
|
+
response = None
|
335
|
+
if response is None and module == "user_simulator":
|
336
|
+
response = "exit"
|
337
|
+
|
338
|
+
return response
|
@@ -0,0 +1,60 @@
|
|
1
|
+
import re
|
2
|
+
from typing import List, Dict
|
3
|
+
from user_sim.handlers.pdf_parser_module import pdf_processor
|
4
|
+
from user_sim.handlers.image_recognition_module import image_description
|
5
|
+
from user_sim.handlers.html_parser_module import webpage_reader
|
6
|
+
|
7
|
+
|
8
|
+
def classify_links(message: str) -> Dict[str, List[str]]:
|
9
|
+
url_pattern = re.compile(r'https?://\S+') # Capture URLs
|
10
|
+
links = url_pattern.findall(message)
|
11
|
+
|
12
|
+
classified_links = {
|
13
|
+
"images": [],
|
14
|
+
"pdfs": [],
|
15
|
+
"webpages": []
|
16
|
+
}
|
17
|
+
|
18
|
+
for link in links:
|
19
|
+
if re.search(r'\.(jpg|jpeg|png|gif|webp|bmp|tiff)$', link, re.IGNORECASE) or '<image>' in message:
|
20
|
+
clean_link = re.sub(r'</?image>', '', link)
|
21
|
+
classified_links["images"].append(clean_link)
|
22
|
+
elif re.search(r'\.pdf$', link, re.IGNORECASE) or 'application/pdf' in message:
|
23
|
+
classified_links["pdfs"].append(link)
|
24
|
+
else:
|
25
|
+
classified_links["webpages"].append(link)
|
26
|
+
|
27
|
+
return classified_links
|
28
|
+
|
29
|
+
|
30
|
+
def process_with_llm(link: str, category) -> str:
|
31
|
+
|
32
|
+
if category == "pdfs":
|
33
|
+
description = pdf_processor(link)
|
34
|
+
message_replacement = f"{link} {description}"
|
35
|
+
return message_replacement
|
36
|
+
|
37
|
+
elif category == "images":
|
38
|
+
description = image_description(link, detailed=True)
|
39
|
+
message_replacement = f"{link} {description}"
|
40
|
+
return message_replacement
|
41
|
+
else:
|
42
|
+
description = webpage_reader(link)
|
43
|
+
message_replacement = f"{link} {description}"
|
44
|
+
return message_replacement
|
45
|
+
|
46
|
+
|
47
|
+
def get_content(message: str) -> str:
|
48
|
+
classified_links = classify_links(message)
|
49
|
+
for category in classified_links:
|
50
|
+
for link in classified_links[category]:
|
51
|
+
description = process_with_llm(link, category)
|
52
|
+
message = message.replace(link, description)
|
53
|
+
|
54
|
+
return message
|
55
|
+
|
56
|
+
|
57
|
+
# def clean_temp_files():
|
58
|
+
# clear_pdf_register()
|
59
|
+
# clear_image_register()
|
60
|
+
# clear_webpage_register()
|