codex-chat-bot 0.1.2__tar.gz → 0.1.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {codex_chat_bot-0.1.2 → codex_chat_bot-0.1.4}/.gitignore +2 -0
- {codex_chat_bot-0.1.2 → codex_chat_bot-0.1.4}/PKG-INFO +1 -1
- {codex_chat_bot-0.1.2 → codex_chat_bot-0.1.4}/pyproject.toml +1 -1
- {codex_chat_bot-0.1.2 → codex_chat_bot-0.1.4}/src/codex_chat_bot/session.py +64 -5
- {codex_chat_bot-0.1.2 → codex_chat_bot-0.1.4}/tests/test_session.py +128 -12
- {codex_chat_bot-0.1.2 → codex_chat_bot-0.1.4}/README.md +0 -0
- {codex_chat_bot-0.1.2 → codex_chat_bot-0.1.4}/src/codex_chat_bot/__init__.py +0 -0
- {codex_chat_bot-0.1.2 → codex_chat_bot-0.1.4}/src/codex_chat_bot/cli.py +0 -0
- {codex_chat_bot-0.1.2 → codex_chat_bot-0.1.4}/src/codex_chat_bot/config.py +0 -0
- {codex_chat_bot-0.1.2 → codex_chat_bot-0.1.4}/src/codex_chat_bot/errors.py +0 -0
- {codex_chat_bot-0.1.2 → codex_chat_bot-0.1.4}/tests/test_cli.py +0 -0
- {codex_chat_bot-0.1.2 → codex_chat_bot-0.1.4}/tests/test_config.py +0 -0
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import json
|
|
4
|
+
import sys
|
|
4
5
|
from collections.abc import Mapping, Sequence
|
|
5
6
|
from dataclasses import dataclass
|
|
6
7
|
from pathlib import Path
|
|
@@ -15,6 +16,10 @@ SYSTEM_USERNAME = "system"
|
|
|
15
16
|
DEVELOPER_USERNAME = "developer"
|
|
16
17
|
DEFAULT_USERNAME = "user"
|
|
17
18
|
ASSISTANT_USERNAME = "assistant"
|
|
19
|
+
MAX_REQUEST_HISTORY_MESSAGES = 5
|
|
20
|
+
ANSI_YELLOW = "\033[1;33m"
|
|
21
|
+
ANSI_RESET = "\033[0m"
|
|
22
|
+
EMPTY_RESPONSE_RETRY_WARNING = "Warning: The robot returned an empty string, and is currently retrying."
|
|
18
23
|
|
|
19
24
|
|
|
20
25
|
@dataclass(frozen=True)
|
|
@@ -136,22 +141,29 @@ class ChatSession:
|
|
|
136
141
|
|
|
137
142
|
self._messages.append(Message(role="user", content=message, username=username))
|
|
138
143
|
try:
|
|
139
|
-
response = self.
|
|
144
|
+
response, text = self._create_response_until_text(extra_request_args)
|
|
140
145
|
except Exception:
|
|
141
146
|
self._messages.pop()
|
|
142
147
|
self._save_bound_history()
|
|
143
148
|
raise
|
|
144
149
|
|
|
145
|
-
text = _extract_response_text(response)
|
|
146
150
|
self._messages.append(Message(role="assistant", content=text, username=ASSISTANT_USERNAME))
|
|
147
151
|
self._trim_history()
|
|
148
152
|
self._save_bound_history()
|
|
149
153
|
return ChatResponse(text=text, raw=response, messages=self.messages)
|
|
150
154
|
|
|
155
|
+
def _create_response_until_text(self, extra_request_args: Mapping[str, Any]) -> tuple[Any, str]:
|
|
156
|
+
while True:
|
|
157
|
+
response = self._client.responses.create(**self._request_payload(extra_request_args))
|
|
158
|
+
text = _extract_response_text(response)
|
|
159
|
+
if text != "":
|
|
160
|
+
return response, text
|
|
161
|
+
_warn_empty_response_retry()
|
|
162
|
+
|
|
151
163
|
def _request_payload(self, extra_request_args: Mapping[str, Any]) -> dict[str, Any]:
|
|
152
164
|
payload: dict[str, Any] = {
|
|
153
165
|
"model": self.config.model,
|
|
154
|
-
"input": [message.to_api() for message in self._messages],
|
|
166
|
+
"input": [message.to_api() for message in _latest_messages(self._messages, MAX_REQUEST_HISTORY_MESSAGES)],
|
|
155
167
|
}
|
|
156
168
|
if self.config.temperature is not None:
|
|
157
169
|
payload["temperature"] = self.config.temperature
|
|
@@ -165,7 +177,7 @@ class ChatSession:
|
|
|
165
177
|
rules = self.config.system_rules if system_rules is None else system_rules
|
|
166
178
|
|
|
167
179
|
clean_rules = [str(rule).strip() for rule in rules if str(rule).strip()]
|
|
168
|
-
return "
|
|
180
|
+
return " ".join(clean_rules)
|
|
169
181
|
|
|
170
182
|
def _trim_history(self) -> None:
|
|
171
183
|
max_messages = self.config.max_history_messages
|
|
@@ -213,6 +225,10 @@ class ChatSession:
|
|
|
213
225
|
|
|
214
226
|
|
|
215
227
|
def _extract_response_text(response: Any) -> str:
|
|
228
|
+
error_json = _extract_response_error_json(response)
|
|
229
|
+
if error_json is not None:
|
|
230
|
+
return error_json
|
|
231
|
+
|
|
216
232
|
output_text = _get_value(response, "output_text")
|
|
217
233
|
if isinstance(output_text, str):
|
|
218
234
|
return output_text
|
|
@@ -230,6 +246,36 @@ def _extract_response_text(response: Any) -> str:
|
|
|
230
246
|
raise ResponseTextError("could not extract text from the model response")
|
|
231
247
|
|
|
232
248
|
|
|
249
|
+
def _warn_empty_response_retry() -> None:
|
|
250
|
+
print(f"{ANSI_YELLOW}{EMPTY_RESPONSE_RETRY_WARNING}{ANSI_RESET}", file=sys.stderr, flush=True)
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def _extract_response_error_json(response: Any) -> str | None:
|
|
254
|
+
error = _get_value(response, "error")
|
|
255
|
+
if error is None:
|
|
256
|
+
return None
|
|
257
|
+
return json.dumps({"error": _to_jsonable(error)}, ensure_ascii=False, indent=4)
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
def _to_jsonable(value: Any) -> Any:
|
|
261
|
+
if isinstance(value, Mapping):
|
|
262
|
+
return {str(key): _to_jsonable(item) for key, item in value.items()}
|
|
263
|
+
if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)):
|
|
264
|
+
return [_to_jsonable(item) for item in value]
|
|
265
|
+
if isinstance(value, (str, int, float, bool)) or value is None:
|
|
266
|
+
return value
|
|
267
|
+
|
|
268
|
+
model_dump = getattr(value, "model_dump", None)
|
|
269
|
+
if callable(model_dump):
|
|
270
|
+
return _to_jsonable(model_dump(mode="json"))
|
|
271
|
+
|
|
272
|
+
to_dict = getattr(value, "to_dict", None)
|
|
273
|
+
if callable(to_dict):
|
|
274
|
+
return _to_jsonable(to_dict())
|
|
275
|
+
|
|
276
|
+
return str(value)
|
|
277
|
+
|
|
278
|
+
|
|
233
279
|
def _messages_from_history_payload(payload: Any) -> tuple[Message, ...]:
|
|
234
280
|
if isinstance(payload, list):
|
|
235
281
|
raw_messages = payload
|
|
@@ -256,6 +302,19 @@ def _messages_from_history_payload(payload: Any) -> tuple[Message, ...]:
|
|
|
256
302
|
return tuple(messages)
|
|
257
303
|
|
|
258
304
|
|
|
305
|
+
def _latest_messages(messages: Sequence[Message], max_messages: int) -> tuple[Message, ...]:
|
|
306
|
+
if max_messages <= 0:
|
|
307
|
+
return tuple(messages)
|
|
308
|
+
|
|
309
|
+
non_system_indexes = [index for index, message in enumerate(messages) if message.role != "system"]
|
|
310
|
+
kept_non_system_indexes = set(non_system_indexes[-max_messages:])
|
|
311
|
+
return tuple(
|
|
312
|
+
message
|
|
313
|
+
for index, message in enumerate(messages)
|
|
314
|
+
if message.role == "system" or index in kept_non_system_indexes
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
|
|
259
318
|
def normalize_username(username: str) -> str:
|
|
260
319
|
username = str(username).strip()
|
|
261
320
|
if not username:
|
|
@@ -279,7 +338,7 @@ def _username_from_history_item(item: Mapping[str, Any], index: int, role: Role)
|
|
|
279
338
|
|
|
280
339
|
|
|
281
340
|
def _format_user_content_for_api(content: str, username: str) -> str:
|
|
282
|
-
return f"
|
|
341
|
+
return f"{username}: {content}"
|
|
283
342
|
|
|
284
343
|
|
|
285
344
|
def _default_username_for_role(role: Role) -> str:
|
|
@@ -4,7 +4,12 @@ import json
|
|
|
4
4
|
import pytest
|
|
5
5
|
|
|
6
6
|
from codex_chat_bot import ChatConfig, ChatSession, Message
|
|
7
|
-
from codex_chat_bot.session import
|
|
7
|
+
from codex_chat_bot.session import (
|
|
8
|
+
ANSI_RESET,
|
|
9
|
+
ANSI_YELLOW,
|
|
10
|
+
EMPTY_RESPONSE_RETRY_WARNING,
|
|
11
|
+
_extract_response_text,
|
|
12
|
+
)
|
|
8
13
|
|
|
9
14
|
|
|
10
15
|
class FakeResponses:
|
|
@@ -14,8 +19,8 @@ class FakeResponses:
|
|
|
14
19
|
def create(self, **kwargs):
|
|
15
20
|
self.calls.append(kwargs)
|
|
16
21
|
user_message = kwargs["input"][-1]["content"]
|
|
17
|
-
if
|
|
18
|
-
user_message = user_message.split("
|
|
22
|
+
if ": " in user_message:
|
|
23
|
+
user_message = user_message.split(": ", 1)[1]
|
|
19
24
|
return SimpleNamespace(output_text=f"answer: {user_message}")
|
|
20
25
|
|
|
21
26
|
|
|
@@ -24,12 +29,33 @@ class FakeClient:
|
|
|
24
29
|
self.responses = FakeResponses()
|
|
25
30
|
|
|
26
31
|
|
|
32
|
+
class ScriptedResponses:
|
|
33
|
+
def __init__(self, output_texts):
|
|
34
|
+
self.output_texts = list(output_texts)
|
|
35
|
+
self.calls = []
|
|
36
|
+
|
|
37
|
+
def create(self, **kwargs):
|
|
38
|
+
self.calls.append(kwargs)
|
|
39
|
+
return SimpleNamespace(output_text=self.output_texts.pop(0))
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class ScriptedClient:
|
|
43
|
+
def __init__(self, output_texts):
|
|
44
|
+
self.responses = ScriptedResponses(output_texts)
|
|
45
|
+
|
|
46
|
+
|
|
27
47
|
def make_session(**config_overrides):
|
|
28
48
|
config = ChatConfig(api_key="test-key", base_url="https://api.example/v1", model="test-model", **config_overrides)
|
|
29
49
|
client = FakeClient()
|
|
30
50
|
return ChatSession(config=config, client=client), client
|
|
31
51
|
|
|
32
52
|
|
|
53
|
+
def make_scripted_session(output_texts, **config_overrides):
|
|
54
|
+
config = ChatConfig(api_key="test-key", base_url="https://api.example/v1", model="test-model", **config_overrides)
|
|
55
|
+
client = ScriptedClient(output_texts)
|
|
56
|
+
return ChatSession(config=config, client=client), client
|
|
57
|
+
|
|
58
|
+
|
|
33
59
|
def test_session_sends_full_single_session_history():
|
|
34
60
|
session, client = make_session(system_rules=("Follow the test.",))
|
|
35
61
|
|
|
@@ -39,9 +65,9 @@ def test_session_sends_full_single_session_history():
|
|
|
39
65
|
second_input = client.responses.calls[1]["input"]
|
|
40
66
|
assert second_input == [
|
|
41
67
|
{"role": "system", "content": "Follow the test."},
|
|
42
|
-
{"role": "user", "content": "
|
|
68
|
+
{"role": "user", "content": "user: hello"},
|
|
43
69
|
{"role": "assistant", "content": "answer: hello"},
|
|
44
|
-
{"role": "user", "content": "
|
|
70
|
+
{"role": "user", "content": "user: what did I say?"},
|
|
45
71
|
]
|
|
46
72
|
|
|
47
73
|
|
|
@@ -54,9 +80,9 @@ def test_session_sends_usernames_as_model_visible_user_context():
|
|
|
54
80
|
second_input = client.responses.calls[1]["input"]
|
|
55
81
|
assert second_input == [
|
|
56
82
|
{"role": "system", "content": "Follow the test."},
|
|
57
|
-
{"role": "user", "content": "
|
|
83
|
+
{"role": "user", "content": "alice: hello"},
|
|
58
84
|
{"role": "assistant", "content": "answer: hello"},
|
|
59
|
-
{"role": "user", "content": "
|
|
85
|
+
{"role": "user", "content": "bob: same chat, different person"},
|
|
60
86
|
]
|
|
61
87
|
assert session.messages == (
|
|
62
88
|
Message(role="system", content="Follow the test.", username="system"),
|
|
@@ -67,6 +93,71 @@ def test_session_sends_usernames_as_model_visible_user_context():
|
|
|
67
93
|
)
|
|
68
94
|
|
|
69
95
|
|
|
96
|
+
def test_session_retries_empty_model_text_until_non_empty(capsys):
|
|
97
|
+
session, client = make_scripted_session(["", "", "answer: hello"], system_rules=("Follow the test.",))
|
|
98
|
+
|
|
99
|
+
response = session.send("hello")
|
|
100
|
+
captured = capsys.readouterr()
|
|
101
|
+
|
|
102
|
+
assert response.text == "answer: hello"
|
|
103
|
+
assert response.raw.output_text == "answer: hello"
|
|
104
|
+
assert captured.err.count(f"{ANSI_YELLOW}{EMPTY_RESPONSE_RETRY_WARNING}{ANSI_RESET}") == 2
|
|
105
|
+
assert len(client.responses.calls) == 3
|
|
106
|
+
assert [call["input"] for call in client.responses.calls] == [
|
|
107
|
+
[
|
|
108
|
+
{"role": "system", "content": "Follow the test."},
|
|
109
|
+
{"role": "user", "content": "user: hello"},
|
|
110
|
+
],
|
|
111
|
+
[
|
|
112
|
+
{"role": "system", "content": "Follow the test."},
|
|
113
|
+
{"role": "user", "content": "user: hello"},
|
|
114
|
+
],
|
|
115
|
+
[
|
|
116
|
+
{"role": "system", "content": "Follow the test."},
|
|
117
|
+
{"role": "user", "content": "user: hello"},
|
|
118
|
+
],
|
|
119
|
+
]
|
|
120
|
+
assert session.messages == (
|
|
121
|
+
Message(role="system", content="Follow the test.", username="system"),
|
|
122
|
+
Message(role="user", content="hello", username="user"),
|
|
123
|
+
Message(role="assistant", content="answer: hello", username="assistant"),
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def test_session_keeps_retrying_empty_model_texts(capsys):
|
|
128
|
+
session, client = make_scripted_session(["", "", "", "", "", "answer: hello"], system_rules=("Follow the test.",))
|
|
129
|
+
|
|
130
|
+
response = session.send("hello")
|
|
131
|
+
captured = capsys.readouterr()
|
|
132
|
+
|
|
133
|
+
assert response.text == "answer: hello"
|
|
134
|
+
assert response.raw.output_text == "answer: hello"
|
|
135
|
+
assert captured.err.count(f"{ANSI_YELLOW}{EMPTY_RESPONSE_RETRY_WARNING}{ANSI_RESET}") == 5
|
|
136
|
+
assert len(client.responses.calls) == 6
|
|
137
|
+
assert session.messages == (
|
|
138
|
+
Message(role="system", content="Follow the test.", username="system"),
|
|
139
|
+
Message(role="user", content="hello", username="user"),
|
|
140
|
+
Message(role="assistant", content="answer: hello", username="assistant"),
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def test_session_returns_response_error_json_as_message_text():
|
|
145
|
+
session, client = make_session(system_rules=("Follow the test.",))
|
|
146
|
+
client.responses.create = lambda **kwargs: SimpleNamespace(
|
|
147
|
+
error={"code": "model_error", "message": "generation failed"}
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
response = session.send("hello")
|
|
151
|
+
|
|
152
|
+
assert json.loads(response.text) == {
|
|
153
|
+
"error": {
|
|
154
|
+
"code": "model_error",
|
|
155
|
+
"message": "generation failed",
|
|
156
|
+
}
|
|
157
|
+
}
|
|
158
|
+
assert session.messages[-1] == Message(role="assistant", content=response.text, username="assistant")
|
|
159
|
+
|
|
160
|
+
|
|
70
161
|
def test_session_adds_system_rules_to_system_message():
|
|
71
162
|
session, client = make_session(
|
|
72
163
|
system_rules=("Follow the test.", "Answer in English.", "Keep answers short."),
|
|
@@ -76,7 +167,7 @@ def test_session_adds_system_rules_to_system_message():
|
|
|
76
167
|
|
|
77
168
|
assert client.responses.calls[0]["input"][0] == {
|
|
78
169
|
"role": "system",
|
|
79
|
-
"content": "Follow the test
|
|
170
|
+
"content": "Follow the test. Answer in English. Keep answers short.",
|
|
80
171
|
}
|
|
81
172
|
|
|
82
173
|
|
|
@@ -112,6 +203,31 @@ def test_history_limit_keeps_latest_non_system_messages():
|
|
|
112
203
|
)
|
|
113
204
|
|
|
114
205
|
|
|
206
|
+
def test_request_payload_keeps_all_system_messages_and_latest_five_non_system_messages():
|
|
207
|
+
session, client = make_session(system_rules=("ignored",))
|
|
208
|
+
history_messages = [
|
|
209
|
+
{"role": "system", "content": "Persisted system."},
|
|
210
|
+
{"role": "user", "content": "old user"},
|
|
211
|
+
{"role": "system", "content": "Later system."},
|
|
212
|
+
]
|
|
213
|
+
history_messages.extend(
|
|
214
|
+
{"role": "user" if index % 2 == 0 else "assistant", "content": f"message {index}"}
|
|
215
|
+
for index in range(8)
|
|
216
|
+
)
|
|
217
|
+
session.load_history_json(json.dumps({"messages": history_messages}))
|
|
218
|
+
|
|
219
|
+
session.ask("new")
|
|
220
|
+
|
|
221
|
+
request_input = client.responses.calls[0]["input"]
|
|
222
|
+
assert len(request_input) == 7
|
|
223
|
+
assert request_input[0] == {"role": "system", "content": "Persisted system."}
|
|
224
|
+
assert request_input[1] == {"role": "system", "content": "Later system."}
|
|
225
|
+
assert request_input[2] == {"role": "user", "content": "user: message 4"}
|
|
226
|
+
assert request_input[-2] == {"role": "assistant", "content": "message 7"}
|
|
227
|
+
assert request_input[-1] == {"role": "user", "content": "user: new"}
|
|
228
|
+
assert session.messages[1] == Message(role="user", content="old user", username="user")
|
|
229
|
+
|
|
230
|
+
|
|
115
231
|
def test_session_exports_and_imports_history_json():
|
|
116
232
|
session, _ = make_session(system_rules=("Follow the test.",))
|
|
117
233
|
|
|
@@ -135,9 +251,9 @@ def test_session_exports_and_imports_history_json():
|
|
|
135
251
|
|
|
136
252
|
assert client.responses.calls[0]["input"] == [
|
|
137
253
|
{"role": "system", "content": "Follow the test."},
|
|
138
|
-
{"role": "user", "content": "
|
|
254
|
+
{"role": "user", "content": "user: hello"},
|
|
139
255
|
{"role": "assistant", "content": "answer: hello"},
|
|
140
|
-
{"role": "user", "content": "
|
|
256
|
+
{"role": "user", "content": "user: continue"},
|
|
141
257
|
]
|
|
142
258
|
|
|
143
259
|
|
|
@@ -184,9 +300,9 @@ def test_session_bind_history_loads_existing_file_and_saves_updates(tmp_path):
|
|
|
184
300
|
|
|
185
301
|
assert client.responses.calls[0]["input"] == [
|
|
186
302
|
{"role": "system", "content": "Persisted system."},
|
|
187
|
-
{"role": "user", "content": "
|
|
303
|
+
{"role": "user", "content": "user: old"},
|
|
188
304
|
{"role": "assistant", "content": "answer: old"},
|
|
189
|
-
{"role": "user", "content": "
|
|
305
|
+
{"role": "user", "content": "user: new"},
|
|
190
306
|
]
|
|
191
307
|
assert json.loads(history_file.read_text(encoding="utf-8")) == {
|
|
192
308
|
"messages": [
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|