mxbiflow 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. mxbiflow/__init__.py +3 -0
  2. mxbiflow/assets/__init__.py +5 -0
  3. mxbiflow/assets/clicker.wav +0 -0
  4. mxbiflow/config_store.py +68 -0
  5. mxbiflow/data_logger.py +114 -0
  6. mxbiflow/default/__init__.py +4 -0
  7. mxbiflow/default/idle/assets/apple_v1.png +0 -0
  8. mxbiflow/default/idle/idle.py +57 -0
  9. mxbiflow/detector_bridge.py +87 -0
  10. mxbiflow/game.py +84 -0
  11. mxbiflow/infra/eventbus.py +31 -0
  12. mxbiflow/main.py +106 -0
  13. mxbiflow/models/animal.py +130 -0
  14. mxbiflow/models/reward.py +7 -0
  15. mxbiflow/models/session.py +145 -0
  16. mxbiflow/mxbiflow.py +43 -0
  17. mxbiflow/path.py +41 -0
  18. mxbiflow/scene/__init__.py +8 -0
  19. mxbiflow/scene/scene_manager.py +64 -0
  20. mxbiflow/scene/scene_protocol.py +22 -0
  21. mxbiflow/scheduler.py +90 -0
  22. mxbiflow/tasks/GNGSiD/models.py +70 -0
  23. mxbiflow/tasks/GNGSiD/stages/detect_stage/config.json +116 -0
  24. mxbiflow/tasks/GNGSiD/stages/detect_stage/detect_stage.py +161 -0
  25. mxbiflow/tasks/GNGSiD/stages/detect_stage/detect_stage_models.py +65 -0
  26. mxbiflow/tasks/GNGSiD/stages/discriminate_stage/config.json +70 -0
  27. mxbiflow/tasks/GNGSiD/stages/discriminate_stage/discriminate_stage.py +173 -0
  28. mxbiflow/tasks/GNGSiD/stages/discriminate_stage/discriminate_stage_models.py +80 -0
  29. mxbiflow/tasks/GNGSiD/stages/size_reduction_stage/config.json +83 -0
  30. mxbiflow/tasks/GNGSiD/stages/size_reduction_stage/size_reduction_models.py +58 -0
  31. mxbiflow/tasks/GNGSiD/stages/size_reduction_stage/size_reduction_stage.py +149 -0
  32. mxbiflow/tasks/GNGSiD/tasks/artifacts.py +13 -0
  33. mxbiflow/tasks/GNGSiD/tasks/detect/models.py +21 -0
  34. mxbiflow/tasks/GNGSiD/tasks/detect/scene.py +271 -0
  35. mxbiflow/tasks/GNGSiD/tasks/discriminate/discriminate_models.py +31 -0
  36. mxbiflow/tasks/GNGSiD/tasks/discriminate/discriminate_scene.py +336 -0
  37. mxbiflow/tasks/GNGSiD/tasks/touch/touch_models.py +17 -0
  38. mxbiflow/tasks/GNGSiD/tasks/touch/touch_scene.py +256 -0
  39. mxbiflow/tasks/GNGSiD/tasks/utils/targets.py +57 -0
  40. mxbiflow/tasks/cross_modal/bundle_dir.py +553 -0
  41. mxbiflow/tasks/cross_modal/config.py +41 -0
  42. mxbiflow/tasks/cross_modal/media.py +61 -0
  43. mxbiflow/tasks/cross_modal/models.py +57 -0
  44. mxbiflow/tasks/cross_modal/scene.py +252 -0
  45. mxbiflow/tasks/cross_modal/stage.py +218 -0
  46. mxbiflow/tasks/cross_modal/trial_io.py +23 -0
  47. mxbiflow/tasks/cross_modal/trial_schema.py +113 -0
  48. mxbiflow/tasks/default/error_task/error_scene.py +53 -0
  49. mxbiflow/tasks/default/idle_task/assets/apple_v1.png +0 -0
  50. mxbiflow/tasks/default/idle_task/idle_scene.py +85 -0
  51. mxbiflow/tasks/default/initial_habituation_training/README.md +188 -0
  52. mxbiflow/tasks/default/initial_habituation_training/stages/config.csv +7 -0
  53. mxbiflow/tasks/default/initial_habituation_training/stages/config.json +67 -0
  54. mxbiflow/tasks/default/initial_habituation_training/stages/initial_habituation_training_stage.py +172 -0
  55. mxbiflow/tasks/default/initial_habituation_training/stages/models.py +56 -0
  56. mxbiflow/tasks/default/initial_habituation_training/tasks/stay_to_reward/stay_to_reward.py +244 -0
  57. mxbiflow/tasks/default/initial_habituation_training/tasks/stay_to_reward/stay_to_reward_models.py +50 -0
  58. mxbiflow/tasks/task_protocol.py +26 -0
  59. mxbiflow/tasks/task_table.py +29 -0
  60. mxbiflow/tasks/two_alternative_choice/assets/starter.py +27 -0
  61. mxbiflow/tasks/two_alternative_choice/models.py +68 -0
  62. mxbiflow/tasks/two_alternative_choice/stages/size_reduction_stage/config.json +118 -0
  63. mxbiflow/tasks/two_alternative_choice/stages/size_reduction_stage/size_reduction_models.py +41 -0
  64. mxbiflow/tasks/two_alternative_choice/stages/size_reduction_stage/size_reduction_stage.py +122 -0
  65. mxbiflow/tasks/two_alternative_choice/tasks/touch/touch_models.py +19 -0
  66. mxbiflow/tasks/two_alternative_choice/tasks/touch/touch_scene.py +249 -0
  67. mxbiflow/timer/__init__.py +3 -0
  68. mxbiflow/timer/frame_timer.py +47 -0
  69. mxbiflow/timer/realtime_timer.py +0 -0
  70. mxbiflow/tmp_email.py +13 -0
  71. mxbiflow/ui/components/animal.py +87 -0
  72. mxbiflow/ui/components/baseconfig.py +68 -0
  73. mxbiflow/ui/components/card.py +18 -0
  74. mxbiflow/ui/components/device_card/__init__.py +17 -0
  75. mxbiflow/ui/components/device_card/detector/beambreak_detector_card.py +29 -0
  76. mxbiflow/ui/components/device_card/detector/fusion_detector.py +45 -0
  77. mxbiflow/ui/components/device_card/detector/mock_detector_card.py +20 -0
  78. mxbiflow/ui/components/device_card/detector/rfid_detector.py +40 -0
  79. mxbiflow/ui/components/device_card/device_card.py +67 -0
  80. mxbiflow/ui/components/device_card/rewarder/mock_rewarder_card.py +20 -0
  81. mxbiflow/ui/components/device_card/rewarder/rpi_gpio_rewarder.py +33 -0
  82. mxbiflow/ui/components/devices.py +183 -0
  83. mxbiflow/ui/components/dialog/__init__.py +3 -0
  84. mxbiflow/ui/components/dialog/add_devices_dialog.py +64 -0
  85. mxbiflow/ui/components/experiment_groups.py +122 -0
  86. mxbiflow/ui/experiment_panel.py +91 -0
  87. mxbiflow/ui/mxbi_panel.py +152 -0
  88. mxbiflow/utils/logger.py +19 -0
  89. mxbiflow/utils/serial.py +10 -0
  90. mxbiflow-0.1.1.dist-info/METADATA +168 -0
  91. mxbiflow-0.1.1.dist-info/RECORD +93 -0
  92. mxbiflow-0.1.1.dist-info/WHEEL +4 -0
  93. mxbiflow-0.1.1.dist-info/entry_points.txt +4 -0
