pymast 0.0.5__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pymast/overlap_removal.py CHANGED
@@ -1,548 +1,2385 @@
1
- # -*- coding: utf-8 -*-
2
- '''
3
- Module contains all of the methods and classes required to identify and remove
4
- overlapping detections from radio telemetry data.
5
- '''
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Bout detection and overlapping detection resolution for radio telemetry data.
4
+
5
+ This module provides two main classes for identifying and resolving overlapping
6
+ detections in radio telemetry studies. Overlapping detections occur when:
7
+ 1. Multiple receivers detect the same fish simultaneously (spatial ambiguity)
8
+ 2. Fish movements violate spatial/temporal constraints (impossible transitions)
9
+ 3. Receiver antenna bleed causes detections from "wrong" direction
10
+
11
+ Core Classes
12
+ ------------
13
+ - **bout**: Detects spatially/temporally clustered detections using DBSCAN
14
+ - **overlap_reduction**: Resolves overlapping detections using signal quality
15
+
16
+ Bout Detection Workflow
17
+ -----------------------
18
+ 1. **DBSCAN Clustering**: Group detections by time and space
19
+ 2. **Bout Assignment**: Label each detection with bout number
20
+ 3. **Presence Matrix**: Create presence/absence by receiver and bout
21
+ 4. **Visualization**: Diagnostic plots for bout length distributions
22
+
23
+ Overlap Resolution Workflow
24
+ ---------------------------
25
+ 1. **Unsupervised Learning**: Compare signal power and posterior probabilities
26
+ 2. **Decision Logic**: Mark weaker overlapping detections
27
+ 3. **Bout Spatial Filter**: Identify bouts with temporal overlap across receivers
28
+ 4. **Write Results**: Store decisions in HDF5 `/overlapping` table
29
+
30
+ Resolution Criteria
31
+ -------------------
32
+ - **Power Comparison**: If both detections have power, keep stronger signal
33
+ - **Posterior Comparison**: If both have classification scores, keep higher posterior
34
+ - **Ambiguous**: Mark ambiguous if signals equal or criteria unavailable
35
+ - **Bout Conflicts**: Identify temporally overlapping bouts at different receivers
36
+
37
+ Output Tables
38
+ -------------
39
+ Creates these HDF5 tables:
40
+
41
+ - `/bouts`: Bout summaries (bout_no, start_time, end_time, detection_count)
42
+ - `/presence`: Presence/absence matrix (fish x bout x receiver)
43
+ - `/overlapping`: Detection-level decisions (overlapping=0/1, ambiguous=0/1)
44
+
45
+ Typical Usage
46
+ -------------
47
+ >>> from pymast.overlap_removal import bout, overlap_reduction
48
+ >>>
49
+ >>> # Detect bouts using DBSCAN
50
+ >>> bout_obj = bout(
51
+ ... db_dir='project.h5',
52
+ ... receiver_dat='receivers.csv',
53
+ ... eps=3600, # 1 hour temporal window
54
+ ... min_samp=1
55
+ ... )
56
+ >>> bout_obj.cluster()
57
+ >>>
58
+ >>> # Resolve overlapping detections
59
+ >>> overlap_obj = overlap_reduction(db_dir='project.h5')
60
+ >>> overlap_obj.unsupervised()
61
+ >>>
62
+ >>> # Visualize results
63
+ >>> overlap_obj.visualize_overlaps()
64
+ >>> bout_obj.visualize_bout_lengths()
65
+
66
+ Notes
67
+ -----
68
+ - DBSCAN parameters (eps, min_samp) control bout sensitivity
69
+ - eps should match expected fish residence time at receiver
70
+ - min_samp=1 treats every detection as potential bout start
71
+ - Bout spatial filter runs automatically after unsupervised()
72
+ - Power and posterior columns are optional (conditionally written)
73
+
74
+ See Also
75
+ --------
76
+ formatter.time_to_event : Uses presence/absence for TTE analysis
77
+ radio_project : Project database management
78
+ """
6
79
 
7
80
  # import modules required for function dependencies
81
+ import os
82
+ import logging
8
83
  import numpy as np
9
84
  import pandas as pd
85
+ from concurrent.futures import ProcessPoolExecutor
86
+ try:
87
+ from tqdm import tqdm
88
+ except ImportError:
89
+ # tqdm is optional - provide a lightweight passthrough iterator when not installed
90
+ def tqdm(iterable, **kwargs):
91
+ return iterable
92
+
10
93
  from scipy.optimize import curve_fit, minimize
11
94
  from scipy.interpolate import UnivariateSpline
95
+ from scipy.stats import ttest_ind
12
96
  import matplotlib.pyplot as plt
13
97
  import networkx as nx
14
98
  from matplotlib import rcParams
99
+ #from sklearn.cluster import KMeans
100
+ from sklearn.neighbors import KNeighborsClassifier
101
+ from sklearn.cluster import AgglomerativeClustering
102
+ from sklearn.mixture import GaussianMixture
103
+ from sklearn.preprocessing import MinMaxScaler
104
+ import dask.dataframe as dd
105
+ import dask.array as da
106
+ try:
107
+ from dask_ml.cluster import KMeans
108
+ _KMEANS_IMPL = 'dask'
109
+ except ImportError:
110
+ # dask-ml may not be installed in all environments; fall back to scikit-learn
111
+ from sklearn.cluster import KMeans
112
+ _KMEANS_IMPL = 'sklearn'
113
+ from dask import delayed
114
+ import sys
115
+ import matplotlib
116
+ from dask import config
117
+ config.set({"dataframe.convert-string": False})
118
+ from dask.distributed import Client
119
+ #client = Client(processes=False, threads_per_worker=1, memory_limit = '8GB') # Single-threaded mode
120
+ from intervaltree import Interval, IntervalTree
121
+ import gc
122
+ gc.collect()
15
123
 
16
124
  font = {'family': 'serif','size': 6}
17
125
  rcParams['font.size'] = 6
18
126
  rcParams['font.family'] = 'serif'
19
127
 
128
+ # Non-interactive helper: if the environment variable PYMAST_NONINTERACTIVE is set, auto-answer prompts
129
+ _NON_INTERACTIVE = os.environ.get('PYMAST_NONINTERACTIVE', '0') in ('1', 'true', 'True')
130
+
131
+ def _prompt(prompt_text, default=None):
132
+ if _NON_INTERACTIVE:
133
+ return default
134
+ try:
135
+ return input(prompt_text)
136
+ except (EOFError, OSError) as exc:
137
+ raise RuntimeError(
138
+ "Input prompt failed. Set PYMAST_NONINTERACTIVE=1 to use defaults."
139
+ ) from exc
140
+
20
141
  class bout():
21
- '''Python class object to delineate when bouts occur at receiver.'''
22
- def __init__ (self, radio_project, node, lag_window, time_limit):
23
- self.lag_window = lag_window
24
- self.time_limit = time_limit
142
+ """
143
+ DBSCAN-based bout detection for identifying continuous fish presence periods.
144
+
145
+ Uses density-based spatial clustering (DBSCAN) to group detections into bouts
146
+ based on temporal proximity. Each bout represents a period of continuous or
147
+ near-continuous presence at a receiver.
148
+
149
+ Attributes
150
+ ----------
151
+ db : str
152
+ Path to project HDF5 database
153
+ rec_id : str
154
+ Receiver identifier to process
155
+ eps_multiplier : int
156
+ Multiplier for pulse rate to set DBSCAN epsilon (temporal threshold)
157
+ lag_window : int
158
+ Time window in seconds for lag calculations (legacy parameter)
159
+ tags : pandas.DataFrame
160
+ Tag metadata (freq_code, pulse_rate, tag_type, etc.)
161
+ recaptures_df : pandas.DataFrame
162
+ Detections for this receiver
163
+ presence_df : pandas.DataFrame
164
+ Bout presence/absence matrix (fish x bout x receiver)
165
+
166
+ Methods
167
+ -------
168
+ cluster()
169
+ Run DBSCAN clustering to assign bout numbers
170
+ visualize_bout_lengths()
171
+ Create diagnostic plots showing bout duration distributions
172
+
173
+ Notes
174
+ -----
175
+ - Physics-based epsilon: pulse_rate * eps_multiplier
176
+ - Default eps_multiplier=5 gives ~40-50 seconds for typical tags
177
+ - min_samples=1 treats every detection as potential bout start
178
+ - Bout numbers are unique per fish, not globally
179
+ - Presence matrix tracks which bouts occurred at which receivers
180
+
181
+ Examples
182
+ --------
183
+ >>> from pymast.overlap_removal import bout
184
+ >>> bout_obj = bout(
185
+ ... radio_project=proj,
186
+ ... rec_id='R03',
187
+ ... eps_multiplier=5,
188
+ ... lag_window=9
189
+ ... )
190
+ >>> bout_obj.cluster()
191
+ >>> bout_obj.visualize_bout_lengths()
192
+
193
+ See Also
194
+ --------
195
+ overlap_reduction : Resolves overlapping detections between bouts
196
+ formatter.time_to_event : Uses bout presence for TTE analysis
197
+ """
198
+ def __init__(self, radio_project, rec_id, eps_multiplier=5, lag_window=9):
199
+ """
200
+ Initialize bout detection for a specific receiver.
201
+
202
+ Args:
203
+ radio_project: Project object with database and tags
204
+ rec_id (str): Receiver ID to process (e.g., 'R03')
205
+ eps_multiplier (int): Multiplier for pulse rate to set DBSCAN epsilon
206
+ Default 5 = ~40-50 sec for typical tags
207
+ lag_window (int): Time window in seconds for lag calculations
208
+ Default 9 seconds (kept for compatibility, not used in DBSCAN)
209
+ """
210
+ from sklearn.cluster import DBSCAN
211
+
25
212
  self.db = radio_project.db
26
-
27
- # get the receivers associated with this particular network node
28
- recs = radio_project.receivers[radio_project.receivers.node == node]
29
- self.receivers = recs.index # get the unique receivers associated with this node
30
- self.data = pd.DataFrame(columns = ['freq_code','epoch','rec_id']) # set up an empty data frame
31
-
32
- # for every receiver
33
- for i in self.receivers:
34
- # get this receivers data from the classified key
35
- rec_dat = pd.read_hdf(self.db,
36
- 'classified',
37
- where = 'rec_id = %s'%(i))
38
- rec_dat = rec_dat[rec_dat.iter == rec_dat.iter.max()]
39
- rec_dat = rec_dat[rec_dat.test == 1]
40
- rec_dat = rec_dat[['freq_code','epoch','rec_id']]
41
- rec_dat = rec_dat.astype({'freq_code':'object',
42
- 'epoch':'float32',
43
- 'rec_id':'object'})
44
-
45
- self.data = self.data.append(rec_dat)
46
-
47
- # clean up and bin the lengths
48
- self.data.drop_duplicates(keep = 'first', inplace = True)
49
- self.data.sort_values(by = ['freq_code','epoch'], inplace = True)
50
- self.data['det_lag'] = self.data.groupby('freq_code')['epoch'].diff().abs() // lag_window * lag_window
51
- self.data.dropna(axis = 0, inplace = True) # drop Nan from the data
52
- self.node = node
53
- self.fishes = self.data.freq_code.unique()
213
+ self.rec_id = rec_id
214
+ self.eps_multiplier = eps_multiplier
215
+ self.lag_window = lag_window
216
+ self.tags = radio_project.tags
54
217
 
55
- def prompt_for_params(self, model_type):
56
- if model_type == 'two_process':
57
- print("Please enter the values for the initial quantity (y0), the quantity at time t (yt), and the time t.")
58
- y0 = float(input("Enter the initial quantity (y0): "))
59
- yt = float(input("Enter the quantity at time t (yt): "))
60
- t = float(input("Enter the time at which yt is observed (t): "))
218
+ # Load classified data for this receiver
219
+ print(f"[bout] Loading classified data for {rec_id}")
220
+ rec_dat = pd.read_hdf(self.db, 'classified', where=f'rec_id == "{rec_id}"')
221
+ rec_dat = rec_dat[rec_dat.iter == rec_dat.iter.max()]
222
+ rec_dat = rec_dat[rec_dat.test == 1]
223
+ rec_dat = rec_dat[['freq_code', 'epoch', 'time_stamp', 'power', 'rec_id']]
224
+ rec_dat = rec_dat.astype({
225
+ 'freq_code': 'object',
226
+ 'epoch': 'int64',
227
+ 'time_stamp': 'datetime64[ns]',
228
+ 'power': 'float32',
229
+ 'rec_id': 'object'
230
+ })
61
231
 
62
- # Calculate decay rate b1 using the provided y0, yt, and t
63
- b1 = -np.log(yt / y0) / t
232
+ # Clean up
233
+ rec_dat.drop_duplicates(keep='first', inplace=True)
234
+ rec_dat.sort_values(by=['freq_code', 'time_stamp'], inplace=True)
64
235
 
65
- # Assume that the decay rate b2 after the knot is the same as b1 before the knot
66
- # This is a simplifying assumption; you may want to calculate b2 differently based on additional data or domain knowledge
67
- b2 = b1
236
+ self.data = rec_dat
237
+ self.fishes = self.data.freq_code.unique()
68
238
 
69
- # For the two-process model, we'll assume a1 is the initial quantity y0
70
- a1 = y0
239
+ print(f"[bout] Loaded {len(self.data)} detections for {len(self.fishes)} fish")
71
240
 
72
- # We'll calculate a2 such that the function is continuous at the knot
73
- # This means solving the equation a1 * exp(-b1 * k) = a2 * exp(-b2 * k)
74
- # Since we've assumed b1 = b2, this simplifies to a2 = a1 * exp(-b1 * k)
75
- a2 = a1 * np.exp(-b1 * t)
241
+ # Run DBSCAN bout detection immediately
242
+ self._detect_bouts()
76
243
 
77
- return [a1, b1, a2, b2, t]
244
+ def _detect_bouts(self):
245
+ """
246
+ Run DBSCAN temporal clustering to identify bouts.
247
+ Called automatically during __init__.
248
+ """
249
+ from sklearn.cluster import DBSCAN
250
+ import logging
78
251
 
79
- else:
80
- print ("Sorry, we don't yet support that model type")
81
- # get lag frequencies
82
- lags = np.arange(0,self.time_limit,2)
83
- freqs, bins = np.histogram(np.sort(self.data.det_lag),lags)
84
- bins = bins[np.where(freqs > 0)]
85
- freqs = freqs[np.where(freqs > 0)]
86
- log_freqs = np.log(freqs)
87
-
88
- # Plot the raw data
89
- plt.scatter(bins, log_freqs, label='Data')
90
- plt.xlabel('Lag')
91
- plt.ylabel('Lag Frequency')
92
- plt.title('Raw Data for Two-Process Model')
93
- plt.legend()
94
- plt.show()
95
-
96
- # Prompt the user for initial parameters
97
- initial_guess = self.prompt_for_params(model_type = 'two_process')
98
-
99
- # Perform the curve fitting
100
- try:
101
- params, params_covariance = curve_fit(self.two_process_model,
102
- bins,
103
- log_freqs,
104
- p0=initial_guess)
105
-
106
- # Plot the fitted function
107
- plt.plot(bins, self.two_process_model(bins, *params),
108
- label='Fitted function',
109
- color='red')
252
+ logger = logging.getLogger(__name__)
253
+
254
+ print(f"[bout] Running DBSCAN bout detection for {self.rec_id}")
255
+ presence_list = []
256
+
257
+ for fish in self.fishes:
258
+ fish_dat = self.data[self.data.freq_code == fish].copy()
110
259
 
111
- plt.scatter(bins, log_freqs, label='Data')
112
- plt.legend()
113
- plt.xlabel('x')
114
- plt.ylabel('y')
115
- plt.title('Fitted Two-Process Model')
116
- plt.show()
260
+ if len(fish_dat) == 0:
261
+ continue
117
262
 
118
- # Return the fitted parameters
119
- return params
120
- except RuntimeError as e:
121
- print(f"An error occurred during fitting: {e}")
122
- return None
123
-
124
- def find_knot(self, initial_knot_guess):
125
- # get lag frequencies
126
- lags = np.arange(0, self.time_limit, 2)
127
- freqs, bins = np.histogram(np.sort(self.data.det_lag), lags)
128
- bins = bins[:-1][freqs > 0] # Ensure the bins array is the right length
129
- freqs = freqs[freqs > 0]
130
- log_freqs = np.log(freqs)
131
-
132
- # Define a two-segment exponential decay function
133
- def two_exp_decay(x, a1, b1, a2, b2, k):
134
- condlist = [x < k, x >= k]
135
- funclist = [
136
- lambda x: a1 * np.exp(-b1 * x),
137
- lambda x: a2 * np.exp(-b2 * (x - k))
138
- ]
139
- return np.piecewise(x, condlist, funclist)
140
-
141
-
142
- # Objective function for two-segment model
143
- def objective_function(knot, bins, log_freqs):
144
- # Fit the model without bounds on the knot
263
+ # Get pulse rate for this fish
145
264
  try:
146
- params, _ = curve_fit(lambda x, a1, b1, a2, b2: two_exp_decay(x, a1, b1, a2, b2, knot),
147
- bins,
148
- log_freqs,
149
- p0=[log_freqs[0], 0.001,
150
- log_freqs[0], 0.001])
151
- y_fit = two_exp_decay(bins, *params, knot)
152
- error = np.sum((log_freqs - y_fit) ** 2)
153
- return error
154
- except RuntimeError:
155
- return np.inf
156
-
157
- # Use minimize to find the optimal knot
158
- result = minimize(
159
- fun=objective_function,
160
- x0=[initial_knot_guess],
161
- args=(bins, log_freqs),
162
- bounds=[(bins.min(), bins.max())]
163
- )
164
-
165
- # Check if the optimization was successful and extract the results
166
- if result.success:
167
- optimized_knot = result.x[0]
168
-
169
- # Refit with the optimized knot to get all parameters
170
- p0 = [log_freqs[0], 0.001,
171
- log_freqs[0], 0.001,
172
- optimized_knot]
173
-
174
- bounds_lower = [0, 0,
175
- 0, 0,
176
- bins.min()]
177
-
178
- bounds_upper = [np.inf, np.inf,
179
- np.inf, np.inf,
180
- bins.max()]
181
-
182
- optimized_params, _ = curve_fit(
183
- two_exp_decay,
184
- bins,
185
- log_freqs,
186
- p0=p0,
187
- bounds=(bounds_lower, bounds_upper)
188
- )
189
-
190
- # Visualization of the final fit with the estimated knot
191
- plt.figure(figsize=(12, 6))
192
-
193
- plt.scatter(bins,
194
- log_freqs,
195
- label='Data',
196
- alpha=0.6)
265
+ pulse_rate = self.tags.loc[fish, 'pulse_rate']
266
+ except (KeyError, AttributeError):
267
+ pulse_rate = 8.0 # Default if not found
197
268
 
198
- x_range = np.linspace(bins.min(), bins.max(), 1000)
269
+ # Calculate epsilon: pulse_rate * multiplier
270
+ eps = pulse_rate * self.eps_multiplier
199
271
 
200
- plt.plot(x_range,
201
- two_exp_decay(x_range, *optimized_params),
202
- label='Fitted function',
203
- color='red')
272
+ # DBSCAN clustering on epoch (1D temporal data)
273
+ epochs = fish_dat[['epoch']].values
274
+ clustering = DBSCAN(eps=eps, min_samples=1).fit(epochs)
275
+ fish_dat['bout_no'] = clustering.labels_
204
276
 
205
- plt.axvline(x=optimized_knot, color='orange', linestyle='--')#, label=f'Knot at x={optimized_knot:.2f}')
206
- plt.title('Fitted Two-Process Model')
207
- plt.xlabel('Lag Time')
208
- plt.ylabel('Frequency')
209
- plt.legend()
210
- plt.show()
277
+ # Filter out noise points (label = -1, though shouldn't happen with min_samples=1)
278
+ fish_dat = fish_dat[fish_dat.bout_no != -1]
211
279
 
212
- return optimized_params[-1]
280
+ # Assign to each detection
281
+ for idx, row in fish_dat.iterrows():
282
+ presence_list.append({
283
+ 'freq_code': row['freq_code'],
284
+ 'epoch': row['epoch'],
285
+ 'time_stamp': row['time_stamp'],
286
+ 'power': row['power'],
287
+ 'rec_id': row['rec_id'],
288
+ 'bout_no': row['bout_no'],
289
+ 'class': 'study',
290
+ 'det_lag': 0 # Not meaningful for DBSCAN, kept for compatibility
291
+ })
213
292
 
293
+ # Store results
294
+ if presence_list:
295
+ self.presence_df = pd.DataFrame(presence_list)
296
+ print(f"[bout] Detected {self.presence_df.bout_no.nunique()} bouts across {len(self.fishes)} fish")
214
297
  else:
215
- print("Optimization failed:", result.message)
216
- return None
217
-
218
- def find_knots(self, initial_knot_guesses):
219
- # get lag frequencies
220
- lags = np.arange(0, self.time_limit, 2)
221
- freqs, bins = np.histogram(np.sort(self.data.det_lag), lags)
222
- bins = bins[:-1][freqs > 0] # Ensure the bins array is the right length
223
- freqs = freqs[freqs > 0]
224
- log_freqs = np.log(freqs)
225
- # Assuming initial_knot_guesses is a list of two knot positions
226
- # This method should fit a three-process model
227
-
228
- # Define bounds for parameters outside of objective_function
229
- bounds_lower = [0, 0, 0, 0, 0, 0, bins.min(), initial_knot_guesses[0]]
230
- bounds_upper = [np.inf, np.inf, np.inf, np.inf, np.inf, np.inf, initial_knot_guesses[1], bins.max()]
231
-
232
- # Define a three-segment exponential decay function
233
- #TODO - is the math correct?
234
- def three_exp_decay(x, a1, b1, a2, b2, a3, b3, k1, k2):
235
- condlist = [x < k1, (k1 <= x) & (x < k2), x >= k2]
236
- funclist = [
237
- lambda x: a1 * np.exp(-b1 * x),
238
- lambda x: a2 * np.exp(-b2 * (x - k1)),
239
- lambda x: a3 * np.exp(-b3 * (x - k2))
240
- ]
241
- return np.piecewise(x, condlist, funclist)
242
-
243
- # Objective function for three-segment model
244
- def objective_function(knots, bins, log_freqs):
245
- # Unpack knots
246
- k1, k2 = knots
247
- # Initial parameter guesses (3 amplitudes, 3 decay rates)
248
- p0 = [log_freqs[0], 0.001, log_freqs[0], 0.001, log_freqs[0], 0.001, k1, k2]
249
-
250
- # Fit the three-segment model
251
- params, _ = curve_fit(three_exp_decay,
252
- bins,
253
- log_freqs,
254
- p0=p0,
255
- bounds=(bounds_lower, bounds_upper)) # Calculate the errors
256
- #TODO - AIC or BIC possible expansion?
257
- y_fit = three_exp_decay(bins, *params)
258
- error = np.sum((log_freqs - y_fit) ** 2)
259
- return error
260
-
261
- # Perform the optimization with the initial guesses
262
- result = minimize(
263
- fun=objective_function,
264
- x0=initial_knot_guesses, # Initial guesses for the knot positions
265
- args=(bins, log_freqs),
266
- bounds=[(bins.min(), initial_knot_guesses[1]), # Ensure k1 < k2
267
- (initial_knot_guesses[0], bins.max())]
268
- )
269
-
270
- # Check if the optimization was successful and extract the results
271
- if result.success:
272
- optimized_knots = result.x # These should be the optimized knots
273
-
274
- # Now we refit the model with the optimized knots to get all the parameters
275
- p0 = [log_freqs[0], 0.001,
276
- log_freqs[0], 0.001,
277
- log_freqs[0], 0.001,
278
- optimized_knots[0], optimized_knots[1]]
279
-
280
- bounds_lower = [0, 0,
281
- 0, 0,
282
- 0, 0,
283
- bins.min(), optimized_knots[0]]
284
-
285
- bounds_upper = [np.inf, np.inf,
286
- np.inf, np.inf,
287
- np.inf, np.inf,
288
- optimized_knots[1], bins.max()]
289
-
290
- optimized_params, _ = curve_fit(
291
- three_exp_decay,
292
- bins,
293
- log_freqs,
294
- p0=p0,
295
- bounds=(bounds_lower, bounds_upper)
296
- )
297
-
298
- # Visualization of the final fit with the estimated knots
299
- plt.figure(figsize=(12, 6))
300
- plt.scatter(bins, log_freqs, label='Data', alpha=0.6)
301
-
302
- # Create a range of x values for plotting the fitted function
303
- x_range = np.linspace(bins.min(), bins.max(), 1000)
304
- plt.plot(x_range,
305
- three_exp_decay(x_range, *optimized_params),
306
- label='Fitted function', color='red')
307
-
308
- plt.axvline(x=optimized_knots[0], color='orange', linestyle='--')
309
- plt.axvline(x=optimized_knots[1], color='green', linestyle='--')
310
-
311
- plt.title('Fitted Three-Process Model')
312
- plt.xlabel('Lag Time')
313
- plt.ylabel('Frequency')
314
- plt.legend()
315
- plt.show()
316
- else:
317
- print("Optimization failed:", result.message)
318
- return None
319
-
320
- return optimized_params[-1]
321
-
322
- def fit_processes(self):
323
- # Step 1: Plot bins vs log frequencies
324
- lags = np.arange(0, self.time_limit, 2)
325
- freqs, bins = np.histogram(np.sort(self.data.det_lag), lags)
326
- bins = bins[:-1][freqs > 0] # Ensure the bins array is the right length
327
- freqs = freqs[freqs > 0]
328
- log_freqs = np.log(freqs)
329
-
330
- plt.figure(figsize=(12, 6))
331
- plt.scatter(bins, log_freqs, label='Log of Frequencies', alpha=0.6)
332
- plt.title('Initial Data Plot')
333
- plt.xlabel('Bins')
334
- plt.ylabel('Log Frequencies')
335
- plt.legend()
298
+ self.presence_df = pd.DataFrame()
299
+ print(f"[bout] No bouts detected for {self.rec_id}")
300
+
301
+ def presence(self):
302
+ """
303
+ Write bout results to /presence table in HDF5.
304
+ Call this after __init__ to save results to database.
305
+ """
306
+ import logging
307
+ logger = logging.getLogger(__name__)
308
+
309
+ if self.presence_df.empty:
310
+ print(f"[bout] No presence data to write for {self.rec_id}")
311
+ return
312
+
313
+ # Prepare data for storage
314
+ presence_df = self.presence_df.astype({
315
+ 'freq_code': 'object',
316
+ 'rec_id': 'object',
317
+ 'epoch': 'int64',
318
+ 'time_stamp': 'datetime64[ns]',
319
+ 'power': 'float32',
320
+ 'bout_no': 'int32',
321
+ 'class': 'object',
322
+ 'det_lag': 'int32'
323
+ })
324
+
325
+ # Write to HDF5
326
+ with pd.HDFStore(self.db, mode='a') as store:
327
+ store.append(
328
+ key='presence',
329
+ value=presence_df[['freq_code', 'epoch', 'time_stamp', 'power', 'rec_id', 'class', 'bout_no', 'det_lag']],
330
+ format='table',
331
+ data_columns=True,
332
+ min_itemsize={'freq_code': 20, 'rec_id': 20, 'class': 20}
333
+ )
334
+
335
+ logger.debug(f"Wrote {len(presence_df)} detections ({presence_df.bout_no.nunique()} bouts) to /presence for {self.rec_id}")
336
+ print(f"[bout] ✓ Wrote {len(presence_df)} detections to database")
337
+
338
+ def visualize_bout_lengths(self, output_dir=None):
339
+ """
340
+ Visualize bout length distributions for this receiver.
341
+
342
+ Creates comprehensive plots showing:
343
+ - Overall distribution of bout lengths
344
+ - Bout lengths by fish
345
+ - Detections vs duration scatter
346
+ - Cumulative distribution
347
+
348
+ Args:
349
+ output_dir (str): Directory to save plots. If None, uses database directory.
350
+ """
351
+ if self.presence_df.empty:
352
+ print(f"[bout] No bout data to visualize for {self.rec_id}")
353
+ return
354
+
355
+ # Calculate bout summaries from presence data
356
+ bout_summary = self.presence_df.groupby(['freq_code', 'bout_no']).agg({
357
+ 'epoch': ['min', 'max', 'count'],
358
+ 'power': 'mean'
359
+ }).reset_index()
360
+ bout_summary.columns = ['freq_code', 'bout_no', 'first_epoch', 'last_epoch', 'num_detections', 'mean_power']
361
+
362
+ bouts = bout_summary.copy()
363
+ bouts['bout_duration_sec'] = bouts['last_epoch'] - bouts['first_epoch']
364
+ bouts['bout_duration_hrs'] = bouts['bout_duration_sec'] / 3600
365
+
366
+ print(f"\n{'='*80}")
367
+ print(f"BOUT LENGTH STATISTICS - {self.rec_id}")
368
+ print(f"{'='*80}")
369
+
370
+ # Overall statistics
371
+ print(f"\nTotal bouts: {len(bouts):,}")
372
+ print(f"\nBout duration statistics (hours):")
373
+ print(bouts['bout_duration_hrs'].describe())
374
+
375
+ # Count by duration categories
376
+ short_bouts = bouts[bouts['bout_duration_sec'] < 300]
377
+ medium_bouts = bouts[(bouts['bout_duration_sec'] >= 300) & (bouts['bout_duration_sec'] < 1800)]
378
+ expected_bouts = bouts[(bouts['bout_duration_sec'] >= 1800) & (bouts['bout_duration_sec'] < 7200)]
379
+ long_bouts = bouts[bouts['bout_duration_sec'] >= 7200]
380
+
381
+ print(f"\nBouts < 5 minutes: {len(short_bouts):,} ({100*len(short_bouts)/len(bouts):.1f}%)")
382
+ print(f"Bouts 5-30 minutes: {len(medium_bouts):,} ({100*len(medium_bouts)/len(bouts):.1f}%)")
383
+ print(f"Bouts 30min-2hrs: {len(expected_bouts):,} ({100*len(expected_bouts)/len(bouts):.1f}%)")
384
+ print(f"Bouts > 2 hours: {len(long_bouts):,} ({100*len(long_bouts)/len(bouts):.1f}%)")
385
+
386
+ # Per-fish statistics
387
+ print(f"\nBout statistics by fish:")
388
+ fish_stats = bouts.groupby('freq_code').agg({
389
+ 'bout_duration_hrs': ['count', 'median', 'mean', 'max'],
390
+ 'num_detections': ['median', 'mean']
391
+ }).round(2)
392
+ fish_stats.columns = ['_'.join(col).strip() for col in fish_stats.columns.values]
393
+ print(fish_stats.head(10))
394
+
395
+ # Sample very short bouts
396
+ if len(short_bouts) > 0:
397
+ print(f"\nSample of very short bouts (< 5 minutes):")
398
+ print(short_bouts[['freq_code', 'num_detections', 'bout_duration_sec', 'bout_duration_hrs']].head(10).to_string())
399
+
400
+ # Create visualizations
401
+ fig, axes = plt.subplots(2, 2, figsize=(14, 10))
402
+ fig.suptitle(f'Bout Length Analysis - {self.rec_id}', fontsize=14, fontweight='bold')
403
+
404
+ # 1. Overall distribution (log scale)
405
+ ax = axes[0, 0]
406
+ bouts['bout_duration_hrs'].hist(bins=50, ax=ax, edgecolor='black', alpha=0.7, color='steelblue')
407
+ ax.axvline(0.5, color='red', linestyle='--', linewidth=2, label='30 min')
408
+ ax.axvline(1.0, color='orange', linestyle='--', linewidth=2, label='1 hour')
409
+ ax.set_xlabel('Bout Duration (hours)', fontsize=10)
410
+ ax.set_ylabel('Frequency (log scale)', fontsize=10)
411
+ ax.set_title('Distribution of Bout Lengths', fontsize=11)
412
+ ax.set_yscale('log')
413
+ ax.legend()
414
+ ax.grid(True, alpha=0.3)
415
+
416
+ # 2. Zoomed in on short bouts (0-2 hours)
417
+ ax = axes[0, 1]
418
+ short_data = bouts[bouts['bout_duration_hrs'] <= 2]['bout_duration_hrs']
419
+ short_data.hist(bins=40, ax=ax, edgecolor='black', alpha=0.7, color='coral')
420
+ ax.axvline(0.5, color='red', linestyle='--', linewidth=2, label='30 min')
421
+ ax.axvline(1.0, color='orange', linestyle='--', linewidth=2, label='1 hour')
422
+ ax.set_xlabel('Bout Duration (hours)', fontsize=10)
423
+ ax.set_ylabel('Frequency', fontsize=10)
424
+ ax.set_title('Bout Lengths 0-2 Hours (Zoomed)', fontsize=11)
425
+ ax.legend()
426
+ ax.grid(True, alpha=0.3)
427
+
428
+ # 3. Detections per bout vs bout duration
429
+ ax = axes[1, 0]
430
+ sample = bouts.sample(min(5000, len(bouts))) # Sample for performance
431
+ ax.scatter(sample['num_detections'], sample['bout_duration_hrs'],
432
+ alpha=0.4, s=20, color='darkgreen')
433
+ ax.set_xlabel('Number of Detections per Bout', fontsize=10)
434
+ ax.set_ylabel('Bout Duration (hours)', fontsize=10)
435
+ ax.set_title('Detections vs Bout Duration', fontsize=11)
436
+ ax.set_xscale('log')
437
+ ax.set_yscale('log')
438
+ ax.axhline(0.5, color='red', linestyle='--', alpha=0.5, label='30 min')
439
+ ax.axhline(1.0, color='orange', linestyle='--', alpha=0.5, label='1 hour')
440
+ ax.legend()
441
+ ax.grid(True, alpha=0.3)
442
+
443
+ # 4. Cumulative distribution
444
+ ax = axes[1, 1]
445
+ sorted_durations = np.sort(bouts['bout_duration_hrs'])
446
+ cumulative = np.arange(1, len(sorted_durations) + 1) / len(sorted_durations) * 100
447
+ ax.plot(sorted_durations, cumulative, linewidth=2, color='purple')
448
+ ax.axvline(0.5, color='red', linestyle='--', linewidth=2, label='30 min')
449
+ ax.axvline(1.0, color='orange', linestyle='--', linewidth=2, label='1 hour')
450
+ ax.set_xlabel('Bout Duration (hours)', fontsize=10)
451
+ ax.set_ylabel('Cumulative % of Bouts', fontsize=10)
452
+ ax.set_title('Cumulative Distribution', fontsize=11)
453
+ ax.set_xscale('log')
454
+ ax.legend()
455
+ ax.grid(True, alpha=0.3)
456
+
457
+ plt.tight_layout()
458
+
459
+ # Save figure
460
+ if output_dir is None:
461
+ output_dir = os.path.dirname(self.db)
462
+ output_path = os.path.join(output_dir, f'bout_lengths_{self.rec_id}.png')
463
+ plt.savefig(output_path, dpi=300, bbox_inches='tight')
464
+ print(f"\n[bout] Saved visualization to: {output_path}")
465
+
336
466
  plt.show()
467
+
468
+ return bouts
469
+
470
+
471
+ class overlap_reduction:
472
+ """
473
+ Resolve overlapping detections using signal quality comparison.
474
+
475
+ Identifies and resolves spatially/temporally overlapping detections by comparing
476
+ signal power and posterior probabilities. Marks weaker detections as overlapping
477
+ to prevent spatial ambiguity in downstream statistical models.
478
+
479
+ Resolution Logic:
480
+ 1. **Power Comparison**: If both detections have power, keep stronger signal
481
+ 2. **Posterior Comparison**: If both have classification scores, keep higher posterior
482
+ 3. **Ambiguous**: Mark if signals equal or criteria unavailable
483
+ 4. **Bout Conflicts**: Identify temporally overlapping bouts at different receivers
484
+
485
+ Attributes
486
+ ----------
487
+ db : str
488
+ Path to project HDF5 database
489
+ project : object
490
+ Radio project instance with database and metadata
491
+ nodes : list
492
+ List of receiver IDs (nodes in network)
493
+ edges : list of tuples
494
+ Directed edges representing receiver relationships
495
+ G : networkx.DiGraph
496
+ Network graph of receiver connections
497
+ node_pres_dict : dict
498
+ Presence data for each receiver (fish x bout)
499
+ node_recap_dict : dict
500
+ Recapture data for each receiver (detections)
501
+
502
+ Methods
503
+ -------
504
+ unsupervised()
505
+ Resolve overlaps using power/posterior comparison, apply bout spatial filter
506
+ visualize_overlaps()
507
+ Create 8-panel diagnostic plots for overlap analysis
508
+
509
+ Notes
510
+ -----
511
+ - Operates on bout-level and detection-level simultaneously
512
+ - Bout spatial filter identifies temporally overlapping bouts (≥50% overlap)
513
+ - Power and posterior columns are optional (conditionally checked)
514
+ - Results written to `/overlapping` table in HDF5 database
515
+ - Visualization includes network structure, temporal patterns, power distributions
337
516
 
