mxbiflow 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mxbiflow/__init__.py +3 -0
- mxbiflow/assets/__init__.py +5 -0
- mxbiflow/assets/clicker.wav +0 -0
- mxbiflow/config_store.py +68 -0
- mxbiflow/data_logger.py +114 -0
- mxbiflow/default/__init__.py +4 -0
- mxbiflow/default/idle/assets/apple_v1.png +0 -0
- mxbiflow/default/idle/idle.py +57 -0
- mxbiflow/detector_bridge.py +87 -0
- mxbiflow/game.py +84 -0
- mxbiflow/infra/eventbus.py +31 -0
- mxbiflow/main.py +106 -0
- mxbiflow/models/animal.py +130 -0
- mxbiflow/models/reward.py +7 -0
- mxbiflow/models/session.py +145 -0
- mxbiflow/mxbiflow.py +43 -0
- mxbiflow/path.py +41 -0
- mxbiflow/scene/__init__.py +8 -0
- mxbiflow/scene/scene_manager.py +64 -0
- mxbiflow/scene/scene_protocol.py +22 -0
- mxbiflow/scheduler.py +90 -0
- mxbiflow/tasks/GNGSiD/models.py +70 -0
- mxbiflow/tasks/GNGSiD/stages/detect_stage/config.json +116 -0
- mxbiflow/tasks/GNGSiD/stages/detect_stage/detect_stage.py +161 -0
- mxbiflow/tasks/GNGSiD/stages/detect_stage/detect_stage_models.py +65 -0
- mxbiflow/tasks/GNGSiD/stages/discriminate_stage/config.json +70 -0
- mxbiflow/tasks/GNGSiD/stages/discriminate_stage/discriminate_stage.py +173 -0
- mxbiflow/tasks/GNGSiD/stages/discriminate_stage/discriminate_stage_models.py +80 -0
- mxbiflow/tasks/GNGSiD/stages/size_reduction_stage/config.json +83 -0
- mxbiflow/tasks/GNGSiD/stages/size_reduction_stage/size_reduction_models.py +58 -0
- mxbiflow/tasks/GNGSiD/stages/size_reduction_stage/size_reduction_stage.py +149 -0
- mxbiflow/tasks/GNGSiD/tasks/artifacts.py +13 -0
- mxbiflow/tasks/GNGSiD/tasks/detect/models.py +21 -0
- mxbiflow/tasks/GNGSiD/tasks/detect/scene.py +271 -0
- mxbiflow/tasks/GNGSiD/tasks/discriminate/discriminate_models.py +31 -0
- mxbiflow/tasks/GNGSiD/tasks/discriminate/discriminate_scene.py +336 -0
- mxbiflow/tasks/GNGSiD/tasks/touch/touch_models.py +17 -0
- mxbiflow/tasks/GNGSiD/tasks/touch/touch_scene.py +256 -0
- mxbiflow/tasks/GNGSiD/tasks/utils/targets.py +57 -0
- mxbiflow/tasks/cross_modal/bundle_dir.py +553 -0
- mxbiflow/tasks/cross_modal/config.py +41 -0
- mxbiflow/tasks/cross_modal/media.py +61 -0
- mxbiflow/tasks/cross_modal/models.py +57 -0
- mxbiflow/tasks/cross_modal/scene.py +252 -0
- mxbiflow/tasks/cross_modal/stage.py +218 -0
- mxbiflow/tasks/cross_modal/trial_io.py +23 -0
- mxbiflow/tasks/cross_modal/trial_schema.py +113 -0
- mxbiflow/tasks/default/error_task/error_scene.py +53 -0
- mxbiflow/tasks/default/idle_task/assets/apple_v1.png +0 -0
- mxbiflow/tasks/default/idle_task/idle_scene.py +85 -0
- mxbiflow/tasks/default/initial_habituation_training/README.md +188 -0
- mxbiflow/tasks/default/initial_habituation_training/stages/config.csv +7 -0
- mxbiflow/tasks/default/initial_habituation_training/stages/config.json +67 -0
- mxbiflow/tasks/default/initial_habituation_training/stages/initial_habituation_training_stage.py +172 -0
- mxbiflow/tasks/default/initial_habituation_training/stages/models.py +56 -0
- mxbiflow/tasks/default/initial_habituation_training/tasks/stay_to_reward/stay_to_reward.py +244 -0
- mxbiflow/tasks/default/initial_habituation_training/tasks/stay_to_reward/stay_to_reward_models.py +50 -0
- mxbiflow/tasks/task_protocol.py +26 -0
- mxbiflow/tasks/task_table.py +29 -0
- mxbiflow/tasks/two_alternative_choice/assets/starter.py +27 -0
- mxbiflow/tasks/two_alternative_choice/models.py +68 -0
- mxbiflow/tasks/two_alternative_choice/stages/size_reduction_stage/config.json +118 -0
- mxbiflow/tasks/two_alternative_choice/stages/size_reduction_stage/size_reduction_models.py +41 -0
- mxbiflow/tasks/two_alternative_choice/stages/size_reduction_stage/size_reduction_stage.py +122 -0
- mxbiflow/tasks/two_alternative_choice/tasks/touch/touch_models.py +19 -0
- mxbiflow/tasks/two_alternative_choice/tasks/touch/touch_scene.py +249 -0
- mxbiflow/timer/__init__.py +3 -0
- mxbiflow/timer/frame_timer.py +47 -0
- mxbiflow/timer/realtime_timer.py +0 -0
- mxbiflow/tmp_email.py +13 -0
- mxbiflow/ui/components/animal.py +87 -0
- mxbiflow/ui/components/baseconfig.py +68 -0
- mxbiflow/ui/components/card.py +18 -0
- mxbiflow/ui/components/device_card/__init__.py +17 -0
- mxbiflow/ui/components/device_card/detector/beambreak_detector_card.py +29 -0
- mxbiflow/ui/components/device_card/detector/fusion_detector.py +45 -0
- mxbiflow/ui/components/device_card/detector/mock_detector_card.py +20 -0
- mxbiflow/ui/components/device_card/detector/rfid_detector.py +40 -0
- mxbiflow/ui/components/device_card/device_card.py +67 -0
- mxbiflow/ui/components/device_card/rewarder/mock_rewarder_card.py +20 -0
- mxbiflow/ui/components/device_card/rewarder/rpi_gpio_rewarder.py +33 -0
- mxbiflow/ui/components/devices.py +183 -0
- mxbiflow/ui/components/dialog/__init__.py +3 -0
- mxbiflow/ui/components/dialog/add_devices_dialog.py +64 -0
- mxbiflow/ui/components/experiment_groups.py +122 -0
- mxbiflow/ui/experiment_panel.py +91 -0
- mxbiflow/ui/mxbi_panel.py +152 -0
- mxbiflow/utils/logger.py +19 -0
- mxbiflow/utils/serial.py +10 -0
- mxbiflow-0.1.1.dist-info/METADATA +168 -0
- mxbiflow-0.1.1.dist-info/RECORD +93 -0
- mxbiflow-0.1.1.dist-info/WHEEL +4 -0
- mxbiflow-0.1.1.dist-info/entry_points.txt +4 -0
mxbiflow/__init__.py
ADDED
|
Binary file
|
mxbiflow/config_store.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Generic, TypeVar
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel
|
|
6
|
+
|
|
7
|
+
T = TypeVar("T", bound=BaseModel)
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class ConfigStore(Generic[T]):
|
|
11
|
+
def __init__(
|
|
12
|
+
self,
|
|
13
|
+
config_path: Path,
|
|
14
|
+
config_class: type[T],
|
|
15
|
+
*,
|
|
16
|
+
create_default: bool = True,
|
|
17
|
+
) -> None:
|
|
18
|
+
self._config_path = config_path
|
|
19
|
+
self._config_class = config_class
|
|
20
|
+
self._create_default = create_default
|
|
21
|
+
self._config = self._load_config()
|
|
22
|
+
|
|
23
|
+
@property
|
|
24
|
+
def value(self) -> T:
|
|
25
|
+
return self._config
|
|
26
|
+
|
|
27
|
+
def _ensure_config_readable(self) -> None:
|
|
28
|
+
if not self._config_path.exists():
|
|
29
|
+
raise FileNotFoundError(f"Config file {self._config_path} not found")
|
|
30
|
+
|
|
31
|
+
if not os.access(self._config_path, os.R_OK):
|
|
32
|
+
raise PermissionError(f"Config file {self._config_path} is not readable")
|
|
33
|
+
|
|
34
|
+
def _create_default_config(self) -> T:
|
|
35
|
+
config = self._config_class()
|
|
36
|
+
self._config_path.parent.mkdir(parents=True, exist_ok=True)
|
|
37
|
+
self._config_path.write_text(
|
|
38
|
+
config.model_dump_json(indent=4),
|
|
39
|
+
encoding="utf-8",
|
|
40
|
+
)
|
|
41
|
+
return config
|
|
42
|
+
|
|
43
|
+
def _load_config(self) -> T:
|
|
44
|
+
try:
|
|
45
|
+
self._ensure_config_readable()
|
|
46
|
+
text = self._config_path.read_text(encoding="utf-8")
|
|
47
|
+
return self._config_class.model_validate_json(text)
|
|
48
|
+
except Exception as e:
|
|
49
|
+
if self._create_default:
|
|
50
|
+
return self._create_default_config()
|
|
51
|
+
raise RuntimeError(
|
|
52
|
+
f"Failed to load config file {self._config_path}: {e}"
|
|
53
|
+
) from e
|
|
54
|
+
|
|
55
|
+
def save(self, data: T | None = None) -> None:
|
|
56
|
+
if data is not None:
|
|
57
|
+
self._config = data
|
|
58
|
+
|
|
59
|
+
self._config_path.parent.mkdir(parents=True, exist_ok=True)
|
|
60
|
+
try:
|
|
61
|
+
self._config_path.write_text(
|
|
62
|
+
self._config.model_dump_json(indent=4),
|
|
63
|
+
encoding="utf-8",
|
|
64
|
+
)
|
|
65
|
+
except Exception as e:
|
|
66
|
+
raise RuntimeError(
|
|
67
|
+
f"Failed to save config file {self._config_path}: {e}"
|
|
68
|
+
) from e
|
mxbiflow/data_logger.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
import csv
|
|
2
|
+
import json
|
|
3
|
+
import sys
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
from enum import StrEnum
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
from .utils.logger import logger
|
|
9
|
+
|
|
10
|
+
now = datetime.now()
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class DataLoggerType(StrEnum):
|
|
14
|
+
JSONL = "jsonl"
|
|
15
|
+
JSON = "json"
|
|
16
|
+
CSV = "csv"
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class DataLogger:
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
path: Path,
|
|
23
|
+
session_id: int,
|
|
24
|
+
monkey: str,
|
|
25
|
+
filename: str,
|
|
26
|
+
type: DataLoggerType = DataLoggerType.JSONL,
|
|
27
|
+
) -> None:
|
|
28
|
+
self._path = path
|
|
29
|
+
self._session_id = session_id
|
|
30
|
+
self._monkey = monkey
|
|
31
|
+
self._filename = filename
|
|
32
|
+
self._type = type
|
|
33
|
+
|
|
34
|
+
self._data_dir = self._ensure_data_dir()
|
|
35
|
+
self._data_path = self._get_path(f".{self._type.value}")
|
|
36
|
+
|
|
37
|
+
@property
|
|
38
|
+
def path(self) -> Path:
|
|
39
|
+
return self._data_path
|
|
40
|
+
|
|
41
|
+
def _ensure_data_dir(self) -> Path:
|
|
42
|
+
date_path = Path(f"{now.year}{now.month:02d}{now.day:02d}")
|
|
43
|
+
session_path = Path(f"{self._session_id}")
|
|
44
|
+
monkey_path = Path(f"{self._monkey}")
|
|
45
|
+
|
|
46
|
+
base_dir = self._path / date_path / session_path / monkey_path
|
|
47
|
+
|
|
48
|
+
try:
|
|
49
|
+
base_dir.mkdir(parents=True, exist_ok=True)
|
|
50
|
+
return base_dir
|
|
51
|
+
except Exception as e:
|
|
52
|
+
logger.error(f"failed to create {base_dir}: {e}")
|
|
53
|
+
sys.exit(1)
|
|
54
|
+
|
|
55
|
+
def _get_path(self, suffix: str) -> Path:
|
|
56
|
+
return self._data_dir / f"{self._filename}{suffix}"
|
|
57
|
+
|
|
58
|
+
def save(self, data: dict) -> None:
|
|
59
|
+
match self._type:
|
|
60
|
+
case DataLoggerType.JSONL:
|
|
61
|
+
self._save_jsonl(data)
|
|
62
|
+
case DataLoggerType.JSON:
|
|
63
|
+
self._save_json(data)
|
|
64
|
+
case DataLoggerType.CSV:
|
|
65
|
+
self.save_csv_row(data)
|
|
66
|
+
|
|
67
|
+
def _save_jsonl(self, data: dict) -> None:
|
|
68
|
+
try:
|
|
69
|
+
json_line = json.dumps(data, ensure_ascii=False)
|
|
70
|
+
|
|
71
|
+
with open(self._data_path, "a", encoding="utf-8") as f:
|
|
72
|
+
f.write(json_line + "\n")
|
|
73
|
+
|
|
74
|
+
except TypeError as e:
|
|
75
|
+
logger.error(f"Data is not JSON serializable: {e}")
|
|
76
|
+
raise
|
|
77
|
+
except IOError as e:
|
|
78
|
+
logger.error(f"Failed to write to file {self._data_path}: {e}")
|
|
79
|
+
raise
|
|
80
|
+
except Exception as e:
|
|
81
|
+
logger.error(f"Unexpected error while writing data: {e}")
|
|
82
|
+
raise
|
|
83
|
+
|
|
84
|
+
def _save_json(self, data: dict) -> None:
|
|
85
|
+
try:
|
|
86
|
+
self._data_path.parent.mkdir(parents=True, exist_ok=True)
|
|
87
|
+
with open(self._data_path, "w", encoding="utf-8") as f:
|
|
88
|
+
json.dump(data, f, ensure_ascii=False, indent=2)
|
|
89
|
+
except TypeError as e:
|
|
90
|
+
logger.error(f"Data is not JSON serializable: {e}")
|
|
91
|
+
raise
|
|
92
|
+
except IOError as e:
|
|
93
|
+
logger.error(f"Failed to write to file {self._data_path}: {e}")
|
|
94
|
+
raise
|
|
95
|
+
except Exception as e:
|
|
96
|
+
logger.error(f"Unexpected error while writing JSON data: {e}")
|
|
97
|
+
raise
|
|
98
|
+
|
|
99
|
+
def save_csv_row(self, data: dict) -> None:
|
|
100
|
+
csv_path = self._get_path(".csv")
|
|
101
|
+
try:
|
|
102
|
+
csv_path.parent.mkdir(parents=True, exist_ok=True)
|
|
103
|
+
file_exists = csv_path.exists() and csv_path.stat().st_size > 0
|
|
104
|
+
|
|
105
|
+
fieldnames = sorted(data.keys())
|
|
106
|
+
|
|
107
|
+
with csv_path.open("a", newline="", encoding="utf-8") as f:
|
|
108
|
+
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
|
109
|
+
if not file_exists:
|
|
110
|
+
writer.writeheader()
|
|
111
|
+
writer.writerow({k: data.get(k, "") for k in fieldnames})
|
|
112
|
+
except Exception as e:
|
|
113
|
+
logger.error(f"Failed to write CSV row to {csv_path}: {e}")
|
|
114
|
+
raise
|
|
Binary file
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from random import choice
|
|
4
|
+
|
|
5
|
+
from pygame import Event, Rect, Surface, image, transform
|
|
6
|
+
|
|
7
|
+
from mxbiflow import get_mxbiflow
|
|
8
|
+
from mxbiflow.scene.scene_protocol import SceneProtocol
|
|
9
|
+
|
|
10
|
+
ASSETS_PATH = Path(__file__).parent / "assets"
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass
|
|
14
|
+
class Asset:
|
|
15
|
+
image: Surface
|
|
16
|
+
rect: Rect
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class IDLE:
|
|
20
|
+
_running: bool
|
|
21
|
+
|
|
22
|
+
def __init__(self) -> None:
|
|
23
|
+
self._mxbiflow = get_mxbiflow()
|
|
24
|
+
|
|
25
|
+
self._screen_size = self._mxbiflow.mxbi.screen_size
|
|
26
|
+
self._pos = ((self._screen_size.width // 4) * 3, self._screen_size.height // 2)
|
|
27
|
+
self._vstimulus_size = self._screen_size.width // 2 * 0.75
|
|
28
|
+
|
|
29
|
+
self._assets = [
|
|
30
|
+
Asset(
|
|
31
|
+
asset_image := transform.scale(
|
|
32
|
+
image.load(path).convert_alpha(),
|
|
33
|
+
(self._vstimulus_size, self._vstimulus_size),
|
|
34
|
+
),
|
|
35
|
+
asset_image.get_rect(center=self._pos),
|
|
36
|
+
)
|
|
37
|
+
for path in ASSETS_PATH.glob("*.png")
|
|
38
|
+
]
|
|
39
|
+
|
|
40
|
+
self._asset = choice(self._assets)
|
|
41
|
+
|
|
42
|
+
def start(self) -> None:
|
|
43
|
+
self._running = True
|
|
44
|
+
|
|
45
|
+
def quit(self) -> None:
|
|
46
|
+
self._running = False
|
|
47
|
+
|
|
48
|
+
@property
|
|
49
|
+
def running(self) -> bool:
|
|
50
|
+
return self._running
|
|
51
|
+
|
|
52
|
+
def handle_event(self, event: Event) -> None: ...
|
|
53
|
+
|
|
54
|
+
def update(self, dt_s: float) -> None: ...
|
|
55
|
+
|
|
56
|
+
def draw(self, screen: Surface) -> None:
|
|
57
|
+
screen.blit(self._asset.image, self._asset.rect)
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from queue import Empty, SimpleQueue
|
|
3
|
+
|
|
4
|
+
import pygame
|
|
5
|
+
from pygame import Event, event
|
|
6
|
+
from pymxbi.detector import MockDetector
|
|
7
|
+
from pymxbi.detector.detector import DetectionResult, Detector, DetectorEvent
|
|
8
|
+
|
|
9
|
+
EVT_DETECTOR = pygame.USEREVENT + 1
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass(frozen=True)
|
|
13
|
+
class DetectorMsg:
|
|
14
|
+
kind: DetectorEvent
|
|
15
|
+
animal: str | None
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class DetectorBridge:
|
|
19
|
+
def __init__(self, detector: Detector) -> None:
|
|
20
|
+
self._detector = detector
|
|
21
|
+
self._q = SimpleQueue()
|
|
22
|
+
self._started = False
|
|
23
|
+
|
|
24
|
+
def start(self) -> None:
|
|
25
|
+
if self._started:
|
|
26
|
+
return
|
|
27
|
+
self._started = True
|
|
28
|
+
|
|
29
|
+
self._detector.begin()
|
|
30
|
+
self._detector.register_event(DetectorEvent.ANIMAL_ENTERED, self._emit_entered)
|
|
31
|
+
self._detector.register_event(DetectorEvent.ANIMAL_LEFT, self._emit_left)
|
|
32
|
+
self._detector.register_event(DetectorEvent.FAULT_DETECTED, self._emit_fault)
|
|
33
|
+
|
|
34
|
+
def _emit(self, kind: DetectorEvent, animal: str | None) -> None:
|
|
35
|
+
self._q.put(DetectorMsg(kind=kind, animal=animal))
|
|
36
|
+
|
|
37
|
+
def _emit_entered(self, detection_result: DetectionResult) -> None:
|
|
38
|
+
self._emit(
|
|
39
|
+
DetectorEvent.ANIMAL_ENTERED,
|
|
40
|
+
detection_result.animal_id or detection_result.animal_name,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
def _emit_left(self, detection_result: DetectionResult) -> None:
|
|
44
|
+
self._emit(
|
|
45
|
+
DetectorEvent.ANIMAL_LEFT,
|
|
46
|
+
detection_result.animal_id or detection_result.animal_name,
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
def _emit_fault(self, detection_result: DetectionResult) -> None:
|
|
50
|
+
self._emit(
|
|
51
|
+
DetectorEvent.FAULT_DETECTED,
|
|
52
|
+
detection_result.animal_id or detection_result.animal_name,
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
def emit_pygame_event(self) -> None:
|
|
56
|
+
while True:
|
|
57
|
+
try:
|
|
58
|
+
msg = self._q.get_nowait()
|
|
59
|
+
except Empty:
|
|
60
|
+
break
|
|
61
|
+
|
|
62
|
+
event.post(event.Event(EVT_DETECTOR, msg=msg))
|
|
63
|
+
|
|
64
|
+
def manaul_emit(self, animal_idx: int | None = None) -> None:
|
|
65
|
+
if isinstance(self._detector, MockDetector):
|
|
66
|
+
if animal_idx is None:
|
|
67
|
+
self._detector.animal_left()
|
|
68
|
+
else:
|
|
69
|
+
self._detector.animal_present(animal_idx)
|
|
70
|
+
|
|
71
|
+
def handle_event(self, event: Event) -> None:
|
|
72
|
+
if event.type == pygame.KEYDOWN:
|
|
73
|
+
match event.key:
|
|
74
|
+
case pygame.K_0:
|
|
75
|
+
self.manaul_emit(0)
|
|
76
|
+
case pygame.K_1:
|
|
77
|
+
self.manaul_emit(1)
|
|
78
|
+
case pygame.K_2:
|
|
79
|
+
self.manaul_emit(2)
|
|
80
|
+
case pygame.K_3:
|
|
81
|
+
self.manaul_emit(3)
|
|
82
|
+
case pygame.K_4:
|
|
83
|
+
self.manaul_emit(4)
|
|
84
|
+
case pygame.K_5:
|
|
85
|
+
self.manaul_emit(5)
|
|
86
|
+
case pygame.K_l:
|
|
87
|
+
self.manaul_emit()
|
mxbiflow/game.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
import pygame
|
|
2
|
+
from pygame import Event
|
|
3
|
+
from pymxbi import MXBI
|
|
4
|
+
|
|
5
|
+
from .detector_bridge import DetectorBridge
|
|
6
|
+
from .models.session import Session
|
|
7
|
+
from .mxbiflow import MXBIFlow, set_mxbiflow
|
|
8
|
+
from .scene import SceneManager
|
|
9
|
+
from .scheduler import Scheduler
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Game:
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
session: Session,
|
|
16
|
+
scene_manager: SceneManager,
|
|
17
|
+
detector_bridge: DetectorBridge,
|
|
18
|
+
mxbi: MXBI,
|
|
19
|
+
) -> None:
|
|
20
|
+
pygame.init()
|
|
21
|
+
|
|
22
|
+
self._scene_manager = scene_manager
|
|
23
|
+
self._session = session
|
|
24
|
+
self._mxbi = mxbi
|
|
25
|
+
self._aplayer = self._mxbi.aplayer
|
|
26
|
+
|
|
27
|
+
self._detector_binder = detector_bridge
|
|
28
|
+
self._detector_binder.start()
|
|
29
|
+
|
|
30
|
+
self._scheduler = Scheduler(self._session, self._scene_manager)
|
|
31
|
+
|
|
32
|
+
self._mxbiflow = MXBIFlow(self._session, self._mxbi)
|
|
33
|
+
set_mxbiflow(self._mxbiflow)
|
|
34
|
+
|
|
35
|
+
self._screen = pygame.display.set_mode(
|
|
36
|
+
(self._mxbi.screen_size.width, self._mxbi.screen_size.height)
|
|
37
|
+
)
|
|
38
|
+
self._clock = pygame.time.Clock()
|
|
39
|
+
self._running = True
|
|
40
|
+
|
|
41
|
+
def play(self) -> None:
|
|
42
|
+
while self._running:
|
|
43
|
+
dt = self._clock.tick(60) / 1000.0
|
|
44
|
+
|
|
45
|
+
self._detector_binder.emit_pygame_event()
|
|
46
|
+
|
|
47
|
+
for event in pygame.event.get():
|
|
48
|
+
self._handle_event(event)
|
|
49
|
+
self._scheduler.handle_event(event)
|
|
50
|
+
self._scene_manager.handle_event(event)
|
|
51
|
+
self._detector_binder.handle_event(event)
|
|
52
|
+
|
|
53
|
+
self._scene_manager.update(dt)
|
|
54
|
+
self._aplayer.update()
|
|
55
|
+
self._scheduler.update()
|
|
56
|
+
self._mxbiflow.update()
|
|
57
|
+
|
|
58
|
+
self._screen.fill((0, 0, 0))
|
|
59
|
+
|
|
60
|
+
self._scene_manager.draw(self._screen)
|
|
61
|
+
self._scene_manager.apply_pending()
|
|
62
|
+
|
|
63
|
+
pygame.display.flip()
|
|
64
|
+
|
|
65
|
+
self.quit()
|
|
66
|
+
|
|
67
|
+
def _handle_event(self, event: Event) -> None:
|
|
68
|
+
match event.type:
|
|
69
|
+
case pygame.QUIT:
|
|
70
|
+
self._running = False
|
|
71
|
+
case pygame.KEYDOWN:
|
|
72
|
+
self._handle_keyboard_event(event)
|
|
73
|
+
|
|
74
|
+
def _handle_keyboard_event(self, event: Event) -> None:
|
|
75
|
+
match event.key:
|
|
76
|
+
case pygame.K_ESCAPE:
|
|
77
|
+
self._running = False
|
|
78
|
+
case pygame.K_q:
|
|
79
|
+
self._running = False
|
|
80
|
+
|
|
81
|
+
def quit(self) -> None:
|
|
82
|
+
if self._scene_manager.current is not None:
|
|
83
|
+
self._scene_manager.current.quit()
|
|
84
|
+
pygame.quit()
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
from typing import Callable
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class EventBus:
|
|
5
|
+
def __init__(self) -> None:
|
|
6
|
+
self._events_dict: dict[str, list[Callable]] = {}
|
|
7
|
+
|
|
8
|
+
def subscribe(self, event: str, handler: Callable) -> None:
|
|
9
|
+
if event in self._events_dict:
|
|
10
|
+
return
|
|
11
|
+
|
|
12
|
+
self._events_dict[event] = [handler]
|
|
13
|
+
|
|
14
|
+
def unsubscribe(self, event: str, handler: Callable) -> None:
|
|
15
|
+
if event not in self._events_dict:
|
|
16
|
+
return
|
|
17
|
+
|
|
18
|
+
self._events_dict[event].remove(handler)
|
|
19
|
+
|
|
20
|
+
def publish(self, event: str) -> None:
|
|
21
|
+
if event not in self._events_dict:
|
|
22
|
+
return
|
|
23
|
+
|
|
24
|
+
for handler in self._events_dict[event]:
|
|
25
|
+
handler()
|
|
26
|
+
|
|
27
|
+
def clear(self) -> None:
|
|
28
|
+
self._events_dict.clear()
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
event_bus = EventBus()
|
mxbiflow/main.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
from pymxbi import MXBI
|
|
2
|
+
|
|
3
|
+
from .game import Game
|
|
4
|
+
from .models.session import Session
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def main():
|
|
8
|
+
run_config()
|
|
9
|
+
|
|
10
|
+
mxbi = build_mxbi()
|
|
11
|
+
|
|
12
|
+
session = init_session()
|
|
13
|
+
|
|
14
|
+
game = init_mxbiflow(mxbi, session)
|
|
15
|
+
game.play()
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def run_config():
|
|
19
|
+
import sys
|
|
20
|
+
|
|
21
|
+
from PySide6.QtWidgets import QApplication
|
|
22
|
+
|
|
23
|
+
from .ui.experiment_panel import ExperimentPanel
|
|
24
|
+
from .ui.mxbi_panel import MXBIPanel
|
|
25
|
+
|
|
26
|
+
app = QApplication(sys.argv)
|
|
27
|
+
|
|
28
|
+
mxbi_panel = MXBIPanel()
|
|
29
|
+
experiment_panel = ExperimentPanel()
|
|
30
|
+
mxbi_panel.accepted.connect(experiment_panel.show)
|
|
31
|
+
experiment_panel.accepted.connect(app.quit)
|
|
32
|
+
|
|
33
|
+
mxbi_panel.show()
|
|
34
|
+
|
|
35
|
+
app.exec()
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def build_mxbi() -> MXBI:
|
|
39
|
+
from loguru import logger
|
|
40
|
+
from pymxbi import MXBIModel, build_mxbi
|
|
41
|
+
|
|
42
|
+
from .config_store import ConfigStore
|
|
43
|
+
from .models.session import SessionConfig
|
|
44
|
+
from .path import MXBI_CONFIG_PATH, SESSION_CONFIG_PATH
|
|
45
|
+
|
|
46
|
+
mxbi_config = ConfigStore(MXBI_CONFIG_PATH, MXBIModel).value
|
|
47
|
+
session_config = ConfigStore(SESSION_CONFIG_PATH, SessionConfig).value
|
|
48
|
+
|
|
49
|
+
mxbi = build_mxbi(mxbi_config, logger)
|
|
50
|
+
mxbi.register_animal(
|
|
51
|
+
{animal.rfid_id: animal.name for animal in session_config.animals}
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
return mxbi
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def init_session() -> Session:
|
|
58
|
+
from .config_store import ConfigStore
|
|
59
|
+
from .models.animal import Animal, StageState
|
|
60
|
+
from .models.session import DailySessionIdStore, Session, SessionConfig
|
|
61
|
+
from .path import SESSION_CONFIG_PATH, SESSION_COUNTER_PATH
|
|
62
|
+
|
|
63
|
+
session_config = ConfigStore(SESSION_CONFIG_PATH, SessionConfig).value
|
|
64
|
+
store = DailySessionIdStore(SESSION_COUNTER_PATH)
|
|
65
|
+
|
|
66
|
+
animal_dict: dict[str, Animal] = {}
|
|
67
|
+
for animal_config in session_config.animals:
|
|
68
|
+
train_state = StageState(
|
|
69
|
+
stage_name=animal_config.stage, level=animal_config.level
|
|
70
|
+
)
|
|
71
|
+
animal_state = Animal(
|
|
72
|
+
rfid_id=animal_config.rfid_id,
|
|
73
|
+
name=animal_config.name,
|
|
74
|
+
)
|
|
75
|
+
animal_state.set_current_stage(train_state)
|
|
76
|
+
animal_dict[animal_config.name] = animal_state
|
|
77
|
+
|
|
78
|
+
session = Session(
|
|
79
|
+
session_id=store.session_id,
|
|
80
|
+
experimenter=session_config.experimenter,
|
|
81
|
+
reward_type=session_config.reward_type,
|
|
82
|
+
send_email=False,
|
|
83
|
+
sync_data=False,
|
|
84
|
+
note=session_config.note,
|
|
85
|
+
animals=animal_dict,
|
|
86
|
+
)
|
|
87
|
+
session.start()
|
|
88
|
+
print(session.session_id)
|
|
89
|
+
|
|
90
|
+
return session
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def init_mxbiflow(mxbi, session) -> Game:
|
|
94
|
+
from .default import IDLE, Habituarion
|
|
95
|
+
from .detector_bridge import DetectorBridge
|
|
96
|
+
from .GNGSiD import SizeReduction
|
|
97
|
+
from .path import STAGE_PATH
|
|
98
|
+
from .scene import SceneManager
|
|
99
|
+
|
|
100
|
+
scene_manager = SceneManager()
|
|
101
|
+
scene_manager.register(Habituarion)
|
|
102
|
+
scene_manager.register(IDLE)
|
|
103
|
+
scene_manager.register(SizeReduction)
|
|
104
|
+
scene_manager.persist(STAGE_PATH)
|
|
105
|
+
detector_bridge = DetectorBridge(mxbi.detector)
|
|
106
|
+
return Game(session, scene_manager, detector_bridge, mxbi)
|
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
from time import time
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, Field, PrivateAttr
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class AnimalConfig(BaseModel):
|
|
7
|
+
rfid_id: str = Field(default="", frozen=True)
|
|
8
|
+
name: str = Field(default="mock", frozen=True)
|
|
9
|
+
stage: str = Field(default="idle", frozen=True)
|
|
10
|
+
level: int = Field(default=0, ge=0, frozen=True)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class AnimalBaseInfo(BaseModel):
|
|
14
|
+
animal: str
|
|
15
|
+
trial_id: int
|
|
16
|
+
level: int
|
|
17
|
+
level_trial_id: int
|
|
18
|
+
animal_session_id: int
|
|
19
|
+
animal_session_trial_id: int
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class StageState(BaseModel):
|
|
23
|
+
stage_name: str
|
|
24
|
+
stage_trial_id: int = Field(default=0, ge=0)
|
|
25
|
+
|
|
26
|
+
level: int = Field(default=0, ge=0)
|
|
27
|
+
level_trial_id: int = Field(default=0, ge=0)
|
|
28
|
+
|
|
29
|
+
def level_up(self):
|
|
30
|
+
self.level_trial_id = 0
|
|
31
|
+
self.level += 1
|
|
32
|
+
|
|
33
|
+
def level_down(self):
|
|
34
|
+
self.level_trial_id = 0
|
|
35
|
+
self.level -= 1
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class AnimalSessionState(BaseModel):
|
|
39
|
+
session_id: int = Field(ge=0)
|
|
40
|
+
start_at: float = Field(default_factory=lambda: time())
|
|
41
|
+
end_at: float | None = None
|
|
42
|
+
trial_id: int = Field(default=0, ge=0)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class Animal(BaseModel):
|
|
46
|
+
rfid_id: str = Field(frozen=True)
|
|
47
|
+
name: str = Field(frozen=True)
|
|
48
|
+
|
|
49
|
+
trial_id: int = Field(default=0, ge=0)
|
|
50
|
+
|
|
51
|
+
_current_stage: str = PrivateAttr(default="idle")
|
|
52
|
+
_stages: dict[str, StageState] = PrivateAttr(default_factory=dict)
|
|
53
|
+
_current_animal_session: AnimalSessionState | None = PrivateAttr(default=None)
|
|
54
|
+
_sessions: list[AnimalSessionState] = PrivateAttr(default_factory=list)
|
|
55
|
+
|
|
56
|
+
def add_trial(self) -> None:
|
|
57
|
+
self.trial_id += 1
|
|
58
|
+
self.current_stage.stage_trial_id += 1
|
|
59
|
+
self.current_stage.level_trial_id += 1
|
|
60
|
+
assert self.current_animal_session is not None
|
|
61
|
+
self.current_animal_session.trial_id += 1
|
|
62
|
+
|
|
63
|
+
@property
|
|
64
|
+
def current_animal_session(self) -> AnimalSessionState | None:
|
|
65
|
+
return self._current_animal_session
|
|
66
|
+
|
|
67
|
+
def start_animal_session(self):
|
|
68
|
+
session_id = 1
|
|
69
|
+
|
|
70
|
+
if self._current_animal_session is not None:
|
|
71
|
+
raise ValueError("Animal session is already started")
|
|
72
|
+
|
|
73
|
+
if self._sessions:
|
|
74
|
+
session_id = len(self._sessions) + 1
|
|
75
|
+
|
|
76
|
+
session = AnimalSessionState(session_id=session_id)
|
|
77
|
+
self._current_animal_session = session
|
|
78
|
+
self._sessions.append(session)
|
|
79
|
+
|
|
80
|
+
def end_animal_session(self):
|
|
81
|
+
if self._current_animal_session is None:
|
|
82
|
+
raise ValueError("Animal session is not started")
|
|
83
|
+
|
|
84
|
+
self._current_animal_session.end_at = time()
|
|
85
|
+
self._current_animal_session = None
|
|
86
|
+
|
|
87
|
+
@property
|
|
88
|
+
def current_stage(self) -> StageState:
|
|
89
|
+
key = self._current_stage
|
|
90
|
+
|
|
91
|
+
state = self._stages.get(key)
|
|
92
|
+
if state is None:
|
|
93
|
+
raise ValueError(f"Unknown stage: {key}")
|
|
94
|
+
|
|
95
|
+
return state
|
|
96
|
+
|
|
97
|
+
def set_current_stage(self, stage: StageState | str) -> None:
|
|
98
|
+
if isinstance(stage, str):
|
|
99
|
+
stage = StageState(stage_name=stage)
|
|
100
|
+
|
|
101
|
+
self._current_stage = stage.stage_name
|
|
102
|
+
if stage.stage_name in self._stages:
|
|
103
|
+
return
|
|
104
|
+
|
|
105
|
+
self._stages[stage.stage_name] = stage
|
|
106
|
+
|
|
107
|
+
def clear_current_stage(self) -> None:
|
|
108
|
+
self._current_stage = "idle"
|
|
109
|
+
|
|
110
|
+
def level_up(self) -> int:
|
|
111
|
+
self.current_stage.level_up()
|
|
112
|
+
return self.current_stage.level
|
|
113
|
+
|
|
114
|
+
def level_down(self) -> int:
|
|
115
|
+
self.current_stage.level_down()
|
|
116
|
+
return self.current_stage.level
|
|
117
|
+
|
|
118
|
+
@property
|
|
119
|
+
def base_info(self) -> AnimalBaseInfo:
|
|
120
|
+
if self.current_animal_session is None:
|
|
121
|
+
raise ValueError("Animal session is not started")
|
|
122
|
+
|
|
123
|
+
return AnimalBaseInfo(
|
|
124
|
+
animal=self.name,
|
|
125
|
+
trial_id=self.trial_id,
|
|
126
|
+
level=self.current_stage.level,
|
|
127
|
+
level_trial_id=self.current_stage.level_trial_id,
|
|
128
|
+
animal_session_id=self.current_animal_session.session_id,
|
|
129
|
+
animal_session_trial_id=self.current_animal_session.trial_id,
|
|
130
|
+
)
|