grasp-tool 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1021 @@
1
+ ########################################################
2
+ # Cell and nucleus boundary registration utilities.
3
+ #
4
+ # This module includes single-threaded and parallel variants used by the GRASP
5
+ # preprocessing pipeline.
6
+ ########################################################
7
+
8
+ import numpy as np
9
+ import pandas as pd
10
+ import pickle
11
+ import timeit
12
+ import warnings
13
+ from typing import Any, cast
14
+ from tqdm import tqdm
15
+ import math
16
+ from math import pi
17
+ from multiprocessing import Pool, cpu_count
18
+ from functools import partial
19
+ import os
20
+ from datetime import datetime
21
+ from scipy.interpolate import interp1d, splprep, splev
22
+
23
+
24
+ def interpolate_boundary_points(
25
+ cell_boundary_dict, target_points_per_cell=100, method="spline", smooth_factor=0
26
+ ):
27
+ """Interpolate boundary points to increase boundary resolution.
28
+
29
+ Args:
30
+ cell_boundary_dict: Mapping {cell_id: DataFrame with columns x,y}.
31
+ target_points_per_cell: Target number of points after interpolation.
32
+ method: "spline" or "linear".
33
+ smooth_factor: Spline smoothing factor (0 means interpolate through all points).
34
+
35
+ Returns:
36
+ Mapping {cell_id: DataFrame with interpolated x,y}.
37
+ """
38
+ interpolated_boundary = {}
39
+
40
+ print(f"Interpolating boundary points for {len(cell_boundary_dict)} cells...")
41
+
42
+ for cell_id, boundary_df in tqdm(
43
+ cell_boundary_dict.items(), desc="Boundary interpolation"
44
+ ):
45
+ try:
46
+ cell_method = method
47
+
48
+ # Original points
49
+ x_points = boundary_df["x"].values
50
+ y_points = boundary_df["y"].values
51
+
52
+ x_new = None
53
+ y_new = None
54
+
55
+ # Need at least 3 points.
56
+ if len(x_points) < 3:
57
+ print(
58
+ f"WARNING: cell {cell_id} has too few boundary points "
59
+ f"({len(x_points)}); skip interpolation"
60
+ )
61
+ interpolated_boundary[cell_id] = boundary_df.copy()
62
+ continue
63
+
64
+ if len(x_points) >= target_points_per_cell:
65
+ interpolated_boundary[cell_id] = boundary_df.copy()
66
+ continue
67
+
68
+ # Ensure a closed boundary.
69
+ if not (
70
+ np.isclose(x_points[0], x_points[-1])
71
+ and np.isclose(y_points[0], y_points[-1])
72
+ ):
73
+ x_points = np.append(x_points, x_points[0])
74
+ y_points = np.append(y_points, y_points[0])
75
+
76
+ if cell_method == "spline":
77
+ # Spline interpolation.
78
+ distances = np.sqrt(np.diff(x_points) ** 2 + np.diff(y_points) ** 2)
79
+ distances = np.insert(distances, 0, 0)
80
+ cumulative_distance = np.cumsum(distances)
81
+
82
+ # Avoid duplicated points.
83
+ unique_indices = np.unique(cumulative_distance, return_index=True)[1]
84
+ if len(unique_indices) < 3:
85
+ print(
86
+ f"WARNING: cell {cell_id} has too few unique points; "
87
+ "fall back to linear interpolation"
88
+ )
89
+ cell_method = "linear"
90
+ else:
91
+ x_unique = x_points[unique_indices]
92
+ y_unique = y_points[unique_indices]
93
+ t_unique = cumulative_distance[unique_indices]
94
+
95
+ try:
96
+ tck, u = splprep(
97
+ [x_unique, y_unique], s=smooth_factor, per=True
98
+ )
99
+
100
+ u_new = np.linspace(
101
+ 0, 1, target_points_per_cell, endpoint=False
102
+ )
103
+
104
+ x_new, y_new = splev(u_new, tck)
105
+
106
+ except Exception as e:
107
+ print(
108
+ f"WARNING: cell {cell_id} spline interpolation failed: {e}; "
109
+ "fall back to linear interpolation"
110
+ )
111
+ cell_method = "linear"
112
+
113
+ if cell_method == "linear":
114
+ # Linear interpolation.
115
+ distances = np.sqrt(np.diff(x_points) ** 2 + np.diff(y_points) ** 2)
116
+ distances = np.insert(distances, 0, 0)
117
+ cumulative_distance = np.cumsum(distances)
118
+
119
+ if cumulative_distance[-1] > 0:
120
+ normalized_distance = cumulative_distance / cumulative_distance[-1]
121
+ else:
122
+ normalized_distance = cumulative_distance
123
+
124
+ t_new = np.linspace(0, 1, target_points_per_cell, endpoint=False)
125
+
126
+ try:
127
+ x_interp = interp1d(
128
+ normalized_distance,
129
+ x_points,
130
+ kind="linear",
131
+ assume_sorted=True,
132
+ bounds_error=False,
133
+ fill_value=cast(Any, "extrapolate"),
134
+ )
135
+ y_interp = interp1d(
136
+ normalized_distance,
137
+ y_points,
138
+ kind="linear",
139
+ assume_sorted=True,
140
+ bounds_error=False,
141
+ fill_value=cast(Any, "extrapolate"),
142
+ )
143
+
144
+ x_new = x_interp(t_new)
145
+ y_new = y_interp(t_new)
146
+
147
+ except Exception as e:
148
+ print(
149
+ f"WARNING: cell {cell_id} linear interpolation failed: {e}; "
150
+ "keep original points"
151
+ )
152
+ interpolated_boundary[cell_id] = boundary_df.copy()
153
+ continue
154
+
155
+ if x_new is None or y_new is None:
156
+ interpolated_boundary[cell_id] = boundary_df.copy()
157
+ continue
158
+
159
+ interpolated_df = pd.DataFrame({"x": x_new, "y": y_new})
160
+
161
+ interpolated_boundary[cell_id] = interpolated_df
162
+
163
+ except Exception as e:
164
+ print(f"ERROR: failed to process cell {cell_id}: {e}")
165
+ interpolated_boundary[cell_id] = boundary_df.copy()
166
+
167
+ original_points = sum(len(df) for df in cell_boundary_dict.values())
168
+ new_points = sum(len(df) for df in interpolated_boundary.values())
169
+
170
+ print("Interpolation complete:")
171
+ print(f" total_original_points: {original_points}")
172
+ print(f" total_interpolated_points: {new_points}")
173
+ print(
174
+ f" mean_original_points_per_cell: {original_points / len(cell_boundary_dict):.1f}"
175
+ )
176
+ print(
177
+ f" mean_interpolated_points_per_cell: {new_points / len(interpolated_boundary):.1f}"
178
+ )
179
+
180
+ return interpolated_boundary
181
+
182
+
183
+ def enhance_boundary_resolution(
184
+ cell_boundary_dict, min_points_per_cell=50, adaptive=True
185
+ ):
186
+ """Increase boundary resolution for each cell boundary.
187
+
188
+ If adaptive=True, the target point count is derived from the boundary
189
+ perimeter; otherwise, uses a fixed minimum.
190
+ """
191
+ enhanced_boundary = {}
192
+
193
+ for cell_id, boundary_df in cell_boundary_dict.items():
194
+ current_points = len(boundary_df)
195
+
196
+ if adaptive:
197
+ x_points = boundary_df["x"].values
198
+ y_points = boundary_df["y"].values
199
+
200
+ perimeter = 0
201
+ for i in range(len(x_points)):
202
+ next_i = (i + 1) % len(x_points)
203
+ perimeter += np.sqrt(
204
+ (x_points[next_i] - x_points[i]) ** 2
205
+ + (y_points[next_i] - y_points[i]) ** 2
206
+ )
207
+
208
+ target_points = max(min_points_per_cell, int(perimeter / 2.5))
209
+ else:
210
+ target_points = min_points_per_cell
211
+
212
+ if current_points >= target_points:
213
+ enhanced_boundary[cell_id] = boundary_df.copy()
214
+ else:
215
+ # Interpolate to increase boundary resolution.
216
+ temp_dict = {cell_id: boundary_df}
217
+ interpolated = interpolate_boundary_points(
218
+ temp_dict, target_points_per_cell=target_points, method="spline"
219
+ )
220
+ enhanced_boundary[cell_id] = interpolated[cell_id]
221
+
222
+ return enhanced_boundary
223
+
224
+
225
+ def register_cells(
226
+ data_df, cell_list_all, cell_mask_df, ntanbin_dict, epsilon=1e-10, nc_demo=None
227
+ ):
228
+ if nc_demo is None:
229
+ nc_demo = len(cell_list_all)
230
+ dict_registered = {}
231
+ cell_radii = {} # Per-cell maximum radius.
232
+ df = data_df.copy() # cp original data
233
+ df_gbC = df.groupby("cell", observed=False) # group by `cell`
234
+ for ic, c in enumerate(tqdm(cell_list_all[:nc_demo], desc="Processing cells")):
235
+ df_c = df_gbC.get_group(c).copy() # df for cell c
236
+ t = df_c.type.iloc[0] # cell type for cell c
237
+ mask_df_c = cell_mask_df[cell_mask_df.cell == c] # get the mask df for cell c
238
+ center_c = [
239
+ int(df_c.centerX.iloc[0]),
240
+ int(df_c.centerY.iloc[0]),
241
+ ] # nuclear center of cell c
242
+ tanbin = np.linspace(0, pi / 2, ntanbin_dict[t] + 1)
243
+ delta_tanbin = (2 * math.pi) / (ntanbin_dict[t] * 4)
244
+ # add centered coord and ratio=y/x for df_c and mask_df_c
245
+ df_c["x_c"] = df_c.x.copy() - center_c[0]
246
+ df_c["y_c"] = df_c.y.copy() - center_c[1]
247
+ df_c["d_c"] = (df_c.x_c.copy() ** 2 + df_c.y_c.copy() ** 2) ** 0.5
248
+ df_c["arctan"] = np.absolute(np.arctan(df_c.y_c / (df_c.x_c + epsilon)))
249
+ mask_df_c["x_c"] = mask_df_c.x.copy() - center_c[0]
250
+ mask_df_c["y_c"] = mask_df_c.y.copy() - center_c[1]
251
+ mask_df_c["d_c"] = (
252
+ mask_df_c.x_c.copy() ** 2 + mask_df_c.y_c.copy() ** 2
253
+ ) ** 0.5
254
+ mask_df_c["arctan"] = np.absolute(
255
+ np.arctan(mask_df_c.y_c / (mask_df_c.x_c + epsilon))
256
+ )
257
+ # in each quatrant, find dismax_c for each tanbin interval using mask_df_c
258
+ mask_df_c_q_dict = {}
259
+ mask_df_c_q_dict["0"] = mask_df_c[(mask_df_c.x_c >= 0) & (mask_df_c.y_c >= 0)]
260
+ mask_df_c_q_dict["1"] = mask_df_c[(mask_df_c.x_c <= 0) & (mask_df_c.y_c >= 0)]
261
+ mask_df_c_q_dict["2"] = mask_df_c[(mask_df_c.x_c <= 0) & (mask_df_c.y_c <= 0)]
262
+ mask_df_c_q_dict["3"] = mask_df_c[(mask_df_c.x_c >= 0) & (mask_df_c.y_c <= 0)]
263
+ # compute the dismax_c
264
+ dismax_c_mat = np.zeros((ntanbin_dict[t], 4))
265
+ for q in range(4): # in each of the 4 quantrants
266
+ mask_df_c_q = mask_df_c_q_dict[str(q)]
267
+ mask_df_c_q["arctan_idx"] = (mask_df_c_q.arctan / delta_tanbin).astype(
268
+ int
269
+ ) # arctan_idx from 0 to self.ntanbin_dict[t]-1
270
+ dismax_c_mat[
271
+ mask_df_c_q.groupby("arctan_idx").max()["d_c"].index.to_numpy(), q
272
+ ] = (
273
+ mask_df_c_q.groupby("arctan_idx").max()["d_c"].values
274
+ ) # automatically sorted by arctan_idx from 0 to self.ntanbin_dict[t]-1
275
+
276
+ # for df_c, for arctan in each interval, find max dis using dismax_c
277
+ df_c_q_dict = {}
278
+ df_c_q_dict["0"] = df_c[(df_c.x_c >= 0) & (df_c.y_c >= 0)]
279
+ df_c_q_dict["1"] = df_c[(df_c.x_c <= 0) & (df_c.y_c >= 0)]
280
+ df_c_q_dict["2"] = df_c[(df_c.x_c <= 0) & (df_c.y_c <= 0)]
281
+ df_c_q_dict["3"] = df_c[(df_c.x_c >= 0) & (df_c.y_c <= 0)]
282
+ d_c_maxc_dict = {}
283
+ for q in range(4): # in each of the 4 quantrants
284
+ df_c_q = df_c_q_dict[str(q)]
285
+ d_c_maxc_q = np.zeros(len(df_c_q))
286
+ df_c_q["arctan_idx"] = (df_c_q.arctan / delta_tanbin).astype(
287
+ int
288
+ ) # arctan_idx from 0 to self.ntanbin_dict[t]-1
289
+ for ai in range(ntanbin_dict[t]):
290
+ d_c_maxc_q[df_c_q.arctan_idx.values == ai] = dismax_c_mat[ai, q]
291
+ d_c_maxc_dict[str(q)] = d_c_maxc_q
292
+ d_c_maxc = np.zeros(len(df_c))
293
+ d_c_maxc[(df_c.x_c >= 0) & (df_c.y_c >= 0)] = d_c_maxc_dict["0"]
294
+ d_c_maxc[(df_c.x_c <= 0) & (df_c.y_c >= 0)] = d_c_maxc_dict["1"]
295
+ d_c_maxc[(df_c.x_c <= 0) & (df_c.y_c <= 0)] = d_c_maxc_dict["2"]
296
+ d_c_maxc[(df_c.x_c >= 0) & (df_c.y_c <= 0)] = d_c_maxc_dict["3"]
297
+ df_c["d_c_maxc"] = d_c_maxc
298
+
299
+ # scale centered x_c and y_c
300
+ d_c_s = np.zeros(len(df_c))
301
+ x_c_s = np.zeros(len(df_c))
302
+ y_c_s = np.zeros(len(df_c))
303
+ d_c_s = df_c.d_c / (df_c.d_c_maxc + epsilon)
304
+ x_c_s = df_c.x_c * (d_c_s / (df_c.d_c + epsilon))
305
+ y_c_s = df_c.y_c * (d_c_s / (df_c.d_c + epsilon))
306
+ df_c["x_c_s"] = x_c_s
307
+ df_c["y_c_s"] = y_c_s
308
+ df_c["d_c_s"] = d_c_s
309
+
310
+ # Store per-cell maximum radius.
311
+ cell_radii[c] = np.max(df_c["d_c_maxc"])
312
+
313
+ dict_registered[c] = df_c
314
+ del df_c
315
+ # concatenate to one df
316
+ df_registered = pd.concat(list(dict_registered.values()))
317
+ print(f"Number of cells registered {len(dict_registered)}")
318
+ return df_registered, cell_radii
319
+
320
+
321
+ def specify_ntanbin(
322
+ cell_list_dict,
323
+ cell_mask_df,
324
+ type_list,
325
+ nc4ntanbin=10,
326
+ high_res=200,
327
+ max_ntanbin=25,
328
+ input_ntanbin_dict=None,
329
+ min_bp=5,
330
+ min_ntanbin_error=3,
331
+ ):
332
+ ntanbin_dict = {} # Initialize empty dict.
333
+ if input_ntanbin_dict is not None: # use customized ntanbin across cell types
334
+ ntanbin_dict = input_ntanbin_dict
335
+
336
+ if input_ntanbin_dict is None: # compute ntanbin for each cell type:
337
+ for t in type_list:
338
+ # specify ntanbin_gen based on cell seg mask/boundary
339
+ # random sample self.nc4ntanbin cells, allow replace
340
+ cell_list_sampled = np.random.choice(
341
+ cell_list_dict[t], nc4ntanbin, replace=True
342
+ )
343
+ cell_mask_df_sampled = cell_mask_df[
344
+ cell_mask_df.cell.isin(cell_list_sampled)
345
+ ]
346
+ # compute the #x and #y unique coords of these sampled cells
347
+ nxu_sampled = []
348
+ nyu_sampled = []
349
+ for c in cell_list_sampled:
350
+ mask_c = cell_mask_df_sampled[cell_mask_df_sampled.cell == c]
351
+ nxu_sampled.append(mask_c.x.nunique())
352
+ nyu_sampled.append(mask_c.y.nunique())
353
+
354
+ # specify ntanbin for pi/2 (a quantrant)
355
+ # if resolution is super high
356
+ if np.mean(nxu_sampled) > high_res and np.mean(nyu_sampled) > high_res:
357
+ ntanbin = max_ntanbin
358
+ # if resolution is not super high
359
+ else:
360
+ # require at least self.min_bp boundary points in each tanbin
361
+ theta = 2 * np.arctan(min_bp / np.mean(nxu_sampled + nyu_sampled))
362
+ ntanbin_ = (pi / 2) / theta
363
+ ntanbin = np.ceil(ntanbin_)
364
+ if ntanbin < min_ntanbin_error:
365
+ print(
366
+ f"Cell type {t} failed, resolution not high enougth to support the analysis"
367
+ )
368
+ ntanbin = 3
369
+ # asign
370
+ ntanbin_dict[t] = int(ntanbin)
371
+ return ntanbin_dict
372
+
373
+
374
+ def process_chunk_cell(chunk, df_gbC, cell_mask_df, ntanbin_dict, epsilon):
375
+ results = []
376
+ for c in chunk:
377
+ df_c = df_gbC.get_group(c).copy() # Per-cell transcript table.
378
+ t = df_c.type.iloc[0] # Cell type.
379
+ mask_df_c = cell_mask_df[
380
+ cell_mask_df.cell == c
381
+ ].copy() # get the mask df for cell c
382
+ center_c = [
383
+ int(df_c.centerX.iloc[0]),
384
+ int(df_c.centerY.iloc[0]),
385
+ ] # nuclear center of cell c
386
+ tanbin = np.linspace(0, pi / 2, ntanbin_dict[t] + 1) # Angle bins per quadrant.
387
+ delta_tanbin = (2 * math.pi) / (
388
+ ntanbin_dict[t] * 4
389
+ ) # Full circle 2*pi is split into (ntanbin_dict[t] * 4) bins.
390
+ mask_df_c["x_c"] = mask_df_c.x.copy() - center_c[0]
391
+ mask_df_c["y_c"] = mask_df_c.y.copy() - center_c[1]
392
+ mask_df_c["d_c"] = (
393
+ mask_df_c.x_c.copy() ** 2 + mask_df_c.y_c.copy() ** 2
394
+ ) ** 0.5
395
+ mask_df_c["arctan"] = np.absolute(
396
+ np.arctan(mask_df_c.y_c / (mask_df_c.x_c + epsilon))
397
+ )
398
+ # Split mask points into 4 quadrants.
399
+ mask_df_c_q_dict = {
400
+ "0": mask_df_c[(mask_df_c.x_c >= 0) & (mask_df_c.y_c >= 0)],
401
+ "1": mask_df_c[(mask_df_c.x_c <= 0) & (mask_df_c.y_c >= 0)],
402
+ "2": mask_df_c[(mask_df_c.x_c <= 0) & (mask_df_c.y_c <= 0)],
403
+ "3": mask_df_c[(mask_df_c.x_c >= 0) & (mask_df_c.y_c <= 0)],
404
+ }
405
+ # compute the dismax_c
406
+ dismax_c_mat = np.zeros((ntanbin_dict[t], 4)) # Shape: (n_bins, 4 quadrants).
407
+ for q in range(4): # in each of the 4 quantrants
408
+ mask_df_c_q = mask_df_c_q_dict[str(q)].copy() # Work on a copy.
409
+ if len(mask_df_c_q) > 0:
410
+ mask_df_c_q["arctan_idx"] = (
411
+ mask_df_c_q["arctan"] / delta_tanbin
412
+ ).astype(int) # arctan_idx from 0 to ntanbin_dict[t]-1.
413
+ # Ensure arctan_idx stays in range.
414
+ mask_df_c_q["arctan_idx"] = np.minimum(
415
+ mask_df_c_q["arctan_idx"], ntanbin_dict[t] - 1
416
+ )
417
+ max_distances = mask_df_c_q.groupby("arctan_idx").max()["d_c"]
418
+ if not max_distances.empty:
419
+ dismax_c_mat[max_distances.index.to_numpy(), q] = (
420
+ max_distances.values
421
+ )
422
+
423
+ # Fill missing max radius values.
424
+ for q in range(4):
425
+ for ai in range(ntanbin_dict[t]):
426
+ if dismax_c_mat[ai, q] == 0:
427
+ # Use non-zero neighbors if available.
428
+ neighbors = [
429
+ i for i in range(ntanbin_dict[t]) if dismax_c_mat[i, q] > 0
430
+ ]
431
+ if neighbors:
432
+ dismax_c_mat[ai, q] = np.mean(
433
+ [dismax_c_mat[i, q] for i in neighbors]
434
+ )
435
+ else:
436
+ # Fall back to max radius within the quadrant.
437
+ max_q = np.max(dismax_c_mat[:, q])
438
+ if max_q > 0:
439
+ dismax_c_mat[ai, q] = max_q
440
+ else:
441
+ # If the whole quadrant is empty, use the global max.
442
+ max_all = np.max(dismax_c_mat)
443
+ if max_all > 0:
444
+ dismax_c_mat[ai, q] = max_all
445
+ else:
446
+ # If everything is empty, use a large default.
447
+ dismax_c_mat[ai, q] = 100
448
+
449
+ # dismax_c_mat stores max radii for each (angle bin, quadrant).
450
+ # add centered coord and ratio=y/x for df_c and mask_df_c
451
+ df_c["x_c"] = df_c.x.copy() - center_c[0] # Relative to cell center.
452
+ df_c["y_c"] = df_c.y.copy() - center_c[1]
453
+ df_c["d_c"] = (
454
+ df_c.x_c.copy() ** 2 + df_c.y_c.copy() ** 2
455
+ ) ** 0.5 # Distance to center.
456
+ df_c["arctan"] = np.absolute(
457
+ np.arctan(df_c.y_c / (df_c.x_c + epsilon))
458
+ ) # Angle to x-axis.
459
+
460
+ # Normalize coordinates.
461
+ df_c_registered = normalize_dataset(
462
+ df_c,
463
+ dismax_c_mat,
464
+ delta_tanbin,
465
+ ntanbin_dict,
466
+ t,
467
+ epsilon,
468
+ is_nucleus=False,
469
+ clip_to_cell=True,
470
+ )
471
+ cell_radius = df_c_registered["d_c_maxc"].max() # Per-cell maximum radius.
472
+ results.append((df_c_registered, cell_radius))
473
+ return results
474
+
475
+
476
+ def register_cells_parallel_chunked(
477
+ data_df,
478
+ cell_list_all,
479
+ cell_mask_df,
480
+ ntanbin_dict,
481
+ epsilon=1e-10,
482
+ nc_demo=None,
483
+ chunk_size=5,
484
+ ):
485
+ if nc_demo is None:
486
+ nc_demo = len(cell_list_all)
487
+ df_gbC = data_df.groupby("cell", observed=False) # Group by cell.
488
+ chunks = list(chunk_list(cell_list_all[:nc_demo], chunk_size)) # Split into chunks.
489
+ pool = Pool(processes=cpu_count() - 2) # Leave some CPU for the system.
490
+ process_chunk_partial = partial(
491
+ process_chunk_cell,
492
+ df_gbC=df_gbC,
493
+ cell_mask_df=cell_mask_df,
494
+ ntanbin_dict=ntanbin_dict,
495
+ epsilon=epsilon,
496
+ )
497
+ results = list(
498
+ tqdm(
499
+ pool.imap(process_chunk_partial, chunks),
500
+ total=len(chunks),
501
+ desc="Processing chunks in parallel",
502
+ )
503
+ ) # Parallel processing.
504
+ pool.close() # Close pool.
505
+ pool.join() # Wait for all workers.
506
+ all_cell_dfs = [] # Aggregate results.
507
+ all_nuclear_dfs = []
508
+ all_radii = {}
509
+ for result_chunk in results:
510
+ for df_c_registered, cell_radius in result_chunk:
511
+ all_cell_dfs.append(df_c_registered)
512
+ all_radii.update(
513
+ {df_c_registered["cell"].iloc[0]: cell_radius}
514
+ ) # Store per-cell radius in a dict.
515
+ cell_df_registered = pd.concat(all_cell_dfs)
516
+ return cell_df_registered, all_radii
517
+
518
+
519
+ def chunk_list(data_list, chunk_size): # Split a list into chunks.
520
+ for i in range(0, len(data_list), chunk_size):
521
+ yield data_list[i : i + chunk_size]
522
+
523
+
524
+ def process_chunk(
525
+ chunk,
526
+ df_gbC,
527
+ cell_mask_df,
528
+ nuclear_boundary,
529
+ ntanbin_dict,
530
+ epsilon,
531
+ clip_to_cell=True,
532
+ remove_outliers=False,
533
+ verbose=False,
534
+ ):
535
+ """Process a chunk of cells and (optionally) their nucleus boundaries.
536
+
537
+ Parameters
538
+ ----------
539
+ chunk:
540
+ List of cell IDs to process.
541
+ df_gbC:
542
+ data_df grouped by cell (DataFrameGroupBy).
543
+ cell_mask_df:
544
+ Cell boundary mask points.
545
+ nuclear_boundary:
546
+ Mapping {cell_id: DataFrame with nucleus boundary points}.
547
+ ntanbin_dict:
548
+ Mapping {cell_type: number_of_angle_bins_per_quadrant}.
549
+ epsilon:
550
+ Small constant for numerical stability.
551
+ clip_to_cell:
552
+ If True, clip normalized distances (d_c_s) to <= 1.
553
+ remove_outliers:
554
+ If True, drop nucleus points that exceed the cell boundary.
555
+ verbose:
556
+ If True, print additional warnings and statistics.
557
+
558
+ Returns
559
+ -------
560
+ list
561
+ List of tuples: (df_c_registered, nuclear_boundary_c_registered, cell_radius).
562
+ """
563
+ results = []
564
+ for c in chunk:
565
+ try:
566
+ df_c = df_gbC.get_group(c).copy() # Per-cell transcript table.
567
+ except KeyError:
568
+ if verbose:
569
+ print(f"Warning: Cell {c} not found in data_df")
570
+ continue
571
+
572
+ t = df_c.type.iloc[0] # Cell type.
573
+
574
+ # Try different lookup strategies (cell id types may differ).
575
+ mask_df_c = cell_mask_df[cell_mask_df.cell == c].copy()
576
+ if len(mask_df_c) == 0:
577
+ # Try casting cell ids to match.
578
+ if isinstance(c, str):
579
+ if verbose:
580
+ print(f"Converting cell {c} to string")
581
+ mask_df_c = cell_mask_df[cell_mask_df.cell.astype(str) == c].copy()
582
+ elif isinstance(c, (int, np.integer)):
583
+ if verbose:
584
+ print(f"Converting cell {c} to integer")
585
+ mask_df_c = cell_mask_df[cell_mask_df.cell.astype(int) == c].copy()
586
+
587
+ if len(mask_df_c) == 0:
588
+ if verbose:
589
+ print(f"Warning: No mask points found for cell {c}")
590
+ continue
591
+
592
+ try:
593
+ nuclear_boundary_c = nuclear_boundary[
594
+ c
595
+ ].copy() # Current cell nucleus boundary.
596
+ except KeyError:
597
+ if verbose:
598
+ print(f"Warning: No nuclear boundary found for cell {c}")
599
+ continue
600
+
601
+ center_c = [
602
+ int(df_c.centerX.iloc[0]),
603
+ int(df_c.centerY.iloc[0]),
604
+ ] # nuclear center of cell c
605
+ tanbin = np.linspace(0, pi / 2, ntanbin_dict[t] + 1) # Angle bins per quadrant.
606
+ delta_tanbin = (2 * math.pi) / (
607
+ ntanbin_dict[t] * 4
608
+ ) # Full circle 2*pi is split into (ntanbin_dict[t] * 4) bins.
609
+
610
+ # Precompute mask coordinates relative to the cell center.
611
+ mask_df_c["x_c"] = mask_df_c.x.copy() - center_c[0]
612
+ mask_df_c["y_c"] = mask_df_c.y.copy() - center_c[1]
613
+ mask_df_c["d_c"] = (
614
+ mask_df_c.x_c.copy() ** 2 + mask_df_c.y_c.copy() ** 2
615
+ ) ** 0.5
616
+ mask_df_c["arctan"] = np.absolute(
617
+ np.arctan(mask_df_c.y_c / (mask_df_c.x_c + epsilon))
618
+ )
619
+
620
+ # Split mask points into 4 quadrants.
621
+ mask_df_c_q_dict = {
622
+ "0": mask_df_c[(mask_df_c.x_c >= 0) & (mask_df_c.y_c >= 0)],
623
+ "1": mask_df_c[(mask_df_c.x_c <= 0) & (mask_df_c.y_c >= 0)],
624
+ "2": mask_df_c[(mask_df_c.x_c <= 0) & (mask_df_c.y_c <= 0)],
625
+ "3": mask_df_c[(mask_df_c.x_c >= 0) & (mask_df_c.y_c <= 0)],
626
+ }
627
+
628
+ # For each angle bin (per quadrant), compute the maximum radius.
629
+ dismax_c_mat = np.zeros((ntanbin_dict[t], 4)) # Shape: (n_bins, 4 quadrants).
630
+ for q in range(4): # in each of the 4 quantrants
631
+ mask_df_c_q = mask_df_c_q_dict[str(q)].copy() # Work on a copy.
632
+ if len(mask_df_c_q) > 0:
633
+ mask_df_c_q["arctan_idx"] = (
634
+ mask_df_c_q["arctan"] / delta_tanbin
635
+ ).astype(int) # arctan_idx from 0 to ntanbin_dict[t]-1.
636
+ # Ensure arctan_idx stays in range.
637
+ mask_df_c_q["arctan_idx"] = np.minimum(
638
+ mask_df_c_q["arctan_idx"], ntanbin_dict[t] - 1
639
+ )
640
+ max_distances = mask_df_c_q.groupby("arctan_idx").max()["d_c"]
641
+ if not max_distances.empty:
642
+ dismax_c_mat[max_distances.index.to_numpy(), q] = (
643
+ max_distances.values
644
+ ) # automatically sorted by arctan_idx from 0 to self.ntanbin_dict[t]-1
645
+
646
+ # Fill missing max radius values.
647
+ fill_zero_indices = np.where(dismax_c_mat == 0)
648
+ if len(fill_zero_indices[0]) > 0:
649
+ for ai, q in zip(fill_zero_indices[0], fill_zero_indices[1]):
650
+ # Find nearby non-zero values.
651
+ neighbors = []
652
+ for offset in range(1, ntanbin_dict[t]):
653
+ ai_before = (ai - offset) % ntanbin_dict[t]
654
+ ai_after = (ai + offset) % ntanbin_dict[t]
655
+ if dismax_c_mat[ai_before, q] > 0:
656
+ neighbors.append(dismax_c_mat[ai_before, q])
657
+ if dismax_c_mat[ai_after, q] > 0:
658
+ neighbors.append(dismax_c_mat[ai_after, q])
659
+ if neighbors: # Stop once we find any non-zero neighbor.
660
+ break
661
+
662
+ if neighbors:
663
+ dismax_c_mat[ai, q] = np.mean(neighbors)
664
+ else:
665
+ # If no neighbors, use the mean of all non-zero values in this quadrant.
666
+ nonzero_in_q = dismax_c_mat[:, q][dismax_c_mat[:, q] > 0]
667
+ if len(nonzero_in_q) > 0:
668
+ dismax_c_mat[ai, q] = np.mean(nonzero_in_q)
669
+ else:
670
+ # If the quadrant is empty, use the mean of all non-zero values.
671
+ all_nonzero = dismax_c_mat[dismax_c_mat > 0]
672
+ if len(all_nonzero) > 0:
673
+ dismax_c_mat[ai, q] = np.mean(all_nonzero)
674
+ else:
675
+ # If everything is empty, fall back to a heuristic default.
676
+ dismax_c_mat[ai, q] = (
677
+ np.max(df_c["d_c"]) * 1.5
678
+ ) # Use 1.5x max gene-point distance.
679
+
680
+ # dismax_c_mat stores max radii for each (angle bin, quadrant).
681
+ # add centered coord and ratio=y/x for df_c and mask_df_c
682
+ df_c["x_c"] = df_c.x.copy() - center_c[0] # Relative to cell center.
683
+ df_c["y_c"] = df_c.y.copy() - center_c[1]
684
+ df_c["d_c"] = (
685
+ df_c.x_c.copy() ** 2 + df_c.y_c.copy() ** 2
686
+ ) ** 0.5 # Distance to center.
687
+ df_c["arctan"] = np.absolute(
688
+ np.arctan(df_c.y_c / (df_c.x_c + epsilon))
689
+ ) # Angle to x-axis.
690
+ # Nucleus boundary points relative to the cell center.
691
+ nuclear_boundary_c["x_c"] = nuclear_boundary_c.x.copy() - center_c[0]
692
+ nuclear_boundary_c["y_c"] = nuclear_boundary_c.y.copy() - center_c[1]
693
+ nuclear_boundary_c["d_c"] = (
694
+ nuclear_boundary_c.x_c**2 + nuclear_boundary_c.y_c**2
695
+ ) ** 0.5
696
+ nuclear_boundary_c["arctan"] = np.abs(
697
+ np.arctan(nuclear_boundary_c.y_c / (nuclear_boundary_c.x_c + epsilon))
698
+ )
699
+
700
+ # Normalize cell and nucleus-boundary data.
701
+ df_c_registered = normalize_dataset(
702
+ df_c,
703
+ dismax_c_mat,
704
+ delta_tanbin,
705
+ ntanbin_dict,
706
+ t,
707
+ epsilon,
708
+ is_nucleus=False,
709
+ clip_to_cell=True,
710
+ remove_outliers=False,
711
+ )
712
+ nuclear_boundary_c_registered = normalize_dataset(
713
+ nuclear_boundary_c,
714
+ dismax_c_mat,
715
+ delta_tanbin,
716
+ ntanbin_dict,
717
+ t,
718
+ epsilon,
719
+ is_nucleus=True,
720
+ clip_to_cell=clip_to_cell,
721
+ remove_outliers=remove_outliers,
722
+ )
723
+ nuclear_boundary_c_registered["cell"] = c
724
+
725
+ # Compute the fraction of nucleus points exceeding the boundary.
726
+ exceed_percent = 0
727
+ if "exceeds_boundary" in nuclear_boundary_c_registered.columns:
728
+ exceed_percent = (
729
+ nuclear_boundary_c_registered["exceeds_boundary"].mean() * 100
730
+ )
731
+ if exceed_percent > 0 and verbose:
732
+ print(
733
+ f"Cell {c}: {exceed_percent:.2f}% of nuclear boundary points exceed cell boundary"
734
+ )
735
+
736
+ cell_radius = df_c_registered["d_c_maxc"].max() # Per-cell maximum radius.
737
+
738
+ results.append((df_c_registered, nuclear_boundary_c_registered, cell_radius))
739
+
740
+ return results
741
+
742
+
743
+ def register_cells_and_nuclei_parallel_chunked(
744
+ data_df,
745
+ cell_list_all,
746
+ cell_mask_df,
747
+ nuclear_boundary,
748
+ ntanbin_dict,
749
+ epsilon=1e-10,
750
+ nc_demo=None,
751
+ chunk_size=2,
752
+ clip_to_cell=True,
753
+ remove_outliers=False,
754
+ verbose=False,
755
+ ):
756
+ """Register cells and nucleus boundaries in parallel (chunked).
757
+
758
+ This function uses multiprocessing to process cells in chunks.
759
+
760
+ Returns
761
+ -------
762
+ cell_df_registered:
763
+ Registered transcript table for all processed cells.
764
+ nuclear_boundary_df_registered:
765
+ Registered nucleus boundary points for all processed cells.
766
+ all_radii:
767
+ Mapping {cell_id: cell_radius}.
768
+ """
769
+ if nc_demo is None:
770
+ nc_demo = len(cell_list_all)
771
+ df_gbC = data_df.groupby("cell", observed=False) # Group by cell.
772
+ chunks = list(chunk_list(cell_list_all[:nc_demo], chunk_size)) # Split into chunks.
773
+ # pool = Pool(processes=cpu_count() - 2) # Leave some CPU for the system.
774
+ pool = Pool(processes=min(4, cpu_count() - 2)) # Cap worker count.
775
+ process_chunk_partial = partial(
776
+ process_chunk,
777
+ df_gbC=df_gbC,
778
+ cell_mask_df=cell_mask_df,
779
+ nuclear_boundary=nuclear_boundary,
780
+ ntanbin_dict=ntanbin_dict,
781
+ epsilon=epsilon,
782
+ clip_to_cell=clip_to_cell,
783
+ remove_outliers=remove_outliers,
784
+ verbose=verbose,
785
+ )
786
+ results = list(
787
+ tqdm(
788
+ pool.imap(process_chunk_partial, chunks),
789
+ total=len(chunks),
790
+ desc="Processing chunks in parallel",
791
+ )
792
+ ) # Parallel processing.
793
+ pool.close() # Close pool.
794
+ pool.join() # Wait for all workers.
795
+ all_cell_dfs = [] # Aggregate results.
796
+ all_nuclear_dfs = []
797
+ all_radii = {}
798
+ for result_chunk in results:
799
+ for df_c_registered, nuclear_boundary_c_registered, cell_radius in result_chunk:
800
+ all_cell_dfs.append(df_c_registered)
801
+ all_nuclear_dfs.append(nuclear_boundary_c_registered)
802
+ all_radii.update(
803
+ {df_c_registered["cell"].iloc[0]: cell_radius}
804
+ ) # Store per-cell radius in a dict.
805
+ cell_df_registered = pd.concat(all_cell_dfs)
806
+ nuclear_boundary_df_registered = pd.concat(all_nuclear_dfs)
807
+ return cell_df_registered, nuclear_boundary_df_registered, all_radii
808
+
809
+
810
+ def register_cells_and_nuclei_parallel_chunked_constrained(
811
+ data_df,
812
+ cell_list_all,
813
+ cell_mask_df,
814
+ nuclear_boundary,
815
+ ntanbin_dict,
816
+ epsilon=1e-10,
817
+ nc_demo=None,
818
+ chunk_size=5,
819
+ clip_to_cell=True,
820
+ remove_outliers=False,
821
+ verbose=True,
822
+ ):
823
+ """Chunked parallel registration with nucleus boundary constraint.
824
+
825
+ Compared to register_cells_and_nuclei_parallel_chunked, this variant also
826
+ returns per-cell statistics about nucleus points exceeding the cell boundary.
827
+ """
828
+ if nc_demo is None:
829
+ nc_demo = len(cell_list_all)
830
+
831
+ # Validate inputs first.
832
+ missing_cells_mask = [
833
+ c for c in cell_list_all[:nc_demo] if c not in cell_mask_df["cell"].unique()
834
+ ]
835
+ missing_cells_nuclear = [
836
+ c for c in cell_list_all[:nc_demo] if c not in nuclear_boundary.keys()
837
+ ]
838
+
839
+ if missing_cells_mask or missing_cells_nuclear:
840
+ print(f"Warning: Found {len(missing_cells_mask)} cells missing in mask_df")
841
+ print(
842
+ f"Warning: Found {len(missing_cells_nuclear)} cells missing in nuclear_boundary"
843
+ )
844
+
845
+ # Filter out cells with missing inputs.
846
+ valid_cells = [
847
+ c
848
+ for c in cell_list_all[:nc_demo]
849
+ if c in cell_mask_df["cell"].unique() and c in nuclear_boundary.keys()
850
+ ]
851
+ print(f"Proceeding with {len(valid_cells)} valid cells (originally {nc_demo})")
852
+ cell_list_for_processing = valid_cells
853
+ else:
854
+ cell_list_for_processing = cell_list_all[:nc_demo]
855
+
856
+ # Group input table and create processing chunks.
857
+ df_gbC = data_df.groupby("cell", observed=False)
858
+ chunks = list(chunk_list(cell_list_for_processing, chunk_size))
859
+
860
+ # Create multiprocessing pool.
861
+ pool = Pool(processes=min(4, cpu_count() - 2))
862
+ process_chunk_partial = partial(
863
+ process_chunk,
864
+ df_gbC=df_gbC,
865
+ cell_mask_df=cell_mask_df,
866
+ nuclear_boundary=nuclear_boundary,
867
+ ntanbin_dict=ntanbin_dict,
868
+ epsilon=epsilon,
869
+ clip_to_cell=clip_to_cell,
870
+ remove_outliers=remove_outliers,
871
+ verbose=verbose,
872
+ )
873
+
874
+ # Parallel processing.
875
+ results = list(
876
+ tqdm(
877
+ pool.imap(process_chunk_partial, chunks),
878
+ total=len(chunks),
879
+ desc="Processing chunks in parallel",
880
+ )
881
+ )
882
+
883
+ pool.close()
884
+ pool.join()
885
+
886
+ # Aggregate results.
887
+ all_cell_dfs = []
888
+ all_nuclear_dfs = []
889
+ all_radii = {}
890
+ all_nuclear_stats = []
891
+
892
+ for result_chunk in results:
893
+ for df_c_registered, nuclear_boundary_c_registered, cell_radius in result_chunk:
894
+ all_cell_dfs.append(df_c_registered)
895
+ all_nuclear_dfs.append(nuclear_boundary_c_registered)
896
+ all_radii.update({df_c_registered["cell"].iloc[0]: cell_radius})
897
+ # Per-cell nuclear boundary stats
898
+ exceed_percent = 0.0
899
+ exceed_count = 0
900
+ num_points = int(len(nuclear_boundary_c_registered))
901
+ if (
902
+ num_points > 0
903
+ and "exceeds_boundary" in nuclear_boundary_c_registered.columns
904
+ ):
905
+ exceed_series = nuclear_boundary_c_registered["exceeds_boundary"]
906
+ exceed_percent = float(exceed_series.mean()) * 100.0
907
+ exceed_count = int(exceed_series.sum())
908
+
909
+ all_nuclear_stats.append(
910
+ {
911
+ "cell": df_c_registered["cell"].iloc[0],
912
+ "exceed_percent": exceed_percent,
913
+ "exceed_count": exceed_count,
914
+ "num_nuclear_points": num_points,
915
+ }
916
+ )
917
+
918
+ cell_df_registered = pd.concat(all_cell_dfs)
919
+ nuclear_boundary_df_registered = pd.concat(all_nuclear_dfs)
920
+ cell_nuclear_stats = pd.DataFrame(all_nuclear_stats)
921
+
922
+ # Print summary stats.
923
+ if verbose:
924
+ cells_with_exceeding_nucleus = cell_nuclear_stats[
925
+ cell_nuclear_stats["exceed_percent"] > 0
926
+ ]
927
+ if not cells_with_exceeding_nucleus.empty:
928
+ mean_exceed = cells_with_exceeding_nucleus["exceed_percent"].mean()
929
+ max_exceed = cells_with_exceeding_nucleus["exceed_percent"].max()
930
+ print(
931
+ f"\nFound {len(cells_with_exceeding_nucleus)} cells with nucleus exceeding cell boundary"
932
+ )
933
+ print(f"Average exceed percentage: {mean_exceed:.2f}%")
934
+ print(f"Maximum exceed percentage: {max_exceed:.2f}%")
935
+ print(f"After {'clipping' if clip_to_cell else 'leaving'} exceed points")
936
+
937
+ return (
938
+ cell_df_registered,
939
+ nuclear_boundary_df_registered,
940
+ all_radii,
941
+ cell_nuclear_stats,
942
+ )
943
+
944
+
945
+ def normalize_dataset(
946
+ dataset,
947
+ dismax_c_mat,
948
+ delta_tanbin,
949
+ ntanbin_dict,
950
+ t,
951
+ epsilon=1e-10,
952
+ is_nucleus=False,
953
+ clip_to_cell=True,
954
+ remove_outliers=False,
955
+ ):
956
+ """Normalize points by the cell boundary (angle-binned max radius).
957
+
958
+ Given per-angle-bin maximum radii (dismax_c_mat), compute normalized radius
959
+ (d_c_s) and normalized coordinates (x_c_s, y_c_s).
960
+ """
961
+ dataset_normalized = dataset.assign(
962
+ d_c_maxc=np.zeros(len(dataset)),
963
+ d_c_s=np.zeros(len(dataset)),
964
+ x_c_s=np.zeros(len(dataset)),
965
+ y_c_s=np.zeros(len(dataset)),
966
+ )
967
+
968
+ # Track points exceeding the cell boundary.
969
+ if is_nucleus:
970
+ dataset_normalized["exceeds_boundary"] = False
971
+
972
+ for q in range(4):
973
+ dataset_q = dataset[
974
+ (dataset.x_c >= 0) & (dataset.y_c >= 0)
975
+ if q == 0
976
+ else (dataset.x_c <= 0) & (dataset.y_c >= 0)
977
+ if q == 1
978
+ else (dataset.x_c <= 0) & (dataset.y_c <= 0)
979
+ if q == 2
980
+ else (dataset.x_c >= 0) & (dataset.y_c <= 0)
981
+ ].copy()
982
+
983
+ if len(dataset_q) > 0:
984
+ dataset_q["arctan_idx"] = (dataset_q["arctan"] / delta_tanbin).astype(int)
985
+
986
+ # Ensure arctan_idx stays in range.
987
+ dataset_q["arctan_idx"] = np.minimum(
988
+ dataset_q["arctan_idx"], ntanbin_dict[t] - 1
989
+ )
990
+
991
+ for ai in range(ntanbin_dict[t]):
992
+ max_d = dismax_c_mat[ai, q]
993
+ indices = dataset_q.index[dataset_q["arctan_idx"] == ai]
994
+ dataset_normalized.loc[indices, "d_c_maxc"] = max_d
995
+
996
+ # Normalized radial distance.
997
+ dataset_normalized["d_c_s"] = dataset["d_c"] / (
998
+ dataset_normalized["d_c_maxc"] + epsilon
999
+ )
1000
+
1001
+ # If nucleus points, mark those exceeding the boundary.
1002
+ if is_nucleus:
1003
+ dataset_normalized["exceeds_boundary"] = dataset_normalized["d_c_s"] > 1
1004
+
1005
+ # Optionally remove out-of-bound nucleus points.
1006
+ if remove_outliers and is_nucleus:
1007
+ dataset_normalized = dataset_normalized[~dataset_normalized["exceeds_boundary"]]
1008
+
1009
+ # Optionally clip to the cell boundary.
1010
+ if clip_to_cell:
1011
+ dataset_normalized["d_c_s"] = np.minimum(dataset_normalized["d_c_s"], 1.0)
1012
+
1013
+ # Normalized coordinates.
1014
+ dataset_normalized["x_c_s"] = (
1015
+ dataset["x_c"] * dataset_normalized["d_c_s"] / (dataset["d_c"] + epsilon)
1016
+ )
1017
+ dataset_normalized["y_c_s"] = (
1018
+ dataset["y_c"] * dataset_normalized["d_c_s"] / (dataset["d_c"] + epsilon)
1019
+ )
1020
+
1021
+ return dataset_normalized