modelctl-mlflow 1.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
modelctl/__init__.py ADDED
@@ -0,0 +1,7 @@
1
+ """Small MLflow Model Registry utility.
2
+
3
+ The package exposes a command line interface via ``python -m modelctl`` and the
4
+ ``modelctl`` console script declared in ``pyproject.toml``.
5
+ """
6
+
7
+ __version__ = "1.1.2"
modelctl/__main__.py ADDED
@@ -0,0 +1,5 @@
1
+ """Entrypoint for ``python -m modelctl``."""
2
+
3
+ from .cli import main
4
+
5
+ raise SystemExit(main())
modelctl/cli.py ADDED
@@ -0,0 +1,277 @@
1
+ """Command line interface for modelctl."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import json
7
+ import sys
8
+ from dataclasses import asdict, is_dataclass
9
+ from pathlib import Path
10
+ from typing import Any
11
+
12
+ from . import __version__
13
+ from .core import (
14
+ DEFAULT_HOST,
15
+ DEFAULT_PORT,
16
+ VerifyResult,
17
+ get_model_info,
18
+ list_model_versions,
19
+ promote_alias,
20
+ pull_model,
21
+ register_model_directory,
22
+ verify_model,
23
+ )
24
+ from .tags import TagError, merge_dicts, parse_key_value_items, read_json_dict
25
+
26
+
27
+ TOP_LEVEL_HELP = """\
28
+ Small MLflow Model Registry utility for arbitrary model payloads.
29
+
30
+ modelctl stores a directory as an opaque payload in MLflow artifacts and keeps
31
+ registry metadata in MLflow Model Registry: versions, aliases, tags and source
32
+ URIs. It does not require the payload to be a Python, sklearn, torch or any
33
+ other framework-specific model.
34
+ """
35
+
36
+ TOP_LEVEL_EPILOG = """\
37
+ Quick start:
38
+ modelctl register ./model my-model
39
+ modelctl list my-model
40
+ modelctl info my-model@champion
41
+ modelctl pull my-model@champion ./downloaded-model
42
+ modelctl verify my-model@champion ./downloaded-model
43
+ modelctl promote my-model 3 champion
44
+
45
+ Model refs:
46
+ my-model@champion Resolve by alias
47
+ my-model:3 Resolve by version
48
+ models:/my-model@champion MLflow alias URI form
49
+ models:/my-model/3 MLflow version URI form
50
+
51
+ Connection:
52
+ By default modelctl uses http://localhost:5000.
53
+ Override with --host/--port or pass --tracking-uri to any command.
54
+ MLflow auth is handled by MLflow environment variables, for example:
55
+ MLFLOW_TRACKING_USERNAME and MLFLOW_TRACKING_PASSWORD.
56
+
57
+ Registration:
58
+ The first version gets aliases baseline and champion by default.
59
+ Later versions get alias candidate by default.
60
+ Pass --alias one or more times to set explicit aliases.
61
+ Optional metadata can be passed with --general-tag, --training-tag,
62
+ --general-tags-json and --training-tags-json.
63
+
64
+ Pull and verify:
65
+ pull downloads payload only by default. Use --full-package to download the
66
+ whole modelctl package with manifest and metadata files.
67
+ pull verifies the payload hash by default. Use --no-verify only when you
68
+ intentionally want to skip that check.
69
+ verify exits with 0 on match, 2 on hash mismatch and 1 on command failure.
70
+
71
+ Output:
72
+ Commands print machine-readable JSON to stdout.
73
+ Human-readable progress and errors go to stderr.
74
+
75
+ More help:
76
+ modelctl <command> --help
77
+ """
78
+
79
+
80
+ def main(argv: list[str] | None = None) -> int:
81
+ """Run the modelctl CLI and return a process exit code."""
82
+
83
+ parser = build_parser()
84
+ args = parser.parse_args(argv)
85
+
86
+ try:
87
+ result = dispatch(args)
88
+ except Exception as exc: # noqa: BLE001 - CLI should print clear errors.
89
+ print(f"ERROR: {exc}", file=sys.stderr)
90
+ return 1
91
+
92
+ if result is not None:
93
+ print_json(result)
94
+
95
+ if isinstance(result, VerifyResult) and not result.matches:
96
+ return 2
97
+ return 0
98
+
99
+
100
+ def build_parser() -> argparse.ArgumentParser:
101
+ """Construct the top-level argument parser and subcommands."""
102
+
103
+ parser = argparse.ArgumentParser(
104
+ prog="modelctl",
105
+ description=TOP_LEVEL_HELP,
106
+ epilog=TOP_LEVEL_EPILOG,
107
+ formatter_class=argparse.RawDescriptionHelpFormatter,
108
+ )
109
+ parser.add_argument("--version", action="version", version=f"modelctl {__version__}")
110
+
111
+ subparsers = parser.add_subparsers(dest="command", required=True)
112
+ add_register_parser(subparsers)
113
+ add_promote_parser(subparsers)
114
+ add_pull_parser(subparsers)
115
+ add_verify_parser(subparsers)
116
+ add_list_parser(subparsers)
117
+ add_info_parser(subparsers)
118
+ return parser
119
+
120
+
121
+ def add_common_connection_args(parser: argparse.ArgumentParser) -> None:
122
+ """Add MLflow connection flags shared by all commands."""
123
+
124
+ parser.add_argument("--host", default=DEFAULT_HOST, help="MLflow host. Default: localhost.")
125
+ parser.add_argument("--port", type=int, default=DEFAULT_PORT, help="MLflow port. Default: 5000.")
126
+ parser.add_argument("--tracking-uri", default=None, help="Full MLflow tracking URI. Overrides --host and --port.")
127
+
128
+
129
+ def add_register_parser(subparsers: argparse._SubParsersAction[argparse.ArgumentParser]) -> None:
130
+ """Add the ``register`` command parser."""
131
+
132
+ parser = subparsers.add_parser("register", help="Register a payload directory as a new MLflow model version.")
133
+ parser.add_argument("source_dir", help="Payload directory to register.")
134
+ parser.add_argument("name", help="Registered model name.")
135
+ parser.add_argument("--alias", action="append", default=None, help="Alias to set on the new version. Can be repeated.")
136
+ parser.add_argument("--general-tags-json", default=None, help="Path to JSON object with general metadata.")
137
+ parser.add_argument("--training-tags-json", default=None, help="Path to JSON object with training metadata.")
138
+ parser.add_argument("--general-tag", action="append", default=None, help="Inline general tag key=value. Can be repeated.")
139
+ parser.add_argument("--training-tag", action="append", default=None, help="Inline training tag key=value. Can be repeated.")
140
+ parser.add_argument("--description", default=None, help="Optional model version description.")
141
+ add_common_connection_args(parser)
142
+
143
+
144
+ def add_promote_parser(subparsers: argparse._SubParsersAction[argparse.ArgumentParser]) -> None:
145
+ """Add the ``promote`` command parser."""
146
+
147
+ parser = subparsers.add_parser("promote", help="Point an alias to an existing model version.")
148
+ parser.add_argument("name", help="Registered model name.")
149
+ parser.add_argument("version", help="Model version number.")
150
+ parser.add_argument("alias", help="Alias to set.")
151
+ add_common_connection_args(parser)
152
+
153
+
154
+ def add_pull_parser(subparsers: argparse._SubParsersAction[argparse.ArgumentParser]) -> None:
155
+ """Add the ``pull`` command parser."""
156
+
157
+ parser = subparsers.add_parser("pull", help="Download a model version or alias.")
158
+ parser.add_argument("ref", help="Model ref: name@alias, name:version, models:/name@alias or models:/name/version.")
159
+ parser.add_argument("output_dir", help="Destination directory.")
160
+ parser.add_argument("--full-package", action="store_true", help="Download the full modelctl package instead of payload only.")
161
+ parser.add_argument("--overwrite", action="store_true", help="Replace output_dir if it already exists.")
162
+ parser.add_argument("--no-verify", action="store_true", help="Skip post-download payload hash verification.")
163
+ add_common_connection_args(parser)
164
+
165
+
166
+ def add_verify_parser(subparsers: argparse._SubParsersAction[argparse.ArgumentParser]) -> None:
167
+ """Add the ``verify`` command parser."""
168
+
169
+ parser = subparsers.add_parser("verify", help="Compare a directory with the registry payload hash.")
170
+ parser.add_argument("ref", help="Model ref: name@alias, name:version, models:/name@alias or models:/name/version.")
171
+ parser.add_argument("path", help="Directory to verify. Payload directories and full modelctl packages are both accepted.")
172
+ add_common_connection_args(parser)
173
+
174
+
175
+ def add_list_parser(subparsers: argparse._SubParsersAction[argparse.ArgumentParser]) -> None:
176
+ """Add the ``list`` command parser."""
177
+
178
+ parser = subparsers.add_parser("list", help="List versions of a registered model.")
179
+ parser.add_argument("name", help="Registered model name.")
180
+ add_common_connection_args(parser)
181
+
182
+
183
+ def add_info_parser(subparsers: argparse._SubParsersAction[argparse.ArgumentParser]) -> None:
184
+ """Add the ``info`` command parser."""
185
+
186
+ parser = subparsers.add_parser("info", help="Show JSON info for one model ref.")
187
+ parser.add_argument("ref", help="Model ref: name@alias, name:version, models:/name@alias or models:/name/version.")
188
+ add_common_connection_args(parser)
189
+
190
+
191
+ def dispatch(args: argparse.Namespace) -> Any:
192
+ """Dispatch parsed CLI arguments to the corresponding core function."""
193
+
194
+ if args.command == "register":
195
+ general_tags = load_tags(args.general_tags_json, args.general_tag)
196
+ training_tags = load_tags(args.training_tags_json, args.training_tag)
197
+ return register_model_directory(
198
+ args.source_dir,
199
+ args.name,
200
+ aliases=args.alias,
201
+ general_tags=general_tags,
202
+ training_tags=training_tags,
203
+ description=args.description,
204
+ host=args.host,
205
+ port=args.port,
206
+ tracking_uri=args.tracking_uri,
207
+ )
208
+
209
+ if args.command == "promote":
210
+ return promote_alias(
211
+ args.name,
212
+ args.version,
213
+ args.alias,
214
+ host=args.host,
215
+ port=args.port,
216
+ tracking_uri=args.tracking_uri,
217
+ )
218
+
219
+ if args.command == "pull":
220
+ return pull_model(
221
+ args.ref,
222
+ args.output_dir,
223
+ full_package=args.full_package,
224
+ overwrite=args.overwrite,
225
+ verify=not args.no_verify,
226
+ host=args.host,
227
+ port=args.port,
228
+ tracking_uri=args.tracking_uri,
229
+ )
230
+
231
+ if args.command == "verify":
232
+ return verify_model(args.ref, args.path, host=args.host, port=args.port, tracking_uri=args.tracking_uri)
233
+
234
+ if args.command == "list":
235
+ return list_model_versions(args.name, host=args.host, port=args.port, tracking_uri=args.tracking_uri)
236
+
237
+ if args.command == "info":
238
+ return get_model_info(args.ref, host=args.host, port=args.port, tracking_uri=args.tracking_uri)
239
+
240
+ raise ValueError(f"Unknown command: {args.command}")
241
+
242
+
243
+ def load_tags(json_path: str | None, inline_items: list[str] | None) -> dict[str, Any]:
244
+ """Load optional JSON and inline tags, then merge them."""
245
+
246
+ try:
247
+ return merge_dicts(read_json_dict(json_path), parse_key_value_items(inline_items))
248
+ except TagError:
249
+ raise
250
+ except Exception as exc: # noqa: BLE001 - normalize CLI error text.
251
+ raise TagError(str(exc)) from exc
252
+
253
+
254
+ def print_json(value: Any) -> None:
255
+ """Print dataclasses, lists and dictionaries as pretty UTF-8 JSON."""
256
+
257
+ print(json.dumps(to_jsonable(value), ensure_ascii=False, indent=2, sort_keys=True))
258
+
259
+
260
+ def to_jsonable(value: Any) -> Any:
261
+ """Convert dataclasses recursively into JSON-serializable objects."""
262
+
263
+ if is_dataclass(value):
264
+ return asdict(value)
265
+ if isinstance(value, list):
266
+ return [to_jsonable(item) for item in value]
267
+ if isinstance(value, tuple):
268
+ return [to_jsonable(item) for item in value]
269
+ if isinstance(value, dict):
270
+ return {str(key): to_jsonable(item) for key, item in value.items()}
271
+ if isinstance(value, Path):
272
+ return str(value)
273
+ return value
274
+
275
+
276
+ if __name__ == "__main__":
277
+ raise SystemExit(main())