ai-plays-jackbox 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-plays-jackbox might be problematic. Click here for more details.

Files changed (34) hide show
  1. ai_plays_jackbox/__init__.py +0 -1
  2. ai_plays_jackbox/bot/__init__.py +0 -1
  3. ai_plays_jackbox/bot/bot_base.py +42 -8
  4. ai_plays_jackbox/bot/bot_factory.py +15 -17
  5. ai_plays_jackbox/bot/bot_personality.py +0 -1
  6. ai_plays_jackbox/bot/jackbox5/__init__.py +0 -0
  7. ai_plays_jackbox/bot/jackbox5/bot_base.py +26 -0
  8. ai_plays_jackbox/bot/jackbox5/mad_verse_city.py +121 -0
  9. ai_plays_jackbox/bot/jackbox5/patently_stupid.py +167 -0
  10. ai_plays_jackbox/bot/jackbox6/bot_base.py +1 -1
  11. ai_plays_jackbox/bot/jackbox6/joke_boat.py +105 -0
  12. ai_plays_jackbox/bot/jackbox7/bot_base.py +1 -1
  13. ai_plays_jackbox/bot/jackbox7/quiplash3.py +8 -4
  14. ai_plays_jackbox/bot/standalone/__init__.py +0 -0
  15. ai_plays_jackbox/bot/standalone/drawful2.py +159 -0
  16. ai_plays_jackbox/cli.py +26 -13
  17. ai_plays_jackbox/constants.py +3 -0
  18. ai_plays_jackbox/llm/chat_model.py +20 -3
  19. ai_plays_jackbox/llm/chat_model_factory.py +24 -19
  20. ai_plays_jackbox/llm/gemini_model.py +86 -0
  21. ai_plays_jackbox/llm/ollama_model.py +19 -7
  22. ai_plays_jackbox/llm/openai_model.py +48 -8
  23. ai_plays_jackbox/room.py +2 -5
  24. ai_plays_jackbox/run.py +7 -10
  25. ai_plays_jackbox/ui/create_ui.py +44 -16
  26. ai_plays_jackbox/web_ui.py +1 -1
  27. ai_plays_jackbox-0.1.0.dist-info/METADATA +154 -0
  28. ai_plays_jackbox-0.1.0.dist-info/RECORD +35 -0
  29. {ai_plays_jackbox-0.0.1.dist-info → ai_plays_jackbox-0.1.0.dist-info}/entry_points.txt +0 -1
  30. ai_plays_jackbox/llm/gemini_vertex_ai.py +0 -60
  31. ai_plays_jackbox-0.0.1.dist-info/METADATA +0 -88
  32. ai_plays_jackbox-0.0.1.dist-info/RECORD +0 -28
  33. {ai_plays_jackbox-0.0.1.dist-info → ai_plays_jackbox-0.1.0.dist-info}/LICENSE +0 -0
  34. {ai_plays_jackbox-0.0.1.dist-info → ai_plays_jackbox-0.1.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,159 @@
1
+ import random
2
+
3
+ from loguru import logger
4
+
5
+ from ai_plays_jackbox.bot.bot_base import JackBoxBotBase
6
+
7
+ _DRAWING_PROMPT_TEMPLATE = """
8
+ You are playing Drawful 2.
9
+
10
+ Generate an image with the following prompt: {prompt}
11
+
12
+ When generating your response, follow these rules:
13
+ - Your personality is: {personality}
14
+ - Make sure to implement your personality somehow into the drawing, but keep the prompt in mind
15
+ - The image must be a simple sketch
16
+ - The image must have a white background and use black for the lines
17
+ - Avoid intricate details
18
+ """
19
+
20
+
21
+ class Drawful2Bot(JackBoxBotBase):
22
+ _drawing_completed: bool = False
23
+
24
+ def __init__(self, *args, **kwargs):
25
+ super().__init__(*args, **kwargs)
26
+
27
+ @property
28
+ def _player_operation_key(self) -> str:
29
+ return f"player:{self._player_id}"
30
+
31
+ def _is_player_operation_key(self, operation_key: str) -> bool:
32
+ return operation_key == self._player_operation_key
33
+
34
+ @property
35
+ def _room_operation_key(self) -> str:
36
+ return "room"
37
+
38
+ def _is_room_operation_key(self, operation_key: str) -> bool:
39
+ return operation_key == self._room_operation_key
40
+
41
+ def _handle_welcome(self, data: dict):
42
+ pass
43
+
44
+ def _handle_player_operation(self, data: dict):
45
+ if not data:
46
+ return
47
+ room_state = data.get("state", None)
48
+ if not room_state:
49
+ return
50
+ prompt = data.get("prompt")
51
+ prompt_text = self._html_to_text(prompt.get("html", "")) if prompt is not None else ""
52
+
53
+ match room_state:
54
+ case "Draw":
55
+ colors = data.get("colors", ["#fb405a", "#7a2259"])
56
+ selected_color = colors[0]
57
+ canvas_height = int(data.get("size", {}).get("height", 320))
58
+ canvas_width = int(data.get("size", {}).get("width", 320))
59
+ lines = self._generate_drawing(prompt_text, canvas_height, canvas_width)
60
+ object_key = data.get("objectKey", "")
61
+ if object_key != "":
62
+ if not self._drawing_completed:
63
+ self._object_update(
64
+ object_key,
65
+ {
66
+ "lines": [{"color": selected_color, "thickness": 1, "points": l} for l in lines],
67
+ "submit": True,
68
+ },
69
+ )
70
+ # This prevents the bot from trying to draw multiple times
71
+ self._drawing_completed = True
72
+
73
+ case "EnterSingleText":
74
+ # We need to reset this once we're entering options
75
+ self._drawing_completed = False
76
+ # Listen, the bot can't see the drawing
77
+ # so they're just going to say something
78
+ text_key = data.get("textKey", "")
79
+ self._text_update(text_key, self._generate_random_response())
80
+
81
+ case "MakeSingleChoice":
82
+ # Bot still can't see the drawing
83
+ # so just pick something
84
+ if data.get("type", "single") == "repeating":
85
+ pass
86
+ choices = data.get("choices", [])
87
+ choices_as_ints = [i for i in range(0, len(choices))]
88
+ selected_choice = random.choice(choices_as_ints)
89
+ self._client_send({"action": "choose", "choice": selected_choice})
90
+
91
+ def _handle_room_operation(self, data: dict):
92
+ pass
93
+
94
+ def _generate_drawing(self, prompt: str, canvas_height: int, canvas_width: int) -> list[str]:
95
+ logger.info("Generating drawing...")
96
+ image_prompt = _DRAWING_PROMPT_TEMPLATE.format(prompt=prompt, personality=self._personality)
97
+ image_bytes = self._chat_model.generate_sketch(
98
+ image_prompt,
99
+ "",
100
+ temperature=self._chat_model._chat_model_temperature,
101
+ top_p=self._chat_model._chat_model_top_p,
102
+ )
103
+ return self._image_bytes_to_polylines(image_bytes, canvas_height, canvas_width)
104
+
105
+ def _generate_random_response(self) -> str:
106
+ possible_responses = [
107
+ "Abstract awkward silence",
108
+ "Abstract existential dread",
109
+ "Abstract late-stage capitalism",
110
+ "Abstract lost hope",
111
+ "Abstract the void",
112
+ "Baby Yoda, but weird",
113
+ "Barbie, but weird",
114
+ "Confused dentist",
115
+ "Confused gym teacher",
116
+ "Confused lawyer",
117
+ "Confused therapist",
118
+ "DJ in trouble",
119
+ "Definitely banana",
120
+ "Definitely blob",
121
+ "Definitely potato",
122
+ "Definitely spaghetti",
123
+ "Taylor Swift, but weird",
124
+ "Waluigi, but weird",
125
+ "banana with feelings",
126
+ "chicken riding a scooter",
127
+ "cloud with feelings",
128
+ "confused gym teacher",
129
+ "confused therapist",
130
+ "dentist in trouble",
131
+ "duck + existential dread",
132
+ "duck riding a scooter",
133
+ "excited janitor",
134
+ "ferret + awkward silence",
135
+ "giraffe + awkward silence",
136
+ "giraffe filing taxes",
137
+ "hamster + existential dread",
138
+ "hamster + lost hope",
139
+ "hamster riding a scooter",
140
+ "janitor in trouble",
141
+ "joyful octopus",
142
+ "lawyer in trouble",
143
+ "llama + awkward silence",
144
+ "llama + late-stage capitalism",
145
+ "lonely dentist",
146
+ "lonely hamster",
147
+ "lonely janitor",
148
+ "lonely pirate",
149
+ "mango with feelings",
150
+ "pirate in trouble",
151
+ "sad DJ",
152
+ "sad hamster",
153
+ "spaghetti with feelings",
154
+ "terrified duck",
155
+ "terrified ferret",
156
+ "terrified lawyer",
157
+ ]
158
+ chosen_response = random.choice(possible_responses)
159
+ return chosen_response
ai_plays_jackbox/cli.py CHANGED
@@ -1,6 +1,12 @@
1
1
  import argparse
2
2
 
3
- from ai_plays_jackbox import run
3
+ from ai_plays_jackbox.constants import (
4
+ DEFAULT_NUM_OF_BOTS,
5
+ DEFAULT_TEMPERATURE,
6
+ DEFAULT_TOP_P,
7
+ )
8
+ from ai_plays_jackbox.llm.chat_model_factory import CHAT_MODEL_PROVIDERS
9
+ from ai_plays_jackbox.run import run
4
10
 
5
11
 
6
12
  def _validate_room_code(string_to_check: str) -> str:
@@ -49,42 +55,49 @@ def cli():
49
55
  metavar="WXYZ",
50
56
  )
