grid-cortex-client 0.3.0__tar.gz → 0.4.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. {grid_cortex_client-0.3.0 → grid_cortex_client-0.4.0}/PKG-INFO +1 -1
  2. {grid_cortex_client-0.3.0 → grid_cortex_client-0.4.0}/src/grid_cortex_client/cortex_client.py +14 -5
  3. {grid_cortex_client-0.3.0 → grid_cortex_client-0.4.0}/src/grid_cortex_client/model_type.py +27 -0
  4. grid_cortex_client-0.4.0/src/grid_cortex_client/models/graspgenx.py +151 -0
  5. grid_cortex_client-0.4.0/tests/test_graspgenx.py +143 -0
  6. {grid_cortex_client-0.3.0 → grid_cortex_client-0.4.0}/.gitignore +0 -0
  7. {grid_cortex_client-0.3.0 → grid_cortex_client-0.4.0}/CLAUDE.md +0 -0
  8. {grid_cortex_client-0.3.0 → grid_cortex_client-0.4.0}/README.md +0 -0
  9. {grid_cortex_client-0.3.0 → grid_cortex_client-0.4.0}/pyproject.toml +0 -0
  10. {grid_cortex_client-0.3.0 → grid_cortex_client-0.4.0}/src/grid_cortex_client/__init__.py +0 -0
  11. {grid_cortex_client-0.3.0 → grid_cortex_client-0.4.0}/src/grid_cortex_client/client.py +0 -0
  12. {grid_cortex_client-0.3.0 → grid_cortex_client-0.4.0}/src/grid_cortex_client/cortex_hub_client.py +0 -0
  13. {grid_cortex_client-0.3.0 → grid_cortex_client-0.4.0}/src/grid_cortex_client/models/__init__.py +0 -0
  14. {grid_cortex_client-0.3.0 → grid_cortex_client-0.4.0}/src/grid_cortex_client/models/base_model.py +0 -0
  15. {grid_cortex_client-0.3.0 → grid_cortex_client-0.4.0}/src/grid_cortex_client/models/da3metric.py +0 -0
  16. {grid_cortex_client-0.3.0 → grid_cortex_client-0.4.0}/src/grid_cortex_client/models/fast_foundation_stereo.py +0 -0
  17. {grid_cortex_client-0.3.0 → grid_cortex_client-0.4.0}/src/grid_cortex_client/models/finetune_pi05.py +0 -0
  18. {grid_cortex_client-0.3.0 → grid_cortex_client-0.4.0}/src/grid_cortex_client/models/foundation_stereo.py +0 -0
  19. {grid_cortex_client-0.3.0 → grid_cortex_client-0.4.0}/src/grid_cortex_client/models/graspgen.py +0 -0
  20. {grid_cortex_client-0.3.0 → grid_cortex_client-0.4.0}/src/grid_cortex_client/models/gsam2.py +0 -0
  21. {grid_cortex_client-0.3.0 → grid_cortex_client-0.4.0}/src/grid_cortex_client/models/locate_anything.py +0 -0
  22. {grid_cortex_client-0.3.0 → grid_cortex_client-0.4.0}/src/grid_cortex_client/models/metric3d.py +0 -0
  23. {grid_cortex_client-0.3.0 → grid_cortex_client-0.4.0}/src/grid_cortex_client/models/moondream.py +0 -0
  24. {grid_cortex_client-0.3.0 → grid_cortex_client-0.4.0}/src/grid_cortex_client/models/oneformer.py +0 -0
  25. {grid_cortex_client-0.3.0 → grid_cortex_client-0.4.0}/src/grid_cortex_client/models/owlv2.py +0 -0
  26. {grid_cortex_client-0.3.0 → grid_cortex_client-0.4.0}/src/grid_cortex_client/models/pi05.py +0 -0
  27. {grid_cortex_client-0.3.0 → grid_cortex_client-0.4.0}/src/grid_cortex_client/models/qwen_vl.py +0 -0
  28. {grid_cortex_client-0.3.0 → grid_cortex_client-0.4.0}/src/grid_cortex_client/models/sam2.py +0 -0
  29. {grid_cortex_client-0.3.0 → grid_cortex_client-0.4.0}/src/grid_cortex_client/models/sam3.py +0 -0
  30. {grid_cortex_client-0.3.0 → grid_cortex_client-0.4.0}/src/grid_cortex_client/models/ur5e_pyroki_collision_batch.py +0 -0
  31. {grid_cortex_client-0.3.0 → grid_cortex_client-0.4.0}/src/grid_cortex_client/models/ur5e_pyroki_fk.py +0 -0
  32. {grid_cortex_client-0.3.0 → grid_cortex_client-0.4.0}/src/grid_cortex_client/models/zoedepth.py +0 -0
  33. {grid_cortex_client-0.3.0 → grid_cortex_client-0.4.0}/src/grid_cortex_client/preprocessing.py +0 -0
  34. {grid_cortex_client-0.3.0 → grid_cortex_client-0.4.0}/src/grid_cortex_client/tools/__init__.py +0 -0
  35. {grid_cortex_client-0.3.0 → grid_cortex_client-0.4.0}/src/grid_cortex_client/tools/generate_enum.py +0 -0
  36. {grid_cortex_client-0.3.0 → grid_cortex_client-0.4.0}/src/grid_cortex_client/tools/registry.py +0 -0
  37. {grid_cortex_client-0.3.0 → grid_cortex_client-0.4.0}/src/grid_cortex_client/utils.py +0 -0
  38. {grid_cortex_client-0.3.0 → grid_cortex_client-0.4.0}/src/grid_cortex_client/ws.py +0 -0
  39. {grid_cortex_client-0.3.0 → grid_cortex_client-0.4.0}/tests/README.md +0 -0
  40. {grid_cortex_client-0.3.0 → grid_cortex_client-0.4.0}/tests/__init__.py +0 -0
  41. {grid_cortex_client-0.3.0 → grid_cortex_client-0.4.0}/tests/conftest.py +0 -0
  42. {grid_cortex_client-0.3.0 → grid_cortex_client-0.4.0}/tests/test_cortex_hub_client.py +0 -0
  43. {grid_cortex_client-0.3.0 → grid_cortex_client-0.4.0}/tests/test_da3metric.py +0 -0
  44. {grid_cortex_client-0.3.0 → grid_cortex_client-0.4.0}/tests/test_fast_foundation_stereo.py +0 -0
  45. {grid_cortex_client-0.3.0 → grid_cortex_client-0.4.0}/tests/test_finetune_pi05.py +0 -0
  46. {grid_cortex_client-0.3.0 → grid_cortex_client-0.4.0}/tests/test_foundation_stereo.py +0 -0
  47. {grid_cortex_client-0.3.0 → grid_cortex_client-0.4.0}/tests/test_graspgen.py +0 -0
  48. {grid_cortex_client-0.3.0 → grid_cortex_client-0.4.0}/tests/test_gsam2.py +0 -0
  49. {grid_cortex_client-0.3.0 → grid_cortex_client-0.4.0}/tests/test_locate_anything.py +0 -0
  50. {grid_cortex_client-0.3.0 → grid_cortex_client-0.4.0}/tests/test_metric3d.py +0 -0
  51. {grid_cortex_client-0.3.0 → grid_cortex_client-0.4.0}/tests/test_moondream.py +0 -0
  52. {grid_cortex_client-0.3.0 → grid_cortex_client-0.4.0}/tests/test_oneformer.py +0 -0
  53. {grid_cortex_client-0.3.0 → grid_cortex_client-0.4.0}/tests/test_owlv2.py +0 -0
  54. {grid_cortex_client-0.3.0 → grid_cortex_client-0.4.0}/tests/test_pi05.py +0 -0
  55. {grid_cortex_client-0.3.0 → grid_cortex_client-0.4.0}/tests/test_qwen_vl.py +0 -0
  56. {grid_cortex_client-0.3.0 → grid_cortex_client-0.4.0}/tests/test_sam2.py +0 -0
  57. {grid_cortex_client-0.3.0 → grid_cortex_client-0.4.0}/tests/test_sam3.py +0 -0
  58. {grid_cortex_client-0.3.0 → grid_cortex_client-0.4.0}/tests/test_ur5e_pyroki_collision_batch.py +0 -0
  59. {grid_cortex_client-0.3.0 → grid_cortex_client-0.4.0}/tests/test_ur5e_pyroki_fk.py +0 -0
  60. {grid_cortex_client-0.3.0 → grid_cortex_client-0.4.0}/tests/test_zoedepth.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: grid-cortex-client
