wcgw 5.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,486 @@
1
+ import base64
2
+ import json
3
+ import mimetypes
4
+ import os
5
+ import subprocess
6
+ import tempfile
7
+ import traceback
8
+ import uuid
9
+ from pathlib import Path
10
+ from typing import Literal, Optional, cast
11
+
12
+ import rich
13
+ from anthropic import Anthropic, MessageStopEvent
14
+ from anthropic.types import (
15
+ ImageBlockParam,
16
+ MessageParam,
17
+ ModelParam,
18
+ RawMessageStartEvent,
19
+ TextBlockParam,
20
+ ToolParam,
21
+ ToolResultBlockParam,
22
+ ToolUseBlockParam,
23
+ )
24
+ from dotenv import load_dotenv
25
+ from pydantic import BaseModel, ValidationError
26
+ from typer import Typer
27
+
28
+ from wcgw.client.bash_state.bash_state import BashState
29
+ from wcgw.client.common import CostData, discard_input
30
+ from wcgw.client.memory import load_memory
31
+ from wcgw.client.tool_prompts import TOOL_PROMPTS
32
+ from wcgw.client.tools import (
33
+ Context,
34
+ ImageData,
35
+ default_enc,
36
+ get_tool_output,
37
+ initialize,
38
+ parse_tool_by_name,
39
+ )
40
+
41
+
42
+ class Config(BaseModel):
43
+ model: ModelParam
44
+ cost_limit: float
45
+ cost_file: dict[ModelParam, CostData]
46
+ cost_unit: str = "$"
47
+
48
+
49
+ History = list[MessageParam]
50
+
51
+
52
+ def text_from_editor(console: rich.console.Console) -> str:
53
+ # First consume all the input till now
54
+ discard_input()
55
+ console.print("\n---------------------------------------\n# User message")
56
+ data = input()
57
+ if data:
58
+ return data
59
+ editor = os.environ.get("EDITOR", "vim")
60
+ with tempfile.NamedTemporaryFile(suffix=".tmp") as tf:
61
+ subprocess.run([editor, tf.name], check=True)
62
+ with open(tf.name, "r") as f:
63
+ data = f.read()
64
+ console.print(data)
65
+ return data
66
+
67
+
68
+ def save_history(history: History, session_id: str) -> None:
69
+ myid = str(history[1]["content"]).replace("/", "_").replace(" ", "_").lower()[:60]
70
+ myid += "_" + session_id
71
+ myid = myid + ".json"
72
+
73
+ mypath = Path(".wcgw") / myid
74
+ mypath.parent.mkdir(parents=True, exist_ok=True)
75
+ with open(mypath, "w") as f:
76
+ json.dump(history, f, indent=3)
77
+
78
+
79
+ def parse_user_message_special(msg: str) -> MessageParam:
80
+ # Search for lines starting with `%` and treat them as special commands
81
+ parts: list[ImageBlockParam | TextBlockParam] = []
82
+ for line in msg.split("\n"):
83
+ if line.startswith("%"):
84
+ args = line[1:].strip().split(" ")
85
+ command = args[0]
86
+ assert command == "image"
87
+ image_path = " ".join(args[1:])
88
+ with open(image_path, "rb") as f:
89
+ image_bytes = f.read()
90
+ image_b64 = base64.b64encode(image_bytes).decode("utf-8")
91
+ image_type = mimetypes.guess_type(image_path)[0]
92
+ parts.append(
93
+ {
94
+ "type": "image",
95
+ "source": {
96
+ "type": "base64",
97
+ "media_type": cast(
98
+ 'Literal["image/jpeg", "image/png", "image/gif", "image/webp"]',
99
+ image_type or "image/png",
100
+ ),
101
+ "data": image_b64,
102
+ },
103
+ }
104
+ )
105
+ else:
106
+ if len(parts) > 0 and parts[-1]["type"] == "text":
107
+ parts[-1]["text"] += "\n" + line
108
+ else:
109
+ parts.append({"type": "text", "text": line})
110
+ return {"role": "user", "content": parts}
111
+
112
+
113
+ app = Typer(pretty_exceptions_show_locals=False)
114
+
115
+
116
+ @app.command()
117
+ def loop(
118
+ first_message: Optional[str] = None,
119
+ limit: Optional[float] = None,
120
+ resume: Optional[str] = None,
121
+ ) -> tuple[str, float]:
122
+ load_dotenv()
123
+
124
+ session_id = str(uuid.uuid4())[:6]
125
+
126
+ history: History = []
127
+ waiting_for_assistant = False
128
+ memory = None
129
+ if resume:
130
+ try:
131
+ _, memory, _ = load_memory(
132
+ resume,
133
+ 24000, # coding_max_tokens
134
+ 8000, # noncoding_max_tokens
135
+ lambda x: default_enc.encoder(x),
136
+ lambda x: default_enc.decoder(x),
137
+ )
138
+ except OSError:
139
+ if resume == "latest":
140
+ resume_path = sorted(Path(".wcgw").iterdir(), key=os.path.getmtime)[-1]
141
+ else:
142
+ resume_path = Path(resume)
143
+ if not resume_path.exists():
144
+ raise FileNotFoundError(f"File {resume} not found")
145
+ with resume_path.open() as f:
146
+ history = json.load(f)
147
+ if len(history) <= 2:
148
+ raise ValueError("Invalid history file")
149
+ first_message = ""
150
+ waiting_for_assistant = history[-1]["role"] != "assistant"
151
+
152
+ config = Config(
153
+ model="claude-3-5-sonnet-20241022",
154
+ cost_limit=0.1,
155
+ cost_unit="$",
156
+ cost_file={
157
+ # Claude 3.5 Haiku
158
+ "claude-3-5-haiku-latest": CostData(
159
+ cost_per_1m_input_tokens=0.80, cost_per_1m_output_tokens=4
160
+ ),
161
+ "claude-3-5-haiku-20241022": CostData(
162
+ cost_per_1m_input_tokens=0.80, cost_per_1m_output_tokens=4
163
+ ),
164
+ # Claude 3.5 Sonnet
165
+ "claude-3-5-sonnet-latest": CostData(
166
+ cost_per_1m_input_tokens=3.0, cost_per_1m_output_tokens=15.0
167
+ ),
168
+ "claude-3-5-sonnet-20241022": CostData(
169
+ cost_per_1m_input_tokens=3.0, cost_per_1m_output_tokens=15.0
170
+ ),
171
+ "claude-3-5-sonnet-20240620": CostData(
172
+ cost_per_1m_input_tokens=3.0, cost_per_1m_output_tokens=15.0
173
+ ),
174
+ # Claude 3 Opus
175
+ "claude-3-opus-latest": CostData(
176
+ cost_per_1m_input_tokens=15.0, cost_per_1m_output_tokens=75.0
177
+ ),
178
+ "claude-3-opus-20240229": CostData(
179
+ cost_per_1m_input_tokens=15.0, cost_per_1m_output_tokens=75.0
180
+ ),
181
+ # Legacy Models
182
+ "claude-3-haiku-20240307": CostData(
183
+ cost_per_1m_input_tokens=0.25, cost_per_1m_output_tokens=1.25
184
+ ),
185
+ "claude-2.1": CostData(
186
+ cost_per_1m_input_tokens=8.0, cost_per_1m_output_tokens=24.0
187
+ ),
188
+ "claude-2.0": CostData(
189
+ cost_per_1m_input_tokens=8.0, cost_per_1m_output_tokens=24.0
190
+ ),
191
+ },
192
+ )
193
+
194
+ if limit is not None:
195
+ config.cost_limit = limit
196
+ limit = config.cost_limit
197
+
198
+ tools = [
199
+ ToolParam(
200
+ name=tool.name,
201
+ description=tool.description or "", # Ensure it's not None
202
+ input_schema=tool.inputSchema,
203
+ )
204
+ for tool in TOOL_PROMPTS
205
+ if tool.name != "Initialize"
206
+ ]
207
+
208
+ system_console = rich.console.Console(style="blue", highlight=False, markup=False)
209
+ error_console = rich.console.Console(style="red", highlight=False, markup=False)
210
+ user_console = rich.console.Console(
211
+ style="bright_black", highlight=False, markup=False
212
+ )
213
+ assistant_console = rich.console.Console(
214
+ style="white bold", highlight=False, markup=False
215
+ )
216
+
217
+ with BashState(
218
+ system_console, os.getcwd(), None, None, None, None, True, None
219
+ ) as bash_state:
220
+ context = Context(bash_state, system_console)
221
+ system, context, _ = initialize(
222
+ "first_call",
223
+ context,
224
+ os.getcwd(),
225
+ [],
226
+ resume if (memory and resume) else "",
227
+ 24000, # coding_max_tokens
228
+ 8000, # noncoding_max_tokens
229
+ mode="wcgw",
230
+ thread_id="",
231
+ )
232
+
233
+ if history:
234
+ if (
235
+ (last_msg := history[-1])["role"] == "user"
236
+ and isinstance((content := last_msg["content"]), dict)
237
+ and content["type"] == "tool_result"
238
+ ):
239
+ waiting_for_assistant = True
240
+
241
+ client = Anthropic()
242
+
243
+ cost: float = 0
244
+ input_toks = 0
245
+ output_toks = 0
246
+
247
+ while True:
248
+ if cost > limit:
249
+ system_console.print(
250
+ f"\nCost limit exceeded. Current cost: {config.cost_unit}{cost:.4f}, "
251
+ f"input tokens: {input_toks}"
252
+ f"output tokens: {output_toks}"
253
+ )
254
+ break
255
+ else:
256
+ system_console.print(
257
+ f"\nTotal cost: {config.cost_unit}{cost:.4f}, input tokens: {input_toks}, output tokens: {output_toks}"
258
+ )
259
+
260
+ if not waiting_for_assistant:
261
+ if first_message:
262
+ msg = first_message
263
+ first_message = ""
264
+ else:
265
+ msg = text_from_editor(user_console)
266
+
267
+ history.append(parse_user_message_special(msg))
268
+ else:
269
+ waiting_for_assistant = False
270
+
271
+ stream = client.messages.stream(
272
+ model=config.model,
273
+ messages=history,
274
+ tools=tools,
275
+ max_tokens=8096,
276
+ system=system,
277
+ )
278
+
279
+ system_console.print(
280
+ "\n---------------------------------------\n# Assistant response",
281
+ style="bold",
282
+ )
283
+ _histories: History = []
284
+ full_response: str = ""
285
+
286
+ tool_calls = []
287
+ tool_results: list[ToolResultBlockParam] = []
288
+ try:
289
+ with stream as stream_:
290
+ for chunk in stream_:
291
+ type_ = chunk.type
292
+ if isinstance(chunk, RawMessageStartEvent):
293
+ message_start = chunk.message
294
+ # Update cost based on token usage from the API response
295
+ input_tokens = message_start.usage.input_tokens
296
+ input_toks += input_tokens
297
+ cost += (
298
+ input_tokens
299
+ * config.cost_file[
300
+ config.model
301
+ ].cost_per_1m_input_tokens
302
+ ) / 1_000_000
303
+ elif isinstance(chunk, MessageStopEvent):
304
+ message_stop = chunk.message
305
+ # Update cost based on output tokens
306
+ output_tokens = message_stop.usage.output_tokens
307
+ output_toks += output_tokens
308
+ cost += (
309
+ output_tokens
310
+ * config.cost_file[
311
+ config.model
312
+ ].cost_per_1m_output_tokens
313
+ ) / 1_000_000
314
+ continue
315
+ elif type_ == "content_block_start" and hasattr(
316
+ chunk, "content_block"
317
+ ):
318
+ content_block = chunk.content_block
319
+ if (
320
+ hasattr(content_block, "type")
321
+ and content_block.type == "text"
322
+ and hasattr(content_block, "text")
323
+ ):
324
+ chunk_str = content_block.text
325
+ assistant_console.print(chunk_str, end="")
326
+ full_response += chunk_str
327
+ elif content_block.type == "tool_use":
328
+ if (
329
+ hasattr(content_block, "input")
330
+ and hasattr(content_block, "name")
331
+ and hasattr(content_block, "id")
332
+ ):
333
+ assert content_block.input == {}
334
+ tool_calls.append(
335
+ {
336
+ "name": str(content_block.name),
337
+ "input": str(""),
338
+ "done": False,
339
+ "id": str(content_block.id),
340
+ }
341
+ )
342
+ else:
343
+ error_console.log(
344
+ f"Ignoring unknown content block type {content_block.type}"
345
+ )
346
+ elif type_ == "content_block_delta" and hasattr(chunk, "delta"):
347
+ delta = chunk.delta
348
+ if hasattr(delta, "type"):
349
+ delta_type = str(delta.type)
350
+ if delta_type == "text_delta" and hasattr(
351
+ delta, "text"
352
+ ):
353
+ chunk_str = delta.text
354
+ assistant_console.print(chunk_str, end="")
355
+ full_response += chunk_str
356
+ elif delta_type == "input_json_delta" and hasattr(
357
+ delta, "partial_json"
358
+ ):
359
+ partial_json = delta.partial_json
360
+ if isinstance(tool_calls[-1]["input"], str):
361
+ tool_calls[-1]["input"] += partial_json
362
+ else:
363
+ error_console.log(
364
+ f"Ignoring unknown content block delta type {delta_type}"
365
+ )
366
+ else:
367
+ raise ValueError("Content block delta has no type")
368
+ elif type_ == "content_block_stop":
369
+ if tool_calls and not tool_calls[-1]["done"]:
370
+ tc = tool_calls[-1]
371
+ tool_name = str(tc["name"])
372
+ tool_input = str(tc["input"])
373
+ tool_id = str(tc["id"])
374
+
375
+ _histories.append(
376
+ {
377
+ "role": "assistant",
378
+ "content": [
379
+ ToolUseBlockParam(
380
+ id=tool_id,
381
+ name=tool_name,
382
+ input=json.loads(tool_input),
383
+ type="tool_use",
384
+ )
385
+ ],
386
+ }
387
+ )
388
+ try:
389
+ tool_parsed = parse_tool_by_name(
390
+ tool_name, json.loads(tool_input)
391
+ )
392
+ except ValidationError:
393
+ error_msg = f"Error parsing tool {tool_name}\n{traceback.format_exc()}"
394
+ system_console.log(
395
+ f"Error parsing tool {tool_name}"
396
+ )
397
+ tool_results.append(
398
+ ToolResultBlockParam(
399
+ type="tool_result",
400
+ tool_use_id=str(tc["id"]),
401
+ content=error_msg,
402
+ is_error=True,
403
+ )
404
+ )
405
+ continue
406
+
407
+ system_console.print(
408
+ f"\n---------------------------------------\n# Assistant invoked tool: {tool_parsed}"
409
+ )
410
+
411
+ try:
412
+ output_or_dones, _ = get_tool_output(
413
+ context,
414
+ tool_parsed,
415
+ default_enc,
416
+ limit - cost,
417
+ loop,
418
+ 24000, # coding_max_tokens
419
+ 8000, # noncoding_max_tokens
420
+ )
421
+ except Exception as e:
422
+ output_or_dones = [
423
+ (
424
+ f"GOT EXCEPTION while calling tool. Error: {e}"
425
+ )
426
+ ]
427
+ tb = traceback.format_exc()
428
+ error_console.print(
429
+ str(output_or_dones) + "\n" + tb
430
+ )
431
+
432
+ tool_results_content: list[
433
+ TextBlockParam | ImageBlockParam
434
+ ] = []
435
+ for output in output_or_dones:
436
+ if isinstance(output, ImageData):
437
+ tool_results_content.append(
438
+ {
439
+ "type": "image",
440
+ "source": {
441
+ "type": "base64",
442
+ "media_type": output.media_type,
443
+ "data": output.data,
444
+ },
445
+ }
446
+ )
447
+
448
+ else:
449
+ tool_results_content.append(
450
+ {
451
+ "type": "text",
452
+ "text": output,
453
+ },
454
+ )
455
+ tool_results.append(
456
+ ToolResultBlockParam(
457
+ type="tool_result",
458
+ tool_use_id=str(tc["id"]),
459
+ content=tool_results_content,
460
+ )
461
+ )
462
+ else:
463
+ _histories.append(
464
+ {
465
+ "role": "assistant",
466
+ "content": full_response
467
+ if full_response.strip()
468
+ else "...",
469
+ } # Fixes anthropic issue of non empty response only
470
+ )
471
+
472
+ except KeyboardInterrupt:
473
+ waiting_for_assistant = False
474
+ input("Interrupted...enter to redo the current turn")
475
+ else:
476
+ history.extend(_histories)
477
+ if tool_results:
478
+ history.append({"role": "user", "content": tool_results})
479
+ waiting_for_assistant = True
480
+ save_history(history, session_id)
481
+
482
+ return "Couldn't finish the task", cost
483
+
484
+
485
+ if __name__ == "__main__":
486
+ app()
wcgw_cli/cli.py ADDED
@@ -0,0 +1,40 @@
1
+ import importlib
2
+ from typing import Optional
3
+
4
+ import typer
5
+ from typer import Typer
6
+
7
+ from wcgw_cli.anthropic_client import loop as claude_loop
8
+ from wcgw_cli.openai_client import loop as openai_loop
9
+
10
+ app = Typer(pretty_exceptions_show_locals=False)
11
+
12
+
13
+ @app.command()
14
+ def loop(
15
+ claude: bool = False,
16
+ first_message: Optional[str] = None,
17
+ limit: Optional[float] = None,
18
+ resume: Optional[str] = None,
19
+ version: bool = typer.Option(False, "--version", "-v"),
20
+ ) -> tuple[str, float]:
21
+ if version:
22
+ version_ = importlib.metadata.version("wcgw")
23
+ print(f"wcgw version: {version_}")
24
+ exit()
25
+ if claude:
26
+ return claude_loop(
27
+ first_message=first_message,
28
+ limit=limit,
29
+ resume=resume,
30
+ )
31
+ else:
32
+ return openai_loop(
33
+ first_message=first_message,
34
+ limit=limit,
35
+ resume=resume,
36
+ )
37
+
38
+
39
+ if __name__ == "__main__":
40
+ app()