51
57
  parser.add_argument(
52
- "--chat-model-name",
58
+ "--chat-model-provider",
53
59
  required=True,
54
60
  help="Choose which chat model platform to use",
55
- choices=("ollama", "openai", "gemini"),
61
+ choices=list(CHAT_MODEL_PROVIDERS.keys()),
62
+ type=str,
63
+ )
64
+ parser.add_argument(
65
+ "--chat-model-name",
66
+ required=False,
67
+ help="Choose which chat model to use (Will default to default for provider)",
56
68
  type=str,
57
69
  )
58
70
  parser.add_argument(
59
71
  "--num-of-bots",
60
72
  required=False,
61
- default=4,
62
- help="How many bots to have play (Defaults to 4)",
73
+ default=DEFAULT_NUM_OF_BOTS,
74
+ help="How many bots to have play",
63
75
  type=_validate_num_of_bots,
64
- metavar="4",
76
+ metavar=str(DEFAULT_NUM_OF_BOTS),
65
77
  )
66
78
  parser.add_argument(
67
79
  "--temperature",
68
80
  required=False,
69
- default=0.5,
70
- help="Temperature for Gen AI (Defaults to 0.5)",
81
+ default=DEFAULT_TEMPERATURE,
82
+ help="Temperature for Gen AI",
71
83
  type=_validate_temperature,
72
- metavar="0.5",
84
+ metavar=str(DEFAULT_TEMPERATURE),
73
85
  )
