XspecT 0.5.0__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of XspecT might be problematic. Click here for more details.
- xspect/classify.py +61 -13
- xspect/definitions.py +61 -13
- xspect/download_models.py +10 -2
- xspect/file_io.py +115 -48
- xspect/filter_sequences.py +81 -29
- xspect/main.py +90 -39
- xspect/mlst_feature/mlst_helper.py +3 -0
- xspect/mlst_feature/pub_mlst_handler.py +43 -1
- xspect/model_management.py +84 -14
- xspect/models/probabilistic_filter_mlst_model.py +75 -37
- xspect/models/probabilistic_filter_model.py +201 -19
- xspect/models/probabilistic_filter_svm_model.py +106 -13
- xspect/models/probabilistic_single_filter_model.py +73 -9
- xspect/models/result.py +77 -10
- xspect/ncbi.py +48 -12
- xspect/train.py +19 -11
- xspect/web.py +68 -12
- xspect/xspect-web/dist/assets/index-Ceo58xui.css +1 -0
- xspect/xspect-web/dist/assets/{index-CMG4V7fZ.js → index-Dt_UlbgE.js} +82 -77
- xspect/xspect-web/dist/index.html +2 -2
- xspect/xspect-web/src/App.tsx +4 -2
- xspect/xspect-web/src/api.tsx +23 -1
- xspect/xspect-web/src/components/filter-form.tsx +16 -3
- xspect/xspect-web/src/components/filtering-result.tsx +65 -0
- xspect/xspect-web/src/components/result.tsx +2 -2
- xspect/xspect-web/src/types.tsx +5 -0
- {xspect-0.5.0.dist-info → xspect-0.5.2.dist-info}/METADATA +11 -5
- {xspect-0.5.0.dist-info → xspect-0.5.2.dist-info}/RECORD +32 -31
- {xspect-0.5.0.dist-info → xspect-0.5.2.dist-info}/WHEEL +1 -1
- xspect/xspect-web/dist/assets/index-jIKg1HIy.css +0 -1
- {xspect-0.5.0.dist-info → xspect-0.5.2.dist-info}/entry_points.txt +0 -0
- {xspect-0.5.0.dist-info → xspect-0.5.2.dist-info}/licenses/LICENSE +0 -0
- {xspect-0.5.0.dist-info → xspect-0.5.2.dist-info}/top_level.txt +0 -0
xspect/main.py
CHANGED
|
@@ -7,21 +7,19 @@ import uvicorn
|
|
|
7
7
|
from xspect import classify
|
|
8
8
|
from xspect.web import app
|
|
9
9
|
from xspect.download_models import download_test_models
|
|
10
|
-
from xspect
|
|
10
|
+
from xspect import filter_sequences
|
|
11
11
|
from xspect.train import train_from_directory, train_from_ncbi
|
|
12
12
|
from xspect.definitions import (
|
|
13
13
|
get_xspect_model_path,
|
|
14
14
|
)
|
|
15
|
-
from xspect.mlst_feature.mlst_helper import pick_scheme
|
|
15
|
+
from xspect.mlst_feature.mlst_helper import pick_scheme
|
|
16
16
|
from xspect.mlst_feature.pub_mlst_handler import PubMLSTHandler
|
|
17
17
|
from xspect.models.probabilistic_filter_mlst_model import (
|
|
18
18
|
ProbabilisticFilterMlstSchemeModel,
|
|
19
19
|
)
|
|
20
20
|
from xspect.model_management import (
|
|
21
|
-
get_genus_model,
|
|
22
21
|
get_model_metadata,
|
|
23
22
|
get_models,
|
|
24
|
-
get_species_model,
|
|
25
23
|
)
|
|
26
24
|
|
|
27
25
|
|
|
@@ -173,8 +171,9 @@ def train_mlst(choose_schemes):
|
|
|
173
171
|
scheme_path = pick_scheme(handler.get_scheme_paths())
|
|
174
172
|
species_name = str(scheme_path).split("/")[-2]
|
|
175
173
|
scheme_name = str(scheme_path).split("/")[-1]
|
|
174
|
+
scheme_url = handler.scheme_mapping[str(scheme_path)]
|
|
176
175
|
model = ProbabilisticFilterMlstSchemeModel(
|
|
177
|
-
31, f"{species_name}:{scheme_name}", get_xspect_model_path()
|
|
176
|
+
31, f"{species_name}:{scheme_name}", get_xspect_model_path(), scheme_url
|
|
178
177
|
)
|
|
179
178
|
click.echo("Creating mlst model")
|
|
180
179
|
model.fit(scheme_path)
|
|
@@ -211,19 +210,27 @@ def classify_seqs():
|
|
|
211
210
|
help="Path to FASTA or FASTQ file for classification.",
|
|
212
211
|
type=click.Path(exists=True, dir_okay=True, file_okay=True),
|
|
213
212
|
prompt=True,
|
|
213
|
+
default=Path("."),
|
|
214
214
|
)
|
|
215
215
|
@click.option(
|
|
216
216
|
"-o",
|
|
217
217
|
"--output-path",
|
|
218
218
|
help="Path to the output file.",
|
|
219
|
-
type=click.Path(dir_okay=
|
|
219
|
+
type=click.Path(dir_okay=False, file_okay=True),
|
|
220
220
|
default=Path(".") / f"result_{uuid4()}.json",
|
|
221
221
|
)
|
|
222
|
-
|
|
222
|
+
@click.option(
|
|
223
|
+
"--sparse-sampling-step",
|
|
224
|
+
type=int,
|
|
225
|
+
help="Sparse sampling step (e. g. only every 500th kmer for '--sparse-sampling-step 500').",
|
|
226
|
+
default=1,
|
|
227
|
+
)
|
|
228
|
+
def classify_genus(model_genus, input_path, output_path, sparse_sampling_step):
|
|
223
229
|
"""Classify samples using a genus model."""
|
|
224
230
|
click.echo("Classifying...")
|
|
225
|
-
classify.classify_genus(
|
|
226
|
-
|
|
231
|
+
classify.classify_genus(
|
|
232
|
+
model_genus, Path(input_path), Path(output_path), sparse_sampling_step
|
|
233
|
+
)
|
|
227
234
|
|
|
228
235
|
|
|
229
236
|
@classify_seqs.command(
|
|
@@ -244,12 +251,13 @@ def classify_genus(model_genus, input_path, output_path):
|
|
|
244
251
|
help="Path to FASTA or FASTQ file for classification.",
|
|
245
252
|
type=click.Path(exists=True, dir_okay=True, file_okay=True),
|
|
246
253
|
prompt=True,
|
|
254
|
+
default=Path("."),
|
|
247
255
|
)
|
|
248
256
|
@click.option(
|
|
249
257
|
"-o",
|
|
250
258
|
"--output-path",
|
|
251
259
|
help="Path to the output file.",
|
|
252
|
-
type=click.Path(dir_okay=
|
|
260
|
+
type=click.Path(dir_okay=False, file_okay=True),
|
|
253
261
|
default=Path(".") / f"result_{uuid4()}.json",
|
|
254
262
|
)
|
|
255
263
|
@click.option(
|
|
@@ -264,7 +272,6 @@ def classify_species(model_genus, input_path, output_path, sparse_sampling_step)
|
|
|
264
272
|
classify.classify_species(
|
|
265
273
|
model_genus, Path(input_path), Path(output_path), sparse_sampling_step
|
|
266
274
|
)
|
|
267
|
-
click.echo(f"Result saved as {output_path}.")
|
|
268
275
|
|
|
269
276
|
|
|
270
277
|
@classify_seqs.command(
|
|
@@ -277,19 +284,22 @@ def classify_species(model_genus, input_path, output_path, sparse_sampling_step)
|
|
|
277
284
|
help="Path to FASTA-file for mlst identification.",
|
|
278
285
|
type=click.Path(exists=True, dir_okay=True, file_okay=True),
|
|
279
286
|
prompt=True,
|
|
287
|
+
default=Path("."),
|
|
280
288
|
)
|
|
281
289
|
@click.option(
|
|
282
290
|
"-o",
|
|
283
291
|
"--output-path",
|
|
284
292
|
help="Path to the output file.",
|
|
285
|
-
type=click.Path(dir_okay=
|
|
286
|
-
default=Path(".") / f"
|
|
293
|
+
type=click.Path(dir_okay=False, file_okay=True),
|
|
294
|
+
default=Path(".") / f"MLST_result_{uuid4()}.json",
|
|
287
295
|
)
|
|
288
|
-
|
|
296
|
+
@click.option(
|
|
297
|
+
"-l", "--limit", is_flag=True, help="Limit the output to 5 results for each locus."
|
|
298
|
+
)
|
|
299
|
+
def classify_mlst(input_path, output_path, limit):
|
|
289
300
|
"""MLST classify a sample."""
|
|
290
301
|
click.echo("Classifying...")
|
|
291
|
-
classify.classify_mlst(Path(input_path), Path(output_path))
|
|
292
|
-
click.echo(f"Result saved as {output_path}.")
|
|
302
|
+
classify.classify_mlst(Path(input_path), Path(output_path), limit)
|
|
293
303
|
|
|
294
304
|
|
|
295
305
|
# # # # # # # # # # # # # # #
|
|
@@ -321,37 +331,54 @@ def filter_seqs():
|
|
|
321
331
|
help="Path to FASTA or FASTQ file for classification.",
|
|
322
332
|
type=click.Path(exists=True, dir_okay=True, file_okay=True),
|
|
323
333
|
prompt=True,
|
|
334
|
+
default=Path("."),
|
|
324
335
|
)
|
|
325
336
|
@click.option(
|
|
326
337
|
"-o",
|
|
327
338
|
"--output-path",
|
|
328
339
|
help="Path to the output file.",
|
|
329
|
-
type=click.Path(dir_okay=
|
|
340
|
+
type=click.Path(dir_okay=False, file_okay=True),
|
|
330
341
|
prompt=True,
|
|
342
|
+
default=Path(".") / f"genus_filtered_{uuid4()}.fasta",
|
|
331
343
|
)
|
|
332
344
|
@click.option(
|
|
345
|
+
"--classification-output-path",
|
|
346
|
+
help="Optional path to the classification output file.",
|
|
347
|
+
type=click.Path(dir_okay=False, file_okay=True),
|
|
348
|
+
)
|
|
349
|
+
@click.option(
|
|
350
|
+
"-t",
|
|
333
351
|
"--threshold",
|
|
334
|
-
type=
|
|
352
|
+
type=click.FloatRange(0, 1),
|
|
335
353
|
help="Threshold for filtering (default: 0.7).",
|
|
336
354
|
default=0.7,
|
|
337
355
|
prompt=True,
|
|
338
356
|
)
|
|
339
|
-
|
|
357
|
+
@click.option(
|
|
358
|
+
"--sparse-sampling-step",
|
|
359
|
+
type=int,
|
|
360
|
+
help="Sparse sampling step (e. g. only every 500th kmer for '--sparse-sampling-step 500').",
|
|
361
|
+
default=1,
|
|
362
|
+
)
|
|
363
|
+
def filter_genus(
|
|
364
|
+
model_genus,
|
|
365
|
+
input_path,
|
|
366
|
+
output_path,
|
|
367
|
+
classification_output_path,
|
|
368
|
+
threshold,
|
|
369
|
+
sparse_sampling_step,
|
|
370
|
+
):
|
|
340
371
|
"""Filter samples using a genus model."""
|
|
341
372
|
click.echo("Filtering...")
|
|
342
|
-
genus_model = get_genus_model(model_genus)
|
|
343
|
-
result = genus_model.predict(Path(input_path))
|
|
344
|
-
included_ids = result.get_filtered_subsequence_labels(model_genus, threshold)
|
|
345
|
-
if not included_ids:
|
|
346
|
-
click.echo("No sequences found for the given genus.")
|
|
347
|
-
return
|
|
348
373
|
|
|
349
|
-
filter_sequences(
|
|
374
|
+
filter_sequences.filter_genus(
|
|
375
|
+
model_genus,
|
|
350
376
|
Path(input_path),
|
|
351
377
|
Path(output_path),
|
|
352
|
-
|
|
378
|
+
threshold,
|
|
379
|
+
Path(classification_output_path) if classification_output_path else None,
|
|
380
|
+
sparse_sampling_step=sparse_sampling_step,
|
|
353
381
|
)
|
|
354
|
-
click.echo(f"Filtered sequences saved at {output_path}.")
|
|
355
382
|
|
|
356
383
|
|
|
357
384
|
@filter_seqs.command(
|
|
@@ -378,24 +405,51 @@ def filter_genus(model_genus, input_path, output_path, threshold):
|
|
|
378
405
|
help="Path to FASTA or FASTQ file for classification.",
|
|
379
406
|
type=click.Path(exists=True, dir_okay=True, file_okay=True),
|
|
380
407
|
prompt=True,
|
|
408
|
+
default=Path("."),
|
|
381
409
|
)
|
|
382
410
|
@click.option(
|
|
383
411
|
"-o",
|
|
384
412
|
"--output-path",
|
|
385
413
|
help="Path to the output file.",
|
|
386
|
-
type=click.Path(dir_okay=
|
|
414
|
+
type=click.Path(dir_okay=False, file_okay=True),
|
|
387
415
|
prompt=True,
|
|
416
|
+
default=Path(".") / f"species_filtered_{uuid4()}.fasta",
|
|
388
417
|
)
|
|
389
418
|
@click.option(
|
|
419
|
+
"--classification-output-path",
|
|
420
|
+
help="Optional path to the classification output file.",
|
|
421
|
+
type=click.Path(dir_okay=False, file_okay=True),
|
|
422
|
+
)
|
|
423
|
+
@click.option(
|
|
424
|
+
"-t",
|
|
390
425
|
"--threshold",
|
|
391
426
|
type=float,
|
|
392
427
|
help="Threshold for filtering (default: 0.7). Use -1 to filter for the highest scoring species.",
|
|
393
428
|
default=0.7,
|
|
394
429
|
prompt=True,
|
|
395
430
|
)
|
|
396
|
-
|
|
431
|
+
@click.option(
|
|
432
|
+
"--sparse-sampling-step",
|
|
433
|
+
type=int,
|
|
434
|
+
help="Sparse sampling step (e. g. only every 500th kmer for '--sparse-sampling-step 500').",
|
|
435
|
+
default=1,
|
|
436
|
+
)
|
|
437
|
+
def filter_species(
|
|
438
|
+
model_genus,
|
|
439
|
+
model_species,
|
|
440
|
+
input_path,
|
|
441
|
+
output_path,
|
|
442
|
+
threshold,
|
|
443
|
+
classification_output_path,
|
|
444
|
+
sparse_sampling_step,
|
|
445
|
+
):
|
|
397
446
|
"""Filter a sample using the species model."""
|
|
398
447
|
|
|
448
|
+
if threshold != -1 and (threshold < 0 or threshold > 1):
|
|
449
|
+
raise click.BadParameter(
|
|
450
|
+
"Threshold must be between 0 and 1, or -1 for filtering by the highest scoring species."
|
|
451
|
+
)
|
|
452
|
+
|
|
399
453
|
available_species = get_model_metadata(f"{model_genus}-species")["display_names"]
|
|
400
454
|
available_species = {
|
|
401
455
|
id: name.replace(f"{model_genus} ", "")
|
|
@@ -420,18 +474,15 @@ def filter_species(model_genus, model_species, input_path, output_path, threshol
|
|
|
420
474
|
][0]
|
|
421
475
|
|
|
422
476
|
click.echo("Filtering...")
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
if not included_ids:
|
|
427
|
-
click.echo("No sequences found for the given species.")
|
|
428
|
-
return
|
|
429
|
-
filter_sequences(
|
|
477
|
+
filter_sequences.filter_species(
|
|
478
|
+
model_genus,
|
|
479
|
+
model_species,
|
|
430
480
|
Path(input_path),
|
|
431
481
|
Path(output_path),
|
|
432
|
-
|
|
482
|
+
threshold,
|
|
483
|
+
Path(classification_output_path) if classification_output_path else None,
|
|
484
|
+
sparse_sampling_step=sparse_sampling_step,
|
|
433
485
|
)
|
|
434
|
-
click.echo(f"Filtered sequences saved at {output_path}.")
|
|
435
486
|
|
|
436
487
|
|
|
437
488
|
if __name__ == "__main__":
|
|
@@ -194,11 +194,13 @@ class MlstResult:
|
|
|
194
194
|
scheme_model: str,
|
|
195
195
|
steps: int,
|
|
196
196
|
hits: dict[str, list[dict]],
|
|
197
|
+
input_source: str = None,
|
|
197
198
|
):
|
|
198
199
|
"""Initialise an MlstResult object."""
|
|
199
200
|
self.scheme_model = scheme_model
|
|
200
201
|
self.steps = steps
|
|
201
202
|
self.hits = hits
|
|
203
|
+
self.input_source = input_source
|
|
202
204
|
|
|
203
205
|
def get_results(self) -> dict:
|
|
204
206
|
"""
|
|
@@ -221,6 +223,7 @@ class MlstResult:
|
|
|
221
223
|
"Scheme": self.scheme_model,
|
|
222
224
|
"Steps": self.steps,
|
|
223
225
|
"Results": self.get_results(),
|
|
226
|
+
"Input_source": self.input_source,
|
|
224
227
|
}
|
|
225
228
|
return result
|
|
226
229
|
|
|
@@ -17,7 +17,7 @@ from xspect.definitions import get_xspect_mlst_path, get_xspect_upload_path
|
|
|
17
17
|
class PubMLSTHandler:
|
|
18
18
|
"""Class for communicating with PubMLST and downloading alleles (FASTA-Format) from all loci."""
|
|
19
19
|
|
|
20
|
-
base_url = "
|
|
20
|
+
base_url = "https://rest.pubmlst.org/db"
|
|
21
21
|
|
|
22
22
|
def __init__(self):
|
|
23
23
|
"""Initialise a PubMLSTHandler object."""
|
|
@@ -27,6 +27,7 @@ class PubMLSTHandler:
|
|
|
27
27
|
self.base_url + "/pubmlst_abaumannii_seqdef/schemes/2",
|
|
28
28
|
]
|
|
29
29
|
self.scheme_paths = []
|
|
30
|
+
self.scheme_mapping = {}
|
|
30
31
|
|
|
31
32
|
def get_scheme_paths(self) -> dict:
|
|
32
33
|
"""
|
|
@@ -103,6 +104,7 @@ class PubMLSTHandler:
|
|
|
103
104
|
|
|
104
105
|
species_name = scheme.split("_")[1] # name = pubmlst_abaumannii_seqdef
|
|
105
106
|
scheme_path = get_xspect_mlst_path() / species_name / scheme_name
|
|
107
|
+
self.scheme_mapping[str(scheme_path)] = scheme
|
|
106
108
|
self.scheme_paths.append(scheme_path)
|
|
107
109
|
|
|
108
110
|
for locus_url in locus_list:
|
|
@@ -143,3 +145,43 @@ class PubMLSTHandler:
|
|
|
143
145
|
# Example: 'Pas_fusA': [{'href': some URL, 'allele_id': '2'}]
|
|
144
146
|
print(locus + ":" + meta_data[0]["allele_id"], end="; ")
|
|
145
147
|
print("\nStrain Type:", response["fields"])
|
|
148
|
+
|
|
149
|
+
def get_strain_type_name(self, highest_results: dict, post_url: str) -> str:
|
|
150
|
+
"""
|
|
151
|
+
Send an API-POST request to PubMLST with the highest result of each locus as payload.
|
|
152
|
+
|
|
153
|
+
This function formats the highest_result dict into an accepted input for the request.
|
|
154
|
+
It gets a response from the site which is the strain type name.
|
|
155
|
+
The name is based on the allele id with the highest score for each locus.
|
|
156
|
+
Example of post_url for the oxford scheme of A.baumannii:
|
|
157
|
+
https://rest.pubmlst.org/db/pubmlst_abaumannii_seqdef/schemes/1/designations
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
highest_results (dict): The allele ids with the highest kmer matches.
|
|
161
|
+
post_url (str): The specific url for the scheme of a species
|
|
162
|
+
|
|
163
|
+
Returns:
|
|
164
|
+
str: The response (ST name or No ST found) of the POST request.
|
|
165
|
+
"""
|
|
166
|
+
payload = {
|
|
167
|
+
"designations": {
|
|
168
|
+
locus: [{"allele": str(allele)}]
|
|
169
|
+
for locus, allele in highest_results.items()
|
|
170
|
+
}
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
response = requests.post(post_url + "/designations", json=payload)
|
|
174
|
+
|
|
175
|
+
if response.status_code == 200:
|
|
176
|
+
data = response.json()
|
|
177
|
+
if "fields" in data:
|
|
178
|
+
post_response = data["fields"]
|
|
179
|
+
return post_response
|
|
180
|
+
else:
|
|
181
|
+
post_response = "No matching Strain Type found in the database. "
|
|
182
|
+
post_response += "Possibly a novel Strain Type."
|
|
183
|
+
return post_response
|
|
184
|
+
else:
|
|
185
|
+
post_response = "Error:" + str(response.status_code)
|
|
186
|
+
post_response += response.text
|
|
187
|
+
return post_response
|
xspect/model_management.py
CHANGED
|
@@ -9,22 +9,55 @@ from xspect.models.probabilistic_filter_svm_model import ProbabilisticFilterSVMM
|
|
|
9
9
|
from xspect.definitions import get_xspect_model_path
|
|
10
10
|
|
|
11
11
|
|
|
12
|
-
def get_genus_model(genus):
|
|
13
|
-
"""
|
|
12
|
+
def get_genus_model(genus) -> ProbabilisticSingleFilterModel:
|
|
13
|
+
"""
|
|
14
|
+
Get a genus model for the specified genus.
|
|
15
|
+
|
|
16
|
+
This function retrieves a pre-trained genus classification model based on the provided genus name.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
genus (str): The genus name for which the model is to be retrieved.
|
|
20
|
+
|
|
21
|
+
Returns:
|
|
22
|
+
ProbabilisticSingleFilterModel: An instance of the genus classification model.
|
|
23
|
+
"""
|
|
14
24
|
genus_model_path = get_xspect_model_path() / (genus.lower() + "-genus.json")
|
|
15
25
|
genus_filter_model = ProbabilisticSingleFilterModel.load(genus_model_path)
|
|
16
26
|
return genus_filter_model
|
|
17
27
|
|
|
18
28
|
|
|
19
|
-
def get_species_model(genus):
|
|
20
|
-
"""
|
|
29
|
+
def get_species_model(genus) -> ProbabilisticFilterSVMModel:
|
|
30
|
+
"""
|
|
31
|
+
Get a species classification model for the specified genus.
|
|
32
|
+
|
|
33
|
+
This function retrieves a pre-trained species classification model based on the provided genus name.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
genus (str): The genus name for which the species model is to be retrieved.
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
ProbabilisticFilterSVMModel: An instance of the species classification model.
|
|
40
|
+
"""
|
|
21
41
|
species_model_path = get_xspect_model_path() / (genus.lower() + "-species.json")
|
|
22
42
|
species_filter_model = ProbabilisticFilterSVMModel.load(species_model_path)
|
|
23
43
|
return species_filter_model
|
|
24
44
|
|
|
25
45
|
|
|
26
|
-
def get_model_metadata(model: str | Path):
|
|
27
|
-
"""
|
|
46
|
+
def get_model_metadata(model: str | Path) -> dict:
|
|
47
|
+
"""
|
|
48
|
+
Get metadata of a specified model.
|
|
49
|
+
|
|
50
|
+
This function retrieves the metadata of a model from its JSON file.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
model (str | Path): The slug of the model (as a string) or the path to the model JSON file.
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
dict: A dictionary containing the model metadata.
|
|
57
|
+
|
|
58
|
+
Raises:
|
|
59
|
+
ValueError: If the model does not exist or is not a valid file.
|
|
60
|
+
"""
|
|
28
61
|
if isinstance(model, str):
|
|
29
62
|
model_path = get_xspect_model_path() / (model.lower() + ".json")
|
|
30
63
|
elif isinstance(model, Path):
|
|
@@ -40,8 +73,17 @@ def get_model_metadata(model: str | Path):
|
|
|
40
73
|
return model_json
|
|
41
74
|
|
|
42
75
|
|
|
43
|
-
def update_model_metadata(model_slug: str, author: str, author_email: str):
|
|
44
|
-
"""
|
|
76
|
+
def update_model_metadata(model_slug: str, author: str, author_email: str) -> None:
|
|
77
|
+
"""
|
|
78
|
+
Update the metadata of a model.
|
|
79
|
+
|
|
80
|
+
This function updates the author and author email in the model's metadata JSON file.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
model_slug (str): The slug of the model to update.
|
|
84
|
+
author (str): The name of the author to set in the metadata.
|
|
85
|
+
author_email (str): The email of the author to set in the metadata.
|
|
86
|
+
"""
|
|
45
87
|
model_metadata = get_model_metadata(model_slug)
|
|
46
88
|
model_metadata["author"] = author
|
|
47
89
|
model_metadata["author_email"] = author_email
|
|
@@ -51,8 +93,19 @@ def update_model_metadata(model_slug: str, author: str, author_email: str):
|
|
|
51
93
|
file.write(dumps(model_metadata, indent=4))
|
|
52
94
|
|
|
53
95
|
|
|
54
|
-
def update_model_display_name(
|
|
55
|
-
|
|
96
|
+
def update_model_display_name(
|
|
97
|
+
model_slug: str, filter_id: str, display_name: str
|
|
98
|
+
) -> None:
|
|
99
|
+
"""
|
|
100
|
+
Update the display name of a filter in a model.
|
|
101
|
+
|
|
102
|
+
This function updates the display name of a specific filter in the model's metadata JSON file.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
model_slug (str): The slug of the model to update.
|
|
106
|
+
filter_id (str): The ID of the filter whose display name is to be updated.
|
|
107
|
+
display_name (str): The new display name for the filter.
|
|
108
|
+
"""
|
|
56
109
|
model_metadata = get_model_metadata(model_slug)
|
|
57
110
|
model_metadata["display_names"][filter_id] = display_name
|
|
58
111
|
|
|
@@ -61,8 +114,15 @@ def update_model_display_name(model_slug: str, filter_id: str, display_name: str
|
|
|
61
114
|
file.write(dumps(model_metadata, indent=4))
|
|
62
115
|
|
|
63
116
|
|
|
64
|
-
def get_models():
|
|
65
|
-
"""
|
|
117
|
+
def get_models() -> dict[str, list[dict]]:
|
|
118
|
+
"""
|
|
119
|
+
Get a list of all available models in a dictionary by type.
|
|
120
|
+
|
|
121
|
+
This function scans the model directory for JSON files and organizes them by their model type.
|
|
122
|
+
|
|
123
|
+
Returns:
|
|
124
|
+
dict[str, list[dict]]: A dictionary where keys are model types and values are lists of model display names.
|
|
125
|
+
"""
|
|
66
126
|
model_dict = {}
|
|
67
127
|
for model_file in get_xspect_model_path().glob("*.json"):
|
|
68
128
|
model_metadata = get_model_metadata(model_file)
|
|
@@ -73,7 +133,17 @@ def get_models():
|
|
|
73
133
|
return model_dict
|
|
74
134
|
|
|
75
135
|
|
|
76
|
-
def get_model_display_names(model_slug: str):
|
|
77
|
-
"""
|
|
136
|
+
def get_model_display_names(model_slug: str) -> list[str]:
|
|
137
|
+
"""
|
|
138
|
+
Get the display names included in a model.
|
|
139
|
+
|
|
140
|
+
This function retrieves the display names of individual filters from the model's metadata.
|
|
141
|
+
|
|
142
|
+
Args:
|
|
143
|
+
model_slug (str): The slug of the model for which to retrieve display names.
|
|
144
|
+
|
|
145
|
+
Returns:
|
|
146
|
+
list[str]: A list of display names for the individual filters in the model.
|
|
147
|
+
"""
|
|
78
148
|
model_metadata = get_model_metadata(model_slug)
|
|
79
149
|
return list(model_metadata["display_names"].values())
|