zigporter 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
zigporter/__init__.py ADDED
File without changes
File without changes
@@ -0,0 +1,230 @@
1
+ import asyncio
2
+
3
+ import httpx
4
+ import questionary
5
+ from rich.console import Console
6
+
7
+ from zigporter.ha_client import HAClient
8
+ from zigporter.models import CheckResult, CheckStatus
9
+
10
+ console = Console()
11
+
12
+ _STYLE = questionary.Style(
13
+ [
14
+ ("qmark", "fg:ansicyan bold"),
15
+ ("question", "bold"),
16
+ ("answer", "fg:ansicyan bold"),
17
+ ("pointer", "fg:ansicyan bold"),
18
+ ("highlighted", "fg:ansicyan bold"),
19
+ ("selected", "fg:ansicyan"),
20
+ ("separator", "fg:ansibrightblack"),
21
+ ("instruction", "fg:ansibrightblack"),
22
+ ("text", ""),
23
+ ("disabled", "fg:ansibrightblack italic"),
24
+ ]
25
+ )
26
+
27
+ _STATUS_ICON = {
28
+ CheckStatus.OK: "[green]✓[/green]",
29
+ CheckStatus.FAILED: "[red]✗[/red]",
30
+ CheckStatus.WARNING: "[yellow]![/yellow]",
31
+ CheckStatus.SKIPPED: "[dim]–[/dim]",
32
+ }
33
+
34
+
35
+ # ---------------------------------------------------------------------------
36
+ # Individual checks
37
+ # ---------------------------------------------------------------------------
38
+
39
+
40
+ async def _check_config(ha_url: str, token: str, z2m_url: str) -> CheckResult:
41
+ missing = [
42
+ name
43
+ for name, val in [("HA_URL", ha_url), ("HA_TOKEN", token), ("Z2M_URL", z2m_url)]
44
+ if not val
45
+ ]
46
+ if missing:
47
+ return CheckResult(
48
+ name="Configuration",
49
+ status=CheckStatus.FAILED,
50
+ message=f"Missing: {', '.join(missing)} — add to .env or set as environment variables",
51
+ )
52
+ return CheckResult(
53
+ name="Configuration", status=CheckStatus.OK, message="HA_URL, HA_TOKEN, Z2M_URL are set"
54
+ )
55
+
56
+
57
+ async def _check_ha_reachable(ha_url: str, token: str, verify_ssl: bool) -> CheckResult:
58
+ if not ha_url:
59
+ return CheckResult(
60
+ name="HA reachable",
61
+ status=CheckStatus.SKIPPED,
62
+ message="Skipped (no HA_URL configured)",
63
+ )
64
+ try:
65
+ async with httpx.AsyncClient(
66
+ headers={"Authorization": f"Bearer {token}"},
67
+ verify=verify_ssl,
68
+ timeout=10,
69
+ ) as client:
70
+ resp = await client.get(f"{ha_url}/api/")
71
+ resp.raise_for_status()
72
+ return CheckResult(name="HA reachable", status=CheckStatus.OK, message=ha_url)
73
+ except Exception as exc:
74
+ return CheckResult(
75
+ name="HA reachable",
76
+ status=CheckStatus.FAILED,
77
+ message=f"Cannot reach {ha_url} — {exc}",
78
+ )
79
+
80
+
81
+ async def _check_zha_active(ha_url: str, token: str, verify_ssl: bool) -> CheckResult:
82
+ if not ha_url:
83
+ return CheckResult(
84
+ name="ZHA active",
85
+ status=CheckStatus.SKIPPED,
86
+ message="Skipped (no HA_URL configured)",
87
+ )
88
+ try:
89
+ client = HAClient(ha_url, token, verify_ssl)
90
+ devices = await client.get_zha_devices()
91
+ count = len(devices)
92
+ if count == 0:
93
+ return CheckResult(
94
+ name="ZHA active",
95
+ status=CheckStatus.WARNING,
96
+ message="ZHA is reachable but no devices found — is ZHA configured?",
97
+ blocking=False,
98
+ )
99
+ return CheckResult(
100
+ name="ZHA active",
101
+ status=CheckStatus.OK,
102
+ message=f"{count} device(s) found",
103
+ )
104
+ except Exception as exc:
105
+ return CheckResult(
106
+ name="ZHA active",
107
+ status=CheckStatus.FAILED,
108
+ message=f"Could not query ZHA — {exc}",
109
+ )
110
+
111
+
112
+ async def _check_z2m_running(
113
+ ha_url: str, token: str, z2m_url: str, verify_ssl: bool
114
+ ) -> CheckResult:
115
+ if not z2m_url:
116
+ return CheckResult(
117
+ name="Z2M running",
118
+ status=CheckStatus.SKIPPED,
119
+ message="Skipped (no Z2M_URL configured)",
120
+ )
121
+ try:
122
+ async with httpx.AsyncClient(
123
+ headers={"Authorization": f"Bearer {token}"},
124
+ verify=verify_ssl,
125
+ timeout=10,
126
+ ) as client:
127
+ resp = await client.get(f"{z2m_url}/api/devices")
128
+ # Any HTTP response (even 401) means the server is reachable
129
+ if resp.status_code < 500:
130
+ try:
131
+ devices = resp.json()
132
+ count = len(devices) if isinstance(devices, list) else "?"
133
+ return CheckResult(
134
+ name="Z2M running",
135
+ status=CheckStatus.OK,
136
+ message=f"{count} device(s) paired",
137
+ )
138
+ except Exception:
139
+ return CheckResult(
140
+ name="Z2M running",
141
+ status=CheckStatus.OK,
142
+ message="Z2M is responding",
143
+ )
144
+ resp.raise_for_status()
145
+ return CheckResult(
146
+ name="Z2M running", status=CheckStatus.OK, message="Z2M is responding"
147
+ )
148
+ except Exception as exc:
149
+ return CheckResult(
150
+ name="Z2M running",
151
+ status=CheckStatus.FAILED,
152
+ message=f"Cannot reach Zigbee2MQTT at {z2m_url} — {exc}",
153
+ )
154
+
155
+
156
+ # ---------------------------------------------------------------------------
157
+ # Orchestrator
158
+ # ---------------------------------------------------------------------------
159
+
160
+
161
+ async def _run_checks(
162
+ ha_url: str,
163
+ token: str,
164
+ verify_ssl: bool,
165
+ z2m_url: str,
166
+ ) -> list[CheckResult]:
167
+ results: list[CheckResult] = []
168
+
169
+ config_result = await _check_config(ha_url, token, z2m_url)
170
+ results.append(config_result)
171
+
172
+ # Only run network checks if config is valid
173
+ if config_result.status == CheckStatus.OK:
174
+ ha_result = await _check_ha_reachable(ha_url, token, verify_ssl)
175
+ results.append(ha_result)
176
+
177
+ # ZHA depends on HA being reachable
178
+ if ha_result.status == CheckStatus.OK:
179
+ results.append(await _check_zha_active(ha_url, token, verify_ssl))
180
+ else:
181
+ results.append(
182
+ CheckResult(
183
+ name="ZHA active",
184
+ status=CheckStatus.SKIPPED,
185
+ message="Skipped (HA not reachable)",
186
+ )
187
+ )
188
+
189
+ results.append(await _check_z2m_running(ha_url, token, z2m_url, verify_ssl))
190
+ else:
191
+ for name in ("HA reachable", "ZHA active", "Z2M running"):
192
+ results.append(
193
+ CheckResult(
194
+ name=name, status=CheckStatus.SKIPPED, message="Skipped (invalid config)"
195
+ )
196
+ )
197
+
198
+ return results
199
+
200
+
201
+ def _print_results(results: list[CheckResult]) -> None:
202
+ console.print()
203
+ for r in results:
204
+ icon = _STATUS_ICON[r.status]
205
+ label = f"[bold]{r.name:<20}[/bold]"
206
+ console.print(f" {icon} {label} {r.message}")
207
+ console.print()
208
+
209
+
210
+ def check_command(
211
+ ha_url: str,
212
+ token: str,
213
+ verify_ssl: bool,
214
+ z2m_url: str,
215
+ ) -> bool:
216
+ """Run all preflight checks. Returns True if the user should proceed, False to abort."""
217
+ console.rule("[bold cyan]Pre-flight checks[/bold cyan]")
218
+
219
+ results = asyncio.run(_run_checks(ha_url, token, verify_ssl, z2m_url))
220
+ _print_results(results)
221
+
222
+ blocking_failures = [r for r in results if r.status == CheckStatus.FAILED and r.blocking]
223
+ if blocking_failures:
224
+ console.print("[yellow]One or more checks failed.[/yellow]")
225
+ proceed = questionary.confirm("Proceed anyway?", default=False, style=_STYLE).ask()
226
+ if not proceed:
227
+ return False
228
+
229
+ console.rule()
230
+ return True
@@ -0,0 +1,221 @@
1
+ import asyncio
2
+ from datetime import datetime, timezone
3
+ from pathlib import Path
4
+ from typing import Any
5
+
6
+ from rich.console import Console
7
+ from rich.progress import Progress, SpinnerColumn, TextColumn
8
+
9
+ from zigporter.ha_client import HAClient
10
+ from zigporter.models import AutomationRef, ZHADevice, ZHAEntity, ZHAExport
11
+
12
+ console = Console()
13
+
14
+
15
+ def _build_area_map(areas: list[dict[str, Any]]) -> dict[str, str]:
16
+ """Return {area_id: area_name}."""
17
+ return {a["area_id"]: a["name"] for a in areas}
18
+
19
+
20
+ def _build_entity_map(
21
+ entity_registry: list[dict[str, Any]],
22
+ ) -> dict[str, list[dict[str, Any]]]:
23
+ """Return {device_id: [entity_registry_entry, ...]} for ZHA entities only."""
24
+ result: dict[str, list[dict[str, Any]]] = {}
25
+ for entry in entity_registry:
26
+ if entry.get("platform") != "zha":
27
+ continue
28
+ device_id = entry.get("device_id")
29
+ if not device_id:
30
+ continue
31
+ result.setdefault(device_id, []).append(entry)
32
+ return result
33
+
34
+
35
+ def _build_state_map(states: list[dict[str, Any]]) -> dict[str, dict[str, Any]]:
36
+ """Return {entity_id: state_entry}."""
37
+ return {s["entity_id"]: s for s in states}
38
+
39
+
40
+ def _extract_entity_ids_from_automation(config: dict[str, Any]) -> list[str]:
41
+ """Walk an automation config dict and collect all entity_id values."""
42
+ entity_ids: set[str] = set()
43
+
44
+ def _walk(node: Any) -> None:
45
+ if isinstance(node, dict):
46
+ if "entity_id" in node:
47
+ value = node["entity_id"]
48
+ if isinstance(value, str):
49
+ entity_ids.add(value)
50
+ elif isinstance(value, list):
51
+ for v in value:
52
+ if isinstance(v, str):
53
+ entity_ids.add(v)
54
+ for v in node.values():
55
+ _walk(v)
56
+ elif isinstance(node, list):
57
+ for item in node:
58
+ _walk(item)
59
+
60
+ _walk(config)
61
+ return sorted(entity_ids)
62
+
63
+
64
+ def _match_automations_to_devices(
65
+ automation_configs: list[dict[str, Any]],
66
+ entity_to_device: dict[str, str],
67
+ ) -> dict[str, list[AutomationRef]]:
68
+ """Return {device_id: [AutomationRef, ...]}."""
69
+ result: dict[str, list[AutomationRef]] = {}
70
+
71
+ for config in automation_configs:
72
+ auto_id = config.get("id", "")
73
+ alias = config.get("alias", auto_id)
74
+ entity_ids = _extract_entity_ids_from_automation(config)
75
+
76
+ referenced_devices: dict[str, list[str]] = {}
77
+ for eid in entity_ids:
78
+ device_id = entity_to_device.get(eid)
79
+ if device_id:
80
+ referenced_devices.setdefault(device_id, []).append(eid)
81
+
82
+ for device_id, refs in referenced_devices.items():
83
+ ref = AutomationRef(
84
+ automation_id=f"automation.{alias.lower().replace(' ', '_')}",
85
+ alias=alias,
86
+ entity_references=refs,
87
+ )
88
+ result.setdefault(device_id, []).append(ref)
89
+
90
+ return result
91
+
92
+
93
+ def build_export(
94
+ zha_devices: list[dict[str, Any]],
95
+ device_registry: list[dict[str, Any]],
96
+ entity_registry: list[dict[str, Any]],
97
+ area_registry: list[dict[str, Any]],
98
+ states: list[dict[str, Any]],
99
+ automation_configs: list[dict[str, Any]],
100
+ ha_url: str,
101
+ ) -> ZHAExport:
102
+ """Join all data sources into a ZHAExport."""
103
+ area_map = _build_area_map(area_registry)
104
+ entity_map = _build_entity_map(entity_registry)
105
+ state_map = _build_state_map(states)
106
+
107
+ # Build device_id -> area_id from device registry
108
+ dr_map: dict[str, dict[str, Any]] = {d["id"]: d for d in device_registry}
109
+
110
+ # Build entity_id -> device_id for automation matching
111
+ entity_to_device: dict[str, str] = {
112
+ e["entity_id"]: e["device_id"]
113
+ for e in entity_registry
114
+ if e.get("platform") == "zha" and e.get("device_id")
115
+ }
116
+
117
+ auto_map = _match_automations_to_devices(automation_configs, entity_to_device)
118
+
119
+ devices: list[ZHADevice] = []
120
+
121
+ for zha_dev in zha_devices:
122
+ ieee = zha_dev.get("ieee", "")
123
+ device_id = zha_dev.get("device_reg_id", "")
124
+
125
+ dr_entry = dr_map.get(device_id, {})
126
+ area_id = dr_entry.get("area_id")
127
+ area_name = area_map.get(area_id, None) if area_id else None
128
+
129
+ # Build entity list
130
+ entity_entries = entity_map.get(device_id, [])
131
+ entities: list[ZHAEntity] = []
132
+ for entry in entity_entries:
133
+ eid = entry.get("entity_id", "")
134
+ state_entry = state_map.get(eid, {})
135
+ attrs = state_entry.get("attributes", {})
136
+ entities.append(
137
+ ZHAEntity(
138
+ entity_id=eid,
139
+ name=attrs.get("friendly_name", eid),
140
+ name_by_user=entry.get("name"),
141
+ platform="zha",
142
+ unique_id=entry.get("unique_id"),
143
+ device_class=entry.get("device_class"),
144
+ disabled=entry.get("disabled_by") is not None,
145
+ state=state_entry.get("state"),
146
+ attributes=attrs,
147
+ )
148
+ )
149
+
150
+ devices.append(
151
+ ZHADevice(
152
+ device_id=device_id,
153
+ ieee=ieee,
154
+ name=zha_dev.get("user_given_name") or zha_dev.get("name", ieee),
155
+ name_by_user=zha_dev.get("user_given_name"),
156
+ manufacturer=zha_dev.get("manufacturer"),
157
+ model=zha_dev.get("model"),
158
+ area_id=area_id,
159
+ area_name=area_name,
160
+ device_type=zha_dev.get("device_type", "Unknown"),
161
+ quirk_applied=zha_dev.get("quirk_applied", False),
162
+ quirk_class=zha_dev.get("quirk_class"),
163
+ entities=entities,
164
+ automations=auto_map.get(device_id, []),
165
+ )
166
+ )
167
+
168
+ return ZHAExport(
169
+ exported_at=datetime.now(tz=timezone.utc),
170
+ ha_url=ha_url,
171
+ devices=devices,
172
+ )
173
+
174
+
175
+ async def run_export(ha_url: str, token: str, verify_ssl: bool) -> ZHAExport:
176
+ """Fetch all data from HA and build the export."""
177
+ client = HAClient(ha_url, token, verify_ssl)
178
+
179
+ with Progress(SpinnerColumn(), TextColumn("{task.description}"), console=console) as progress:
180
+ t = progress.add_task("Connecting to Home Assistant...", total=None)
181
+
182
+ progress.update(t, description="Fetching ZHA + registry data...")
183
+ ws_data = await client.get_all_ws_data()
184
+
185
+ progress.update(t, description="Fetching entity states...")
186
+ states = await client.get_states()
187
+
188
+ progress.update(t, description="Building device map...")
189
+ export = build_export(
190
+ zha_devices=ws_data["zha_devices"],
191
+ device_registry=ws_data["device_registry"],
192
+ entity_registry=ws_data["entity_registry"],
193
+ area_registry=ws_data["area_registry"],
194
+ states=states,
195
+ automation_configs=ws_data["automation_configs"],
196
+ ha_url=ha_url,
197
+ )
198
+ progress.stop()
199
+
200
+ return export
201
+
202
+
203
+ def export_command(
204
+ output: Path,
205
+ pretty: bool,
206
+ ha_url: str,
207
+ token: str,
208
+ verify_ssl: bool,
209
+ ) -> None:
210
+ """Entry point called from the CLI."""
211
+ export = asyncio.run(run_export(ha_url, token, verify_ssl))
212
+
213
+ indent = 2 if pretty else None
214
+ output.write_text(export.model_dump_json(indent=indent))
215
+
216
+ console.print(
217
+ f"\nExport complete: [bold]{len(export.devices)}[/bold] devices, "
218
+ f"[bold]{sum(len(d.entities) for d in export.devices)}[/bold] entities, "
219
+ f"[bold]{sum(len(d.automations) for d in export.devices)}[/bold] automation references\n"
220
+ f"Written to [green]{output}[/green]"
221
+ )