reflex-sdk 0.3.0__tar.gz → 0.3.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. {reflex_sdk-0.3.0 → reflex_sdk-0.3.2}/PKG-INFO +8 -1
  2. {reflex_sdk-0.3.0 → reflex_sdk-0.3.2}/pyproject.toml +17 -3
  3. {reflex_sdk-0.3.0 → reflex_sdk-0.3.2}/reflex/__init__.py +41 -5
  4. reflex_sdk-0.3.2/reflex/_region_probe.py +210 -0
  5. {reflex_sdk-0.3.0 → reflex_sdk-0.3.2}/reflex/actions.py +137 -23
  6. reflex_sdk-0.3.2/reflex/auth_runner.py +146 -0
  7. reflex_sdk-0.3.2/reflex/cameras/__init__.py +39 -0
  8. reflex_sdk-0.3.2/reflex/cameras/base.py +39 -0
  9. reflex_sdk-0.3.2/reflex/cameras/realsense.py +73 -0
  10. reflex_sdk-0.3.2/reflex/cameras/shm.py +120 -0
  11. reflex_sdk-0.3.2/reflex/cameras/v4l2.py +75 -0
  12. {reflex_sdk-0.3.0 → reflex_sdk-0.3.2}/reflex/cli.py +982 -16
  13. {reflex_sdk-0.3.0 → reflex_sdk-0.3.2}/reflex/client.py +116 -29
  14. reflex_sdk-0.3.2/reflex/connect_runner.py +442 -0
  15. reflex_sdk-0.3.2/reflex/connectors/__init__.py +44 -0
  16. reflex_sdk-0.3.2/reflex/connectors/base.py +70 -0
  17. reflex_sdk-0.3.2/reflex/connectors/shell.py +67 -0
  18. reflex_sdk-0.3.2/reflex/connectors/yam_bimanual.py +456 -0
  19. reflex_sdk-0.3.2/reflex/deployments.py +143 -0
  20. reflex_sdk-0.3.2/reflex/models.py +207 -0
  21. reflex_sdk-0.3.2/reflex/receipts.py +34 -0
  22. reflex_sdk-0.3.2/reflex/robot_runtime.py +207 -0
  23. reflex_sdk-0.3.2/reflex/robots.py +194 -0
  24. reflex_sdk-0.3.2/reflex/sessions.py +113 -0
  25. reflex_sdk-0.3.2/reflex/so101.py +1024 -0
  26. reflex_sdk-0.3.2/reflex/transports/__init__.py +44 -0
  27. reflex_sdk-0.3.2/reflex/transports/_webrtc_client.py +679 -0
  28. reflex_sdk-0.3.2/reflex/transports/_webrtc_streaming_client.py +326 -0
  29. reflex_sdk-0.3.2/reflex/transports/base.py +40 -0
  30. reflex_sdk-0.3.2/reflex/transports/edge_http.py +174 -0
  31. reflex_sdk-0.3.2/reflex/transports/platform.py +157 -0
  32. reflex_sdk-0.3.2/reflex/transports/webrtc.py +275 -0
  33. {reflex_sdk-0.3.0 → reflex_sdk-0.3.2}/reflex_sdk.egg-info/PKG-INFO +8 -1
  34. reflex_sdk-0.3.2/reflex_sdk.egg-info/SOURCES.txt +53 -0
  35. {reflex_sdk-0.3.0 → reflex_sdk-0.3.2}/reflex_sdk.egg-info/requires.txt +8 -0
  36. reflex_sdk-0.3.2/tests/test_connect_runner.py +352 -0
  37. reflex_sdk-0.3.2/tests/test_prod1_smoke.py +53 -0
  38. reflex_sdk-0.3.2/tests/test_product_cli.py +2242 -0
  39. reflex_sdk-0.3.2/tests/test_public_sdk.py +638 -0
  40. reflex_sdk-0.3.2/tests/test_region_probe.py +278 -0
  41. reflex_sdk-0.3.2/tests/test_so101_actions.py +148 -0
  42. reflex_sdk-0.3.0/reflex/models.py +0 -106
  43. reflex_sdk-0.3.0/reflex_sdk.egg-info/SOURCES.txt +0 -24
  44. reflex_sdk-0.3.0/tests/test_product_cli.py +0 -940
  45. reflex_sdk-0.3.0/tests/test_public_sdk.py +0 -207
  46. {reflex_sdk-0.3.0 → reflex_sdk-0.3.2}/README.md +0 -0
  47. {reflex_sdk-0.3.0 → reflex_sdk-0.3.2}/reflex/__main__.py +0 -0
  48. {reflex_sdk-0.3.0 → reflex_sdk-0.3.2}/reflex/_convex.py +0 -0
  49. {reflex_sdk-0.3.0 → reflex_sdk-0.3.2}/reflex/_transport.py +0 -0
  50. {reflex_sdk-0.3.0 → reflex_sdk-0.3.2}/reflex/_version.py +0 -0
  51. {reflex_sdk-0.3.0 → reflex_sdk-0.3.2}/reflex/datasets.py +0 -0
  52. {reflex_sdk-0.3.0 → reflex_sdk-0.3.2}/reflex/instances.py +0 -0
  53. {reflex_sdk-0.3.0 → reflex_sdk-0.3.2}/reflex/keys.py +0 -0
  54. {reflex_sdk-0.3.0 → reflex_sdk-0.3.2}/reflex/product.py +0 -0
  55. {reflex_sdk-0.3.0 → reflex_sdk-0.3.2}/reflex/training.py +0 -0
  56. {reflex_sdk-0.3.0 → reflex_sdk-0.3.2}/reflex_sdk.egg-info/dependency_links.txt +0 -0
  57. {reflex_sdk-0.3.0 → reflex_sdk-0.3.2}/reflex_sdk.egg-info/entry_points.txt +0 -0
  58. {reflex_sdk-0.3.0 → reflex_sdk-0.3.2}/reflex_sdk.egg-info/top_level.txt +0 -0
  59. {reflex_sdk-0.3.0 → reflex_sdk-0.3.2}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: reflex-sdk
