scmcp-shared 0.4.0__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. scmcp_shared/__init__.py +1 -3
  2. scmcp_shared/agent.py +38 -21
  3. scmcp_shared/backend.py +44 -0
  4. scmcp_shared/cli.py +75 -46
  5. scmcp_shared/kb.py +139 -0
  6. scmcp_shared/logging_config.py +6 -8
  7. scmcp_shared/mcp_base.py +184 -0
  8. scmcp_shared/schema/io.py +101 -59
  9. scmcp_shared/schema/pl.py +386 -490
  10. scmcp_shared/schema/pp.py +514 -265
  11. scmcp_shared/schema/preset/__init__.py +15 -0
  12. scmcp_shared/schema/preset/io.py +103 -0
  13. scmcp_shared/schema/preset/pl.py +843 -0
  14. scmcp_shared/schema/preset/pp.py +616 -0
  15. scmcp_shared/schema/preset/tl.py +917 -0
  16. scmcp_shared/schema/preset/util.py +123 -0
  17. scmcp_shared/schema/tl.py +355 -407
  18. scmcp_shared/schema/util.py +57 -72
  19. scmcp_shared/server/__init__.py +5 -10
  20. scmcp_shared/server/auto.py +15 -11
  21. scmcp_shared/server/code.py +3 -0
  22. scmcp_shared/server/preset/__init__.py +14 -0
  23. scmcp_shared/server/{io.py → preset/io.py} +26 -22
  24. scmcp_shared/server/{pl.py → preset/pl.py} +162 -78
  25. scmcp_shared/server/{pp.py → preset/pp.py} +123 -65
  26. scmcp_shared/server/{tl.py → preset/tl.py} +142 -79
  27. scmcp_shared/server/{util.py → preset/util.py} +123 -66
  28. scmcp_shared/server/rag.py +13 -0
  29. scmcp_shared/util.py +109 -38
  30. {scmcp_shared-0.4.0.dist-info → scmcp_shared-0.6.0.dist-info}/METADATA +6 -2
  31. scmcp_shared-0.6.0.dist-info/RECORD +35 -0
  32. scmcp_shared/server/base.py +0 -148
  33. scmcp_shared-0.4.0.dist-info/RECORD +0 -24
  34. {scmcp_shared-0.4.0.dist-info → scmcp_shared-0.6.0.dist-info}/WHEEL +0 -0
  35. {scmcp_shared-0.4.0.dist-info → scmcp_shared-0.6.0.dist-info}/licenses/LICENSE +0 -0
scmcp_shared/schema/pp.py CHANGED
@@ -1,191 +1,187 @@
1
- from pydantic import (
2
- Field,
3
- ValidationInfo,
4
- computed_field,
5
- field_validator,
6
- model_validator,BaseModel
7
- )
1
+ from pydantic import Field, field_validator, BaseModel
8
2
 
9
3
  from typing import Optional, Union, List, Dict, Any
10
4
  from typing import Literal
11
5
  import numpy as np
12
6
 
13
7
 
14
- class FilterCells(BaseModel):
8
+ class FilterCellsParam(BaseModel):
15
9
  """Input schema for the filter_cells preprocessing tool."""
16
-
10
+
11
+ adata: str = Field(..., description="The AnnData object variable name.")
17
12
  min_counts: Optional[int] = Field(
18
13
  default=None,
19
- description="Minimum number of counts required for a cell to pass filtering."
14
+ description="Minimum number of counts required for a cell to pass filtering.",
20
15
  )
21
-
16
+
22
17
  min_genes: Optional[int] = Field(
23
18
  default=None,
24
- description="Minimum number of genes expressed required for a cell to pass filtering."
19
+ description="Minimum number of genes expressed required for a cell to pass filtering.",
25
20
  )
26
-
21
+
27
22
  max_counts: Optional[int] = Field(
28
23
  default=None,
29
- description="Maximum number of counts required for a cell to pass filtering."
24
+ description="Maximum number of counts required for a cell to pass filtering.",
30
25
  )
31
-
26
+
32
27
  max_genes: Optional[int] = Field(
33
28
  default=None,
34
- description="Maximum number of genes expressed required for a cell to pass filtering."
29
+ description="Maximum number of genes expressed required for a cell to pass filtering.",
35
30
  )
36
-
37
- @field_validator('min_counts', 'min_genes', 'max_counts', 'max_genes')
31
+
32
+ @field_validator("min_counts", "min_genes", "max_counts", "max_genes")
38
33
  def validate_positive_integers(cls, v: Optional[int]) -> Optional[int]:
39
- """验证整数参数为正数"""
40
34
  if v is not None and v <= 0:
41
- raise ValueError("过滤参数必须是正整数")
35
+ raise ValueError("must be positive_integers")
42
36
  return v
43
37
 
44
38
 
45
- class FilterGenes(BaseModel):
39
+ class FilterGenesParam(BaseModel):
46
40
  """Input schema for the filter_genes preprocessing tool."""
47
-
41
+
42
+ adata: str = Field(..., description="The AnnData object variable name.")
48
43
  min_counts: Optional[int] = Field(
49
44
  default=None,
50
- description="Minimum number of counts required for a gene to pass filtering."
45
+ description="Minimum number of counts required for a gene to pass filtering.",
51
46
  )
52
-
47
+
53
48
  min_cells: Optional[int] = Field(
54
49
  default=None,
55
- description="Minimum number of cells expressed required for a gene to pass filtering."
50
+ description="Minimum number of cells expressed required for a gene to pass filtering.",
56
51
  )
57
-
52
+
58
53
  max_counts: Optional[int] = Field(
59
54
  default=None,
60
- description="Maximum number of counts required for a gene to pass filtering."
55
+ description="Maximum number of counts required for a gene to pass filtering.",
61
56
  )
62
-
57
+
63
58
  max_cells: Optional[int] = Field(
64
59
  default=None,
65
- description="Maximum number of cells expressed required for a gene to pass filtering."
60
+ description="Maximum number of cells expressed required for a gene to pass filtering.",
66
61
  )
