ai-snake-lab 0.1.0__py3-none-any.whl → 0.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_snake_lab/AISim.py +243 -86
- ai_snake_lab/ai/AIAgent.py +34 -31
- ai_snake_lab/ai/AITrainer.py +9 -5
- ai_snake_lab/ai/ReplayMemory.py +61 -24
- ai_snake_lab/ai/models/ModelL.py +7 -4
- ai_snake_lab/ai/models/ModelRNN.py +1 -1
- ai_snake_lab/constants/DDb4EPlot.py +20 -0
- ai_snake_lab/constants/DDef.py +1 -1
- ai_snake_lab/constants/DDir.py +4 -1
- ai_snake_lab/constants/DFields.py +6 -0
- ai_snake_lab/constants/DFile.py +2 -1
- ai_snake_lab/constants/DLabels.py +24 -4
- ai_snake_lab/constants/DLayout.py +16 -2
- ai_snake_lab/constants/DModelL.py +4 -0
- ai_snake_lab/constants/DSim.py +20 -0
- ai_snake_lab/game/GameBoard.py +36 -22
- ai_snake_lab/game/SnakeGame.py +17 -0
- ai_snake_lab/ui/Db4EPlot.py +160 -0
- ai_snake_lab/utils/AISim.tcss +81 -38
- {ai_snake_lab-0.1.0.dist-info → ai_snake_lab-0.4.3.dist-info}/METADATA +39 -5
- ai_snake_lab-0.4.3.dist-info/RECORD +31 -0
- {ai_snake_lab-0.1.0.dist-info → ai_snake_lab-0.4.3.dist-info}/WHEEL +1 -1
- {ai_snake_lab-0.1.0.dist-info → ai_snake_lab-0.4.3.dist-info/licenses}/LICENSE +2 -0
- ai_snake_lab-0.1.0.dist-info/RECORD +0 -28
- {ai_snake_lab-0.1.0.dist-info → ai_snake_lab-0.4.3.dist-info}/entry_points.txt +0 -0
ai_snake_lab/AISim.py
CHANGED
@@ -10,44 +10,74 @@ AISim.py
|
|
10
10
|
|
11
11
|
import threading
|
12
12
|
import time
|
13
|
-
import sys
|
13
|
+
import sys, os
|
14
|
+
from datetime import datetime, timedelta
|
14
15
|
|
15
16
|
from textual.app import App, ComposeResult
|
16
|
-
from textual.widgets import Label, Input, Button
|
17
|
+
from textual.widgets import Label, Input, Button, Static
|
17
18
|
from textual.containers import Vertical, Horizontal
|
18
19
|
from textual.reactive import var
|
20
|
+
from textual.theme import Theme
|
19
21
|
|
20
|
-
from constants.DDef import DDef
|
21
|
-
from constants.DEpsilon import DEpsilon
|
22
|
-
from constants.DFields import DField
|
23
|
-
from constants.DFile import DFile
|
24
|
-
from constants.DLayout import DLayout
|
25
|
-
from constants.DLabels import DLabel
|
26
|
-
from constants.DReplayMemory import MEM_TYPE
|
22
|
+
from ai_snake_lab.constants.DDef import DDef
|
23
|
+
from ai_snake_lab.constants.DEpsilon import DEpsilon
|
24
|
+
from ai_snake_lab.constants.DFields import DField
|
25
|
+
from ai_snake_lab.constants.DFile import DFile
|
26
|
+
from ai_snake_lab.constants.DLayout import DLayout
|
27
|
+
from ai_snake_lab.constants.DLabels import DLabel
|
28
|
+
from ai_snake_lab.constants.DReplayMemory import MEM_TYPE
|
29
|
+
from ai_snake_lab.constants.DDir import DDir
|
30
|
+
from ai_snake_lab.constants.DDb4EPlot import Plot
|
27
31
|
|
28
|
-
|
29
|
-
from ai.
|
30
|
-
from
|
31
|
-
from game.
|
32
|
+
|
33
|
+
from ai_snake_lab.ai.AIAgent import AIAgent
|
34
|
+
from ai_snake_lab.ai.EpsilonAlgo import EpsilonAlgo
|
35
|
+
from ai_snake_lab.game.GameBoard import GameBoard
|
36
|
+
from ai_snake_lab.game.SnakeGame import SnakeGame
|
37
|
+
from ai_snake_lab.ui.Db4EPlot import Db4EPlot
|
32
38
|
|
33
39
|
RANDOM_SEED = 1970
|
34
40
|
|
41
|
+
snake_lab_theme = Theme(
|
42
|
+
name="db4e",
|
43
|
+
primary="#88C0D0",
|
44
|
+
secondary="#1f6a83ff",
|
45
|
+
accent="#B48EAD",
|
46
|
+
foreground="#31b8e6",
|
47
|
+
background="black",
|
48
|
+
success="#A3BE8C",
|
49
|
+
warning="#EBCB8B",
|
50
|
+
error="#BF616A",
|
51
|
+
surface="black",
|
52
|
+
panel="#000000",
|
53
|
+
dark=True,
|
54
|
+
variables={
|
55
|
+
"block-cursor-text-style": "none",
|
56
|
+
"footer-key-foreground": "#88C0D0",
|
57
|
+
"input-selection-background": "#81a1c1 35%",
|
58
|
+
},
|
59
|
+
)
|
60
|
+
|
35
61
|
|
36
62
|
class AISim(App):
|
37
63
|
"""A Textual app that has an AI Agent playing the Snake Game."""
|
38
64
|
|
39
65
|
TITLE = DDef.APP_TITLE
|
40
|
-
CSS_PATH = DFile.
|
66
|
+
CSS_PATH = os.path.join(DDir.UTILS, DFile.CSS_FILE)
|
41
67
|
|
42
68
|
## Runtime values
|
43
69
|
# Current epsilon value (degrades in real-time)
|
44
70
|
cur_epsilon_widget = Label("N/A", id=DLayout.CUR_EPSILON)
|
45
71
|
# Current memory type
|
46
72
|
cur_mem_type_widget = Label("N/A", id=DLayout.CUR_MEM_TYPE)
|
47
|
-
#
|
48
|
-
|
49
|
-
#
|
73
|
+
# Current model type
|
74
|
+
cur_model_type_widget = Label("N/A", id=DLayout.CUR_MODEL_TYPE)
|
75
|
+
# Time delay between moves
|
50
76
|
cur_move_delay = DDef.MOVE_DELAY
|
77
|
+
# Number of stored games in the ReplayMemory
|
78
|
+
cur_num_games_widget = Label("N/A", id=DLayout.NUM_GAMES)
|
79
|
+
# Elapsed time
|
80
|
+
cur_runtime_widget = Label("N/A", id=DLayout.RUNTIME)
|
51
81
|
|
52
82
|
# Intial Settings for Epsilon
|
53
83
|
initial_epsilon_input = Input(
|
@@ -77,23 +107,42 @@ class AISim(App):
|
|
77
107
|
|
78
108
|
# Buttons
|
79
109
|
pause_button = Button(label=DLabel.PAUSE, id=DLayout.BUTTON_PAUSE, compact=True)
|
110
|
+
restart_button = Button(
|
111
|
+
label=DLabel.RESTART, id=DLayout.BUTTON_RESTART, compact=True
|
112
|
+
)
|
80
113
|
start_button = Button(label=DLabel.START, id=DLayout.BUTTON_START, compact=True)
|
81
114
|
quit_button = Button(label=DLabel.QUIT, id=DLayout.BUTTON_QUIT, compact=True)
|
82
|
-
|
115
|
+
defaults_button = Button(
|
116
|
+
label=DLabel.DEFAULTS, id=DLayout.BUTTON_DEFAULTS, compact=True
|
117
|
+
)
|
83
118
|
update_button = Button(label=DLabel.UPDATE, id=DLayout.BUTTON_UPDATE, compact=True)
|
84
119
|
|
120
|
+
# A dictionary to hold runtime statistics
|
121
|
+
stats = {
|
122
|
+
DField.GAME_SCORE: {
|
123
|
+
DField.GAME_NUM: [],
|
124
|
+
DField.GAME_SCORE: [],
|
125
|
+
}
|
126
|
+
}
|
127
|
+
|
128
|
+
game_score_plot = Db4EPlot(
|
129
|
+
title=DLabel.GAME_SCORE, id=DLayout.GAME_SCORE_PLOT, thin_method=Plot.SLIDING
|
130
|
+
)
|
131
|
+
|
85
132
|
def __init__(self) -> None:
|
86
133
|
super().__init__()
|
87
134
|
self.game_board = GameBoard(20, id=DLayout.GAME_BOARD)
|
88
135
|
self.snake_game = SnakeGame(game_board=self.game_board, id=DLayout.GAME_BOARD)
|
89
136
|
self.epsilon_algo = EpsilonAlgo(seed=RANDOM_SEED)
|
90
137
|
self.agent = AIAgent(self.epsilon_algo, seed=RANDOM_SEED)
|
91
|
-
self.
|
92
|
-
|
93
|
-
self.
|
138
|
+
self.cur_state = DField.STOPPED
|
139
|
+
self.game_score_plot._x_label = DLabel.GAME_NUM
|
140
|
+
self.game_score_plot._y_label = DLabel.GAME_SCORE
|
94
141
|
|
95
142
|
# Setup the simulator in a background thread
|
96
143
|
self.stop_event = threading.Event()
|
144
|
+
self.pause_event = threading.Event()
|
145
|
+
self.running = DField.STOPPED
|
97
146
|
self.simulator_thread = threading.Thread(target=self.start_sim, daemon=True)
|
98
147
|
|
99
148
|
async def action_quit(self) -> None:
|
@@ -105,70 +154,95 @@ class AISim(App):
|
|
105
154
|
|
106
155
|
def compose(self) -> ComposeResult:
|
107
156
|
"""Create child widgets for the app."""
|
157
|
+
|
158
|
+
# Title bar
|
108
159
|
yield Label(DDef.APP_TITLE, id=DLayout.TITLE)
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
),
|
117
|
-
self.initial_epsilon_input,
|
118
|
-
),
|
119
|
-
Horizontal(
|
120
|
-
Label(
|
121
|
-
f"{DLabel.EPSILON_DECAY} : ",
|
122
|
-
classes=DLayout.LABEL_SETTINGS,
|
123
|
-
),
|
124
|
-
self.epsilon_decay_input,
|
125
|
-
),
|
126
|
-
Horizontal(
|
127
|
-
Label(
|
128
|
-
f"{DLabel.EPSILON_MIN} : ", classes=DLayout.LABEL_SETTINGS
|
129
|
-
),
|
130
|
-
self.epsilon_min_input,
|
131
|
-
),
|
132
|
-
Horizontal(
|
133
|
-
Label(
|
134
|
-
f"{DLabel.MOVE_DELAY} : ",
|
135
|
-
classes=DLayout.LABEL_SETTINGS,
|
136
|
-
),
|
137
|
-
self.move_delay_input,
|
138
|
-
),
|
139
|
-
id=DLayout.SETTINGS_BOX,
|
160
|
+
|
161
|
+
# Configuration Settings
|
162
|
+
yield Vertical(
|
163
|
+
Horizontal(
|
164
|
+
Label(
|
165
|
+
f"{DLabel.EPSILON_INITIAL}",
|
166
|
+
classes=DLayout.LABEL_SETTINGS,
|
140
167
|
),
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
),
|
148
|
-
id=DLayout.BUTTON_ROW,
|
168
|
+
self.initial_epsilon_input,
|
169
|
+
),
|
170
|
+
Horizontal(
|
171
|
+
Label(
|
172
|
+
f"{DLabel.EPSILON_DECAY}",
|
173
|
+
classes=DLayout.LABEL_SETTINGS,
|
149
174
|
),
|
175
|
+
self.epsilon_decay_input,
|
150
176
|
),
|
151
|
-
|
152
|
-
|
153
|
-
|
177
|
+
Horizontal(
|
178
|
+
Label(f"{DLabel.EPSILON_MIN}", classes=DLayout.LABEL_SETTINGS),
|
179
|
+
self.epsilon_min_input,
|
154
180
|
),
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
181
|
+
Horizontal(
|
182
|
+
Label(
|
183
|
+
f"{DLabel.MOVE_DELAY}",
|
184
|
+
classes=DLayout.LABEL_SETTINGS,
|
159
185
|
),
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
186
|
+
self.move_delay_input,
|
187
|
+
),
|
188
|
+
id=DLayout.SETTINGS_BOX,
|
189
|
+
)
|
190
|
+
|
191
|
+
# The Snake Game
|
192
|
+
yield Vertical(
|
193
|
+
self.game_board,
|
194
|
+
id=DLayout.GAME_BOX,
|
195
|
+
)
|
196
|
+
|
197
|
+
# Runtime values
|
198
|
+
yield Vertical(
|
199
|
+
Horizontal(
|
200
|
+
Label(f"{DLabel.EPSILON}", classes=DLayout.LABEL),
|
201
|
+
self.cur_epsilon_widget,
|
202
|
+
),
|
203
|
+
Horizontal(
|
204
|
+
Label(f"{DLabel.MEM_TYPE}", classes=DLayout.LABEL),
|
205
|
+
self.cur_mem_type_widget,
|
206
|
+
),
|
207
|
+
Horizontal(
|
208
|
+
Label(f"{DLabel.STORED_GAMES}", classes=DLayout.LABEL),
|
209
|
+
self.cur_num_games_widget,
|
210
|
+
),
|
211
|
+
Horizontal(
|
212
|
+
Label(f"{DLabel.MODEL_TYPE}", classes=DLayout.LABEL),
|
213
|
+
self.cur_model_type_widget,
|
214
|
+
),
|
215
|
+
Horizontal(
|
216
|
+
Label(f"{DLabel.RUNTIME}", classes=DLayout.LABEL),
|
217
|
+
self.cur_runtime_widget,
|
169
218
|
),
|
219
|
+
id=DLayout.RUNTIME_BOX,
|
170
220
|
)
|
171
221
|
|
222
|
+
# Buttons
|
223
|
+
yield Vertical(
|
224
|
+
Horizontal(
|
225
|
+
self.start_button,
|
226
|
+
self.pause_button,
|
227
|
+
self.quit_button,
|
228
|
+
classes=DLayout.BUTTON_ROW,
|
229
|
+
),
|
230
|
+
Horizontal(
|
231
|
+
self.defaults_button,
|
232
|
+
self.update_button,
|
233
|
+
self.restart_button,
|
234
|
+
classes=DLayout.BUTTON_ROW,
|
235
|
+
),
|
236
|
+
)
|
237
|
+
|
238
|
+
# Empty fillers
|
239
|
+
yield Static(id=DLayout.FILLER_1)
|
240
|
+
yield Static(id=DLayout.FILLER_2)
|
241
|
+
yield Static(id=DLayout.FILLER_3)
|
242
|
+
|
243
|
+
# The game score plot
|
244
|
+
yield self.game_score_plot
|
245
|
+
|
172
246
|
def on_mount(self):
|
173
247
|
self.initial_epsilon_input.value = str(DEpsilon.EPSILON_INITIAL)
|
174
248
|
self.epsilon_decay_input.value = str(DEpsilon.EPSILON_DECAY)
|
@@ -177,16 +251,21 @@ class AISim(App):
|
|
177
251
|
settings_box = self.query_one(f"#{DLayout.SETTINGS_BOX}", Vertical)
|
178
252
|
settings_box.border_title = DLabel.SETTINGS
|
179
253
|
runtime_box = self.query_one(f"#{DLayout.RUNTIME_BOX}", Vertical)
|
180
|
-
runtime_box.border_title = DLabel.
|
254
|
+
runtime_box.border_title = DLabel.RUNTIME_VALUES
|
181
255
|
self.cur_mem_type_widget.update(
|
182
256
|
MEM_TYPE.MEM_TYPE_TABLE[self.agent.memory.mem_type()]
|
183
257
|
)
|
184
|
-
self.
|
258
|
+
self.cur_num_games_widget.update(str(self.agent.memory.get_num_games()))
|
185
259
|
# Initial state is that the app is stopped
|
186
260
|
self.add_class(DField.STOPPED)
|
261
|
+
# Register the theme
|
262
|
+
self.register_theme(snake_lab_theme)
|
263
|
+
|
264
|
+
# Set the app's theme
|
265
|
+
self.theme = "db4e"
|
187
266
|
|
188
267
|
def on_quit(self):
|
189
|
-
if self.running ==
|
268
|
+
if self.running == DField.RUNNING:
|
190
269
|
self.stop_event.set()
|
191
270
|
if self.simulator_thread.is_alive():
|
192
271
|
self.simulator_thread.join()
|
@@ -194,22 +273,69 @@ class AISim(App):
|
|
194
273
|
|
195
274
|
def on_button_pressed(self, event: Button.Pressed) -> None:
|
196
275
|
button_id = event.button.id
|
276
|
+
|
277
|
+
# Pause button was pressed
|
278
|
+
if button_id == DLayout.BUTTON_PAUSE:
|
279
|
+
self.pause_event.set()
|
280
|
+
self.running = DField.PAUSED
|
281
|
+
self.remove_class(DField.RUNNING)
|
282
|
+
self.add_class(DField.PAUSED)
|
283
|
+
self.cur_move_delay = float(self.move_delay_input.value)
|
284
|
+
self.cur_model_type_widget.update(self.agent.model_type())
|
285
|
+
|
286
|
+
# Restart button was pressed
|
287
|
+
elif button_id == DLayout.BUTTON_RESTART:
|
288
|
+
self.running = DField.STOPPED
|
289
|
+
self.add_class(DField.STOPPED)
|
290
|
+
self.remove_class(DField.PAUSED)
|
291
|
+
|
292
|
+
# Signal thread to stop
|
293
|
+
self.stop_event.set()
|
294
|
+
# Unpause so we can exit cleanly
|
295
|
+
self.pause_event.clear()
|
296
|
+
# Join the old thread
|
297
|
+
if self.simulator_thread.is_alive():
|
298
|
+
self.simulator_thread.join(timeout=2)
|
299
|
+
|
300
|
+
# Reset the game and the UI
|
301
|
+
self.snake_game.reset()
|
302
|
+
score = 0
|
303
|
+
highscore = 0
|
304
|
+
self.epoch = 1
|
305
|
+
game_box = self.query_one(f"#{DLayout.GAME_BOX}", Vertical)
|
306
|
+
game_box.border_title = ""
|
307
|
+
game_box.border_subtitle = ""
|
308
|
+
|
309
|
+
# Recreate events and get a new thread
|
310
|
+
self.stop_event = threading.Event()
|
311
|
+
self.pause_event = threading.Event()
|
312
|
+
self.simulator_thread = threading.Thread(target=self.start_sim, daemon=True)
|
313
|
+
|
197
314
|
# Start button was pressed
|
198
|
-
|
199
|
-
self.
|
200
|
-
|
315
|
+
elif button_id == DLayout.BUTTON_START:
|
316
|
+
if self.running == DField.STOPPED:
|
317
|
+
self.start_thread()
|
318
|
+
elif self.running == DField.PAUSED:
|
319
|
+
self.pause_event.clear()
|
320
|
+
self.pause_event.clear()
|
321
|
+
self.running = DField.RUNNING
|
201
322
|
self.add_class(DField.RUNNING)
|
202
323
|
self.remove_class(DField.STOPPED)
|
324
|
+
self.remove_class(DField.PAUSED)
|
203
325
|
self.cur_move_delay = float(self.move_delay_input.value)
|
326
|
+
self.cur_model_type_widget.update(self.agent.model_type())
|
327
|
+
|
204
328
|
# Reset button was pressed
|
205
|
-
elif button_id == DLayout.
|
329
|
+
elif button_id == DLayout.BUTTON_DEFAULTS:
|
206
330
|
self.initial_epsilon_input.value = str(DEpsilon.EPSILON_INITIAL)
|
207
331
|
self.epsilon_decay_input.value = str(DEpsilon.EPSILON_DECAY)
|
208
332
|
self.epsilon_min_input.value = str(DEpsilon.EPSILON_MIN)
|
209
333
|
self.move_delay_input.value = str(DDef.MOVE_DELAY)
|
334
|
+
|
210
335
|
# Quit button was pressed
|
211
336
|
elif button_id == DLayout.BUTTON_QUIT:
|
212
337
|
self.on_quit()
|
338
|
+
|
213
339
|
# Update button was pressed
|
214
340
|
elif button_id == DLayout.BUTTON_UPDATE:
|
215
341
|
self.cur_move_delay = float(self.move_delay_input.value)
|
@@ -224,8 +350,13 @@ class AISim(App):
|
|
224
350
|
self.epoch = 1
|
225
351
|
game_box = self.query_one(f"#{DLayout.GAME_BOX}", Vertical)
|
226
352
|
game_box.border_title = f"{DLabel.GAME} #{self.epoch}"
|
353
|
+
start_time = datetime.now()
|
227
354
|
|
228
355
|
while not self.stop_event.is_set():
|
356
|
+
if self.pause_event.is_set():
|
357
|
+
self.pause_event.wait()
|
358
|
+
time.sleep(0.2)
|
359
|
+
continue
|
229
360
|
# The actual training loop...
|
230
361
|
old_state = game_board.get_state()
|
231
362
|
move = agent.get_move(old_state)
|
@@ -261,14 +392,40 @@ class AISim(App):
|
|
261
392
|
else:
|
262
393
|
self.cur_epsilon_widget.update(str(round(cur_epsilon, 4)))
|
263
394
|
# Update the number of stored memories
|
264
|
-
self.
|
265
|
-
|
266
|
-
)
|
395
|
+
self.cur_num_games_widget.update(str(self.agent.memory.get_num_games()))
|
396
|
+
# Update the stats object
|
397
|
+
self.stats[DField.GAME_SCORE][DField.GAME_NUM].append(self.epoch)
|
398
|
+
self.stats[DField.GAME_SCORE][DField.GAME_SCORE].append(score)
|
399
|
+
# Update the plot object
|
400
|
+
self.game_score_plot.add_data(self.epoch, score)
|
401
|
+
self.game_score_plot.db4e_plot()
|
402
|
+
# Update the runtime widget
|
403
|
+
elapsed_secs = (datetime.now() - start_time).total_seconds()
|
404
|
+
runtime = minutes_to_uptime(elapsed_secs)
|
405
|
+
self.cur_runtime_widget.update(runtime)
|
267
406
|
|
268
407
|
def start_thread(self):
|
269
408
|
self.simulator_thread.start()
|
270
409
|
|
271
410
|
|
411
|
+
def minutes_to_uptime(seconds: int):
|
412
|
+
# Return a string like:
|
413
|
+
# 0h 0m 45s
|
414
|
+
# 1d 7h 32m
|
415
|
+
days, minutes = divmod(int(seconds), 86400)
|
416
|
+
hours, minutes = divmod(minutes, 3600)
|
417
|
+
minutes, seconds = divmod(minutes, 60)
|
418
|
+
|
419
|
+
if days > 0:
|
420
|
+
return f"{days}d {hours}h {minutes}m"
|
421
|
+
elif hours > 0:
|
422
|
+
return f"{hours}h {minutes}m"
|
423
|
+
elif minutes > 0:
|
424
|
+
return f"{minutes}m {seconds}s"
|
425
|
+
else:
|
426
|
+
return f"{seconds}s"
|
427
|
+
|
428
|
+
|
272
429
|
if __name__ == "__main__":
|
273
430
|
app = AISim()
|
274
431
|
app.run()
|
ai_snake_lab/ai/AIAgent.py
CHANGED
@@ -9,13 +9,14 @@ ai/Agent.py
|
|
9
9
|
"""
|
10
10
|
|
11
11
|
import torch
|
12
|
-
from ai.EpsilonAlgo import EpsilonAlgo
|
13
|
-
from ai.ReplayMemory import ReplayMemory
|
14
|
-
from ai.AITrainer import AITrainer
|
15
|
-
from ai.models.ModelL import ModelL
|
16
|
-
from ai.models.ModelRNN import ModelRNN
|
12
|
+
from ai_snake_lab.ai.EpsilonAlgo import EpsilonAlgo
|
13
|
+
from ai_snake_lab.ai.ReplayMemory import ReplayMemory
|
14
|
+
from ai_snake_lab.ai.AITrainer import AITrainer
|
15
|
+
from ai_snake_lab.ai.models.ModelL import ModelL
|
16
|
+
from ai_snake_lab.ai.models.ModelRNN import ModelRNN
|
17
17
|
|
18
|
-
from constants.DReplayMemory import MEM_TYPE
|
18
|
+
from ai_snake_lab.constants.DReplayMemory import MEM_TYPE
|
19
|
+
from ai_snake_lab.constants.DLabels import DLabel
|
19
20
|
|
20
21
|
|
21
22
|
class AIAgent:
|
@@ -23,16 +24,13 @@ class AIAgent:
|
|
23
24
|
def __init__(self, epsilon_algo: EpsilonAlgo, seed: int):
|
24
25
|
self.epsilon_algo = epsilon_algo
|
25
26
|
self.memory = ReplayMemory(seed=seed)
|
26
|
-
self.
|
27
|
-
|
28
|
-
self.trainer = AITrainer(self.
|
27
|
+
# self._model = ModelL(seed=seed)
|
28
|
+
self._model = ModelRNN(seed=seed)
|
29
|
+
self.trainer = AITrainer(model=self._model)
|
29
30
|
|
30
|
-
if type(self.
|
31
|
+
if type(self._model) == ModelRNN:
|
31
32
|
self.memory.mem_type(MEM_TYPE.RANDOM_GAME)
|
32
33
|
|
33
|
-
def get_model(self):
|
34
|
-
return self.model
|
35
|
-
|
36
34
|
def get_move(self, state):
|
37
35
|
random_move = self.epsilon_algo.get_move() # Explore with epsilon
|
38
36
|
if random_move != False:
|
@@ -42,7 +40,7 @@ class AIAgent:
|
|
42
40
|
final_move = [0, 0, 0]
|
43
41
|
if type(state) != torch.Tensor:
|
44
42
|
state = torch.tensor(state, dtype=torch.float) # Convert to a tensor
|
45
|
-
prediction = self.
|
43
|
+
prediction = self._model(state) # Get the prediction
|
46
44
|
move = torch.argmax(prediction).item() # Select the move with the highest value
|
47
45
|
final_move[move] = 1 # Set the move
|
48
46
|
return final_move # Return
|
@@ -50,6 +48,15 @@ class AIAgent:
|
|
50
48
|
def get_optimizer(self):
|
51
49
|
return self.trainer.get_optimizer()
|
52
50
|
|
51
|
+
def model_type(self):
|
52
|
+
if type(self._model) == ModelL:
|
53
|
+
return DLabel.MODEL_LINEAR
|
54
|
+
elif type(self._model) == ModelRNN:
|
55
|
+
return DLabel.MODEL_RNN
|
56
|
+
|
57
|
+
def model(self):
|
58
|
+
return self._model
|
59
|
+
|
53
60
|
def played_game(self, score):
|
54
61
|
self.epsilon_algo.played_game()
|
55
62
|
|
@@ -57,27 +64,23 @@ class AIAgent:
|
|
57
64
|
# Store the state, action, reward, next_state, and done in memory
|
58
65
|
self.memory.append((state, action, reward, next_state, done))
|
59
66
|
|
60
|
-
def set_model(self, model):
|
61
|
-
self.model = model
|
62
|
-
|
63
67
|
def set_optimizer(self, optimizer):
|
64
68
|
self.trainer.set_optimizer(optimizer)
|
65
69
|
|
66
70
|
def train_long_memory(self):
|
67
|
-
#
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
for state, action, reward, next_state, done in
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
for state, action, reward, next_state, done in memory[0]:
|
71
|
+
# Train on 5 games
|
72
|
+
max_games = 2
|
73
|
+
# Get a random full game
|
74
|
+
while max_games > 0:
|
75
|
+
max_games -= 1
|
76
|
+
game = self.memory.get_random_game()
|
77
|
+
if not game:
|
78
|
+
return # no games to train on yet
|
79
|
+
|
80
|
+
for count, (state, action, reward, next_state, done) in enumerate(
|
81
|
+
game, start=1
|
82
|
+
):
|
83
|
+
# print(f"Move #{count}: {action}")
|
81
84
|
self.trainer.train_step(state, action, reward, next_state, [done])
|
82
85
|
|
83
86
|
def train_short_memory(self, state, action, reward, next_state, done):
|
ai_snake_lab/ai/AITrainer.py
CHANGED
@@ -15,23 +15,27 @@ import numpy as np
|
|
15
15
|
import time
|
16
16
|
import sys
|
17
17
|
|
18
|
-
from ai.models.ModelL import ModelL
|
19
|
-
from ai.models.ModelRNN import ModelRNN
|
18
|
+
from ai_snake_lab.ai.models.ModelL import ModelL
|
19
|
+
from ai_snake_lab.ai.models.ModelRNN import ModelRNN
|
20
20
|
|
21
|
-
from constants.DModelL import DModelL
|
22
|
-
from constants.DModelLRNN import DModelRNN
|
21
|
+
from ai_snake_lab.constants.DModelL import DModelL
|
22
|
+
from ai_snake_lab.constants.DModelLRNN import DModelRNN
|
23
23
|
|
24
24
|
|
25
25
|
class AITrainer:
|
26
26
|
|
27
27
|
def __init__(self, model):
|
28
28
|
torch.manual_seed(1970)
|
29
|
-
|
29
|
+
|
30
30
|
# The learning rate needs to be adjusted for the model type
|
31
31
|
if type(model) == ModelL:
|
32
32
|
learning_rate = DModelL.LEARNING_RATE
|
33
33
|
elif type(model) == ModelRNN:
|
34
34
|
learning_rate = DModelRNN.LEARNING_RATE
|
35
|
+
else:
|
36
|
+
raise ValueError(f"Unknown model type: {type(model)}")
|
37
|
+
|
38
|
+
self.model = model
|
35
39
|
self.optimizer = optim.Adam(self.model.parameters(), lr=learning_rate)
|
36
40
|
self.criterion = nn.MSELoss()
|
37
41
|
self.gamma = 0.9
|