veris-ai 0.2.1__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of veris-ai might be problematic. Click here for more details.

@@ -0,0 +1,153 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from typing import Any
5
+
6
+ from pydantic import BaseModel, ConfigDict, Field, field_validator
7
+
8
+ __all__ = [
9
+ "Tag",
10
+ "Process",
11
+ "Span",
12
+ "Trace",
13
+ "SearchResponse",
14
+ "GetTraceResponse",
15
+ "SearchQuery",
16
+ ]
17
+
18
+
19
+ class Tag(BaseModel):
20
+ """A Jaeger tag key/value pair."""
21
+
22
+ key: str
23
+ value: Any
24
+ type: str | None = None # Jaeger uses an optional *type* field in v1
25
+
26
+
27
+ class Process(BaseModel):
28
+ """Represents the *process* section of a Jaeger trace."""
29
+
30
+ serviceName: str = Field(alias="serviceName") # noqa: N815
31
+ tags: list[Tag] | None = None
32
+
33
+
34
+ class Span(BaseModel):
35
+ """Represents a single Jaeger span."""
36
+
37
+ traceID: str # noqa: N815
38
+ spanID: str # noqa: N815
39
+ operationName: str # noqa: N815
40
+ startTime: int # noqa: N815
41
+ duration: int
42
+ tags: list[Tag] | None = None
43
+ references: list[dict[str, Any]] | None = None
44
+ processID: str | None = None # noqa: N815
45
+
46
+ model_config = ConfigDict(extra="allow")
47
+
48
+
49
+ class Trace(BaseModel):
50
+ """A full Jaeger trace as returned by the Query API."""
51
+
52
+ traceID: str # noqa: N815
53
+ spans: list[Span]
54
+ process: Process | dict[str, Process] | None = None
55
+ warnings: list[str] | None = None
56
+
57
+ model_config = ConfigDict(extra="allow")
58
+
59
+
60
+ class _BaseResponse(BaseModel):
61
+ data: list[Trace] | Trace | None = None
62
+ errors: list[str] | None = None
63
+
64
+ # Allow any additional keys returned by Jaeger so that nothing gets
65
+ # silently dropped if the backend adds new fields we don’t know about.
66
+
67
+ model_config = ConfigDict(extra="allow")
68
+
69
+
70
+ class SearchResponse(_BaseResponse):
71
+ """Response model for *search* or *find traces* requests."""
72
+
73
+ total: int | None = None
74
+ limit: int | None = None
75
+
76
+
77
+ class GetTraceResponse(_BaseResponse):
78
+ """Response model for *get trace by id* requests."""
79
+
80
+ # Same as base but alias for clarity
81
+
82
+
83
+ # ---------------------------------------------------------------------------
84
+ # Query models
85
+ # ---------------------------------------------------------------------------
86
+
87
+
88
+ class SearchQuery(BaseModel):
89
+ """Minimal set of query parameters for the `/api/traces` endpoint.
90
+
91
+ Parameter interaction rules:
92
+
93
+ * **service** – global filter; *all* returned traces must belong to this
94
+ service.
95
+ * **operation** – optional secondary filter; returned traces must contain
96
+ *at least one span* whose ``operationName`` equals the provided value.
97
+ * **tags** – dictionary of key‒value pairs; each trace must include a span
98
+ that matches **all** of the pairs (logical AND).
99
+ * **limit** – applied *after* all other filters; truncates the final list
100
+ of traces to the requested maximum.
101
+
102
+ Any additional/unknown parameters are forwarded thanks to
103
+ ``extra = "allow"`` – this keeps the model future-proof.
104
+ """
105
+
106
+ # NOTE: Only the fields that are reliably supported by Jaeger’s REST API and
107
+ # work with the user’s deployment are kept. The model remains *open* to any
108
+ # extra parameters thanks to `extra = "allow"`.
109
+
110
+ service: str = Field(
111
+ ...,
112
+ description="Service name to search for. Example: 'veris-agent'",
113
+ )
114
+
115
+ limit: int | None = Field(
116
+ None,
117
+ description="Maximum number of traces to return. Example: 10",
118
+ )
119
+
120
+ tags: dict[str, Any] | None = Field(
121
+ None,
122
+ description=(
123
+ "Dictionary of tag filters (AND-combined). "
124
+ "Example: {'error': 'true', 'bt.metrics.time_to_first_token': '0.813544'}"
125
+ ),
126
+ )
127
+
128
+ operation: str | None = Field(
129
+ None,
130
+ description="Operation name to search for. Example: 'process_chat_message'",
131
+ )
132
+
133
+ model_config = ConfigDict(
134
+ extra="allow", # allow additional query params implicitly
135
+ populate_by_name=True,
136
+ str_to_lower=False,
137
+ )
138
+
139
+ @field_validator("tags", mode="before")
140
+ @classmethod
141
+ def _empty_to_none(cls, v: dict[str, Any] | None) -> dict[str, Any] | None: # noqa: D401, ANN102
142
+ return v or None
143
+
144
+ def to_params(self) -> dict[str, Any]: # noqa: D401
145
+ """Translate the model into a *requests*/*httpx* compatible params dict."""
146
+ # Dump using aliases so ``span_kind`` becomes ``spanKind`` automatically.
147
+ params: dict[str, Any] = self.model_dump(exclude_none=True, by_alias=True)
148
+
149
+ # Convert tags to a JSON string if necessary – this matches what the UI sends.
150
+ if "tags" in params and isinstance(params["tags"], dict):
151
+ params["tags"] = json.dumps(params["tags"])
152
+
153
+ return params
veris_ai/tool_mock.py CHANGED
@@ -4,137 +4,196 @@ import logging
4
4
  import os
