trianglengin 1.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/__init__.py +0 -0
- tests/conftest.py +108 -0
- tests/core/__init__.py +2 -0
- tests/core/environment/README.md +47 -0
- tests/core/environment/__init__.py +2 -0
- tests/core/environment/test_action_codec.py +50 -0
- tests/core/environment/test_game_state.py +483 -0
- tests/core/environment/test_grid_data.py +205 -0
- tests/core/environment/test_grid_logic.py +362 -0
- tests/core/environment/test_shape_logic.py +171 -0
- tests/core/environment/test_step.py +372 -0
- tests/core/structs/__init__.py +0 -0
- tests/core/structs/test_shape.py +83 -0
- tests/core/structs/test_triangle.py +97 -0
- tests/utils/__init__.py +0 -0
- tests/utils/test_geometry.py +93 -0
- trianglengin/__init__.py +18 -0
- trianglengin/app.py +110 -0
- trianglengin/cli.py +134 -0
- trianglengin/config/__init__.py +9 -0
- trianglengin/config/display_config.py +47 -0
- trianglengin/config/env_config.py +103 -0
- trianglengin/core/__init__.py +8 -0
- trianglengin/core/environment/__init__.py +31 -0
- trianglengin/core/environment/action_codec.py +37 -0
- trianglengin/core/environment/game_state.py +217 -0
- trianglengin/core/environment/grid/README.md +46 -0
- trianglengin/core/environment/grid/__init__.py +18 -0
- trianglengin/core/environment/grid/grid_data.py +140 -0
- trianglengin/core/environment/grid/line_cache.py +189 -0
- trianglengin/core/environment/grid/logic.py +131 -0
- trianglengin/core/environment/logic/__init__.py +3 -0
- trianglengin/core/environment/logic/actions.py +38 -0
- trianglengin/core/environment/logic/step.py +134 -0
- trianglengin/core/environment/shapes/__init__.py +19 -0
- trianglengin/core/environment/shapes/logic.py +84 -0
- trianglengin/core/environment/shapes/templates.py +587 -0
- trianglengin/core/structs/__init__.py +27 -0
- trianglengin/core/structs/constants.py +28 -0
- trianglengin/core/structs/shape.py +61 -0
- trianglengin/core/structs/triangle.py +48 -0
- trianglengin/interaction/README.md +45 -0
- trianglengin/interaction/__init__.py +17 -0
- trianglengin/interaction/debug_mode_handler.py +96 -0
- trianglengin/interaction/event_processor.py +43 -0
- trianglengin/interaction/input_handler.py +82 -0
- trianglengin/interaction/play_mode_handler.py +141 -0
- trianglengin/utils/__init__.py +9 -0
- trianglengin/utils/geometry.py +73 -0
- trianglengin/utils/types.py +10 -0
- trianglengin/visualization/README.md +44 -0
- trianglengin/visualization/__init__.py +61 -0
- trianglengin/visualization/core/README.md +52 -0
- trianglengin/visualization/core/__init__.py +12 -0
- trianglengin/visualization/core/colors.py +117 -0
- trianglengin/visualization/core/coord_mapper.py +73 -0
- trianglengin/visualization/core/fonts.py +55 -0
- trianglengin/visualization/core/layout.py +101 -0
- trianglengin/visualization/core/visualizer.py +232 -0
- trianglengin/visualization/drawing/README.md +45 -0
- trianglengin/visualization/drawing/__init__.py +30 -0
- trianglengin/visualization/drawing/grid.py +156 -0
- trianglengin/visualization/drawing/highlight.py +30 -0
- trianglengin/visualization/drawing/hud.py +39 -0
- trianglengin/visualization/drawing/previews.py +172 -0
- trianglengin/visualization/drawing/shapes.py +36 -0
- trianglengin-1.0.6.dist-info/METADATA +367 -0
- trianglengin-1.0.6.dist-info/RECORD +72 -0
- trianglengin-1.0.6.dist-info/WHEEL +5 -0
- trianglengin-1.0.6.dist-info/entry_points.txt +2 -0
- trianglengin-1.0.6.dist-info/licenses/LICENSE +22 -0
- trianglengin-1.0.6.dist-info/top_level.txt +2 -0
trianglengin/app.py
ADDED
@@ -0,0 +1,110 @@
|
|
1
|
+
import logging
|
2
|
+
|
3
|
+
import pygame
|
4
|
+
|
5
|
+
# Use internal imports
|
6
|
+
from . import config as tg_config
|
7
|
+
from . import core as tg_core
|
8
|
+
from . import interaction, visualization
|
9
|
+
|
10
|
+
logger = logging.getLogger(__name__)
|
11
|
+
|
12
|
+
|
13
|
+
class Application:
|
14
|
+
"""Main application integrating visualization and interaction for trianglengin."""
|
15
|
+
|
16
|
+
def __init__(self, mode: str = "play"):
|
17
|
+
# Use DisplayConfig from this library now
|
18
|
+
self.display_config = tg_config.DisplayConfig() # Use DisplayConfig
|
19
|
+
self.env_config = tg_config.EnvConfig()
|
20
|
+
self.mode = mode
|
21
|
+
|
22
|
+
pygame.init()
|
23
|
+
pygame.font.init()
|
24
|
+
self.screen = self._setup_screen()
|
25
|
+
self.clock = pygame.time.Clock()
|
26
|
+
self.fonts = visualization.load_fonts()
|
27
|
+
|
28
|
+
if self.mode in ["play", "debug"]:
|
29
|
+
# Create GameState using trianglengin core
|
30
|
+
self.game_state = tg_core.environment.GameState(self.env_config)
|
31
|
+
# Create Visualizer using trianglengin visualization
|
32
|
+
self.visualizer = visualization.Visualizer(
|
33
|
+
self.screen,
|
34
|
+
self.display_config,
|
35
|
+
self.env_config,
|
36
|
+
self.fonts, # Pass DisplayConfig
|
37
|
+
)
|
38
|
+
# Create InputHandler using trianglengin interaction
|
39
|
+
self.input_handler = interaction.InputHandler(
|
40
|
+
self.game_state, self.visualizer, self.mode, self.env_config
|
41
|
+
)
|
42
|
+
else:
|
43
|
+
# Handle other modes or raise error if necessary
|
44
|
+
logger.error(f"Unsupported application mode: {self.mode}")
|
45
|
+
raise ValueError(f"Unsupported application mode: {self.mode}")
|
46
|
+
|
47
|
+
self.running = True
|
48
|
+
|
49
|
+
def _setup_screen(self) -> pygame.Surface:
|
50
|
+
"""Initializes the Pygame screen."""
|
51
|
+
screen = pygame.display.set_mode(
|
52
|
+
(
|
53
|
+
self.display_config.SCREEN_WIDTH,
|
54
|
+
self.display_config.SCREEN_HEIGHT,
|
55
|
+
), # Use DisplayConfig
|
56
|
+
pygame.RESIZABLE,
|
57
|
+
)
|
58
|
+
# Use a generic name or make APP_NAME part of trianglengin config later
|
59
|
+
pygame.display.set_caption(f"Triangle Engine - {self.mode.capitalize()} Mode")
|
60
|
+
return screen
|
61
|
+
|
62
|
+
def run(self):
|
63
|
+
"""Main application loop."""
|
64
|
+
logger.info(f"Starting application in {self.mode} mode.")
|
65
|
+
while self.running:
|
66
|
+
self.clock.tick(self.display_config.FPS) # Use DisplayConfig
|
67
|
+
|
68
|
+
# Handle Input using InputHandler
|
69
|
+
if self.input_handler:
|
70
|
+
self.running = self.input_handler.handle_input()
|
71
|
+
if not self.running:
|
72
|
+
break
|
73
|
+
else:
|
74
|
+
# Fallback event handling (should not happen in play/debug)
|
75
|
+
for event in pygame.event.get():
|
76
|
+
if event.type == pygame.QUIT:
|
77
|
+
self.running = False
|
78
|
+
if event.type == pygame.KEYDOWN and event.key == pygame.K_ESCAPE:
|
79
|
+
self.running = False
|
80
|
+
if event.type == pygame.VIDEORESIZE and self.visualizer:
|
81
|
+
try:
|
82
|
+
w, h = max(320, event.w), max(240, event.h)
|
83
|
+
self.visualizer.screen = pygame.display.set_mode(
|
84
|
+
(w, h), pygame.RESIZABLE
|
85
|
+
)
|
86
|
+
self.visualizer.layout_rects = None
|
87
|
+
except pygame.error as e:
|
88
|
+
logger.error(f"Error resizing window: {e}")
|
89
|
+
if not self.running:
|
90
|
+
break
|
91
|
+
|
92
|
+
# Render using Visualizer
|
93
|
+
if (
|
94
|
+
self.mode in ["play", "debug"]
|
95
|
+
and self.visualizer
|
96
|
+
and self.game_state
|
97
|
+
and self.input_handler
|
98
|
+
):
|
99
|
+
interaction_render_state = (
|
100
|
+
self.input_handler.get_render_interaction_state()
|
101
|
+
)
|
102
|
+
self.visualizer.render(
|
103
|
+
self.game_state,
|
104
|
+
self.mode,
|
105
|
+
**interaction_render_state,
|
106
|
+
)
|
107
|
+
pygame.display.flip()
|
108
|
+
|
109
|
+
logger.info("Application loop finished.")
|
110
|
+
pygame.quit()
|
trianglengin/cli.py
ADDED
@@ -0,0 +1,134 @@
|
|
1
|
+
import logging
|
2
|
+
import sys
|
3
|
+
from typing import Annotated
|
4
|
+
|
5
|
+
# Removed torch import
|
6
|
+
import typer
|
7
|
+
|
8
|
+
# Use internal imports
|
9
|
+
from .app import Application
|
10
|
+
from .config import EnvConfig
|
11
|
+
|
12
|
+
app = typer.Typer(
|
13
|
+
name="trianglengin",
|
14
|
+
help="Core Triangle Engine - Interactive Modes.",
|
15
|
+
add_completion=False,
|
16
|
+
)
|
17
|
+
|
18
|
+
LogLevelOption = Annotated[
|
19
|
+
str,
|
20
|
+
typer.Option(
|
21
|
+
"--log-level",
|
22
|
+
"-l",
|
23
|
+
help="Set the logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL).",
|
24
|
+
case_sensitive=False,
|
25
|
+
),
|
26
|
+
]
|
27
|
+
|
28
|
+
SeedOption = Annotated[
|
29
|
+
int,
|
30
|
+
typer.Option(
|
31
|
+
"--seed",
|
32
|
+
"-s",
|
33
|
+
help="Random seed for reproducibility.",
|
34
|
+
),
|
35
|
+
]
|
36
|
+
|
37
|
+
|
38
|
+
def setup_logging(log_level_str: str):
|
39
|
+
"""Configures root logger based on string level."""
|
40
|
+
log_level_str = log_level_str.upper()
|
41
|
+
log_level_map = {
|
42
|
+
"DEBUG": logging.DEBUG,
|
43
|
+
"INFO": logging.INFO,
|
44
|
+
"WARNING": logging.WARNING,
|
45
|
+
"ERROR": logging.ERROR,
|
46
|
+
"CRITICAL": logging.CRITICAL,
|
47
|
+
}
|
48
|
+
log_level = log_level_map.get(log_level_str, logging.INFO)
|
49
|
+
logging.basicConfig(
|
50
|
+
level=log_level,
|
51
|
+
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
52
|
+
handlers=[logging.StreamHandler(sys.stdout)],
|
53
|
+
force=True, # Override existing config
|
54
|
+
)
|
55
|
+
# Keep external libraries less verbose if needed
|
56
|
+
logging.getLogger("pygame").setLevel(logging.WARNING)
|
57
|
+
logging.info(f"Root logger level set to {logging.getLevelName(log_level)}")
|
58
|
+
|
59
|
+
|
60
|
+
def run_interactive_mode(mode: str, seed: int, log_level: str):
|
61
|
+
"""Runs the interactive application."""
|
62
|
+
setup_logging(log_level)
|
63
|
+
logger = logging.getLogger(__name__)
|
64
|
+
logger.info(f"Running Triangle Engine in {mode.capitalize()} mode...")
|
65
|
+
|
66
|
+
# --- UPDATED SEEDING (Removed Torch) ---
|
67
|
+
try:
|
68
|
+
import random
|
69
|
+
|
70
|
+
import numpy as np
|
71
|
+
|
72
|
+
random.seed(seed)
|
73
|
+
# Use default_rng() for NumPy if available, otherwise skip NumPy seeding
|
74
|
+
try:
|
75
|
+
np.random.default_rng(seed)
|
76
|
+
logger.debug("NumPy seeded using default_rng.")
|
77
|
+
except AttributeError:
|
78
|
+
logger.warning("np.random.default_rng not available. Skipping NumPy seed.")
|
79
|
+
except ImportError:
|
80
|
+
logger.warning("NumPy not found. Skipping NumPy seed.")
|
81
|
+
|
82
|
+
# Removed torch.manual_seed(seed)
|
83
|
+
logger.info(f"Set random seeds to {seed}")
|
84
|
+
except ImportError:
|
85
|
+
logger.warning("Could not import all libraries for full seeding.")
|
86
|
+
except Exception as e:
|
87
|
+
logger.error(f"Error setting seeds: {e}")
|
88
|
+
# --- END UPDATED SEEDING ---
|
89
|
+
|
90
|
+
# Validate EnvConfig
|
91
|
+
try:
|
92
|
+
_ = EnvConfig()
|
93
|
+
logger.info("EnvConfig validated.")
|
94
|
+
except Exception as e:
|
95
|
+
logger.critical(f"EnvConfig validation failed: {e}", exc_info=True)
|
96
|
+
sys.exit(1)
|
97
|
+
|
98
|
+
try:
|
99
|
+
app_instance = Application(mode=mode)
|
100
|
+
app_instance.run()
|
101
|
+
except ImportError as e:
|
102
|
+
logger.error(f"Runtime ImportError: {e}")
|
103
|
+
logger.error(
|
104
|
+
"Please ensure all dependencies (including pygame) are installed for trianglengin."
|
105
|
+
)
|
106
|
+
sys.exit(1)
|
107
|
+
except Exception as e:
|
108
|
+
logger.critical(f"An unhandled error occurred: {e}", exc_info=True)
|
109
|
+
sys.exit(1)
|
110
|
+
|
111
|
+
logger.info("Exiting.")
|
112
|
+
sys.exit(0)
|
113
|
+
|
114
|
+
|
115
|
+
@app.command()
|
116
|
+
def play(
|
117
|
+
log_level: LogLevelOption = "INFO",
|
118
|
+
seed: SeedOption = 42,
|
119
|
+
):
|
120
|
+
"""Run the game in interactive Play mode."""
|
121
|
+
run_interactive_mode(mode="play", seed=seed, log_level=log_level)
|
122
|
+
|
123
|
+
|
124
|
+
@app.command()
|
125
|
+
def debug(
|
126
|
+
log_level: LogLevelOption = "DEBUG", # Default to DEBUG for debug mode
|
127
|
+
seed: SeedOption = 42,
|
128
|
+
):
|
129
|
+
"""Run the game in interactive Debug mode."""
|
130
|
+
run_interactive_mode(mode="debug", seed=seed, log_level=log_level)
|
131
|
+
|
132
|
+
|
133
|
+
if __name__ == "__main__":
|
134
|
+
app()
|
@@ -0,0 +1,9 @@
|
|
1
|
+
# trianglengin/config/__init__.py
|
2
|
+
"""
|
3
|
+
Shared configuration models for the Triangle Engine.
|
4
|
+
"""
|
5
|
+
|
6
|
+
from .display_config import DEFAULT_DISPLAY_CONFIG, DisplayConfig
|
7
|
+
from .env_config import EnvConfig
|
8
|
+
|
9
|
+
__all__ = ["EnvConfig", "DisplayConfig", "DEFAULT_DISPLAY_CONFIG"]
|
@@ -0,0 +1,47 @@
|
|
1
|
+
# trianglengin/config/display_config.py
|
2
|
+
"""
|
3
|
+
Configuration specific to display and visualization settings.
|
4
|
+
"""
|
5
|
+
|
6
|
+
import pygame
|
7
|
+
from pydantic import BaseModel, Field
|
8
|
+
|
9
|
+
# Initialize Pygame font module if not already done (safe to call multiple times)
|
10
|
+
pygame.font.init()
|
11
|
+
|
12
|
+
# Define a placeholder font loading function or load directly here
|
13
|
+
# In a real app, this might load from files or use system fonts more robustly.
|
14
|
+
try:
|
15
|
+
DEBUG_FONT_DEFAULT = pygame.font.SysFont("monospace", 12)
|
16
|
+
except Exception:
|
17
|
+
DEBUG_FONT_DEFAULT = pygame.font.Font(None, 15) # Fallback default pygame font
|
18
|
+
|
19
|
+
|
20
|
+
class DisplayConfig(BaseModel):
|
21
|
+
"""Configuration for visualization display settings."""
|
22
|
+
|
23
|
+
# Screen and Layout
|
24
|
+
SCREEN_WIDTH: int = Field(default=1024, gt=0)
|
25
|
+
SCREEN_HEIGHT: int = Field(default=768, gt=0)
|
26
|
+
FPS: int = Field(default=60, gt=0)
|
27
|
+
PADDING: int = Field(default=10, ge=0)
|
28
|
+
HUD_HEIGHT: int = Field(default=30, ge=0)
|
29
|
+
PREVIEW_AREA_WIDTH: int = Field(default=150, ge=50)
|
30
|
+
PREVIEW_PADDING: int = Field(default=5, ge=0)
|
31
|
+
PREVIEW_INNER_PADDING: int = Field(default=3, ge=0)
|
32
|
+
PREVIEW_BORDER_WIDTH: int = Field(default=1, ge=0)
|
33
|
+
PREVIEW_SELECTED_BORDER_WIDTH: int = Field(default=3, ge=0)
|
34
|
+
|
35
|
+
# Fonts (Store font objects directly or paths/names)
|
36
|
+
# Using Field(default=...) requires the default value to be simple.
|
37
|
+
# For complex objects like fonts, use default_factory or initialize in __init__.
|
38
|
+
# For simplicity here, we'll assign the pre-loaded font.
|
39
|
+
# Consider using default_factory=lambda: pygame.font.SysFont(...)
|
40
|
+
DEBUG_FONT: pygame.font.Font = Field(default=DEBUG_FONT_DEFAULT)
|
41
|
+
|
42
|
+
class Config:
|
43
|
+
arbitrary_types_allowed = True # Allow pygame.font.Font
|
44
|
+
|
45
|
+
|
46
|
+
# Optional: Create a default instance for easy import elsewhere
|
47
|
+
DEFAULT_DISPLAY_CONFIG = DisplayConfig()
|
@@ -0,0 +1,103 @@
|
|
1
|
+
# File: trianglengin/config/env_config.py
|
2
|
+
from pydantic import (
|
3
|
+
BaseModel,
|
4
|
+
Field,
|
5
|
+
computed_field,
|
6
|
+
field_validator,
|
7
|
+
model_validator,
|
8
|
+
)
|
9
|
+
|
10
|
+
|
11
|
+
class EnvConfig(BaseModel):
|
12
|
+
"""Configuration for the game environment (Pydantic model)."""
|
13
|
+
|
14
|
+
ROWS: int = Field(default=8, gt=0)
|
15
|
+
COLS: int = Field(default=15, gt=0)
|
16
|
+
PLAYABLE_RANGE_PER_ROW: list[tuple[int, int]] = Field(
|
17
|
+
default=[
|
18
|
+
(3, 12), # 9 cols, centered in 15
|
19
|
+
(2, 13), # 11 cols
|
20
|
+
(1, 14), # 13 cols
|
21
|
+
(0, 15), # 15 cols
|
22
|
+
(0, 15), # 15 cols
|
23
|
+
(1, 14), # 13 cols
|
24
|
+
(2, 13), # 11 cols
|
25
|
+
(3, 12), # 9 cols
|
26
|
+
]
|
27
|
+
)
|
28
|
+
NUM_SHAPE_SLOTS: int = Field(default=3, gt=0)
|
29
|
+
|
30
|
+
# --- Reward System Constants (v3) ---
|
31
|
+
REWARD_PER_PLACED_TRIANGLE: float = Field(default=0.01)
|
32
|
+
REWARD_PER_CLEARED_TRIANGLE: float = Field(default=0.5)
|
33
|
+
REWARD_PER_STEP_ALIVE: float = Field(default=0.005)
|
34
|
+
PENALTY_GAME_OVER: float = Field(default=-10.0)
|
35
|
+
# --- End Reward System Constants ---
|
36
|
+
|
37
|
+
@field_validator("PLAYABLE_RANGE_PER_ROW")
|
38
|
+
@classmethod
|
39
|
+
def check_playable_range_length(
|
40
|
+
cls, v: list[tuple[int, int]], info
|
41
|
+
) -> list[tuple[int, int]]:
|
42
|
+
"""Validates PLAYABLE_RANGE_PER_ROW."""
|
43
|
+
# Pydantic v2 uses 'values' in validator context
|
44
|
+
data = getattr(info, "data", None) or getattr(info, "values", {})
|
45
|
+
|
46
|
+
rows = data.get("ROWS")
|
47
|
+
cols = data.get("COLS")
|
48
|
+
|
49
|
+
if rows is None or cols is None:
|
50
|
+
return v
|
51
|
+
|
52
|
+
if len(v) != rows:
|
53
|
+
raise ValueError(
|
54
|
+
f"PLAYABLE_RANGE_PER_ROW length ({len(v)}) must equal ROWS ({rows})"
|
55
|
+
)
|
56
|
+
|
57
|
+
for i, (start, end) in enumerate(v):
|
58
|
+
if not (0 <= start < cols):
|
59
|
+
raise ValueError(
|
60
|
+
f"Row {i}: start_col ({start}) out of bounds [0, {cols})."
|
61
|
+
)
|
62
|
+
if not (start < end <= cols):
|
63
|
+
raise ValueError(
|
64
|
+
f"Row {i}: end_col ({end}) invalid. Must be > start_col ({start}) and <= COLS ({cols})."
|
65
|
+
)
|
66
|
+
# Allow zero width ranges (rows that are entirely death zones)
|
67
|
+
# if end - start <= 0:
|
68
|
+
# raise ValueError(
|
69
|
+
# f"Row {i}: Playable range width must be positive ({start}, {end})."
|
70
|
+
# )
|
71
|
+
|
72
|
+
return v
|
73
|
+
|
74
|
+
@model_validator(mode="after")
|
75
|
+
def check_cols_sufficient_for_ranges(self) -> "EnvConfig":
|
76
|
+
"""Ensure COLS is large enough for the specified ranges."""
|
77
|
+
if hasattr(self, "PLAYABLE_RANGE_PER_ROW") and self.PLAYABLE_RANGE_PER_ROW:
|
78
|
+
max_end_col = 0
|
79
|
+
for _, end in self.PLAYABLE_RANGE_PER_ROW:
|
80
|
+
max_end_col = max(max_end_col, end)
|
81
|
+
|
82
|
+
if max_end_col > self.COLS:
|
83
|
+
raise ValueError(
|
84
|
+
f"COLS ({self.COLS}) must be >= the maximum end_col in PLAYABLE_RANGE_PER_ROW ({max_end_col})"
|
85
|
+
)
|
86
|
+
return self
|
87
|
+
|
88
|
+
@computed_field # type: ignore[misc]
|
89
|
+
@property
|
90
|
+
def ACTION_DIM(self) -> int:
|
91
|
+
"""Total number of possible actions (shape_slot * row * col)."""
|
92
|
+
# Ensure attributes exist before calculating
|
93
|
+
if (
|
94
|
+
hasattr(self, "NUM_SHAPE_SLOTS")
|
95
|
+
and hasattr(self, "ROWS")
|
96
|
+
and hasattr(self, "COLS")
|
97
|
+
):
|
98
|
+
return self.NUM_SHAPE_SLOTS * self.ROWS * self.COLS
|
99
|
+
return 0 # Should not happen with pydantic defaults
|
100
|
+
|
101
|
+
|
102
|
+
# Ensure model is rebuilt after computed_field definition
|
103
|
+
EnvConfig.model_rebuild(force=True)
|
@@ -0,0 +1,31 @@
|
|
1
|
+
# File: trianglengin/core/environment/__init__.py
|
2
|
+
"""
|
3
|
+
Environment module defining the game rules, state, actions, and logic.
|
4
|
+
"""
|
5
|
+
|
6
|
+
from ...config import EnvConfig
|
7
|
+
from .action_codec import decode_action, encode_action
|
8
|
+
from .game_state import GameState
|
9
|
+
from .grid import logic as GridLogic
|
10
|
+
from .grid.grid_data import GridData
|
11
|
+
from .logic.actions import get_valid_actions
|
12
|
+
from .logic.step import calculate_reward, execute_placement
|
13
|
+
from .shapes import logic as ShapeLogic
|
14
|
+
|
15
|
+
__all__ = [
|
16
|
+
# Core
|
17
|
+
"GameState",
|
18
|
+
"encode_action",
|
19
|
+
"decode_action",
|
20
|
+
# Grid
|
21
|
+
"GridData",
|
22
|
+
"GridLogic",
|
23
|
+
# Shapes
|
24
|
+
"ShapeLogic",
|
25
|
+
# Logic
|
26
|
+
"get_valid_actions",
|
27
|
+
"execute_placement",
|
28
|
+
"calculate_reward",
|
29
|
+
# Config
|
30
|
+
"EnvConfig",
|
31
|
+
]
|
@@ -0,0 +1,37 @@
|
|
1
|
+
from typing import TypeAlias
|
2
|
+
|
3
|
+
from ...config import EnvConfig
|
4
|
+
|
5
|
+
ActionType: TypeAlias = int
|
6
|
+
|
7
|
+
|
8
|
+
def encode_action(shape_idx: int, r: int, c: int, config: EnvConfig) -> ActionType:
|
9
|
+
"""Encodes a (shape_idx, r, c) action into a single integer."""
|
10
|
+
if not (0 <= shape_idx < config.NUM_SHAPE_SLOTS):
|
11
|
+
raise ValueError(
|
12
|
+
f"Invalid shape index: {shape_idx}, must be < {config.NUM_SHAPE_SLOTS}"
|
13
|
+
)
|
14
|
+
if not (0 <= r < config.ROWS):
|
15
|
+
raise ValueError(f"Invalid row index: {r}, must be < {config.ROWS}")
|
16
|
+
if not (0 <= c < config.COLS):
|
17
|
+
raise ValueError(f"Invalid column index: {c}, must be < {config.COLS}")
|
18
|
+
|
19
|
+
action_index = shape_idx * (config.ROWS * config.COLS) + r * config.COLS + c
|
20
|
+
return action_index
|
21
|
+
|
22
|
+
|
23
|
+
def decode_action(action_index: ActionType, config: EnvConfig) -> tuple[int, int, int]:
|
24
|
+
"""Decodes an integer action into (shape_idx, r, c)."""
|
25
|
+
action_dim_int = int(config.ACTION_DIM) # type: ignore[call-overload]
|
26
|
+
if not (0 <= action_index < action_dim_int):
|
27
|
+
raise ValueError(
|
28
|
+
f"Invalid action index: {action_index}, must be < {action_dim_int}"
|
29
|
+
)
|
30
|
+
|
31
|
+
grid_size = config.ROWS * config.COLS
|
32
|
+
shape_idx = action_index // grid_size
|
33
|
+
remainder = action_index % grid_size
|
34
|
+
r = remainder // config.COLS
|
35
|
+
c = remainder % config.COLS
|
36
|
+
|
37
|
+
return shape_idx, r, c
|