shkit 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- healing_kit/__init__.py +3 -0
- healing_kit/auth.py +79 -0
- healing_kit/clients/__init__.py +1 -0
- healing_kit/clients/databricks_client.py +183 -0
- healing_kit/clients/teams_client.py +128 -0
- healing_kit/models/__init__.py +1 -0
- healing_kit/models/diagnosis.py +45 -0
- healing_kit/models/events.py +30 -0
- healing_kit/models/evidence.py +83 -0
- healing_kit/runtime/__init__.py +6 -0
- healing_kit/runtime/approval.py +141 -0
- healing_kit/runtime/maintenance.py +52 -0
- healing_kit/services/__init__.py +1 -0
- healing_kit/services/cache_service.py +120 -0
- healing_kit/services/circuit_breaker.py +114 -0
- healing_kit/services/context_agent.py +127 -0
- healing_kit/services/dependency_graph.py +141 -0
- healing_kit/services/diagnosis_engine.py +165 -0
- healing_kit/services/identity.py +61 -0
- healing_kit/services/model_router.py +52 -0
- healing_kit/services/query_guard.py +168 -0
- healing_kit/services/resolution_verifier.py +100 -0
- healing_kit/services/token_budget.py +137 -0
- healing_kit/utils/__init__.py +1 -0
- healing_kit/utils/error_hash.py +15 -0
- healing_kit/utils/hmac_tokens.py +86 -0
- healing_kit/utils/sql_safety.py +84 -0
- iic/__init__.py +51 -0
- iic/__main__.py +18 -0
- iic/_console.py +235 -0
- iic/_doctor.py +143 -0
- iic/change/__init__.py +7 -0
- iic/change/change_detector.py +154 -0
- iic/context/__init__.py +7 -0
- iic/context/context_builder.py +117 -0
- iic/dependency/__init__.py +7 -0
- iic/dependency/dependency_analyzer.py +93 -0
- iic/diagnosis/__init__.py +7 -0
- iic/diagnosis/diagnosis_engine.py +183 -0
- iic/dna/__init__.py +7 -0
- iic/dna/dna_builder.py +184 -0
- iic/impact/__init__.py +7 -0
- iic/impact/impact_engine.py +102 -0
- iic/ingestion/__init__.py +14 -0
- iic/ingestion/base.py +21 -0
- iic/ingestion/databricks_source.py +98 -0
- iic/ingestion/static_source.py +23 -0
- iic/ingestion/webhook_source.py +39 -0
- iic/models/__init__.py +44 -0
- iic/models/change.py +77 -0
- iic/models/context.py +46 -0
- iic/models/diagnosis.py +37 -0
- iic/models/dna.py +77 -0
- iic/models/event.py +78 -0
- iic/models/impact.py +60 -0
- iic/models/report.py +88 -0
- iic/models/routing.py +41 -0
- iic/notify/__init__.py +7 -0
- iic/notify/teams_notifier.py +112 -0
- iic/report/__init__.py +7 -0
- iic/report/report_generator.py +67 -0
- iic/routing/__init__.py +7 -0
- iic/routing/router.py +42 -0
- iic/runtime/__init__.py +10 -0
- iic/runtime/_sql.py +11 -0
- iic/runtime/agent_config.py +48 -0
- iic/runtime/agent_runtime.py +70 -0
- iic/runtime/antibodies.py +100 -0
- iic/runtime/bootstrap.py +157 -0
- iic/runtime/constants.py +40 -0
- iic/runtime/context.py +46 -0
- iic/runtime/detective.py +72 -0
- iic/runtime/hooks.py +85 -0
- iic/runtime/incident_engine.py +207 -0
- iic/runtime/inprocess.py +350 -0
- iic/runtime/ledger.py +120 -0
- iic/runtime/monitor.py +155 -0
- iic/runtime/pattern_store.py +53 -0
- iic/runtime/reconciler.py +139 -0
- iic/runtime/scope_config.py +127 -0
- iic/runtime/store.py +150 -0
- iic/runtime/wrapper.py +28 -0
- iic_autoload.pth +1 -0
- onboarding/__init__.py +1 -0
- onboarding/cli.py +168 -0
- onboarding/config_schema.py +62 -0
- onboarding/manifest.py +27 -0
- onboarding/preflight.py +129 -0
- onboarding/provisioner.py +573 -0
- onboarding/rollback.py +81 -0
- shkit-1.2.0.dist-info/METADATA +239 -0
- shkit-1.2.0.dist-info/RECORD +94 -0
- shkit-1.2.0.dist-info/WHEEL +4 -0
- shkit-1.2.0.dist-info/entry_points.txt +2 -0
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
"""Dependency graph builder and event batching for thundering herd prevention."""
|
|
2
|
+
|
|
3
|
+
from collections import defaultdict, deque
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
|
|
6
|
+
from healing_kit.models.events import FailureEvent, RootCauseEvent
|
|
7
|
+
from healing_kit.utils.error_hash import compute_error_hash
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass
|
|
11
|
+
class DependencyGraph:
|
|
12
|
+
"""DAG representing task and table dependencies."""
|
|
13
|
+
|
|
14
|
+
# task_key → list of upstream task_keys
|
|
15
|
+
upstream: dict[str, list[str]] = field(default_factory=lambda: defaultdict(list))
|
|
16
|
+
# task_key → list of downstream task_keys
|
|
17
|
+
downstream: dict[str, list[str]] = field(default_factory=lambda: defaultdict(list))
|
|
18
|
+
# All task keys in the graph
|
|
19
|
+
nodes: set[str] = field(default_factory=set)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class DependencyGraphBuilder:
|
|
23
|
+
"""
|
|
24
|
+
Builds a dependency graph from Databricks Jobs API task dependencies
|
|
25
|
+
and Unity Catalog lineage (table-level upstream/downstream).
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def build_from_tasks(self, tasks: list[dict]) -> DependencyGraph:
|
|
29
|
+
"""
|
|
30
|
+
Build DAG from a list of task definitions (from Jobs API).
|
|
31
|
+
Each task has 'task_key' and optional 'depends_on' list.
|
|
32
|
+
"""
|
|
33
|
+
graph = DependencyGraph()
|
|
34
|
+
|
|
35
|
+
for task in tasks:
|
|
36
|
+
key = task.get("task_key", "")
|
|
37
|
+
graph.nodes.add(key)
|
|
38
|
+
for dep in task.get("depends_on", []):
|
|
39
|
+
dep_key = dep.get("task_key", "") if isinstance(dep, dict) else dep
|
|
40
|
+
graph.upstream[key].append(dep_key)
|
|
41
|
+
graph.downstream[dep_key].append(key)
|
|
42
|
+
|
|
43
|
+
return graph
|
|
44
|
+
|
|
45
|
+
def find_root_ancestors(self, graph: DependencyGraph, failed_keys: set[str]) -> set[str]:
|
|
46
|
+
"""
|
|
47
|
+
Traverse upstream to find root ancestors of failed tasks.
|
|
48
|
+
A root ancestor is a failed task with no failed upstream dependencies.
|
|
49
|
+
"""
|
|
50
|
+
roots = set()
|
|
51
|
+
|
|
52
|
+
for key in failed_keys:
|
|
53
|
+
# Walk upstream until we find a failed task with no failed parents
|
|
54
|
+
current = key
|
|
55
|
+
visited = set()
|
|
56
|
+
|
|
57
|
+
while current not in visited:
|
|
58
|
+
visited.add(current)
|
|
59
|
+
upstream_failed = [
|
|
60
|
+
u for u in graph.upstream.get(current, [])
|
|
61
|
+
if u in failed_keys
|
|
62
|
+
]
|
|
63
|
+
if not upstream_failed:
|
|
64
|
+
roots.add(current)
|
|
65
|
+
break
|
|
66
|
+
current = upstream_failed[0] # Follow first failed parent
|
|
67
|
+
|
|
68
|
+
return roots
|
|
69
|
+
|
|
70
|
+
def get_all_downstream(self, graph: DependencyGraph, task_key: str) -> list[str]:
|
|
71
|
+
"""BFS to find all transitively downstream tasks."""
|
|
72
|
+
downstream = []
|
|
73
|
+
visited = set()
|
|
74
|
+
queue = deque([task_key])
|
|
75
|
+
|
|
76
|
+
while queue:
|
|
77
|
+
current = queue.popleft()
|
|
78
|
+
if current in visited:
|
|
79
|
+
continue
|
|
80
|
+
visited.add(current)
|
|
81
|
+
|
|
82
|
+
for child in graph.downstream.get(current, []):
|
|
83
|
+
if child not in visited:
|
|
84
|
+
downstream.append(child)
|
|
85
|
+
queue.append(child)
|
|
86
|
+
|
|
87
|
+
return downstream
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class EventBatcher:
|
|
91
|
+
"""
|
|
92
|
+
Batches failure events by dependency graph, deduplicating cascading
|
|
93
|
+
downstream failures into single root-cause events.
|
|
94
|
+
"""
|
|
95
|
+
|
|
96
|
+
def __init__(self):
|
|
97
|
+
self.graph_builder = DependencyGraphBuilder()
|
|
98
|
+
|
|
99
|
+
def deduplicate_to_root_causes(
|
|
100
|
+
self, events: list[FailureEvent], tasks: list[dict]
|
|
101
|
+
) -> list[RootCauseEvent]:
|
|
102
|
+
"""
|
|
103
|
+
Given a batch of failure events and the pipeline's task definitions,
|
|
104
|
+
identify true root causes and group derived failures.
|
|
105
|
+
"""
|
|
106
|
+
if not events:
|
|
107
|
+
return []
|
|
108
|
+
|
|
109
|
+
graph = self.graph_builder.build_from_tasks(tasks)
|
|
110
|
+
failed_keys = {e.task_key for e in events}
|
|
111
|
+
roots = self.graph_builder.find_root_ancestors(graph, failed_keys)
|
|
112
|
+
|
|
113
|
+
# Map events by task_key for lookup
|
|
114
|
+
event_map = {e.task_key: e for e in events}
|
|
115
|
+
|
|
116
|
+
root_cause_events = []
|
|
117
|
+
for root_key in roots:
|
|
118
|
+
root_event = event_map.get(root_key)
|
|
119
|
+
if not root_event:
|
|
120
|
+
continue
|
|
121
|
+
|
|
122
|
+
# Find all downstream failures derived from this root
|
|
123
|
+
all_downstream = self.graph_builder.get_all_downstream(graph, root_key)
|
|
124
|
+
derived = [event_map[k] for k in all_downstream if k in event_map and k != root_key]
|
|
125
|
+
|
|
126
|
+
error_hash = compute_error_hash(
|
|
127
|
+
error_type=root_event.error_message[:50],
|
|
128
|
+
notebook_path=root_key,
|
|
129
|
+
affected_tables=[],
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
root_cause_events.append(RootCauseEvent(
|
|
133
|
+
root_run_id=root_event.run_id,
|
|
134
|
+
root_job_id=root_event.job_id,
|
|
135
|
+
root_task_key=root_key,
|
|
136
|
+
error_hash=error_hash,
|
|
137
|
+
derived_failures=derived,
|
|
138
|
+
total_downstream_impact=len(derived),
|
|
139
|
+
))
|
|
140
|
+
|
|
141
|
+
return root_cause_events
|
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
"""AI Diagnosis Engine — structured LLM invocation with schema validation."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
from healing_kit.clients.databricks_client import DatabricksClient
|
|
7
|
+
from healing_kit.models.diagnosis import ActionId, DiagnosisResponse
|
|
8
|
+
from healing_kit.models.evidence import EvidencePackage
|
|
9
|
+
from healing_kit.services.model_router import ModelRouter
|
|
10
|
+
from healing_kit.services.token_budget import TokenBudgetEnforcer
|
|
11
|
+
|
|
12
|
+
SYSTEM_PROMPT = """You are a Databricks pipeline failure diagnosis expert.
|
|
13
|
+
You receive structured evidence from a failed pipeline task and must return a JSON diagnosis.
|
|
14
|
+
|
|
15
|
+
RULES:
|
|
16
|
+
- NEVER suggest DROP, DELETE, TRUNCATE, or ALTER on user tables.
|
|
17
|
+
- NEVER return free text. Only parseable JSON.
|
|
18
|
+
- confidence_score must be between 0.0 and 1.0.
|
|
19
|
+
- action_id must be one of: CLUSTER_RESIZE, RETRY, SCHEMA_FIX, CODE_PATCH, ESCALATE, UNKNOWN.
|
|
20
|
+
|
|
21
|
+
REQUIRED JSON OUTPUT:
|
|
22
|
+
{
|
|
23
|
+
"root_cause": "specific human-readable explanation",
|
|
24
|
+
"confidence_score": 0.0-1.0,
|
|
25
|
+
"action_id": "CLUSTER_RESIZE | RETRY | SCHEMA_FIX | CODE_PATCH | ESCALATE | UNKNOWN",
|
|
26
|
+
"action_params": {},
|
|
27
|
+
"reasoning": "2-3 sentence explanation citing evidence",
|
|
28
|
+
"evidence_used": ["list of evidence sources you consulted"],
|
|
29
|
+
"diagnostic_query": "SELECT query to show bad rows (or null)",
|
|
30
|
+
"code_issue": "specific line/logic issue (or null)"
|
|
31
|
+
}"""
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class DiagnosisEngine:
|
|
35
|
+
"""
|
|
36
|
+
Sends enriched context to LLM, validates response schema,
|
|
37
|
+
integrates with model router and token budget.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def __init__(self, client: DatabricksClient, model_router: ModelRouter, token_budget: Optional[TokenBudgetEnforcer] = None):
|
|
41
|
+
self.client = client
|
|
42
|
+
self.router = model_router
|
|
43
|
+
self.budget = token_budget
|
|
44
|
+
|
|
45
|
+
def diagnose(self, evidence: EvidencePackage, hint_action_id: str = "UNKNOWN") -> DiagnosisResponse:
|
|
46
|
+
"""
|
|
47
|
+
Invoke the appropriate LLM with structured evidence.
|
|
48
|
+
Returns validated DiagnosisResponse.
|
|
49
|
+
"""
|
|
50
|
+
# Check token budget
|
|
51
|
+
model_tier = self.router.get_model_tier(ActionId(hint_action_id))
|
|
52
|
+
if self.budget and not self.budget.can_invoke_model(model_tier):
|
|
53
|
+
return DiagnosisResponse(
|
|
54
|
+
root_cause="Token budget exhausted — degraded mode active",
|
|
55
|
+
confidence_score=0.0,
|
|
56
|
+
action_id=ActionId.ESCALATE,
|
|
57
|
+
reasoning="System is in degraded mode due to token budget limits",
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
# Select model
|
|
61
|
+
model = self.router.get_model(ActionId(hint_action_id))
|
|
62
|
+
if model is None:
|
|
63
|
+
return DiagnosisResponse(
|
|
64
|
+
root_cause="No LLM needed for this action type",
|
|
65
|
+
confidence_score=1.0,
|
|
66
|
+
action_id=ActionId(hint_action_id),
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
# Build prompt
|
|
70
|
+
prompt = self._build_prompt(evidence)
|
|
71
|
+
|
|
72
|
+
# Call LLM
|
|
73
|
+
try:
|
|
74
|
+
response_text = self.client.invoke_model(
|
|
75
|
+
endpoint_name=model,
|
|
76
|
+
messages=[
|
|
77
|
+
{"role": "system", "content": SYSTEM_PROMPT},
|
|
78
|
+
{"role": "user", "content": prompt},
|
|
79
|
+
],
|
|
80
|
+
max_tokens=2000,
|
|
81
|
+
temperature=0.1,
|
|
82
|
+
)
|
|
83
|
+
except Exception as e:
|
|
84
|
+
return DiagnosisResponse(
|
|
85
|
+
root_cause=f"LLM invocation failed: {str(e)[:200]}",
|
|
86
|
+
confidence_score=0.0,
|
|
87
|
+
action_id=ActionId.ESCALATE,
|
|
88
|
+
reasoning="Model serving endpoint error",
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
# Record token usage (estimate)
|
|
92
|
+
if self.budget:
|
|
93
|
+
estimated_tokens = len(prompt.split()) + len(response_text.split())
|
|
94
|
+
self.budget.record_usage(estimated_tokens, estimated_tokens * 0.00001)
|
|
95
|
+
|
|
96
|
+
# Parse and validate
|
|
97
|
+
return self._parse_response(response_text)
|
|
98
|
+
|
|
99
|
+
def _build_prompt(self, evidence: EvidencePackage) -> str:
|
|
100
|
+
"""Build the structured prompt from the evidence package."""
|
|
101
|
+
sections = [f"FAILED TASK: {evidence.task_key}", f"JOB: {evidence.job_id}", f"RUN: {evidence.run_id}"]
|
|
102
|
+
|
|
103
|
+
if evidence.driver_stdout:
|
|
104
|
+
if evidence.driver_stdout.schema_errors:
|
|
105
|
+
sections.append("SCHEMA ERRORS:\n" + "\n".join(evidence.driver_stdout.schema_errors[:5]))
|
|
106
|
+
if evidence.driver_stdout.data_source_exceptions:
|
|
107
|
+
sections.append("EXCEPTIONS:\n" + "\n".join(evidence.driver_stdout.data_source_exceptions[:10]))
|
|
108
|
+
if evidence.driver_stdout.missing_table_errors:
|
|
109
|
+
sections.append("MISSING TABLES:\n" + "\n".join(evidence.driver_stdout.missing_table_errors[:5]))
|
|
110
|
+
|
|
111
|
+
if evidence.spark_event_logs and evidence.spark_event_logs.first_errors:
|
|
112
|
+
sections.append("NOTEBOOK SOURCE / SPARK LOGS:\n" + "\n".join(evidence.spark_event_logs.first_errors[:1])[:2000])
|
|
113
|
+
|
|
114
|
+
if evidence.task_metrics:
|
|
115
|
+
m = evidence.task_metrics
|
|
116
|
+
sections.append(f"METRICS: spill={m.spill_to_disk_bytes}B, GC={m.gc_overhead_pct}%, shuffle_r={m.shuffle_read_bytes}B")
|
|
117
|
+
|
|
118
|
+
if evidence.git_context:
|
|
119
|
+
sections.append(f"GIT: {evidence.git_context.last_commit_author} — {evidence.git_context.last_commit_message}")
|
|
120
|
+
|
|
121
|
+
if evidence.cluster_events and evidence.cluster_events.termination_reason:
|
|
122
|
+
sections.append(f"CLUSTER: terminated={evidence.cluster_events.termination_reason}")
|
|
123
|
+
|
|
124
|
+
if evidence.missing_sources:
|
|
125
|
+
sections.append(f"UNAVAILABLE SOURCES: {', '.join(evidence.missing_sources)}")
|
|
126
|
+
|
|
127
|
+
return "\n\n".join(sections)
|
|
128
|
+
|
|
129
|
+
def _parse_response(self, response_text: str) -> DiagnosisResponse:
|
|
130
|
+
"""Parse and validate the LLM JSON response."""
|
|
131
|
+
try:
|
|
132
|
+
cleaned = response_text.strip()
|
|
133
|
+
if cleaned.startswith("```"):
|
|
134
|
+
cleaned = cleaned.split("\n", 1)[1] if "\n" in cleaned else cleaned[3:]
|
|
135
|
+
if cleaned.endswith("```"):
|
|
136
|
+
cleaned = cleaned[:-3]
|
|
137
|
+
|
|
138
|
+
data = json.loads(cleaned.strip())
|
|
139
|
+
|
|
140
|
+
# Validate action_id
|
|
141
|
+
action_id_str = data.get("action_id", "UNKNOWN")
|
|
142
|
+
valid_actions = [a.value for a in ActionId]
|
|
143
|
+
if action_id_str not in valid_actions:
|
|
144
|
+
action_id_str = "UNKNOWN"
|
|
145
|
+
|
|
146
|
+
# Validate confidence
|
|
147
|
+
confidence = float(data.get("confidence_score", 0))
|
|
148
|
+
confidence = max(0.0, min(1.0, confidence))
|
|
149
|
+
|
|
150
|
+
return DiagnosisResponse(
|
|
151
|
+
root_cause=data.get("root_cause", "Unknown"),
|
|
152
|
+
confidence_score=confidence,
|
|
153
|
+
action_id=ActionId(action_id_str),
|
|
154
|
+
action_params=data.get("action_params", {}),
|
|
155
|
+
reasoning=data.get("reasoning", ""),
|
|
156
|
+
evidence_used=data.get("evidence_used", []),
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
except (json.JSONDecodeError, KeyError, ValueError):
|
|
160
|
+
return DiagnosisResponse(
|
|
161
|
+
root_cause=f"LLM response not parseable: {response_text[:200]}",
|
|
162
|
+
confidence_score=0.0,
|
|
163
|
+
action_id=ActionId.ESCALATE,
|
|
164
|
+
reasoning="Response failed JSON schema validation",
|
|
165
|
+
)
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
"""Approver authorization (#3).
|
|
2
|
+
|
|
3
|
+
When a user clicks Approve/Reject, we must confirm — by governance — that the
|
|
4
|
+
clicker actually has access to this workspace before honoring the action. The
|
|
5
|
+
clicker is identified by their SSO login or Teams email, then verified against
|
|
6
|
+
the workspace via SCIM (user exists + is active, and optionally is in an
|
|
7
|
+
approvers group). If they don't have access we refuse with a clear message.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from dataclasses import dataclass
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass
|
|
14
|
+
class ApproverDecision:
|
|
15
|
+
"""Result of authorizing an approver."""
|
|
16
|
+
|
|
17
|
+
authorized: bool
|
|
18
|
+
identity: str
|
|
19
|
+
user_id: str = ""
|
|
20
|
+
display_name: str = ""
|
|
21
|
+
reason: str = ""
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
# Message returned to the relay/handler when the clicker isn't entitled.
|
|
25
|
+
ACCESS_DENIED_MESSAGE = "You don't have enough access to approve this action."
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def authorize_approver(client, identity: str, approvers_group: str = "") -> ApproverDecision:
|
|
29
|
+
"""Authorize an approver by SSO login / Teams email against workspace governance.
|
|
30
|
+
|
|
31
|
+
client: a DatabricksClient (with find_user_by_identity / user_in_group).
|
|
32
|
+
identity: the clicker's login or email (verified upstream by the relay's SSO).
|
|
33
|
+
approvers_group: optional group the user must belong to (segregation of duties).
|
|
34
|
+
"""
|
|
35
|
+
ident = (identity or "").strip()
|
|
36
|
+
if not ident:
|
|
37
|
+
return ApproverDecision(False, ident, reason="No approver identity supplied")
|
|
38
|
+
|
|
39
|
+
user = client.find_user_by_identity(ident)
|
|
40
|
+
if not user:
|
|
41
|
+
return ApproverDecision(False, ident, reason=ACCESS_DENIED_MESSAGE)
|
|
42
|
+
|
|
43
|
+
if user.get("active") is False:
|
|
44
|
+
return ApproverDecision(
|
|
45
|
+
False, ident, user_id=str(user.get("id", "")),
|
|
46
|
+
reason="User exists but is deactivated.",
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
if approvers_group and not client.user_in_group(user, approvers_group):
|
|
50
|
+
return ApproverDecision(
|
|
51
|
+
False, ident, user_id=str(user.get("id", "")),
|
|
52
|
+
display_name=user.get("displayName", ""),
|
|
53
|
+
reason=f"User is not in the '{approvers_group}' approvers group.",
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
return ApproverDecision(
|
|
57
|
+
True, ident,
|
|
58
|
+
user_id=str(user.get("id", "")),
|
|
59
|
+
display_name=user.get("displayName", ""),
|
|
60
|
+
reason="ok",
|
|
61
|
+
)
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
"""Deterministic model routing by action_id."""
|
|
2
|
+
|
|
3
|
+
from healing_kit.models.diagnosis import ActionId
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ModelRouter:
|
|
7
|
+
"""
|
|
8
|
+
Routes LLM calls to the appropriate model based on action_id.
|
|
9
|
+
|
|
10
|
+
Routing is fully deterministic and auditable:
|
|
11
|
+
- RETRY, CLUSTER_RESIZE → lightweight model (cost-efficient)
|
|
12
|
+
- SCHEMA_FIX, CODE_PATCH, UNKNOWN → powerful model (needs reasoning)
|
|
13
|
+
- CACHE_HIT, ESCALATE → no LLM call (zero tokens)
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
def __init__(self, lightweight_model: str, powerful_model: str):
|
|
17
|
+
self.lightweight_model = lightweight_model
|
|
18
|
+
self.powerful_model = powerful_model
|
|
19
|
+
|
|
20
|
+
self._routing_table = {
|
|
21
|
+
ActionId.RETRY: lightweight_model,
|
|
22
|
+
ActionId.CLUSTER_RESIZE: lightweight_model,
|
|
23
|
+
ActionId.SCHEMA_FIX: powerful_model,
|
|
24
|
+
ActionId.CODE_PATCH: powerful_model,
|
|
25
|
+
ActionId.UNKNOWN: powerful_model,
|
|
26
|
+
ActionId.CACHE_HIT: None,
|
|
27
|
+
ActionId.ESCALATE: None,
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
def get_model(self, action_id: ActionId) -> str | None:
|
|
31
|
+
"""
|
|
32
|
+
Get the model endpoint name for a given action_id.
|
|
33
|
+
Returns None if no LLM call is needed.
|
|
34
|
+
"""
|
|
35
|
+
if isinstance(action_id, str):
|
|
36
|
+
action_id = ActionId(action_id)
|
|
37
|
+
return self._routing_table.get(action_id)
|
|
38
|
+
|
|
39
|
+
def get_model_tier(self, action_id: ActionId) -> str:
|
|
40
|
+
"""
|
|
41
|
+
Get the tier ('lightweight', 'powerful', or 'none') for budget checking.
|
|
42
|
+
"""
|
|
43
|
+
model = self.get_model(action_id)
|
|
44
|
+
if model is None:
|
|
45
|
+
return "none"
|
|
46
|
+
if model == self.lightweight_model:
|
|
47
|
+
return "lightweight"
|
|
48
|
+
return "powerful"
|
|
49
|
+
|
|
50
|
+
def requires_llm(self, action_id: ActionId) -> bool:
|
|
51
|
+
"""Check if this action requires an LLM call."""
|
|
52
|
+
return self.get_model(action_id) is not None
|
|
@@ -0,0 +1,168 @@
|
|
|
1
|
+
"""Query guard — allowlist + classification-aware masking for diagnostic SELECTs.
|
|
2
|
+
|
|
3
|
+
Closes the prompt-injection -> data-exfiltration chain (#2). The LLM-emitted
|
|
4
|
+
diagnostic query is:
|
|
5
|
+
1. checked read-only (single SELECT/CTE, no DDL/DML/comments/stacking),
|
|
6
|
+
2. parsed with sqlglot so every *referenced table* is matched against an
|
|
7
|
+
explicit allowlist (deny-by-default) — far stronger than keyword filtering,
|
|
8
|
+
3. inspected for sensitive columns; result values are masked BY DEFAULT and
|
|
9
|
+
only sent raw to Teams when the data classification + policy permit it.
|
|
10
|
+
|
|
11
|
+
Pure logic, no Spark — unit-tested with plain strings/rows.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from dataclasses import dataclass, field
|
|
15
|
+
|
|
16
|
+
import sqlglot
|
|
17
|
+
from sqlglot import exp
|
|
18
|
+
|
|
19
|
+
from healing_kit.utils.sql_safety import is_safe_select
|
|
20
|
+
|
|
21
|
+
# Evidence policies (what may leave the workspace into Teams):
|
|
22
|
+
POLICY_SCHEMA_ONLY = "schema_only" # never send values; only column names + row count
|
|
23
|
+
POLICY_MASKED = "masked" # send masked values (DEFAULT)
|
|
24
|
+
POLICY_RAW_IF_ALLOWED = "raw_if_allowed" # raw only if no sensitive column referenced
|
|
25
|
+
|
|
26
|
+
# Conservative default set of sensitive column-name fragments (case-insensitive).
|
|
27
|
+
DEFAULT_SENSITIVE_FRAGMENTS = (
|
|
28
|
+
"ssn", "social_security", "email", "e_mail", "phone", "mobile", "dob",
|
|
29
|
+
"birth", "address", "passport", "license", "credit", "card", "cvv",
|
|
30
|
+
"iban", "account_number", "acct_no", "salary", "password", "secret",
|
|
31
|
+
"token", "national_id", "tax_id", "mrn", "patient", "diagnosis_code",
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@dataclass
|
|
36
|
+
class GuardDecision:
|
|
37
|
+
"""Outcome of evaluating a diagnostic query."""
|
|
38
|
+
|
|
39
|
+
allowed: bool
|
|
40
|
+
reason: str
|
|
41
|
+
tables: list[str] = field(default_factory=list)
|
|
42
|
+
columns: list[str] = field(default_factory=list)
|
|
43
|
+
sensitive_columns: list[str] = field(default_factory=list)
|
|
44
|
+
has_star: bool = False
|
|
45
|
+
# How rows must be handled if the query runs: schema_only | masked | raw
|
|
46
|
+
effective_policy: str = POLICY_MASKED
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class QueryGuard:
|
|
50
|
+
"""Allowlist + masking gate for LLM-generated diagnostic queries."""
|
|
51
|
+
|
|
52
|
+
def __init__(
|
|
53
|
+
self,
|
|
54
|
+
allowed_tables,
|
|
55
|
+
sensitive_fragments=None,
|
|
56
|
+
policy: str = POLICY_MASKED,
|
|
57
|
+
default_catalog: str | None = None,
|
|
58
|
+
default_schema: str | None = None,
|
|
59
|
+
dialect: str = "databricks",
|
|
60
|
+
):
|
|
61
|
+
# allowed_tables: iterable of "catalog.schema.table", "catalog.schema.*",
|
|
62
|
+
# "catalog.*", or "*" (allow all — discouraged). Stored lowercased.
|
|
63
|
+
self.allowed = {t.strip().lower() for t in allowed_tables if t and t.strip()}
|
|
64
|
+
self.sensitive_fragments = tuple(
|
|
65
|
+
f.lower() for f in (sensitive_fragments or DEFAULT_SENSITIVE_FRAGMENTS)
|
|
66
|
+
)
|
|
67
|
+
self.policy = policy if policy in (POLICY_SCHEMA_ONLY, POLICY_MASKED, POLICY_RAW_IF_ALLOWED) else POLICY_MASKED
|
|
68
|
+
self.default_catalog = (default_catalog or "").lower()
|
|
69
|
+
self.default_schema = (default_schema or "").lower()
|
|
70
|
+
self.dialect = dialect
|
|
71
|
+
|
|
72
|
+
# ─── table allowlisting ───
|
|
73
|
+
|
|
74
|
+
def _qualify(self, table: exp.Table) -> str:
|
|
75
|
+
catalog = (table.catalog or self.default_catalog).lower()
|
|
76
|
+
schema = (table.db or self.default_schema).lower()
|
|
77
|
+
name = (table.name or "").lower()
|
|
78
|
+
return ".".join(p for p in (catalog, schema, name) if p)
|
|
79
|
+
|
|
80
|
+
def _table_allowed(self, qualified: str) -> bool:
|
|
81
|
+
if "*" in self.allowed:
|
|
82
|
+
return True
|
|
83
|
+
if qualified in self.allowed:
|
|
84
|
+
return True
|
|
85
|
+
parts = qualified.split(".")
|
|
86
|
+
# Try progressively broader wildcards: cat.sch.*, cat.*
|
|
87
|
+
for i in range(len(parts) - 1, 0, -1):
|
|
88
|
+
if ".".join(parts[:i] + ["*"]) in self.allowed:
|
|
89
|
+
return True
|
|
90
|
+
return False
|
|
91
|
+
|
|
92
|
+
def _is_sensitive(self, col_name: str) -> bool:
|
|
93
|
+
c = col_name.lower()
|
|
94
|
+
return any(frag in c for frag in self.sensitive_fragments)
|
|
95
|
+
|
|
96
|
+
def evaluate(self, query: str) -> GuardDecision:
|
|
97
|
+
"""Decide whether a query may run and how its rows must be handled."""
|
|
98
|
+
if not is_safe_select(query):
|
|
99
|
+
return GuardDecision(False, "Not a single read-only SELECT/CTE")
|
|
100
|
+
|
|
101
|
+
try:
|
|
102
|
+
parsed = sqlglot.parse_one(query, read=self.dialect)
|
|
103
|
+
except Exception as ex: # unparseable -> reject
|
|
104
|
+
return GuardDecision(False, f"Unparseable SQL: {str(ex)[:80]}")
|
|
105
|
+
|
|
106
|
+
tables = sorted({self._qualify(t) for t in parsed.find_all(exp.Table)})
|
|
107
|
+
if not tables:
|
|
108
|
+
return GuardDecision(False, "No table referenced")
|
|
109
|
+
|
|
110
|
+
not_allowed = [t for t in tables if not self._table_allowed(t)]
|
|
111
|
+
if not_allowed:
|
|
112
|
+
return GuardDecision(
|
|
113
|
+
False,
|
|
114
|
+
f"Table(s) not on evidence allowlist: {', '.join(not_allowed)}",
|
|
115
|
+
tables=tables,
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
has_star = bool(list(parsed.find_all(exp.Star)))
|
|
119
|
+
columns = sorted({c.name.lower() for c in parsed.find_all(exp.Column) if c.name})
|
|
120
|
+
sensitive = sorted({c for c in columns if self._is_sensitive(c)})
|
|
121
|
+
|
|
122
|
+
# Decide the effective row policy.
|
|
123
|
+
if self.policy == POLICY_SCHEMA_ONLY:
|
|
124
|
+
effective = POLICY_SCHEMA_ONLY
|
|
125
|
+
elif self.policy == POLICY_RAW_IF_ALLOWED:
|
|
126
|
+
# SELECT * could expose unknown sensitive columns -> downgrade to masked.
|
|
127
|
+
effective = "masked" if (sensitive or has_star) else "raw"
|
|
128
|
+
else: # POLICY_MASKED (default)
|
|
129
|
+
effective = "masked"
|
|
130
|
+
|
|
131
|
+
return GuardDecision(
|
|
132
|
+
allowed=True,
|
|
133
|
+
reason="ok",
|
|
134
|
+
tables=tables,
|
|
135
|
+
columns=columns,
|
|
136
|
+
sensitive_columns=sensitive,
|
|
137
|
+
has_star=has_star,
|
|
138
|
+
effective_policy=effective,
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
# ─── output handling ───
|
|
142
|
+
|
|
143
|
+
@staticmethod
|
|
144
|
+
def _mask_value(value) -> str:
|
|
145
|
+
s = str(value)
|
|
146
|
+
if len(s) <= 2:
|
|
147
|
+
return "**"
|
|
148
|
+
return s[:2] + "*" * min(len(s) - 2, 6)
|
|
149
|
+
|
|
150
|
+
def render_rows(self, columns: list[str], rows: list[list], decision: GuardDecision, max_rows: int = 8) -> str:
|
|
151
|
+
"""Render query rows for Teams according to the decision's effective policy."""
|
|
152
|
+
if decision.effective_policy == POLICY_SCHEMA_ONLY:
|
|
153
|
+
return "columns: " + ", ".join(columns) + f"\n({len(rows)} row(s) — values withheld by policy)"
|
|
154
|
+
|
|
155
|
+
sensitive_set = set(decision.sensitive_columns)
|
|
156
|
+
lines = [" | ".join(columns)]
|
|
157
|
+
for row in rows[:max_rows]:
|
|
158
|
+
cells = []
|
|
159
|
+
for col, val in zip(columns, row):
|
|
160
|
+
if decision.effective_policy == "raw":
|
|
161
|
+
cells.append(str(val)[:20])
|
|
162
|
+
else: # masked: mask sensitive columns, lightly truncate the rest
|
|
163
|
+
if col.lower() in sensitive_set or self._is_sensitive(col):
|
|
164
|
+
cells.append(self._mask_value(val))
|
|
165
|
+
else:
|
|
166
|
+
cells.append(str(val)[:20])
|
|
167
|
+
lines.append(" | ".join(cells))
|
|
168
|
+
return "\n".join(lines)
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
"""Resolution verifier (#5) — close the loop after a fix is applied.
|
|
2
|
+
|
|
3
|
+
"Self-healing" only earns the name if an applied fix is *verified*: re-run the
|
|
4
|
+
failed task, confirm it now succeeds, and roll back / escalate if it doesn't.
|
|
5
|
+
This module owns that loop so the approval handler doesn't blindly mark things
|
|
6
|
+
resolved. RETRY/CLUSTER_RESIZE are verified by re-running; SCHEMA_FIX/CODE_PATCH
|
|
7
|
+
must supply an ``apply_fn`` (and ideally a ``rollback_fn``) and are gated so they
|
|
8
|
+
never silently no-op.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import time
|
|
12
|
+
from dataclasses import dataclass
|
|
13
|
+
|
|
14
|
+
# Action types we can verify by simply re-running the task.
|
|
15
|
+
RERUNNABLE = {"RETRY", "CLUSTER_RESIZE"}
|
|
16
|
+
# Action types that change code/schema and must apply a fix before re-running.
|
|
17
|
+
MUTATING = {"SCHEMA_FIX", "CODE_PATCH"}
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class VerificationResult:
|
|
22
|
+
action_id: str
|
|
23
|
+
applied: bool
|
|
24
|
+
verified: bool
|
|
25
|
+
rolled_back: bool
|
|
26
|
+
final_state: str
|
|
27
|
+
detail: str = ""
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class ResolutionVerifier:
|
|
31
|
+
def __init__(self, client, poll_seconds: int = 15, max_polls: int = 40):
|
|
32
|
+
self.client = client
|
|
33
|
+
self.poll_seconds = poll_seconds
|
|
34
|
+
self.max_polls = max_polls
|
|
35
|
+
|
|
36
|
+
def _poll_task_state(self, run_id: int, task_key: str) -> str:
|
|
37
|
+
"""Poll runs/get until the task reaches a terminal state. Returns the
|
|
38
|
+
result_state ('SUCCESS'/'FAILED'/...) or 'TIMEOUT'."""
|
|
39
|
+
for _ in range(self.max_polls):
|
|
40
|
+
try:
|
|
41
|
+
run = self.client._get("/api/2.1/jobs/runs/get", {"run_id": int(run_id)})
|
|
42
|
+
except Exception as ex:
|
|
43
|
+
return f"ERROR: {str(ex)[:80]}"
|
|
44
|
+
for t in run.get("tasks", []):
|
|
45
|
+
if t.get("task_key") == task_key:
|
|
46
|
+
state = t.get("state", {})
|
|
47
|
+
life = state.get("life_cycle_state")
|
|
48
|
+
if life in ("TERMINATED", "SKIPPED", "INTERNAL_ERROR"):
|
|
49
|
+
return state.get("result_state", "UNKNOWN")
|
|
50
|
+
time.sleep(self.poll_seconds)
|
|
51
|
+
return "TIMEOUT"
|
|
52
|
+
|
|
53
|
+
def verify(self, action_id: str, run_id, task_key: str, apply_fn=None, rollback_fn=None) -> VerificationResult:
|
|
54
|
+
"""Apply (if needed), re-run, and verify a resolution."""
|
|
55
|
+
action = str(action_id).upper()
|
|
56
|
+
|
|
57
|
+
if action in MUTATING:
|
|
58
|
+
if apply_fn is None:
|
|
59
|
+
# Never claim success for a fix we can't actually apply.
|
|
60
|
+
return VerificationResult(action, applied=False, verified=False, rolled_back=False,
|
|
61
|
+
final_state="STAGED",
|
|
62
|
+
detail=f"{action} requires an apply function; staged for manual execution.")
|
|
63
|
+
try:
|
|
64
|
+
apply_fn()
|
|
65
|
+
except Exception as ex:
|
|
66
|
+
return VerificationResult(action, applied=False, verified=False, rolled_back=False,
|
|
67
|
+
final_state="APPLY_FAILED", detail=str(ex)[:120])
|
|
68
|
+
applied = True
|
|
69
|
+
elif action in RERUNNABLE:
|
|
70
|
+
applied = True
|
|
71
|
+
else:
|
|
72
|
+
return VerificationResult(action, applied=False, verified=False, rolled_back=False,
|
|
73
|
+
final_state="NOT_VERIFIABLE",
|
|
74
|
+
detail=f"{action} is not auto-verifiable (e.g. ESCALATE).")
|
|
75
|
+
|
|
76
|
+
# Re-run the failed task and watch for the outcome.
|
|
77
|
+
try:
|
|
78
|
+
self.client.repair_run(int(run_id), [task_key])
|
|
79
|
+
except Exception as ex:
|
|
80
|
+
return VerificationResult(action, applied=applied, verified=False, rolled_back=False,
|
|
81
|
+
final_state="RERUN_FAILED", detail=str(ex)[:120])
|
|
82
|
+
|
|
83
|
+
state = self._poll_task_state(run_id, task_key)
|
|
84
|
+
if state == "SUCCESS":
|
|
85
|
+
return VerificationResult(action, applied=applied, verified=True, rolled_back=False,
|
|
86
|
+
final_state="VERIFIED", detail="Task re-ran and succeeded.")
|
|
87
|
+
|
|
88
|
+
# Did not succeed — roll back mutating fixes if we can.
|
|
89
|
+
rolled_back = False
|
|
90
|
+
if action in MUTATING and rollback_fn is not None:
|
|
91
|
+
try:
|
|
92
|
+
rollback_fn()
|
|
93
|
+
rolled_back = True
|
|
94
|
+
except Exception as ex:
|
|
95
|
+
return VerificationResult(action, applied=applied, verified=False, rolled_back=False,
|
|
96
|
+
final_state="ROLLBACK_FAILED", detail=str(ex)[:120])
|
|
97
|
+
|
|
98
|
+
return VerificationResult(action, applied=applied, verified=False, rolled_back=rolled_back,
|
|
99
|
+
final_state="ESCALATE",
|
|
100
|
+
detail=f"Re-run ended in {state}; escalating to humans.")
|