arthexis 0.1.19__py3-none-any.whl → 0.1.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nodes/models.py CHANGED
@@ -41,6 +41,9 @@ import logging
41
41
  logger = logging.getLogger(__name__)
42
42
 
43
43
 
44
+ ROLE_RENAMES: dict[str, str] = {"Constellation": "Watchtower"}
45
+
46
+
44
47
  class NodeRoleManager(models.Manager):
45
48
  def get_by_natural_key(self, name: str):
46
49
  return self.get(name=name)
@@ -188,8 +191,10 @@ class Node(Entity):
188
191
 
189
192
  DEFAULT_BADGE_COLOR = "#28a745"
190
193
  ROLE_BADGE_COLORS = {
191
- "Constellation": "#daa520", # goldenrod
194
+ "Watchtower": "#daa520", # goldenrod
195
+ "Constellation": "#daa520", # legacy alias
192
196
  "Control": "#673ab7", # deep purple
197
+ "Interface": "#0dcaf0", # cyan
193
198
  }
194
199
 
195
200
  class Relation(models.TextChoices):
@@ -202,6 +207,10 @@ class Node(Entity):
202
207
  address = models.GenericIPAddressField()
203
208
  mac_address = models.CharField(max_length=17, unique=True, null=True, blank=True)
204
209
  port = models.PositiveIntegerField(default=8000)
210
+ message_queue_length = models.PositiveSmallIntegerField(
211
+ default=10,
212
+ help_text="Maximum queued NetMessages to retain for this peer.",
213
+ )
205
214
  badge_color = models.CharField(max_length=7, default=DEFAULT_BADGE_COLOR)
206
215
  role = models.ForeignKey(NodeRole, on_delete=models.SET_NULL, null=True, blank=True)
207
216
  current_relation = models.CharField(
@@ -332,6 +341,7 @@ class Node(Entity):
332
341
  }
333
342
  role_lock = Path(settings.BASE_DIR) / "locks" / "role.lck"
334
343
  role_name = role_lock.read_text().strip() if role_lock.exists() else "Terminal"
344
+ role_name = ROLE_RENAMES.get(role_name, role_name)
335
345
  desired_role = NodeRole.objects.filter(name=role_name).first()
336
346
 
337
347
  if node:
@@ -490,6 +500,27 @@ class Node(Entity):
490
500
  self.public_key = pub_path.read_text()
491
501
  self.save(update_fields=["public_key"])
492
502
 
503
+ def get_private_key(self):
504
+ """Return the private key for this node if available."""
505
+
506
+ if not self.public_endpoint:
507
+ return None
508
+ try:
509
+ self.ensure_keys()
510
+ except Exception:
511
+ return None
512
+ priv_path = (
513
+ Path(self.base_path or settings.BASE_DIR)
514
+ / "security"
515
+ / f"{self.public_endpoint}"
516
+ )
517
+ try:
518
+ return serialization.load_pem_private_key(
519
+ priv_path.read_bytes(), password=None
520
+ )
521
+ except Exception:
522
+ return None
523
+
493
524
  @property
494
525
  def is_local(self):
495
526
  """Determine if this node represents the current host."""
@@ -764,6 +795,7 @@ class Node(Entity):
764
795
  self._sync_screenshot_task(screenshot_enabled)
765
796
  self._sync_landing_lead_task(celery_enabled)
766
797
  self._sync_ocpp_session_report_task(celery_enabled)
798
+ self._sync_upstream_poll_task(celery_enabled)
767
799
 
768
800
  def _sync_clipboard_task(self, enabled: bool):
769
801
  from django_celery_beat.models import IntervalSchedule, PeriodicTask
@@ -869,6 +901,28 @@ class Node(Entity):
869
901
  except (OperationalError, ProgrammingError):
870
902
  logger.debug("Skipping OCPP session report task sync; tables not ready")
871
903
 
