ai-plays-jackbox 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-plays-jackbox might be problematic. Click here for more details.
- ai_plays_jackbox/__init__.py +0 -1
- ai_plays_jackbox/bot/__init__.py +0 -1
- ai_plays_jackbox/bot/bot_base.py +42 -8
- ai_plays_jackbox/bot/bot_factory.py +15 -17
- ai_plays_jackbox/bot/bot_personality.py +0 -1
- ai_plays_jackbox/bot/jackbox5/__init__.py +0 -0
- ai_plays_jackbox/bot/jackbox5/bot_base.py +26 -0
- ai_plays_jackbox/bot/jackbox5/mad_verse_city.py +121 -0
- ai_plays_jackbox/bot/jackbox5/patently_stupid.py +167 -0
- ai_plays_jackbox/bot/jackbox6/bot_base.py +1 -1
- ai_plays_jackbox/bot/jackbox6/joke_boat.py +105 -0
- ai_plays_jackbox/bot/jackbox7/bot_base.py +1 -1
- ai_plays_jackbox/bot/jackbox7/quiplash3.py +8 -4
- ai_plays_jackbox/bot/standalone/__init__.py +0 -0
- ai_plays_jackbox/bot/standalone/drawful2.py +159 -0
- ai_plays_jackbox/cli.py +26 -13
- ai_plays_jackbox/constants.py +3 -0
- ai_plays_jackbox/llm/chat_model.py +20 -3
- ai_plays_jackbox/llm/chat_model_factory.py +24 -19
- ai_plays_jackbox/llm/gemini_model.py +86 -0
- ai_plays_jackbox/llm/ollama_model.py +19 -7
- ai_plays_jackbox/llm/openai_model.py +48 -8
- ai_plays_jackbox/room.py +2 -5
- ai_plays_jackbox/run.py +7 -10
- ai_plays_jackbox/ui/create_ui.py +44 -16
- ai_plays_jackbox/web_ui.py +1 -1
- ai_plays_jackbox-0.1.0.dist-info/METADATA +154 -0
- ai_plays_jackbox-0.1.0.dist-info/RECORD +35 -0
- {ai_plays_jackbox-0.0.1.dist-info → ai_plays_jackbox-0.1.0.dist-info}/entry_points.txt +0 -1
- ai_plays_jackbox/llm/gemini_vertex_ai.py +0 -60
- ai_plays_jackbox-0.0.1.dist-info/METADATA +0 -88
- ai_plays_jackbox-0.0.1.dist-info/RECORD +0 -28
- {ai_plays_jackbox-0.0.1.dist-info → ai_plays_jackbox-0.1.0.dist-info}/LICENSE +0 -0
- {ai_plays_jackbox-0.0.1.dist-info → ai_plays_jackbox-0.1.0.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,159 @@
|
|
|
1
|
+
import random
|
|
2
|
+
|
|
3
|
+
from loguru import logger
|
|
4
|
+
|
|
5
|
+
from ai_plays_jackbox.bot.bot_base import JackBoxBotBase
|
|
6
|
+
|
|
7
|
+
_DRAWING_PROMPT_TEMPLATE = """
|
|
8
|
+
You are playing Drawful 2.
|
|
9
|
+
|
|
10
|
+
Generate an image with the following prompt: {prompt}
|
|
11
|
+
|
|
12
|
+
When generating your response, follow these rules:
|
|
13
|
+
- Your personality is: {personality}
|
|
14
|
+
- Make sure to implement your personality somehow into the drawing, but keep the prompt in mind
|
|
15
|
+
- The image must be a simple sketch
|
|
16
|
+
- The image must have a white background and use black for the lines
|
|
17
|
+
- Avoid intricate details
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class Drawful2Bot(JackBoxBotBase):
|
|
22
|
+
_drawing_completed: bool = False
|
|
23
|
+
|
|
24
|
+
def __init__(self, *args, **kwargs):
|
|
25
|
+
super().__init__(*args, **kwargs)
|
|
26
|
+
|
|
27
|
+
@property
|
|
28
|
+
def _player_operation_key(self) -> str:
|
|
29
|
+
return f"player:{self._player_id}"
|
|
30
|
+
|
|
31
|
+
def _is_player_operation_key(self, operation_key: str) -> bool:
|
|
32
|
+
return operation_key == self._player_operation_key
|
|
33
|
+
|
|
34
|
+
@property
|
|
35
|
+
def _room_operation_key(self) -> str:
|
|
36
|
+
return "room"
|
|
37
|
+
|
|
38
|
+
def _is_room_operation_key(self, operation_key: str) -> bool:
|
|
39
|
+
return operation_key == self._room_operation_key
|
|
40
|
+
|
|
41
|
+
def _handle_welcome(self, data: dict):
|
|
42
|
+
pass
|
|
43
|
+
|
|
44
|
+
def _handle_player_operation(self, data: dict):
|
|
45
|
+
if not data:
|
|
46
|
+
return
|
|
47
|
+
room_state = data.get("state", None)
|
|
48
|
+
if not room_state:
|
|
49
|
+
return
|
|
50
|
+
prompt = data.get("prompt")
|
|
51
|
+
prompt_text = self._html_to_text(prompt.get("html", "")) if prompt is not None else ""
|
|
52
|
+
|
|
53
|
+
match room_state:
|
|
54
|
+
case "Draw":
|
|
55
|
+
colors = data.get("colors", ["#fb405a", "#7a2259"])
|
|
56
|
+
selected_color = colors[0]
|
|
57
|
+
canvas_height = int(data.get("size", {}).get("height", 320))
|
|
58
|
+
canvas_width = int(data.get("size", {}).get("width", 320))
|
|
59
|
+
lines = self._generate_drawing(prompt_text, canvas_height, canvas_width)
|
|
60
|
+
object_key = data.get("objectKey", "")
|
|
61
|
+
if object_key != "":
|
|
62
|
+
if not self._drawing_completed:
|
|
63
|
+
self._object_update(
|
|
64
|
+
object_key,
|
|
65
|
+
{
|
|
66
|
+
"lines": [{"color": selected_color, "thickness": 1, "points": l} for l in lines],
|
|
67
|
+
"submit": True,
|
|
68
|
+
},
|
|
69
|
+
)
|
|
70
|
+
# This prevents the bot from trying to draw multiple times
|
|
71
|
+
self._drawing_completed = True
|
|
72
|
+
|
|
73
|
+
case "EnterSingleText":
|
|
74
|
+
# We need to reset this once we're entering options
|
|
75
|
+
self._drawing_completed = False
|
|
76
|
+
# Listen, the bot can't see the drawing
|
|
77
|
+
# so they're just going to say something
|
|
78
|
+
text_key = data.get("textKey", "")
|
|
79
|
+
self._text_update(text_key, self._generate_random_response())
|
|
80
|
+
|
|
81
|
+
case "MakeSingleChoice":
|
|
82
|
+
# Bot still can't see the drawing
|
|
83
|
+
# so just pick something
|
|
84
|
+
if data.get("type", "single") == "repeating":
|
|
85
|
+
pass
|
|
86
|
+
choices = data.get("choices", [])
|
|
87
|
+
choices_as_ints = [i for i in range(0, len(choices))]
|
|
88
|
+
selected_choice = random.choice(choices_as_ints)
|
|
89
|
+
self._client_send({"action": "choose", "choice": selected_choice})
|
|
90
|
+
|
|
91
|
+
def _handle_room_operation(self, data: dict):
|
|
92
|
+
pass
|
|
93
|
+
|
|
94
|
+
def _generate_drawing(self, prompt: str, canvas_height: int, canvas_width: int) -> list[str]:
|
|
95
|
+
logger.info("Generating drawing...")
|
|
96
|
+
image_prompt = _DRAWING_PROMPT_TEMPLATE.format(prompt=prompt, personality=self._personality)
|
|
97
|
+
image_bytes = self._chat_model.generate_sketch(
|
|
98
|
+
image_prompt,
|
|
99
|
+
"",
|
|
100
|
+
temperature=self._chat_model._chat_model_temperature,
|
|
101
|
+
top_p=self._chat_model._chat_model_top_p,
|
|
102
|
+
)
|
|
103
|
+
return self._image_bytes_to_polylines(image_bytes, canvas_height, canvas_width)
|
|
104
|
+
|
|
105
|
+
def _generate_random_response(self) -> str:
|
|
106
|
+
possible_responses = [
|
|
107
|
+
"Abstract awkward silence",
|
|
108
|
+
"Abstract existential dread",
|
|
109
|
+
"Abstract late-stage capitalism",
|
|
110
|
+
"Abstract lost hope",
|
|
111
|
+
"Abstract the void",
|
|
112
|
+
"Baby Yoda, but weird",
|
|
113
|
+
"Barbie, but weird",
|
|
114
|
+
"Confused dentist",
|
|
115
|
+
"Confused gym teacher",
|
|
116
|
+
"Confused lawyer",
|
|
117
|
+
"Confused therapist",
|
|
118
|
+
"DJ in trouble",
|
|
119
|
+
"Definitely banana",
|
|
120
|
+
"Definitely blob",
|
|
121
|
+
"Definitely potato",
|
|
122
|
+
"Definitely spaghetti",
|
|
123
|
+
"Taylor Swift, but weird",
|
|
124
|
+
"Waluigi, but weird",
|
|
125
|
+
"banana with feelings",
|
|
126
|
+
"chicken riding a scooter",
|
|
127
|
+
"cloud with feelings",
|
|
128
|
+
"confused gym teacher",
|
|
129
|
+
"confused therapist",
|
|
130
|
+
"dentist in trouble",
|
|
131
|
+
"duck + existential dread",
|
|
132
|
+
"duck riding a scooter",
|
|
133
|
+
"excited janitor",
|
|
134
|
+
"ferret + awkward silence",
|
|
135
|
+
"giraffe + awkward silence",
|
|
136
|
+
"giraffe filing taxes",
|
|
137
|
+
"hamster + existential dread",
|
|
138
|
+
"hamster + lost hope",
|
|
139
|
+
"hamster riding a scooter",
|
|
140
|
+
"janitor in trouble",
|
|
141
|
+
"joyful octopus",
|
|
142
|
+
"lawyer in trouble",
|
|
143
|
+
"llama + awkward silence",
|
|
144
|
+
"llama + late-stage capitalism",
|
|
145
|
+
"lonely dentist",
|
|
146
|
+
"lonely hamster",
|
|
147
|
+
"lonely janitor",
|
|
148
|
+
"lonely pirate",
|
|
149
|
+
"mango with feelings",
|
|
150
|
+
"pirate in trouble",
|
|
151
|
+
"sad DJ",
|
|
152
|
+
"sad hamster",
|
|
153
|
+
"spaghetti with feelings",
|
|
154
|
+
"terrified duck",
|
|
155
|
+
"terrified ferret",
|
|
156
|
+
"terrified lawyer",
|
|
157
|
+
]
|
|
158
|
+
chosen_response = random.choice(possible_responses)
|
|
159
|
+
return chosen_response
|
ai_plays_jackbox/cli.py
CHANGED
|
@@ -1,6 +1,12 @@
|
|
|
1
1
|
import argparse
|
|
2
2
|
|
|
3
|
-
from ai_plays_jackbox import
|
|
3
|
+
from ai_plays_jackbox.constants import (
|
|
4
|
+
DEFAULT_NUM_OF_BOTS,
|
|
5
|
+
DEFAULT_TEMPERATURE,
|
|
6
|
+
DEFAULT_TOP_P,
|
|
7
|
+
)
|
|
8
|
+
from ai_plays_jackbox.llm.chat_model_factory import CHAT_MODEL_PROVIDERS
|
|
9
|
+
from ai_plays_jackbox.run import run
|
|
4
10
|
|
|
5
11
|
|
|
6
12
|
def _validate_room_code(string_to_check: str) -> str:
|
|
@@ -49,42 +55,49 @@ def cli():
|
|
|
49
55
|
metavar="WXYZ",
|
|
50
56
|
)
|
|
51
57
|
parser.add_argument(
|
|
52
|
-
"--chat-model-
|
|
58
|
+
"--chat-model-provider",
|
|
53
59
|
required=True,
|
|
54
60
|
help="Choose which chat model platform to use",
|
|
55
|
-
choices=(
|
|
61
|
+
choices=list(CHAT_MODEL_PROVIDERS.keys()),
|
|
62
|
+
type=str,
|
|
63
|
+
)
|
|
64
|
+
parser.add_argument(
|
|
65
|
+
"--chat-model-name",
|
|
66
|
+
required=False,
|
|
67
|
+
help="Choose which chat model to use (Will default to default for provider)",
|
|
56
68
|
type=str,
|
|
57
69
|
)
|
|
58
70
|
parser.add_argument(
|
|
59
71
|
"--num-of-bots",
|
|
60
72
|
required=False,
|
|
61
|
-
default=
|
|
62
|
-
help="How many bots to have play
|
|
73
|
+
default=DEFAULT_NUM_OF_BOTS,
|
|
74
|
+
help="How many bots to have play",
|
|
63
75
|
type=_validate_num_of_bots,
|
|
64
|
-
metavar=
|
|
76
|
+
metavar=str(DEFAULT_NUM_OF_BOTS),
|
|
65
77
|
)
|
|
66
78
|
parser.add_argument(
|
|
67
79
|
"--temperature",
|
|
68
80
|
required=False,
|
|
69
|
-
default=
|
|
70
|
-
help="Temperature for Gen AI
|
|
81
|
+
default=DEFAULT_TEMPERATURE,
|
|
82
|
+
help="Temperature for Gen AI",
|
|
71
83
|
type=_validate_temperature,
|
|
72
|
-
metavar=
|
|
84
|
+
metavar=str(DEFAULT_TEMPERATURE),
|
|
73
85
|
)
|
|
74
86
|
parser.add_argument(
|
|
75
87
|
"--top-p",
|
|
76
88
|
required=False,
|
|
77
|
-
default=
|
|
78
|
-
help="Top P for Gen AI
|
|
89
|
+
default=DEFAULT_TOP_P,
|
|
90
|
+
help="Top P for Gen AI",
|
|
79
91
|
type=_validate_top_p,
|
|
80
|
-
metavar=
|
|
92
|
+
metavar=str(DEFAULT_TOP_P),
|
|
81
93
|
)
|
|
82
94
|
args = parser.parse_args()
|
|
83
95
|
|
|
84
96
|
run(
|
|
85
97
|
args.room_code.upper(),
|
|
86
|
-
|
|
98
|
+
args.chat_model_provider,
|
|
87
99
|
chat_model_name=args.chat_model_name,
|
|
100
|
+
num_of_bots=args.num_of_bots,
|
|
88
101
|
chat_model_temperature=args.temperature,
|
|
89
102
|
chat_model_top_p=args.top_p,
|
|
90
103
|
)
|
ai_plays_jackbox/constants.py
CHANGED
|
@@ -3,13 +3,20 @@ from typing import Optional
|
|
|
3
3
|
|
|
4
4
|
|
|
5
5
|
class ChatModel(ABC):
|
|
6
|
-
|
|
7
|
-
|
|
6
|
+
_model: str
|
|
7
|
+
_chat_model_temperature: float
|
|
8
|
+
_chat_model_top_p: float
|
|
8
9
|
|
|
9
|
-
def __init__(self, chat_model_temperature: float = 0.5, chat_model_top_p: float = 0.9):
|
|
10
|
+
def __init__(self, model: str, chat_model_temperature: float = 0.5, chat_model_top_p: float = 0.9):
|
|
11
|
+
self._model = model
|
|
10
12
|
self._chat_model_temperature = chat_model_temperature
|
|
11
13
|
self._chat_model_top_p = chat_model_top_p
|
|
12
14
|
|
|
15
|
+
@classmethod
|
|
16
|
+
@abstractmethod
|
|
17
|
+
def get_default_model(self) -> str:
|
|
18
|
+
pass
|
|
19
|
+
|
|
13
20
|
@abstractmethod
|
|
14
21
|
def generate_text(
|
|
15
22
|
self,
|
|
@@ -20,3 +27,13 @@ class ChatModel(ABC):
|
|
|
20
27
|
top_p: Optional[float] = None,
|
|
21
28
|
) -> str:
|
|
22
29
|
pass
|
|
30
|
+
|
|
31
|
+
@abstractmethod
|
|
32
|
+
def generate_sketch(
|
|
33
|
+
self,
|
|
34
|
+
prompt: str,
|
|
35
|
+
instructions: str,
|
|
36
|
+
temperature: Optional[float] = None,
|
|
37
|
+
top_p: Optional[float] = None,
|
|
38
|
+
) -> bytes:
|
|
39
|
+
pass
|
|
@@ -1,30 +1,35 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
1
3
|
from ai_plays_jackbox.llm.chat_model import ChatModel
|
|
2
|
-
from ai_plays_jackbox.llm.
|
|
4
|
+
from ai_plays_jackbox.llm.gemini_model import GeminiModel
|
|
3
5
|
from ai_plays_jackbox.llm.ollama_model import OllamaModel
|
|
4
6
|
from ai_plays_jackbox.llm.openai_model import OpenAIModel
|
|
5
7
|
|
|
8
|
+
CHAT_MODEL_PROVIDERS: dict[str, type[ChatModel]] = {
|
|
9
|
+
"openai": OpenAIModel,
|
|
10
|
+
"gemini": GeminiModel,
|
|
11
|
+
"ollama": OllamaModel,
|
|
12
|
+
}
|
|
13
|
+
|
|
6
14
|
|
|
7
15
|
class ChatModelFactory:
|
|
8
16
|
@staticmethod
|
|
9
17
|
def get_chat_model(
|
|
10
|
-
|
|
18
|
+
chat_model_provider: str,
|
|
19
|
+
chat_model_name: Optional[str] = None,
|
|
11
20
|
chat_model_temperature: float = 0.5,
|
|
12
21
|
chat_model_top_p: float = 0.9,
|
|
13
22
|
) -> ChatModel:
|
|
14
|
-
|
|
15
|
-
if
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
chat_model_top_p=chat_model_top_p,
|
|
28
|
-
)
|
|
29
|
-
else:
|
|
30
|
-
raise ValueError(f"Unknown chat model type: {chat_model_name}")
|
|
23
|
+
chat_model_provider = chat_model_provider.lower()
|
|
24
|
+
if chat_model_provider not in CHAT_MODEL_PROVIDERS.keys():
|
|
25
|
+
raise ValueError(f"Unknown chat model provider: {chat_model_provider}")
|
|
26
|
+
|
|
27
|
+
return CHAT_MODEL_PROVIDERS[chat_model_provider](
|
|
28
|
+
(
|
|
29
|
+
chat_model_name
|
|
30
|
+
if chat_model_name is not None
|
|
31
|
+
else CHAT_MODEL_PROVIDERS[chat_model_provider].get_default_model()
|
|
32
|
+
),
|
|
33
|
+
chat_model_temperature=chat_model_temperature,
|
|
34
|
+
chat_model_top_p=chat_model_top_p,
|
|
35
|
+
)
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
from google import genai
|
|
5
|
+
from google.genai.types import GenerateContentConfig
|
|
6
|
+
from loguru import logger
|
|
7
|
+
|
|
8
|
+
from ai_plays_jackbox.llm.chat_model import ChatModel
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class GeminiModel(ChatModel):
|
|
12
|
+
_gemini_vertex_ai_client: genai.Client
|
|
13
|
+
|
|
14
|
+
def __init__(self, *args, **kwargs):
|
|
15
|
+
super().__init__(*args, **kwargs)
|
|
16
|
+
self._gemini_vertex_ai_client = genai.Client(
|
|
17
|
+
vertexai=bool(os.environ.get("GOOGLE_GENAI_USE_VERTEXAI")),
|
|
18
|
+
api_key=os.environ.get("GOOGLE_GEMINI_DEVELOPER_API_KEY"),
|
|
19
|
+
project=os.environ.get("GOOGLE_CLOUD_PROJECT"),
|
|
20
|
+
location=os.environ.get("GOOGLE_CLOUD_LOCATION"),
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
# Check connection and if model exists, this will hard fail if connection can't be made
|
|
24
|
+
# Or if the model is not found
|
|
25
|
+
_ = self._gemini_vertex_ai_client.models.get(self._model)
|
|
26
|
+
|
|
27
|
+
@classmethod
|
|
28
|
+
def get_default_model(cls):
|
|
29
|
+
return "gemini-2.0-flash-001"
|
|
30
|
+
|
|
31
|
+
def generate_text(
|
|
32
|
+
self,
|
|
33
|
+
prompt: str,
|
|
34
|
+
instructions: str,
|
|
35
|
+
max_tokens: Optional[int] = None,
|
|
36
|
+
temperature: Optional[float] = None,
|
|
37
|
+
top_p: Optional[float] = None,
|
|
38
|
+
) -> str:
|
|
39
|
+
if temperature is None:
|
|
40
|
+
temperature = self._chat_model_temperature
|
|
41
|
+
if top_p is None:
|
|
42
|
+
top_p = self._chat_model_top_p
|
|
43
|
+
|
|
44
|
+
chat_response = self._gemini_vertex_ai_client.models.generate_content(
|
|
45
|
+
model=self._model,
|
|
46
|
+
contents=prompt,
|
|
47
|
+
config=GenerateContentConfig(
|
|
48
|
+
system_instruction=[instructions],
|
|
49
|
+
max_output_tokens=max_tokens,
|
|
50
|
+
temperature=temperature,
|
|
51
|
+
top_p=top_p,
|
|
52
|
+
),
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
text = str(chat_response.text).strip().replace("\n", "")
|
|
56
|
+
logger.info(f"Generated text: {text}")
|
|
57
|
+
return text
|
|
58
|
+
|
|
59
|
+
def generate_sketch(
|
|
60
|
+
self,
|
|
61
|
+
prompt: str,
|
|
62
|
+
instructions: str,
|
|
63
|
+
temperature: Optional[float] = None,
|
|
64
|
+
top_p: Optional[float] = None,
|
|
65
|
+
) -> bytes:
|
|
66
|
+
image_gen_response = self._gemini_vertex_ai_client.models.generate_content(
|
|
67
|
+
model="gemini-2.0-flash-preview-image-generation",
|
|
68
|
+
contents=prompt,
|
|
69
|
+
config=GenerateContentConfig(
|
|
70
|
+
system_instruction=[instructions],
|
|
71
|
+
temperature=temperature,
|
|
72
|
+
top_p=top_p,
|
|
73
|
+
response_modalities=["IMAGE"],
|
|
74
|
+
),
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
if (
|
|
78
|
+
image_gen_response.candidates
|
|
79
|
+
and image_gen_response.candidates[0].content
|
|
80
|
+
and image_gen_response.candidates[0].content.parts
|
|
81
|
+
):
|
|
82
|
+
for part in image_gen_response.candidates[0].content.parts:
|
|
83
|
+
if part.inline_data is not None and part.inline_data.data is not None:
|
|
84
|
+
return part.inline_data.data
|
|
85
|
+
|
|
86
|
+
return b""
|
|
@@ -1,20 +1,23 @@
|
|
|
1
1
|
from typing import Optional
|
|
2
2
|
|
|
3
3
|
from loguru import logger
|
|
4
|
-
from ollama import Options, chat,
|
|
4
|
+
from ollama import Options, chat, show
|
|
5
5
|
|
|
6
6
|
from ai_plays_jackbox.llm.chat_model import ChatModel
|
|
7
7
|
|
|
8
8
|
|
|
9
9
|
class OllamaModel(ChatModel):
|
|
10
|
-
_model: str
|
|
11
10
|
|
|
12
|
-
def __init__(self,
|
|
11
|
+
def __init__(self, *args, **kwargs):
|
|
13
12
|
super().__init__(*args, **kwargs)
|
|
14
|
-
self._model = model
|
|
15
13
|
|
|
16
|
-
# Check connection, this will hard fail if connection can't be made
|
|
17
|
-
|
|
14
|
+
# Check connection and if model exists, this will hard fail if connection can't be made
|
|
15
|
+
# Or if the model is not found
|
|
16
|
+
_ = show(self._model)
|
|
17
|
+
|
|
18
|
+
@classmethod
|
|
19
|
+
def get_default_model(cls):
|
|
20
|
+
return "gemma3:12b"
|
|
18
21
|
|
|
19
22
|
def generate_text(
|
|
20
23
|
self,
|
|
@@ -36,6 +39,15 @@ class OllamaModel(ChatModel):
|
|
|
36
39
|
stream=False,
|
|
37
40
|
options=Options(num_predict=max_tokens, temperature=temperature, top_p=top_p),
|
|
38
41
|
)
|
|
39
|
-
text = chat_response.message.content.strip().replace("\n", " ")
|
|
42
|
+
text = str(chat_response.message.content).strip().replace("\n", " ")
|
|
40
43
|
logger.info(f"Generated text: {text}")
|
|
41
44
|
return text
|
|
45
|
+
|
|
46
|
+
def generate_sketch(
|
|
47
|
+
self,
|
|
48
|
+
prompt: str,
|
|
49
|
+
instructions: str,
|
|
50
|
+
temperature: Optional[float] = None,
|
|
51
|
+
top_p: Optional[float] = None,
|
|
52
|
+
) -> bytes:
|
|
53
|
+
raise Exception("Ollama model not supported yet for sketches")
|
|
@@ -1,23 +1,32 @@
|
|
|
1
|
+
import base64
|
|
1
2
|
import os
|
|
2
3
|
from typing import Optional
|
|
3
4
|
|
|
4
5
|
from loguru import logger
|
|
5
6
|
from openai import OpenAI
|
|
7
|
+
from openai.types.chat import (
|
|
8
|
+
ChatCompletionDeveloperMessageParam,
|
|
9
|
+
ChatCompletionUserMessageParam,
|
|
10
|
+
)
|
|
11
|
+
from openai.types.responses import Response
|
|
6
12
|
|
|
7
13
|
from ai_plays_jackbox.llm.chat_model import ChatModel
|
|
8
14
|
|
|
9
15
|
|
|
10
16
|
class OpenAIModel(ChatModel):
|
|
11
|
-
_model: str
|
|
12
17
|
_open_ai_client: OpenAI
|
|
13
18
|
|
|
14
|
-
def __init__(self,
|
|
19
|
+
def __init__(self, *args, **kwargs):
|
|
15
20
|
super().__init__(*args, **kwargs)
|
|
16
21
|
self._open_ai_client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
|
|
17
|
-
self._model = model
|
|
18
22
|
|
|
19
|
-
# Check connection, this will hard fail if connection can't be made
|
|
20
|
-
|
|
23
|
+
# Check connection and if model exists, this will hard fail if connection can't be made
|
|
24
|
+
# Or if the model is not found
|
|
25
|
+
_ = self._open_ai_client.models.retrieve(self._model)
|
|
26
|
+
|
|
27
|
+
@classmethod
|
|
28
|
+
def get_default_model(cls):
|
|
29
|
+
return "gpt-4o-mini"
|
|
21
30
|
|
|
22
31
|
def generate_text(
|
|
23
32
|
self,
|
|
@@ -32,15 +41,46 @@ class OpenAIModel(ChatModel):
|
|
|
32
41
|
if top_p is None:
|
|
33
42
|
top_p = self._chat_model_top_p
|
|
34
43
|
|
|
35
|
-
instructions_formatted = {"role": "developer", "content": instructions}
|
|
36
44
|
chat_response = self._open_ai_client.chat.completions.create(
|
|
37
45
|
model=self._model,
|
|
38
|
-
messages=[
|
|
46
|
+
messages=[
|
|
47
|
+
ChatCompletionDeveloperMessageParam(content=instructions, role="developer"),
|
|
48
|
+
ChatCompletionUserMessageParam(content=prompt, role="user"),
|
|
49
|
+
],
|
|
39
50
|
stream=False,
|
|
40
51
|
max_completion_tokens=max_tokens,
|
|
41
52
|
temperature=temperature,
|
|
42
53
|
top_p=top_p,
|
|
43
54
|
)
|
|
44
|
-
text = chat_response.choices[0].message.content.strip().replace("\n", "")
|
|
55
|
+
text = str(chat_response.choices[0].message.content).strip().replace("\n", "")
|
|
45
56
|
logger.info(f"Generated text: {text}")
|
|
46
57
|
return text
|
|
58
|
+
|
|
59
|
+
def generate_sketch(
|
|
60
|
+
self,
|
|
61
|
+
prompt: str,
|
|
62
|
+
instructions: str,
|
|
63
|
+
temperature: Optional[float] = None,
|
|
64
|
+
top_p: Optional[float] = None,
|
|
65
|
+
) -> bytes:
|
|
66
|
+
image_gen_response: Response = self._open_ai_client.responses.create(
|
|
67
|
+
model=self._model,
|
|
68
|
+
instructions=instructions,
|
|
69
|
+
input=prompt,
|
|
70
|
+
temperature=temperature,
|
|
71
|
+
top_p=top_p,
|
|
72
|
+
tools=[
|
|
73
|
+
{
|
|
74
|
+
"type": "image_generation",
|
|
75
|
+
"quality": "low",
|
|
76
|
+
"size": "1024x1024",
|
|
77
|
+
}
|
|
78
|
+
],
|
|
79
|
+
)
|
|
80
|
+
# Save the image to a file
|
|
81
|
+
image_data = [output.result for output in image_gen_response.output if output.type == "image_generation_call"]
|
|
82
|
+
image_base64 = ""
|
|
83
|
+
if image_data:
|
|
84
|
+
image_base64 = str(image_data[0])
|
|
85
|
+
|
|
86
|
+
return base64.b64decode(image_base64)
|
ai_plays_jackbox/room.py
CHANGED
|
@@ -11,7 +11,6 @@ from ai_plays_jackbox.bot.bot_factory import JackBoxBotFactory
|
|
|
11
11
|
from ai_plays_jackbox.bot.bot_personality import JackBoxBotVariant
|
|
12
12
|
from ai_plays_jackbox.constants import ECAST_HOST
|
|
13
13
|
from ai_plays_jackbox.llm.chat_model import ChatModel
|
|
14
|
-
from ai_plays_jackbox.llm.ollama_model import OllamaModel
|
|
15
14
|
|
|
16
15
|
|
|
17
16
|
class JackBoxRoom:
|
|
@@ -24,12 +23,10 @@ class JackBoxRoom:
|
|
|
24
23
|
def play(
|
|
25
24
|
self,
|
|
26
25
|
room_code: str,
|
|
26
|
+
chat_model: ChatModel,
|
|
27
27
|
num_of_bots: int = 4,
|
|
28
28
|
bots_in_play: Optional[list] = None,
|
|
29
|
-
chat_model: Optional[ChatModel] = None,
|
|
30
29
|
):
|
|
31
|
-
if chat_model is None:
|
|
32
|
-
chat_model = OllamaModel()
|
|
33
30
|
room_type = self._get_room_type(room_code)
|
|
34
31
|
if not room_type:
|
|
35
32
|
logger.error(f"Unable to find room {room_code}")
|
|
@@ -45,9 +42,9 @@ class JackBoxRoom:
|
|
|
45
42
|
for b in bots_to_make:
|
|
46
43
|
bot = bot_factory.get_bot(
|
|
47
44
|
room_type,
|
|
45
|
+
chat_model,
|
|
48
46
|
name=b.value.name,
|
|
49
47
|
personality=b.value.personality,
|
|
50
|
-
chat_model=chat_model,
|
|
51
48
|
)
|
|
52
49
|
self._bots.append(bot)
|
|
53
50
|
with self._lock:
|
ai_plays_jackbox/run.py
CHANGED
|
@@ -6,21 +6,18 @@ from ai_plays_jackbox.room import JackBoxRoom
|
|
|
6
6
|
|
|
7
7
|
def run(
|
|
8
8
|
room_code: str,
|
|
9
|
+
chat_model_provider: str,
|
|
10
|
+
chat_model_name: Optional[str] = None,
|
|
9
11
|
num_of_bots: int = 4,
|
|
10
12
|
bots_in_play: Optional[list] = None,
|
|
11
|
-
chat_model_name: str = "ollama",
|
|
12
13
|
chat_model_temperature: float = 0.5,
|
|
13
14
|
chat_model_top_p: float = 0.9,
|
|
14
15
|
):
|
|
15
|
-
"""Will run a set of bots through a game of JackBox given a room code.
|
|
16
|
-
|
|
17
|
-
Args:
|
|
18
|
-
room_code (str): The room code.
|
|
19
|
-
num_of_bots (int, optional): The number of bots to participate. Defaults to 4.
|
|
20
|
-
chat_model (str, optional): The chat model to use to generate responses. Defaults to "ollama".
|
|
21
|
-
"""
|
|
22
16
|
chat_model = ChatModelFactory.get_chat_model(
|
|
23
|
-
|
|
17
|
+
chat_model_provider,
|
|
18
|
+
chat_model_name=chat_model_name,
|
|
19
|
+
chat_model_temperature=chat_model_temperature,
|
|
20
|
+
chat_model_top_p=chat_model_top_p,
|
|
24
21
|
)
|
|
25
22
|
room = JackBoxRoom()
|
|
26
|
-
room.play(room_code, num_of_bots=num_of_bots, bots_in_play=bots_in_play
|
|
23
|
+
room.play(room_code, chat_model, num_of_bots=num_of_bots, bots_in_play=bots_in_play)
|