338
- # Step 2: Ask user for initial knots
339
- num_knots = int(input("Enter the number of knots (1 for two-process, 2 for three-process): "))
340
- initial_knots = []
341
- for i in range(num_knots):
342
- knot = float(input(f"Enter initial guess for knot {i+1}: "))
343
- initial_knots.append(knot)
344
-
345
- # Step 3 & 4: Determine the number of processes and fit accordingly
346
- if num_knots == 1:
347
- # Fit a two-process model
348
- self.initial_knot_guess = initial_knots[0]
349
- self.find_knot(self.initial_knot_guess)
350
- elif num_knots == 2:
351
- # Fit a three-process model (you will need to implement this method)
352
- self.initial_knot_guesses = initial_knots
353
- self.find_knots(self.initial_knot_guesses)
517
+ Examples
518
+ --------
519
+ >>> from pymast.overlap_removal import overlap_reduction
520
+ >>> overlap_obj = overlap_reduction(
521
+ ... nodes=['R01', 'R02', 'R03'],
522
+ ... edges=[('R01', 'R02'), ('R02', 'R03')],
523
+ ... radio_project=proj
524
+ ... )
525
+ >>> overlap_obj.unsupervised()
526
+ >>> overlap_obj.visualize_overlaps()
527
+
528
+ See Also
529
+ --------
530
+ bout : DBSCAN-based bout detection
531
+ formatter.time_to_event : Uses overlap decisions for filtering
532
+ """
533
+
534
+ def _rec_id_variants(self, node):
535
+ node_str = str(node)
536
+ variants = [node_str]
537
+ if node_str.startswith(('R', 'r')):
538
+ variants.append(node_str[1:])
539
+ variants.append(node_str.lstrip('0'))
540
+ variants.append(node_str.lstrip('R').lstrip('0'))
541
+ digits = ''.join(filter(str.isdigit, node_str))
542
+ if digits:
543
+ variants.append(str(int(digits)))
544
+ seen = set()
545
+ variants_clean = []
546
+ for v in variants:
547
+ if not v:
548
+ continue
549
+ if v not in seen:
550
+ seen.add(v)
551
+ variants_clean.append(v)
552
+ return variants_clean
553
+
554
+ def _read_presence(self, node, logger):
555
+ pres_where = f"rec_id == '{node}'"
556
+ used_full_presence_read = False
557
+ try:
558
+ pres_data = pd.read_hdf(
559
+ self.db,
560
+ 'presence',
561
+ columns=['freq_code', 'epoch', 'time_stamp', 'power', 'rec_id', 'bout_no'],
562
+ where=pres_where
563
+ )
564
+ except (TypeError, ValueError):
565
+ # Some stores are fixed-format and don't support column selection - read entire table
566
+ used_full_presence_read = True
567
+ pres_data = pd.read_hdf(self.db, 'presence')
568
+
569
+ if len(pres_data) == 0 and not used_full_presence_read:
570
+ for cand in self._rec_id_variants(node):
571
+ alt_where = f"rec_id == '{cand}'"
572
+ try:
573
+ alt_pres = pd.read_hdf(
574
+ self.db,
575
+ 'presence',
576
+ columns=['freq_code', 'epoch', 'time_stamp', 'power', 'rec_id', 'bout_no'],
577
+ where=alt_where
578
+ )
579
+ if len(alt_pres) > 0:
580
+ pres_data = alt_pres
581
+ logger.info("Node %s: found %d presence rows using alternate WHERE rec_id == '%s'", node, len(pres_data), cand)
582
+ break
583
+ except (TypeError, ValueError):
584
+ logger.debug("Node %s: alternate WHERE '%s' not supported by store", node, alt_where)
585
+ break
586
+ except (KeyError, OSError, ValueError) as e:
587
+ raise RuntimeError(
588
+ f"Failed to read presence for node '{node}' using WHERE '{alt_where}': {e}"
589
+ ) from e
590
+
591
+ return pres_data, pres_where, used_full_presence_read
592
+
593
+ def _read_recap(self, node):
594
+ classified_where = f"(rec_id == '{node}') & (test == 1)"
595
+ try:
596
+ recap_data = pd.read_hdf(
597
+ self.db,
598
+ 'classified',
599
+ columns=[
600
+ 'freq_code', 'epoch', 'time_stamp', 'power', 'rec_id', 'iter',
601
+ 'test', 'posterior_T', 'posterior_F'
602
+ ],
603
+ where=classified_where
604
+ )
605
+ except (TypeError, ValueError) as e:
606
+ recap_data = pd.read_hdf(self.db, 'classified')
607
+ try:
608
+ recap_data = recap_data.query(classified_where)
609
+ except ValueError as exc:
610
+ raise RuntimeError(
611
+ f"Failed to filter classified data for node '{node}': {exc}"
612
+ ) from exc
613
+ except (KeyError, FileNotFoundError, OSError) as e:
614
+ # If classified isn't available, try recaptures
615
+ try:
616
+ recap_data = pd.read_hdf(self.db, 'recaptures')
617
+ recap_data = recap_data.query(f"rec_id == '{node}'")
618
+ except (KeyError, FileNotFoundError, OSError, ValueError) as exc:
619
+ raise RuntimeError(
620
+ f"Failed to read classified or recaptures data for node '{node}': {exc}"
621
+ ) from exc
622
+ recap_data = recap_data[recap_data['iter'] == recap_data['iter'].max()]
623
+ return recap_data
624
+
625
+ def _summarize_presence(self, node, pres_data, recap_data, pres_where, used_full_presence_read, logger):
626
+ # Ensure presence has a 'power' column by merging power from the
627
+ # classified/recaptures table when available. We don't change how
628
+ # presence is originally created (bouts), we only attach power here
629
+ # for downstream aggregation.
630
+ if 'power' not in pres_data.columns and not recap_data.empty and 'power' in recap_data.columns:
631
+ try:
632
+ pres_data = pres_data.merge(
633
+ recap_data[['freq_code', 'epoch', 'rec_id', 'power']],
634
+ on=['freq_code', 'epoch', 'rec_id'],
635
+ how='left'
636
+ )
637
+ except (KeyError, ValueError, TypeError) as e:
638
+ raise RuntimeError(
639
+ f"Failed to merge power into presence data for node '{node}': {e}"
640
+ ) from e
641
+
642
+ # Group presence data by frequency code and bout, then calculate min, max, and median power
643
+ # Check if power column exists first
644
+ if 'power' in pres_data.columns:
645
+ summarized_data = pres_data.groupby(['freq_code', 'bout_no', 'rec_id']).agg(
646
+ min_epoch=('epoch', 'min'),
647
+ max_epoch=('epoch', 'max'),
648
+ median_power=('power', 'median')
649
+ ).reset_index()
354
650
  else:
355
- print("Invalid number of knots. Please enter 1 or 2.")
356
-
357
- # Fit the model based on the number of knots
358
- optimized_knots = None
359
- if num_knots == 1:
360
- # Fit a two-process model
361
- optimized_knots = self.find_knot(initial_knots[0])
362
- elif num_knots == 2:
363
- # Fit a three-process model
364
- optimized_knots = self.find_knots(initial_knots)
651
+ logger.warning(f'Node {node}: Power column not found in presence data. Run bout detection first to populate power.')
652
+ summarized_data = pres_data.groupby(['freq_code', 'bout_no', 'rec_id']).agg(
653
+ min_epoch=('epoch', 'min'),
654
+ max_epoch=('epoch', 'max')
655
+ ).reset_index()
656
+ summarized_data['median_power'] = None
657
+
658
+ # Log detailed counts so users can see raw vs summarized presence lengths
659
+ raw_count = len(pres_data)
660
+ summarized_count = len(summarized_data)
661
+ logger.info(f"Node {node}: raw presence rows={raw_count}, summarized bouts={summarized_count}")
662
+
663
+ # If we had to read the full presence table, warn (this can be slow and surprising)
664
+ if used_full_presence_read:
665
+ logger.warning(
666
+ "Node %s: had to read entire 'presence' table (fixed-format store); this may be slow and cause large raw counts. WHERE used: %s",
667
+ node,
668
+ pres_where,
669
+ )
670
+
671
+ # If counts are zero or unexpectedly large, include a small sample and the WHERE clause to help debug
672
+ if raw_count == 0 or raw_count > 100000:
673
+ sample_head = pres_data.head(10).to_dict(orient='list')
674
+ logger.debug(
675
+ "Node %s: pres_data sample (up to 10 rows)=%s; WHERE=%s",
676
+ node,
677
+ sample_head,
678
+ pres_where,
679
+ )
680
+
681
+ # If we got zero rows from the column/where read, try a safe in-memory
682
+ # fallback: read the full presence table and match rec_id after
683
+ # normalizing (strip/upper). This can detect formatting mismatches
684
+ # (e.g. numeric vs string rec_id, padding, whitespace).
685
+ if raw_count == 0 and not used_full_presence_read:
686
+ try:
687
+ logger.debug(
688
+ "Node %s: attempting in-memory fallback full-table read to find rec_id matches",
689
+ node,
690
+ )
691
+ full_pres = pd.read_hdf(self.db, 'presence')
692
+ if 'rec_id' in full_pres.columns:
693
+ node_norm = str(node).strip().upper()
694
+ full_pres['_rec_norm'] = full_pres['rec_id'].astype(str).str.strip().str.upper()
695
+ candidate = full_pres[full_pres['_rec_norm'] == node_norm]
696
+ if len(candidate) > 0:
697
+ # select expected columns if present
698
+ cols = [c for c in ['freq_code', 'epoch', 'time_stamp', 'power', 'rec_id', 'bout_no'] if c in candidate.columns]
699
+ pres_data = candidate[cols].copy()
700
+ raw_count = len(pres_data)
701
+ used_full_presence_read = True
702
+ logger.info(
703
+ "Node %s: found %d presence rows after in-memory normalization of rec_id",
704
+ node,
705
+ raw_count,
706
+ )
707
+ else:
708
+ logger.debug("Node %s: in-memory full-table read did not find rec_id matches", node)
709
+ else:
710
+ logger.debug("Node %s: 'rec_id' column not present in full presence table", node)
711
+ except (KeyError, OSError, ValueError) as e:
712
+ raise RuntimeError(
713
+ f"Failed to perform in-memory presence fallback for node '{node}': {e}"
714
+ ) from e
715
+
716
+ return summarized_data, raw_count
717
+
718
+ def _apply_posterior_decision(
719
+ self,
720
+ parent_dat,
721
+ child_dat,
722
+ p_indices,
723
+ c_indices,
724
+ p_power,
725
+ c_power,
726
+ min_detections,
727
+ p_value_threshold,
728
+ effect_size_threshold,
729
+ mean_diff_threshold,
730
+ decisions,
731
+ skip_reasons,
732
+ parent_mark_idx,
733
+ child_mark_idx,
734
+ ):
735
+ p_posteriors = parent_dat.loc[p_indices, 'posterior_T'].values if 'posterior_T' in parent_dat.columns else []
736
+ c_posteriors = child_dat.loc[c_indices, 'posterior_T'].values if 'posterior_T' in child_dat.columns else []
737
+
738
+ if len(p_posteriors) == 0 or len(c_posteriors) == 0:
739
+ decisions['keep_both'] += 1
740
+ skip_reasons['no_posterior_data'] += 1
741
+ return 0
742
+
743
+ p_posteriors = p_posteriors[~np.isnan(p_posteriors)]
744
+ c_posteriors = c_posteriors[~np.isnan(c_posteriors)]
745
+
746
+ if len(p_posteriors) < min_detections or len(c_posteriors) < min_detections:
747
+ decisions['keep_both'] += 1
748
+ skip_reasons['insufficient_after_nan'] += 1
749
+ return 0
750
+
751
+ t_stat, p_value = ttest_ind(p_posteriors, c_posteriors, equal_var=False)
752
+
753
+ mean_diff = np.mean(p_posteriors) - np.mean(c_posteriors)
754
+ n1, n2 = len(p_posteriors), len(c_posteriors)
755
+ var1 = np.var(p_posteriors, ddof=1) if n1 > 1 else 0.0
756
+ var2 = np.var(c_posteriors, ddof=1) if n2 > 1 else 0.0
757
+ pooled_std = np.sqrt(((n1-1)*var1 + (n2-1)*var2) / (n1+n2-2)) if (n1+n2-2) > 0 else 1.0
758
+ cohens_d = mean_diff / pooled_std if pooled_std > 0 else 0.0
759
+
760
+ if p_value < p_value_threshold and abs(cohens_d) >= effect_size_threshold:
761
+ if cohens_d > 0:
762
+ child_mark_idx.extend(c_indices)
763
+ decisions['remove_child'] += 1
764
+ return len(c_indices)
765
+ parent_mark_idx.extend(p_indices)
766
+ decisions['remove_parent'] += 1
767
+ return len(p_indices)
768
+
769
+ p_mean_posterior = np.mean(p_posteriors)
770
+ c_mean_posterior = np.mean(c_posteriors)
771
+
772
+ if not pd.isna(p_power) and not pd.isna(c_power) and (p_power + c_power) > 0:
773
+ p_norm_power = p_power / (p_power + c_power)
774
+ c_norm_power = c_power / (p_power + c_power)
365
775
  else:
366
- print("Invalid number of knots. Please enter 1 or 2.")
367
-
368
- # Return the optimized knot(s)
369
- return optimized_knots
370
-
371
- def presence(self, threshold):
372
- '''Function takes the break point between a continuous presence and new presence,
373
- enumerates the presence number at a receiver and writes the data to the
374
- analysis database.'''
375
- fishes = self.data.freq_code.unique()
376
-
377
- for fish in fishes:
378
- fish_dat = self.data[self.data.freq_code == fish]
379
-
380
- # Vectorized classification
381
- classifications = np.where(fish_dat.det_lag <= threshold, 'within_bout', 'start_new_bout')
382
-
383
- # Generating bout numbers
384
- # Increment bout number each time a new bout starts
385
- bout_changes = np.where(classifications == 'start_new_bout', 1, 0)
386
- bout_no = np.cumsum(bout_changes)
387
-
388
- # Assigning classifications and bout numbers to the dataframe
389
- fish_dat['class'] = classifications
390
- fish_dat['bout_no'] = bout_no
391
-
392
- fish_dat = fish_dat.astype({'freq_code': 'object',
393
- 'epoch': 'float32',
394
- 'rec_id': 'object',
395
- 'class': 'object',
396
- 'bout_no':'int32',
397
- 'det_lag':'int32'})
398
-
399
- # append to hdf5
400
- with pd.HDFStore(self.db, mode='a') as store:
401
- store.append(key = 'presence',
402
- value = fish_dat,
403
- format = 'table',
404
- index = False,
405
- min_itemsize = {'freq_code':20,
406
- 'rec_id':20,
407
- 'class':20},
408
- append = True,
409
- data_columns = True,
410
- chunksize = 1000000)
411
-
412
- print ('bouts classified for fish %s'%(fish))
413
-
776
+ p_norm_power = c_norm_power = 0.5
414
777
 
415
- class overlap_reduction():
416
- '''Python class to reduce redundant dections at overlappin receivers.
417
- More often than not, a large aerial yagi will be placed adjacent to a dipole.
418
- The yagi has a much larger detection area and will pick up fish in either detection
419
- zone. The dipole is limited in its coverage area, therefore if a fish is
420
- currently present at a dipole and is also detected at the yagi, we can safely
421
- assume that the detection at the Yagi are overlapping and we can place the fish
422
- at the dipole antenna. By identifying and removing these overlapping detections
423
- we remove bias in time-to-event modeling when we want to understand movement
424
- from detection areas with limited coverage to those areas with larger aerial
425
- coverages.
426
-
427
- This class object contains a series of methods to identify overlapping detections
428
- and import a table for joining into the project database.'''
778
+ p_score = 0.7 * p_mean_posterior + 0.3 * p_norm_power
779
+ c_score = 0.7 * c_mean_posterior + 0.3 * c_norm_power
780
+ if mean_diff_threshold is not None and abs(p_mean_posterior - c_mean_posterior) < mean_diff_threshold:
781
+ decisions['keep_both'] += 1
782
+ return 0
783
+ if p_score > c_score:
784
+ child_mark_idx.extend(c_indices)
785
+ decisions['remove_child'] += 1
786
+ return len(c_indices)
787
+ parent_mark_idx.extend(p_indices)
788
+ decisions['remove_parent'] += 1
789
+ return len(p_indices)
429
790
 
430
- def __init__(self, nodes, edges, radio_project):
431
- '''The initialization module imports data and creates a networkx graph object.
791
+ def _apply_power_decision(
792
+ self,
793
+ parent,
794
+ child,
795
+ p_indices,
796
+ c_indices,
797
+ p_power,
798
+ c_power,
799
+ p_posterior,
800
+ c_posterior,
801
+ power_threshold,
802
+ decisions,
803
+ parent_mark_idx,
804
+ child_mark_idx,
805
+ parent_ambiguous_idx,
806
+ child_ambiguous_idx,
807
+ ):
808
+ logger = logging.getLogger(__name__)
809
+ p_ambiguous = 0
810
+ c_ambiguous = 0
432
811
 
433
- The end user supplies a list of nodes, and a list of edges with instructions
434
- on how to connect them and the function does the rest. NO knowlege of networkx
435
- is required.
812
+ receivers = getattr(self.project, 'receivers', None)
813
+ parent_rec = None
814
+ child_rec = None
815
+ if receivers is None:
816
+ logger.warning(
817
+ "Receiver metadata not available; using relative power normalization."
818
+ )
819
+ else:
820
+ if parent in receivers.index:
821
+ parent_rec = receivers.loc[parent]
822
+ else:
823
+ logger.warning(
824
+ "Receiver '%s' not found in receiver metadata; using relative normalization.",
825
+ parent,
826
+ )
827
+ if child in receivers.index:
828
+ child_rec = receivers.loc[child]
829
+ else:
830
+ logger.warning(
831
+ "Receiver '%s' not found in receiver metadata; using relative normalization.",
832
+ child,
833
+ )
436
834
 
437
- The nodes and edge relationships should start with the outermost nodes and
438
- eventually end with the inner most node/receiver combinations.
835
+ p_max = getattr(parent_rec, 'max_power', -40) if parent_rec is not None else -40
836
+ p_min = getattr(parent_rec, 'min_power', -100) if parent_rec is not None else -100
837
+ c_max = getattr(child_rec, 'max_power', -40) if child_rec is not None else -40
838
+ c_min = getattr(child_rec, 'min_power', -100) if child_rec is not None else -100
439
839
 
440
- Nodes must be a list of nodes and edges must be a list of tuples.
441
- Edge example: [(1,2),(2,3)],
442
- Edges always in format of [(from,to)] or [(outer,inner)] or [(parent,child)]'''
443
- self.db = radio_project.db
840
+ detections_added = 0
841
+ if pd.isna(p_power) or pd.isna(c_power):
842
+ if not pd.isna(p_posterior) and not pd.isna(c_posterior):
843
+ posterior_diff = p_posterior - c_posterior
844
+ if abs(posterior_diff) > 0.1:
845
+ if posterior_diff > 0:
846
+ child_mark_idx.extend(c_indices)
847
+ decisions['remove_child'] += 1
848
+ detections_added += len(c_indices)
849
+ else:
850
+ parent_mark_idx.extend(p_indices)
851
+ decisions['remove_parent'] += 1
852
+ detections_added += len(p_indices)
853
+ else:
854
+ p_ambiguous = 1
855
+ c_ambiguous = 1
856
+ decisions['keep_both'] += 1
857
+ else:
858
+ p_ambiguous = 1
859
+ c_ambiguous = 1
860
+ decisions['keep_both'] += 1
861
+ else:
862
+ if parent_rec is None or child_rec is None:
863
+ denom = (p_power + c_power) if (p_power + c_power) != 0 else 1.0
864
+ p_norm = float(p_power) / denom
865
+ c_norm = float(c_power) / denom
866
+ else:
867
+ p_norm = (p_power - p_min) / (p_max - p_min) if (p_max - p_min) != 0 else 0.5
868
+ c_norm = (c_power - c_min) / (c_max - c_min) if (c_max - c_min) != 0 else 0.5
869
+
870
+ p_norm = max(0.0, min(1.0, p_norm))
871
+ c_norm = max(0.0, min(1.0, c_norm))
872
+ power_diff = p_norm - c_norm
873
+
874
+ if power_diff > power_threshold:
875
+ child_mark_idx.extend(c_indices)
876
+ decisions['remove_child'] += 1
877
+ detections_added += len(c_indices)
878
+ elif power_diff < -power_threshold:
879
+ parent_mark_idx.extend(p_indices)
880
+ decisions['remove_parent'] += 1
881
+ detections_added += len(p_indices)
882
+ else:
883
+ if not pd.isna(p_posterior) and not pd.isna(c_posterior):
884
+ posterior_diff = p_posterior - c_posterior
885
+ if abs(posterior_diff) > 0.1:
886
+ if posterior_diff > 0:
887
+ child_mark_idx.extend(c_indices)
888
+ decisions['remove_child'] += 1
889
+ detections_added += len(c_indices)
890
+ else:
891
+ parent_mark_idx.extend(p_indices)
892
+ decisions['remove_parent'] += 1
893
+ detections_added += len(p_indices)
894
+ else:
895
+ p_ambiguous = 1
896
+ c_ambiguous = 1
897
+ decisions['keep_both'] += 1
898
+ else:
899
+ p_ambiguous = 1
900
+ c_ambiguous = 1
901
+ decisions['keep_both'] += 1
902
+
903
+ if p_ambiguous == 1:
904
+ parent_ambiguous_idx.extend(p_indices)
905
+ if c_ambiguous == 1:
906
+ child_ambiguous_idx.extend(c_indices)
444
907
 
445
- # Step 1, create a directed graph from list of edges
908
+ return detections_added
909
+
910
+ def __init__(self, nodes, edges, radio_project):
911
+ """
912
+ Initializes the OverlapReduction class.
913
+
914
+ Args:
915
+ nodes (list): List of nodes (receiver IDs) in the network.
916
+ edges (list of tuples): Directed edges representing relationships between receivers.
917
+ radio_project (object): Object representing the radio project, containing database path.
918
+
919
+ This method reads and filters data from the project database for each node and stores
920
+ the processed data in dictionaries (`node_pres_dict` and `node_recap_dict`).
921
+ """
922
+ logger = logging.getLogger(__name__)
923
+ logger.info("Initializing overlap_reduction")
924
+
925
+ self.db = radio_project.db
926
+ self.project = radio_project
927
+ self.nodes = nodes
928
+ self.edges = edges
446
929
  self.G = nx.DiGraph()
447
930
  self.G.add_edges_from(edges)
931
+ # Initialize dictionaries for presence and recapture data
932
+ self.node_pres_dict = {}
933
+ self.node_recap_dict = {}
448
934
 
449
- # Step 2, import data and create a dictionary of node dataframes
450
- self.node_pres_dict = dict()
451
- self.node_recap_dict = dict()
935
+ logger.info(f" Loading data for {len(nodes)} nodes")
452
936
 
453
- for i in nodes:
454
- #import data and add to node dict
455
- node_recs = radio_project.receivers[radio_project.receivers.node == i]
456
- node_recs = node_recs.index # get the unique receivers associated with this node
457
- pres_data = pd.DataFrame(columns = ['freq_code','epoch','node','rec_id','presence']) # set up an empty data frame
458
- recap_data = pd.DataFrame(columns = ['freq_code','epoch','node','rec_id'])
937
+ # Read and preprocess data for each node
938
+ for node in tqdm(nodes, desc="Loading nodes", unit="node"):
939
+ pres_data, pres_where, used_full_presence_read = self._read_presence(node, logger)
940
+ recap_data = self._read_recap(node)
941
+ summarized_data, raw_count = self._summarize_presence(
942
+ node,
943
+ pres_data,
944
+ recap_data,
945
+ pres_where,
946
+ used_full_presence_read,
947
+ logger
948
+ )
949
+
950
+ # Store the processed data in the dictionaries
951
+ self.node_pres_dict[node] = summarized_data
952
+ # Don't store recap_data - load on-demand per edge (memory efficient)
953
+ self.node_recap_dict[node] = len(recap_data) # Just track count
954
+ logger.debug(f" {node}: {raw_count} presence records, {len(recap_data)} detections")
955
+
956
+ logger.info(f"✓ Data loaded for {len(nodes)} nodes")
957
+
958
+ def _resolve_db_path(self):
959
+ if getattr(self, "project", None) is not None and getattr(self.project, "db", None):
960
+ return self.project.db
961
+ if getattr(self, "db", None):
962
+ return self.db
963
+ raise RuntimeError("Overlap reduction requires a database path on 'project.db' or 'db'.")
964
+
965
+ def unsupervised_removal(self, method='posterior', p_value_threshold=0.05, effect_size_threshold=0.3,
966
+ power_threshold=0.2, min_detections=1, bout_expansion=0, confidence_threshold=None):
967
+ """
968
+ Unsupervised overlap removal supporting multiple methods with statistical testing.
969
+
970
+ Parameters
971
+ ----------
972
+ method : {'posterior', 'power'}
973
+ 'posterior' (default) uses `posterior_T` columns produced by the
974
+ Naive Bayes classifier (recommended for radio telemetry).
975
+ 'power' compares median power in overlapping bouts (fallback).
976
+ p_value_threshold : float, default=0.05
977
+ Maximum p-value for t-test to consider difference statistically significant.
978
+ Only applies when method='posterior'.
979
+ effect_size_threshold : float, default=0.3
980
+ Minimum Cohen's d effect size required (in addition to statistical significance).
981
+ 0.2 = small, 0.5 = medium, 0.8 = large effect. Lower values (0.3) are more
982
+ conservative for radio telemetry where small differences matter.
983
+ power_threshold : float
984
+ Relative difference threshold for power-based decisions; computed
985
+ as (parent_median - child_median) / max(parent_median, child_median).
986
+ min_detections : int, default=3
987
+ Minimum number of detections required in a bout for statistical comparison.
988
+ bout_expansion : int, default=0
989
+ Seconds to expand bout windows before/after (0 = no expansion, recommended
990
+ for cleaner movement trajectories).
991
+ """
992
+ logger = logging.getLogger(__name__)
993
+ db_path = self._resolve_db_path()
994
+ # Preserve the confidence_threshold value for use as a posterior mean-difference
995
+ # tiebreaker (tests pass `confidence_threshold` expecting this behavior).
996
+ mean_diff_threshold = None
997
+ if confidence_threshold is not None:
998
+ try:
999
+ mean_diff_threshold = float(confidence_threshold)
1000
+ except (TypeError, ValueError) as exc:
1001
+ raise ValueError(
1002
+ f"confidence_threshold must be numeric, got {confidence_threshold!r}"
1003
+ ) from exc
1004
+ logger.info(f"Starting unsupervised overlap removal (method={method})")
1005
+ # Create an empty '/overlapping' table early so downstream readers/tests
1006
+ # will find the key even if no rows are written during processing.
1007
+ try:
1008
+ with pd.HDFStore(db_path, mode='a') as store:
1009
+ if 'overlapping' not in store.keys():
1010
+ # Create a minimal placeholder row so the key exists reliably.
1011
+ placeholder = pd.DataFrame([{
1012
+ 'freq_code': '__DUMMY__',
1013
+ 'epoch': 0,
1014
+ 'time_stamp': pd.Timestamp('1970-01-01'),
1015
+ 'rec_id': '__DUMMY__',
1016
+ 'overlapping': 0,
1017
+ 'ambiguous_overlap': 0.0,
1018
+ 'power': np.nan,
1019
+ 'posterior_T': np.nan,
1020
+ 'posterior_F': np.nan
1021
+ }])
1022
+ # Cast to the same dtypes used by write_results_to_hdf5 to avoid
1023
+ # PyTables validation errors when later appending rows with the
1024
+ # same schema.
1025
+ placeholder = placeholder.astype({
1026
+ 'freq_code': 'object',
1027
+ 'epoch': 'int32',
1028
+ 'time_stamp': 'datetime64[ns]',
1029
+ 'rec_id': 'object',
1030
+ 'overlapping': 'int32',
1031
+ 'ambiguous_overlap': 'float32',
1032
+ 'power': 'float32',
1033
+ 'posterior_T': 'float32',
1034
+ 'posterior_F': 'float32'
1035
+ })
1036
+ # Use append with a generous min_itemsize so later real values
1037
+ # (longer strings) can be appended without hitting PyTables limits.
1038
+ store.append(key='overlapping', value=placeholder, format='table', data_columns=True,
1039
+ min_itemsize={'freq_code': 50, 'rec_id': 50})
1040
+ except (OSError, ValueError, KeyError, RuntimeError) as exc:
1041
+ raise RuntimeError("Failed to pre-create /overlapping table in HDF5.") from exc
1042
+ overlaps_processed = 0
1043
+ detections_marked = 0
1044
+ decisions = {'remove_parent': 0, 'remove_child': 0, 'keep_both': 0}
1045
+ skip_reasons = {'parent_too_small': 0, 'no_overlap': 0, 'child_too_small': 0,
1046
+ 'no_posterior_data': 0, 'insufficient_after_nan': 0}
1047
+
1048
+ # Precompute per-node, per-bout summaries (indices, posterior means, median power)
1049
+ # and build IntervalTrees per fish for fast overlap queries. This avoids
1050
+ # repeated mean()/median() computations inside the tight edge loops.
1051
+ node_bout_index = {} # node -> fish -> list of bout dicts
1052
+ node_bout_trees = {} # node -> fish -> IntervalTree
1053
+ node_recap_cache = {} # node -> recap DataFrame (cache for edge loop)
1054
+ for node, bouts in self.node_pres_dict.items():
1055
+ # Load recap data and cache it for use in edge loop
1056
+ cached_recaps = self.node_recap_dict.get(node)
1057
+ if isinstance(cached_recaps, pd.DataFrame):
1058
+ recaps = cached_recaps.copy()
1059
+ if 'test' in recaps.columns:
1060
+ recaps = recaps[recaps['test'] == 1]
1061
+ if 'iter' in recaps.columns:
1062
+ recaps = recaps[recaps['iter'] == recaps['iter'].max()]
1063
+ else:
1064
+ try:
1065
+ recaps = pd.read_hdf(
1066
+ db_path,
1067
+ 'classified',
1068
+ where=f"(rec_id == '{node}') & (test == 1)",
1069
+ columns=[
1070
+ 'freq_code',
1071
+ 'epoch',
1072
+ 'time_stamp',
1073
+ 'power',
1074
+ 'rec_id',
1075
+ 'iter',
1076
+ 'test',
1077
+ 'posterior_T',
1078
+ 'posterior_F',
1079
+ ],
1080
+ )
1081
+ if 'iter' not in recaps.columns:
1082
+ raise KeyError("Missing 'iter' column in classified table.")
1083
+ recaps = recaps[recaps['iter'] == recaps['iter'].max()]
1084
+ except (FileNotFoundError, KeyError, OSError, ValueError) as exc:
1085
+ raise RuntimeError(
1086
+ f"Failed to read classified detections for rec_id '{node}': {exc}"
1087
+ ) from exc
1088
+ node_recap_cache[node] = recaps # Cache for later use
459
1089
 
460
- for j in node_recs:
461
- # get presence data and final classifications for this receiver
462
- presence_dat = pd.read_hdf(radio_project.db,'presence', where = 'rec_id = %s'%(j))
463
- presence_dat['node'] = np.repeat(i,len(presence_dat))
464
- class_dat = pd.read_hdf(radio_project.db,'classified', where = 'rec_id = %s'%(j))
465
- class_dat = class_dat[class_dat.iter == class_dat.iter.max()]
466
- class_dat = class_dat[class_dat.test == 1]
467
- class_dat = class_dat[['freq_code', 'epoch', 'rec_id']]
468
- class_dat['node'] = np.repeat(i,len(class_dat))
469
- # append to node specific dataframe
470
- pres_data = pres_data.append(presence_dat)
471
- recap_data = recap_data.append(class_dat)
472
-
473
- # now that we have data, we need to summarize it, use group by to get min ans max epoch by freq code, recID and presence_number
474
- dat = pres_data.groupby(['freq_code','bout_no','node','rec_id'])['epoch'].agg(['min','max'])
475
- dat.reset_index(inplace = True, drop = False)
476
- dat.rename(columns = {'min':'min_epoch','max':'max_epoch'},inplace = True)
477
- self.node_pres_dict[i] = dat
478
- self.node_recap_dict[i] = recap_data
1090
+ node_bout_index[node] = {}
1091
+ node_bout_trees[node] = {}
1092
+ if bouts.empty or recaps.empty:
1093
+ continue
1094
+ # ensure epoch dtype numeric for comparisons
1095
+ recaps_epoch = recaps['epoch']
1096
+ for fish_id, fish_bouts in bouts.groupby('freq_code'):
1097
+ r_fish = recaps[recaps['freq_code'] == fish_id]
1098
+ bout_list = []
1099
+ intervals = []
1100
+ for b_idx, bout_row in fish_bouts.reset_index(drop=True).iterrows():
1101
+ min_epoch = bout_row['min_epoch']
1102
+ max_epoch = bout_row['max_epoch']
1103
+ if bout_expansion and bout_expansion > 0:
1104
+ min_epoch = min_epoch - bout_expansion
1105
+ max_epoch = max_epoch + bout_expansion
1106
+
1107
+ if not r_fish.empty:
1108
+ mask = (r_fish['epoch'] >= min_epoch) & (r_fish['epoch'] <= max_epoch)
1109
+ in_df = r_fish.loc[mask]
1110
+ indices = in_df.index.tolist()
1111
+ posterior = in_df['posterior_T'].mean(skipna=True) if 'posterior_T' in in_df.columns else np.nan
1112
+ median_power = in_df['power'].median() if 'power' in in_df.columns else np.nan
1113
+ else:
1114
+ indices = []
1115
+ posterior = np.nan
1116
+ median_power = np.nan
1117
+
1118
+ bout_list.append({'min_epoch': min_epoch, 'max_epoch': max_epoch, 'indices': indices, 'posterior': posterior, 'median_power': median_power})
1119
+ intervals.append((min_epoch, max_epoch, b_idx))
1120
+
1121
+ node_bout_index[node][fish_id] = bout_list
1122
+ # build IntervalTree for this fish (only include intervals with numeric bounds)
1123
+ try:
1124
+ interval_entries = []
1125
+ for a, b, idx in intervals:
1126
+ if pd.isna(a) or pd.isna(b):
1127
+ continue
1128
+ a_int = int(a)
1129
+ b_int = int(b)
1130
+ # IntervalTree does not allow null intervals; expand single-point bouts by 1 second.
1131
+ if b_int <= a_int:
1132
+ b_int = a_int + 1
1133
+ interval_entries.append(Interval(a_int, b_int, idx))
1134
+ tree = IntervalTree(interval_entries)
1135
+ node_bout_trees[node][fish_id] = tree
1136
+ except (TypeError, ValueError) as exc:
1137
+ raise ValueError(
1138
+ f"Invalid bout interval bounds for node '{node}' fish '{fish_id}': {exc}"
1139
+ ) from exc
1140
+
1141
+ # If the posterior-based method was requested, ensure there is at least
1142
+ # some `posterior_T` data available in the cached recapture tables. Tests
1143
+ # and callers rely on this raising an error when posterior data are absent.
1144
+ if method == 'posterior':
1145
+ has_posterior = False
1146
+ for df in node_recap_cache.values():
1147
+ if not df.empty and 'posterior_T' in df.columns:
1148
+ # also ensure there is at least one non-null posterior value
1149
+ if df['posterior_T'].notna().any():
1150
+ has_posterior = True
1151
+ break
1152
+ if not has_posterior:
1153
+ raise ValueError("Method 'posterior' requested but no 'posterior_T' values are available in classified data")
1154
+
1155
+ for edge_idx, (parent, child) in enumerate(tqdm(self.edges, desc="Processing edges", unit="edge")):
1156
+ logger.info(f"Edge {edge_idx+1}/{len(self.edges)}: {parent} → {child}")
1157
+
1158
+ parent_bouts = self.node_pres_dict.get(parent, pd.DataFrame())
479
1159
 
480
- # clean up
481
- del pres_data, recap_data, dat, presence_dat, class_dat
482
- print ("Completed data management process for node %s"%(i))
483
-
484
- # visualize the graph
485
- shells = []
486
- for n in list(self.G.nodes()):
487
- successors = list(self.G.succ[n].keys())
488
- shells.append(successors)
489
-
490
- fig, ax = plt.subplots(1, 1, figsize=(4, 4));
491
- pos= nx.circular_layout(self.G)
492
- nx.draw_networkx_nodes(self.G,pos,list(self.G.nodes()),node_color = 'r',node_size = 400)
493
- nx.draw_networkx_edges(self.G,pos,list(self.G.edges()),edge_color = 'k')
494
- nx.draw_networkx_labels(self.G,pos,font_size=8)
495
- plt.axis('off')
496
- plt.show()
1160
+ # Use cached recap data but make a COPY to avoid cross-edge contamination
1161
+ # Without .copy(), the ambiguous_overlap column carries over between edges
1162
+ parent_dat = node_recap_cache.get(parent, pd.DataFrame()).copy()
1163
+ child_dat = node_recap_cache.get(child, pd.DataFrame()).copy()
1164
+
1165
+ # Quick skip when any required table is empty
1166
+ if parent_bouts.empty or parent_dat.empty or child_dat.empty:
1167
+ logger.debug(f"Skipping {parent}->{child}: empty data")
1168
+ continue
1169
+
1170
+ # Normalize freq_code dtype and pre-split recapture tables by freq_code
1171
+ # to avoid repeated full-DataFrame boolean comparisons inside loops.
1172
+ if 'freq_code' in parent_dat.columns:
1173
+ parent_dat['freq_code'] = parent_dat['freq_code'].astype('object')
1174
+ if 'freq_code' in child_dat.columns:
1175
+ child_dat['freq_code'] = child_dat['freq_code'].astype('object')
1176
+
1177
+ parent_by_fish = {k: v for k, v in parent_dat.groupby('freq_code')} if not parent_dat.empty else {}
1178
+ child_by_fish = {k: v for k, v in child_dat.groupby('freq_code')} if not child_dat.empty else {}
1179
+
1180
+ # Initialize overlapping and ambiguous_overlap columns fresh for each edge
1181
+ # This prevents carryover from previous edges
1182
+ parent_dat['overlapping'] = np.float32(0)
1183
+ parent_dat['ambiguous_overlap'] = np.float32(0)
1184
+ child_dat['overlapping'] = np.float32(0)
1185
+ child_dat['ambiguous_overlap'] = np.float32(0)
1186
+
1187
+ fishes = parent_bouts['freq_code'].unique()
1188
+ logger.debug(f" Processing {len(fishes)} fish for edge {parent}->{child}")
1189
+ print(f" [overlap] {parent}→{child}: processing {len(fishes)} fish")
1190
+
1191
+ # Buffers for indices to mark as overlapping or ambiguous for this edge
1192
+ parent_mark_idx = []
1193
+ child_mark_idx = []
1194
+ parent_ambiguous_idx = []
1195
+ child_ambiguous_idx = []
1196
+
1197
+ for fish_idx, fish_id in enumerate(fishes, 1):
1198
+ # Progress update every 10 fish or for the last fish
1199
+ if fish_idx % 10 == 0 or fish_idx == len(fishes):
1200
+ print(f" [overlap] {parent}→{child}: fish {fish_idx}/{len(fishes)} ({fish_id})", end='\r')
1201
+ # fast access to precomputed bout lists and trees
1202
+ p_bouts = node_bout_index.get(parent, {}).get(fish_id, [])
1203
+ c_tree = node_bout_trees.get(child, {}).get(fish_id, IntervalTree())
1204
+
1205
+ if not p_bouts or c_tree is None:
1206
+ continue
1207
+
1208
+ for p_i, p_info in enumerate(p_bouts):
1209
+ p_indices = p_info['indices']
1210
+ p_conf = p_info['posterior']
1211
+ p_power = p_info['median_power']
1212
+
1213
+ # skip bouts with insufficient detections
1214
+ if (not p_indices) or len(p_indices) < min_detections:
1215
+ decisions['keep_both'] += 1
1216
+ skip_reasons['parent_too_small'] += 1
1217
+ continue
1218
+
1219
+ # query overlapping child bouts via IntervalTree
1220
+ overlaps = c_tree.overlap(int(p_info['min_epoch']), int(p_info['max_epoch']))
1221
+ # Fallback: IntervalTree may miss matches for tiny ranges in some
1222
+ # synthetic/test datasets; perform a manual scan of child bouts
1223
+ # if no overlaps were returned by the tree.
1224
+ if not overlaps:
1225
+ manual_overlaps = []
1226
+ child_bouts_list = node_bout_index.get(child, {}).get(fish_id, [])
1227
+ for c_idx_manual, c_info_manual in enumerate(child_bouts_list):
1228
+ try:
1229
+ c_min = int(c_info_manual.get('min_epoch', -1))
1230
+ c_max = int(c_info_manual.get('max_epoch', -1))
1231
+ except (TypeError, ValueError) as exc:
1232
+ raise ValueError(
1233
+ f"Invalid child bout bounds for node '{child}' fish '{fish_id}'."
1234
+ ) from exc
1235
+ if (c_min <= int(p_info['max_epoch'])) and (c_max >= int(p_info['min_epoch'])):
1236
+ # create a lightweight object with `.data` attribute to mimic Interval
1237
+ class _O:
1238
+ def __init__(self, d):
1239
+ self.data = d
1240
+ manual_overlaps.append(_O(c_idx_manual))
1241
+ overlaps = manual_overlaps
1242
+ if not overlaps:
1243
+ decisions['keep_both'] += 1
1244
+ skip_reasons['no_overlap'] += 1
1245
+ continue
1246
+
1247
+ overlaps_processed += 1
1248
+
1249
+ for iv in overlaps:
1250
+ c_idx = iv.data
1251
+ try:
1252
+ c_info = node_bout_index[child][fish_id][c_idx]
1253
+ except (KeyError, IndexError, TypeError) as exc:
1254
+ raise KeyError(
1255
+ f"Missing child bout index {c_idx} for node '{child}' fish '{fish_id}'."
1256
+ ) from exc
1257
+
1258
+ c_indices = c_info['indices']
1259
+ c_conf = c_info['posterior']
1260
+ c_power = c_info['median_power']
1261
+
1262
+ # require minimum detections on both
1263
+ if (not c_indices) or len(c_indices) < min_detections:
1264
+ decisions['keep_both'] += 1
1265
+ skip_reasons['child_too_small'] += 1
1266
+ continue
1267
+
1268
+ if method == 'posterior':
1269
+ detections_marked += self._apply_posterior_decision(
1270
+ parent_dat,
1271
+ child_dat,
1272
+ p_indices,
1273
+ c_indices,
1274
+ p_power,
1275
+ c_power,
1276
+ min_detections,
1277
+ p_value_threshold,
1278
+ effect_size_threshold,
1279
+ mean_diff_threshold,
1280
+ decisions,
1281
+ skip_reasons,
1282
+ parent_mark_idx,
1283
+ child_mark_idx,
1284
+ )
1285
+ elif method == 'power':
1286
+ p_posterior = p_info.get('posterior', np.nan)
1287
+ c_posterior = c_info.get('posterior', np.nan)
1288
+ detections_marked += self._apply_power_decision(
1289
+ parent,
1290
+ child,
1291
+ p_indices,
1292
+ c_indices,
1293
+ p_power,
1294
+ c_power,
1295
+ p_posterior,
1296
+ c_posterior,
1297
+ power_threshold,
1298
+ decisions,
1299
+ parent_mark_idx,
1300
+ child_mark_idx,
1301
+ parent_ambiguous_idx,
1302
+ child_ambiguous_idx,
1303
+ )
1304
+ else:
1305
+ raise ValueError(f"Unknown method: {method}")
1306
+
1307
+ # After processing all fish/bouts for this edge, bulk-assign overlapping flags
1308
+ print(f"\n [overlap] {parent}→{child}: marking {len(set(parent_mark_idx))} parent + {len(set(child_mark_idx))} child detections as overlapping")
1309
+ print(f" [overlap] {parent}→{child}: marking {len(set(parent_ambiguous_idx))} parent + {len(set(child_ambiguous_idx))} child detections as ambiguous")
1310
+ if parent_mark_idx:
1311
+ parent_dat.loc[sorted(set(parent_mark_idx)), 'overlapping'] = np.float32(1)
1312
+ if child_mark_idx:
1313
+ child_dat.loc[sorted(set(child_mark_idx)), 'overlapping'] = np.float32(1)
1314
+
1315
+ # Bulk-assign ambiguous_overlap flags
1316
+ if parent_ambiguous_idx:
1317
+ parent_dat.loc[sorted(set(parent_ambiguous_idx)), 'ambiguous_overlap'] = np.float32(1)
1318
+ if child_ambiguous_idx:
1319
+ child_dat.loc[sorted(set(child_ambiguous_idx)), 'ambiguous_overlap'] = np.float32(1)
1320
+
1321
+ # Write ONLY the marked detections (overlapping=1 OR ambiguous_overlap=1)
1322
+ # Combine overlapping and ambiguous indices (use set to avoid duplicates)
1323
+ parent_write_idx = sorted(set(parent_mark_idx + parent_ambiguous_idx))
1324
+ child_write_idx = sorted(set(child_mark_idx + child_ambiguous_idx))
1325
+
1326
+ logger.debug(f" Writing results for {parent} and {child} (parent overlapping={len(parent_mark_idx)}, ambiguous={len(parent_ambiguous_idx)}, child overlapping={len(child_mark_idx)}, ambiguous={len(child_ambiguous_idx)})")
1327
+ print(f" [overlap] {parent}→{child}: writing overlapping detections to HDF5...")
1328
+
1329
+ # Only write detections that were marked as overlapping or ambiguous
1330
+ if parent_write_idx:
1331
+ parent_overlapping = parent_dat.loc[parent_write_idx]
1332
+ ambig_count = (parent_overlapping['ambiguous_overlap'] == 1).sum()
1333
+ if ambig_count > 0:
1334
+ print(f" [overlap] {parent}→{child}: writing {ambig_count} parent ambiguous detections")
1335
+ self.write_results_to_hdf5(parent_overlapping)
1336
+ if child_write_idx:
1337
+ child_overlapping = child_dat.loc[child_write_idx]
1338
+ ambig_count = (child_overlapping['ambiguous_overlap'] == 1).sum()
1339
+ if ambig_count > 0:
1340
+ print(f" [overlap] {parent}→{child}: writing {ambig_count} child ambiguous detections")
1341
+ self.write_results_to_hdf5(child_overlapping)
1342
+ print(f" [overlap] ✓ {parent}→{child} complete\n")
1343
+
1344
+ # cleanup
1345
+ del parent_bouts, parent_dat, child_dat
1346
+ gc.collect()
1347
+
1348
+ # Calculate statistics from HDF5 overlapping table
1349
+ logger.info("Calculating final statistics from overlapping table...")
1350
+ try:
1351
+ with pd.HDFStore(db_path, mode='r') as store:
1352
+ if '/overlapping' in store:
1353
+ overlapping_table = store.select('overlapping')
1354
+ total_written = len(overlapping_table)
1355
+ overlapping_count = (overlapping_table['overlapping'] == 1).sum()
1356
+ ambiguous_count = (overlapping_table['ambiguous_overlap'] == 1).sum()
1357
+ unique_fish = overlapping_table['freq_code'].nunique()
1358
+ unique_receivers = overlapping_table['rec_id'].nunique()
1359
+ else:
1360
+ total_written = overlapping_count = ambiguous_count = unique_fish = unique_receivers = 0
1361
+ except (OSError, KeyError, ValueError) as exc:
1362
+ raise RuntimeError(
1363
+ f"Could not read overlapping table for statistics: {exc}"
1364
+ ) from exc
1365
+
1366
+ print("\n" + "="*80)
1367
+ logger.info("✓ Unsupervised overlap removal complete")
1368
+ logger.info(f" Overlapping bouts processed: {overlaps_processed}")
1369
+ logger.info(f" Detections marked as overlapping: {detections_marked}")
1370
+ logger.info(f" Decision breakdown: {decisions}")
1371
+ logger.info(f" Skip reasons: {skip_reasons}")
1372
+ print(f"[overlap] ✓ Complete: {overlaps_processed} overlaps processed")
1373
+ print(f"[overlap] Decisions: remove_parent={decisions['remove_parent']}, remove_child={decisions['remove_child']}, keep_both={decisions['keep_both']}")
1374
+ print(f"[overlap] Skip breakdown: parent_too_small={skip_reasons['parent_too_small']}, child_too_small={skip_reasons['child_too_small']}, no_posterior={skip_reasons['no_posterior_data']}, insufficient_after_nan={skip_reasons['insufficient_after_nan']}")
1375
+ print(f"\n[overlap] Written to HDF5:")
1376
+ print(f" Total detections in /overlapping table: {total_written:,}")
1377
+ print(f" Detections with overlapping=1: {overlapping_count:,} ({100*overlapping_count/total_written if total_written > 0 else 0:.1f}%)")
1378
+ print(f" Detections with ambiguous_overlap=1: {ambiguous_count:,} ({100*ambiguous_count/total_written if total_written > 0 else 0:.1f}%)")
1379
+ print(f" Unique fish affected: {unique_fish}")
1380
+ print(f" Unique receivers affected: {unique_receivers}")
1381
+ print("="*80)
1382
+
1383
+ # Ensure an '/overlapping' key exists in the HDF5 file even if no rows were written.
1384
+ try:
1385
+ with pd.HDFStore(db_path, mode='a') as store:
1386
+ if 'overlapping' not in store.keys():
1387
+ empty_df = pd.DataFrame(columns=['freq_code', 'epoch', 'time_stamp', 'rec_id', 'overlapping', 'ambiguous_overlap', 'power', 'posterior_T', 'posterior_F'])
1388
+ # set dtypes consistent with write_results_to_hdf5
1389
+ empty_df = empty_df.astype({
1390
+ 'freq_code': 'object',
1391
+ 'epoch': 'int32',
1392
+ 'rec_id': 'object',
1393
+ 'overlapping': 'int32',
1394
+ 'ambiguous_overlap': 'float32',
1395
+ 'power': 'float32',
1396
+ 'posterior_T': 'float32',
1397
+ 'posterior_F': 'float32'
1398
+ })
1399
+ # Use put instead of append to ensure an empty table is created
1400
+ store.put(key='overlapping', value=empty_df, format='table', data_columns=True,
1401
+ min_itemsize={'freq_code': 50, 'rec_id': 50})
1402
+ except (OSError, KeyError, ValueError) as exc:
1403
+ raise RuntimeError(
1404
+ f"Failed to ensure /overlapping table exists: {exc}"
1405
+ ) from exc
1406
+
1407
+ # Apply bout-based spatial filter to handle antenna bleed
1408
+ self._apply_bout_spatial_filter()
1409
+
1410
+ def _apply_bout_spatial_filter(self, temporal_overlap_threshold=0.5):
1411
+ """
1412
+ Apply bout-based spatial logic filter to handle antenna bleed in overlapping table.
1413
+
1414
+ When a fish has simultaneous bouts at multiple receivers (temporal overlap),
1415
+ keep the longer/stronger bout and mark the shorter bout as overlapping=1.
1416
+
1417
+ This addresses the problem where powerhouse antennas detect fish on the "wrong"
1418
+ side due to back lobes, reflections, or diffraction.
1419
+
1420
+ Parameters
1421
+ ----------
1422
+ temporal_overlap_threshold : float
1423
+ Fraction of temporal overlap required to consider bouts conflicting (0-1)
1424
+ Default 0.5 = 50% overlap
1425
+ """
1426
+ import logging
1427
+ logger = logging.getLogger(__name__)
1428
+ db_path = self._resolve_db_path()
1429
+
1430
+ print(f"\n{'='*80}")
1431
+ print(f"BOUT-BASED SPATIAL FILTER")
1432
+ print(f"{'='*80}")
1433
+ print(f"[overlap] Resolving antenna bleed using bout strength...")
1434
+
1435
+ # Read overlapping table
1436
+ try:
1437
+ overlapping_data = pd.read_hdf(db_path, key='/overlapping')
1438
+ except (OSError, KeyError, ValueError) as exc:
1439
+ raise RuntimeError(f"Failed to read /overlapping table: {exc}") from exc
1440
+
1441
+ if overlapping_data.empty:
1442
+ print(f"[overlap] No data in overlapping table")
1443
+ return
1444
+
1445
+ # Get bout summaries from presence table
1446
+ try:
1447
+ presence_data = pd.read_hdf(db_path, key='/presence')
1448
+ except KeyError:
1449
+ logger.warning("Presence table missing; skipping bout-based spatial filter.")
1450
+ print("[overlap] /presence table missing; skipping bout-based spatial filter.")
1451
+ return
1452
+ except (OSError, ValueError) as exc:
1453
+ raise RuntimeError(f"Failed to read /presence table: {exc}") from exc
1454
+
1455
+ if presence_data.empty or 'bout_no' not in presence_data.columns:
1456
+ raise ValueError("[overlap] Presence data empty or missing bout_no")
1457
+
1458
+ # Build bout summary: min/max epoch, detection count per bout
1459
+ bout_summary = presence_data.groupby(['freq_code', 'rec_id', 'bout_no']).agg({
1460
+ 'epoch': ['min', 'max', 'count']
1461
+ }).reset_index()
1462
+
1463
+ bout_summary.columns = ['freq_code', 'rec_id', 'bout_no', 'min_epoch', 'max_epoch', 'num_detections']
1464
+ bout_summary['bout_duration'] = bout_summary['max_epoch'] - bout_summary['min_epoch']
1465
+
1466
+ print(f"[overlap] Loaded {len(bout_summary):,} bouts from {bout_summary['freq_code'].nunique()} fish")
1467
+
1468
+ # Track which detections to mark as overlapping
1469
+ detections_to_mark = [] # List of (freq_code, rec_id, bout_no) tuples
1470
+ conflicts_found = 0
1471
+
1472
+ # For each fish, check for temporally overlapping bouts at different receivers
1473
+ for fish in bout_summary['freq_code'].unique():
1474
+ fish_bouts = bout_summary[bout_summary['freq_code'] == fish].copy()
1475
+
1476
+ if len(fish_bouts) < 2:
1477
+ continue # Can't have conflicts with only one bout
1478
+
1479
+ # Compare all pairs of bouts for this fish
1480
+ for i, bout_a in fish_bouts.iterrows():
1481
+ for j, bout_b in fish_bouts.iterrows():
1482
+ if i >= j: # Skip self-comparison and duplicates
1483
+ continue
1484
+
1485
+ # Only consider bouts at different receivers
1486
+ if bout_a['rec_id'] == bout_b['rec_id']:
1487
+ continue
1488
+
1489
+ # Calculate temporal overlap
1490
+ overlap_start = max(bout_a['min_epoch'], bout_b['min_epoch'])
1491
+ overlap_end = min(bout_a['max_epoch'], bout_b['max_epoch'])
1492
+ overlap_duration = max(0, overlap_end - overlap_start)
1493
+
1494
+ # Calculate overlap as fraction of shorter bout
1495
+ min_duration = min(bout_a['bout_duration'], bout_b['bout_duration'])
1496
+
1497
+ if min_duration > 0:
1498
+ overlap_fraction = overlap_duration / min_duration
1499
+ else:
1500
+ overlap_fraction = 0
1501
+
1502
+ # If significant temporal overlap exists, we have a conflict
1503
+ if overlap_fraction >= temporal_overlap_threshold:
1504
+ conflicts_found += 1
1505
+
1506
+ # Decide which bout to mark as overlapping based on:
1507
+ # 1. Number of detections (primary - longer bout is more reliable)
1508
+ # 2. Duration (secondary - longer time = more confidence)
1509
+
1510
+ if bout_a['num_detections'] > bout_b['num_detections']:
1511
+ # Keep A, mark B as overlapping
1512
+ loser = (bout_b['freq_code'], bout_b['rec_id'], bout_b['bout_no'])
1513
+ winner_rec = bout_a['rec_id']
1514
+ loser_rec = bout_b['rec_id']
1515
+ winner_dets = bout_a['num_detections']
1516
+ loser_dets = bout_b['num_detections']
1517
+ elif bout_b['num_detections'] > bout_a['num_detections']:
1518
+ # Keep B, mark A as overlapping
1519
+ loser = (bout_a['freq_code'], bout_a['rec_id'], bout_a['bout_no'])
1520
+ winner_rec = bout_b['rec_id']
1521
+ loser_rec = bout_a['rec_id']
1522
+ winner_dets = bout_b['num_detections']
1523
+ loser_dets = bout_a['num_detections']
1524
+ else:
1525
+ # Same detection count - use duration as tiebreaker
1526
+ if bout_a['bout_duration'] > bout_b['bout_duration']:
1527
+ loser = (bout_b['freq_code'], bout_b['rec_id'], bout_b['bout_no'])
1528
+ winner_rec = bout_a['rec_id']
1529
+ loser_rec = bout_b['rec_id']
1530
+ winner_dets = bout_a['num_detections']
1531
+ loser_dets = bout_b['num_detections']
1532
+ else:
1533
+ loser = (bout_a['freq_code'], bout_a['rec_id'], bout_a['bout_no'])
1534
+ winner_rec = bout_b['rec_id']
1535
+ loser_rec = bout_a['rec_id']
1536
+ winner_dets = bout_b['num_detections']
1537
+ loser_dets = bout_a['num_detections']
1538
+
1539
+ detections_to_mark.append(loser)
1540
+ logger.debug(f" Fish {fish}: {winner_rec} ({winner_dets} dets) vs {loser_rec} ({loser_dets} dets, {overlap_fraction*100:.0f}% overlap) → Marking {loser_rec} as overlapping")
1541
+
1542
+ # Mark conflicting bouts as overlapping=1 in overlapping table
1543
+ if len(detections_to_mark) > 0:
1544
+ # Need to join overlapping_data with presence to get bout_no
1545
+ overlapping_with_bouts = overlapping_data.merge(
1546
+ presence_data[['freq_code', 'rec_id', 'epoch', 'bout_no']],
1547
+ on=['freq_code', 'rec_id', 'epoch'],
1548
+ how='left'
1549
+ )
1550
+
1551
+ # Mark detections from losing bouts as overlapping=1
1552
+ initial_overlapping = (overlapping_with_bouts['overlapping'] == 1).sum()
1553
+
1554
+ for fish, rec, bout in detections_to_mark:
1555
+ mask = (
1556
+ (overlapping_with_bouts['freq_code'] == fish) &
1557
+ (overlapping_with_bouts['rec_id'] == rec) &
1558
+ (overlapping_with_bouts['bout_no'] == bout)
1559
+ )
1560
+ overlapping_with_bouts.loc[mask, 'overlapping'] = 1
1561
+
1562
+ final_overlapping = (overlapping_with_bouts['overlapping'] == 1).sum()
1563
+ newly_marked = final_overlapping - initial_overlapping
1564
+
1565
+ # Drop bout_no before writing back (not needed in overlapping table)
1566
+ overlapping_with_bouts = overlapping_with_bouts.drop(columns=['bout_no'])
1567
+
1568
+ # Write back to HDF5 (replace entire table)
1569
+ with pd.HDFStore(db_path, mode='a') as store:
1570
+ # Remove old table
1571
+ if '/overlapping' in store:
1572
+ store.remove('overlapping')
1573
+
1574
+ # Write updated table
1575
+ store.append(
1576
+ key='overlapping',
1577
+ value=overlapping_with_bouts,
1578
+ format='table',
1579
+ data_columns=True,
1580
+ min_itemsize={'freq_code': 20, 'rec_id': 20}
1581
+ )
1582
+
1583
+ print(f"\n[overlap] Bout spatial filter complete:")
1584
+ print(f" Found {conflicts_found} temporal bout conflicts")
1585
+ print(f" Marked {len(detections_to_mark)} conflicting bouts as overlapping")
1586
+ print(f" Newly marked {newly_marked:,} detections ({newly_marked/len(overlapping_data)*100:.1f}%)")
1587
+ print(f" Total overlapping detections: {final_overlapping:,} ({final_overlapping/len(overlapping_data)*100:.1f}%)")
1588
+
1589
+ logger.info(f"Bout spatial filter marked {newly_marked} additional detections as overlapping")
1590
+ else:
1591
+ print(f"[overlap] No temporally overlapping bouts found across different receivers")
1592
+ logger.info("Bout spatial filter: no conflicts found")
497
1593
 
498
1594
  def nested_doll(self):
499
- '''Function iterates through matching recap data from successors to see if
500
- current recapture row at predeccesor overlaps with successor presence.'''
501
- # create function that we can vectorize over list of epochs (i)
502
- def overlap_check(i, min_epoch, max_epoch):
503
- return np.logical_and(min_epoch >= i, max_epoch < i).any()
504
- for i in self.node_recap_dict:
1595
+ """
1596
+ Identify and mark overlapping detections between parent and child nodes.
1597
+ """
1598
+ logger = logging.getLogger(__name__)
1599
+ logger.info("Starting nested_doll overlap detection")
1600
+ logger.info(" Method: Interval-based (conservative)")
1601
+
1602
+ overlaps_found = False
1603
+ overlap_count = 0
1604
+
1605
+ for i in tqdm(self.node_recap_dict, desc="Processing nodes", unit="node"):
505
1606
  fishes = self.node_recap_dict[i].freq_code.unique()
506
1607
 
507
1608
  for j in fishes:
508
- children = self.G.succ[i]
1609
+ children = list(self.G.successors(i))
509
1610
  fish_dat = self.node_recap_dict[i][self.node_recap_dict[i].freq_code == j]
510
- fish_dat['overlapping'] = np.zeros(len(fish_dat))
511
- fish_dat['parent'] = np.repeat('',len(fish_dat))
512
- fish_dat.set_index('epoch', inplace = True, drop = False)
513
- if len(children) > 0: # if there is no child node, who cares? there is no overlapping detections, we are at the middle doll
1611
+ fish_dat['overlapping'] = 0.0
1612
+
1613
+ if len(children) > 0:
514
1614
  for k in children:
515
- child_dat = self.node_pres_dict[i][self.node_pres_dict[i].freq_code == j]
1615
+ child_dat = self.node_pres_dict[k][self.node_pres_dict[k].freq_code == j]
516
1616
  if len(child_dat) > 0:
517
- for l in child_dat.rec_id.unique():
518
- while l != i:
519
- rec_dat = child_dat[child_dat.rec_id == l]
520
- min_epochs = child_dat.min_epoch.values
521
- max_epochs = child_dat.max_epoch.values
522
- for m in fish_dat.epoch.values: # for every row in the fish data
523
- if np.logical_and(min_epochs <= m, max_epochs >m).any(): # if the current epoch is within a presence at a child receiver
524
- print ("Overlap Found, at %s fish %s was recaptured at both %s and %s"%(m,j,i,l))
525
- fish_dat.at[m,'overlapping'] = 1
526
- fish_dat.at[m,'parent'] = i
527
-
528
- fish_dat.reset_index(inplace = True, drop = True)
1617
+ min_epochs = child_dat.min_epoch.values
1618
+ max_epochs = child_dat.max_epoch.values
1619
+
1620
+ fish_epochs = fish_dat.epoch.values
1621
+ overlaps = np.any(
1622
+ (min_epochs[:, None] <= fish_epochs) & (max_epochs[:, None] > fish_epochs), axis=0
1623
+ )
1624
+ overlap_indices = np.where(overlaps)[0]
1625
+ if overlap_indices.size > 0:
1626
+ overlaps_found = True
1627
+ overlap_count += overlap_indices.size
1628
+ fish_dat.loc[overlaps, 'overlapping'] = 1.0
1629
+ #fish_dat.loc[overlaps, 'parent'] = i
1630
+
1631
+ # fish_dat = fish_dat.astype({
1632
+ # 'freq_code': 'object',
1633
+ # 'epoch': 'int32',
1634
+ # 'rec_id': 'object',
1635
+ # 'overlapping': 'int32',
1636
+ # })
1637
+ fish_dat = fish_dat[['freq_code', 'epoch', 'time_stamp', 'rec_id', 'overlapping']]
1638
+ self.write_results_to_hdf5(fish_dat)
1639
+
1640
+ # with pd.HDFStore(self.db, mode='a') as store:
1641
+ # store.append(key='overlapping',
1642
+ # value=fish_dat,
1643
+ # format='table',
1644
+ # index=False,
1645
+ # min_itemsize={'freq_code': 20,
1646
+ # 'rec_id': 20},
1647
+ # append=True,
1648
+ # data_columns=True,
1649
+ # chunksize=1000000)
1650
+
1651
+ if overlaps_found:
1652
+ logger.info(f"✓ Nested doll complete")
1653
+ logger.info(f" Total overlaps found: {overlap_count}")
1654
+ else:
1655
+ logger.info("✓ Nested doll complete - no overlaps found")
1656
+
1657
+ def write_results_to_hdf5(self, df):
1658
+ """
1659
+ Writes the processed DataFrame to the HDF5 database.
1660
+
1661
+ Args:
1662
+ df (DataFrame): The DataFrame containing processed detection data.
1663
+
1664
+ The function appends data to the 'overlapping' table in the HDF5 database, ensuring
1665
+ that each record is written incrementally to minimize memory usage.
1666
+ """
1667
+ logger = logging.getLogger(__name__)
1668
+ # Initialize ambiguous_overlap column if not present
1669
+ if 'ambiguous_overlap' not in df.columns:
1670
+ df['ambiguous_overlap'] = np.float32(0)
1671
+
1672
+ # Determine which columns to write
1673
+ base_columns = ['freq_code', 'epoch', 'time_stamp', 'rec_id', 'overlapping', 'ambiguous_overlap']
1674
+ optional_columns = ['power', 'posterior_T', 'posterior_F']
1675
+
1676
+ columns_to_write = base_columns.copy()
1677
+ for col in optional_columns:
1678
+ if col in df.columns:
1679
+ columns_to_write.append(col)
1680
+
1681
+ # Set data types for base columns
1682
+ dtype_dict = {
1683
+ 'freq_code': 'object',
1684
+ 'epoch': 'int32',
1685
+ 'rec_id': 'object',
1686
+ 'overlapping': 'int32',
1687
+ 'ambiguous_overlap': 'float32',
1688
+ }
1689
+
1690
+ # Add optional column types if present
1691
+ if 'power' in df.columns:
1692
+ dtype_dict['power'] = 'float32'
1693
+ if 'posterior_T' in df.columns:
1694
+ dtype_dict['posterior_T'] = 'float32'
1695
+ if 'posterior_F' in df.columns:
1696
+ dtype_dict['posterior_F'] = 'float32'
1697
+
1698
+ try:
1699
+ df = df.astype(dtype_dict)
1700
+ except (ValueError, TypeError) as exc:
1701
+ raise ValueError(f"Failed to cast overlapping data types: {exc}") from exc
1702
+
1703
+ # To avoid PyTables validation errors when the incoming DataFrame has
1704
+ # a different set of columns or dtypes than an existing `/overlapping`
1705
+ # table, read the existing table (if present), concatenate and replace
1706
+ # it atomically.
1707
+ db_path = self._resolve_db_path()
1708
+ with pd.HDFStore(db_path, mode='a') as store:
1709
+ if '/overlapping' in store:
1710
+ try:
1711
+ existing = store.select('overlapping')
1712
+ except (OSError, KeyError, ValueError) as exc:
1713
+ raise RuntimeError(f"Failed to read /overlapping table: {exc}") from exc
1714
+
1715
+ # Ensure existing and new columns align: add missing cols as NaN
1716
+ for c in columns_to_write:
1717
+ if c not in existing.columns:
1718
+ existing[c] = np.nan
1719
+ for c in existing.columns:
1720
+ if c not in columns_to_write:
1721
+ df[c] = np.nan
1722
+ # Reorder df columns to match final schema
1723
+ final_cols = list(existing.columns)
1724
+ # Concatenate and replace
1725
+ combined = pd.concat([existing, df[final_cols]], ignore_index=True)
1726
+
1727
+ try:
1728
+ store.remove('overlapping')
1729
+ except (OSError, KeyError, ValueError) as exc:
1730
+ raise RuntimeError(f"Failed to remove /overlapping table: {exc}") from exc
1731
+
1732
+ try:
1733
+ store.put(key='overlapping', value=combined, format='table', data_columns=True,
1734
+ min_itemsize={'freq_code': 50, 'rec_id': 50})
1735
+ except (OSError, ValueError) as exc:
1736
+ raise RuntimeError(f"Failed to write /overlapping table: {exc}") from exc
1737
+ else:
1738
+ try:
1739
+ store.append(key='overlapping', value=df[columns_to_write], format='table', data_columns=True,
1740
+ min_itemsize={'freq_code': 50, 'rec_id': 50})
1741
+ except (OSError, ValueError) as exc:
1742
+ raise RuntimeError(f"Failed to append /overlapping table: {exc}") from exc
1743
+ logger.debug(f" Wrote {len(df)} detections to /overlapping (ambiguous: {df['ambiguous_overlap'].sum()})")
1744
+
1745
+
1746
+
1747
+
1748
+
1749
+ # def _plot_kmeans_results(self, combined, centers, fish_id, node_a, node_b, project_dir):
1750
+ # """
1751
+ # Plots and saves the K-means clustering results to the project directory.
1752
+ # """
1753
+ # plt.figure(figsize=(10, 6))
1754
+ # plt.hist(combined['norm_power'], bins=30, alpha=0.5, label='Normalized Power')
1755
+ # plt.axvline(centers[0], color='r', linestyle='dashed', linewidth=2, label='Cluster Center 1')
1756
+ # plt.axvline(centers[1], color='b', linestyle='dashed', linewidth=2, label='Cluster Center 2')
1757
+ # plt.title(f"K-means Clustering between Nodes {node_a} and {node_b}")
1758
+ # plt.xlabel("Normalized Power")
1759
+ # plt.ylabel("Frequency")
1760
+ # plt.legend()
1761
+
1762
+ # output_path = os.path.join(project_dir, 'Output', 'Figures', f'kmeans_nodes_{node_a}_{node_b}.png')
1763
+ # plt.savefig(output_path)
1764
+ # plt.close()
1765
+ # print(f"K-means plot saved")
1766
+
1767
+
1768
+
1769
+ # class overlap_reduction():
1770
+ # def __init__(self, nodes, edges, radio_project, n_clusters=2):
1771
+ # self.db = radio_project.db
1772
+ # self.G = nx.DiGraph()
1773
+ # self.G.add_edges_from(edges)
1774
+
1775
+ # self.node_pres_dict = {}
1776
+ # self.node_recap_dict = {}
1777
+ # self.nodes = nodes
1778
+ # self.edges = edges
1779
+ # self.n_clusters = n_clusters
1780
+
1781
+ # for node in nodes:
1782
+ # pres_data = dd.read_hdf(self.db, 'presence', columns=['freq_code', 'epoch', 'time_stamp', 'power', 'rec_id', 'bout_no'])
1783
+ # recap_data = dd.read_hdf(self.db, 'classified', columns=['freq_code', 'epoch', 'time_stamp', 'power', 'rec_id', 'iter', 'test'])
1784
+
1785
+ # pres_data['epoch'] = ((pres_data['time_stamp'] - pd.Timestamp("1970-01-01")) // pd.Timedelta('1s')).astype('int64')
1786
+ # recap_data['epoch'] = ((recap_data['time_stamp'] - pd.Timestamp("1970-01-01")) // pd.Timedelta('1s')).astype('int64')
1787
+
1788
+ # pres_data = pres_data[pres_data['rec_id'] == node]
1789
+ # recap_data = recap_data[(recap_data['rec_id'] == node) &
1790
+ # (recap_data['iter'] == recap_data['iter'].max()) &
1791
+ # (recap_data['test'] == 1)]
1792
+ # recap_data = recap_data[['freq_code', 'epoch', 'time_stamp', 'power', 'rec_id']]
1793
+
1794
+ # pres_data = pres_data.compute()
1795
+
1796
+ # summarized_data = pres_data.groupby(['freq_code', 'bout_no', 'rec_id']).agg({
1797
+ # 'epoch': ['min', 'max'],
1798
+ # 'power': 'median'
1799
+ # }).reset_index()
1800
+
1801
+ # summarized_data.columns = ['freq_code', 'bout_no', 'rec_id',
1802
+ # 'min_epoch', 'max_epoch', 'median_power']
1803
+
1804
+ # rec_ids = summarized_data['rec_id'].values
1805
+ # median_powers = summarized_data['median_power'].values
1806
+ # normalized_power = np.zeros_like(median_powers)
1807
+
1808
+ # for rec_id in np.unique(rec_ids):
1809
+ # mask = rec_ids == rec_id
1810
+ # norm_power = median_powers[mask]
1811
+ # normalized_power[mask] = (norm_power - norm_power.min()) / (norm_power.max() - norm_power.min())
1812
+
1813
+ # summarized_data['norm_power'] = normalized_power
1814
+
1815
+ # self.node_pres_dict[node] = dd.from_pandas(summarized_data, npartitions=10)
1816
+ # self.node_recap_dict[node] = recap_data
1817
+ # print(f"Completed data management process for node {node}")
1818
+
1819
+ # # Debugging step to check initialized keys
1820
+ # print("Initialized nodes in node_pres_dict:", list(self.node_pres_dict.keys()))
1821
+
1822
+
1823
+ # def unsupervised_removal(self):
1824
+ # final_classifications = {}
1825
+ # combined_recaps_list = []
1826
+
1827
+ # def process_pair(parent, child):
1828
+ # parent_bouts = self.node_pres_dict[parent]
1829
+ # child_bouts = self.node_pres_dict[child]
1830
+
1831
+ # overlapping = parent_bouts.merge(
1832
+ # child_bouts,
1833
+ # on='freq_code',
1834
+ # suffixes=('_parent', '_child')
1835
+ # ).query('(min_epoch_child <= max_epoch_parent) & (max_epoch_child >= min_epoch_parent)').compute()
1836
+
1837
+ # if overlapping.empty:
1838
+ # return None
1839
+
1840
+ # parent_recaps = self.node_recap_dict[parent].merge(
1841
+ # overlapping[['freq_code', 'min_epoch_parent', 'max_epoch_parent']],
1842
+ # on='freq_code'
1843
+ # ).query('epoch >= min_epoch_parent and epoch <= max_epoch_parent').compute()
1844
+
1845
+ # child_recaps = self.node_recap_dict[child].merge(
1846
+ # overlapping[['freq_code', 'min_epoch_child', 'max_epoch_child']],
1847
+ # on='freq_code'
1848
+ # ).query('epoch >= min_epoch_child and epoch <= max_epoch_child').compute()
1849
+
1850
+ # if parent_recaps.empty or child_recaps.empty:
1851
+ # return None
1852
+
1853
+ # combined_recaps = pd.concat([parent_recaps, child_recaps])
1854
+ # combined_recaps['norm_power'] = (combined_recaps['power'] - combined_recaps['power'].min()) / (combined_recaps['power'].max() - combined_recaps['power'].min())
1855
+ # return combined_recaps
1856
+
1857
+ # # Process receiver pairs in parallel
1858
+ # with ProcessPoolExecutor() as executor:
1859
+ # results = executor.map(lambda pair: process_pair(pair[0], pair[1]), self.edges)
1860
+
1861
+ # for combined_recaps in results:
1862
+ # if combined_recaps is not None:
1863
+ # combined_recaps_list.append(combined_recaps)
1864
+
1865
+ # if combined_recaps_list:
1866
+ # all_combined_recaps = pd.concat(combined_recaps_list, ignore_index=True)
1867
+ # best_bout_mask = self.apply_kmeans(all_combined_recaps)
1868
+
1869
+ # all_combined_recaps['overlapping'] = np.where(best_bout_mask, 0, 1)
1870
+ # for _, rec in all_combined_recaps.iterrows():
1871
+ # key = (rec['freq_code'], rec['epoch'])
1872
+ # if key not in final_classifications:
1873
+ # final_classifications[key] = rec['overlapping']
1874
+ # else:
1875
+ # final_classifications[key] = max(final_classifications[key], rec['overlapping'])
1876
+
1877
+ # final_detections = []
1878
+ # for parent in self.node_pres_dict.keys():
1879
+ # recaps_chunk = self.node_recap_dict[parent].compute()
1880
+ # recaps_chunk['overlapping'] = 1
1881
+
1882
+ # for (freq_code, epoch), overlap_value in final_classifications.items():
1883
+ # recaps_chunk.loc[(recaps_chunk['epoch'] == epoch) & (recaps_chunk['freq_code'] == freq_code), 'overlapping'] = overlap_value
1884
+
1885
+ # final_detections.append(recaps_chunk)
1886
+
1887
+ # final_result = pd.concat(final_detections, ignore_index=True)
1888
+ # final_result['epoch'] = final_result['epoch'].astype('int64')
1889
+
1890
+ # string_columns = final_result.select_dtypes(include=['string']).columns
1891
+ # final_result[string_columns] = final_result[string_columns].astype('object')
1892
+
1893
+ # with pd.HDFStore(self.db, mode='a') as store:
1894
+ # store.append(key='overlapping',
1895
+ # value=final_result,
1896
+ # format='table',
1897
+ # index=False,
1898
+ # min_itemsize={'freq_code': 20, 'rec_id': 20},
1899
+ # append=True,
1900
+ # data_columns=True,
1901
+ # chunksize=1000000)
1902
+
1903
+ # print(f'Processed overlap for all receiver pairs.')
1904
+
1905
+
1906
+ # def apply_kmeans(self, combined_recaps):
1907
+ # """
1908
+ # Applies KMeans clustering to identify 'near' and 'far' clusters.
1909
+ # If KMeans cannot find two distinct clusters, falls back to a simple power comparison.
1910
+ # """
1911
+ # # Convert to NumPy arrays directly from the DataFrame
1912
+ # features = combined_recaps[['norm_power']].values
1913
+
1914
+ # kmeans = KMeans(n_clusters=2, random_state=42, n_init=10)
1915
+ # kmeans.fit(features)
1916
+
1917
+ # # Ensure labels are a NumPy array
1918
+ # labels = np.array(kmeans.labels_)
1919
+
1920
+ # # Check if KMeans found fewer than 2 clusters
1921
+ # if len(np.unique(labels)) < 2:
1922
+ # print("Found fewer than 2 clusters. Falling back to selecting the recapture with the highest power.")
1923
+ # return combined_recaps['power'].values >= combined_recaps['power'].mean()
1924
+
1925
+ # # Determine which cluster corresponds to 'near' based on median power
1926
+ # cluster_medians = combined_recaps.groupby(labels)['norm_power'].median()
1927
+ # near_cluster = cluster_medians.idxmax() # Cluster with the higher median power is 'near'
1928
+
1929
+ # return labels == near_cluster
1930
+
1931
+ # # def unsupervised_removal(self):
1932
+ # # """
1933
+ # # Identifies and removes overlapping detections across receivers using KMeans for clustering.
1934
+ # # Ensures each detection is classified only once, with the most conservative (i.e., 'far') classification.
1935
+ # # """
1936
+ # # final_classifications = {}
1937
+
1938
+ # # for parent, child in self.edges:
1939
+ # # print(f"Processing parent: {parent}")
1940
+
1941
+ # # if parent not in self.node_pres_dict:
1942
+ # # raise KeyError(f"Parent {parent} not found in node_pres_dict. Available keys: {list(self.node_pres_dict.keys())}")
1943
+
1944
+ # # parent_bouts = self.node_pres_dict[parent].compute()
1945
+ # # child_bouts = self.node_pres_dict[child].compute()
1946
+
1947
+ # # # Merge and detect overlaps between parent and child
1948
+ # # overlapping = parent_bouts.merge(
1949
+ # # child_bouts,
1950
+ # # on='freq_code',
1951
+ # # suffixes=('_parent', '_child')
1952
+ # # ).query('(min_epoch_child <= max_epoch_parent) & (max_epoch_child >= min_epoch_parent)')
1953
+
1954
+ # # if not overlapping.empty:
1955
+ # # # Apply KMeans clustering or fallback to greater-than analysis
1956
+ # # best_bout_mask = self.apply_kmeans(overlapping)
1957
+ # # overlapping['overlapping'] = np.where(best_bout_mask, 0, 1)
1958
+
1959
+ # # # Update the final classification for each detection
1960
+ # # for _, bout in overlapping.iterrows():
1961
+ # # key = (bout['freq_code'], bout['min_epoch_parent'], bout['max_epoch_parent'])
1962
+ # # if key not in final_classifications:
1963
+ # # final_classifications[key] = bout['overlapping']
1964
+ # # else:
1965
+ # # final_classifications[key] = max(final_classifications[key], bout['overlapping'])
1966
+
1967
+ # # # Prepare final result based on the most conservative classification
1968
+ # # final_detections = []
1969
+ # # for parent in self.node_pres_dict.keys():
1970
+ # # recaps_chunk = self.node_recap_dict[parent].compute()
1971
+
1972
+ # # # Initialize 'overlapping' column as 1 (conservative)
1973
+ # # recaps_chunk['overlapping'] = 1
1974
+
1975
+ # # # Update based on the final classifications
1976
+ # # for (freq_code, min_epoch, max_epoch), overlap_value in final_classifications.items():
1977
+ # # in_bout = (recaps_chunk['epoch'] >= min_epoch) & (recaps_chunk['epoch'] <= max_epoch) & (recaps_chunk['freq_code'] == freq_code)
1978
+ # # recaps_chunk.loc[in_bout, 'overlapping'] = overlap_value
1979
+
1980
+ # # final_detections.append(recaps_chunk)
1981
+
1982
+ # # # Combine all detections
1983
+ # # final_result = pd.concat(final_detections, ignore_index=True)
1984
+ # # final_result['epoch'] = final_result['epoch'].astype('int64')
1985
+
1986
+ # # # Convert StringDtype columns to object dtype
1987
+ # # string_columns = final_result.select_dtypes(include=['string']).columns
1988
+ # # final_result[string_columns] = final_result[string_columns].astype('object')
1989
+
1990
+ # # # Save the final results to the HDF5 store
1991
+ # # with pd.HDFStore(self.db, mode='a') as store:
1992
+ # # store.append(key='overlapping',
1993
+ # # value=final_result,
1994
+ # # format='table',
1995
+ # # index=False,
1996
+ # # min_itemsize={'freq_code': 20, 'rec_id': 20},
1997
+ # # append=True,
1998
+ # # data_columns=True,
1999
+ # # chunksize=1000000)
2000
+
2001
+ # # print(f'Processed overlap for all receiver pairs.')
2002
+
2003
+ # def _plot_kmeans_results(self, combined, centers, fish_id, node_a, node_b, project_dir):
2004
+ # """
2005
+ # Plots and saves the K-means clustering results to the project directory.
2006
+ # """
2007
+ # plt.figure(figsize=(10, 6))
2008
+ # plt.hist(combined['norm_power'], bins=30, alpha=0.5, label='Normalized Power')
2009
+ # plt.axvline(centers[0], color='r', linestyle='dashed', linewidth=2, label='Cluster Center 1')
2010
+ # plt.axvline(centers[1], color='b', linestyle='dashed', linewidth=2, label='Cluster Center 2')
2011
+ # plt.title(f"K-means Clustering between Nodes {node_a} and {node_b}")
2012
+ # plt.xlabel("Normalized Power")
2013
+ # plt.ylabel("Frequency")
2014
+ # plt.legend()
2015
+
2016
+ # output_path = os.path.join(project_dir, 'Output', 'Figures', f'kmeans_nodes_{node_a}_{node_b}.png')
2017
+ # plt.savefig(output_path)
2018
+ # plt.close()
2019
+ # print(f"K-means plot saved")
2020
+
2021
+ # def nested_doll(self):
2022
+ # """
2023
+ # Identify and mark overlapping detections between parent and child nodes.
2024
+ # """
2025
+ # overlaps_found = False
2026
+ # overlap_count = 0
2027
+
2028
+ # for i in self.node_recap_dict:
2029
+ # fishes = self.node_recap_dict[i].freq_code.unique().compute()
2030
+
2031
+ # for j in fishes:
2032
+ # children = list(self.G.successors(i))
2033
+ # fish_dat = self.node_recap_dict[i][self.node_recap_dict[i].freq_code == j].compute().copy()
2034
+ # fish_dat['overlapping'] = 0
2035
+ # fish_dat['parent'] = ''
2036
+
2037
+ # if len(children) > 0:
2038
+ # for k in children:
2039
+ # child_dat = self.node_pres_dict[k][self.node_pres_dict[k].freq_code == j].compute()
2040
+ # if len(child_dat) > 0:
2041
+ # min_epochs = child_dat.min_epoch.values
2042
+ # max_epochs = child_dat.max_epoch.values
2043
+
2044
+ # fish_epochs = fish_dat.epoch.values
2045
+ # overlaps = np.any(
2046
+ # (min_epochs[:, None] <= fish_epochs) & (max_epochs[:, None] > fish_epochs), axis=0
2047
+ # )
2048
+ # overlap_indices = np.where(overlaps)[0]
2049
+ # if overlap_indices.size > 0:
2050
+ # overlaps_found = True
2051
+ # overlap_count += overlap_indices.size
2052
+ # fish_dat.loc[overlaps, 'overlapping'] = 1
2053
+ # fish_dat.loc[overlaps, 'parent'] = i
2054
+
2055
+ # fish_dat = fish_dat.astype({
2056
+ # 'freq_code': 'object',
2057
+ # 'epoch': 'int32',
2058
+ # 'rec_id': 'object',
2059
+ # 'node': 'object',
2060
+ # 'overlapping': 'int32',
2061
+ # 'parent': 'object'
2062
+ # })
2063
+
2064
+ # with pd.HDFStore(self.db, mode='a') as store:
2065
+ # store.append(key='overlapping',
2066
+ # value=fish_dat,
2067
+ # format='table',
2068
+ # index=False,
2069
+ # min_itemsize={'freq_code': 20,
2070
+ # 'rec_id': 20,
2071
+ # 'parent': 20},
2072
+ # append=True,
2073
+ # data_columns=True,
2074
+ # chunksize=1000000)
2075
+
2076
+ # if overlaps_found:
2077
+ # print(f"Overlaps were found and processed. Total number of overlaps: {overlap_count}.")
2078
+ # else:
2079
+ # print("No overlaps were found.")
2080
+
2081
+ def visualize_overlaps(self, output_dir=None):
2082
+ """
2083
+ Visualize overlap patterns, decisions, and network structure.
2084
+
2085
+ Creates comprehensive plots showing:
2086
+ - Network graph of receiver relationships with overlap counts
2087
+ - Decision breakdown (remove_parent, remove_child, keep_both)
2088
+ - Overlap distribution by receiver and fish
2089
+ - Temporal patterns of overlaps
2090
+ - Power distributions for overlapping vs non-overlapping detections
2091
+
2092
+ Args:
2093
+ output_dir (str): Directory to save plots. If None, uses database directory.
2094
+ """
2095
+ import logging
2096
+ logger = logging.getLogger(__name__)
2097
+
2098
+ print(f"\n{'='*80}")
2099
+ print("OVERLAP VISUALIZATION")
2100
+ print(f"{'='*80}")
2101
+
2102
+ # Load overlapping data
2103
+ try:
2104
+ overlapping = pd.read_hdf(self.db, key='/overlapping')
2105
+ except (OSError, KeyError, ValueError) as exc:
2106
+ raise RuntimeError(f"Error loading overlapping data: {exc}") from exc
2107
+ print(f"Loaded {len(overlapping):,} detections from /overlapping table")
2108
+
2109
+ if overlapping.empty:
2110
+ print("No overlap data to visualize")
2111
+ return
2112
+
2113
+ # Calculate statistics
2114
+ total_detections = len(overlapping)
2115
+ overlapping_count = overlapping['overlapping'].sum() if 'overlapping' in overlapping.columns else 0
2116
+ ambiguous_count = overlapping['ambiguous_overlap'].sum() if 'ambiguous_overlap' in overlapping.columns else 0
2117
+
2118
+ print(f"\nOverlap Statistics:")
2119
+ print(f" Total detections: {total_detections:,}")
2120
+ print(f" Overlapping detections: {overlapping_count:,} ({100*overlapping_count/total_detections:.1f}%)")
2121
+ print(f" Ambiguous overlaps: {ambiguous_count:,} ({100*ambiguous_count/total_detections:.1f}%)")
2122
+ print(f" Unique fish: {overlapping['freq_code'].nunique()}")
2123
+ print(f" Unique receivers: {overlapping['rec_id'].nunique()}")
2124
+
2125
+ # Create figure with subplots
2126
+ fig = plt.figure(figsize=(18, 12))
2127
+ gs = fig.add_gridspec(3, 3, hspace=0.3, wspace=0.3)
2128
+
2129
+ # 1. Network graph showing receiver relationships
2130
+ ax1 = fig.add_subplot(gs[0, :2])
2131
+ self._plot_network_graph(ax1)
2132
+
2133
+ # 2. Decision breakdown pie chart
2134
+ ax2 = fig.add_subplot(gs[0, 2])
2135
+ self._plot_decision_breakdown(ax2, overlapping)
2136
+
2137
+ # 3. Overlaps by receiver (bar chart)
2138
+ ax3 = fig.add_subplot(gs[1, 0])
2139
+ self._plot_overlaps_by_receiver(ax3, overlapping)
2140
+
2141
+ # 4. Overlaps by fish (top 15)
2142
+ ax4 = fig.add_subplot(gs[1, 1])
2143
+ self._plot_overlaps_by_fish(ax4, overlapping)
2144
+
2145
+ # 5. Temporal pattern of overlaps
2146
+ ax5 = fig.add_subplot(gs[1, 2])
2147
+ self._plot_temporal_patterns(ax5, overlapping)
2148
+
2149
+ # 6. Power distribution comparison
2150
+ ax6 = fig.add_subplot(gs[2, 0])
2151
+ self._plot_power_distributions(ax6, overlapping)
2152
+
2153
+ # 7. Detection count per fish
2154
+ ax7 = fig.add_subplot(gs[2, 1])
2155
+ self._plot_detection_counts(ax7, overlapping)
2156
+
2157
+ # 8. Overlap percentage by receiver pair
2158
+ ax8 = fig.add_subplot(gs[2, 2])
2159
+ self._plot_receiver_pair_heatmap(ax8, overlapping)
2160
+
2161
+ fig.suptitle('Overlap Removal Analysis', fontsize=16, fontweight='bold')
2162
+
2163
+ # Save figure
2164
+ if output_dir is None:
2165
+ output_dir = os.path.dirname(self.db)
2166
+ output_path = os.path.join(output_dir, 'overlap_analysis.png')
2167
+ plt.savefig(output_path, dpi=300, bbox_inches='tight')
2168
+ print(f"\n[overlap] Saved visualization to: {output_path}")
2169
+
2170
+ plt.show()
2171
+
2172
+ def _plot_network_graph(self, ax):
2173
+ """Plot the receiver network graph with edge weights showing overlap counts."""
2174
+ pos = nx.spring_layout(self.G, seed=42, k=0.5, iterations=50)
2175
+
2176
+ # Draw nodes
2177
+ nx.draw_networkx_nodes(self.G, pos, node_color='lightblue',
2178
+ node_size=1000, alpha=0.9, ax=ax)
2179
+
2180
+ # Draw edges with varying thickness based on overlap count
2181
+ edges = self.G.edges()
2182
+ if len(edges) > 0:
2183
+ nx.draw_networkx_edges(self.G, pos, width=2, alpha=0.6,
2184
+ edge_color='gray', arrows=True,
2185
+ arrowsize=20, ax=ax)
2186
+
2187
+ # Draw labels
2188
+ nx.draw_networkx_labels(self.G, pos, font_size=10, font_weight='bold', ax=ax)
2189
+
2190
+ ax.set_title('Receiver Network Structure', fontsize=12, fontweight='bold')
2191
+ ax.axis('off')
2192
+
2193
+ def _plot_decision_breakdown(self, ax, overlapping):
2194
+ """Plot pie chart of overlap decisions."""
2195
+ # Count decisions from the data
2196
+ decisions = {
2197
+ 'Overlapping': int(overlapping['overlapping'].sum()) if 'overlapping' in overlapping.columns else 0,
2198
+ 'Ambiguous': int(overlapping['ambiguous_overlap'].sum()) if 'ambiguous_overlap' in overlapping.columns else 0,
2199
+ 'Clean': int((overlapping['overlapping'] == 0).sum()) if 'overlapping' in overlapping.columns else len(overlapping)
2200
+ }
2201
+
2202
+ # Filter out zero counts
2203
+ decisions = {k: v for k, v in decisions.items() if v > 0}
2204
+
2205
+ if decisions:
2206
+ colors = {'Overlapping': '#ff6b6b', 'Ambiguous': '#ffd93d', 'Clean': '#6bcf7f'}
2207
+ ax.pie(decisions.values(), labels=decisions.keys(), autopct='%1.1f%%',
2208
+ colors=[colors.get(k, 'gray') for k in decisions.keys()],
2209
+ startangle=90)
2210
+ ax.set_title('Detection Categories', fontsize=12, fontweight='bold')
2211
+ else:
2212
+ ax.text(0.5, 0.5, 'No data', ha='center', va='center', transform=ax.transAxes)
2213
+ ax.set_title('Detection Categories', fontsize=12, fontweight='bold')
2214
+
2215
+ def _plot_overlaps_by_receiver(self, ax, overlapping):
2216
+ """Bar chart of overlap counts by receiver."""
2217
+ if 'overlapping' in overlapping.columns:
2218
+ overlap_by_rec = overlapping[overlapping['overlapping'] == 1].groupby('rec_id').size().sort_values(ascending=False)
2219
+
2220
+ if len(overlap_by_rec) > 0:
2221
+ overlap_by_rec.plot(kind='bar', ax=ax, color='steelblue', edgecolor='black')
2222
+ ax.set_xlabel('Receiver ID', fontsize=10)
2223
+ ax.set_ylabel('Overlapping Detections', fontsize=10)
2224
+ ax.set_title('Overlaps by Receiver', fontsize=12, fontweight='bold')
2225
+ ax.tick_params(axis='x', rotation=45)
2226
+ ax.grid(True, alpha=0.3, axis='y')
2227
+ else:
2228
+ ax.text(0.5, 0.5, 'No overlapping detections', ha='center', va='center', transform=ax.transAxes)
2229
+ ax.set_title('Overlaps by Receiver', fontsize=12, fontweight='bold')
2230
+ else:
2231
+ ax.text(0.5, 0.5, 'No overlap data', ha='center', va='center', transform=ax.transAxes)
2232
+ ax.set_title('Overlaps by Receiver', fontsize=12, fontweight='bold')
2233
+
2234
+ def _plot_overlaps_by_fish(self, ax, overlapping):
2235
+ """Bar chart of overlap counts by fish (top 15)."""
2236
+ if 'overlapping' in overlapping.columns:
2237
+ overlap_by_fish = overlapping[overlapping['overlapping'] == 1].groupby('freq_code').size().sort_values(ascending=False).head(15)
2238
+
2239
+ if len(overlap_by_fish) > 0:
2240
+ overlap_by_fish.plot(kind='barh', ax=ax, color='coral', edgecolor='black')
2241
+ ax.set_xlabel('Overlapping Detections', fontsize=10)
2242
+ ax.set_ylabel('Fish ID', fontsize=10)
2243
+ ax.set_title('Top 15 Fish with Overlaps', fontsize=12, fontweight='bold')
2244
+ ax.grid(True, alpha=0.3, axis='x')
2245
+ else:
2246
+ ax.text(0.5, 0.5, 'No overlapping detections', ha='center', va='center', transform=ax.transAxes)
2247
+ ax.set_title('Top 15 Fish with Overlaps', fontsize=12, fontweight='bold')
2248
+ else:
2249
+ ax.text(0.5, 0.5, 'No overlap data', ha='center', va='center', transform=ax.transAxes)
2250
+ ax.set_title('Top 15 Fish with Overlaps', fontsize=12, fontweight='bold')
2251
+
2252
+ def _plot_temporal_patterns(self, ax, overlapping):
2253
+ """Plot posterior ratio distributions to see if weaker classifications correlate with overlaps."""
2254
+ if 'posterior_T' in overlapping.columns and 'posterior_F' in overlapping.columns and 'overlapping' in overlapping.columns:
2255
+ # Calculate posterior ratio (T/F) - higher = stronger classification
2256
+ overlapping_copy = overlapping.copy()
2257
+ overlapping_copy['posterior_ratio'] = overlapping_copy['posterior_T'] / (overlapping_copy['posterior_F'] + 1e-10)
2258
+
2259
+ overlap_ratio = overlapping_copy[overlapping_copy['overlapping'] == 1]['posterior_ratio'].dropna()
2260
+ clean_ratio = overlapping_copy[overlapping_copy['overlapping'] == 0]['posterior_ratio'].dropna()
2261
+
2262
+ if len(overlap_ratio) > 0 and len(clean_ratio) > 0:
2263
+ # Use log scale for ratio
2264
+ overlap_log = np.log10(overlap_ratio + 1e-10)
2265
+ clean_log = np.log10(clean_ratio + 1e-10)
2266
+
2267
+ ax.hist([clean_log, overlap_log], bins=30, label=['Clean', 'Overlapping'],
2268
+ color=['lightblue', 'salmon'], alpha=0.7, edgecolor='black')
2269
+ ax.set_xlabel('log10(Posterior_T / Posterior_F)', fontsize=10)
2270
+ ax.set_ylabel('Frequency', fontsize=10)
2271
+ ax.set_title('Classification Strength: Overlapping vs Clean', fontsize=12, fontweight='bold')
2272
+ ax.legend()
2273
+ ax.grid(True, alpha=0.3, axis='y')
529
2274
 
