scmcp-shared 0.2.0__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
scmcp_shared/server/pp.py CHANGED
@@ -3,7 +3,9 @@ import os
3
3
  import inspect
4
4
  import scanpy as sc
5
5
  from fastmcp import FastMCP , Context
6
+ from fastmcp.exceptions import ToolError
6
7
  from ..schema.pp import *
8
+ from ..schema import AdataModel
7
9
  from ..util import filter_args, add_op_log, forward_request, get_ads, generate_msg
8
10
  from ..logging_config import setup_logger
9
11
  logger = setup_logger()
@@ -14,21 +16,22 @@ pp_mcp = FastMCP("ScanpyMCP-PP-Server")
14
16
 
15
17
  @pp_mcp.tool()
16
18
  async def subset_cells(
17
- request: SubsetCellModel = SubsetCellModel()
19
+ request: SubsetCellModel = SubsetCellModel(),
20
+ adinfo: AdataModel = AdataModel()
18
21
  ):
19
22
  """filter or subset cells based on total genes expressed counts and numbers. or values in adata.obs[obs_key]"""
20
23
 
21
24
  try:
22
- result = await forward_request("subset_cells", request)
25
+ result = await forward_request("subset_cells", request, adinfo)
23
26
  if result is not None:
24
27
  return result
25
28
 
26
29
  ads = get_ads()
27
- adata = ads.get_adata(request=request).copy()
30
+ adata = ads.get_adata(adinfo=adinfo).copy()
28
31
  func_kwargs = filter_args(request, sc.pp.filter_cells)
29
32
  if func_kwargs:
30
33
  sc.pp.filter_cells(adata, **func_kwargs)
31
- add_op_log(adata, sc.pp.filter_cells, func_kwargs)
34
+ add_op_log(adata, sc.pp.filter_cells, func_kwargs, adinfo)
32
35
  # Subset based on obs (cells) criteria
33
36
  if request.obs_key is not None:
34
37
  if request.obs_key not in adata.obs.columns:
@@ -45,36 +48,37 @@ async def subset_cells(
45
48
  {
46
49
  "obs_key": request.obs_key, "obs_value": request.obs_value,
47
50
  "obs_min": request.obs_min, "obs_max": request.obs_max
48
- }
51
+ }, adinfo
49
52
  )
50
- ads.set_adata(adata, request=request)
53
+ ads.set_adata(adata, adinfo=adinfo)
51
54
  return [
52
- generate_msg(request, adata, ads)
55
+ generate_msg(adinfo, adata, ads)
53
56
  ]
54
- except KeyError as e:
55
- raise e
57
+ except ToolError as e:
58
+ raise ToolError(e)
56
59
  except Exception as e:
57
60
  if hasattr(e, '__context__') and e.__context__:
58
- raise Exception(f"{str(e.__context__)}")
61
+ raise ToolError(e.__context__)
59
62
  else:
60
- raise e
63
+ raise ToolError(e)
61
64
 
62
65
 
63
66
  @pp_mcp.tool()
64
67
  async def subset_genes(
65
- request: SubsetGeneModel = SubsetGeneModel()
68
+ request: SubsetGeneModel = SubsetGeneModel(),
69
+ adinfo: AdataModel = AdataModel()
66
70
  ):
67
71
  """filter or subset genes based on number of cells or counts, or values in adata.var[var_key] or subset highly variable genes"""
68
72
  try:
69
- result = await forward_request("pp_subset_genes", request)
73
+ result = await forward_request("pp_subset_genes", request, adinfo)
70
74
  if result is not None:
71
75
  return result
72
76
  func_kwargs = filter_args(request, sc.pp.filter_genes)
73
77
  ads = get_ads()
74
- adata = ads.get_adata(request=request).copy()
78
+ adata = ads.get_adata(adinfo=adinfo).copy()
75
79
  if func_kwargs:
76
80
  sc.pp.filter_genes(adata, **func_kwargs)
77
- add_op_log(adata, sc.pp.filter_genes, func_kwargs)
81
+ add_op_log(adata, sc.pp.filter_genes, func_kwargs, adinfo)
78
82
  if request.var_key is not None:
