scmcp-shared 0.1.0__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,368 @@
1
+
2
+ import os
3
+ import inspect
4
+ import scanpy as sc
5
+ from fastmcp import FastMCP , Context
6
+ from fastmcp.exceptions import ToolError
7
+ from ..schema.pp import *
8
+ from ..util import filter_args, add_op_log, forward_request, get_ads, generate_msg
9
+ from ..logging_config import setup_logger
10
+ logger = setup_logger()
11
+
12
+
13
+ pp_mcp = FastMCP("ScanpyMCP-PP-Server")
14
+
15
+
16
+ @pp_mcp.tool()
17
+ async def subset_cells(
18
+ request: SubsetCellModel = SubsetCellModel()
19
+ ):
20
+ """filter or subset cells based on total genes expressed counts and numbers. or values in adata.obs[obs_key]"""
21
+
22
+ try:
23
+ result = await forward_request("subset_cells", request)
24
+ if result is not None:
25
+ return result
26
+
27
+ ads = get_ads()
28
+ adata = ads.get_adata(request=request).copy()
29
+ func_kwargs = filter_args(request, sc.pp.filter_cells)
30
+ if func_kwargs:
31
+ sc.pp.filter_cells(adata, **func_kwargs)
32
+ add_op_log(adata, sc.pp.filter_cells, func_kwargs)
33
+ # Subset based on obs (cells) criteria
34
+ if request.obs_key is not None:
35
+ if request.obs_key not in adata.obs.columns:
36
+ raise ValueError(f"Key '{request.obs_key}' not found in adata.obs")
37
+ mask = True # Start with all cells selected
38
+ if request.obs_value is not None:
39
+ mask = mask & (adata.obs[request.obs_key] == request.obs_value)
40
+ if request.obs_min is not None:
41
+ mask = mask & (adata.obs[request.obs_key] >= request.obs_min)
42
+ if request.obs_max is not None:
43
+ mask = mask & (adata.obs[request.obs_key] <= request.obs_max)
44
+ adata = adata[mask, :]
45
+ add_op_log(adata, "subset_cells",
46
+ {
47
+ "obs_key": request.obs_key, "obs_value": request.obs_value,
48
+ "obs_min": request.obs_min, "obs_max": request.obs_max
49
+ }
50
+ )
51
+ ads.set_adata(adata, request=request)
52
+ return [
53
+ generate_msg(request, adata, ads)
54
+ ]
55
+ except ToolError as e:
56
+ raise ToolError(e)
57
+ except Exception as e:
58
+ if hasattr(e, '__context__') and e.__context__:
59
+ raise ToolError(e.__context__)
60
+ else:
61
+ raise ToolError(e)
62
+
63
+
64
+ @pp_mcp.tool()
65
+ async def subset_genes(
66
+ request: SubsetGeneModel = SubsetGeneModel()
67
+ ):
68
+ """filter or subset genes based on number of cells or counts, or values in adata.var[var_key] or subset highly variable genes"""
69
+ try:
70
+ result = await forward_request("pp_subset_genes", request)
71
+ if result is not None:
72
+ return result
73
+ func_kwargs = filter_args(request, sc.pp.filter_genes)
74
+ ads = get_ads()
75
+ adata = ads.get_adata(request=request).copy()
76
+ if func_kwargs:
77
+ sc.pp.filter_genes(adata, **func_kwargs)
78
+ add_op_log(adata, sc.pp.filter_genes, func_kwargs)
79
+ if request.var_key is not None:
80
+ if request.var_key not in adata.var.columns:
81
+ raise ValueError(f"Key '{request.var_key}' not found in adata.var")
82
+ mask = True # Start with all genes selected
83
+ if request.var_min is not None:
84
+ mask = mask & (adata.var[request.var_key] >= request.var_min)
85
+ if request.var_max is not None:
86
+ mask = mask & (adata.var[request.var_key] <= request.var_max)
87
+ adata = adata[:, mask]
88
+ if request.highly_variable is not None:
89
+ adata = adata[:, adata.var.highly_variable]
90
+ add_op_log(adata, "subset_genes",
91
+ {
92
+ "var_key": request.var_key, "var_value": request.var_value,
93
+ "var_min": request.var_min, "var_max": request.var_max, "hpv": request.highly_variable
94
+ }
95
+ )
96
+ ads.set_adata(adata, request=request)
97
+ return [
98
+ generate_msg(request, adata, ads)
99
+ ]
100
+ except ToolError as e:
101
+ raise ToolError(e)
102
+ except Exception as e:
103
+ if hasattr(e, '__context__') and e.__context__:
104
+ raise ToolError(e.__context__)
105
+ else:
106
+ raise ToolError(e)
107
+
108
+ @pp_mcp.tool()
109
+ async def calculate_qc_metrics(
110
+ request: CalculateQCMetrics = CalculateQCMetrics()
111
+ ):
112
+ """Calculate quality control metrics(common metrics: total counts, gene number, percentage of counts in ribosomal and mitochondrial) for AnnData."""
113
+
114
+ try:
115
+ result = await forward_request("pp_calculate_qc_metrics", request)
116
+ if result is not None:
117
+ return result
118
+ logger.info(f"calculate_qc_metrics {request.model_dump()}")
119
+ func_kwargs = filter_args(request, sc.pp.calculate_qc_metrics)
120
+ ads = get_ads()
121
+ adata = ads.get_adata(request=request)
122
+ func_kwargs["inplace"] = True
123
+ try:
124
+ sc.pp.calculate_qc_metrics(adata, **func_kwargs)
125
+ add_op_log(adata, sc.pp.calculate_qc_metrics, func_kwargs)
126
+ except KeyError as e:
127
+ raise KeyError(f"Cound find {e} in adata.var")
128
+ return [
129
+ generate_msg(request, adata, ads)
130
+ ]
131
+ except ToolError as e:
132
+ raise ToolError(e)
133
+ except Exception as e:
134
+ if hasattr(e, '__context__') and e.__context__:
135
+ raise ToolError(e.__context__)
136
+ else:
137
+ raise ToolError(e)
138
+
139
+
140
+ @pp_mcp.tool()
141
+ async def log1p(
142
+ request: Log1PModel = Log1PModel()
143
+ ):
144
+ """Logarithmize the data matrix"""
145
+
146
+ try:
147
+ result = await forward_request("pp_log1p", request)
148
+ if result is not None:
149
+ return result
150
+ func_kwargs = filter_args(request, sc.pp.log1p)
151
+ ads = get_ads()
152
+ adata = ads.get_adata(request=request).copy()
153
+ try:
154
+ sc.pp.log1p(adata, **func_kwargs)
155
+ adata.raw = adata.copy()
156
+ add_op_log(adata, sc.pp.log1p, func_kwargs)
157
+ except Exception as e:
158
+ raise e
159
+ ads.set_adata(adata, request=request)
160
+ return [
161
+ generate_msg(request, adata, ads)
162
+ ]
163
+ except ToolError as e:
164
+ raise ToolError(e)
165
+ except Exception as e:
166
+ if hasattr(e, '__context__') and e.__context__:
167
+ raise ToolError(e.__context__)
168
+ else:
169
+ raise ToolError(e)
170
+
171
+
172
+ @pp_mcp.tool()
173
+ async def normalize_total(
174
+ request: NormalizeTotalModel = NormalizeTotalModel()
175
+ ):
176
+ """Normalize counts per cell to the same total count"""
177
+
178
+ try:
179
+ result = await forward_request("pp_normalize_total", request)
180
+ if result is not None:
181
+ return result
182
+ func_kwargs = filter_args(request, sc.pp.normalize_total)
183
+ ads = get_ads()
184
+ adata = ads.get_adata(request=request).copy()
185
+ sc.pp.normalize_total(adata, **func_kwargs)
186
+ add_op_log(adata, sc.pp.normalize_total, func_kwargs)
187
+ ads.set_adata(adata, request=request)
188
+ return [
189
+ generate_msg(request, adata, ads)
190
+ ]
191
+ except ToolError as e:
192
+ raise ToolError(e)
193
+ except Exception as e:
194
+ if hasattr(e, '__context__') and e.__context__:
195
+ raise ToolError(e.__context__)
196
+ else:
197
+ raise ToolError(e)
198
+
199
+
200
+
201
+ @pp_mcp.tool()
202
+ async def highly_variable_genes(
203
+ request: HighlyVariableGenesModel = HighlyVariableGenesModel()
204
+ ):
205
+ """Annotate highly variable genes"""
206
+
207
+ try:
208
+ result = await forward_request("pp_highly_variable_genes", request)
209
+ if result is not None:
210
+ return result
211
+ try:
212
+ func_kwargs = filter_args(request, sc.pp.highly_variable_genes)
213
+ ads = get_ads()
214
+ adata = ads.get_adata(request=request)
215
+ sc.pp.highly_variable_genes(adata, **func_kwargs)
216
+ add_op_log(adata, sc.pp.highly_variable_genes, func_kwargs)
217
+ except Exception as e:
218
+ logger.error(f"Error in pp_highly_variable_genes: {str(e)}")
219
+ raise e
220
+ return [
221
+ generate_msg(request, adata, ads)
222
+ ]
223
+ except ToolError as e:
224
+ raise ToolError(e)
225
+ except Exception as e:
226
+ if hasattr(e, '__context__') and e.__context__:
227
+ raise ToolError(e.__context__)
228
+ else:
229
+ raise ToolError(e)
230
+
231
+
232
+ @pp_mcp.tool()
233
+ async def regress_out(
234
+ request: RegressOutModel,
235
+
236
+ ):
237
+ """Regress out (mostly) unwanted sources of variation."""
238
+
239
+ try:
240
+ result = await forward_request("pp_regress_out", request)
241
+ if result is not None:
242
+ return result
243
+ func_kwargs = filter_args(request, sc.pp.regress_out)
244
+ ads = get_ads()
245
+ adata = ads.get_adata(request=request).copy()
246
+ sc.pp.regress_out(adata, **func_kwargs)
247
+ add_op_log(adata, sc.pp.regress_out, func_kwargs)
248
+ ads.set_adata(adata, request=request)
249
+ return [
250
+ generate_msg(request, adata, ads)
251
+ ]
252
+ except ToolError as e:
253
+ raise ToolError(e)
254
+ except Exception as e:
255
+ if hasattr(e, '__context__') and e.__context__:
256
+ raise ToolError(e.__context__)
257
+ else:
258
+ raise ToolError(e)
259
+
260
+ @pp_mcp.tool()
261
+ async def scale(
262
+ request: ScaleModel = ScaleModel()
263
+ ):
264
+ """Scale data to unit variance and zero mean"""
265
+
266
+ try:
267
+ result = await forward_request("pp_scale", request)
268
+ if result is not None:
269
+ return result
270
+ func_kwargs = filter_args(request, sc.pp.scale)
271
+ ads = get_ads()
272
+ adata = ads.get_adata(request=request).copy()
273
+
274
+ sc.pp.scale(adata, **func_kwargs)
275
+ add_op_log(adata, sc.pp.scale, func_kwargs)
276
+
277
+ ads.set_adata(adata, request=request)
278
+ return [
279
+ generate_msg(request, adata, ads)
280
+ ]
281
+ except ToolError as e:
282
+ raise ToolError(e)
283
+ except Exception as e:
284
+ if hasattr(e, '__context__') and e.__context__:
285
+ raise ToolError(e.__context__)
286
+ else:
287
+ raise ToolError(e)
288
+
289
+ @pp_mcp.tool()
290
+ async def combat(
291
+ request: CombatModel = CombatModel()
292
+ ):
293
+ """ComBat function for batch effect correction"""
294
+
295
+ try:
296
+ result = await forward_request("pp_combat", request)
297
+ if result is not None:
298
+ return result
299
+ func_kwargs = filter_args(request, sc.pp.combat)
300
+ ads = get_ads()
301
+ adata = ads.get_adata(request=request).copy()
302
+
303
+ sc.pp.combat(adata, **func_kwargs)
304
+ add_op_log(adata, sc.pp.combat, func_kwargs)
305
+
306
+ ads.set_adata(adata, request=request)
307
+ return [
308
+ generate_msg(request, adata, ads)
309
+ ]
310
+ except ToolError as e:
311
+ raise ToolError(e)
312
+ except Exception as e:
313
+ if hasattr(e, '__context__') and e.__context__:
314
+ raise ToolError(e.__context__)
315
+ else:
316
+ raise ToolError(e)
317
+
318
+ @pp_mcp.tool()
319
+ async def scrublet(
320
+ request: ScrubletModel = ScrubletModel()
321
+ ):
322
+ """Predict doublets using Scrublet"""
323
+
324
+ try:
325
+ result = await forward_request("pp_scrublet", request)
326
+ if result is not None:
327
+ return result
328
+ func_kwargs = filter_args(request, sc.pp.scrublet)
329
+ ads = get_ads()
330
+ adata = ads.get_adata(request=request)
331
+ sc.pp.scrublet(adata, **func_kwargs)
332
+ add_op_log(adata, sc.pp.scrublet, func_kwargs)
333
+ return [
334
+ generate_msg(request, adata, ads)
335
+ ]
336
+ except ToolError as e:
337
+ raise ToolError(e)
338
+ except Exception as e:
339
+ if hasattr(e, '__context__') and e.__context__:
340
+ raise ToolError(e.__context__)
341
+ else:
342
+ raise ToolError(e)
343
+
344
+ @pp_mcp.tool()
345
+ async def neighbors(
346
+ request: NeighborsModel = NeighborsModel()
347
+ ):
348
+ """Compute nearest neighbors distance matrix and neighborhood graph"""
349
+
350
+ try:
351
+ result = await forward_request("pp_neighbors", request)
352
+ if result is not None:
353
+ return result
354
+ func_kwargs = filter_args(request, sc.pp.neighbors)
355
+ ads = get_ads()
356
+ adata = ads.get_adata(request=request)
357
+ sc.pp.neighbors(adata, **func_kwargs)
358
+ add_op_log(adata, sc.pp.neighbors, func_kwargs)
359
+ return [
360
+ generate_msg(request, adata, ads)
361
+ ]
362
+ except ToolError as e:
363
+ raise ToolError(e)
364
+ except Exception as e:
365
+ if hasattr(e, '__context__') and e.__context__:
366
+ raise ToolError(e.__context__)
367
+ else:
368
+ raise ToolError(e)