scmcp-shared 0.2.1__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
scmcp_shared/server/pp.py CHANGED
@@ -5,6 +5,7 @@ import scanpy as sc
5
5
  from fastmcp import FastMCP , Context
6
6
  from fastmcp.exceptions import ToolError
7
7
  from ..schema.pp import *
8
+ from ..schema import AdataModel
8
9
  from ..util import filter_args, add_op_log, forward_request, get_ads, generate_msg
9
10
  from ..logging_config import setup_logger
10
11
  logger = setup_logger()
@@ -15,21 +16,22 @@ pp_mcp = FastMCP("ScanpyMCP-PP-Server")
15
16
 
16
17
  @pp_mcp.tool()
17
18
  async def subset_cells(
18
- request: SubsetCellModel = SubsetCellModel()
19
+ request: SubsetCellModel = SubsetCellModel(),
20
+ adinfo: AdataModel = AdataModel()
19
21
  ):
20
22
  """filter or subset cells based on total genes expressed counts and numbers. or values in adata.obs[obs_key]"""
21
23
 
22
24
  try:
23
- result = await forward_request("subset_cells", request)
25
+ result = await forward_request("subset_cells", request, adinfo)
24
26
  if result is not None:
25
27
  return result
26
28
 
27
29
  ads = get_ads()
28
- adata = ads.get_adata(request=request).copy()
30
+ adata = ads.get_adata(adinfo=adinfo).copy()
29
31
  func_kwargs = filter_args(request, sc.pp.filter_cells)
30
32
  if func_kwargs:
31
33
  sc.pp.filter_cells(adata, **func_kwargs)
32
- add_op_log(adata, sc.pp.filter_cells, func_kwargs)
34
+ add_op_log(adata, sc.pp.filter_cells, func_kwargs, adinfo)
33
35
  # Subset based on obs (cells) criteria
34
36
  if request.obs_key is not None:
35
37
  if request.obs_key not in adata.obs.columns:
@@ -46,11 +48,11 @@ async def subset_cells(
46
48
  {
47
49
  "obs_key": request.obs_key, "obs_value": request.obs_value,
48
50
  "obs_min": request.obs_min, "obs_max": request.obs_max
49
- }
51
+ }, adinfo
50
52
  )
51
- ads.set_adata(adata, request=request)
53
+ ads.set_adata(adata, adinfo=adinfo)
52
54
  return [
53
- generate_msg(request, adata, ads)
55
+ generate_msg(adinfo, adata, ads)
54
56
  ]
55
57
  except ToolError as e:
56
58
  raise ToolError(e)
@@ -63,19 +65,20 @@ async def subset_cells(
63
65
 
64
66
  @pp_mcp.tool()
65
67
  async def subset_genes(
66
- request: SubsetGeneModel = SubsetGeneModel()
68
+ request: SubsetGeneModel = SubsetGeneModel(),
69
+ adinfo: AdataModel = AdataModel()
67
70
  ):
68
71
  """filter or subset genes based on number of cells or counts, or values in adata.var[var_key] or subset highly variable genes"""
69
72
  try:
70
- result = await forward_request("pp_subset_genes", request)
73
+ result = await forward_request("pp_subset_genes", request, adinfo)
71
74
  if result is not None:
72
75
  return result
73
76
  func_kwargs = filter_args(request, sc.pp.filter_genes)
74
77
  ads = get_ads()
75
- adata = ads.get_adata(request=request).copy()
78
+ adata = ads.get_adata(adinfo=adinfo).copy()
76
79
  if func_kwargs:
77
80
  sc.pp.filter_genes(adata, **func_kwargs)
78
- add_op_log(adata, sc.pp.filter_genes, func_kwargs)
81
+ add_op_log(adata, sc.pp.filter_genes, func_kwargs, adinfo)
79
82
  if request.var_key is not None:
80
83
  if request.var_key not in adata.var.columns:
81
84
  raise ValueError(f"Key '{request.var_key}' not found in adata.var")
@@ -91,11 +94,11 @@ async def subset_genes(
91
94
  {
92
95
  "var_key": request.var_key, "var_value": request.var_value,
93
96
  "var_min": request.var_min, "var_max": request.var_max, "hpv": request.highly_variable
94
- }
97
+ }, adinfo
95
98
  )
96
- ads.set_adata(adata, request=request)
99
+ ads.set_adata(adata, adinfo=adinfo)
97
100
  return [
98
- generate_msg(request, adata, ads)
101
+ generate_msg(adinfo, adata, ads)
99
102
  ]
