scimappro 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. scimappro/__init__.py +1 -0
  2. scimappro/pl/__init__.py +1 -0
  3. scimappro/pl/archive/barplot 2.py +425 -0
  4. scimappro/pl/archive/barplot.py +442 -0
  5. scimappro/pl/barplot.py +559 -0
  6. scimappro/pl/heatmap.py +505 -0
  7. scimappro/pl/image_viewer.py +363 -0
  8. scimappro/pl/stacked_barplot.py +407 -0
  9. scimappro/pp/__init__.py +1 -0
  10. scimappro/pp/archive/combat-prestream.py +249 -0
  11. scimappro/pp/archive/rescale.py +407 -0
  12. scimappro/pp/combat.py +331 -0
  13. scimappro/pp/log1p.py +267 -0
  14. scimappro/pp/mcmicro_to_scimap.py +343 -0
  15. scimappro/pp/rescale.py +348 -0
  16. scimappro/tl/__init__.py +1 -0
  17. scimappro/tl/archive/cluster.py +285 -0
  18. scimappro/tl/archive/neighCount.py +313 -0
  19. scimappro/tl/archive/neighCount_beforestream.py +242 -0
  20. scimappro/tl/archive/neighExp.py +275 -0
  21. scimappro/tl/archive/neighLda.py +415 -0
  22. scimappro/tl/archive/phenotype pre rest.py +407 -0
  23. scimappro/tl/archive/phenotype pre stream.py +410 -0
  24. scimappro/tl/archive/phenotype.py +431 -0
  25. scimappro/tl/archive/spatialProximityScore.py +281 -0
  26. scimappro/tl/archive/spatialSimilarityLookup.py +231 -0
  27. scimappro/tl/archive/spatial_aggregate.py +201 -0
  28. scimappro/tl/archive/spatial_cooccurrence.py +289 -0
  29. scimappro/tl/archive/spatial_distance.py +200 -0
  30. scimappro/tl/archive/spatial_distance_prestrream.py +138 -0
  31. scimappro/tl/archive/umap.py +209 -0
  32. scimappro/tl/cluster.py +484 -0
  33. scimappro/tl/foldChange.py +253 -0
  34. scimappro/tl/neighCount.py +498 -0
  35. scimappro/tl/neighExp.py +323 -0
  36. scimappro/tl/neighLDA.py +587 -0
  37. scimappro/tl/neighNMF.py +431 -0
  38. scimappro/tl/phenotype.py +540 -0
  39. scimappro/tl/spatialProximityScore.py +321 -0
  40. scimappro/tl/spatialSimilarityLookup.py +426 -0
  41. scimappro/tl/spatial_aggregate.py +191 -0
  42. scimappro/tl/spatial_cooccurrence.py +342 -0
  43. scimappro/tl/spatial_distance.py +219 -0
  44. scimappro/tl/umap.py +224 -0
  45. scimappro-0.1.0.dist-info/METADATA +23 -0
  46. scimappro-0.1.0.dist-info/RECORD +47 -0
  47. scimappro-0.1.0.dist-info/WHEEL +4 -0
