ai-plays-jackbox 0.0.1__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-plays-jackbox might be problematic. Click here for more details.

Files changed (44) hide show
  1. ai_plays_jackbox/__init__.py +0 -1
  2. ai_plays_jackbox/bot/__init__.py +0 -1
  3. ai_plays_jackbox/bot/bot_base.py +42 -8
  4. ai_plays_jackbox/bot/bot_factory.py +18 -17
  5. ai_plays_jackbox/bot/bot_personality.py +0 -1
  6. ai_plays_jackbox/bot/jackbox5/__init__.py +0 -0
  7. ai_plays_jackbox/bot/jackbox5/bot_base.py +26 -0
  8. ai_plays_jackbox/bot/jackbox5/mad_verse_city.py +121 -0
  9. ai_plays_jackbox/bot/jackbox5/patently_stupid.py +167 -0
  10. ai_plays_jackbox/bot/jackbox6/bot_base.py +1 -1
  11. ai_plays_jackbox/bot/jackbox6/dictionarium.py +105 -0
  12. ai_plays_jackbox/bot/jackbox6/joke_boat.py +105 -0
  13. ai_plays_jackbox/bot/jackbox7/bot_base.py +1 -1
  14. ai_plays_jackbox/bot/jackbox7/quiplash3.py +8 -4
  15. ai_plays_jackbox/bot/jackbox8/__init__.py +0 -0
  16. ai_plays_jackbox/bot/jackbox8/bot_base.py +20 -0
  17. ai_plays_jackbox/bot/jackbox8/job_job.py +205 -0
  18. ai_plays_jackbox/bot/standalone/__init__.py +0 -0
  19. ai_plays_jackbox/bot/standalone/drawful2.py +159 -0
  20. ai_plays_jackbox/cli/__init__.py +0 -0
  21. ai_plays_jackbox/{cli.py → cli/main.py} +28 -15
  22. ai_plays_jackbox/constants.py +3 -0
  23. ai_plays_jackbox/llm/chat_model.py +20 -3
  24. ai_plays_jackbox/llm/chat_model_factory.py +24 -19
  25. ai_plays_jackbox/llm/gemini_model.py +86 -0
  26. ai_plays_jackbox/llm/ollama_model.py +19 -7
  27. ai_plays_jackbox/llm/openai_model.py +48 -8
  28. ai_plays_jackbox/room/__init__.py +0 -0
  29. ai_plays_jackbox/{room.py → room/room.py} +2 -5
  30. ai_plays_jackbox/run.py +8 -11
  31. ai_plays_jackbox/scripts/lint.py +18 -0
  32. ai_plays_jackbox/ui/main.py +12 -0
  33. ai_plays_jackbox/ui/startup.py +248 -0
  34. ai_plays_jackbox-0.2.0.dist-info/METADATA +156 -0
  35. ai_plays_jackbox-0.2.0.dist-info/RECORD +42 -0
  36. ai_plays_jackbox-0.2.0.dist-info/entry_points.txt +4 -0
  37. ai_plays_jackbox/llm/gemini_vertex_ai.py +0 -60
  38. ai_plays_jackbox/ui/create_ui.py +0 -169
  39. ai_plays_jackbox/web_ui.py +0 -12
  40. ai_plays_jackbox-0.0.1.dist-info/METADATA +0 -88
  41. ai_plays_jackbox-0.0.1.dist-info/RECORD +0 -28
  42. ai_plays_jackbox-0.0.1.dist-info/entry_points.txt +0 -5
  43. {ai_plays_jackbox-0.0.1.dist-info → ai_plays_jackbox-0.2.0.dist-info}/LICENSE +0 -0
  44. {ai_plays_jackbox-0.0.1.dist-info → ai_plays_jackbox-0.2.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,105 @@
1
+ import random
2
+
3
+ from loguru import logger
4
+
5
+ from ai_plays_jackbox.bot.jackbox6.bot_base import JackBox6BotBase
6
+
7
+ _TOPIC_PROMPT_TEMPLATE = """
8
+ You are playing Joke Boat.
9
+
10
+ You are being asked to come up with a topic that is {placeholder}.
11
+
12
+ When generating your response, follow these rules:
13
+ - Your personality is: {personality}
14
+ - You response must be {max_length} characters or less
15
+ - Your response should be a single word.
16
+ - Do not include quotes in your response or any newlines, just the response itself
17
+ """
18
+
19
+ _PUNCHLINE_INSTRUCTIONS_TEMPLATE = """
20
+ You are playing Joke Boat. You need to fill in the given prompt with a punchline.
21
+
22
+ When generating your response, follow these rules:
23
+ - Your personality is: {personality}
24
+ - You response must be {max_length} characters or less
25
+ - Do not include quotes in your response or any newlines, just the response itself
26
+ """
27
+
28
+
29
+ class JokeBoatBot(JackBox6BotBase):
30
+ def __init__(self, *args, **kwargs):
31
+ super().__init__(*args, **kwargs)
32
+
33
+ def _handle_welcome(self, data: dict):
34
+ pass
35
+
36
+ def _handle_player_operation(self, data: dict):
37
+ if not data:
38
+ return
39
+ room_state = data.get("state", None)
40
+ if not room_state:
41
+ return
42
+
43
+ prompt = data.get("prompt", {})
44
+ prompt_html = prompt.get("html", "")
45
+ clean_prompt = self._html_to_text(prompt_html)
46
+
47
+ match room_state:
48
+ case "MakeSingleChoice":
49
+ choices: list[dict] = data.get("choices", [])
50
+ choice_type = data.get("choiceType", "")
51
+ choice_indexes = [i for i in range(0, len(choices))]
52
+ selected_choice = random.choice(choice_indexes)
53
+
54
+ if choice_type == "ChooseAuthorReady":
55
+ selected_choice = 1
56
+ if choice_type == "Skip":
57
+ selected_choice = 0
58
+
59
+ if data.get("chosen", None) is None:
60
+ self._client_send({"action": "choose", "choice": selected_choice})
61
+
62
+ case "EnterSingleText":
63
+ if "Write as many topics as you can." in clean_prompt:
64
+ placeholder = data.get("placeholder", "")
65
+ max_length = data.get("maxLength", 42)
66
+ topic = self._generate_topic(placeholder, max_length)
67
+ self._client_send({"action": "write", "entry": topic})
68
+ if "Write your punchline" in clean_prompt or "Write the punchline to this joke" in clean_prompt:
69
+ max_length = data.get("maxLength", 80)
70
+ punchline = self._generate_punchline(clean_prompt, max_length)
71
+ self._client_send({"action": "write", "entry": punchline})
72
+
73
+ def _handle_room_operation(self, data: dict):
74
+ pass
75
+
76
+ def _generate_topic(self, placeholder: str, max_length: int) -> str:
77
+ logger.info("Generating topic...")
78
+ formatted_prompt = _TOPIC_PROMPT_TEMPLATE.format(
79
+ personality=self._personality,
80
+ placeholder=placeholder,
81
+ max_length=max_length,
82
+ )
83
+
84
+ topic = self._chat_model.generate_text(
85
+ formatted_prompt,
86
+ "",
87
+ temperature=self._chat_model._chat_model_temperature,
88
+ top_p=self._chat_model._chat_model_top_p,
89
+ )
90
+ return topic[:max_length]
91
+
92
+ def _generate_punchline(self, prompt: str, max_length: int) -> str:
93
+ logger.info("Generating punchline...")
94
+ formatted_instructions = _PUNCHLINE_INSTRUCTIONS_TEMPLATE.format(
95
+ personality=self._personality,
96
+ prompt=prompt,
97
+ max_length=max_length,
98
+ )
99
+ punchline = self._chat_model.generate_text(
100
+ prompt,
101
+ formatted_instructions,
102
+ temperature=self._chat_model._chat_model_temperature,
103
+ top_p=self._chat_model._chat_model_top_p,
104
+ )
105
+ return punchline[:max_length]
@@ -1,4 +1,4 @@
1
- from abc import ABC, abstractmethod
1
+ from abc import ABC
2
2
 
3
3
  from ai_plays_jackbox.bot.bot_base import JackBoxBotBase
4
4
 
@@ -75,7 +75,11 @@ class Quiplash3Bot(JackBox7BotBase):
75
75
  max_tokens = 32
76
76
  instructions = _FINAL_QUIP_PROMPT_INSTRUCTIONS_TEMPLATE.format(personality=self._personality)
77
77
  quip = self._chat_model.generate_text(
78
- prompt, instructions=instructions, max_tokens=max_tokens, temperature=0.7, top_p=0.7
78
+ prompt,
79
+ instructions,
80
+ max_tokens=max_tokens,
81
+ temperature=self._chat_model._chat_model_temperature,
82
+ top_p=self._chat_model._chat_model_top_p,
79
83
  )
80
84
  return quip
81
85
 
@@ -84,7 +88,7 @@ class Quiplash3Bot(JackBox7BotBase):
84
88
  instructions = _QUIP_CHOICE_PROMPT_INSTRUCTIONS_TEMPLATE.format(prompt=prompt)
85
89
  response = self._chat_model.generate_text(
86
90
  f"Vote for your favorite response. Your options are: {choices_str}",
87
- instructions=instructions,
91
+ instructions,
88
92
  max_tokens=1,
89
93
  )
90
94
  try:
@@ -100,5 +104,5 @@ class Quiplash3Bot(JackBox7BotBase):
100
104
  return choosen_prompt - 1
101
105
 
102
106
  def _choose_random_favorite(self, choices: list[dict]) -> int:
103
- choices = [i for i in range(0, len(choices))]
104
- return random.choice(choices)
107
+ choices_as_ints = [i for i in range(0, len(choices))]
108
+ return random.choice(choices_as_ints)
File without changes
@@ -0,0 +1,20 @@
1
+ from abc import ABC
2
+
3
+ from ai_plays_jackbox.bot.bot_base import JackBoxBotBase
4
+
5
+
6
+ class JackBox8BotBase(JackBoxBotBase, ABC):
7
+
8
+ @property
9
+ def _player_operation_key(self) -> str:
10
+ return f"player:{self._player_id}"
11
+
12
+ def _is_player_operation_key(self, operation_key: str) -> bool:
13
+ return operation_key == self._player_operation_key
14
+
15
+ @property
16
+ def _room_operation_key(self) -> str:
17
+ return "room"
18
+
19
+ def _is_room_operation_key(self, operation_key: str) -> bool:
20
+ return operation_key == self._room_operation_key
@@ -0,0 +1,205 @@
1
+ import random
2
+ import string
3
+
4
+ from loguru import logger
5
+
6
+ from ai_plays_jackbox.bot.jackbox8.bot_base import JackBox8BotBase
7
+
8
+ _RESPONSE_PROMPT_TEMPLATE = """
9
+ You are playing Job Job. You need response to the given prompt.
10
+
11
+ When generating your response, follow these rules:
12
+ - Your personality is: {personality}
13
+ - Your response must be {max_length} letters or less.
14
+ - Your response must have a minimum of {min_words} words.
15
+ - Do not include quotes in your response.
16
+
17
+ {instruction}
18
+
19
+ Your prompt is:
20
+
21
+ {prompt}
22
+ """
23
+
24
+ _COMPOSITION_PROMPT_TEMPLATE = """
25
+ You are playing Job Job. You must create a response to a interview question using only specific words given.
26
+
27
+ When generating your response, follow these rules:
28
+ - Your personality is: {personality}
29
+ - Your response must only use the allowed words or characters, nothing else
30
+ - If you decide to use a character, you must have it separated by a space from any words
31
+ - You can select a maximum of {max_words} words
32
+
33
+ Your interview question is:
34
+
35
+ {prompt}
36
+
37
+ Your allowed words or characters are:
38
+
39
+ {all_possible_words_str}
40
+ """
41
+
42
+
43
+ class JobJobBot(JackBox8BotBase):
44
+ def __init__(self, *args, **kwargs):
45
+ super().__init__(*args, **kwargs)
46
+
47
+ def _handle_welcome(self, data: dict):
48
+ pass
49
+
50
+ def _handle_player_operation(self, data: dict):
51
+ if not data:
52
+ return
53
+
54
+ kind = data.get("kind", "")
55
+ has_controls = data.get("hasControls", False)
56
+ response_key = data.get("responseKey", "")
57
+ done_key = data.get("doneKey", "")
58
+
59
+ if has_controls:
60
+ if "skip:" in response_key:
61
+ self._object_update(response_key, {"action": "skip"})
62
+ return
63
+
64
+ match kind:
65
+ case "writing":
66
+ instruction = data.get("instruction", "")
67
+ prompt = data.get("prompt", "")
68
+ max_length = data.get("maxLength", 128)
69
+ min_words = data.get("minWords", 5)
70
+ text_key = data.get("textKey", "")
71
+ response = self._generate_response(instruction, prompt, max_length, min_words)
72
+ self._text_update(text_key, response)
73
+ self._object_update(done_key, {"done": True})
74
+
75
+ case "magnets":
76
+ prompt = data.get("prompt", "")
77
+ answer_key = data.get("answerKey", "")
78
+ stash = data.get("stash", [[]])
79
+ max_words = data.get("maxWords", 12)
80
+ composition_list = self._generate_composition_list(prompt, stash, max_words)
81
+ self._object_update(
82
+ answer_key,
83
+ {
84
+ "final": True,
85
+ "text": composition_list,
86
+ },
87
+ )
88
+
89
+ case "resumagents":
90
+ prompt = data.get("prompt", "")
91
+ answer_key = data.get("answerKey", "")
92
+ stash = data.get("stash", [[]])
93
+ max_words = data.get("maxWords", 12)
94
+ max_words_per_answer = data.get("maxWordsPerAnswer", 8)
95
+ num_answers = data.get("numAnswers", 8)
96
+ resume_composition_list = self._generate_resume_composition_list(
97
+ prompt, stash, max_words, max_words_per_answer, num_answers
98
+ )
99
+ self._object_update(
100
+ answer_key,
101
+ {
102
+ "final": True,
103
+ "text": resume_composition_list,
104
+ },
105
+ )
106
+
107
+ case "voting":
108
+ choices: list[dict] = data.get("choices", [])
109
+ choice_indexes = [i for i in range(0, len(choices))]
110
+ selected_choice = random.choice(choice_indexes)
111
+ self._object_update(response_key, {"action": "choose", "choice": selected_choice})
112
+
113
+ def _handle_room_operation(self, data: dict):
114
+ pass
115
+
116
+ def _generate_response(self, instruction: str, prompt: str, max_length: int, min_words: int) -> str:
117
+ formatted_prompt = _RESPONSE_PROMPT_TEMPLATE.format(
118
+ personality=self._personality,
119
+ max_length=max_length,
120
+ min_words=min_words,
121
+ instruction=instruction,
122
+ prompt=prompt,
123
+ )
124
+ response = self._chat_model.generate_text(
125
+ formatted_prompt,
126
+ "",
127
+ temperature=self._chat_model._chat_model_temperature,
128
+ top_p=self._chat_model._chat_model_top_p,
129
+ )
130
+ if len(response) > max_length:
131
+ response = response[: max_length - 1]
132
+ return response
133
+
134
+ def _generate_composition_list(
135
+ self,
136
+ prompt: str,
137
+ stash: list[list[str]],
138
+ max_words: int,
139
+ ) -> list[dict]:
140
+
141
+ possible_word_choices = []
142
+
143
+ for stash_entry in stash:
144
+ for word in stash_entry:
145
+ possible_word_choices.append(word)
146
+
147
+ all_possible_words_str = "\n".join([word for word in possible_word_choices])
148
+ formatted_prompt = _COMPOSITION_PROMPT_TEMPLATE.format(
149
+ personality=self._personality,
150
+ all_possible_words_str=all_possible_words_str,
151
+ max_words=max_words,
152
+ prompt=prompt,
153
+ )
154
+ response = self._chat_model.generate_text(
155
+ formatted_prompt,
156
+ "",
157
+ temperature=self._chat_model._chat_model_temperature,
158
+ top_p=self._chat_model._chat_model_top_p,
159
+ )
160
+
161
+ ## Listen, I know this is isn't the fastest way to search
162
+ ## It's 12 words, bite me with your Big O notation
163
+ composition_list = []
164
+ response_list = response.split(" ")
165
+ for response_word in response_list:
166
+ found_word = False
167
+ response_word = response_word.strip()
168
+ if not all(char in string.punctuation for char in response_word):
169
+ response_word = response_word.translate(str.maketrans("", "", string.punctuation)).lower()
170
+
171
+ if not found_word:
172
+ for stash_index, stash_entry in enumerate(stash):
173
+ for check_word_index, check_word in enumerate(stash_entry):
174
+ if response_word == check_word.lower():
175
+ composition_list.append(
176
+ {
177
+ "index": stash_index,
178
+ "word": check_word_index,
179
+ }
180
+ )
181
+ found_word = True
182
+ break
183
+ if found_word:
184
+ break
185
+
186
+ if not found_word:
187
+ logger.warning(f"Word not found in choices: {response_word}")
188
+
189
+ if len(composition_list) > max_words:
190
+ composition_list = composition_list[: max_words - 1]
191
+ return composition_list
192
+
193
+ def _generate_resume_composition_list(
194
+ self,
195
+ prompt: str,
196
+ stash: list[list[str]],
197
+ max_words: int,
198
+ max_words_per_answers: int,
199
+ num_of_answers: int,
200
+ ) -> list[list[dict]]:
201
+ # TODO Figure this out
202
+ resume_composition_list = []
203
+ for _ in range(0, num_of_answers):
204
+ resume_composition_list.append([{"index": 0, "word": 0}])
205
+ return resume_composition_list
File without changes
@@ -0,0 +1,159 @@
1
+ import random
2
+
3
+ from loguru import logger
4
+
5
+ from ai_plays_jackbox.bot.bot_base import JackBoxBotBase
6
+
7
+ _DRAWING_PROMPT_TEMPLATE = """
8
+ You are playing Drawful 2.
9
+
10
+ Generate an image with the following prompt: {prompt}
11
+
12
+ When generating your response, follow these rules:
13
+ - Your personality is: {personality}
14
+ - Make sure to implement your personality somehow into the drawing, but keep the prompt in mind
15
+ - The image must be a simple sketch
16
+ - The image must have a white background and use black for the lines
17
+ - Avoid intricate details
18
+ """
19
+
20
+
21
+ class Drawful2Bot(JackBoxBotBase):
22
+ _drawing_completed: bool = False
23
+
24
+ def __init__(self, *args, **kwargs):
25
+ super().__init__(*args, **kwargs)
26
+
27
+ @property
28
+ def _player_operation_key(self) -> str:
29
+ return f"player:{self._player_id}"
30
+
31
+ def _is_player_operation_key(self, operation_key: str) -> bool:
32
+ return operation_key == self._player_operation_key
33
+
34
+ @property
35
+ def _room_operation_key(self) -> str:
36
+ return "room"
37
+
38
+ def _is_room_operation_key(self, operation_key: str) -> bool:
39
+ return operation_key == self._room_operation_key
40
+
41
+ def _handle_welcome(self, data: dict):
42
+ pass
43
+
44
+ def _handle_player_operation(self, data: dict):
45
+ if not data:
46
+ return
47
+ room_state = data.get("state", None)
48
+ if not room_state:
49
+ return
50
+ prompt = data.get("prompt")
51
+ prompt_text = self._html_to_text(prompt.get("html", "")) if prompt is not None else ""
52
+
53
+ match room_state:
54
+ case "Draw":
55
+ colors = data.get("colors", ["#fb405a", "#7a2259"])
56
+ selected_color = colors[0]
57
+ canvas_height = int(data.get("size", {}).get("height", 320))
58
+ canvas_width = int(data.get("size", {}).get("width", 320))
59
+ lines = self._generate_drawing(prompt_text, canvas_height, canvas_width)
60
+ object_key = data.get("objectKey", "")
61
+ if object_key != "":
62
+ if not self._drawing_completed:
63
+ self._object_update(
64
+ object_key,
65
+ {
66
+ "lines": [{"color": selected_color, "thickness": 1, "points": l} for l in lines],
67
+ "submit": True,
68
+ },
69
+ )
70
+ # This prevents the bot from trying to draw multiple times
71
+ self._drawing_completed = True
72
+
73
+ case "EnterSingleText":
74
+ # We need to reset this once we're entering options
75
+ self._drawing_completed = False
76
+ # Listen, the bot can't see the drawing
77
+ # so they're just going to say something
78
+ text_key = data.get("textKey", "")
79
+ self._text_update(text_key, self._generate_random_response())
80
+
81
+ case "MakeSingleChoice":
82
+ # Bot still can't see the drawing
83
+ # so just pick something
84
+ if data.get("type", "single") == "repeating":
85
+ pass
86
+ choices = data.get("choices", [])
87
+ choices_as_ints = [i for i in range(0, len(choices))]
88
+ selected_choice = random.choice(choices_as_ints)
89
+ self._client_send({"action": "choose", "choice": selected_choice})
90
+
91
+ def _handle_room_operation(self, data: dict):
92
+ pass
93
+
94
+ def _generate_drawing(self, prompt: str, canvas_height: int, canvas_width: int) -> list[str]:
95
+ logger.info("Generating drawing...")
96
+ image_prompt = _DRAWING_PROMPT_TEMPLATE.format(prompt=prompt, personality=self._personality)
97
+ image_bytes = self._chat_model.generate_sketch(
98
+ image_prompt,
99
+ "",
100
+ temperature=self._chat_model._chat_model_temperature,
101
+ top_p=self._chat_model._chat_model_top_p,
102
+ )
103
+ return self._image_bytes_to_polylines(image_bytes, canvas_height, canvas_width)
104
+
105
+ def _generate_random_response(self) -> str:
106
+ possible_responses = [
107
+ "Abstract awkward silence",
108
+ "Abstract existential dread",
109
+ "Abstract late-stage capitalism",
110
+ "Abstract lost hope",
111
+ "Abstract the void",
112
+ "Baby Yoda, but weird",
113
+ "Barbie, but weird",
114
+ "Confused dentist",
115
+ "Confused gym teacher",
116
+ "Confused lawyer",
117
+ "Confused therapist",
118
+ "DJ in trouble",
119
+ "Definitely banana",
120
+ "Definitely blob",
121
+ "Definitely potato",
122
+ "Definitely spaghetti",
123
+ "Taylor Swift, but weird",
124
+ "Waluigi, but weird",
125
+ "banana with feelings",
126
+ "chicken riding a scooter",
127
+ "cloud with feelings",
128
+ "confused gym teacher",
129
+ "confused therapist",
130
+ "dentist in trouble",
131
+ "duck + existential dread",
132
+ "duck riding a scooter",
133
+ "excited janitor",
134
+ "ferret + awkward silence",
135
+ "giraffe + awkward silence",
136
+ "giraffe filing taxes",
137
+ "hamster + existential dread",
138
+ "hamster + lost hope",
139
+ "hamster riding a scooter",
140
+ "janitor in trouble",
141
+ "joyful octopus",
142
+ "lawyer in trouble",
143
+ "llama + awkward silence",
144
+ "llama + late-stage capitalism",
145
+ "lonely dentist",
146
+ "lonely hamster",
147
+ "lonely janitor",
148
+ "lonely pirate",
149
+ "mango with feelings",
150
+ "pirate in trouble",
151
+ "sad DJ",
152
+ "sad hamster",
153
+ "spaghetti with feelings",
154
+ "terrified duck",
155
+ "terrified ferret",
156
+ "terrified lawyer",
157
+ ]
158
+ chosen_response = random.choice(possible_responses)
159
+ return chosen_response
File without changes
@@ -1,6 +1,12 @@
1
1
  import argparse
2
2
 
3
- from ai_plays_jackbox import run
3
+ from ai_plays_jackbox.constants import (
4
+ DEFAULT_NUM_OF_BOTS,
5
+ DEFAULT_TEMPERATURE,
6
+ DEFAULT_TOP_P,
7
+ )
8
+ from ai_plays_jackbox.llm.chat_model_factory import CHAT_MODEL_PROVIDERS
9
+ from ai_plays_jackbox.run import run
4
10
 
5
11
 
6
12
  def _validate_room_code(string_to_check: str) -> str:
@@ -39,7 +45,7 @@ def _validate_top_p(string_to_check: str) -> float:
39
45
  return number_value
40
46
 
41
47
 
42
- def cli():
48
+ def main():
43
49
  parser = argparse.ArgumentParser()
44
50
  parser.add_argument(
45
51
  "--room-code",
@@ -49,46 +55,53 @@ def cli():
49
55
  metavar="WXYZ",
50
56
  )
51
57
  parser.add_argument(
52
- "--chat-model-name",
58
+ "--chat-model-provider",
53
59
  required=True,
54
60
  help="Choose which chat model platform to use",
55
- choices=("ollama", "openai", "gemini"),
61
+ choices=list(CHAT_MODEL_PROVIDERS.keys()),
62
+ type=str,
63
+ )
64
+ parser.add_argument(
65
+ "--chat-model-name",
66
+ required=False,
67
+ help="Choose which chat model to use (Will default to default for provider)",
56
68
  type=str,
57
69
  )
58
70
  parser.add_argument(
59
71
  "--num-of-bots",
60
72
  required=False,
61
- default=4,
62
- help="How many bots to have play (Defaults to 4)",
73
+ default=DEFAULT_NUM_OF_BOTS,
74
+ help="How many bots to have play",
63
75
  type=_validate_num_of_bots,
64
- metavar="4",
76
+ metavar=str(DEFAULT_NUM_OF_BOTS),
65
77
  )
66
78
  parser.add_argument(
67
79
  "--temperature",
68
80
  required=False,
69
- default=0.5,
70
- help="Temperature for Gen AI (Defaults to 0.5)",
81
+ default=DEFAULT_TEMPERATURE,
82
+ help="Temperature for Gen AI",
71
83
  type=_validate_temperature,
72
- metavar="0.5",
84
+ metavar=str(DEFAULT_TEMPERATURE),
73
85
  )
74
86
  parser.add_argument(
75
87
  "--top-p",
76
88
  required=False,
77
- default=0.9,
78
- help="Top P for Gen AI (Defaults to 0.9)",
89
+ default=DEFAULT_TOP_P,
90
+ help="Top P for Gen AI",
79
91
  type=_validate_top_p,
80
- metavar="0.9",
92
+ metavar=str(DEFAULT_TOP_P),
81
93
  )
82
94
  args = parser.parse_args()
83
95
 
84
96
  run(
85
97
  args.room_code.upper(),
86
- num_of_bots=args.num_of_bots,
98
+ args.chat_model_provider,
87
99
  chat_model_name=args.chat_model_name,
100
+ num_of_bots=args.num_of_bots,
88
101
  chat_model_temperature=args.temperature,
89
102
  chat_model_top_p=args.top_p,
90
103
  )
91
104
 
92
105
 
93
106
  if __name__ == "__main__":
94
- cli()
107
+ main()
@@ -1 +1,4 @@
1
1
  ECAST_HOST = "ecast.jackboxgames.com"
2
+ DEFAULT_TEMPERATURE = 0.5
3
+ DEFAULT_TOP_P = 0.9
4
+ DEFAULT_NUM_OF_BOTS = 4
@@ -3,13 +3,20 @@ from typing import Optional
3
3
 
4
4
 
5
5
  class ChatModel(ABC):
6
- _chat_model_temperature: float = 0.5
7
- _chat_model_top_p: float = 0.9
6
+ _model: str
7
+ _chat_model_temperature: float
8
+ _chat_model_top_p: float
8
9
 
9
- def __init__(self, chat_model_temperature: float = 0.5, chat_model_top_p: float = 0.9):
10
+ def __init__(self, model: str, chat_model_temperature: float = 0.5, chat_model_top_p: float = 0.9):
11
+ self._model = model
10
12
  self._chat_model_temperature = chat_model_temperature
11
13
  self._chat_model_top_p = chat_model_top_p
12
14
 
15
+ @classmethod
16
+ @abstractmethod
17
+ def get_default_model(self) -> str:
18
+ pass
19
+
13
20
  @abstractmethod
14
21
  def generate_text(
15
22
  self,
@@ -20,3 +27,13 @@ class ChatModel(ABC):
20
27
  top_p: Optional[float] = None,
21
28
  ) -> str:
22
29
  pass
30
+
31
+ @abstractmethod
32
+ def generate_sketch(
33
+ self,
34
+ prompt: str,
35
+ instructions: str,
36
+ temperature: Optional[float] = None,
37
+ top_p: Optional[float] = None,
38
+ ) -> bytes:
39
+ pass