904
+ def _sync_upstream_poll_task(self, celery_enabled: bool):
905
+ if not self.is_local:
906
+ return
907
+
908
+ from django_celery_beat.models import IntervalSchedule, PeriodicTask
909
+
910
+ task_name = "nodes_poll_upstream_messages"
911
+ if celery_enabled:
912
+ schedule, _ = IntervalSchedule.objects.get_or_create(
913
+ every=1, period=IntervalSchedule.MINUTES
914
+ )
915
+ PeriodicTask.objects.update_or_create(
916
+ name=task_name,
917
+ defaults={
918
+ "interval": schedule,
919
+ "task": "nodes.tasks.poll_unreachable_upstream",
920
+ "enabled": True,
921
+ },
922
+ )
923
+ else:
924
+ PeriodicTask.objects.filter(name=task_name).delete()
925
+
872
926
  def send_mail(
873
927
  self,
874
928
  subject: str,
@@ -1507,6 +1561,193 @@ class NetMessage(Entity):
1507
1561
  self.pk,
1508
1562
  )
1509
1563
 
1564
+ def _build_payload(
1565
+ self,
1566
+ *,
1567
+ sender_id: str | None,
1568
+ origin_uuid: str | None,
1569
+ reach_name: str | None,
1570
+ seen: list[str],
1571
+ ) -> dict[str, object]:
1572
+ payload: dict[str, object] = {
1573
+ "uuid": str(self.uuid),
1574
+ "subject": self.subject,
1575
+ "body": self.body,
1576
+ "seen": list(seen),
1577
+ "reach": reach_name,
1578
+ "sender": sender_id,
1579
+ "origin": origin_uuid,
1580
+ }
1581
+ if self.attachments:
1582
+ payload["attachments"] = self.attachments
1583
+ if self.filter_node:
1584
+ payload["filter_node"] = str(self.filter_node.uuid)
1585
+ if self.filter_node_feature:
1586
+ payload["filter_node_feature"] = self.filter_node_feature.slug
1587
+ if self.filter_node_role:
1588
+ payload["filter_node_role"] = self.filter_node_role.name
1589
+ if self.filter_current_relation:
1590
+ payload["filter_current_relation"] = self.filter_current_relation
1591
+ if self.filter_installed_version:
1592
+ payload["filter_installed_version"] = self.filter_installed_version
1593
+ if self.filter_installed_revision:
1594
+ payload["filter_installed_revision"] = self.filter_installed_revision
1595
+ return payload
1596
+
1597
+ @staticmethod
1598
+ def _serialize_payload(payload: dict[str, object]) -> str:
1599
+ return json.dumps(payload, separators=(",", ":"), sort_keys=True)
1600
+
1601
+ @staticmethod
1602
+ def _sign_payload(payload_json: str, private_key) -> str | None:
1603
+ if not private_key:
1604
+ return None
1605
+ try:
1606
+ signature = private_key.sign(
1607
+ payload_json.encode(),
1608
+ padding.PKCS1v15(),
1609
+ hashes.SHA256(),
1610
+ )
1611
+ except Exception:
1612
+ return None
1613
+ return base64.b64encode(signature).decode()
1614
+
1615
+ def queue_for_node(self, node: "Node", seen: list[str]) -> None:
1616
+ """Queue this message for later delivery to ``node``."""
1617
+
1618
+ if node.current_relation != Node.Relation.DOWNSTREAM:
1619
+ return
1620
+
1621
+ now = timezone.now()
1622
+ expires_at = now + timedelta(hours=1)
1623
+ normalized_seen = [str(value) for value in seen]
1624
+ entry, created = PendingNetMessage.objects.get_or_create(
1625
+ node=node,
1626
+ message=self,
1627
+ defaults={
1628
+ "seen": normalized_seen,
1629
+ "stale_at": expires_at,
1630
+ },
1631
+ )
1632
+ if created:
1633
+ entry.queued_at = now
1634
+ entry.save(update_fields=["queued_at"])
1635
+ else:
1636
+ entry.seen = normalized_seen
1637
+ entry.stale_at = expires_at
1638
+ entry.queued_at = now
1639
+ entry.save(update_fields=["seen", "stale_at", "queued_at"])
1640
+ self._trim_queue(node)
1641
+
1642
+ def clear_queue_for_node(self, node: "Node") -> None:
1643
+ PendingNetMessage.objects.filter(node=node, message=self).delete()
1644
+
1645
+ def _trim_queue(self, node: "Node") -> None:
1646
+ limit = max(int(node.message_queue_length or 0), 0)
1647
+ if limit == 0:
1648
+ PendingNetMessage.objects.filter(node=node).delete()
1649
+ return
1650
+ qs = PendingNetMessage.objects.filter(node=node).order_by("-queued_at")
1651
+ keep_ids = list(qs.values_list("pk", flat=True)[:limit])
1652
+ if keep_ids:
1653
+ PendingNetMessage.objects.filter(node=node).exclude(pk__in=keep_ids).delete()
1654
+ else:
1655
+ qs.delete()
1656
+
1657
+ @classmethod
1658
+ def receive_payload(
1659
+ cls,
1660
+ data: dict[str, object],
1661
+ *,
1662
+ sender: "Node",
1663
+ ) -> "NetMessage":
1664
+ msg_uuid = data.get("uuid")
1665
+ if not msg_uuid:
1666
+ raise ValueError("uuid required")
1667
+ subject = (data.get("subject") or "")[:64]
1668
+ body = (data.get("body") or "")[:256]
1669
+ attachments = cls.normalize_attachments(data.get("attachments"))
1670
+ reach_name = data.get("reach")
1671
+ reach_role = None
1672
+ if reach_name:
1673
+ reach_role = NodeRole.objects.filter(name=reach_name).first()
1674
+ filter_node_uuid = data.get("filter_node")
1675
+ filter_node = None
1676
+ if filter_node_uuid:
1677
+ filter_node = Node.objects.filter(uuid=filter_node_uuid).first()
1678
+ filter_feature_slug = data.get("filter_node_feature")
1679
+ filter_feature = None
1680
+ if filter_feature_slug:
1681
+ filter_feature = NodeFeature.objects.filter(slug=filter_feature_slug).first()
1682
+ filter_role_name = data.get("filter_node_role")
1683
+ filter_role = None
1684
+ if filter_role_name:
1685
+ filter_role = NodeRole.objects.filter(name=filter_role_name).first()
1686
+ filter_relation_value = data.get("filter_current_relation")
1687
+ filter_relation = ""
1688
+ if filter_relation_value:
1689
+ relation = Node.normalize_relation(filter_relation_value)
1690
+ filter_relation = relation.value if relation else ""
1691
+ filter_installed_version = (data.get("filter_installed_version") or "")[:20]
1692
+ filter_installed_revision = (data.get("filter_installed_revision") or "")[:40]
1693
+ seen_values = data.get("seen", [])
1694
+ if not isinstance(seen_values, list):
1695
+ seen_values = list(seen_values) # type: ignore[arg-type]
1696
+ normalized_seen = [str(value) for value in seen_values if value is not None]
1697
+ origin_id = data.get("origin")
1698
+ origin_node = None
1699
+ if origin_id:
1700
+ origin_node = Node.objects.filter(uuid=origin_id).first()
1701
+ if not origin_node:
1702
+ origin_node = sender
1703
+ msg, created = cls.objects.get_or_create(
1704
+ uuid=msg_uuid,
1705
+ defaults={
1706
+ "subject": subject,
1707
+ "body": body,
1708
+ "reach": reach_role,
1709
+ "node_origin": origin_node,
1710
+ "attachments": attachments or None,
1711
+ "filter_node": filter_node,
1712
+ "filter_node_feature": filter_feature,
1713
+ "filter_node_role": filter_role,
1714
+ "filter_current_relation": filter_relation,
1715
+ "filter_installed_version": filter_installed_version,
1716
+ "filter_installed_revision": filter_installed_revision,
1717
+ },
1718
+ )
1719
+ if not created:
1720
+ msg.subject = subject
1721
+ msg.body = body
1722
+ update_fields = ["subject", "body"]
1723
+ if reach_role and msg.reach_id != reach_role.id:
1724
+ msg.reach = reach_role
1725
+ update_fields.append("reach")
1726
+ if msg.node_origin_id is None and origin_node:
1727
+ msg.node_origin = origin_node
1728
+ update_fields.append("node_origin")
1729
+ if attachments and msg.attachments != attachments:
1730
+ msg.attachments = attachments
1731
+ update_fields.append("attachments")
1732
+ field_updates = {
1733
+ "filter_node": filter_node,
1734
+ "filter_node_feature": filter_feature,
1735
+ "filter_node_role": filter_role,
1736
+ "filter_current_relation": filter_relation,
1737
+ "filter_installed_version": filter_installed_version,
1738
+ "filter_installed_revision": filter_installed_revision,
1739
+ }
1740
+ for field, value in field_updates.items():
1741
+ if getattr(msg, field) != value:
1742
+ setattr(msg, field, value)
1743
+ update_fields.append(field)
1744
+ if update_fields:
1745
+ msg.save(update_fields=update_fields)
1746
+ if attachments:
1747
+ msg.apply_attachments(attachments)
1748
+ msg.propagate(seen=normalized_seen)
1749
+ return msg
1750
+
1510
1751
  def propagate(self, seen: list[str] | None = None):