79
83
  if request.var_key not in adata.var.columns:
80
84
  raise ValueError(f"Key '{request.var_key}' not found in adata.var")
@@ -90,274 +94,286 @@ async def subset_genes(
90
94
  {
91
95
  "var_key": request.var_key, "var_value": request.var_value,
92
96
  "var_min": request.var_min, "var_max": request.var_max, "hpv": request.highly_variable
93
- }
97
+ }, adinfo
94
98
  )
95
- ads.set_adata(adata, request=request)
99
+ ads.set_adata(adata, adinfo=adinfo)
96
100
  return [
97
- generate_msg(request, adata, ads)
101
+ generate_msg(adinfo, adata, ads)
98
102
  ]
99
- except KeyError as e:
100
- raise e
103
+ except ToolError as e:
104
+ raise ToolError(e)
101
105
  except Exception as e:
102
106
  if hasattr(e, '__context__') and e.__context__:
103
- raise Exception(f"{str(e.__context__)}")
107
+ raise ToolError(e.__context__)
104
108
  else:
105
- raise e
109
+ raise ToolError(e)
106
110
 
107
111
  @pp_mcp.tool()
108
112
  async def calculate_qc_metrics(
109
- request: CalculateQCMetrics = CalculateQCMetrics()
113
+ request: CalculateQCMetrics = CalculateQCMetrics(),
114
+ adinfo: AdataModel = AdataModel()
110
115
  ):
111
116
  """Calculate quality control metrics(common metrics: total counts, gene number, percentage of counts in ribosomal and mitochondrial) for AnnData."""
112
117
 
113
118
  try:
114
- result = await forward_request("pp_calculate_qc_metrics", request)
119
+ result = await forward_request("pp_calculate_qc_metrics", request, adinfo)
115
120
  if result is not None:
116
121
  return result
117
122
  logger.info(f"calculate_qc_metrics {request.model_dump()}")
118
123
  func_kwargs = filter_args(request, sc.pp.calculate_qc_metrics)
119
124
  ads = get_ads()
120
- adata = ads.get_adata(request=request)
125
+ adata = ads.get_adata(adinfo=adinfo)
121
126
  func_kwargs["inplace"] = True
122
127
  try:
123
128
  sc.pp.calculate_qc_metrics(adata, **func_kwargs)
124
- add_op_log(adata, sc.pp.calculate_qc_metrics, func_kwargs)
129
+ add_op_log(adata, sc.pp.calculate_qc_metrics, func_kwargs, adinfo)
125
130
  except KeyError as e:
126
131
  raise KeyError(f"Cound find {e} in adata.var")
127
132
  return [
128
- generate_msg(request, adata, ads)
133
+ generate_msg(adinfo, adata, ads)
129
134
  ]
135
+ except ToolError as e:
136
+ raise ToolError(e)
130
137
  except Exception as e:
131
138
  if hasattr(e, '__context__') and e.__context__:
132
- raise Exception(f"{str(e.__context__)}")
139
+ raise ToolError(e.__context__)
133
140
  else:
134
- raise e
141
+ raise ToolError(e)
135
142
 
136
143
 
137
144
  @pp_mcp.tool()
138
145
  async def log1p(
139
- request: Log1PModel = Log1PModel()
146
+ request: Log1PModel = Log1PModel(),
147
+ adinfo: AdataModel = AdataModel()
140
148
  ):
141
149
  """Logarithmize the data matrix"""
142
150
 
143
151
  try:
144
- result = await forward_request("pp_log1p", request)
152
+ result = await forward_request("pp_log1p", request, adinfo)
145
153
  if result is not None:
146
154
  return result
147
155
  func_kwargs = filter_args(request, sc.pp.log1p)
148
156
  ads = get_ads()
149
- adata = ads.get_adata(request=request).copy()
157
+ adata = ads.get_adata(adinfo=adinfo).copy()
150
158
  try:
151
159
  sc.pp.log1p(adata, **func_kwargs)
152
160
  adata.raw = adata.copy()
153
- add_op_log(adata, sc.pp.log1p, func_kwargs)
161
+ add_op_log(adata, sc.pp.log1p, func_kwargs, adinfo)
154
162
  except Exception as e:
