uiprotect 1.20.0__tar.gz → 2.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of uiprotect might be problematic. Click here for more details.

Files changed (37) hide show
  1. {uiprotect-1.20.0 → uiprotect-2.1.0}/PKG-INFO +1 -1
  2. {uiprotect-1.20.0 → uiprotect-2.1.0}/pyproject.toml +1 -1
  3. {uiprotect-1.20.0 → uiprotect-2.1.0}/src/uiprotect/api.py +126 -88
  4. {uiprotect-1.20.0 → uiprotect-2.1.0}/src/uiprotect/cli/__init__.py +5 -0
  5. {uiprotect-1.20.0 → uiprotect-2.1.0}/src/uiprotect/test_util/__init__.py +6 -2
  6. uiprotect-2.1.0/src/uiprotect/websocket.py +187 -0
  7. uiprotect-1.20.0/src/uiprotect/websocket.py +0 -226
  8. {uiprotect-1.20.0 → uiprotect-2.1.0}/LICENSE +0 -0
  9. {uiprotect-1.20.0 → uiprotect-2.1.0}/README.md +0 -0
  10. {uiprotect-1.20.0 → uiprotect-2.1.0}/src/uiprotect/__init__.py +0 -0
  11. {uiprotect-1.20.0 → uiprotect-2.1.0}/src/uiprotect/__main__.py +0 -0
  12. {uiprotect-1.20.0 → uiprotect-2.1.0}/src/uiprotect/cli/backup.py +0 -0
  13. {uiprotect-1.20.0 → uiprotect-2.1.0}/src/uiprotect/cli/base.py +0 -0
  14. {uiprotect-1.20.0 → uiprotect-2.1.0}/src/uiprotect/cli/cameras.py +0 -0
  15. {uiprotect-1.20.0 → uiprotect-2.1.0}/src/uiprotect/cli/chimes.py +0 -0
  16. {uiprotect-1.20.0 → uiprotect-2.1.0}/src/uiprotect/cli/doorlocks.py +0 -0
  17. {uiprotect-1.20.0 → uiprotect-2.1.0}/src/uiprotect/cli/events.py +0 -0
  18. {uiprotect-1.20.0 → uiprotect-2.1.0}/src/uiprotect/cli/lights.py +0 -0
  19. {uiprotect-1.20.0 → uiprotect-2.1.0}/src/uiprotect/cli/liveviews.py +0 -0
  20. {uiprotect-1.20.0 → uiprotect-2.1.0}/src/uiprotect/cli/nvr.py +0 -0
  21. {uiprotect-1.20.0 → uiprotect-2.1.0}/src/uiprotect/cli/sensors.py +0 -0
  22. {uiprotect-1.20.0 → uiprotect-2.1.0}/src/uiprotect/cli/viewers.py +0 -0
  23. {uiprotect-1.20.0 → uiprotect-2.1.0}/src/uiprotect/data/__init__.py +0 -0
  24. {uiprotect-1.20.0 → uiprotect-2.1.0}/src/uiprotect/data/base.py +0 -0
  25. {uiprotect-1.20.0 → uiprotect-2.1.0}/src/uiprotect/data/bootstrap.py +0 -0
  26. {uiprotect-1.20.0 → uiprotect-2.1.0}/src/uiprotect/data/convert.py +0 -0
  27. {uiprotect-1.20.0 → uiprotect-2.1.0}/src/uiprotect/data/devices.py +0 -0
  28. {uiprotect-1.20.0 → uiprotect-2.1.0}/src/uiprotect/data/nvr.py +0 -0
  29. {uiprotect-1.20.0 → uiprotect-2.1.0}/src/uiprotect/data/types.py +0 -0
  30. {uiprotect-1.20.0 → uiprotect-2.1.0}/src/uiprotect/data/user.py +0 -0
  31. {uiprotect-1.20.0 → uiprotect-2.1.0}/src/uiprotect/data/websocket.py +0 -0
  32. {uiprotect-1.20.0 → uiprotect-2.1.0}/src/uiprotect/exceptions.py +0 -0
  33. {uiprotect-1.20.0 → uiprotect-2.1.0}/src/uiprotect/py.typed +0 -0
  34. {uiprotect-1.20.0 → uiprotect-2.1.0}/src/uiprotect/release_cache.json +0 -0
  35. {uiprotect-1.20.0 → uiprotect-2.1.0}/src/uiprotect/stream.py +0 -0
  36. {uiprotect-1.20.0 → uiprotect-2.1.0}/src/uiprotect/test_util/anonymize.py +0 -0
  37. {uiprotect-1.20.0 → uiprotect-2.1.0}/src/uiprotect/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: uiprotect
