reflex-sdk 0.3.0__tar.gz → 0.3.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {reflex_sdk-0.3.0 → reflex_sdk-0.3.2}/PKG-INFO +8 -1
- {reflex_sdk-0.3.0 → reflex_sdk-0.3.2}/pyproject.toml +17 -3
- {reflex_sdk-0.3.0 → reflex_sdk-0.3.2}/reflex/__init__.py +41 -5
- reflex_sdk-0.3.2/reflex/_region_probe.py +210 -0
- {reflex_sdk-0.3.0 → reflex_sdk-0.3.2}/reflex/actions.py +137 -23
- reflex_sdk-0.3.2/reflex/auth_runner.py +146 -0
- reflex_sdk-0.3.2/reflex/cameras/__init__.py +39 -0
- reflex_sdk-0.3.2/reflex/cameras/base.py +39 -0
- reflex_sdk-0.3.2/reflex/cameras/realsense.py +73 -0
- reflex_sdk-0.3.2/reflex/cameras/shm.py +120 -0
- reflex_sdk-0.3.2/reflex/cameras/v4l2.py +75 -0
- {reflex_sdk-0.3.0 → reflex_sdk-0.3.2}/reflex/cli.py +982 -16
- {reflex_sdk-0.3.0 → reflex_sdk-0.3.2}/reflex/client.py +116 -29
- reflex_sdk-0.3.2/reflex/connect_runner.py +442 -0
- reflex_sdk-0.3.2/reflex/connectors/__init__.py +44 -0
- reflex_sdk-0.3.2/reflex/connectors/base.py +70 -0
- reflex_sdk-0.3.2/reflex/connectors/shell.py +67 -0
- reflex_sdk-0.3.2/reflex/connectors/yam_bimanual.py +456 -0
- reflex_sdk-0.3.2/reflex/deployments.py +143 -0
- reflex_sdk-0.3.2/reflex/models.py +207 -0
- reflex_sdk-0.3.2/reflex/receipts.py +34 -0
- reflex_sdk-0.3.2/reflex/robot_runtime.py +207 -0
- reflex_sdk-0.3.2/reflex/robots.py +194 -0
- reflex_sdk-0.3.2/reflex/sessions.py +113 -0
- reflex_sdk-0.3.2/reflex/so101.py +1024 -0
- reflex_sdk-0.3.2/reflex/transports/__init__.py +44 -0
- reflex_sdk-0.3.2/reflex/transports/_webrtc_client.py +679 -0
- reflex_sdk-0.3.2/reflex/transports/_webrtc_streaming_client.py +326 -0
- reflex_sdk-0.3.2/reflex/transports/base.py +40 -0
- reflex_sdk-0.3.2/reflex/transports/edge_http.py +174 -0
- reflex_sdk-0.3.2/reflex/transports/platform.py +157 -0
- reflex_sdk-0.3.2/reflex/transports/webrtc.py +275 -0
- {reflex_sdk-0.3.0 → reflex_sdk-0.3.2}/reflex_sdk.egg-info/PKG-INFO +8 -1
- reflex_sdk-0.3.2/reflex_sdk.egg-info/SOURCES.txt +53 -0
- {reflex_sdk-0.3.0 → reflex_sdk-0.3.2}/reflex_sdk.egg-info/requires.txt +8 -0
- reflex_sdk-0.3.2/tests/test_connect_runner.py +352 -0
- reflex_sdk-0.3.2/tests/test_prod1_smoke.py +53 -0
- reflex_sdk-0.3.2/tests/test_product_cli.py +2242 -0
- reflex_sdk-0.3.2/tests/test_public_sdk.py +638 -0
- reflex_sdk-0.3.2/tests/test_region_probe.py +278 -0
- reflex_sdk-0.3.2/tests/test_so101_actions.py +148 -0
- reflex_sdk-0.3.0/reflex/models.py +0 -106
- reflex_sdk-0.3.0/reflex_sdk.egg-info/SOURCES.txt +0 -24
- reflex_sdk-0.3.0/tests/test_product_cli.py +0 -940
- reflex_sdk-0.3.0/tests/test_public_sdk.py +0 -207
- {reflex_sdk-0.3.0 → reflex_sdk-0.3.2}/README.md +0 -0
- {reflex_sdk-0.3.0 → reflex_sdk-0.3.2}/reflex/__main__.py +0 -0
- {reflex_sdk-0.3.0 → reflex_sdk-0.3.2}/reflex/_convex.py +0 -0
- {reflex_sdk-0.3.0 → reflex_sdk-0.3.2}/reflex/_transport.py +0 -0
- {reflex_sdk-0.3.0 → reflex_sdk-0.3.2}/reflex/_version.py +0 -0
- {reflex_sdk-0.3.0 → reflex_sdk-0.3.2}/reflex/datasets.py +0 -0
- {reflex_sdk-0.3.0 → reflex_sdk-0.3.2}/reflex/instances.py +0 -0
- {reflex_sdk-0.3.0 → reflex_sdk-0.3.2}/reflex/keys.py +0 -0
- {reflex_sdk-0.3.0 → reflex_sdk-0.3.2}/reflex/product.py +0 -0
- {reflex_sdk-0.3.0 → reflex_sdk-0.3.2}/reflex/training.py +0 -0
- {reflex_sdk-0.3.0 → reflex_sdk-0.3.2}/reflex_sdk.egg-info/dependency_links.txt +0 -0
- {reflex_sdk-0.3.0 → reflex_sdk-0.3.2}/reflex_sdk.egg-info/entry_points.txt +0 -0
- {reflex_sdk-0.3.0 → reflex_sdk-0.3.2}/reflex_sdk.egg-info/top_level.txt +0 -0
- {reflex_sdk-0.3.0 → reflex_sdk-0.3.2}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: reflex-sdk
|
|
3
|
-
Version: 0.3.
|
|
3
|
+
Version: 0.3.2
|
|
4
4
|
Summary: Python SDK for Reflex hosted robot inference and training
|
|
5
5
|
Author: Reflex
|
|
6
6
|
Project-URL: Homepage, https://tryreflex.ai
|
|
@@ -15,8 +15,15 @@ Classifier: Programming Language :: Python :: 3.11
|
|
|
15
15
|
Classifier: Programming Language :: Python :: 3.12
|
|
16
16
|
Requires-Python: >=3.9
|
|
17
17
|
Description-Content-Type: text/markdown
|
|
18
|
+
Requires-Dist: PyYAML<7,>=6.0
|
|
18
19
|
Requires-Dist: typer<0.24,>=0.23; python_version < "3.10"
|
|
19
20
|
Requires-Dist: typer>=0.24.1; python_version >= "3.10"
|
|
21
|
+
Provides-Extra: webrtc
|
|
22
|
+
Requires-Dist: aiortc>=1.8; extra == "webrtc"
|
|
23
|
+
Requires-Dist: av>=11; extra == "webrtc"
|
|
24
|
+
Requires-Dist: msgpack>=1.0; extra == "webrtc"
|
|
25
|
+
Requires-Dist: numpy>=1.24; extra == "webrtc"
|
|
26
|
+
Requires-Dist: Pillow>=10; extra == "webrtc"
|
|
20
27
|
|
|
21
28
|
# Reflex Python SDK
|
|
22
29
|
|
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "reflex-sdk"
|
|
3
|
-
version = "0.3.
|
|
3
|
+
version = "0.3.2"
|
|
4
4
|
description = "Python SDK for Reflex hosted robot inference and training"
|
|
5
5
|
readme = "README.md"
|
|
6
6
|
requires-python = ">=3.9"
|
|
7
7
|
dependencies = [
|
|
8
|
+
"PyYAML>=6.0,<7",
|
|
8
9
|
"typer>=0.23,<0.24; python_version < '3.10'",
|
|
9
10
|
"typer>=0.24.1; python_version >= '3.10'",
|
|
10
11
|
]
|
|
@@ -22,6 +23,19 @@ classifiers = [
|
|
|
22
23
|
"Programming Language :: Python :: 3.12",
|
|
23
24
|
]
|
|
24
25
|
|
|
26
|
+
[project.optional-dependencies]
|
|
27
|
+
# Pulled in by `pip install reflex-sdk[webrtc]`. Required when using
|
|
28
|
+
# `target.kind: webrtc` in a `reflex connect` config — the WebRTCTransport
|
|
29
|
+
# imports numpy/Pillow at module top, and the vendored client uses aiortc
|
|
30
|
+
# (peer connection + H.264 video tracks) and msgpack (DataChannel wire).
|
|
31
|
+
webrtc = [
|
|
32
|
+
"aiortc>=1.8",
|
|
33
|
+
"av>=11",
|
|
34
|
+
"msgpack>=1.0",
|
|
35
|
+
"numpy>=1.24",
|
|
36
|
+
"Pillow>=10",
|
|
37
|
+
]
|
|
38
|
+
|
|
25
39
|
[project.urls]
|
|
26
40
|
Homepage = "https://tryreflex.ai"
|
|
27
41
|
Repository = "https://github.com/reflex-inc/reflex"
|
|
@@ -30,11 +44,11 @@ Repository = "https://github.com/reflex-inc/reflex"
|
|
|
30
44
|
reflex = "reflex.cli:main_reflex"
|
|
31
45
|
|
|
32
46
|
[tool.setuptools]
|
|
33
|
-
packages = ["reflex"]
|
|
47
|
+
packages = ["reflex", "reflex.cameras", "reflex.connectors", "reflex.transports"]
|
|
34
48
|
|
|
35
49
|
[tool.ruff]
|
|
36
50
|
line-length = 100
|
|
37
51
|
|
|
38
52
|
[build-system]
|
|
39
53
|
requires = ["setuptools>=68"]
|
|
40
|
-
build-backend = "setuptools.build_meta"
|
|
54
|
+
build-backend = "setuptools.build_meta"
|
|
@@ -3,6 +3,7 @@
|
|
|
3
3
|
from ._version import __version__
|
|
4
4
|
from .actions import ActionStream, action, connect, infer_actions, observation
|
|
5
5
|
from .client import Client
|
|
6
|
+
# TODO PROD.2: re-add ConnectError/ConnectRunner/SessionGrant once auth wrapper is integrated
|
|
6
7
|
from .datasets import (
|
|
7
8
|
complete_dataset,
|
|
8
9
|
create_dataset,
|
|
@@ -12,9 +13,26 @@ from .datasets import (
|
|
|
12
13
|
upload_dataset,
|
|
13
14
|
validate_dataset,
|
|
14
15
|
)
|
|
16
|
+
from .deployments import (
|
|
17
|
+
create_deployment,
|
|
18
|
+
create_deployment_from_spec,
|
|
19
|
+
get_deployment,
|
|
20
|
+
list_deployments,
|
|
21
|
+
run_deployment_doctor,
|
|
22
|
+
)
|
|
15
23
|
from .instances import instance_status, provision_instance, teardown_instance
|
|
16
|
-
from .
|
|
17
|
-
from .
|
|
24
|
+
from .receipts import list_receipts
|
|
25
|
+
from .robots import (
|
|
26
|
+
claim_pairing_token,
|
|
27
|
+
create_pairing_token,
|
|
28
|
+
heartbeat_robot,
|
|
29
|
+
list_robot_schemas,
|
|
30
|
+
list_robots,
|
|
31
|
+
register_robot,
|
|
32
|
+
register_robot_schema,
|
|
33
|
+
)
|
|
34
|
+
from .robot_runtime import RobotExecutionConfig, RobotExecutionResult, run_robot_execution_loop
|
|
35
|
+
from .sessions import close_session, list_sessions, promote_session, start_session
|
|
18
36
|
from .training import (
|
|
19
37
|
AdamParams,
|
|
20
38
|
AdapterHandle,
|
|
@@ -37,35 +55,53 @@ __all__ = [
|
|
|
37
55
|
"AdamParams",
|
|
38
56
|
"AdapterHandle",
|
|
39
57
|
"Client",
|
|
58
|
+
"claim_pairing_token",
|
|
40
59
|
"Datum",
|
|
41
60
|
"ForwardBackwardResult",
|
|
42
61
|
"LoraTrainingClient",
|
|
43
62
|
"OptimStepResult",
|
|
63
|
+
"RobotExecutionConfig",
|
|
64
|
+
"RobotExecutionResult",
|
|
44
65
|
"ServiceClient",
|
|
66
|
+
"SessionGrant",
|
|
45
67
|
"__version__",
|
|
46
68
|
"action",
|
|
47
69
|
"bind_key_to_model",
|
|
48
70
|
"cancel_training_job",
|
|
71
|
+
"close_session",
|
|
49
72
|
"complete_dataset",
|
|
50
73
|
"connect",
|
|
51
74
|
"create_dataset",
|
|
75
|
+
"create_deployment",
|
|
76
|
+
"create_deployment_from_spec",
|
|
77
|
+
"create_pairing_token",
|
|
52
78
|
"create_training_job",
|
|
53
79
|
"delete_model",
|
|
54
80
|
"full_finetune",
|
|
55
81
|
"full_train",
|
|
56
82
|
"get_dataset",
|
|
57
|
-
"
|
|
83
|
+
"get_deployment",
|
|
58
84
|
"get_training_job",
|
|
59
|
-
"
|
|
85
|
+
"heartbeat_robot",
|
|
60
86
|
"infer_actions",
|
|
61
87
|
"instance_status",
|
|
62
|
-
"
|
|
88
|
+
"list_sessions",
|
|
63
89
|
"list_training_jobs",
|
|
64
90
|
"list_datasets",
|
|
91
|
+
"list_deployments",
|
|
92
|
+
"list_receipts",
|
|
93
|
+
"list_robot_schemas",
|
|
94
|
+
"list_robots",
|
|
65
95
|
"lora_finetune",
|
|
66
96
|
"observation",
|
|
67
97
|
"provision_instance",
|
|
98
|
+
"promote_session",
|
|
68
99
|
"register_huggingface_dataset",
|
|
100
|
+
"register_robot",
|
|
101
|
+
"register_robot_schema",
|
|
102
|
+
"run_deployment_doctor",
|
|
103
|
+
"run_robot_execution_loop",
|
|
104
|
+
"start_session",
|
|
69
105
|
"teardown_instance",
|
|
70
106
|
"upload_dataset",
|
|
71
107
|
"validate_dataset",
|
|
@@ -0,0 +1,210 @@
|
|
|
1
|
+
"""Client-side nearest-region probe for the Reflex SDK CLI (CFG.1).
|
|
2
|
+
|
|
3
|
+
When `REFLEX_INFERENCE_REGION_MAP=region=URL,region=URL,...` is set, the
|
|
4
|
+
CLI hits each URL's `/health` at startup, picks the lowest-p50 region,
|
|
5
|
+
and uses that URL for the inference stream instead of the one the
|
|
6
|
+
platform returned. This buys ~20-60ms RTT on multi-region deploys
|
|
7
|
+
without server-side changes (NetR.3a finding, re-nr7l).
|
|
8
|
+
|
|
9
|
+
Single-entry maps skip the probe entirely (zero startup cost), so the
|
|
10
|
+
default single-region behaviour is unchanged until a multi-region map
|
|
11
|
+
is configured.
|
|
12
|
+
|
|
13
|
+
Stdlib-only (urllib + statistics) so it runs in any environment the
|
|
14
|
+
rest of the published `reflex` package runs in.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
import os
|
|
20
|
+
import time
|
|
21
|
+
import urllib.error
|
|
22
|
+
import urllib.request
|
|
23
|
+
from dataclasses import dataclass
|
|
24
|
+
from typing import Callable, Optional
|
|
25
|
+
|
|
26
|
+
DEFAULT_REGION_ENV = "REFLEX_INFERENCE_REGION_MAP"
|
|
27
|
+
DEFAULT_PROBE_ATTEMPTS = 5
|
|
28
|
+
DEFAULT_PROBE_WARMUP = 1
|
|
29
|
+
DEFAULT_PROBE_TIMEOUT_S = 5.0
|
|
30
|
+
DEFAULT_PROBE_SETTLE_S = 0.05
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@dataclass(frozen=True)
|
|
34
|
+
class RegionResult:
|
|
35
|
+
label: str
|
|
36
|
+
url: str
|
|
37
|
+
ok_count: int
|
|
38
|
+
attempts: int
|
|
39
|
+
p50_ms: Optional[float]
|
|
40
|
+
error_tail: Optional[str] = None
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def parse_region_map(spec: str) -> list[tuple[str, str]]:
|
|
44
|
+
"""Parse `region1=url1,region2=url2[,...]` into `[(label, url), ...]`.
|
|
45
|
+
|
|
46
|
+
Empty input → []. Malformed entries raise `ValueError` so a typo in
|
|
47
|
+
the env var fails loudly instead of silently routing to the wrong
|
|
48
|
+
region.
|
|
49
|
+
"""
|
|
50
|
+
if not spec or not spec.strip():
|
|
51
|
+
return []
|
|
52
|
+
items: list[tuple[str, str]] = []
|
|
53
|
+
for raw in spec.split(","):
|
|
54
|
+
raw = raw.strip()
|
|
55
|
+
if not raw:
|
|
56
|
+
continue
|
|
57
|
+
if "=" not in raw:
|
|
58
|
+
raise ValueError(f"region map entry missing '=': {raw!r}")
|
|
59
|
+
label, url = raw.split("=", 1)
|
|
60
|
+
label, url = label.strip(), url.strip()
|
|
61
|
+
if not label or not url:
|
|
62
|
+
raise ValueError(f"empty label or url in region map entry: {raw!r}")
|
|
63
|
+
items.append((label, url))
|
|
64
|
+
return items
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def _pct(samples: list[float], q: float) -> Optional[float]:
|
|
68
|
+
if not samples:
|
|
69
|
+
return None
|
|
70
|
+
s = sorted(samples)
|
|
71
|
+
idx = max(0, min(len(s) - 1, int(round(q * (len(s) - 1)))))
|
|
72
|
+
return round(s[idx], 2)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def _probe_once(url: str, timeout_s: float) -> tuple[Optional[float], Optional[str]]:
|
|
76
|
+
started = time.perf_counter()
|
|
77
|
+
try:
|
|
78
|
+
req = urllib.request.Request(
|
|
79
|
+
url, headers={"User-Agent": "reflex-sdk-region-probe/1.0"}
|
|
80
|
+
)
|
|
81
|
+
with urllib.request.urlopen(req, timeout=timeout_s) as response:
|
|
82
|
+
response.read() # drain so the conn closes cleanly
|
|
83
|
+
if not (200 <= response.status < 300):
|
|
84
|
+
return None, f"HTTP {response.status}"
|
|
85
|
+
except urllib.error.HTTPError as exc:
|
|
86
|
+
return None, f"HTTPError {exc.code}: {exc.reason}"
|
|
87
|
+
except Exception as exc: # noqa: BLE001 — any failure → unviable target
|
|
88
|
+
return None, f"{type(exc).__name__}: {exc}"
|
|
89
|
+
return (time.perf_counter() - started) * 1000.0, None
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def probe_region(
|
|
93
|
+
label: str,
|
|
94
|
+
url: str,
|
|
95
|
+
*,
|
|
96
|
+
attempts: int = DEFAULT_PROBE_ATTEMPTS,
|
|
97
|
+
warmup: int = DEFAULT_PROBE_WARMUP,
|
|
98
|
+
timeout_s: float = DEFAULT_PROBE_TIMEOUT_S,
|
|
99
|
+
settle_s: float = DEFAULT_PROBE_SETTLE_S,
|
|
100
|
+
) -> RegionResult:
|
|
101
|
+
"""Hit `<url>/health` `warmup + attempts` times; return per-region stats.
|
|
102
|
+
|
|
103
|
+
Warmup probes amortise TCP/TLS handshake + Modal proxy cold-route so
|
|
104
|
+
the p50 reflects steady-state RTT.
|
|
105
|
+
"""
|
|
106
|
+
health = url.rstrip("/") + "/health"
|
|
107
|
+
for _ in range(warmup):
|
|
108
|
+
_probe_once(health, timeout_s)
|
|
109
|
+
if settle_s > 0:
|
|
110
|
+
time.sleep(settle_s)
|
|
111
|
+
samples: list[float] = []
|
|
112
|
+
last_err: Optional[str] = None
|
|
113
|
+
for _ in range(attempts):
|
|
114
|
+
rtt, err = _probe_once(health, timeout_s)
|
|
115
|
+
if rtt is not None:
|
|
116
|
+
samples.append(rtt)
|
|
117
|
+
elif err is not None:
|
|
118
|
+
last_err = err
|
|
119
|
+
if settle_s > 0:
|
|
120
|
+
time.sleep(settle_s)
|
|
121
|
+
return RegionResult(
|
|
122
|
+
label=label,
|
|
123
|
+
url=url,
|
|
124
|
+
attempts=attempts,
|
|
125
|
+
ok_count=len(samples),
|
|
126
|
+
p50_ms=_pct(samples, 0.5),
|
|
127
|
+
error_tail=last_err if not samples else None,
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def pick_nearest(
|
|
132
|
+
region_map: list[tuple[str, str]],
|
|
133
|
+
*,
|
|
134
|
+
attempts: int = DEFAULT_PROBE_ATTEMPTS,
|
|
135
|
+
warmup: int = DEFAULT_PROBE_WARMUP,
|
|
136
|
+
timeout_s: float = DEFAULT_PROBE_TIMEOUT_S,
|
|
137
|
+
settle_s: float = DEFAULT_PROBE_SETTLE_S,
|
|
138
|
+
log: Optional[Callable[[RegionResult], None]] = None,
|
|
139
|
+
) -> tuple[str, str, list[RegionResult]]:
|
|
140
|
+
"""Return `(label, url, all_results)` for the lowest-p50 region.
|
|
141
|
+
|
|
142
|
+
Single-region maps skip the probe entirely. All-fail falls back to
|
|
143
|
+
the first entry so the caller always gets a usable URL — better to
|
|
144
|
+
attempt against a known-stale endpoint than refuse to start.
|
|
145
|
+
"""
|
|
146
|
+
if not region_map:
|
|
147
|
+
raise ValueError("region_map is empty")
|
|
148
|
+
if len(region_map) == 1:
|
|
149
|
+
label, url = region_map[0]
|
|
150
|
+
result = RegionResult(
|
|
151
|
+
label=label, url=url, attempts=0, ok_count=0, p50_ms=None,
|
|
152
|
+
)
|
|
153
|
+
if log is not None:
|
|
154
|
+
log(result)
|
|
155
|
+
return label, url, [result]
|
|
156
|
+
results: list[RegionResult] = []
|
|
157
|
+
for label, url in region_map:
|
|
158
|
+
r = probe_region(
|
|
159
|
+
label, url,
|
|
160
|
+
attempts=attempts, warmup=warmup,
|
|
161
|
+
timeout_s=timeout_s, settle_s=settle_s,
|
|
162
|
+
)
|
|
163
|
+
if log is not None:
|
|
164
|
+
log(r)
|
|
165
|
+
results.append(r)
|
|
166
|
+
viable = [r for r in results if r.p50_ms is not None]
|
|
167
|
+
if not viable:
|
|
168
|
+
first = results[0]
|
|
169
|
+
return first.label, first.url, results
|
|
170
|
+
best = min(viable, key=lambda r: r.p50_ms) # type: ignore[arg-type,return-value]
|
|
171
|
+
return best.label, best.url, results
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def resolve_url(
|
|
175
|
+
explicit_url: Optional[str],
|
|
176
|
+
*,
|
|
177
|
+
region_map: Optional[list[tuple[str, str]]] = None,
|
|
178
|
+
env_var: str = DEFAULT_REGION_ENV,
|
|
179
|
+
fallback_url: Optional[str] = None,
|
|
180
|
+
attempts: int = DEFAULT_PROBE_ATTEMPTS,
|
|
181
|
+
warmup: int = DEFAULT_PROBE_WARMUP,
|
|
182
|
+
timeout_s: float = DEFAULT_PROBE_TIMEOUT_S,
|
|
183
|
+
settle_s: float = DEFAULT_PROBE_SETTLE_S,
|
|
184
|
+
log: Optional[Callable[[RegionResult], None]] = None,
|
|
185
|
+
) -> str:
|
|
186
|
+
"""Resolve the URL to use, honouring explicit > region-map > fallback.
|
|
187
|
+
|
|
188
|
+
- `explicit_url` truthy → return as-is (server-returned URL wins
|
|
189
|
+
unless overridden by env).
|
|
190
|
+
- Else parse `region_map` (or `os.environ[env_var]` if not given).
|
|
191
|
+
- Empty map → return `fallback_url` (preserves current behaviour).
|
|
192
|
+
- Non-empty map → `pick_nearest` and return its URL.
|
|
193
|
+
"""
|
|
194
|
+
if explicit_url:
|
|
195
|
+
return explicit_url
|
|
196
|
+
if region_map is None:
|
|
197
|
+
region_map = parse_region_map(os.environ.get(env_var, ""))
|
|
198
|
+
if not region_map:
|
|
199
|
+
if fallback_url is None:
|
|
200
|
+
raise RuntimeError(
|
|
201
|
+
f"resolve_url: env var {env_var} is empty and no fallback_url given"
|
|
202
|
+
)
|
|
203
|
+
return fallback_url
|
|
204
|
+
_, url, _ = pick_nearest(
|
|
205
|
+
region_map,
|
|
206
|
+
attempts=attempts, warmup=warmup,
|
|
207
|
+
timeout_s=timeout_s, settle_s=settle_s,
|
|
208
|
+
log=log,
|
|
209
|
+
)
|
|
210
|
+
return url
|
|
@@ -9,6 +9,7 @@ import os
|
|
|
9
9
|
import socket
|
|
10
10
|
import ssl
|
|
11
11
|
import struct
|
|
12
|
+
import time
|
|
12
13
|
from types import TracebackType
|
|
13
14
|
from typing import Any
|
|
14
15
|
from urllib import parse
|
|
@@ -17,6 +18,16 @@ from ._version import user_agent as _user_agent
|
|
|
17
18
|
from ._transport import api_key as _api_key
|
|
18
19
|
from ._transport import base_url as _base_url
|
|
19
20
|
|
|
21
|
+
RETRYABLE_HANDSHAKE_STATUS_CODES = {403, 404, 408, 425, 429, 500, 502, 503, 504}
|
|
22
|
+
DEFAULT_CONNECT_RETRY_SECONDS = 90.0
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class WebSocketHandshakeError(RuntimeError):
|
|
26
|
+
def __init__(self, status_line: str) -> None:
|
|
27
|
+
super().__init__(f"WebSocket handshake failed: {status_line}")
|
|
28
|
+
self.status_line = status_line
|
|
29
|
+
self.status_code = _status_code(status_line)
|
|
30
|
+
|
|
20
31
|
|
|
21
32
|
def _actions_ws_url(url: str | None) -> str:
|
|
22
33
|
resolved = (url or os.environ.get("REFLEX_ACTIONS_URL", "")).strip()
|
|
@@ -25,6 +36,8 @@ def _actions_ws_url(url: str | None) -> str:
|
|
|
25
36
|
|
|
26
37
|
parsed = parse.urlparse(resolved)
|
|
27
38
|
if parsed.scheme in {"ws", "wss"}:
|
|
39
|
+
if parsed.query:
|
|
40
|
+
return resolved
|
|
28
41
|
if parsed.path and parsed.path != "/":
|
|
29
42
|
return resolved
|
|
30
43
|
return f"{resolved.rstrip('/')}/v1/actions"
|
|
@@ -35,7 +48,42 @@ def _actions_ws_url(url: str | None) -> str:
|
|
|
35
48
|
raise ValueError("Actions URL must start with http://, https://, ws://, or wss://.")
|
|
36
49
|
|
|
37
50
|
|
|
38
|
-
def
|
|
51
|
+
def _status_code(status_line: str) -> int | None:
|
|
52
|
+
pieces = status_line.split()
|
|
53
|
+
if len(pieces) < 2:
|
|
54
|
+
return None
|
|
55
|
+
try:
|
|
56
|
+
return int(pieces[1])
|
|
57
|
+
except ValueError:
|
|
58
|
+
return None
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _connect_retry_seconds(timeout: float) -> float:
|
|
62
|
+
raw = os.environ.get("REFLEX_WS_CONNECT_RETRY_SECONDS", "").strip()
|
|
63
|
+
if raw:
|
|
64
|
+
try:
|
|
65
|
+
configured = float(raw)
|
|
66
|
+
except ValueError:
|
|
67
|
+
configured = DEFAULT_CONNECT_RETRY_SECONDS
|
|
68
|
+
else:
|
|
69
|
+
configured = DEFAULT_CONNECT_RETRY_SECONDS
|
|
70
|
+
return max(0.0, min(max(0.0, timeout), configured))
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _retryable_connect_error(exc: BaseException) -> bool:
|
|
74
|
+
if isinstance(exc, WebSocketHandshakeError):
|
|
75
|
+
return exc.status_code in RETRYABLE_HANDSHAKE_STATUS_CODES
|
|
76
|
+
return isinstance(exc, (TimeoutError, socket.timeout, OSError))
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def _connect_websocket(
|
|
80
|
+
url: str,
|
|
81
|
+
*,
|
|
82
|
+
api_key: str,
|
|
83
|
+
timeout: float,
|
|
84
|
+
timing: dict[str, float] | None = None,
|
|
85
|
+
) -> socket.socket:
|
|
86
|
+
connect_started = time.monotonic()
|
|
39
87
|
parsed = parse.urlparse(url)
|
|
40
88
|
if parsed.scheme not in {"ws", "wss"} or not parsed.hostname:
|
|
41
89
|
raise ValueError("Actions URL must be a ws:// or wss:// URL.")
|
|
@@ -45,17 +93,27 @@ def _connect_websocket(url: str, *, api_key: str, timeout: float) -> socket.sock
|
|
|
45
93
|
if parsed.query:
|
|
46
94
|
path = f"{path}?{parsed.query}"
|
|
47
95
|
|
|
96
|
+
tcp_started = time.monotonic()
|
|
48
97
|
raw_sock = socket.create_connection((parsed.hostname, port), timeout=timeout)
|
|
98
|
+
if timing is not None:
|
|
99
|
+
timing["tcp_ms"] = (time.monotonic() - tcp_started) * 1000.0
|
|
49
100
|
sock: socket.socket
|
|
50
101
|
if parsed.scheme == "wss":
|
|
102
|
+
tls_started = time.monotonic()
|
|
51
103
|
context = ssl.create_default_context()
|
|
52
104
|
sock = context.wrap_socket(raw_sock, server_hostname=parsed.hostname)
|
|
105
|
+
if timing is not None:
|
|
106
|
+
timing["tls_ms"] = (time.monotonic() - tls_started) * 1000.0
|
|
53
107
|
else:
|
|
54
108
|
sock = raw_sock
|
|
109
|
+
if timing is not None:
|
|
110
|
+
timing["tls_ms"] = 0.0
|
|
55
111
|
sock.settimeout(timeout)
|
|
56
112
|
|
|
113
|
+
upgrade_started = time.monotonic()
|
|
57
114
|
key = base64.b64encode(os.urandom(16)).decode("ascii")
|
|
58
115
|
host = parsed.hostname if parsed.port is None else f"{parsed.hostname}:{port}"
|
|
116
|
+
auth_header = f"Authorization: Bearer {api_key}\r\n" if api_key else ""
|
|
59
117
|
request_text = (
|
|
60
118
|
f"GET {path} HTTP/1.1\r\n"
|
|
61
119
|
f"Host: {host}\r\n"
|
|
@@ -63,7 +121,7 @@ def _connect_websocket(url: str, *, api_key: str, timeout: float) -> socket.sock
|
|
|
63
121
|
"Connection: Upgrade\r\n"
|
|
64
122
|
f"Sec-WebSocket-Key: {key}\r\n"
|
|
65
123
|
"Sec-WebSocket-Version: 13\r\n"
|
|
66
|
-
f"
|
|
124
|
+
f"{auth_header}"
|
|
67
125
|
f"User-Agent: {_user_agent('reflex-actions-sdk')}\r\n"
|
|
68
126
|
"\r\n"
|
|
69
127
|
)
|
|
@@ -79,7 +137,7 @@ def _connect_websocket(url: str, *, api_key: str, timeout: float) -> socket.sock
|
|
|
79
137
|
status_line = header_text.split("\r\n", 1)[0]
|
|
80
138
|
if " 101 " not in status_line:
|
|
81
139
|
sock.close()
|
|
82
|
-
raise
|
|
140
|
+
raise WebSocketHandshakeError(status_line)
|
|
83
141
|
|
|
84
142
|
accept_source = (key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11").encode("ascii")
|
|
85
143
|
expected_accept = base64.b64encode(hashlib.sha1(accept_source).digest()).decode("ascii")
|
|
@@ -91,6 +149,9 @@ def _connect_websocket(url: str, *, api_key: str, timeout: float) -> socket.sock
|
|
|
91
149
|
if headers.get("sec-websocket-accept") != expected_accept:
|
|
92
150
|
sock.close()
|
|
93
151
|
raise RuntimeError("WebSocket handshake failed: invalid Sec-WebSocket-Accept header.")
|
|
152
|
+
if timing is not None:
|
|
153
|
+
timing["websocket_upgrade_ms"] = (time.monotonic() - upgrade_started) * 1000.0
|
|
154
|
+
timing["connect_ms"] = (time.monotonic() - connect_started) * 1000.0
|
|
94
155
|
return sock
|
|
95
156
|
|
|
96
157
|
|
|
@@ -176,6 +237,11 @@ def _normalize_images(images: dict[str, Any] | None) -> dict[str, Any]:
|
|
|
176
237
|
return normalized
|
|
177
238
|
|
|
178
239
|
|
|
240
|
+
def _query_has_token(url: str) -> bool:
|
|
241
|
+
parsed = parse.urlparse(url)
|
|
242
|
+
return bool(parse.parse_qs(parsed.query).get("token"))
|
|
243
|
+
|
|
244
|
+
|
|
179
245
|
class ActionStream:
|
|
180
246
|
"""Long-lived Reflex actions session.
|
|
181
247
|
|
|
@@ -193,25 +259,36 @@ class ActionStream:
|
|
|
193
259
|
prompt: str,
|
|
194
260
|
model: str | None = None,
|
|
195
261
|
lora: str | None = None,
|
|
262
|
+
robot: str | None = None,
|
|
263
|
+
action_adapter: str | None = None,
|
|
196
264
|
cameras: list[str] | None = None,
|
|
197
265
|
hz: float | None = None,
|
|
198
266
|
chunk_size: int | None = None,
|
|
199
267
|
max_gpu_seconds: float | None = None,
|
|
200
268
|
session_id: str = "",
|
|
201
269
|
timeout: float = 30.0,
|
|
270
|
+
connect_retry_seconds: float | None = None,
|
|
202
271
|
) -> None:
|
|
203
272
|
self.url = _actions_ws_url(url)
|
|
204
|
-
self.api_key =
|
|
273
|
+
self.api_key = (
|
|
274
|
+
""
|
|
275
|
+
if _query_has_token(self.url) and not (api_key or "").strip()
|
|
276
|
+
else _api_key(api_key)
|
|
277
|
+
)
|
|
205
278
|
self.prompt = prompt
|
|
206
279
|
self.model = model
|
|
207
280
|
self.lora = lora
|
|
281
|
+
self.robot = robot
|
|
282
|
+
self.action_adapter = action_adapter
|
|
208
283
|
self.cameras = cameras or []
|
|
209
284
|
self.hz = hz
|
|
210
285
|
self.chunk_size = chunk_size
|
|
211
286
|
self.max_gpu_seconds = max_gpu_seconds
|
|
212
287
|
self.session_id = session_id
|
|
213
288
|
self.timeout = timeout
|
|
289
|
+
self.connect_retry_seconds = connect_retry_seconds
|
|
214
290
|
self.ready: dict[str, Any] | None = None
|
|
291
|
+
self.open_timing: dict[str, float] = {}
|
|
215
292
|
self._sock: socket.socket | None = None
|
|
216
293
|
self._next_seq = 0
|
|
217
294
|
|
|
@@ -227,24 +304,61 @@ class ActionStream:
|
|
|
227
304
|
self.close()
|
|
228
305
|
|
|
229
306
|
def open(self) -> ActionStream:
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
307
|
+
retry_seconds = (
|
|
308
|
+
_connect_retry_seconds(self.timeout)
|
|
309
|
+
if self.connect_retry_seconds is None
|
|
310
|
+
else max(0.0, min(max(0.0, self.timeout), self.connect_retry_seconds))
|
|
311
|
+
)
|
|
312
|
+
deadline = time.monotonic() + retry_seconds
|
|
313
|
+
delay = 0.5
|
|
314
|
+
while True:
|
|
315
|
+
open_started = time.monotonic()
|
|
316
|
+
timing: dict[str, float] = {}
|
|
317
|
+
try:
|
|
318
|
+
self._sock = _connect_websocket(
|
|
319
|
+
self.url,
|
|
320
|
+
api_key=self.api_key,
|
|
321
|
+
timeout=min(self.timeout, 15.0),
|
|
322
|
+
timing=timing,
|
|
323
|
+
)
|
|
324
|
+
ready_started = time.monotonic()
|
|
325
|
+
self.send_raw({
|
|
326
|
+
"type": "session.open",
|
|
327
|
+
"session_id": self.session_id,
|
|
328
|
+
"prompt": self.prompt,
|
|
329
|
+
"model": self.model,
|
|
330
|
+
"lora": self.lora,
|
|
331
|
+
"robot": self.robot,
|
|
332
|
+
"action_adapter": self.action_adapter,
|
|
333
|
+
"cameras": self.cameras,
|
|
334
|
+
"hz": self.hz,
|
|
335
|
+
"chunk_size": self.chunk_size,
|
|
336
|
+
"max_gpu_seconds": self.max_gpu_seconds,
|
|
337
|
+
})
|
|
338
|
+
ready = self.receive()
|
|
339
|
+
if ready.get("type") != "session.ready":
|
|
340
|
+
raise RuntimeError(f"Expected session.ready, received {ready!r}")
|
|
341
|
+
timing["session_ready_ms"] = (time.monotonic() - ready_started) * 1000.0
|
|
342
|
+
timing["open_ms"] = (time.monotonic() - open_started) * 1000.0
|
|
343
|
+
self.open_timing = timing
|
|
344
|
+
self.ready = ready
|
|
345
|
+
self.session_id = str(ready.get("session_id") or self.session_id)
|
|
346
|
+
return self
|
|
347
|
+
except Exception as exc:
|
|
348
|
+
if self._sock is not None:
|
|
349
|
+
try:
|
|
350
|
+
_close_websocket(self._sock)
|
|
351
|
+
except Exception:
|
|
352
|
+
try:
|
|
353
|
+
self._sock.close()
|
|
354
|
+
except Exception:
|
|
355
|
+
pass
|
|
356
|
+
self._sock = None
|
|
357
|
+
now = time.monotonic()
|
|
358
|
+
if retry_seconds <= 0 or not _retryable_connect_error(exc) or now >= deadline:
|
|
359
|
+
raise
|
|
360
|
+
time.sleep(min(delay, max(0.0, deadline - now)))
|
|
361
|
+
delay = min(delay * 1.6, 5.0)
|
|
248
362
|
|
|
249
363
|
def send_raw(self, frame: dict[str, Any]) -> None:
|
|
250
364
|
if self._sock is None:
|
|
@@ -428,4 +542,4 @@ def infer_actions(
|
|
|
428
542
|
}
|
|
429
543
|
with ActionStream(url=url, api_key=api_key, timeout=timeout, **session_fields) as stream:
|
|
430
544
|
stream.send_observation_frame(observation)
|
|
431
|
-
return stream.recv_action()
|
|
545
|
+
return stream.recv_action()
|