1511
1752
  from core.notifications import notify
1512
1753
  import random
@@ -1541,17 +1782,7 @@ class NetMessage(Entity):
1541
1782
  local_id = str(local.uuid)
1542
1783
  if local_id not in seen:
1543
1784
  seen.append(local_id)
1544
- priv_path = (
1545
- Path(local.base_path or settings.BASE_DIR)
1546
- / "security"
1547
- / f"{local.public_endpoint}"
1548
- )
1549
- try:
1550
- private_key = serialization.load_pem_private_key(
1551
- priv_path.read_bytes(), password=None
1552
- )
1553
- except Exception:
1554
- private_key = None
1785
+ private_key = local.get_private_key()
1555
1786
  for node_id in seen:
1556
1787
  node = Node.objects.filter(uuid=node_id).first()
1557
1788
  if node and (not local or node.pk != local.pk):
@@ -1601,11 +1832,18 @@ class NetMessage(Entity):
1601
1832
  reach_source = self.filter_node_role or self.reach
1602
1833
  reach_name = reach_source.name if reach_source else None
1603
1834
  role_map = {
1835
+ "Interface": ["Interface", "Terminal"],
1604
1836
  "Terminal": ["Terminal"],
1605
1837
  "Control": ["Control", "Terminal"],
1606
1838
  "Satellite": ["Satellite", "Control", "Terminal"],
1839
+ "Watchtower": [
1840
+ "Watchtower",
1841
+ "Satellite",
1842
+ "Control",
1843
+ "Terminal",
1844
+ ],
1607
1845
  "Constellation": [
1608
- "Constellation",
1846
+ "Watchtower",
1609
1847
  "Satellite",
1610
1848
  "Control",
1611
1849
  "Terminal",
@@ -1650,54 +1888,36 @@ class NetMessage(Entity):
1650
1888
  selected_ids = [str(n.uuid) for n in selected]
1651
1889
  payload_seen = seen_list + selected_ids
1652
1890
  for node in selected:
1653
- payload = {
1654
- "uuid": str(self.uuid),
1655
- "subject": self.subject,
1656
- "body": self.body,
1657
- "seen": payload_seen,
1658
- "reach": reach_name,
1659
- "sender": local_id,
1660
- "origin": origin_uuid,
1661
- }
1662
- if self.attachments:
1663
- payload["attachments"] = self.attachments
1664
- if self.filter_node:
1665
- payload["filter_node"] = str(self.filter_node.uuid)
1666
- if self.filter_node_feature:
1667
- payload["filter_node_feature"] = self.filter_node_feature.slug
1668
- if self.filter_node_role:
1669
- payload["filter_node_role"] = self.filter_node_role.name
1670
- if self.filter_current_relation:
1671
- payload["filter_current_relation"] = self.filter_current_relation
1672
- if self.filter_installed_version:
1673
- payload["filter_installed_version"] = self.filter_installed_version
1674
- if self.filter_installed_revision:
1675
- payload["filter_installed_revision"] = self.filter_installed_revision
1676
- payload_json = json.dumps(payload, separators=(",", ":"), sort_keys=True)
1891
+ payload = self._build_payload(
1892
+ sender_id=local_id,
1893
+ origin_uuid=origin_uuid,
1894
+ reach_name=reach_name,
1895
+ seen=payload_seen,
1896
+ )
1897
+ payload_json = self._serialize_payload(payload)
1677
1898
  headers = {"Content-Type": "application/json"}
1678
- if private_key:
1679
- try:
1680
- signature = private_key.sign(
1681
- payload_json.encode(),
1682
- padding.PKCS1v15(),
1683
- hashes.SHA256(),
1684
- )
1685
- headers["X-Signature"] = base64.b64encode(signature).decode()
1686
- except Exception:
1687
- pass
1899
+ signature = self._sign_payload(payload_json, private_key)
1900
+ if signature:
1901
+ headers["X-Signature"] = signature
1902
+ success = False
1688
1903
  try:
1689
- requests.post(
1904
+ response = requests.post(
1690
1905
  f"http://{node.address}:{node.port}/nodes/net-message/",
1691
1906
  data=payload_json,
1692
1907
  headers=headers,
1693
1908
  timeout=1,
1694
1909
  )
1910
+ success = bool(response.ok)
1695
1911
  except Exception:
1696
1912
  logger.exception(
1697
1913
  "Failed to propagate NetMessage %s to node %s",
1698
1914
  self.pk,
1699
1915
  node.pk,
1700
1916
  )
1917
+ if success:
1918
+ self.clear_queue_for_node(node)
1919
+ else:
1920
+ self.queue_for_node(node, payload_seen)
1701
1921
  self.propagated_to.add(node)
1702
1922
 
1703
1923
  save_fields: list[str] = []
@@ -1709,6 +1929,32 @@ class NetMessage(Entity):
1709
1929
  self.save(update_fields=save_fields)
1710
1930
 
1711
1931
 
1932
+ class PendingNetMessage(models.Model):
1933
+ """Queued :class:`NetMessage` awaiting delivery to a downstream node."""
1934
+
1935
+ node = models.ForeignKey(
1936
+ Node, on_delete=models.CASCADE, related_name="pending_net_messages"
1937
+ )
1938
+ message = models.ForeignKey(
1939
+ NetMessage,
1940
+ on_delete=models.CASCADE,
1941
+ related_name="pending_deliveries",
1942
+ )
1943
+ seen = models.JSONField(default=list)
1944
+ queued_at = models.DateTimeField(auto_now_add=True)
1945
+ stale_at = models.DateTimeField()
1946
+
1947
+ class Meta:
1948
+ unique_together = ("node", "message")
1949
+ ordering = ("queued_at",)
1950
+
1951
+ def __str__(self) -> str: # pragma: no cover - simple representation
1952
+ return f"{self.message_id} → {self.node_id}"
1953
+
1954
+ @property
1955
+ def is_stale(self) -> bool:
1956
+ return self.stale_at <= timezone.now()
1957
+
1712
1958
  class ContentSample(Entity):
1713
1959
  """Collected content such as text snippets or screenshots."""
1714
1960
 
nodes/rfid_sync.py CHANGED
@@ -99,7 +99,7 @@ def apply_rfid_payload(
99
99
  last_seen = entry.get("last_seen_on")
100
100
  defaults["last_seen_on"] = parse_datetime(last_seen) if last_seen else None
101
101
 
102
- obj, created = RFID.objects.update_or_create(rfid=rfid_value, defaults=defaults)
102
+ obj, created = RFID.update_or_create_from_code(rfid_value, defaults=defaults)
103
103
 
104
104
  outcome.instance = obj
105
105
  outcome.created = created
nodes/tasks.py CHANGED
@@ -1,11 +1,16 @@
1
+ import base64
2
+ import json
1
3
  import logging
2
4
  from pathlib import Path
3
5
 
4
6
  import pyperclip
5
- from pyperclip import PyperclipException
7
+ import requests
6
8
  from celery import shared_task
9
+ from cryptography.hazmat.primitives import hashes, serialization
10
+ from cryptography.hazmat.primitives.asymmetric import padding
11
+ from pyperclip import PyperclipException
7
12
 
8
- from .models import ContentSample, Node
13
+ from .models import ContentSample, NetMessage, Node
9
14
  from .utils import capture_screenshot, save_screenshot
10
15
 
11
16
  logger = logging.getLogger(__name__)
@@ -44,3 +49,96 @@ def capture_node_screenshot(
44
49
  node = Node.get_local()
45
50
  save_screenshot(path, node=node, method=method)
46
51
  return str(path)
52
+
53
+
54
+ @shared_task
55
+ def poll_unreachable_upstream() -> None:
56
+ """Poll upstream nodes for queued NetMessages."""
57
+
58
+ local = Node.get_local()
59
+ if not local or not local.has_feature("celery-queue"):
60
+ return
61
+
62
+ private_key = local.get_private_key()
63
+ if not private_key:
64
+ logger.warning("Node %s cannot sign upstream polls", getattr(local, "pk", None))
65
+ return
66
+
67
+ requester_payload = {"requester": str(local.uuid)}
68
+ payload_json = json.dumps(requester_payload, separators=(",", ":"), sort_keys=True)
69
+ try:
70
+ signature = base64.b64encode(
71
+ private_key.sign(
72
+ payload_json.encode(),
73
+ padding.PKCS1v15(),
74
+ hashes.SHA256(),
75
+ )
76
+ ).decode()
77
+ except Exception as exc:
78
+ logger.warning("Failed to sign upstream poll request: %s", exc)
79
+ return
80
+
81
+ headers = {"Content-Type": "application/json", "X-Signature": signature}
82
+ upstream_nodes = Node.objects.filter(current_relation=Node.Relation.UPSTREAM)
83
+ for upstream in upstream_nodes:
84
+ if not upstream.public_key:
85
+ continue
86
+ host = (upstream.address or upstream.hostname or "").strip()
87
+ if not host:
88
+ continue
89
+ if ":" in host and not host.startswith("["):
90
+ host = f"[{host}]"
91
+ port = upstream.port or 8000
92
+ if port in {80, 443}:
93
+ url = f"http://{host}/nodes/net-message/pull/"
94
+ else:
95
+ url = f"http://{host}:{port}/nodes/net-message/pull/"
96
+ try:
97
+ response = requests.post(url, data=payload_json, headers=headers, timeout=5)
98
+ except Exception as exc:
99
+ logger.warning("Polling upstream node %s failed: %s", upstream.pk, exc)
100
+ continue
101
+ if not response.ok:
102
+ logger.warning(
103
+ "Upstream node %s returned status %s", upstream.pk, response.status_code
104
+ )
105
+ continue
106
+ try:
107
+ body = response.json()
108
+ except ValueError:
109
+ logger.warning("Upstream node %s returned invalid JSON", upstream.pk)
110
+ continue
111
+ messages = body.get("messages", [])
112
+ if not isinstance(messages, list) or not messages:
113
+ continue
114
+ try:
115
+ public_key = serialization.load_pem_public_key(upstream.public_key.encode())
116
+ except Exception:
117
+ logger.warning("Upstream node %s has invalid public key", upstream.pk)
118
+ continue
119
+ for item in messages:
120
+ if not isinstance(item, dict):
121
+ continue
122
+ payload = item.get("payload")
123
+ payload_signature = item.get("signature")
124
+ if not isinstance(payload, dict) or not payload_signature:
125
+ continue
126
+ payload_text = json.dumps(payload, separators=(",", ":"), sort_keys=True)
127
+ try:
128
+ public_key.verify(
129
+ base64.b64decode(payload_signature),
130
+ payload_text.encode(),
131
+ padding.PKCS1v15(),
132
+ hashes.SHA256(),
133
+ )
134
+ except Exception:
135
+ logger.warning(
136
+ "Signature verification failed for upstream node %s", upstream.pk
137
+ )
138
+ continue
139
+ try:
140
+ NetMessage.receive_payload(payload, sender=upstream)
141
+ except ValueError as exc:
142
+ logger.warning(
143
+ "Discarded upstream message from node %s: %s", upstream.pk, exc
144
+ )