3
- Version: 1.20.0
3
+ Version: 2.1.0
4
4
  Summary: Python API for Unifi Protect (Unofficial)
5
5
  Home-page: https://github.com/uilibs/uiprotect
6
6
  License: MIT
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "uiprotect"
3
- version = "1.20.0"
3
+ version = "2.1.0"
4
4
  description = "Python API for Unifi Protect (Unofficial)"
5
5
  authors = ["UI Protect Maintainers <ui@koston.org>"]
6
6
  license = "MIT"
@@ -11,7 +11,8 @@ import sys
11
11
  import time
12
12
  from collections.abc import Callable
13
13
  from datetime import datetime, timedelta
14
- from functools import cached_property
14
+ from functools import cached_property, partial
15
+ from http import HTTPStatus
15
16
  from http.cookies import Morsel, SimpleCookie
16
17
  from ipaddress import IPv4Address, IPv6Address
17
18
  from pathlib import Path
@@ -203,6 +204,7 @@ class BaseApiClient:
203
204
  self._verify_ssl = verify_ssl
204
205
  self._ws_timeout = ws_timeout
205
206
  self._loaded_session = False
207
+ self._update_task: asyncio.Task[Bootstrap | None] | None = None
206
208
 
207
209
  self.config_dir = config_dir or (Path(user_config_dir()) / "ufp")
208
210
  self.cache_dir = cache_dir or (Path(user_cache_dir()) / "ufp_cache")
@@ -221,23 +223,24 @@ class BaseApiClient:
221
223
  """Updates the url after changing _host or _port."""
222
224
  if self._port != 443:
223
225
  self._url = URL(f"https://{self._host}:{self._port}")
226
+ self._ws_url = URL(f"wss://{self._host}:{self._port}{self.ws_path}")
224
227
  else:
225
228
  self._url = URL(f"https://{self._host}")
229
+ self._ws_url = URL(f"wss://{self._host}{self.ws_path}")
226
230
 
227
231
  self.base_url = str(self._url)
228
232
 
229
233
  @property
230
- def ws_url(self) -> str:
234
+ def _ws_url_object(self) -> URL:
231
235
  """Get Websocket URL."""
232
- url = f"wss://{self._host}"
233
- if self._port != 443:
234
- url += f":{self._port}"
236
+ if last_update_id := self._get_last_update_id():
237
+ return self._ws_url.with_query(lastUpdateId=last_update_id)
238
+ return self._ws_url
235
239
 
236
- url += self.ws_path
237
- last_update_id = self._get_last_update_id()
238
- if last_update_id is None:
239
- return url
240
- return f"{url}?lastUpdateId={last_update_id}"
240
+ @property
241
+ def ws_url(self) -> str:
242
+ """Get Websocket URL."""
243
+ return str(self._ws_url_object)
241
244
 
242
245
  @property
243
246
  def config_file(self) -> Path:
@@ -253,37 +256,56 @@ class BaseApiClient:
253
256
 
254
257
  return self._session
255
258
 
256
- async def get_websocket(self) -> Websocket:
257
- """Gets or creates current Websocket."""
258
-
259
- async def _auth(force: bool) -> dict[str, str] | None:
260
- if force:
261
- if self._session is not None:
262
- self._session.cookie_jar.clear()
263
- self.set_header("cookie", None)
264
- self.set_header("x-csrf-token", None)
259
+ async def _auth_websocket(self, force: bool) -> dict[str, str] | None:
260
+ """Authenticate for Websocket."""
261
+ if force:
262
+ if self._session is not None:
263
+ self._session.cookie_jar.clear()
264
+ self.set_header("cookie", None)
265
+ self.set_header("x-csrf-token", None)
266
+ self._is_authenticated = False
265
267
 
266
- await self.ensure_authenticated()
267
- return self.headers
268
+ await self.ensure_authenticated()
269
+ return self.headers
268
270
 
271
+ def _get_websocket(self) -> Websocket:
272
+ """Gets or creates current Websocket."""
269
273
  if self._websocket is None:
270
274
  self._websocket = Websocket(
271
- self.get_websocket_url,
272
- _auth,
275
+ self._get_websocket_url,
276
+ self._auth_websocket,
277
+ self._update_bootstrap_soon,
278
+ self.get_session,
279
+ self._process_ws_message,
273
280
  verify=self._verify_ssl,
274
281
  timeout=self._ws_timeout,
275
282
  )
276
- self._websocket.subscribe(self._process_ws_message)
277
-
278
283
  return self._websocket
279
284
 