67
-
68
- @field_validator('min_counts', 'min_cells', 'max_counts', 'max_cells')
62
+
63
+ @field_validator("min_counts", "min_cells", "max_counts", "max_cells")
69
64
  def validate_positive_integers(cls, v: Optional[int]) -> Optional[int]:
70
- """验证整数参数为正数"""
71
65
  if v is not None and v <= 0:
72
66
  raise ValueError("must be positive_integers")
73
67
  return v
74
68
 
75
69
 
76
- class SubsetCellParams(BaseModel):
70
+ class SubsetCellParam(BaseModel):
77
71
  """Input schema for subsetting AnnData objects based on various criteria."""
72
+
73
+ adata: str = Field(..., description="The AnnData object variable name.")
78
74
  obs_key: Optional[str] = Field(
79
75
  default=None,
80
- description="Key in adata.obs to use for subsetting observations/cells."
76
+ description="Key in adata.obs to use for subsetting observations/cells.",
81
77
  )
82
78
  obs_min: Optional[float] = Field(
83
79
  default=None,
84
- description="Minimum value for the obs_key to include in the subset."
80
+ description="Minimum value for the obs_key to include in the subset.",
85
81
  )
86
82
  obs_max: Optional[float] = Field(
87
83
  default=None,
88
- description="Maximum value for the obs_key to include in the subset."
84
+ description="Maximum value for the obs_key to include in the subset.",
89
85
  )
90
86
  obs_value: Optional[Any] = Field(
91
87
  default=None,
92
- description="Exact value for the obs_key to include in the subset (adata.obs[obs_key] == obs_value)."
88
+ description="Exact value for the obs_key to include in the subset (adata.obs[obs_key] == obs_value).",
93
89
  )
94
90
  min_counts: Optional[int] = Field(
95
91
  default=None,
96
- description="Minimum number of counts required for a cell to pass filtering."
97
- )
92
+ description="Minimum number of counts required for a cell to pass filtering.",
93
+ )
98
94
  min_genes: Optional[int] = Field(
99
95
  default=None,
100
- description="Minimum number of genes expressed required for a cell to pass filtering."
96
+ description="Minimum number of genes expressed required for a cell to pass filtering.",
101
97
  )
102
98
  max_counts: Optional[int] = Field(
103
99
  default=None,
104
- description="Maximum number of counts required for a cell to pass filtering."
100
+ description="Maximum number of counts required for a cell to pass filtering.",
105
101
  )
106
102
  max_genes: Optional[int] = Field(
107
103
  default=None,
108
- description="Maximum number of genes expressed required for a cell to pass filtering."
104
+ description="Maximum number of genes expressed required for a cell to pass filtering.",
109
105
  )
110
106
 
111
107
 
112
- class SubsetGeneParams(BaseModel):
108
+ class SubsetGeneParam(BaseModel):
113
109
  """Input schema for subsetting AnnData objects based on various criteria."""
110
+
111
+ adata: str = Field(..., description="The AnnData object variable name.")
114
112
  min_counts: Optional[int] = Field(
115
113
  default=None,
116
- description="Minimum number of counts required for a gene to pass filtering."
114
+ description="Minimum number of counts required for a gene to pass filtering.",
117
115
  )
118
116
  min_cells: Optional[int] = Field(
119
117
  default=None,
120
- description="Minimum number of cells expressed required for a gene to pass filtering."
118
+ description="Minimum number of cells expressed required for a gene to pass filtering.",
121
119
  )
122
120
  max_counts: Optional[int] = Field(
123
121
  default=None,
124
- description="Maximum number of counts required for a gene to pass filtering."
125
- )
122
+ description="Maximum number of counts required for a gene to pass filtering.",
123
+ )
126
124
  max_cells: Optional[int] = Field(
127
125
  default=None,
128
- description="Maximum number of cells expressed required for a gene to pass filtering."
129
- )
126
+ description="Maximum number of cells expressed required for a gene to pass filtering.",
127
+ )
130
128
  var_key: Optional[str] = Field(
131
129
  default=None,
132
- description="Key in adata.var to use for subsetting variables/genes."
130
+ description="Key in adata.var to use for subsetting variables/genes.",
133
131
  )
134
132
  var_min: Optional[float] = Field(
135
133
  default=None,
136
- description="Minimum value for the var_key to include in the subset."
134
+ description="Minimum value for the var_key to include in the subset.",
137
135
  )
138
136
  var_max: Optional[float] = Field(
139
137
  default=None,
140
- description="Maximum value for the var_key to include in the subset."
138
+ description="Maximum value for the var_key to include in the subset.",
141
139
  )
142
140
  highly_variable: Optional[bool] = Field(
143
141
  default=False,
144
- description="If True, subset to highly variable genes. Requires 'highly_variable' column in adata.var."
142
+ description="If True, subset to highly variable genes. Requires 'highly_variable' column in adata.var.",
145
143
  )
146
144
 
147
145
 
148
146
  class CalculateQCMetrics(BaseModel):
149
147
  """Input schema for the calculate_qc_metrics preprocessing tool."""
150
-
151
- expr_type: str = Field(
152
- default="counts",
153
- description="Name of kind of values in X."
154
- )
155
-
148
+
149
+ adata: str = Field(..., description="The AnnData object variable name.")
150
+
151
+ expr_type: str = Field(default="counts", description="Name of kind of values in X.")
152
+
156
153
  var_type: str = Field(
157
- default="genes",
158
- description="The kind of thing the variables are."
154
+ default="genes", description="The kind of thing the variables are."
159
155
  )
160
-
161
- qc_vars: Optional[Union[List[str], str]] = Field(
156
+
157
+ qc_vars: Optional[Union[List[str], str]] = Field(
162
158
  default=[],
163
159
  description=(
164
160
  "Keys for boolean columns of .var which identify variables you could want to control for "
165
161
  "mark_var tool should be called frist when you want to calculate mt, ribo, hb, and check tool output for var columns"
166
- )
162
+ ),
167
163
  )
