ai-snake-lab 0.1.0__py3-none-any.whl → 0.4.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ai_snake_lab/AISim.py CHANGED
@@ -10,44 +10,74 @@ AISim.py
10
10
 
11
11
  import threading
12
12
  import time
13
- import sys
13
+ import sys, os
14
+ from datetime import datetime, timedelta
14
15
 
15
16
  from textual.app import App, ComposeResult
16
- from textual.widgets import Label, Input, Button
17
+ from textual.widgets import Label, Input, Button, Static
17
18
  from textual.containers import Vertical, Horizontal
18
19
  from textual.reactive import var
20
+ from textual.theme import Theme
19
21
 
20
- from constants.DDef import DDef
21
- from constants.DEpsilon import DEpsilon
22
- from constants.DFields import DField
23
- from constants.DFile import DFile
24
- from constants.DLayout import DLayout
25
- from constants.DLabels import DLabel
26
- from constants.DReplayMemory import MEM_TYPE
22
+ from ai_snake_lab.constants.DDef import DDef
23
+ from ai_snake_lab.constants.DEpsilon import DEpsilon
24
+ from ai_snake_lab.constants.DFields import DField
25
+ from ai_snake_lab.constants.DFile import DFile
26
+ from ai_snake_lab.constants.DLayout import DLayout
27
+ from ai_snake_lab.constants.DLabels import DLabel
28
+ from ai_snake_lab.constants.DReplayMemory import MEM_TYPE
29
+ from ai_snake_lab.constants.DDir import DDir
30
+ from ai_snake_lab.constants.DDb4EPlot import Plot
27
31
 
28
- from ai.AIAgent import AIAgent
29
- from ai.EpsilonAlgo import EpsilonAlgo
30
- from game.GameBoard import GameBoard
31
- from game.SnakeGame import SnakeGame
32
+
33
+ from ai_snake_lab.ai.AIAgent import AIAgent
34
+ from ai_snake_lab.ai.EpsilonAlgo import EpsilonAlgo
35
+ from ai_snake_lab.game.GameBoard import GameBoard
36
+ from ai_snake_lab.game.SnakeGame import SnakeGame
37
+ from ai_snake_lab.ui.Db4EPlot import Db4EPlot
32
38
 
33
39
  RANDOM_SEED = 1970
34
40
 
41
+ snake_lab_theme = Theme(
42
+ name="db4e",
43
+ primary="#88C0D0",
44
+ secondary="#1f6a83ff",
45
+ accent="#B48EAD",
46
+ foreground="#31b8e6",
47
+ background="black",
48
+ success="#A3BE8C",
49
+ warning="#EBCB8B",
50
+ error="#BF616A",
51
+ surface="black",
52
+ panel="#000000",
53
+ dark=True,
54
+ variables={
55
+ "block-cursor-text-style": "none",
56
+ "footer-key-foreground": "#88C0D0",
57
+ "input-selection-background": "#81a1c1 35%",
58
+ },
59
+ )
60
+
35
61
 
36
62
  class AISim(App):
37
63
  """A Textual app that has an AI Agent playing the Snake Game."""
38
64
 
39
65
  TITLE = DDef.APP_TITLE
40
- CSS_PATH = DFile.CSS_PATH
66
+ CSS_PATH = os.path.join(DDir.UTILS, DFile.CSS_FILE)
41
67
 
42
68
  ## Runtime values
43
69
  # Current epsilon value (degrades in real-time)
44
70
  cur_epsilon_widget = Label("N/A", id=DLayout.CUR_EPSILON)
45
71
  # Current memory type
46
72
  cur_mem_type_widget = Label("N/A", id=DLayout.CUR_MEM_TYPE)
47
- # Number of stored memories
48
- cur_num_memories_widget = Label("N/A", id=DLayout.NUM_MEMORIES)
49
- # Runtime move delay value
73
+ # Current model type
74
+ cur_model_type_widget = Label("N/A", id=DLayout.CUR_MODEL_TYPE)
75
+ # Time delay between moves
50
76
  cur_move_delay = DDef.MOVE_DELAY
77
+ # Number of stored games in the ReplayMemory
78
+ cur_num_games_widget = Label("N/A", id=DLayout.NUM_GAMES)
79
+ # Elapsed time
80
+ cur_runtime_widget = Label("N/A", id=DLayout.RUNTIME)
51
81
 
52
82
  # Intial Settings for Epsilon