285
+ def _update_bootstrap_soon(self) -> None:
286
+ """Update bootstrap soon."""
287
+ _LOGGER.debug("Updating bootstrap soon")
288
+ # Force the next bootstrap update
289
+ # since the lastUpdateId is not valid anymore
290
+ if self._update_task and not self._update_task.done():
291
+ return
292
+ self._update_task = asyncio.create_task(self.update(force=True))
293
+
280
294
  async def close_session(self) -> None:
281
- """Closing and delets client session"""
295
+ """Closing and deletes client session"""
296
+ await self._cancel_update_task()
282
297
  if self._session is not None:
283
298
  await self._session.close()
284
299
  self._session = None
285
300
  self._loaded_session = False
286
301
 
302
+ async def _cancel_update_task(self) -> None:
303
+ if self._update_task:
304
+ self._update_task.cancel()
305
+ with contextlib.suppress(asyncio.CancelledError):
306
+ await self._update_task
307
+ self._update_task = None
308
+
287
309
  def set_header(self, key: str, value: str | None) -> None:
288
310
  """Set header."""
289
311
  self.headers = self.headers or {}
@@ -375,15 +397,7 @@ class BaseApiClient:
375
397
 
376
398
  try:
377
399
  if response.status != 200:
378
- reason = await get_response_reason(response)
379
- msg = "Request failed: %s - Status: %s - Reason: %s"
380
- if raise_exception:
381
- if response.status in {401, 403}:
382
- raise NotAuthorized(msg % (url, response.status, reason))
383
- if response.status >= 400 and response.status < 500:
384
- raise BadRequest(msg % (url, response.status, reason))
385
- raise NvrError(msg % (url, response.status, reason))
386
- _LOGGER.debug(msg, url, response.status, reason)
400
+ await self._raise_for_status(response, raise_exception)
387
401
  return None
388
402
 
389
403
  data: bytes | None = await response.read()
@@ -396,6 +410,30 @@ class BaseApiClient:
396
410
  # re-raise exception
397
411
  raise
398
412
 
413
+ async def _raise_for_status(
414
+ self, response: aiohttp.ClientResponse, raise_exception: bool = True
415
+ ) -> None:
416
+ url = response.url
417
+ reason = await get_response_reason(response)
418
+ msg = "Request failed: %s - Status: %s - Reason: %s"
419
+ if raise_exception:
420
+ status = response.status
421
+ if status in {
422
+ HTTPStatus.UNAUTHORIZED.value,
423
+ HTTPStatus.FORBIDDEN.value,
424
+ }:
425
+ raise NotAuthorized(msg % (url, status, reason))
426
+ elif status == HTTPStatus.TOO_MANY_REQUESTS.value:
427
+ _LOGGER.debug("Too many requests - Login is rate limited: %s", response)
428
+ raise NvrError(msg % (url, status, reason))
429
+ elif (
430
+ status >= HTTPStatus.BAD_REQUEST.value
431
+ and status < HTTPStatus.INTERNAL_SERVER_ERROR.value
432
+ ):
433
+ raise BadRequest(msg % (url, status, reason))
434
+ raise NvrError(msg % (url, status, reason))
435
+ _LOGGER.debug(msg, url, status, reason)
436
+
399
437
  async def api_request(
400
438
  self,
401
439
  url: str,
@@ -413,8 +451,13 @@ class BaseApiClient:
413
451
  )
414
452
 
415
453
  if data is not None:
416
- json_data: list[Any] | dict[str, Any] = orjson.loads(data)
417
- return json_data
454
+ json_data: list[Any] | dict[str, Any]
455
+ try:
456
+ json_data = orjson.loads(data)
457
+ return json_data
458
+ except orjson.JSONDecodeError as ex:
459
+ _LOGGER.error("Could not decode JSON from %s", url)
460
+ raise NvrError(f"Could not decode JSON from {url}") from ex
418
461
  return None
419
462
 
420
463
  async def api_request_obj(
@@ -487,6 +530,8 @@ class BaseApiClient:
487
530
  }
488
531
 
489
532
  response = await self.request("post", url=url, json=auth)
533
+ if response.status != 200:
534
+ await self._raise_for_status(response, True)
490
535
  self.set_header("cookie", response.headers.get("set-cookie", ""))
491
536
  self._is_authenticated = True
492
537
  _LOGGER.debug("Authenticated successfully!")
@@ -620,31 +665,17 @@ class BaseApiClient:
620
665
 
621
666
  return token_expires_at >= max_expire_time
622
667
 