155
163
  raise e
156
- ads.set_adata(adata, request=request)
164
+ ads.set_adata(adata, adinfo=adinfo)
157
165
  return [
158
- generate_msg(request, adata, ads)
166
+ generate_msg(adinfo, adata, ads)
159
167
  ]
160
- except KeyError as e:
161
- raise e
168
+ except ToolError as e:
169
+ raise ToolError(e)
162
170
  except Exception as e:
163
171
  if hasattr(e, '__context__') and e.__context__:
164
- raise Exception(f"{str(e.__context__)}")
172
+ raise ToolError(e.__context__)
165
173
  else:
166
- raise e
174
+ raise ToolError(e)
167
175
 
168
176
 
169
177
  @pp_mcp.tool()
170
178
  async def normalize_total(
171
- request: NormalizeTotalModel = NormalizeTotalModel()
179
+ request: NormalizeTotalModel = NormalizeTotalModel(),
180
+ adinfo: AdataModel = AdataModel()
172
181
  ):
173
182
  """Normalize counts per cell to the same total count"""
174
183
 
175
184
  try:
176
- result = await forward_request("pp_normalize_total", request)
185
+ result = await forward_request("pp_normalize_total", request, adinfo)
177
186
  if result is not None:
178
187
  return result
179
188
  func_kwargs = filter_args(request, sc.pp.normalize_total)
180
189
  ads = get_ads()
181
- adata = ads.get_adata(request=request).copy()
190
+ adata = ads.get_adata(adinfo=adinfo).copy()
182
191
  sc.pp.normalize_total(adata, **func_kwargs)
183
- add_op_log(adata, sc.pp.normalize_total, func_kwargs)
184
- ads.set_adata(adata, request=request)
192
+ add_op_log(adata, sc.pp.normalize_total, func_kwargs, adinfo)
193
+ ads.set_adata(adata, adinfo=adinfo)
185
194
  return [
186
- generate_msg(request, adata, ads)
195
+ generate_msg(adinfo, adata, ads)
187
196
  ]
188
- except KeyError as e:
189
- raise e
197
+ except ToolError as e:
198
+ raise ToolError(e)
190
199
  except Exception as e:
191
200
  if hasattr(e, '__context__') and e.__context__:
192
- raise Exception(f"{str(e.__context__)}")
201
+ raise ToolError(e.__context__)
193
202
  else:
194
- raise e
203
+ raise ToolError(e)
195
204
 
196
205
 
197
206
 
198
207
  @pp_mcp.tool()
199
208
  async def highly_variable_genes(
200
- request: HighlyVariableGenesModel = HighlyVariableGenesModel()
209
+ request: HighlyVariableGenesModel = HighlyVariableGenesModel(),
210
+ adinfo: AdataModel = AdataModel()
201
211
  ):
202
212
  """Annotate highly variable genes"""
203
213
 
204
214
  try:
205
- result = await forward_request("pp_highly_variable_genes", request)
215
+ result = await forward_request("pp_highly_variable_genes", request, adinfo)
206
216
  if result is not None:
207
217
  return result
208
218
  try:
209
219
  func_kwargs = filter_args(request, sc.pp.highly_variable_genes)
210
220
  ads = get_ads()
211
- adata = ads.get_adata(request=request)
221
+ adata = ads.get_adata(adinfo=adinfo)
212
222
  sc.pp.highly_variable_genes(adata, **func_kwargs)
213
- add_op_log(adata, sc.pp.highly_variable_genes, func_kwargs)
223
+ add_op_log(adata, sc.pp.highly_variable_genes, func_kwargs, adinfo)
214
224
  except Exception as e:
215
225
  logger.error(f"Error in pp_highly_variable_genes: {str(e)}")
216
226
  raise e
217
227
  return [
218
- generate_msg(request, adata, ads)
228
+ generate_msg(adinfo, adata, ads)
219
229
  ]
220
- except KeyError as e:
221
- raise e
230
+ except ToolError as e:
231
+ raise ToolError(e)
222
232
  except Exception as e:
