shkit 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. healing_kit/__init__.py +3 -0
  2. healing_kit/auth.py +79 -0
  3. healing_kit/clients/__init__.py +1 -0
  4. healing_kit/clients/databricks_client.py +183 -0
  5. healing_kit/clients/teams_client.py +128 -0
  6. healing_kit/models/__init__.py +1 -0
  7. healing_kit/models/diagnosis.py +45 -0
  8. healing_kit/models/events.py +30 -0
  9. healing_kit/models/evidence.py +83 -0
  10. healing_kit/runtime/__init__.py +6 -0
  11. healing_kit/runtime/approval.py +141 -0
  12. healing_kit/runtime/maintenance.py +52 -0
  13. healing_kit/services/__init__.py +1 -0
  14. healing_kit/services/cache_service.py +120 -0
  15. healing_kit/services/circuit_breaker.py +114 -0
  16. healing_kit/services/context_agent.py +127 -0
  17. healing_kit/services/dependency_graph.py +141 -0
  18. healing_kit/services/diagnosis_engine.py +165 -0
  19. healing_kit/services/identity.py +61 -0
  20. healing_kit/services/model_router.py +52 -0
  21. healing_kit/services/query_guard.py +168 -0
  22. healing_kit/services/resolution_verifier.py +100 -0
  23. healing_kit/services/token_budget.py +137 -0
  24. healing_kit/utils/__init__.py +1 -0
  25. healing_kit/utils/error_hash.py +15 -0
  26. healing_kit/utils/hmac_tokens.py +86 -0
  27. healing_kit/utils/sql_safety.py +84 -0
  28. iic/__init__.py +51 -0
  29. iic/__main__.py +18 -0
  30. iic/_console.py +235 -0
  31. iic/_doctor.py +143 -0
  32. iic/change/__init__.py +7 -0
  33. iic/change/change_detector.py +154 -0
  34. iic/context/__init__.py +7 -0
  35. iic/context/context_builder.py +117 -0
  36. iic/dependency/__init__.py +7 -0
  37. iic/dependency/dependency_analyzer.py +93 -0
  38. iic/diagnosis/__init__.py +7 -0
  39. iic/diagnosis/diagnosis_engine.py +183 -0
  40. iic/dna/__init__.py +7 -0
  41. iic/dna/dna_builder.py +184 -0
  42. iic/impact/__init__.py +7 -0
  43. iic/impact/impact_engine.py +102 -0
  44. iic/ingestion/__init__.py +14 -0
  45. iic/ingestion/base.py +21 -0
  46. iic/ingestion/databricks_source.py +98 -0
  47. iic/ingestion/static_source.py +23 -0
  48. iic/ingestion/webhook_source.py +39 -0
  49. iic/models/__init__.py +44 -0
  50. iic/models/change.py +77 -0
  51. iic/models/context.py +46 -0
  52. iic/models/diagnosis.py +37 -0
  53. iic/models/dna.py +77 -0
  54. iic/models/event.py +78 -0
  55. iic/models/impact.py +60 -0
  56. iic/models/report.py +88 -0
  57. iic/models/routing.py +41 -0
  58. iic/notify/__init__.py +7 -0
  59. iic/notify/teams_notifier.py +112 -0
  60. iic/report/__init__.py +7 -0
  61. iic/report/report_generator.py +67 -0
  62. iic/routing/__init__.py +7 -0
  63. iic/routing/router.py +42 -0
  64. iic/runtime/__init__.py +10 -0
  65. iic/runtime/_sql.py +11 -0
  66. iic/runtime/agent_config.py +48 -0
  67. iic/runtime/agent_runtime.py +70 -0
  68. iic/runtime/antibodies.py +100 -0
  69. iic/runtime/bootstrap.py +157 -0
  70. iic/runtime/constants.py +40 -0
  71. iic/runtime/context.py +46 -0
  72. iic/runtime/detective.py +72 -0
  73. iic/runtime/hooks.py +85 -0
  74. iic/runtime/incident_engine.py +207 -0
  75. iic/runtime/inprocess.py +350 -0
  76. iic/runtime/ledger.py +120 -0
  77. iic/runtime/monitor.py +155 -0
  78. iic/runtime/pattern_store.py +53 -0
  79. iic/runtime/reconciler.py +139 -0
  80. iic/runtime/scope_config.py +127 -0
  81. iic/runtime/store.py +150 -0
  82. iic/runtime/wrapper.py +28 -0
  83. iic_autoload.pth +1 -0
  84. onboarding/__init__.py +1 -0
  85. onboarding/cli.py +168 -0
  86. onboarding/config_schema.py +62 -0
  87. onboarding/manifest.py +27 -0
  88. onboarding/preflight.py +129 -0
  89. onboarding/provisioner.py +573 -0
  90. onboarding/rollback.py +81 -0
  91. shkit-1.2.0.dist-info/METADATA +239 -0
  92. shkit-1.2.0.dist-info/RECORD +94 -0
  93. shkit-1.2.0.dist-info/WHEEL +4 -0
  94. shkit-1.2.0.dist-info/entry_points.txt +2 -0
