wcgw 5.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,404 @@
1
+ import base64
2
+ import json
3
+ import mimetypes
4
+ import os
5
+ import subprocess
6
+ import tempfile
7
+ import traceback
8
+ import uuid
9
+ from pathlib import Path
10
+ from typing import DefaultDict, Optional, cast
11
+
12
+ import openai
13
+ import petname # type: ignore[import-untyped]
14
+ import rich
15
+ import tokenizers # type: ignore[import-untyped]
16
+ from dotenv import load_dotenv
17
+ from openai import OpenAI
18
+ from openai.types.chat import (
19
+ ChatCompletionContentPartParam,
20
+ ChatCompletionMessageParam,
21
+ ChatCompletionUserMessageParam,
22
+ )
23
+ from pydantic import BaseModel
24
+ from typer import Typer
25
+
26
+ from wcgw.client.bash_state.bash_state import BashState
27
+ from wcgw.client.common import CostData, History, Models, discard_input
28
+ from wcgw.client.memory import load_memory
29
+ from wcgw.client.tool_prompts import TOOL_PROMPTS
30
+ from wcgw.client.tools import (
31
+ Context,
32
+ ImageData,
33
+ default_enc,
34
+ get_tool_output,
35
+ initialize,
36
+ which_tool,
37
+ which_tool_name,
38
+ )
39
+
40
+ from .openai_utils import get_input_cost, get_output_cost
41
+
42
+
43
+ class Config(BaseModel):
44
+ model: Models
45
+ cost_limit: float
46
+ cost_file: dict[Models, CostData]
47
+ cost_unit: str = "$"
48
+
49
+
50
+ def text_from_editor(console: rich.console.Console) -> str:
51
+ # First consume all the input till now
52
+ discard_input()
53
+ console.print("\n---------------------------------------\n# User message")
54
+ data = input()
55
+ if data:
56
+ return data
57
+ editor = os.environ.get("EDITOR", "vim")
58
+ with tempfile.NamedTemporaryFile(suffix=".tmp") as tf:
59
+ subprocess.run([editor, tf.name], check=True)
60
+ with open(tf.name, "r") as f:
61
+ data = f.read()
62
+ console.print(data)
63
+ return data
64
+
65
+
66
+ def save_history(history: History, session_id: str) -> None:
67
+ myid = str(history[1]["content"]).replace("/", "_").replace(" ", "_").lower()[:60]
68
+ myid += "_" + session_id
69
+ myid = myid + ".json"
70
+
71
+ mypath = Path(".wcgw") / myid
72
+ mypath.parent.mkdir(parents=True, exist_ok=True)
73
+ with open(mypath, "w") as f:
74
+ json.dump(history, f, indent=3)
75
+
76
+
77
+ def parse_user_message_special(msg: str) -> ChatCompletionUserMessageParam:
78
+ # Search for lines starting with `%` and treat them as special commands
79
+ parts: list[ChatCompletionContentPartParam] = []
80
+ for line in msg.split("\n"):
81
+ if line.startswith("%"):
82
+ args = line[1:].strip().split(" ")
83
+ command = args[0]
84
+ assert command == "image"
85
+ image_path = " ".join(args[1:])
86
+ with open(image_path, "rb") as f:
87
+ image_bytes = f.read()
88
+ image_b64 = base64.b64encode(image_bytes).decode("utf-8")
89
+ image_type = mimetypes.guess_type(image_path)[0]
90
+ dataurl = f"data:{image_type};base64,{image_b64}"
91
+ parts.append(
92
+ {"type": "image_url", "image_url": {"url": dataurl, "detail": "auto"}}
93
+ )
94
+ else:
95
+ if len(parts) > 0 and parts[-1]["type"] == "text":
96
+ parts[-1]["text"] += "\n" + line
97
+ else:
98
+ parts.append({"type": "text", "text": line})
99
+ return {"role": "user", "content": parts}
100
+
101
+
102
+ app = Typer(pretty_exceptions_show_locals=False)
103
+
104
+
105
+ @app.command()
106
+ def loop(
107
+ first_message: Optional[str] = None,
108
+ limit: Optional[float] = None,
109
+ resume: Optional[str] = None,
110
+ ) -> tuple[str, float]:
111
+ load_dotenv()
112
+
113
+ session_id = str(uuid.uuid4())[:6]
114
+
115
+ history: History = []
116
+ waiting_for_assistant = False
117
+
118
+ memory = None
119
+ if resume:
120
+ try:
121
+ _, memory, _ = load_memory(
122
+ resume,
123
+ 24000, # coding_max_tokens
124
+ 8000, # noncoding_max_tokens
125
+ lambda x: default_enc.encoder(x),
126
+ lambda x: default_enc.decoder(x),
127
+ )
128
+ except OSError:
129
+ if resume == "latest":
130
+ resume_path = sorted(Path(".wcgw").iterdir(), key=os.path.getmtime)[-1]
131
+ else:
132
+ resume_path = Path(resume)
133
+ if not resume_path.exists():
134
+ raise FileNotFoundError(f"File {resume} not found")
135
+ with resume_path.open() as f:
136
+ history = json.load(f)
137
+ if len(history) <= 2:
138
+ raise ValueError("Invalid history file")
139
+ first_message = ""
140
+ waiting_for_assistant = history[-1]["role"] != "assistant"
141
+
142
+ my_dir = os.path.dirname(__file__)
143
+
144
+ config = Config(
145
+ model=cast(Models, os.getenv("OPENAI_MODEL", "gpt-4o-2024-08-06").lower()),
146
+ cost_limit=0.1,
147
+ cost_unit="$",
148
+ cost_file={
149
+ "gpt-4o-2024-08-06": CostData(
150
+ cost_per_1m_input_tokens=5, cost_per_1m_output_tokens=15
151
+ ),
152
+ },
153
+ )
154
+
155
+ if limit is not None:
156
+ config.cost_limit = limit
157
+ limit = config.cost_limit
158
+
159
+ enc = tokenizers.Tokenizer.from_pretrained("Xenova/gpt-4o")
160
+
161
+ tools = [
162
+ openai.pydantic_function_tool(
163
+ which_tool_name(tool.name), description=tool.description
164
+ )
165
+ for tool in TOOL_PROMPTS
166
+ if tool.name != "Initialize"
167
+ ]
168
+
169
+ cost: float = 0
170
+ input_toks = 0
171
+ output_toks = 0
172
+ system_console = rich.console.Console(style="blue", highlight=False, markup=False)
173
+ error_console = rich.console.Console(style="red", highlight=False, markup=False)
174
+ user_console = rich.console.Console(
175
+ style="bright_black", highlight=False, markup=False
176
+ )
177
+ assistant_console = rich.console.Console(
178
+ style="white bold", highlight=False, markup=False
179
+ )
180
+
181
+ with BashState(
182
+ system_console, os.getcwd(), None, None, None, None, True, None
183
+ ) as bash_state:
184
+ context = Context(bash_state, system_console)
185
+ system, context, _ = initialize(
186
+ "first_call",
187
+ context,
188
+ os.getcwd(),
189
+ [],
190
+ resume if (memory and resume) else "",
191
+ 24000, # coding_max_tokens
192
+ 8000, # noncoding_max_tokens
193
+ mode="wcgw",
194
+ thread_id="",
195
+ )
196
+
197
+ if not history:
198
+ history = [{"role": "system", "content": system}]
199
+ else:
200
+ if history[-1]["role"] == "tool":
201
+ waiting_for_assistant = True
202
+
203
+ client = OpenAI()
204
+
205
+ while True:
206
+ if cost > limit:
207
+ system_console.print(
208
+ f"\nCost limit exceeded. Current cost: {cost}, input tokens: {input_toks}, output tokens: {output_toks}"
209
+ )
210
+ break
211
+
212
+ if not waiting_for_assistant:
213
+ if first_message:
214
+ msg = first_message
215
+ first_message = ""
216
+ else:
217
+ msg = text_from_editor(user_console)
218
+
219
+ history.append(parse_user_message_special(msg))
220
+ else:
221
+ waiting_for_assistant = False
222
+
223
+ cost_, input_toks_ = get_input_cost(
224
+ config.cost_file[config.model], enc, history
225
+ )
226
+ cost += cost_
227
+ input_toks += input_toks_
228
+
229
+ stream = client.chat.completions.create(
230
+ messages=history,
231
+ model=config.model,
232
+ stream=True,
233
+ tools=tools,
234
+ )
235
+
236
+ system_console.print(
237
+ "\n---------------------------------------\n# Assistant response",
238
+ style="bold",
239
+ )
240
+ tool_call_args_by_id = DefaultDict[str, DefaultDict[int, str]](
241
+ lambda: DefaultDict(str)
242
+ )
243
+ _histories: History = []
244
+ item: ChatCompletionMessageParam
245
+ full_response: str = ""
246
+ image_histories: History = []
247
+ try:
248
+ for chunk in stream:
249
+ if chunk.choices[0].finish_reason == "tool_calls":
250
+ assert tool_call_args_by_id
251
+ item = {
252
+ "role": "assistant",
253
+ "content": full_response,
254
+ "tool_calls": [
255
+ {
256
+ "id": tool_call_id + str(toolindex),
257
+ "type": "function",
258
+ "function": {
259
+ "arguments": tool_args,
260
+ "name": type(which_tool(tool_args)).__name__,
261
+ },
262
+ }
263
+ for tool_call_id, toolcallargs in tool_call_args_by_id.items()
264
+ for toolindex, tool_args in toolcallargs.items()
265
+ ],
266
+ }
267
+ cost_, output_toks_ = get_output_cost(
268
+ config.cost_file[config.model], enc, item
269
+ )
270
+ cost += cost_
271
+ system_console.print(
272
+ f"\n---------------------------------------\n# Assistant invoked tools: {[which_tool(tool['function']['arguments']) for tool in item['tool_calls']]}"
273
+ )
274
+ system_console.print(
275
+ f"\nTotal cost: {config.cost_unit}{cost:.3f}"
276
+ )
277
+ output_toks += output_toks_
278
+
279
+ _histories.append(item)
280
+ for tool_call_id, toolcallargs in tool_call_args_by_id.items():
281
+ for toolindex, tool_args in toolcallargs.items():
282
+ try:
283
+ output_or_dones, cost_ = get_tool_output(
284
+ context,
285
+ json.loads(tool_args),
286
+ enc,
287
+ limit - cost,
288
+ loop,
289
+ 24000, # coding_max_tokens
290
+ 8000, # noncoding_max_tokens
291
+ )
292
+ output_or_done = output_or_dones[0]
293
+ except Exception as e:
294
+ output_or_done = (
295
+ f"GOT EXCEPTION while calling tool. Error: {e}"
296
+ )
297
+ tb = traceback.format_exc()
298
+ error_console.print(output_or_done + "\n" + tb)
299
+ cost_ = 0
300
+ cost += cost_
301
+ system_console.print(
302
+ f"\nTotal cost: {config.cost_unit}{cost:.3f}"
303
+ )
304
+
305
+ output = output_or_done
306
+
307
+ if isinstance(output, ImageData):
308
+ randomId = petname.Generate(2, "-")
309
+ if not image_histories:
310
+ image_histories.extend(
311
+ [
312
+ {
313
+ "role": "assistant",
314
+ "content": f"Share images with ids: {randomId}",
315
+ },
316
+ {
317
+ "role": "user",
318
+ "content": [
319
+ {
320
+ "type": "image_url",
321
+ "image_url": {
322
+ "url": output.dataurl,
323
+ "detail": "auto",
324
+ },
325
+ }
326
+ ],
327
+ },
328
+ ]
329
+ )
330
+ else:
331
+ image_histories[0]["content"] += ", " + randomId
332
+ second_content = image_histories[1]["content"]
333
+ assert isinstance(second_content, list)
334
+ second_content.append(
335
+ {
336
+ "type": "image_url",
337
+ "image_url": {
338
+ "url": output.dataurl,
339
+ "detail": "auto",
340
+ },
341
+ }
342
+ )
343
+
344
+ item = {
345
+ "role": "tool",
346
+ "content": f"Ask user for image id: {randomId}",
347
+ "tool_call_id": tool_call_id + str(toolindex),
348
+ }
349
+ else:
350
+ item = {
351
+ "role": "tool",
352
+ "content": str(output),
353
+ "tool_call_id": tool_call_id + str(toolindex),
354
+ }
355
+ cost_, output_toks_ = get_output_cost(
356
+ config.cost_file[config.model], enc, item
357
+ )
358
+ cost += cost_
359
+ output_toks += output_toks_
360
+
361
+ _histories.append(item)
362
+ waiting_for_assistant = True
363
+ break
364
+ elif chunk.choices[0].finish_reason:
365
+ assistant_console.print("")
366
+ item = {
367
+ "role": "assistant",
368
+ "content": full_response,
369
+ }
370
+ cost_, output_toks_ = get_output_cost(
371
+ config.cost_file[config.model], enc, item
372
+ )
373
+ cost += cost_
374
+ output_toks += output_toks_
375
+
376
+ system_console.print(
377
+ f"\nTotal cost: {config.cost_unit}{cost:.3f}"
378
+ )
379
+ _histories.append(item)
380
+ break
381
+
382
+ if chunk.choices[0].delta.tool_calls:
383
+ tool_call = chunk.choices[0].delta.tool_calls[0]
384
+ if tool_call.function and tool_call.function.arguments:
385
+ tool_call_args_by_id[tool_call.id or ""][
386
+ tool_call.index
387
+ ] += tool_call.function.arguments
388
+
389
+ chunk_str = chunk.choices[0].delta.content or ""
390
+ assistant_console.print(chunk_str, end="")
391
+ full_response += chunk_str
392
+ except KeyboardInterrupt:
393
+ waiting_for_assistant = False
394
+ input("Interrupted...enter to redo the current turn")
395
+ else:
396
+ history.extend(_histories)
397
+ history.extend(image_histories)
398
+ save_history(history, session_id)
399
+
400
+ return "Couldn't finish the task", cost
401
+
402
+
403
+ if __name__ == "__main__":
404
+ app()
@@ -0,0 +1,67 @@
1
+ from typing import cast
2
+
3
+ from openai.types.chat import (
4
+ ChatCompletionAssistantMessageParam,
5
+ ChatCompletionMessage,
6
+ ChatCompletionMessageParam,
7
+ ParsedChatCompletionMessage,
8
+ )
9
+ from tokenizers import Tokenizer # type: ignore[import-untyped]
10
+
11
+ from wcgw.client.common import CostData, History
12
+
13
+
14
+ def get_input_cost(
15
+ cost_map: CostData, enc: Tokenizer, history: History
16
+ ) -> tuple[float, int]:
17
+ input_tokens = 0
18
+ for msg in history:
19
+ content = msg["content"]
20
+ refusal = msg.get("refusal")
21
+ if isinstance(content, list):
22
+ for part in content:
23
+ if "text" in part:
24
+ input_tokens += len(enc.encode(part["text"]))
25
+ elif content is None:
26
+ if refusal is None:
27
+ raise ValueError("Expected content or refusal to be present")
28
+ input_tokens += len(enc.encode(str(refusal)))
29
+ elif not isinstance(content, str):
30
+ raise ValueError(f"Expected content to be string, got {type(content)}")
31
+ else:
32
+ input_tokens += len(enc.encode(content))
33
+ cost = input_tokens * cost_map.cost_per_1m_input_tokens / 1_000_000
34
+ return cost, input_tokens
35
+
36
+
37
+ def get_output_cost(
38
+ cost_map: CostData,
39
+ enc: Tokenizer,
40
+ item: ChatCompletionMessage | ChatCompletionMessageParam,
41
+ ) -> tuple[float, int]:
42
+ if isinstance(item, ChatCompletionMessage):
43
+ content = item.content
44
+ if not isinstance(content, str):
45
+ raise ValueError(f"Expected content to be string, got {type(content)}")
46
+ else:
47
+ if not isinstance(item["content"], str):
48
+ raise ValueError(
49
+ f"Expected content to be string, got {type(item['content'])}"
50
+ )
51
+ content = item["content"]
52
+ if item["role"] == "tool":
53
+ return 0, 0
54
+ output_tokens = len(enc.encode(content))
55
+
56
+ if "tool_calls" in item:
57
+ item = cast(ChatCompletionAssistantMessageParam, item)
58
+ toolcalls = item["tool_calls"]
59
+ for tool_call in toolcalls or []:
60
+ output_tokens += len(enc.encode(tool_call["function"]["arguments"]))
61
+ elif isinstance(item, ParsedChatCompletionMessage):
62
+ if item.tool_calls:
63
+ for tool_callf in item.tool_calls:
64
+ output_tokens += len(enc.encode(tool_callf.function.arguments))
65
+
66
+ cost = output_tokens * cost_map.cost_per_1m_output_tokens / 1_000_000
67
+ return cost, output_tokens