74
86
  parser.add_argument(
75
87
  "--top-p",
76
88
  required=False,
77
- default=0.9,
78
- help="Top P for Gen AI (Defaults to 0.9)",
89
+ default=DEFAULT_TOP_P,
90
+ help="Top P for Gen AI",
79
91
  type=_validate_top_p,
80
- metavar="0.9",
92
+ metavar=str(DEFAULT_TOP_P),
81
93
  )
82
94
  args = parser.parse_args()
83
95
 
84
96
  run(
85
97
  args.room_code.upper(),
86
- num_of_bots=args.num_of_bots,
98
+ args.chat_model_provider,
87
99
  chat_model_name=args.chat_model_name,
100
+ num_of_bots=args.num_of_bots,
88
101
  chat_model_temperature=args.temperature,
89
102
  chat_model_top_p=args.top_p,
90
103
  )
@@ -1 +1,4 @@
1
1
  ECAST_HOST = "ecast.jackboxgames.com"
2
+ DEFAULT_TEMPERATURE = 0.5
3
+ DEFAULT_TOP_P = 0.9
4
+ DEFAULT_NUM_OF_BOTS = 4
@@ -3,13 +3,20 @@ from typing import Optional
3
3
 
4
4
 
5
5
  class ChatModel(ABC):
6
- _chat_model_temperature: float = 0.5
7
- _chat_model_top_p: float = 0.9
6
+ _model: str
7
+ _chat_model_temperature: float
8
+ _chat_model_top_p: float
8
9
 