53
83
  initial_epsilon_input = Input(
@@ -77,23 +107,42 @@ class AISim(App):
77
107
 
78
108
  # Buttons
79
109
  pause_button = Button(label=DLabel.PAUSE, id=DLayout.BUTTON_PAUSE, compact=True)
110
+ restart_button = Button(
111
+ label=DLabel.RESTART, id=DLayout.BUTTON_RESTART, compact=True
112
+ )
80
113
  start_button = Button(label=DLabel.START, id=DLayout.BUTTON_START, compact=True)
81
114
  quit_button = Button(label=DLabel.QUIT, id=DLayout.BUTTON_QUIT, compact=True)
82
- reset_button = Button(label=DLabel.RESET, id=DLayout.BUTTON_RESET, compact=True)
115
+ defaults_button = Button(
116
+ label=DLabel.DEFAULTS, id=DLayout.BUTTON_DEFAULTS, compact=True
117
+ )
83
118
  update_button = Button(label=DLabel.UPDATE, id=DLayout.BUTTON_UPDATE, compact=True)
84
119
 
120
+ # A dictionary to hold runtime statistics
121
+ stats = {
122
+ DField.GAME_SCORE: {
123
+ DField.GAME_NUM: [],
124
+ DField.GAME_SCORE: [],
125
+ }
126
+ }
127
+
128
+ game_score_plot = Db4EPlot(
129
+ title=DLabel.GAME_SCORE, id=DLayout.GAME_SCORE_PLOT, thin_method=Plot.SLIDING
130
+ )
131
+
85
132
  def __init__(self) -> None:
86
133
  super().__init__()
87
134
  self.game_board = GameBoard(20, id=DLayout.GAME_BOARD)
88
135
  self.snake_game = SnakeGame(game_board=self.game_board, id=DLayout.GAME_BOARD)
89
136
  self.epsilon_algo = EpsilonAlgo(seed=RANDOM_SEED)
90
137
  self.agent = AIAgent(self.epsilon_algo, seed=RANDOM_SEED)
91
- self.running = False
92
-
93
- self.score = Label("Game: 0, Highscore: 0, Score: 0")
138
+ self.cur_state = DField.STOPPED
139
+ self.game_score_plot._x_label = DLabel.GAME_NUM
140
+ self.game_score_plot._y_label = DLabel.GAME_SCORE
94
141
 
95
142
  # Setup the simulator in a background thread
96
143
  self.stop_event = threading.Event()
144
+ self.pause_event = threading.Event()
145
+ self.running = DField.STOPPED
97
146
  self.simulator_thread = threading.Thread(target=self.start_sim, daemon=True)
98
147
 
99
148
  async def action_quit(self) -> None:
@@ -105,70 +154,95 @@ class AISim(App):
105
154
 
106
155
  def compose(self) -> ComposeResult:
107
156
  """Create child widgets for the app."""
157
+
158
+ # Title bar
108
159
  yield Label(DDef.APP_TITLE, id=DLayout.TITLE)
109
- yield Horizontal(
110
- Vertical(
111
- Vertical(
112
- Horizontal(
113
- Label(
114
- f"{DLabel.EPSILON_INITIAL} : ",
115
- classes=DLayout.LABEL_SETTINGS,
116
- ),
117
- self.initial_epsilon_input,
118
- ),
119
- Horizontal(
120
- Label(
121
- f"{DLabel.EPSILON_DECAY} : ",
122
- classes=DLayout.LABEL_SETTINGS,
123
- ),
124
- self.epsilon_decay_input,
125
- ),
126
- Horizontal(
127
- Label(
128
- f"{DLabel.EPSILON_MIN} : ", classes=DLayout.LABEL_SETTINGS
129
- ),
130
- self.epsilon_min_input,
131
- ),
132
- Horizontal(
133
- Label(
134
- f"{DLabel.MOVE_DELAY} : ",
135
- classes=DLayout.LABEL_SETTINGS,
136
- ),
137
- self.move_delay_input,
138
- ),
139
- id=DLayout.SETTINGS_BOX,
160
+
161
+ # Configuration Settings
162
+ yield Vertical(
163
+ Horizontal(
164
+ Label(
165
+ f"{DLabel.EPSILON_INITIAL}",
166
+ classes=DLayout.LABEL_SETTINGS,
140
167
  ),
141
- Vertical(
142
- Horizontal(
143
- self.start_button,
144
- self.reset_button,
145
- self.update_button,
146
- self.quit_button,
147
- ),
148
- id=DLayout.BUTTON_ROW,
168
+ self.initial_epsilon_input,
169
+ ),
170
+ Horizontal(
171
+ Label(
172
+ f"{DLabel.EPSILON_DECAY}",
173
+ classes=DLayout.LABEL_SETTINGS,
149
174
  ),
175
+ self.epsilon_decay_input,
150
176
  ),
151
- Vertical(
152
- self.game_board,
153
- id=DLayout.GAME_BOX,
177
+ Horizontal(
178
+ Label(f"{DLabel.EPSILON_MIN}", classes=DLayout.LABEL_SETTINGS),
179
+ self.epsilon_min_input,
154
180
  ),
155
- Vertical(
156
- Horizontal(
157
- Label(f"{DLabel.EPSILON} : ", classes=DLayout.LABEL),
158
- self.cur_epsilon_widget,
181
+ Horizontal(
182
+ Label(
183
+ f"{DLabel.MOVE_DELAY}",
184
+ classes=DLayout.LABEL_SETTINGS,
159
185
  ),
160
- Horizontal(
161
- Label(f"{DLabel.MEM_TYPE} : ", classes=DLayout.LABEL),
162
- self.cur_mem_type_widget,
163
- ),
164
- Horizontal(
165
- Label(f"{DLabel.MEMORIES} : ", classes=DLayout.LABEL),
166
- self.cur_num_memories_widget,
167
- ),
168
- id=DLayout.RUNTIME_BOX,
186
+ self.move_delay_input,
187
+ ),
188
+ id=DLayout.SETTINGS_BOX,
189
+ )
190
+
191
+ # The Snake Game
192
+ yield Vertical(
193
+ self.game_board,
194
+ id=DLayout.GAME_BOX,
195
+ )
196
+
197
+ # Runtime values
198
+ yield Vertical(
199
+ Horizontal(
200
+ Label(f"{DLabel.EPSILON}", classes=DLayout.LABEL),
201
+ self.cur_epsilon_widget,
202
+ ),
203
+ Horizontal(
204
+ Label(f"{DLabel.MEM_TYPE}", classes=DLayout.LABEL),
205
+ self.cur_mem_type_widget,
206
+ ),
207
+ Horizontal(
208
+ Label(f"{DLabel.STORED_GAMES}", classes=DLayout.LABEL),
209
+ self.cur_num_games_widget,
210
+ ),
211
+ Horizontal(
212
+ Label(f"{DLabel.MODEL_TYPE}", classes=DLayout.LABEL),
213
+ self.cur_model_type_widget,
214
+ ),
215
+ Horizontal(
216
+ Label(f"{DLabel.RUNTIME}", classes=DLayout.LABEL),
217
+ self.cur_runtime_widget,
169
218
  ),
219
+ id=DLayout.RUNTIME_BOX,
170
220
  )
171
221
 
222
+ # Buttons
223
+ yield Vertical(
224
+ Horizontal(
225
+ self.start_button,
226
+ self.pause_button,
227
+ self.quit_button,
228
+ classes=DLayout.BUTTON_ROW,
229
+ ),
230
+ Horizontal(
231
+ self.defaults_button,
232
+ self.update_button,
233
+ self.restart_button,
234
+ classes=DLayout.BUTTON_ROW,
235
+ ),
236
+ )
237
+
238
+ # Empty fillers
239
+ yield Static(id=DLayout.FILLER_1)
240
+ yield Static(id=DLayout.FILLER_2)
241
+ yield Static(id=DLayout.FILLER_3)
242
+
243
+ # The game score plot
244
+ yield self.game_score_plot
245
+
172
246
  def on_mount(self):
173
247
  self.initial_epsilon_input.value = str(DEpsilon.EPSILON_INITIAL)
174
248
  self.epsilon_decay_input.value = str(DEpsilon.EPSILON_DECAY)
@@ -177,16 +251,21 @@ class AISim(App):
177
251
  settings_box = self.query_one(f"#{DLayout.SETTINGS_BOX}", Vertical)
178
252
  settings_box.border_title = DLabel.SETTINGS
179
253
  runtime_box = self.query_one(f"#{DLayout.RUNTIME_BOX}", Vertical)
180
- runtime_box.border_title = DLabel.RUNTIME
254
+ runtime_box.border_title = DLabel.RUNTIME_VALUES
181
255
  self.cur_mem_type_widget.update(
182
256
  MEM_TYPE.MEM_TYPE_TABLE[self.agent.memory.mem_type()]
183
257
  )
184
- self.cur_num_memories_widget.update(str(self.agent.memory.get_num_memories()))
258
+ self.cur_num_games_widget.update(str(self.agent.memory.get_num_games()))
185
259
  # Initial state is that the app is stopped
186
260
  self.add_class(DField.STOPPED)
261
+ # Register the theme
262
+ self.register_theme(snake_lab_theme)
263
+
264
+ # Set the app's theme
265
+ self.theme = "db4e"
187
266
 
188
267
  def on_quit(self):
189
- if self.running == True:
268
+ if self.running == DField.RUNNING:
190
269
  self.stop_event.set()
191
270
  if self.simulator_thread.is_alive():
192
271
  self.simulator_thread.join()
@@ -194,22 +273,69 @@ class AISim(App):
194
273
 
195
274
  def on_button_pressed(self, event: Button.Pressed) -> None:
196
275
  button_id = event.button.id
276
+
277
+ # Pause button was pressed
278
+ if button_id == DLayout.BUTTON_PAUSE:
279
+ self.pause_event.set()
280
+ self.running = DField.PAUSED
281
+ self.remove_class(DField.RUNNING)
282
+ self.add_class(DField.PAUSED)
283
+ self.cur_move_delay = float(self.move_delay_input.value)
284
+ self.cur_model_type_widget.update(self.agent.model_type())
285
+
286
+ # Restart button was pressed
287
+ elif button_id == DLayout.BUTTON_RESTART:
288
+ self.running = DField.STOPPED
289
+ self.add_class(DField.STOPPED)
290
+ self.remove_class(DField.PAUSED)
291
+
292
+ # Signal thread to stop
293
+ self.stop_event.set()
294
+ # Unpause so we can exit cleanly
295
+ self.pause_event.clear()
296
+ # Join the old thread
297
+ if self.simulator_thread.is_alive():
298
+ self.simulator_thread.join(timeout=2)
299
+
300
+ # Reset the game and the UI
301
+ self.snake_game.reset()
302
+ score = 0
303
+ highscore = 0
304
+ self.epoch = 1
305
+ game_box = self.query_one(f"#{DLayout.GAME_BOX}", Vertical)
306
+ game_box.border_title = ""
307
+ game_box.border_subtitle = ""
308
+
309
+ # Recreate events and get a new thread
310
+ self.stop_event = threading.Event()
311
+ self.pause_event = threading.Event()
312
+ self.simulator_thread = threading.Thread(target=self.start_sim, daemon=True)
313
+
197
314
  # Start button was pressed
198
- if button_id == DLayout.BUTTON_START:
199
- self.start_thread()
200
- self.running = True
315
+ elif button_id == DLayout.BUTTON_START:
316
+ if self.running == DField.STOPPED:
317
+ self.start_thread()
318
+ elif self.running == DField.PAUSED:
319
+ self.pause_event.clear()
320
+ self.pause_event.clear()
321
+ self.running = DField.RUNNING
201
322
  self.add_class(DField.RUNNING)
202
323
  self.remove_class(DField.STOPPED)
324
+ self.remove_class(DField.PAUSED)
203
325
  self.cur_move_delay = float(self.move_delay_input.value)
326
+ self.cur_model_type_widget.update(self.agent.model_type())
327
+
204
328
  # Reset button was pressed
205
- elif button_id == DLayout.BUTTON_RESET:
329
+ elif button_id == DLayout.BUTTON_DEFAULTS:
206
330
  self.initial_epsilon_input.value = str(DEpsilon.EPSILON_INITIAL)
207
331
  self.epsilon_decay_input.value = str(DEpsilon.EPSILON_DECAY)
208
332
  self.epsilon_min_input.value = str(DEpsilon.EPSILON_MIN)
209
333
  self.move_delay_input.value = str(DDef.MOVE_DELAY)
334
+
210
335
  # Quit button was pressed
211
336
  elif button_id == DLayout.BUTTON_QUIT:
212
337
  self.on_quit()
338
+
213
339
  # Update button was pressed
214
340
  elif button_id == DLayout.BUTTON_UPDATE:
215
341
  self.cur_move_delay = float(self.move_delay_input.value)
@@ -224,8 +350,13 @@ class AISim(App):
224
350
  self.epoch = 1
225
351
  game_box = self.query_one(f"#{DLayout.GAME_BOX}", Vertical)
226
352
  game_box.border_title = f"{DLabel.GAME} #{self.epoch}"
353
+ start_time = datetime.now()
227
354
 
228
355
  while not self.stop_event.is_set():
356
+ if self.pause_event.is_set():
357
+ self.pause_event.wait()
358
+ time.sleep(0.2)
359
+ continue
229
360
  # The actual training loop...
230
361
  old_state = game_board.get_state()
231
362
  move = agent.get_move(old_state)
@@ -261,14 +392,40 @@ class AISim(App):
261
392
  else:
262
393
  self.cur_epsilon_widget.update(str(round(cur_epsilon, 4)))
263
394
  # Update the number of stored memories
264
- self.cur_num_memories_widget.update(
265
- str(self.agent.memory.get_num_memories())
266
- )
395
+ self.cur_num_games_widget.update(str(self.agent.memory.get_num_games()))
396
+ # Update the stats object
397
+ self.stats[DField.GAME_SCORE][DField.GAME_NUM].append(self.epoch)
398
+ self.stats[DField.GAME_SCORE][DField.GAME_SCORE].append(score)
399
+ # Update the plot object
400
+ self.game_score_plot.add_data(self.epoch, score)
401
+ self.game_score_plot.db4e_plot()
402
+ # Update the runtime widget
403
+ elapsed_secs = (datetime.now() - start_time).total_seconds()
404
+ runtime = minutes_to_uptime(elapsed_secs)
405
+ self.cur_runtime_widget.update(runtime)
267
406
 
268
407
  def start_thread(self):
269
408
  self.simulator_thread.start()
270
409
 
271
410
 
411
+ def minutes_to_uptime(seconds: int):
412
+ # Return a string like:
413
+ # 0h 0m 45s
414
+ # 1d 7h 32m
415
+ days, minutes = divmod(int(seconds), 86400)
416
+ hours, minutes = divmod(minutes, 3600)
417
+ minutes, seconds = divmod(minutes, 60)
418
+
419
+ if days > 0:
420
+ return f"{days}d {hours}h {minutes}m"
421
+ elif hours > 0:
422
+ return f"{hours}h {minutes}m"
423
+ elif minutes > 0:
424
+ return f"{minutes}m {seconds}s"
425
+ else:
426
+ return f"{seconds}s"
427
+
428
+
272
429
  if __name__ == "__main__":
273
430
  app = AISim()
274
431
  app.run()
@@ -9,13 +9,14 @@ ai/Agent.py
9
9
  """
10
10
 
11
11
  import torch
12
- from ai.EpsilonAlgo import EpsilonAlgo
13
- from ai.ReplayMemory import ReplayMemory
14
- from ai.AITrainer import AITrainer
15
- from ai.models.ModelL import ModelL
16
- from ai.models.ModelRNN import ModelRNN
12
+ from ai_snake_lab.ai.EpsilonAlgo import EpsilonAlgo
13
+ from ai_snake_lab.ai.ReplayMemory import ReplayMemory
14
+ from ai_snake_lab.ai.AITrainer import AITrainer
15
+ from ai_snake_lab.ai.models.ModelL import ModelL
16
+ from ai_snake_lab.ai.models.ModelRNN import ModelRNN
17
17
 
18
- from constants.DReplayMemory import MEM_TYPE
18
+ from ai_snake_lab.constants.DReplayMemory import MEM_TYPE
19
+ from ai_snake_lab.constants.DLabels import DLabel
19
20
 
20
21
 
21
22
  class AIAgent:
@@ -23,16 +24,13 @@ class AIAgent:
23
24
  def __init__(self, epsilon_algo: EpsilonAlgo, seed: int):
24
25
  self.epsilon_algo = epsilon_algo
25
26
  self.memory = ReplayMemory(seed=seed)
26
- self.model = ModelL(seed=seed)
27
- # self.model = ModelRNN(seed=seed)
28
- self.trainer = AITrainer(self.model)
27
+ # self._model = ModelL(seed=seed)
28
+ self._model = ModelRNN(seed=seed)
29
+ self.trainer = AITrainer(model=self._model)
29
30
 
30
- if type(self.model) == ModelRNN:
31
+ if type(self._model) == ModelRNN:
31
32
  self.memory.mem_type(MEM_TYPE.RANDOM_GAME)
32
33
 
33
- def get_model(self):
34
- return self.model
35
-
36
34
  def get_move(self, state):
37
35
  random_move = self.epsilon_algo.get_move() # Explore with epsilon
38
36
  if random_move != False:
@@ -42,7 +40,7 @@ class AIAgent:
42
40
  final_move = [0, 0, 0]
43
41
  if type(state) != torch.Tensor:
44
42
  state = torch.tensor(state, dtype=torch.float) # Convert to a tensor
45
- prediction = self.model(state) # Get the prediction
43
+ prediction = self._model(state) # Get the prediction
46
44
  move = torch.argmax(prediction).item() # Select the move with the highest value
47
45
  final_move[move] = 1 # Set the move
48
46
  return final_move # Return
@@ -50,6 +48,15 @@ class AIAgent:
50
48
  def get_optimizer(self):
51
49
  return self.trainer.get_optimizer()
52
50
 
51
+ def model_type(self):
52
+ if type(self._model) == ModelL:
53
+ return DLabel.MODEL_LINEAR
54
+ elif type(self._model) == ModelRNN:
55
+ return DLabel.MODEL_RNN
56
+
57
+ def model(self):
58
+ return self._model
59
+
53
60
  def played_game(self, score):
54
61
  self.epsilon_algo.played_game()
55
62
 
@@ -57,27 +64,23 @@ class AIAgent:
57
64
  # Store the state, action, reward, next_state, and done in memory
58
65
  self.memory.append((state, action, reward, next_state, done))
59
66
 
60
- def set_model(self, model):
61
- self.model = model
62
-
63
67
  def set_optimizer(self, optimizer):
64
68
  self.trainer.set_optimizer(optimizer)
65
69
 
66
70
  def train_long_memory(self):
67
- # Get the states, actions, rewards, next_states, and dones from the mini_sample
68
- memory = self.memory.get_memory()
69
- memory_type = self.memory.mem_type()
70
-
71
- if type(self.model) == ModelRNN:
72
- for state, action, reward, next_state, done in memory[0]:
73
- self.trainer.train_step(state, action, reward, next_state, [done])
74
-
75
- elif memory_type == MEM_TYPE.SHUFFLE:
76
- for state, action, reward, next_state, done in memory:
77
- self.trainer.train_step(state, action, reward, next_state, [done])
78
-
79
- else:
80
- for state, action, reward, next_state, done in memory[0]:
71
+ # Train on 5 games
72
+ max_games = 2
73
+ # Get a random full game
74
+ while max_games > 0:
75
+ max_games -= 1
76
+ game = self.memory.get_random_game()
77
+ if not game:
78
+ return # no games to train on yet
79
+
80
+ for count, (state, action, reward, next_state, done) in enumerate(
81
+ game, start=1
82
+ ):
83
+ # print(f"Move #{count}: {action}")
81
84
  self.trainer.train_step(state, action, reward, next_state, [done])
82
85
 
83
86
  def train_short_memory(self, state, action, reward, next_state, done):
@@ -15,23 +15,27 @@ import numpy as np
15
15
  import time
16
16
  import sys
17
17
 
18
- from ai.models.ModelL import ModelL
19
- from ai.models.ModelRNN import ModelRNN
18
+ from ai_snake_lab.ai.models.ModelL import ModelL
19
+ from ai_snake_lab.ai.models.ModelRNN import ModelRNN
20
20
 
21
- from constants.DModelL import DModelL
22
- from constants.DModelLRNN import DModelRNN
21
+ from ai_snake_lab.constants.DModelL import DModelL
22
+ from ai_snake_lab.constants.DModelLRNN import DModelRNN
23
23
 
24
24
 
25
25
  class AITrainer:
26
26
 
27
27
  def __init__(self, model):
28
28
  torch.manual_seed(1970)
29
- self.model = model
29
+
30
30
  # The learning rate needs to be adjusted for the model type
31
31
  if type(model) == ModelL:
32
32
  learning_rate = DModelL.LEARNING_RATE
33
33
  elif type(model) == ModelRNN:
34
34
  learning_rate = DModelRNN.LEARNING_RATE
35
+ else:
36
+ raise ValueError(f"Unknown model type: {type(model)}")
37
+
38
+ self.model = model
35
39
  self.optimizer = optim.Adam(self.model.parameters(), lr=learning_rate)
36
40
  self.criterion = nn.MSELoss()
37
41
  self.gamma = 0.9