wcgw 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of wcgw might be problematic. Click here for more details.

wcgw/__init__.py ADDED
@@ -0,0 +1,2 @@
1
+ from .basic import app, loop
2
+ from .tools import run as listen
wcgw/__main__.py ADDED
@@ -0,0 +1,3 @@
1
+ from wcgw.tools import run
2
+
3
+ run()
wcgw/basic.py ADDED
@@ -0,0 +1,341 @@
1
+ import json
2
+ from pathlib import Path
3
+ import sys
4
+ import traceback
5
+ from typing import Callable, DefaultDict, Optional, cast
6
+ import openai
7
+ from openai import OpenAI
8
+ from openai.types.chat import (
9
+ ChatCompletionMessageParam,
10
+ ChatCompletionAssistantMessageParam,
11
+ ChatCompletionMessage,
12
+ ParsedChatCompletionMessage,
13
+ )
14
+ import rich
15
+ from typer import Typer
16
+ import uuid
17
+
18
+ from .common import Models, discard_input
19
+ from .common import CostData, History
20
+ from .openai_utils import get_input_cost, get_output_cost
21
+ from .tools import ExecuteBash, GetShellOutputLastCommand
22
+
23
+ from .tools import (
24
+ BASH_CLF_OUTPUT,
25
+ Confirmation,
26
+ DoneFlag,
27
+ Writefile,
28
+ get_is_waiting_user_input,
29
+ get_tool_output,
30
+ SHELL,
31
+ start_shell,
32
+ which_tool,
33
+ )
34
+ import tiktoken
35
+
36
+ from urllib import parse
37
+ import subprocess
38
+ import os
39
+ import tempfile
40
+
41
+ import toml
42
+ from pydantic import BaseModel
43
+
44
+
45
+ from dotenv import load_dotenv
46
+
47
+
48
+ class Config(BaseModel):
49
+ model: Models
50
+ secondary_model: Models
51
+ cost_limit: float
52
+ cost_file: dict[Models, CostData]
53
+ cost_unit: str = "$"
54
+
55
+
56
+ def text_from_editor(console: rich.console.Console) -> str:
57
+ # First consume all the input till now
58
+ discard_input()
59
+ console.print("\n---------------------------------------\n# User message")
60
+ data = input()
61
+ if data:
62
+ return data
63
+ editor = os.environ.get("EDITOR", "vim")
64
+ with tempfile.NamedTemporaryFile(suffix=".tmp") as tf:
65
+ subprocess.run([editor, tf.name], check=True)
66
+ with open(tf.name, "r") as f:
67
+ data = f.read()
68
+ console.print(data)
69
+ return data
70
+
71
+
72
+ def save_history(history: History, session_id: str) -> None:
73
+ myid = str(history[1]["content"]).replace("/", "_").replace(" ", "_").lower()[:60]
74
+ myid += "_" + session_id
75
+ myid = myid + ".json"
76
+
77
+ mypath = Path(".wcgw") / myid
78
+ mypath.parent.mkdir(parents=True, exist_ok=True)
79
+ with open(mypath, "w") as f:
80
+ json.dump(history, f, indent=3)
81
+
82
+
83
+ app = Typer(pretty_exceptions_show_locals=False)
84
+
85
+
86
+ @app.command()
87
+ def loop(
88
+ first_message: Optional[str] = None,
89
+ limit: Optional[float] = None,
90
+ resume: Optional[str] = None,
91
+ ) -> tuple[str, float]:
92
+ load_dotenv()
93
+
94
+ session_id = str(uuid.uuid4())[:6]
95
+
96
+ history: History = []
97
+ if resume:
98
+ if resume == "latest":
99
+ resume_path = sorted(Path(".wcgw").iterdir(), key=os.path.getmtime)[-1]
100
+ else:
101
+ resume_path = Path(resume)
102
+ if not resume_path.exists():
103
+ raise FileNotFoundError(f"File {resume} not found")
104
+ with resume_path.open() as f:
105
+ history = json.load(f)
106
+ if len(history) <= 2:
107
+ raise ValueError("Invalid history file")
108
+ if history[1]["role"] != "user":
109
+ raise ValueError("Invalid history file, second message should be user")
110
+ first_message = ""
111
+
112
+ my_dir = os.path.dirname(__file__)
113
+ config_file = os.path.join(my_dir, "..", "..", "config.toml")
114
+ with open(config_file) as f:
115
+ config_json = toml.load(f)
116
+ config = Config.model_validate(config_json)
117
+
118
+ if limit is not None:
119
+ config.cost_limit = limit
120
+ limit = config.cost_limit
121
+
122
+ enc = tiktoken.encoding_for_model(
123
+ config.model if not config.model.startswith("o1") else "gpt-4o"
124
+ )
125
+ is_waiting_user_input = get_is_waiting_user_input(
126
+ config.model, config.cost_file[config.model]
127
+ )
128
+
129
+ tools = [
130
+ openai.pydantic_function_tool(
131
+ ExecuteBash,
132
+ description="""
133
+ Execute a bash script. Stateful (beware with subsequent calls).
134
+ Execute commands using `execute_command` attribute.
135
+ Do not use interactive commands like nano. Prefer writing simpler commands.
136
+ Last line will always be `(exit <int code>)` except if
137
+ the last line is `(waiting for input)` which will be the case if you've run any interactive command (which you shouldn't run) by mistake. You can then send input using `send_ascii` attributes.
138
+ Optionally the last line is `(won't exit)` in which case you need to kill the process if you want to run a new command.
139
+ Optionally `exit shell has restarted` is the output, in which case environment resets, you can run fresh commands.
140
+ The first line might be `(...truncated)` if the output is too long.""",
141
+ ),
142
+ openai.pydantic_function_tool(
143
+ GetShellOutputLastCommand,
144
+ description="Get output of the last command run in the shell. Use this in case you want to know status of a running program.",
145
+ ),
146
+ openai.pydantic_function_tool(
147
+ Writefile,
148
+ description="Write content to a file. Provide file path and content. Use this instead of ExecuteBash for writing files.",
149
+ ),
150
+ ]
151
+ uname_sysname = os.uname().sysname
152
+ uname_machine = os.uname().machine
153
+
154
+ system = f"""
155
+ You're a cli assistant.
156
+
157
+ Instructions:
158
+
159
+ - You should use the provided bash execution tool to run script to complete objective.
160
+ - Do not use sudo. Do not use interactive commands.
161
+ - Ask user for confirmation before running anything major
162
+
163
+ System information:
164
+ - System: {uname_sysname}
165
+ - Machine: {uname_machine}
166
+ """
167
+
168
+ has_tool_output = False
169
+ if not history:
170
+ history = [{"role": "system", "content": system}]
171
+ else:
172
+ if history[-1]["role"] == "tool":
173
+ has_tool_output = True
174
+
175
+ client = OpenAI()
176
+
177
+ cost: float = 0
178
+ input_toks = 0
179
+ output_toks = 0
180
+ system_console = rich.console.Console(style="blue", highlight=False)
181
+ error_console = rich.console.Console(style="red", highlight=False)
182
+ user_console = rich.console.Console(style="bright_black", highlight=False)
183
+ assistant_console = rich.console.Console(style="white bold", highlight=False)
184
+
185
+ while True:
186
+ if cost > limit:
187
+ system_console.print(
188
+ f"\nCost limit exceeded. Current cost: {cost}, input tokens: {input_toks}, output tokens: {output_toks}"
189
+ )
190
+ break
191
+
192
+ if not has_tool_output:
193
+ if first_message:
194
+ msg = first_message
195
+ first_message = ""
196
+ else:
197
+ msg = text_from_editor(user_console)
198
+
199
+ history.append({"role": "user", "content": msg})
200
+ else:
201
+ has_tool_output = False
202
+
203
+ cost_, input_toks_ = get_input_cost(
204
+ config.cost_file[config.model], enc, history
205
+ )
206
+ cost += cost_
207
+ input_toks += input_toks_
208
+
209
+ stream = client.chat.completions.create(
210
+ messages=history,
211
+ model=config.model,
212
+ stream=True,
213
+ tools=tools,
214
+ )
215
+
216
+ system_console.print(
217
+ "\n---------------------------------------\n# Assistant response",
218
+ style="bold",
219
+ )
220
+ tool_call_args_by_id = DefaultDict[str, DefaultDict[int, str]](
221
+ lambda: DefaultDict(str)
222
+ )
223
+ _histories: History = []
224
+ item: ChatCompletionMessageParam
225
+ full_response: str = ""
226
+ try:
227
+ for chunk in stream:
228
+ if chunk.choices[0].finish_reason == "tool_calls":
229
+ assert tool_call_args_by_id
230
+ item = {
231
+ "role": "assistant",
232
+ "content": full_response,
233
+ "tool_calls": [
234
+ {
235
+ "id": tool_call_id + str(toolindex),
236
+ "type": "function",
237
+ "function": {
238
+ "arguments": tool_args,
239
+ "name": "execute_bash",
240
+ },
241
+ }
242
+ for tool_call_id, toolcallargs in tool_call_args_by_id.items()
243
+ for toolindex, tool_args in toolcallargs.items()
244
+ ],
245
+ }
246
+ cost_, output_toks_ = get_output_cost(
247
+ config.cost_file[config.model], enc, item
248
+ )
249
+ cost += cost_
250
+ system_console.print(
251
+ f"\n---------------------------------------\n# Assistant invoked tools: {[which_tool(tool['function']['arguments']) for tool in item['tool_calls']]}"
252
+ )
253
+ system_console.print(f"\nTotal cost: {config.cost_unit}{cost:.3f}")
254
+ output_toks += output_toks_
255
+
256
+ _histories.append(item)
257
+ for tool_call_id, toolcallargs in tool_call_args_by_id.items():
258
+ for toolindex, tool_args in toolcallargs.items():
259
+ try:
260
+ output_or_done, cost_ = get_tool_output(
261
+ json.loads(tool_args),
262
+ enc,
263
+ limit - cost,
264
+ loop,
265
+ is_waiting_user_input,
266
+ )
267
+ except Exception as e:
268
+ output_or_done = (
269
+ f"GOT EXCEPTION while calling tool. Error: {e}"
270
+ )
271
+ tb = traceback.format_exc()
272
+ error_console.print(output_or_done + "\n" + tb)
273
+ cost_ = 0
274
+ cost += cost_
275
+ system_console.print(
276
+ f"\nTotal cost: {config.cost_unit}{cost:.3f}"
277
+ )
278
+
279
+ if isinstance(output_or_done, DoneFlag):
280
+ system_console.print(
281
+ f"\n# Task marked done, with output {output_or_done.task_output}",
282
+ )
283
+ system_console.print(
284
+ f"\nTotal cost: {config.cost_unit}{cost:.3f}"
285
+ )
286
+ return output_or_done.task_output, cost
287
+ output = output_or_done
288
+
289
+ item = {
290
+ "role": "tool",
291
+ "content": str(output),
292
+ "tool_call_id": tool_call_id + str(toolindex),
293
+ }
294
+ cost_, output_toks_ = get_output_cost(
295
+ config.cost_file[config.model], enc, item
296
+ )
297
+ cost += cost_
298
+ output_toks += output_toks_
299
+
300
+ _histories.append(item)
301
+ has_tool_output = True
302
+ break
303
+ elif chunk.choices[0].finish_reason:
304
+ assistant_console.print("")
305
+ item = {
306
+ "role": "assistant",
307
+ "content": full_response,
308
+ }
309
+ cost_, output_toks_ = get_output_cost(
310
+ config.cost_file[config.model], enc, item
311
+ )
312
+ cost += cost_
313
+ output_toks += output_toks_
314
+
315
+ system_console.print(f"\nTotal cost: {config.cost_unit}{cost:.3f}")
316
+ _histories.append(item)
317
+ break
318
+
319
+ if chunk.choices[0].delta.tool_calls:
320
+ tool_call = chunk.choices[0].delta.tool_calls[0]
321
+ if tool_call.function and tool_call.function.arguments:
322
+ tool_call_args_by_id[tool_call.id or ""][tool_call.index] += (
323
+ tool_call.function.arguments
324
+ )
325
+
326
+ chunk_str = chunk.choices[0].delta.content or ""
327
+ assistant_console.print(chunk_str, end="")
328
+ full_response += chunk_str
329
+ except KeyboardInterrupt:
330
+ has_tool_output = False
331
+ input("Interrupted...enter to redo the current turn")
332
+ else:
333
+ history.extend(_histories)
334
+
335
+ save_history(history, session_id)
336
+
337
+ return "Couldn't finish the task", cost
338
+
339
+
340
+ if __name__ == "__main__":
341
+ app()
wcgw/common.py ADDED
@@ -0,0 +1,47 @@
1
+ import select
2
+ import sys
3
+ import termios
4
+ import tty
5
+ from typing import Literal
6
+ from pydantic import BaseModel
7
+
8
+
9
+ class CostData(BaseModel):
10
+ cost_per_1m_input_tokens: float
11
+ cost_per_1m_output_tokens: float
12
+
13
+
14
+ from openai.types.chat import (
15
+ ChatCompletionMessageParam,
16
+ ChatCompletionAssistantMessageParam,
17
+ ChatCompletionMessage,
18
+ ParsedChatCompletionMessage,
19
+ )
20
+
21
+ History = list[ChatCompletionMessageParam]
22
+ Models = Literal["gpt-4o-2024-08-06", "gpt-4o-mini"]
23
+
24
+
25
+ def discard_input() -> None:
26
+ # Get the file descriptor for stdin
27
+ fd = sys.stdin.fileno()
28
+
29
+ # Save current terminal settings
30
+ old_settings = termios.tcgetattr(fd)
31
+
32
+ try:
33
+ # Switch terminal to non-canonical mode where input is read immediately
34
+ tty.setcbreak(fd)
35
+
36
+ # Discard all input
37
+ while True:
38
+ # Check if there is input to be read
39
+ if sys.stdin in select.select([sys.stdin], [], [], 0)[0]:
40
+ sys.stdin.read(
41
+ 1
42
+ ) # Read one character at a time to flush the input buffer
43
+ else:
44
+ break
45
+ finally:
46
+ # Restore old terminal settings
47
+ termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
File without changes
wcgw/openai_utils.py ADDED
@@ -0,0 +1,68 @@
1
+ import json
2
+ from pathlib import Path
3
+ import select
4
+ import sys
5
+ import termios
6
+ import traceback
7
+ import tty
8
+ from typing import Callable, DefaultDict, Literal, Optional, cast
9
+ import openai
10
+ from openai import OpenAI
11
+ from openai.types.chat import (
12
+ ChatCompletionMessageParam,
13
+ ChatCompletionAssistantMessageParam,
14
+ ChatCompletionMessage,
15
+ ParsedChatCompletionMessage,
16
+ )
17
+ import rich
18
+ import tiktoken
19
+ from typer import Typer
20
+ import uuid
21
+
22
+ from .common import CostData, History
23
+
24
+
25
+ def get_input_cost(
26
+ cost_map: CostData, enc: tiktoken.Encoding, history: History
27
+ ) -> tuple[float, int]:
28
+ input_tokens = 0
29
+ for msg in history:
30
+ content = msg["content"]
31
+ if not isinstance(content, str):
32
+ raise ValueError(f"Expected content to be string, got {type(content)}")
33
+ input_tokens += len(enc.encode(content))
34
+ cost = input_tokens * cost_map.cost_per_1m_input_tokens / 1_000_000
35
+ return cost, input_tokens
36
+
37
+
38
+ def get_output_cost(
39
+ cost_map: CostData,
40
+ enc: tiktoken.Encoding,
41
+ item: ChatCompletionMessage | ChatCompletionMessageParam,
42
+ ) -> tuple[float, int]:
43
+ if isinstance(item, ChatCompletionMessage):
44
+ content = item.content
45
+ if not isinstance(content, str):
46
+ raise ValueError(f"Expected content to be string, got {type(content)}")
47
+ else:
48
+ if not isinstance(item["content"], str):
49
+ raise ValueError(
50
+ f"Expected content to be string, got {type(item['content'])}"
51
+ )
52
+ content = item["content"]
53
+ if item["role"] == "tool":
54
+ return 0, 0
55
+ output_tokens = len(enc.encode(content))
56
+
57
+ if "tool_calls" in item:
58
+ item = cast(ChatCompletionAssistantMessageParam, item)
59
+ toolcalls = item["tool_calls"]
60
+ for tool_call in toolcalls or []:
61
+ output_tokens += len(enc.encode(tool_call["function"]["arguments"]))
62
+ elif isinstance(item, ParsedChatCompletionMessage):
63
+ if item.tool_calls:
64
+ for tool_callf in item.tool_calls:
65
+ output_tokens += len(enc.encode(tool_callf.function.arguments))
66
+
67
+ cost = output_tokens * cost_map.cost_per_1m_output_tokens / 1_000_000
68
+ return cost, output_tokens
wcgw/tools.py ADDED
@@ -0,0 +1,465 @@
1
+ import asyncio
2
+ import json
3
+ import sys
4
+ import threading
5
+ import traceback
6
+ from typing import Callable, Literal, Optional, ParamSpec, Sequence, TypeVar, TypedDict
7
+ import uuid
8
+ from pydantic import BaseModel, TypeAdapter
9
+
10
+ import os
11
+ import tiktoken
12
+ import petname # type: ignore[import]
13
+ import pexpect
14
+ from typer import Typer
15
+ import websockets
16
+
17
+ import rich
18
+ import pyte
19
+ from dotenv import load_dotenv
20
+
21
+ import openai
22
+ from openai import OpenAI
23
+ from openai.types.chat import (
24
+ ChatCompletionMessageParam,
25
+ ChatCompletionAssistantMessageParam,
26
+ ChatCompletionMessage,
27
+ ParsedChatCompletionMessage,
28
+ )
29
+
30
+ from .common import CostData, Models, discard_input
31
+
32
+ from .openai_utils import get_input_cost, get_output_cost
33
+
34
+ console = rich.console.Console(style="magenta", highlight=False)
35
+
36
+ TIMEOUT = 30
37
+
38
+
39
+ def render_terminal_output(text: str) -> str:
40
+ screen = pyte.Screen(160, 500)
41
+ screen.set_mode(pyte.modes.LNM)
42
+ stream = pyte.Stream(screen)
43
+ stream.feed(text)
44
+ # Filter out empty lines
45
+ dsp = screen.display[::-1]
46
+ for i, line in enumerate(dsp):
47
+ if line.strip():
48
+ break
49
+ else:
50
+ i = len(dsp)
51
+ return "\n".join(screen.display[: len(dsp) - i])
52
+
53
+
54
+ class Confirmation(BaseModel):
55
+ prompt: str
56
+
57
+
58
+ def ask_confirmation(prompt: Confirmation) -> str:
59
+ response = input(prompt.prompt + " [y/n] ")
60
+ return "Yes" if response.lower() == "y" else "No"
61
+
62
+
63
+ class Writefile(BaseModel):
64
+ file_path: str
65
+ file_content: str
66
+
67
+
68
+ def start_shell():
69
+ SHELL = pexpect.spawn(
70
+ "/bin/bash",
71
+ env={**os.environ, **{"PS1": "#@@"}},
72
+ echo=False,
73
+ encoding="utf-8",
74
+ timeout=TIMEOUT,
75
+ ) # type: ignore[arg-type]
76
+ SHELL.expect("#@@")
77
+ SHELL.sendline("stty -icanon -echo")
78
+ SHELL.expect("#@@")
79
+ return SHELL
80
+
81
+
82
+ SHELL = start_shell()
83
+
84
+
85
+ def _get_exit_code() -> int:
86
+ SHELL.sendline("echo $?")
87
+ SHELL.expect("#@@")
88
+ assert isinstance(SHELL.before, str)
89
+ return int((SHELL.before))
90
+
91
+
92
+ Specials = Literal["Key-up", "Key-down", "Key-left", "Key-right", "Enter", "Ctrl-c"]
93
+
94
+
95
+ class ExecuteBash(BaseModel):
96
+ execute_command: Optional[str] = None
97
+ send_ascii: Optional[Sequence[int | Specials]] = None
98
+
99
+
100
+ class GetShellOutputLastCommand(BaseModel):
101
+ type: Literal["get_output_of_last_command"] = "get_output_of_last_command"
102
+
103
+
104
+ BASH_CLF_OUTPUT = Literal["running", "waiting_for_input", "wont_exit"]
105
+ BASH_STATE: BASH_CLF_OUTPUT = "running"
106
+
107
+
108
+ def get_output_of_last_command(enc: tiktoken.Encoding) -> str:
109
+ global SHELL, BASH_STATE
110
+ output = render_terminal_output(SHELL.before)
111
+
112
+ tokens = enc.encode(output)
113
+ if len(tokens) >= 2048:
114
+ output = "...(truncated)\n" + enc.decode(tokens[-2047:])
115
+
116
+ return output
117
+
118
+
119
+ WETTING_INPUT_MESSAGE = """A command is already running waiting for input. NOTE: You can't run multiple shell sessions, likely a previous program hasn't exited.
120
+ 1. Get its output using `GetShellOutputLastCommand` OR
121
+ 2. Use `send_ascii` to give inputs to the running program, don't use `execute_command` OR
122
+ 3. kill the previous program by sending ctrl+c first using `send_ascii`"""
123
+
124
+
125
+ def execute_bash(
126
+ enc: tiktoken.Encoding,
127
+ bash_arg: ExecuteBash,
128
+ is_waiting_user_input: Callable[[str], tuple[BASH_CLF_OUTPUT, float]],
129
+ ) -> tuple[str, float]:
130
+ global SHELL, BASH_STATE
131
+ try:
132
+ if bash_arg.execute_command:
133
+ if BASH_STATE == "waiting_for_input":
134
+ raise ValueError(WETTING_INPUT_MESSAGE)
135
+ elif BASH_STATE == "wont_exit":
136
+ raise ValueError(
137
+ """A command is already running that hasn't exited. NOTE: You can't run multiple shell sessions, likely a previous program is in infinite loop.
138
+ Kill the previous program by sending ctrl+c first using `send_ascii`"""
139
+ )
140
+ command = bash_arg.execute_command.strip()
141
+
142
+ if "\n" in command:
143
+ raise ValueError(
144
+ "Command should not contain newline character in middle. Run only one command at a time."
145
+ )
146
+
147
+ console.print(f"$ {command}")
148
+ SHELL.sendline(command)
149
+ elif bash_arg.send_ascii:
150
+ console.print(f"Sending ASCII sequence: {bash_arg.send_ascii}")
151
+ for char in bash_arg.send_ascii:
152
+ if isinstance(char, int):
153
+ SHELL.send(chr(char))
154
+ if char == "Key-up":
155
+ SHELL.send("\033[A")
156
+ elif char == "Key-down":
157
+ SHELL.send("\033[B")
158
+ elif char == "Key-left":
159
+ SHELL.send("\033[D")
160
+ elif char == "Key-right":
161
+ SHELL.send("\033[C")
162
+ elif char == "Enter":
163
+ SHELL.send("\n")
164
+ elif char == "Ctrl-c":
165
+ SHELL.sendintr()
166
+ else:
167
+ raise Exception("Nothing to send")
168
+ BASH_STATE = "running"
169
+
170
+ except KeyboardInterrupt:
171
+ SHELL.close(True)
172
+ SHELL = start_shell()
173
+ raise
174
+
175
+ wait = timeout = 5
176
+ index = SHELL.expect(["#@@", pexpect.TIMEOUT], timeout=wait)
177
+ running = ""
178
+ while index == 1:
179
+ if wait > TIMEOUT:
180
+ raise TimeoutError("Timeout while waiting for shell prompt")
181
+
182
+ text = SHELL.before
183
+ print(text[len(running) :])
184
+ running = text
185
+
186
+ text = render_terminal_output(text)
187
+ BASH_STATE, cost = is_waiting_user_input(text)
188
+ if BASH_STATE == "waiting_for_input" or BASH_STATE == "wont_exit":
189
+ tokens = enc.encode(text)
190
+
191
+ if len(tokens) >= 2048:
192
+ text = "...(truncated)\n" + enc.decode(tokens[-2047:])
193
+
194
+ last_line = (
195
+ "(waiting for input)"
196
+ if BASH_STATE == "waiting_for_input"
197
+ else "(won't exit)"
198
+ )
199
+ return text + f"\n{last_line}", cost
200
+ index = SHELL.expect(["#@@", pexpect.TIMEOUT], timeout=wait)
201
+ wait += timeout
202
+
203
+ assert isinstance(SHELL.before, str)
204
+ output = render_terminal_output(SHELL.before)
205
+
206
+ tokens = enc.encode(output)
207
+ if len(tokens) >= 2048:
208
+ output = "...(truncated)\n" + enc.decode(tokens[-2047:])
209
+
210
+ try:
211
+ exit_code = _get_exit_code()
212
+ output += f"\n(exit {exit_code})"
213
+
214
+ except ValueError:
215
+ console.print("Malformed output, restarting shell", style="red")
216
+ # Malformed output, restart shell
217
+ SHELL.close(True)
218
+ SHELL = start_shell()
219
+ output = "(exit shell has restarted)"
220
+ return output, 0
221
+
222
+
223
+ Param = ParamSpec("Param")
224
+
225
+ T = TypeVar("T")
226
+
227
+
228
+ def ensure_no_previous_output(func: Callable[Param, T]) -> Callable[Param, T]:
229
+ def wrapper(*args: Param.args, **kwargs: Param.kwargs) -> T:
230
+ global BASH_STATE
231
+ if BASH_STATE == "waiting_for_input":
232
+ raise ValueError(WETTING_INPUT_MESSAGE)
233
+ elif BASH_STATE == "wont_exit":
234
+ raise ValueError(
235
+ "A command is already running that hasn't exited. NOTE: You can't run multiple shell sessions, likely the previous program is in infinite loop. Please kill the previous program by sending ctrl+c first."
236
+ )
237
+ return func(*args, **kwargs)
238
+
239
+ return wrapper
240
+
241
+
242
+ @ensure_no_previous_output
243
+ def write_file(writefile: Writefile) -> str:
244
+ if not os.path.isabs(writefile.file_path):
245
+ SHELL.sendline("pwd")
246
+ SHELL.expect("#@@")
247
+ assert isinstance(SHELL.before, str)
248
+ current_dir = SHELL.before.strip()
249
+ writefile.file_path = os.path.join(current_dir, writefile.file_path)
250
+ os.makedirs(os.path.dirname(writefile.file_path), exist_ok=True)
251
+ try:
252
+ with open(writefile.file_path, "w") as f:
253
+ f.write(writefile.file_content)
254
+ except OSError as e:
255
+ console.print(f"Error: {e}", style="red")
256
+ return f"Error: {e}"
257
+ console.print(f"File written to {writefile.file_path}")
258
+ return "Success"
259
+
260
+
261
+ class DoneFlag(BaseModel):
262
+ task_output: str
263
+
264
+
265
+ def mark_finish(done: DoneFlag) -> DoneFlag:
266
+ return done
267
+
268
+
269
+ class AIAssistant(BaseModel):
270
+ instruction: str
271
+ desired_output: str
272
+
273
+
274
+ def take_help_of_ai_assistant(
275
+ aiassistant: AIAssistant,
276
+ limit: float,
277
+ loop_call: Callable[[str, float], tuple[str, float]],
278
+ ) -> tuple[str, float]:
279
+ output, cost = loop_call(aiassistant.instruction, limit)
280
+ return output, cost
281
+
282
+
283
+ class AddTasks(BaseModel):
284
+ task_statement: str
285
+
286
+
287
+ def add_task(addtask: AddTasks) -> str:
288
+ petname_id = petname.Generate(2, "-")
289
+ return petname_id
290
+
291
+
292
+ class RemoveTask(BaseModel):
293
+ task_id: str
294
+
295
+
296
+ def remove_task(removetask: RemoveTask) -> str:
297
+ return "removed"
298
+
299
+
300
+ def which_tool(args: str) -> BaseModel:
301
+ adapter = TypeAdapter[
302
+ Confirmation | ExecuteBash | Writefile | AIAssistant | DoneFlag
303
+ ](Confirmation | ExecuteBash | Writefile | AIAssistant | DoneFlag)
304
+ return adapter.validate_python(json.loads(args))
305
+
306
+
307
+ def get_tool_output(
308
+ args: dict | Confirmation | ExecuteBash | Writefile | AIAssistant | DoneFlag,
309
+ enc: tiktoken.Encoding,
310
+ limit: float,
311
+ loop_call: Callable[[str, float], tuple[str, float]],
312
+ is_waiting_user_input: Callable[[str], tuple[BASH_CLF_OUTPUT, float]],
313
+ ) -> tuple[str | DoneFlag, float]:
314
+ if isinstance(args, dict):
315
+ adapter = TypeAdapter[
316
+ Confirmation | ExecuteBash | Writefile | AIAssistant | DoneFlag
317
+ ](Confirmation | ExecuteBash | Writefile | AIAssistant | DoneFlag)
318
+ arg = adapter.validate_python(args)
319
+ else:
320
+ arg = args
321
+ output: tuple[str | DoneFlag, float]
322
+ if isinstance(arg, Confirmation):
323
+ console.print("Calling ask confirmation tool")
324
+ output = ask_confirmation(arg), 0.0
325
+ elif isinstance(arg, ExecuteBash):
326
+ console.print("Calling execute bash tool")
327
+ output = execute_bash(enc, arg, is_waiting_user_input)
328
+ elif isinstance(arg, Writefile):
329
+ console.print("Calling write file tool")
330
+ output = write_file(arg), 0
331
+ elif isinstance(arg, DoneFlag):
332
+ console.print("Calling mark finish tool")
333
+ output = mark_finish(arg), 0.0
334
+ elif isinstance(arg, AIAssistant):
335
+ console.print("Calling AI assistant tool")
336
+ output = take_help_of_ai_assistant(arg, limit, loop_call)
337
+ elif isinstance(arg, AddTasks):
338
+ console.print("Calling add task tool")
339
+ output = add_task(arg), 0
340
+ elif isinstance(arg, get_output_of_last_command):
341
+ console.print("Calling get output of last program tool")
342
+ output = get_output_of_last_command(enc), 0
343
+ else:
344
+ raise ValueError(f"Unknown tool: {arg}")
345
+
346
+ console.print(str(output[0]))
347
+ return output
348
+
349
+
350
+ History = list[ChatCompletionMessageParam]
351
+
352
+
353
+ def get_is_waiting_user_input(model: Models, cost_data: CostData):
354
+ enc = tiktoken.encoding_for_model(model if not model.startswith("o1") else "gpt-4o")
355
+ system_prompt = """You need to classify if a bash program is waiting for user input based on its stdout, or if it won't exit. You'll be given the output of any program.
356
+ Return `waiting_for_input` if the program is waiting for INTERACTIVE input only, Return false if it's waiting for external resources or just waiting to finish.
357
+ Return `wont_exit` if the program won't exit, for example if it's a server.
358
+ Return `normal` otherwise.
359
+ """
360
+ history: History = [{"role": "system", "content": system_prompt}]
361
+ client = OpenAI()
362
+
363
+ class ExpectedOutput(BaseModel):
364
+ output_classified: BASH_CLF_OUTPUT
365
+
366
+ def is_waiting_user_input(output: str) -> tuple[BASH_CLF_OUTPUT, float]:
367
+ # Send only last 30 lines
368
+ output = "\n".join(output.split("\n")[-30:])
369
+ # Send only max last 200 tokens
370
+ output = enc.decode(enc.encode(output)[-200:])
371
+
372
+ history.append({"role": "user", "content": output})
373
+ response = client.beta.chat.completions.parse(
374
+ model=model, messages=history, response_format=ExpectedOutput
375
+ )
376
+ parsed = response.choices[0].message.parsed
377
+ if parsed is None:
378
+ raise ValueError("No parsed output")
379
+ cost = (
380
+ get_input_cost(cost_data, enc, history)[0]
381
+ + get_output_cost(cost_data, enc, response.choices[0].message)[0]
382
+ )
383
+ return parsed.output_classified, cost
384
+
385
+ return is_waiting_user_input
386
+
387
+
388
+ default_enc = tiktoken.encoding_for_model("gpt-4o")
389
+ default_model: Models = "gpt-4o-2024-08-06"
390
+ default_cost = CostData(cost_per_1m_input_tokens=0.15, cost_per_1m_output_tokens=0.6)
391
+ curr_cost = 0.0
392
+
393
+
394
+ class Mdata(BaseModel):
395
+ data: ExecuteBash | Writefile
396
+
397
+
398
+ execution_lock = threading.Lock()
399
+
400
+
401
+ def execute_user_input() -> None:
402
+ while True:
403
+ discard_input()
404
+ user_input = input()
405
+ if user_input:
406
+ with execution_lock:
407
+ try:
408
+ console.log(
409
+ execute_bash(
410
+ default_enc,
411
+ ExecuteBash(
412
+ send_ascii=[ord(x) for x in user_input] + [ord("\n")]
413
+ ),
414
+ lambda x: ("wont_exit", 0),
415
+ )[0]
416
+ )
417
+ except Exception as e:
418
+ traceback.print_exc()
419
+ console.log(f"Error: {e}")
420
+
421
+
422
+ async def register_client(server_url: str) -> None:
423
+ global default_enc, default_model, curr_cost
424
+ # Generate a unique UUID for this client
425
+ client_uuid = str(uuid.uuid4())
426
+ print(f"Connecting with UUID: {client_uuid}")
427
+
428
+ # Create the WebSocket connection
429
+ async with websockets.connect(f"{server_url}/{client_uuid}") as websocket:
430
+ try:
431
+ while True:
432
+ # Wait to receive data from the server
433
+ message = await websocket.recv()
434
+ print(message, type(message))
435
+ mdata = Mdata.model_validate_json(message)
436
+ with execution_lock:
437
+ is_waiting_user_input = get_is_waiting_user_input(
438
+ default_model, default_cost
439
+ )
440
+ try:
441
+ output, cost = get_tool_output(
442
+ mdata.data,
443
+ default_enc,
444
+ 0.0,
445
+ lambda x, y: ("", 0),
446
+ is_waiting_user_input,
447
+ )
448
+ curr_cost += cost
449
+ print(f"{curr_cost=}")
450
+ except Exception as e:
451
+ output = f"GOT EXCEPTION while calling tool. Error: {e}"
452
+ traceback.print_exc()
453
+ assert not isinstance(output, DoneFlag)
454
+ await websocket.send(output)
455
+
456
+ except websockets.ConnectionClosed:
457
+ print(f"Connection closed for UUID: {client_uuid}")
458
+
459
+
460
+ def run() -> None:
461
+ if len(sys.argv) > 1:
462
+ server_url = sys.argv[1]
463
+ else:
464
+ server_url = "ws://localhost:8000/register"
465
+ asyncio.run(register_client(server_url))
@@ -0,0 +1,23 @@
1
+ Metadata-Version: 2.3
2
+ Name: wcgw
3
+ Version: 0.1.0
4
+ Summary: What could go wrong giving full shell access to chatgpt?
5
+ Requires-Python: >=3.12
6
+ Requires-Dist: fastapi>=0.115.0
7
+ Requires-Dist: mypy>=1.11.2
8
+ Requires-Dist: openai>=1.46.0
9
+ Requires-Dist: petname>=2.6
10
+ Requires-Dist: pexpect>=4.9.0
11
+ Requires-Dist: pyte>=0.8.2
12
+ Requires-Dist: python-dotenv>=1.0.1
13
+ Requires-Dist: rich>=13.8.1
14
+ Requires-Dist: shell>=1.0.1
15
+ Requires-Dist: tiktoken==0.7.0
16
+ Requires-Dist: toml>=0.10.2
17
+ Requires-Dist: typer>=0.12.5
18
+ Requires-Dist: types-pexpect>=4.9.0.20240806
19
+ Requires-Dist: uvicorn>=0.31.0
20
+ Requires-Dist: websockets>=13.1
21
+ Description-Content-Type: text/markdown
22
+
23
+ # What could go wrong giving full shell access to Chatgpt?
@@ -0,0 +1,11 @@
1
+ wcgw/__init__.py,sha256=okSsOWpTKDjEQzgOin3Kdpx4Mc3MFX1RunjopHQSIWE,62
2
+ wcgw/__main__.py,sha256=MjJnFwfYzA1rW47xuSP1EVsi53DTHeEGqESkQwsELFQ,34
3
+ wcgw/basic.py,sha256=BiVjIwrtiz93SkUedDXSwtfVMKoV8-zEWeFKBIamVSQ,12372
4
+ wcgw/common.py,sha256=jn39zTpaFUO1PWof_z7qBmklaZH5G1blzjlBvez0cg4,1225
5
+ wcgw/openai_adapters.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
+ wcgw/openai_utils.py,sha256=4Hr9S2-WT8xhDdu3b2YVoX1l9AVxwCFdi_GJbQAx7Us,2202
7
+ wcgw/tools.py,sha256=k6Xoaq_kqLnC6cgBQ9NGovwNDtam5zV7DPTNmq9zSbo,15069
8
+ wcgw-0.1.0.dist-info/METADATA,sha256=oM2m4AYiPXrCPytISxbWfGWCMcS3t7lgw0TiitTIb-s,701
9
+ wcgw-0.1.0.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
10
+ wcgw-0.1.0.dist-info/entry_points.txt,sha256=T-IH7w6Vc650hr8xksC8kJfbJR4uwN8HDudejwDwrNM,59
11
+ wcgw-0.1.0.dist-info/RECORD,,
@@ -0,0 +1,4 @@
1
+ Wheel-Version: 1.0
2
+ Generator: hatchling 1.25.0
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
@@ -0,0 +1,3 @@
1
+ [console_scripts]
2
+ wcgw = wcgw:listen
3
+ wcgw_local = wcgw:app