9
- def __init__(self, chat_model_temperature: float = 0.5, chat_model_top_p: float = 0.9):
10
+ def __init__(self, model: str, chat_model_temperature: float = 0.5, chat_model_top_p: float = 0.9):
11
+ self._model = model
10
12
  self._chat_model_temperature = chat_model_temperature
11
13
  self._chat_model_top_p = chat_model_top_p
12
14
 
15
+ @classmethod
16
+ @abstractmethod
17
+ def get_default_model(self) -> str:
18
+ pass
19
+
13
20
  @abstractmethod
14
21
  def generate_text(
15
22
  self,
@@ -20,3 +27,13 @@ class ChatModel(ABC):
20
27
  top_p: Optional[float] = None,
21
28
  ) -> str:
22
29
  pass
30
+
31
+ @abstractmethod
32
+ def generate_sketch(
33
+ self,
34
+ prompt: str,
35
+ instructions: str,
36
+ temperature: Optional[float] = None,
37
+ top_p: Optional[float] = None,
38
+ ) -> bytes:
39
+ pass
@@ -1,30 +1,35 @@
1
+ from typing import Optional
2
+
1
3
  from ai_plays_jackbox.llm.chat_model import ChatModel
2
- from ai_plays_jackbox.llm.gemini_vertex_ai import GeminiVertextAIModel
4
+ from ai_plays_jackbox.llm.gemini_model import GeminiModel
3
5
  from ai_plays_jackbox.llm.ollama_model import OllamaModel
4
6
  from ai_plays_jackbox.llm.openai_model import OpenAIModel
5
7
 
8
+ CHAT_MODEL_PROVIDERS: dict[str, type[ChatModel]] = {
9
+ "openai": OpenAIModel,
10
+ "gemini": GeminiModel,
11
+ "ollama": OllamaModel,
12
+ }
13
+
6
14
 
7
15
  class ChatModelFactory:
8
16
  @staticmethod
9
17
  def get_chat_model(
10
- chat_model_name: str,
18
+ chat_model_provider: str,
19
+ chat_model_name: Optional[str] = None,
11
20
  chat_model_temperature: float = 0.5,
12
21
  chat_model_top_p: float = 0.9,
13
22
  ) -> ChatModel:
14
- chat_model_name = chat_model_name.lower()
15
- if chat_model_name == "ollama":
16
- return OllamaModel(
17
- model="gemma3:12b", chat_model_temperature=chat_model_temperature, chat_model_top_p=chat_model_top_p
18
- )
19
- if chat_model_name == "openai":
20
- return OpenAIModel(
21
- model="gpt-4o-mini", chat_model_temperature=chat_model_temperature, chat_model_top_p=chat_model_top_p
22
- )
23
- if chat_model_name == "gemini":
24
- return GeminiVertextAIModel(
25
- model="gemini-2.0-flash-001",
26
- chat_model_temperature=chat_model_temperature,
27
- chat_model_top_p=chat_model_top_p,
28
- )
29
- else:
30
- raise ValueError(f"Unknown chat model type: {chat_model_name}")
23
+ chat_model_provider = chat_model_provider.lower()
24
+ if chat_model_provider not in CHAT_MODEL_PROVIDERS.keys():
25
+ raise ValueError(f"Unknown chat model provider: {chat_model_provider}")
26
+
27
+ return CHAT_MODEL_PROVIDERS[chat_model_provider](
28
+ (
29
+ chat_model_name
30
+ if chat_model_name is not None
31
+ else CHAT_MODEL_PROVIDERS[chat_model_provider].get_default_model()
32
+ ),
33
+ chat_model_temperature=chat_model_temperature,
34
+ chat_model_top_p=chat_model_top_p,
35
+ )
@@ -0,0 +1,86 @@
1
+ import os
2
+ from typing import Optional
3
+
4
+ from google import genai
5
+ from google.genai.types import GenerateContentConfig
6
+ from loguru import logger
7
+
8
+ from ai_plays_jackbox.llm.chat_model import ChatModel
9
+
10
+
11
+ class GeminiModel(ChatModel):
12
+ _gemini_vertex_ai_client: genai.Client
13
+
14
+ def __init__(self, *args, **kwargs):
15
+ super().__init__(*args, **kwargs)
16
+ self._gemini_vertex_ai_client = genai.Client(
17
+ vertexai=bool(os.environ.get("GOOGLE_GENAI_USE_VERTEXAI")),
18
+ api_key=os.environ.get("GOOGLE_GEMINI_DEVELOPER_API_KEY"),
19
+ project=os.environ.get("GOOGLE_CLOUD_PROJECT"),
20
+ location=os.environ.get("GOOGLE_CLOUD_LOCATION"),
21
+ )
22
+
23
+ # Check connection and if model exists, this will hard fail if connection can't be made
24
+ # Or if the model is not found
25
+ _ = self._gemini_vertex_ai_client.models.get(self._model)
26
+
27
+ @classmethod
28
+ def get_default_model(cls):
29
+ return "gemini-2.0-flash-001"
30
+
31
+ def generate_text(
32
+ self,
33
+ prompt: str,
34
+ instructions: str,
35
+ max_tokens: Optional[int] = None,
36
+ temperature: Optional[float] = None,
37
+ top_p: Optional[float] = None,
38
+ ) -> str:
39
+ if temperature is None:
40
+ temperature = self._chat_model_temperature
41
+ if top_p is None:
42
+ top_p = self._chat_model_top_p
43
+
44
+ chat_response = self._gemini_vertex_ai_client.models.generate_content(
45
+ model=self._model,
46
+ contents=prompt,
47
+ config=GenerateContentConfig(
48
+ system_instruction=[instructions],
49
+ max_output_tokens=max_tokens,
50
+ temperature=temperature,
51
+ top_p=top_p,
52
+ ),
53
+ )
54
+
55
+ text = str(chat_response.text).strip().replace("\n", "")
56
+ logger.info(f"Generated text: {text}")
57
+ return text
58
+
59
+ def generate_sketch(
60
+ self,
61
+ prompt: str,
62
+ instructions: str,
63
+ temperature: Optional[float] = None,
64
+ top_p: Optional[float] = None,
65
+ ) -> bytes:
66
+ image_gen_response = self._gemini_vertex_ai_client.models.generate_content(
67
+ model="gemini-2.0-flash-preview-image-generation",
68
+ contents=prompt,
69
+ config=GenerateContentConfig(
70
+ system_instruction=[instructions],
71
+ temperature=temperature,
72
+ top_p=top_p,
73
+ response_modalities=["IMAGE"],
74
+ ),
75
+ )
76
+
77
+ if (
78
+ image_gen_response.candidates
79
+ and image_gen_response.candidates[0].content
80
+ and image_gen_response.candidates[0].content.parts
81
+ ):
82
+ for part in image_gen_response.candidates[0].content.parts:
83
+ if part.inline_data is not None and part.inline_data.data is not None:
84
+ return part.inline_data.data
85
+
86
+ return b""
@@ -1,20 +1,23 @@
1
1
  from typing import Optional
2
2
 
3
3
  from loguru import logger
4
- from ollama import Options, chat, list
4
+ from ollama import Options, chat, show
5
5
 
