veris-ai 1.0.0__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of veris-ai might be problematic. Click here for more details.
- veris_ai/__init__.py +35 -1
- veris_ai/braintrust_tracing.py +282 -0
- veris_ai/jaeger_interface/README.md +109 -0
- veris_ai/jaeger_interface/__init__.py +26 -0
- veris_ai/jaeger_interface/client.py +133 -0
- veris_ai/jaeger_interface/models.py +153 -0
- veris_ai/tool_mock.py +102 -64
- veris_ai/utils.py +200 -1
- veris_ai-1.1.0.dist-info/METADATA +448 -0
- veris_ai-1.1.0.dist-info/RECORD +12 -0
- veris_ai-1.0.0.dist-info/METADATA +0 -239
- veris_ai-1.0.0.dist-info/RECORD +0 -7
- {veris_ai-1.0.0.dist-info → veris_ai-1.1.0.dist-info}/WHEEL +0 -0
- {veris_ai-1.0.0.dist-info → veris_ai-1.1.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"Tag",
|
|
10
|
+
"Process",
|
|
11
|
+
"Span",
|
|
12
|
+
"Trace",
|
|
13
|
+
"SearchResponse",
|
|
14
|
+
"GetTraceResponse",
|
|
15
|
+
"SearchQuery",
|
|
16
|
+
]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class Tag(BaseModel):
|
|
20
|
+
"""A Jaeger tag key/value pair."""
|
|
21
|
+
|
|
22
|
+
key: str
|
|
23
|
+
value: Any
|
|
24
|
+
type: str | None = None # Jaeger uses an optional *type* field in v1
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class Process(BaseModel):
|
|
28
|
+
"""Represents the *process* section of a Jaeger trace."""
|
|
29
|
+
|
|
30
|
+
serviceName: str = Field(alias="serviceName") # noqa: N815
|
|
31
|
+
tags: list[Tag] | None = None
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class Span(BaseModel):
|
|
35
|
+
"""Represents a single Jaeger span."""
|
|
36
|
+
|
|
37
|
+
traceID: str # noqa: N815
|
|
38
|
+
spanID: str # noqa: N815
|
|
39
|
+
operationName: str # noqa: N815
|
|
40
|
+
startTime: int # noqa: N815
|
|
41
|
+
duration: int
|
|
42
|
+
tags: list[Tag] | None = None
|
|
43
|
+
references: list[dict[str, Any]] | None = None
|
|
44
|
+
processID: str | None = None # noqa: N815
|
|
45
|
+
|
|
46
|
+
model_config = ConfigDict(extra="allow")
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class Trace(BaseModel):
|
|
50
|
+
"""A full Jaeger trace as returned by the Query API."""
|
|
51
|
+
|
|
52
|
+
traceID: str # noqa: N815
|
|
53
|
+
spans: list[Span]
|
|
54
|
+
process: Process | dict[str, Process] | None = None
|
|
55
|
+
warnings: list[str] | None = None
|
|
56
|
+
|
|
57
|
+
model_config = ConfigDict(extra="allow")
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class _BaseResponse(BaseModel):
|
|
61
|
+
data: list[Trace] | Trace | None = None
|
|
62
|
+
errors: list[str] | None = None
|
|
63
|
+
|
|
64
|
+
# Allow any additional keys returned by Jaeger so that nothing gets
|
|
65
|
+
# silently dropped if the backend adds new fields we don’t know about.
|
|
66
|
+
|
|
67
|
+
model_config = ConfigDict(extra="allow")
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class SearchResponse(_BaseResponse):
|
|
71
|
+
"""Response model for *search* or *find traces* requests."""
|
|
72
|
+
|
|
73
|
+
total: int | None = None
|
|
74
|
+
limit: int | None = None
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class GetTraceResponse(_BaseResponse):
|
|
78
|
+
"""Response model for *get trace by id* requests."""
|
|
79
|
+
|
|
80
|
+
# Same as base but alias for clarity
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
# ---------------------------------------------------------------------------
|
|
84
|
+
# Query models
|
|
85
|
+
# ---------------------------------------------------------------------------
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class SearchQuery(BaseModel):
|
|
89
|
+
"""Minimal set of query parameters for the `/api/traces` endpoint.
|
|
90
|
+
|
|
91
|
+
Parameter interaction rules:
|
|
92
|
+
|
|
93
|
+
* **service** – global filter; *all* returned traces must belong to this
|
|
94
|
+
service.
|
|
95
|
+
* **operation** – optional secondary filter; returned traces must contain
|
|
96
|
+
*at least one span* whose ``operationName`` equals the provided value.
|
|
97
|
+
* **tags** – dictionary of key‒value pairs; each trace must include a span
|
|
98
|
+
that matches **all** of the pairs (logical AND).
|
|
99
|
+
* **limit** – applied *after* all other filters; truncates the final list
|
|
100
|
+
of traces to the requested maximum.
|
|
101
|
+
|
|
102
|
+
Any additional/unknown parameters are forwarded thanks to
|
|
103
|
+
``extra = "allow"`` – this keeps the model future-proof.
|
|
104
|
+
"""
|
|
105
|
+
|
|
106
|
+
# NOTE: Only the fields that are reliably supported by Jaeger’s REST API and
|
|
107
|
+
# work with the user’s deployment are kept. The model remains *open* to any
|
|
108
|
+
# extra parameters thanks to `extra = "allow"`.
|
|
109
|
+
|
|
110
|
+
service: str = Field(
|
|
111
|
+
...,
|
|
112
|
+
description="Service name to search for. Example: 'veris-agent'",
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
limit: int | None = Field(
|
|
116
|
+
None,
|
|
117
|
+
description="Maximum number of traces to return. Example: 10",
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
tags: dict[str, Any] | None = Field(
|
|
121
|
+
None,
|
|
122
|
+
description=(
|
|
123
|
+
"Dictionary of tag filters (AND-combined). "
|
|
124
|
+
"Example: {'error': 'true', 'bt.metrics.time_to_first_token': '0.813544'}"
|
|
125
|
+
),
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
operation: str | None = Field(
|
|
129
|
+
None,
|
|
130
|
+
description="Operation name to search for. Example: 'process_chat_message'",
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
model_config = ConfigDict(
|
|
134
|
+
extra="allow", # allow additional query params implicitly
|
|
135
|
+
populate_by_name=True,
|
|
136
|
+
str_to_lower=False,
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
@field_validator("tags", mode="before")
|
|
140
|
+
@classmethod
|
|
141
|
+
def _empty_to_none(cls, v: dict[str, Any] | None) -> dict[str, Any] | None: # noqa: D401, ANN102
|
|
142
|
+
return v or None
|
|
143
|
+
|
|
144
|
+
def to_params(self) -> dict[str, Any]: # noqa: D401
|
|
145
|
+
"""Translate the model into a *requests*/*httpx* compatible params dict."""
|
|
146
|
+
# Dump using aliases so ``span_kind`` becomes ``spanKind`` automatically.
|
|
147
|
+
params: dict[str, Any] = self.model_dump(exclude_none=True, by_alias=True)
|
|
148
|
+
|
|
149
|
+
# Convert tags to a JSON string if necessary – this matches what the UI sends.
|
|
150
|
+
if "tags" in params and isinstance(params["tags"], dict):
|
|
151
|
+
params["tags"] = json.dumps(params["tags"])
|
|
152
|
+
|
|
153
|
+
return params
|
veris_ai/tool_mock.py
CHANGED
|
@@ -8,13 +8,14 @@ from contextvars import ContextVar
|
|
|
8
8
|
from functools import wraps
|
|
9
9
|
from typing import (
|
|
10
10
|
Any,
|
|
11
|
+
Literal,
|
|
11
12
|
TypeVar,
|
|
12
13
|
get_type_hints,
|
|
13
14
|
)
|
|
14
15
|
|
|
15
16
|
import httpx
|
|
16
17
|
|
|
17
|
-
from veris_ai.utils import convert_to_type
|
|
18
|
+
from veris_ai.utils import convert_to_type, extract_json_schema
|
|
18
19
|
|
|
19
20
|
logger = logging.getLogger(__name__)
|
|
20
21
|
|
|
@@ -90,72 +91,109 @@ class VerisSDK:
|
|
|
90
91
|
**params_dict,
|
|
91
92
|
)
|
|
92
93
|
|
|
93
|
-
def mock(
|
|
94
|
+
def mock(
|
|
95
|
+
self,
|
|
96
|
+
mode: Literal["tool", "function"] = "tool",
|
|
97
|
+
expects_response: bool | None = None,
|
|
98
|
+
cache_response: bool | None = None,
|
|
99
|
+
) -> Callable:
|
|
94
100
|
"""Decorator for mocking tool calls."""
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
#
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
101
|
+
|
|
102
|
+
def decorator(func: Callable) -> Callable:
|
|
103
|
+
"""Decorator for mocking tool calls."""
|
|
104
|
+
endpoint = os.getenv("VERIS_MOCK_ENDPOINT_URL")
|
|
105
|
+
if not endpoint:
|
|
106
|
+
error_msg = "VERIS_MOCK_ENDPOINT_URL environment variable is not set"
|
|
107
|
+
raise ValueError(error_msg)
|
|
108
|
+
# Default timeout of 30 seconds
|
|
109
|
+
timeout = float(os.getenv("VERIS_MOCK_TIMEOUT", "90.0"))
|
|
110
|
+
|
|
111
|
+
@wraps(func)
|
|
112
|
+
async def wrapper(
|
|
113
|
+
*args: tuple[object, ...],
|
|
114
|
+
**kwargs: dict[str, object],
|
|
115
|
+
) -> object:
|
|
116
|
+
# Check if we're in simulation mode
|
|
117
|
+
env_mode = os.getenv("ENV", "").lower()
|
|
118
|
+
if env_mode != "simulation":
|
|
119
|
+
# If not in simulation mode, execute the original function
|
|
120
|
+
return await func(*args, **kwargs)
|
|
121
|
+
logger.info(f"Simulating function: {func.__name__}")
|
|
122
|
+
sig = inspect.signature(func)
|
|
123
|
+
type_hints = get_type_hints(func)
|
|
124
|
+
|
|
125
|
+
# Extract return type object (not just the name)
|
|
126
|
+
return_type_obj = type_hints.pop("return", Any)
|
|
127
|
+
# Create parameter info
|
|
128
|
+
params_info = {}
|
|
129
|
+
bound_args = sig.bind(*args, **kwargs)
|
|
130
|
+
bound_args.apply_defaults()
|
|
131
|
+
_ = bound_args.arguments.pop("ctx", None)
|
|
132
|
+
_ = bound_args.arguments.pop("self", None)
|
|
133
|
+
_ = bound_args.arguments.pop("cls", None)
|
|
134
|
+
|
|
135
|
+
for param_name, param_value in bound_args.arguments.items():
|
|
136
|
+
params_info[param_name] = {
|
|
137
|
+
"value": str(param_value),
|
|
138
|
+
"type": str(type_hints.get(param_name, Any)),
|
|
139
|
+
}
|
|
140
|
+
# Get function docstring
|
|
141
|
+
docstring = inspect.getdoc(func) or ""
|
|
142
|
+
nonlocal expects_response
|
|
143
|
+
if expects_response is None and mode == "function":
|
|
144
|
+
expects_response = False
|
|
145
|
+
# Prepare payload
|
|
146
|
+
payload = {
|
|
147
|
+
"session_id": self.session_id,
|
|
148
|
+
"expects_response": expects_response,
|
|
149
|
+
"cache_response": cache_response,
|
|
150
|
+
"tool_call": {
|
|
151
|
+
"function_name": func.__name__,
|
|
152
|
+
"parameters": params_info,
|
|
153
|
+
"return_type": json.dumps(extract_json_schema(return_type_obj)),
|
|
154
|
+
"docstring": docstring,
|
|
155
|
+
},
|
|
128
156
|
}
|
|
129
157
|
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
"
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
mock_result
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
158
|
+
# Send request to endpoint with timeout
|
|
159
|
+
async with httpx.AsyncClient(timeout=timeout) as client:
|
|
160
|
+
response = await client.post(endpoint, json=payload)
|
|
161
|
+
response.raise_for_status()
|
|
162
|
+
mock_result = response.json()
|
|
163
|
+
logger.info(f"Mock response: {mock_result}")
|
|
164
|
+
|
|
165
|
+
# Convert the mock result to the expected return type
|
|
166
|
+
if mode == "tool":
|
|
167
|
+
return {"content": [{"type": "text", "text": mock_result}]}
|
|
168
|
+
# Parse the mock result if it's a string
|
|
169
|
+
# Extract result field for backwards compatibility
|
|
170
|
+
# Parse the mock result if it's a string
|
|
171
|
+
if isinstance(mock_result, str):
|
|
172
|
+
with suppress(json.JSONDecodeError):
|
|
173
|
+
mock_result = json.loads(mock_result)
|
|
174
|
+
return convert_to_type(mock_result, return_type_obj)
|
|
175
|
+
return convert_to_type(mock_result, return_type_obj)
|
|
176
|
+
|
|
177
|
+
return wrapper
|
|
178
|
+
|
|
179
|
+
return decorator
|
|
180
|
+
|
|
181
|
+
def stub(self, return_value: Any) -> Callable: # noqa: ANN401
|
|
182
|
+
"""Decorator for stubbing tool calls."""
|
|
183
|
+
|
|
184
|
+
def decorator(func: Callable) -> Callable:
|
|
185
|
+
@wraps(func)
|
|
186
|
+
async def wrapper(*args: tuple[object, ...], **kwargs: dict[str, object]) -> object:
|
|
187
|
+
env_mode = os.getenv("ENV", "").lower()
|
|
188
|
+
if env_mode != "simulation":
|
|
189
|
+
# If not in simulation mode, execute the original function
|
|
190
|
+
return await func(*args, **kwargs)
|
|
191
|
+
logger.info(f"Simulating function: {func.__name__}")
|
|
192
|
+
return return_value
|
|
193
|
+
|
|
194
|
+
return wrapper
|
|
195
|
+
|
|
196
|
+
return decorator
|
|
159
197
|
|
|
160
198
|
|
|
161
199
|
veris = VerisSDK()
|
veris_ai/utils.py
CHANGED
|
@@ -1,5 +1,10 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
import types
|
|
3
|
+
import typing
|
|
1
4
|
from contextlib import suppress
|
|
2
|
-
from typing import Any, Union, get_args, get_origin
|
|
5
|
+
from typing import Any, ForwardRef, Literal, NotRequired, Required, Union, get_args, get_origin
|
|
6
|
+
|
|
7
|
+
from pydantic import BaseModel
|
|
3
8
|
|
|
4
9
|
|
|
5
10
|
def convert_to_type(value: object, target_type: type) -> object:
|
|
@@ -69,3 +74,197 @@ def _convert_simple_type(value: object, target_type: type) -> object:
|
|
|
69
74
|
return target_type(**value)
|
|
70
75
|
|
|
71
76
|
return target_type(value)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def _resolve_forward_ref(ref: ForwardRef, module_context: types.ModuleType | None = None) -> Any: # noqa: ANN401
|
|
80
|
+
"""Resolve a ForwardRef to its actual type."""
|
|
81
|
+
if not isinstance(ref, ForwardRef):
|
|
82
|
+
return ref
|
|
83
|
+
|
|
84
|
+
# Try to evaluate the forward reference
|
|
85
|
+
try:
|
|
86
|
+
# Get the module's namespace for evaluation
|
|
87
|
+
namespace = dict(vars(module_context)) if module_context else {}
|
|
88
|
+
|
|
89
|
+
# Add common typing imports to namespace
|
|
90
|
+
namespace.update(
|
|
91
|
+
{
|
|
92
|
+
"Union": Union,
|
|
93
|
+
"Any": Any,
|
|
94
|
+
"Literal": Literal,
|
|
95
|
+
"Required": Required,
|
|
96
|
+
"NotRequired": NotRequired,
|
|
97
|
+
"List": list,
|
|
98
|
+
"Dict": dict,
|
|
99
|
+
"Optional": typing.Optional,
|
|
100
|
+
"Iterable": typing.Iterable,
|
|
101
|
+
"str": str,
|
|
102
|
+
"int": int,
|
|
103
|
+
"float": float,
|
|
104
|
+
"bool": bool,
|
|
105
|
+
},
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
# Try to import from the same module to resolve local references
|
|
109
|
+
if module_context and hasattr(module_context, "__name__"):
|
|
110
|
+
with suppress(Exception):
|
|
111
|
+
# Import all from the module to get access to local types
|
|
112
|
+
exec(f"from {module_context.__name__} import *", namespace) # noqa: S102
|
|
113
|
+
|
|
114
|
+
# Get the forward reference string
|
|
115
|
+
ref_string = ref.__forward_arg__ if hasattr(ref, "__forward_arg__") else str(ref)
|
|
116
|
+
|
|
117
|
+
# Try to evaluate the forward reference string
|
|
118
|
+
return eval(ref_string, namespace, namespace) # noqa: S307
|
|
119
|
+
except Exception:
|
|
120
|
+
# If we can't resolve it, return the ref itself
|
|
121
|
+
return ref
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def _unwrap_required(field_type: Any) -> tuple[Any, bool]: # noqa: ANN401
|
|
125
|
+
"""Unwrap Required/NotRequired and return the inner type and whether it's required."""
|
|
126
|
+
origin = get_origin(field_type)
|
|
127
|
+
|
|
128
|
+
# Check if it's Required or NotRequired
|
|
129
|
+
if origin is Required:
|
|
130
|
+
args = get_args(field_type)
|
|
131
|
+
return args[0] if args else field_type, True
|
|
132
|
+
if origin is NotRequired:
|
|
133
|
+
args = get_args(field_type)
|
|
134
|
+
return args[0] if args else field_type, False
|
|
135
|
+
|
|
136
|
+
# Default to required for TypedDict fields
|
|
137
|
+
return field_type, True
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def extract_json_schema(target_type: Any) -> dict: # noqa: PLR0911, PLR0912, C901, ANN401
|
|
141
|
+
"""Extract the JSON schema from a type or pydantic model.
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
target_type: The type or pydantic model to extract the JSON schema from.
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
A dictionary representing the JSON schema.
|
|
148
|
+
|
|
149
|
+
Example:
|
|
150
|
+
>>> extract_json_schema(int)
|
|
151
|
+
{"type": "integer"}
|
|
152
|
+
|
|
153
|
+
>>> extract_json_schema(list[int])
|
|
154
|
+
{"type": "array", "items": {"type": "integer"}}
|
|
155
|
+
|
|
156
|
+
>>> extract_json_schema(list[User])
|
|
157
|
+
{"type": "array", "items": {"type": "object", "properties": {...}}}
|
|
158
|
+
"""
|
|
159
|
+
# Handle Pydantic BaseModel instances or classes
|
|
160
|
+
if isinstance(target_type, type) and issubclass(target_type, BaseModel):
|
|
161
|
+
return target_type.model_json_schema()
|
|
162
|
+
if isinstance(target_type, BaseModel):
|
|
163
|
+
return target_type.model_json_schema()
|
|
164
|
+
|
|
165
|
+
# Handle TypedDict
|
|
166
|
+
if (
|
|
167
|
+
isinstance(target_type, type)
|
|
168
|
+
and hasattr(target_type, "__annotations__")
|
|
169
|
+
and hasattr(target_type, "__total__")
|
|
170
|
+
):
|
|
171
|
+
# This is a TypedDict
|
|
172
|
+
properties = {}
|
|
173
|
+
required = []
|
|
174
|
+
|
|
175
|
+
# Get the module context for resolving forward references
|
|
176
|
+
module = sys.modules.get(target_type.__module__)
|
|
177
|
+
|
|
178
|
+
for field_name, field_type_annotation in target_type.__annotations__.items():
|
|
179
|
+
# Resolve forward references if present
|
|
180
|
+
resolved_type = field_type_annotation
|
|
181
|
+
if isinstance(resolved_type, ForwardRef):
|
|
182
|
+
resolved_type = _resolve_forward_ref(resolved_type, module)
|
|
183
|
+
|
|
184
|
+
# Unwrap Required/NotRequired
|
|
185
|
+
unwrapped_type, is_required = _unwrap_required(resolved_type)
|
|
186
|
+
|
|
187
|
+
# Extract schema for the unwrapped type
|
|
188
|
+
properties[field_name] = extract_json_schema(unwrapped_type)
|
|
189
|
+
|
|
190
|
+
# Add to required list if necessary
|
|
191
|
+
if is_required and getattr(target_type, "__total__", True):
|
|
192
|
+
required.append(field_name)
|
|
193
|
+
|
|
194
|
+
schema = {"type": "object", "properties": properties}
|
|
195
|
+
if required:
|
|
196
|
+
schema["required"] = required
|
|
197
|
+
return schema
|
|
198
|
+
|
|
199
|
+
# Handle built-in types
|
|
200
|
+
type_mapping = {
|
|
201
|
+
str: {"type": "string"},
|
|
202
|
+
int: {"type": "integer"},
|
|
203
|
+
float: {"type": "number"},
|
|
204
|
+
bool: {"type": "boolean"},
|
|
205
|
+
type(None): {"type": "null"},
|
|
206
|
+
Any: {}, # Empty schema for Any type
|
|
207
|
+
}
|
|
208
|
+
|
|
209
|
+
if target_type in type_mapping:
|
|
210
|
+
return type_mapping[target_type]
|
|
211
|
+
|
|
212
|
+
# Handle generic types
|
|
213
|
+
origin = get_origin(target_type)
|
|
214
|
+
|
|
215
|
+
# Handle bare collection types
|
|
216
|
+
if target_type is list:
|
|
217
|
+
return {"type": "array"}
|
|
218
|
+
if target_type is dict:
|
|
219
|
+
return {"type": "object"}
|
|
220
|
+
if target_type is tuple:
|
|
221
|
+
return {"type": "array"}
|
|
222
|
+
|
|
223
|
+
# Handle Literal types
|
|
224
|
+
if origin is Literal:
|
|
225
|
+
values = get_args(target_type)
|
|
226
|
+
if len(values) == 1:
|
|
227
|
+
# Single literal value - use const
|
|
228
|
+
return {"const": values[0]}
|
|
229
|
+
# Multiple literal values - use enum
|
|
230
|
+
return {"enum": list(values)}
|
|
231
|
+
|
|
232
|
+
if origin is list:
|
|
233
|
+
args = get_args(target_type)
|
|
234
|
+
if args:
|
|
235
|
+
return {"type": "array", "items": extract_json_schema(args[0])}
|
|
236
|
+
return {"type": "array"}
|
|
237
|
+
|
|
238
|
+
if origin is dict:
|
|
239
|
+
args = get_args(target_type)
|
|
240
|
+
if len(args) == 2: # noqa: PLR2004
|
|
241
|
+
# For typed dicts like dict[str, int]
|
|
242
|
+
return {
|
|
243
|
+
"type": "object",
|
|
244
|
+
"additionalProperties": extract_json_schema(args[1]),
|
|
245
|
+
}
|
|
246
|
+
return {"type": "object"}
|
|
247
|
+
|
|
248
|
+
if origin is Union:
|
|
249
|
+
args = get_args(target_type)
|
|
250
|
+
# Handle Optional types (Union[T, None])
|
|
251
|
+
if len(args) == 2 and type(None) in args: # noqa: PLR2004
|
|
252
|
+
non_none_type = args[0] if args[1] is type(None) else args[1]
|
|
253
|
+
schema = extract_json_schema(non_none_type)
|
|
254
|
+
return {"anyOf": [schema, {"type": "null"}]}
|
|
255
|
+
# Handle general Union types
|
|
256
|
+
return {"anyOf": [extract_json_schema(arg) for arg in args]}
|
|
257
|
+
|
|
258
|
+
if origin is tuple:
|
|
259
|
+
args = get_args(target_type)
|
|
260
|
+
if args:
|
|
261
|
+
return {
|
|
262
|
+
"type": "array",
|
|
263
|
+
"prefixItems": [extract_json_schema(arg) for arg in args],
|
|
264
|
+
"minItems": len(args),
|
|
265
|
+
"maxItems": len(args),
|
|
266
|
+
}
|
|
267
|
+
return {"type": "array"}
|
|
268
|
+
|
|
269
|
+
# Default case for unknown types
|
|
270
|
+
return {"type": "object"}
|