ai-snake-lab 0.1.0__py3-none-any.whl → 0.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_snake_lab/AISim.py +243 -86
- ai_snake_lab/ai/AIAgent.py +34 -31
- ai_snake_lab/ai/AITrainer.py +9 -5
- ai_snake_lab/ai/ReplayMemory.py +61 -24
- ai_snake_lab/ai/models/ModelL.py +7 -4
- ai_snake_lab/ai/models/ModelRNN.py +1 -1
- ai_snake_lab/constants/DDb4EPlot.py +20 -0
- ai_snake_lab/constants/DDef.py +2 -2
- ai_snake_lab/constants/DDir.py +5 -2
- ai_snake_lab/constants/DEpsilon.py +1 -1
- ai_snake_lab/constants/DFields.py +7 -1
- ai_snake_lab/constants/DFile.py +3 -2
- ai_snake_lab/constants/DLabels.py +24 -4
- ai_snake_lab/constants/DLayout.py +17 -3
- ai_snake_lab/constants/DModelL.py +5 -1
- ai_snake_lab/constants/DModelLRNN.py +1 -1
- ai_snake_lab/constants/DReplayMemory.py +1 -1
- ai_snake_lab/constants/DSim.py +20 -0
- ai_snake_lab/game/GameBoard.py +36 -22
- ai_snake_lab/game/SnakeGame.py +17 -0
- ai_snake_lab/ui/Db4EPlot.py +160 -0
- ai_snake_lab/utils/AISim.tcss +81 -38
- {ai_snake_lab-0.1.0.dist-info → ai_snake_lab-0.4.4.dist-info}/METADATA +39 -5
- ai_snake_lab-0.4.4.dist-info/RECORD +31 -0
- {ai_snake_lab-0.1.0.dist-info → ai_snake_lab-0.4.4.dist-info}/WHEEL +1 -1
- {ai_snake_lab-0.1.0.dist-info → ai_snake_lab-0.4.4.dist-info/licenses}/LICENSE +2 -0
- ai_snake_lab-0.1.0.dist-info/RECORD +0 -28
- {ai_snake_lab-0.1.0.dist-info → ai_snake_lab-0.4.4.dist-info}/entry_points.txt +0 -0
ai_snake_lab/ai/ReplayMemory.py
CHANGED
@@ -10,10 +10,14 @@ ai/ReplayMemory.py
|
|
10
10
|
This file contains the ReplayMemory class.
|
11
11
|
"""
|
12
12
|
|
13
|
+
import os
|
13
14
|
from collections import deque
|
14
15
|
import random, sys
|
16
|
+
import sqlite3, pickle
|
15
17
|
|
16
18
|
from constants.DReplayMemory import MEM_TYPE
|
19
|
+
from constants.DFile import DFile
|
20
|
+
from constants.DDir import DDir
|
17
21
|
|
18
22
|
|
19
23
|
class ReplayMemory:
|
@@ -27,6 +31,11 @@ class ReplayMemory:
|
|
27
31
|
self.max_states = 15000
|
28
32
|
self.max_shuffle_games = 40
|
29
33
|
self.max_games = 500
|
34
|
+
self.db_file = os.path.join(DDir.AI_SNAKE_LAB, DDir.DB, DFile.REPLAY_DB)
|
35
|
+
|
36
|
+
# Delete the replay memory file, if it exists
|
37
|
+
if os.path.exists(self.db_file):
|
38
|
+
os.remove(self.db_file)
|
30
39
|
|
31
40
|
if self._mem_type == MEM_TYPE.SHUFFLE:
|
32
41
|
# States are stored in a deque and a random sample will be returned
|
@@ -35,35 +44,50 @@ class ReplayMemory:
|
|
35
44
|
elif self._mem_type == MEM_TYPE.RANDOM_GAME:
|
36
45
|
# All of the states for a game are stored, in order, in a deque.
|
37
46
|
# A complete game will be returned
|
38
|
-
self.memories = deque(maxlen=self.max_shuffle_games)
|
39
47
|
self.cur_memory = []
|
40
48
|
|
41
|
-
|
42
|
-
|
43
|
-
|
49
|
+
# Connect to SQLite
|
50
|
+
self.conn = sqlite3.connect(self.db_file, check_same_thread=False)
|
51
|
+
self.cursor = self.conn.cursor()
|
52
|
+
self.init_db()
|
44
53
|
|
45
|
-
def
|
46
|
-
|
54
|
+
def __len__(self):
|
55
|
+
return len(self.memories)
|
47
56
|
|
48
|
-
|
49
|
-
|
50
|
-
|
57
|
+
def append(self, transition):
|
58
|
+
"""Add a transition to the current game."""
|
59
|
+
if self._mem_type != MEM_TYPE.RANDOM_GAME:
|
60
|
+
raise NotImplementedError(
|
61
|
+
"Only RANDOM_GAME memory type is implemented for SQLite backend"
|
62
|
+
)
|
63
|
+
|
64
|
+
self.cur_memory.append(transition)
|
65
|
+
_, _, _, _, done = transition
|
66
|
+
|
67
|
+
if done:
|
68
|
+
# Serialize the full game to JSON
|
69
|
+
serialized = pickle.dumps(self.cur_memory)
|
70
|
+
self.cursor.execute(
|
71
|
+
"INSERT INTO games (transitions) VALUES (?)", (serialized,)
|
72
|
+
)
|
73
|
+
self.conn.commit()
|
74
|
+
self.cur_memory = []
|
51
75
|
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
self.cur_memory.append(transition)
|
56
|
-
state, action, reward, next_state, done = transition
|
57
|
-
if done:
|
58
|
-
self.memories.append(self.cur_memory)
|
59
|
-
self.cur_memory = []
|
76
|
+
def close(self):
|
77
|
+
"""Close the database connection."""
|
78
|
+
self.conn.close()
|
60
79
|
|
61
80
|
def get_random_game(self):
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
81
|
+
"""Return a random full game from the database."""
|
82
|
+
self.cursor.execute("SELECT id FROM games")
|
83
|
+
all_ids = [row[0] for row in self.cursor.fetchall()]
|
84
|
+
if len(all_ids) >= self.min_games:
|
85
|
+
rand_id = random.choice(all_ids)
|
86
|
+
self.cursor.execute("SELECT transitions FROM games WHERE id=?", (rand_id,))
|
87
|
+
row = self.cursor.fetchone()
|
88
|
+
if row:
|
89
|
+
return pickle.loads(row[0])
|
90
|
+
return False
|
67
91
|
|
68
92
|
def get_random_states(self):
|
69
93
|
mem_size = len(self.memories)
|
@@ -78,8 +102,21 @@ class ReplayMemory:
|
|
78
102
|
elif self._mem_type == MEM_TYPE.RANDOM_GAME:
|
79
103
|
return self.get_random_game()
|
80
104
|
|
81
|
-
def
|
82
|
-
|
105
|
+
def get_num_games(self):
|
106
|
+
"""Return number of games stored in the database."""
|
107
|
+
self.cursor.execute("SELECT COUNT(*) FROM games")
|
108
|
+
return self.cursor.fetchone()[0]
|
109
|
+
|
110
|
+
def init_db(self):
|
111
|
+
self.cursor.execute(
|
112
|
+
"""
|
113
|
+
CREATE TABLE IF NOT EXISTS games (
|
114
|
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
115
|
+
transitions TEXT NOT NULL
|
116
|
+
)
|
117
|
+
"""
|
118
|
+
)
|
119
|
+
self.conn.commit()
|
83
120
|
|
84
121
|
def mem_type(self, mem_type=None):
|
85
122
|
if mem_type is not None:
|
ai_snake_lab/ai/models/ModelL.py
CHANGED
@@ -12,15 +12,18 @@ import torch
|
|
12
12
|
import torch.nn as nn
|
13
13
|
import torch.nn.functional as F
|
14
14
|
|
15
|
+
from ai_snake_lab.constants.DSim import DSim
|
16
|
+
from ai_snake_lab.constants.DModelL import DModelL
|
17
|
+
|
15
18
|
|
16
19
|
class ModelL(nn.Module):
|
17
20
|
def __init__(self, seed: int):
|
18
21
|
super(ModelL, self).__init__()
|
19
22
|
torch.manual_seed(seed)
|
20
|
-
input_size =
|
21
|
-
hidden_size =
|
22
|
-
output_size =
|
23
|
-
p_value =
|
23
|
+
input_size = DSim.STATE_SIZE # Size of the "state" as tracked by the GameBoard
|
24
|
+
hidden_size = DModelL.HIDDEN_SIZE
|
25
|
+
output_size = DSim.OUTPUT_SIZE
|
26
|
+
p_value = DModelL.P_VALUE
|
24
27
|
self.input_block = nn.Sequential(
|
25
28
|
nn.Linear(input_size, hidden_size),
|
26
29
|
nn.ReLU(),
|
@@ -0,0 +1,20 @@
|
|
1
|
+
"""
|
2
|
+
constants/DDb4EPlot.py
|
3
|
+
|
4
|
+
AI Snake Game Simulator
|
5
|
+
Author: Nadim-Daniel Ghaznavi
|
6
|
+
Copyright: (c) 2024-2025 Nadim-Daniel Ghaznavi
|
7
|
+
GitHub: https://github.com/NadimGhaznavi/ai
|
8
|
+
License: GPL 3.0
|
9
|
+
"""
|
10
|
+
|
11
|
+
from ai_snake_lab.utils.ConstGroup import ConstGroup
|
12
|
+
|
13
|
+
|
14
|
+
class Plot(ConstGroup):
|
15
|
+
"""Db4EPlot Constants"""
|
16
|
+
|
17
|
+
# Simulation loop states
|
18
|
+
AVERAGE: str = "average"
|
19
|
+
SLIDING: str = "sliding"
|
20
|
+
MAX_DATA_POINTS: int = 200
|
ai_snake_lab/constants/DDef.py
CHANGED
@@ -8,11 +8,11 @@ constants/DDef.py
|
|
8
8
|
License: GPL 3.0
|
9
9
|
"""
|
10
10
|
|
11
|
-
from utils.ConstGroup import ConstGroup
|
11
|
+
from ai_snake_lab.utils.ConstGroup import ConstGroup
|
12
12
|
|
13
13
|
|
14
14
|
class DDef(ConstGroup):
|
15
15
|
"""Defaults"""
|
16
16
|
|
17
|
-
APP_TITLE: str = "AI Snake Game
|
17
|
+
APP_TITLE: str = "AI Snake Game Lab"
|
18
18
|
MOVE_DELAY: float = 0.0
|
ai_snake_lab/constants/DDir.py
CHANGED
@@ -8,9 +8,12 @@ constants/DDir.py
|
|
8
8
|
License: GPL 3.0
|
9
9
|
"""
|
10
10
|
|
11
|
-
from utils.ConstGroup import ConstGroup
|
11
|
+
from ai_snake_lab.utils.ConstGroup import ConstGroup
|
12
12
|
|
13
13
|
|
14
14
|
class DDir(ConstGroup):
|
15
15
|
"""Directories"""
|
16
|
-
|
16
|
+
|
17
|
+
AI_SNAKE_LAB: str = "ai_snake_lab"
|
18
|
+
DB: str = "db"
|
19
|
+
UTILS: str = "utils"
|
@@ -8,11 +8,17 @@ constants/DFields.py
|
|
8
8
|
License: GPL 3.0
|
9
9
|
"""
|
10
10
|
|
11
|
-
from utils.ConstGroup import ConstGroup
|
11
|
+
from ai_snake_lab.utils.ConstGroup import ConstGroup
|
12
12
|
|
13
13
|
|
14
14
|
class DField(ConstGroup):
|
15
15
|
"""Fields"""
|
16
16
|
|
17
|
+
# Simulation loop states
|
18
|
+
PAUSED: str = "paused"
|
17
19
|
RUNNING: str = "running"
|
18
20
|
STOPPED: str = "stopped"
|
21
|
+
|
22
|
+
# Stats dictionary keys
|
23
|
+
GAME_SCORE: str = "game_score"
|
24
|
+
GAME_NUM: str = "game_num"
|
ai_snake_lab/constants/DFile.py
CHANGED
@@ -8,10 +8,11 @@ constants/DFile.py
|
|
8
8
|
License: GPL 3.0
|
9
9
|
"""
|
10
10
|
|
11
|
-
from utils.ConstGroup import ConstGroup
|
11
|
+
from ai_snake_lab.utils.ConstGroup import ConstGroup
|
12
12
|
|
13
13
|
|
14
14
|
class DFile(ConstGroup):
|
15
15
|
"""Files"""
|
16
16
|
|
17
|
-
|
17
|
+
CSS_FILE: str = "AISim.tcss"
|
18
|
+
REPLAY_DB: str = "replay_mem.db"
|
@@ -8,27 +8,47 @@ constants/DLabels.py
|
|
8
8
|
License: GPL 3.0
|
9
9
|
"""
|
10
10
|
|
11
|
-
from
|
11
|
+
from ai_snake_lab.ai.models.ModelL import ModelL
|
12
|
+
from ai_snake_lab.ai.models.ModelRNN import ModelRNN
|
13
|
+
|
14
|
+
from ai_snake_lab.utils.ConstGroup import ConstGroup
|
12
15
|
|
13
16
|
|
14
17
|
class DLabel(ConstGroup):
|
15
18
|
"""Labels"""
|
16
19
|
|
20
|
+
AVERAGE: str = "Average"
|
21
|
+
CURRENT: str = "Current"
|
22
|
+
CURRENT_EPSILON: str = "Current Epsilon"
|
23
|
+
DEFAULTS: str = "Defaults"
|
17
24
|
EPSILON: str = "Epsilon"
|
18
25
|
EPSILON_DECAY: str = "Epsilon Decay"
|
19
26
|
EPSILON_INITIAL: str = "Initial Epsilon"
|
20
27
|
EPSILON_MIN: str = "Minimum Epsilon"
|
21
28
|
GAME: str = "Game"
|
29
|
+
GAMES: str = "Games"
|
30
|
+
GAME_SCORE: str = "Game Score"
|
31
|
+
GAME_NUM: str = "Game Number"
|
22
32
|
HIGHSCORE: str = "Highscore"
|
23
33
|
MEM_TYPE: str = "Memory Type"
|
24
|
-
MEMORIES: str = "Memories"
|
25
34
|
MIN_EPSILON: str = "Minimum Epsilon"
|
35
|
+
MODEL_LINEAR: str = "Linear"
|
36
|
+
MODEL_RNN: str = "RNN"
|
37
|
+
MODEL_TYPE: str = "Model Type"
|
26
38
|
MOVE_DELAY: str = "Move Delay"
|
27
39
|
PAUSE: str = "Pause"
|
28
40
|
QUIT: str = "Quit"
|
29
|
-
|
41
|
+
RESTART: str = "Restart"
|
42
|
+
RUNTIME: str = "Runtime"
|
43
|
+
RUNTIME_VALUES: str = "Runtime Values"
|
30
44
|
SCORE: str = "Score"
|
31
45
|
SETTINGS: str = "Configuration Settings"
|
32
46
|
START: str = "Start"
|
33
|
-
|
47
|
+
STORED_GAMES: str = "Stored Games"
|
48
|
+
RESTART: str = "Restart"
|
34
49
|
UPDATE: str = "Update"
|
50
|
+
|
51
|
+
MODEL_TYPE_TABLE: dict = {
|
52
|
+
str(ModelL): MODEL_LINEAR,
|
53
|
+
ModelRNN: MODEL_RNN,
|
54
|
+
}
|
@@ -8,22 +8,35 @@ constants/DLayout.py
|
|
8
8
|
License: GPL 3.0
|
9
9
|
"""
|
10
10
|
|
11
|
-
from utils.ConstGroup import ConstGroup
|
11
|
+
from ai_snake_lab.utils.ConstGroup import ConstGroup
|
12
12
|
|
13
13
|
|
14
14
|
class DLayout(ConstGroup):
|
15
15
|
"""Layout"""
|
16
16
|
|
17
|
+
BUTTON_BOX: str = "button_box"
|
17
18
|
BUTTON_PAUSE: str = "button_pause"
|
18
19
|
BUTTON_QUIT: str = "button_quit"
|
20
|
+
BUTTON_RESTART: str = "button_restart"
|
19
21
|
BUTTON_ROW: str = "button_row"
|
20
22
|
BUTTON_START: str = "button_start"
|
21
|
-
|
23
|
+
BUTTON_DEFAULTS: str = "button_defaults"
|
22
24
|
BUTTON_UPDATE: str = "button_update"
|
23
25
|
CUR_EPSILON: str = "cur_epsilon"
|
24
26
|
CUR_MEM_TYPE: str = "cur_mem_type"
|
27
|
+
CUR_MODEL_TYPE: str = "cur_model_type"
|
28
|
+
FILLER_1: str = "filler_1"
|
29
|
+
FILLER_2: str = "filler_2"
|
30
|
+
FILLER_3: str = "filler_3"
|
31
|
+
FILLER_4: str = "filler_4"
|
32
|
+
FILLER_5: str = "filler_5"
|
33
|
+
FILLER_6: str = "filler_6"
|
34
|
+
FILLER_7: str = "filler_7"
|
35
|
+
FILLER_8: str = "filler_8"
|
25
36
|
GAME_BOARD: str = "game_board"
|
26
37
|
GAME_BOX: str = "game_box"
|
38
|
+
GAME_SCORE: str = "game_score"
|
39
|
+
GAME_SCORE_PLOT: str = "game_score_plot"
|
27
40
|
EPSILON_DECAY: str = "epsilon_decay"
|
28
41
|
EPSILON_INITIAL: str = "initial_epsilon"
|
29
42
|
EPSILON_MIN: str = "epsilon_min"
|
@@ -31,8 +44,9 @@ class DLayout(ConstGroup):
|
|
31
44
|
LABEL: str = "label"
|
32
45
|
LABEL_SETTINGS: str = "label_settings"
|
33
46
|
MOVE_DELAY: str = "move_delay"
|
34
|
-
|
47
|
+
NUM_GAMES: str = "num_games"
|
35
48
|
RUNTIME_BOX: str = "runtime_box"
|
49
|
+
RUNTIME: str = "runtime"
|
36
50
|
SCORE: str = "score"
|
37
51
|
SETTINGS_BOX: str = "settings_box"
|
38
52
|
TITLE: str = "title"
|
@@ -8,10 +8,14 @@ constants/DModelL.py
|
|
8
8
|
License: GPL 3.0
|
9
9
|
"""
|
10
10
|
|
11
|
-
from utils.ConstGroup import ConstGroup
|
11
|
+
from ai_snake_lab.utils.ConstGroup import ConstGroup
|
12
12
|
|
13
13
|
|
14
14
|
class DModelL(ConstGroup):
|
15
15
|
"""Linear Model Defaults"""
|
16
16
|
|
17
17
|
LEARNING_RATE: float = 0.000009
|
18
|
+
# The number of nodes in the hidden layer
|
19
|
+
HIDDEN_SIZE: int = 170
|
20
|
+
# The dropout value, 0.2 represents 20%
|
21
|
+
P_VALUE: float = 0.2
|
@@ -0,0 +1,20 @@
|
|
1
|
+
"""
|
2
|
+
constants/DGameBoard.py
|
3
|
+
|
4
|
+
AI Snake Game Simulator
|
5
|
+
Author: Nadim-Daniel Ghaznavi
|
6
|
+
Copyright: (c) 2024-2025 Nadim-Daniel Ghaznavi
|
7
|
+
GitHub: https://github.com/NadimGhaznavi/ai
|
8
|
+
License: GPL 3.0
|
9
|
+
"""
|
10
|
+
|
11
|
+
from ai_snake_lab.utils.ConstGroup import ConstGroup
|
12
|
+
|
13
|
+
|
14
|
+
class DSim(ConstGroup):
|
15
|
+
"""Simulation Constants"""
|
16
|
+
|
17
|
+
# Size of the statemap, this is from the GameBoard class
|
18
|
+
STATE_SIZE: int = 30
|
19
|
+
# The number of "choices" the snake has: go forward, left or right.
|
20
|
+
OUTPUT_SIZE: int = 3
|
ai_snake_lab/game/GameBoard.py
CHANGED
@@ -68,6 +68,19 @@ class GameBoard(ScrollView):
|
|
68
68
|
def board_size(self) -> int:
|
69
69
|
return self._board_size
|
70
70
|
|
71
|
+
def get_binary(self, bits_needed, some_int):
|
72
|
+
# This is used in the state map, the get_state() function.
|
73
|
+
some_int = int(some_int)
|
74
|
+
bin_str = format(some_int, "b")
|
75
|
+
out_list = []
|
76
|
+
for bit in range(len(bin_str)):
|
77
|
+
out_list.append(bin_str[bit])
|
78
|
+
for zero in range(bits_needed - len(out_list)):
|
79
|
+
out_list.insert(0, "0")
|
80
|
+
for x in range(bits_needed):
|
81
|
+
out_list[x] = int(out_list[x])
|
82
|
+
return out_list
|
83
|
+
|
71
84
|
def get_state(self):
|
72
85
|
|
73
86
|
head = self.snake_head
|
@@ -80,7 +93,7 @@ class GameBoard(ScrollView):
|
|
80
93
|
dir_r = direction == Direction.RIGHT
|
81
94
|
dir_u = direction == Direction.UP
|
82
95
|
dir_d = direction == Direction.DOWN
|
83
|
-
|
96
|
+
slb = self.get_binary(7, len(self.snake_body))
|
84
97
|
state = [
|
85
98
|
# 1. Snake collision straight ahead
|
86
99
|
(dir_r and self.is_snake_collision(point_r))
|
@@ -97,47 +110,48 @@ class GameBoard(ScrollView):
|
|
97
110
|
or (dir_u and self.is_snake_collision(point_l))
|
98
111
|
or (dir_r and self.is_snake_collision(point_u))
|
99
112
|
or (dir_l and self.is_snake_collision(point_d)),
|
100
|
-
# 4.
|
101
|
-
0,
|
102
|
-
# 5. Wall collision straight ahead
|
113
|
+
# 4. Wall collision straight ahead
|
103
114
|
(dir_r and self.is_wall_collision(point_r))
|
104
115
|
or (dir_l and self.is_wall_collision(point_l))
|
105
116
|
or (dir_u and self.is_wall_collision(point_u))
|
106
117
|
or (dir_d and self.is_wall_collision(point_d)),
|
107
|
-
#
|
118
|
+
# 5. Wall collision to the right
|
108
119
|
(dir_u and self.is_wall_collision(point_r))
|
109
120
|
or (dir_d and self.is_wall_collision(point_l))
|
110
121
|
or (dir_l and self.is_wall_collision(point_u))
|
111
122
|
or (dir_r and self.is_wall_collision(point_d)),
|
112
|
-
#
|
123
|
+
# 6. Wall collision to the left
|
113
124
|
(dir_d and self.is_wall_collision(point_r))
|
114
125
|
or (dir_u and self.is_wall_collision(point_l))
|
115
126
|
or (dir_r and self.is_wall_collision(point_u))
|
116
127
|
or (dir_l and self.is_wall_collision(point_d)),
|
117
|
-
#
|
118
|
-
0,
|
119
|
-
# 9, 10, 11, 12. Last move direction
|
128
|
+
# 7 - 10. Last move direction
|
120
129
|
dir_l,
|
121
130
|
dir_r,
|
122
131
|
dir_u,
|
123
132
|
dir_d,
|
124
|
-
#
|
125
|
-
|
126
|
-
|
127
|
-
self.food.
|
128
|
-
self.food.
|
129
|
-
self.food.
|
130
|
-
self.food.y > self.snake_head.y, # Food down
|
131
|
-
self.food.x == self.snake_head.x,
|
133
|
+
# 11 - 19. Food location
|
134
|
+
self.food.x < self.snake_head.x, # 11. Food left
|
135
|
+
self.food.x > self.snake_head.x, # 12. Food right
|
136
|
+
self.food.y < self.snake_head.y, # 13. Food up
|
137
|
+
self.food.y > self.snake_head.y, # 14. Food down
|
138
|
+
self.food.x == self.snake_head.x, # 15.
|
132
139
|
self.food.x == self.snake_head.x
|
133
|
-
and self.food.y > self.snake_head.y, # Food ahead
|
140
|
+
and self.food.y > self.snake_head.y, # 16. Food ahead
|
134
141
|
self.food.x == self.snake_head.x
|
135
|
-
and self.food.y < self.snake_head.y, # Food behind
|
136
|
-
self.food.y == self.snake_head.y,
|
142
|
+
and self.food.y < self.snake_head.y, # 17. Food behind
|
137
143
|
self.food.y == self.snake_head.y
|
138
|
-
and self.food.x > self.snake_head.x, # Food above
|
144
|
+
and self.food.x > self.snake_head.x, # 18. Food above
|
139
145
|
self.food.y == self.snake_head.y
|
140
|
-
and self.food.x < self.snake_head.x, # Food below
|
146
|
+
and self.food.x < self.snake_head.x, # 19. Food below
|
147
|
+
# 20 - 26. Snake length in binary
|
148
|
+
slb[0],
|
149
|
+
slb[1],
|
150
|
+
slb[2],
|
151
|
+
slb[3],
|
152
|
+
slb[4],
|
153
|
+
slb[5],
|
154
|
+
slb[6],
|
141
155
|
]
|
142
156
|
|
143
157
|
# 24, 25, 26 and 27. Previous direction of the snake
|
ai_snake_lab/game/SnakeGame.py
CHANGED
@@ -54,6 +54,9 @@ class SnakeGame:
|
|
54
54
|
self.game_board.update_snake(snake=self.snake, direction=self.direction)
|
55
55
|
self.game_board.update_food(food=self.food)
|
56
56
|
|
57
|
+
# Track the distance from the snake head to the food to feed the reward system
|
58
|
+
self.distance_to_food = self.game_board.board_size() // 2
|
59
|
+
|
57
60
|
# The current game score
|
58
61
|
self.game_score = 0
|
59
62
|
|
@@ -126,6 +129,12 @@ class SnakeGame:
|
|
126
129
|
game_over = True
|
127
130
|
reward = -10
|
128
131
|
|
132
|
+
# Set a negative reward if the snake head is adjacent to the snake body.
|
133
|
+
# This is to discourage snake collisions.
|
134
|
+
for segment in self.snake[1:]:
|
135
|
+
if abs(self.head.x - segment.x) < 2 and abs(self.head.y - segment.y) < 2:
|
136
|
+
reward -= -1
|
137
|
+
|
129
138
|
if game_over == True:
|
130
139
|
# Game is over: Snake or wall collision or exceeded max moves
|
131
140
|
self.game_reward += reward
|
@@ -142,6 +151,14 @@ class SnakeGame:
|
|
142
151
|
else:
|
143
152
|
self.snake.pop()
|
144
153
|
|
154
|
+
## 5. See if we're closer to the food than the last move, or further away
|
155
|
+
cur_distance = abs(self.head.x - self.food.x) + abs(self.head.y - self.food.y)
|
156
|
+
if cur_distance < self.distance_to_food:
|
157
|
+
reward += 2
|
158
|
+
elif cur_distance > self.distance_to_food:
|
159
|
+
reward -= 2
|
160
|
+
self.distance_to_food = cur_distance
|
161
|
+
|
145
162
|
self.game_reward += reward
|
146
163
|
self.game_board.update_snake(snake=self.snake, direction=self.direction)
|
147
164
|
self.game_board.update_food(food=self.food)
|