6
6
  from ai_plays_jackbox.llm.chat_model import ChatModel
7
7
 
8
8
 
9
9
  class OllamaModel(ChatModel):
10
- _model: str
11
10
 
12
- def __init__(self, model: str = "gemma3:12b", *args, **kwargs):
11
+ def __init__(self, *args, **kwargs):
13
12
  super().__init__(*args, **kwargs)
14
- self._model = model
15
13
 
16
- # Check connection, this will hard fail if connection can't be made
17
- _ = list()
14
+ # Check connection and if model exists, this will hard fail if connection can't be made
15
+ # Or if the model is not found
16
+ _ = show(self._model)
17
+
18
+ @classmethod
19
+ def get_default_model(cls):
20
+ return "gemma3:12b"
18
21
 
19
22
  def generate_text(
20
23
  self,
@@ -36,6 +39,15 @@ class OllamaModel(ChatModel):
36
39
  stream=False,
37
40
  options=Options(num_predict=max_tokens, temperature=temperature, top_p=top_p),
38
41
  )
39
- text = chat_response.message.content.strip().replace("\n", " ")
42
+ text = str(chat_response.message.content).strip().replace("\n", " ")
40
43
  logger.info(f"Generated text: {text}")
41
44
  return text
45
+
46
+ def generate_sketch(
47
+ self,
48
+ prompt: str,
49
+ instructions: str,
50
+ temperature: Optional[float] = None,
51
+ top_p: Optional[float] = None,
52
+ ) -> bytes:
53
+ raise Exception("Ollama model not supported yet for sketches")
@@ -1,23 +1,32 @@
1
+ import base64
1
2
  import os
2
3
  from typing import Optional
3
4
 
4
5
  from loguru import logger
5
6
  from openai import OpenAI
7
+ from openai.types.chat import (
8
+ ChatCompletionDeveloperMessageParam,
9
+ ChatCompletionUserMessageParam,
10
+ )
11
+ from openai.types.responses import Response
6
12
 
7
13
  from ai_plays_jackbox.llm.chat_model import ChatModel
8
14
 
9
15
 
10
16
  class OpenAIModel(ChatModel):
11
- _model: str
12
17
  _open_ai_client: OpenAI
13
18
 
14
- def __init__(self, model: str = "gpt-4o-mini", *args, **kwargs):
19
+ def __init__(self, *args, **kwargs):
15
20
  super().__init__(*args, **kwargs)
16
21
  self._open_ai_client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
17
- self._model = model
18
22
 
19
- # Check connection, this will hard fail if connection can't be made
20
- _ = self._open_ai_client.models.list()
23
+ # Check connection and if model exists, this will hard fail if connection can't be made
24
+ # Or if the model is not found
25
+ _ = self._open_ai_client.models.retrieve(self._model)
26
+
27
+ @classmethod
28
+ def get_default_model(cls):
29
+ return "gpt-4o-mini"
21
30
 