223
233
  if hasattr(e, '__context__') and e.__context__:
224
- raise Exception(f"{str(e.__context__)}")
234
+ raise ToolError(e.__context__)
225
235
  else:
226
- raise e
236
+ raise ToolError(e)
227
237
 
228
238
 
229
239
  @pp_mcp.tool()
230
240
  async def regress_out(
231
241
  request: RegressOutModel,
232
-
242
+ adinfo: AdataModel = AdataModel()
233
243
  ):
234
244
  """Regress out (mostly) unwanted sources of variation."""
235
245
 
236
246
  try:
237
- result = await forward_request("pp_regress_out", request)
247
+ result = await forward_request("pp_regress_out", request, adinfo)
238
248
  if result is not None:
239
249
  return result
240
250
  func_kwargs = filter_args(request, sc.pp.regress_out)
241
251
  ads = get_ads()
242
- adata = ads.get_adata(request=request).copy()
252
+ adata = ads.get_adata(adinfo=adinfo).copy()
243
253
  sc.pp.regress_out(adata, **func_kwargs)
244
- add_op_log(adata, sc.pp.regress_out, func_kwargs)
245
- ads.set_adata(adata, request=request)
254
+ add_op_log(adata, sc.pp.regress_out, func_kwargs, adinfo)
255
+ ads.set_adata(adata, adinfo=adinfo)
246
256
  return [
247
- generate_msg(request, adata, ads)
257
+ generate_msg(adinfo, adata, ads)
248
258
  ]
249
- except KeyError as e:
250
- raise e
259
+ except ToolError as e:
260
+ raise ToolError(e)
251
261
  except Exception as e:
252
262
  if hasattr(e, '__context__') and e.__context__:
253
- raise Exception(f"{str(e.__context__)}")
263
+ raise ToolError(e.__context__)
254
264
  else:
255
- raise e
265
+ raise ToolError(e)
256
266
 
257
267
  @pp_mcp.tool()
258
268
  async def scale(
259
- request: ScaleModel = ScaleModel()
269
+ request: ScaleModel = ScaleModel(),
270
+ adinfo: AdataModel = AdataModel()
260
271
  ):
261
272
  """Scale data to unit variance and zero mean"""
262
273
 
263
274
  try:
264
- result = await forward_request("pp_scale", request)
275
+ result = await forward_request("pp_scale", request, adinfo)
265
276
  if result is not None:
266
277
  return result
267
278
  func_kwargs = filter_args(request, sc.pp.scale)
268
279
  ads = get_ads()
269
- adata = ads.get_adata(request=request).copy()
280
+ adata = ads.get_adata(adinfo=adinfo).copy()
270
281
 
271
282
  sc.pp.scale(adata, **func_kwargs)
272
- add_op_log(adata, sc.pp.scale, func_kwargs)
283
+ add_op_log(adata, sc.pp.scale, func_kwargs, adinfo)
273
284
 
274
- ads.set_adata(adata, request=request)
285
+ ads.set_adata(adata, adinfo=adinfo)
275
286
  return [
276
- generate_msg(request, adata, ads)
287
+ generate_msg(adinfo, adata, ads)
277
288
  ]
278
- except KeyError as e:
279
- raise e
289
+ except ToolError as e:
290
+ raise ToolError(e)
280
291
  except Exception as e:
281
292
  if hasattr(e, '__context__') and e.__context__:
282
- raise Exception(f"{str(e.__context__)}")
293
+ raise ToolError(e.__context__)
283
294
  else:
284
- raise e
295
+ raise ToolError(e)
285
296
 
286
297
  @pp_mcp.tool()
287
298
  async def combat(
288
- request: CombatModel = CombatModel()
299
+ request: CombatModel = CombatModel(),
300
+ adinfo: AdataModel = AdataModel()
289
301
  ):
290
302
  """ComBat function for batch effect correction"""
291
303
 
292
304
  try:
293
- result = await forward_request("pp_combat", request)
305
+ result = await forward_request("pp_combat", request, adinfo)
294
306
  if result is not None:
295
307
  return result
296
308
  func_kwargs = filter_args(request, sc.pp.combat)