5
5
  from collections.abc import Callable
6
6
  from contextlib import suppress
7
+ from contextvars import ContextVar
7
8
  from functools import wraps
8
- from typing import Any, TypeVar, Union, get_args, get_origin, get_type_hints
9
+ from typing import (
10
+ Any,
11
+ Literal,
12
+ TypeVar,
13
+ get_type_hints,
14
+ )
9
15
 
10
16
  import httpx
11
17
 
18
+ from veris_ai.utils import convert_to_type, extract_json_schema
19
+
12
20
  logger = logging.getLogger(__name__)
13
21
 
14
22
  T = TypeVar("T")
15
23
 
24
+ # Context variable to store session_id for each call
25
+ _session_id_context: ContextVar[str | None] = ContextVar("veris_session_id", default=None)
26
+
16
27
 
17
- class ToolMock:
28
+ class VerisSDK:
18
29
  """Class for mocking tool calls."""
19
30
 
20
31
  def __init__(self) -> None:
21
32
  """Initialize the ToolMock class."""
33
+ self._mcp = None
34
+
35
+ @property
36
+ def session_id(self) -> str | None:
37
+ """Get the session_id from context variable."""
38
+ return _session_id_context.get()
39
+
40
+ def set_session_id(self, session_id: str) -> None:
41
+ """Set the session_id in context variable."""
42
+ _session_id_context.set(session_id)
43
+ logger.info(f"Session ID set to {session_id}")
44
+
45
+ def clear_session_id(self) -> None:
46
+ """Clear the session_id from context variable."""
47
+ _session_id_context.set(None)
48
+ logger.info("Session ID cleared")
49
+
50
+ @property
51
+ def fastapi_mcp(self) -> Any | None: # noqa: ANN401
52
+ """Get the FastAPI MCP server."""
53
+ return self._mcp
54
+
55
+ def set_fastapi_mcp(self, **params_dict: Any) -> None: # noqa: ANN401
56
+ """Set the FastAPI MCP server."""
57
+ from fastapi import Depends, Request # noqa: PLC0415
58
+ from fastapi.security import OAuth2PasswordBearer # noqa: PLC0415
59
+ from fastapi_mcp import ( # type: ignore[import-untyped] # noqa: PLC0415
60
+ AuthConfig,
61
+ FastApiMCP,
62
+ )
63
+
64
+ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
65
+
66
+ async def authenticate_request(
67
+ _: Request,
68
+ token: str = Depends(oauth2_scheme), # noqa: ARG001
69
+ ) -> None:
70
+ self.set_session_id(token)
71
+
72
+ # Create auth config with dependencies
73
+ auth_config = AuthConfig(
74
+ dependencies=[Depends(authenticate_request)],
75
+ )
76
+
77
+ # Merge the provided params with our auth config
78
+ if "auth_config" in params_dict:
79
+ # Merge the provided auth config with our dependencies
80
+ provided_auth_config = params_dict.pop("auth_config")
81
+ if provided_auth_config.dependencies:
82
+ auth_config.dependencies.extend(provided_auth_config.dependencies)
83
+ # Copy other auth config properties if they exist
84
+ for field, value in provided_auth_config.model_dump(exclude_none=True).items():
85
+ if field != "dependencies" and hasattr(auth_config, field):
86
+ setattr(auth_config, field, value)
87
+
88
+ # Create the FastApiMCP instance with merged parameters
89
+ self._mcp = FastApiMCP(
90
+ auth_config=auth_config,
91
+ **params_dict,
92
+ )
93
+
94
+ def mock(
95
+ self,
96
+ mode: Literal["tool", "function"] = "tool",
97
+ expects_response: bool | None = None,
98
+ cache_response: bool | None = None,
99
+ ) -> Callable:
100
+ """Decorator for mocking tool calls."""
22
101
 
