ai-plays-jackbox 0.0.1__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-plays-jackbox might be problematic. Click here for more details.
- ai_plays_jackbox/__init__.py +0 -1
- ai_plays_jackbox/bot/__init__.py +0 -1
- ai_plays_jackbox/bot/bot_base.py +42 -8
- ai_plays_jackbox/bot/bot_factory.py +18 -17
- ai_plays_jackbox/bot/bot_personality.py +0 -1
- ai_plays_jackbox/bot/jackbox5/__init__.py +0 -0
- ai_plays_jackbox/bot/jackbox5/bot_base.py +26 -0
- ai_plays_jackbox/bot/jackbox5/mad_verse_city.py +121 -0
- ai_plays_jackbox/bot/jackbox5/patently_stupid.py +167 -0
- ai_plays_jackbox/bot/jackbox6/bot_base.py +1 -1
- ai_plays_jackbox/bot/jackbox6/dictionarium.py +105 -0
- ai_plays_jackbox/bot/jackbox6/joke_boat.py +105 -0
- ai_plays_jackbox/bot/jackbox7/bot_base.py +1 -1
- ai_plays_jackbox/bot/jackbox7/quiplash3.py +8 -4
- ai_plays_jackbox/bot/jackbox8/__init__.py +0 -0
- ai_plays_jackbox/bot/jackbox8/bot_base.py +20 -0
- ai_plays_jackbox/bot/jackbox8/job_job.py +205 -0
- ai_plays_jackbox/bot/standalone/__init__.py +0 -0
- ai_plays_jackbox/bot/standalone/drawful2.py +159 -0
- ai_plays_jackbox/cli/__init__.py +0 -0
- ai_plays_jackbox/{cli.py → cli/main.py} +28 -15
- ai_plays_jackbox/constants.py +3 -0
- ai_plays_jackbox/llm/chat_model.py +20 -3
- ai_plays_jackbox/llm/chat_model_factory.py +24 -19
- ai_plays_jackbox/llm/gemini_model.py +86 -0
- ai_plays_jackbox/llm/ollama_model.py +19 -7
- ai_plays_jackbox/llm/openai_model.py +48 -8
- ai_plays_jackbox/room/__init__.py +0 -0
- ai_plays_jackbox/{room.py → room/room.py} +2 -5
- ai_plays_jackbox/run.py +8 -11
- ai_plays_jackbox/scripts/lint.py +18 -0
- ai_plays_jackbox/ui/main.py +12 -0
- ai_plays_jackbox/ui/startup.py +248 -0
- ai_plays_jackbox-0.2.0.dist-info/METADATA +156 -0
- ai_plays_jackbox-0.2.0.dist-info/RECORD +42 -0
- ai_plays_jackbox-0.2.0.dist-info/entry_points.txt +4 -0
- ai_plays_jackbox/llm/gemini_vertex_ai.py +0 -60
- ai_plays_jackbox/ui/create_ui.py +0 -169
- ai_plays_jackbox/web_ui.py +0 -12
- ai_plays_jackbox-0.0.1.dist-info/METADATA +0 -88
- ai_plays_jackbox-0.0.1.dist-info/RECORD +0 -28
- ai_plays_jackbox-0.0.1.dist-info/entry_points.txt +0 -5
- {ai_plays_jackbox-0.0.1.dist-info → ai_plays_jackbox-0.2.0.dist-info}/LICENSE +0 -0
- {ai_plays_jackbox-0.0.1.dist-info → ai_plays_jackbox-0.2.0.dist-info}/WHEEL +0 -0
|
@@ -1,30 +1,35 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
1
3
|
from ai_plays_jackbox.llm.chat_model import ChatModel
|
|
2
|
-
from ai_plays_jackbox.llm.
|
|
4
|
+
from ai_plays_jackbox.llm.gemini_model import GeminiModel
|
|
3
5
|
from ai_plays_jackbox.llm.ollama_model import OllamaModel
|
|
4
6
|
from ai_plays_jackbox.llm.openai_model import OpenAIModel
|
|
5
7
|
|
|
8
|
+
CHAT_MODEL_PROVIDERS: dict[str, type[ChatModel]] = {
|
|
9
|
+
"openai": OpenAIModel,
|
|
10
|
+
"gemini": GeminiModel,
|
|
11
|
+
"ollama": OllamaModel,
|
|
12
|
+
}
|
|
13
|
+
|
|
6
14
|
|
|
7
15
|
class ChatModelFactory:
|
|
8
16
|
@staticmethod
|
|
9
17
|
def get_chat_model(
|
|
10
|
-
|
|
18
|
+
chat_model_provider: str,
|
|
19
|
+
chat_model_name: Optional[str] = None,
|
|
11
20
|
chat_model_temperature: float = 0.5,
|
|
12
21
|
chat_model_top_p: float = 0.9,
|
|
13
22
|
) -> ChatModel:
|
|
14
|
-
|
|
15
|
-
if
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
chat_model_top_p=chat_model_top_p,
|
|
28
|
-
)
|
|
29
|
-
else:
|
|
30
|
-
raise ValueError(f"Unknown chat model type: {chat_model_name}")
|
|
23
|
+
chat_model_provider = chat_model_provider.lower()
|
|
24
|
+
if chat_model_provider not in CHAT_MODEL_PROVIDERS.keys():
|
|
25
|
+
raise ValueError(f"Unknown chat model provider: {chat_model_provider}")
|
|
26
|
+
|
|
27
|
+
return CHAT_MODEL_PROVIDERS[chat_model_provider](
|
|
28
|
+
(
|
|
29
|
+
chat_model_name
|
|
30
|
+
if chat_model_name is not None
|
|
31
|
+
else CHAT_MODEL_PROVIDERS[chat_model_provider].get_default_model()
|
|
32
|
+
),
|
|
33
|
+
chat_model_temperature=chat_model_temperature,
|
|
34
|
+
chat_model_top_p=chat_model_top_p,
|
|
35
|
+
)
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
from google import genai
|
|
5
|
+
from google.genai.types import GenerateContentConfig
|
|
6
|
+
from loguru import logger
|
|
7
|
+
|
|
8
|
+
from ai_plays_jackbox.llm.chat_model import ChatModel
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class GeminiModel(ChatModel):
|
|
12
|
+
_gemini_vertex_ai_client: genai.Client
|
|
13
|
+
|
|
14
|
+
def __init__(self, *args, **kwargs):
|
|
15
|
+
super().__init__(*args, **kwargs)
|
|
16
|
+
self._gemini_vertex_ai_client = genai.Client(
|
|
17
|
+
vertexai=bool(os.environ.get("GOOGLE_GENAI_USE_VERTEXAI")),
|
|
18
|
+
api_key=os.environ.get("GOOGLE_GEMINI_DEVELOPER_API_KEY"),
|
|
19
|
+
project=os.environ.get("GOOGLE_CLOUD_PROJECT"),
|
|
20
|
+
location=os.environ.get("GOOGLE_CLOUD_LOCATION"),
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
# Check connection and if model exists, this will hard fail if connection can't be made
|
|
24
|
+
# Or if the model is not found
|
|
25
|
+
_ = self._gemini_vertex_ai_client.models.get(self._model)
|
|
26
|
+
|
|
27
|
+
@classmethod
|
|
28
|
+
def get_default_model(cls):
|
|
29
|
+
return "gemini-2.0-flash-001"
|
|
30
|
+
|
|
31
|
+
def generate_text(
|
|
32
|
+
self,
|
|
33
|
+
prompt: str,
|
|
34
|
+
instructions: str,
|
|
35
|
+
max_tokens: Optional[int] = None,
|
|
36
|
+
temperature: Optional[float] = None,
|
|
37
|
+
top_p: Optional[float] = None,
|
|
38
|
+
) -> str:
|
|
39
|
+
if temperature is None:
|
|
40
|
+
temperature = self._chat_model_temperature
|
|
41
|
+
if top_p is None:
|
|
42
|
+
top_p = self._chat_model_top_p
|
|
43
|
+
|
|
44
|
+
chat_response = self._gemini_vertex_ai_client.models.generate_content(
|
|
45
|
+
model=self._model,
|
|
46
|
+
contents=prompt,
|
|
47
|
+
config=GenerateContentConfig(
|
|
48
|
+
system_instruction=[instructions],
|
|
49
|
+
max_output_tokens=max_tokens,
|
|
50
|
+
temperature=temperature,
|
|
51
|
+
top_p=top_p,
|
|
52
|
+
),
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
text = str(chat_response.text).strip().replace("\n", "")
|
|
56
|
+
logger.info(f"Generated text: {text}")
|
|
57
|
+
return text
|
|
58
|
+
|
|
59
|
+
def generate_sketch(
|
|
60
|
+
self,
|
|
61
|
+
prompt: str,
|
|
62
|
+
instructions: str,
|
|
63
|
+
temperature: Optional[float] = None,
|
|
64
|
+
top_p: Optional[float] = None,
|
|
65
|
+
) -> bytes:
|
|
66
|
+
image_gen_response = self._gemini_vertex_ai_client.models.generate_content(
|
|
67
|
+
model="gemini-2.0-flash-preview-image-generation",
|
|
68
|
+
contents=prompt,
|
|
69
|
+
config=GenerateContentConfig(
|
|
70
|
+
system_instruction=[instructions],
|
|
71
|
+
temperature=temperature,
|
|
72
|
+
top_p=top_p,
|
|
73
|
+
response_modalities=["IMAGE"],
|
|
74
|
+
),
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
if (
|
|
78
|
+
image_gen_response.candidates
|
|
79
|
+
and image_gen_response.candidates[0].content
|
|
80
|
+
and image_gen_response.candidates[0].content.parts
|
|
81
|
+
):
|
|
82
|
+
for part in image_gen_response.candidates[0].content.parts:
|
|
83
|
+
if part.inline_data is not None and part.inline_data.data is not None:
|
|
84
|
+
return part.inline_data.data
|
|
85
|
+
|
|
86
|
+
return b""
|
|
@@ -1,20 +1,23 @@
|
|
|
1
1
|
from typing import Optional
|
|
2
2
|
|
|
3
3
|
from loguru import logger
|
|
4
|
-
from ollama import Options, chat,
|
|
4
|
+
from ollama import Options, chat, show
|
|
5
5
|
|
|
6
6
|
from ai_plays_jackbox.llm.chat_model import ChatModel
|
|
7
7
|
|
|
8
8
|
|
|
9
9
|
class OllamaModel(ChatModel):
|
|
10
|
-
_model: str
|
|
11
10
|
|
|
12
|
-
def __init__(self,
|
|
11
|
+
def __init__(self, *args, **kwargs):
|
|
13
12
|
super().__init__(*args, **kwargs)
|
|
14
|
-
self._model = model
|
|
15
13
|
|
|
16
|
-
# Check connection, this will hard fail if connection can't be made
|
|
17
|
-
|
|
14
|
+
# Check connection and if model exists, this will hard fail if connection can't be made
|
|
15
|
+
# Or if the model is not found
|
|
16
|
+
_ = show(self._model)
|
|
17
|
+
|
|
18
|
+
@classmethod
|
|
19
|
+
def get_default_model(cls):
|
|
20
|
+
return "gemma3:12b"
|
|
18
21
|
|
|
19
22
|
def generate_text(
|
|
20
23
|
self,
|
|
@@ -36,6 +39,15 @@ class OllamaModel(ChatModel):
|
|
|
36
39
|
stream=False,
|
|
37
40
|
options=Options(num_predict=max_tokens, temperature=temperature, top_p=top_p),
|
|
38
41
|
)
|
|
39
|
-
text = chat_response.message.content.strip().replace("\n", " ")
|
|
42
|
+
text = str(chat_response.message.content).strip().replace("\n", " ")
|
|
40
43
|
logger.info(f"Generated text: {text}")
|
|
41
44
|
return text
|
|
45
|
+
|
|
46
|
+
def generate_sketch(
|
|
47
|
+
self,
|
|
48
|
+
prompt: str,
|
|
49
|
+
instructions: str,
|
|
50
|
+
temperature: Optional[float] = None,
|
|
51
|
+
top_p: Optional[float] = None,
|
|
52
|
+
) -> bytes:
|
|
53
|
+
raise Exception("Ollama model not supported yet for sketches")
|
|
@@ -1,23 +1,32 @@
|
|
|
1
|
+
import base64
|
|
1
2
|
import os
|
|
2
3
|
from typing import Optional
|
|
3
4
|
|
|
4
5
|
from loguru import logger
|
|
5
6
|
from openai import OpenAI
|
|
7
|
+
from openai.types.chat import (
|
|
8
|
+
ChatCompletionDeveloperMessageParam,
|
|
9
|
+
ChatCompletionUserMessageParam,
|
|
10
|
+
)
|
|
11
|
+
from openai.types.responses import Response
|
|
6
12
|
|
|
7
13
|
from ai_plays_jackbox.llm.chat_model import ChatModel
|
|
8
14
|
|
|
9
15
|
|
|
10
16
|
class OpenAIModel(ChatModel):
|
|
11
|
-
_model: str
|
|
12
17
|
_open_ai_client: OpenAI
|
|
13
18
|
|
|
14
|
-
def __init__(self,
|
|
19
|
+
def __init__(self, *args, **kwargs):
|
|
15
20
|
super().__init__(*args, **kwargs)
|
|
16
21
|
self._open_ai_client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
|
|
17
|
-
self._model = model
|
|
18
22
|
|
|
19
|
-
# Check connection, this will hard fail if connection can't be made
|
|
20
|
-
|
|
23
|
+
# Check connection and if model exists, this will hard fail if connection can't be made
|
|
24
|
+
# Or if the model is not found
|
|
25
|
+
_ = self._open_ai_client.models.retrieve(self._model)
|
|
26
|
+
|
|
27
|
+
@classmethod
|
|
28
|
+
def get_default_model(cls):
|
|
29
|
+
return "gpt-4o-mini"
|
|
21
30
|
|
|
22
31
|
def generate_text(
|
|
23
32
|
self,
|
|
@@ -32,15 +41,46 @@ class OpenAIModel(ChatModel):
|
|
|
32
41
|
if top_p is None:
|
|
33
42
|
top_p = self._chat_model_top_p
|
|
34
43
|
|
|
35
|
-
instructions_formatted = {"role": "developer", "content": instructions}
|
|
36
44
|
chat_response = self._open_ai_client.chat.completions.create(
|
|
37
45
|
model=self._model,
|
|
38
|
-
messages=[
|
|
46
|
+
messages=[
|
|
47
|
+
ChatCompletionDeveloperMessageParam(content=instructions, role="developer"),
|
|
48
|
+
ChatCompletionUserMessageParam(content=prompt, role="user"),
|
|
49
|
+
],
|
|
39
50
|
stream=False,
|
|
40
51
|
max_completion_tokens=max_tokens,
|
|
41
52
|
temperature=temperature,
|
|
42
53
|
top_p=top_p,
|
|
43
54
|
)
|
|
44
|
-
text = chat_response.choices[0].message.content.strip().replace("\n", "")
|
|
55
|
+
text = str(chat_response.choices[0].message.content).strip().replace("\n", "")
|
|
45
56
|
logger.info(f"Generated text: {text}")
|
|
46
57
|
return text
|
|
58
|
+
|
|
59
|
+
def generate_sketch(
|
|
60
|
+
self,
|
|
61
|
+
prompt: str,
|
|
62
|
+
instructions: str,
|
|
63
|
+
temperature: Optional[float] = None,
|
|
64
|
+
top_p: Optional[float] = None,
|
|
65
|
+
) -> bytes:
|
|
66
|
+
image_gen_response: Response = self._open_ai_client.responses.create(
|
|
67
|
+
model=self._model,
|
|
68
|
+
instructions=instructions,
|
|
69
|
+
input=prompt,
|
|
70
|
+
temperature=temperature,
|
|
71
|
+
top_p=top_p,
|
|
72
|
+
tools=[
|
|
73
|
+
{
|
|
74
|
+
"type": "image_generation",
|
|
75
|
+
"quality": "low",
|
|
76
|
+
"size": "1024x1024",
|
|
77
|
+
}
|
|
78
|
+
],
|
|
79
|
+
)
|
|
80
|
+
# Save the image to a file
|
|
81
|
+
image_data = [output.result for output in image_gen_response.output if output.type == "image_generation_call"]
|
|
82
|
+
image_base64 = ""
|
|
83
|
+
if image_data:
|
|
84
|
+
image_base64 = str(image_data[0])
|
|
85
|
+
|
|
86
|
+
return base64.b64decode(image_base64)
|
|
File without changes
|
|
@@ -11,7 +11,6 @@ from ai_plays_jackbox.bot.bot_factory import JackBoxBotFactory
|
|
|
11
11
|
from ai_plays_jackbox.bot.bot_personality import JackBoxBotVariant
|
|
12
12
|
from ai_plays_jackbox.constants import ECAST_HOST
|
|
13
13
|
from ai_plays_jackbox.llm.chat_model import ChatModel
|
|
14
|
-
from ai_plays_jackbox.llm.ollama_model import OllamaModel
|
|
15
14
|
|
|
16
15
|
|
|
17
16
|
class JackBoxRoom:
|
|
@@ -24,12 +23,10 @@ class JackBoxRoom:
|
|
|
24
23
|
def play(
|
|
25
24
|
self,
|
|
26
25
|
room_code: str,
|
|
26
|
+
chat_model: ChatModel,
|
|
27
27
|
num_of_bots: int = 4,
|
|
28
28
|
bots_in_play: Optional[list] = None,
|
|
29
|
-
chat_model: Optional[ChatModel] = None,
|
|
30
29
|
):
|
|
31
|
-
if chat_model is None:
|
|
32
|
-
chat_model = OllamaModel()
|
|
33
30
|
room_type = self._get_room_type(room_code)
|
|
34
31
|
if not room_type:
|
|
35
32
|
logger.error(f"Unable to find room {room_code}")
|
|
@@ -45,9 +42,9 @@ class JackBoxRoom:
|
|
|
45
42
|
for b in bots_to_make:
|
|
46
43
|
bot = bot_factory.get_bot(
|
|
47
44
|
room_type,
|
|
45
|
+
chat_model,
|
|
48
46
|
name=b.value.name,
|
|
49
47
|
personality=b.value.personality,
|
|
50
|
-
chat_model=chat_model,
|
|
51
48
|
)
|
|
52
49
|
self._bots.append(bot)
|
|
53
50
|
with self._lock:
|
ai_plays_jackbox/run.py
CHANGED
|
@@ -1,26 +1,23 @@
|
|
|
1
1
|
from typing import Optional
|
|
2
2
|
|
|
3
3
|
from ai_plays_jackbox.llm.chat_model_factory import ChatModelFactory
|
|
4
|
-
from ai_plays_jackbox.room import JackBoxRoom
|
|
4
|
+
from ai_plays_jackbox.room.room import JackBoxRoom
|
|
5
5
|
|
|
6
6
|
|
|
7
7
|
def run(
|
|
8
8
|
room_code: str,
|
|
9
|
+
chat_model_provider: str,
|
|
10
|
+
chat_model_name: Optional[str] = None,
|
|
9
11
|
num_of_bots: int = 4,
|
|
10
12
|
bots_in_play: Optional[list] = None,
|
|
11
|
-
chat_model_name: str = "ollama",
|
|
12
13
|
chat_model_temperature: float = 0.5,
|
|
13
14
|
chat_model_top_p: float = 0.9,
|
|
14
15
|
):
|
|
15
|
-
"""Will run a set of bots through a game of JackBox given a room code.
|
|
16
|
-
|
|
17
|
-
Args:
|
|
18
|
-
room_code (str): The room code.
|
|
19
|
-
num_of_bots (int, optional): The number of bots to participate. Defaults to 4.
|
|
20
|
-
chat_model (str, optional): The chat model to use to generate responses. Defaults to "ollama".
|
|
21
|
-
"""
|
|
22
16
|
chat_model = ChatModelFactory.get_chat_model(
|
|
23
|
-
|
|
17
|
+
chat_model_provider,
|
|
18
|
+
chat_model_name=chat_model_name,
|
|
19
|
+
chat_model_temperature=chat_model_temperature,
|
|
20
|
+
chat_model_top_p=chat_model_top_p,
|
|
24
21
|
)
|
|
25
22
|
room = JackBoxRoom()
|
|
26
|
-
room.play(room_code, num_of_bots=num_of_bots, bots_in_play=bots_in_play
|
|
23
|
+
room.play(room_code, chat_model, num_of_bots=num_of_bots, bots_in_play=bots_in_play)
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
import subprocess
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def run():
|
|
5
|
+
commands = [
|
|
6
|
+
["autoflake", "--in-place", "--recursive", "--remove-all-unused-imports", "--verbose", "ai_plays_jackbox"],
|
|
7
|
+
["isort", "--profile", "black", "--project=ai_plays_jackbox", "ai_plays_jackbox"],
|
|
8
|
+
["black", "-l", "120", "ai_plays_jackbox"],
|
|
9
|
+
["mypy", "ai_plays_jackbox"],
|
|
10
|
+
]
|
|
11
|
+
|
|
12
|
+
for cmd in commands:
|
|
13
|
+
print(f"\n>>> Running: {' '.join(cmd)}")
|
|
14
|
+
subprocess.run(["poetry", "run"] + cmd, check=True)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
if __name__ == "__main__":
|
|
18
|
+
run()
|
|
@@ -0,0 +1,248 @@
|
|
|
1
|
+
from multiprocessing import Process, Queue
|
|
2
|
+
|
|
3
|
+
import psutil
|
|
4
|
+
from loguru import logger
|
|
5
|
+
from nicegui import app, ui
|
|
6
|
+
|
|
7
|
+
from ai_plays_jackbox.bot.bot_personality import JackBoxBotVariant
|
|
8
|
+
from ai_plays_jackbox.constants import (
|
|
9
|
+
DEFAULT_NUM_OF_BOTS,
|
|
10
|
+
DEFAULT_TEMPERATURE,
|
|
11
|
+
DEFAULT_TOP_P,
|
|
12
|
+
)
|
|
13
|
+
from ai_plays_jackbox.llm.chat_model_factory import CHAT_MODEL_PROVIDERS
|
|
14
|
+
from ai_plays_jackbox.run import run
|
|
15
|
+
|
|
16
|
+
LOG_QUEUE: Queue = Queue()
|
|
17
|
+
LOG_DISPLAY = None
|
|
18
|
+
SELECT_ALL_BOT_VARIANTS = None
|
|
19
|
+
BOT_VARIANT_CHECKBOX_STATES: dict = {}
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _format_log(record):
|
|
23
|
+
thread_name = record["thread"].name
|
|
24
|
+
color = "red"
|
|
25
|
+
colored_name = f"<{color}>{thread_name:<12}</{color}>"
|
|
26
|
+
|
|
27
|
+
return (
|
|
28
|
+
f"<green>{record['time']:YYYY-MM-DD HH:mm:ss}</green> | "
|
|
29
|
+
f"<cyan>{record['level']:<8}</cyan> | "
|
|
30
|
+
f"{colored_name} | "
|
|
31
|
+
f"{record['message']}\n"
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _build_log_display():
|
|
36
|
+
global LOG_DISPLAY
|
|
37
|
+
with ui.row().classes("w-full"):
|
|
38
|
+
LOG_DISPLAY = ui.log(max_lines=100).classes("h-64 overflow-auto bg-black text-white")
|
|
39
|
+
ui.timer(interval=0.5, callback=_poll_log_queue)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _poll_log_queue():
|
|
43
|
+
global LOG_DISPLAY
|
|
44
|
+
try:
|
|
45
|
+
while not LOG_QUEUE.empty():
|
|
46
|
+
log_msg = LOG_QUEUE.get_nowait()
|
|
47
|
+
LOG_DISPLAY.push(log_msg)
|
|
48
|
+
except Exception as e:
|
|
49
|
+
LOG_DISPLAY.push(f"[ERROR] Failed to read log: {e}")
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def _start(
|
|
53
|
+
room_code: str,
|
|
54
|
+
chat_model_provider: str,
|
|
55
|
+
chat_model_name: str,
|
|
56
|
+
num_of_bots: int,
|
|
57
|
+
bots_in_play: list[str],
|
|
58
|
+
temperature: float,
|
|
59
|
+
top_p: float,
|
|
60
|
+
log_queue: Queue,
|
|
61
|
+
):
|
|
62
|
+
logger.add(lambda msg: log_queue.put(msg), format=_format_log)
|
|
63
|
+
|
|
64
|
+
try:
|
|
65
|
+
run(
|
|
66
|
+
room_code.strip().upper(),
|
|
67
|
+
chat_model_provider,
|
|
68
|
+
chat_model_name=chat_model_name,
|
|
69
|
+
num_of_bots=num_of_bots,
|
|
70
|
+
bots_in_play=bots_in_play,
|
|
71
|
+
chat_model_temperature=temperature,
|
|
72
|
+
chat_model_top_p=top_p,
|
|
73
|
+
)
|
|
74
|
+
except Exception as e:
|
|
75
|
+
logger.exception("Bot startup failed")
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _is_game_process_alive():
|
|
79
|
+
game_pid = app.storage.general.get("game_pid", None)
|
|
80
|
+
is_game_alive = game_pid is not None and psutil.pid_exists(game_pid) and psutil.Process(game_pid).is_running()
|
|
81
|
+
if not is_game_alive:
|
|
82
|
+
app.storage.general["game_pid"] = None
|
|
83
|
+
return is_game_alive
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def _handle_start_click(
|
|
87
|
+
room_code: str,
|
|
88
|
+
chat_model_provider: str,
|
|
89
|
+
chat_model_name: str,
|
|
90
|
+
num_of_bots: int,
|
|
91
|
+
temperature: float,
|
|
92
|
+
top_p: float,
|
|
93
|
+
):
|
|
94
|
+
global BOT_VARIANT_CHECKBOX_STATES
|
|
95
|
+
|
|
96
|
+
if not _is_game_process_alive():
|
|
97
|
+
logger.info("Starting...")
|
|
98
|
+
game_thread = Process(
|
|
99
|
+
target=_start,
|
|
100
|
+
args=(
|
|
101
|
+
room_code,
|
|
102
|
+
chat_model_provider,
|
|
103
|
+
chat_model_name,
|
|
104
|
+
num_of_bots,
|
|
105
|
+
[k for k, v in BOT_VARIANT_CHECKBOX_STATES.items() if v.value],
|
|
106
|
+
temperature,
|
|
107
|
+
top_p,
|
|
108
|
+
LOG_QUEUE,
|
|
109
|
+
),
|
|
110
|
+
daemon=True,
|
|
111
|
+
)
|
|
112
|
+
game_thread.start()
|
|
113
|
+
app.storage.general["game_pid"] = game_thread.pid
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def _select_all_bot_variants_changed():
|
|
117
|
+
for checkbox in BOT_VARIANT_CHECKBOX_STATES.values():
|
|
118
|
+
checkbox.value = SELECT_ALL_BOT_VARIANTS.value
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def _sync_select_all_bot_variants():
|
|
122
|
+
all_checked = all(cb.value for cb in BOT_VARIANT_CHECKBOX_STATES.values())
|
|
123
|
+
SELECT_ALL_BOT_VARIANTS.value = all_checked
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def _setup_bot_variant_display():
|
|
127
|
+
global SELECT_ALL_BOT_VARIANTS
|
|
128
|
+
with ui.list().props("bordered separator").classes("w-full"):
|
|
129
|
+
with ui.item_label("Bot Personalities").props("header").classes("text-bold"):
|
|
130
|
+
SELECT_ALL_BOT_VARIANTS = ui.checkbox(text="Select All", value=True)
|
|
131
|
+
SELECT_ALL_BOT_VARIANTS.on("update:model-value", lambda e: _select_all_bot_variants_changed())
|
|
132
|
+
ui.separator()
|
|
133
|
+
with ui.element("div").classes("overflow-y-auto h-64"):
|
|
134
|
+
for variant in list(JackBoxBotVariant):
|
|
135
|
+
with ui.item():
|
|
136
|
+
with ui.item_section().props("avatar"):
|
|
137
|
+
cb = ui.checkbox(value=True)
|
|
138
|
+
cb.on("update:model-value", lambda e: _sync_select_all_bot_variants())
|
|
139
|
+
BOT_VARIANT_CHECKBOX_STATES[variant.name] = cb
|
|
140
|
+
with ui.item_section():
|
|
141
|
+
ui.item_label(variant.value.name)
|
|
142
|
+
ui.item_label(variant.value.personality).props("caption")
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def startup():
|
|
146
|
+
ui.page_title("AI Plays JackBox")
|
|
147
|
+
ui.label("🤖 AI Plays JackBox").classes("text-2xl font-bold")
|
|
148
|
+
|
|
149
|
+
_build_log_display()
|
|
150
|
+
|
|
151
|
+
with ui.grid(columns=16).classes("w-full gap-0"):
|
|
152
|
+
with ui.column().classes("col-span-1"):
|
|
153
|
+
pass
|
|
154
|
+
with ui.column().classes("col-span-7"):
|
|
155
|
+
with ui.row():
|
|
156
|
+
ui.label("Number of Bots")
|
|
157
|
+
num_of_bots_label = ui.label(str(DEFAULT_NUM_OF_BOTS))
|
|
158
|
+
num_of_bots = ui.slider(
|
|
159
|
+
min=1,
|
|
160
|
+
max=10,
|
|
161
|
+
value=DEFAULT_NUM_OF_BOTS,
|
|
162
|
+
step=1,
|
|
163
|
+
on_change=lambda e: num_of_bots_label.set_text(f"{e.value}"),
|
|
164
|
+
)
|
|
165
|
+
chat_model_provider = ui.select(
|
|
166
|
+
list(CHAT_MODEL_PROVIDERS.keys()),
|
|
167
|
+
label="Chat Model Provider",
|
|
168
|
+
value=list(CHAT_MODEL_PROVIDERS.keys())[0],
|
|
169
|
+
on_change=lambda e: chat_model_name.set_value(CHAT_MODEL_PROVIDERS[e.value].get_default_model()),
|
|
170
|
+
).classes("w-1/3")
|
|
171
|
+
|
|
172
|
+
chat_model_name = ui.input(
|
|
173
|
+
label="Chat Model Name",
|
|
174
|
+
value=CHAT_MODEL_PROVIDERS[chat_model_provider.value].get_default_model(),
|
|
175
|
+
).classes("w-1/3")
|
|
176
|
+
|
|
177
|
+
room_code = (
|
|
178
|
+
ui.input(
|
|
179
|
+
label="Room Code",
|
|
180
|
+
placeholder="ABCD",
|
|
181
|
+
validation={
|
|
182
|
+
"must be letters only": lambda value: value.isalpha(),
|
|
183
|
+
"must be 4 letters": lambda value: len(value) == 4,
|
|
184
|
+
},
|
|
185
|
+
)
|
|
186
|
+
.props("uppercase")
|
|
187
|
+
.classes("w-1/4")
|
|
188
|
+
)
|
|
189
|
+
start_button = (
|
|
190
|
+
ui.button(
|
|
191
|
+
"Start Bots",
|
|
192
|
+
on_click=lambda _: _handle_start_click(
|
|
193
|
+
room_code.value,
|
|
194
|
+
chat_model_provider.value,
|
|
195
|
+
chat_model_name.value,
|
|
196
|
+
num_of_bots.value,
|
|
197
|
+
temperature.value,
|
|
198
|
+
top_p.value,
|
|
199
|
+
),
|
|
200
|
+
)
|
|
201
|
+
.bind_enabled_from(room_code, "error", lambda error: room_code.value and not error)
|
|
202
|
+
.classes("w-full")
|
|
203
|
+
)
|
|
204
|
+
ui.timer(
|
|
205
|
+
interval=0.5,
|
|
206
|
+
callback=lambda: start_button.props(
|
|
207
|
+
f"color={'blue' if _is_game_process_alive() else 'green'}"
|
|
208
|
+
).set_text("Running..." if _is_game_process_alive() else "Start Bots"),
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
ui.label("Advanced Options").classes("w-full text-xl font-bold")
|
|
212
|
+
|
|
213
|
+
ui.label("Temperature").classes("w-1/4").tooltip(
|
|
214
|
+
"""
|
|
215
|
+
What sampling temperature to use, between 0 and 2. Higher values like 0.8 will
|
|
216
|
+
make the output more random, while lower values like 0.2 will make it more
|
|
217
|
+
focused and deterministic. We generally recommend altering this or `top_p` but
|
|
218
|
+
not both."""
|
|
219
|
+
)
|
|
220
|
+
temperature_label = ui.label(str(DEFAULT_TEMPERATURE)).classes("w-1/6")
|
|
221
|
+
temperature = ui.slider(
|
|
222
|
+
min=0.0,
|
|
223
|
+
max=2.0,
|
|
224
|
+
value=DEFAULT_TEMPERATURE,
|
|
225
|
+
step=0.1,
|
|
226
|
+
on_change=lambda e: temperature_label.set_text(f"{e.value}"),
|
|
227
|
+
).classes("w-1/2")
|
|
228
|
+
|
|
229
|
+
ui.label("Top P").classes("w-1/4").tooltip(
|
|
230
|
+
"""
|
|
231
|
+
An alternative to sampling with temperature, called nucleus sampling, where the
|
|
232
|
+
model considers the results of the tokens with top_p probability mass. So 0.1
|
|
233
|
+
means only the tokens comprising the top 10% probability mass are considered."""
|
|
234
|
+
)
|
|
235
|
+
top_p_label = ui.label(str(DEFAULT_TOP_P)).classes("w-1/6")
|
|
236
|
+
top_p = ui.slider(
|
|
237
|
+
min=0.0,
|
|
238
|
+
max=1.0,
|
|
239
|
+
value=DEFAULT_TOP_P,
|
|
240
|
+
step=0.1,
|
|
241
|
+
on_change=lambda e: top_p_label.set_text(f"{e.value}"),
|
|
242
|
+
).classes("w-1/2")
|
|
243
|
+
|
|
244
|
+
with ui.column().classes("col-span-1"):
|
|
245
|
+
pass
|
|
246
|
+
|
|
247
|
+
with ui.column().classes("col-span-6"):
|
|
248
|
+
_setup_bot_variant_display()
|