168
-
164
+
169
165
  percent_top: Optional[List[int]] = Field(
170
166
  default=[50, 100, 200, 500],
171
- description="List of ranks (where genes are ranked by expression) at which the cumulative proportion of expression will be reported as a percentage."
167
+ description="List of ranks (where genes are ranked by expression) at which the cumulative proportion of expression will be reported as a percentage.",
172
168
  )
173
-
169
+
174
170
  layer: Optional[str] = Field(
175
171
  default=None,
176
- description="If provided, use adata.layers[layer] for expression values instead of adata.X"
172
+ description="If provided, use adata.layers[layer] for expression values instead of adata.X",
177
173
  )
178
-
174
+
179
175
  use_raw: bool = Field(
180
176
  default=False,
181
- description="If True, use adata.raw.X for expression values instead of adata.X"
177
+ description="If True, use adata.raw.X for expression values instead of adata.X",
182
178
  )
183
179
  log1p: bool = Field(
184
180
  default=True,
185
- description="Set to False to skip computing log1p transformed annotations."
181
+ description="Set to False to skip computing log1p transformed annotations.",
186
182
  )
187
-
188
- @field_validator('percent_top')
183
+
184
+ @field_validator("percent_top")
189
185
  def validate_percent_top(cls, v: Optional[List[int]]) -> Optional[List[int]]:
190
186
  """验证 percent_top 中的值为正整数"""
191
187
  if v is not None:
@@ -193,38 +189,34 @@ class CalculateQCMetrics(BaseModel):
193
189
  if not isinstance(rank, int) or rank <= 0:
194
190
  raise ValueError("percent_top 中的所有值必须是正整数")
195
191
  return v
196
-
197
192
 
198
193
 
199
- class Log1PParams(BaseModel):
194
+ class Log1PParam(BaseModel):
200
195
  """Input schema for the log1p preprocessing tool."""
201
-
196
+
197
+ adata: str = Field(..., description="The AnnData object variable name.")
202
198
  base: Optional[Union[int, float]] = Field(
203
199
  default=None,
204
- description="Base of the logarithm. Natural logarithm is used by default."
200
+ description="Base of the logarithm. Natural logarithm is used by default.",
205
201
  )
206
-
202
+
207
203
  chunked: Optional[bool] = Field(
208
204
  default=None,
209
- description="Process the data matrix in chunks, which will save memory."
205
+ description="Process the data matrix in chunks, which will save memory.",
210
206
  )
211
-
207
+
212
208
  chunk_size: Optional[int] = Field(
213
209
  default=None,
214
- description="Number of observations in the chunks to process the data in."
210
+ description="Number of observations in the chunks to process the data in.",
215
211
  )
216
-
212
+
217
213
  layer: Optional[str] = Field(
218
- default=None,
219
- description="Entry of layers to transform."
214
+ default=None, description="Entry of layers to transform."
220
215
  )
221
-
222
- obsm: Optional[str] = Field(
223
- default=None,
224
- description="Entry of obsm to transform."
225
- )
226
-
227
- @field_validator('chunk_size')
216
+
217
+ obsm: Optional[str] = Field(default=None, description="Entry of obsm to transform.")
218
+
219
+ @field_validator("chunk_size")
228
220
  def validate_chunk_size(cls, v: Optional[int]) -> Optional[int]:
229
221
  """Validate chunk_size is positive integer"""
230
222
  if v is not None and v <= 0:
@@ -232,74 +224,66 @@ class Log1PParams(BaseModel):
232
224
  return v
233
225
 
234
226
 
235
-
236
- class HighlyVariableGenesParams(BaseModel):
227
+ class HighlyVariableGenesParam(BaseModel):
237
228
  """Input schema for the highly_variable_genes preprocessing tool."""
238
-
229
+
230
+ adata: str = Field(..., description="The AnnData object variable name.")
239
231
  layer: Optional[str] = Field(
240
232
  default=None,
241
- description="If provided, use adata.layers[layer] for expression values."
233
+ description="If provided, use adata.layers[layer] for expression values.",
242
234
  )
243
-
235
+
244
236
  n_top_genes: Optional[int] = Field(
245
237
  default=None,
246
238
  description="Number of highly-variable genes to keep. Mandatory if `flavor='seurat_v3'",
247
239
  )
248
-
240
+
249
241
  min_disp: Optional[float] = Field(
250
- default=0.5,
251
- description="Minimum dispersion cutoff for gene selection."
242
+ default=0.5, description="Minimum dispersion cutoff for gene selection."
252
243
  )
253
-
244
+
254
245
  max_disp: Optional[float] = Field(
255
- default=np.inf,
256
- description="Maximum dispersion cutoff for gene selection."
246
+ default=np.inf, description="Maximum dispersion cutoff for gene selection."
257
247
  )
258
248
  min_mean: Optional[float] = Field(
259
- default=0.0125,
260
- description="Minimum mean expression cutoff for gene selection."
249
+ default=0.0125, description="Minimum mean expression cutoff for gene selection."
261
250
  )
262
251
  max_mean: Optional[float] = Field(
263
- default=3,
264
- description="Maximum mean expression cutoff for gene selection."
252
+ default=3, description="Maximum mean expression cutoff for gene selection."
265
253
  )
266
254
  span: Optional[float] = Field(
267
255
  default=0.3,
268
256
  description="Fraction of data used for loess model fit in seurat_v3.",
269
257
  gt=0,
270
- lt=1
258
+ lt=1,
271
259
  )
272
260
  n_bins: Optional[int] = Field(
273
- default=20,
274
- description="Number of bins for mean expression binning.",
275
- gt=0
261
+ default=20, description="Number of bins for mean expression binning.", gt=0
276
262
  )
