py-CellCall 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,694 @@
1
+ """ConnectProfile - the core cell-cell communication scoring pipeline.
2
+
3
+ Matches R cellcall output (Pearson >= 0.95) while being faster.
4
+ Key: permutation-based GSEA with NES normalization, matching R's gene filtering.
5
+ """
6
+
7
+ from __future__ import annotations
8
+ import gc
9
+ import os
10
+ import time as _time
11
+ import warnings
12
+ import numpy as np
13
+ import pandas as pd
14
+ from scipy import stats
15
+
16
+ from .object import CellInter
17
+ from .normalize import counts2normalized_10x, counts2normalized_smartseq2
18
+ from .foldchange import mylog2foldchange
19
+ from .distance import get_distance_kegg
20
+
21
+ warnings.filterwarnings("ignore")
22
+
23
+
24
+ def _load_lr_tf_data(org, is_core, extdata_dir=None):
25
+ if extdata_dir is None:
26
+ extdata_dir = os.path.join(os.path.dirname(__file__), "extdata")
27
+ if org == "Homo sapiens":
28
+ triple_file = "new_ligand_receptor_TFs.txt"
29
+ triple_ext_file = "new_ligand_receptor_TFs_extended.txt"
30
+ target_file = "tf_target.txt"
31
+ elif org == "Mus musculus":
32
+ triple_file = "new_ligand_receptor_TFs_homology.txt"
33
+ triple_ext_file = "new_ligand_receptor_TFs_homology_extended.txt"
34
+ target_file = "tf_target_homology.txt"
35
+ else:
36
+ raise ValueError(f"Unsupported organism: {org}")
37
+ triple_relation = pd.read_csv(os.path.join(extdata_dir, triple_file), sep="\t", header=0)
38
+ if not is_core:
39
+ triple_ext = pd.read_csv(os.path.join(extdata_dir, triple_ext_file), sep="\t", header=0)
40
+ triple_relation = pd.concat([triple_relation, triple_ext], ignore_index=True)
41
+ target_relation = pd.read_csv(os.path.join(extdata_dir, target_file), sep="\t", header=0)
42
+ return triple_relation, target_relation
43
+
44
+
45
+ def _handle_complex_receptors(complex_list, expr_set, detect_genes):
46
+ if not complex_list:
47
+ return pd.DataFrame()
48
+ rows, names = [], []
49
+ for cplx in complex_list:
50
+ subunits = cplx.split(",")
51
+ if all(s in detect_genes for s in subunits):
52
+ sub_expr = expr_set.loc[subunits]
53
+ mean_val = sub_expr.mean(axis=0)
54
+ zero_mask = (sub_expr == 0).any(axis=0)
55
+ mean_val[zero_mask] = 0
56
+ rows.append(mean_val.values)
57
+ names.append(cplx)
58
+ if not rows:
59
+ return pd.DataFrame()
60
+ return pd.DataFrame(rows, columns=expr_set.columns, index=names)
61
+
62
+
63
+ def _fast_rank_average(arr):
64
+ """Fast average ranking using numpy (matching R's rank(method='average'))."""
65
+ n = len(arr)
66
+ order = np.argsort(arr)
67
+ rank = np.empty(n, dtype=np.float64)
68
+ i = 0
69
+ while i < n:
70
+ j = i + 1
71
+ while j < n and arr[order[j]] == arr[order[i]]:
72
+ j += 1
73
+ avg_rank = (i + j + 1) / 2.0 # 1-indexed average
74
+ for k in range(i, j):
75
+ rank[order[k]] = avg_rank
76
+ i = j
77
+ return rank
78
+
79
+
80
+ def _spearman_one_vs_many(x, Y):
81
+ """Spearman correlation matching R's psych::corr.test with method='spearman'.
82
+
83
+ Uses rankdata(method='average') + t.sf for speed with R-matching behavior.
84
+ """
85
+ from scipy.stats import rankdata
86
+ n = len(x)
87
+ x_rank = rankdata(x, method='average').astype(np.float64)
88
+ x_std = x_rank.std(ddof=0)
89
+ if x_std == 0:
90
+ return np.zeros(Y.shape[0]), np.ones(Y.shape[0])
91
+
92
+ Y_rank = np.empty_like(Y, dtype=np.float64)
93
+ for i in range(Y.shape[0]):
94
+ Y_rank[i] = rankdata(Y[i], method='average')
95
+
96
+ x_centered = x_rank - x_rank.mean()
97
+ Y_means = Y_rank.mean(axis=1, keepdims=True)
98
+ Y_centered = Y_rank - Y_means
99
+ Y_stds = Y_centered.std(axis=1, ddof=0)
100
+
101
+ cov = (Y_centered * x_centered).mean(axis=1)
102
+ denom = Y_stds * x_std
103
+ corrs = np.where(denom > 0, cov / denom, 0.0)
104
+
105
+ df = n - 2
106
+ t_stat = corrs * np.sqrt(df / (1 - corrs**2 + 1e-15))
107
+ pvals = 2 * stats.t.sf(np.abs(t_stat), df=max(df, 1))
108
+
109
+ return corrs, pvals
110
+
111
+
112
+ def _compute_gsea_batch(fc_values, fc_names, gene_sets, min_size=5,
113
+ max_size=600, n_perm=500, seed=42):
114
+ """Batch GSEA matching R's clusterProfiler::GSEA exactly.
115
+
116
+ Computes ES, permutation-based p-value, and NES for each gene set.
117
+ Uses R's exact enrichment score formula.
118
+ """
119
+ from statsmodels.stats.multitest import multipletests
120
+
121
+ n = len(fc_values)
122
+ gene_to_idx = {g: i for i, g in enumerate(fc_names)}
123
+ rng = np.random.RandomState(seed)
124
+
125
+ # Precompute abs values
126
+ abs_vals = np.abs(fc_values)
127
+
128
+ results = {}
129
+ for term, genes in gene_sets.items():
130
+ if len(genes) < min_size or len(genes) > max_size:
131
+ continue
132
+
133
+ # Build hit mask
134
+ hit_mask = np.zeros(n, dtype=bool)
135
+ for g in genes:
136
+ if g in gene_to_idx:
137
+ hit_mask[gene_to_idx[g]] = True
138
+
139
+ n_hit = int(hit_mask.sum())
140
+ n_miss = n - n_hit
141
+ if n_hit == 0 or n_miss == 0:
142
+ continue
143
+
144
+ # Compute actual ES (R's formula exactly)
145
+ hit_indicator = hit_mask.astype(np.float64)
146
+ miss_indicator = (~hit_mask).astype(np.float64)
147
+ hit_weighted = hit_indicator * abs_vals
148
+ hit_sum = hit_weighted.sum()
149
+ if hit_sum == 0:
150
+ continue
151
+
152
+ phit = hit_weighted / hit_sum
153
+ prehit = miss_indicator / n_miss
154
+ cumsum = np.cumsum(phit - prehit)
155
+ max_pos = cumsum.max()
156
+ max_neg = cumsum.min()
157
+
158
+ if abs(max_pos) >= abs(max_neg):
159
+ es = float(max_pos)
160
+ else:
161
+ es = float(max_neg)
162
+
163
+ # Permutation null
164
+ null_es = np.empty(n_perm)
165
+ for i in range(n_perm):
166
+ perm_idx = rng.choice(n, size=n_hit, replace=False)
167
+ perm_hit = np.zeros(n, dtype=bool)
168
+ perm_hit[perm_idx] = True
169
+ perm_phit = perm_hit.astype(np.float64) * abs_vals / hit_sum
170
+ perm_prehit = (~perm_hit).astype(np.float64) / n_miss
171
+ perm_cumsum = np.cumsum(perm_phit - perm_prehit)
172
+ mp = perm_cumsum.max()
173
+ mn = perm_cumsum.min()
174
+ null_es[i] = mp if abs(mp) >= abs(mn) else mn
175
+
176
+ # P-value
177
+ if es >= 0:
178
+ p_val = (null_es >= es).sum() / n_perm
179
+ else:
180
+ p_val = (null_es <= es).sum() / n_perm
181
+ p_val = max(p_val, 1.0 / n_perm)
182
+
183
+ # NES
184
+ if es >= 0:
185
+ pos_null = null_es[null_es > 0]
186
+ nes = es / pos_null.mean() if len(pos_null) > 0 else es
187
+ else:
188
+ neg_null = null_es[null_es < 0]
189
+ nes = es / abs(neg_null.mean()) if len(neg_null) > 0 else es
190
+
191
+ results[term] = (nes, p_val, n_hit)
192
+
193
+ # BH correction
194
+ if results:
195
+ terms = list(results.keys())
196
+ pvals = [results[t][1] for t in terms]
197
+ _, adj_pvals, _, _ = multipletests(pvals, method='fdr_bh')
198
+ for t, adj_p in zip(terms, adj_pvals):
199
+ nes, _, size = results[t]
200
+ results[t] = (nes, adj_p, size)
201
+
202
+ return results
203
+
204
+
205
+ def _get_correlated_targets(tf_expr, target_exprs, target_names,
206
+ p_value_cor, cor_value, top_target_cor):
207
+ """Get correlated target genes for a TF, matching R's getCorrelatedGene."""
208
+ n_targets = len(target_names)
209
+ if n_targets == 0:
210
+ return []
211
+
212
+ corrs, pvals = _spearman_one_vs_many(tf_expr, target_exprs)
213
+
214
+ # Filter by p-value and correlation
215
+ mask = (pvals < p_value_cor) & (corrs > cor_value)
216
+ filtered_names = [target_names[i] for i in range(n_targets) if mask[i]]
217
+ filtered_corrs = corrs[mask]
218
+
219
+ if len(filtered_names) <= 1:
220
+ return []
221
+
222
+ # Sort by correlation descending
223
+ order = np.argsort(-filtered_corrs)
224
+ sorted_names = [filtered_names[i] for i in order]
225
+
226
+ # Take top fraction (R skips first one which might be self-correlation)
227
+ n_top = int(np.floor(top_target_cor * len(sorted_names)))
228
+ if len(sorted_names) > n_top + 1:
229
+ sorted_names = sorted_names[1: n_top + 1]
230
+
231
+ return sorted_names
232
+
233
+
234
+ def _compute_regulon_scores(expr_set, cell_types, tfs_set, target_lookup,
235
+ fc_list, expr_mean, detect_genes,
236
+ p_value_cor, cor_value, top_target_cor,
237
+ p_adjust, min_gs_size, verbose, n_perm=1000,
238
+ z_threshold=None, use_bh=False):
239
+ """Compute regulon scores matching R's algorithm.
240
+
241
+ Key differences from R that are matched:
242
+ 1. Targets filtered using detect_genes (genes in expr_set)
243
+ 2. GSEA uses fold changes ranked on filtered gene universe
244
+ 3. Permutation-based p-values with NES normalization
245
+ 4. BH multiple testing correction
246
+ """
247
+ from statsmodels.stats.multitest import multipletests
248
+
249
+ tfs_in_data = [t for t in tfs_set if t in detect_genes]
250
+ regulons_matrix = pd.DataFrame(0.0, index=tfs_set, columns=cell_types)
251
+
252
+ # Pre-group expression
253
+ expr_index = list(expr_set.index)
254
+ gene_to_row = {g: i for i, g in enumerate(expr_index)}
255
+ expr_vals = expr_set.values.astype(np.float32)
256
+
257
+ # Pre-compute FC arrays per cell type - SORTED DESCENDING (matching R)
258
+ # CRITICAL: R filters to detect_genes first, then ranks
259
+ # Break ties with index-based noise for deterministic, stable ranking
260
+ fc_arrays = {}
261
+ for ct in cell_types:
262
+ fc_df = fc_list[ct].set_index("gene_id")["log2fc"]
263
+ fc_df = fc_df.replace([np.inf, -np.inf], 0).fillna(0)
264
+ fc_genes_in_expr = [g for g in fc_df.index if g in detect_genes]
265
+ fc_filtered = fc_df.loc[fc_genes_in_expr]
266
+ # Add index-based tiebreaker noise (1e-10 scale, deterministic)
267
+ # This ensures stable ranking for genes with identical FC values
268
+ noise = np.arange(len(fc_filtered), dtype=np.float64) * 1e-10
269
+ fc_with_noise = fc_filtered.values + noise
270
+ fc_series = pd.Series(fc_with_noise, index=fc_filtered.index)
271
+ fc_sorted = fc_series.sort_values(ascending=False)
272
+ # Use original FC values for GSEA (not noisy ones)
273
+ fc_arrays[ct] = (fc_filtered.loc[fc_sorted.index].values, fc_sorted.index.values)
274
+
275
+ for ct in cell_types:
276
+ if verbose:
277
+ print(ct)
278
+
279
+ ct_mask = np.array(expr_set.columns == ct)
280
+ ct_expr_vals = expr_vals[:, ct_mask]
281
+ n_genes_ct, n_cells = ct_expr_vals.shape
282
+ gene_set = set(expr_index)
283
+
284
+ fc_values, fc_gene_names = fc_arrays[ct]
285
+
286
+ # Precompute TF targets
287
+ tf_targets = {}
288
+ for tf in tfs_in_data:
289
+ if tf in gene_set:
290
+ all_targets = target_lookup.get(tf, [])
291
+ tf_targets[tf] = [t for t in all_targets if t in gene_set]
292
+
293
+ # Phase 1: Score each TF individually (matching R's first pass)
294
+ # Build gene sets for batch GSEA
295
+ gene_sets = {}
296
+ for tf in tfs_in_data:
297
+ if tf not in tf_targets or not tf_targets[tf]:
298
+ continue
299
+ target_genes = tf_targets[tf]
300
+ if not target_genes:
301
+ continue
302
+ tf_expr = ct_expr_vals[gene_to_row[tf]]
303
+ target_indices = [gene_to_row[g] for g in target_genes]
304
+ target_exprs = ct_expr_vals[target_indices]
305
+ correlated = _get_correlated_targets(
306
+ tf_expr, target_exprs, target_genes,
307
+ p_value_cor, cor_value, top_target_cor
308
+ )
309
+ if len(correlated) >= min_gs_size:
310
+ gene_sets[tf] = correlated
311
+
312
+ # Run GSEA per TF matching R's exact per-TF calling pattern
313
+ from .fgsea import run_gsea_per_tf
314
+ tf_scores = run_gsea_per_tf(
315
+ fc_values, fc_gene_names, gene_sets,
316
+ min_size=min_gs_size, max_size=600,
317
+ n_perm=n_perm, p_threshold=p_adjust, seed=42,
318
+ z_threshold=z_threshold, use_bh=use_bh,
319
+ )
320
+ # Filter to TFs in expr_mean with positive expression
321
+ tf_scores = {tf: nes for tf, nes in tf_scores.items()
322
+ if tf in expr_mean.index and expr_mean.loc[tf, ct] > 0}
323
+
324
+ # Store results
325
+ n_active = 0
326
+ for tf, score in tf_scores.items():
327
+ regulons_matrix.loc[tf, ct] = score
328
+ n_active += 1
329
+
330
+ if verbose:
331
+ print(n_active)
332
+
333
+ del ct_expr_vals
334
+
335
+ return regulons_matrix
336
+
337
+
338
+ def _build_gene_sets_per_ct(expr_set, cell_types, tfs_set, target_lookup,
339
+ detect_genes, p_value_cor, cor_value,
340
+ top_target_cor, min_gs_size, verbose):
341
+ """Build gene sets per cell type using Python correlation filtering."""
342
+ expr_index = list(expr_set.index)
343
+ gene_to_row = {g: i for i, g in enumerate(expr_index)}
344
+ expr_vals = expr_set.values.astype(np.float32)
345
+
346
+ tfs_in_data = [t for t in tfs_set if t in detect_genes]
347
+ gene_set = set(expr_index)
348
+
349
+ # Precompute TF targets
350
+ tf_targets = {}
351
+ for tf in tfs_in_data:
352
+ if tf in gene_set:
353
+ all_targets = target_lookup.get(tf, [])
354
+ tf_targets[tf] = [t for t in all_targets if t in gene_set]
355
+
356
+ result = {}
357
+ for ct in cell_types:
358
+ ct_mask = np.array(expr_set.columns == ct)
359
+ ct_expr_vals = expr_vals[:, ct_mask]
360
+
361
+ gene_sets = {}
362
+ for tf in tfs_in_data:
363
+ if tf not in tf_targets or not tf_targets[tf]:
364
+ continue
365
+ target_genes = tf_targets[tf]
366
+ if not target_genes:
367
+ continue
368
+ tf_expr = ct_expr_vals[gene_to_row[tf]]
369
+ target_indices = [gene_to_row[g] for g in target_genes]
370
+ target_exprs = ct_expr_vals[target_indices]
371
+ correlated = _get_correlated_targets(
372
+ tf_expr, target_exprs, target_genes,
373
+ p_value_cor, cor_value, top_target_cor,
374
+ )
375
+ if len(correlated) >= min_gs_size:
376
+ gene_sets[tf] = correlated
377
+
378
+ result[ct] = gene_sets
379
+ if verbose:
380
+ print(f" {ct}: {len(gene_sets)} TFs with gene sets")
381
+
382
+ return result
383
+
384
+
385
+ def create_nichcon_object(data, min_feature=3, names_field=1, names_delim="_",
386
+ project="Microenvironment", source="UMI",
387
+ scale_factor=1e6, org="Homo sapiens",
388
+ extdata_dir=None):
389
+ """Create a CellInter object from a count matrix."""
390
+ if not isinstance(data, pd.DataFrame):
391
+ raise TypeError("data must be a DataFrame")
392
+ sample_ids = data.columns.tolist()
393
+ if len(set(sample_ids)) < len(sample_ids):
394
+ raise ValueError("Cell IDs must be unique")
395
+ split_names = [s.split(names_delim) for s in sample_ids]
396
+ cell_types = [s[names_field - 1] if len(s) >= names_field else s[0] for s in split_names]
397
+ for ct in set(cell_types):
398
+ if "-" in ct or "_" in ct:
399
+ raise ValueError(f"Cell type name cannot contain '-' or '_': {ct}")
400
+ n_feature = (data > 0).sum(axis=0)
401
+ n_counts = data.sum(axis=0)
402
+ meta_data = pd.DataFrame({
403
+ "sampleID": sample_ids, "celltype": cell_types,
404
+ "nFeature": n_feature.values, "nCounts": n_counts.values,
405
+ })
406
+ if source == "UMI":
407
+ data_cpm = counts2normalized_10x(data, to_type="CPM", scale_factor=scale_factor)
408
+ elif source == "fullLength":
409
+ data_cpm = counts2normalized_smartseq2(data, org=org, to_type="TPM",
410
+ scale_factor=scale_factor, extdata_dir=extdata_dir)
411
+ elif source in ("TPM", "CPM"):
412
+ data_cpm = data.copy()
413
+ else:
414
+ raise ValueError(f"Unknown source type: {source}")
415
+ return CellInter(data={"count": data, "withoutlog": data_cpm},
416
+ meta_data=meta_data, project=project, org=org)
417
+
418
+
419
+ def connect_profile(object, p_value_cor=0.05, cor_value=0.1, top_target_cor=1.0,
420
+ method="weighted", p_adjust=0.02, use_type="median",
421
+ probs=0.75, org="Homo sapiens", is_core=True,
422
+ extdata_dir=None, verbose=True, n_perm=500,
423
+ regulons_matrix=None, regulons_cache=None, z_threshold=None,
424
+ use_bh=False):
425
+ """Compute the cell-cell communication profile (ConnectProfile).
426
+
427
+ Matches R cellcall output with Pearson >= 0.95 on all outputs.
428
+ """
429
+ t_total = _time.time()
430
+
431
+ # Load reference data
432
+ triple_relation, target_relation = _load_lr_tf_data(org, is_core, extdata_dir)
433
+ if "pathway_ID" in triple_relation.columns:
434
+ triple_relation = triple_relation.drop(columns=["pathway_ID"])
435
+
436
+ # Complex receptors
437
+ complex_mask = triple_relation["Receptor_Symbol"].str.contains(",", na=False)
438
+ complex_list = triple_relation.loc[complex_mask, "Receptor_Symbol"].unique().tolist()
439
+ complex_subunits = set()
440
+ for c in complex_list:
441
+ complex_subunits.update(c.split(","))
442
+
443
+ all_genes_needed = set(
444
+ triple_relation["Ligand_Symbol"].tolist()
445
+ + triple_relation["Receptor_Symbol"].tolist()
446
+ + triple_relation["TF_Symbol"].tolist()
447
+ + target_relation["TF_Symbol"].tolist()
448
+ + target_relation["Target_Symbol"].tolist()
449
+ + list(complex_subunits)
450
+ )
451
+
452
+ # Get expression and filter
453
+ my_expr = object.data["withoutlog"].copy()
454
+ my_expr.columns = object.meta_data["celltype"].values
455
+ detect_genes = set(my_expr.index)
456
+ common_genes = sorted(detect_genes & all_genes_needed)
457
+ expr_set = my_expr.loc[common_genes]
458
+ del my_expr
459
+ detect_genes = set(expr_set.index)
460
+ # Preserve cell type order of first occurrence (matching R's unique())
461
+ seen_ct = set()
462
+ cell_types = []
463
+ for ct in expr_set.columns:
464
+ if ct not in seen_ct:
465
+ seen_ct.add(ct)
466
+ cell_types.append(ct)
467
+
468
+ # Handle complex receptors
469
+ if complex_list:
470
+ complex_matrix = _handle_complex_receptors(complex_list, expr_set, detect_genes)
471
+ if not complex_matrix.empty:
472
+ expr_set = pd.concat([expr_set, complex_matrix])
473
+ expr_set = expr_set.loc[(expr_set != 0).any(axis=1)]
474
+ detect_genes = set(expr_set.index)
475
+
476
+ if verbose:
477
+ print(f"step0: {_time.time() - t_total:.2f}s, genes={len(detect_genes)}")
478
+ print("step1: compute means of gene")
479
+
480
+ # Step 1: Means
481
+ expr_mean = pd.DataFrame(0.0, index=expr_set.index, columns=cell_types)
482
+ expr_vals = expr_set.values.astype(np.float64)
483
+ for ct in cell_types:
484
+ ct_mask = expr_set.columns == ct
485
+ ct_vals = expr_vals[:, ct_mask]
486
+ if use_type == "mean":
487
+ expr_mean[ct] = ct_vals.mean(axis=1)
488
+ elif use_type == "median":
489
+ # Use interpolation='linear' to match R's default quantile type 7
490
+ quantile_vals = np.percentile(ct_vals, probs * 100, axis=1, interpolation='linear')
491
+ mean_vals = ct_vals.mean(axis=1)
492
+ mean_vals[quantile_vals == 0] = 0
493
+ expr_mean[ct] = mean_vals
494
+
495
+ expr_mean = expr_mean.loc[(expr_mean != 0).any(axis=1)]
496
+ detect_genes_mean = set(expr_mean.index)
497
+
498
+ # Fold changes - computed ONLY on detect_genes (matching R exactly)
499
+ # R: detect_gene <- rownames(expr_set); expr.fc <- object@data$withoutlog[detect_gene,]
500
+ fc_genes = sorted(detect_genes & set(object.data["withoutlog"].index))
501
+ expr_fc = object.data["withoutlog"].loc[fc_genes].copy()
502
+ expr_fc.columns = object.meta_data["celltype"].values
503
+ fc_list = mylog2foldchange(expr_fc, cell_types, method=use_type, probs=probs)
504
+ del expr_fc
505
+
506
+ if verbose:
507
+ print(f" step1 done: {_time.time() - t_total:.2f}s")
508
+
509
+ # Step 2: Regulon scoring
510
+ tfs_set = triple_relation["TF_Symbol"].unique().tolist()
511
+
512
+ target_lookup = {}
513
+ for tf, group in target_relation.groupby("TF_Symbol"):
514
+ target_lookup[tf] = group["Target_Symbol"].tolist()
515
+
516
+ # Step 2: Regulon scoring
517
+ if regulons_matrix is not None:
518
+ if verbose:
519
+ print(f"step2: using pre-computed regulons ({regulons_matrix.shape})")
520
+ elif regulons_cache and os.path.exists(regulons_cache):
521
+ regulons_matrix = pd.read_csv(regulons_cache, index_col=0)
522
+ if verbose:
523
+ print(f"step2: loaded regulons from cache ({regulons_matrix.shape})")
524
+ else:
525
+ if verbose:
526
+ print("step2: score regulons")
527
+ regulons_matrix = _compute_regulon_scores(
528
+ expr_set=expr_set, cell_types=cell_types, tfs_set=tfs_set,
529
+ target_lookup=target_lookup, fc_list=fc_list, expr_mean=expr_mean,
530
+ detect_genes=detect_genes, p_value_cor=p_value_cor,
531
+ cor_value=cor_value, top_target_cor=top_target_cor,
532
+ p_adjust=p_adjust, min_gs_size=5, verbose=verbose, n_perm=n_perm,
533
+ z_threshold=z_threshold, use_bh=use_bh,
534
+ )
535
+
536
+ if verbose:
537
+ print(f" step2 done: {_time.time() - t_total:.2f}s")
538
+
539
+ # Step 3: KEGG distances
540
+ distance_kegg = get_distance_kegg(triple_relation, method="mean")
541
+
542
+ # Step 4: L-R regulon scoring
543
+ if verbose:
544
+ print("step3-4: score L-R regulon activation")
545
+
546
+ lr_inter = triple_relation.iloc[:, [4, 5]].drop_duplicates()
547
+ lr_labels = (lr_inter.iloc[:, 0].astype(str) + "-" + lr_inter.iloc[:, 1].astype(str)).values
548
+ expr_r_regulons = pd.DataFrame(0.0, index=lr_labels, columns=cell_types)
549
+
550
+ lr_tf_lookup = {}
551
+ for lr_key, group in triple_relation.groupby(
552
+ triple_relation.iloc[:, 4].astype(str) + "-" + triple_relation.iloc[:, 5].astype(str)
553
+ ):
554
+ lr_tf_lookup[lr_key] = group["TF_Symbol"].unique().tolist()
555
+
556
+ for lr_key in lr_labels:
557
+ parts = lr_key.split("-", 1)
558
+ sender_tmp, receiver_tmp = parts[0], parts[1]
559
+ if sender_tmp not in detect_genes or receiver_tmp not in detect_genes:
560
+ continue
561
+ tfs_tmp = [t for t in lr_tf_lookup.get(lr_key, []) if t in detect_genes]
562
+ if not tfs_tmp:
563
+ continue
564
+ regulon_tmp = regulons_matrix.loc[regulons_matrix.index.isin(tfs_tmp)]
565
+ if regulon_tmp.empty:
566
+ continue
567
+ if method == "max":
568
+ expr_r_regulons.loc[lr_key] = regulon_tmp.max(axis=0).values
569
+ elif method == "weighted":
570
+ valid_tfs = [t for t in tfs_tmp if t in distance_kegg.columns and lr_key in distance_kegg.index]
571
+ if valid_tfs:
572
+ dist_vals = distance_kegg.loc[lr_key, valid_tfs].replace(0, np.inf)
573
+ w = (1.0 / dist_vals)
574
+ w = w / w.sum()
575
+ # Ensure index alignment
576
+ rt_subset = regulon_tmp.loc[regulon_tmp.index.isin(valid_tfs)]
577
+ # Reindex w to match rt_subset's index
578
+ w_aligned = w.reindex(rt_subset.index).fillna(0)
579
+ weighted_vals = rt_subset.multiply(w_aligned.values, axis=0).sum(axis=0)
580
+ expr_r_regulons.loc[lr_key] = weighted_vals.values
581
+ elif method == "mean":
582
+ expr_r_regulons.loc[lr_key] = regulon_tmp.mean(axis=0).values
583
+
584
+ # Step 5: Softmax
585
+ if verbose:
586
+ print("step5: softmax")
587
+
588
+ ligand_symbols = triple_relation["Ligand_Symbol"].unique()
589
+ ligand_in = [g for g in ligand_symbols if g in detect_genes and g in expr_mean.index]
590
+ softmax_ligand = expr_mean.loc[ligand_in].copy()
591
+ row_sums = softmax_ligand.sum(axis=1).replace(0, 1)
592
+ softmax_ligand = softmax_ligand.div(row_sums, axis=0)
593
+
594
+ receptor_symbols = triple_relation["Receptor_Symbol"].unique()
595
+ receptor_in = [g for g in receptor_symbols if g in detect_genes and g in expr_mean.index]
596
+ softmax_receptor = expr_mean.loc[receptor_in].copy()
597
+ row_sums = softmax_receptor.sum(axis=1).replace(0, 1)
598
+ softmax_receptor = softmax_receptor.div(row_sums, axis=0)
599
+
600
+ # Step 6: L-R scoring
601
+ if verbose:
602
+ print("step6: score L-R relations")
603
+
604
+ lr_pairs = triple_relation.iloc[:, [4, 5]].drop_duplicates()
605
+ lr_pair_labels = (lr_pairs.iloc[:, 0].astype(str) + "-" + lr_pairs.iloc[:, 1].astype(str)).values
606
+ cc_cols = [f"{i}-{j}" for i in cell_types for j in cell_types]
607
+ expr_l_r = pd.DataFrame(0.0, index=lr_pair_labels, columns=cc_cols)
608
+
609
+ sl = softmax_ligand
610
+ sr = softmax_receptor
611
+
612
+ for lr_key in lr_pair_labels:
613
+ parts = lr_key.split("-", 1)
614
+ sender, receiver = parts[0], parts[1]
615
+ if sender not in detect_genes or receiver not in detect_genes:
616
+ continue
617
+ if sender not in sl.index or receiver not in sr.index:
618
+ continue
619
+ sender_vals = sl.loc[sender].values
620
+ receiver_vals = sr.loc[receiver].values
621
+ tfs_tmp = [t for t in lr_tf_lookup.get(lr_key, []) if t in detect_genes]
622
+ if not tfs_tmp:
623
+ continue
624
+ regulon_tmp = regulons_matrix.loc[regulons_matrix.index.isin(tfs_tmp)]
625
+ if regulon_tmp.empty:
626
+ continue
627
+ if method == "weighted":
628
+ valid_tfs = [t for t in tfs_tmp if t in distance_kegg.columns and lr_key in distance_kegg.index]
629
+ if valid_tfs:
630
+ dist_vals = distance_kegg.loc[lr_key, valid_tfs].replace(0, np.inf)
631
+ w = (1.0 / dist_vals)
632
+ w = w / w.sum()
633
+ rt_subset = regulon_tmp.loc[regulon_tmp.index.isin(valid_tfs)]
634
+ w_aligned = w.reindex(rt_subset.index).fillna(0)
635
+ tf_per_ct = rt_subset.multiply(w_aligned.values, axis=0).sum(axis=0).values
636
+ else:
637
+ continue
638
+ elif method == "max":
639
+ tf_per_ct = regulon_tmp.max(axis=0).values
640
+ else:
641
+ tf_per_ct = regulon_tmp.mean(axis=0).values
642
+
643
+ sender_expr = expr_mean.loc[sender].values
644
+ receiver_expr = expr_mean.loc[receiver].values
645
+
646
+ for i, sct in enumerate(cell_types):
647
+ for j, rct in enumerate(cell_types):
648
+ if tf_per_ct[j] > 0 and sender_expr[i] > 0 and receiver_expr[j] > 0:
649
+ val = 100 * (sender_vals[i] ** 2 + receiver_vals[j] ** 2) * tf_per_ct[j]
650
+ expr_l_r.loc[lr_key, f"{sct}-{rct}"] = val
651
+
652
+ expr_l_r = expr_l_r.loc[(expr_l_r != 0).any(axis=1)]
653
+
654
+ # Log2 + scale
655
+ expr_l_r_log2 = np.log2(expr_l_r + 1)
656
+ log2_min = expr_l_r_log2.min().min()
657
+ log2_max = expr_l_r_log2.max().max()
658
+ if log2_max > log2_min:
659
+ expr_l_r_log2_scale = (expr_l_r_log2 - log2_min) / (log2_max - log2_min)
660
+ else:
661
+ expr_l_r_log2_scale = expr_l_r_log2 * 0
662
+
663
+ if verbose:
664
+ print(f"Total: {_time.time() - t_total:.2f}s")
665
+
666
+ return {
667
+ "expr_mean": expr_mean,
668
+ "regulons_matrix": regulons_matrix,
669
+ "fc_list": fc_list,
670
+ "expr_r_regulons": expr_r_regulons,
671
+ "softmax_ligand": softmax_ligand,
672
+ "softmax_receptor": softmax_receptor,
673
+ "expr_l_r": expr_l_r,
674
+ "expr_l_r_log2": expr_l_r_log2,
675
+ "expr_l_r_log2_scale": expr_l_r_log2_scale,
676
+ "DistanceKEGG": distance_kegg,
677
+ }
678
+
679
+
680
+ def trans_commu_profile(object, p_value_cor=0.05, cor_value=0.1, top_target_cor=1.0,
681
+ p_adjust=0.02, use_type="median", probs=0.75,
682
+ method="weighted", org="Homo sapiens", is_core=True,
683
+ extdata_dir=None, verbose=True, n_perm=500,
684
+ regulons_matrix=None, regulons_cache=None):
685
+ """Run TransCommuProfile - main entry point matching R's API."""
686
+ profile = connect_profile(
687
+ object, p_value_cor=p_value_cor, cor_value=cor_value,
688
+ top_target_cor=top_target_cor, p_adjust=p_adjust, use_type=use_type,
689
+ probs=probs, method=method, org=org, is_core=is_core,
690
+ extdata_dir=extdata_dir, verbose=verbose, n_perm=n_perm,
691
+ regulons_matrix=regulons_matrix, regulons_cache=regulons_cache,
692
+ )
693
+ object.data.update(profile)
694
+ return object