22
31
  def generate_text(
23
32
  self,
@@ -32,15 +41,46 @@ class OpenAIModel(ChatModel):
32
41
  if top_p is None:
33
42
  top_p = self._chat_model_top_p
34
43
 
35
- instructions_formatted = {"role": "developer", "content": instructions}
36
44
  chat_response = self._open_ai_client.chat.completions.create(
37
45
  model=self._model,
38
- messages=[instructions_formatted, {"role": "user", "content": prompt}],
46
+ messages=[
47
+ ChatCompletionDeveloperMessageParam(content=instructions, role="developer"),
48
+ ChatCompletionUserMessageParam(content=prompt, role="user"),
49
+ ],
39
50
  stream=False,
40
51
  max_completion_tokens=max_tokens,
41
52
  temperature=temperature,
42
53
  top_p=top_p,
43
54
  )
44
- text = chat_response.choices[0].message.content.strip().replace("\n", "")
55
+ text = str(chat_response.choices[0].message.content).strip().replace("\n", "")
45
56
  logger.info(f"Generated text: {text}")
46
57
  return text
58
+
59
+ def generate_sketch(
60
+ self,
61
+ prompt: str,
62
+ instructions: str,
63
+ temperature: Optional[float] = None,
64
+ top_p: Optional[float] = None,
65
+ ) -> bytes:
66
+ image_gen_response: Response = self._open_ai_client.responses.create(
67
+ model=self._model,
68
+ instructions=instructions,
69
+ input=prompt,
70
+ temperature=temperature,
71
+ top_p=top_p,
72
+ tools=[
73
+ {
74
+ "type": "image_generation",
75
+ "quality": "low",
76
+ "size": "1024x1024",
77
+ }
78
+ ],
79
+ )
80
+ # Save the image to a file
81
+ image_data = [output.result for output in image_gen_response.output if output.type == "image_generation_call"]
82
+ image_base64 = ""
83
+ if image_data:
84
+ image_base64 = str(image_data[0])
85
+
86
+ return base64.b64decode(image_base64)
ai_plays_jackbox/room.py CHANGED
@@ -11,7 +11,6 @@ from ai_plays_jackbox.bot.bot_factory import JackBoxBotFactory
11
11
  from ai_plays_jackbox.bot.bot_personality import JackBoxBotVariant
12
12
  from ai_plays_jackbox.constants import ECAST_HOST
13
13
  from ai_plays_jackbox.llm.chat_model import ChatModel
14
- from ai_plays_jackbox.llm.ollama_model import OllamaModel
15
14
 
16
15
 
17
16
  class JackBoxRoom:
@@ -24,12 +23,10 @@ class JackBoxRoom:
24
23
  def play(
25
24
  self,
26
25
  room_code: str,
26
+ chat_model: ChatModel,
27
27
  num_of_bots: int = 4,
28
28
  bots_in_play: Optional[list] = None,
29
- chat_model: Optional[ChatModel] = None,
30
29
  ):
31
- if chat_model is None:
32
- chat_model = OllamaModel()
33
30
  room_type = self._get_room_type(room_code)
34
31
  if not room_type:
35
32
  logger.error(f"Unable to find room {room_code}")
@@ -45,9 +42,9 @@ class JackBoxRoom:
45
42
  for b in bots_to_make:
46
43
  bot = bot_factory.get_bot(
47
44
  room_type,
45
+ chat_model,
48
46
  name=b.value.name,
49
47
  personality=b.value.personality,
50
- chat_model=chat_model,
51
48
  )
52
49
  self._bots.append(bot)
53
50
  with self._lock:
ai_plays_jackbox/run.py CHANGED
@@ -6,21 +6,18 @@ from ai_plays_jackbox.room import JackBoxRoom
6
6
 
7
7
  def run(
8
8
  room_code: str,
9
+ chat_model_provider: str,
10
+ chat_model_name: Optional[str] = None,
9
11
  num_of_bots: int = 4,
10
12
  bots_in_play: Optional[list] = None,
11
- chat_model_name: str = "ollama",
12
13
  chat_model_temperature: float = 0.5,
13
14
  chat_model_top_p: float = 0.9,
14
15
  ):
15
- """Will run a set of bots through a game of JackBox given a room code.
16
-
17
- Args:
18
- room_code (str): The room code.
19
- num_of_bots (int, optional): The number of bots to participate. Defaults to 4.
20
- chat_model (str, optional): The chat model to use to generate responses. Defaults to "ollama".
21
- """
22
16
  chat_model = ChatModelFactory.get_chat_model(
23
- chat_model_name, chat_model_temperature=chat_model_temperature, chat_model_top_p=chat_model_top_p
17
+ chat_model_provider,
18
+ chat_model_name=chat_model_name,
19
+ chat_model_temperature=chat_model_temperature,
20
+ chat_model_top_p=chat_model_top_p,
24
21
  )
25
22
  room = JackBoxRoom()
26
- room.play(room_code, num_of_bots=num_of_bots, bots_in_play=bots_in_play, chat_model=chat_model)
23
+ room.play(room_code, chat_model, num_of_bots=num_of_bots, bots_in_play=bots_in_play)