277
- flavor: Optional[Literal['seurat', 'cell_ranger', 'seurat_v3', 'seurat_v3_paper']] = Field(
278
- default='seurat',
279
- description="Method for identifying highly variable genes."
263
+ flavor: Optional[
264
+ Literal["seurat", "cell_ranger", "seurat_v3", "seurat_v3_paper"]
265
+ ] = Field(
266
+ default="seurat", description="Method for identifying highly variable genes."
280
267
  )
281
268
  subset: Optional[bool] = Field(
282
- default=False,
283
- description="Inplace subset to highly-variable genes if True."
269
+ default=False, description="Inplace subset to highly-variable genes if True."
284
270
  )
285
271
  batch_key: Optional[str] = Field(
286
- default=None,
287
- description="Key in adata.obs for batch information."
272
+ default=None, description="Key in adata.obs for batch information."
288
273
  )
289
-
274
+
290
275
  check_values: Optional[bool] = Field(
291
- default=True,
292
- description="Check if counts are integers for seurat_v3 flavor."
276
+ default=True, description="Check if counts are integers for seurat_v3 flavor."
293
277
  )
294
-
295
- @field_validator('n_top_genes', 'n_bins')
278
+
279
+ @field_validator("n_top_genes", "n_bins")
296
280
  def validate_positive_integers(cls, v: Optional[int]) -> Optional[int]:
297
281
  """Validate positive integers"""
298
282
  if v is not None and v <= 0:
299
283
  raise ValueError("must be a positive integer")
300
284
  return v
301
-
302
- @field_validator('span')
285
+
286
+ @field_validator("span")
303
287
  def validate_span(cls, v: float) -> float:
304
288
  """Validate span is between 0 and 1"""
305
289
  if v <= 0 or v >= 1:
@@ -307,30 +291,28 @@ class HighlyVariableGenesParams(BaseModel):
307
291
  return v
308
292
 
309
293
 
310
- class RegressOutParams(BaseModel):
294
+ class RegressOutParam(BaseModel):
311
295
  """Input schema for the regress_out preprocessing tool."""
312
-
296
+
297
+ adata: str = Field(..., description="The AnnData object variable name.")
313
298
  keys: Union[str, List[str]] = Field(
314
299
  description="Keys for observation annotation on which to regress on."
315
300
  )
316
301
  layer: Optional[str] = Field(
317
- default=None,
318
- description="If provided, which element of layers to regress on."
302
+ default=None, description="If provided, which element of layers to regress on."
319
303
  )
320
304
  n_jobs: Optional[int] = Field(
321
- default=None,
322
- description="Number of jobs for parallel computation.",
323
- gt=0
305
+ default=None, description="Number of jobs for parallel computation.", gt=0
324
306
  )
325
-
326
- @field_validator('n_jobs')
307
+
308
+ @field_validator("n_jobs")
327
309
  def validate_n_jobs(cls, v: Optional[int]) -> Optional[int]:
328
310
  """Validate n_jobs is positive integer"""
329
311
  if v is not None and v <= 0:
330
312
  raise ValueError("n_jobs must be a positive integer")
331
313
  return v
332
-
333
- @field_validator('keys')
314
+
315
+ @field_validator("keys")
334
316
  def validate_keys(cls, v: Union[str, List[str]]) -> Union[str, List[str]]:
335
317
  """Ensure keys is either a string or list of strings"""
336
318
  if isinstance(v, str):
@@ -340,35 +322,34 @@ class RegressOutParams(BaseModel):
340
322
  raise ValueError("keys must be a string or list of strings")
341
323
 
342
324
 
343
- class ScaleParams(BaseModel):
325
+ class ScaleParam(BaseModel):
344
326
  """Input schema for the scale preprocessing tool."""
345
-
327
+
328
+ adata: str = Field(..., description="The AnnData object variable name.")
346
329
  zero_center: bool = Field(
347
330
  default=True,
348
- description="If False, omit zero-centering variables to handle sparse input efficiently."
331
+ description="If False, omit zero-centering variables to handle sparse input efficiently.",
349
332
  )
350
-
333
+
351
334
  max_value: Optional[float] = Field(
352
335
  default=None,
353
- description="Clip (truncate) to this value after scaling. If None, do not clip."
336
+ description="Clip (truncate) to this value after scaling. If None, do not clip.",
354
337
  )
355
-
338
+
356
339
  layer: Optional[str] = Field(
357
- default=None,
358
- description="If provided, which element of layers to scale."
340
+ default=None, description="If provided, which element of layers to scale."
359
341
  )
360
-
342
+
361
343
  obsm: Optional[str] = Field(
362
- default=None,
363
- description="If provided, which element of obsm to scale."
344
+ default=None, description="If provided, which element of obsm to scale."
364
345
  )
365
-
346
+
366
347
  mask_obs: Optional[Union[str, bool]] = Field(
367
348
  default=None,
368
- description="Boolean mask or string referring to obs column for subsetting observations."
349
+ description="Boolean mask or string referring to obs column for subsetting observations.",
369
350
  )
370
-
371
- @field_validator('max_value')
351
+
352
+ @field_validator("max_value")
372
353
  def validate_max_value(cls, v: Optional[float]) -> Optional[float]:
373
354
  """Validate max_value is positive if provided"""
374
355
  if v is not None and v <= 0:
@@ -376,27 +357,28 @@ class ScaleParams(BaseModel):
376
357
  return v
377
358
 
378
359
 
379
- class CombatParams(BaseModel):
360
+ class CombatParam(BaseModel):
380
361
  """Input schema for the combat batch effect correction tool."""
381
-
362
+
363
+ adata: str = Field(..., description="The AnnData object variable name.")
382
364
  key: str = Field(
383
- default='batch',
384
- description="Key to a categorical annotation from adata.obs that will be used for batch effect removal."
365
+ default="batch",
366
+ description="Key to a categorical annotation from adata.obs that will be used for batch effect removal.",
385
367
  )
386
-
368
+
387
369
  covariates: Optional[List[str]] = Field(
388
370
  default=None,
389
- description="Additional covariates besides the batch variable such as adjustment variables or biological condition."
371
+ description="Additional covariates besides the batch variable such as adjustment variables or biological condition.",
390
372
  )
391
-
392
- @field_validator('key')
373
+
374
+ @field_validator("key")
393
375
  def validate_key(cls, v: str) -> str:
394
376
  """Validate key is not empty"""
395
377
  if not v.strip():
396
378
  raise ValueError("key cannot be empty")
397
379
  return v
398
-
399
- @field_validator('covariates')
380
+
381
+ @field_validator("covariates")
400
382
  def validate_covariates(cls, v: Optional[List[str]]) -> Optional[List[str]]:
401
383
  """Validate covariates are non-empty strings if provided"""
402
384
  if v is not None:
@@ -405,243 +387,240 @@ class CombatParams(BaseModel):
405
387
  return v
406
388
 
407
389
 
408
- class ScrubletParams(BaseModel):
390
+ class ScrubletParam(BaseModel):
409
391
  """Input schema for the scrublet doublet prediction tool."""
410
-
392
+
393
+ adata: str = Field(..., description="The AnnData object variable name.")
411
394
  adata_sim: Optional[str] = Field(
412
395
  default=None,
413
- description="Optional path to AnnData object with simulated doublets."
396
+ description="Optional path to AnnData object with simulated doublets.",
414
397
  )
415
-
398
+
416
399
  batch_key: Optional[str] = Field(
417
- default=None,
418
- description="Key in adata.obs for batch information."
400
+ default=None, description="Key in adata.obs for batch information."
419
401
  )
420
-
402
+
421
403
  sim_doublet_ratio: float = Field(
422
404
  default=2.0,
423
405
  description="Number of doublets to simulate relative to observed transcriptomes.",
424
- gt=0
406
+ gt=0,
425
407
  )
426
-
408
+
427
409
  expected_doublet_rate: float = Field(
428
410
  default=0.05,
429
411
  description="Estimated doublet rate for the experiment.",
430
412
  ge=0,
431
- le=1
413
+ le=1,
432
414
  )
433
-
415
+
434
416
  stdev_doublet_rate: float = Field(
435
417
  default=0.02,
436
418
  description="Uncertainty in the expected doublet rate.",
437
419
  ge=0,
438
- le=1
420
+ le=1,
439
421
  )
440
-
422
+
441
423
  synthetic_doublet_umi_subsampling: float = Field(
442
424
  default=1.0,
443
425
  description="Rate for sampling UMIs when creating synthetic doublets.",
444
426
  gt=0,
445
- le=1
427
+ le=1,
446
428
  )
447
-
429
+
448
430
  knn_dist_metric: str = Field(
449
431
  default="euclidean",
450
- description="Distance metric used when finding nearest neighbors."
432
+ description="Distance metric used when finding nearest neighbors.",
451
433
  )
452
-
434
+
453
435
  normalize_variance: bool = Field(
454
436
  default=True,
455
- description="Normalize data such that each gene has variance of 1."
437
+ description="Normalize data such that each gene has variance of 1.",
456
438
  )
457
-
439
+
458
440
  log_transform: bool = Field(
459
- default=False,
460
- description="Whether to log-transform the data prior to PCA."
441
+ default=False, description="Whether to log-transform the data prior to PCA."
461
442
  )
462
-
443
+
463
444
  mean_center: bool = Field(
464
- default=True,
465
- description="Center data such that each gene has mean of 0."
445
+ default=True, description="Center data such that each gene has mean of 0."
466
446
  )
467
-
447
+
468
448
  n_prin_comps: int = Field(
469
449
  default=30,
470
450
  description="Number of principal components used for embedding.",
471
- gt=0
451
+ gt=0,
472
452
  )
473
-
453
+
474
454
  use_approx_neighbors: Optional[bool] = Field(
475
- default=None,
476
- description="Use approximate nearest neighbor method (annoy)."
455
+ default=None, description="Use approximate nearest neighbor method (annoy)."
477
456
  )
478
-
457
+
479
458
  get_doublet_neighbor_parents: bool = Field(
480
459
  default=False,
481
- description="Return parent transcriptomes that generated doublet neighbors."
460
+ description="Return parent transcriptomes that generated doublet neighbors.",
482
461
  )
483
-
462
+
484
463
  n_neighbors: Optional[int] = Field(
485
464
  default=None,
486
465
  description="Number of neighbors used to construct KNN graph.",
487
- gt=0
466
+ gt=0,
488
467
  )
489
-
468
+
490
469
  threshold: Optional[float] = Field(
491
470
  default=None,
492
471
  description="Doublet score threshold for calling a transcriptome a doublet.",
493
472
  ge=0,
494
- le=1
473
+ le=1,
495
474
  )
496
-
497
- @field_validator('sim_doublet_ratio', 'expected_doublet_rate', 'stdev_doublet_rate',
498
- 'synthetic_doublet_umi_subsampling', 'n_prin_comps', 'n_neighbors')
499
- def validate_positive_numbers(cls, v: Optional[Union[int, float]]) -> Optional[Union[int, float]]:
475
+
476
+ @field_validator(
477
+ "sim_doublet_ratio",
478
+ "expected_doublet_rate",
479
+ "stdev_doublet_rate",
480
+ "synthetic_doublet_umi_subsampling",
481
+ "n_prin_comps",
482
+ "n_neighbors",
483
+ )
484
+ def validate_positive_numbers(
485
+ cls, v: Optional[Union[int, float]]
486
+ ) -> Optional[Union[int, float]]:
500
487
  """Validate positive numbers where applicable"""
501
488
  if v is not None and v <= 0:
502
489
  raise ValueError("must be a positive number")
503
490
  return v
504
-
505
- @field_validator('knn_dist_metric')
491
+
492
+ @field_validator("knn_dist_metric")
506
493
  def validate_knn_dist_metric(cls, v: str) -> str:
507
494
  """Validate distance metric is supported"""
508
- valid_metrics = ['euclidean', 'manhattan', 'cosine', 'correlation']
495
+ valid_metrics = ["euclidean", "manhattan", "cosine", "correlation"]
509
496
  if v.lower() not in valid_metrics:
510
497
  raise ValueError(f"knn_dist_metric must be one of {valid_metrics}")
511
498
  return v.lower()
512
499
 
513
500
 
514
- class NeighborsParams(BaseModel):
501
+ class NeighborsParam(BaseModel):
515
502
  """Input schema for the neighbors graph construction tool."""
516
-
503
+
504
+ adata: str = Field(..., description="The AnnData object variable name.")
517
505
  n_neighbors: int = Field(
518
506
  default=15,
519
507
  description="Size of local neighborhood used for manifold approximation.",
520
508
  gt=1,
521
- le=100
509
+ le=100,
522
510
  )
523
-
511
+
524
512
  n_pcs: Optional[int] = Field(
525
513
  default=None,
526
514
  description="Number of PCs to use. If None, automatically determined.",
527
- ge=0
515
+ ge=0,
528
516
  )
529
-
517
+
530
518
  use_rep: Optional[str] = Field(
531
- default=None,
532
- description="Key for .obsm to use as representation."
519
+ default=None, description="Key for .obsm to use as representation."
533
520
  )
534
-
521
+
535
522
  knn: bool = Field(
536
523
  default=True,
537
- description="Whether to use hard threshold for neighbor restriction."
524
+ description="Whether to use hard threshold for neighbor restriction.",
538
525
  )
539
-
540
- method: Literal['umap', 'gauss'] = Field(
541
- default='umap',
542
- description="Method for computing connectivities ('umap' or 'gauss')."
526
+
527
+ method: Literal["umap", "gauss"] = Field(
528
+ default="umap",
529
+ description="Method for computing connectivities ('umap' or 'gauss').",
543
530
  )
544
-
531
+
545
532
  transformer: Optional[str] = Field(
546
533
  default=None,
547
- description="Approximate kNN search implementation ('pynndescent' or 'rapids')."
534
+ description="Approximate kNN search implementation ('pynndescent' or 'rapids').",
548
535
  )
549
-
550
- metric: str = Field(
551
- default='euclidean',
552
- description="Distance metric to use."
553
- )
554
-
536
+
537
+ metric: str = Field(default="euclidean", description="Distance metric to use.")
538
+
555
539
  metric_kwds: Dict[str, Any] = Field(
556
- default_factory=dict,
557
- description="Options for the distance metric."
540
+ default_factory=dict, description="Options for the distance metric."
558
541
  )
559
-
560
- random_state: int = Field(
561
- default=0,
562
- description="Random seed for reproducibility."
563
- )
564
-
542
+
543
+ random_state: int = Field(default=0, description="Random seed for reproducibility.")
544
+
565
545
  key_added: Optional[str] = Field(
566
- default=None,
567
- description="Key prefix for storing neighbor results."
546
+ default=None, description="Key prefix for storing neighbor results."
568
547
  )
569
-
570
- @field_validator('n_neighbors', 'n_pcs')
548
+
549
+ @field_validator("n_neighbors", "n_pcs")
571
550
  def validate_positive_integers(cls, v: Optional[int]) -> Optional[int]:
572
551
  """Validate positive integers where applicable"""
573
552
  if v is not None and v <= 0:
574
553
  raise ValueError("must be a positive integer")
575
554
  return v
576
-
577
- @field_validator('method')
555
+
556
+ @field_validator("method")
578
557
  def validate_method(cls, v: str) -> str:
579
558
  """Validate method is supported"""
580
- if v not in ['umap', 'gauss']:
559
+ if v not in ["umap", "gauss"]:
581
560
  raise ValueError("method must be either 'umap' or 'gauss'")
582
561
  return v
583
-
584
- @field_validator('transformer')
562
+
563
+ @field_validator("transformer")
585
564
  def validate_transformer(cls, v: Optional[str]) -> Optional[str]:
586
565
  """Validate transformer option is supported"""
587
- if v is not None and v not in ['pynndescent', 'rapids']:
566
+ if v is not None and v not in ["pynndescent", "rapids"]:
588
567
  raise ValueError("transformer must be either 'pynndescent' or 'rapids'")
589
568
  return v
590
569
 
591
570
 
592
- class NormalizeTotalParams(BaseModel):
571
+ class NormalizeTotalParam(BaseModel):
593
572
  """Input schema for the normalize_total preprocessing tool."""
594
-
573
+
574
+ adata: str = Field(..., description="The AnnData object variable name.")
595
575
  target_sum: Optional[float] = Field(
596
576
  default=None,
597
- description="If None, after normalization, each cell has a total count equal to the median of total counts before normalization. If a number is provided, each cell will have this total count after normalization."
577
+ description="If None, after normalization, each cell has a total count equal to the median of total counts before normalization. If a number is provided, each cell will have this total count after normalization.",
598
578
  )
599
-
579
+
600
580
  exclude_highly_expressed: bool = Field(
601
581
  default=False,
602
- description="Exclude highly expressed genes for the computation of the normalization factor for each cell."
582
+ description="Exclude highly expressed genes for the computation of the normalization factor for each cell.",
603
583
  )
604
-
584
+
605
585
  max_fraction: float = Field(
606
586
  default=0.05,
607
587
  description="If exclude_highly_expressed=True, consider cells as highly expressed that have more counts than max_fraction of the original total counts in at least one cell.",
608
588
  gt=0,
609
- le=1
589
+ le=1,
610
590
  )
611
-
591
+
612
592
  key_added: Optional[str] = Field(
613
593
  default=None,
614
- description="Name of the field in adata.obs where the normalization factor is stored."
594
+ description="Name of the field in adata.obs where the normalization factor is stored.",
615
595
  )
616
-
596
+
617
597
  layer: Optional[str] = Field(
618
598
  default=None,
619
- description="Layer to normalize instead of X. If None, X is normalized."
599
+ description="Layer to normalize instead of X. If None, X is normalized.",
620
600
  )
621
-
622
- layers: Optional[Union[Literal['all'], List[str]]] = Field(
601
+
602
+ layers: Optional[Union[Literal["all"], List[str]]] = Field(
623
603
  default=None,
624
- description="List of layers to normalize. If 'all', normalize all layers."
604
+ description="List of layers to normalize. If 'all', normalize all layers.",
625
605
  )
626
-
606
+
627
607
  layer_norm: Optional[str] = Field(
628
- default=None,
629
- description="Specifies how to normalize layers."
608
+ default=None, description="Specifies how to normalize layers."
630
609
  )
631
-
610
+
632
611
  inplace: bool = Field(
633
612
  default=True,
634
- description="Whether to update adata or return dictionary with normalized copies."
613
+ description="Whether to update adata or return dictionary with normalized copies.",
635
614
  )
636
-
637
- @field_validator('target_sum')
615
+
616
+ @field_validator("target_sum")
638
617
  def validate_target_sum(cls, v: Optional[float]) -> Optional[float]:
639
618
  """Validate target_sum is positive if provided"""
640
619
  if v is not None and v <= 0:
641
620
  raise ValueError("target_sum must be positive")
642
621
  return v
643
-
644
- @field_validator('max_fraction')
622
+
623
+ @field_validator("max_fraction")
645
624
  def validate_max_fraction(cls, v: float) -> float:
646
625
  """Validate max_fraction is between 0 and 1"""
647
626
  if v <= 0 or v > 1:
@@ -649,3 +628,273 @@ class NormalizeTotalParams(BaseModel):
649
628
  return v
650
629
 
651
630
 
631
+ class BBKNNParam(BaseModel):
632
+ """Input schema for the bbknn (batch balanced kNN) preprocessing tool."""
633
+
634
+ adata: str = Field(..., description="The AnnData object variable name.")
635
+
636
+ batch_key: str = Field(
637
+ default="batch",
638
+ description="adata.obs column name discriminating between your batches.",
639
+ )
640
+
641
+ use_rep: str = Field(
642
+ default="X_pca",
643
+ description="The dimensionality reduction in .obsm to use for neighbour detection. Defaults to PCA.",
644
+ )
645
+
646
+ approx: bool = Field(
647
+ default=True,
648
+ description="If True, use approximate neighbour finding - annoy or PyNNDescent. This results in a quicker run time for large datasets while also potentially increasing the degree of batch correction.",
649
+ )
650
+
651
+ use_annoy: bool = Field(
652
+ default=True,
653
+ description="Only used when approx=True. If True, will use annoy for neighbour finding. If False, will use pyNNDescent instead.",
654
+ )
655
+
656
+ metric: str = Field(
657
+ default="euclidean",
658
+ description="What distance metric to use. The options depend on the choice of neighbour algorithm.",
659
+ )
660
+
661
+ neighbors_within_batch: int = Field(
662
+ default=3,
663
+ description="How many top neighbours to report for each batch; total number of neighbours in the initial k-nearest-neighbours computation will be this number times the number of batches.",
664
+ gt=0,
665
+ )
666
+
667
+ n_pcs: int = Field(
668
+ default=50,
669
+ description="How many dimensions (in case of PCA, principal components) to use in the analysis.",
670
+ gt=0,
671
+ )
672
+
673
+ trim: Optional[int] = Field(
674
+ default=None,
675
+ description="Trim the neighbours of each cell to these many top connectivities. May help with population independence and improve the tidiness of clustering. If None, sets the parameter value automatically to 10 times neighbors_within_batch times the number of batches. Set to 0 to skip.",
676
+ ge=0,
677
+ )
678
+
679
+ annoy_n_trees: int = Field(
680
+ default=10,
681
+ description="Only used with annoy neighbour identification. The number of trees to construct in the annoy forest. More trees give higher precision when querying, at the cost of increased run time and resource intensity.",
682
+ gt=0,
683
+ )
684
+
685
+ pynndescent_n_neighbors: int = Field(
686
+ default=30,
687
+ description="Only used with pyNNDescent neighbour identification. The number of neighbours to include in the approximate neighbour graph. More neighbours give higher precision when querying, at the cost of increased run time and resource intensity.",
688
+ gt=0,
689
+ )
690
+
691
+ pynndescent_random_state: int = Field(
692
+ default=0,
693
+ description="Only used with pyNNDescent neighbour identification. The RNG seed to use when creating the graph.",
694
+ )
695
+
696
+ use_faiss: bool = Field(
697
+ default=True,
698
+ description="If approx=False and the metric is 'euclidean', use the faiss package to compute nearest neighbours if installed. This improves performance at a minor cost to numerical precision as faiss operates on float32.",
699
+ )
700
+
701
+ set_op_mix_ratio: float = Field(
702
+ default=1.0,
703
+ description="UMAP connectivity computation parameter, float between 0 and 1, controlling the blend between a connectivity matrix formed exclusively from mutual nearest neighbour pairs (0) and a union of all observed neighbour relationships with the mutual pairs emphasised (1).",
704
+ ge=0.0,
705
+ le=1.0,
706
+ )
707
+
708
+ local_connectivity: int = Field(
709
+ default=1,
710
+ description="UMAP connectivity computation parameter, how many nearest neighbors of each cell are assumed to be fully connected (and given a connectivity value of 1).",
711
+ gt=0,
712
+ )
713
+
714
+ @field_validator(
715
+ "neighbors_within_batch",
716
+ "n_pcs",
717
+ "annoy_n_trees",
718
+ "pynndescent_n_neighbors",
719
+ "local_connectivity",
720
+ )
721
+ def validate_positive_integers(cls, v: int) -> int:
722
+ """Validate positive integers"""
723
+ if v <= 0:
724
+ raise ValueError("must be a positive integer")
725
+ return v
726
+
727
+ @field_validator("trim")
728
+ def validate_trim(cls, v: Optional[int]) -> Optional[int]:
729
+ """Validate trim is non-negative if provided"""
730
+ if v is not None and v < 0:
731
+ raise ValueError("trim must be non-negative")
732
+ return v
733
+
734
+ @field_validator("set_op_mix_ratio")
735
+ def validate_set_op_mix_ratio(cls, v: float) -> float:
736
+ """Validate set_op_mix_ratio is between 0 and 1"""
737
+ if v < 0 or v > 1:
738
+ raise ValueError("set_op_mix_ratio must be between 0 and 1")
739
+ return v
740
+
741
+ @field_validator("metric")
742
+ def validate_metric(cls, v: str) -> str:
743
+ """Validate metric is supported"""
744
+ valid_metrics = [
745
+ "euclidean",
746
+ "l2",
747
+ "sqeuclidean",
748
+ "manhattan",
749
+ "taxicab",
750
+ "l1",
751
+ "chebyshev",
752
+ "linfinity",
753
+ "linfty",
754
+ "linf",
755
+ "minkowski",
756
+ "seuclidean",
757
+ "standardised_euclidean",
758
+ "wminkowski",
759
+ "angular",
760
+ "hamming",
761
+ ]
762
+ if v.lower() not in valid_metrics:
763
+ raise ValueError(f"metric must be one of {valid_metrics}")
764
+ return v.lower()
765
+
766
+
767
+ class HarmonyIntegrateParam(BaseModel):
768
+ """Input schema for the harmony_integrate preprocessing tool."""
769
+
770
+ adata: str = Field(..., description="The AnnData object variable name.")
771
+
772
+ key: Union[str, List[str]] = Field(
773
+ description="The name of the column in adata.obs that differentiates among experiments/batches. To integrate over two or more covariates, you can pass multiple column names as a list."
774
+ )
775
+
776
+ basis: str = Field(
777
+ default="X_pca",
778
+ description="The name of the field in adata.obsm where the PCA table is stored. Defaults to 'X_pca', which is the default for sc.pp.pca().",
779
+ )
780
+
781
+ adjusted_basis: str = Field(
782
+ default="X_pca_harmony",
783
+ description="The name of the field in adata.obsm where the adjusted PCA table will be stored after running this function. Defaults to X_pca_harmony.",
784
+ )
785
+
786
+ theta: float = Field(
787
+ default=2.0,
788
+ description="Diversity clustering penalty parameter. Theta = 0 does not encourage any diversity. Larger values of theta result in more diverse clusters.",
789
+ gt=0,
790
+ )
791
+
792
+ lambda_: float = Field(
793
+ default=1.0,
794
+ description="Ridge regression penalty parameter. Lambda = 0 gives no regularization. Larger values of lambda result in more regularization.",
795
+ ge=0,
796
+ )
797
+
798
+ sigma: float = Field(
799
+ default=0.1,
800
+ description="Width of soft kmeans clusters. Sigma scales the distance from a cell to cluster centroids. Larger values of sigma result in cells assigned to more clusters. Each cell is assigned to clusters with probability proportional to exp(-distance^2 / sigma^2).",
801
+ gt=0,
802
+ )
803
+
804
+ nclust: Optional[int] = Field(
805
+ default=None,
806
+ description="Number of clusters in Harmony. If None, estimated automatically. If provided, this overrides the theta parameter.",
807
+ gt=0,
808
+ )
809
+
810
+ tau: float = Field(
811
+ default=0.0,
812
+ description="Protection against overclustering small datasets with rare cell types. tau is the expected number of cells per cluster.",
813
+ ge=0,
814
+ )
815
+
816
+ block_size: float = Field(
817
+ default=0.05,
818
+ description="What proportion of cells to update during clustering. Between 0 to 1, e.g. 0.05 updates 5% of cells per iteration.",
819
+ gt=0,
820
+ le=1,
821
+ )
822
+
823
+ max_iter_harmony: int = Field(
824
+ default=10,
825
+ description="Maximum number of rounds to run Harmony. One round of Harmony involves one clustering and one correction step.",
826
+ gt=0,
827
+ )
828
+
829
+ max_iter_kmeans: int = Field(
830
+ default=20,
831
+ description="Maximum number of rounds to run clustering at each round of Harmony. If at least k < nclust clusters contain 1 or fewer cells, then stop early.",
832
+ gt=0,
833
+ )
834
+
835
+ epsilon_cluster: float = Field(
836
+ default=1e-5,
837
+ description="Convergence tolerance for clustering round of Harmony. Set to -Inf to never stop early.",
838
+ gt=0,
839
+ )
840
+
841
+ epsilon_harmony: float = Field(
842
+ default=1e-4,
843
+ description="Convergence tolerance for Harmony. Set to -Inf to never stop early.",
844
+ gt=0,
845
+ )
846
+
847
+ random_state: int = Field(default=0, description="Random seed for reproducibility.")
848
+
849
+ verbose: bool = Field(
850
+ default=True, description="Whether to print progress messages."
851
+ )
852
+
853
+ @field_validator("key")
854
+ def validate_key(cls, v: Union[str, List[str]]) -> Union[str, List[str]]:
855
+ """Ensure key is either a string or list of strings"""
856
+ if isinstance(v, str):
857
+ if not v.strip():
858
+ raise ValueError("key cannot be empty")
859
+ return v
860
+ elif isinstance(v, list) and all(
861
+ isinstance(item, str) and item.strip() for item in v
862
+ ):
863
+ return v
864
+ raise ValueError("key must be a non-empty string or list of non-empty strings")
865
+
866
+ @field_validator("basis", "adjusted_basis")
867
+ def validate_basis_names(cls, v: str) -> str:
868
+ """Validate basis names are not empty"""
869
+ if not v.strip():
870
+ raise ValueError("basis and adjusted_basis cannot be empty")
871
+ return v
872
+
873
+ @field_validator(
874
+ "theta",
875
+ "lambda_",
876
+ "sigma",
877
+ "tau",
878
+ "block_size",
879
+ "epsilon_cluster",
880
+ "epsilon_harmony",
881
+ )
882
+ def validate_positive_floats(cls, v: float) -> float:
883
+ """Validate positive floats"""
884
+ if v < 0:
885
+ raise ValueError("must be non-negative")
886
+ return v
887
+
888
+ @field_validator("max_iter_harmony", "max_iter_kmeans")
889
+ def validate_positive_integers(cls, v: int) -> int:
890
+ """Validate positive integers"""
891
+ if v <= 0:
892
+ raise ValueError("must be a positive integer")
893
+ return v
894
+
895
+ @field_validator("nclust")
896
+ def validate_nclust(cls, v: Optional[int]) -> Optional[int]:
897
+ """Validate nclust is positive if provided"""
898
+ if v is not None and v <= 0:
899
+ raise ValueError("nclust must be positive if provided")
900
+ return v