user-simulator 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- user_sim/__init__.py +0 -0
- user_sim/cli/__init__.py +0 -0
- user_sim/cli/gen_user_profile.py +34 -0
- user_sim/cli/init_project.py +65 -0
- user_sim/cli/sensei_chat.py +481 -0
- user_sim/cli/sensei_check.py +103 -0
- user_sim/cli/validation_check.py +143 -0
- user_sim/core/__init__.py +0 -0
- user_sim/core/ask_about.py +665 -0
- user_sim/core/data_extraction.py +260 -0
- user_sim/core/data_gathering.py +134 -0
- user_sim/core/interaction_styles.py +147 -0
- user_sim/core/role_structure.py +608 -0
- user_sim/core/user_simulator.py +302 -0
- user_sim/handlers/__init__.py +0 -0
- user_sim/handlers/asr_module.py +128 -0
- user_sim/handlers/html_parser_module.py +202 -0
- user_sim/handlers/image_recognition_module.py +139 -0
- user_sim/handlers/pdf_parser_module.py +123 -0
- user_sim/utils/__init__.py +0 -0
- user_sim/utils/config.py +47 -0
- user_sim/utils/cost_tracker.py +153 -0
- user_sim/utils/cost_tracker_v2.py +193 -0
- user_sim/utils/errors.py +15 -0
- user_sim/utils/exceptions.py +47 -0
- user_sim/utils/languages.py +78 -0
- user_sim/utils/register_management.py +62 -0
- user_sim/utils/show_logs.py +63 -0
- user_sim/utils/token_cost_calculator.py +338 -0
- user_sim/utils/url_management.py +60 -0
- user_sim/utils/utilities.py +568 -0
- user_simulator-0.1.0.dist-info/METADATA +733 -0
- user_simulator-0.1.0.dist-info/RECORD +37 -0
- user_simulator-0.1.0.dist-info/WHEEL +5 -0
- user_simulator-0.1.0.dist-info/entry_points.txt +6 -0
- user_simulator-0.1.0.dist-info/licenses/LICENSE.txt +21 -0
- user_simulator-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,139 @@
|
|
1
|
+
import re
|
2
|
+
import logging
|
3
|
+
from langchain.schema.messages import HumanMessage, SystemMessage
|
4
|
+
from user_sim.utils.token_cost_calculator import calculate_cost, max_input_tokens_allowed, max_output_tokens_allowed
|
5
|
+
from user_sim.utils import config
|
6
|
+
from user_sim.utils.utilities import init_model
|
7
|
+
from user_sim.utils.register_management import save_register, load_register, hash_generate
|
8
|
+
|
9
|
+
|
10
|
+
logger = logging.getLogger('Info Logger')
|
11
|
+
model = None
|
12
|
+
llm = None
|
13
|
+
|
14
|
+
|
15
|
+
image_register_name = "image_register.json"
|
16
|
+
|
17
|
+
def init_vision_module():
|
18
|
+
global model
|
19
|
+
global llm
|
20
|
+
model, llm = init_model()
|
21
|
+
|
22
|
+
|
23
|
+
|
24
|
+
def generate_image_description(image, url=True, detailed=False):
|
25
|
+
|
26
|
+
if not url:
|
27
|
+
image_parsed = f"data:image/png;base64,{image.decode('utf-8')}"
|
28
|
+
else:
|
29
|
+
image_parsed = image
|
30
|
+
|
31
|
+
if detailed:
|
32
|
+
prompt = ("""
|
33
|
+
Describe in detail this image and its content.
|
34
|
+
If there's text, describe everything you read. don't give vague descriptions.
|
35
|
+
If there is content listed, read it as it is.
|
36
|
+
Be as detailed as possible.
|
37
|
+
""")
|
38
|
+
else:
|
39
|
+
prompt = "briefly describe this image, don't over explain, just give a simple and fast explanation of the main characteristics."
|
40
|
+
|
41
|
+
if llm is None:
|
42
|
+
logger.error("vision module not initialized.")
|
43
|
+
return "Empty data"
|
44
|
+
|
45
|
+
if max_input_tokens_allowed(prompt, model, image=image):
|
46
|
+
logger.error(f"Token limit was surpassed")
|
47
|
+
return None
|
48
|
+
|
49
|
+
message = HumanMessage(
|
50
|
+
content=[
|
51
|
+
{"type": "text", "text": prompt},
|
52
|
+
{
|
53
|
+
"type": "image_url",
|
54
|
+
"image_url": {
|
55
|
+
"url": image_parsed,
|
56
|
+
# "detail": "auto"
|
57
|
+
}
|
58
|
+
}
|
59
|
+
]
|
60
|
+
)
|
61
|
+
|
62
|
+
try:
|
63
|
+
if config.token_count_enabled:
|
64
|
+
llm.max_tokens = max_output_tokens_allowed(model)
|
65
|
+
output = llm.invoke([message])
|
66
|
+
else:
|
67
|
+
output = llm.invoke([message])
|
68
|
+
output_text = f"(Image description: {output.content})"
|
69
|
+
except Exception as e:
|
70
|
+
logger.error(e)
|
71
|
+
logger.error("Couldn't get image description")
|
72
|
+
output_text = "Empty data"
|
73
|
+
logger.info(output_text)
|
74
|
+
if config.token_count_enabled:
|
75
|
+
calculate_cost(prompt, output_text, model=model, module="image recognition module", image=image)
|
76
|
+
|
77
|
+
return output_text
|
78
|
+
|
79
|
+
def image_description(image, detailed=False, url=True):
|
80
|
+
if config.ignore_cache:
|
81
|
+
register = {}
|
82
|
+
logger.info("Cache will be ignored.")
|
83
|
+
else:
|
84
|
+
register = load_register(image_register_name)
|
85
|
+
|
86
|
+
image_hash = hash_generate(content=image)
|
87
|
+
|
88
|
+
if image_hash in register:
|
89
|
+
if config.update_cache:
|
90
|
+
description = generate_image_description(image, url, detailed)
|
91
|
+
register[image_hash] = description
|
92
|
+
logger.info("Cache updated!")
|
93
|
+
# description = register[image_hash]
|
94
|
+
logger.info("Retrieved information from cache.")
|
95
|
+
return register[image_hash]
|
96
|
+
else:
|
97
|
+
description = generate_image_description(image, url)
|
98
|
+
register[image_hash] = description
|
99
|
+
|
100
|
+
if config.ignore_cache:
|
101
|
+
logger.info("Images cache was ignored")
|
102
|
+
else:
|
103
|
+
save_register(register, image_register_name)
|
104
|
+
logger.info("Images cache was saved!")
|
105
|
+
|
106
|
+
return description
|
107
|
+
|
108
|
+
|
109
|
+
def image_processor(text):
|
110
|
+
|
111
|
+
def get_images(phrase):
|
112
|
+
pattern = r"<image>(.*?)</image>"
|
113
|
+
matches = re.findall(pattern, phrase)
|
114
|
+
return matches
|
115
|
+
|
116
|
+
def replacer(match):
|
117
|
+
nonlocal replacement_index, descriptions
|
118
|
+
if replacement_index < len(descriptions):
|
119
|
+
original_image = match.group(1)
|
120
|
+
replacement = descriptions[replacement_index]
|
121
|
+
replacement_index += 1
|
122
|
+
return f"<image>{original_image}</image> {replacement}"
|
123
|
+
return match.group(0) # If no more replacements, return the original match
|
124
|
+
|
125
|
+
if text is None:
|
126
|
+
return text
|
127
|
+
else:
|
128
|
+
images = get_images(text)
|
129
|
+
if images:
|
130
|
+
descriptions = []
|
131
|
+
for image in images:
|
132
|
+
descriptions.append(image_description(image))
|
133
|
+
|
134
|
+
replacement_index = 0
|
135
|
+
|
136
|
+
result = re.sub(r"<image>(.*?)</image>", replacer, text)
|
137
|
+
return result
|
138
|
+
else:
|
139
|
+
return text
|
@@ -0,0 +1,123 @@
|
|
1
|
+
import fitz
|
2
|
+
import base64
|
3
|
+
import re
|
4
|
+
import logging
|
5
|
+
import os
|
6
|
+
import requests
|
7
|
+
from urllib.parse import urlparse
|
8
|
+
from user_sim.utils import config
|
9
|
+
from user_sim.utils.register_management import save_register, load_register, hash_generate
|
10
|
+
from user_sim.handlers.image_recognition_module import image_description
|
11
|
+
|
12
|
+
logger = logging.getLogger('Info Logger')
|
13
|
+
current_script_dir = os.path.dirname(os.path.abspath(__file__))
|
14
|
+
project_root = os.path.abspath(os.path.join(current_script_dir, "../../.."))
|
15
|
+
pdf_register_name = "pdf_register.json"
|
16
|
+
|
17
|
+
|
18
|
+
def pdf_reader(pdf):
|
19
|
+
|
20
|
+
if config.ignore_cache:
|
21
|
+
register = {}
|
22
|
+
logger.info("Cache will be ignored.")
|
23
|
+
else:
|
24
|
+
register = load_register(pdf_register_name)
|
25
|
+
|
26
|
+
pdf_hash = hash_generate(content_type="pdf", content=pdf)
|
27
|
+
|
28
|
+
def process_pdf(pdf_file):
|
29
|
+
doc = fitz.open(pdf_file)
|
30
|
+
plain_text = ""
|
31
|
+
for page_number in range(len(doc)):
|
32
|
+
page = doc.load_page(page_number)
|
33
|
+
plain_text += f"Page nª{page_number}: {page.get_text()} "
|
34
|
+
|
35
|
+
images = page.get_images(full=True)
|
36
|
+
if images:
|
37
|
+
plain_text += f"Images in this page: "
|
38
|
+
for img_index, img in enumerate(images):
|
39
|
+
xref = img[0]
|
40
|
+
base_image = doc.extract_image(xref)
|
41
|
+
image_bytes = base_image["image"]
|
42
|
+
image_base64 = base64.b64encode(image_bytes)
|
43
|
+
description = image_description(image_base64, detailed=False, url=False)
|
44
|
+
plain_text += f"Image description {img_index + 1}: {description}"
|
45
|
+
return f"(PDF content: {plain_text} >>)"
|
46
|
+
|
47
|
+
if pdf_hash in register:
|
48
|
+
if config.update_cache:
|
49
|
+
output_text = process_pdf(pdf)
|
50
|
+
register[pdf_hash] = output_text
|
51
|
+
logger.info("Cache updated!")
|
52
|
+
output_text = register[pdf_hash]
|
53
|
+
logger.info("Retrieved information from cache.")
|
54
|
+
|
55
|
+
else:
|
56
|
+
output_text = process_pdf(pdf)
|
57
|
+
register[pdf_hash] = output_text
|
58
|
+
|
59
|
+
if config.ignore_cache:
|
60
|
+
logger.info("PDF cache was ignored.")
|
61
|
+
else:
|
62
|
+
save_register(register, pdf_register_name)
|
63
|
+
logger.info("PDF cache was saved!")
|
64
|
+
|
65
|
+
logger.info(output_text)
|
66
|
+
return output_text
|
67
|
+
|
68
|
+
|
69
|
+
def get_pdf(url):
|
70
|
+
# response = requests.get(url)
|
71
|
+
headers = {"User-Agent": "Mozilla/5.0"}
|
72
|
+
response = requests.get(url, headers=headers)
|
73
|
+
if response.status_code != 200:
|
74
|
+
return f"Error accessing the page: {response.status_code}"
|
75
|
+
|
76
|
+
response.encoding = response.apparent_encoding
|
77
|
+
content_type = response.headers.get("Content-Type", "")
|
78
|
+
|
79
|
+
filename = None
|
80
|
+
content_disposition = response.headers.get('Content-Disposition', '')
|
81
|
+
|
82
|
+
if 'application/pdf' in content_type:
|
83
|
+
extension = ".pdf"
|
84
|
+
|
85
|
+
pdfs_dir = os.path.join(project_root, "data/pdfs")
|
86
|
+
|
87
|
+
if not os.path.exists(pdfs_dir):
|
88
|
+
os.makedirs(pdfs_dir)
|
89
|
+
|
90
|
+
if 'filename=' in content_disposition:
|
91
|
+
filename_match = re.search(r'filename="?([^"]+)"?', content_disposition)
|
92
|
+
if filename_match:
|
93
|
+
filename = filename_match.group(1)
|
94
|
+
|
95
|
+
if not filename:
|
96
|
+
parsed_url = urlparse(url)
|
97
|
+
filename = os.path.basename(parsed_url.path)
|
98
|
+
if not filename:
|
99
|
+
filename = 'pdf_download'
|
100
|
+
if extension and not filename.endswith(extension):
|
101
|
+
filename += extension
|
102
|
+
|
103
|
+
full_path = os.path.join(pdfs_dir, filename)
|
104
|
+
content = response.content
|
105
|
+
|
106
|
+
with open(full_path, 'wb') as f:
|
107
|
+
f.write(content)
|
108
|
+
|
109
|
+
return full_path
|
110
|
+
|
111
|
+
else:
|
112
|
+
return None
|
113
|
+
|
114
|
+
|
115
|
+
|
116
|
+
def pdf_processor(pdf_url):
|
117
|
+
if pdf_url is None:
|
118
|
+
return pdf_url
|
119
|
+
else:
|
120
|
+
pdf_path = get_pdf(pdf_url)
|
121
|
+
if pdf_path is not None:
|
122
|
+
description = pdf_reader(pdf_path)
|
123
|
+
return description
|
File without changes
|
user_sim/utils/config.py
ADDED
@@ -0,0 +1,47 @@
|
|
1
|
+
# execution data
|
2
|
+
errors = []
|
3
|
+
conversation_name = ""
|
4
|
+
serial = ""
|
5
|
+
# model = "gpt-4o-mini"
|
6
|
+
cost_ds_path = None
|
7
|
+
test_name = ''
|
8
|
+
ignore_cache = False
|
9
|
+
update_cache = False
|
10
|
+
clean_cache = False
|
11
|
+
|
12
|
+
|
13
|
+
# project data
|
14
|
+
root_path = ""
|
15
|
+
project_folder_path = ""
|
16
|
+
profiles_path = ""
|
17
|
+
types_folder_path = ""
|
18
|
+
test_cases_folder = ""
|
19
|
+
types_dict = {}
|
20
|
+
custom_personalities_folder = ""
|
21
|
+
|
22
|
+
|
23
|
+
# cost metrics
|
24
|
+
token_count_enabled = False
|
25
|
+
limit_cost = 10000000000
|
26
|
+
limit_individual_cost = 10000000000
|
27
|
+
total_cost = 0
|
28
|
+
total_individual_cost = 0
|
29
|
+
|
30
|
+
#llm
|
31
|
+
model = "gpt-4o-mini"
|
32
|
+
model_provider = "openai"
|
33
|
+
|
34
|
+
|
35
|
+
# context
|
36
|
+
default_context = [
|
37
|
+
"You are a helpful user simulator that test chatbots.",
|
38
|
+
"Don't add starting sentences, for example 'Okay, here we go'. The first thing you say must be already in the role of a user"
|
39
|
+
"You must act like a user since the beginning of the conversation. "
|
40
|
+
"never recreate a whole conversation, just act like you're a user or client",
|
41
|
+
"never generate a message starting by 'user:'",
|
42
|
+
'Sometimes, interact with what the assistant just said.',
|
43
|
+
'Never act as the assistant, always behave as a user.',
|
44
|
+
"Don't end the conversation until you've asked everything you need.",
|
45
|
+
"you're testing a chatbot, so there can be random values or irrational things "
|
46
|
+
"in your requests"
|
47
|
+
]
|
@@ -0,0 +1,153 @@
|
|
1
|
+
import os
|
2
|
+
import csv
|
3
|
+
import pandas as pd
|
4
|
+
import tiktoken
|
5
|
+
|
6
|
+
from user_sim.utils import config
|
7
|
+
from datetime import datetime
|
8
|
+
from typing import List, Union
|
9
|
+
from langchain.callbacks.base import BaseCallbackHandler
|
10
|
+
from langchain.schema import LLMResult, ChatGeneration, ChatResult
|
11
|
+
from langchain.chat_models import ChatModel
|
12
|
+
import logging
|
13
|
+
|
14
|
+
logger = logging.getLogger('Info Logger')
|
15
|
+
|
16
|
+
cost_rates = {
|
17
|
+
"gpt-3.5-turbo": {"prompt": 0.0015, "completion": 0.0020},
|
18
|
+
"gpt-4": {"prompt": 0.03, "completion": 0.06},
|
19
|
+
# later, if you try e.g. Anthropic:
|
20
|
+
"claude-v1": {"prompt": 0.0075, "completion": 0.015},
|
21
|
+
}
|
22
|
+
|
23
|
+
columns = ["Conversation",
|
24
|
+
"Test Name",
|
25
|
+
"Module",
|
26
|
+
"Model",
|
27
|
+
"Total Cost",
|
28
|
+
"Timestamp",
|
29
|
+
"Input Cost",
|
30
|
+
"Input Message",
|
31
|
+
"Output Cost",
|
32
|
+
"Output Message"]
|
33
|
+
|
34
|
+
def create_cost_dataset(serial, test_cases_folder):
|
35
|
+
folder = f"{test_cases_folder}/reports/__cost_reports__"
|
36
|
+
file = f"cost_report_{serial}.csv"
|
37
|
+
|
38
|
+
if not os.path.exists(folder):
|
39
|
+
os.makedirs(folder)
|
40
|
+
logger.info(f"Created cost report folder at: {folder}")
|
41
|
+
|
42
|
+
path = f"{folder}/{file}"
|
43
|
+
if not os.path.exists(path):
|
44
|
+
cost_df = pd.DataFrame(columns=columns)
|
45
|
+
cost_df.to_csv(path, index=False)
|
46
|
+
config.cost_ds_path = path
|
47
|
+
logger.info(f"Cost dataframe created at {path}.")
|
48
|
+
|
49
|
+
return path
|
50
|
+
|
51
|
+
def count_message_tokens(messages: List[dict], model_name: str) -> int:
|
52
|
+
# """Rough token count for a list of role/content dicts via tiktoken."""
|
53
|
+
# enc = tiktoken.encoding_for_model(model_name)
|
54
|
+
# total = 0
|
55
|
+
# for msg in messages:
|
56
|
+
# # include role token + content tokens
|
57
|
+
# total += len(enc.encode(msg["role"]))
|
58
|
+
# total += len(enc.encode(msg["content"]))
|
59
|
+
# return total
|
60
|
+
pass
|
61
|
+
|
62
|
+
|
63
|
+
class CostTrackingCallback(BaseCallbackHandler):
|
64
|
+
"""
|
65
|
+
A LangChain callback that, on each LLM invocation,
|
66
|
+
1. extracts token usage
|
67
|
+
2. computes cost via a per-model rate table
|
68
|
+
3. appends a row into a CSV in real time
|
69
|
+
"""
|
70
|
+
def __init__(self):
|
71
|
+
"""
|
72
|
+
csv_path: where to store the cost log
|
73
|
+
cost_rates: {
|
74
|
+
"<model_name>": {"prompt": <$/1K tokens>, "completion": <$/1K tokens>},
|
75
|
+
...
|
76
|
+
}
|
77
|
+
"""
|
78
|
+
path = create_cost_dataset(config.serial, )
|
79
|
+
self.serial = config.serial
|
80
|
+
self.csv_path = path
|
81
|
+
self.cost_rates = cost_rates
|
82
|
+
|
83
|
+
# If first time, write header
|
84
|
+
# if not os.path.exists(self.csv_path):
|
85
|
+
# with open(self.csv_path, 'w', newline='') as f:
|
86
|
+
# writer = csv.writer(f)
|
87
|
+
# writer.writerow([
|
88
|
+
# "timestamp",
|
89
|
+
# "model_name",
|
90
|
+
# "prompt_tokens",
|
91
|
+
# "completion_tokens",
|
92
|
+
# "total_tokens",
|
93
|
+
# "total_cost_usd",
|
94
|
+
# ])
|
95
|
+
|
96
|
+
def on_llm_end(self, result: LLMResult, **kwargs) -> None:
|
97
|
+
# 1. pull out usage (providers like OpenAI put it in result.llm_output["usage"])
|
98
|
+
usage = result.llm_output.get("usage", {})
|
99
|
+
p_tokens = usage.get("prompt_tokens", 0)
|
100
|
+
c_tokens = usage.get("completion_tokens", 0)
|
101
|
+
total = usage.get("total_tokens", p_tokens + c_tokens)
|
102
|
+
|
103
|
+
# 2. determine model name key
|
104
|
+
model_name = (
|
105
|
+
result.llm_output.get("model_name")
|
106
|
+
or getattr(result, "model_name", None)
|
107
|
+
or "unknown"
|
108
|
+
)
|
109
|
+
|
110
|
+
# 3. look up rates (per 1K tokens) and compute
|
111
|
+
rates = self.cost_rates.get(model_name, {})
|
112
|
+
cost = (p_tokens / 1000) * rates.get("prompt", 0.0) \
|
113
|
+
+ (c_tokens / 1000) * rates.get("completion", 0.0)
|
114
|
+
|
115
|
+
# 4. append to CSV (real‐time)
|
116
|
+
with open(self.csv_path, 'a', newline='') as f:
|
117
|
+
writer = csv.writer(f)
|
118
|
+
writer.writerow([
|
119
|
+
datetime.utcnow().isoformat(),
|
120
|
+
model_name,
|
121
|
+
p_tokens,
|
122
|
+
c_tokens,
|
123
|
+
total,
|
124
|
+
f"{cost:.8f}",
|
125
|
+
])
|
126
|
+
|
127
|
+
def budgeted_invoke(
|
128
|
+
self,
|
129
|
+
llm: ChatModel,
|
130
|
+
messages: List[dict],
|
131
|
+
**invoke_kwargs
|
132
|
+
) -> ChatResult:
|
133
|
+
"""
|
134
|
+
Checks your budget, then either:
|
135
|
+
• returns "" if spent up
|
136
|
+
• calls llm.invoke with max_tokens clipped to budget
|
137
|
+
"""
|
138
|
+
model_name = llm.client.model_name # or however your ChatModel exposes it
|
139
|
+
max_toks = self.get_max_completion_tokens(messages, model_name)
|
140
|
+
|
141
|
+
if max_toks is None:
|
142
|
+
# unknown model: just invoke normally
|
143
|
+
return llm.invoke(messages, callbacks=[self], **invoke_kwargs)
|
144
|
+
|
145
|
+
if max_toks <= 0:
|
146
|
+
# out of money
|
147
|
+
raise BudgetExceeded(f"No budget remaining for model {model_name}")
|
148
|
+
|
149
|
+
# enforce our budget
|
150
|
+
invoke_kwargs.setdefault("config", {})
|
151
|
+
invoke_kwargs["config"]["max_tokens"] = max_toks
|
152
|
+
|
153
|
+
return llm.invoke(messages, callbacks=[self], **invoke_kwargs)
|
@@ -0,0 +1,193 @@
|
|
1
|
+
import os
|
2
|
+
import csv
|
3
|
+
import pandas as pd
|
4
|
+
from datetime import datetime
|
5
|
+
from typing import List, Union
|
6
|
+
import tiktoken
|
7
|
+
|
8
|
+
from langchain.callbacks.base import BaseCallbackHandler
|
9
|
+
from langchain.schema import LLMResult, ChatGeneration, ChatResult
|
10
|
+
from user_sim.utils import config
|
11
|
+
import logging
|
12
|
+
|
13
|
+
logger = logging.getLogger('Info Logger')
|
14
|
+
|
15
|
+
cost_rates = {
|
16
|
+
"gpt-3.5-turbo": {"prompt": 0.0015, "completion": 0.0020},
|
17
|
+
"gpt-4": {"prompt": 0.03, "completion": 0.06},
|
18
|
+
"gpt-4o-mini": {"prompt": 0.03, "completion": 0.06},
|
19
|
+
# later, if you try e.g. Anthropic:
|
20
|
+
"claude-v1": {"prompt": 0.0075, "completion": 0.015},
|
21
|
+
}
|
22
|
+
|
23
|
+
columns = ["Conversation",
|
24
|
+
"Test Name",
|
25
|
+
"Module",
|
26
|
+
"Model",
|
27
|
+
"Total Cost",
|
28
|
+
"Timestamp",
|
29
|
+
"Input Cost",
|
30
|
+
"Input Message",
|
31
|
+
"Output Cost",
|
32
|
+
"Output Message"]
|
33
|
+
|
34
|
+
def create_cost_dataset(serial, test_cases_folder):
|
35
|
+
|
36
|
+
if config.cost_ds_path is None:
|
37
|
+
folder = f"{test_cases_folder}/reports/__cost_reports__"
|
38
|
+
file = f"cost_report_{serial}.csv"
|
39
|
+
if not os.path.exists(folder):
|
40
|
+
os.makedirs(folder)
|
41
|
+
logger.info(f"Created cost report folder at: {folder}")
|
42
|
+
|
43
|
+
path = f"{folder}/{file}"
|
44
|
+
|
45
|
+
cost_df = pd.DataFrame(columns=columns)
|
46
|
+
cost_df.to_csv(path, index=False)
|
47
|
+
config.cost_ds_path = path
|
48
|
+
logger.info(f"Cost dataframe created at {path}.")
|
49
|
+
return path
|
50
|
+
else:
|
51
|
+
return config.cost_ds_path
|
52
|
+
|
53
|
+
|
54
|
+
|
55
|
+
class BudgetExceeded(Exception):
|
56
|
+
"""Raised when there's no budget left for a new LLM call."""
|
57
|
+
pass
|
58
|
+
|
59
|
+
def count_message_tokens(messages: List[dict], model_name: str) -> int:
|
60
|
+
"""Rough token count for a list of role/content dicts via tiktoken."""
|
61
|
+
enc = tiktoken.encoding_for_model(model_name)
|
62
|
+
total = 0
|
63
|
+
for msg in messages:
|
64
|
+
# include role token + content tokens
|
65
|
+
total += len(enc.encode(msg["role"]))
|
66
|
+
total += len(enc.encode(msg["content"]))
|
67
|
+
return total
|
68
|
+
|
69
|
+
class CostTrackingCallback(BaseCallbackHandler):
|
70
|
+
"""
|
71
|
+
1. Logs every call’s token usage and cost to CSV.
|
72
|
+
2. Tracks cumulative spend against a total budget.
|
73
|
+
3. Provides a `budgeted_invoke` helper to enforce the budget.
|
74
|
+
"""
|
75
|
+
|
76
|
+
def __init__(
|
77
|
+
self,
|
78
|
+
# csv_path: str,
|
79
|
+
# cost_rates: dict,
|
80
|
+
# total_budget_usd: float
|
81
|
+
):
|
82
|
+
"""
|
83
|
+
csv_path: path to your log CSV
|
84
|
+
cost_rates: {"model": {"prompt": $/1K, "completion": $/1K}, …}
|
85
|
+
total_budget_usd: how many dollars you're willing to spend in total
|
86
|
+
"""
|
87
|
+
path = create_cost_dataset(config.serial, config.test_cases_folder)
|
88
|
+
self.serial = config.serial
|
89
|
+
self.csv_path = path
|
90
|
+
self.cost_rates = cost_rates
|
91
|
+
self.total_budget = 5
|
92
|
+
self.spent = 0.0
|
93
|
+
|
94
|
+
|
95
|
+
# self.csv_path = csv_path
|
96
|
+
# self.cost_rates = cost_rates
|
97
|
+
# self.total_budget = total_budget_usd
|
98
|
+
# self.spent = 0.0
|
99
|
+
|
100
|
+
# write header if new
|
101
|
+
# if not os.path.exists(self.csv_path):
|
102
|
+
# with open(self.csv_path, "w", newline="") as f:
|
103
|
+
# w = csv.writer(f)
|
104
|
+
# w.writerow([
|
105
|
+
# "timestamp","model_name",
|
106
|
+
# "prompt_tokens","completion_tokens","total_tokens",
|
107
|
+
# "call_cost_usd","cumulative_spent_usd"
|
108
|
+
# ])
|
109
|
+
|
110
|
+
def on_llm_end(self, result: LLMResult, **kwargs) -> None:
|
111
|
+
# pull OpenAI‐style usage
|
112
|
+
usage = result.llm_output.get("usage", {})
|
113
|
+
p = usage.get("prompt_tokens", 0)
|
114
|
+
c = usage.get("completion_tokens", 0)
|
115
|
+
tot = usage.get("total_tokens", p + c)
|
116
|
+
|
117
|
+
# model lookup
|
118
|
+
model = result.llm_output.get("model_name") or getattr(result, "model_name", "unknown")
|
119
|
+
rates = self.cost_rates.get(model, {"prompt":0.0,"completion":0.0})
|
120
|
+
cost = (p/1000)*rates["prompt"] + (c/1000)*rates["completion"]
|
121
|
+
|
122
|
+
# update spend
|
123
|
+
self.spent += cost
|
124
|
+
print(self.spent)
|
125
|
+
# append row
|
126
|
+
with open(self.csv_path, "a", newline="") as f:
|
127
|
+
w = csv.writer(f)
|
128
|
+
w.writerow([
|
129
|
+
datetime.utcnow().isoformat(),
|
130
|
+
model, p, c, tot,
|
131
|
+
f"{cost:.8f}",
|
132
|
+
f"{self.spent:.8f}"
|
133
|
+
])
|
134
|
+
|
135
|
+
def get_max_completion_tokens(self,
|
136
|
+
messages: List[dict],
|
137
|
+
model_name: str
|
138
|
+
) -> Union[int, None]:
|
139
|
+
"""
|
140
|
+
Based on your remaining budget, estimate how many completion tokens you can afford.
|
141
|
+
Returns None if model not in cost_rates.
|
142
|
+
"""
|
143
|
+
if model_name not in self.cost_rates:
|
144
|
+
return None
|
145
|
+
|
146
|
+
rates = self.cost_rates[model_name]
|
147
|
+
# 1. count prompt tokens, cost them
|
148
|
+
p_tokens = count_message_tokens(messages, model_name)
|
149
|
+
cost_prompt = (p_tokens/1000) * rates["prompt"]
|
150
|
+
remaining = self.total_budget - self.spent - cost_prompt
|
151
|
+
|
152
|
+
# 2. if no budget left, signal
|
153
|
+
if remaining <= 0:
|
154
|
+
return 0
|
155
|
+
|
156
|
+
# 3. convert back into max tokens for completion
|
157
|
+
return int((remaining * 1000) / rates["completion"])
|
158
|
+
|
159
|
+
|
160
|
+
|
161
|
+
def budgeted_invoke(
|
162
|
+
llm,
|
163
|
+
callbacks,
|
164
|
+
**invoke_kwargs
|
165
|
+
) -> ChatResult:
|
166
|
+
"""
|
167
|
+
Checks your budget, then either:
|
168
|
+
• returns "" if spent up
|
169
|
+
• calls llm.invoke with max_tokens clipped to budget
|
170
|
+
"""
|
171
|
+
|
172
|
+
|
173
|
+
model_name = llm.model_name # or however your ChatModel exposes it
|
174
|
+
max_toks = self.get_max_completion_tokens(messages, model_name)
|
175
|
+
|
176
|
+
invoke_kwargs.setdefault("config", {})
|
177
|
+
invoke_kwargs["config"].update({
|
178
|
+
"max_tokens": max_toks,
|
179
|
+
"callbacks": [self], # ← put it here
|
180
|
+
})
|
181
|
+
if max_toks is None:
|
182
|
+
# unknown model: just invoke normally
|
183
|
+
return llm.invoke(messages, **invoke_kwargs)
|
184
|
+
|
185
|
+
if max_toks <= 0:
|
186
|
+
# out of money
|
187
|
+
raise BudgetExceeded(f"No budget remaining for model {model_name}")
|
188
|
+
|
189
|
+
# enforce our budget
|
190
|
+
invoke_kwargs.setdefault("config", {})
|
191
|
+
invoke_kwargs["config"]["max_tokens"] = max_toks
|
192
|
+
|
193
|
+
return llm.invoke(messages, callbacks=[self], **invoke_kwargs)
|
user_sim/utils/errors.py
ADDED
@@ -0,0 +1,15 @@
|
|
1
|
+
GOAL_NOT_COMPLETED = 1001
|
2
|
+
NO_RESPONSE = 500
|
3
|
+
TIMEOUT = 504
|
4
|
+
EXCEEDED_LOOP_LIMIT = 1000
|
5
|
+
MAX_COST_EXCEEDED = 2000
|
6
|
+
MAX_INDIVIDUAL_COST_EXCEEDED = 2001
|
7
|
+
|
8
|
+
all_errors = {
|
9
|
+
'goal_not_completed': GOAL_NOT_COMPLETED,
|
10
|
+
'no_response': NO_RESPONSE,
|
11
|
+
'timeout': TIMEOUT,
|
12
|
+
'exceeded_loop_limit': EXCEEDED_LOOP_LIMIT,
|
13
|
+
'max_cost_exceeded': MAX_COST_EXCEEDED,
|
14
|
+
'max_individual_cost_exceeded': MAX_INDIVIDUAL_COST_EXCEEDED
|
15
|
+
}
|