mxbiflow 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mxbiflow/__init__.py +3 -0
- mxbiflow/assets/__init__.py +5 -0
- mxbiflow/assets/clicker.wav +0 -0
- mxbiflow/config_store.py +68 -0
- mxbiflow/data_logger.py +114 -0
- mxbiflow/default/__init__.py +4 -0
- mxbiflow/default/idle/assets/apple_v1.png +0 -0
- mxbiflow/default/idle/idle.py +57 -0
- mxbiflow/detector_bridge.py +87 -0
- mxbiflow/game.py +84 -0
- mxbiflow/infra/eventbus.py +31 -0
- mxbiflow/main.py +106 -0
- mxbiflow/models/animal.py +130 -0
- mxbiflow/models/reward.py +7 -0
- mxbiflow/models/session.py +145 -0
- mxbiflow/mxbiflow.py +43 -0
- mxbiflow/path.py +41 -0
- mxbiflow/scene/__init__.py +8 -0
- mxbiflow/scene/scene_manager.py +64 -0
- mxbiflow/scene/scene_protocol.py +22 -0
- mxbiflow/scheduler.py +90 -0
- mxbiflow/tasks/GNGSiD/models.py +70 -0
- mxbiflow/tasks/GNGSiD/stages/detect_stage/config.json +116 -0
- mxbiflow/tasks/GNGSiD/stages/detect_stage/detect_stage.py +161 -0
- mxbiflow/tasks/GNGSiD/stages/detect_stage/detect_stage_models.py +65 -0
- mxbiflow/tasks/GNGSiD/stages/discriminate_stage/config.json +70 -0
- mxbiflow/tasks/GNGSiD/stages/discriminate_stage/discriminate_stage.py +173 -0
- mxbiflow/tasks/GNGSiD/stages/discriminate_stage/discriminate_stage_models.py +80 -0
- mxbiflow/tasks/GNGSiD/stages/size_reduction_stage/config.json +83 -0
- mxbiflow/tasks/GNGSiD/stages/size_reduction_stage/size_reduction_models.py +58 -0
- mxbiflow/tasks/GNGSiD/stages/size_reduction_stage/size_reduction_stage.py +149 -0
- mxbiflow/tasks/GNGSiD/tasks/artifacts.py +13 -0
- mxbiflow/tasks/GNGSiD/tasks/detect/models.py +21 -0
- mxbiflow/tasks/GNGSiD/tasks/detect/scene.py +271 -0
- mxbiflow/tasks/GNGSiD/tasks/discriminate/discriminate_models.py +31 -0
- mxbiflow/tasks/GNGSiD/tasks/discriminate/discriminate_scene.py +336 -0
- mxbiflow/tasks/GNGSiD/tasks/touch/touch_models.py +17 -0
- mxbiflow/tasks/GNGSiD/tasks/touch/touch_scene.py +256 -0
- mxbiflow/tasks/GNGSiD/tasks/utils/targets.py +57 -0
- mxbiflow/tasks/cross_modal/bundle_dir.py +553 -0
- mxbiflow/tasks/cross_modal/config.py +41 -0
- mxbiflow/tasks/cross_modal/media.py +61 -0
- mxbiflow/tasks/cross_modal/models.py +57 -0
- mxbiflow/tasks/cross_modal/scene.py +252 -0
- mxbiflow/tasks/cross_modal/stage.py +218 -0
- mxbiflow/tasks/cross_modal/trial_io.py +23 -0
- mxbiflow/tasks/cross_modal/trial_schema.py +113 -0
- mxbiflow/tasks/default/error_task/error_scene.py +53 -0
- mxbiflow/tasks/default/idle_task/assets/apple_v1.png +0 -0
- mxbiflow/tasks/default/idle_task/idle_scene.py +85 -0
- mxbiflow/tasks/default/initial_habituation_training/README.md +188 -0
- mxbiflow/tasks/default/initial_habituation_training/stages/config.csv +7 -0
- mxbiflow/tasks/default/initial_habituation_training/stages/config.json +67 -0
- mxbiflow/tasks/default/initial_habituation_training/stages/initial_habituation_training_stage.py +172 -0
- mxbiflow/tasks/default/initial_habituation_training/stages/models.py +56 -0
- mxbiflow/tasks/default/initial_habituation_training/tasks/stay_to_reward/stay_to_reward.py +244 -0
- mxbiflow/tasks/default/initial_habituation_training/tasks/stay_to_reward/stay_to_reward_models.py +50 -0
- mxbiflow/tasks/task_protocol.py +26 -0
- mxbiflow/tasks/task_table.py +29 -0
- mxbiflow/tasks/two_alternative_choice/assets/starter.py +27 -0
- mxbiflow/tasks/two_alternative_choice/models.py +68 -0
- mxbiflow/tasks/two_alternative_choice/stages/size_reduction_stage/config.json +118 -0
- mxbiflow/tasks/two_alternative_choice/stages/size_reduction_stage/size_reduction_models.py +41 -0
- mxbiflow/tasks/two_alternative_choice/stages/size_reduction_stage/size_reduction_stage.py +122 -0
- mxbiflow/tasks/two_alternative_choice/tasks/touch/touch_models.py +19 -0
- mxbiflow/tasks/two_alternative_choice/tasks/touch/touch_scene.py +249 -0
- mxbiflow/timer/__init__.py +3 -0
- mxbiflow/timer/frame_timer.py +47 -0
- mxbiflow/timer/realtime_timer.py +0 -0
- mxbiflow/tmp_email.py +13 -0
- mxbiflow/ui/components/animal.py +87 -0
- mxbiflow/ui/components/baseconfig.py +68 -0
- mxbiflow/ui/components/card.py +18 -0
- mxbiflow/ui/components/device_card/__init__.py +17 -0
- mxbiflow/ui/components/device_card/detector/beambreak_detector_card.py +29 -0
- mxbiflow/ui/components/device_card/detector/fusion_detector.py +45 -0
- mxbiflow/ui/components/device_card/detector/mock_detector_card.py +20 -0
- mxbiflow/ui/components/device_card/detector/rfid_detector.py +40 -0
- mxbiflow/ui/components/device_card/device_card.py +67 -0
- mxbiflow/ui/components/device_card/rewarder/mock_rewarder_card.py +20 -0
- mxbiflow/ui/components/device_card/rewarder/rpi_gpio_rewarder.py +33 -0
- mxbiflow/ui/components/devices.py +183 -0
- mxbiflow/ui/components/dialog/__init__.py +3 -0
- mxbiflow/ui/components/dialog/add_devices_dialog.py +64 -0
- mxbiflow/ui/components/experiment_groups.py +122 -0
- mxbiflow/ui/experiment_panel.py +91 -0
- mxbiflow/ui/mxbi_panel.py +152 -0
- mxbiflow/utils/logger.py +19 -0
- mxbiflow/utils/serial.py +10 -0
- mxbiflow-0.1.1.dist-info/METADATA +168 -0
- mxbiflow-0.1.1.dist-info/RECORD +93 -0
- mxbiflow-0.1.1.dist-info/WHEEL +4 -0
- mxbiflow-0.1.1.dist-info/entry_points.txt +4 -0
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
from enum import StrEnum, auto
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class CrossModalOutcome(StrEnum):
|
|
7
|
+
CORRECT = auto()
|
|
8
|
+
INCORRECT = auto()
|
|
9
|
+
TIMEOUT = auto()
|
|
10
|
+
ABORTED = auto()
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class CrossModalResultRecord(BaseModel):
|
|
14
|
+
timestamp: float
|
|
15
|
+
session_id: int | None
|
|
16
|
+
subject_id: str
|
|
17
|
+
partner_id: str
|
|
18
|
+
|
|
19
|
+
trial_id: str
|
|
20
|
+
trial_number: int
|
|
21
|
+
|
|
22
|
+
call_identity_id: str
|
|
23
|
+
call_category: str
|
|
24
|
+
is_partner_call: bool
|
|
25
|
+
|
|
26
|
+
other_identity_id: str
|
|
27
|
+
other_category: str
|
|
28
|
+
|
|
29
|
+
partner_side: str
|
|
30
|
+
correct_side: str
|
|
31
|
+
|
|
32
|
+
audio_identity_id: str
|
|
33
|
+
audio_index: int
|
|
34
|
+
audio_path: str
|
|
35
|
+
|
|
36
|
+
left_image_identity_id: str
|
|
37
|
+
left_image_index: int
|
|
38
|
+
left_image_path: str
|
|
39
|
+
|
|
40
|
+
right_image_identity_id: str
|
|
41
|
+
right_image_index: int
|
|
42
|
+
right_image_path: str
|
|
43
|
+
|
|
44
|
+
chosen_side: str | None
|
|
45
|
+
chosen_identity_id: str | None
|
|
46
|
+
|
|
47
|
+
outcome: CrossModalOutcome
|
|
48
|
+
|
|
49
|
+
trial_start_time: float | None
|
|
50
|
+
choice_time: float | None
|
|
51
|
+
latency_sec: float | None
|
|
52
|
+
|
|
53
|
+
choice_x: int | None
|
|
54
|
+
choice_y: int | None
|
|
55
|
+
|
|
56
|
+
aborted: bool
|
|
57
|
+
timeout: bool
|
|
@@ -0,0 +1,252 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from time import time
|
|
5
|
+
from tkinter import CENTER, Canvas, Event
|
|
6
|
+
from typing import TYPE_CHECKING
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
from mxbi.tasks.cross_modal.config import CrossModalConfig
|
|
10
|
+
from mxbi.utils.logger import logger
|
|
11
|
+
from mxbi.utils.tkinter.components.canvas_with_border import CanvasWithInnerBorder
|
|
12
|
+
from mxbi.utils.tkinter.components.showdata_widget import ShowDataWidget
|
|
13
|
+
from numpy.typing import NDArray
|
|
14
|
+
from PIL import ImageTk
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from mxbi.models.animal import AnimalState
|
|
18
|
+
from mxbi.models.session import ScreenConfig, SessionState
|
|
19
|
+
from mxbi.tasks.cross_modal.trial_schema import Trial
|
|
20
|
+
from mxbi.theater import Theater
|
|
21
|
+
from PIL.Image import Image as PILImage
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclass
|
|
25
|
+
class CrossModalResult:
|
|
26
|
+
chosen_side: str | None
|
|
27
|
+
timeout: bool
|
|
28
|
+
feedback: bool
|
|
29
|
+
cancelled: bool
|
|
30
|
+
trial_start_time: float | None
|
|
31
|
+
choice_time: float | None
|
|
32
|
+
choice_x: int | None
|
|
33
|
+
choice_y: int | None
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class CrossModalScene:
|
|
37
|
+
def __init__(
|
|
38
|
+
self,
|
|
39
|
+
theater: "Theater",
|
|
40
|
+
session_state: "SessionState",
|
|
41
|
+
animal_state: "AnimalState",
|
|
42
|
+
screen: "ScreenConfig",
|
|
43
|
+
trial: "Trial",
|
|
44
|
+
cross_modal_config: CrossModalConfig,
|
|
45
|
+
left_image: "PILImage",
|
|
46
|
+
right_image: "PILImage",
|
|
47
|
+
audio_stimulus: "NDArray[np.int16]",
|
|
48
|
+
) -> None:
|
|
49
|
+
self._theater = theater
|
|
50
|
+
self._session_state = session_state
|
|
51
|
+
self._animal_state = animal_state
|
|
52
|
+
self._screen = screen
|
|
53
|
+
self._trial = trial
|
|
54
|
+
self._cross_modal_config = cross_modal_config
|
|
55
|
+
|
|
56
|
+
self._left_image_pil = left_image
|
|
57
|
+
self._right_image_pil = right_image
|
|
58
|
+
self._audio_stimulus = audio_stimulus
|
|
59
|
+
|
|
60
|
+
self._background: CanvasWithInnerBorder | None = None
|
|
61
|
+
self._left_canvas: Canvas | None = None
|
|
62
|
+
self._right_canvas: Canvas | None = None
|
|
63
|
+
self._left_image: ImageTk.PhotoImage | None = None
|
|
64
|
+
self._right_image: ImageTk.PhotoImage | None = None
|
|
65
|
+
self._show_data_widget: ShowDataWidget | None = None
|
|
66
|
+
|
|
67
|
+
self._chosen_side: str | None = None
|
|
68
|
+
self._timeout = False
|
|
69
|
+
self._feedback = False
|
|
70
|
+
self._cancelled = False
|
|
71
|
+
|
|
72
|
+
self._trial_start_time: float | None = None
|
|
73
|
+
self._choice_time: float | None = None
|
|
74
|
+
self._choice_x: int | None = None
|
|
75
|
+
self._choice_y: int | None = None
|
|
76
|
+
|
|
77
|
+
def start(self) -> CrossModalResult:
|
|
78
|
+
self._create_view()
|
|
79
|
+
self._bind_events()
|
|
80
|
+
self._play_audio()
|
|
81
|
+
self._theater.mainloop()
|
|
82
|
+
return CrossModalResult(
|
|
83
|
+
chosen_side=self._chosen_side,
|
|
84
|
+
timeout=self._timeout,
|
|
85
|
+
feedback=self._feedback,
|
|
86
|
+
cancelled=self._cancelled,
|
|
87
|
+
trial_start_time=self._trial_start_time,
|
|
88
|
+
choice_time=self._choice_time,
|
|
89
|
+
choice_x=self._choice_x,
|
|
90
|
+
choice_y=self._choice_y,
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
def cancel(self) -> None:
|
|
94
|
+
if self._cancelled:
|
|
95
|
+
return
|
|
96
|
+
self._cancelled = True
|
|
97
|
+
self._feedback = False
|
|
98
|
+
self._timeout = False
|
|
99
|
+
self._stop_and_close()
|
|
100
|
+
|
|
101
|
+
def _create_view(self) -> None:
|
|
102
|
+
self._background = CanvasWithInnerBorder(
|
|
103
|
+
master=self._theater.root,
|
|
104
|
+
bg="black",
|
|
105
|
+
width=self._screen.width,
|
|
106
|
+
height=self._screen.height,
|
|
107
|
+
border_width=40,
|
|
108
|
+
)
|
|
109
|
+
self._background.place(relx=0.5, rely=0.5, anchor="center")
|
|
110
|
+
self._background.focus_set()
|
|
111
|
+
|
|
112
|
+
self._show_data_widget = ShowDataWidget(self._background)
|
|
113
|
+
self._show_data_widget.place(relx=0, rely=1, anchor="sw")
|
|
114
|
+
self._show_data_widget.show_data(
|
|
115
|
+
{
|
|
116
|
+
"name": self._animal_state.name,
|
|
117
|
+
"id": self._animal_state.trial_id,
|
|
118
|
+
"level_id": self._animal_state.current_level_trial_id,
|
|
119
|
+
"level": self._animal_state.level,
|
|
120
|
+
"rewards": 0,
|
|
121
|
+
"correct": self._animal_state.correct_trial,
|
|
122
|
+
"incorrect": 0,
|
|
123
|
+
"timeout": 0,
|
|
124
|
+
}
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
self._background.create_text(
|
|
128
|
+
self._screen.width / 2,
|
|
129
|
+
self._screen.height / 2,
|
|
130
|
+
text="+",
|
|
131
|
+
fill="white",
|
|
132
|
+
font=("Helvetica", 40),
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
fixation_ms = self._cross_modal_config.timing.fixation_ms
|
|
136
|
+
self._background.after(fixation_ms, self._show_images)
|
|
137
|
+
|
|
138
|
+
def _show_images(self) -> None:
|
|
139
|
+
if self._background is None:
|
|
140
|
+
return
|
|
141
|
+
|
|
142
|
+
self._background.delete("all")
|
|
143
|
+
|
|
144
|
+
img_size = int(self._left_image_pil.size[0])
|
|
145
|
+
|
|
146
|
+
self._left_canvas = Canvas(
|
|
147
|
+
self._background,
|
|
148
|
+
width=img_size,
|
|
149
|
+
height=img_size,
|
|
150
|
+
bg="grey",
|
|
151
|
+
highlightthickness=0,
|
|
152
|
+
)
|
|
153
|
+
self._left_canvas.place(relx=0.25, rely=0.5, anchor=CENTER)
|
|
154
|
+
|
|
155
|
+
self._right_canvas = Canvas(
|
|
156
|
+
self._background,
|
|
157
|
+
width=img_size,
|
|
158
|
+
height=img_size,
|
|
159
|
+
bg="grey",
|
|
160
|
+
highlightthickness=0,
|
|
161
|
+
)
|
|
162
|
+
self._right_canvas.place(relx=0.75, rely=0.5, anchor=CENTER)
|
|
163
|
+
|
|
164
|
+
if self._left_canvas is not None:
|
|
165
|
+
self._left_canvas.bind(
|
|
166
|
+
"<ButtonPress-1>", lambda e: self._on_choice("left", e)
|
|
167
|
+
)
|
|
168
|
+
if self._right_canvas is not None:
|
|
169
|
+
self._right_canvas.bind(
|
|
170
|
+
"<ButtonPress-1>", lambda e: self._on_choice("right", e)
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
self._left_image = ImageTk.PhotoImage(self._left_image_pil)
|
|
174
|
+
self._left_canvas.create_image(
|
|
175
|
+
img_size // 2, img_size // 2, image=self._left_image
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
self._right_image = ImageTk.PhotoImage(self._right_image_pil)
|
|
179
|
+
self._right_canvas.create_image(
|
|
180
|
+
img_size // 2, img_size // 2, image=self._right_image
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
self._background.create_text(
|
|
184
|
+
self._screen.width * 0.25,
|
|
185
|
+
self._screen.height * 0.8,
|
|
186
|
+
text=f"Left: {self._trial.left_image_identity_id}",
|
|
187
|
+
fill="white",
|
|
188
|
+
)
|
|
189
|
+
self._background.create_text(
|
|
190
|
+
self._screen.width * 0.75,
|
|
191
|
+
self._screen.height * 0.8,
|
|
192
|
+
text=f"Right: {self._trial.right_image_identity_id}",
|
|
193
|
+
fill="white",
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
self._trial_start_time = time()
|
|
197
|
+
timeout_ms = self._cross_modal_config.timing.trial_timeout_ms
|
|
198
|
+
self._background.after(timeout_ms, self._on_timeout)
|
|
199
|
+
|
|
200
|
+
def _bind_events(self) -> None:
|
|
201
|
+
if self._background is None:
|
|
202
|
+
return
|
|
203
|
+
self._background.bind("<r>", lambda e: self._give_manual_reward())
|
|
204
|
+
self._background.bind("<s>", lambda e: self._theater.caputre(self._background))
|
|
205
|
+
|
|
206
|
+
def _play_audio(self) -> None:
|
|
207
|
+
try:
|
|
208
|
+
self._theater.acontroller.set_master_volume(
|
|
209
|
+
self._cross_modal_config.audio.master_volume
|
|
210
|
+
)
|
|
211
|
+
self._theater.acontroller.set_digital_volume(
|
|
212
|
+
self._cross_modal_config.audio.digital_volume
|
|
213
|
+
)
|
|
214
|
+
except Exception:
|
|
215
|
+
logger.exception("Failed to set cross-modal audio volume")
|
|
216
|
+
|
|
217
|
+
self._theater.aplayer.play_stimulus(self._audio_stimulus)
|
|
218
|
+
|
|
219
|
+
def _on_choice(self, side: str, event: Event) -> None:
|
|
220
|
+
if self._cancelled or self._chosen_side is not None:
|
|
221
|
+
return
|
|
222
|
+
self._chosen_side = side
|
|
223
|
+
self._choice_time = time()
|
|
224
|
+
self._choice_x = event.x_root
|
|
225
|
+
self._choice_y = event.y_root
|
|
226
|
+
self._timeout = False
|
|
227
|
+
self._feedback = side == self._trial.correct_side
|
|
228
|
+
self._stop_and_close()
|
|
229
|
+
|
|
230
|
+
def _on_timeout(self) -> None:
|
|
231
|
+
if self._cancelled or self._chosen_side is not None:
|
|
232
|
+
return
|
|
233
|
+
self._chosen_side = None
|
|
234
|
+
self._timeout = True
|
|
235
|
+
self._feedback = False
|
|
236
|
+
self._stop_and_close()
|
|
237
|
+
|
|
238
|
+
def _stop_and_close(self) -> None:
|
|
239
|
+
self._theater.aplayer.stop()
|
|
240
|
+
if self._background is not None:
|
|
241
|
+
try:
|
|
242
|
+
self._background.destroy()
|
|
243
|
+
except Exception:
|
|
244
|
+
pass
|
|
245
|
+
self._background = None
|
|
246
|
+
try:
|
|
247
|
+
self._theater.root.quit()
|
|
248
|
+
except Exception:
|
|
249
|
+
pass
|
|
250
|
+
|
|
251
|
+
def _give_manual_reward(self) -> None:
|
|
252
|
+
self._theater.reward.give_reward(500)
|
|
@@ -0,0 +1,218 @@
|
|
|
1
|
+
import time
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import TYPE_CHECKING, Final
|
|
4
|
+
|
|
5
|
+
from mxbi.data_logger import DataLogger
|
|
6
|
+
from mxbi.tasks.cross_modal.bundle_dir import CrossModalBundleDir
|
|
7
|
+
from mxbi.tasks.cross_modal.config import CrossModalConfig, load_cross_modal_config
|
|
8
|
+
from mxbi.tasks.cross_modal.media import load_wav_as_int16
|
|
9
|
+
from mxbi.tasks.cross_modal.models import CrossModalOutcome, CrossModalResultRecord
|
|
10
|
+
from mxbi.tasks.cross_modal.scene import CrossModalResult, CrossModalScene
|
|
11
|
+
from mxbi.tasks.cross_modal.trial_io import TrialCursor
|
|
12
|
+
from mxbi.utils.logger import logger
|
|
13
|
+
from PIL import Image
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
import numpy as np
|
|
17
|
+
from mxbi.models.animal import AnimalState
|
|
18
|
+
from mxbi.models.session import SessionState
|
|
19
|
+
from mxbi.models.task import Feedback
|
|
20
|
+
from mxbi.theater import Theater
|
|
21
|
+
from numpy.typing import NDArray
|
|
22
|
+
from PIL.Image import Image as PILImage
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class CrossModalTask:
|
|
26
|
+
STAGE_NAME: Final[str] = "cross_modal_task"
|
|
27
|
+
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
theater: "Theater",
|
|
31
|
+
session_state: "SessionState",
|
|
32
|
+
animal_state: "AnimalState",
|
|
33
|
+
) -> None:
|
|
34
|
+
self._theater = theater
|
|
35
|
+
self._session_state = session_state
|
|
36
|
+
self._animal_state = animal_state
|
|
37
|
+
self._screen = session_state.session_config.screen_type
|
|
38
|
+
|
|
39
|
+
bundle_dir_str = getattr(
|
|
40
|
+
session_state.session_config, "cross_modal_bundle_dir", None
|
|
41
|
+
)
|
|
42
|
+
if not bundle_dir_str:
|
|
43
|
+
raise RuntimeError(
|
|
44
|
+
"No cross-modal bundle directory configured. "
|
|
45
|
+
"Set session_config.cross_modal_bundle_dir via the launcher."
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
bundle_root = Path(bundle_dir_str).expanduser().resolve()
|
|
49
|
+
self._bundle_dir = CrossModalBundleDir.from_dir_path(bundle_root)
|
|
50
|
+
|
|
51
|
+
subject_id = self._animal_state.name
|
|
52
|
+
trials = self._bundle_dir.load_trials(subject_id)
|
|
53
|
+
if not trials:
|
|
54
|
+
raise RuntimeError(
|
|
55
|
+
f"Bundle contains zero trials for subject '{subject_id}'."
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
self._cursor = TrialCursor(
|
|
59
|
+
bundle_root=self._bundle_dir.root_dir, subject_id=subject_id
|
|
60
|
+
)
|
|
61
|
+
self._trial_index = self._cursor.next_index(len(trials))
|
|
62
|
+
self._trial = trials[self._trial_index]
|
|
63
|
+
|
|
64
|
+
self._cross_modal_config: CrossModalConfig = load_cross_modal_config()
|
|
65
|
+
|
|
66
|
+
image_size = int(
|
|
67
|
+
min(self._screen.width, self._screen.height)
|
|
68
|
+
* self._cross_modal_config.visual.image_scale
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
left_image_path = self._bundle_dir.resolve_media_path(
|
|
72
|
+
self._trial.left_image_path
|
|
73
|
+
)
|
|
74
|
+
right_image_path = self._bundle_dir.resolve_media_path(
|
|
75
|
+
self._trial.right_image_path
|
|
76
|
+
)
|
|
77
|
+
audio_path = self._bundle_dir.resolve_media_path(self._trial.audio_path)
|
|
78
|
+
|
|
79
|
+
left_image = self._prepare_image(
|
|
80
|
+
left_image_path,
|
|
81
|
+
image_size=image_size,
|
|
82
|
+
)
|
|
83
|
+
right_image = self._prepare_image(
|
|
84
|
+
right_image_path,
|
|
85
|
+
image_size=image_size,
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
audio_stimulus = load_wav_as_int16(
|
|
89
|
+
audio_path,
|
|
90
|
+
rate_policy=self._cross_modal_config.audio.wav_rate_policy,
|
|
91
|
+
gain=self._cross_modal_config.audio.gain,
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
self._scene = CrossModalScene(
|
|
95
|
+
theater=self._theater,
|
|
96
|
+
session_state=self._session_state,
|
|
97
|
+
animal_state=self._animal_state,
|
|
98
|
+
screen=self._screen,
|
|
99
|
+
trial=self._trial,
|
|
100
|
+
cross_modal_config=self._cross_modal_config,
|
|
101
|
+
left_image=left_image,
|
|
102
|
+
right_image=right_image,
|
|
103
|
+
audio_stimulus=audio_stimulus,
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
self._data_logger = DataLogger(
|
|
107
|
+
self._session_state, self._animal_state.name, self.STAGE_NAME
|
|
108
|
+
)
|
|
109
|
+
self._feedback = False
|
|
110
|
+
|
|
111
|
+
def start(self) -> "Feedback":
|
|
112
|
+
result = self._scene.start()
|
|
113
|
+
self._feedback = result.feedback
|
|
114
|
+
|
|
115
|
+
if not result.cancelled:
|
|
116
|
+
self._log_trial(result)
|
|
117
|
+
|
|
118
|
+
self._cursor.advance(self._trial_index)
|
|
119
|
+
|
|
120
|
+
logger.debug(
|
|
121
|
+
"cross_modal_task: session_id=%s, subject=%s, level=%s, "
|
|
122
|
+
"trial_index=%s, is_partner=%s, feedback=%s",
|
|
123
|
+
getattr(self._session_state, "session_id", None),
|
|
124
|
+
self._animal_state.name,
|
|
125
|
+
self._animal_state.level,
|
|
126
|
+
self._trial_index + 1,
|
|
127
|
+
self._trial.is_partner_call,
|
|
128
|
+
self._feedback,
|
|
129
|
+
)
|
|
130
|
+
return self._feedback
|
|
131
|
+
|
|
132
|
+
def quit(self) -> None:
|
|
133
|
+
self._scene.cancel()
|
|
134
|
+
|
|
135
|
+
def on_idle(self) -> None:
|
|
136
|
+
self._scene.cancel()
|
|
137
|
+
|
|
138
|
+
def on_return(self) -> None:
|
|
139
|
+
self._scene.cancel()
|
|
140
|
+
|
|
141
|
+
@property
|
|
142
|
+
def condition(self):
|
|
143
|
+
return None
|
|
144
|
+
|
|
145
|
+
def _log_trial(self, result: CrossModalResult) -> None:
|
|
146
|
+
try:
|
|
147
|
+
if result.timeout:
|
|
148
|
+
outcome = CrossModalOutcome.TIMEOUT
|
|
149
|
+
elif result.cancelled:
|
|
150
|
+
outcome = CrossModalOutcome.ABORTED
|
|
151
|
+
elif result.feedback:
|
|
152
|
+
outcome = CrossModalOutcome.CORRECT
|
|
153
|
+
else:
|
|
154
|
+
outcome = CrossModalOutcome.INCORRECT
|
|
155
|
+
|
|
156
|
+
if (
|
|
157
|
+
result.trial_start_time is not None
|
|
158
|
+
and result.choice_time is not None
|
|
159
|
+
and not result.timeout
|
|
160
|
+
):
|
|
161
|
+
latency = result.choice_time - result.trial_start_time
|
|
162
|
+
else:
|
|
163
|
+
latency = None
|
|
164
|
+
|
|
165
|
+
if result.chosen_side == "left":
|
|
166
|
+
chosen_identity = self._trial.left_image_identity_id
|
|
167
|
+
elif result.chosen_side == "right":
|
|
168
|
+
chosen_identity = self._trial.right_image_identity_id
|
|
169
|
+
else:
|
|
170
|
+
chosen_identity = None
|
|
171
|
+
|
|
172
|
+
rec = CrossModalResultRecord(
|
|
173
|
+
timestamp=time.time(),
|
|
174
|
+
session_id=getattr(self._session_state, "session_id", None),
|
|
175
|
+
subject_id=self._trial.subject_id,
|
|
176
|
+
partner_id=self._trial.partner_id,
|
|
177
|
+
trial_id=self._trial.trial_id,
|
|
178
|
+
trial_number=self._trial.trial_number,
|
|
179
|
+
call_identity_id=self._trial.call_identity_id,
|
|
180
|
+
call_category=self._trial.call_category,
|
|
181
|
+
is_partner_call=self._trial.is_partner_call,
|
|
182
|
+
other_identity_id=self._trial.other_identity_id,
|
|
183
|
+
other_category=self._trial.other_category,
|
|
184
|
+
partner_side=self._trial.partner_side,
|
|
185
|
+
correct_side=self._trial.correct_side,
|
|
186
|
+
audio_identity_id=self._trial.audio_identity_id,
|
|
187
|
+
audio_index=self._trial.audio_index,
|
|
188
|
+
audio_path=self._trial.audio_path,
|
|
189
|
+
left_image_identity_id=self._trial.left_image_identity_id,
|
|
190
|
+
left_image_index=self._trial.left_image_index,
|
|
191
|
+
left_image_path=self._trial.left_image_path,
|
|
192
|
+
right_image_identity_id=self._trial.right_image_identity_id,
|
|
193
|
+
right_image_index=self._trial.right_image_index,
|
|
194
|
+
right_image_path=self._trial.right_image_path,
|
|
195
|
+
chosen_side=result.chosen_side,
|
|
196
|
+
chosen_identity_id=chosen_identity,
|
|
197
|
+
outcome=outcome,
|
|
198
|
+
trial_start_time=result.trial_start_time,
|
|
199
|
+
choice_time=result.choice_time,
|
|
200
|
+
latency_sec=latency,
|
|
201
|
+
choice_x=result.choice_x,
|
|
202
|
+
choice_y=result.choice_y,
|
|
203
|
+
aborted=result.cancelled,
|
|
204
|
+
timeout=result.timeout,
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
payload = rec.model_dump()
|
|
208
|
+
self._data_logger.save_jsonl(payload)
|
|
209
|
+
self._data_logger.save_csv_row(payload)
|
|
210
|
+
except Exception:
|
|
211
|
+
logger.exception("Failed to log cross-modal trial")
|
|
212
|
+
|
|
213
|
+
@staticmethod
|
|
214
|
+
def _prepare_image(image_path: Path, *, image_size: int) -> "PILImage":
|
|
215
|
+
with Image.open(image_path) as img:
|
|
216
|
+
prepared = img.convert("RGB").rotate(90, expand=True)
|
|
217
|
+
prepared = prepared.resize((image_size, image_size), Image.LANCZOS)
|
|
218
|
+
return prepared
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
_TRIAL_INDEX_CACHE: dict[tuple[Path, str], int] = {}
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass(frozen=True)
|
|
10
|
+
class TrialCursor:
|
|
11
|
+
bundle_root: Path
|
|
12
|
+
subject_id: str
|
|
13
|
+
|
|
14
|
+
def next_index(self, trial_count: int) -> int:
|
|
15
|
+
if trial_count <= 0:
|
|
16
|
+
raise ValueError("trial_count must be > 0")
|
|
17
|
+
key = (self.bundle_root.resolve(), self.subject_id)
|
|
18
|
+
idx = _TRIAL_INDEX_CACHE.get(key, 0)
|
|
19
|
+
return idx % trial_count
|
|
20
|
+
|
|
21
|
+
def advance(self, last_index: int) -> None:
|
|
22
|
+
key = (self.bundle_root.resolve(), self.subject_id)
|
|
23
|
+
_TRIAL_INDEX_CACHE[key] = last_index + 1
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
from typing import Any, Literal
|
|
3
|
+
|
|
4
|
+
from pydantic import BaseModel, field_validator, model_validator
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class Trial(BaseModel):
|
|
8
|
+
trial_id: str
|
|
9
|
+
trial_number: int
|
|
10
|
+
subject_id: str
|
|
11
|
+
partner_id: str
|
|
12
|
+
|
|
13
|
+
call_identity_id: str
|
|
14
|
+
call_category: str
|
|
15
|
+
is_partner_call: bool
|
|
16
|
+
|
|
17
|
+
other_identity_id: str
|
|
18
|
+
other_category: str
|
|
19
|
+
|
|
20
|
+
partner_side: Literal["left", "right"]
|
|
21
|
+
correct_side: Literal["left", "right"]
|
|
22
|
+
|
|
23
|
+
audio_identity_id: str
|
|
24
|
+
audio_index: int
|
|
25
|
+
audio_path: str
|
|
26
|
+
|
|
27
|
+
left_image_identity_id: str
|
|
28
|
+
left_image_index: int
|
|
29
|
+
left_image_path: str
|
|
30
|
+
|
|
31
|
+
right_image_identity_id: str
|
|
32
|
+
right_image_index: int
|
|
33
|
+
right_image_path: str
|
|
34
|
+
|
|
35
|
+
seed: str
|
|
36
|
+
|
|
37
|
+
@model_validator(mode="before")
|
|
38
|
+
@classmethod
|
|
39
|
+
def _from_bundle_trial(cls, values: Any) -> Any:
|
|
40
|
+
if not isinstance(values, dict):
|
|
41
|
+
return values
|
|
42
|
+
|
|
43
|
+
if "trialId" not in values:
|
|
44
|
+
return values
|
|
45
|
+
|
|
46
|
+
audio = values.get("audio") if isinstance(values.get("audio"), dict) else {}
|
|
47
|
+
left_image = (
|
|
48
|
+
values.get("leftImage") if isinstance(values.get("leftImage"), dict) else {}
|
|
49
|
+
)
|
|
50
|
+
right_image = (
|
|
51
|
+
values.get("rightImage")
|
|
52
|
+
if isinstance(values.get("rightImage"), dict)
|
|
53
|
+
else {}
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
return {
|
|
57
|
+
"trial_id": values.get("trialId"),
|
|
58
|
+
"trial_number": values.get("trialNumber"),
|
|
59
|
+
"subject_id": values.get("subjectId"),
|
|
60
|
+
"partner_id": values.get("partnerId"),
|
|
61
|
+
"call_identity_id": values.get("callIdentityId"),
|
|
62
|
+
"call_category": values.get("callCategory"),
|
|
63
|
+
"is_partner_call": values.get("isPartnerCall"),
|
|
64
|
+
"other_identity_id": values.get("otherIdentityId"),
|
|
65
|
+
"other_category": values.get("otherCategory"),
|
|
66
|
+
"partner_side": values.get("partnerSide"),
|
|
67
|
+
"correct_side": values.get("correctSide"),
|
|
68
|
+
"audio_identity_id": audio.get("identityId"),
|
|
69
|
+
"audio_index": audio.get("exemplarIndex"),
|
|
70
|
+
"audio_path": audio.get("path"),
|
|
71
|
+
"left_image_identity_id": left_image.get("identityId"),
|
|
72
|
+
"left_image_index": left_image.get("exemplarIndex"),
|
|
73
|
+
"left_image_path": left_image.get("path"),
|
|
74
|
+
"right_image_identity_id": right_image.get("identityId"),
|
|
75
|
+
"right_image_index": right_image.get("exemplarIndex"),
|
|
76
|
+
"right_image_path": right_image.get("path"),
|
|
77
|
+
"seed": values.get("seed"),
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
@field_validator("partner_side", "correct_side")
|
|
81
|
+
@classmethod
|
|
82
|
+
def _normalize_sides(cls, v: str) -> str:
|
|
83
|
+
lv = v.strip().lower()
|
|
84
|
+
if lv not in {"left", "right"}:
|
|
85
|
+
raise ValueError("side must be 'left' or 'right'")
|
|
86
|
+
return lv
|
|
87
|
+
|
|
88
|
+
@field_validator("is_partner_call", mode="before")
|
|
89
|
+
@classmethod
|
|
90
|
+
def _boolify(cls, v):
|
|
91
|
+
if isinstance(v, bool):
|
|
92
|
+
return v
|
|
93
|
+
if isinstance(v, int):
|
|
94
|
+
return bool(v)
|
|
95
|
+
if isinstance(v, str):
|
|
96
|
+
lv = v.strip().lower()
|
|
97
|
+
if lv in {"true", "t", "yes", "y", "1"}:
|
|
98
|
+
return True
|
|
99
|
+
if lv in {"false", "f", "no", "n", "0"}:
|
|
100
|
+
return False
|
|
101
|
+
raise ValueError("is_partner_call must be boolean-like")
|
|
102
|
+
|
|
103
|
+
def audio_path_obj(self, base: Path | None = None) -> Path:
|
|
104
|
+
p = Path(self.audio_path)
|
|
105
|
+
return (base / p) if base and not p.is_absolute() else p
|
|
106
|
+
|
|
107
|
+
def left_image_path_obj(self, base: Path | None = None) -> Path:
|
|
108
|
+
p = Path(self.left_image_path)
|
|
109
|
+
return (base / p) if base and not p.is_absolute() else p
|
|
110
|
+
|
|
111
|
+
def right_image_path_obj(self, base: Path | None = None) -> Path:
|
|
112
|
+
p = Path(self.right_image_path)
|
|
113
|
+
return (base / p) if base and not p.is_absolute() else p
|