opentau 0.1.1__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,210 @@
1
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
16
+ """Client and server classes corresponding to protobuf-defined services."""
17
+
18
+ import grpc
19
+ from opentau.scripts.grpc import robot_inference_pb2 as robot__inference__pb2
20
+
21
+ GRPC_GENERATED_VERSION = "1.76.0"
22
+ GRPC_VERSION = grpc.__version__
23
+ _version_not_supported = False
24
+
25
+ try:
26
+ from grpc._utilities import first_version_is_lower
27
+
28
+ _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
29
+ except ImportError:
30
+ _version_not_supported = True
31
+
32
+ if _version_not_supported:
33
+ raise RuntimeError(
34
+ f"The grpc package installed is at version {GRPC_VERSION},"
35
+ + " but the generated code in robot_inference_pb2_grpc.py depends on"
36
+ + f" grpcio>={GRPC_GENERATED_VERSION}."
37
+ + f" Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}"
38
+ + f" or downgrade your generated code using grpcio-tools<={GRPC_VERSION}."
39
+ )
40
+
41
+
42
+ class RobotPolicyServiceStub:
43
+ """Service for robot policy inference"""
44
+
45
+ def __init__(self, channel):
46
+ """Constructor.
47
+
48
+ Args:
49
+ channel: A grpc.Channel.
50
+ """
51
+ self.GetActionChunk = channel.unary_unary(
52
+ "/robot_inference.RobotPolicyService/GetActionChunk",
53
+ request_serializer=robot__inference__pb2.ObservationRequest.SerializeToString,
54
+ response_deserializer=robot__inference__pb2.ActionChunkResponse.FromString,
55
+ _registered_method=True,
56
+ )
57
+ self.StreamActionChunks = channel.stream_stream(
58
+ "/robot_inference.RobotPolicyService/StreamActionChunks",
59
+ request_serializer=robot__inference__pb2.ObservationRequest.SerializeToString,
60
+ response_deserializer=robot__inference__pb2.ActionChunkResponse.FromString,
61
+ _registered_method=True,
62
+ )
63
+ self.HealthCheck = channel.unary_unary(
64
+ "/robot_inference.RobotPolicyService/HealthCheck",
65
+ request_serializer=robot__inference__pb2.HealthCheckRequest.SerializeToString,
66
+ response_deserializer=robot__inference__pb2.HealthCheckResponse.FromString,
67
+ _registered_method=True,
68
+ )
69
+
70
+
71
+ class RobotPolicyServiceServicer:
72
+ """Service for robot policy inference"""
73
+
74
+ def GetActionChunk(self, request, context):
75
+ """Get action chunk from observations (single request-response)"""
76
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
77
+ context.set_details("Method not implemented!")
78
+ raise NotImplementedError("Method not implemented!")
79
+
80
+ def StreamActionChunks(self, request_iterator, context):
81
+ """Streaming version for continuous inference (robot sends observations, server sends action chunks)"""
82
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
83
+ context.set_details("Method not implemented!")
84
+ raise NotImplementedError("Method not implemented!")
85
+
86
+ def HealthCheck(self, request, context):
87
+ """Health check"""
88
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
89
+ context.set_details("Method not implemented!")
90
+ raise NotImplementedError("Method not implemented!")
91
+
92
+
93
+ def add_RobotPolicyServiceServicer_to_server(servicer, server):
94
+ rpc_method_handlers = {
95
+ "GetActionChunk": grpc.unary_unary_rpc_method_handler(
96
+ servicer.GetActionChunk,
97
+ request_deserializer=robot__inference__pb2.ObservationRequest.FromString,
98
+ response_serializer=robot__inference__pb2.ActionChunkResponse.SerializeToString,
99
+ ),
100
+ "StreamActionChunks": grpc.stream_stream_rpc_method_handler(
101
+ servicer.StreamActionChunks,
102
+ request_deserializer=robot__inference__pb2.ObservationRequest.FromString,
103
+ response_serializer=robot__inference__pb2.ActionChunkResponse.SerializeToString,
104
+ ),
105
+ "HealthCheck": grpc.unary_unary_rpc_method_handler(
106
+ servicer.HealthCheck,
107
+ request_deserializer=robot__inference__pb2.HealthCheckRequest.FromString,
108
+ response_serializer=robot__inference__pb2.HealthCheckResponse.SerializeToString,
109
+ ),
110
+ }
111
+ generic_handler = grpc.method_handlers_generic_handler(
112
+ "robot_inference.RobotPolicyService", rpc_method_handlers
113
+ )
114
+ server.add_generic_rpc_handlers((generic_handler,))
115
+ server.add_registered_method_handlers("robot_inference.RobotPolicyService", rpc_method_handlers)
116
+
117
+
118
+ # This class is part of an EXPERIMENTAL API.
119
+ class RobotPolicyService:
120
+ """Service for robot policy inference"""
121
+
122
+ @staticmethod
123
+ def GetActionChunk(
124
+ request,
125
+ target,
126
+ options=(),
127
+ channel_credentials=None,
128
+ call_credentials=None,
129
+ insecure=False,
130
+ compression=None,
131
+ wait_for_ready=None,
132
+ timeout=None,
133
+ metadata=None,
134
+ ):
135
+ return grpc.experimental.unary_unary(
136
+ request,
137
+ target,
138
+ "/robot_inference.RobotPolicyService/GetActionChunk",
139
+ robot__inference__pb2.ObservationRequest.SerializeToString,
140
+ robot__inference__pb2.ActionChunkResponse.FromString,
141
+ options,
142
+ channel_credentials,
143
+ insecure,
144
+ call_credentials,
145
+ compression,
146
+ wait_for_ready,
147
+ timeout,
148
+ metadata,
149
+ _registered_method=True,
150
+ )
151
+
152
+ @staticmethod
153
+ def StreamActionChunks(
154
+ request_iterator,
155
+ target,
156
+ options=(),
157
+ channel_credentials=None,
158
+ call_credentials=None,
159
+ insecure=False,
160
+ compression=None,
161
+ wait_for_ready=None,
162
+ timeout=None,
163
+ metadata=None,
164
+ ):
165
+ return grpc.experimental.stream_stream(
166
+ request_iterator,
167
+ target,
168
+ "/robot_inference.RobotPolicyService/StreamActionChunks",
169
+ robot__inference__pb2.ObservationRequest.SerializeToString,
170
+ robot__inference__pb2.ActionChunkResponse.FromString,
171
+ options,
172
+ channel_credentials,
173
+ insecure,
174
+ call_credentials,
175
+ compression,
176
+ wait_for_ready,
177
+ timeout,
178
+ metadata,
179
+ _registered_method=True,
180
+ )
181
+
182
+ @staticmethod
183
+ def HealthCheck(
184
+ request,
185
+ target,
186
+ options=(),
187
+ channel_credentials=None,
188
+ call_credentials=None,
189
+ insecure=False,
190
+ compression=None,
191
+ wait_for_ready=None,
192
+ timeout=None,
193
+ metadata=None,
194
+ ):
195
+ return grpc.experimental.unary_unary(
196
+ request,
197
+ target,
198
+ "/robot_inference.RobotPolicyService/HealthCheck",
199
+ robot__inference__pb2.HealthCheckRequest.SerializeToString,
200
+ robot__inference__pb2.HealthCheckResponse.FromString,
201
+ options,
202
+ channel_credentials,
203
+ insecure,
204
+ call_credentials,
205
+ compression,
206
+ wait_for_ready,
207
+ timeout,
208
+ metadata,
209
+ _registered_method=True,
210
+ )
@@ -0,0 +1,313 @@
1
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """gRPC server for robot policy inference on GPU.
16
+
17
+ This server loads an ML policy model and serves inference requests from
18
+ robots running ROS 2. Designed to run on a server with a ML GPU.
19
+
20
+ Usage:
21
+ python src/opentau/scripts/grpc/server.py --config_path=/path/to/config.json \\
22
+ --server.port=50051 --server.max_workers=4
23
+ """
24
+
25
+ import io
26
+ import logging
27
+ from concurrent import futures
28
+ from dataclasses import asdict
29
+ from pprint import pformat
30
+ from typing import Iterator
31
+
32
+ import numpy as np
33
+ import torch
34
+ from PIL import Image
35
+
36
+ import grpc
37
+ from opentau.configs import parser
38
+ from opentau.configs.train import TrainPipelineConfig
39
+ from opentau.policies.factory import get_policy_class
40
+ from opentau.scripts.grpc import robot_inference_pb2, robot_inference_pb2_grpc
41
+ from opentau.utils.random_utils import set_seed
42
+ from opentau.utils.utils import (
43
+ attempt_torch_compile,
44
+ auto_torch_device,
45
+ init_logging,
46
+ )
47
+
48
+ logger = logging.getLogger(__name__)
49
+
50
+
51
+ class RobotPolicyServicer(robot_inference_pb2_grpc.RobotPolicyServiceServicer):
52
+ """gRPC servicer implementing the RobotPolicyService."""
53
+
54
+ def __init__(self, cfg: TrainPipelineConfig):
55
+ """Initialize the servicer with model and configuration.
56
+
57
+ Args:
58
+ cfg: Training pipeline configuration including policy settings.
59
+ """
60
+ self.cfg = cfg
61
+ self.device = auto_torch_device()
62
+ self.dtype = torch.bfloat16
63
+
64
+ logger.info(f"Initializing policy on device: {self.device}")
65
+
66
+ # Load the policy model
67
+ self._load_policy()
68
+
69
+ def _load_policy(self):
70
+ """Load the policy model from pretrained weights."""
71
+ logger.info(f"Loading policy from: {self.cfg.policy.pretrained_path}")
72
+
73
+ policy_class = get_policy_class(self.cfg.policy.type)
74
+ self.policy = policy_class.from_pretrained(self.cfg.policy.pretrained_path, config=self.cfg.policy)
75
+ self.policy.to(device=self.device, dtype=self.dtype)
76
+ self.policy.eval()
77
+ self.policy = attempt_torch_compile(self.policy, device_hint=self.device)
78
+ self.policy.reset()
79
+ logger.info("Policy loaded successfully")
80
+
81
+ def _decode_image(self, camera_image: robot_inference_pb2.CameraImage) -> torch.Tensor:
82
+ """Decode an image from the protobuf message.
83
+
84
+ Args:
85
+ camera_image: CameraImage protobuf message.
86
+
87
+ Returns:
88
+ Tensor of shape (1, C, H, W) normalized to [0, 1].
89
+ """
90
+ if camera_image.encoding in ["jpeg", "png"]:
91
+ # Decode compressed image
92
+ image = Image.open(io.BytesIO(camera_image.image_data))
93
+ image = image.convert("RGB")
94
+ image = image.resize(self.cfg.resolution[::-1]) # PIL uses (W, H)
95
+ image_array = np.array(image, dtype=np.float32) / 255.0
96
+ elif camera_image.encoding == "raw":
97
+ # Raw image data - assume it's already in the right shape
98
+ image_array = np.frombuffer(camera_image.image_data, dtype=np.float32)
99
+ # Reshape assuming square image with 3 channels
100
+ side = int(np.sqrt(len(image_array) / 3))
101
+ image_array = image_array.reshape(side, side, 3)
102
+ # Resize if needed
103
+ if (side, side) != self.cfg.resolution:
104
+ image = Image.fromarray((image_array * 255).astype(np.uint8))
105
+ image = image.resize(self.cfg.resolution[::-1])
106
+ image_array = np.array(image, dtype=np.float32) / 255.0
107
+ else:
108
+ raise ValueError(f"Unknown image encoding: {camera_image.encoding}")
109
+
110
+ # Convert to (C, H, W) tensor
111
+ image_tensor = torch.from_numpy(image_array).permute(2, 0, 1).unsqueeze(0)
112
+ return image_tensor.to(device=self.device, dtype=self.dtype)
113
+
114
+ def _prepare_observation(
115
+ self, request: robot_inference_pb2.ObservationRequest
116
+ ) -> dict[str, torch.Tensor]:
117
+ """Convert a protobuf observation request to the policy input format.
118
+
119
+ Args:
120
+ request: ObservationRequest protobuf message.
121
+
122
+ Returns:
123
+ Dictionary of tensors matching the policy's expected input format.
124
+ """
125
+ batch = {}
126
+
127
+ # Process camera images
128
+ img_is_pad = []
129
+ for i, camera_image in enumerate(request.images):
130
+ camera_name = f"camera{i}"
131
+ batch[camera_name] = self._decode_image(camera_image)
132
+ img_is_pad.append(False)
133
+
134
+ # Fill in missing cameras with zeros
135
+ for i in range(len(request.images), self.cfg.num_cams):
136
+ batch[f"camera{i}"] = torch.zeros(
137
+ (1, 3, *self.cfg.resolution),
138
+ dtype=self.dtype,
139
+ device=self.device,
140
+ )
141
+ img_is_pad.append(True)
142
+
143
+ batch["img_is_pad"] = torch.tensor([img_is_pad], dtype=torch.bool, device=self.device)
144
+
145
+ # Process robot state
146
+ if request.robot_state.state:
147
+ state = list(request.robot_state.state)
148
+ # Pad to max_state_dim if needed
149
+ if len(state) < self.cfg.max_state_dim:
150
+ state.extend([0.0] * (self.cfg.max_state_dim - len(state)))
151
+ batch["state"] = torch.tensor(
152
+ [state[: self.cfg.max_state_dim]],
153
+ dtype=self.dtype,
154
+ device=self.device,
155
+ )
156
+ else:
157
+ raise ValueError("Robot state is required but was not provided in the request")
158
+
159
+ # Process prompt
160
+ batch["prompt"] = [request.prompt] if request.prompt else [""]
161
+
162
+ return batch
163
+
164
+ def GetActionChunk(
165
+ self,
166
+ request: robot_inference_pb2.ObservationRequest,
167
+ context: grpc.ServicerContext,
168
+ ) -> robot_inference_pb2.ActionChunkResponse:
169
+ """Handle a single action chunk inference request.
170
+
171
+ Args:
172
+ request: ObservationRequest containing observations.
173
+ context: gRPC context.
174
+
175
+ Returns:
176
+ ActionChunkResponse containing the predicted action chunk.
177
+ """
178
+ import time
179
+
180
+ start_time = time.perf_counter()
181
+ response = robot_inference_pb2.ActionChunkResponse()
182
+ response.request_id = request.request_id
183
+ response.timestamp_ns = time.time_ns()
184
+
185
+ try:
186
+ # Prepare observation batch
187
+ batch = self._prepare_observation(request)
188
+
189
+ # Run inference
190
+ with torch.inference_mode():
191
+ action_chunk = self.policy.sample_actions(batch)
192
+ # action_chunk shape: (n_action_steps, batch_size=1, action_dim)
193
+ # Remove batch dimension and convert to numpy
194
+ action_chunk = action_chunk.squeeze(1).to("cpu", torch.float32).numpy()
195
+
196
+ # Populate 2D action chunk structure
197
+ for action_vector in action_chunk:
198
+ action_vec_msg = robot_inference_pb2.ActionVector()
199
+ action_vec_msg.values.extend(action_vector.tolist())
200
+ response.action_chunk.append(action_vec_msg)
201
+
202
+ except ValueError as e:
203
+ # Invalid request (e.g., missing required fields)
204
+ logger.error(f"Invalid request: {e}")
205
+ context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e))
206
+
207
+ except Exception as e:
208
+ # Unexpected error during inference
209
+ logger.exception("Error during inference")
210
+ context.abort(grpc.StatusCode.INTERNAL, f"Inference error: {e}")
211
+
212
+ response.inference_time_ms = (time.perf_counter() - start_time) * 1000
213
+ return response
214
+
215
+ def StreamActionChunks(
216
+ self,
217
+ request_iterator: Iterator[robot_inference_pb2.ObservationRequest],
218
+ context: grpc.ServicerContext,
219
+ ) -> Iterator[robot_inference_pb2.ActionChunkResponse]:
220
+ """Handle streaming action chunk inference requests.
221
+
222
+ Args:
223
+ request_iterator: Iterator of ObservationRequest messages.
224
+ context: gRPC context.
225
+
226
+ Yields:
227
+ ActionChunkResponse messages for each observation.
228
+ """
229
+ for request in request_iterator:
230
+ if context.is_active():
231
+ yield self.GetActionChunk(request, context)
232
+ else:
233
+ break
234
+
235
+ def HealthCheck(
236
+ self,
237
+ request: robot_inference_pb2.HealthCheckRequest,
238
+ context: grpc.ServicerContext,
239
+ ) -> robot_inference_pb2.HealthCheckResponse:
240
+ """Check server health and GPU status.
241
+
242
+ Args:
243
+ request: HealthCheckRequest message.
244
+ context: gRPC context.
245
+
246
+ Returns:
247
+ HealthCheckResponse with server status.
248
+ """
249
+ response = robot_inference_pb2.HealthCheckResponse()
250
+ response.healthy = True
251
+ response.status = "Server is running"
252
+ response.model_name = self.cfg.policy.type
253
+ response.device = str(self.device)
254
+
255
+ if torch.cuda.is_available():
256
+ response.gpu_memory_used_gb = torch.cuda.memory_allocated() / 1e9
257
+ response.gpu_memory_total_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
258
+
259
+ return response
260
+
261
+
262
+ def serve(cfg: TrainPipelineConfig):
263
+ """Start the gRPC server.
264
+
265
+ Args:
266
+ cfg: Training pipeline configuration including server settings.
267
+ """
268
+ server_cfg = cfg.server
269
+ server = grpc.server(
270
+ futures.ThreadPoolExecutor(max_workers=server_cfg.max_workers),
271
+ options=[
272
+ ("grpc.max_send_message_length", server_cfg.max_send_message_length),
273
+ ("grpc.max_receive_message_length", server_cfg.max_receive_message_length),
274
+ ],
275
+ )
276
+
277
+ servicer = RobotPolicyServicer(cfg)
278
+ robot_inference_pb2_grpc.add_RobotPolicyServiceServicer_to_server(servicer, server)
279
+
280
+ server.add_insecure_port(f"[::]:{server_cfg.port}")
281
+ server.start()
282
+
283
+ logger.info(f"Server started on port {server_cfg.port}")
284
+ logger.info(f"Policy: {cfg.policy.type}")
285
+ logger.info(f"Device: {servicer.device}")
286
+ logger.info(f"Max workers: {server_cfg.max_workers}")
287
+ logger.info("Waiting for requests...")
288
+
289
+ try:
290
+ server.wait_for_termination()
291
+ except KeyboardInterrupt:
292
+ logger.info("Shutting down server...")
293
+ server.stop(grace=5)
294
+
295
+
296
+ @parser.wrap()
297
+ def server_main(cfg: TrainPipelineConfig):
298
+ """Main entry point for the gRPC server.
299
+
300
+ Args:
301
+ cfg: Training pipeline configuration parsed from CLI/config file.
302
+ """
303
+ logging.info(pformat(asdict(cfg)))
304
+
305
+ if cfg.seed is not None:
306
+ set_seed(cfg.seed)
307
+
308
+ serve(cfg)
309
+
310
+
311
+ if __name__ == "__main__":
312
+ init_logging()
313
+ server_main()
opentau/scripts/launch.py CHANGED
@@ -18,10 +18,6 @@ import sys
18
18
  from pathlib import Path