623
- async def async_connect_ws(self, force: bool) -> None:
624
- """Connect to Websocket."""
625
- if force and self._websocket is not None:
626
- await self._websocket.disconnect()
627
- self._websocket = None
628
-
629
- websocket = await self.get_websocket()
630
- if not websocket.is_connected:
631
- self._last_ws_status = False
632
- with contextlib.suppress(
633
- TimeoutError,
634
- asyncio.TimeoutError,
635
- asyncio.CancelledError,
636
- ):
637
- await websocket.connect()
638
-
639
- def get_websocket_url(self) -> str:
668
+ def _get_websocket_url(self) -> URL:
640
669
  """Get Websocket URL."""
641
- return self.ws_url
670
+ return self._ws_url_object
642
671
 
643
672
  async def async_disconnect_ws(self) -> None:
644
673
  """Disconnect from Websocket."""
645
- if self._websocket is None:
646
- return
647
- await self._websocket.disconnect()
674
+ if self._websocket:
675
+ websocket = self._get_websocket()
676
+ websocket.stop()
677
+ await websocket.wait_closed()
678
+ self._websocket = None
648
679
 
649
680
  def check_ws(self) -> bool:
650
681
  """Checks current state of Websocket."""
@@ -668,6 +699,9 @@ class BaseApiClient:
668
699
  def _get_last_update_id(self) -> str | None:
669
700
  raise NotImplementedError
670
701
 
702
+ async def update(self, force: bool = False) -> Bootstrap | None:
703
+ raise NotImplementedError
704
+
671
705
 
672
706
  class ProtectApiClient(BaseApiClient):
673
707
  """
@@ -748,6 +782,7 @@ class ProtectApiClient(BaseApiClient):
748
782
  self._ignore_stats = ignore_stats
749
783
  self._ws_subscriptions = []
750
784
  self.ignore_unadopted = ignore_unadopted
785
+ self._update_lock = asyncio.Lock()
751
786
 
752
787
  if override_connection_host:
753
788
  self._connection_host = ip_from_host(self._host)
@@ -788,41 +823,37 @@ class ProtectApiClient(BaseApiClient):
788
823
 
789
824
  You can use the various other `get_` methods if you need one off data from UFP
790
825
  """
791
- now = time.monotonic()
792
- now_dt = utc_now()
793
- max_event_dt = now_dt - timedelta(hours=1)
794
- if force:
795
- self._last_update = NEVER_RAN
796
- self._last_update_dt = max_event_dt
797
-
798
- bootstrap_updated = False
799
- if self._bootstrap is None or now - self._last_update > DEVICE_UPDATE_INTERVAL:
800
- bootstrap_updated = True
801
- self._bootstrap = await self.get_bootstrap()
802
- self.__dict__.pop("bootstrap", None)
803
- self._last_update = now
804
- self._last_update_dt = now_dt
826
+ async with self._update_lock:
827
+ now = time.monotonic()
828
+ if force:
829
+ self._last_update = NEVER_RAN
805
830
 
806
- await self.async_connect_ws(force)
807
- if self.check_ws():
808
- # If the websocket is connected/connecting
809
- # we do not need to get events
810
- _LOGGER.debug("Skipping update since websocket is active")
811
- return None
831
+ bootstrap_updated = False
832
+ if (
833
+ self._bootstrap is None
834
+ or now - self._last_update > DEVICE_UPDATE_INTERVAL
835
+ ):
836
+ bootstrap_updated = True
837
+ self._bootstrap = await self.get_bootstrap()
838
+ self.__dict__.pop("bootstrap", None)
839
+ self._last_update = now
812
840
 
813
- if bootstrap_updated:
814
- return None
841
+ if bootstrap_updated:
842
+ return None
843
+ self._last_update = now
844
+ return self._bootstrap
815
845
 
846
+ async def poll_events(self) -> None:
847
+ """Poll for events."""
848
+ now_dt = utc_now()
849
+ max_event_dt = now_dt - timedelta(hours=1)
816
850
  events = await self.get_events(
817
851
  start=self._last_update_dt or max_event_dt,
818
852
  end=now_dt,
819
853
  )
820
854
  for event in events:
821
855
  self.bootstrap.process_event(event)
822
-
823
- self._last_update = now
824
856
  self._last_update_dt = now_dt
825
- return self._bootstrap
826
857
 
827
858
  def emit_message(self, msg: WSSubscriptionMessage) -> None:
828
859
  """Emit message to all subscriptions."""
@@ -1108,13 +1139,20 @@ class ProtectApiClient(BaseApiClient):
1108
1139
 
1109
1140
  Returns a callback that will unsubscribe.