3
- Version: 0.3.0
3
+ Version: 0.4.0
4
4
  Summary: Python client for Grid Cortex
5
5
  Classifier: Programming Language :: Python :: 3
6
6
  Classifier: Programming Language :: Python :: 3.8
@@ -62,11 +62,20 @@ class _CortexClientBase:
62
62
  _MODEL_ID_TO_HANDLER_CLASS: Dict[str, Type[BaseModel]] = get_registry()
63
63
 
64
64
  def _get_handler_class(self, model_id: Union[str, ModelType]) -> Optional[Type[BaseModel]]:
65
- """Look up the handler class for a model ID."""
66
- model_id_str = model_id.value if isinstance(model_id, ModelType) else str(model_id)
67
- for keyword, HClass in self._MODEL_ID_TO_HANDLER_CLASS.items():
68
- if keyword in model_id_str.lower():
69
- return HClass
65
+ """Look up the handler class for a model ID.
66
+
67
+ Prefer an exact match so a registered id that is a substring of another
68
+ (e.g. ``graspgen`` vs ``graspgenx``) can't shadow it. Fall back to
69
+ substring matching, longest keyword first, to keep the lenient behavior
70
+ for free-form ids while still preventing short-id shadowing.
71
+ """
72
+ model_id_str = (model_id.value if isinstance(model_id, ModelType) else str(model_id)).lower()
73
+ exact = self._MODEL_ID_TO_HANDLER_CLASS.get(model_id_str)
74
+ if exact is not None:
75
+ return exact
76
+ for keyword in sorted(self._MODEL_ID_TO_HANDLER_CLASS, key=len, reverse=True):
77
+ if keyword in model_id_str:
78
+ return self._MODEL_ID_TO_HANDLER_CLASS[keyword]
70
79
  return None
