tensorzero 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tensorzero/__init__.py ADDED
@@ -0,0 +1,29 @@
1
+ from .client import AsyncTensorZero
2
+ from .types import (
3
+ ChatInferenceResponse,
4
+ ContentBlock,
5
+ FeedbackResponse,
6
+ InferenceChunk,
7
+ InferenceResponse,
8
+ JsonInferenceOutput,
9
+ Text,
10
+ TextChunk,
11
+ ToolCall,
12
+ ToolCallChunk,
13
+ Usage,
14
+ )
15
+
16
+ __all__ = [
17
+ "ChatInferenceResponse",
18
+ "ContentBlock",
19
+ "FeedbackResponse",
20
+ "InferenceChunk",
21
+ "InferenceResponse",
22
+ "JsonInferenceOutput",
23
+ "AsyncTensorZero",
24
+ "Text",
25
+ "TextChunk",
26
+ "ToolCall",
27
+ "ToolCallChunk",
28
+ "Usage",
29
+ ]
tensorzero/client.py ADDED
@@ -0,0 +1,194 @@
1
+ """
2
+ TensorZero Client
3
+
4
+ This module provides an asynchronous client for interacting with the TensorZero gateway.
5
+ It includes functionality for making inference requests and sending feedback.
6
+
7
+ The main class, AsyncTensorZero, offers methods for:
8
+ - Initializing the client with a base URL
9
+ - Making inference requests (with optional streaming)
10
+ - Sending feedback on episodes or inferences
11
+ - Managing the client session using async context managers
12
+
13
+ Usage:
14
+ async with TensorZero(base_url) as client:
15
+ response = await client.inference(...)
16
+ feedback = await client.feedback(...)
17
+ """
18
+
19
+ import json
20
+ import logging
21
+ from typing import Any, AsyncGenerator, Dict, List, Literal, Optional, Union
22
+ from urllib.parse import urljoin
23
+ from uuid import UUID
24
+
25
+ import httpx
26
+
27
+ from .types import (
28
+ FeedbackResponse,
29
+ InferenceChunk,
30
+ InferenceResponse,
31
+ parse_inference_chunk,
32
+ parse_inference_response,
33
+ )
34
+
35
+
36
+ class AsyncTensorZero:
37
+ def __init__(self, base_url: str):
38
+ """
39
+ Initialize the TensorZero client.
40
+
41
+ :param base_url: The base URL of the TensorZero gateway. Example:"http://localhost:3000"
42
+ """
43
+ self.base_url = base_url
44
+ self.client = httpx.AsyncClient()
45
+ self.logger = logging.getLogger(__name__)
46
+
47
+ async def inference(
48
+ self,
49
+ function_name: str,
50
+ input: Dict[str, Any],
51
+ episode_id: Optional[UUID] = None,
52
+ stream: Optional[bool] = None,
53
+ params: Optional[Dict[str, Any]] = None,
54
+ variant_name: Optional[str] = None,
55
+ dryrun: Optional[bool] = None,
56
+ allowed_tools: Optional[List[str]] = None,
57
+ additional_tools: Optional[List[Dict[str, Any]]] = None,
58
+ tool_choice: Optional[
59
+ Union[Literal["auto", "required", "off"], Dict[Literal["specific"], str]]
60
+ ] = None,
61
+ parallel_tool_calls: Optional[bool] = None,
62
+ ) -> Union[InferenceResponse, AsyncGenerator[InferenceChunk, None]]:
63
+ """
64
+ Make a POST request to the /inference endpoint.
65
+
66
+ :param function_name: The name of the function to call
67
+ :param input: The input to the function
68
+ Structure: {"system": Optional[str], "messages": List[{"role": "user" | "assistant", "content": Any}]}
69
+ The input will be validated server side against the input schema of the function being called.
70
+ :param episode_id: The episode ID to use for the inference.
71
+ If this is the first inference in an episode, leave this field blank. The TensorZero gateway will generate and return a new episode ID.
72
+ Note: Only use episode IDs generated by the TensorZero gateway. Don't generate them yourself.
73
+ :param stream: If set, the TensorZero gateway will stream partial message deltas (e.g. generated tokens) as it receives them from model providers.
74
+ :param params: Override inference-time parameters for a particular variant type. Currently, we support:
75
+ {"chat_completion": {"temperature": float, "max_tokens": int, "seed": int}}
76
+ :param variant_name: If set, pins the inference request to a particular variant.
77
+ Note: You should generally not do this, and instead let the TensorZero gateway assign a
78
+ particular variant. This field is primarily used for testing or debugging purposes.
79
+ :param dryrun: If true, the request will be executed but won't be stored to the database.
80
+ :param allowed_tools: If set, restricts the tools available during this inference request.
81
+ The list of names should be a subset of the tools configured for the function.
82
+ Tools provided at inference time in `additional_tools` (if any) are always available.
83
+ :param additional_tools: A list of additional tools to use for the request. Each element should look like {"name": str, "parameters": valid JSON Schema, "description": str}
84
+ :param tool_choice: If set, overrides the tool choice strategy for the request.
85
+ It should be one of: "auto", "required", "off", or {"specific": str}. The last option pins the request to a specific tool name.
86
+ :param parallel_tool_calls: If true, the request will allow for multiple tool calls in a single inference request.
87
+ :return: If stream is false, returns an InferenceResponse.
88
+ If stream is true, returns an async generator that yields InferenceChunks as they come in.
89
+ """
90
+ url = urljoin(self.base_url, "inference")
91
+ data = {
92
+ "function_name": function_name,
93
+ "input": input,
94
+ }
95
+ if episode_id is not None:
96
+ data["episode_id"] = str(episode_id)
97
+ if stream is not None:
98
+ data["stream"] = stream
99
+ if params is not None:
100
+ data["params"] = params
101
+ if variant_name is not None:
102
+ data["variant_name"] = variant_name
103
+ if dryrun is not None:
104
+ data["dryrun"] = dryrun
105
+ if allowed_tools is not None:
106
+ data["allowed_tools"] = allowed_tools
107
+ if additional_tools is not None:
108
+ data["additional_tools"] = additional_tools
109
+ if tool_choice is not None:
110
+ data["tool_choice"] = tool_choice
111
+ if parallel_tool_calls is not None:
112
+ data["parallel_tool_calls"] = parallel_tool_calls
113
+ response = await self.client.post(url, json=data)
114
+ response.raise_for_status()
115
+ if not stream:
116
+ return parse_inference_response(response.json())
117
+ else:
118
+ return self._stream_sse(response)
119
+
120
+ async def feedback(
121
+ self,
122
+ metric_name: str,
123
+ value: Any,
124
+ inference_id: Optional[UUID] = None,
125
+ episode_id: Optional[UUID] = None,
126
+ dryrun: Optional[bool] = None,
127
+ ) -> Dict[str, Any]:
128
+ """
129
+ Make a POST request to the /feedback endpoint.
130
+
131
+ :param metric_name: The name of the metric to provide feedback for
132
+ :param value: The value of the feedback. It should correspond to the metric type.
133
+ :param inference_id: The inference ID to assign the feedback to.
134
+ Only use inference IDs that were returned by the TensorZero gateway.
135
+ Note: You can assign feedback to either an episode or an inference, but not both.
136
+ :param episode_id: The episode ID to use for the request
137
+ Only use episode IDs that were returned by the TensorZero gateway.
138
+ Note: You can assign feedback to either an episode or an inference, but not both.
139
+ :param dryrun: If true, the feedback request will be executed but won't be stored to the database (i.e. no-op).
140
+ :return: {"feedback_id": str}
141
+ """
142
+ if episode_id is None and inference_id is None:
143
+ raise ValueError("Either episode_id or inference_id must be provided")
144
+ if episode_id is not None and inference_id is not None:
145
+ raise ValueError(
146
+ "Only one of episode_id or inference_id can be provided, not both"
147
+ )
148
+ data = {
149
+ "metric_name": metric_name,
150
+ "value": value,
151
+ }
152
+ if dryrun is not None:
153
+ data["dryrun"] = dryrun
154
+ if episode_id is not None:
155
+ data["episode_id"] = str(episode_id)
156
+ if inference_id is not None:
157
+ data["inference_id"] = str(inference_id)
158
+ url = urljoin(self.base_url, "feedback")
159
+ response = await self.client.post(url, json=data)
160
+ response.raise_for_status()
161
+ feedback_result = FeedbackResponse(**response.json())
162
+ return feedback_result
163
+
164
+ async def close(self):
165
+ """
166
+ Close the connection to the TensorZero gateway.
167
+ """
168
+ await self.client.aclose()
169
+
170
+ async def __aenter__(self):
171
+ return self
172
+
173
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
174
+ await self.close()
175
+
176
+ async def _stream_sse(
177
+ self, response: httpx.Response
178
+ ) -> AsyncGenerator[InferenceChunk, None]:
179
+ """
180
+ Parse the SSE stream from the response.
181
+
182
+ :param response: The httpx.Response object
183
+ :yield: Parsed SSE events as dictionaries
184
+ """
185
+ async for line in response.aiter_lines():
186
+ if line.startswith("data: "):
187
+ data = line[6:].strip()
188
+ if data == "[DONE]":
189
+ break
190
+ try:
191
+ data = json.loads(data)
192
+ yield parse_inference_chunk(data)
193
+ except json.JSONDecodeError:
194
+ self.logger.error(f"Failed to parse SSE data: {data}")
tensorzero/types.py ADDED
@@ -0,0 +1,175 @@
1
+ from dataclasses import dataclass
2
+ from typing import Any, Dict, List, Optional, Union
3
+ from uuid import UUID
4
+
5
+ # Types for non-streaming inference responses
6
+
7
+
8
+ @dataclass
9
+ class Usage:
10
+ input_tokens: int
11
+ output_tokens: int
12
+
13
+
14
+ @dataclass
15
+ class Text:
16
+ text: str
17
+
18
+
19
+ @dataclass
20
+ class ToolCall:
21
+ name: str
22
+ arguments: Dict[str, Any]
23
+ id: str
24
+ parsed_name: Optional[str]
25
+ parsed_arguments: Optional[Dict[str, Any]]
26
+
27
+
28
+ ContentBlock = Union[Text, ToolCall]
29
+
30
+
31
+ @dataclass
32
+ class JsonInferenceOutput:
33
+ raw: str
34
+ parsed: Optional[Dict[str, Any]]
35
+
36
+
37
+ @dataclass
38
+ class ChatInferenceResponse:
39
+ inference_id: UUID
40
+ episode_id: UUID
41
+ variant_name: str
42
+ output: List[ContentBlock]
43
+ usage: Usage
44
+
45
+
46
+ @dataclass
47
+ class JsonInferenceResponse:
48
+ inference_id: UUID
49
+ episode_id: UUID
50
+ variant_name: str
51
+ output: JsonInferenceOutput
52
+ usage: Usage
53
+
54
+
55
+ InferenceResponse = Union[ChatInferenceResponse, JsonInferenceResponse]
56
+
57
+
58
+ def parse_inference_response(data: Dict[str, Any]) -> InferenceResponse:
59
+ if "output" in data and isinstance(data["output"], list):
60
+ return ChatInferenceResponse(
61
+ inference_id=UUID(data["inference_id"]),
62
+ episode_id=UUID(data["episode_id"]),
63
+ variant_name=data["variant_name"],
64
+ output=[parse_content_block(block) for block in data["output"]],
65
+ usage=Usage(**data["usage"]),
66
+ )
67
+ elif "output" in data and isinstance(data["output"], dict):
68
+ return JsonInferenceResponse(
69
+ inference_id=UUID(data["inference_id"]),
70
+ episode_id=UUID(data["episode_id"]),
71
+ variant_name=data["variant_name"],
72
+ output=JsonInferenceOutput(**data["output"]),
73
+ usage=Usage(**data["usage"]),
74
+ )
75
+ else:
76
+ raise ValueError("Unable to determine response type")
77
+
78
+
79
+ def parse_content_block(block: Dict[str, Any]) -> ContentBlock:
80
+ block_type = block["type"]
81
+ if block_type == "text":
82
+ return Text(text=block["text"])
83
+ elif block_type == "tool_call":
84
+ return ToolCall(
85
+ name=block["name"],
86
+ arguments=block["arguments"],
87
+ id=block["id"],
88
+ parsed_name=block.get("parsed_name"),
89
+ parsed_arguments=block.get("parsed_arguments"),
90
+ )
91
+ else:
92
+ raise ValueError(f"Unknown content block type: {block}")
93
+
94
+
95
+ # Types for streaming inference responses
96
+
97
+
98
+ @dataclass
99
+ class TextChunk:
100
+ # In the possibility that multiple text messages are sent in a single streaming response,
101
+ # this `id` will be used to disambiguate them
102
+ id: str
103
+ text: str
104
+
105
+
106
+ @dataclass
107
+ class ToolCallChunk:
108
+ name: str
109
+ # This is the tool call ID that many LLM APIs use to associate tool calls with tool responses
110
+ id: str
111
+ # `arguments` will come as partial JSON
112
+ arguments: str
113
+
114
+
115
+ ContentBlockChunk = Union[TextChunk, ToolCallChunk]
116
+
117
+
118
+ @dataclass
119
+ class ChatChunk:
120
+ inference_id: UUID
121
+ episode_id: UUID
122
+ variant_name: str
123
+ content: List[ContentBlockChunk]
124
+ usage: Optional[Usage]
125
+
126
+
127
+ @dataclass
128
+ class JsonChunk:
129
+ inference_id: UUID
130
+ episode_id: UUID
131
+ variant_name: str
132
+ raw: str
133
+ usage: Optional[Usage]
134
+
135
+
136
+ InferenceChunk = Union[ChatChunk, JsonChunk]
137
+
138
+
139
+ def parse_inference_chunk(chunk: Dict[str, Any]) -> InferenceChunk:
140
+ if "content" in chunk:
141
+ return ChatChunk(
142
+ inference_id=UUID(chunk["inference_id"]),
143
+ episode_id=UUID(chunk["episode_id"]),
144
+ variant_name=chunk["variant_name"],
145
+ content=[parse_content_block_chunk(block) for block in chunk["content"]],
146
+ usage=Usage(**chunk["usage"]) if "usage" in chunk else None,
147
+ )
148
+ elif "raw" in chunk:
149
+ return JsonChunk(
150
+ inference_id=UUID(chunk["inference_id"]),
151
+ episode_id=UUID(chunk["episode_id"]),
152
+ variant_name=chunk["variant_name"],
153
+ raw=chunk["raw"],
154
+ usage=Usage(**chunk["usage"]) if "usage" in chunk else None,
155
+ )
156
+ else:
157
+ raise ValueError(f"Unable to determine response type: {chunk}")
158
+
159
+
160
+ def parse_content_block_chunk(block: Dict[str, Any]) -> ContentBlockChunk:
161
+ block_type = block["type"]
162
+ if block_type == "text":
163
+ return TextChunk(id=block["id"], text=block["text"])
164
+ elif block_type == "tool_call":
165
+ return ToolCallChunk(
166
+ name=block["name"], id=block["id"], arguments=block["arguments"]
167
+ )
168
+ else:
169
+ raise ValueError(f"Unknown content block type: {block}")
170
+
171
+
172
+ # Types for feedback
173
+ @dataclass
174
+ class FeedbackResponse:
175
+ feedback_id: UUID
@@ -0,0 +1,56 @@
1
+ Metadata-Version: 2.3
2
+ Name: tensorzero
3
+ Version: 0.0.1
4
+ Summary: The Python client for TensorZero
5
+ Author-email: Viraj Mehta <viraj@tensorzero.com>, Gabriel Bianconi <gabriel@tensorzero.com>
6
+ Requires-Python: >=3.10
7
+ Requires-Dist: httpx>=0.27.0
8
+ Description-Content-Type: text/markdown
9
+
10
+ # TensorZero Python Client
11
+
12
+ This is an async Python client for the TensorZero gateway. Check out the [docs](https://tensorzero.com/docs/) for more information. This client allows you to easily make inference requests and assign feedback to them via the TensorZero gateway.
13
+
14
+ ## Installation
15
+
16
+ ```bash
17
+ pip install tensorzero
18
+ ```
19
+
20
+ ## Basic Usage
21
+
22
+ ### Non-Streaming Inference
23
+
24
+ ```python
25
+ from tensorzero import AsyncTensorZero
26
+
27
+ with AsyncTensorZero("http://localhost:3000") as client:
28
+ result = await client.inference(
29
+ function_name="basic_test",
30
+ input={
31
+ "system": {"assistant_name": "Alfred Pennyworth"},
32
+ "messages": [{"role": "user", "content": "Hello"}],
33
+ },
34
+ )
35
+ episode_id = result.episode_id
36
+ output = result.output
37
+ print(output[0].text) # Prints the text of the first content block returned by TensorZero
38
+ ```
39
+
40
+ ### Streaming Inference
41
+
42
+ ```python
43
+ from tensorzero import AsyncTensorZero
44
+
45
+ async with AsyncTensorZero() as client:
46
+ stream = await client.chat.completions.create(
47
+ function_name="basic_test",
48
+ input={
49
+ "system": {"assistant_name": "Alfred Pennyworth"},
50
+ "messages": [{"role": "user", "content": "Hello"}],
51
+ },
52
+ stream=True,
53
+ )
54
+ async for chunk in stream:
55
+ print(chunk.content[0].text) # Prints the text in each chunk returned by TensorZero
56
+ ```
@@ -0,0 +1,6 @@
1
+ tensorzero/__init__.py,sha256=X7ombaI3m4Slt1qgOCCQqsqujxGz27v-Ux7Q0M-j9dA,527
2
+ tensorzero/client.py,sha256=qWpgC5yPUuxmsYzNvXl3Oiv3c9VRPPjHaWeJA-njMT0,8801
3
+ tensorzero/types.py,sha256=owmbIa1HpdpcyeC8uZJJhGwqr_LDat0rHjo4reJS9m4,4564
4
+ tensorzero-0.0.1.dist-info/METADATA,sha256=1gCyEXjz2eUGJxZmD729NXCQytCLMUIzeNTI-1mubcU,1655
5
+ tensorzero-0.0.1.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
6
+ tensorzero-0.0.1.dist-info/RECORD,,
@@ -0,0 +1,4 @@
1
+ Wheel-Version: 1.0
2
+ Generator: hatchling 1.25.0
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any