1110
1141
  """
1111
-
1112
- def _unsub_ws_callback() -> None:
1113
- self._ws_subscriptions.remove(ws_callback)
1114
-
1115
1142
  _LOGGER.debug("Adding subscription: %s", ws_callback)
1116
1143
  self._ws_subscriptions.append(ws_callback)
1117
- return _unsub_ws_callback
1144
+ self._get_websocket().start()
1145
+ return partial(self._unsubscribe_websocket, ws_callback)
1146
+
1147
+ def _unsubscribe_websocket(
1148
+ self,
1149
+ ws_callback: Callable[[WSSubscriptionMessage], None],
1150
+ ) -> None:
1151
+ """Unsubscribe to websocket events."""
1152
+ _LOGGER.debug("Removing subscription: %s", ws_callback)
1153
+ self._ws_subscriptions.remove(ws_callback)
1154
+ if not self._ws_subscriptions:
1155
+ self._get_websocket().stop()
1118
1156
 
1119
1157
  async def get_bootstrap(self) -> Bootstrap:
1120
1158
  """
@@ -201,6 +201,7 @@ def shell(ctx: typer.Context) -> None:
201
201
 
202
202
  async def wait_forever() -> None:
203
203
  await protect.update()
204
+ protect.subscribe_websocket(lambda _: None)
204
205
  while True:
205
206
  await asyncio.sleep(10)
206
207
  await protect.update()
@@ -262,12 +263,16 @@ def profile_ws(
262
263
 
263
264
  async def callback() -> None:
264
265
  await protect.update()
266
+ unsub = protect.subscribe_websocket(lambda _: None)
265
267
  await profile_ws_job(
266
268
  protect,
267
269
  wait_time,
268
270
  output_path=output_path,
269
271
  ws_progress=_progress_bar,
270
272
  )
273
+ unsub()
274
+ await protect.async_disconnect_ws()
275
+ await protect.close_session()
271
276
 
272
277
  _setup_logger()
273
278
 
@@ -103,8 +103,10 @@ class SampleDataGenerator:
103
103
  async def async_generate(self, close_session: bool = True) -> None:
104
104
  self.log(f"Output folder: {self.output_folder}")
105
105
  self.output_folder.mkdir(parents=True, exist_ok=True)
106
- websocket = await self.client.get_websocket()
107
- websocket.subscribe(self._handle_ws_message)
106
+ websocket = self.client._get_websocket()
107
+ websocket.start()
108
+ self.log("Websocket started...")
109
+ websocket._subscription = self._handle_ws_message
108
110
 
109
111
  self.log("Updating devices...")
110
112
  await self.client.update()
@@ -131,8 +133,10 @@ class SampleDataGenerator:
131
133
  "chime": len(bootstrap["chimes"]),
132
134
  }
133
135
 
136
+ self.log("Generating event data...")
134
137
  motion_event, smart_detection = await self.generate_event_data()
135
138
  await self.generate_device_data(motion_event, smart_detection)
139
+ self.log("Recording websocket events...")
136
140
  await self.record_ws_events()
137
141
 
138
142
  if close_session:
