mermaid-trace 0.5.3.post0__py3-none-any.whl → 0.6.1.post0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mermaid_trace/cli.py CHANGED
@@ -232,7 +232,7 @@ def _create_handler(
232
232
  return Handler
233
233
 
234
234
 
235
- def serve(filename: str, port: int = 8000) -> None:
235
+ def serve(filename: str, port: int = 8000, master: bool = False) -> None:
236
236
  """
237
237
  Starts the local HTTP server and file watcher to preview a Mermaid diagram.
238
238
 
@@ -249,7 +249,30 @@ def serve(filename: str, port: int = 8000) -> None:
249
249
  Args:
250
250
  filename (str): The path to the .mmd file to serve.
251
251
  port (int): The port number to bind the server to (default: 8000).
252
+ master (bool): Whether to use the enhanced Master Preview Server (FastAPI + SSE).
252
253
  """
254
+ # 1. Enhanced Master Mode
255
+ if master:
256
+ try:
257
+ from .server import run_server, HAS_SERVER_DEPS
258
+
259
+ if HAS_SERVER_DEPS:
260
+ # For master mode, we watch the directory of the file
261
+ target_dir = os.path.dirname(os.path.abspath(filename))
262
+ if os.path.isdir(filename):
263
+ target_dir = os.path.abspath(filename)
264
+
265
+ # Open browser first
266
+ webbrowser.open(f"http://localhost:{port}")
267
+ run_server(target_dir, port)
268
+ return
269
+ else:
270
+ print(
271
+ "Warning: FastAPI/Uvicorn not found. Falling back to basic server."
272
+ )
273
+ except ImportError:
274
+ print("Warning: server module not found. Falling back to basic server.")
275
+
253
276
  # Create a Path object for robust file path handling
254
277
  path = Path(filename)
255
278
 
@@ -342,6 +365,11 @@ def main() -> None:
342
365
  serve_parser.add_argument(
343
366
  "--port", type=int, default=8000, help="Port to bind to (default: 8000)"
344
367
  )
368
+ serve_parser.add_argument(
369
+ "--master",
370
+ action="store_true",
371
+ help="Use enhanced Master Preview (requires FastAPI)",
372
+ )
345
373
 
346
374
  # Parse the arguments provided by the user
347
375
  args = parser.parse_args()
@@ -349,7 +377,7 @@ def main() -> None:
349
377
  # Dispatch logic
350
378
  if args.command == "serve":
351
379
  # Invoke the serve function with parsed arguments
352
- serve(args.file, args.port)
380
+ serve(args.file, args.port, args.master)
353
381
 
354
382
 
355
383
  if __name__ == "__main__":
@@ -2,3 +2,8 @@
2
2
  MermaidTrace integrations package.
3
3
  Contains middleware and adapters for third-party frameworks.
4
4
  """
5
+
6
+ from .fastapi import MermaidTraceMiddleware
7
+ from .langchain import MermaidTraceCallbackHandler
8
+
9
+ __all__ = ["MermaidTraceMiddleware", "MermaidTraceCallbackHandler"]
@@ -0,0 +1,447 @@
1
+ """
2
+ LangChain Integration Module for MermaidTrace.
3
+
4
+ This module provides a LangChain Callback Handler that allows you to automatically
5
+ generate Mermaid sequence diagrams for your LangChain chains, LLM calls, and tool usage.
6
+ """
7
+
8
+ from typing import Any, Dict, List, Optional, Sequence, TYPE_CHECKING
9
+ import uuid
10
+
11
+ from ..core.events import FlowEvent
12
+ from ..core.context import LogContext
13
+ from ..core.decorators import get_flow_logger
14
+
15
+ if TYPE_CHECKING:
16
+ from langchain_core.callbacks import BaseCallbackHandler
17
+ from langchain_core.outputs import LLMResult
18
+ from langchain_core.agents import AgentAction, AgentFinish
19
+ from langchain_core.documents import Document
20
+ else:
21
+ try:
22
+ from langchain_core.callbacks import BaseCallbackHandler
23
+ from langchain_core.outputs import LLMResult
24
+ from langchain_core.agents import AgentAction, AgentFinish
25
+ from langchain_core.documents import Document
26
+ except ImportError:
27
+ BaseCallbackHandler = object
28
+ LLMResult = Any
29
+ AgentAction = Any
30
+ AgentFinish = Any
31
+ Document = Any
32
+
33
+
34
+ class MermaidTraceCallbackHandler(BaseCallbackHandler):
35
+ """LangChain Callback Handler that records execution flow as Mermaid sequence diagrams.
36
+
37
+ This handler intercepts LangChain events (Chain, LLM, Tool, Agent) and logs them as
38
+ FlowEvents, which are then processed by MermaidTrace to generate diagrams.
39
+ """
40
+
41
+ def __init__(self, host_name: str = "LangChain"):
42
+ """Initialize the callback handler.
43
+
44
+ Args:
45
+ host_name: The name of the host participant in the diagram.
46
+ Defaults to "LangChain".
47
+ """
48
+ if BaseCallbackHandler is object:
49
+ raise ImportError(
50
+ "langchain-core is required to use MermaidTraceCallbackHandler. "
51
+ "Install it with `pip install langchain-core`."
52
+ )
53
+ self.host_name = host_name
54
+ self.logger = get_flow_logger()
55
+ self._participant_stack: List[str] = []
56
+
57
+ def _get_current_source(self) -> str:
58
+ """Get the current source participant from stack or context."""
59
+ if self._participant_stack:
60
+ return self._participant_stack[-1]
61
+ return str(LogContext.get("current_participant", self.host_name))
62
+
63
+ def on_chain_start(
64
+ self,
65
+ serialized: Optional[Dict[str, Any]],
66
+ inputs: Dict[str, Any],
67
+ *,
68
+ run_id: Any = None,
69
+ parent_run_id: Any = None,
70
+ tags: Optional[List[str]] = None,
71
+ metadata: Optional[Dict[str, Any]] = None,
72
+ **kwargs: Any,
73
+ ) -> None:
74
+ """Run when chain starts running."""
75
+ target = (
76
+ (serialized.get("name") if serialized else None)
77
+ or kwargs.get("name")
78
+ or "Chain"
79
+ )
80
+ source = self._get_current_source()
81
+
82
+ event = FlowEvent(
83
+ source=source,
84
+ target=target,
85
+ action="Run Chain",
86
+ message=f"Start Chain: {target}",
87
+ trace_id=LogContext.get("trace_id", str(uuid.uuid4())),
88
+ params=str(inputs),
89
+ )
90
+ self.logger.info(
91
+ f"{source} -> {target}: {event.action}", extra={"flow_event": event}
92
+ )
93
+ self._participant_stack.append(target)
94
+
95
+ def on_chain_end(
96
+ self,
97
+ outputs: Dict[str, Any],
98
+ *,
99
+ run_id: Any = None,
100
+ parent_run_id: Any = None,
101
+ **kwargs: Any,
102
+ ) -> None:
103
+ """Run when chain ends running."""
104
+ if not self._participant_stack:
105
+ return
106
+
107
+ target = self._participant_stack.pop()
108
+ source = self._get_current_source()
109
+
110
+ event = FlowEvent(
111
+ source=target,
112
+ target=source,
113
+ action="Finish Chain",
114
+ message="Chain Complete",
115
+ trace_id=LogContext.get("trace_id", str(uuid.uuid4())),
116
+ result=str(outputs),
117
+ is_return=True,
118
+ )
119
+ self.logger.info(
120
+ f"{target} -> {source}: {event.action}", extra={"flow_event": event}
121
+ )
122
+
123
+ def on_llm_start(
124
+ self,
125
+ serialized: Optional[Dict[str, Any]],
126
+ prompts: List[str],
127
+ *,
128
+ run_id: Any = None,
129
+ parent_run_id: Any = None,
130
+ tags: Optional[List[str]] = None,
131
+ metadata: Optional[Dict[str, Any]] = None,
132
+ **kwargs: Any,
133
+ ) -> None:
134
+ """Run when LLM starts running."""
135
+ target = (serialized.get("name") if serialized else None) or "LLM"
136
+ source = self._get_current_source()
137
+
138
+ event = FlowEvent(
139
+ source=source,
140
+ target=target,
141
+ action="Prompt",
142
+ message="LLM Request",
143
+ trace_id=LogContext.get("trace_id", str(uuid.uuid4())),
144
+ params=str(prompts),
145
+ )
146
+ self.logger.info(
147
+ f"{source} -> {target}: {event.action}", extra={"flow_event": event}
148
+ )
149
+ self._participant_stack.append(target)
150
+
151
+ def on_chat_model_start(
152
+ self,
153
+ serialized: Optional[Dict[str, Any]],
154
+ messages: List[List[Any]],
155
+ *,
156
+ run_id: Any = None,
157
+ parent_run_id: Any = None,
158
+ tags: Optional[List[str]] = None,
159
+ metadata: Optional[Dict[str, Any]] = None,
160
+ **kwargs: Any,
161
+ ) -> None:
162
+ """Run when Chat Model starts running."""
163
+ target = (serialized.get("name") if serialized else None) or "ChatModel"
164
+ source = self._get_current_source()
165
+
166
+ event = FlowEvent(
167
+ source=source,
168
+ target=target,
169
+ action="Chat",
170
+ message="ChatModel Request",
171
+ trace_id=LogContext.get("trace_id", str(uuid.uuid4())),
172
+ params=str(messages),
173
+ )
174
+ self.logger.info(
175
+ f"{source} -> {target}: {event.action}", extra={"flow_event": event}
176
+ )
177
+ self._participant_stack.append(target)
178
+
179
+ def on_llm_end(
180
+ self,
181
+ response: LLMResult,
182
+ *,
183
+ run_id: Any = None,
184
+ parent_run_id: Any = None,
185
+ **kwargs: Any,
186
+ ) -> None:
187
+ """Run when LLM ends running."""
188
+ if not self._participant_stack:
189
+ return
190
+
191
+ source = self._participant_stack.pop()
192
+ target = self._get_current_source()
193
+
194
+ event = FlowEvent(
195
+ source=source,
196
+ target=target,
197
+ action="Response",
198
+ message="LLM/Chat Completion",
199
+ trace_id=LogContext.get("trace_id", str(uuid.uuid4())),
200
+ result=str(response.generations),
201
+ is_return=True,
202
+ )
203
+ self.logger.info(
204
+ f"{source} -> {target}: {event.action}", extra={"flow_event": event}
205
+ )
206
+
207
+ def on_llm_error(
208
+ self,
209
+ error: BaseException,
210
+ *,
211
+ run_id: Any = None,
212
+ parent_run_id: Any = None,
213
+ **kwargs: Any,
214
+ ) -> None:
215
+ """Run when LLM errors."""
216
+ if not self._participant_stack:
217
+ return
218
+ target = self._participant_stack.pop()
219
+ source = self._get_current_source()
220
+ event = FlowEvent(
221
+ source=target,
222
+ target=source,
223
+ action="Error",
224
+ message=f"LLM Error: {type(error).__name__}",
225
+ trace_id=LogContext.get("trace_id", str(uuid.uuid4())),
226
+ is_error=True,
227
+ error_message=str(error),
228
+ is_return=True,
229
+ )
230
+ self.logger.info(
231
+ f"{target} -> {source}: {event.action}", extra={"flow_event": event}
232
+ )
233
+
234
+ def on_retriever_start(
235
+ self,
236
+ serialized: Optional[Dict[str, Any]],
237
+ query: str,
238
+ *,
239
+ run_id: Any = None,
240
+ parent_run_id: Any = None,
241
+ tags: Optional[List[str]] = None,
242
+ metadata: Optional[Dict[str, Any]] = None,
243
+ **kwargs: Any,
244
+ ) -> None:
245
+ """Run when Retriever starts running."""
246
+ target = (serialized.get("name") if serialized else None) or "Retriever"
247
+ source = self._get_current_source()
248
+
249
+ event = FlowEvent(
250
+ source=source,
251
+ target=target,
252
+ action="Retrieve",
253
+ message=f"Query: {query[:50]}...",
254
+ trace_id=LogContext.get("trace_id", str(uuid.uuid4())),
255
+ params=query,
256
+ )
257
+ self.logger.info(
258
+ f"{source} -> {target}: {event.action}", extra={"flow_event": event}
259
+ )
260
+ self._participant_stack.append(target)
261
+
262
+ def on_retriever_end(
263
+ self,
264
+ documents: Sequence[Document],
265
+ *,
266
+ run_id: Any = None,
267
+ parent_run_id: Any = None,
268
+ **kwargs: Any,
269
+ ) -> Any:
270
+ """Run when Retriever ends running."""
271
+ if not self._participant_stack:
272
+ return
273
+
274
+ target = self._participant_stack.pop()
275
+ source = self._get_current_source()
276
+
277
+ event = FlowEvent(
278
+ source=target,
279
+ target=source,
280
+ action="Documents",
281
+ message=f"Retrieved {len(documents)} docs",
282
+ trace_id=LogContext.get("trace_id", str(uuid.uuid4())),
283
+ result=f"Count: {len(documents)}",
284
+ is_return=True,
285
+ )
286
+ self.logger.info(
287
+ f"{target} -> {source}: {event.action}", extra={"flow_event": event}
288
+ )
289
+
290
+ def on_tool_start(
291
+ self,
292
+ serialized: Optional[Dict[str, Any]],
293
+ input_str: str,
294
+ *,
295
+ run_id: Any = None,
296
+ parent_run_id: Any = None,
297
+ tags: Optional[List[str]] = None,
298
+ metadata: Optional[Dict[str, Any]] = None,
299
+ **kwargs: Any,
300
+ ) -> None:
301
+ """Run when tool starts running."""
302
+ target = (serialized.get("name") if serialized else None) or "Tool"
303
+ source = self._get_current_source()
304
+
305
+ event = FlowEvent(
306
+ source=source,
307
+ target=target,
308
+ action="Call Tool",
309
+ message=f"Tool: {target}",
310
+ trace_id=LogContext.get("trace_id", str(uuid.uuid4())),
311
+ params=input_str,
312
+ )
313
+ self.logger.info(
314
+ f"{source} -> {target}: {event.action}", extra={"flow_event": event}
315
+ )
316
+ self._participant_stack.append(target)
317
+
318
+ def on_tool_end(
319
+ self,
320
+ output: Any,
321
+ *,
322
+ run_id: Any = None,
323
+ parent_run_id: Any = None,
324
+ **kwargs: Any,
325
+ ) -> None:
326
+ """Run when tool ends running."""
327
+ if not self._participant_stack:
328
+ return
329
+
330
+ target = self._participant_stack.pop()
331
+ source = self._get_current_source()
332
+
333
+ event = FlowEvent(
334
+ source=target,
335
+ target=source,
336
+ action="Finish Tool",
337
+ message="Tool Complete",
338
+ trace_id=LogContext.get("trace_id", str(uuid.uuid4())),
339
+ result=str(output),
340
+ is_return=True,
341
+ )
342
+ self.logger.info(
343
+ f"{target} -> {source}: {event.action}", extra={"flow_event": event}
344
+ )
345
+
346
+ def on_agent_action(
347
+ self,
348
+ action: AgentAction,
349
+ *,
350
+ run_id: Any = None,
351
+ parent_run_id: Any = None,
352
+ **kwargs: Any,
353
+ ) -> Any:
354
+ """Run on agent action."""
355
+ target = action.tool
356
+ source = self._get_current_source()
357
+
358
+ event = FlowEvent(
359
+ source=source,
360
+ target=target,
361
+ action="Agent Action",
362
+ message=f"Decided to use: {target}",
363
+ trace_id=LogContext.get("trace_id", str(uuid.uuid4())),
364
+ params=str(action.tool_input),
365
+ )
366
+ self.logger.info(
367
+ f"{source} -> {target}: {event.action}", extra={"flow_event": event}
368
+ )
369
+
370
+ def on_agent_finish(
371
+ self,
372
+ finish: AgentFinish,
373
+ *,
374
+ run_id: Any = None,
375
+ parent_run_id: Any = None,
376
+ **kwargs: Any,
377
+ ) -> Any:
378
+ """Run on agent finish."""
379
+ source = self._get_current_source()
380
+ target = "User" # Usually agents finish by returning to the user
381
+
382
+ event = FlowEvent(
383
+ source=source,
384
+ target=target,
385
+ action="Agent Finish",
386
+ message="Final Answer Ready",
387
+ trace_id=LogContext.get("trace_id", str(uuid.uuid4())),
388
+ result=str(finish.return_values),
389
+ is_return=True,
390
+ )
391
+ self.logger.info(
392
+ f"{source} -> {target}: {event.action}", extra={"flow_event": event}
393
+ )
394
+
395
+ def on_chain_error(
396
+ self,
397
+ error: BaseException,
398
+ *,
399
+ run_id: Any = None,
400
+ parent_run_id: Any = None,
401
+ **kwargs: Any,
402
+ ) -> None:
403
+ """Run when chain errors."""
404
+ if not self._participant_stack:
405
+ return
406
+ target = self._participant_stack.pop()
407
+ source = self._get_current_source()
408
+ event = FlowEvent(
409
+ source=target,
410
+ target=source,
411
+ action="Error",
412
+ message=f"Chain Error: {type(error).__name__}",
413
+ trace_id=LogContext.get("trace_id", str(uuid.uuid4())),
414
+ is_error=True,
415
+ error_message=str(error),
416
+ is_return=True,
417
+ )
418
+ self.logger.info(
419
+ f"{target} -> {source}: {event.action}", extra={"flow_event": event}
420
+ )
421
+
422
+ def on_tool_error(
423
+ self,
424
+ error: BaseException,
425
+ *,
426
+ run_id: Any = None,
427
+ parent_run_id: Any = None,
428
+ **kwargs: Any,
429
+ ) -> None:
430
+ """Run when tool errors."""
431
+ if not self._participant_stack:
432
+ return
433
+ target = self._participant_stack.pop()
434
+ source = self._get_current_source()
435
+ event = FlowEvent(
436
+ source=target,
437
+ target=source,
438
+ action="Error",
439
+ message=f"Tool Error: {type(error).__name__}",
440
+ trace_id=LogContext.get("trace_id", str(uuid.uuid4())),
441
+ is_error=True,
442
+ error_message=str(error),
443
+ is_return=True,
444
+ )
445
+ self.logger.info(
446
+ f"{target} -> {source}: {event.action}", extra={"flow_event": event}
447
+ )