user-simulator 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. user_sim/__init__.py +0 -0
  2. user_sim/cli/__init__.py +0 -0
  3. user_sim/cli/gen_user_profile.py +34 -0
  4. user_sim/cli/init_project.py +65 -0
  5. user_sim/cli/sensei_chat.py +481 -0
  6. user_sim/cli/sensei_check.py +103 -0
  7. user_sim/cli/validation_check.py +143 -0
  8. user_sim/core/__init__.py +0 -0
  9. user_sim/core/ask_about.py +665 -0
  10. user_sim/core/data_extraction.py +260 -0
  11. user_sim/core/data_gathering.py +134 -0
  12. user_sim/core/interaction_styles.py +147 -0
  13. user_sim/core/role_structure.py +608 -0
  14. user_sim/core/user_simulator.py +302 -0
  15. user_sim/handlers/__init__.py +0 -0
  16. user_sim/handlers/asr_module.py +128 -0
  17. user_sim/handlers/html_parser_module.py +202 -0
  18. user_sim/handlers/image_recognition_module.py +139 -0
  19. user_sim/handlers/pdf_parser_module.py +123 -0
  20. user_sim/utils/__init__.py +0 -0
  21. user_sim/utils/config.py +47 -0
  22. user_sim/utils/cost_tracker.py +153 -0
  23. user_sim/utils/cost_tracker_v2.py +193 -0
  24. user_sim/utils/errors.py +15 -0
  25. user_sim/utils/exceptions.py +47 -0
  26. user_sim/utils/languages.py +78 -0
  27. user_sim/utils/register_management.py +62 -0
  28. user_sim/utils/show_logs.py +63 -0
  29. user_sim/utils/token_cost_calculator.py +338 -0
  30. user_sim/utils/url_management.py +60 -0
  31. user_sim/utils/utilities.py +568 -0
  32. user_simulator-0.1.0.dist-info/METADATA +733 -0
  33. user_simulator-0.1.0.dist-info/RECORD +37 -0
  34. user_simulator-0.1.0.dist-info/WHEEL +5 -0
  35. user_simulator-0.1.0.dist-info/entry_points.txt +6 -0
  36. user_simulator-0.1.0.dist-info/licenses/LICENSE.txt +21 -0
  37. user_simulator-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,139 @@