100
103
  except ToolError as e:
101
104
  raise ToolError(e)
@@ -107,26 +110,27 @@ async def subset_genes(
107
110
 
108
111
  @pp_mcp.tool()
109
112
  async def calculate_qc_metrics(
110
- request: CalculateQCMetrics = CalculateQCMetrics()
113
+ request: CalculateQCMetrics = CalculateQCMetrics(),
114
+ adinfo: AdataModel = AdataModel()
111
115
  ):
112
116
  """Calculate quality control metrics(common metrics: total counts, gene number, percentage of counts in ribosomal and mitochondrial) for AnnData."""
113
117
 
114
118
  try:
115
- result = await forward_request("pp_calculate_qc_metrics", request)
119
+ result = await forward_request("pp_calculate_qc_metrics", request, adinfo)
116
120
  if result is not None:
117
121
  return result
118
122
  logger.info(f"calculate_qc_metrics {request.model_dump()}")
119
123
  func_kwargs = filter_args(request, sc.pp.calculate_qc_metrics)
120
124
  ads = get_ads()
121
- adata = ads.get_adata(request=request)
125
+ adata = ads.get_adata(adinfo=adinfo)
122
126
  func_kwargs["inplace"] = True
123
127
  try:
124
128
  sc.pp.calculate_qc_metrics(adata, **func_kwargs)
125
- add_op_log(adata, sc.pp.calculate_qc_metrics, func_kwargs)
129
+ add_op_log(adata, sc.pp.calculate_qc_metrics, func_kwargs, adinfo)
126
130
  except KeyError as e:
127
131
  raise KeyError(f"Cound find {e} in adata.var")
128
132
  return [
129
- generate_msg(request, adata, ads)
133
+ generate_msg(adinfo, adata, ads)
130
134
  ]
131
135
  except ToolError as e:
132
136
  raise ToolError(e)
@@ -139,26 +143,27 @@ async def calculate_qc_metrics(
139
143
 
140
144
  @pp_mcp.tool()
141
145
  async def log1p(
142
- request: Log1PModel = Log1PModel()
146
+ request: Log1PModel = Log1PModel(),
147
+ adinfo: AdataModel = AdataModel()
143
148
  ):
144
149
  """Logarithmize the data matrix"""
145
150
 
146
151
  try:
147
- result = await forward_request("pp_log1p", request)
152
+ result = await forward_request("pp_log1p", request, adinfo)
148
153
  if result is not None:
149
154
  return result
150
155
  func_kwargs = filter_args(request, sc.pp.log1p)
151
156
  ads = get_ads()
152
- adata = ads.get_adata(request=request).copy()
157
+ adata = ads.get_adata(adinfo=adinfo).copy()
153
158
  try:
154
159
  sc.pp.log1p(adata, **func_kwargs)
155
160
  adata.raw = adata.copy()
156
- add_op_log(adata, sc.pp.log1p, func_kwargs)
161
+ add_op_log(adata, sc.pp.log1p, func_kwargs, adinfo)
157
162
  except Exception as e:
158
163
  raise e
159
- ads.set_adata(adata, request=request)
164
+ ads.set_adata(adata, adinfo=adinfo)
160
165
  return [
161
- generate_msg(request, adata, ads)
166
+ generate_msg(adinfo, adata, ads)
162
167
  ]
163
168
  except ToolError as e:
164
169
  raise ToolError(e)
@@ -171,22 +176,23 @@ async def log1p(
171
176
 
172
177
  @pp_mcp.tool()
173
178
  async def normalize_total(
174
- request: NormalizeTotalModel = NormalizeTotalModel()
179
+ request: NormalizeTotalModel = NormalizeTotalModel(),
180
+ adinfo: AdataModel = AdataModel()
175
181
  ):
176
182
  """Normalize counts per cell to the same total count"""
177
183
 
178
184
  try:
179
- result = await forward_request("pp_normalize_total", request)
185
+ result = await forward_request("pp_normalize_total", request, adinfo)
180
186
  if result is not None:
181
187
  return result
182
188
  func_kwargs = filter_args(request, sc.pp.normalize_total)
183
189
  ads = get_ads()
184
- adata = ads.get_adata(request=request).copy()
190
+ adata = ads.get_adata(adinfo=adinfo).copy()
185
191
  sc.pp.normalize_total(adata, **func_kwargs)
186
- add_op_log(adata, sc.pp.normalize_total, func_kwargs)
187
- ads.set_adata(adata, request=request)
192
+ add_op_log(adata, sc.pp.normalize_total, func_kwargs, adinfo)
193
+ ads.set_adata(adata, adinfo=adinfo)
188
194
  return [
189
- generate_msg(request, adata, ads)
195
+ generate_msg(adinfo, adata, ads)
190
196
  ]
191
197
  except ToolError as e:
192
198
  raise ToolError(e)
@@ -200,25 +206,26 @@ async def normalize_total(
200
206
 
201
207
  @pp_mcp.tool()
202
208
  async def highly_variable_genes(
203
- request: HighlyVariableGenesModel = HighlyVariableGenesModel()
209
+ request: HighlyVariableGenesModel = HighlyVariableGenesModel(),
210
+ adinfo: AdataModel = AdataModel()
204
211
  ):
205
212
  """Annotate highly variable genes"""
206
213
 
207
214
  try:
208
- result = await forward_request("pp_highly_variable_genes", request)
215
+ result = await forward_request("pp_highly_variable_genes", request, adinfo)
209
216
  if result is not None:
210
217
  return result
211
218
  try:
212
219
  func_kwargs = filter_args(request, sc.pp.highly_variable_genes)
213
220
  ads = get_ads()
214
- adata = ads.get_adata(request=request)
221
+ adata = ads.get_adata(adinfo=adinfo)
215
222
  sc.pp.highly_variable_genes(adata, **func_kwargs)
216
- add_op_log(adata, sc.pp.highly_variable_genes, func_kwargs)
223
+ add_op_log(adata, sc.pp.highly_variable_genes, func_kwargs, adinfo)
217
224
  except Exception as e:
218
225
  logger.error(f"Error in pp_highly_variable_genes: {str(e)}")
219
226
  raise e
220
227
  return [
221
- generate_msg(request, adata, ads)
228
+ generate_msg(adinfo, adata, ads)
222
229
  ]
223
230
  except ToolError as e:
224
231
  raise ToolError(e)
@@ -232,22 +239,22 @@ async def highly_variable_genes(
232
239
  @pp_mcp.tool()
233
240
  async def regress_out(
234
241
  request: RegressOutModel,
235
-
242
+ adinfo: AdataModel = AdataModel()
236
243
  ):
237
244
  """Regress out (mostly) unwanted sources of variation."""
238
245
 
239
246
  try:
240
- result = await forward_request("pp_regress_out", request)
247
+ result = await forward_request("pp_regress_out", request, adinfo)
241
248
  if result is not None:
242
249
  return result
243
250
  func_kwargs = filter_args(request, sc.pp.regress_out)
244
251
  ads = get_ads()
245
- adata = ads.get_adata(request=request).copy()
252
+ adata = ads.get_adata(adinfo=adinfo).copy()
246
253
  sc.pp.regress_out(adata, **func_kwargs)
247
- add_op_log(adata, sc.pp.regress_out, func_kwargs)
248
- ads.set_adata(adata, request=request)
254
+ add_op_log(adata, sc.pp.regress_out, func_kwargs, adinfo)
255
+ ads.set_adata(adata, adinfo=adinfo)
249
256
  return [
250
- generate_msg(request, adata, ads)
257
+ generate_msg(adinfo, adata, ads)
251
258
  ]
252
259
  except ToolError as e:
253
260
  raise ToolError(e)
@@ -259,24 +266,25 @@ async def regress_out(
259
266
 
260
267
  @pp_mcp.tool()
261
268
  async def scale(
262
- request: ScaleModel = ScaleModel()
269
+ request: ScaleModel = ScaleModel(),
270
+ adinfo: AdataModel = AdataModel()
263
271
  ):
264
272
  """Scale data to unit variance and zero mean"""
265
273
 
266
274
  try:
267
- result = await forward_request("pp_scale", request)
275
+ result = await forward_request("pp_scale", request, adinfo)
268
276
  if result is not None:
269
277
  return result
270
278
  func_kwargs = filter_args(request, sc.pp.scale)
271
279
  ads = get_ads()
272
- adata = ads.get_adata(request=request).copy()
280
+ adata = ads.get_adata(adinfo=adinfo).copy()
273
281
 
274
282
  sc.pp.scale(adata, **func_kwargs)
275
- add_op_log(adata, sc.pp.scale, func_kwargs)
283
+ add_op_log(adata, sc.pp.scale, func_kwargs, adinfo)
276
284
 
277
- ads.set_adata(adata, request=request)
285
+ ads.set_adata(adata, adinfo=adinfo)
278
286
  return [
279
- generate_msg(request, adata, ads)
287
+ generate_msg(adinfo, adata, ads)
280
288
  ]
281
289
  except ToolError as e:
282
290
  raise ToolError(e)
@@ -288,24 +296,25 @@ async def scale(
288
296
 
289
297
  @pp_mcp.tool()
290
298
  async def combat(
291
- request: CombatModel = CombatModel()
299
+ request: CombatModel = CombatModel(),
300
+ adinfo: AdataModel = AdataModel()
292
301
  ):
293
302
  """ComBat function for batch effect correction"""
294
303
 
295
304
  try:
296
- result = await forward_request("pp_combat", request)
305
+ result = await forward_request("pp_combat", request, adinfo)
297
306
  if result is not None:
298
307
  return result
299
308
  func_kwargs = filter_args(request, sc.pp.combat)
300
309
  ads = get_ads()
301
- adata = ads.get_adata(request=request).copy()
310
+ adata = ads.get_adata(adinfo=adinfo).copy()
302
311
 
303
312
  sc.pp.combat(adata, **func_kwargs)
304
- add_op_log(adata, sc.pp.combat, func_kwargs)
313
+ add_op_log(adata, sc.pp.combat, func_kwargs, adinfo)
305
314
 
306
- ads.set_adata(adata, request=request)
315
+ ads.set_adata(adata, adinfo=adinfo)
307
316
  return [
308
- generate_msg(request, adata, ads)
317
+ generate_msg(adinfo, adata, ads)
309
318
  ]
310
319
  except ToolError as e:
311
320
  raise ToolError(e)
@@ -317,21 +326,22 @@ async def combat(
317
326
 
318
327
  @pp_mcp.tool()
319
328
  async def scrublet(
320
- request: ScrubletModel = ScrubletModel()
329
+ request: ScrubletModel = ScrubletModel(),
330
+ adinfo: AdataModel = AdataModel()
321
331
  ):
322
332
  """Predict doublets using Scrublet"""
323
333
 
324
334
  try:
325
- result = await forward_request("pp_scrublet", request)
335
+ result = await forward_request("pp_scrublet", request, adinfo)
326
336
  if result is not None:
327
337
  return result
328
338
  func_kwargs = filter_args(request, sc.pp.scrublet)
329
339
  ads = get_ads()
330
- adata = ads.get_adata(request=request)
340
+ adata = ads.get_adata(adinfo=adinfo)
331
341
  sc.pp.scrublet(adata, **func_kwargs)
332
- add_op_log(adata, sc.pp.scrublet, func_kwargs)
342
+ add_op_log(adata, sc.pp.scrublet, func_kwargs, adinfo)
333
343
  return [
334
- generate_msg(request, adata, ads)
344
+ generate_msg(adinfo, adata, ads)
335
345
  ]
336
346
  except ToolError as e:
337
347
  raise ToolError(e)
@@ -343,21 +353,22 @@ async def scrublet(
343
353
 
344
354
  @pp_mcp.tool()
345
355
  async def neighbors(
346
- request: NeighborsModel = NeighborsModel()
356
+ request: NeighborsModel = NeighborsModel(),
357
+ adinfo: AdataModel = AdataModel()
347
358
  ):
348
359
  """Compute nearest neighbors distance matrix and neighborhood graph"""
349
360
 
350
361
  try:
351
- result = await forward_request("pp_neighbors", request)
362
+ result = await forward_request("pp_neighbors", request, adinfo)
352
363
  if result is not None:
353
364
  return result
354
365
  func_kwargs = filter_args(request, sc.pp.neighbors)
355
366
  ads = get_ads()
356
- adata = ads.get_adata(request=request)
367
+ adata = ads.get_adata(adinfo=adinfo)
357
368
  sc.pp.neighbors(adata, **func_kwargs)
358
- add_op_log(adata, sc.pp.neighbors, func_kwargs)
369
+ add_op_log(adata, sc.pp.neighbors, func_kwargs, adinfo)
359
370
  return [
360
- generate_msg(request, adata, ads)
371
+ generate_msg(adinfo, adata, ads)
361
372
  ]
362
373
  except ToolError as e:
363
374
  raise ToolError(e)