19
19
  from types import ModuleType
20
20
 
21
- import opentau.scripts.eval as eval_script
22
- import opentau.scripts.export_to_onnx as export_script
23
- import opentau.scripts.train as train_script
24
-
25
21
 
26
22
  def launch(script_module: ModuleType, description: str, use_accelerate: bool = True):
27
23
  """Generic launcher for OpenTau scripts using Accelerate or Python."""
@@ -68,12 +64,24 @@ def launch(script_module: ModuleType, description: str, use_accelerate: bool = T
68
64
 
69
65
 
70
66
  def train():
67
+ import opentau.scripts.train as train_script
68
+
71
69
  launch(train_script, "Launch OpenTau training with Accelerate")
72
70
 
73
71
 
74
72
  def eval():
73
+ import opentau.scripts.eval as eval_script
74
+
75
75
  launch(eval_script, "Launch OpenTau evaluation with Accelerate")
76
76
 
77
77
 
78
78
  def export():
79
+ import opentau.scripts.export_to_onnx as export_script
80
+
79
81
  launch(export_script, "Launch OpenTau ONNX export", use_accelerate=False)
82
+
83
+
84
+ def visualize():
85
+ import opentau.scripts.visualize_dataset as visualize_script
86
+
87
+ launch(visualize_script, "Launch OpenTau visualization", use_accelerate=False)