grid-cortex-client 0.1.150__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,34 @@
1
+ \
2
+ import logging
3
+
4
+ # Configure the library logger to be silent by default.
5
+ # Applications using this library should configure their own logging handlers and levels
6
+ # if they wish to see logs from "grid_cortex_client".
7
+ logging.getLogger(__name__).addHandler(logging.NullHandler())
8
+
9
+ # Make key classes and functions available for import
10
+ from .client import CortexAPIError, CortexNetworkError, HTTPClient
11
+ from .cortex_client import CortexClient
12
+ from .models import BaseModel, DepthModel, DetectionModel, SegmentationModel
13
+ from .preprocessing import encode_image_to_base64, load_image
14
+ from .postprocessing import (
15
+ postprocess_depth_response,
16
+ postprocess_detection_response,
17
+ postprocess_segmentation_response,
18
+ )
19
+
20
+ __all__ = [
21
+ "CortexClient",
22
+ "HTTPClient",
23
+ "CortexAPIError",
24
+ "CortexNetworkError",
25
+ "BaseModel",
26
+ "DepthModel",
27
+ "DetectionModel",
28
+ "SegmentationModel",
29
+ "load_image",
30
+ "encode_image_to_base64",
31
+ "postprocess_depth_response",
32
+ "postprocess_detection_response",
33
+ "postprocess_segmentation_response",
34
+ ]
@@ -0,0 +1,85 @@
1
+ from typing import Any, Dict, Optional, Union
2
+ import os
3
+ import httpx
4
+ import logging
5
+
6
+ BASE_URL = os.getenv("GRID_CORTEX_BASE_URL", "https://cortex-prod.generalrobotics.dev")
7
+
8
+ # Remove basicConfig, as library logging should not configure root logger
9
+ # logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
10
+ logger = logging.getLogger(__name__)
11
+
12
+ class CortexAPIError(Exception):
13
+ """Custom exception for API errors."""
14
+ def __init__(self, status_code: int, message: str, details: Optional[Dict[str, Any]] = None): # Added details
15
+ self.status_code = status_code
16
+ self.message = message
17
+ self.details = details # Store details
18
+ super().__init__(f"API Error {status_code}: {message} {details if details else ''}") # Optionally include details in main message
19
+
20
+ class CortexNetworkError(Exception):
21
+ """Custom exception for network issues."""
22
+ pass
23
+
24
+ class HTTPClient:
25
+ def __init__(
26
+ self,
27
+ api_key: Optional[str] = None,
28
+ base_url: Optional[str] = None, # Added base_url parameter
29
+ timeout: float = 60.0,
30
+ ):
31
+ headers = {"Content-Type": "application/json"}
32
+ if api_key is None:
33
+ api_key = os.getenv("GRID_CORTEX_API_KEY")
34
+ if not api_key:
35
+ logger.warning("GRID_CORTEX_API_KEY environment variable not set. Client might not authenticate if API requires a key.")
36
+ # Allow client to be initialized without API key if base_url is for a public/local service not requiring it
37
+ # raise ValueError("GRID_CORTEX_API_KEY is not set. Please provide it or set the environment variable.")
38
+ else:
39
+ headers["x-api-key"] = api_key
40
+ else: # api_key is provided directly
41
+ headers["x-api-key"] = api_key
42
+
43
+ # Determine the final base URL
44
+ final_base_url = base_url or os.getenv("GRID_CORTEX_BASE_URL", "https://cortex-prod.generalrobotics.dev")
45
+ logger.info(f"HTTPClient initialized for base URL: {final_base_url}")
46
+
47
+ self._client = httpx.Client(
48
+ base_url=final_base_url, timeout=timeout, headers=headers
49
+ )
50
+
51
+ def post(
52
+ self,
53
+ path: str,
54
+ *,
55
+ json: Any = None,
56
+ data: Union[Dict[str, Any], bytes, None] = None,
57
+ headers: Optional[Dict[str, str]] = None,
58
+ timeout: Optional[float] = None, # Added timeout parameter
59
+ ) -> Dict[str, Any]:
60
+ logger.info(f"POST {path}")
61
+ try:
62
+ # Pass timeout to httpx client's post method.
63
+ # If timeout is None here, httpx will use the client's default timeout.
64
+ resp = self._client.post(path, json=json, data=data, headers=headers, timeout=timeout)
65
+ resp.raise_for_status()
66
+ logger.info(f"POST {path} successful ({resp.status_code})")
67
+ return resp.json()
68
+ except httpx.HTTPStatusError as e:
69
+ logger.error(f"HTTP Status Error for {path}: {e.response.status_code} - {e.response.text}")
70
+ error_details_dict: Optional[Dict[str, Any]] = None
71
+ try:
72
+ error_details_dict = e.response.json()
73
+
74
+ # Use 'detail' key if present, otherwise use the full JSON or raw text
75
+ error_message = error_details_dict.get("detail", str(error_details_dict)) # Use str(error_details_dict) if 'detail' is missing
76
+ except ValueError: # If response is not JSON
77
+ error_message = e.response.text
78
+ raise CortexAPIError(status_code=e.response.status_code, message=error_message, details=error_details_dict) from e
79
+ except httpx.RequestError as e:
80
+ logger.error(f"Request Error for {path}: {e}")
81
+ raise CortexNetworkError(f"Network request to {e.request.url} failed: {e}") from e
82
+
83
+ def close(self) -> None:
84
+ logger.info("Closing HTTPClient.")
85
+ self._client.close()
@@ -0,0 +1,333 @@
1
+ # filepath: /home/pranay/GRID/Grid-Cortex-Infra/grid-cortex-client/src/grid_cortex_client/cortex_client.py
2
+ import logging
3
+ from typing import Any, Dict, Optional, Type # Added Type
4
+
5
+ from PIL import Image # Add this import
6
+
7
+ from .client import (CortexAPIError, CortexNetworkError, HTTPClient)
8
+ from .preprocessing import load_image, encode_image_to_base64 # Add this import
9
+ from .postprocessing import postprocess_depth_response, postprocess_detection_response # Add postprocess_detection_response
10
+ from .models.base_model import BaseModel
11
+ from .models.depth import DepthModel
12
+ from .models.detection import DetectionModel
13
+ from .models.segmentation import SegmentationModel, SAM2Model
14
+ from .models.vlm import VLMModel, MoonDreamModel
15
+ from .models.matching import MatchingModel
16
+ from .models.grasp import GraspModel
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ class CortexClient:
21
+ """Client for interacting with Grid Cortex Ray Serve deployments."""
22
+
23
+ # Client-side registry to map model_id keywords to handler classes
24
+ _MODEL_ID_TO_HANDLER_CLASS: Dict[str, Type[BaseModel]] = {
25
+ "midas": DepthModel,
26
+ "marigold": DepthModel, # Example keyword for depth models
27
+ "depthpro": DepthModel, # Example keyword for depth models
28
+ "metric3d": DepthModel, # Example keyword for depth models
29
+ "depthanything2": DepthModel, # Example keyword for depth models
30
+ "zoedepth": DepthModel, # Example keyword for depth models
31
+ "grounding-dino": DetectionModel,
32
+ "owlv2": DetectionModel, # Added owlv2 mapping to DetectionModel
33
+ "sam2": SAM2Model, # Example keyword for SAM-like models
34
+ "clipseg": SegmentationModel, # Example keyword for LSeg
35
+ "oneformer": SegmentationModel, # Example keyword for OneFormer
36
+ "lseg": SegmentationModel, # Example keyword for LSeg
37
+ "gsam2": SegmentationModel, # Example keyword for G-SAM2 # Add more mappings here as new model types are supported
38
+ "rtdetr": DetectionModel, # Example keyword for RT-DETR
39
+ "moondream": MoonDreamModel, # Added MoonDreamModel mapping to VLMModel
40
+ "llava": VLMModel, # Added LLAVA mapping to VLMModel
41
+ "molmo": VLMModel, # Added MolModel mapping to VLMModel
42
+ "phi4": VLMModel, # Added Phi4 mapping to VLMModel
43
+ "magma": VLMModel, # Added Magma mapping to VLMModel
44
+ "lightglue": MatchingModel,
45
+ "robobrain2": VLMModel, # Added RoboBrain-2 mapping to VLMModel
46
+ "vggt": DepthModel,
47
+ "robopoint": VLMModel, # Added RoboPoint mapping to VLMModel
48
+ "m2t2": GraspModel,
49
+ }
50
+
51
+ def __init__(
52
+ self,
53
+ api_key: Optional[str] = None,
54
+ base_url: Optional[str] = None, # This will be used to set HTTPClient's base_url
55
+ timeout: float = 30.0,
56
+ ):
57
+ """
58
+ Initializes the CortexClient.
59
+
60
+ Args:
61
+ api_key: API key. Uses GRID_CORTEX_API_KEY env var if None.
62
+ base_url: Base URL of the Cortex API. If None, uses GRID_CORTEX_BASE_URL env var or HTTPClient's default.
63
+ timeout: Default timeout for HTTP requests in seconds.
64
+ """
65
+ # Pass base_url to HTTPClient constructor
66
+ self.http_client = HTTPClient(api_key=api_key, base_url=base_url, timeout=timeout)
67
+
68
+ effective_http_base_url = self.http_client._client.base_url # Access the actual base_url used by httpx.Client
69
+ logger.info(f"CortexClient initialized. HTTPClient target: {effective_http_base_url}")
70
+
71
+ def _make_request(
72
+ self,
73
+ endpoint: str,
74
+ payload: Optional[Dict[str, Any]] = None,
75
+ timeout: Optional[float] = None, # Added timeout parameter
76
+ ) -> Dict[str, Any]:
77
+ """Helper method to make POST requests to the /run endpoint."""
78
+ try:
79
+ # All model interactions go through a unified /run endpoint via POST
80
+ # Pass the timeout to the http_client.post method
81
+ return self.http_client.post(endpoint, json=payload, timeout=timeout)
82
+ except (CortexAPIError, CortexNetworkError) as e:
83
+ logger.error(f"Request to {endpoint} failed: {e}")
84
+ raise
85
+ except Exception as e:
86
+ logger.error(f"Unexpected error during request to {endpoint}: {e}")
87
+ raise CortexNetworkError(f"An unexpected error occurred: {e}") from e
88
+
89
+ def run_model(
90
+ self,
91
+ model: BaseModel,
92
+ input_data: Any,
93
+ preprocess_kwargs: Optional[Dict[str, Any]] = None,
94
+ postprocess_kwargs: Optional[Dict[str, Any]] = None,
95
+ visualize: bool = False,
96
+ visualization_kwargs: Optional[Dict[str, Any]] = None,
97
+ timeout: Optional[float] = None, # Added timeout parameter
98
+ ) -> Any:
99
+ """
100
+ Runs a specified model with the given input data.
101
+
102
+ This method orchestrates the preprocessing, API request, postprocessing,
103
+ and optional visualization steps.
104
+
105
+ Args:
106
+ model: An instance of a BaseModel subclass (e.g., DepthModel).
107
+ input_data: The raw input data for the model.
108
+ preprocess_kwargs: Additional keyword arguments for the model's preprocess method.
109
+ postprocess_kwargs: Additional keyword arguments for the model's postprocess method.
110
+ visualize: If True, attempts to run the model's visualize method.
111
+ visualization_kwargs: Additional keyword arguments for the model's visualize method.
112
+ timeout: Optional timeout in seconds for the request.
113
+
114
+ Returns:
115
+ The processed output from the model.
116
+
117
+ Raises:
118
+ CortexAPIError: If the API returns an error.
119
+ CortexNetworkError: If a network error occurs.
120
+ ValueError: If preprocessing or postprocessing fails.
121
+ """
122
+ logger.info(f"Running model {model.model_id} with input type: {type(input_data)}")
123
+
124
+ _preprocess_kwargs = preprocess_kwargs or {}
125
+ _postprocess_kwargs = postprocess_kwargs or {}
126
+ _visualization_kwargs = visualization_kwargs or {}
127
+
128
+ try:
129
+ payload = model.preprocess(input_data, **_preprocess_kwargs)
130
+ # Ensure model_id from the model instance is in the payload.
131
+ # Only warn if 'model_id' is present in the payload but incorrect.
132
+ # If 'model_id' is missing, run_model will add it, which is expected.
133
+ if "model_id" in payload and payload["model_id"] != model.model_id:
134
+ logger.warning(
135
+ f"Payload 'model_id' ('{payload.get('model_id')}') from preprocess "
136
+ f"does not match model instance's model_id ('{model.model_id}'). "
137
+ f"Overwriting with instance's model_id."
138
+ )
139
+ elif "model_id" not in payload:
140
+ logger.debug( # Log at debug level if model_id is simply missing from preprocess output
141
+ f"Payload from preprocess for model '{model.model_id}' does not contain 'model_id'. "
142
+ f"'{self.__class__.__name__}.run_model' will add it."
143
+ )
144
+ payload["model_id"] = model.model_id # Ensure it's there and correct.
145
+
146
+ except Exception as e:
147
+ logger.error(f"Preprocessing failed for model {model.model_id}: {e}", exc_info=True)
148
+ raise ValueError(f"Preprocessing error for {model.model_id}: {e}") from e
149
+
150
+ # Construct the endpoint using the model's model_id
151
+ # e.g., /depth-anything-v2-large/run
152
+ # The HTTPClient will prepend its base_url (e.g., https://cortex-stage.generalrobotics.dev)
153
+ specific_model_endpoint = f"/{model.model_id}/run"
154
+
155
+ logger.info(f"Requesting model execution from: {specific_model_endpoint} for model_id: {model.model_id}")
156
+
157
+ # Pass timeout to _make_request
158
+ api_response = self._make_request(specific_model_endpoint, payload=payload, timeout=timeout)
159
+ logger.info(f"API response for model {model.model_id}: {api_response}") # Log the API response
160
+
161
+ try:
162
+ _postprocess_kwargs['model_name'] = model.model_id
163
+ processed_output = model.postprocess(api_response, **_postprocess_kwargs)
164
+ except Exception as e:
165
+ logger.error(f"Postprocessing failed for model {model.model_id}: {e}", exc_info=True)
166
+ # Consider re-raising a more specific error or just re-raising
167
+ raise ValueError(f"Postprocessing error for {model.model_id}: {e}") from e
168
+
169
+ if visualize:
170
+ try:
171
+ logger.info(f"Attempting visualization for model {model.model_id}.")
172
+ # Pass original input data to visualize method if it might be needed
173
+ model.visualize(processed_output, original_input=input_data, **_visualization_kwargs)
174
+ except Exception as e:
175
+ logger.error(f"Visualization failed for model {model.model_id}: {e}", exc_info=True)
176
+ # Do not re-raise visualization errors, just log them.
177
+
178
+ logger.info(f"Successfully ran model {model.model_id}.")
179
+ return processed_output
180
+
181
+ def run(
182
+ self,
183
+ model_id: str,
184
+ timeout: Optional[float] = None,
185
+ debug: bool = False, # Added debug parameter
186
+ **kwargs: Any
187
+ ) -> Any:
188
+ """
189
+ Runs a model by its ID with the given input data using keyword arguments.
190
+
191
+ This method simplifies model execution by:
192
+ 1. Identifying the appropriate model handler based on model_id.
193
+ 2. Instantiating the handler.
194
+ 3. Passing all keyword arguments (**kwargs) to the handler's preprocess method
195
+ via the run_model method.
196
+
197
+ Args:
198
+ model_id: The identifier of the model to run.
199
+ timeout: Optional timeout in seconds for the request.
200
+ debug: If True, sets the library's logger level to DEBUG for this call.
201
+ **kwargs: Arbitrary keyword arguments that will be passed as a dictionary
202
+ to the model handler's preprocess method. This should include
203
+ all necessary inputs for the model (e.g., image_input, text_prompts).
204
+
205
+ Returns:
206
+ The processed output from the model.
207
+
208
+ Raises:
209
+ CortexAPIError: If the API returns an error.
210
+ CortexNetworkError: If a network error occurs.
211
+ NotImplementedError: If no suitable model handler is found for the model_id.
212
+ ValueError: If preprocessing or postprocessing within the handler fails.
213
+ """
214
+ # Store original logging level
215
+ original_level = None
216
+ library_logger = logging.getLogger("grid_cortex_client") # Get the library's root logger
217
+
218
+ if debug:
219
+ original_level = library_logger.getEffectiveLevel()
220
+ library_logger.setLevel(logging.DEBUG)
221
+ # Ensure there's a handler that outputs debug messages, e.g., to console for the debug session
222
+ # This is tricky as libraries shouldn't add handlers.
223
+ # For a temporary debug flag, we might add a temporary console handler if none exist
224
+ # or rely on the application to have configured one.
225
+ # For simplicity here, we'll assume if debug=True, the user wants to see logs
226
+ # and might have a handler. If not, they won't see them despite level change.
227
+ # A more robust solution might involve a context manager for logging level.
228
+ logger.info(f"Debug mode enabled for this run. Setting grid_cortex_client logger to DEBUG.")
229
+
230
+ logger.info(f"Attempting to run model '{model_id}' with inputs: {list(kwargs.keys())}")
231
+
232
+ HandlerClass = None
233
+ for keyword, HClass in self._MODEL_ID_TO_HANDLER_CLASS.items():
234
+ if keyword in model_id.lower():
235
+ HandlerClass = HClass
236
+ logger.info(f"Found handler {HandlerClass.__name__} for model_id '{model_id}' based on keyword '{keyword}'.")
237
+ break
238
+
239
+ if HandlerClass is None:
240
+ logger.error(f"No suitable model handler found for model_id: {model_id}. "
241
+ f"Available handlers are for keywords: {list(self._MODEL_ID_TO_HANDLER_CLASS.keys())}")
242
+ raise NotImplementedError(
243
+ f"No model handler configured for model_id containing typical keywords for known types: '{model_id}'. "
244
+ f"Please ensure the model_id is correct or update the client's model handler registry."
245
+ )
246
+
247
+ try:
248
+ # Instantiate the handler
249
+ model_handler = HandlerClass(model_id=model_id)
250
+
251
+ # The `run_model` method expects `input_data` which will be passed to the
252
+ # handler's `preprocess` method. We pass the collected `kwargs` directly.
253
+ # The handler's `preprocess` method is responsible for interpreting these kwargs.
254
+ # The `http_client.post` within `run_model` will use its own default timeout
255
+ # if `timeout` is not explicitly managed by `run_model` or `_make_request`.
256
+ # For now, the `timeout` parameter in `run` is not directly plumbed into `run_model`
257
+ # as `run_model` doesn't have a top-level timeout arg.
258
+ # The `HTTPClient.post` method, called by `_make_request`, does accept a timeout.
259
+ # This could be a point of future refinement if per-call timeout in `run()` needs
260
+ # to override the client's default when using `run_model`.
261
+ # However, the original `run` method *did* pass timeout to `http_client.post`.
262
+ # To maintain that, we'd need to adjust `run_model` or how it calls `_make_request`.
263
+
264
+ # For simplicity and consistency with `run_model`'s current signature,
265
+ # we are not passing the `timeout` from `run` to `run_model` here.
266
+ # The `http_client` will use its configured default or the one passed during its init.
267
+ # If a per-call timeout override is needed here, `run_model` or `_make_request`
268
+ # would need modification.
269
+
270
+ # Let's reconsider: The original `run` method's `http_client.post` call *did* use the `timeout` parameter.
271
+ # The `_make_request` method in `run_model` does not currently accept `timeout`.
272
+ # To honor the `timeout` parameter in `run()`, we should adjust `_make_request` or `HTTPClient.post`
273
+ # handling within `run_model`.
274
+
275
+ # Simplest immediate path: Modify `_make_request` to accept timeout.
276
+ # This is a bit of a detour from just refactoring `run`, but important for functionality.
277
+ # Let's assume for now `_make_request` will be updated or `run_model` handles it.
278
+ # The current `run_model` calls `self._make_request(specific_model_endpoint, payload=payload)`
279
+ # which doesn't pass timeout.
280
+
281
+ # Given the current structure, the `timeout` in `run()` will not be used if we directly call `run_model`.
282
+ # The original `run` method's `self.http_client.post(endpoint_path, json=payload, timeout=timeout)`
283
+ # directly used the timeout.
284
+
285
+ # To preserve the timeout functionality when calling from the simplified `run` method,
286
+ # we need to ensure it's passed down.
287
+ # One way is to make `run_model` accept a timeout.
288
+ # Another is to bypass `run_model` if `run` is meant to be a truly simplified path
289
+ # that handles its own HTTP call after getting the payload from the handler.
290
+ #
291
+ # The goal of `run_model` is to be the comprehensive orchestrator.
292
+ # Let's stick to calling `run_model`. The `timeout` in `run` might be considered
293
+ # an override for this specific call.
294
+ #
295
+ # For now, I will proceed with calling `run_model`. The `timeout` from `run`
296
+ # will effectively be ignored by `run_model` unless `run_model` or `_make_request` is changed.
297
+ # This is a point to clarify or address in a subsequent step if precise timeout control
298
+ # from `run()` through `run_model()` is critical.
299
+
300
+ # The `preprocess_kwargs` and `postprocess_kwargs` in `run_model` are for additional args
301
+ # to those methods, not for the primary input data.
302
+ # Pass the timeout from run to run_model
303
+ output = self.run_model(
304
+ model=model_handler,
305
+ input_data=kwargs, # Pass all kwargs as the input_data dictionary
306
+ timeout=timeout # Pass timeout here
307
+ )
308
+ return output
309
+
310
+ except (CortexAPIError, CortexNetworkError, ValueError, NotImplementedError) as e:
311
+ # Re-raise known errors
312
+ logger.error(f"Error running model '{model_id}': {e}", exc_info=True)
313
+ raise
314
+ except Exception as e:
315
+ # Catch any other unexpected errors
316
+ logger.error(f"Unexpected error running model '{model_id}': {e}", exc_info=True)
317
+ raise CortexNetworkError(f"An unexpected error occurred while running model {model_id}: {e}") from e
318
+ finally:
319
+ if debug and original_level is not None:
320
+ library_logger.setLevel(original_level)
321
+ logger.info(f"Debug mode disabled. Restored grid_cortex_client logger level to {logging.getLevelName(original_level)}.")
322
+
323
+ def close(self):
324
+ """Closes the underlying HTTP client."""
325
+ self.http_client.close()
326
+ logger.info("CortexClient closed.")
327
+
328
+ def __enter__(self):
329
+ return self
330
+
331
+ def __exit__(self, exc_type, exc_val, exc_tb):
332
+ self.close()
333
+
@@ -0,0 +1,6 @@
1
+ def main():
2
+ print("Hello from grid-cortex-client!")
3
+
4
+
5
+ if __name__ == "__main__":
6
+ main()
@@ -0,0 +1,21 @@
1
+ """Models for interacting with specific Cortex deployments."""
2
+ from .base_model import BaseModel
3
+ from .depth import DepthModel
4
+ from .detection import DetectionModel
5
+ from .segmentation import SegmentationModel, SAM2Model
6
+ from .vlm import VLMModel, MoonDreamModel # Importing the VLMModel and MoonDreamModel for Vision Language Models
7
+ from .matching import MatchingModel # Importing the MatchingModel for image matching
8
+ from .grasp import GraspModel
9
+
10
+ __all__ = [
11
+ "BaseModel",
12
+ "DepthModel",
13
+ "DetectionModel",
14
+ "SegmentationModel",
15
+ "SAM2Model",
16
+ "MatchingModel", # Including MatchingModel in the public API
17
+ "VLMModel", # Including VLMModel in the public API
18
+ "GraspModel",
19
+ "MoonDreamModel", # Assuming MoonDreamModel is a specific VLMModel
20
+ ]
21
+
@@ -0,0 +1,64 @@
1
+ \
2
+ """Base model for Cortex API interactions."""
3
+ from abc import ABC, abstractmethod
4
+ from typing import Any, Dict, Optional
5
+ import numpy as np
6
+ from PIL import Image
7
+
8
+ class BaseModel(ABC):
9
+ """
10
+ Abstract base class for all models.
11
+ Each model implementation should define how to preprocess its specific input,
12
+ postprocess the server's response, and optionally visualize the output.
13
+ """
14
+
15
+ def __init__(self, model_id: str):
16
+ """
17
+ Initializes the BaseModel.
18
+
19
+ Args:
20
+ model_id: A unique identifier for the model, used in the payload to the /run endpoint.
21
+ """
22
+ self.model_id = model_id
23
+
24
+ @abstractmethod
25
+ def preprocess(self, input_data: Any, **kwargs) -> Dict[str, Any]:
26
+ """
27
+ Preprocesses the input data into the format expected by the Cortex API's /run endpoint.
28
+
29
+ Args:
30
+ input_data: The raw input data (e.g., image path, numpy array, PIL Image).
31
+ **kwargs: Additional keyword arguments for preprocessing.
32
+
33
+ Returns:
34
+ A dictionary representing the JSON payload for the /run endpoint.
35
+ This payload should include the model_id and the processed input.
36
+ """
37
+ pass
38
+
39
+ @abstractmethod
40
+ def postprocess(self, response_data: Dict[str, Any], **kwargs) -> Any:
41
+ """
42
+ Postprocesses the JSON response from the Cortex API.
43
+
44
+ Args:
45
+ response_data: The JSON response from the API.
46
+ **kwargs: Additional keyword arguments for postprocessing.
47
+
48
+ Returns:
49
+ The processed output in a user-friendly format (e.g., numpy array, custom object).
50
+ """
51
+ pass
52
+
53
+ def visualize(self, processed_output: Any, original_input: Optional[Any] = None, **kwargs) -> None:
54
+ """
55
+ Optional method to visualize the processed output.
56
+ Implementations should handle cases where visualization is not possible
57
+ or not requested.
58
+
59
+ Args:
60
+ processed_output: The output from the postprocess method.
61
+ original_input: The original input data, if needed for context in visualization.
62
+ **kwargs: Additional keyword arguments for visualization.
63
+ """
64
+ print(f"Visualization not implemented for model {self.model_id}.")
@@ -0,0 +1,131 @@
1
+ \
2
+ """Depth estimation model implementation."""
3
+ import base64
4
+ import io
5
+ from typing import Any, Dict, Union, Optional
6
+
7
+ import numpy as np
8
+ from PIL import Image
9
+
10
+ from .base_model import BaseModel
11
+ from ..preprocessing import load_image as general_load_image
12
+ from ..postprocessing import postprocess_depth_response as general_postprocess_depth_response
13
+ from ..visualization import visualize_depth_map_rerun as general_visualize_depth_map_rerun
14
+
15
+
16
+ class DepthModel(BaseModel):
17
+ """
18
+ Handles depth estimation tasks by interacting with the Cortex API.
19
+ """
20
+
21
+ DEFAULT_MODEL_ID = "depth_estimation_model_id" # Placeholder, user should configure
22
+
23
+ def __init__(self, model_id: Optional[str] = None):
24
+ """
25
+ Initializes the DepthModel.
26
+
27
+ Args:
28
+ model_id: The specific model ID for depth estimation if different from default.
29
+ This ID is sent to the generic /run endpoint.
30
+ """
31
+ super().__init__(model_id or self.DEFAULT_MODEL_ID)
32
+
33
+ def preprocess(
34
+ self,
35
+ input_data: Dict[str, Any],
36
+ ) -> Dict[str, Any]:
37
+ """
38
+ Prepares the image for depth estimation using a dictionary of inputs.
39
+
40
+ Args:
41
+ input_data: A dictionary containing the input data. Expected key:
42
+ 'image_input' (Union[str, np.ndarray, Image.Image]): Path to the image,
43
+ a NumPy array, or a PIL Image object. (Required)
44
+
45
+ Returns:
46
+ A dictionary payload for the /run endpoint, including the model_id
47
+ and the base64 encoded image.
48
+
49
+ Raises:
50
+ ValueError: If 'image_input' is not found in input_data.
51
+ """
52
+ image_input = input_data.get('image_input')
53
+ if image_input is None:
54
+ raise ValueError("'image_input' not found in input_data for DepthModel preprocessing.")
55
+
56
+ # Use a default encoding format; resizing is not handled here.
57
+ encoding_format = "JPEG"
58
+
59
+ pil_image = general_load_image(image_input)
60
+
61
+ # Resizing logic based on target_width/target_height has been removed.
62
+ # If resizing is needed, it should be done before calling the client,
63
+ # or handled by the model server.
64
+
65
+ buffered = io.BytesIO()
66
+ pil_image.save(buffered, format=encoding_format)
67
+ encoded_image = base64.b64encode(buffered.getvalue()).decode("utf-8")
68
+
69
+ # The API expects 'image_input' at the top level of the payload for this specific model endpoint.
70
+ # It also seems to expect 'model_id' and other parameters alongside 'image_input',
71
+ # not nested under 'inputs'.
72
+ # The error indicates 'image_input' should be a string.
73
+ payload = {
74
+ "model_id": self.model_id,
75
+ "image_input": encoded_image
76
+ # "encoding_format": encoding_format.lower() # This might be a separate top-level param or not needed
77
+ # Any other model-specific parameters can be added here if they are top-level
78
+ }
79
+ # If the API expects image_base64 directly under image_input without a nested object:
80
+ # payload = {
81
+ # "model_id": self.model_id,
82
+ # "image_input": encoded_image, # If image_input is just the base64 string
83
+ # # "encoding_format": encoding_format.lower() # This might be a separate top-level param
84
+ # }
85
+ # Based on the error "loc": ['body', 'image_input'], 'msg': 'Field required',
86
+ # it implies 'image_input' is a required top-level field.
87
+ # The exact structure of 'image_input' (string vs object) needs to be confirmed
88
+ # if this change doesn't work. For now, assuming it's an object as per common practice.
89
+
90
+ return payload
91
+
92
+ def postprocess(self, response_data: Dict[str, Any], **kwargs: Any) -> np.ndarray:
93
+ """
94
+ Processes the API response to extract the depth map.
95
+
96
+ Args:
97
+ response_data: The JSON response from the API.
98
+ **kwargs: Additional keyword arguments (not used).
99
+
100
+ Returns:
101
+ A NumPy array representing the depth map.
102
+ """
103
+ # Assuming the generic postprocessor can handle the depth response structure
104
+ return general_postprocess_depth_response(response_data)
105
+
106
+ def visualize(
107
+ self,
108
+ depth_map: np.ndarray,
109
+ original_image: Optional[Union[str, np.ndarray, Image.Image]] = None,
110
+ **kwargs
111
+ ) -> None:
112
+ """
113
+ Visualizes the depth map using Rerun.
114
+
115
+ Args:
116
+ depth_map: The depth map (NumPy array) to visualize.
117
+ original_image: Optional original image for context.
118
+ **kwargs: Additional arguments for Rerun visualization.
119
+ """
120
+ try:
121
+ import rerun as rr # type: ignore
122
+ rr.log("depth_map", rr.DepthImage(depth_map))
123
+ if original_image is not None:
124
+ pil_image = general_load_image(original_image)
125
+ rr.log("original_image", rr.Image(np.array(pil_image)))
126
+ print("Depth map visualized. Check your Rerun viewer.")
127
+ except ImportError:
128
+ print("Rerun SDK not installed. Skipping visualization.")
129
+ except Exception as e:
130
+ print(f"Error during Rerun visualization: {e}")
131
+