23
- def _convert_to_type(self, value: object, target_type: type) -> object:
24
- """Convert a value to the specified type."""
25
- if target_type is Any:
26
- return value
27
-
28
- if target_type in (str, int, float, bool):
29
- return target_type(value)
30
-
31
- origin = get_origin(target_type)
32
- if origin is list:
33
- if not isinstance(value, list):
34
- error_msg = f"Expected list but got {type(value)}"
35
- raise ValueError(error_msg)
36
- item_type = get_args(target_type)[0]
37
- return [self._convert_to_type(item, item_type) for item in value]
38
-
39
- if origin is dict:
40
- if not isinstance(value, dict):
41
- error_msg = f"Expected dict but got {type(value)}"
102
+ def decorator(func: Callable) -> Callable:
103
+ """Decorator for mocking tool calls."""
104
+ endpoint = os.getenv("VERIS_MOCK_ENDPOINT_URL")
105
+ if not endpoint:
106
+ error_msg = "VERIS_MOCK_ENDPOINT_URL environment variable is not set"
42
107
  raise ValueError(error_msg)
43
- key_type, value_type = get_args(target_type)
44
- return {
45
- self._convert_to_type(k, key_type): self._convert_to_type(v, value_type)
46
- for k, v in value.items()
47
- }
48
-
49
- if origin is Union:
50
- error_msg = (
51
- f"Could not convert {value} to any of the union types {get_args(target_type)}"
52
- )
53
- for possible_type in get_args(target_type):
54
- try:
55
- return self._convert_to_type(value, possible_type)
56
- except (ValueError, TypeError):
57
- continue
58
- raise ValueError(error_msg)
59
-
60
- # For other types, try direct conversion
61
- return target_type(value)
62
-
63
- def mock(self, func: Callable) -> Callable:
64
- """Decorator for mocking tool calls."""
65
- endpoint = os.getenv("VERIS_MOCK_ENDPOINT_URL")
66
- if not endpoint:
67
- error_msg = "VERIS_MOCK_ENDPOINT_URL environment variable is not set"
68
- raise ValueError(error_msg)
69
- # Default timeout of 30 seconds
70
- timeout = float(os.getenv("VERIS_MOCK_TIMEOUT", "30.0"))
71
-
72
- @wraps(func)
73
- async def wrapper(
74
- *args: tuple[object, ...],
75
- **kwargs: dict[str, object],
76
- ) -> object:
77
- # Check if we're in simulation mode
78
- env_mode = os.getenv("ENV", "").lower()
79
- if env_mode != "simulation":
80
- # If not in simulation mode, execute the original function
81
- return await func(*args, **kwargs)
82
-
83
- logger.info(f"Simulating function: {func.__name__}")
84
- sig = inspect.signature(func)
85
- type_hints = get_type_hints(func)
86
-
87
- # Extract return type object (not just the name)
88
- return_type_obj = type_hints.pop("return", Any)
89
-
90
- # Create parameter info
91
- params_info = {}
92
- bound_args = sig.bind(*args, **kwargs)
93
- bound_args.apply_defaults()
94
-
95
- ctx = bound_args.arguments.pop("ctx", None)
96
- session_id = None
97
- if ctx:
98
- try:
99
- session_id = ctx.request_context.lifespan_context.session_id
100
- except AttributeError:
101
- logger.warning("Cannot get session_id from context.")
102
-
103
- for param_name, param_value in bound_args.arguments.items():
104
- params_info[param_name] = {
105
- "value": param_value,
106
- "type": type_hints.get(param_name, Any).__name__,
108
+ # Default timeout of 30 seconds
109
+ timeout = float(os.getenv("VERIS_MOCK_TIMEOUT", "90.0"))
110
+
111
+ @wraps(func)
112
+ async def wrapper(
113
+ *args: tuple[object, ...],
114
+ **kwargs: dict[str, object],
115
+ ) -> object:
116
+ # Check if we're in simulation mode
117
+ env_mode = os.getenv("ENV", "").lower()
118
+ if env_mode != "simulation":
119
+ # If not in simulation mode, execute the original function
120
+ return await func(*args, **kwargs)
121
+ logger.info(f"Simulating function: {func.__name__}")
122
+ sig = inspect.signature(func)
123
+ type_hints = get_type_hints(func)
124
+
125
+ # Extract return type object (not just the name)
126
+ return_type_obj = type_hints.pop("return", Any)
127
+ # Create parameter info
128
+ params_info = {}
129
+ bound_args = sig.bind(*args, **kwargs)
130
+ bound_args.apply_defaults()
131
+ _ = bound_args.arguments.pop("ctx", None)
132
+ _ = bound_args.arguments.pop("self", None)
133
+ _ = bound_args.arguments.pop("cls", None)
134
+
135
+ for param_name, param_value in bound_args.arguments.items():
136
+ params_info[param_name] = {
137
+ "value": str(param_value),
138
+ "type": str(type_hints.get(param_name, Any)),
139
+ }
140
+ # Get function docstring
141
+ docstring = inspect.getdoc(func) or ""
142
+ nonlocal expects_response
143
+ if expects_response is None and mode == "function":
144
+ expects_response = False
145
+ # Prepare payload
146
+ payload = {
147
+ "session_id": self.session_id,
148
+ "expects_response": expects_response,
149
+ "cache_response": cache_response,
150
+ "tool_call": {
151
+ "function_name": func.__name__,
152
+ "parameters": params_info,
153
+ "return_type": json.dumps(extract_json_schema(return_type_obj)),
154
+ "docstring": docstring,
155
+ },
107
156
  }
108
157
 
109
- # Get function docstring
110
- docstring = inspect.getdoc(func) or ""
111
- # Prepare payload
112
- payload = {
113
- "session_id": session_id,
114
- "tool_call": {
115
- "function_name": func.__name__,
116
- "parameters": params_info,
117
- "return_type": return_type_obj.__name__,
118
- "docstring": docstring,
119
- },
120
- }
121
-
122
- # Send request to endpoint with timeout
123
- async with httpx.AsyncClient(timeout=timeout) as client:
124
- response = await client.post(endpoint, json=payload)
125
- response.raise_for_status()
126
- mock_result = response.json()["result"]
127
- logger.info(f"Mock response: {mock_result}")
158
+ # Send request to endpoint with timeout
159
+ async with httpx.AsyncClient(timeout=timeout) as client:
160
+ response = await client.post(endpoint, json=payload)
161
+ response.raise_for_status()
162
+ mock_result = response.json()
163
+ logger.info(f"Mock response: {mock_result}")
128
164
 
165
+ # Convert the mock result to the expected return type
166
+ if mode == "tool":
167
+ return {"content": [{"type": "text", "text": mock_result}]}
168
+ # Parse the mock result if it's a string
169
+ # Extract result field for backwards compatibility
129
170
  # Parse the mock result if it's a string
130
171
  if isinstance(mock_result, str):
131
172
  with suppress(json.JSONDecodeError):
132
173
  mock_result = json.loads(mock_result)
174
+ return convert_to_type(mock_result, return_type_obj)
175
+ return convert_to_type(mock_result, return_type_obj)
133
176
 
134
- # Convert the mock result to the expected return type
135
- return self._convert_to_type(mock_result, return_type_obj)
177
+ return wrapper
178
+
179
+ return decorator
180
+
181
+ def stub(self, return_value: Any) -> Callable: # noqa: ANN401
182
+ """Decorator for stubbing tool calls."""
183
+
184
+ def decorator(func: Callable) -> Callable:
185
+ @wraps(func)
186
+ async def wrapper(*args: tuple[object, ...], **kwargs: dict[str, object]) -> object:
187
+ env_mode = os.getenv("ENV", "").lower()
188
+ if env_mode != "simulation":
189
+ # If not in simulation mode, execute the original function
190
+ return await func(*args, **kwargs)
191
+ logger.info(f"Simulating function: {func.__name__}")
192
+ return return_value
193
+
194
+ return wrapper
136
195
 
137
- return wrapper
196
+ return decorator
138
197
 
139
198
 
140
- veris = ToolMock()
199
+ veris = VerisSDK()