tracellm-cli 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,186 @@
1
+ import time
2
+ import uuid
3
+ from datetime import datetime, timezone
4
+ from typing import Any, Optional
5
+
6
+ from langchain_core.callbacks import BaseCallbackHandler
7
+ from langchain_core.messages import BaseMessage
8
+
9
+ from tracellm.utils import build_tool_step, estimate_tokens, console
10
+ from tracellm.db import save_trace_payload
11
+
12
+
13
+ def _get_trace_context() -> dict[str, Any] | None:
14
+ try:
15
+ from tracellm.tracer import _current_trace_context
16
+ return _current_trace_context.get()
17
+ except ImportError:
18
+ return None
19
+
20
+
21
+ def _finalize_lc_trace(trace_data: dict[str, Any]) -> None:
22
+ ctx = _get_trace_context() or {}
23
+ trace_data.setdefault("project_id", ctx.get("project_id") or "default")
24
+ trace_data.setdefault("project_name", ctx.get("project_name"))
25
+ trace_data.setdefault("environment", ctx.get("environment") or "development")
26
+ trace_data.setdefault("api_key", ctx.get("api_key"))
27
+ try:
28
+ save_trace_payload(trace_data)
29
+ except Exception as e:
30
+ console.print(f"[yellow]LangChain trace persist skipped: {e}[/yellow]")
31
+
32
+
33
+ class TracellmCallbackHandler(BaseCallbackHandler):
34
+ def __init__(self) -> None:
35
+ super().__init__()
36
+ self.trace_id = str(uuid.uuid4())
37
+ self.steps: list[dict[str, Any]] = []
38
+ self.start_time = time.perf_counter()
39
+ self.start_dt = datetime.now(timezone.utc)
40
+ self._chain_stack: list[dict[str, Any]] = []
41
+ self._retry_count = 0
42
+ self._error: str | None = None
43
+ self._llm_inputs: list[str] = []
44
+ self._llm_outputs: list[str] = []
45
+
46
+ @property
47
+ def always_verbose(self) -> bool:
48
+ return True
49
+
50
+ def on_llm_start(self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any) -> None:
51
+ self._llm_inputs.append(str(prompts))
52
+
53
+ def on_llm_end(self, response: Any, **kwargs: Any) -> None:
54
+ content = ""
55
+ try:
56
+ if hasattr(response, "generations"):
57
+ content = str(response.generations)
58
+ elif hasattr(response, "text"):
59
+ content = response.text
60
+ except Exception:
61
+ content = str(response)
62
+ self._llm_outputs.append(content[:500])
63
+
64
+ def on_llm_error(self, error: Exception | KeyboardInterrupt, **kwargs: Any) -> None:
65
+ self._error = str(error)
66
+
67
+ def on_chain_start(self, serialized: dict[str, Any], inputs: dict[str, Any], **kwargs: Any) -> None:
68
+ step = {
69
+ "step_id": str(uuid.uuid4()),
70
+ "tool_name": serialized.get("id", ["unknown"])[-1],
71
+ "input": {"inputs": inputs},
72
+ "output": {},
73
+ "duration": 0.0,
74
+ "success": True,
75
+ "timestamp": datetime.now(timezone.utc).isoformat(),
76
+ "_start": time.perf_counter(),
77
+ }
78
+ self._chain_stack.append(step)
79
+
80
+ def on_chain_end(self, outputs: dict[str, Any], **kwargs: Any) -> None:
81
+ if self._chain_stack:
82
+ step = self._chain_stack.pop()
83
+ step["duration"] = round((time.perf_counter() - step["_start"]) * 1000, 2)
84
+ step["output"] = {"outputs": outputs}
85
+ del step["_start"]
86
+ self.steps.append(step)
87
+
88
+ def on_chain_error(self, error: Exception | KeyboardInterrupt, **kwargs: Any) -> None:
89
+ if self._chain_stack:
90
+ step = self._chain_stack.pop()
91
+ step["duration"] = round((time.perf_counter() - step["_start"]) * 1000, 2)
92
+ step["output"] = {"error": str(error)}
93
+ step["success"] = False
94
+ del step["_start"]
95
+ self.steps.append(step)
96
+ self._error = str(error)
97
+
98
+ def on_tool_start(self, serialized: dict[str, Any], input_str: str, **kwargs: Any) -> None:
99
+ step = {
100
+ "step_id": str(uuid.uuid4()),
101
+ "tool_name": serialized.get("name", "unknown_tool"),
102
+ "input": {"input": input_str},
103
+ "output": {},
104
+ "duration": 0.0,
105
+ "success": True,
106
+ "timestamp": datetime.now(timezone.utc).isoformat(),
107
+ "_start": time.perf_counter(),
108
+ }
109
+ self._chain_stack.append(step)
110
+
111
+ def on_tool_end(self, output: str, **kwargs: Any) -> None:
112
+ if self._chain_stack:
113
+ step = self._chain_stack.pop()
114
+ step["duration"] = round((time.perf_counter() - step["_start"]) * 1000, 2)
115
+ step["output"] = {"output": output[:500]}
116
+ del step["_start"]
117
+ self.steps.append(step)
118
+
119
+ def on_tool_error(self, error: Exception | KeyboardInterrupt, **kwargs: Any) -> None:
120
+ if self._chain_stack:
121
+ step = self._chain_stack.pop()
122
+ step["duration"] = round((time.perf_counter() - step["_start"]) * 1000, 2)
123
+ step["output"] = {"error": str(error)}
124
+ step["success"] = False
125
+ del step["_start"]
126
+ self.steps.append(step)
127
+ self._retry_count += 1
128
+
129
+ def on_retriever_start(self, serialized: dict[str, Any], query: str, **kwargs: Any) -> None:
130
+ step = {
131
+ "step_id": str(uuid.uuid4()),
132
+ "tool_name": "retriever",
133
+ "input": {"query": query},
134
+ "output": {},
135
+ "duration": 0.0,
136
+ "success": True,
137
+ "timestamp": datetime.now(timezone.utc).isoformat(),
138
+ "_start": time.perf_counter(),
139
+ }
140
+ self._chain_stack.append(step)
141
+
142
+ def on_retriever_end(self, documents: list[Any], **kwargs: Any) -> None:
143
+ if self._chain_stack:
144
+ step = self._chain_stack.pop()
145
+ step["duration"] = round((time.perf_counter() - step["_start"]) * 1000, 2)
146
+ doc_previews = [str(doc)[:200] for doc in documents[:5]]
147
+ step["output"] = {"documents": doc_previews, "count": len(documents)}
148
+ del step["_start"]
149
+ self.steps.append(step)
150
+
151
+ def on_retriever_error(self, error: Exception | KeyboardInterrupt, **kwargs: Any) -> None:
152
+ if self._chain_stack:
153
+ step = self._chain_stack.pop()
154
+ step["duration"] = round((time.perf_counter() - step["_start"]) * 1000, 2)
155
+ step["output"] = {"error": str(error)}
156
+ step["success"] = False
157
+ del step["_start"]
158
+ self.steps.append(step)
159
+
160
+ def flush_trace(
161
+ self,
162
+ prompt: str | None = None,
163
+ response: str | None = None,
164
+ status: str = "success",
165
+ ) -> dict[str, Any]:
166
+ latency = (time.perf_counter() - self.start_time) * 1000
167
+ total_input = " ".join(self._llm_inputs) if self._llm_inputs else (prompt or "")
168
+ total_output = " ".join(self._llm_outputs) if self._llm_outputs else (response or "")
169
+
170
+ trace_data = {
171
+ "trace_id": self.trace_id,
172
+ "prompt": total_input,
173
+ "response": total_output,
174
+ "latency": round(latency, 2),
175
+ "token_count": estimate_tokens(total_input, total_output),
176
+ "model_name": "langchain",
177
+ "status": status if not self._error else "failed",
178
+ "steps": self.steps,
179
+ "retry_count": self._retry_count,
180
+ "failure_reason": self._error,
181
+ "slow_request": latency >= 1500.0,
182
+ "created_at": self.start_dt.isoformat(),
183
+ "updated_at": datetime.now(timezone.utc).isoformat(),
184
+ }
185
+ _finalize_lc_trace(trace_data)
186
+ return trace_data
@@ -0,0 +1,234 @@
1
+ import functools
2
+ import time
3
+ import uuid
4
+ from datetime import datetime, timezone
5
+ from typing import Any, Iterator
6
+
7
+ from openai import OpenAI, Stream
8
+ from openai.types.chat import ChatCompletion, ChatCompletionChunk
9
+
10
+ from tracellm.utils import (
11
+ console,
12
+ estimate_tokens,
13
+ build_tool_step,
14
+ )
15
+ from tracellm.db import save_trace_payload
16
+
17
+
18
+ def _get_openai_trace_context() -> dict[str, Any] | None:
19
+ try:
20
+ from tracellm.tracer import _current_trace_context
21
+ return _current_trace_context.get()
22
+ except ImportError:
23
+ return None
24
+
25
+
26
+ def _build_openai_step(
27
+ tool_name: str,
28
+ input_data: dict[str, Any],
29
+ output_data: dict[str, Any],
30
+ duration: float,
31
+ success: bool = True,
32
+ ) -> dict[str, Any]:
33
+ return build_tool_step(tool_name, input_data, output_data, duration, success)
34
+
35
+
36
+ def _finalize_trace(
37
+ trace_id: str,
38
+ prompt: str,
39
+ response: str,
40
+ model: str,
41
+ latency: float,
42
+ token_count: int,
43
+ steps: list[dict[str, Any]],
44
+ status: str,
45
+ error: str | None = None,
46
+ retry_count: int = 0,
47
+ ) -> dict[str, Any]:
48
+ now = datetime.now(timezone.utc)
49
+ ctx = _get_openai_trace_context()
50
+ trace_data = {
51
+ "trace_id": trace_id,
52
+ "prompt": prompt,
53
+ "response": response,
54
+ "latency": round(latency, 2),
55
+ "token_count": token_count,
56
+ "model_name": model,
57
+ "project_id": (ctx.get("project_id") if ctx else None) or "default",
58
+ "project_name": (ctx.get("project_name") if ctx else None) or None,
59
+ "api_key": (ctx.get("api_key") if ctx else None) or None,
60
+ "environment": (ctx.get("environment") if ctx else None) or "development",
61
+ "status": status,
62
+ "steps": steps,
63
+ "retry_count": retry_count,
64
+ "failure_reason": error,
65
+ "slow_request": latency >= 1500.0,
66
+ "created_at": now.isoformat(),
67
+ "updated_at": now.isoformat(),
68
+ }
69
+ try:
70
+ save_trace_payload(trace_data)
71
+ except Exception as e:
72
+ console.print(f"[yellow]Trace persist skipped: {e}[/yellow]")
73
+ return trace_data
74
+
75
+
76
+ def _extract_token_usage(completion: ChatCompletion) -> dict[str, int]:
77
+ usage = completion.usage
78
+ if usage:
79
+ return {
80
+ "prompt_tokens": usage.prompt_tokens,
81
+ "completion_tokens": usage.completion_tokens,
82
+ "total_tokens": usage.total_tokens,
83
+ }
84
+ return {}
85
+
86
+
87
+ def _estimate_streaming_tokens(chunks: list[ChatCompletionChunk]) -> dict[str, int]:
88
+ total = 0
89
+ for chunk in chunks:
90
+ if chunk.choices and chunk.choices[0].delta.content:
91
+ total += 1
92
+ return {"total_tokens": total}
93
+
94
+
95
+ def wrap_openai(client: OpenAI) -> OpenAI:
96
+ original_create = client.chat.completions.create
97
+
98
+ @functools.wraps(original_create)
99
+ def traced_create(*args: Any, **kwargs: Any) -> Any:
100
+ return _traced_chat_completion(client, kwargs, original_create)
101
+
102
+ client.chat.completions.create = traced_create
103
+ return client
104
+
105
+
106
+ def _traced_chat_completion(client: OpenAI, kwargs: dict[str, Any], original_create: Any) -> Any:
107
+ trace_id = str(uuid.uuid4())
108
+ started_at = time.perf_counter()
109
+ start_dt = datetime.now(timezone.utc)
110
+
111
+ messages = kwargs.get("messages", [])
112
+ model = kwargs.get("model", "unknown")
113
+ stream = kwargs.get("stream", False)
114
+
115
+ prompt_text = str(messages)
116
+ steps: list[dict[str, Any]] = []
117
+ error: str | None = None
118
+ retry_count = 0
119
+ max_retries = kwargs.get("max_retries", 0) or 0
120
+
121
+ max_attempts = max_retries + 1
122
+ last_exception: Exception | None = None
123
+
124
+ for attempt in range(max_attempts):
125
+ step_start = time.perf_counter()
126
+ try:
127
+ if attempt > 0:
128
+ retry_count += 1
129
+
130
+ result = original_create(*client.chat.completions._get_original_args() if hasattr(client.chat.completions, '_get_original_args') else [], **kwargs)
131
+
132
+ if stream:
133
+ collected_chunks: list[ChatCompletionChunk] = []
134
+ full_content = ""
135
+
136
+ for chunk in result:
137
+ collected_chunks.append(chunk)
138
+ if chunk.choices and chunk.choices[0].delta.content:
139
+ full_content += chunk.choices[0].delta.content
140
+ yield chunk
141
+
142
+ step_duration = (time.perf_counter() - step_start) * 1000
143
+ token_data = _estimate_streaming_tokens(collected_chunks)
144
+ token_count = sum(token_data.values())
145
+ steps.append(_build_openai_step(
146
+ "openai_chat_stream",
147
+ {"model": model, "messages": messages},
148
+ {"content_preview": full_content[:200], "chunks": len(collected_chunks)},
149
+ step_duration,
150
+ ))
151
+
152
+ latency = (time.perf_counter() - started_at) * 1000
153
+ _finalize_trace(
154
+ trace_id=trace_id,
155
+ prompt=prompt_text,
156
+ response=full_content,
157
+ model=model,
158
+ latency=latency,
159
+ token_count=token_count,
160
+ steps=steps,
161
+ status="success",
162
+ retry_count=retry_count,
163
+ )
164
+ return
165
+
166
+ step_duration = (time.perf_counter() - step_start) * 1000
167
+
168
+ token_usage = _extract_token_usage(result)
169
+ token_count = token_usage.get("total_tokens", estimate_tokens(str(messages), str(result.choices[0].message.content if result.choices else "")))
170
+
171
+ steps.append(_build_openai_step(
172
+ "openai_chat",
173
+ {"model": model, "messages": messages},
174
+ {
175
+ "content": result.choices[0].message.content if result.choices else "",
176
+ "finish_reason": result.choices[0].finish_reason if result.choices else None,
177
+ "usage": token_usage,
178
+ },
179
+ step_duration,
180
+ ))
181
+
182
+ latency = (time.perf_counter() - started_at) * 1000
183
+ response_text = result.choices[0].message.content if result.choices else ""
184
+ _finalize_trace(
185
+ trace_id=trace_id,
186
+ prompt=prompt_text,
187
+ response=response_text,
188
+ model=model,
189
+ latency=latency,
190
+ token_count=token_count,
191
+ steps=steps,
192
+ status="success",
193
+ retry_count=retry_count,
194
+ )
195
+ return result
196
+
197
+ except Exception as e:
198
+ last_exception = e
199
+ step_duration = (time.perf_counter() - step_start) * 1000
200
+ error = str(e)
201
+ steps.append(_build_openai_step(
202
+ f"openai_chat_retry_{attempt + 1}" if attempt < max_attempts - 1 else "openai_chat",
203
+ {"model": model, "messages": messages, "attempt": attempt + 1},
204
+ {"error": error},
205
+ step_duration,
206
+ success=(attempt == max_attempts - 1),
207
+ ))
208
+ if attempt < max_attempts - 1:
209
+ time.sleep(min(0.5 * (2 ** attempt), 5.0))
210
+ else:
211
+ break
212
+
213
+ latency = (time.perf_counter() - started_at) * 1000
214
+ _finalize_trace(
215
+ trace_id=trace_id,
216
+ prompt=prompt_text,
217
+ response="",
218
+ model=model,
219
+ latency=latency,
220
+ token_count=0,
221
+ steps=steps,
222
+ status="failed",
223
+ error=error,
224
+ retry_count=retry_count,
225
+ )
226
+ if last_exception:
227
+ raise last_exception
228
+ return None
229
+
230
+
231
+ class TraceOpenAI(OpenAI):
232
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
233
+ super().__init__(*args, **kwargs)
234
+ wrap_openai(self)
@@ -0,0 +1,151 @@
1
+ import asyncio
2
+ import functools
3
+ import inspect
4
+ import time
5
+ import uuid
6
+ from datetime import datetime, timezone
7
+ from typing import Any, Callable, TypeVar
8
+
9
+ from tracellm.utils import console
10
+ from tracellm.db import save_trace_payload
11
+
12
+ F = TypeVar("F", bound=Callable[..., Any])
13
+
14
+
15
+ def trace_tool(
16
+ name: str | None = None,
17
+ max_retries: int = 0,
18
+ capture_input: bool = True,
19
+ capture_output: bool = True,
20
+ ) -> Callable[[F], F]:
21
+ def decorator(func: F) -> F:
22
+ tool_name = name or func.__name__
23
+ retry_count = 0
24
+
25
+ if inspect.iscoroutinefunction(func):
26
+
27
+ @functools.wraps(func)
28
+ async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
29
+ nonlocal retry_count
30
+ step_id = str(uuid.uuid4())
31
+ started_at = time.perf_counter()
32
+ last_exception: Exception | None = None
33
+
34
+ for attempt in range(max_retries + 1):
35
+ try:
36
+ if attempt > 0:
37
+ retry_count += 1
38
+ console.print(f" [yellow]retry {attempt}/{max_retries}[/yellow] {tool_name}")
39
+
40
+ result = await func(*args, **kwargs)
41
+ duration = (time.perf_counter() - started_at) * 1000
42
+
43
+ step = {
44
+ "step_id": step_id,
45
+ "tool_name": tool_name,
46
+ "input": _capture_args(func, args, kwargs) if capture_input else {},
47
+ "output": _capture_output(result) if capture_output else {},
48
+ "duration": round(duration, 2),
49
+ "success": True,
50
+ "timestamp": datetime.now(timezone.utc).isoformat(),
51
+ }
52
+ _try_append_step(step)
53
+ return result
54
+
55
+ except Exception as e:
56
+ last_exception = e
57
+ duration = (time.perf_counter() - started_at) * 1000
58
+ if attempt < max_retries:
59
+ wait = min(0.5 * (2 ** attempt), 5.0)
60
+ await asyncio.sleep(wait)
61
+ else:
62
+ step = {
63
+ "step_id": step_id,
64
+ "tool_name": tool_name,
65
+ "input": _capture_args(func, args, kwargs) if capture_input else {},
66
+ "output": {"error": str(e)},
67
+ "duration": round(duration, 2),
68
+ "success": False,
69
+ "timestamp": datetime.now(timezone.utc).isoformat(),
70
+ }
71
+ _try_append_step(step)
72
+ raise
73
+
74
+ return async_wrapper # type: ignore
75
+
76
+ else:
77
+
78
+ @functools.wraps(func)
79
+ def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
80
+ nonlocal retry_count
81
+ step_id = str(uuid.uuid4())
82
+ started_at = time.perf_counter()
83
+ last_exception: Exception | None = None
84
+
85
+ for attempt in range(max_retries + 1):
86
+ try:
87
+ if attempt > 0:
88
+ retry_count += 1
89
+ console.print(f" [yellow]retry {attempt}/{max_retries}[/yellow] {tool_name}")
90
+
91
+ result = func(*args, **kwargs)
92
+ duration = (time.perf_counter() - started_at) * 1000
93
+
94
+ step = {
95
+ "step_id": step_id,
96
+ "tool_name": tool_name,
97
+ "input": _capture_args(func, args, kwargs) if capture_input else {},
98
+ "output": _capture_output(result) if capture_output else {},
99
+ "duration": round(duration, 2),
100
+ "success": True,
101
+ "timestamp": datetime.now(timezone.utc).isoformat(),
102
+ }
103
+ _try_append_step(step)
104
+ return result
105
+
106
+ except Exception as e:
107
+ last_exception = e
108
+ duration = (time.perf_counter() - started_at) * 1000
109
+ if attempt < max_retries:
110
+ time.sleep(min(0.5 * (2 ** attempt), 5.0))
111
+ else:
112
+ step = {
113
+ "step_id": step_id,
114
+ "tool_name": tool_name,
115
+ "input": _capture_args(func, args, kwargs) if capture_input else {},
116
+ "output": {"error": str(e)},
117
+ "duration": round(duration, 2),
118
+ "success": False,
119
+ "timestamp": datetime.now(timezone.utc).isoformat(),
120
+ }
121
+ _try_append_step(step)
122
+ raise
123
+
124
+ return sync_wrapper # type: ignore
125
+
126
+ return decorator
127
+
128
+
129
+ def _capture_args(func: Callable, args: tuple[Any, ...], kwargs: dict[str, Any]) -> dict[str, Any]:
130
+ try:
131
+ sig = inspect.signature(func)
132
+ bound = sig.bind(*args, **kwargs)
133
+ bound.apply_defaults()
134
+ return {k: str(v)[:500] for k, v in bound.arguments.items()}
135
+ except Exception:
136
+ return {"args": str(args)[:500], "kwargs": str(kwargs)[:500]}
137
+
138
+
139
+ def _capture_output(result: Any) -> dict[str, Any]:
140
+ return {"result": str(result)[:500]}
141
+
142
+
143
+ def _try_append_step(step: dict[str, Any]) -> None:
144
+ try:
145
+ from tracellm.tracer import _current_trace_context
146
+ ctx = _current_trace_context.get()
147
+ if ctx is not None:
148
+ steps = ctx.setdefault("collected_steps", [])
149
+ steps.append(step)
150
+ except Exception:
151
+ pass
tracellm/mascot.py ADDED
@@ -0,0 +1,49 @@
1
+ """TraceLLM dinosaur mascot."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from enum import Enum
6
+
7
+ from rich.panel import Panel
8
+ from rich.text import Text
9
+
10
+
11
+ class MascotState(Enum):
12
+ IDLE = "idle"
13
+ LOADING = "loading"
14
+ SUCCESS = "success"
15
+ WARNING = "warning"
16
+
17
+
18
+ DINOSAUR = """\
19
+ __
20
+ / _)
21
+ .-^^^-/ /
22
+ __/ /
23
+ <__.|_|-|_|"""
24
+
25
+
26
+ _STYLE: dict[MascotState, str] = {
27
+ MascotState.IDLE: "bright_black",
28
+ MascotState.LOADING: "cyan",
29
+ MascotState.SUCCESS: "green",
30
+ MascotState.WARNING: "yellow",
31
+ }
32
+
33
+
34
+ def render(state: MascotState = MascotState.IDLE) -> Text:
35
+ """Full ASCII dinosaur art with state-based styling."""
36
+ return Text(DINOSAUR, style=_STYLE[state])
37
+
38
+
39
+ def header(title: str, state: MascotState = MascotState.IDLE) -> Panel:
40
+ """Compact header Panel with mascot prefix."""
41
+ dino = Text("🦖 ", style=_STYLE[state])
42
+ title_text = Text.assemble(dino, Text(title, style="bold white"))
43
+ return Panel("", title=title_text, border_style="bright_black")
44
+
45
+
46
+ def message(text: str, state: MascotState = MascotState.IDLE) -> Text:
47
+ """One-line mascot message (e.g. \"🦖 Trace complete\")."""
48
+ dino = Text("🦖 ", style=_STYLE[state])
49
+ return Text.assemble(dino, Text(text, style=_STYLE[state]))