mxbiflow/__init__.py ADDED
@@ -0,0 +1,3 @@
1
+ from .mxbiflow import MXBIFlow, get_mxbiflow
2
+
3
+ __all__ = ["MXBIFlow", "get_mxbiflow"]
@@ -0,0 +1,5 @@
1
+ from pathlib import Path
2
+
3
+ ROOT = Path(__file__).parent
4
+
5
+ ASSET_CLICKER_PATH = ROOT / "clicker.wav"
Binary file
@@ -0,0 +1,68 @@
1
+ import os
2
+ from pathlib import Path
3
+ from typing import Generic, TypeVar
4
+
5
+ from pydantic import BaseModel
6
+
7
+ T = TypeVar("T", bound=BaseModel)
8
+
9
+
10
+ class ConfigStore(Generic[T]):
11
+ def __init__(
12
+ self,
13
+ config_path: Path,
14
+ config_class: type[T],
15
+ *,
16
+ create_default: bool = True,
17
+ ) -> None:
18
+ self._config_path = config_path
19
+ self._config_class = config_class
20
+ self._create_default = create_default
21
+ self._config = self._load_config()
22
+
23
+ @property
24
+ def value(self) -> T:
25
+ return self._config
26
+
27
+ def _ensure_config_readable(self) -> None:
28
+ if not self._config_path.exists():
29
+ raise FileNotFoundError(f"Config file {self._config_path} not found")
30
+
31
+ if not os.access(self._config_path, os.R_OK):
32
+ raise PermissionError(f"Config file {self._config_path} is not readable")
33
+
34
+ def _create_default_config(self) -> T:
35
+ config = self._config_class()
36
+ self._config_path.parent.mkdir(parents=True, exist_ok=True)
37
+ self._config_path.write_text(
38
+ config.model_dump_json(indent=4),
39
+ encoding="utf-8",
40
+ )
41
+ return config
42
+
43
+ def _load_config(self) -> T:
44
+ try:
45
+ self._ensure_config_readable()
46
+ text = self._config_path.read_text(encoding="utf-8")
47
+ return self._config_class.model_validate_json(text)
48
+ except Exception as e:
49
+ if self._create_default:
50
+ return self._create_default_config()
51
+ raise RuntimeError(
52
+ f"Failed to load config file {self._config_path}: {e}"
53
+ ) from e
54
+
55
+ def save(self, data: T | None = None) -> None:
56
+ if data is not None:
57
+ self._config = data
58
+
59
+ self._config_path.parent.mkdir(parents=True, exist_ok=True)
60
+ try:
61
+ self._config_path.write_text(
62
+ self._config.model_dump_json(indent=4),
63
+ encoding="utf-8",
64
+ )
65
+ except Exception as e:
66
+ raise RuntimeError(
67
+ f"Failed to save config file {self._config_path}: {e}"
68
+ ) from e
@@ -0,0 +1,114 @@
1
+ import csv
2
+ import json
3
+ import sys
4
+ from datetime import datetime
5
+ from enum import StrEnum
6
+ from pathlib import Path
7
+
8
+ from .utils.logger import logger
9
+
10
+ now = datetime.now()
11
+
12
+
13
+ class DataLoggerType(StrEnum):
14
+ JSONL = "jsonl"
15
+ JSON = "json"
16
+ CSV = "csv"
17
+
18
+
19
+ class DataLogger:
20
+ def __init__(
21
+ self,
22
+ path: Path,
23
+ session_id: int,
24
+ monkey: str,
25
+ filename: str,
26
+ type: DataLoggerType = DataLoggerType.JSONL,
27
+ ) -> None:
28
+ self._path = path
29
+ self._session_id = session_id
30
+ self._monkey = monkey
31
+ self._filename = filename
32
+ self._type = type
33
+
34
+ self._data_dir = self._ensure_data_dir()
35
+ self._data_path = self._get_path(f".{self._type.value}")
36
+
37
+ @property
38
+ def path(self) -> Path:
39
+ return self._data_path
40
+
41
+ def _ensure_data_dir(self) -> Path:
42
+ date_path = Path(f"{now.year}{now.month:02d}{now.day:02d}")
43
+ session_path = Path(f"{self._session_id}")
44
+ monkey_path = Path(f"{self._monkey}")
45
+
46
+ base_dir = self._path / date_path / session_path / monkey_path
47
+
48
+ try:
49
+ base_dir.mkdir(parents=True, exist_ok=True)
50
+ return base_dir
51
+ except Exception as e:
52
+ logger.error(f"failed to create {base_dir}: {e}")
53
+ sys.exit(1)
54
+
55
+ def _get_path(self, suffix: str) -> Path:
56
+ return self._data_dir / f"{self._filename}{suffix}"
57
+
58
+ def save(self, data: dict) -> None:
59
+ match self._type:
60
+ case DataLoggerType.JSONL:
61
+ self._save_jsonl(data)
62
+ case DataLoggerType.JSON:
63
+ self._save_json(data)
64
+ case DataLoggerType.CSV:
65
+ self.save_csv_row(data)
66
+
67
+ def _save_jsonl(self, data: dict) -> None:
68
+ try:
69
+ json_line = json.dumps(data, ensure_ascii=False)
70
+
71
+ with open(self._data_path, "a", encoding="utf-8") as f:
72
+ f.write(json_line + "\n")
73
+
74
+ except TypeError as e:
75
+ logger.error(f"Data is not JSON serializable: {e}")
76
+ raise
77
+ except IOError as e:
78
+ logger.error(f"Failed to write to file {self._data_path}: {e}")
79
+ raise
80
+ except Exception as e:
81
+ logger.error(f"Unexpected error while writing data: {e}")
82
+ raise
83
+
84
+ def _save_json(self, data: dict) -> None:
85
+ try:
86
+ self._data_path.parent.mkdir(parents=True, exist_ok=True)
87
+ with open(self._data_path, "w", encoding="utf-8") as f:
88
+ json.dump(data, f, ensure_ascii=False, indent=2)
89
+ except TypeError as e:
90
+ logger.error(f"Data is not JSON serializable: {e}")
91
+ raise
92
+ except IOError as e:
93
+ logger.error(f"Failed to write to file {self._data_path}: {e}")
94
+ raise
95
+ except Exception as e:
96
+ logger.error(f"Unexpected error while writing JSON data: {e}")
97
+ raise
98
+
99
+ def save_csv_row(self, data: dict) -> None:
100
+ csv_path = self._get_path(".csv")
101
+ try:
102
+ csv_path.parent.mkdir(parents=True, exist_ok=True)
103
+ file_exists = csv_path.exists() and csv_path.stat().st_size > 0
104
+
105
+ fieldnames = sorted(data.keys())
106
+
107
+ with csv_path.open("a", newline="", encoding="utf-8") as f:
108
+ writer = csv.DictWriter(f, fieldnames=fieldnames)
109
+ if not file_exists:
110
+ writer.writeheader()
111
+ writer.writerow({k: data.get(k, "") for k in fieldnames})
112
+ except Exception as e:
113
+ logger.error(f"Failed to write CSV row to {csv_path}: {e}")
114
+ raise
@@ -0,0 +1,4 @@
1
+ from .habituation.habituarion import Habituarion
2
+ from .idle.idle import IDLE
3
+
4
+ __all__ = ["Habituarion", "IDLE"]
@@ -0,0 +1,57 @@
1
+ from dataclasses import dataclass
2
+ from pathlib import Path
3
+ from random import choice
4
+
5
+ from pygame import Event, Rect, Surface, image, transform
6
+
7
+ from mxbiflow import get_mxbiflow
8
+ from mxbiflow.scene.scene_protocol import SceneProtocol
9
+
10
+ ASSETS_PATH = Path(__file__).parent / "assets"
11
+
12
+
13
+ @dataclass
14
+ class Asset:
15
+ image: Surface
16
+ rect: Rect
17
+
18
+
19
+ class IDLE:
20
+ _running: bool
21
+
22
+ def __init__(self) -> None:
23
+ self._mxbiflow = get_mxbiflow()
24
+
25
+ self._screen_size = self._mxbiflow.mxbi.screen_size
26
+ self._pos = ((self._screen_size.width // 4) * 3, self._screen_size.height // 2)
27
+ self._vstimulus_size = self._screen_size.width // 2 * 0.75
28
+
29
+ self._assets = [
30
+ Asset(
31
+ asset_image := transform.scale(
32
+ image.load(path).convert_alpha(),
33
+ (self._vstimulus_size, self._vstimulus_size),
34
+ ),
35
+ asset_image.get_rect(center=self._pos),
36
+ )
37
+ for path in ASSETS_PATH.glob("*.png")
38
+ ]
39
+
40
+ self._asset = choice(self._assets)
41
+
42
+ def start(self) -> None:
43
+ self._running = True
44
+
45
+ def quit(self) -> None:
46
+ self._running = False
47
+
48
+ @property
49
+ def running(self) -> bool:
50
+ return self._running
51
+
52
+ def handle_event(self, event: Event) -> None: ...
53
+
54
+ def update(self, dt_s: float) -> None: ...
55
+
56
+ def draw(self, screen: Surface) -> None:
57
+ screen.blit(self._asset.image, self._asset.rect)
@@ -0,0 +1,87 @@
1
+ from dataclasses import dataclass
2
+ from queue import Empty, SimpleQueue
3
+
4
+ import pygame
5
+ from pygame import Event, event
6
+ from pymxbi.detector import MockDetector
7
+ from pymxbi.detector.detector import DetectionResult, Detector, DetectorEvent
8
+
9
+ EVT_DETECTOR = pygame.USEREVENT + 1
10
+
11
+
12
+ @dataclass(frozen=True)
13
+ class DetectorMsg:
14
+ kind: DetectorEvent
15
+ animal: str | None
16
+
17
+
18
+ class DetectorBridge:
19
+ def __init__(self, detector: Detector) -> None:
20
+ self._detector = detector
21
+ self._q = SimpleQueue()
22
+ self._started = False
23
+
24
+ def start(self) -> None:
25
+ if self._started:
26
+ return
27
+ self._started = True
28
+
29
+ self._detector.begin()
30
+ self._detector.register_event(DetectorEvent.ANIMAL_ENTERED, self._emit_entered)
31
+ self._detector.register_event(DetectorEvent.ANIMAL_LEFT, self._emit_left)
32
+ self._detector.register_event(DetectorEvent.FAULT_DETECTED, self._emit_fault)
33
+
34
+ def _emit(self, kind: DetectorEvent, animal: str | None) -> None:
35
+ self._q.put(DetectorMsg(kind=kind, animal=animal))
36
+
37
+ def _emit_entered(self, detection_result: DetectionResult) -> None:
38
+ self._emit(
39
+ DetectorEvent.ANIMAL_ENTERED,
40
+ detection_result.animal_id or detection_result.animal_name,
41
+ )
42
+
43
+ def _emit_left(self, detection_result: DetectionResult) -> None:
44
+ self._emit(
45
+ DetectorEvent.ANIMAL_LEFT,
46
+ detection_result.animal_id or detection_result.animal_name,
47
+ )
48
+
49
+ def _emit_fault(self, detection_result: DetectionResult) -> None:
50
+ self._emit(
51
+ DetectorEvent.FAULT_DETECTED,
52
+ detection_result.animal_id or detection_result.animal_name,
53
+ )
54
+
55
+ def emit_pygame_event(self) -> None:
56
+ while True:
57
+ try:
58
+ msg = self._q.get_nowait()
59
+ except Empty:
60
+ break
61
+
62
+ event.post(event.Event(EVT_DETECTOR, msg=msg))
63
+
64
+ def manaul_emit(self, animal_idx: int | None = None) -> None:
65
+ if isinstance(self._detector, MockDetector):
66
+ if animal_idx is None:
67
+ self._detector.animal_left()
68
+ else:
69
+ self._detector.animal_present(animal_idx)
70
+
71
+ def handle_event(self, event: Event) -> None:
72
+ if event.type == pygame.KEYDOWN:
73
+ match event.key:
74
+ case pygame.K_0:
75
+ self.manaul_emit(0)
76
+ case pygame.K_1:
77
+ self.manaul_emit(1)
78
+ case pygame.K_2:
79
+ self.manaul_emit(2)
80
+ case pygame.K_3:
81
+ self.manaul_emit(3)
82
+ case pygame.K_4:
83
+ self.manaul_emit(4)
84
+ case pygame.K_5:
85
+ self.manaul_emit(5)
86
+ case pygame.K_l:
87
+ self.manaul_emit()
mxbiflow/game.py ADDED
@@ -0,0 +1,84 @@
1
+ import pygame
2
+ from pygame import Event
3
+ from pymxbi import MXBI
4
+
5
+ from .detector_bridge import DetectorBridge
6
+ from .models.session import Session
7
+ from .mxbiflow import MXBIFlow, set_mxbiflow
8
+ from .scene import SceneManager
9
+ from .scheduler import Scheduler
10
+
11
+
12
+ class Game:
13
+ def __init__(
14
+ self,
15
+ session: Session,
16
+ scene_manager: SceneManager,
17
+ detector_bridge: DetectorBridge,
18
+ mxbi: MXBI,
19
+ ) -> None:
20
+ pygame.init()
21
+
22
+ self._scene_manager = scene_manager
23
+ self._session = session
24
+ self._mxbi = mxbi
25
+ self._aplayer = self._mxbi.aplayer
26
+
27
+ self._detector_binder = detector_bridge
28
+ self._detector_binder.start()
29
+
30
+ self._scheduler = Scheduler(self._session, self._scene_manager)
31
+
32
+ self._mxbiflow = MXBIFlow(self._session, self._mxbi)
33
+ set_mxbiflow(self._mxbiflow)
34
+
35
+ self._screen = pygame.display.set_mode(
36
+ (self._mxbi.screen_size.width, self._mxbi.screen_size.height)
37
+ )
38
+ self._clock = pygame.time.Clock()
39
+ self._running = True
40
+
41
+ def play(self) -> None:
42
+ while self._running:
43
+ dt = self._clock.tick(60) / 1000.0
44
+
45
+ self._detector_binder.emit_pygame_event()
46
+
47
+ for event in pygame.event.get():
48
+ self._handle_event(event)
49
+ self._scheduler.handle_event(event)
50
+ self._scene_manager.handle_event(event)
51
+ self._detector_binder.handle_event(event)
52
+
53
+ self._scene_manager.update(dt)
54
+ self._aplayer.update()
55
+ self._scheduler.update()
56
+ self._mxbiflow.update()
57
+
58
+ self._screen.fill((0, 0, 0))
59
+
60
+ self._scene_manager.draw(self._screen)
61
+ self._scene_manager.apply_pending()
62
+
63
+ pygame.display.flip()
64
+
65
+ self.quit()
66
+
67
+ def _handle_event(self, event: Event) -> None:
68
+ match event.type:
69
+ case pygame.QUIT:
70
+ self._running = False
71
+ case pygame.KEYDOWN:
72
+ self._handle_keyboard_event(event)
73
+
74
+ def _handle_keyboard_event(self, event: Event) -> None:
75
+ match event.key:
76
+ case pygame.K_ESCAPE:
77
+ self._running = False
78
+ case pygame.K_q:
79
+ self._running = False
80
+
81
+ def quit(self) -> None:
82
+ if self._scene_manager.current is not None:
83
+ self._scene_manager.current.quit()
84
+ pygame.quit()
@@ -0,0 +1,31 @@
1
+ from typing import Callable
2
+
3
+
4
+ class EventBus:
5
+ def __init__(self) -> None:
6
+ self._events_dict: dict[str, list[Callable]] = {}
7
+
8
+ def subscribe(self, event: str, handler: Callable) -> None:
9
+ if event in self._events_dict:
10
+ return
11
+
12
+ self._events_dict[event] = [handler]
13
+
14
+ def unsubscribe(self, event: str, handler: Callable) -> None:
15
+ if event not in self._events_dict:
16
+ return
17
+
18
+ self._events_dict[event].remove(handler)
19
+
20
+ def publish(self, event: str) -> None:
21
+ if event not in self._events_dict:
22
+ return
23
+
24
+ for handler in self._events_dict[event]:
25
+ handler()
26
+
27
+ def clear(self) -> None:
28
+ self._events_dict.clear()
29
+
30
+
31
+ event_bus = EventBus()
mxbiflow/main.py ADDED
@@ -0,0 +1,106 @@
1
+ from pymxbi import MXBI
2
+
3
+ from .game import Game
4
+ from .models.session import Session
5
+
6
+
7
+ def main():
8
+ run_config()
9
+
10
+ mxbi = build_mxbi()
11
+
12
+ session = init_session()
13
+
14
+ game = init_mxbiflow(mxbi, session)
15
+ game.play()
16
+
17
+
18
+ def run_config():
19
+ import sys
20
+
21
+ from PySide6.QtWidgets import QApplication
22
+
23
+ from .ui.experiment_panel import ExperimentPanel
24
+ from .ui.mxbi_panel import MXBIPanel
25
+
26
+ app = QApplication(sys.argv)
27
+
28
+ mxbi_panel = MXBIPanel()
29
+ experiment_panel = ExperimentPanel()
30
+ mxbi_panel.accepted.connect(experiment_panel.show)
31
+ experiment_panel.accepted.connect(app.quit)
32
+
33
+ mxbi_panel.show()
34
+
35
+ app.exec()
36
+
37
+
38
+ def build_mxbi() -> MXBI:
39
+ from loguru import logger
40
+ from pymxbi import MXBIModel, build_mxbi
41
+
42
+ from .config_store import ConfigStore
43
+ from .models.session import SessionConfig
44
+ from .path import MXBI_CONFIG_PATH, SESSION_CONFIG_PATH
45
+
46
+ mxbi_config = ConfigStore(MXBI_CONFIG_PATH, MXBIModel).value
47
+ session_config = ConfigStore(SESSION_CONFIG_PATH, SessionConfig).value
48
+
49
+ mxbi = build_mxbi(mxbi_config, logger)
50
+ mxbi.register_animal(
51
+ {animal.rfid_id: animal.name for animal in session_config.animals}
52
+ )
53
+
54
+ return mxbi
55
+
56
+
57
+ def init_session() -> Session:
58
+ from .config_store import ConfigStore
59
+ from .models.animal import Animal, StageState
60
+ from .models.session import DailySessionIdStore, Session, SessionConfig
61
+ from .path import SESSION_CONFIG_PATH, SESSION_COUNTER_PATH
62
+
63
+ session_config = ConfigStore(SESSION_CONFIG_PATH, SessionConfig).value
64
+ store = DailySessionIdStore(SESSION_COUNTER_PATH)
65
+
66
+ animal_dict: dict[str, Animal] = {}
67
+ for animal_config in session_config.animals:
68
+ train_state = StageState(
69
+ stage_name=animal_config.stage, level=animal_config.level
70
+ )
71
+ animal_state = Animal(
72
+ rfid_id=animal_config.rfid_id,
73
+ name=animal_config.name,
74
+ )
75
+ animal_state.set_current_stage(train_state)
76
+ animal_dict[animal_config.name] = animal_state
77
+
78
+ session = Session(
79
+ session_id=store.session_id,
80
+ experimenter=session_config.experimenter,
81
+ reward_type=session_config.reward_type,
82
+ send_email=False,
83
+ sync_data=False,
84
+ note=session_config.note,
85
+ animals=animal_dict,
86
+ )
87
+ session.start()
88
+ print(session.session_id)
89
+
90
+ return session
91
+
92
+
93
+ def init_mxbiflow(mxbi, session) -> Game:
94
+ from .default import IDLE, Habituarion
95
+ from .detector_bridge import DetectorBridge
96
+ from .GNGSiD import SizeReduction
97
+ from .path import STAGE_PATH
98
+ from .scene import SceneManager
99
+
100
+ scene_manager = SceneManager()
101
+ scene_manager.register(Habituarion)
102
+ scene_manager.register(IDLE)
103
+ scene_manager.register(SizeReduction)
104
+ scene_manager.persist(STAGE_PATH)
105
+ detector_bridge = DetectorBridge(mxbi.detector)
106
+ return Game(session, scene_manager, detector_bridge, mxbi)
@@ -0,0 +1,130 @@
1
+ from time import time
2
+
3
+ from pydantic import BaseModel, Field, PrivateAttr
4
+
5
+
6
+ class AnimalConfig(BaseModel):
7
+ rfid_id: str = Field(default="", frozen=True)
8
+ name: str = Field(default="mock", frozen=True)
9
+ stage: str = Field(default="idle", frozen=True)
10
+ level: int = Field(default=0, ge=0, frozen=True)
11
+
12
+
13
+ class AnimalBaseInfo(BaseModel):
14
+ animal: str
15
+ trial_id: int
16
+ level: int
17
+ level_trial_id: int
18
+ animal_session_id: int
19
+ animal_session_trial_id: int
20
+
21
+
22
+ class StageState(BaseModel):
23
+ stage_name: str
24
+ stage_trial_id: int = Field(default=0, ge=0)
25
+
26
+ level: int = Field(default=0, ge=0)
27
+ level_trial_id: int = Field(default=0, ge=0)
28
+
29
+ def level_up(self):
30
+ self.level_trial_id = 0
31
+ self.level += 1
32
+
33
+ def level_down(self):
34
+ self.level_trial_id = 0
35
+ self.level -= 1
36
+
37
+
38
+ class AnimalSessionState(BaseModel):
39
+ session_id: int = Field(ge=0)
40
+ start_at: float = Field(default_factory=lambda: time())
41
+ end_at: float | None = None
42
+ trial_id: int = Field(default=0, ge=0)
43
+
44
+
45
+ class Animal(BaseModel):
46
+ rfid_id: str = Field(frozen=True)
47
+ name: str = Field(frozen=True)
48
+
49
+ trial_id: int = Field(default=0, ge=0)
50
+
51
+ _current_stage: str = PrivateAttr(default="idle")
52
+ _stages: dict[str, StageState] = PrivateAttr(default_factory=dict)
53
+ _current_animal_session: AnimalSessionState | None = PrivateAttr(default=None)
54
+ _sessions: list[AnimalSessionState] = PrivateAttr(default_factory=list)
55
+
56
+ def add_trial(self) -> None:
57
+ self.trial_id += 1
58
+ self.current_stage.stage_trial_id += 1
59
+ self.current_stage.level_trial_id += 1
60
+ assert self.current_animal_session is not None
61
+ self.current_animal_session.trial_id += 1
62
+
63
+ @property
64
+ def current_animal_session(self) -> AnimalSessionState | None:
65
+ return self._current_animal_session
66
+
67
+ def start_animal_session(self):
68
+ session_id = 1
69
+
70
+ if self._current_animal_session is not None:
71
+ raise ValueError("Animal session is already started")
72
+
73
+ if self._sessions:
74
+ session_id = len(self._sessions) + 1
75
+
76
+ session = AnimalSessionState(session_id=session_id)
77
+ self._current_animal_session = session
78
+ self._sessions.append(session)
79
+
80
+ def end_animal_session(self):
81
+ if self._current_animal_session is None:
82
+ raise ValueError("Animal session is not started")
83
+
84
+ self._current_animal_session.end_at = time()
85
+ self._current_animal_session = None
86
+
87
+ @property
88
+ def current_stage(self) -> StageState:
89
+ key = self._current_stage
90
+
91
+ state = self._stages.get(key)
92
+ if state is None:
93
+ raise ValueError(f"Unknown stage: {key}")
94
+
95
+ return state
96
+
97
+ def set_current_stage(self, stage: StageState | str) -> None:
98
+ if isinstance(stage, str):
99
+ stage = StageState(stage_name=stage)
100
+
101
+ self._current_stage = stage.stage_name
102
+ if stage.stage_name in self._stages:
103
+ return
104
+
105
+ self._stages[stage.stage_name] = stage
106
+
107
+ def clear_current_stage(self) -> None:
108
+ self._current_stage = "idle"
109
+
110
+ def level_up(self) -> int:
111
+ self.current_stage.level_up()
112
+ return self.current_stage.level
113
+
114
+ def level_down(self) -> int:
115
+ self.current_stage.level_down()
116
+ return self.current_stage.level
117
+
118
+ @property
119
+ def base_info(self) -> AnimalBaseInfo:
120
+ if self.current_animal_session is None:
121
+ raise ValueError("Animal session is not started")
122
+
123
+ return AnimalBaseInfo(
124
+ animal=self.name,
125
+ trial_id=self.trial_id,
126
+ level=self.current_stage.level,
127
+ level_trial_id=self.current_stage.level_trial_id,
128
+ animal_session_id=self.current_animal_session.session_id,
129
+ animal_session_trial_id=self.current_animal_session.trial_id,
130
+ )