297
309
  ads = get_ads()
298
- adata = ads.get_adata(request=request).copy()
310
+ adata = ads.get_adata(adinfo=adinfo).copy()
299
311
 
300
312
  sc.pp.combat(adata, **func_kwargs)
301
- add_op_log(adata, sc.pp.combat, func_kwargs)
313
+ add_op_log(adata, sc.pp.combat, func_kwargs, adinfo)
302
314
 
303
- ads.set_adata(adata, request=request)
315
+ ads.set_adata(adata, adinfo=adinfo)
304
316
  return [
305
- generate_msg(request, adata, ads)
317
+ generate_msg(adinfo, adata, ads)
306
318
  ]
307
- except KeyError as e:
308
- raise e
319
+ except ToolError as e:
320
+ raise ToolError(e)
309
321
  except Exception as e:
310
322
  if hasattr(e, '__context__') and e.__context__:
311
- raise Exception(f"{str(e.__context__)}")
323
+ raise ToolError(e.__context__)
312
324
  else:
313
- raise e
325
+ raise ToolError(e)
314
326
 
315
327
  @pp_mcp.tool()
316
328
  async def scrublet(
317
- request: ScrubletModel = ScrubletModel()
329
+ request: ScrubletModel = ScrubletModel(),
330
+ adinfo: AdataModel = AdataModel()
318
331
  ):
319
332
  """Predict doublets using Scrublet"""
320
333
 
321
334
  try:
322
- result = await forward_request("pp_scrublet", request)
335
+ result = await forward_request("pp_scrublet", request, adinfo)
323
336
  if result is not None:
324
337
  return result
325
338
  func_kwargs = filter_args(request, sc.pp.scrublet)
326
339
  ads = get_ads()
327
- adata = ads.get_adata(request=request)
340
+ adata = ads.get_adata(adinfo=adinfo)
328
341
  sc.pp.scrublet(adata, **func_kwargs)
329
- add_op_log(adata, sc.pp.scrublet, func_kwargs)
342
+ add_op_log(adata, sc.pp.scrublet, func_kwargs, adinfo)
330
343
  return [
331
- generate_msg(request, adata, ads)
344
+ generate_msg(adinfo, adata, ads)
332
345
  ]
346
+ except ToolError as e:
347
+ raise ToolError(e)
333
348
  except Exception as e:
334
349
  if hasattr(e, '__context__') and e.__context__:
335
- raise Exception(f"{str(e.__context__)}")
350
+ raise ToolError(e.__context__)
336
351
  else:
337
- raise e
352
+ raise ToolError(e)
338
353
 
339
354
  @pp_mcp.tool()
340
355
  async def neighbors(
341
- request: NeighborsModel = NeighborsModel()
356
+ request: NeighborsModel = NeighborsModel(),
357
+ adinfo: AdataModel = AdataModel()
342
358
  ):
343
359
  """Compute nearest neighbors distance matrix and neighborhood graph"""
344
360
 
345
361
  try:
346
- result = await forward_request("pp_neighbors", request)
362
+ result = await forward_request("pp_neighbors", request, adinfo)
347
363
  if result is not None:
348
364
  return result
349
365
  func_kwargs = filter_args(request, sc.pp.neighbors)
350
366
  ads = get_ads()
351
- adata = ads.get_adata(request=request)
367
+ adata = ads.get_adata(adinfo=adinfo)
352
368
  sc.pp.neighbors(adata, **func_kwargs)
353
- add_op_log(adata, sc.pp.neighbors, func_kwargs)
369
+ add_op_log(adata, sc.pp.neighbors, func_kwargs, adinfo)
354
370
  return [
355
- generate_msg(request, adata, ads)
371
+ generate_msg(adinfo, adata, ads)
356
372
  ]
357
- except KeyError as e:
358
- raise e
373
+ except ToolError as e:
374
+ raise ToolError(e)
359
375
  except Exception as e:
360
376
  if hasattr(e, '__context__') and e.__context__:
361
- raise Exception(f"{str(e.__context__)}")
377
+ raise ToolError(e.__context__)
362
378
  else:
363
- raise e
379
+ raise ToolError(e)