medcheck 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- medcheck/__init__.py +3 -0
- medcheck/core/__init__.py +1 -0
- medcheck/core/config.py +28 -0
- medcheck/core/context.py +90 -0
- medcheck/core/step.py +17 -0
- medcheck/core/workflow.py +117 -0
- medcheck/llm/__init__.py +1 -0
- medcheck/llm/base.py +74 -0
- medcheck/llm/claude.py +61 -0
- medcheck/llm/gemini.py +46 -0
- medcheck/llm/openai_provider.py +57 -0
- medcheck/llm/router.py +43 -0
- medcheck/main.py +173 -0
- medcheck/models/__init__.py +1 -0
- medcheck/pipeline/__init__.py +1 -0
- medcheck/pipeline/ingest.py +66 -0
- medcheck/pipeline/ml_analysis.py +143 -0
- medcheck/pipeline/preprocess.py +129 -0
- medcheck/pipeline/report.py +325 -0
- medcheck/pipeline/vision_analysis.py +225 -0
- medcheck/prompts/anatomy/knee.txt +44 -0
- medcheck/prompts/anatomy/shoulder.txt +9 -0
- medcheck/prompts/anatomy/spine.txt +9 -0
- medcheck/prompts/report_schema.json +39 -0
- medcheck/prompts/system.txt +19 -0
- medcheck/providers/__init__.py +1 -0
- medcheck/providers/base.py +17 -0
- medcheck/providers/easyradiology.py +116 -0
- medcheck/providers/local.py +96 -0
- medcheck/providers/registry.py +46 -0
- medcheck/py.typed +0 -0
- medcheck/web/app.py +39 -0
- medcheck/web/static/.gitkeep +0 -0
- medcheck/web/templates/index.html +726 -0
- medcheck-0.1.0.dist-info/METADATA +275 -0
- medcheck-0.1.0.dist-info/RECORD +39 -0
- medcheck-0.1.0.dist-info/WHEEL +4 -0
- medcheck-0.1.0.dist-info/entry_points.txt +2 -0
- medcheck-0.1.0.dist-info/licenses/LICENSE +195 -0
medcheck/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Core functionality for MedCheck."""
|
medcheck/core/config.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
"""Application configuration via environment variables."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class Settings:
|
|
11
|
+
host: str = field(default_factory=lambda: os.environ.get("MEDCHECK_HOST", "0.0.0.0")) # nosec B104
|
|
12
|
+
port: int = field(default_factory=lambda: int(os.environ.get("MEDCHECK_PORT", "8080")))
|
|
13
|
+
default_llm_provider: str = field(default_factory=lambda: os.environ.get("MEDCHECK_LLM_PROVIDER", "claude"))
|
|
14
|
+
default_language: str = field(default_factory=lambda: os.environ.get("MEDCHECK_LANGUAGE", "en"))
|
|
15
|
+
anthropic_api_key: str | None = field(default_factory=lambda: os.environ.get("ANTHROPIC_API_KEY"))
|
|
16
|
+
openai_api_key: str | None = field(default_factory=lambda: os.environ.get("OPENAI_API_KEY"))
|
|
17
|
+
google_api_key: str | None = field(default_factory=lambda: os.environ.get("GOOGLE_API_KEY"))
|
|
18
|
+
|
|
19
|
+
def available_llm_providers(self) -> list[str]:
|
|
20
|
+
providers: list[str] = []
|
|
21
|
+
if self.anthropic_api_key:
|
|
22
|
+
providers.append("claude")
|
|
23
|
+
if self.openai_api_key:
|
|
24
|
+
providers.append("openai")
|
|
25
|
+
if self.google_api_key:
|
|
26
|
+
providers.append("gemini")
|
|
27
|
+
providers.append("local")
|
|
28
|
+
return providers
|
medcheck/core/context.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@dataclass
|
|
8
|
+
class PatientInfo:
|
|
9
|
+
name: str = ""
|
|
10
|
+
patient_id: str = ""
|
|
11
|
+
birth_date: str = ""
|
|
12
|
+
sex: str = ""
|
|
13
|
+
age: str = ""
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class StudyInfo:
|
|
18
|
+
date: str = ""
|
|
19
|
+
description: str = ""
|
|
20
|
+
institution: str = ""
|
|
21
|
+
manufacturer: str = ""
|
|
22
|
+
model_name: str = ""
|
|
23
|
+
field_strength: str = ""
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass
|
|
27
|
+
class DicomSeries:
|
|
28
|
+
description: str = ""
|
|
29
|
+
plane: str = ""
|
|
30
|
+
modality: str = ""
|
|
31
|
+
series_number: int = 0
|
|
32
|
+
slices: list[Any] = field(default_factory=list)
|
|
33
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dataclass
|
|
37
|
+
class SignalStats:
|
|
38
|
+
mean_intensity: list[float] = field(default_factory=list)
|
|
39
|
+
max_intensity: list[float] = field(default_factory=list)
|
|
40
|
+
high_signal_ratio: list[float] = field(default_factory=list)
|
|
41
|
+
high_signal_slices: list[int] = field(default_factory=list)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@dataclass
|
|
45
|
+
class StructureFinding:
|
|
46
|
+
name: str = ""
|
|
47
|
+
status: str = ""
|
|
48
|
+
findings: str = ""
|
|
49
|
+
confidence: float = 0.0
|
|
50
|
+
slices_evaluated: int = 0
|
|
51
|
+
secondary_signs: list[str] = field(default_factory=list)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@dataclass
|
|
55
|
+
class ClinicalContext:
|
|
56
|
+
symptoms: str = ""
|
|
57
|
+
trauma: str = ""
|
|
58
|
+
trauma_date: str = ""
|
|
59
|
+
suspected_diagnosis: str = ""
|
|
60
|
+
patient_age: int = 0
|
|
61
|
+
patient_sex: str = ""
|
|
62
|
+
anatomy: str = ""
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
@dataclass
|
|
66
|
+
class PipelineContext:
|
|
67
|
+
dicom_series: list[DicomSeries] = field(default_factory=list)
|
|
68
|
+
patient: PatientInfo = field(default_factory=PatientInfo)
|
|
69
|
+
# Ingest config
|
|
70
|
+
source: str = ""
|
|
71
|
+
provider_name: str = "local"
|
|
72
|
+
credentials: dict[str, str] = field(default_factory=dict)
|
|
73
|
+
study: StudyInfo = field(default_factory=StudyInfo)
|
|
74
|
+
volumes: dict[str, Any] = field(default_factory=dict)
|
|
75
|
+
detected_anatomy: str | None = None
|
|
76
|
+
detected_planes: dict[str, Any] = field(default_factory=dict)
|
|
77
|
+
clinical_context: ClinicalContext | None = None
|
|
78
|
+
anomaly_scores: dict[str, Any] = field(default_factory=dict)
|
|
79
|
+
top_slices: dict[str, Any] = field(default_factory=dict)
|
|
80
|
+
signal_analysis: dict[str, Any] = field(default_factory=dict)
|
|
81
|
+
annotated_images: dict[str, Any] = field(default_factory=dict)
|
|
82
|
+
findings: list[StructureFinding] = field(default_factory=list)
|
|
83
|
+
overall_impression: str = ""
|
|
84
|
+
clinical_correlation: str = ""
|
|
85
|
+
limitations: list[str] = field(default_factory=list)
|
|
86
|
+
report_path: str = ""
|
|
87
|
+
report_format: str = "json"
|
|
88
|
+
report_language: str = "en"
|
|
89
|
+
output_dir: str = ""
|
|
90
|
+
step_config: dict[str, Any] = field(default_factory=dict)
|
medcheck/core/step.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
|
|
5
|
+
from medcheck.core.context import PipelineContext
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class PipelineStep(ABC):
|
|
9
|
+
name: str = ""
|
|
10
|
+
|
|
11
|
+
@abstractmethod
|
|
12
|
+
def run(self, context: PipelineContext) -> PipelineContext:
|
|
13
|
+
"""Execute this step and return the (modified) context."""
|
|
14
|
+
|
|
15
|
+
def validate(self, context: PipelineContext) -> bool:
|
|
16
|
+
"""Validate preconditions before running. Returns True by default."""
|
|
17
|
+
return True
|
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import yaml
|
|
6
|
+
from rich.console import Console
|
|
7
|
+
|
|
8
|
+
from medcheck.core.context import PipelineContext
|
|
9
|
+
from medcheck.core.step import PipelineStep
|
|
10
|
+
|
|
11
|
+
console = Console()
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class StepRegistry:
|
|
15
|
+
"""Registry that maps step names to PipelineStep subclasses."""
|
|
16
|
+
|
|
17
|
+
def __init__(self) -> None:
|
|
18
|
+
self._steps: dict[str, type[PipelineStep]] = {}
|
|
19
|
+
|
|
20
|
+
def register(self, name: str, step_class: type[PipelineStep]) -> None:
|
|
21
|
+
"""Register a step class under the given name."""
|
|
22
|
+
self._steps[name] = step_class
|
|
23
|
+
|
|
24
|
+
def get(self, name: str) -> type[PipelineStep]:
|
|
25
|
+
"""Return the step class for *name*, raising KeyError if not found."""
|
|
26
|
+
if name not in self._steps:
|
|
27
|
+
raise KeyError(name)
|
|
28
|
+
return self._steps[name]
|
|
29
|
+
|
|
30
|
+
def list_steps(self) -> list[str]:
|
|
31
|
+
"""Return registered step names in insertion order."""
|
|
32
|
+
return list(self._steps.keys())
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class WorkflowEngine:
|
|
36
|
+
"""Orchestrates sequential execution of pipeline steps."""
|
|
37
|
+
|
|
38
|
+
def __init__(self, registry: StepRegistry) -> None:
|
|
39
|
+
self.registry = registry
|
|
40
|
+
|
|
41
|
+
def run(
|
|
42
|
+
self,
|
|
43
|
+
steps: list[str],
|
|
44
|
+
context: PipelineContext,
|
|
45
|
+
step_configs: dict[str, Any] | None = None,
|
|
46
|
+
) -> PipelineContext:
|
|
47
|
+
"""Instantiate and run each step in *steps* order.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
steps: Ordered list of step names to execute.
|
|
51
|
+
context: The shared pipeline context passed through all steps.
|
|
52
|
+
step_configs: Optional per-step configuration dicts keyed by name.
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
The (potentially mutated) context after all steps have run.
|
|
56
|
+
|
|
57
|
+
Raises:
|
|
58
|
+
KeyError: If a step name is not found in the registry.
|
|
59
|
+
"""
|
|
60
|
+
step_configs = step_configs or {}
|
|
61
|
+
|
|
62
|
+
for name in steps:
|
|
63
|
+
step_class = self.registry.get(name) # raises KeyError if unknown
|
|
64
|
+
step_instance = step_class()
|
|
65
|
+
console.print(f"[bold blue]▶ Running step:[/bold blue] {name}")
|
|
66
|
+
if not step_instance.validate(context):
|
|
67
|
+
console.print(f"[yellow]Skipping {name}: prerequisites not met[/yellow]")
|
|
68
|
+
continue
|
|
69
|
+
context.step_config = step_configs.get(name, {})
|
|
70
|
+
context = step_instance.run(context)
|
|
71
|
+
console.print(f"[bold green]✔ Completed step:[/bold green] {name}")
|
|
72
|
+
|
|
73
|
+
return context
|
|
74
|
+
|
|
75
|
+
def run_from_yaml(
|
|
76
|
+
self,
|
|
77
|
+
yaml_path: str,
|
|
78
|
+
context: PipelineContext,
|
|
79
|
+
) -> PipelineContext:
|
|
80
|
+
"""Load a workflow YAML file and run its steps.
|
|
81
|
+
|
|
82
|
+
Expected YAML format::
|
|
83
|
+
|
|
84
|
+
name: my_workflow
|
|
85
|
+
steps:
|
|
86
|
+
- step_name: # value may be null or a config dict
|
|
87
|
+
- another_step:
|
|
88
|
+
param: value
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
yaml_path: Path to the YAML workflow definition file.
|
|
92
|
+
context: The shared pipeline context.
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
The context after all steps have run.
|
|
96
|
+
"""
|
|
97
|
+
with open(yaml_path, encoding="utf-8") as fh:
|
|
98
|
+
workflow_def = yaml.safe_load(fh)
|
|
99
|
+
|
|
100
|
+
raw_steps: list[Any] = workflow_def.get("steps", [])
|
|
101
|
+
|
|
102
|
+
step_names: list[str] = []
|
|
103
|
+
step_configs: dict[str, Any] = {}
|
|
104
|
+
|
|
105
|
+
for entry in raw_steps:
|
|
106
|
+
if isinstance(entry, str):
|
|
107
|
+
step_names.append(entry)
|
|
108
|
+
elif isinstance(entry, dict):
|
|
109
|
+
for step_name, cfg in entry.items():
|
|
110
|
+
step_names.append(step_name)
|
|
111
|
+
if cfg is not None:
|
|
112
|
+
step_configs[step_name] = cfg
|
|
113
|
+
|
|
114
|
+
workflow_name = workflow_def.get("name", yaml_path)
|
|
115
|
+
console.print(f"[bold cyan]Workflow:[/bold cyan] {workflow_name}")
|
|
116
|
+
|
|
117
|
+
return self.run(steps=step_names, context=context, step_configs=step_configs)
|
medcheck/llm/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""LLM client and interaction utilities."""
|
medcheck/llm/base.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from medcheck.core.context import ClinicalContext, StructureFinding
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class AnnotatedImage:
|
|
13
|
+
series_name: str
|
|
14
|
+
slice_index: int
|
|
15
|
+
image_bytes: bytes
|
|
16
|
+
description: str
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass
|
|
20
|
+
class AnalysisResult:
|
|
21
|
+
overall_impression: str = ""
|
|
22
|
+
raw_response: str = ""
|
|
23
|
+
structures: list[StructureFinding] = field(default_factory=list)
|
|
24
|
+
clinical_correlation: str = ""
|
|
25
|
+
limitations: list[str] = field(default_factory=list)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def parse_llm_json(raw: str) -> dict[str, Any]:
|
|
29
|
+
"""Extract and parse the first valid JSON object from *raw* using brace-depth tracking."""
|
|
30
|
+
start = raw.index("{")
|
|
31
|
+
depth = 0
|
|
32
|
+
for i, ch in enumerate(raw[start:], start):
|
|
33
|
+
if ch == "{":
|
|
34
|
+
depth += 1
|
|
35
|
+
elif ch == "}":
|
|
36
|
+
depth -= 1
|
|
37
|
+
if depth == 0:
|
|
38
|
+
result: dict[str, Any] = json.loads(raw[start : i + 1])
|
|
39
|
+
return result
|
|
40
|
+
raise ValueError("No valid JSON found")
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def parse_llm_response(raw: str) -> AnalysisResult:
|
|
44
|
+
"""Shared response parser used by all LLM providers."""
|
|
45
|
+
try:
|
|
46
|
+
data = parse_llm_json(raw)
|
|
47
|
+
structures = [StructureFinding(**s) for s in data.get("structures", [])]
|
|
48
|
+
return AnalysisResult(
|
|
49
|
+
structures=structures,
|
|
50
|
+
overall_impression=data.get("overall_impression", ""),
|
|
51
|
+
clinical_correlation=data.get("clinical_correlation", ""),
|
|
52
|
+
limitations=data.get("limitations", []),
|
|
53
|
+
raw_response=raw,
|
|
54
|
+
)
|
|
55
|
+
except (ValueError, json.JSONDecodeError, TypeError):
|
|
56
|
+
return AnalysisResult(overall_impression=raw, raw_response=raw)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class LLMProvider(ABC):
|
|
60
|
+
name: str = ""
|
|
61
|
+
supports_vision: bool = False
|
|
62
|
+
|
|
63
|
+
@abstractmethod
|
|
64
|
+
def check_available(self) -> bool:
|
|
65
|
+
"""Return True if this provider is configured and reachable."""
|
|
66
|
+
|
|
67
|
+
@abstractmethod
|
|
68
|
+
def analyze_images(
|
|
69
|
+
self,
|
|
70
|
+
images: list[AnnotatedImage],
|
|
71
|
+
prompt: str,
|
|
72
|
+
context: ClinicalContext | None,
|
|
73
|
+
) -> AnalysisResult:
|
|
74
|
+
"""Run image analysis and return structured results."""
|
medcheck/llm/claude.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import base64
|
|
4
|
+
import os
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from medcheck.core.context import ClinicalContext
|
|
8
|
+
from medcheck.llm.base import AnalysisResult, AnnotatedImage, LLMProvider, parse_llm_response
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ClaudeProvider(LLMProvider):
|
|
12
|
+
"""Anthropic Claude provider."""
|
|
13
|
+
|
|
14
|
+
name = "claude"
|
|
15
|
+
supports_vision = True
|
|
16
|
+
|
|
17
|
+
def __init__(self, model: str = "claude-opus-4-7") -> None:
|
|
18
|
+
self.model = model
|
|
19
|
+
|
|
20
|
+
def check_available(self) -> bool:
|
|
21
|
+
return bool(os.environ.get("ANTHROPIC_API_KEY"))
|
|
22
|
+
|
|
23
|
+
def analyze_images(
|
|
24
|
+
self,
|
|
25
|
+
images: list[AnnotatedImage],
|
|
26
|
+
prompt: str,
|
|
27
|
+
context: ClinicalContext | None,
|
|
28
|
+
) -> AnalysisResult:
|
|
29
|
+
import anthropic
|
|
30
|
+
|
|
31
|
+
api_key = os.environ.get("ANTHROPIC_API_KEY")
|
|
32
|
+
if not api_key:
|
|
33
|
+
raise RuntimeError("ANTHROPIC_API_KEY not set")
|
|
34
|
+
client = anthropic.Anthropic(api_key=api_key)
|
|
35
|
+
|
|
36
|
+
content: list[dict[str, Any]] = []
|
|
37
|
+
for img in images:
|
|
38
|
+
b64 = base64.standard_b64encode(img.image_bytes).decode()
|
|
39
|
+
content.append(
|
|
40
|
+
{
|
|
41
|
+
"type": "image",
|
|
42
|
+
"source": {
|
|
43
|
+
"type": "base64",
|
|
44
|
+
"media_type": "image/png",
|
|
45
|
+
"data": b64,
|
|
46
|
+
},
|
|
47
|
+
}
|
|
48
|
+
)
|
|
49
|
+
if img.description:
|
|
50
|
+
content.append({"type": "text", "text": img.description})
|
|
51
|
+
|
|
52
|
+
content.append({"type": "text", "text": prompt})
|
|
53
|
+
|
|
54
|
+
message = client.messages.create(
|
|
55
|
+
model=self.model,
|
|
56
|
+
max_tokens=4096,
|
|
57
|
+
messages=[{"role": "user", "content": content}],
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
raw = message.content[0].text
|
|
61
|
+
return parse_llm_response(raw)
|
medcheck/llm/gemini.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from medcheck.core.context import ClinicalContext
|
|
7
|
+
from medcheck.llm.base import AnalysisResult, AnnotatedImage, LLMProvider, parse_llm_response
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class GeminiProvider(LLMProvider):
|
|
11
|
+
"""Google Gemini provider."""
|
|
12
|
+
|
|
13
|
+
name = "gemini"
|
|
14
|
+
supports_vision = True
|
|
15
|
+
|
|
16
|
+
def __init__(self, model: str = "gemini-3.5-flash") -> None:
|
|
17
|
+
self.model = model
|
|
18
|
+
|
|
19
|
+
def check_available(self) -> bool:
|
|
20
|
+
return bool(os.environ.get("GOOGLE_API_KEY"))
|
|
21
|
+
|
|
22
|
+
def analyze_images(
|
|
23
|
+
self,
|
|
24
|
+
images: list[AnnotatedImage],
|
|
25
|
+
prompt: str,
|
|
26
|
+
context: ClinicalContext | None,
|
|
27
|
+
) -> AnalysisResult:
|
|
28
|
+
import google.generativeai as genai
|
|
29
|
+
|
|
30
|
+
api_key = os.environ.get("GOOGLE_API_KEY")
|
|
31
|
+
if not api_key:
|
|
32
|
+
raise RuntimeError("GOOGLE_API_KEY not set")
|
|
33
|
+
genai.configure(api_key=api_key)
|
|
34
|
+
model = genai.GenerativeModel(self.model)
|
|
35
|
+
|
|
36
|
+
parts: list[Any] = []
|
|
37
|
+
for img in images:
|
|
38
|
+
parts.append({"mime_type": "image/png", "data": img.image_bytes})
|
|
39
|
+
if img.description:
|
|
40
|
+
parts.append(img.description)
|
|
41
|
+
|
|
42
|
+
parts.append(prompt)
|
|
43
|
+
|
|
44
|
+
response = model.generate_content(parts)
|
|
45
|
+
raw = response.text
|
|
46
|
+
return parse_llm_response(raw)
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import base64
|
|
4
|
+
import os
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from medcheck.core.context import ClinicalContext
|
|
8
|
+
from medcheck.llm.base import AnalysisResult, AnnotatedImage, LLMProvider, parse_llm_response
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class OpenAIProvider(LLMProvider):
|
|
12
|
+
"""OpenAI GPT provider."""
|
|
13
|
+
|
|
14
|
+
name = "openai"
|
|
15
|
+
supports_vision = True
|
|
16
|
+
|
|
17
|
+
def __init__(self, model: str = "gpt-5.5") -> None:
|
|
18
|
+
self.model = model
|
|
19
|
+
|
|
20
|
+
def check_available(self) -> bool:
|
|
21
|
+
return bool(os.environ.get("OPENAI_API_KEY"))
|
|
22
|
+
|
|
23
|
+
def analyze_images(
|
|
24
|
+
self,
|
|
25
|
+
images: list[AnnotatedImage],
|
|
26
|
+
prompt: str,
|
|
27
|
+
context: ClinicalContext | None,
|
|
28
|
+
) -> AnalysisResult:
|
|
29
|
+
from openai import OpenAI
|
|
30
|
+
|
|
31
|
+
api_key = os.environ.get("OPENAI_API_KEY")
|
|
32
|
+
if not api_key:
|
|
33
|
+
raise RuntimeError("OPENAI_API_KEY not set")
|
|
34
|
+
client = OpenAI(api_key=api_key)
|
|
35
|
+
|
|
36
|
+
content: list[dict[str, Any]] = []
|
|
37
|
+
for img in images:
|
|
38
|
+
b64 = base64.standard_b64encode(img.image_bytes).decode()
|
|
39
|
+
content.append(
|
|
40
|
+
{
|
|
41
|
+
"type": "image_url",
|
|
42
|
+
"image_url": {"url": f"data:image/png;base64,{b64}"},
|
|
43
|
+
}
|
|
44
|
+
)
|
|
45
|
+
if img.description:
|
|
46
|
+
content.append({"type": "text", "text": img.description})
|
|
47
|
+
|
|
48
|
+
content.append({"type": "text", "text": prompt})
|
|
49
|
+
|
|
50
|
+
response = client.chat.completions.create(
|
|
51
|
+
model=self.model,
|
|
52
|
+
messages=[{"role": "user", "content": content}],
|
|
53
|
+
max_tokens=4096,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
raw = response.choices[0].message.content or ""
|
|
57
|
+
return parse_llm_response(raw)
|
medcheck/llm/router.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from medcheck.llm.base import LLMProvider
|
|
4
|
+
|
|
5
|
+
FALLBACK_ORDER = ["claude", "openai", "gemini", "local"]
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class LLMRouter:
|
|
9
|
+
"""Register LLM providers and select the best available one."""
|
|
10
|
+
|
|
11
|
+
def __init__(self) -> None:
|
|
12
|
+
self._providers: dict[str, LLMProvider] = {}
|
|
13
|
+
|
|
14
|
+
def register(self, provider: LLMProvider) -> None:
|
|
15
|
+
"""Add a provider to the registry."""
|
|
16
|
+
self._providers[provider.name] = provider
|
|
17
|
+
|
|
18
|
+
def select(self, preferred: str) -> LLMProvider:
|
|
19
|
+
"""Return the preferred provider if available, else fall back through
|
|
20
|
+
FALLBACK_ORDER, then any registered available provider."""
|
|
21
|
+
# Try the explicitly requested provider first.
|
|
22
|
+
preferred_provider = self._providers.get(preferred)
|
|
23
|
+
if preferred_provider and preferred_provider.check_available():
|
|
24
|
+
return preferred_provider
|
|
25
|
+
|
|
26
|
+
# Walk the canonical fallback order.
|
|
27
|
+
for name in FALLBACK_ORDER:
|
|
28
|
+
if name == preferred:
|
|
29
|
+
continue
|
|
30
|
+
provider = self._providers.get(name)
|
|
31
|
+
if provider and provider.check_available():
|
|
32
|
+
return provider
|
|
33
|
+
|
|
34
|
+
# Last resort: any registered provider that is available.
|
|
35
|
+
for provider in self._providers.values():
|
|
36
|
+
if provider.check_available():
|
|
37
|
+
return provider
|
|
38
|
+
|
|
39
|
+
raise RuntimeError("No LLM provider available")
|
|
40
|
+
|
|
41
|
+
def list_available(self) -> list[str]:
|
|
42
|
+
"""Return names of all currently available providers."""
|
|
43
|
+
return [name for name, provider in self._providers.items() if provider.check_available()]
|