3
- Version: 0.3.0
3
+ Version: 0.3.2
4
4
  Summary: Python SDK for Reflex hosted robot inference and training
5
5
  Author: Reflex
6
6
  Project-URL: Homepage, https://tryreflex.ai
@@ -15,8 +15,15 @@ Classifier: Programming Language :: Python :: 3.11
15
15
  Classifier: Programming Language :: Python :: 3.12
16
16
  Requires-Python: >=3.9
17
17
  Description-Content-Type: text/markdown
18
+ Requires-Dist: PyYAML<7,>=6.0
18
19
  Requires-Dist: typer<0.24,>=0.23; python_version < "3.10"
19
20
  Requires-Dist: typer>=0.24.1; python_version >= "3.10"
21
+ Provides-Extra: webrtc
22
+ Requires-Dist: aiortc>=1.8; extra == "webrtc"
23
+ Requires-Dist: av>=11; extra == "webrtc"
24
+ Requires-Dist: msgpack>=1.0; extra == "webrtc"
25
+ Requires-Dist: numpy>=1.24; extra == "webrtc"
26
+ Requires-Dist: Pillow>=10; extra == "webrtc"
20
27
 
21
28
  # Reflex Python SDK
22
29
 
@@ -1,10 +1,11 @@
1
1
  [project]
2
2
  name = "reflex-sdk"
3
- version = "0.3.0"
3
+ version = "0.3.2"
4
4
  description = "Python SDK for Reflex hosted robot inference and training"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.9"
7
7
  dependencies = [
8
+ "PyYAML>=6.0,<7",
8
9
  "typer>=0.23,<0.24; python_version < '3.10'",
9
10
  "typer>=0.24.1; python_version >= '3.10'",
10
11
  ]
@@ -22,6 +23,19 @@ classifiers = [
22
23
  "Programming Language :: Python :: 3.12",
23
24
  ]
24
25
 
26
+ [project.optional-dependencies]
27
+ # Pulled in by `pip install reflex-sdk[webrtc]`. Required when using
28
+ # `target.kind: webrtc` in a `reflex connect` config — the WebRTCTransport
29
+ # imports numpy/Pillow at module top, and the vendored client uses aiortc
30
+ # (peer connection + H.264 video tracks) and msgpack (DataChannel wire).
31
+ webrtc = [
32
+ "aiortc>=1.8",
33
+ "av>=11",
34
+ "msgpack>=1.0",
35
+ "numpy>=1.24",
36
+ "Pillow>=10",
37
+ ]
38
+
25
39
  [project.urls]