530
- fish_dat = fish_dat.astype({'freq_code': 'object',
531
- 'epoch': 'int32',
532
- 'rec_id': 'object',
533
- 'node': 'object',
534
- 'overlapping':'int32',
535
- 'parent':'object'})
2275
+ # Add statistics
2276
+ median_overlap = np.median(overlap_log)
2277
+ median_clean = np.median(clean_log)
2278
+ ax.axvline(median_clean, color='blue', linestyle='--', alpha=0.7, linewidth=2, label=f'Clean median: {median_clean:.2f}')
2279
+ ax.axvline(median_overlap, color='red', linestyle='--', alpha=0.7, linewidth=2, label=f'Overlap median: {median_overlap:.2f}')
2280
+ ax.legend()
536
2281
 
537
- # append to hdf5
538
- with pd.HDFStore(self.db, mode='a') as store:
539
- store.append(key = 'overlapping',
540
- value = fish_dat,
541
- format = 'table',
542
- index = False,
543
- min_itemsize = {'freq_code':20,
544
- 'rec_id':20,
545
- 'parent':20},
546
- append = True,
547
- data_columns = True,
548
- chunksize = 1000000)
2282
+ print(f"\n[overlap] Posterior ratio analysis:")
2283
+ print(f" Clean detections - median log10(T/F): {median_clean:.3f} (ratio: {10**median_clean:.2f})")
2284
+ print(f" Overlapping detections - median log10(T/F): {median_overlap:.3f} (ratio: {10**median_overlap:.2f})")
2285
+ if median_overlap < median_clean:
2286
+ print(f" ✓ Overlapping detections have WEAKER classifications (lower T/F ratio)")
2287
+ else:
2288
+ print(f" ⚠ Overlapping detections do NOT have weaker classifications")
2289
+ else:
2290
+ ax.text(0.5, 0.5, 'Insufficient posterior data', ha='center', va='center', transform=ax.transAxes)
2291
+ ax.set_title('Classification Strength', fontsize=12, fontweight='bold')
2292
+ else:
2293
+ # Fallback to temporal patterns if no posterior data
2294
+ if 'time_stamp' in overlapping.columns and 'overlapping' in overlapping.columns:
2295
+ overlap_data = overlapping[overlapping['overlapping'] == 1].copy()
2296
+
2297
+ if len(overlap_data) > 0:
2298
+ overlap_data['date'] = pd.to_datetime(overlap_data['time_stamp']).dt.date
2299
+ daily_overlaps = overlap_data.groupby('date').size()
2300
+
2301
+ daily_overlaps.plot(ax=ax, color='darkgreen', linewidth=2)
2302
+ ax.set_xlabel('Date', fontsize=10)
2303
+ ax.set_ylabel('Overlapping Detections', fontsize=10)
2304
+ ax.set_title('Overlaps Over Time', fontsize=12, fontweight='bold')
2305
+ ax.grid(True, alpha=0.3)
2306
+ ax.tick_params(axis='x', rotation=45)
2307
+ else:
2308
+ ax.text(0.5, 0.5, 'No overlapping detections', ha='center', va='center', transform=ax.transAxes)
2309
+ ax.set_title('Overlaps Over Time', fontsize=12, fontweight='bold')
2310
+ else:
2311
+ ax.text(0.5, 0.5, 'No temporal data', ha='center', va='center', transform=ax.transAxes)
2312
+ ax.set_title('Overlaps Over Time', fontsize=12, fontweight='bold')
2313
+
2314
+ def _plot_power_distributions(self, ax, overlapping):
2315
+ """Compare power distributions for overlapping vs non-overlapping detections."""
2316
+ if 'power' in overlapping.columns and 'overlapping' in overlapping.columns:
2317
+ overlap_power = overlapping[overlapping['overlapping'] == 1]['power'].dropna()
2318
+ clean_power = overlapping[overlapping['overlapping'] == 0]['power'].dropna()
2319
+
2320
+ if len(overlap_power) > 0 and len(clean_power) > 0:
2321
+ ax.hist([clean_power, overlap_power], bins=30, label=['Clean', 'Overlapping'],
2322
+ color=['lightblue', 'salmon'], alpha=0.7, edgecolor='black')
2323
+ ax.set_xlabel('Power (dB)', fontsize=10)
2324
+ ax.set_ylabel('Frequency', fontsize=10)
2325
+ ax.set_title('Power Distribution', fontsize=12, fontweight='bold')
2326
+ ax.legend()
2327
+ ax.grid(True, alpha=0.3, axis='y')
2328
+ else:
2329
+ ax.text(0.5, 0.5, 'Insufficient power data', ha='center', va='center', transform=ax.transAxes)
2330
+ ax.set_title('Power Distribution', fontsize=12, fontweight='bold')
2331
+ else:
2332
+ ax.text(0.5, 0.5, 'No power data', ha='center', va='center', transform=ax.transAxes)
2333
+ ax.set_title('Power Distribution', fontsize=12, fontweight='bold')
2334
+
2335
+ def _plot_detection_counts(self, ax, overlapping):
2336
+ """Plot detection counts: total, overlapping, ambiguous."""
2337
+ categories = ['Total', 'Overlapping', 'Ambiguous', 'Clean']
2338
+ counts = [
2339
+ len(overlapping),
2340
+ int(overlapping['overlapping'].sum()) if 'overlapping' in overlapping.columns else 0,
2341
+ int(overlapping['ambiguous_overlap'].sum()) if 'ambiguous_overlap' in overlapping.columns else 0,
2342
+ int((overlapping['overlapping'] == 0).sum()) if 'overlapping' in overlapping.columns else len(overlapping)
2343
+ ]
2344
+
2345
+ colors = ['#4a90e2', '#ff6b6b', '#ffd93d', '#6bcf7f']
2346
+ bars = ax.bar(categories, counts, color=colors, edgecolor='black', alpha=0.8)
2347
+
2348
+ # Add value labels on bars
2349
+ for bar in bars:
2350
+ height = bar.get_height()
2351
+ ax.text(bar.get_x() + bar.get_width()/2., height,
2352
+ f'{int(height):,}',
2353
+ ha='center', va='bottom', fontsize=9)
2354
+
2355
+ ax.set_ylabel('Count', fontsize=10)
2356
+ ax.set_title('Detection Counts', fontsize=12, fontweight='bold')
2357
+ ax.tick_params(axis='x', rotation=45)
2358
+ ax.grid(True, alpha=0.3, axis='y')
2359
+ ax.set_yscale('log')
2360
+
2361
+ def _plot_receiver_pair_heatmap(self, ax, overlapping):
2362
+ """Heatmap showing overlap percentages between receiver pairs."""
2363
+ if 'overlapping' in overlapping.columns:
2364
+ # This is simplified - would need parent-child relationship data for full heatmap
2365
+ overlap_by_rec = overlapping.groupby('rec_id').agg({
2366
+ 'overlapping': 'sum',
2367
+ 'freq_code': 'count'
2368
+ })
2369
+ overlap_by_rec['pct_overlap'] = 100 * overlap_by_rec['overlapping'] / overlap_by_rec['freq_code']
2370
+
2371
+ receivers = overlap_by_rec.index.tolist()
2372
+ pct_values = overlap_by_rec['pct_overlap'].values
2373
+
2374
+ if len(receivers) > 0:
2375
+ bars = ax.barh(receivers, pct_values, color='purple', alpha=0.7, edgecolor='black')
2376
+ ax.set_xlabel('% Detections Overlapping', fontsize=10)
2377
+ ax.set_ylabel('Receiver ID', fontsize=10)
2378
+ ax.set_title('Overlap % by Receiver', fontsize=12, fontweight='bold')
2379
+ ax.grid(True, alpha=0.3, axis='x')
2380
+ else:
2381
+ ax.text(0.5, 0.5, 'No overlap data', ha='center', va='center', transform=ax.transAxes)
2382
+ ax.set_title('Overlap % by Receiver', fontsize=12, fontweight='bold')
2383
+ else:
2384
+ ax.text(0.5, 0.5, 'No overlap data', ha='center', va='center', transform=ax.transAxes)
2385
+ ax.set_title('Overlap % by Receiver', fontsize=12, fontweight='bold')