71
80
 
72
81
  def _preprocess(
@@ -181,6 +181,33 @@ class ModelType(Enum):
181
181
  >>> grasps, conf = client.run(ModelType.GRASPGEN, depth_image=depth_image, seg_image=seg_image, camera_intrinsics=K, aux_args=aux)
182
182
  >>> print(f"Generated {len(grasps)} grasps")
183
183
  """
184
+ GRASPGENX = "graspgenx"
185
+ """
186
+ Generate 6-DoF grasps for a named gripper with GraspGenX.
187
+
188
+ Args:
189
+ aux_args (Dict[str, Any]): Must include "gripper_name"; optional
190
+ "num_grasps", "planner" ("graspmoe"|"diffusion"), "grasp_threshold",
191
+ "camera_extrinsics".
192
+ point_cloud (Union[str, np.ndarray, Sequence, None]): (N, 3) object point cloud.
193
+ depth_image / seg_image / camera_intrinsics: alternative to point_cloud.
194
+ timeout (float | None): Optional HTTP timeout.
195
+
196
+ Returns:
197
+ Dict[str, Any]: Dict with:
198
+ - grasps: (N, 4, 4) array of grasp poses
199
+ - confidence: (N,) array of scores
200
+ - latency_ms: server-reported latency
201
+
202
+ Examples:
203
+ >>> from grid_cortex_client import CortexClient, ModelType
204
+ >>> import numpy as np
205
+ >>> client = CortexClient()
206
+ >>> pc = np.load("object_pc.npy") # (N, 3)
207
+ >>> out = client.run(ModelType.GRASPGENX, point_cloud=pc,
208
+ ... aux_args={"gripper_name": "robotiq_2f_85", "num_grasps": 64})
209
+ >>> print(out["grasps"].shape)
210
+ """
184
211
  GSAM2 = "gsam2"
185
212
  """
186
213
  Segment objects in image using text prompt.
@@ -0,0 +1,151 @@
1
+ # grid_cortex_client/src/grid_cortex_client/models/graspgenx.py
2
+ """GraspGenX wrapper.
3
+
4
+ Cross-embodiment 6-DoF grasp generation (NVlabs/GraspGenX). A single model
5
+ conditioned on a gripper representation, so grasps are requested for a named
6
+ gripper (``gripper_name``) rather than a per-gripper model.
7
+
8
+ Input: a segmented object point cloud (or depth+seg+intrinsics, converted
9
+ server-side). Output: 6-DoF grasp poses + confidences.
10
+
11
+ NOTE: weights are under the NVIDIA Open Model License — see the model recipe.
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ from typing import Any, Dict, Sequence, Union
17
+
18
+ import numpy as np
19
+ from PIL import Image
20
+
21
+ from ..preprocessing import array_to_npy_bytes, load_image
22
+ from .base_model import BaseModel
23
+
24
+
25
+ class GraspGenX(BaseModel):
26
+ """Cross-embodiment grasp generation (GraspGenX).
27
+
28
+ Preferred usage
29
+ ---------------
30
+ ```pycon
31
+ >>> out = CortexClient().run(
32
+ ... "graspgenx",
33
+ ... point_cloud=pc, # (N, 3) float32
34
+ ... aux_args={"gripper_name": "robotiq_2f_85", "num_grasps": 64},
35
+ ... )
36
+ >>> grasps, conf = out["grasps"], out["confidence"]
37
+ ```
38
+ """
39
+
40
+ name: str = "graspgenx"
41
+ model_id: str = "graspgenx"
42
+
43
+ # ------------------------------------------------------------------
44
+ # BaseModel implementation
45
+ # ------------------------------------------------------------------
46
+
47
+ def preprocess(
48
+ self,
49
+ *,
50
+ aux_args: Dict[str, Any] | None = None,
51
+ point_cloud: Union[str, np.ndarray, Sequence[Sequence[float]], None] = None,
52
+ depth_image: Union[str, Image.Image, np.ndarray, None] = None,
53
+ seg_image: Union[str, Image.Image, np.ndarray, None] = None,
54
+ camera_intrinsics: Union[str, np.ndarray, None] = None,
55
+ ) -> Dict[str, Any]:
56
+ """Prepare the msgpack payload.
57
+
58
+ Args:
59
+ aux_args: Dict with "gripper_name", "num_grasps", "planner",
60
+ "grasp_threshold", "camera_extrinsics".
61
+ point_cloud: (N, 3) point cloud (ndarray/list/path to .npy).
62
+ depth_image / seg_image / camera_intrinsics: alternative to point_cloud;
63
+ converted to a point cloud server-side.
64
+ """
65
+ if aux_args is None:
66
+ raise ValueError("'aux_args' is required (must include 'gripper_name').")
67
+
68
+ has_pc = point_cloud is not None
69
+ has_depth_triplet = any(x is not None for x in (depth_image, seg_image, camera_intrinsics))
70
+ if has_pc and has_depth_triplet:
71
+ raise ValueError("Provide either point_cloud or depth_image+seg_image+camera_intrinsics, not both.")
72
+ if not has_pc and not all(x is not None for x in (depth_image, seg_image, camera_intrinsics)):
73
+ raise ValueError("Provide point_cloud, or all of depth_image + seg_image + camera_intrinsics.")
74
+
75
+ if has_pc:
76
+ pc_array = np.load(point_cloud) if isinstance(point_cloud, str) else np.asarray(point_cloud)
77
+ if pc_array.ndim != 2 or pc_array.shape[1] != 3:
78
+ raise ValueError("point_cloud must be an (N, 3) array.")
79
+ return {"point_cloud": array_to_npy_bytes(pc_array.astype(np.float32)), "aux_args": aux_args}
80
+
81
+ depth_array = (
82
+ np.array(load_image(depth_image))
83
+ if isinstance(depth_image, (str, Image.Image))
84
+ else np.asarray(depth_image)
85
+ )
86
+ seg_array = (
87
+ np.array(load_image(seg_image)) if isinstance(seg_image, (str, Image.Image)) else np.asarray(seg_image)
88
+ )
89
+ intrinsics = np.load(camera_intrinsics) if isinstance(camera_intrinsics, str) else np.asarray(camera_intrinsics)
90
+ return {
91
+ "depth_image": array_to_npy_bytes(depth_array),
92
+ "seg_image": array_to_npy_bytes(seg_array),
93
+ "camera_intrinsics": array_to_npy_bytes(intrinsics),
94
+ "aux_args": aux_args,
95
+ }
96
+
97
+ def postprocess(self, response_data: Dict[str, Any], **_: Any) -> Dict[str, Any]: # noqa: D401
98
+ """Decode grasps + confidence from the response."""
99
+ grasps = np.array(response_data["output"])
100
+ if grasps.size == 0:
101
+ # Keep the documented (N, 4, 4) contract on the empty path.
102
+ grasps = np.empty((0, 4, 4), dtype=np.float32)
103
+ return {
104
+ "grasps": grasps,
105
+ "confidence": np.array(response_data["confidence"]),
106
+ "latency_ms": response_data.get("latency_ms"),
107
+ }
108
+
109
+ def run(
110
+ self,
111
+ *,
112
+ aux_args: Dict[str, Any],
113
+ point_cloud: Union[str, np.ndarray, Sequence[Sequence[float]], None] = None,
114
+ depth_image: Union[str, Image.Image, np.ndarray, None] = None,
115
+ seg_image: Union[str, Image.Image, np.ndarray, None] = None,
116
+ camera_intrinsics: Union[str, np.ndarray, None] = None,
117
+ timeout: float | None = None,
118
+ ) -> Dict[str, Any]:
119
+ """Generate 6-DoF grasps for a named gripper with GraspGenX.
120
+
121
+ Args:
122
+ aux_args (Dict[str, Any]): Must include "gripper_name"; optional
123
+ "num_grasps", "planner" ("graspmoe"|"diffusion"), "grasp_threshold",
124
+ "camera_extrinsics".
125
+ point_cloud (Union[str, np.ndarray, Sequence, None]): (N, 3) object point cloud.
126
+ depth_image / seg_image / camera_intrinsics: alternative to point_cloud.
127
+ timeout (float | None): Optional HTTP timeout.
128
+
129
+ Returns:
130
+ Dict[str, Any]: Dict with:
131
+ - grasps: (N, 4, 4) array of grasp poses
132
+ - confidence: (N,) array of scores
133
+ - latency_ms: server-reported latency
134
+
135
+ Examples:
136
+ >>> from grid_cortex_client import CortexClient, ModelType
137
+ >>> import numpy as np
138
+ >>> client = CortexClient()
139
+ >>> pc = np.load("object_pc.npy") # (N, 3)
140
+ >>> out = client.run(ModelType.GRASPGENX, point_cloud=pc,
141
+ ... aux_args={"gripper_name": "robotiq_2f_85", "num_grasps": 64})
142
+ >>> print(out["grasps"].shape)
143
+ """
144
+ return super().run(
145
+ aux_args=aux_args,
146
+ point_cloud=point_cloud,
147
+ depth_image=depth_image,
148
+ seg_image=seg_image,
149
+ camera_intrinsics=camera_intrinsics,
150
+ timeout=timeout,
151
+ )
@@ -0,0 +1,143 @@
1
+ """Test GraspGenX model (cross-embodiment 6-DoF grasp generation).
2
+
3
+ Runs against a live Ray Serve + CortexHub via CortexClient.run(...). Uses a
4
+ synthetic object point cloud so the test does not depend on staged eval data.
5
+ """
6
+
7
+ import os
8
+ import time
9
+
10
+ import numpy as np
11
+ import pytest
12
+
13
+ from grid_cortex_client import AsyncCortexClient, CortexClient, CortexHubClient, ModelType
14
+
15
+ from .conftest import BaseModelTestHTTP, is_gpu_mem_close, run_concurrent_requests
16
+
17
+
18
+ def _synthetic_object_pc(n: int = 2048, seed: int = 0) -> np.ndarray:
19
+ """A small noisy point cloud standing in for a segmented object."""
20
+ rng = np.random.default_rng(seed)
21
+ return ((rng.random((n, 3), dtype=np.float32) - 0.5) * 0.1).astype(np.float32)
22
+
23
+
24
+ class TestGraspGenXHTTP(BaseModelTestHTTP):
25
+ """HTTP-based tests for GraspGenX."""
26
+
27
+ @property
28
+ def batch_size(self) -> int:
29
+ """max_ongoing_requests from config."""
30
+ return 1
31
+
32
+ @property
33
+ def gpu_mem(self) -> float:
34
+ """Declared GPU memory (MB). ~688 MB peak in CI smoke (num_grasps=16); 1000 = headroom."""
35
+ return 1000.0
36
+
37
+ @property
38
+ def max_input_size(self) -> tuple:
39
+ """Not image-based; placeholder for the ABC contract."""
40
+ return (480, 640)
41
+
42
+ @pytest.fixture(scope="class")
43
+ def point_cloud(self) -> np.ndarray:
44
+ return _synthetic_object_pc()
45
+
46
+ @pytest.fixture(scope="class")
47
+ def aux_args(self) -> dict:
48
+ return {"gripper_name": "robotiq_2f_85", "num_grasps": 16}
49
+
50
+ def test_basic_inference(self, client: CortexClient, point_cloud: np.ndarray, aux_args: dict):
51
+ """Grasp generation from a direct point cloud."""
52
+ result = client.run(ModelType.GRASPGENX, point_cloud=point_cloud, aux_args=aux_args)
53
+ assert result is not None
54
+ assert "grasps" in result
55
+ assert "confidence" in result
56
+ assert isinstance(result["grasps"], np.ndarray)
57
+ assert isinstance(result["confidence"], np.ndarray)
58
+
59
+ def test_grasp_output_structure(self, client: CortexClient, point_cloud: np.ndarray, aux_args: dict):
60
+ """Grasps are (N, 4, 4) and confidence is (N,) when any are returned."""
61
+ result = client.run(ModelType.GRASPGENX, point_cloud=point_cloud, aux_args=aux_args)
62
+ grasps, confidence = result["grasps"], result["confidence"]
63
+ if len(grasps) > 0:
64
+ assert grasps.ndim == 3 and grasps.shape[1:] == (4, 4)
65
+ assert confidence.ndim == 1
66
+ assert len(grasps) == len(confidence)
67
+
68
+ def test_maximum_payload_size(self, client: CortexClient, aux_args: dict):
69
+ """Inference works on a large object point cloud (well above the basic case).
70
+
71
+ Uses 8192 points — a realistic dense segmented object, ~4x the basic test.
72
+ (The handler's hard guard is MAX_POINTS=50000, but a single object cloud is
73
+ never that large, and graspmoe's outlier-removal step is superlinear in N.)
74
+ """
75
+ large_pc = _synthetic_object_pc(n=8192)
76
+ result = client.run(ModelType.GRASPGENX, point_cloud=large_pc, aux_args=aux_args)
77
+ assert result is not None
78
+ assert "grasps" in result
79
+ assert "confidence" in result
80
+
81
+ def test_raw_response_with_gpu_stats(self, client: CortexClient, point_cloud: np.ndarray, aux_args: dict):
82
+ """Raw response carries output/confidence/latency/gpu_stats."""
83
+ raw = client.run(ModelType.GRASPGENX, point_cloud=point_cloud, aux_args=aux_args, raw_response=True)
84
+ assert "output" in raw
85
+ assert "confidence" in raw
86
+ assert "latency_ms" in raw
87
+ assert "gpu_stats" in raw and "mem_usage_mb" in raw["gpu_stats"]
88
+
89
+ def test_missing_aux_args(self, client: CortexClient, point_cloud: np.ndarray):
90
+ """Missing aux_args raises a client-side ValueError."""
91
+ with pytest.raises(ValueError):
92
+ client.run(ModelType.GRASPGENX, point_cloud=point_cloud)
93
+
94
+ def test_missing_input(self, client: CortexClient, aux_args: dict):
95
+ """Neither point_cloud nor depth-triplet raises a client-side ValueError."""
96
+ with pytest.raises(ValueError):
97
+ client.run(ModelType.GRASPGENX, aux_args=aux_args)
98
+
99
+ def test_inference_latency(self, client: CortexClient, point_cloud: np.ndarray, aux_args: dict):
100
+ """Latency within an acceptable bound."""
101
+ start = time.time()
102
+ result = client.run(ModelType.GRASPGENX, point_cloud=point_cloud, aux_args=aux_args)
103
+ elapsed = time.time() - start
104
+ assert result is not None
105
+ assert elapsed < 50.0, f"Inference took {elapsed:.2f}s, expected < 50s"
106
+
107
+ @pytest.mark.asyncio
108
+ async def test_concurrent_requests(self, async_client: AsyncCortexClient, point_cloud: np.ndarray, aux_args: dict):
109
+ """Concurrent requests are handled."""
110
+
111
+ async def make_request():
112
+ return await async_client.run(ModelType.GRASPGENX, point_cloud=point_cloud, aux_args=aux_args)
113
+
114
+ num_requests = min(self.batch_size, 1)
115
+ results = await run_concurrent_requests(make_request, num_requests=num_requests)
116
+ assert len(results) == num_requests
117
+ assert all(not isinstance(r, Exception) for r in results)
118
+ assert all("grasps" in r for r in results)
119
+
120
+ @pytest.mark.asyncio
121
+ async def test_gpu_memory_usage(self, async_client: AsyncCortexClient, point_cloud: np.ndarray, aux_args: dict):
122
+ """GPU memory within the declared budget."""
123
+
124
+ async def make_request():
125
+ return await async_client.run(
126
+ ModelType.GRASPGENX, point_cloud=point_cloud, aux_args=aux_args, raw_response=True
127
+ )
128
+
129
+ results = await run_concurrent_requests(make_request, num_requests=min(self.batch_size, 2))
130
+ assert all(not isinstance(r, Exception) for r in results)
131
+ assert all(is_gpu_mem_close(self.gpu_mem, r["gpu_stats"]["mem_usage_mb"]) for r in results), (
132
+ f"GPU memory usage too high: {[r['gpu_stats']['mem_usage_mb'] for r in results]} MB"
133
+ )
134
+
135
+ @pytest.mark.asyncio
136
+ async def test_cortex_hub(self, point_cloud: np.ndarray, aux_args: dict):
137
+ """Inference via CortexHubClient (WebSocket)."""
138
+ api_key = os.getenv("GRID_CORTEX_API_KEY") or os.getenv("CORTEX_API_KEY")
139
+ if not api_key:
140
+ pytest.skip("GRID_CORTEX_API_KEY or CORTEX_API_KEY not set")
141
+ async with CortexHubClient(api_key=api_key) as hub:
142
+ result = await hub.run(ModelType.GRASPGENX, point_cloud=point_cloud, aux_args=aux_args)
143
+ assert result is not None