26
40
  Homepage = "https://tryreflex.ai"
27
41
  Repository = "https://github.com/reflex-inc/reflex"
@@ -30,11 +44,11 @@ Repository = "https://github.com/reflex-inc/reflex"
30
44
  reflex = "reflex.cli:main_reflex"
31
45
 
32
46
  [tool.setuptools]
33
- packages = ["reflex"]
47
+ packages = ["reflex", "reflex.cameras", "reflex.connectors", "reflex.transports"]
34
48
 
35
49
  [tool.ruff]
36
50
  line-length = 100
37
51
 
38
52
  [build-system]
39
53
  requires = ["setuptools>=68"]
40
- build-backend = "setuptools.build_meta"
54
+ build-backend = "setuptools.build_meta"
@@ -3,6 +3,7 @@
3
3
  from ._version import __version__
4
4
  from .actions import ActionStream, action, connect, infer_actions, observation
5
5
  from .client import Client
6
+ # TODO PROD.2: re-add ConnectError/ConnectRunner/SessionGrant once auth wrapper is integrated
6
7
  from .datasets import (
7
8
  complete_dataset,
8
9
  create_dataset,
@@ -12,9 +13,26 @@ from .datasets import (
12
13
  upload_dataset,
13
14
  validate_dataset,
14
15
  )
16
+ from .deployments import (
17
+ create_deployment,
18
+ create_deployment_from_spec,
19
+ get_deployment,
20
+ list_deployments,
21
+ run_deployment_doctor,
22
+ )
15
23
  from .instances import instance_status, provision_instance, teardown_instance
16
- from .keys import bind_key_to_model
17
- from .models import delete_model, get_model, import_from_hf, list_models
24
+ from .receipts import list_receipts
25
+ from .robots import (
26
+ claim_pairing_token,
27
+ create_pairing_token,
28
+ heartbeat_robot,
29
+ list_robot_schemas,
30
+ list_robots,
31
+ register_robot,
32
+ register_robot_schema,
33
+ )
34
+ from .robot_runtime import RobotExecutionConfig, RobotExecutionResult, run_robot_execution_loop
35
+ from .sessions import close_session, list_sessions, promote_session, start_session
18
36
  from .training import (
19
37
  AdamParams,
20
38
  AdapterHandle,
@@ -37,35 +55,53 @@ __all__ = [
37
55
  "AdamParams",
38
56
  "AdapterHandle",
39
57
  "Client",
58
+ "claim_pairing_token",
40
59
  "Datum",
41
60
  "ForwardBackwardResult",
42
61
  "LoraTrainingClient",
43
62
  "OptimStepResult",
63
+ "RobotExecutionConfig",
64
+ "RobotExecutionResult",
44
65
  "ServiceClient",
66
+ "SessionGrant",
45
67
  "__version__",
46
68
  "action",
47
69
  "bind_key_to_model",
48
70
  "cancel_training_job",
71
+ "close_session",
49
72
  "complete_dataset",
50
73
  "connect",
51
74
  "create_dataset",
75
+ "create_deployment",
76
+ "create_deployment_from_spec",
77
+ "create_pairing_token",
52
78
  "create_training_job",
53
79
  "delete_model",
54
80
  "full_finetune",
55
81
  "full_train",
56
82
  "get_dataset",
57
- "get_model",
83
+ "get_deployment",
58
84
  "get_training_job",
59
- "import_from_hf",
85
+ "heartbeat_robot",
60
86
  "infer_actions",
61
87
  "instance_status",
62
- "list_models",
88
+ "list_sessions",
63
89
  "list_training_jobs",
64
90
  "list_datasets",
91
+ "list_deployments",
92
+ "list_receipts",
93
+ "list_robot_schemas",
94
+ "list_robots",
65
95
  "lora_finetune",
66
96
  "observation",
67
97
  "provision_instance",
98
+ "promote_session",
68
99
  "register_huggingface_dataset",
100
+ "register_robot",
101
+ "register_robot_schema",
102
+ "run_deployment_doctor",
103
+ "run_robot_execution_loop",
104
+ "start_session",
69
105
  "teardown_instance",
70
106
  "upload_dataset",
71
107
  "validate_dataset",
@@ -0,0 +1,210 @@
1
+ """Client-side nearest-region probe for the Reflex SDK CLI (CFG.1).
2
+
3
+ When `REFLEX_INFERENCE_REGION_MAP=region=URL,region=URL,...` is set, the
4
+ CLI hits each URL's `/health` at startup, picks the lowest-p50 region,
5
+ and uses that URL for the inference stream instead of the one the
6
+ platform returned. This buys ~20-60ms RTT on multi-region deploys
7
+ without server-side changes (NetR.3a finding, re-nr7l).
8
+
9
+ Single-entry maps skip the probe entirely (zero startup cost), so the
10
+ default single-region behaviour is unchanged until a multi-region map
11
+ is configured.
12
+
13
+ Stdlib-only (urllib + statistics) so it runs in any environment the
14
+ rest of the published `reflex` package runs in.
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import os
20
+ import time
21
+ import urllib.error
22
+ import urllib.request
23
+ from dataclasses import dataclass
24
+ from typing import Callable, Optional
25
+
26
+ DEFAULT_REGION_ENV = "REFLEX_INFERENCE_REGION_MAP"
27
+ DEFAULT_PROBE_ATTEMPTS = 5
28
+ DEFAULT_PROBE_WARMUP = 1
29
+ DEFAULT_PROBE_TIMEOUT_S = 5.0
30
+ DEFAULT_PROBE_SETTLE_S = 0.05
31
+
32
+
33
+ @dataclass(frozen=True)
34
+ class RegionResult:
35
+ label: str
36
+ url: str
37
+ ok_count: int
38
+ attempts: int
39
+ p50_ms: Optional[float]
40
+ error_tail: Optional[str] = None
41
+
42
+
43
+ def parse_region_map(spec: str) -> list[tuple[str, str]]:
44
+ """Parse `region1=url1,region2=url2[,...]` into `[(label, url), ...]`.
45
+
46
+ Empty input → []. Malformed entries raise `ValueError` so a typo in
47
+ the env var fails loudly instead of silently routing to the wrong
48
+ region.
49
+ """
50
+ if not spec or not spec.strip():
51
+ return []
52
+ items: list[tuple[str, str]] = []
53
+ for raw in spec.split(","):
54
+ raw = raw.strip()
55
+ if not raw:
56
+ continue
57
+ if "=" not in raw:
58
+ raise ValueError(f"region map entry missing '=': {raw!r}")
59
+ label, url = raw.split("=", 1)
60
+ label, url = label.strip(), url.strip()
61
+ if not label or not url:
62
+ raise ValueError(f"empty label or url in region map entry: {raw!r}")
63
+ items.append((label, url))
64
+ return items
65
+
66
+
67
+ def _pct(samples: list[float], q: float) -> Optional[float]:
68
+ if not samples:
69
+ return None
70
+ s = sorted(samples)
71
+ idx = max(0, min(len(s) - 1, int(round(q * (len(s) - 1)))))
72
+ return round(s[idx], 2)
73
+
74
+
75
+ def _probe_once(url: str, timeout_s: float) -> tuple[Optional[float], Optional[str]]:
76
+ started = time.perf_counter()
77
+ try:
78
+ req = urllib.request.Request(
79
+ url, headers={"User-Agent": "reflex-sdk-region-probe/1.0"}
80
+ )
81
+ with urllib.request.urlopen(req, timeout=timeout_s) as response:
82
+ response.read() # drain so the conn closes cleanly
83
+ if not (200 <= response.status < 300):
84
+ return None, f"HTTP {response.status}"
85
+ except urllib.error.HTTPError as exc:
86
+ return None, f"HTTPError {exc.code}: {exc.reason}"
87
+ except Exception as exc: # noqa: BLE001 — any failure → unviable target
88
+ return None, f"{type(exc).__name__}: {exc}"
89
+ return (time.perf_counter() - started) * 1000.0, None
90
+
91
+
92
+ def probe_region(
93
+ label: str,
94
+ url: str,
95
+ *,
96
+ attempts: int = DEFAULT_PROBE_ATTEMPTS,
97
+ warmup: int = DEFAULT_PROBE_WARMUP,
98
+ timeout_s: float = DEFAULT_PROBE_TIMEOUT_S,
99
+ settle_s: float = DEFAULT_PROBE_SETTLE_S,
100
+ ) -> RegionResult:
101
+ """Hit `<url>/health` `warmup + attempts` times; return per-region stats.
102
+
103
+ Warmup probes amortise TCP/TLS handshake + Modal proxy cold-route so
104
+ the p50 reflects steady-state RTT.
105
+ """
106
+ health = url.rstrip("/") + "/health"
107
+ for _ in range(warmup):
108
+ _probe_once(health, timeout_s)
109
+ if settle_s > 0:
110
+ time.sleep(settle_s)
111
+ samples: list[float] = []
112
+ last_err: Optional[str] = None
113
+ for _ in range(attempts):
114
+ rtt, err = _probe_once(health, timeout_s)
115
+ if rtt is not None:
116
+ samples.append(rtt)
117
+ elif err is not None:
118
+ last_err = err
119
+ if settle_s > 0:
120
+ time.sleep(settle_s)
121
+ return RegionResult(
122
+ label=label,
123
+ url=url,
124
+ attempts=attempts,
125
+ ok_count=len(samples),
126
+ p50_ms=_pct(samples, 0.5),
127
+ error_tail=last_err if not samples else None,
128
+ )
129
+
130
+
131
+ def pick_nearest(
132
+ region_map: list[tuple[str, str]],
133
+ *,
134
+ attempts: int = DEFAULT_PROBE_ATTEMPTS,
135
+ warmup: int = DEFAULT_PROBE_WARMUP,
136
+ timeout_s: float = DEFAULT_PROBE_TIMEOUT_S,
137
+ settle_s: float = DEFAULT_PROBE_SETTLE_S,
138
+ log: Optional[Callable[[RegionResult], None]] = None,
139
+ ) -> tuple[str, str, list[RegionResult]]:
140
+ """Return `(label, url, all_results)` for the lowest-p50 region.
141
+
142
+ Single-region maps skip the probe entirely. All-fail falls back to
143
+ the first entry so the caller always gets a usable URL — better to
144
+ attempt against a known-stale endpoint than refuse to start.
145
+ """
146
+ if not region_map:
147
+ raise ValueError("region_map is empty")
148
+ if len(region_map) == 1:
149
+ label, url = region_map[0]
150
+ result = RegionResult(
151
+ label=label, url=url, attempts=0, ok_count=0, p50_ms=None,
152
+ )
153
+ if log is not None:
154
+ log(result)
155
+ return label, url, [result]
156
+ results: list[RegionResult] = []
157
+ for label, url in region_map:
158
+ r = probe_region(
159
+ label, url,
160
+ attempts=attempts, warmup=warmup,
161
+ timeout_s=timeout_s, settle_s=settle_s,
162
+ )
163
+ if log is not None:
164
+ log(r)
165
+ results.append(r)
166
+ viable = [r for r in results if r.p50_ms is not None]
167
+ if not viable:
168
+ first = results[0]
169
+ return first.label, first.url, results
170
+ best = min(viable, key=lambda r: r.p50_ms) # type: ignore[arg-type,return-value]
171
+ return best.label, best.url, results
172
+
173
+
174
+ def resolve_url(
175
+ explicit_url: Optional[str],
176
+ *,
177
+ region_map: Optional[list[tuple[str, str]]] = None,
178
+ env_var: str = DEFAULT_REGION_ENV,
179
+ fallback_url: Optional[str] = None,
180
+ attempts: int = DEFAULT_PROBE_ATTEMPTS,
181
+ warmup: int = DEFAULT_PROBE_WARMUP,
182
+ timeout_s: float = DEFAULT_PROBE_TIMEOUT_S,
183
+ settle_s: float = DEFAULT_PROBE_SETTLE_S,
184
+ log: Optional[Callable[[RegionResult], None]] = None,
185
+ ) -> str:
186
+ """Resolve the URL to use, honouring explicit > region-map > fallback.
187
+
188
+ - `explicit_url` truthy → return as-is (server-returned URL wins
189
+ unless overridden by env).
190
+ - Else parse `region_map` (or `os.environ[env_var]` if not given).
191
+ - Empty map → return `fallback_url` (preserves current behaviour).
192
+ - Non-empty map → `pick_nearest` and return its URL.
193
+ """
194
+ if explicit_url:
195
+ return explicit_url
196
+ if region_map is None:
197
+ region_map = parse_region_map(os.environ.get(env_var, ""))
198
+ if not region_map:
199
+ if fallback_url is None:
200
+ raise RuntimeError(
201
+ f"resolve_url: env var {env_var} is empty and no fallback_url given"
202
+ )
203
+ return fallback_url
204
+ _, url, _ = pick_nearest(
205
+ region_map,
206
+ attempts=attempts, warmup=warmup,
207
+ timeout_s=timeout_s, settle_s=settle_s,
208
+ log=log,
209
+ )
210
+ return url
@@ -9,6 +9,7 @@ import os
9
9
  import socket
10
10
  import ssl
11
11
  import struct
12
+ import time
12
13
  from types import TracebackType
13
14
  from typing import Any
14
15
  from urllib import parse
@@ -17,6 +18,16 @@ from ._version import user_agent as _user_agent
17
18
  from ._transport import api_key as _api_key
18
19
  from ._transport import base_url as _base_url
19
20
 
21
+ RETRYABLE_HANDSHAKE_STATUS_CODES = {403, 404, 408, 425, 429, 500, 502, 503, 504}
22
+ DEFAULT_CONNECT_RETRY_SECONDS = 90.0
23
+
24
+
25
+ class WebSocketHandshakeError(RuntimeError):
26
+ def __init__(self, status_line: str) -> None:
27
+ super().__init__(f"WebSocket handshake failed: {status_line}")
28
+ self.status_line = status_line
29
+ self.status_code = _status_code(status_line)
30
+
20
31
 
21
32
  def _actions_ws_url(url: str | None) -> str:
22
33
  resolved = (url or os.environ.get("REFLEX_ACTIONS_URL", "")).strip()
@@ -25,6 +36,8 @@ def _actions_ws_url(url: str | None) -> str:
25
36
 
26
37
  parsed = parse.urlparse(resolved)
27
38
  if parsed.scheme in {"ws", "wss"}:
39
+ if parsed.query:
40
+ return resolved
28
41
  if parsed.path and parsed.path != "/":
29
42
  return resolved
30
43
  return f"{resolved.rstrip('/')}/v1/actions"
@@ -35,7 +48,42 @@ def _actions_ws_url(url: str | None) -> str:
35
48
  raise ValueError("Actions URL must start with http://, https://, ws://, or wss://.")
36
49
 
37
50
 
38
- def _connect_websocket(url: str, *, api_key: str, timeout: float) -> socket.socket:
51
+ def _status_code(status_line: str) -> int | None:
52
+ pieces = status_line.split()
53
+ if len(pieces) < 2:
54
+ return None
55
+ try:
56
+ return int(pieces[1])
57
+ except ValueError:
58
+ return None
59
+
60
+
61
+ def _connect_retry_seconds(timeout: float) -> float:
62
+ raw = os.environ.get("REFLEX_WS_CONNECT_RETRY_SECONDS", "").strip()
63
+ if raw:
64
+ try:
65
+ configured = float(raw)
66
+ except ValueError:
67
+ configured = DEFAULT_CONNECT_RETRY_SECONDS
68
+ else:
69
+ configured = DEFAULT_CONNECT_RETRY_SECONDS
70
+ return max(0.0, min(max(0.0, timeout), configured))
71
+
72
+
73
+ def _retryable_connect_error(exc: BaseException) -> bool:
74
+ if isinstance(exc, WebSocketHandshakeError):
75
+ return exc.status_code in RETRYABLE_HANDSHAKE_STATUS_CODES
76
+ return isinstance(exc, (TimeoutError, socket.timeout, OSError))
77
+
78
+
79
+ def _connect_websocket(
80
+ url: str,
81
+ *,
82
+ api_key: str,
83
+ timeout: float,
84
+ timing: dict[str, float] | None = None,
85
+ ) -> socket.socket:
86
+ connect_started = time.monotonic()
39
87
  parsed = parse.urlparse(url)
40
88
  if parsed.scheme not in {"ws", "wss"} or not parsed.hostname:
41
89
  raise ValueError("Actions URL must be a ws:// or wss:// URL.")
@@ -45,17 +93,27 @@ def _connect_websocket(url: str, *, api_key: str, timeout: float) -> socket.sock
45
93
  if parsed.query:
46
94
  path = f"{path}?{parsed.query}"
47
95
 
96
+ tcp_started = time.monotonic()
48
97
  raw_sock = socket.create_connection((parsed.hostname, port), timeout=timeout)
98
+ if timing is not None:
99
+ timing["tcp_ms"] = (time.monotonic() - tcp_started) * 1000.0
49
100
  sock: socket.socket
50
101
  if parsed.scheme == "wss":
102
+ tls_started = time.monotonic()
51
103
  context = ssl.create_default_context()
52
104
  sock = context.wrap_socket(raw_sock, server_hostname=parsed.hostname)
105
+ if timing is not None:
106
+ timing["tls_ms"] = (time.monotonic() - tls_started) * 1000.0
53
107
  else:
54
108
  sock = raw_sock
109
+ if timing is not None:
110
+ timing["tls_ms"] = 0.0
55
111
  sock.settimeout(timeout)
56
112
 
113
+ upgrade_started = time.monotonic()
57
114
  key = base64.b64encode(os.urandom(16)).decode("ascii")
58
115
  host = parsed.hostname if parsed.port is None else f"{parsed.hostname}:{port}"
116
+ auth_header = f"Authorization: Bearer {api_key}\r\n" if api_key else ""
59
117
  request_text = (
60
118
  f"GET {path} HTTP/1.1\r\n"
61
119
  f"Host: {host}\r\n"
@@ -63,7 +121,7 @@ def _connect_websocket(url: str, *, api_key: str, timeout: float) -> socket.sock
63
121
  "Connection: Upgrade\r\n"
64
122
  f"Sec-WebSocket-Key: {key}\r\n"
65
123
  "Sec-WebSocket-Version: 13\r\n"
66
- f"Authorization: Bearer {api_key}\r\n"
124
+ f"{auth_header}"
67
125
  f"User-Agent: {_user_agent('reflex-actions-sdk')}\r\n"
68
126
  "\r\n"
69
127
  )
@@ -79,7 +137,7 @@ def _connect_websocket(url: str, *, api_key: str, timeout: float) -> socket.sock
79
137
  status_line = header_text.split("\r\n", 1)[0]
80
138
  if " 101 " not in status_line:
81
139
  sock.close()
82
- raise RuntimeError(f"WebSocket handshake failed: {status_line}")
140
+ raise WebSocketHandshakeError(status_line)
83
141
 
84
142
  accept_source = (key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11").encode("ascii")
85
143
  expected_accept = base64.b64encode(hashlib.sha1(accept_source).digest()).decode("ascii")
@@ -91,6 +149,9 @@ def _connect_websocket(url: str, *, api_key: str, timeout: float) -> socket.sock
91
149
  if headers.get("sec-websocket-accept") != expected_accept:
92
150
  sock.close()
93
151
  raise RuntimeError("WebSocket handshake failed: invalid Sec-WebSocket-Accept header.")
152
+ if timing is not None:
153
+ timing["websocket_upgrade_ms"] = (time.monotonic() - upgrade_started) * 1000.0
154
+ timing["connect_ms"] = (time.monotonic() - connect_started) * 1000.0
94
155
  return sock
95
156
 
96
157
 
@@ -176,6 +237,11 @@ def _normalize_images(images: dict[str, Any] | None) -> dict[str, Any]:
176
237
  return normalized
177
238
 
178
239
 
240
+ def _query_has_token(url: str) -> bool:
241
+ parsed = parse.urlparse(url)
242
+ return bool(parse.parse_qs(parsed.query).get("token"))
243
+
244
+
179
245
  class ActionStream:
180
246
  """Long-lived Reflex actions session.
181
247
 
@@ -193,25 +259,36 @@ class ActionStream:
193
259
  prompt: str,
194
260
  model: str | None = None,
195
261
  lora: str | None = None,
262
+ robot: str | None = None,
263
+ action_adapter: str | None = None,
196
264
  cameras: list[str] | None = None,
197
265
  hz: float | None = None,
198
266
  chunk_size: int | None = None,
199
267
  max_gpu_seconds: float | None = None,
200
268
  session_id: str = "",
201
269
  timeout: float = 30.0,
270
+ connect_retry_seconds: float | None = None,
202
271
  ) -> None:
203
272
  self.url = _actions_ws_url(url)
204
- self.api_key = _api_key(api_key)
273
+ self.api_key = (
274
+ ""
275
+ if _query_has_token(self.url) and not (api_key or "").strip()
276
+ else _api_key(api_key)
277
+ )
205
278
  self.prompt = prompt
206
279
  self.model = model
207
280
  self.lora = lora
281
+ self.robot = robot
282
+ self.action_adapter = action_adapter
208
283
  self.cameras = cameras or []
209
284
  self.hz = hz
210
285
  self.chunk_size = chunk_size
211
286
  self.max_gpu_seconds = max_gpu_seconds
212
287
  self.session_id = session_id
213
288
  self.timeout = timeout
289
+ self.connect_retry_seconds = connect_retry_seconds
214
290
  self.ready: dict[str, Any] | None = None
291
+ self.open_timing: dict[str, float] = {}
215
292
  self._sock: socket.socket | None = None
216
293
  self._next_seq = 0
217
294
 
@@ -227,24 +304,61 @@ class ActionStream:
227
304
  self.close()
228
305
 
229
306
  def open(self) -> ActionStream:
230
- self._sock = _connect_websocket(self.url, api_key=self.api_key, timeout=self.timeout)
231
- self.send_raw({
232
- "type": "session.open",
233
- "session_id": self.session_id,
234
- "prompt": self.prompt,
235
- "model": self.model,
236
- "lora": self.lora,
237
- "cameras": self.cameras,
238
- "hz": self.hz,
239
- "chunk_size": self.chunk_size,
240
- "max_gpu_seconds": self.max_gpu_seconds,
241
- })
242
- ready = self.receive()
243
- if ready.get("type") != "session.ready":
244
- raise RuntimeError(f"Expected session.ready, received {ready!r}")
245
- self.ready = ready
246
- self.session_id = str(ready.get("session_id") or self.session_id)
247
- return self
307
+ retry_seconds = (
308
+ _connect_retry_seconds(self.timeout)
309
+ if self.connect_retry_seconds is None
310
+ else max(0.0, min(max(0.0, self.timeout), self.connect_retry_seconds))
311
+ )
312
+ deadline = time.monotonic() + retry_seconds
313
+ delay = 0.5
314
+ while True:
315
+ open_started = time.monotonic()
316
+ timing: dict[str, float] = {}
317
+ try:
318
+ self._sock = _connect_websocket(
319
+ self.url,
320
+ api_key=self.api_key,
321
+ timeout=min(self.timeout, 15.0),
322
+ timing=timing,
323
+ )
324
+ ready_started = time.monotonic()
325
+ self.send_raw({
326
+ "type": "session.open",
327
+ "session_id": self.session_id,
328
+ "prompt": self.prompt,
329
+ "model": self.model,
330
+ "lora": self.lora,
331
+ "robot": self.robot,
332
+ "action_adapter": self.action_adapter,
333
+ "cameras": self.cameras,
334
+ "hz": self.hz,
335
+ "chunk_size": self.chunk_size,
336
+ "max_gpu_seconds": self.max_gpu_seconds,
337
+ })
338
+ ready = self.receive()
339
+ if ready.get("type") != "session.ready":
340
+ raise RuntimeError(f"Expected session.ready, received {ready!r}")
341
+ timing["session_ready_ms"] = (time.monotonic() - ready_started) * 1000.0
342
+ timing["open_ms"] = (time.monotonic() - open_started) * 1000.0
343
+ self.open_timing = timing
344
+ self.ready = ready
345
+ self.session_id = str(ready.get("session_id") or self.session_id)
346
+ return self
347
+ except Exception as exc:
348
+ if self._sock is not None:
349
+ try:
350
+ _close_websocket(self._sock)
351
+ except Exception:
352
+ try:
353
+ self._sock.close()
354
+ except Exception:
355
+ pass
356
+ self._sock = None
357
+ now = time.monotonic()
358
+ if retry_seconds <= 0 or not _retryable_connect_error(exc) or now >= deadline:
359
+ raise
360
+ time.sleep(min(delay, max(0.0, deadline - now)))
361
+ delay = min(delay * 1.6, 5.0)
248
362
 
249
363
  def send_raw(self, frame: dict[str, Any]) -> None:
250
364
  if self._sock is None:
@@ -428,4 +542,4 @@ def infer_actions(
428
542
  }
429
543
  with ActionStream(url=url, api_key=api_key, timeout=timeout, **session_fields) as stream:
430
544
  stream.send_observation_frame(observation)
431
- return stream.recv_action()
545
+ return stream.recv_action()