@@ -0,0 +1,187 @@
1
+ """UniFi Protect Websockets."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import contextlib
7
+ import logging
8
+ from collections.abc import Awaitable, Callable, Coroutine
9
+ from http import HTTPStatus
10
+ from typing import Any, Optional
11
+
12
+ from aiohttp import (
13
+ ClientError,
14
+ ClientSession,
15
+ ClientWebSocketResponse,
16
+ WSMessage,
17
+ WSMsgType,
18
+ WSServerHandshakeError,
19
+ )
20
+ from yarl import URL
21
+
22
+ _LOGGER = logging.getLogger(__name__)
23
+ AuthCallbackType = Callable[..., Coroutine[Any, Any, Optional[dict[str, str]]]]
24
+ GetSessionCallbackType = Callable[[], Awaitable[ClientSession]]
25
+ UpdateBootstrapCallbackType = Callable[[], None]
26
+ _CLOSE_MESSAGE_TYPES = {WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED}
27
+
28
+
29
+ class Websocket:
30
+ """UniFi Protect Websocket manager."""
31
+
32
+ _running = False
33
+ _headers: dict[str, str] | None = None
34
+ _websocket_loop_task: asyncio.Task[None] | None = None
35
+ _stop_task: asyncio.Task[None] | None = None
36
+ _ws_connection: ClientWebSocketResponse | None = None
37
+
38
+ def __init__(
39
+ self,
40
+ get_url: Callable[[], URL],
41
+ auth_callback: AuthCallbackType,
42
+ update_bootstrap: UpdateBootstrapCallbackType,
43
+ get_session: GetSessionCallbackType,
44
+ subscription: Callable[[WSMessage], None],
45
+ *,
46
+ timeout: float = 30.0,
47
+ backoff: int = 10,
48
+ verify: bool = True,
49
+ ) -> None:
50
+ """Init Websocket."""
51
+ self.get_url = get_url
52
+ self.timeout = timeout
53
+ self.backoff = backoff
54
+ self.verify = verify
55
+ self._get_session = get_session
56
+ self._auth = auth_callback
57
+ self._update_bootstrap = update_bootstrap
58
+ self._subscription = subscription
59
+ self._seen_non_close_message = False
60
+
61
+ @property
62
+ def is_connected(self) -> bool:
63
+ """Return if the websocket is connected."""
64
+ return self._ws_connection is not None and not self._ws_connection.closed
65
+
66
+ async def _websocket_loop(self) -> None:
67
+ """Running loop for websocket."""
68
+ await self.wait_closed()
69
+ backoff = self.backoff
70
+
71
+ while True:
72
+ url = self.get_url()
73
+ try:
74
+ await self._websocket_inner_loop(url)
75
+ except ClientError as ex:
76
+ level = logging.ERROR if self._seen_non_close_message else logging.DEBUG
77
+ if isinstance(ex, WSServerHandshakeError):
78
+ if ex.status == HTTPStatus.UNAUTHORIZED.value:
79
+ _LOGGER.log(level, "Websocket authentication error: %s", url)
80
+ await self._attempt_reauth()
81
+ else:
82
+ _LOGGER.log(
83
+ level, "Websocket handshake error: %s", url, exc_info=True
84
+ )
85
+ else:
86
+ _LOGGER.log(
87
+ level, "Websocket disconnect error: %s", url, exc_info=True
88
+ )
89
+ except asyncio.TimeoutError:
90
+ level = logging.ERROR if self._seen_non_close_message else logging.DEBUG
91
+ _LOGGER.log(level, "Websocket timeout: %s", url)
92
+ except Exception:
93
+ _LOGGER.debug(
94
+ "Unexpected error in websocket reconnect loop, backoff: %s",
95
+ backoff,
96
+ exc_info=True,
97
+ )
98
+
99
+ if self._running is False:
100
+ break
101
+ await asyncio.sleep(self.backoff)
102
+
103
+ async def _websocket_inner_loop(self, url: URL) -> None:
104
+ _LOGGER.debug("Connecting WS to %s", url)
105
+ self._headers = await self._auth(False)
106
+ ssl = None if self.verify else False
107
+ msg: WSMessage | None = None
108
+ self._seen_non_close_message = False
109
+ session = await self._get_session()
110
+ # catch any and all errors for Websocket so we can clean up correctly
111
+ try:
112
+ self._ws_connection = await session.ws_connect(
113
+ url, ssl=ssl, headers=self._headers, timeout=self.timeout
114
+ )
115
+ while True:
116
+ msg = await self._ws_connection.receive(self.timeout)
117
+ msg_type = msg.type
118
+ if msg_type is WSMsgType.ERROR:
119
+ _LOGGER.exception("Error from Websocket: %s", msg.data)
120
+ break
121
+ elif msg_type in _CLOSE_MESSAGE_TYPES:
122
+ _LOGGER.debug("Websocket closed: %s", msg)
123
+ break
124
+
125
+ self._seen_non_close_message = True
126
+ try:
127
+ self._subscription(msg)
128
+ except Exception:
129
+ _LOGGER.exception("Error processing websocket message")
130
+ finally:
131
+ if (
132
+ msg is not None
133
+ and msg.type is WSMsgType.CLOSE
134
+ # If it closes right away or lastUpdateId is in the extra
135
+ # its an indication that we should update the bootstrap
136
+ # since lastUpdateId is invalid
137
+ and (
138
+ not self._seen_non_close_message
139
+ or (msg.extra and "lastUpdateId" in msg.extra)
140
+ )
141
+ ):
142
+ self._update_bootstrap()
143
+ _LOGGER.debug("Websocket disconnected: last message: %s", msg)
144
+ if self._ws_connection is not None and not self._ws_connection.closed:
145
+ await self._ws_connection.close()
146
+ self._ws_connection = None
147
+
148
+ async def _attempt_reauth(self) -> None:
149
+ """Attempt to re-authenticate."""
150
+ try:
151
+ self._headers = await self._auth(True)
152
+ except Exception:
153
+ _LOGGER.exception("Error reauthenticating websocket")
154
+
155
+ def start(self) -> None:
156
+ """Start the websocket."""
157
+ if self._running:
158
+ return
159
+ self._running = True
160
+ self._websocket_loop_task = asyncio.create_task(self._websocket_loop())
161
+
162
+ def stop(self) -> None:
163
+ """Disconnect the websocket."""
164
+ _LOGGER.debug("Disconnecting websocket...")
165
+ if not self._running:
166
+ return
167
+ if self._websocket_loop_task:
168
+ self._websocket_loop_task.cancel()
169
+ self._running = False
170
+ self._stop_task = asyncio.create_task(self._stop())
171
+
172
+ async def wait_closed(self) -> None:
173
+ """Wait for the websocket to close."""
174
+ if self._stop_task:
175
+ with contextlib.suppress(asyncio.CancelledError):
176
+ await self._stop_task
177
+ self._stop_task = None
178
+
179
+ async def _stop(self) -> None:
180
+ """Stop the websocket."""
181
+ if self._ws_connection:
182
+ await self._ws_connection.close()
183
+ self._ws_connection = None
184
+ if self._websocket_loop_task:
185
+ with contextlib.suppress(asyncio.CancelledError):
186
+ await self._websocket_loop_task
187
+ self._websocket_loop_task = None
@@ -1,226 +0,0 @@
1
- """UniFi Protect Websockets."""
2
-
3
- from __future__ import annotations
4
-
5
- import asyncio
6
- import logging
7
- import time
8
- from collections.abc import Callable, Coroutine
9
- from typing import Any, Optional
10
-
11
- from aiohttp import (
12
- ClientError,
13
- ClientSession,
14
- ClientWebSocketResponse,
15
- WSMessage,
16
- WSMsgType,
17
- )
18
-
19
- from .utils import asyncio_timeout
20
-
21
- _LOGGER = logging.getLogger(__name__)
22
- CALLBACK_TYPE = Callable[..., Coroutine[Any, Any, Optional[dict[str, str]]]]
23
- RECENT_FAILURE_CUT_OFF = 30
24
- RECENT_FAILURE_THRESHOLD = 2
25
-
26
-
27
- class Websocket:
28
- """UniFi Protect Websocket manager."""
29
-
30
- url: str
31
- verify: bool
32
- timeout_interval: int
33
- backoff: int
34
- _auth: CALLBACK_TYPE
35
- _timeout: float
36
- _ws_subscriptions: list[Callable[[WSMessage], None]]
37
- _connect_lock: asyncio.Lock
38
-
39
- _headers: dict[str, str] | None = None
40
- _websocket_loop_task: asyncio.Task[None] | None = None
41
- _timer_task: asyncio.Task[None] | None = None
42
- _ws_connection: ClientWebSocketResponse | None = None
43
- _last_connect: float = -1000
44
- _recent_failures: int = 0
45
-
46
- def __init__(
47
- self,
48
- get_url: Callable[[], str],
49
- auth_callback: CALLBACK_TYPE,
50
- *,
51
- timeout: int = 30,
52
- backoff: int = 10,
53
- verify: bool = True,
54
- ) -> None:
55
- """Init Websocket."""
56
- self.get_url = get_url
57
- self.timeout_interval = timeout
58
- self.backoff = backoff
59
- self.verify = verify
60
- self._auth = auth_callback
61
- self._timeout = time.monotonic()
62
- self._ws_subscriptions = []
63
- self._connect_lock = asyncio.Lock()
64
-
65
- @property
66
- def is_connected(self) -> bool:
67
- """Check if Websocket connected."""
68
- return self._ws_connection is not None
69
-
70
- def _get_session(self) -> ClientSession:
71
- # for testing, to make easier to mock
72
- return ClientSession()
73
-
74
- def _process_message(self, msg: WSMessage) -> bool:
75
- if msg.type == WSMsgType.ERROR:
76
- _LOGGER.exception("Error from Websocket: %s", msg.data)
77
- return False
78
-
79
- for sub in self._ws_subscriptions:
80
- try:
81
- sub(msg)
82
- except Exception:
83
- _LOGGER.exception("Error processing websocket message")
84
-
85
- return True
86
-
87
- async def _websocket_loop(self, start_event: asyncio.Event) -> None:
88
- url = self.get_url()
89
- _LOGGER.debug("Connecting WS to %s", url)
90
- self._headers = await self._auth(self._should_reset_auth)
91
-
92
- session = self._get_session()
93
- # catch any and all errors for Websocket so we can clean up correctly
94
- try:
95
- self._ws_connection = await session.ws_connect(
96
- url,
97
- ssl=None if self.verify else False,
98
- headers=self._headers,
99
- )
100
- start_event.set()
101
-
102
- self._reset_timeout()
103
- async for msg in self._ws_connection:
104
- if not self._process_message(msg):
105
- break
106
- self._reset_timeout()
107
- except ClientError:
108
- _LOGGER.exception("Websocket disconnect error: %s", url)
109
- finally:
110
- _LOGGER.debug("Websocket disconnected")
111
- self._increase_failure()
112
- self._cancel_timeout()
113
- if self._ws_connection is not None and not self._ws_connection.closed:
114
- await self._ws_connection.close()
115
- if not session.closed:
116
- await session.close()
117
- self._ws_connection = None
118
- # make sure event does not timeout
119
- start_event.set()
120
-
121
- @property
122
- def has_recent_connect(self) -> bool:
123
- """Check if Websocket has recent connection."""
124
- return time.monotonic() - RECENT_FAILURE_CUT_OFF <= self._last_connect
125
-
126
- @property
127
- def _should_reset_auth(self) -> bool:
128
- if self.has_recent_connect:
129
- if self._recent_failures > RECENT_FAILURE_THRESHOLD:
130
- return True
131
- else:
132
- self._recent_failures = 0
133
- return False
134
-
135
- def _increase_failure(self) -> None:
136
- if self.has_recent_connect:
137
- self._recent_failures += 1
138
- else:
139
- self._recent_failures = 1
140
-
141
- async def _do_timeout(self) -> bool:
142
- _LOGGER.debug("WS timed out")
143
- return await self.reconnect()
144
-
145
- async def _timeout_loop(self) -> None:
146
- while True:
147
- now = time.monotonic()
148
- if now > self._timeout:
149
- _LOGGER.debug("WS timed out")
150
- if not await self.reconnect():
151
- _LOGGER.debug("WS could not reconnect")
152
- continue
153
- sleep_time = self._timeout - now
154
- _LOGGER.debug("WS Timeout loop sleep %s", sleep_time)
155
- await asyncio.sleep(sleep_time)
156
-
157
- def _reset_timeout(self) -> None:
158
- self._timeout = time.monotonic() + self.timeout_interval
159
-
160
- if self._timer_task is None:
161
- self._timer_task = asyncio.create_task(self._timeout_loop())
162
-
163
- def _cancel_timeout(self) -> None:
164
- if self._timer_task:
165
- self._timer_task.cancel()
166
-
167
- async def connect(self) -> bool:
168
- """Connect the websocket."""
169
- if self._connect_lock.locked():
170
- _LOGGER.debug("Another connect is already happening")
171
- return False
172
- try:
173
- async with asyncio_timeout(0.1):
174
- await self._connect_lock.acquire()
175
- except (TimeoutError, asyncio.TimeoutError, asyncio.CancelledError):
176
- _LOGGER.debug("Failed to get connection lock")
177
-
178
- start_event = asyncio.Event()
179
- _LOGGER.debug("Scheduling WS connect...")
180
- self._websocket_loop_task = asyncio.create_task(
181
- self._websocket_loop(start_event),
182
- )
183
-
184
- try:
185
- async with asyncio_timeout(self.timeout_interval):
186
- await start_event.wait()
187
- except (TimeoutError, asyncio.TimeoutError, asyncio.CancelledError):
188
- _LOGGER.warning("Timed out while waiting for Websocket to connect")
189
- await self.disconnect()
190
-
191
- self._connect_lock.release()
192
- if self._ws_connection is None:
193
- _LOGGER.debug("Failed to connect to Websocket")
194
- return False
195
- _LOGGER.debug("Connected to Websocket successfully")
196
- self._last_connect = time.monotonic()
197
- return True
198
-
199
- async def disconnect(self) -> None:
200
- """Disconnect the websocket."""
201
- _LOGGER.debug("Disconnecting websocket...")
202
- if self._ws_connection is None:
203
- return
204
- await self._ws_connection.close()
205
- self._ws_connection = None
206
-
207
- async def reconnect(self) -> bool:
208
- """Reconnect the websocket."""
209
- _LOGGER.debug("Reconnecting websocket...")
210
- await self.disconnect()
211
- await asyncio.sleep(self.backoff)
212
- return await self.connect()
213
-
214
- def subscribe(self, ws_callback: Callable[[WSMessage], None]) -> Callable[[], None]:
215
- """
216
- Subscribe to raw websocket messages.
217
-
218
- Returns a callback that will unsubscribe.
219
- """
220
-
221
- def _unsub_ws_callback() -> None:
222
- self._ws_subscriptions.remove(ws_callback)
223
-
224
- _LOGGER.debug("Adding subscription: %s", ws_callback)
225
- self._ws_subscriptions.append(ws_callback)
226
- return _unsub_ws_callback
File without changes
File without changes