1
+ import re
2
+ import logging
3
+ from langchain.schema.messages import HumanMessage, SystemMessage
4
+ from user_sim.utils.token_cost_calculator import calculate_cost, max_input_tokens_allowed, max_output_tokens_allowed
5
+ from user_sim.utils import config
6
+ from user_sim.utils.utilities import init_model
7
+ from user_sim.utils.register_management import save_register, load_register, hash_generate
8
+
9
+
10
+ logger = logging.getLogger('Info Logger')
11
+ model = None
12
+ llm = None
13
+
14
+
15
+ image_register_name = "image_register.json"
16
+
17
+ def init_vision_module():
18
+ global model
19
+ global llm
20
+ model, llm = init_model()
21
+
22
+
23
+
24
+ def generate_image_description(image, url=True, detailed=False):
25
+
26
+ if not url:
27
+ image_parsed = f"data:image/png;base64,{image.decode('utf-8')}"
28
+ else:
29
+ image_parsed = image
30
+
31
+ if detailed:
32
+ prompt = ("""
33
+ Describe in detail this image and its content.
34
+ If there's text, describe everything you read. don't give vague descriptions.
35
+ If there is content listed, read it as it is.
36
+ Be as detailed as possible.
37
+ """)
38
+ else:
39
+ prompt = "briefly describe this image, don't over explain, just give a simple and fast explanation of the main characteristics."
40
+
41
+ if llm is None:
42
+ logger.error("vision module not initialized.")
43
+ return "Empty data"
44
+
45
+ if max_input_tokens_allowed(prompt, model, image=image):
46
+ logger.error(f"Token limit was surpassed")
47
+ return None
48
+
49
+ message = HumanMessage(
50
+ content=[
51
+ {"type": "text", "text": prompt},
52
+ {
53
+ "type": "image_url",
54
+ "image_url": {
55
+ "url": image_parsed,
56
+ # "detail": "auto"
57
+ }
58
+ }
59
+ ]
60
+ )
61
+
62
+ try:
63
+ if config.token_count_enabled:
64
+ llm.max_tokens = max_output_tokens_allowed(model)
65
+ output = llm.invoke([message])
66
+ else:
67
+ output = llm.invoke([message])
68
+ output_text = f"(Image description: {output.content})"
69
+ except Exception as e:
70
+ logger.error(e)
71
+ logger.error("Couldn't get image description")
72
+ output_text = "Empty data"
73
+ logger.info(output_text)
74
+ if config.token_count_enabled:
75
+ calculate_cost(prompt, output_text, model=model, module="image recognition module", image=image)
76
+
77
+ return output_text
78
+
79
+ def image_description(image, detailed=False, url=True):
80
+ if config.ignore_cache:
81
+ register = {}
82
+ logger.info("Cache will be ignored.")
83
+ else:
84
+ register = load_register(image_register_name)
85
+
86
+ image_hash = hash_generate(content=image)
87
+
88
+ if image_hash in register:
89
+ if config.update_cache:
90
+ description = generate_image_description(image, url, detailed)
91
+ register[image_hash] = description
92
+ logger.info("Cache updated!")
93
+ # description = register[image_hash]
94
+ logger.info("Retrieved information from cache.")
95
+ return register[image_hash]
96
+ else:
97
+ description = generate_image_description(image, url)
98
+ register[image_hash] = description
99
+
100
+ if config.ignore_cache:
101
+ logger.info("Images cache was ignored")
102
+ else:
103
+ save_register(register, image_register_name)
104
+ logger.info("Images cache was saved!")
105
+
106
+ return description
107
+
108
+
109
+ def image_processor(text):
110
+
111
+ def get_images(phrase):
112
+ pattern = r"<image>(.*?)</image>"
113
+ matches = re.findall(pattern, phrase)
114
+ return matches
115
+
116
+ def replacer(match):
117
+ nonlocal replacement_index, descriptions
118
+ if replacement_index < len(descriptions):
119
+ original_image = match.group(1)
120
+ replacement = descriptions[replacement_index]
121
+ replacement_index += 1
122
+ return f"<image>{original_image}</image> {replacement}"
123
+ return match.group(0) # If no more replacements, return the original match
124
+
125
+ if text is None:
126
+ return text
127
+ else:
128
+ images = get_images(text)
129
+ if images:
130
+ descriptions = []
131
+ for image in images:
132
+ descriptions.append(image_description(image))
133
+
134
+ replacement_index = 0
135
+
136
+ result = re.sub(r"<image>(.*?)</image>", replacer, text)
137
+ return result
138
+ else:
139
+ return text
@@ -0,0 +1,123 @@
1
+ import fitz
2
+ import base64
3
+ import re
4
+ import logging
5
+ import os
6
+ import requests
7
+ from urllib.parse import urlparse
8
+ from user_sim.utils import config
9
+ from user_sim.utils.register_management import save_register, load_register, hash_generate
10
+ from user_sim.handlers.image_recognition_module import image_description
11
+
12
+ logger = logging.getLogger('Info Logger')
13
+ current_script_dir = os.path.dirname(os.path.abspath(__file__))
14
+ project_root = os.path.abspath(os.path.join(current_script_dir, "../../.."))
15
+ pdf_register_name = "pdf_register.json"
16
+
17
+
18
+ def pdf_reader(pdf):
19
+
20
+ if config.ignore_cache:
21
+ register = {}
22
+ logger.info("Cache will be ignored.")
23
+ else:
24
+ register = load_register(pdf_register_name)
25
+
26
+ pdf_hash = hash_generate(content_type="pdf", content=pdf)
27
+
28
+ def process_pdf(pdf_file):
29
+ doc = fitz.open(pdf_file)
30
+ plain_text = ""
31
+ for page_number in range(len(doc)):
32
+ page = doc.load_page(page_number)
33
+ plain_text += f"Page nª{page_number}: {page.get_text()} "
34
+
35
+ images = page.get_images(full=True)
36
+ if images:
37
+ plain_text += f"Images in this page: "
38
+ for img_index, img in enumerate(images):
39
+ xref = img[0]
40
+ base_image = doc.extract_image(xref)
41
+ image_bytes = base_image["image"]
42
+ image_base64 = base64.b64encode(image_bytes)
43
+ description = image_description(image_base64, detailed=False, url=False)
44
+ plain_text += f"Image description {img_index + 1}: {description}"
45
+ return f"(PDF content: {plain_text} >>)"
46
+
47
+ if pdf_hash in register:
48
+ if config.update_cache:
49
+ output_text = process_pdf(pdf)
50
+ register[pdf_hash] = output_text
51
+ logger.info("Cache updated!")
52
+ output_text = register[pdf_hash]
53
+ logger.info("Retrieved information from cache.")
54
+
55
+ else:
56
+ output_text = process_pdf(pdf)
57
+ register[pdf_hash] = output_text
58
+
59
+ if config.ignore_cache:
60
+ logger.info("PDF cache was ignored.")
61
+ else:
62
+ save_register(register, pdf_register_name)
63
+ logger.info("PDF cache was saved!")
64
+
65
+ logger.info(output_text)
66
+ return output_text
67
+
68
+
69
+ def get_pdf(url):
70
+ # response = requests.get(url)
71
+ headers = {"User-Agent": "Mozilla/5.0"}
72
+ response = requests.get(url, headers=headers)
73
+ if response.status_code != 200:
74
+ return f"Error accessing the page: {response.status_code}"
75
+
76
+ response.encoding = response.apparent_encoding
77
+ content_type = response.headers.get("Content-Type", "")
78
+
79
+ filename = None
80
+ content_disposition = response.headers.get('Content-Disposition', '')
81
+
82
+ if 'application/pdf' in content_type:
83
+ extension = ".pdf"
84
+
85
+ pdfs_dir = os.path.join(project_root, "data/pdfs")
86
+
87
+ if not os.path.exists(pdfs_dir):
88
+ os.makedirs(pdfs_dir)
89
+
90
+ if 'filename=' in content_disposition:
91
+ filename_match = re.search(r'filename="?([^"]+)"?', content_disposition)
92
+ if filename_match:
93
+ filename = filename_match.group(1)
94
+
95
+ if not filename:
96
+ parsed_url = urlparse(url)
97
+ filename = os.path.basename(parsed_url.path)
98
+ if not filename:
99
+ filename = 'pdf_download'
100
+ if extension and not filename.endswith(extension):
101
+ filename += extension
102
+
103
+ full_path = os.path.join(pdfs_dir, filename)
104
+ content = response.content
105
+
106
+ with open(full_path, 'wb') as f:
107
+ f.write(content)
108
+
109
+ return full_path
110
+
111
+ else:
112
+ return None
113
+
114
+
115
+
116
+ def pdf_processor(pdf_url):
117
+ if pdf_url is None:
118
+ return pdf_url
119
+ else:
120
+ pdf_path = get_pdf(pdf_url)
121
+ if pdf_path is not None:
122
+ description = pdf_reader(pdf_path)
123
+ return description
File without changes
@@ -0,0 +1,47 @@
1
+ # execution data
2
+ errors = []
3
+ conversation_name = ""
4
+ serial = ""
5
+ # model = "gpt-4o-mini"
6
+ cost_ds_path = None
7
+ test_name = ''
8
+ ignore_cache = False
9
+ update_cache = False
10
+ clean_cache = False
11
+
12
+
13
+ # project data
14
+ root_path = ""
15
+ project_folder_path = ""
16
+ profiles_path = ""
17
+ types_folder_path = ""
18
+ test_cases_folder = ""
19
+ types_dict = {}
20
+ custom_personalities_folder = ""
21
+
22
+
23
+ # cost metrics
24
+ token_count_enabled = False
25
+ limit_cost = 10000000000
26
+ limit_individual_cost = 10000000000
27
+ total_cost = 0
28
+ total_individual_cost = 0
29
+
30
+ #llm
31
+ model = "gpt-4o-mini"
32
+ model_provider = "openai"
33
+
34
+
35
+ # context
36
+ default_context = [
37
+ "You are a helpful user simulator that test chatbots.",
38
+ "Don't add starting sentences, for example 'Okay, here we go'. The first thing you say must be already in the role of a user"
39
+ "You must act like a user since the beginning of the conversation. "
40
+ "never recreate a whole conversation, just act like you're a user or client",
41
+ "never generate a message starting by 'user:'",
42
+ 'Sometimes, interact with what the assistant just said.',
43
+ 'Never act as the assistant, always behave as a user.',
44
+ "Don't end the conversation until you've asked everything you need.",
45
+ "you're testing a chatbot, so there can be random values or irrational things "
46
+ "in your requests"
47
+ ]
@@ -0,0 +1,153 @@
1
+ import os
2
+ import csv
3
+ import pandas as pd
4
+ import tiktoken
5
+
6
+ from user_sim.utils import config
7
+ from datetime import datetime
8
+ from typing import List, Union
9
+ from langchain.callbacks.base import BaseCallbackHandler
10
+ from langchain.schema import LLMResult, ChatGeneration, ChatResult
11
+ from langchain.chat_models import ChatModel
12
+ import logging
13
+
14
+ logger = logging.getLogger('Info Logger')
15
+
16
+ cost_rates = {
17
+ "gpt-3.5-turbo": {"prompt": 0.0015, "completion": 0.0020},
18
+ "gpt-4": {"prompt": 0.03, "completion": 0.06},
19
+ # later, if you try e.g. Anthropic:
20
+ "claude-v1": {"prompt": 0.0075, "completion": 0.015},
21
+ }
22
+
23
+ columns = ["Conversation",
24
+ "Test Name",
25
+ "Module",
26
+ "Model",
27
+ "Total Cost",
28
+ "Timestamp",
29
+ "Input Cost",
30
+ "Input Message",
31
+ "Output Cost",
32
+ "Output Message"]
33
+
34
+ def create_cost_dataset(serial, test_cases_folder):
35
+ folder = f"{test_cases_folder}/reports/__cost_reports__"
36
+ file = f"cost_report_{serial}.csv"
37
+
38
+ if not os.path.exists(folder):
39
+ os.makedirs(folder)
40
+ logger.info(f"Created cost report folder at: {folder}")
41
+
42
+ path = f"{folder}/{file}"
43
+ if not os.path.exists(path):
44
+ cost_df = pd.DataFrame(columns=columns)
45
+ cost_df.to_csv(path, index=False)
46
+ config.cost_ds_path = path
47
+ logger.info(f"Cost dataframe created at {path}.")
48
+
49
+ return path
50
+
51
+ def count_message_tokens(messages: List[dict], model_name: str) -> int:
52
+ # """Rough token count for a list of role/content dicts via tiktoken."""
53
+ # enc = tiktoken.encoding_for_model(model_name)
54
+ # total = 0
55
+ # for msg in messages:
56
+ # # include role token + content tokens
57
+ # total += len(enc.encode(msg["role"]))
58
+ # total += len(enc.encode(msg["content"]))
59
+ # return total
60
+ pass
61
+
62
+
63
+ class CostTrackingCallback(BaseCallbackHandler):
64
+ """
65
+ A LangChain callback that, on each LLM invocation,
66
+ 1. extracts token usage
67
+ 2. computes cost via a per-model rate table
68
+ 3. appends a row into a CSV in real time
69
+ """
70
+ def __init__(self):
71
+ """
72
+ csv_path: where to store the cost log
73
+ cost_rates: {
74
+ "<model_name>": {"prompt": <$/1K tokens>, "completion": <$/1K tokens>},
75
+ ...
76
+ }
77
+ """
78
+ path = create_cost_dataset(config.serial, )
79
+ self.serial = config.serial
80
+ self.csv_path = path
81
+ self.cost_rates = cost_rates
82
+
83
+ # If first time, write header
84
+ # if not os.path.exists(self.csv_path):
85
+ # with open(self.csv_path, 'w', newline='') as f:
86
+ # writer = csv.writer(f)
87
+ # writer.writerow([
88
+ # "timestamp",
89
+ # "model_name",
90
+ # "prompt_tokens",
91
+ # "completion_tokens",
92
+ # "total_tokens",
93
+ # "total_cost_usd",
94
+ # ])
95
+
96
+ def on_llm_end(self, result: LLMResult, **kwargs) -> None:
97
+ # 1. pull out usage (providers like OpenAI put it in result.llm_output["usage"])
98
+ usage = result.llm_output.get("usage", {})
99
+ p_tokens = usage.get("prompt_tokens", 0)
100
+ c_tokens = usage.get("completion_tokens", 0)
101
+ total = usage.get("total_tokens", p_tokens + c_tokens)
102
+
103
+ # 2. determine model name key
104
+ model_name = (
105
+ result.llm_output.get("model_name")
106
+ or getattr(result, "model_name", None)
107
+ or "unknown"
108
+ )
109
+
110
+ # 3. look up rates (per 1K tokens) and compute
111
+ rates = self.cost_rates.get(model_name, {})
112
+ cost = (p_tokens / 1000) * rates.get("prompt", 0.0) \
113
+ + (c_tokens / 1000) * rates.get("completion", 0.0)
114
+
115
+ # 4. append to CSV (real‐time)
116
+ with open(self.csv_path, 'a', newline='') as f:
117
+ writer = csv.writer(f)
118
+ writer.writerow([
119
+ datetime.utcnow().isoformat(),
120
+ model_name,
121
+ p_tokens,
122
+ c_tokens,
123
+ total,
124
+ f"{cost:.8f}",
125
+ ])
126
+
127
+ def budgeted_invoke(
128
+ self,
129
+ llm: ChatModel,
130
+ messages: List[dict],
131
+ **invoke_kwargs
132
+ ) -> ChatResult:
133
+ """
134
+ Checks your budget, then either:
135
+ • returns "" if spent up
136
+ • calls llm.invoke with max_tokens clipped to budget
137
+ """
138
+ model_name = llm.client.model_name # or however your ChatModel exposes it
139
+ max_toks = self.get_max_completion_tokens(messages, model_name)
140
+
141
+ if max_toks is None:
142
+ # unknown model: just invoke normally
143
+ return llm.invoke(messages, callbacks=[self], **invoke_kwargs)
144
+
145
+ if max_toks <= 0:
146
+ # out of money
147
+ raise BudgetExceeded(f"No budget remaining for model {model_name}")
148
+
149
+ # enforce our budget
150
+ invoke_kwargs.setdefault("config", {})
151
+ invoke_kwargs["config"]["max_tokens"] = max_toks
152
+
153
+ return llm.invoke(messages, callbacks=[self], **invoke_kwargs)
@@ -0,0 +1,193 @@
1
+ import os
2
+ import csv
3
+ import pandas as pd
4
+ from datetime import datetime
5
+ from typing import List, Union
6
+ import tiktoken
7
+
8
+ from langchain.callbacks.base import BaseCallbackHandler
9
+ from langchain.schema import LLMResult, ChatGeneration, ChatResult
10
+ from user_sim.utils import config
11
+ import logging
12
+
13
+ logger = logging.getLogger('Info Logger')
14
+
15
+ cost_rates = {
16
+ "gpt-3.5-turbo": {"prompt": 0.0015, "completion": 0.0020},
17
+ "gpt-4": {"prompt": 0.03, "completion": 0.06},
18
+ "gpt-4o-mini": {"prompt": 0.03, "completion": 0.06},
19
+ # later, if you try e.g. Anthropic:
20
+ "claude-v1": {"prompt": 0.0075, "completion": 0.015},
21
+ }
22
+
23
+ columns = ["Conversation",
24
+ "Test Name",
25
+ "Module",
26
+ "Model",
27
+ "Total Cost",
28
+ "Timestamp",
29
+ "Input Cost",
30
+ "Input Message",
31
+ "Output Cost",
32
+ "Output Message"]
33
+
34
+ def create_cost_dataset(serial, test_cases_folder):
35
+
36
+ if config.cost_ds_path is None:
37
+ folder = f"{test_cases_folder}/reports/__cost_reports__"
38
+ file = f"cost_report_{serial}.csv"
39
+ if not os.path.exists(folder):
40
+ os.makedirs(folder)
41
+ logger.info(f"Created cost report folder at: {folder}")
42
+
43
+ path = f"{folder}/{file}"
44
+
45
+ cost_df = pd.DataFrame(columns=columns)
46
+ cost_df.to_csv(path, index=False)
47
+ config.cost_ds_path = path
48
+ logger.info(f"Cost dataframe created at {path}.")
49
+ return path
50
+ else:
51
+ return config.cost_ds_path
52
+
53
+
54
+
55
+ class BudgetExceeded(Exception):
56
+ """Raised when there's no budget left for a new LLM call."""
57
+ pass
58
+
59
+ def count_message_tokens(messages: List[dict], model_name: str) -> int:
60
+ """Rough token count for a list of role/content dicts via tiktoken."""
61
+ enc = tiktoken.encoding_for_model(model_name)
62
+ total = 0
63
+ for msg in messages:
64
+ # include role token + content tokens
65
+ total += len(enc.encode(msg["role"]))
66
+ total += len(enc.encode(msg["content"]))
67
+ return total
68
+
69
+ class CostTrackingCallback(BaseCallbackHandler):
70
+ """
71
+ 1. Logs every call’s token usage and cost to CSV.
72
+ 2. Tracks cumulative spend against a total budget.
73
+ 3. Provides a `budgeted_invoke` helper to enforce the budget.
74
+ """
75
+
76
+ def __init__(
77
+ self,
78
+ # csv_path: str,
79
+ # cost_rates: dict,
80
+ # total_budget_usd: float
81
+ ):
82
+ """
83
+ csv_path: path to your log CSV
84
+ cost_rates: {"model": {"prompt": $/1K, "completion": $/1K}, …}
85
+ total_budget_usd: how many dollars you're willing to spend in total
86
+ """
87
+ path = create_cost_dataset(config.serial, config.test_cases_folder)
88
+ self.serial = config.serial
89
+ self.csv_path = path
90
+ self.cost_rates = cost_rates
91
+ self.total_budget = 5
92
+ self.spent = 0.0
93
+
94
+
95
+ # self.csv_path = csv_path
96
+ # self.cost_rates = cost_rates
97
+ # self.total_budget = total_budget_usd
98
+ # self.spent = 0.0
99
+
100
+ # write header if new
101
+ # if not os.path.exists(self.csv_path):
102
+ # with open(self.csv_path, "w", newline="") as f:
103
+ # w = csv.writer(f)
104
+ # w.writerow([
105
+ # "timestamp","model_name",
106
+ # "prompt_tokens","completion_tokens","total_tokens",
107
+ # "call_cost_usd","cumulative_spent_usd"
108
+ # ])
109
+
110
+ def on_llm_end(self, result: LLMResult, **kwargs) -> None:
111
+ # pull OpenAI‐style usage
112
+ usage = result.llm_output.get("usage", {})
113
+ p = usage.get("prompt_tokens", 0)
114
+ c = usage.get("completion_tokens", 0)
115
+ tot = usage.get("total_tokens", p + c)
116
+
117
+ # model lookup
118
+ model = result.llm_output.get("model_name") or getattr(result, "model_name", "unknown")
119
+ rates = self.cost_rates.get(model, {"prompt":0.0,"completion":0.0})
120
+ cost = (p/1000)*rates["prompt"] + (c/1000)*rates["completion"]
121
+
122
+ # update spend
123
+ self.spent += cost
124
+ print(self.spent)
125
+ # append row
126
+ with open(self.csv_path, "a", newline="") as f:
127
+ w = csv.writer(f)
128
+ w.writerow([
129
+ datetime.utcnow().isoformat(),
130
+ model, p, c, tot,
131
+ f"{cost:.8f}",
132
+ f"{self.spent:.8f}"
133
+ ])
134
+
135
+ def get_max_completion_tokens(self,
136
+ messages: List[dict],
137
+ model_name: str
138
+ ) -> Union[int, None]:
139
+ """
140
+ Based on your remaining budget, estimate how many completion tokens you can afford.
141
+ Returns None if model not in cost_rates.
142
+ """
143
+ if model_name not in self.cost_rates:
144
+ return None
145
+
146
+ rates = self.cost_rates[model_name]
147
+ # 1. count prompt tokens, cost them
148
+ p_tokens = count_message_tokens(messages, model_name)
149
+ cost_prompt = (p_tokens/1000) * rates["prompt"]
150
+ remaining = self.total_budget - self.spent - cost_prompt
151
+
152
+ # 2. if no budget left, signal
153
+ if remaining <= 0:
154
+ return 0
155
+
156
+ # 3. convert back into max tokens for completion
157
+ return int((remaining * 1000) / rates["completion"])
158
+
159
+
160
+
161
+ def budgeted_invoke(
162
+ llm,
163
+ callbacks,
164
+ **invoke_kwargs
165
+ ) -> ChatResult:
166
+ """
167
+ Checks your budget, then either:
168
+ • returns "" if spent up
169
+ • calls llm.invoke with max_tokens clipped to budget
170
+ """
171
+
172
+
173
+ model_name = llm.model_name # or however your ChatModel exposes it
174
+ max_toks = self.get_max_completion_tokens(messages, model_name)
175
+
176
+ invoke_kwargs.setdefault("config", {})
177
+ invoke_kwargs["config"].update({
178
+ "max_tokens": max_toks,
179
+ "callbacks": [self], # ← put it here
180
+ })
181
+ if max_toks is None:
182
+ # unknown model: just invoke normally
183
+ return llm.invoke(messages, **invoke_kwargs)
184
+
185
+ if max_toks <= 0:
186
+ # out of money
187
+ raise BudgetExceeded(f"No budget remaining for model {model_name}")
188
+
189
+ # enforce our budget
190
+ invoke_kwargs.setdefault("config", {})
191
+ invoke_kwargs["config"]["max_tokens"] = max_toks
192
+
193
+ return llm.invoke(messages, callbacks=[self], **invoke_kwargs)
@@ -0,0 +1,15 @@
1
+ GOAL_NOT_COMPLETED = 1001
2
+ NO_RESPONSE = 500
3
+ TIMEOUT = 504
4
+ EXCEEDED_LOOP_LIMIT = 1000
5
+ MAX_COST_EXCEEDED = 2000
6
+ MAX_INDIVIDUAL_COST_EXCEEDED = 2001
7
+
8
+ all_errors = {
9
+ 'goal_not_completed': GOAL_NOT_COMPLETED,
10
+ 'no_response': NO_RESPONSE,
11
+ 'timeout': TIMEOUT,
12
+ 'exceeded_loop_limit': EXCEEDED_LOOP_LIMIT,
13
+ 'max_cost_exceeded': MAX_COST_EXCEEDED,
14
+ 'max_individual_cost_exceeded': MAX_INDIVIDUAL_COST_EXCEEDED
15
+ }