refactorai-cli 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,171 @@
1
+ """Model policy resolution and enforcement helpers (R15)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ import shutil
7
+ import subprocess
8
+ from dataclasses import dataclass
9
+
10
+ from refactor_cli.control_plane import resolve_policy
11
+
12
+
13
+ @dataclass
14
+ class ModelDecision:
15
+ allowed: bool
16
+ reason_code: str
17
+ message: str
18
+ suggested_model_id: str = ""
19
+ override_applied: bool = False
20
+
21
+
22
+ def _read_total_mem_mb() -> int:
23
+ try:
24
+ pages = os.sysconf("SC_PHYS_PAGES")
25
+ page_size = os.sysconf("SC_PAGE_SIZE")
26
+ total = int((int(pages) * int(page_size)) / (1024 * 1024))
27
+ return max(0, total)
28
+ except (ValueError, OSError, AttributeError):
29
+ return 0
30
+
31
+
32
+ def detect_machine_profile() -> dict:
33
+ mem_mb = _read_total_mem_mb()
34
+ gpu_vendor = ""
35
+ if shutil.which("nvidia-smi"):
36
+ gpu_vendor = "nvidia"
37
+ elif shutil.which("rocm-smi"):
38
+ gpu_vendor = "amd"
39
+ elif shutil.which("system_profiler"):
40
+ gpu_vendor = "apple"
41
+ has_gpu = bool(gpu_vendor)
42
+ profile = "cpu-small"
43
+ if has_gpu:
44
+ profile = "gpu-standard"
45
+ elif mem_mb >= 16 * 1024:
46
+ profile = "cpu-balanced"
47
+ return {
48
+ "profile": profile,
49
+ "has_gpu": has_gpu,
50
+ "gpu_vendor": gpu_vendor or "none",
51
+ "memory_mb": mem_mb,
52
+ }
53
+
54
+
55
+ def get_policy_bundle(*, force_refresh: bool = False) -> dict:
56
+ return resolve_policy(force_refresh=force_refresh)
57
+
58
+
59
+ def _model_policy(policy_bundle: dict) -> dict:
60
+ return dict((policy_bundle.get("bundle") or {}).get("model_policy") or {})
61
+
62
+
63
+ def recommended_model_id(policy_bundle: dict, *, profile: str) -> str:
64
+ model_policy = _model_policy(policy_bundle)
65
+ defaults = dict(model_policy.get("default_model_by_profile") or {})
66
+ return str(defaults.get(profile) or defaults.get("cpu-small") or "")
67
+
68
+
69
+ def list_model_entries(policy_bundle: dict) -> list[dict]:
70
+ model_policy = _model_policy(policy_bundle)
71
+ allowlist = list(model_policy.get("allowlist") or [])
72
+ blocklist = list(model_policy.get("blocklist") or [])
73
+ merged: list[dict] = []
74
+ for entry in allowlist:
75
+ item = dict(entry or {})
76
+ item.setdefault("status", "allowed")
77
+ merged.append(item)
78
+ for entry in blocklist:
79
+ item = dict(entry or {})
80
+ item.setdefault("status", "blocked")
81
+ merged.append(item)
82
+ return merged
83
+
84
+
85
+ def evaluate_model(
86
+ *,
87
+ model_id: str,
88
+ policy_bundle: dict,
89
+ profile: str,
90
+ ) -> ModelDecision:
91
+ wanted = str(model_id or "").strip()
92
+ if not wanted:
93
+ return ModelDecision(False, "invalid_model_id", "Model id is required.")
94
+
95
+ model_policy = _model_policy(policy_bundle)
96
+ allowlist = list(model_policy.get("allowlist") or [])
97
+ blocklist = list(model_policy.get("blocklist") or [])
98
+ overrides = list(model_policy.get("overrides") or [])
99
+ suggested = recommended_model_id(policy_bundle, profile=profile)
100
+
101
+ override_entry = next((o for o in overrides if str((o or {}).get("model_id") or "") == wanted), None)
102
+ if override_entry:
103
+ return ModelDecision(
104
+ True,
105
+ "override_allowed",
106
+ "Model allowed by signed policy override.",
107
+ suggested_model_id=suggested,
108
+ override_applied=True,
109
+ )
110
+
111
+ blocked = next((b for b in blocklist if str((b or {}).get("model_id") or "") == wanted), None)
112
+ if blocked:
113
+ reason = str((blocked or {}).get("reason_code") or "policy_block")
114
+ return ModelDecision(
115
+ False,
116
+ reason,
117
+ f"Model '{wanted}' is blocked by policy ({reason}).",
118
+ suggested_model_id=suggested,
119
+ )
120
+
121
+ allowed = next((a for a in allowlist if str((a or {}).get("model_id") or "") == wanted), None)
122
+ if not allowed:
123
+ return ModelDecision(
124
+ False,
125
+ "model_unknown",
126
+ f"Model '{wanted}' is not in the signed allowlist.",
127
+ suggested_model_id=suggested,
128
+ )
129
+
130
+ requires_gpu = bool((allowed or {}).get("requires_gpu", False))
131
+ if requires_gpu and not profile.startswith("gpu"):
132
+ return ModelDecision(
133
+ False,
134
+ "hardware_block",
135
+ f"Model '{wanted}' requires GPU profile.",
136
+ suggested_model_id=suggested,
137
+ )
138
+
139
+ commercial_allowed = bool((allowed or {}).get("commercial_allowed", False))
140
+ if not commercial_allowed:
141
+ return ModelDecision(
142
+ False,
143
+ "license_block",
144
+ f"Model '{wanted}' is not approved for commercial usage.",
145
+ suggested_model_id=suggested,
146
+ )
147
+
148
+ return ModelDecision(
149
+ True,
150
+ "allowed",
151
+ f"Model '{wanted}' is allowed by signed policy.",
152
+ suggested_model_id=suggested,
153
+ )
154
+
155
+
156
+ def ollama_installed() -> bool:
157
+ return bool(shutil.which("ollama"))
158
+
159
+
160
+ def probe_ollama_model(model_id: str) -> tuple[bool, str]:
161
+ if not ollama_installed():
162
+ return False, "ollama not installed on host."
163
+ proc = subprocess.run(
164
+ ["ollama", "show", model_id],
165
+ capture_output=True,
166
+ text=True,
167
+ )
168
+ text = (proc.stdout or proc.stderr or "").strip()
169
+ if proc.returncode != 0:
170
+ return False, text or "model not found"
171
+ return True, "model is available locally"
@@ -0,0 +1,241 @@
1
+ """Local runtime artifact manager for proprietary runtime delivery (R12).
2
+
3
+ This module manages versioned runtime artifacts under ``~/.refactor/runtime``:
4
+ manifest persistence, checksum verification, activation, rollback pointers, and
5
+ status reporting. It intentionally avoids any code execution semantics and only
6
+ handles artifact state.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import hashlib
12
+ import json
13
+ import platform
14
+ from dataclasses import dataclass
15
+ from datetime import datetime, timezone
16
+ from pathlib import Path
17
+
18
+ import httpx
19
+
20
+ from refactor_core.paths import ensure_dir, refactor_home
21
+
22
+ from refactor_cli.control_plane import ensure_lease
23
+ from refactor_cli.credentials import resolve_developer_key
24
+ from refactor_cli.settings import platform_url
25
+
26
+ RUNTIME_DIR = "runtime"
27
+ VERSIONS_DIR = "versions"
28
+ STATE_FILE = "state.json"
29
+ ARTIFACT_FILE = "refactorai-core.bin"
30
+ MANIFEST_FILE = "manifest.json"
31
+
32
+
33
+ @dataclass
34
+ class RuntimeManifest:
35
+ runtime_version: str
36
+ artifact_url: str
37
+ sha256: str
38
+ signature: str
39
+ min_cli_version: str | None = None
40
+
41
+
42
+ def _now_iso() -> str:
43
+ return datetime.now(timezone.utc).isoformat()
44
+
45
+
46
+ def runtime_root() -> Path:
47
+ return refactor_home() / RUNTIME_DIR
48
+
49
+
50
+ def runtime_versions_dir() -> Path:
51
+ return runtime_root() / VERSIONS_DIR
52
+
53
+
54
+ def runtime_state_path() -> Path:
55
+ return runtime_root() / STATE_FILE
56
+
57
+
58
+ def _default_state() -> dict:
59
+ return {
60
+ "active_version": "",
61
+ "rollback_version": "",
62
+ "channel": "stable",
63
+ "updated_at": "",
64
+ "last_error": "",
65
+ }
66
+
67
+
68
+ def read_runtime_state() -> dict:
69
+ path = runtime_state_path()
70
+ if not path.is_file():
71
+ return _default_state()
72
+ try:
73
+ data = json.loads(path.read_text(encoding="utf-8"))
74
+ except (json.JSONDecodeError, OSError):
75
+ return _default_state()
76
+ out = _default_state()
77
+ out.update({k: data.get(k, v) for k, v in out.items()})
78
+ return out
79
+
80
+
81
+ def write_runtime_state(state: dict) -> None:
82
+ ensure_dir(runtime_root())
83
+ merged = _default_state()
84
+ merged.update(state or {})
85
+ path = runtime_state_path()
86
+ tmp = path.with_suffix(path.suffix + ".tmp")
87
+ tmp.write_text(json.dumps(merged, indent=2), encoding="utf-8")
88
+ tmp.replace(path)
89
+
90
+
91
+ def runtime_version_dir(version: str) -> Path:
92
+ safe = str(version or "").strip()
93
+ if not safe:
94
+ raise ValueError("Runtime version is required")
95
+ return runtime_versions_dir() / safe
96
+
97
+
98
+ def runtime_artifact_path(version: str) -> Path:
99
+ return runtime_version_dir(version) / ARTIFACT_FILE
100
+
101
+
102
+ def runtime_manifest_path(version: str) -> Path:
103
+ return runtime_version_dir(version) / MANIFEST_FILE
104
+
105
+
106
+ def sha256_hex(content: bytes) -> str:
107
+ return hashlib.sha256(content).hexdigest()
108
+
109
+
110
+ def verify_sha256(content: bytes, expected_sha256: str) -> bool:
111
+ expected = str(expected_sha256 or "").strip().lower()
112
+ if not expected:
113
+ return False
114
+ return sha256_hex(content) == expected
115
+
116
+
117
+ def parse_manifest(payload: dict) -> RuntimeManifest:
118
+ return RuntimeManifest(
119
+ runtime_version=str(payload.get("runtime_version") or "").strip(),
120
+ artifact_url=str(payload.get("artifact_url") or "").strip(),
121
+ sha256=str(payload.get("sha256") or "").strip().lower(),
122
+ signature=str(payload.get("signature") or "").strip(),
123
+ min_cli_version=str(payload.get("min_cli_version") or "").strip() or None,
124
+ )
125
+
126
+
127
+ def resolve_runtime_manifest(*, channel: str = "stable", timeout: float = 20.0) -> RuntimeManifest:
128
+ auth_header = ""
129
+ try:
130
+ lease = ensure_lease()
131
+ auth_header = f"Bearer {lease.token}"
132
+ except RuntimeError:
133
+ resolved = resolve_developer_key()
134
+ if not resolved:
135
+ raise RuntimeError(
136
+ "No developer key configured. Run `refactor login`, set REFACTOR_API_KEY, "
137
+ "or add developer_key to refactor.config."
138
+ )
139
+ auth_header = f"Bearer {resolved.key}"
140
+ os_name = str(platform.system() or "linux").lower()
141
+ arch = str(platform.machine() or "x86_64").lower()
142
+ payload = {
143
+ "cli_version": "0.1.0",
144
+ "os": os_name,
145
+ "arch": arch,
146
+ "channel": channel,
147
+ }
148
+ try:
149
+ response = httpx.post(
150
+ f"{platform_url()}/v1/runtime/resolve",
151
+ headers={"Authorization": auth_header, "Content-Type": "application/json"},
152
+ json=payload,
153
+ timeout=timeout,
154
+ )
155
+ except httpx.HTTPError as exc:
156
+ raise RuntimeError(f"Could not resolve runtime manifest: {exc}") from exc
157
+ if response.status_code >= 400:
158
+ raise RuntimeError(f"Runtime manifest resolve failed ({response.status_code}): {response.text[:500]}")
159
+ manifest = parse_manifest(response.json() if response.content else {})
160
+ if not manifest.runtime_version or not manifest.artifact_url or not manifest.sha256:
161
+ raise RuntimeError("Runtime manifest is missing required fields")
162
+ return manifest
163
+
164
+
165
+ def download_artifact(url: str, *, timeout: float = 60.0) -> bytes:
166
+ try:
167
+ response = httpx.get(url, timeout=timeout)
168
+ except httpx.HTTPError as exc:
169
+ raise RuntimeError(f"Could not download runtime artifact: {exc}") from exc
170
+ if response.status_code >= 400:
171
+ raise RuntimeError(f"Runtime artifact download failed ({response.status_code})")
172
+ return response.content
173
+
174
+
175
+ def activate_runtime(manifest: RuntimeManifest, artifact: bytes, *, channel: str = "stable") -> Path:
176
+ if not verify_sha256(artifact, manifest.sha256):
177
+ raise RuntimeError("Runtime artifact checksum verification failed")
178
+ ensure_dir(runtime_versions_dir())
179
+ target_dir = ensure_dir(runtime_version_dir(manifest.runtime_version))
180
+ artifact_path = target_dir / ARTIFACT_FILE
181
+ manifest_path = target_dir / MANIFEST_FILE
182
+ artifact_path.write_bytes(artifact)
183
+ manifest_path.write_text(
184
+ json.dumps(
185
+ {
186
+ "runtime_version": manifest.runtime_version,
187
+ "artifact_url": manifest.artifact_url,
188
+ "sha256": manifest.sha256,
189
+ "signature": manifest.signature,
190
+ "min_cli_version": manifest.min_cli_version,
191
+ "activated_at": _now_iso(),
192
+ },
193
+ indent=2,
194
+ ),
195
+ encoding="utf-8",
196
+ )
197
+ state = read_runtime_state()
198
+ previous = str(state.get("active_version") or "").strip()
199
+ new_state = {
200
+ "active_version": manifest.runtime_version,
201
+ "rollback_version": previous,
202
+ "channel": channel,
203
+ "updated_at": _now_iso(),
204
+ "last_error": "",
205
+ }
206
+ write_runtime_state(new_state)
207
+ return artifact_path
208
+
209
+
210
+ def rollback_runtime() -> str:
211
+ state = read_runtime_state()
212
+ active = str(state.get("active_version") or "").strip()
213
+ rollback = str(state.get("rollback_version") or "").strip()
214
+ if not rollback:
215
+ raise RuntimeError("No rollback runtime version is available")
216
+ if not runtime_artifact_path(rollback).is_file():
217
+ raise RuntimeError(f"Rollback artifact for version '{rollback}' is missing")
218
+ write_runtime_state(
219
+ {
220
+ "active_version": rollback,
221
+ "rollback_version": active,
222
+ "channel": state.get("channel", "stable"),
223
+ "updated_at": _now_iso(),
224
+ "last_error": "",
225
+ }
226
+ )
227
+ return rollback
228
+
229
+
230
+ def runtime_status() -> dict:
231
+ state = read_runtime_state()
232
+ active = str(state.get("active_version") or "").strip()
233
+ active_artifact = runtime_artifact_path(active) if active else None
234
+ active_exists = bool(active_artifact and active_artifact.is_file())
235
+ return {
236
+ "active_version": active,
237
+ "rollback_version": str(state.get("rollback_version") or "").strip(),
238
+ "channel": str(state.get("channel") or "stable"),
239
+ "updated_at": str(state.get("updated_at") or ""),
240
+ "active_artifact_exists": active_exists,
241
+ }
@@ -0,0 +1,33 @@
1
+ """Resolved CLI settings and constants."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+
7
+ DEFAULT_PLATFORM_URL = "http://localhost:8000"
8
+ ENV_API_KEY = "REFACTOR_API_KEY"
9
+ ENV_PLATFORM_URL = "REFACTOR_PLATFORM_URL"
10
+ ENV_POLICY_SIGNING_KEY = "REFACTOR_POLICY_SIGNING_KEY"
11
+
12
+
13
+ def platform_url() -> str:
14
+ return os.environ.get(ENV_PLATFORM_URL, DEFAULT_PLATFORM_URL).rstrip("/")
15
+
16
+
17
+ def env_api_key() -> str | None:
18
+ value = os.environ.get(ENV_API_KEY, "").strip()
19
+ return value or None
20
+
21
+
22
+ def policy_signing_key() -> str:
23
+ value = os.environ.get(ENV_POLICY_SIGNING_KEY, "").strip()
24
+ return value or "dev-policy-signing-key"
25
+
26
+
27
+ def mask_key(key: str) -> str:
28
+ """Return a masked representation that never reveals the full key."""
29
+ if not key:
30
+ return "<none>"
31
+ if len(key) <= 8:
32
+ return "*" * len(key)
33
+ return f"{key[:6]}…{key[-4:]}"