@@ -0,0 +1,141 @@
1
+ """Dependency graph builder and event batching for thundering herd prevention."""
2
+
3
+ from collections import defaultdict, deque
4
+ from dataclasses import dataclass, field
5
+
6
+ from healing_kit.models.events import FailureEvent, RootCauseEvent
7
+ from healing_kit.utils.error_hash import compute_error_hash
8
+
9
+
10
+ @dataclass
11
+ class DependencyGraph:
12
+ """DAG representing task and table dependencies."""
13
+
14
+ # task_key → list of upstream task_keys
15
+ upstream: dict[str, list[str]] = field(default_factory=lambda: defaultdict(list))
16
+ # task_key → list of downstream task_keys
17
+ downstream: dict[str, list[str]] = field(default_factory=lambda: defaultdict(list))
18
+ # All task keys in the graph
19
+ nodes: set[str] = field(default_factory=set)
20
+
21
+
22
+ class DependencyGraphBuilder:
23
+ """
24
+ Builds a dependency graph from Databricks Jobs API task dependencies
25
+ and Unity Catalog lineage (table-level upstream/downstream).
26
+ """
27
+
28
+ def build_from_tasks(self, tasks: list[dict]) -> DependencyGraph:
29
+ """
30
+ Build DAG from a list of task definitions (from Jobs API).
31
+ Each task has 'task_key' and optional 'depends_on' list.
32
+ """
33
+ graph = DependencyGraph()
34
+
35
+ for task in tasks:
36
+ key = task.get("task_key", "")
37
+ graph.nodes.add(key)
38
+ for dep in task.get("depends_on", []):
39
+ dep_key = dep.get("task_key", "") if isinstance(dep, dict) else dep
40
+ graph.upstream[key].append(dep_key)
41
+ graph.downstream[dep_key].append(key)
42
+
43
+ return graph
44
+
45
+ def find_root_ancestors(self, graph: DependencyGraph, failed_keys: set[str]) -> set[str]:
46
+ """
47
+ Traverse upstream to find root ancestors of failed tasks.
48
+ A root ancestor is a failed task with no failed upstream dependencies.
49
+ """
50
+ roots = set()
51
+
52
+ for key in failed_keys:
53
+ # Walk upstream until we find a failed task with no failed parents
54
+ current = key
55
+ visited = set()
56
+
57
+ while current not in visited:
58
+ visited.add(current)
59
+ upstream_failed = [
60
+ u for u in graph.upstream.get(current, [])
61
+ if u in failed_keys
62
+ ]
63
+ if not upstream_failed:
64
+ roots.add(current)
65
+ break
66
+ current = upstream_failed[0] # Follow first failed parent
67
+
68
+ return roots
69
+
70
+ def get_all_downstream(self, graph: DependencyGraph, task_key: str) -> list[str]:
71
+ """BFS to find all transitively downstream tasks."""
72
+ downstream = []
73
+ visited = set()
74
+ queue = deque([task_key])
75
+
76
+ while queue:
77
+ current = queue.popleft()
78
+ if current in visited:
79
+ continue
80
+ visited.add(current)
81
+
82
+ for child in graph.downstream.get(current, []):
83
+ if child not in visited:
84
+ downstream.append(child)
85
+ queue.append(child)
86
+
87
+ return downstream
88
+
89
+
90
+ class EventBatcher:
91
+ """
92
+ Batches failure events by dependency graph, deduplicating cascading
93
+ downstream failures into single root-cause events.
94
+ """
95
+
96
+ def __init__(self):
97
+ self.graph_builder = DependencyGraphBuilder()
98
+
99
+ def deduplicate_to_root_causes(
100
+ self, events: list[FailureEvent], tasks: list[dict]
101
+ ) -> list[RootCauseEvent]:
102
+ """
103
+ Given a batch of failure events and the pipeline's task definitions,
104
+ identify true root causes and group derived failures.
105
+ """
106
+ if not events:
107
+ return []
108
+
109
+ graph = self.graph_builder.build_from_tasks(tasks)
110
+ failed_keys = {e.task_key for e in events}
111
+ roots = self.graph_builder.find_root_ancestors(graph, failed_keys)
112
+
113
+ # Map events by task_key for lookup
114
+ event_map = {e.task_key: e for e in events}
115
+
116
+ root_cause_events = []
117
+ for root_key in roots:
118
+ root_event = event_map.get(root_key)
119
+ if not root_event:
120
+ continue
121
+
122
+ # Find all downstream failures derived from this root
123
+ all_downstream = self.graph_builder.get_all_downstream(graph, root_key)
124
+ derived = [event_map[k] for k in all_downstream if k in event_map and k != root_key]
125
+
126
+ error_hash = compute_error_hash(
127
+ error_type=root_event.error_message[:50],
128
+ notebook_path=root_key,
129
+ affected_tables=[],
130
+ )
131
+
132
+ root_cause_events.append(RootCauseEvent(
133
+ root_run_id=root_event.run_id,
134
+ root_job_id=root_event.job_id,
135
+ root_task_key=root_key,
136
+ error_hash=error_hash,
137
+ derived_failures=derived,
138
+ total_downstream_impact=len(derived),
139
+ ))
140
+
141
+ return root_cause_events
@@ -0,0 +1,165 @@
1
+ """AI Diagnosis Engine — structured LLM invocation with schema validation."""
2
+
3
+ import json
4
+ from typing import Optional
5
+
6
+ from healing_kit.clients.databricks_client import DatabricksClient
7
+ from healing_kit.models.diagnosis import ActionId, DiagnosisResponse
8
+ from healing_kit.models.evidence import EvidencePackage
9
+ from healing_kit.services.model_router import ModelRouter
10
+ from healing_kit.services.token_budget import TokenBudgetEnforcer
11
+
12
+ SYSTEM_PROMPT = """You are a Databricks pipeline failure diagnosis expert.
13
+ You receive structured evidence from a failed pipeline task and must return a JSON diagnosis.
14
+
15
+ RULES:
16
+ - NEVER suggest DROP, DELETE, TRUNCATE, or ALTER on user tables.
17
+ - NEVER return free text. Only parseable JSON.
18
+ - confidence_score must be between 0.0 and 1.0.
19
+ - action_id must be one of: CLUSTER_RESIZE, RETRY, SCHEMA_FIX, CODE_PATCH, ESCALATE, UNKNOWN.
20
+
21
+ REQUIRED JSON OUTPUT:
22
+ {
23
+ "root_cause": "specific human-readable explanation",
24
+ "confidence_score": 0.0-1.0,
25
+ "action_id": "CLUSTER_RESIZE | RETRY | SCHEMA_FIX | CODE_PATCH | ESCALATE | UNKNOWN",
26
+ "action_params": {},
27
+ "reasoning": "2-3 sentence explanation citing evidence",
28
+ "evidence_used": ["list of evidence sources you consulted"],
29
+ "diagnostic_query": "SELECT query to show bad rows (or null)",
30
+ "code_issue": "specific line/logic issue (or null)"
31
+ }"""
32
+
33
+
34
+ class DiagnosisEngine:
35
+ """
36
+ Sends enriched context to LLM, validates response schema,
37
+ integrates with model router and token budget.
38
+ """
39
+
40
+ def __init__(self, client: DatabricksClient, model_router: ModelRouter, token_budget: Optional[TokenBudgetEnforcer] = None):
41
+ self.client = client
42
+ self.router = model_router
43
+ self.budget = token_budget
44
+
45
+ def diagnose(self, evidence: EvidencePackage, hint_action_id: str = "UNKNOWN") -> DiagnosisResponse:
46
+ """
47
+ Invoke the appropriate LLM with structured evidence.
48
+ Returns validated DiagnosisResponse.
49
+ """
50
+ # Check token budget
51
+ model_tier = self.router.get_model_tier(ActionId(hint_action_id))
52
+ if self.budget and not self.budget.can_invoke_model(model_tier):
53
+ return DiagnosisResponse(
54
+ root_cause="Token budget exhausted — degraded mode active",
55
+ confidence_score=0.0,
56
+ action_id=ActionId.ESCALATE,
57
+ reasoning="System is in degraded mode due to token budget limits",
58
+ )
59
+
60
+ # Select model
61
+ model = self.router.get_model(ActionId(hint_action_id))
62
+ if model is None:
63
+ return DiagnosisResponse(
64
+ root_cause="No LLM needed for this action type",
65
+ confidence_score=1.0,
66
+ action_id=ActionId(hint_action_id),
67
+ )
68
+
69
+ # Build prompt
70
+ prompt = self._build_prompt(evidence)
71
+
72
+ # Call LLM
73
+ try:
74
+ response_text = self.client.invoke_model(
75
+ endpoint_name=model,
76
+ messages=[
77
+ {"role": "system", "content": SYSTEM_PROMPT},
78
+ {"role": "user", "content": prompt},
79
+ ],
80
+ max_tokens=2000,
81
+ temperature=0.1,
82
+ )
83
+ except Exception as e:
84
+ return DiagnosisResponse(
85
+ root_cause=f"LLM invocation failed: {str(e)[:200]}",
86
+ confidence_score=0.0,
87
+ action_id=ActionId.ESCALATE,
88
+ reasoning="Model serving endpoint error",
89
+ )
90
+
91
+ # Record token usage (estimate)
92
+ if self.budget:
93
+ estimated_tokens = len(prompt.split()) + len(response_text.split())
94
+ self.budget.record_usage(estimated_tokens, estimated_tokens * 0.00001)
95
+
96
+ # Parse and validate
97
+ return self._parse_response(response_text)
98
+
99
+ def _build_prompt(self, evidence: EvidencePackage) -> str:
100
+ """Build the structured prompt from the evidence package."""
101
+ sections = [f"FAILED TASK: {evidence.task_key}", f"JOB: {evidence.job_id}", f"RUN: {evidence.run_id}"]
102
+
103
+ if evidence.driver_stdout:
104
+ if evidence.driver_stdout.schema_errors:
105
+ sections.append("SCHEMA ERRORS:\n" + "\n".join(evidence.driver_stdout.schema_errors[:5]))
106
+ if evidence.driver_stdout.data_source_exceptions:
107
+ sections.append("EXCEPTIONS:\n" + "\n".join(evidence.driver_stdout.data_source_exceptions[:10]))
108
+ if evidence.driver_stdout.missing_table_errors:
109
+ sections.append("MISSING TABLES:\n" + "\n".join(evidence.driver_stdout.missing_table_errors[:5]))
110
+
111
+ if evidence.spark_event_logs and evidence.spark_event_logs.first_errors:
112
+ sections.append("NOTEBOOK SOURCE / SPARK LOGS:\n" + "\n".join(evidence.spark_event_logs.first_errors[:1])[:2000])
113
+
114
+ if evidence.task_metrics:
115
+ m = evidence.task_metrics
116
+ sections.append(f"METRICS: spill={m.spill_to_disk_bytes}B, GC={m.gc_overhead_pct}%, shuffle_r={m.shuffle_read_bytes}B")
117
+
118
+ if evidence.git_context:
119
+ sections.append(f"GIT: {evidence.git_context.last_commit_author} — {evidence.git_context.last_commit_message}")
120
+
121
+ if evidence.cluster_events and evidence.cluster_events.termination_reason:
122
+ sections.append(f"CLUSTER: terminated={evidence.cluster_events.termination_reason}")
123
+
124
+ if evidence.missing_sources:
125
+ sections.append(f"UNAVAILABLE SOURCES: {', '.join(evidence.missing_sources)}")
126
+
127
+ return "\n\n".join(sections)
128
+
129
+ def _parse_response(self, response_text: str) -> DiagnosisResponse:
130
+ """Parse and validate the LLM JSON response."""
131
+ try:
132
+ cleaned = response_text.strip()
133
+ if cleaned.startswith("```"):
134
+ cleaned = cleaned.split("\n", 1)[1] if "\n" in cleaned else cleaned[3:]
135
+ if cleaned.endswith("```"):
136
+ cleaned = cleaned[:-3]
137
+
138
+ data = json.loads(cleaned.strip())
139
+
140
+ # Validate action_id
141
+ action_id_str = data.get("action_id", "UNKNOWN")
142
+ valid_actions = [a.value for a in ActionId]
143
+ if action_id_str not in valid_actions:
144
+ action_id_str = "UNKNOWN"
145
+
146
+ # Validate confidence
147
+ confidence = float(data.get("confidence_score", 0))
148
+ confidence = max(0.0, min(1.0, confidence))
149
+
150
+ return DiagnosisResponse(
151
+ root_cause=data.get("root_cause", "Unknown"),
152
+ confidence_score=confidence,
153
+ action_id=ActionId(action_id_str),
154
+ action_params=data.get("action_params", {}),
155
+ reasoning=data.get("reasoning", ""),
156
+ evidence_used=data.get("evidence_used", []),
157
+ )
158
+
159
+ except (json.JSONDecodeError, KeyError, ValueError):
160
+ return DiagnosisResponse(
161
+ root_cause=f"LLM response not parseable: {response_text[:200]}",
162
+ confidence_score=0.0,
163
+ action_id=ActionId.ESCALATE,
164
+ reasoning="Response failed JSON schema validation",
165
+ )
@@ -0,0 +1,61 @@
1
+ """Approver authorization (#3).
2
+
3
+ When a user clicks Approve/Reject, we must confirm — by governance — that the
4
+ clicker actually has access to this workspace before honoring the action. The
5
+ clicker is identified by their SSO login or Teams email, then verified against
6
+ the workspace via SCIM (user exists + is active, and optionally is in an
7
+ approvers group). If they don't have access we refuse with a clear message.
8
+ """
9
+
10
+ from dataclasses import dataclass
11
+
12
+
13
+ @dataclass
14
+ class ApproverDecision:
15
+ """Result of authorizing an approver."""
16
+
17
+ authorized: bool
18
+ identity: str
19
+ user_id: str = ""
20
+ display_name: str = ""
21
+ reason: str = ""
22
+
23
+
24
+ # Message returned to the relay/handler when the clicker isn't entitled.
25
+ ACCESS_DENIED_MESSAGE = "You don't have enough access to approve this action."
26
+
27
+
28
+ def authorize_approver(client, identity: str, approvers_group: str = "") -> ApproverDecision:
29
+ """Authorize an approver by SSO login / Teams email against workspace governance.
30
+
31
+ client: a DatabricksClient (with find_user_by_identity / user_in_group).
32
+ identity: the clicker's login or email (verified upstream by the relay's SSO).
33
+ approvers_group: optional group the user must belong to (segregation of duties).
34
+ """
35
+ ident = (identity or "").strip()
36
+ if not ident:
37
+ return ApproverDecision(False, ident, reason="No approver identity supplied")
38
+
39
+ user = client.find_user_by_identity(ident)
40
+ if not user:
41
+ return ApproverDecision(False, ident, reason=ACCESS_DENIED_MESSAGE)
42
+
43
+ if user.get("active") is False:
44
+ return ApproverDecision(
45
+ False, ident, user_id=str(user.get("id", "")),
46
+ reason="User exists but is deactivated.",
47
+ )
48
+
49
+ if approvers_group and not client.user_in_group(user, approvers_group):
50
+ return ApproverDecision(
51
+ False, ident, user_id=str(user.get("id", "")),
52
+ display_name=user.get("displayName", ""),
53
+ reason=f"User is not in the '{approvers_group}' approvers group.",
54
+ )
55
+
56
+ return ApproverDecision(
57
+ True, ident,
58
+ user_id=str(user.get("id", "")),
59
+ display_name=user.get("displayName", ""),
60
+ reason="ok",
61
+ )
@@ -0,0 +1,52 @@
1
+ """Deterministic model routing by action_id."""
2
+
3
+ from healing_kit.models.diagnosis import ActionId
4
+
5
+
6
+ class ModelRouter:
7
+ """
8
+ Routes LLM calls to the appropriate model based on action_id.
9
+
10
+ Routing is fully deterministic and auditable:
11
+ - RETRY, CLUSTER_RESIZE → lightweight model (cost-efficient)
12
+ - SCHEMA_FIX, CODE_PATCH, UNKNOWN → powerful model (needs reasoning)
13
+ - CACHE_HIT, ESCALATE → no LLM call (zero tokens)
14
+ """
15
+
16
+ def __init__(self, lightweight_model: str, powerful_model: str):
17
+ self.lightweight_model = lightweight_model
18
+ self.powerful_model = powerful_model
19
+
20
+ self._routing_table = {
21
+ ActionId.RETRY: lightweight_model,
22
+ ActionId.CLUSTER_RESIZE: lightweight_model,
23
+ ActionId.SCHEMA_FIX: powerful_model,
24
+ ActionId.CODE_PATCH: powerful_model,
25
+ ActionId.UNKNOWN: powerful_model,
26
+ ActionId.CACHE_HIT: None,
27
+ ActionId.ESCALATE: None,
28
+ }
29
+
30
+ def get_model(self, action_id: ActionId) -> str | None:
31
+ """
32
+ Get the model endpoint name for a given action_id.
33
+ Returns None if no LLM call is needed.
34
+ """
35
+ if isinstance(action_id, str):
36
+ action_id = ActionId(action_id)
37
+ return self._routing_table.get(action_id)
38
+
39
+ def get_model_tier(self, action_id: ActionId) -> str:
40
+ """
41
+ Get the tier ('lightweight', 'powerful', or 'none') for budget checking.
42
+ """
43
+ model = self.get_model(action_id)
44
+ if model is None:
45
+ return "none"
46
+ if model == self.lightweight_model:
47
+ return "lightweight"
48
+ return "powerful"
49
+
50
+ def requires_llm(self, action_id: ActionId) -> bool:
51
+ """Check if this action requires an LLM call."""
52
+ return self.get_model(action_id) is not None
@@ -0,0 +1,168 @@
1
+ """Query guard — allowlist + classification-aware masking for diagnostic SELECTs.
2
+
3
+ Closes the prompt-injection -> data-exfiltration chain (#2). The LLM-emitted
4
+ diagnostic query is:
5
+ 1. checked read-only (single SELECT/CTE, no DDL/DML/comments/stacking),
6
+ 2. parsed with sqlglot so every *referenced table* is matched against an
7
+ explicit allowlist (deny-by-default) — far stronger than keyword filtering,
8
+ 3. inspected for sensitive columns; result values are masked BY DEFAULT and
9
+ only sent raw to Teams when the data classification + policy permit it.
10
+
11
+ Pure logic, no Spark — unit-tested with plain strings/rows.
12
+ """
13
+
14
+ from dataclasses import dataclass, field
15
+
16
+ import sqlglot
17
+ from sqlglot import exp
18
+
19
+ from healing_kit.utils.sql_safety import is_safe_select
20
+
21
+ # Evidence policies (what may leave the workspace into Teams):
22
+ POLICY_SCHEMA_ONLY = "schema_only" # never send values; only column names + row count
23
+ POLICY_MASKED = "masked" # send masked values (DEFAULT)
24
+ POLICY_RAW_IF_ALLOWED = "raw_if_allowed" # raw only if no sensitive column referenced
25
+
26
+ # Conservative default set of sensitive column-name fragments (case-insensitive).
27
+ DEFAULT_SENSITIVE_FRAGMENTS = (
28
+ "ssn", "social_security", "email", "e_mail", "phone", "mobile", "dob",
29
+ "birth", "address", "passport", "license", "credit", "card", "cvv",
30
+ "iban", "account_number", "acct_no", "salary", "password", "secret",
31
+ "token", "national_id", "tax_id", "mrn", "patient", "diagnosis_code",
32
+ )
33
+
34
+
35
+ @dataclass
36
+ class GuardDecision:
37
+ """Outcome of evaluating a diagnostic query."""
38
+
39
+ allowed: bool
40
+ reason: str
41
+ tables: list[str] = field(default_factory=list)
42
+ columns: list[str] = field(default_factory=list)
43
+ sensitive_columns: list[str] = field(default_factory=list)
44
+ has_star: bool = False
45
+ # How rows must be handled if the query runs: schema_only | masked | raw
46
+ effective_policy: str = POLICY_MASKED
47
+
48
+
49
+ class QueryGuard:
50
+ """Allowlist + masking gate for LLM-generated diagnostic queries."""
51
+
52
+ def __init__(
53
+ self,
54
+ allowed_tables,
55
+ sensitive_fragments=None,
56
+ policy: str = POLICY_MASKED,
57
+ default_catalog: str | None = None,
58
+ default_schema: str | None = None,
59
+ dialect: str = "databricks",
60
+ ):
61
+ # allowed_tables: iterable of "catalog.schema.table", "catalog.schema.*",
62
+ # "catalog.*", or "*" (allow all — discouraged). Stored lowercased.
63
+ self.allowed = {t.strip().lower() for t in allowed_tables if t and t.strip()}
64
+ self.sensitive_fragments = tuple(
65
+ f.lower() for f in (sensitive_fragments or DEFAULT_SENSITIVE_FRAGMENTS)
66
+ )
67
+ self.policy = policy if policy in (POLICY_SCHEMA_ONLY, POLICY_MASKED, POLICY_RAW_IF_ALLOWED) else POLICY_MASKED
68
+ self.default_catalog = (default_catalog or "").lower()
69
+ self.default_schema = (default_schema or "").lower()
70
+ self.dialect = dialect
71
+
72
+ # ─── table allowlisting ───
73
+
74
+ def _qualify(self, table: exp.Table) -> str:
75
+ catalog = (table.catalog or self.default_catalog).lower()
76
+ schema = (table.db or self.default_schema).lower()
77
+ name = (table.name or "").lower()
78
+ return ".".join(p for p in (catalog, schema, name) if p)
79
+
80
+ def _table_allowed(self, qualified: str) -> bool:
81
+ if "*" in self.allowed:
82
+ return True
83
+ if qualified in self.allowed:
84
+ return True
85
+ parts = qualified.split(".")
86
+ # Try progressively broader wildcards: cat.sch.*, cat.*
87
+ for i in range(len(parts) - 1, 0, -1):
88
+ if ".".join(parts[:i] + ["*"]) in self.allowed:
89
+ return True
90
+ return False
91
+
92
+ def _is_sensitive(self, col_name: str) -> bool:
93
+ c = col_name.lower()
94
+ return any(frag in c for frag in self.sensitive_fragments)
95
+
96
+ def evaluate(self, query: str) -> GuardDecision:
97
+ """Decide whether a query may run and how its rows must be handled."""
98
+ if not is_safe_select(query):
99
+ return GuardDecision(False, "Not a single read-only SELECT/CTE")
100
+
101
+ try:
102
+ parsed = sqlglot.parse_one(query, read=self.dialect)
103
+ except Exception as ex: # unparseable -> reject
104
+ return GuardDecision(False, f"Unparseable SQL: {str(ex)[:80]}")
105
+
106
+ tables = sorted({self._qualify(t) for t in parsed.find_all(exp.Table)})
107
+ if not tables:
108
+ return GuardDecision(False, "No table referenced")
109
+
110
+ not_allowed = [t for t in tables if not self._table_allowed(t)]
111
+ if not_allowed:
112
+ return GuardDecision(
113
+ False,
114
+ f"Table(s) not on evidence allowlist: {', '.join(not_allowed)}",
115
+ tables=tables,
116
+ )
117
+
118
+ has_star = bool(list(parsed.find_all(exp.Star)))
119
+ columns = sorted({c.name.lower() for c in parsed.find_all(exp.Column) if c.name})
120
+ sensitive = sorted({c for c in columns if self._is_sensitive(c)})
121
+
122
+ # Decide the effective row policy.
123
+ if self.policy == POLICY_SCHEMA_ONLY:
124
+ effective = POLICY_SCHEMA_ONLY
125
+ elif self.policy == POLICY_RAW_IF_ALLOWED:
126
+ # SELECT * could expose unknown sensitive columns -> downgrade to masked.
127
+ effective = "masked" if (sensitive or has_star) else "raw"
128
+ else: # POLICY_MASKED (default)
129
+ effective = "masked"
130
+
131
+ return GuardDecision(
132
+ allowed=True,
133
+ reason="ok",
134
+ tables=tables,
135
+ columns=columns,
136
+ sensitive_columns=sensitive,
137
+ has_star=has_star,
138
+ effective_policy=effective,
139
+ )
140
+
141
+ # ─── output handling ───
142
+
143
+ @staticmethod
144
+ def _mask_value(value) -> str:
145
+ s = str(value)
146
+ if len(s) <= 2:
147
+ return "**"
148
+ return s[:2] + "*" * min(len(s) - 2, 6)
149
+
150
+ def render_rows(self, columns: list[str], rows: list[list], decision: GuardDecision, max_rows: int = 8) -> str:
151
+ """Render query rows for Teams according to the decision's effective policy."""
152
+ if decision.effective_policy == POLICY_SCHEMA_ONLY:
153
+ return "columns: " + ", ".join(columns) + f"\n({len(rows)} row(s) — values withheld by policy)"
154
+
155
+ sensitive_set = set(decision.sensitive_columns)
156
+ lines = [" | ".join(columns)]
157
+ for row in rows[:max_rows]:
158
+ cells = []
159
+ for col, val in zip(columns, row):
160
+ if decision.effective_policy == "raw":
161
+ cells.append(str(val)[:20])
162
+ else: # masked: mask sensitive columns, lightly truncate the rest
163
+ if col.lower() in sensitive_set or self._is_sensitive(col):
164
+ cells.append(self._mask_value(val))
165
+ else:
166
+ cells.append(str(val)[:20])
167
+ lines.append(" | ".join(cells))
168
+ return "\n".join(lines)
@@ -0,0 +1,100 @@
1
+ """Resolution verifier (#5) — close the loop after a fix is applied.
2
+
3
+ "Self-healing" only earns the name if an applied fix is *verified*: re-run the
4
+ failed task, confirm it now succeeds, and roll back / escalate if it doesn't.
5
+ This module owns that loop so the approval handler doesn't blindly mark things
6
+ resolved. RETRY/CLUSTER_RESIZE are verified by re-running; SCHEMA_FIX/CODE_PATCH
7
+ must supply an ``apply_fn`` (and ideally a ``rollback_fn``) and are gated so they
8
+ never silently no-op.
9
+ """
10
+
11
+ import time
12
+ from dataclasses import dataclass
13
+
14
+ # Action types we can verify by simply re-running the task.
15
+ RERUNNABLE = {"RETRY", "CLUSTER_RESIZE"}
16
+ # Action types that change code/schema and must apply a fix before re-running.
17
+ MUTATING = {"SCHEMA_FIX", "CODE_PATCH"}
18
+
19
+
20
+ @dataclass
21
+ class VerificationResult:
22
+ action_id: str
23
+ applied: bool
24
+ verified: bool
25
+ rolled_back: bool
26
+ final_state: str
27
+ detail: str = ""
28
+
29
+
30
+ class ResolutionVerifier:
31
+ def __init__(self, client, poll_seconds: int = 15, max_polls: int = 40):
32
+ self.client = client
33
+ self.poll_seconds = poll_seconds
34
+ self.max_polls = max_polls
35
+
36
+ def _poll_task_state(self, run_id: int, task_key: str) -> str:
37
+ """Poll runs/get until the task reaches a terminal state. Returns the
38
+ result_state ('SUCCESS'/'FAILED'/...) or 'TIMEOUT'."""
39
+ for _ in range(self.max_polls):
40
+ try:
41
+ run = self.client._get("/api/2.1/jobs/runs/get", {"run_id": int(run_id)})
42
+ except Exception as ex:
43
+ return f"ERROR: {str(ex)[:80]}"
44
+ for t in run.get("tasks", []):
45
+ if t.get("task_key") == task_key:
46
+ state = t.get("state", {})
47
+ life = state.get("life_cycle_state")
48
+ if life in ("TERMINATED", "SKIPPED", "INTERNAL_ERROR"):
49
+ return state.get("result_state", "UNKNOWN")
50
+ time.sleep(self.poll_seconds)
51
+ return "TIMEOUT"
52
+
53
+ def verify(self, action_id: str, run_id, task_key: str, apply_fn=None, rollback_fn=None) -> VerificationResult:
54
+ """Apply (if needed), re-run, and verify a resolution."""
55
+ action = str(action_id).upper()
56
+
57
+ if action in MUTATING:
58
+ if apply_fn is None:
59
+ # Never claim success for a fix we can't actually apply.
60
+ return VerificationResult(action, applied=False, verified=False, rolled_back=False,
61
+ final_state="STAGED",
62
+ detail=f"{action} requires an apply function; staged for manual execution.")
63
+ try:
64
+ apply_fn()
65
+ except Exception as ex:
66
+ return VerificationResult(action, applied=False, verified=False, rolled_back=False,
67
+ final_state="APPLY_FAILED", detail=str(ex)[:120])
68
+ applied = True
69
+ elif action in RERUNNABLE:
70
+ applied = True
71
+ else:
72
+ return VerificationResult(action, applied=False, verified=False, rolled_back=False,
73
+ final_state="NOT_VERIFIABLE",
74
+ detail=f"{action} is not auto-verifiable (e.g. ESCALATE).")
75
+
76
+ # Re-run the failed task and watch for the outcome.
77
+ try:
78
+ self.client.repair_run(int(run_id), [task_key])
79
+ except Exception as ex:
80
+ return VerificationResult(action, applied=applied, verified=False, rolled_back=False,
81
+ final_state="RERUN_FAILED", detail=str(ex)[:120])
82
+
83
+ state = self._poll_task_state(run_id, task_key)
84
+ if state == "SUCCESS":
85
+ return VerificationResult(action, applied=applied, verified=True, rolled_back=False,
86
+ final_state="VERIFIED", detail="Task re-ran and succeeded.")
87
+
88
+ # Did not succeed — roll back mutating fixes if we can.
89
+ rolled_back = False
90
+ if action in MUTATING and rollback_fn is not None:
91
+ try:
92
+ rollback_fn()
93
+ rolled_back = True
94
+ except Exception as ex:
95
+ return VerificationResult(action, applied=applied, verified=False, rolled_back=False,
96
+ final_state="ROLLBACK_FAILED", detail=str(ex)[:120])
97
+
98
+ return VerificationResult(action, applied=applied, verified=False, rolled_back=rolled_back,
99
+ final_state="ESCALATE",
100
+ detail=f"Re-run ended in {state}; escalating to humans.")