scimappro/__init__.py ADDED
@@ -0,0 +1 @@
1
+ from . import pp as pp
@@ -0,0 +1 @@
1
+ from .barplot import barplot
@@ -0,0 +1,425 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Refactored stacked bar plot function with cap_anndata streaming support.
5
+ Optimized for memory usage and computational speed for large datasets.
6
+ Author: Ajit Johnson Nirmal (refactored)
7
+ """
8
+
9
+ import os
10
+ import sys
11
+ import multiprocessing
12
+ import argparse
13
+ from typing import Optional, Tuple, List, Union
14
+ from pathlib import Path
15
+
16
+ import pandas as pd # For Pandas-based operations
17
+ import matplotlib.pyplot as plt
18
+ from matplotlib import rcParams
19
+ import matplotlib.colors as mcolors
20
+ import plotly.graph_objects as go
21
+ import plotly.io as pio
22
+ from tqdm import tqdm
23
+ import polars as pl
24
+
25
+ # Optional: Import anndata if available.
26
+ try:
27
+ import anndata
28
+ except ImportError:
29
+ anndata = None
30
+
31
+ # Set Plotly renderer and font settings.
32
+ pio.renderers.default = 'browser'
33
+ rcParams['pdf.fonttype'] = 42
34
+
35
+ def barplot(
36
+ adata,
37
+ xAxis: str = 'imageid',
38
+ yAxis: str = 'phenotype',
39
+ subsetXAxis: Optional[List[str]] = None,
40
+ subsetYAxis: Optional[List[str]] = None,
41
+ orderXAxis: Optional[List[str]] = None,
42
+ orderYAxis: Optional[List[str]] = None,
43
+ method: str = 'percent',
44
+ plottingMode: str = 'standard', # 'standard' or 'interactive'
45
+ figSize: Optional[Tuple[float, float]] = None,
46
+ fontSize: Optional[int] = None,
47
+ color: Optional[Union[str, List[str], dict]] = None,
48
+ palette: Optional[str] = None, # If None, auto palette is selected.
49
+ alpha: float = 1.0,
50
+ barWidth: Optional[float] = None, # Default bar width.
51
+ streamData: bool = False, # New parameter for streaming.
52
+ maxWorkers: Optional[int] = None,
53
+ verbose: bool = False,
54
+ outputDir: Optional[str] = None,
55
+ show: bool = True,
56
+ dpi: int = 300,
57
+ transparent: bool = False,
58
+ watermark: bool = True,
59
+ matplotlib_bbox_to_anchor=(1, 1.02),
60
+ matplotlib_legend_loc=2,
61
+ ax: Optional[plt.Axes] = None,
62
+ **kwargs,
63
+ ):
64
+ """
65
+ Generate a stacked bar plot visualizing the distribution of categories within a specified
66
+ categorical column across different groups in an AnnData object (or .h5ad file).
67
+
68
+ When streamData=True, the function uses cap_anndata to stream the AnnData object from disk.
69
+ Only the necessary columns from .obs are read into memory. After plotting, the streamed file
70
+ is flushed and closed, and the function returns None.
71
+
72
+ The outputDir parameter is interpreted as follows:
73
+ - If outputDir (a string) has a file suffix (e.g. ".pdf", ".html", ".png", etc.) and is not an
74
+ existing directory, it is treated as the full output file path.
75
+ - Otherwise, it is treated as a directory, and a default file name is appended based on the plotting mode.
76
+ """
77
+ # Determine maxWorkers if not provided.
78
+ if maxWorkers is None:
79
+ maxWorkers = max(1, multiprocessing.cpu_count() - 1)
80
+ if verbose:
81
+ print(f"Using {maxWorkers} parallel worker(s) for data processing.")
82
+
83
+ # Load the AnnData object.
84
+ if streamData:
85
+ try:
86
+ from cap_anndata import read_h5ad
87
+ except ImportError:
88
+ raise ImportError("The 'cap_anndata' package is required when streamData=True.")
89
+ if isinstance(adata, str):
90
+ if verbose:
91
+ print(f"Streaming AnnData from disk: {adata}")
92
+ cap_adata = read_h5ad(adata, edit=False)
93
+ else:
94
+ cap_adata = adata
95
+ cap_adata.read_obs(columns=[xAxis, yAxis])
96
+ obs_df = cap_adata.obs.copy().astype(str)
97
+ else:
98
+ # If adata is a string in non-streaming mode, load it using anndata.
99
+ if isinstance(adata, str):
100
+ if verbose:
101
+ print(f"Loading AnnData from file path: {adata}")
102
+ if anndata is None:
103
+ raise ImportError("The 'anndata' package is required to load h5ad files.")
104
+ adata = anndata.read_h5ad(adata)
105
+ obs_df = adata.obs[[xAxis, yAxis]].astype(str)
106
+
107
+ # Convert obs DataFrame to a Polars DataFrame.
108
+ if verbose:
109
+ print("Converting data to Polars DataFrame...")
110
+ plDF = pl.from_pandas(obs_df)
111
+
112
+ # Apply subsetting.
113
+ if subsetXAxis is not None:
114
+ if isinstance(subsetXAxis, str):
115
+ subsetXAxis = [subsetXAxis]
116
+ plDF = plDF.filter(pl.col(xAxis).is_in(subsetXAxis))
117
+ if subsetYAxis is not None:
118
+ if isinstance(subsetYAxis, str):
119
+ subsetYAxis = [subsetYAxis]
120
+ plDF = plDF.filter(pl.col(yAxis).is_in(subsetYAxis))
121
+
122
+ # Grouping and aggregation.
123
+ if verbose:
124
+ print("Performing groupby and aggregation...")
125
+ lazyDF = plDF.lazy()
126
+ if hasattr(lazyDF, "group_by"):
127
+ groupedLazy = lazyDF.group_by([xAxis, yAxis]).agg(pl.count().alias("count"))
128
+ groupedDF = groupedLazy.collect()
129
+ else:
130
+ if verbose:
131
+ print("Lazy group_by not available; using eager group_by instead.")
132
+ groupedDF = plDF.group_by([xAxis, yAxis]).agg(pl.count().alias("count"))
133
+
134
+ # Compute percentages or absolute counts.
135
+ if method == 'percent':
136
+ if verbose:
137
+ print("Calculating percentage proportions for each group...")
138
+ totals = groupedDF.group_by([xAxis]).agg(pl.col("count").sum().alias("total"))
139
+ groupedDF = groupedDF.join(totals, on=xAxis)
140
+ groupedDF = groupedDF.with_columns((pl.col("count") / pl.col("total")).alias("value"))
141
+ elif method == 'absolute':
142
+ groupedDF = groupedDF.with_columns(pl.col("count").alias("value"))
143
+ else:
144
+ raise ValueError("method should be either 'percent' or 'absolute'")
145
+
146
+ # Convert aggregated data to a Pandas DataFrame for pivoting.
147
+ if verbose:
148
+ print("Converting aggregated data to Pandas DataFrame for pivoting...")
149
+ df = groupedDF.to_pandas()
150
+
151
+ # Apply custom ordering if provided.
152
+ if orderXAxis is not None:
153
+ df[xAxis] = pd.Categorical(df[xAxis], categories=orderXAxis, ordered=True)
154
+ if orderYAxis is not None:
155
+ df[yAxis] = pd.Categorical(df[yAxis], categories=orderYAxis, ordered=True)
156
+ df.sort_values(by=[xAxis, yAxis], inplace=True)
157
+
158
+ # Create pivot table.
159
+ pivotDF = df.pivot(index=xAxis, columns=yAxis, values="value").fillna(0)
160
+ if verbose:
161
+ print("Pivot table created:")
162
+ print(pivotDF)
163
+ if pivotDF.empty:
164
+ if verbose:
165
+ print("Warning: The pivot table is empty. Check your data or subset parameters.")
166
+ if streamData:
167
+ if verbose:
168
+ print("Flushing and closing streaming AnnData object.")
169
+ try:
170
+ cap_adata.overwrite(["obs"])
171
+ except Exception:
172
+ pass
173
+ if hasattr(cap_adata, "file"):
174
+ cap_adata.file.close()
175
+ return None
176
+
177
+ # --- Plotting ---
178
+ if plottingMode not in ['standard', 'interactive']:
179
+ raise ValueError("plottingMode must be either 'standard' or 'interactive'")
180
+
181
+ if plottingMode == 'standard':
182
+ if verbose:
183
+ print("Creating standard matplotlib plot...")
184
+ nCats = len(pivotDF.columns)
185
+ # Auto-select palette.
186
+ if color is None:
187
+ if palette is None:
188
+ if nCats <= 9:
189
+ cmap = plt.get_cmap("Set1")
190
+ elif nCats <= 20:
191
+ cmap = plt.get_cmap("tab20")
192
+ else:
193
+ cmap = plt.get_cmap("gist_ncar")
194
+ else:
195
+ cmap = plt.get_cmap(palette)
196
+ elif isinstance(color, dict):
197
+ colorMapping = color
198
+ else:
199
+ cmap = color
200
+
201
+ # Smart figure sizing: narrow width.
202
+ if ax is None:
203
+ if figSize is None:
204
+ nGroups = len(pivotDF.index)
205
+ figWidth = max(6, nGroups * 0.3)
206
+ figSize = (figWidth, 6)
207
+ fig, ax = plt.subplots(figsize=figSize)
208
+ else:
209
+ fig = ax.figure
210
+
211
+ if barWidth is None:
212
+ barWidth = 0.95
213
+
214
+ xPositions = range(len(pivotDF.index))
215
+ bottoms = [0] * len(pivotDF.index)
216
+ categories = list(pivotDF.columns)
217
+ if (color is None or (not isinstance(color, dict) and not isinstance(cmap, str) and hasattr(cmap, '__call__'))):
218
+ norm = plt.Normalize(vmin=0, vmax=nCats - 1)
219
+
220
+ # Remove extraneous kwargs.
221
+ keys_to_remove = [
222
+ "plottingMode",
223
+ "matplotlib_cmap",
224
+ "matplotlib_bbox_to_anchor",
225
+ "matplotlib_legend_loc",
226
+ "outputDir"
227
+ ]
228
+ for key in keys_to_remove:
229
+ kwargs.pop(key, None)
230
+
231
+ for idx, cat in enumerate(tqdm(categories, desc="Plotting categories", disable=not verbose)):
232
+ if isinstance(color, dict):
233
+ catColor = colorMapping.get(cat, plt.get_cmap("Set1")(norm(idx)) if 'norm' in locals() else None)
234
+ elif isinstance(cmap, list):
235
+ catColor = cmap[idx % len(cmap)]
236
+ elif isinstance(cmap, str):
237
+ catColor = cmap
238
+ elif hasattr(cmap, '__call__'):
239
+ catColor = mcolors.to_hex(cmap(norm(idx)))
240
+ else:
241
+ catColor = None
242
+
243
+ values = pivotDF[cat].values
244
+ ax.bar(
245
+ xPositions,
246
+ values,
247
+ bottom=bottoms,
248
+ label=str(cat),
249
+ color=catColor,
250
+ alpha=alpha,
251
+ width=barWidth,
252
+ **kwargs
253
+ )
254
+ bottoms = [bottoms[i] + values[i] for i in range(len(values))]
255
+
256
+ handles, labels = ax.get_legend_handles_labels()
257
+ if len(handles) > 1:
258
+ handles = list(reversed(handles))
259
+ labels = list(reversed(labels))
260
+ ax.legend(handles, labels, bbox_to_anchor=matplotlib_bbox_to_anchor, loc=matplotlib_legend_loc)
261
+
262
+ ax.set_xticks(xPositions)
263
+ ax.set_xticklabels(pivotDF.index, rotation=45, ha="right", fontsize=fontSize if fontSize else 10)
264
+ ax.set_xlabel(xAxis, fontsize=fontSize if fontSize else 12)
265
+ ax.set_ylabel("Percentage" if method == 'percent' else "Count", fontsize=fontSize if fontSize else 12)
266
+
267
+ ax.set_facecolor("white")
268
+ fig.patch.set_facecolor("white")
269
+ ax.grid(False)
270
+
271
+ plt.subplots_adjust(bottom=0.3)
272
+ if watermark:
273
+ fig.text(
274
+ 0.99, 0.02, "made with scimap",
275
+ transform=fig.transFigure,
276
+ fontsize=8,
277
+ color="#AAAAAA",
278
+ ha='right',
279
+ va='bottom',
280
+ alpha=0.5
281
+ )
282
+ plt.tight_layout(rect=[0, 0.05, 1, 1])
283
+
284
+ # Handle output path using pathlib.
285
+ if outputDir:
286
+ out_path = Path(outputDir)
287
+ if out_path.suffix and not out_path.is_dir():
288
+ fullPath = str(out_path)
289
+ else:
290
+ out_path.mkdir(parents=True, exist_ok=True)
291
+ fullPath = str(out_path / "scimap_barplot.pdf")
292
+ fig.savefig(fullPath, dpi=dpi, transparent=transparent)
293
+ if verbose:
294
+ print(f"Saved plot to {fullPath}")
295
+
296
+ if show:
297
+ plt.show()
298
+ else:
299
+ plt.close(fig)
300
+
301
+ else:
302
+ if verbose:
303
+ print("Creating interactive Plotly plot...")
304
+ fig = go.Figure()
305
+ nCats = len(pivotDF.columns)
306
+ if color is None:
307
+ if palette is None:
308
+ if nCats <= 9:
309
+ cmap = plt.get_cmap("Set1")
310
+ elif nCats <= 20:
311
+ cmap = plt.get_cmap("tab20")
312
+ else:
313
+ cmap = plt.get_cmap("gist_ncar")
314
+ else:
315
+ cmap = plt.get_cmap(palette)
316
+ norm = plt.Normalize(vmin=0, vmax=nCats - 1)
317
+ xCategories = pivotDF.index.tolist()
318
+ for idx, cat in enumerate(tqdm(pivotDF.columns, desc="Plotting categories", disable=not verbose)):
319
+ if isinstance(color, dict):
320
+ catColor = color.get(cat, None)
321
+ elif isinstance(color, list):
322
+ catColor = color[idx % len(color)]
323
+ elif isinstance(color, str):
324
+ catColor = color
325
+ elif cmap is not None:
326
+ catColor = mcolors.to_hex(cmap(norm(idx)))
327
+ else:
328
+ catColor = None
329
+
330
+ fig.add_trace(
331
+ go.Bar(
332
+ x=xCategories,
333
+ y=pivotDF[cat].values,
334
+ name=str(cat),
335
+ marker_color=catColor,
336
+ opacity=alpha,
337
+ **kwargs
338
+ )
339
+ )
340
+ fig.update_layout(
341
+ barmode='stack',
342
+ xaxis_title=xAxis,
343
+ yaxis_title="Percentage" if method == 'percent' else "Count",
344
+ plot_bgcolor='rgba(0, 0, 0, 0)',
345
+ paper_bgcolor='rgba(0, 0, 0, 0)',
346
+ xaxis=dict(showline=True, linecolor='black', linewidth=1, showgrid=False, zeroline=False),
347
+ yaxis=dict(showline=True, linecolor='black', linewidth=1, showgrid=False, zeroline=False),
348
+ margin=dict(b=150)
349
+ )
350
+ if watermark:
351
+ fig.add_annotation(
352
+ text="made with scimap",
353
+ xref="paper", yref="paper",
354
+ x=0.99, y=-0.2,
355
+ showarrow=False,
356
+ font=dict(size=8, color="#AAAAAA"),
357
+ opacity=0.5,
358
+ xanchor='right',
359
+ yanchor='top'
360
+ )
361
+ if outputDir:
362
+ out_path = Path(outputDir)
363
+ if out_path.suffix and not out_path.is_dir():
364
+ if out_path.suffix.lower() != ".html":
365
+ print("Warning: Interactive mode requires a .html file. Changing extension to .html.")
366
+ fullPath = str(out_path.with_suffix(".html"))
367
+ else:
368
+ fullPath = str(out_path)
369
+ else:
370
+ out_path.mkdir(parents=True, exist_ok=True)
371
+ fullPath = str(out_path / "scimap_barplot.html")
372
+ fig.write_html(fullPath)
373
+ if verbose:
374
+ print(f"Saved interactive plot to {fullPath}")
375
+ if show:
376
+ fig.show()
377
+
378
+ if streamData:
379
+ if verbose:
380
+ print("Flushing and closing streaming AnnData object.")
381
+ try:
382
+ cap_adata.overwrite(["obs"])
383
+ except Exception:
384
+ pass
385
+ if hasattr(cap_adata, "file"):
386
+ cap_adata.file.close()
387
+ return None
388
+
389
+ return (fig, ax) if plottingMode == 'standard' else fig
390
+
391
+ def main():
392
+ parser = argparse.ArgumentParser(
393
+ description="Generate a stacked bar plot from an AnnData object or h5ad file."
394
+ )
395
+ parser.add_argument("adata", type=str,
396
+ help="Path to an h5ad file or identifier for an AnnData object.")
397
+ parser.add_argument("--xAxis", type=str, default="imageid",
398
+ help="Column for x-axis categories.")
399
+ parser.add_argument("--yAxis", type=str, default="phenotype",
400
+ help="Column for y-axis categories.")
401
+ parser.add_argument("--method", type=str, default="percent",
402
+ choices=["percent", "absolute"], help="Plotting method.")
403
+ parser.add_argument("--plottingMode", type=str, default="standard",
404
+ choices=["standard", "interactive"], help="Plotting mode.")
405
+ parser.add_argument("--outputDir", type=str, default=None,
406
+ help="Output file path (if it has a suffix) or directory.")
407
+ parser.add_argument("--verbose", action="store_true",
408
+ help="Enable verbose output.")
409
+ parser.add_argument("--streamData", action="store_true",
410
+ help="Enable streaming of AnnData using cap_anndata.")
411
+ args = parser.parse_args()
412
+
413
+ barplot(
414
+ adata=args.adata,
415
+ xAxis=args.xAxis,
416
+ yAxis=args.yAxis,
417
+ method=args.method,
418
+ plottingMode=args.plottingMode,
419
+ streamData=args.streamData,
420
+ outputDir=args.outputDir,
421
+ verbose=args.verbose
422
+ )
423
+
424
+ if __name__ == '__main__':
425
+ main()