tracellm-cli 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- app/__init__.py +1 -0
- app/database/__init__.py +1 -0
- app/database/mongodb.py +94 -0
- app/database/project_service.py +97 -0
- app/database/trace_service.py +417 -0
- app/main.py +44 -0
- app/models/__init__.py +14 -0
- app/models/health.py +5 -0
- app/models/project.py +32 -0
- app/models/trace.py +71 -0
- app/models/trace_model.py +62 -0
- app/routes/__init__.py +1 -0
- app/routes/health.py +10 -0
- app/routes/observability.py +60 -0
- app/routes/projects.py +25 -0
- app/websocket/__init__.py +1 -0
- app/websocket/socket.py +64 -0
- sdk/__init__.py +3 -0
- sdk/tracer.py +8 -0
- tracellm/__init__.py +6 -0
- tracellm/banner.py +34 -0
- tracellm/cli.py +124 -0
- tracellm/db.py +75 -0
- tracellm/exporter.py +65 -0
- tracellm/integrations/__init__.py +4 -0
- tracellm/integrations/langchain.py +186 -0
- tracellm/integrations/openai.py +234 -0
- tracellm/integrations/tool_tracer.py +151 -0
- tracellm/mascot.py +49 -0
- tracellm/monitor.py +381 -0
- tracellm/palette.py +186 -0
- tracellm/replay.py +80 -0
- tracellm/startup.py +121 -0
- tracellm/summary.py +53 -0
- tracellm/trace_stream.py +68 -0
- tracellm/tracer.py +598 -0
- tracellm/tree_renderer.py +78 -0
- tracellm/utils.py +390 -0
- tracellm_cli-0.1.0.dist-info/METADATA +30 -0
- tracellm_cli-0.1.0.dist-info/RECORD +43 -0
- tracellm_cli-0.1.0.dist-info/WHEEL +5 -0
- tracellm_cli-0.1.0.dist-info/entry_points.txt +2 -0
- tracellm_cli-0.1.0.dist-info/top_level.txt +3 -0
|
@@ -0,0 +1,186 @@
|
|
|
1
|
+
import time
|
|
2
|
+
import uuid
|
|
3
|
+
from datetime import datetime, timezone
|
|
4
|
+
from typing import Any, Optional
|
|
5
|
+
|
|
6
|
+
from langchain_core.callbacks import BaseCallbackHandler
|
|
7
|
+
from langchain_core.messages import BaseMessage
|
|
8
|
+
|
|
9
|
+
from tracellm.utils import build_tool_step, estimate_tokens, console
|
|
10
|
+
from tracellm.db import save_trace_payload
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _get_trace_context() -> dict[str, Any] | None:
|
|
14
|
+
try:
|
|
15
|
+
from tracellm.tracer import _current_trace_context
|
|
16
|
+
return _current_trace_context.get()
|
|
17
|
+
except ImportError:
|
|
18
|
+
return None
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _finalize_lc_trace(trace_data: dict[str, Any]) -> None:
|
|
22
|
+
ctx = _get_trace_context() or {}
|
|
23
|
+
trace_data.setdefault("project_id", ctx.get("project_id") or "default")
|
|
24
|
+
trace_data.setdefault("project_name", ctx.get("project_name"))
|
|
25
|
+
trace_data.setdefault("environment", ctx.get("environment") or "development")
|
|
26
|
+
trace_data.setdefault("api_key", ctx.get("api_key"))
|
|
27
|
+
try:
|
|
28
|
+
save_trace_payload(trace_data)
|
|
29
|
+
except Exception as e:
|
|
30
|
+
console.print(f"[yellow]LangChain trace persist skipped: {e}[/yellow]")
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class TracellmCallbackHandler(BaseCallbackHandler):
|
|
34
|
+
def __init__(self) -> None:
|
|
35
|
+
super().__init__()
|
|
36
|
+
self.trace_id = str(uuid.uuid4())
|
|
37
|
+
self.steps: list[dict[str, Any]] = []
|
|
38
|
+
self.start_time = time.perf_counter()
|
|
39
|
+
self.start_dt = datetime.now(timezone.utc)
|
|
40
|
+
self._chain_stack: list[dict[str, Any]] = []
|
|
41
|
+
self._retry_count = 0
|
|
42
|
+
self._error: str | None = None
|
|
43
|
+
self._llm_inputs: list[str] = []
|
|
44
|
+
self._llm_outputs: list[str] = []
|
|
45
|
+
|
|
46
|
+
@property
|
|
47
|
+
def always_verbose(self) -> bool:
|
|
48
|
+
return True
|
|
49
|
+
|
|
50
|
+
def on_llm_start(self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any) -> None:
|
|
51
|
+
self._llm_inputs.append(str(prompts))
|
|
52
|
+
|
|
53
|
+
def on_llm_end(self, response: Any, **kwargs: Any) -> None:
|
|
54
|
+
content = ""
|
|
55
|
+
try:
|
|
56
|
+
if hasattr(response, "generations"):
|
|
57
|
+
content = str(response.generations)
|
|
58
|
+
elif hasattr(response, "text"):
|
|
59
|
+
content = response.text
|
|
60
|
+
except Exception:
|
|
61
|
+
content = str(response)
|
|
62
|
+
self._llm_outputs.append(content[:500])
|
|
63
|
+
|
|
64
|
+
def on_llm_error(self, error: Exception | KeyboardInterrupt, **kwargs: Any) -> None:
|
|
65
|
+
self._error = str(error)
|
|
66
|
+
|
|
67
|
+
def on_chain_start(self, serialized: dict[str, Any], inputs: dict[str, Any], **kwargs: Any) -> None:
|
|
68
|
+
step = {
|
|
69
|
+
"step_id": str(uuid.uuid4()),
|
|
70
|
+
"tool_name": serialized.get("id", ["unknown"])[-1],
|
|
71
|
+
"input": {"inputs": inputs},
|
|
72
|
+
"output": {},
|
|
73
|
+
"duration": 0.0,
|
|
74
|
+
"success": True,
|
|
75
|
+
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
76
|
+
"_start": time.perf_counter(),
|
|
77
|
+
}
|
|
78
|
+
self._chain_stack.append(step)
|
|
79
|
+
|
|
80
|
+
def on_chain_end(self, outputs: dict[str, Any], **kwargs: Any) -> None:
|
|
81
|
+
if self._chain_stack:
|
|
82
|
+
step = self._chain_stack.pop()
|
|
83
|
+
step["duration"] = round((time.perf_counter() - step["_start"]) * 1000, 2)
|
|
84
|
+
step["output"] = {"outputs": outputs}
|
|
85
|
+
del step["_start"]
|
|
86
|
+
self.steps.append(step)
|
|
87
|
+
|
|
88
|
+
def on_chain_error(self, error: Exception | KeyboardInterrupt, **kwargs: Any) -> None:
|
|
89
|
+
if self._chain_stack:
|
|
90
|
+
step = self._chain_stack.pop()
|
|
91
|
+
step["duration"] = round((time.perf_counter() - step["_start"]) * 1000, 2)
|
|
92
|
+
step["output"] = {"error": str(error)}
|
|
93
|
+
step["success"] = False
|
|
94
|
+
del step["_start"]
|
|
95
|
+
self.steps.append(step)
|
|
96
|
+
self._error = str(error)
|
|
97
|
+
|
|
98
|
+
def on_tool_start(self, serialized: dict[str, Any], input_str: str, **kwargs: Any) -> None:
|
|
99
|
+
step = {
|
|
100
|
+
"step_id": str(uuid.uuid4()),
|
|
101
|
+
"tool_name": serialized.get("name", "unknown_tool"),
|
|
102
|
+
"input": {"input": input_str},
|
|
103
|
+
"output": {},
|
|
104
|
+
"duration": 0.0,
|
|
105
|
+
"success": True,
|
|
106
|
+
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
107
|
+
"_start": time.perf_counter(),
|
|
108
|
+
}
|
|
109
|
+
self._chain_stack.append(step)
|
|
110
|
+
|
|
111
|
+
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
|
112
|
+
if self._chain_stack:
|
|
113
|
+
step = self._chain_stack.pop()
|
|
114
|
+
step["duration"] = round((time.perf_counter() - step["_start"]) * 1000, 2)
|
|
115
|
+
step["output"] = {"output": output[:500]}
|
|
116
|
+
del step["_start"]
|
|
117
|
+
self.steps.append(step)
|
|
118
|
+
|
|
119
|
+
def on_tool_error(self, error: Exception | KeyboardInterrupt, **kwargs: Any) -> None:
|
|
120
|
+
if self._chain_stack:
|
|
121
|
+
step = self._chain_stack.pop()
|
|
122
|
+
step["duration"] = round((time.perf_counter() - step["_start"]) * 1000, 2)
|
|
123
|
+
step["output"] = {"error": str(error)}
|
|
124
|
+
step["success"] = False
|
|
125
|
+
del step["_start"]
|
|
126
|
+
self.steps.append(step)
|
|
127
|
+
self._retry_count += 1
|
|
128
|
+
|
|
129
|
+
def on_retriever_start(self, serialized: dict[str, Any], query: str, **kwargs: Any) -> None:
|
|
130
|
+
step = {
|
|
131
|
+
"step_id": str(uuid.uuid4()),
|
|
132
|
+
"tool_name": "retriever",
|
|
133
|
+
"input": {"query": query},
|
|
134
|
+
"output": {},
|
|
135
|
+
"duration": 0.0,
|
|
136
|
+
"success": True,
|
|
137
|
+
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
138
|
+
"_start": time.perf_counter(),
|
|
139
|
+
}
|
|
140
|
+
self._chain_stack.append(step)
|
|
141
|
+
|
|
142
|
+
def on_retriever_end(self, documents: list[Any], **kwargs: Any) -> None:
|
|
143
|
+
if self._chain_stack:
|
|
144
|
+
step = self._chain_stack.pop()
|
|
145
|
+
step["duration"] = round((time.perf_counter() - step["_start"]) * 1000, 2)
|
|
146
|
+
doc_previews = [str(doc)[:200] for doc in documents[:5]]
|
|
147
|
+
step["output"] = {"documents": doc_previews, "count": len(documents)}
|
|
148
|
+
del step["_start"]
|
|
149
|
+
self.steps.append(step)
|
|
150
|
+
|
|
151
|
+
def on_retriever_error(self, error: Exception | KeyboardInterrupt, **kwargs: Any) -> None:
|
|
152
|
+
if self._chain_stack:
|
|
153
|
+
step = self._chain_stack.pop()
|
|
154
|
+
step["duration"] = round((time.perf_counter() - step["_start"]) * 1000, 2)
|
|
155
|
+
step["output"] = {"error": str(error)}
|
|
156
|
+
step["success"] = False
|
|
157
|
+
del step["_start"]
|
|
158
|
+
self.steps.append(step)
|
|
159
|
+
|
|
160
|
+
def flush_trace(
|
|
161
|
+
self,
|
|
162
|
+
prompt: str | None = None,
|
|
163
|
+
response: str | None = None,
|
|
164
|
+
status: str = "success",
|
|
165
|
+
) -> dict[str, Any]:
|
|
166
|
+
latency = (time.perf_counter() - self.start_time) * 1000
|
|
167
|
+
total_input = " ".join(self._llm_inputs) if self._llm_inputs else (prompt or "")
|
|
168
|
+
total_output = " ".join(self._llm_outputs) if self._llm_outputs else (response or "")
|
|
169
|
+
|
|
170
|
+
trace_data = {
|
|
171
|
+
"trace_id": self.trace_id,
|
|
172
|
+
"prompt": total_input,
|
|
173
|
+
"response": total_output,
|
|
174
|
+
"latency": round(latency, 2),
|
|
175
|
+
"token_count": estimate_tokens(total_input, total_output),
|
|
176
|
+
"model_name": "langchain",
|
|
177
|
+
"status": status if not self._error else "failed",
|
|
178
|
+
"steps": self.steps,
|
|
179
|
+
"retry_count": self._retry_count,
|
|
180
|
+
"failure_reason": self._error,
|
|
181
|
+
"slow_request": latency >= 1500.0,
|
|
182
|
+
"created_at": self.start_dt.isoformat(),
|
|
183
|
+
"updated_at": datetime.now(timezone.utc).isoformat(),
|
|
184
|
+
}
|
|
185
|
+
_finalize_lc_trace(trace_data)
|
|
186
|
+
return trace_data
|
|
@@ -0,0 +1,234 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import time
|
|
3
|
+
import uuid
|
|
4
|
+
from datetime import datetime, timezone
|
|
5
|
+
from typing import Any, Iterator
|
|
6
|
+
|
|
7
|
+
from openai import OpenAI, Stream
|
|
8
|
+
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
|
9
|
+
|
|
10
|
+
from tracellm.utils import (
|
|
11
|
+
console,
|
|
12
|
+
estimate_tokens,
|
|
13
|
+
build_tool_step,
|
|
14
|
+
)
|
|
15
|
+
from tracellm.db import save_trace_payload
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _get_openai_trace_context() -> dict[str, Any] | None:
|
|
19
|
+
try:
|
|
20
|
+
from tracellm.tracer import _current_trace_context
|
|
21
|
+
return _current_trace_context.get()
|
|
22
|
+
except ImportError:
|
|
23
|
+
return None
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _build_openai_step(
|
|
27
|
+
tool_name: str,
|
|
28
|
+
input_data: dict[str, Any],
|
|
29
|
+
output_data: dict[str, Any],
|
|
30
|
+
duration: float,
|
|
31
|
+
success: bool = True,
|
|
32
|
+
) -> dict[str, Any]:
|
|
33
|
+
return build_tool_step(tool_name, input_data, output_data, duration, success)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _finalize_trace(
|
|
37
|
+
trace_id: str,
|
|
38
|
+
prompt: str,
|
|
39
|
+
response: str,
|
|
40
|
+
model: str,
|
|
41
|
+
latency: float,
|
|
42
|
+
token_count: int,
|
|
43
|
+
steps: list[dict[str, Any]],
|
|
44
|
+
status: str,
|
|
45
|
+
error: str | None = None,
|
|
46
|
+
retry_count: int = 0,
|
|
47
|
+
) -> dict[str, Any]:
|
|
48
|
+
now = datetime.now(timezone.utc)
|
|
49
|
+
ctx = _get_openai_trace_context()
|
|
50
|
+
trace_data = {
|
|
51
|
+
"trace_id": trace_id,
|
|
52
|
+
"prompt": prompt,
|
|
53
|
+
"response": response,
|
|
54
|
+
"latency": round(latency, 2),
|
|
55
|
+
"token_count": token_count,
|
|
56
|
+
"model_name": model,
|
|
57
|
+
"project_id": (ctx.get("project_id") if ctx else None) or "default",
|
|
58
|
+
"project_name": (ctx.get("project_name") if ctx else None) or None,
|
|
59
|
+
"api_key": (ctx.get("api_key") if ctx else None) or None,
|
|
60
|
+
"environment": (ctx.get("environment") if ctx else None) or "development",
|
|
61
|
+
"status": status,
|
|
62
|
+
"steps": steps,
|
|
63
|
+
"retry_count": retry_count,
|
|
64
|
+
"failure_reason": error,
|
|
65
|
+
"slow_request": latency >= 1500.0,
|
|
66
|
+
"created_at": now.isoformat(),
|
|
67
|
+
"updated_at": now.isoformat(),
|
|
68
|
+
}
|
|
69
|
+
try:
|
|
70
|
+
save_trace_payload(trace_data)
|
|
71
|
+
except Exception as e:
|
|
72
|
+
console.print(f"[yellow]Trace persist skipped: {e}[/yellow]")
|
|
73
|
+
return trace_data
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def _extract_token_usage(completion: ChatCompletion) -> dict[str, int]:
|
|
77
|
+
usage = completion.usage
|
|
78
|
+
if usage:
|
|
79
|
+
return {
|
|
80
|
+
"prompt_tokens": usage.prompt_tokens,
|
|
81
|
+
"completion_tokens": usage.completion_tokens,
|
|
82
|
+
"total_tokens": usage.total_tokens,
|
|
83
|
+
}
|
|
84
|
+
return {}
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def _estimate_streaming_tokens(chunks: list[ChatCompletionChunk]) -> dict[str, int]:
|
|
88
|
+
total = 0
|
|
89
|
+
for chunk in chunks:
|
|
90
|
+
if chunk.choices and chunk.choices[0].delta.content:
|
|
91
|
+
total += 1
|
|
92
|
+
return {"total_tokens": total}
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def wrap_openai(client: OpenAI) -> OpenAI:
|
|
96
|
+
original_create = client.chat.completions.create
|
|
97
|
+
|
|
98
|
+
@functools.wraps(original_create)
|
|
99
|
+
def traced_create(*args: Any, **kwargs: Any) -> Any:
|
|
100
|
+
return _traced_chat_completion(client, kwargs, original_create)
|
|
101
|
+
|
|
102
|
+
client.chat.completions.create = traced_create
|
|
103
|
+
return client
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def _traced_chat_completion(client: OpenAI, kwargs: dict[str, Any], original_create: Any) -> Any:
|
|
107
|
+
trace_id = str(uuid.uuid4())
|
|
108
|
+
started_at = time.perf_counter()
|
|
109
|
+
start_dt = datetime.now(timezone.utc)
|
|
110
|
+
|
|
111
|
+
messages = kwargs.get("messages", [])
|
|
112
|
+
model = kwargs.get("model", "unknown")
|
|
113
|
+
stream = kwargs.get("stream", False)
|
|
114
|
+
|
|
115
|
+
prompt_text = str(messages)
|
|
116
|
+
steps: list[dict[str, Any]] = []
|
|
117
|
+
error: str | None = None
|
|
118
|
+
retry_count = 0
|
|
119
|
+
max_retries = kwargs.get("max_retries", 0) or 0
|
|
120
|
+
|
|
121
|
+
max_attempts = max_retries + 1
|
|
122
|
+
last_exception: Exception | None = None
|
|
123
|
+
|
|
124
|
+
for attempt in range(max_attempts):
|
|
125
|
+
step_start = time.perf_counter()
|
|
126
|
+
try:
|
|
127
|
+
if attempt > 0:
|
|
128
|
+
retry_count += 1
|
|
129
|
+
|
|
130
|
+
result = original_create(*client.chat.completions._get_original_args() if hasattr(client.chat.completions, '_get_original_args') else [], **kwargs)
|
|
131
|
+
|
|
132
|
+
if stream:
|
|
133
|
+
collected_chunks: list[ChatCompletionChunk] = []
|
|
134
|
+
full_content = ""
|
|
135
|
+
|
|
136
|
+
for chunk in result:
|
|
137
|
+
collected_chunks.append(chunk)
|
|
138
|
+
if chunk.choices and chunk.choices[0].delta.content:
|
|
139
|
+
full_content += chunk.choices[0].delta.content
|
|
140
|
+
yield chunk
|
|
141
|
+
|
|
142
|
+
step_duration = (time.perf_counter() - step_start) * 1000
|
|
143
|
+
token_data = _estimate_streaming_tokens(collected_chunks)
|
|
144
|
+
token_count = sum(token_data.values())
|
|
145
|
+
steps.append(_build_openai_step(
|
|
146
|
+
"openai_chat_stream",
|
|
147
|
+
{"model": model, "messages": messages},
|
|
148
|
+
{"content_preview": full_content[:200], "chunks": len(collected_chunks)},
|
|
149
|
+
step_duration,
|
|
150
|
+
))
|
|
151
|
+
|
|
152
|
+
latency = (time.perf_counter() - started_at) * 1000
|
|
153
|
+
_finalize_trace(
|
|
154
|
+
trace_id=trace_id,
|
|
155
|
+
prompt=prompt_text,
|
|
156
|
+
response=full_content,
|
|
157
|
+
model=model,
|
|
158
|
+
latency=latency,
|
|
159
|
+
token_count=token_count,
|
|
160
|
+
steps=steps,
|
|
161
|
+
status="success",
|
|
162
|
+
retry_count=retry_count,
|
|
163
|
+
)
|
|
164
|
+
return
|
|
165
|
+
|
|
166
|
+
step_duration = (time.perf_counter() - step_start) * 1000
|
|
167
|
+
|
|
168
|
+
token_usage = _extract_token_usage(result)
|
|
169
|
+
token_count = token_usage.get("total_tokens", estimate_tokens(str(messages), str(result.choices[0].message.content if result.choices else "")))
|
|
170
|
+
|
|
171
|
+
steps.append(_build_openai_step(
|
|
172
|
+
"openai_chat",
|
|
173
|
+
{"model": model, "messages": messages},
|
|
174
|
+
{
|
|
175
|
+
"content": result.choices[0].message.content if result.choices else "",
|
|
176
|
+
"finish_reason": result.choices[0].finish_reason if result.choices else None,
|
|
177
|
+
"usage": token_usage,
|
|
178
|
+
},
|
|
179
|
+
step_duration,
|
|
180
|
+
))
|
|
181
|
+
|
|
182
|
+
latency = (time.perf_counter() - started_at) * 1000
|
|
183
|
+
response_text = result.choices[0].message.content if result.choices else ""
|
|
184
|
+
_finalize_trace(
|
|
185
|
+
trace_id=trace_id,
|
|
186
|
+
prompt=prompt_text,
|
|
187
|
+
response=response_text,
|
|
188
|
+
model=model,
|
|
189
|
+
latency=latency,
|
|
190
|
+
token_count=token_count,
|
|
191
|
+
steps=steps,
|
|
192
|
+
status="success",
|
|
193
|
+
retry_count=retry_count,
|
|
194
|
+
)
|
|
195
|
+
return result
|
|
196
|
+
|
|
197
|
+
except Exception as e:
|
|
198
|
+
last_exception = e
|
|
199
|
+
step_duration = (time.perf_counter() - step_start) * 1000
|
|
200
|
+
error = str(e)
|
|
201
|
+
steps.append(_build_openai_step(
|
|
202
|
+
f"openai_chat_retry_{attempt + 1}" if attempt < max_attempts - 1 else "openai_chat",
|
|
203
|
+
{"model": model, "messages": messages, "attempt": attempt + 1},
|
|
204
|
+
{"error": error},
|
|
205
|
+
step_duration,
|
|
206
|
+
success=(attempt == max_attempts - 1),
|
|
207
|
+
))
|
|
208
|
+
if attempt < max_attempts - 1:
|
|
209
|
+
time.sleep(min(0.5 * (2 ** attempt), 5.0))
|
|
210
|
+
else:
|
|
211
|
+
break
|
|
212
|
+
|
|
213
|
+
latency = (time.perf_counter() - started_at) * 1000
|
|
214
|
+
_finalize_trace(
|
|
215
|
+
trace_id=trace_id,
|
|
216
|
+
prompt=prompt_text,
|
|
217
|
+
response="",
|
|
218
|
+
model=model,
|
|
219
|
+
latency=latency,
|
|
220
|
+
token_count=0,
|
|
221
|
+
steps=steps,
|
|
222
|
+
status="failed",
|
|
223
|
+
error=error,
|
|
224
|
+
retry_count=retry_count,
|
|
225
|
+
)
|
|
226
|
+
if last_exception:
|
|
227
|
+
raise last_exception
|
|
228
|
+
return None
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
class TraceOpenAI(OpenAI):
|
|
232
|
+
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
|
233
|
+
super().__init__(*args, **kwargs)
|
|
234
|
+
wrap_openai(self)
|
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import functools
|
|
3
|
+
import inspect
|
|
4
|
+
import time
|
|
5
|
+
import uuid
|
|
6
|
+
from datetime import datetime, timezone
|
|
7
|
+
from typing import Any, Callable, TypeVar
|
|
8
|
+
|
|
9
|
+
from tracellm.utils import console
|
|
10
|
+
from tracellm.db import save_trace_payload
|
|
11
|
+
|
|
12
|
+
F = TypeVar("F", bound=Callable[..., Any])
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def trace_tool(
|
|
16
|
+
name: str | None = None,
|
|
17
|
+
max_retries: int = 0,
|
|
18
|
+
capture_input: bool = True,
|
|
19
|
+
capture_output: bool = True,
|
|
20
|
+
) -> Callable[[F], F]:
|
|
21
|
+
def decorator(func: F) -> F:
|
|
22
|
+
tool_name = name or func.__name__
|
|
23
|
+
retry_count = 0
|
|
24
|
+
|
|
25
|
+
if inspect.iscoroutinefunction(func):
|
|
26
|
+
|
|
27
|
+
@functools.wraps(func)
|
|
28
|
+
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
29
|
+
nonlocal retry_count
|
|
30
|
+
step_id = str(uuid.uuid4())
|
|
31
|
+
started_at = time.perf_counter()
|
|
32
|
+
last_exception: Exception | None = None
|
|
33
|
+
|
|
34
|
+
for attempt in range(max_retries + 1):
|
|
35
|
+
try:
|
|
36
|
+
if attempt > 0:
|
|
37
|
+
retry_count += 1
|
|
38
|
+
console.print(f" [yellow]retry {attempt}/{max_retries}[/yellow] {tool_name}")
|
|
39
|
+
|
|
40
|
+
result = await func(*args, **kwargs)
|
|
41
|
+
duration = (time.perf_counter() - started_at) * 1000
|
|
42
|
+
|
|
43
|
+
step = {
|
|
44
|
+
"step_id": step_id,
|
|
45
|
+
"tool_name": tool_name,
|
|
46
|
+
"input": _capture_args(func, args, kwargs) if capture_input else {},
|
|
47
|
+
"output": _capture_output(result) if capture_output else {},
|
|
48
|
+
"duration": round(duration, 2),
|
|
49
|
+
"success": True,
|
|
50
|
+
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
51
|
+
}
|
|
52
|
+
_try_append_step(step)
|
|
53
|
+
return result
|
|
54
|
+
|
|
55
|
+
except Exception as e:
|
|
56
|
+
last_exception = e
|
|
57
|
+
duration = (time.perf_counter() - started_at) * 1000
|
|
58
|
+
if attempt < max_retries:
|
|
59
|
+
wait = min(0.5 * (2 ** attempt), 5.0)
|
|
60
|
+
await asyncio.sleep(wait)
|
|
61
|
+
else:
|
|
62
|
+
step = {
|
|
63
|
+
"step_id": step_id,
|
|
64
|
+
"tool_name": tool_name,
|
|
65
|
+
"input": _capture_args(func, args, kwargs) if capture_input else {},
|
|
66
|
+
"output": {"error": str(e)},
|
|
67
|
+
"duration": round(duration, 2),
|
|
68
|
+
"success": False,
|
|
69
|
+
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
70
|
+
}
|
|
71
|
+
_try_append_step(step)
|
|
72
|
+
raise
|
|
73
|
+
|
|
74
|
+
return async_wrapper # type: ignore
|
|
75
|
+
|
|
76
|
+
else:
|
|
77
|
+
|
|
78
|
+
@functools.wraps(func)
|
|
79
|
+
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
80
|
+
nonlocal retry_count
|
|
81
|
+
step_id = str(uuid.uuid4())
|
|
82
|
+
started_at = time.perf_counter()
|
|
83
|
+
last_exception: Exception | None = None
|
|
84
|
+
|
|
85
|
+
for attempt in range(max_retries + 1):
|
|
86
|
+
try:
|
|
87
|
+
if attempt > 0:
|
|
88
|
+
retry_count += 1
|
|
89
|
+
console.print(f" [yellow]retry {attempt}/{max_retries}[/yellow] {tool_name}")
|
|
90
|
+
|
|
91
|
+
result = func(*args, **kwargs)
|
|
92
|
+
duration = (time.perf_counter() - started_at) * 1000
|
|
93
|
+
|
|
94
|
+
step = {
|
|
95
|
+
"step_id": step_id,
|
|
96
|
+
"tool_name": tool_name,
|
|
97
|
+
"input": _capture_args(func, args, kwargs) if capture_input else {},
|
|
98
|
+
"output": _capture_output(result) if capture_output else {},
|
|
99
|
+
"duration": round(duration, 2),
|
|
100
|
+
"success": True,
|
|
101
|
+
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
102
|
+
}
|
|
103
|
+
_try_append_step(step)
|
|
104
|
+
return result
|
|
105
|
+
|
|
106
|
+
except Exception as e:
|
|
107
|
+
last_exception = e
|
|
108
|
+
duration = (time.perf_counter() - started_at) * 1000
|
|
109
|
+
if attempt < max_retries:
|
|
110
|
+
time.sleep(min(0.5 * (2 ** attempt), 5.0))
|
|
111
|
+
else:
|
|
112
|
+
step = {
|
|
113
|
+
"step_id": step_id,
|
|
114
|
+
"tool_name": tool_name,
|
|
115
|
+
"input": _capture_args(func, args, kwargs) if capture_input else {},
|
|
116
|
+
"output": {"error": str(e)},
|
|
117
|
+
"duration": round(duration, 2),
|
|
118
|
+
"success": False,
|
|
119
|
+
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
120
|
+
}
|
|
121
|
+
_try_append_step(step)
|
|
122
|
+
raise
|
|
123
|
+
|
|
124
|
+
return sync_wrapper # type: ignore
|
|
125
|
+
|
|
126
|
+
return decorator
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def _capture_args(func: Callable, args: tuple[Any, ...], kwargs: dict[str, Any]) -> dict[str, Any]:
|
|
130
|
+
try:
|
|
131
|
+
sig = inspect.signature(func)
|
|
132
|
+
bound = sig.bind(*args, **kwargs)
|
|
133
|
+
bound.apply_defaults()
|
|
134
|
+
return {k: str(v)[:500] for k, v in bound.arguments.items()}
|
|
135
|
+
except Exception:
|
|
136
|
+
return {"args": str(args)[:500], "kwargs": str(kwargs)[:500]}
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def _capture_output(result: Any) -> dict[str, Any]:
|
|
140
|
+
return {"result": str(result)[:500]}
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def _try_append_step(step: dict[str, Any]) -> None:
|
|
144
|
+
try:
|
|
145
|
+
from tracellm.tracer import _current_trace_context
|
|
146
|
+
ctx = _current_trace_context.get()
|
|
147
|
+
if ctx is not None:
|
|
148
|
+
steps = ctx.setdefault("collected_steps", [])
|
|
149
|
+
steps.append(step)
|
|
150
|
+
except Exception:
|
|
151
|
+
pass
|
tracellm/mascot.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
"""TraceLLM dinosaur mascot."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from enum import Enum
|
|
6
|
+
|
|
7
|
+
from rich.panel import Panel
|
|
8
|
+
from rich.text import Text
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class MascotState(Enum):
|
|
12
|
+
IDLE = "idle"
|
|
13
|
+
LOADING = "loading"
|
|
14
|
+
SUCCESS = "success"
|
|
15
|
+
WARNING = "warning"
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
DINOSAUR = """\
|
|
19
|
+
__
|
|
20
|
+
/ _)
|
|
21
|
+
.-^^^-/ /
|
|
22
|
+
__/ /
|
|
23
|
+
<__.|_|-|_|"""
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
_STYLE: dict[MascotState, str] = {
|
|
27
|
+
MascotState.IDLE: "bright_black",
|
|
28
|
+
MascotState.LOADING: "cyan",
|
|
29
|
+
MascotState.SUCCESS: "green",
|
|
30
|
+
MascotState.WARNING: "yellow",
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def render(state: MascotState = MascotState.IDLE) -> Text:
|
|
35
|
+
"""Full ASCII dinosaur art with state-based styling."""
|
|
36
|
+
return Text(DINOSAUR, style=_STYLE[state])
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def header(title: str, state: MascotState = MascotState.IDLE) -> Panel:
|
|
40
|
+
"""Compact header Panel with mascot prefix."""
|
|
41
|
+
dino = Text("🦖 ", style=_STYLE[state])
|
|
42
|
+
title_text = Text.assemble(dino, Text(title, style="bold white"))
|
|
43
|
+
return Panel("", title=title_text, border_style="bright_black")
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def message(text: str, state: MascotState = MascotState.IDLE) -> Text:
|
|
47
|
+
"""One-line mascot message (e.g. \"🦖 Trace complete\")."""
|
|
48
|
+
dino = Text("🦖 ", style=_STYLE[state])
|
|
49
|
+
return Text.assemble(dino, Text(text, style=_STYLE[state]))
|