autogaita 1.5.2__py3-none-any.whl → 1.5.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,359 @@
1
+ # %% imports
2
+ import autogaita.gui.gaita_widgets as gaita_widgets
3
+ import autogaita.gui.gui_utils as gui_utils
4
+ from autogaita.gui.common2D_columninfo_gui import build_column_info_window
5
+ from autogaita.gui.common2D_advanced_config_gui import build_cfg_window
6
+ from autogaita.gui.common2D_run_and_done_gui import build_run_and_done_windows
7
+ from autogaita.gui.first_level_gui_utils import (
8
+ update_config_file,
9
+ extract_cfg_from_json_file,
10
+ )
11
+
12
+ import tkinter as tk
13
+ import customtkinter as ctk
14
+ import os
15
+ import platform
16
+
17
+ # %% global constants
18
+ from autogaita.gui.gui_constants import (
19
+ DLC_FG_COLOR,
20
+ DLC_HOVER_COLOR,
21
+ TEXT_FONT_NAME,
22
+ TEXT_FONT_SIZE,
23
+ WINDOWS_TASKBAR_MAXHEIGHT,
24
+ AUTOGAITA_FOLDER_PATH,
25
+ get_widget_cfg_dict, # function!
26
+ )
27
+
28
+ # these colors are GUI-specific - add to common widget cfg
29
+ FG_COLOR = DLC_FG_COLOR
30
+ HOVER_COLOR = DLC_HOVER_COLOR
31
+ WIDGET_CFG = get_widget_cfg_dict()
32
+ WIDGET_CFG["FG_COLOR"] = FG_COLOR
33
+ WIDGET_CFG["HOVER_COLOR"] = HOVER_COLOR
34
+
35
+ # gaita-variable related constants
36
+ CONFIG_FILE_NAME = "dlc_gui_config.json"
37
+ FLOAT_VARS = ["pixel_to_mm_ratio"]
38
+ INT_VARS = [
39
+ "sampling_rate",
40
+ "x_sc_broken_threshold",
41
+ "y_sc_broken_threshold",
42
+ "bin_num",
43
+ "mouse_num",
44
+ "run_num",
45
+ "plot_joint_number",
46
+ ]
47
+ LIST_VARS = [
48
+ "hind_joints",
49
+ "fore_joints",
50
+ "x_standardisation_joint",
51
+ "y_standardisation_joint",
52
+ "beam_hind_jointadd",
53
+ "beam_fore_jointadd",
54
+ "beam_col_left",
55
+ "beam_col_right",
56
+ ]
57
+ DICT_VARS = ["angles"]
58
+ # TK_BOOL/STR_VARS are only used for initialising widgets based on cfg file
59
+ # (note that numbers are initialised as strings)
60
+ TK_BOOL_VARS = [
61
+ "subtract_beam",
62
+ "dont_show_plots",
63
+ "convert_to_mm",
64
+ "x_acceleration",
65
+ "angular_acceleration",
66
+ "save_to_xls",
67
+ "plot_SE",
68
+ "standardise_y_at_SC_level",
69
+ "standardise_y_to_a_joint",
70
+ "standardise_x_coordinates",
71
+ "invert_y_axis",
72
+ "flip_gait_direction",
73
+ "analyse_average_x",
74
+ "legend_outside",
75
+ ]
76
+ TK_STR_VARS = [
77
+ "mouse_num", # (config file's) results dict
78
+ "run_num",
79
+ "root_dir",
80
+ "sctable_filename",
81
+ "data_string",
82
+ "beam_string",
83
+ "premouse_string",
84
+ "postmouse_string",
85
+ "prerun_string",
86
+ "postrun_string",
87
+ "sampling_rate", # (config file's) cfg dict
88
+ "pixel_to_mm_ratio",
89
+ "x_sc_broken_threshold",
90
+ "y_sc_broken_threshold",
91
+ "bin_num",
92
+ "plot_joint_number",
93
+ "color_palette",
94
+ "results_dir",
95
+ ]
96
+ GUI_SPECIFIC_VARS = {
97
+ "CONFIG_FILE_NAME": CONFIG_FILE_NAME,
98
+ "FLOAT_VARS": FLOAT_VARS,
99
+ "INT_VARS": INT_VARS,
100
+ "LIST_VARS": LIST_VARS,
101
+ "DICT_VARS": DICT_VARS,
102
+ "TK_BOOL_VARS": TK_BOOL_VARS,
103
+ "TK_STR_VARS": TK_STR_VARS,
104
+ }
105
+
106
+
107
+ # %% An important Note
108
+ # I am using a global variable called cfg because I need its info to be shared
109
+ # between root and advanced_cfg windows. This is not the object-oriented way
110
+ # that one would do this typically. However, it works as expected since:
111
+ # 1) cfg's values are only ever modified except @ initialisation & by widgets
112
+ # 2) cfg's values are shared for analysis of a single and multiple video(s)
113
+ # 3) cfg's values are passed to all functions that need them
114
+ # 4) and (IMPORTANTLY!) just before running either (i.e. single/multi) analysis, cfg's
115
+ # and result's values are unpacked and assigned to "this_" dicts that are passed
116
+ # to the runanalysis local function. Hence, from that point onwards, only
117
+ # "this_" dicts are used, never cfg or result dicts themselves.
118
+ # ==> see donewindow for point 4)
119
+
120
+ # %%............................ MAIN PROGRAM ................................
121
+
122
+
123
+ def run_dlc_gui():
124
+ # ..........................................................................
125
+ # ...................... root window initialisation .......................
126
+ # ..........................................................................
127
+ # Check for config file
128
+ config_file_path = os.path.join(AUTOGAITA_FOLDER_PATH, CONFIG_FILE_NAME)
129
+ if not os.path.isfile(config_file_path):
130
+ config_file_error_msg = (
131
+ "dlc_gui_config.json file not found in autogaita folder.\n"
132
+ "Confirm that the file exists and is named correctly.\n"
133
+ "If not, download it again from the GitHub repository."
134
+ )
135
+ tk.messagebox.showerror(
136
+ title="Config File Error", message=config_file_error_msg
137
+ )
138
+ exit()
139
+
140
+ # CustomTkinter vars
141
+ ctk.set_appearance_mode("dark") # Modes: system (default), light, dark
142
+ ctk.set_default_color_theme("green") # Themes: blue , dark-blue, green
143
+ # root
144
+ root = ctk.CTk()
145
+ # make window pretty
146
+ screen_width = root.winfo_screenwidth() # width of the screen
147
+ screen_height = root.winfo_screenheight() # height of the screen
148
+ if platform.system() == "Windows": # adjust for taskbar in windows only
149
+ screen_height -= WINDOWS_TASKBAR_MAXHEIGHT
150
+ # create root dimensions (based on fullscreen) to pass to other window-functions l8r
151
+ w, h, x, y = screen_width, screen_height, 0, 0
152
+ root_dimensions = (w, h, x, y)
153
+ # set the dimensions of the screen and where it is placed
154
+ # => have it half-wide starting at 1/4 of screen's width (dont change w & x!)
155
+ root.geometry(f"{int(screen_width / 2)}x{screen_height}+{int(screen_width / 4)}+0")
156
+ root.title("DLC GaitA")
157
+ gui_utils.fix_window_after_its_creation(root)
158
+ gui_utils.configure_the_icon(root)
159
+
160
+ # ..................... load cfg dict from config .....................
161
+ # use the values in the config json file for the results dictionary
162
+ global cfg
163
+ cfg = extract_cfg_from_json_file(
164
+ root,
165
+ AUTOGAITA_FOLDER_PATH,
166
+ CONFIG_FILE_NAME,
167
+ LIST_VARS,
168
+ DICT_VARS,
169
+ TK_STR_VARS,
170
+ TK_BOOL_VARS,
171
+ )
172
+
173
+ # ............................... header ..........................................
174
+ # main configuration header
175
+ main_cfg_header_label = gaita_widgets.header_label(
176
+ root,
177
+ "Main Configuration",
178
+ WIDGET_CFG,
179
+ )
180
+ main_cfg_header_label.grid(row=0, column=0, columnspan=3, sticky="nsew")
181
+
182
+ # ............................ main cfg section ..................................
183
+ # sampling rate
184
+ samprate_label, samprate_entry = gaita_widgets.label_and_entry_pair(
185
+ root,
186
+ "Sampling rate of videos in Hertz (frames/second):",
187
+ cfg["sampling_rate"],
188
+ WIDGET_CFG,
189
+ )
190
+ samprate_label.grid(row=1, column=0, columnspan=2, sticky="w")
191
+ samprate_entry.grid(row=1, column=2, sticky="w")
192
+
193
+ # convert pixel to mm - checkbox
194
+ convert_checkbox = gaita_widgets.checkbox(
195
+ root,
196
+ "Convert pixels to millimetres:",
197
+ cfg["convert_to_mm"],
198
+ WIDGET_CFG,
199
+ )
200
+ convert_checkbox.configure(
201
+ command=lambda: gui_utils.change_widget_state_based_on_checkbox(
202
+ cfg, "convert_to_mm", ratio_entry
203
+ ),
204
+ )
205
+ convert_checkbox.grid(row=2, column=0, columnspan=2, sticky="w")
206
+
207
+ # ratio label
208
+ ratio_entry = ctk.CTkEntry(
209
+ root,
210
+ textvariable=cfg["pixel_to_mm_ratio"],
211
+ font=(TEXT_FONT_NAME, TEXT_FONT_SIZE),
212
+ )
213
+ ratio_entry.grid(row=2, column=1, sticky="e")
214
+ ratio_right_string = "pixels = 1 mm"
215
+ ratio_right_label = ctk.CTkLabel(
216
+ root, text=ratio_right_string, font=(TEXT_FONT_NAME, TEXT_FONT_SIZE)
217
+ )
218
+ ratio_right_label.grid(row=2, column=2, sticky="w")
219
+ # to initialise the widget correctly, run this function once
220
+ gui_utils.change_widget_state_based_on_checkbox(cfg, "convert_to_mm", ratio_entry)
221
+
222
+ # subtract beam
223
+ subtract_beam_checkbox = gaita_widgets.checkbox(
224
+ root,
225
+ "Standardise y-coordinates to baseline height (requires to be tracked)",
226
+ cfg["subtract_beam"],
227
+ WIDGET_CFG,
228
+ )
229
+ subtract_beam_checkbox.grid(row=3, column=0, columnspan=3, sticky="w")
230
+
231
+ # flip gait direction
232
+ flip_gait_direction_box = gaita_widgets.checkbox(
233
+ root,
234
+ "Adjust x-coordinates to follow direction of movement",
235
+ cfg["flip_gait_direction"],
236
+ WIDGET_CFG,
237
+ )
238
+ flip_gait_direction_box.grid(row=4, column=0, columnspan=3, sticky="w")
239
+
240
+ # plot plots to python
241
+ showplots_checkbox = gaita_widgets.checkbox(
242
+ root,
243
+ "Don't show plots in Figure GUI (save only)",
244
+ cfg["dont_show_plots"],
245
+ WIDGET_CFG,
246
+ )
247
+ showplots_checkbox.grid(row=5, column=0, columnspan=2, sticky="w")
248
+
249
+ # bin number of SC normalisation
250
+ bin_num_label, bin_num_entry = gaita_widgets.label_and_entry_pair(
251
+ root,
252
+ "Number of bins used to normalise the step cycle:",
253
+ cfg["bin_num"],
254
+ WIDGET_CFG,
255
+ )
256
+ bin_num_label.grid(row=6, column=0, columnspan=2, sticky="w")
257
+ bin_num_entry.grid(row=6, column=2, sticky="w")
258
+
259
+ # empty label 1 (for spacing)
260
+ empty_label_one = ctk.CTkLabel(root, text="")
261
+ empty_label_one.grid(row=7, column=0)
262
+
263
+ # .......................... advanced cfg section ................................
264
+ # advanced header string
265
+ advanced_cfg_header_label = gaita_widgets.header_label(
266
+ root,
267
+ "Advanced Configuration",
268
+ WIDGET_CFG,
269
+ )
270
+ advanced_cfg_header_label.grid(row=8, column=0, columnspan=3, sticky="nsew")
271
+
272
+ # column name information window
273
+ column_info_button = gaita_widgets.header_button(
274
+ root, "Customise Joints & Angles", WIDGET_CFG
275
+ )
276
+ column_info_button.configure(
277
+ command=lambda: build_column_info_window(root, cfg, WIDGET_CFG, root_dimensions)
278
+ )
279
+ column_info_button.grid(row=9, column=0, columnspan=3)
280
+
281
+ # advanced cfg
282
+ cfg_window_button = gaita_widgets.header_button(
283
+ root, "Advanced Configuration", WIDGET_CFG
284
+ )
285
+ cfg_window_button.configure(
286
+ command=lambda: build_cfg_window(root, cfg, WIDGET_CFG, root_dimensions)
287
+ )
288
+ cfg_window_button.grid(row=10, column=0, columnspan=3)
289
+
290
+ # empty label 2 (for spacing)
291
+ empty_label_two = ctk.CTkLabel(root, text="")
292
+ empty_label_two.grid(row=11, column=0)
293
+
294
+ # run analysis label
295
+ runheader_label = gaita_widgets.header_label(root, "Run Analysis", WIDGET_CFG)
296
+ runheader_label.grid(row=12, column=0, columnspan=3, sticky="nsew")
297
+
298
+ # single gaita button
299
+ onevid_button = gaita_widgets.header_button(root, "One Video", WIDGET_CFG)
300
+ onevid_button.configure(
301
+ command=lambda: build_run_and_done_windows(
302
+ "DLC", "single", root, cfg, WIDGET_CFG, GUI_SPECIFIC_VARS, root_dimensions
303
+ )
304
+ )
305
+ onevid_button.grid(row=13, column=1, sticky="ew")
306
+
307
+ # multi gaita button
308
+ multivid_button = gaita_widgets.header_button(root, "Batch Analysis", WIDGET_CFG)
309
+ multivid_button.configure(
310
+ command=lambda: build_run_and_done_windows(
311
+ "DLC", "multi", root, cfg, WIDGET_CFG, GUI_SPECIFIC_VARS, root_dimensions
312
+ )
313
+ )
314
+ multivid_button.grid(row=14, column=1, sticky="ew")
315
+
316
+ # empty label 2 (for spacing)
317
+ empty_label_two = ctk.CTkLabel(root, text="")
318
+ empty_label_two.grid(row=15, column=0)
319
+
320
+ # close & exit button
321
+ exit_button = gaita_widgets.exit_button(root, WIDGET_CFG)
322
+ exit_button.configure(
323
+ command=lambda: (
324
+ # results variable is only defined later in populate_run_window()
325
+ # therefore only cfg settings will be updated
326
+ update_config_file(
327
+ "results dict not defined yet",
328
+ cfg,
329
+ AUTOGAITA_FOLDER_PATH,
330
+ CONFIG_FILE_NAME,
331
+ LIST_VARS,
332
+ DICT_VARS,
333
+ TK_STR_VARS,
334
+ TK_BOOL_VARS,
335
+ ),
336
+ root.withdraw(),
337
+ root.after(5000, root.destroy),
338
+ ),
339
+ )
340
+ exit_button.grid(row=16, column=0, columnspan=3)
341
+
342
+ # # ......................... widget configuration ...............................
343
+
344
+ # first maximise everything according to sticky
345
+ # => Silent_Creme is some undocumented option that makes stuff uniform
346
+ # see: https://stackoverflow.com/questions/45847313/tkinter-grid-rowconfigure-weight-doesnt-work
347
+ root.columnconfigure(list(range(3)), weight=1, uniform="Silent_Creme")
348
+ root.rowconfigure(list(range(17)), weight=1, uniform="Silent_Creme")
349
+
350
+ # then un-maximise main config rows to have them grouped together
351
+ root.rowconfigure(list(range(1, 7)), weight=0)
352
+
353
+ # main loop
354
+ root.mainloop()
355
+
356
+
357
+ # %% what happens if we hit run
358
+ if __name__ == "__main__":
359
+ run_dlc_gui()
@@ -0,0 +1,303 @@
1
+ # %% imports
2
+ from autogaita.resources.utils import write_issues_to_textfile
3
+ from autogaita.common2D.common2D_1_preparation import (
4
+ check_and_expand_cfg,
5
+ flip_mouse_body,
6
+ )
7
+ import os
8
+ import shutil
9
+ import json
10
+ import pandas as pd
11
+ import numpy as np
12
+ import h5py
13
+ import pdb
14
+
15
+ # %% constants
16
+ from autogaita.resources.constants import (
17
+ TIME_COL,
18
+ ISSUES_TXT_FILENAME,
19
+ CONFIG_JSON_FILENAME,
20
+ )
21
+
22
+
23
+ # %% workflow step #1 - preparation
24
+
25
+
26
+ def some_prep(info, folderinfo, cfg):
27
+ """Preparation of the data & cfg file for later analyses"""
28
+
29
+ # ............................ unpack stuff ......................................
30
+ # => DON'T unpack (joint) cfg-keys that are tested later by check_and_expand_cfg
31
+ # SLEAP-specific NOTE
32
+ # => I commented out vars that we dont need but might need in the future
33
+ name = info["name"]
34
+ results_dir = info["results_dir"]
35
+ data_string = folderinfo["data_string"]
36
+ beam_string = folderinfo["beam_string"]
37
+ sampling_rate = cfg["sampling_rate"]
38
+
39
+ # VERY
40
+ # VERY
41
+ # VERY
42
+ # IMPORTANT NOTE
43
+
44
+ # => subtract_beam is hardcoded to False until I have data that allows me to test
45
+ # it properly (same for gait direction flipping @ end of this function)
46
+ cfg["subtract_beam"] = False
47
+ subtract_beam = cfg["subtract_beam"]
48
+ convert_to_mm = cfg["convert_to_mm"]
49
+ pixel_to_mm_ratio = cfg["pixel_to_mm_ratio"]
50
+ standardise_y_at_SC_level = cfg["standardise_y_at_SC_level"]
51
+ # invert_y_axis = cfg["invert_y_axis"]
52
+ flip_gait_direction = cfg["flip_gait_direction"]
53
+ analyse_average_x = cfg["analyse_average_x"]
54
+ standardise_x_coordinates = cfg["standardise_x_coordinates"]
55
+ standardise_y_to_a_joint = cfg["standardise_y_to_a_joint"]
56
+
57
+ # ............................. move data ........................................
58
+ # => see if we can delete a previous runs results folder if existant. if not, it's a
59
+ # bit ugly since we only update results if filenames match...
60
+ # => for example if angle acceleration not wanted in current run, but was stored in
61
+ # previous run, the previous run's figure is in the folder
62
+ # => inform the user and leave this as is
63
+ if os.path.exists(results_dir):
64
+ try:
65
+ shutil.rmtree(results_dir)
66
+ move_data_to_folders(info, folderinfo)
67
+ except OSError:
68
+ move_data_to_folders(info, folderinfo)
69
+ unable_to_rm_resdir_error = (
70
+ "\n***********\n! WARNING !\n***********\n"
71
+ + "Unable to remove previous Results subfolder of ID: "
72
+ + name
73
+ + "!\n Results will only be updated if filenames match!"
74
+ )
75
+ print(unable_to_rm_resdir_error)
76
+ write_issues_to_textfile(unable_to_rm_resdir_error, info)
77
+ else:
78
+ move_data_to_folders(info, folderinfo)
79
+
80
+ # ....... initialise Issues.txt & quick check for file existence .................
81
+ # Issues.txt - delete if saved in a previous run
82
+ issues_txt_path = os.path.join(results_dir, ISSUES_TXT_FILENAME)
83
+ if os.path.exists(issues_txt_path):
84
+ os.remove(issues_txt_path)
85
+ # read data & beam
86
+ if not os.listdir(results_dir):
87
+ no_files_error = (
88
+ "\n******************\n! CRITICAL ERROR !\n******************\n"
89
+ + "Unable to identify ANY RELEVANT FILES for "
90
+ + name
91
+ )
92
+ write_issues_to_textfile(no_files_error, info)
93
+ print(no_files_error)
94
+ return
95
+
96
+ # ............................ import data .......................................
97
+ # initialise dfs for user error handling
98
+ datadf = pd.DataFrame(data=None)
99
+ datadf_duplicate_error = ""
100
+ beamdf = pd.DataFrame(data=None)
101
+ beamdf_duplicate_error = ""
102
+ # loop through folder and import data
103
+ for filename in os.listdir(results_dir):
104
+ if name + data_string + ".h5" in filename:
105
+ if datadf.empty:
106
+ datadf = h5_to_df(results_dir, filename)
107
+ else:
108
+ datadf_duplicate_error = (
109
+ "\n******************\n! CRITICAL ERROR !\n******************\n"
110
+ + "Multiple DATA .h5 files found for "
111
+ + name
112
+ + "!\nPlease make sure to only have one data file per ID."
113
+ )
114
+ if subtract_beam and name + beam_string + ".h5" in filename:
115
+ if beamdf.empty:
116
+ beamdf = h5_to_df(results_dir, filename)
117
+ else:
118
+ beamdf_duplicate_error = (
119
+ "\n******************\n! CRITICAL ERROR !\n******************\n"
120
+ + "Multiple BEAM .h5 files found for "
121
+ + name
122
+ + "!\nPlease make sure to only have one beam file per ID."
123
+ )
124
+ # handle errors now
125
+ import_error_message = ""
126
+ if datadf_duplicate_error:
127
+ import_error_message += datadf_duplicate_error
128
+ if datadf.empty:
129
+ import_error_message += (
130
+ "\n******************\n! CRITICAL ERROR !\n******************\n"
131
+ + "No DATA .h5 file found for "
132
+ + name
133
+ + "!\nTry again!"
134
+ )
135
+ if subtract_beam:
136
+ if beamdf_duplicate_error:
137
+ import_error_message += beamdf_duplicate_error
138
+ if beamdf.empty:
139
+ import_error_message += (
140
+ "\n******************\n! CRITICAL ERROR !\n******************\n"
141
+ + "No BEAM .h5 file found for "
142
+ + name
143
+ + "!\nTry again!"
144
+ )
145
+ if import_error_message:
146
+ write_issues_to_textfile(import_error_message, info)
147
+ print(import_error_message)
148
+ return # make sure to stop execution if there is an issue!
149
+ # create "data" as floats and depending on whether we subtracted beam or not
150
+ if subtract_beam:
151
+ data = pd.concat([datadf, beamdf], axis=1)
152
+ else:
153
+ data = datadf.copy(deep=True)
154
+ data = data.astype(float)
155
+
156
+ # ................ final data checks, conversions & additions ....................
157
+ # IMPORTANT - MAIN TESTS OF USER-INPUT VALIDITY OCCUR HERE!
158
+ # => UNPACK VARS FROM CFG THAT ARE TESTED BY check_and_expand HERE, NOT EARLIER!
159
+ cfg = check_and_expand_cfg(data, cfg, info)
160
+ if cfg is None: # some critical error occured
161
+ return
162
+ hind_joints = cfg["hind_joints"]
163
+ fore_joints = cfg["fore_joints"]
164
+ angles = cfg["angles"]
165
+ beam_hind_jointadd = cfg["beam_hind_jointadd"]
166
+ beam_fore_jointadd = cfg["beam_fore_jointadd"]
167
+ direction_joint = cfg["direction_joint"]
168
+ # important to unpack to vars hand not to cfg since cfg is overwritten in multiruns!
169
+ x_standardisation_joint = cfg["x_standardisation_joint"][0]
170
+ y_standardisation_joint = cfg["y_standardisation_joint"][0]
171
+ # store config json file @ group path
172
+ # !!! NU - do this @ mouse path!
173
+ group_path = results_dir.split(name)[0]
174
+ config_json_path = os.path.join(group_path, CONFIG_JSON_FILENAME)
175
+ config_vars_to_json = {
176
+ "sampling_rate": sampling_rate,
177
+ "convert_to_mm": convert_to_mm,
178
+ "standardise_y_at_SC_level": standardise_y_at_SC_level,
179
+ "analyse_average_x": analyse_average_x,
180
+ "standardise_x_coordinates": standardise_x_coordinates,
181
+ "x_standardisation_joint": x_standardisation_joint,
182
+ "standardise_y_to_a_joint": standardise_y_to_a_joint,
183
+ "y_standardisation_joint": y_standardisation_joint,
184
+ "hind_joints": hind_joints,
185
+ "fore_joints": fore_joints,
186
+ "angles": angles,
187
+ "tracking_software": "SLEAP",
188
+ }
189
+ # note - using "w" will overwrite/truncate file, thus no need to remove it if exists
190
+ with open(config_json_path, "w") as config_json_file:
191
+ json.dump(config_vars_to_json, config_json_file, indent=4)
192
+ # if we don't have a beam to subtract, standardise y to a joint's or global ymin = 0
193
+ if not subtract_beam:
194
+ y_min = float("inf")
195
+ y_cols = [col for col in data.columns if col.endswith("y")]
196
+ if standardise_y_to_a_joint:
197
+ y_min = data[y_standardisation_joint + "y"].min()
198
+ else:
199
+ y_min = data[y_cols].min().min()
200
+ data[y_cols] -= y_min
201
+ # convert pixels to millimeters
202
+ if convert_to_mm:
203
+ for column in data.columns:
204
+ # if might be unnecessary but I'm cautious as I don't know SLEAP data much
205
+ if column.endswith("x") or column.endswith("y"):
206
+ data[column] = data[column] / pixel_to_mm_ratio
207
+
208
+ # IMPORTANT NOTE
209
+ # => I keep gait direction flipping commented out until receiving data that allows
210
+ # me to test it properly
211
+ # => Note that subtract_beam is hardcoded to False above for the same reason
212
+
213
+ pdb.set_trace()
214
+
215
+ # quick warning if cfg is set to not flip gait direction but to standardise x
216
+ if not flip_gait_direction and standardise_x_coordinates:
217
+ message = (
218
+ "\n***********\n! WARNING !\n***********\n"
219
+ + "You are standardising x-coordinates without standardising the direction "
220
+ + "of gait (e.g. all walking from right to left)."
221
+ + "\nThis can be correct if you are doing things like treadmill walking "
222
+ + "but can lead to unexpected behaviour otherwise!"
223
+ + "\nMake sure you know what you are doing!"
224
+ )
225
+ print(message)
226
+ write_issues_to_textfile(message, info)
227
+
228
+ # check gait direction
229
+ # data = check_gait_direction(data, direction_joint, flip_gait_direction, info)
230
+ data["Flipped"] = False # because of IMPORTANT NOTE above
231
+
232
+ # subtract the beam from the joints to standardise y
233
+ # => bc. we simulate that all mice run from left to right, we can write:
234
+ # (note that we also flip beam x columns, but never y-columns!)
235
+ # => & bc. we multiply y values by *-1 earlier, it's a neg_num - - neg_num
236
+ # pushing it towards zero.
237
+ # => using list(set()) to ensure that we don't have duplicate values (if users
238
+ # should have provided them in both cfg vars by misstake)
239
+ # => beam_col_left and right is provided by users
240
+ if subtract_beam:
241
+ # note beam_col_left/right are always lists in cfg!
242
+ beam_col_left = cfg["beam_col_left"][0]
243
+ beam_col_right = cfg["beam_col_right"][0]
244
+ for joint in list(set(hind_joints + beam_hind_jointadd)):
245
+ data[joint + "y"] = data[joint + "y"] - data[beam_col_left + "y"]
246
+ for joint in list(set(fore_joints + beam_fore_jointadd)):
247
+ data[joint + "y"] = data[joint + "y"] - data[beam_col_right + "y"]
248
+ data.drop(columns=list(beamdf.columns), inplace=True) # beam not needed anymore
249
+ # add Time
250
+ data[TIME_COL] = data.index * (1 / sampling_rate)
251
+ # reorder the columns we added
252
+ cols = [TIME_COL, "Flipped"]
253
+ data = data[cols + [c for c in data.columns if c not in cols]]
254
+ return data
255
+
256
+
257
+ # .............................. helper functions ....................................
258
+
259
+
260
+ def move_data_to_folders(info, folderinfo):
261
+ """Copy data to new results_dir"""
262
+ # unpack
263
+ name = info["name"]
264
+ results_dir = info["results_dir"]
265
+ root_dir = folderinfo["root_dir"]
266
+ data_string = folderinfo["data_string"]
267
+ os.makedirs(results_dir)
268
+ # move h5 files
269
+ for filename in os.listdir(root_dir):
270
+ if name + data_string + ".h5" in filename:
271
+ shutil.copy2(
272
+ os.path.join(root_dir, filename),
273
+ os.path.join(results_dir, filename),
274
+ )
275
+
276
+
277
+ def h5_to_df(results_dir, filename):
278
+ """Convert a SLEAP h5 file to the pandas dataframe used in gaita"""
279
+ df = pd.DataFrame(data=None)
280
+ with h5py.File(os.path.join(results_dir, filename), "r") as f:
281
+ locations = f["tracks"][:].T
282
+ node_names = [n.decode() for n in f["node_names"][:]]
283
+ df.index = np.arange(np.shape(locations)[0])
284
+ for node_idx, node_name in enumerate(node_names):
285
+ for c, coord in enumerate(["x", "y"]):
286
+ df[node_name + " " + coord] = locations[:, node_idx, c, 0]
287
+ return df
288
+
289
+
290
+ def check_gait_direction(data, direction_joint, flip_gait_direction, info):
291
+ """Check direction of gait - reverse it if needed"""
292
+
293
+ data["Flipped"] = False
294
+
295
+ # beloow if condition means that the mouse ran from right to left
296
+ # => in this case we flip
297
+ if np.median(data[direction_joint + "x"][: len(data) // 2]) > np.median(
298
+ data[direction_joint + "x"][len(data) // 2 :]
299
+ ):
300
+ if flip_gait_direction:
301
+ data = flip_mouse_body(data, info)
302
+ data["Flipped"] = True
303
+ return data