medgemma 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
medgemma/__init__.py ADDED
@@ -0,0 +1,9 @@
1
+ """MedGemma – Medical AI on Apple Silicon via MLX."""
2
+
3
+ from ._compat import check_platform
4
+ from ._version import __version__
5
+ from .client import MedGemma, Response
6
+
7
+ check_platform()
8
+
9
+ __all__ = ["MedGemma", "Response", "__version__"]
medgemma/__main__.py ADDED
@@ -0,0 +1,5 @@
1
+ """Allow ``python -m medgemma``."""
2
+
3
+ from .cli import cli
4
+
5
+ cli()
medgemma/_compat.py ADDED
@@ -0,0 +1,24 @@
1
+ """Platform compatibility checks for Apple Silicon."""
2
+
3
+ import platform
4
+ import sys
5
+
6
+
7
+ def check_platform() -> None:
8
+ """Fail fast if not running on macOS with Apple Silicon."""
9
+ if sys.platform != "darwin":
10
+ raise RuntimeError(
11
+ "medgemma requires macOS with Apple Silicon (M1/M2/M3/M4). "
12
+ f"Detected platform: {sys.platform}"
13
+ )
14
+ machine = platform.machine()
15
+ if machine != "arm64":
16
+ raise RuntimeError(
17
+ "medgemma requires Apple Silicon (arm64). "
18
+ f"Detected architecture: {machine}"
19
+ )
20
+
21
+
22
+ def is_apple_silicon() -> bool:
23
+ """Return True if running on macOS arm64."""
24
+ return sys.platform == "darwin" and platform.machine() == "arm64"
medgemma/_version.py ADDED
@@ -0,0 +1 @@
1
+ __version__ = "0.1.0"
medgemma/cli.py ADDED
@@ -0,0 +1,91 @@
1
+ """Click CLI for medgemma."""
2
+
3
+ import json
4
+ import sys
5
+
6
+ import click
7
+
8
+ from ._version import __version__
9
+
10
+
11
+ @click.group()
12
+ @click.version_option(__version__, prog_name="medgemma")
13
+ def cli():
14
+ """MedGemma – Medical AI on Apple Silicon."""
15
+
16
+
17
+ @cli.command()
18
+ @click.argument("prompt")
19
+ @click.option("--image", default=None, help="Path to an image file.")
20
+ @click.option("--max-tokens", default=None, type=int, help="Max tokens to generate.")
21
+ @click.option("--temperature", default=None, type=float, help="Sampling temperature.")
22
+ @click.option("--model-path", default=None, help="Path to a local MLX model.")
23
+ @click.option("--json", "as_json", is_flag=True, help="Output as JSON with stats.")
24
+ @click.option("--no-stream", is_flag=True, help="Disable streaming output.")
25
+ def ask(prompt, image, max_tokens, temperature, model_path, as_json, no_stream):
26
+ """Send a prompt to MedGemma."""
27
+ from .client import MedGemma
28
+
29
+ kwargs = {}
30
+ if model_path:
31
+ kwargs["model_path"] = model_path
32
+ if max_tokens is not None:
33
+ kwargs["max_tokens"] = max_tokens
34
+ if temperature is not None:
35
+ kwargs["temperature"] = temperature
36
+
37
+ mg = MedGemma(**kwargs)
38
+
39
+ gen_kwargs = {}
40
+ if image:
41
+ gen_kwargs["image"] = image
42
+
43
+ if no_stream or as_json:
44
+ resp = mg.ask(prompt, **gen_kwargs)
45
+ if as_json:
46
+ click.echo(
47
+ json.dumps(
48
+ {
49
+ "text": resp.text,
50
+ "completion_tokens": resp.completion_tokens,
51
+ "tokens_per_second": resp.tokens_per_second,
52
+ "elapsed_seconds": resp.elapsed_seconds,
53
+ },
54
+ indent=2,
55
+ )
56
+ )
57
+ else:
58
+ click.echo(resp.text)
59
+ else:
60
+ for chunk in mg.stream(prompt, **gen_kwargs):
61
+ click.echo(chunk, nl=False)
62
+ click.echo()
63
+
64
+
65
+ @cli.command()
66
+ @click.option("--local-path", default=None, help="Copy an existing local model.")
67
+ @click.option("--force", is_flag=True, help="Re-download / overwrite existing cache.")
68
+ def setup(local_path, force):
69
+ """Download or set up the MedGemma model."""
70
+ from .convert import setup_model
71
+
72
+ click.echo("Setting up MedGemma model...")
73
+ try:
74
+ path = setup_model(force=force, local_path=local_path)
75
+ click.echo(f"Model ready at: {path}")
76
+ except Exception as exc:
77
+ click.echo(f"Error: {exc}", err=True)
78
+ sys.exit(1)
79
+
80
+
81
+ @cli.command()
82
+ def info():
83
+ """Show model and cache information."""
84
+ from .model import model_info
85
+
86
+ mi = model_info()
87
+ click.echo(f"Cache directory : {mi.cache_dir}")
88
+ click.echo(f"Model in cache : {'yes' if mi.cache_ready else 'no'}")
89
+ click.echo(f"Model loaded : {'yes' if mi.loaded else 'no'}")
90
+ if mi.path:
91
+ click.echo(f"Loaded from : {mi.path}")
medgemma/client.py ADDED
@@ -0,0 +1,169 @@
1
+ """High-level MedGemma client – the main public API."""
2
+
3
+ from dataclasses import dataclass
4
+ from pathlib import Path
5
+ from typing import Generator
6
+
7
+ from .config import DEFAULT_MAX_TOKENS, DEFAULT_TEMPERATURE, SYSTEM_PROMPT
8
+ from .model import get_model, model_info, unload_model
9
+
10
+
11
+ @dataclass(frozen=True)
12
+ class Response:
13
+ """Result returned by :meth:`MedGemma.ask`."""
14
+
15
+ text: str
16
+ prompt_tokens: int = 0
17
+ completion_tokens: int = 0
18
+ tokens_per_second: float = 0.0
19
+ elapsed_seconds: float = 0.0
20
+
21
+
22
+ class MedGemma:
23
+ """Friendly wrapper around the MLX MedGemma model.
24
+
25
+ Parameters
26
+ ----------
27
+ model_path:
28
+ Path to a local converted MLX model directory. When ``None`` the
29
+ default cache at ``~/.medgemma/model`` is used (auto-downloaded on
30
+ first call).
31
+ max_tokens:
32
+ Default maximum tokens for generation.
33
+ temperature:
34
+ Default sampling temperature.
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ *,
40
+ model_path: str | Path | None = None,
41
+ max_tokens: int = DEFAULT_MAX_TOKENS,
42
+ temperature: float = DEFAULT_TEMPERATURE,
43
+ ) -> None:
44
+ self._model_path = model_path
45
+ self.max_tokens = max_tokens
46
+ self.temperature = temperature
47
+
48
+ # -- public API --------------------------------------------------------
49
+
50
+ def ask(
51
+ self,
52
+ prompt: str,
53
+ *,
54
+ image: str | Path | None = None,
55
+ max_tokens: int | None = None,
56
+ temperature: float | None = None,
57
+ ) -> Response:
58
+ """Send a prompt (and optional image) to the model.
59
+
60
+ Returns a :class:`Response` with ``.text`` and generation stats.
61
+ """
62
+ model, processor = self._ensure_loaded()
63
+ max_tok = max_tokens if max_tokens is not None else self.max_tokens
64
+ temp = temperature if temperature is not None else self.temperature
65
+
66
+ formatted_prompt = self._apply_template(processor, prompt, image)
67
+ image_arg = self._resolve_image(image)
68
+
69
+ from mlx_vlm import generate
70
+
71
+ result = generate(
72
+ model,
73
+ processor,
74
+ formatted_prompt,
75
+ image=image_arg,
76
+ max_tokens=max_tok,
77
+ temperature=temp,
78
+ verbose=False,
79
+ )
80
+
81
+ # generate returns a GenerationResult dataclass
82
+ text = result.text if hasattr(result, "text") else str(result)
83
+ prompt_tokens = getattr(result, "prompt_tokens", 0)
84
+ gen_tokens = getattr(result, "generation_tokens", 0)
85
+ gen_tps = getattr(result, "generation_tps", 0.0)
86
+ # Compute elapsed from tokens / tps
87
+ elapsed = gen_tokens / gen_tps if gen_tps > 0 else 0.0
88
+
89
+ return Response(
90
+ text=text.strip(),
91
+ prompt_tokens=prompt_tokens,
92
+ completion_tokens=gen_tokens,
93
+ tokens_per_second=round(gen_tps, 1),
94
+ elapsed_seconds=round(elapsed, 2),
95
+ )
96
+
97
+ def stream(
98
+ self,
99
+ prompt: str,
100
+ *,
101
+ image: str | Path | None = None,
102
+ max_tokens: int | None = None,
103
+ temperature: float | None = None,
104
+ ) -> Generator[str, None, None]:
105
+ """Stream generated text chunk by chunk."""
106
+ model, processor = self._ensure_loaded()
107
+ max_tok = max_tokens if max_tokens is not None else self.max_tokens
108
+ temp = temperature if temperature is not None else self.temperature
109
+
110
+ formatted_prompt = self._apply_template(processor, prompt, image)
111
+ image_arg = self._resolve_image(image)
112
+
113
+ from mlx_vlm import stream_generate
114
+
115
+ for chunk in stream_generate(
116
+ model,
117
+ processor,
118
+ formatted_prompt,
119
+ image=image_arg,
120
+ max_tokens=max_tok,
121
+ temperature=temp,
122
+ ):
123
+ if isinstance(chunk, str):
124
+ yield chunk
125
+ elif hasattr(chunk, "text"):
126
+ yield chunk.text
127
+ else:
128
+ yield str(chunk)
129
+
130
+ def unload(self) -> None:
131
+ """Release the model from memory."""
132
+ unload_model()
133
+
134
+ @staticmethod
135
+ def info():
136
+ """Return model info without loading."""
137
+ return model_info()
138
+
139
+ # -- internals ---------------------------------------------------------
140
+
141
+ def _ensure_loaded(self):
142
+ return get_model(self._model_path)
143
+
144
+ @staticmethod
145
+ def _apply_template(processor, prompt: str, image=None) -> str:
146
+ """Build chat messages and apply the processor's chat template."""
147
+ content: list[dict] = []
148
+ if image is not None:
149
+ content.append({"type": "image"})
150
+ content.append({"type": "text", "text": prompt})
151
+
152
+ messages = [
153
+ {"role": "system", "content": SYSTEM_PROMPT},
154
+ {"role": "user", "content": content},
155
+ ]
156
+
157
+ return processor.apply_chat_template(
158
+ messages, tokenize=False, add_generation_prompt=True
159
+ )
160
+
161
+ @staticmethod
162
+ def _resolve_image(image):
163
+ if image is None:
164
+ return None
165
+ p = Path(image).expanduser()
166
+ if p.is_file():
167
+ return [str(p)]
168
+ # Might be a URL – let mlx_vlm handle it
169
+ return [str(image)]
medgemma/config.py ADDED
@@ -0,0 +1,24 @@
1
+ """Constants and default configuration for medgemma."""
2
+
3
+ from pathlib import Path
4
+
5
+ # Default cache location for the converted MLX model
6
+ CACHE_DIR = Path.home() / ".medgemma" / "model"
7
+
8
+ # Hugging Face repo for the original model
9
+ HF_REPO_ID = "google/medgemma-4b-it"
10
+
11
+ # MLX-VLM quantisation defaults
12
+ DEFAULT_QUANT_BITS = 4
13
+ DEFAULT_QUANT_GROUP_SIZE = 64
14
+
15
+ # Generation defaults
16
+ DEFAULT_MAX_TOKENS = 512
17
+ DEFAULT_TEMPERATURE = 0.1
18
+
19
+ # Chat template for MedGemma
20
+ SYSTEM_PROMPT = (
21
+ "You are a helpful medical AI assistant. Provide accurate, "
22
+ "evidence-based medical information. Always recommend consulting "
23
+ "a healthcare professional for personal medical advice."
24
+ )
medgemma/convert.py ADDED
@@ -0,0 +1,98 @@
1
+ """Model download and MLX conversion utilities."""
2
+
3
+ import shutil
4
+ from pathlib import Path
5
+
6
+ from .config import CACHE_DIR, DEFAULT_QUANT_BITS, DEFAULT_QUANT_GROUP_SIZE, HF_REPO_ID
7
+
8
+
9
+ def is_model_ready(model_path: str | Path | None = None) -> bool:
10
+ """Check whether a converted MLX model exists at the given path."""
11
+ path = Path(model_path) if model_path else CACHE_DIR
12
+ return (path / "config.json").is_file()
13
+
14
+
15
+ def setup_model(
16
+ *,
17
+ force: bool = False,
18
+ hf_repo: str = HF_REPO_ID,
19
+ local_path: str | Path | None = None,
20
+ cache_dir: str | Path | None = None,
21
+ ) -> Path:
22
+ """Download / convert the model into the cache directory.
23
+
24
+ Parameters
25
+ ----------
26
+ force:
27
+ Re-download even if a model already exists in cache.
28
+ hf_repo:
29
+ Hugging Face repo ID to download from.
30
+ local_path:
31
+ Path to an already-converted local model directory. When provided the
32
+ files are copied (or symlinked) into the cache instead of downloading.
33
+ cache_dir:
34
+ Override the default cache directory (~/.medgemma/model).
35
+
36
+ Returns
37
+ -------
38
+ Path to the ready-to-load model directory.
39
+ """
40
+ dest = Path(cache_dir) if cache_dir else CACHE_DIR
41
+
42
+ if local_path is not None:
43
+ return _copy_local(Path(local_path).expanduser(), dest, force=force)
44
+
45
+ if not force and is_model_ready(dest):
46
+ return dest
47
+
48
+ return _convert_from_hf(hf_repo, dest)
49
+
50
+
51
+ def _copy_local(src: Path, dest: Path, *, force: bool = False) -> Path:
52
+ """Copy a local converted model into the cache directory."""
53
+ if not src.is_dir():
54
+ raise FileNotFoundError(f"Local model path does not exist: {src}")
55
+ if not (src / "config.json").is_file():
56
+ raise FileNotFoundError(
57
+ f"No config.json in {src} – is this a converted MLX model?"
58
+ )
59
+
60
+ if src.resolve() == dest.resolve():
61
+ return dest
62
+
63
+ if force and dest.exists():
64
+ shutil.rmtree(dest)
65
+
66
+ if not dest.exists():
67
+ dest.parent.mkdir(parents=True, exist_ok=True)
68
+ shutil.copytree(src, dest)
69
+
70
+ return dest
71
+
72
+
73
+ def _convert_from_hf(hf_repo: str, dest: Path) -> Path:
74
+ """Download from Hugging Face and convert to MLX format."""
75
+ try:
76
+ from mlx_vlm import convert as mlx_convert
77
+ except ImportError as exc:
78
+ raise ImportError(
79
+ "mlx-vlm is required for model conversion. "
80
+ "Install it with: pip install mlx-vlm>=0.3.10"
81
+ ) from exc
82
+
83
+ dest.parent.mkdir(parents=True, exist_ok=True)
84
+
85
+ mlx_convert(
86
+ hf_repo,
87
+ mlx_path=str(dest),
88
+ q_bits=DEFAULT_QUANT_BITS,
89
+ q_group_size=DEFAULT_QUANT_GROUP_SIZE,
90
+ )
91
+
92
+ if not is_model_ready(dest):
93
+ raise RuntimeError(
94
+ f"Conversion finished but {dest / 'config.json'} not found. "
95
+ "The model may not have been converted correctly."
96
+ )
97
+
98
+ return dest
medgemma/model.py ADDED
@@ -0,0 +1,74 @@
1
+ """Thread-safe lazy-loading model singleton."""
2
+
3
+ import threading
4
+ from dataclasses import dataclass
5
+ from pathlib import Path
6
+ from typing import Any
7
+
8
+ from .config import CACHE_DIR
9
+ from .convert import is_model_ready, setup_model
10
+
11
+ _lock = threading.Lock()
12
+ _model: Any | None = None
13
+ _processor: Any | None = None
14
+ _model_path: Path | None = None
15
+
16
+
17
+ @dataclass(frozen=True)
18
+ class ModelInfo:
19
+ loaded: bool
20
+ path: str | None
21
+ cache_dir: str
22
+ cache_ready: bool
23
+
24
+
25
+ def get_model(model_path: str | Path | None = None) -> tuple[Any, Any]:
26
+ """Return ``(model, processor)``, loading on first call.
27
+
28
+ Thread-safe: only one thread will trigger the load.
29
+ """
30
+ global _model, _processor, _model_path
31
+
32
+ if _model is not None and _processor is not None:
33
+ return _model, _processor
34
+
35
+ with _lock:
36
+ # Double-checked locking
37
+ if _model is not None and _processor is not None:
38
+ return _model, _processor
39
+
40
+ path = Path(model_path).expanduser() if model_path else CACHE_DIR
41
+
42
+ if not is_model_ready(path):
43
+ setup_model(cache_dir=path)
44
+
45
+ try:
46
+ from mlx_vlm import load
47
+ except ImportError as exc:
48
+ raise ImportError(
49
+ "mlx-vlm is required. Install with: pip install mlx-vlm>=0.3.10"
50
+ ) from exc
51
+
52
+ _model, _processor = load(str(path))
53
+ _model_path = path
54
+
55
+ return _model, _processor
56
+
57
+
58
+ def unload_model() -> None:
59
+ """Release the loaded model from memory."""
60
+ global _model, _processor, _model_path
61
+ with _lock:
62
+ _model = None
63
+ _processor = None
64
+ _model_path = None
65
+
66
+
67
+ def model_info() -> ModelInfo:
68
+ """Return current model state without triggering a load."""
69
+ return ModelInfo(
70
+ loaded=_model is not None,
71
+ path=str(_model_path) if _model_path else None,
72
+ cache_dir=str(CACHE_DIR),
73
+ cache_ready=is_model_ready(),
74
+ )
@@ -0,0 +1,49 @@
1
+ Metadata-Version: 2.4
2
+ Name: medgemma
3
+ Version: 0.1.0
4
+ Summary: Medical AI on Apple Silicon – MedGemma 1.5 4B via MLX
5
+ Author: chiboko
6
+ License-Expression: MIT
7
+ Keywords: ai,apple-silicon,medgemma,medical,mlx
8
+ Classifier: Development Status :: 3 - Alpha
9
+ Classifier: Intended Audience :: Healthcare Industry
10
+ Classifier: Intended Audience :: Science/Research
11
+ Classifier: Operating System :: MacOS
12
+ Classifier: Programming Language :: Python :: 3
13
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
14
+ Requires-Python: >=3.10
15
+ Requires-Dist: click>=8.0
16
+ Requires-Dist: mlx-vlm>=0.3.10
17
+ Provides-Extra: dev
18
+ Requires-Dist: pytest-mock; extra == 'dev'
19
+ Requires-Dist: pytest>=7.0; extra == 'dev'
20
+ Description-Content-Type: text/markdown
21
+
22
+ # medgemma
23
+
24
+ Medical AI on Apple Silicon – MedGemma 1.5 4B via MLX.
25
+
26
+ ## Install
27
+
28
+ ```bash
29
+ pip install medgemma
30
+ ```
31
+
32
+ ## Usage
33
+
34
+ ```python
35
+ from medgemma import MedGemma
36
+
37
+ mg = MedGemma()
38
+ response = mg.ask("What are symptoms of diabetes?")
39
+ print(response.text)
40
+ ```
41
+
42
+ ### CLI
43
+
44
+ ```bash
45
+ medgemma ask "What are symptoms of diabetes?"
46
+ medgemma ask "Describe this X-ray" --image xray.png
47
+ medgemma setup
48
+ medgemma info
49
+ ```
@@ -0,0 +1,13 @@
1
+ medgemma/__init__.py,sha256=G29Y0KqJ6cXMqqwnBFypNR5wSazaUlIJcOLjccgHSns,235
2
+ medgemma/__main__.py,sha256=bwb3cqcddGbAvdwq9sDSBqhmCjoZG4BaH-CIkpSgsb4,65
3
+ medgemma/_compat.py,sha256=rwxytYiARMIa24EB42MMhnVbKRD5LceMFWrDHKV_rW4,727
4
+ medgemma/_version.py,sha256=kUR5RAFc7HCeiqdlX36dZOHkUI5wI6V_43RpEcD8b-0,22
5
+ medgemma/cli.py,sha256=cjFICqjfyZeCvzC-7A3FW2bbsm-YpXhmq25St1k_4ck,2869
6
+ medgemma/client.py,sha256=s_urpYSmD7_vcmIo-yF1LOomNtJFQv8whK4qZkXcUzY,5183
7
+ medgemma/config.py,sha256=bfS-LED3fCdCVKNV4qpJUU5r5kUGfOaKr9rnypfBUiE,672
8
+ medgemma/convert.py,sha256=1EpvMtRG7p8WNeRxsPGtxG2saAidsEdRiqWgnSpKygc,2889
9
+ medgemma/model.py,sha256=7Vuhxry8CMXdKxTRoiRqXHF7lx7rPNQ7Vqa3lF8SULQ,1920
10
+ medgemma-0.1.0.dist-info/METADATA,sha256=1ZTUGvR2jAMgp_kZvHgePoYO4VpbjhDF6kmtPPOl_Mo,1146
11
+ medgemma-0.1.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
12
+ medgemma-0.1.0.dist-info/entry_points.txt,sha256=xdJRJThjQNLeVqUIw0vERv9kReFa15FhAoX2X-0x2ck,46
13
+ medgemma-0.1.0.dist-info/RECORD,,
@@ -0,0 +1,4 @@
1
+ Wheel-Version: 1.0
2
+ Generator: hatchling 1.28.0
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
@@ -0,0 +1,2 @@
1
+ [console_scripts]
2
+ medgemma = medgemma.cli:cli