turingpulse-sdk-langchain 1.0.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,42 @@
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # Virtual environments
7
+ .venv/
8
+ venv/
9
+ ENV/
10
+
11
+ # Distribution / packaging
12
+ dist/
13
+ build/
14
+ *.egg-info/
15
+
16
+ # Database files
17
+ *.db
18
+ *.sqlite3
19
+
20
+ # Environment variables
21
+ .env
22
+ .env.local
23
+
24
+ # IDE
25
+ .idea/
26
+ .vscode/
27
+ *.swp
28
+ *.swo
29
+
30
+ # Testing
31
+ .pytest_cache/
32
+ .coverage
33
+ htmlcov/
34
+ .tox/
35
+
36
+ # Logs
37
+ *.log
38
+ logs/
39
+
40
+ # OS files
41
+ .DS_Store
42
+ Thumbs.db
@@ -0,0 +1,11 @@
1
+ Metadata-Version: 2.4
2
+ Name: turingpulse-sdk-langchain
3
+ Version: 1.0.0
4
+ Summary: TuringPulse SDK integration for LangChain
5
+ License-Expression: LicenseRef-Proprietary
6
+ Requires-Python: >=3.11
7
+ Requires-Dist: langchain-core>=0.3.0
8
+ Requires-Dist: turingpulse-sdk>=1.0.0
9
+ Provides-Extra: dev
10
+ Requires-Dist: pytest-asyncio>=0.23; extra == 'dev'
11
+ Requires-Dist: pytest>=8.0; extra == 'dev'
@@ -0,0 +1,17 @@
1
+ [build-system]
2
+ requires = ["hatchling"]
3
+ build-backend = "hatchling.build"
4
+
5
+ [project]
6
+ name = "turingpulse-sdk-langchain"
7
+ version = "1.0.0"
8
+ description = "TuringPulse SDK integration for LangChain"
9
+ requires-python = ">=3.11"
10
+ license = "LicenseRef-Proprietary"
11
+ dependencies = [
12
+ "turingpulse-sdk>=1.0.0",
13
+ "langchain-core>=0.3.0",
14
+ ]
15
+
16
+ [project.optional-dependencies]
17
+ dev = ["pytest>=8.0", "pytest-asyncio>=0.23"]
@@ -0,0 +1,6 @@
1
+ """TuringPulse SDK integration for LangChain."""
2
+
3
+ from ._wrapper import instrument_langchain
4
+
5
+ __version__ = "1.0.0"
6
+ __all__ = ["instrument_langchain"]
@@ -0,0 +1,693 @@
1
+ """LangChain instrumentation for TuringPulse SDK.
2
+
3
+ Wraps LangChain ``Runnable`` objects (chains, agents, RunnableSequence,
4
+ AgentExecutor, etc.) for full observability via ``langchain-core``
5
+ callbacks.
6
+
7
+ Captures:
8
+ - Per-step chain spans with state diffs
9
+ - LLM calls with prompt, tokens, model, system prompt
10
+ - Tool executions with input/output/errors
11
+ - Retriever lookups with query and document counts
12
+ - Nested sub-agent invocations via parent_run_id tracking
13
+
14
+ Usage::
15
+
16
+ from turingpulse_sdk_langchain import instrument_langchain
17
+
18
+ chain = prompt | llm | output_parser
19
+ run = instrument_langchain(chain, name="my-chain")
20
+ result = run.invoke({"input": "Hello"})
21
+ """
22
+
23
+ from __future__ import annotations
24
+
25
+ import logging
26
+ import time
27
+ from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Optional, Sequence
28
+ from uuid import uuid4
29
+
30
+ from turingpulse_sdk.config import MAX_FIELD_SIZE
31
+ from turingpulse_sdk.context import current_context
32
+ from turingpulse_sdk import instrument, GovernanceDirective
33
+ from turingpulse_sdk.integrations.base import emit_child_spans, emit_child_spans_async
34
+
35
+ logger = logging.getLogger("turingpulse.sdk.integrations.langchain")
36
+
37
+ FRAMEWORK_NAME = "langchain"
38
+
39
+ _INPUT_KEYS = (
40
+ "input", "query", "question", "prompt", "request",
41
+ "user_input", "message", "text",
42
+ )
43
+ _OUTPUT_KEYS = (
44
+ "output", "result", "answer", "response", "text",
45
+ )
46
+
47
+ _HAS_CALLBACKS = False
48
+ try:
49
+ from langchain_core.callbacks import BaseCallbackHandler
50
+ from langchain_core.messages import BaseMessage
51
+ _HAS_CALLBACKS = True
52
+ except ImportError:
53
+ pass
54
+
55
+
56
+ def _safe_str(value: Any) -> str:
57
+ try:
58
+ return str(value)[:MAX_FIELD_SIZE]
59
+ except Exception:
60
+ return ""
61
+
62
+
63
+ def _extract_system_prompt(messages: List[Any]) -> str:
64
+ for msg in (messages or []):
65
+ if hasattr(msg, "type") and msg.type == "system":
66
+ return _safe_str(msg.content)
67
+ if isinstance(msg, dict) and msg.get("role") == "system":
68
+ return _safe_str(msg.get("content", ""))
69
+ return ""
70
+
71
+
72
+ def _extract_user_query(messages: List[Any]) -> str:
73
+ for msg in reversed(messages or []):
74
+ if hasattr(msg, "type") and msg.type == "human":
75
+ return _safe_str(msg.content)
76
+ if isinstance(msg, dict) and msg.get("role") == "user":
77
+ return _safe_str(msg.get("content", ""))
78
+ return ""
79
+
80
+
81
+ def _messages_to_str(messages: List[Any]) -> str:
82
+ parts = []
83
+ for msg in (messages or []):
84
+ role = getattr(msg, "type", "unknown")
85
+ content = getattr(msg, "content", str(msg))
86
+ parts.append(f"[{role}]: {content}")
87
+ return "\n".join(parts)[:MAX_FIELD_SIZE]
88
+
89
+
90
+ def _condense_state(state: Any) -> Dict[str, Any]:
91
+ if not isinstance(state, dict):
92
+ return {"value": _safe_str(state)}
93
+ condensed = {}
94
+ for k, v in state.items():
95
+ s = _safe_str(v)
96
+ condensed[k] = s if len(s) <= 500 else s[:500] + "…"
97
+ return condensed
98
+
99
+
100
+ def _state_diff(inputs: Any, outputs: Any) -> Dict[str, Any]:
101
+ if not isinstance(inputs, dict) or not isinstance(outputs, dict):
102
+ return outputs if isinstance(outputs, dict) else {"result": _safe_str(outputs)}
103
+ diff: Dict[str, Any] = {}
104
+ for key, value in outputs.items():
105
+ old = inputs.get(key)
106
+ try:
107
+ same = old == value
108
+ except Exception:
109
+ same = False
110
+ if key not in inputs or not same:
111
+ diff[key] = _safe_str(value)
112
+ return diff if diff else _condense_state(outputs)
113
+
114
+
115
+ def _extract_io(data: Any, keys: tuple) -> str:
116
+ if isinstance(data, str):
117
+ return data[:MAX_FIELD_SIZE]
118
+ if isinstance(data, dict):
119
+ for k in keys:
120
+ v = data.get(k)
121
+ if v:
122
+ return _safe_str(v)
123
+ return _safe_str(data)
124
+ return _safe_str(data)
125
+
126
+
127
+ if _HAS_CALLBACKS:
128
+ class _SpanCollector(BaseCallbackHandler):
129
+ """LangChain callback handler that collects per-step span data.
130
+
131
+ Handles chains, LLM calls, tool executions, and retriever lookups.
132
+ LLM calls inside named chains are merged into the parent chain span.
133
+ """
134
+
135
+ _SKIP_NAMES = frozenset({
136
+ "RunnableSequence", "RunnableLambda", "RunnableParallel",
137
+ "RunnablePassthrough", "RunnableBranch", "RunnableWithFallbacks",
138
+ "ChannelWrite", "ChannelRead",
139
+ })
140
+
141
+ def __init__(self, workflow_name: str, default_model: str, default_provider: str):
142
+ super().__init__()
143
+ self.workflow_name = workflow_name
144
+ self.default_model = default_model
145
+ self.default_provider = default_provider
146
+ self.collected_spans: List[Dict[str, Any]] = []
147
+ self._active_chains: Dict[str, Dict[str, Any]] = {}
148
+ self._active_llms: Dict[str, Dict[str, Any]] = {}
149
+ self._active_tools: Dict[str, Dict[str, Any]] = {}
150
+ self._active_retrievers: Dict[str, Dict[str, Any]] = {}
151
+ self._parent_map: Dict[str, str] = {}
152
+ self.total_prompt_tokens = 0
153
+ self.total_completion_tokens = 0
154
+
155
+ def _find_named_ancestor(self, parent_run_id: Optional[str]) -> Optional[str]:
156
+ visited: set = set()
157
+ rid = parent_run_id
158
+ while rid and rid not in visited:
159
+ visited.add(rid)
160
+ if rid in self._active_chains:
161
+ return rid
162
+ rid = self._parent_map.get(rid)
163
+ return None
164
+
165
+ @staticmethod
166
+ def _resolve_chain_name(
167
+ serialized: Optional[Dict[str, Any]],
168
+ kwargs: Dict[str, Any],
169
+ ) -> str:
170
+ kw_name = kwargs.get("name", "")
171
+ if kw_name:
172
+ return str(kw_name)
173
+ if serialized:
174
+ return serialized.get("name", "") or (serialized.get("id", [""]) or [""])[-1]
175
+ return ""
176
+
177
+ def on_chain_start(
178
+ self, serialized: Optional[Dict[str, Any]], inputs: Dict[str, Any],
179
+ *, run_id, parent_run_id=None, tags=None, metadata=None, **kwargs,
180
+ ):
181
+ rid = str(run_id)
182
+ pid = str(parent_run_id) if parent_run_id else None
183
+ name = self._resolve_chain_name(serialized, kwargs)
184
+
185
+ if pid:
186
+ self._parent_map[rid] = pid
187
+
188
+ if name and name not in self._SKIP_NAMES:
189
+ self._active_chains[rid] = {
190
+ "name": name,
191
+ "started_at": time.time(),
192
+ "inputs": dict(inputs) if isinstance(inputs, dict) else inputs,
193
+ "parent_run_id": pid,
194
+ "llm_data": None,
195
+ "node_type": "processor",
196
+ }
197
+
198
+ def on_chain_end(self, outputs: Dict[str, Any], *, run_id, **kwargs):
199
+ key = str(run_id)
200
+ data = self._active_chains.pop(key, None)
201
+ if not data:
202
+ return
203
+ dur = int((time.time() - data["started_at"]) * 1000)
204
+
205
+ llm = data.get("llm_data")
206
+
207
+ if llm:
208
+ node_input = data["inputs"]
209
+ node_output = outputs if isinstance(outputs, dict) else {"result": _safe_str(outputs)}
210
+ else:
211
+ node_input = _condense_state(data["inputs"])
212
+ raw_out = outputs if isinstance(outputs, dict) else {}
213
+ node_output = _state_diff(data["inputs"], raw_out)
214
+
215
+ span: Dict[str, Any] = {
216
+ "node": data["name"],
217
+ "node_type": "llm" if llm else data.get("node_type", "processor"),
218
+ "duration_ms": dur,
219
+ "status": "success",
220
+ "input": node_input,
221
+ "output": node_output,
222
+ }
223
+
224
+ if llm:
225
+ span["model"] = llm.get("model", self.default_model)
226
+ span["provider"] = self.default_provider
227
+ span["tokens"] = llm.get("tokens")
228
+ if llm.get("system_prompt"):
229
+ span["system_prompt"] = llm["system_prompt"]
230
+ if llm.get("prompt"):
231
+ span["prompt"] = llm["prompt"]
232
+ if llm.get("user_query"):
233
+ span["input"] = {"user_query": llm["user_query"]}
234
+ if llm.get("tool_calls"):
235
+ span["tool_calls"] = llm["tool_calls"]
236
+ if llm.get("llm_output"):
237
+ span["output"] = llm["llm_output"]
238
+
239
+ self.collected_spans.append(span)
240
+
241
+ def on_chain_error(self, error: BaseException, *, run_id, **kwargs):
242
+ key = str(run_id)
243
+ data = self._active_chains.pop(key, None)
244
+ if not data:
245
+ return
246
+ dur = int((time.time() - data["started_at"]) * 1000)
247
+ self.collected_spans.append({
248
+ "node": data["name"],
249
+ "node_type": data.get("node_type", "processor"),
250
+ "duration_ms": dur,
251
+ "status": "error",
252
+ "input": _condense_state(data["inputs"]),
253
+ "output": {"error": str(error)[:MAX_FIELD_SIZE]},
254
+ })
255
+
256
+ def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], *, run_id, parent_run_id=None, **kwargs):
257
+ rid = str(run_id)
258
+ if parent_run_id:
259
+ self._parent_map[rid] = str(parent_run_id)
260
+ self._active_llms[rid] = {
261
+ "started_at": time.time(),
262
+ "prompts": prompts,
263
+ "model": (serialized or {}).get("kwargs", {}).get("model_name", self.default_model),
264
+ "parent_run_id": str(parent_run_id) if parent_run_id else None,
265
+ }
266
+
267
+ def on_chat_model_start(self, serialized: Dict[str, Any], messages: List[List[Any]], *, run_id, parent_run_id=None, **kwargs):
268
+ rid = str(run_id)
269
+ if parent_run_id:
270
+ self._parent_map[rid] = str(parent_run_id)
271
+ flat = messages[0] if messages else []
272
+ self._active_llms[rid] = {
273
+ "started_at": time.time(),
274
+ "messages": flat,
275
+ "system_prompt": _extract_system_prompt(flat),
276
+ "user_query": _extract_user_query(flat),
277
+ "prompt": _messages_to_str(flat),
278
+ "model": (
279
+ kwargs.get("invocation_params", {}).get("model_name")
280
+ or kwargs.get("invocation_params", {}).get("model")
281
+ or (serialized or {}).get("kwargs", {}).get("model_name")
282
+ or self.default_model
283
+ ),
284
+ "parent_run_id": str(parent_run_id) if parent_run_id else None,
285
+ }
286
+
287
+ def on_llm_end(self, response, *, run_id, parent_run_id=None, **kwargs):
288
+ key = str(run_id)
289
+ data = self._active_llms.pop(key, None)
290
+ if not data:
291
+ return
292
+
293
+ dur = int((time.time() - data["started_at"]) * 1000)
294
+ content = ""
295
+ usage: Dict[str, int] = {}
296
+
297
+ if hasattr(response, "generations") and response.generations:
298
+ gen = response.generations[0]
299
+ if gen:
300
+ msg = gen[0].message if hasattr(gen[0], "message") else gen[0]
301
+ content = getattr(msg, "content", "") or str(gen[0])
302
+
303
+ if hasattr(response, "llm_output") and response.llm_output:
304
+ token_usage = response.llm_output.get("token_usage", {})
305
+ usage = {
306
+ "prompt": int(token_usage.get("prompt_tokens", 0)),
307
+ "completion": int(token_usage.get("completion_tokens", 0)),
308
+ }
309
+
310
+ if not usage.get("prompt") and hasattr(response, "generations") and response.generations:
311
+ gen = response.generations[0]
312
+ if gen:
313
+ msg = gen[0].message if hasattr(gen[0], "message") else gen[0]
314
+ um = getattr(msg, "usage_metadata", None) or {}
315
+ if isinstance(um, dict):
316
+ usage = {
317
+ "prompt": int(um.get("input_tokens", 0)),
318
+ "completion": int(um.get("output_tokens", 0)),
319
+ }
320
+
321
+ self.total_prompt_tokens += usage.get("prompt", 0)
322
+ self.total_completion_tokens += usage.get("completion", 0)
323
+
324
+ tool_calls_raw = []
325
+ if hasattr(response, "generations") and response.generations:
326
+ gen = response.generations[0]
327
+ if gen:
328
+ msg = gen[0].message if hasattr(gen[0], "message") else gen[0]
329
+ if hasattr(msg, "tool_calls") and msg.tool_calls:
330
+ for tc in msg.tool_calls:
331
+ tool_calls_raw.append({
332
+ "tool_name": tc.get("name", "unknown") if isinstance(tc, dict) else getattr(tc, "name", "unknown"),
333
+ "tool_args": tc.get("args", {}) if isinstance(tc, dict) else getattr(tc, "args", {}),
334
+ "tool_id": tc.get("id", str(uuid4())) if isinstance(tc, dict) else getattr(tc, "id", str(uuid4())),
335
+ "tool_result": "",
336
+ "success": True,
337
+ })
338
+
339
+ llm_result = {
340
+ "model": data.get("model", self.default_model),
341
+ "tokens": usage if usage else None,
342
+ "system_prompt": data.get("system_prompt"),
343
+ "prompt": data.get("prompt"),
344
+ "user_query": data.get("user_query"),
345
+ "tool_calls": tool_calls_raw or None,
346
+ "llm_output": {"response": _safe_str(content)},
347
+ "duration_ms": dur,
348
+ }
349
+
350
+ ancestor_id = self._find_named_ancestor(
351
+ data.get("parent_run_id") or (str(parent_run_id) if parent_run_id else None)
352
+ )
353
+ if ancestor_id and ancestor_id in self._active_chains:
354
+ self._active_chains[ancestor_id]["llm_data"] = llm_result
355
+ else:
356
+ span: Dict[str, Any] = {
357
+ "node": f"llm_{len(self.collected_spans)}",
358
+ "node_type": "llm",
359
+ "duration_ms": dur,
360
+ "status": "success",
361
+ "model": data.get("model", self.default_model),
362
+ "provider": self.default_provider,
363
+ "tokens": usage if usage else None,
364
+ "output": {"response": _safe_str(content)},
365
+ }
366
+ if data.get("system_prompt"):
367
+ span["system_prompt"] = data["system_prompt"]
368
+ if data.get("prompt"):
369
+ span["prompt"] = data["prompt"]
370
+ if data.get("user_query"):
371
+ span["input"] = {"user_query": data["user_query"]}
372
+ if tool_calls_raw:
373
+ span["tool_calls"] = tool_calls_raw
374
+ self.collected_spans.append(span)
375
+
376
+ def on_tool_start(self, serialized: Dict[str, Any], input_str: str, *, run_id, parent_run_id=None, **kwargs):
377
+ rid = str(run_id)
378
+ if parent_run_id:
379
+ self._parent_map[rid] = str(parent_run_id)
380
+ tool_name = (serialized or {}).get("name", "unknown_tool")
381
+ self._active_tools[rid] = {
382
+ "name": tool_name,
383
+ "started_at": time.time(),
384
+ "input": input_str,
385
+ }
386
+
387
+ def on_tool_end(self, output: str, *, run_id, **kwargs):
388
+ key = str(run_id)
389
+ data = self._active_tools.pop(key, None)
390
+ if not data:
391
+ return
392
+ dur = int((time.time() - data["started_at"]) * 1000)
393
+ self.collected_spans.append({
394
+ "node": data["name"],
395
+ "node_type": "tool",
396
+ "duration_ms": dur,
397
+ "status": "success",
398
+ "input": {"tool_input": _safe_str(data["input"])},
399
+ "output": {"tool_output": _safe_str(output)},
400
+ "tool_calls": [{
401
+ "tool_name": data["name"],
402
+ "tool_args": _safe_str(data["input"]),
403
+ "tool_result": _safe_str(output),
404
+ "tool_id": str(uuid4()),
405
+ "success": True,
406
+ }],
407
+ })
408
+
409
+ def on_tool_error(self, error: BaseException, *, run_id, **kwargs):
410
+ key = str(run_id)
411
+ data = self._active_tools.pop(key, None)
412
+ if not data:
413
+ return
414
+ dur = int((time.time() - data["started_at"]) * 1000)
415
+ self.collected_spans.append({
416
+ "node": data["name"],
417
+ "node_type": "tool",
418
+ "duration_ms": dur,
419
+ "status": "error",
420
+ "input": {"tool_input": _safe_str(data["input"])},
421
+ "output": {"error": str(error)[:MAX_FIELD_SIZE]},
422
+ "tool_calls": [{
423
+ "tool_name": data["name"],
424
+ "tool_args": _safe_str(data["input"]),
425
+ "tool_result": "",
426
+ "tool_id": str(uuid4()),
427
+ "success": False,
428
+ "error_message": str(error)[:MAX_FIELD_SIZE],
429
+ }],
430
+ })
431
+
432
+ def on_retriever_start(self, serialized: Dict[str, Any], query: str, *, run_id, parent_run_id=None, **kwargs):
433
+ rid = str(run_id)
434
+ if parent_run_id:
435
+ self._parent_map[rid] = str(parent_run_id)
436
+ retriever_name = (serialized or {}).get("name", "retriever")
437
+ self._active_retrievers[rid] = {
438
+ "name": retriever_name,
439
+ "started_at": time.time(),
440
+ "query": query,
441
+ }
442
+
443
+ def on_retriever_end(self, documents: List[Any], *, run_id, **kwargs):
444
+ key = str(run_id)
445
+ data = self._active_retrievers.pop(key, None)
446
+ if not data:
447
+ return
448
+ dur = int((time.time() - data["started_at"]) * 1000)
449
+ doc_count = len(documents) if documents else 0
450
+ self.collected_spans.append({
451
+ "node": data["name"],
452
+ "node_type": "retriever",
453
+ "duration_ms": dur,
454
+ "status": "success",
455
+ "input": {"query": _safe_str(data["query"])},
456
+ "output": {"document_count": doc_count},
457
+ "metadata": {"document_count": str(doc_count)},
458
+ })
459
+
460
+
461
+ class _InstrumentedRunnable:
462
+ """Wrapper returned by ``instrument_langchain`` with .invoke/.ainvoke/.stream/.astream."""
463
+
464
+ def __init__(
465
+ self,
466
+ runnable: Any,
467
+ *,
468
+ name: str,
469
+ governance: Optional[GovernanceDirective] = None,
470
+ model: str = "gpt-4.1-mini",
471
+ provider: str = "openai",
472
+ kpis: Optional[Sequence] = None,
473
+ metadata: Optional[Dict[str, str]] = None,
474
+ ):
475
+ self._runnable = runnable
476
+ self._name = name
477
+ self._governance = governance
478
+ self._model = model
479
+ self._provider = provider
480
+ self._kpis = kpis
481
+ self._metadata = metadata or {}
482
+
483
+ def _make_collector(self) -> Any:
484
+ if not _HAS_CALLBACKS:
485
+ return None
486
+ return _SpanCollector(self._name, self._model, self._provider)
487
+
488
+ def _merge_config(self, config: Optional[Dict[str, Any]], collector: Any) -> Dict[str, Any]:
489
+ cfg = dict(config) if config else {}
490
+ if collector is not None:
491
+ existing = cfg.get("callbacks") or []
492
+ cfg["callbacks"] = list(existing) + [collector]
493
+ return cfg
494
+
495
+ def _finalize(self, collector: Any, result: Any, duration_ms: int) -> None:
496
+ ctx = current_context()
497
+ if ctx is None:
498
+ return
499
+
500
+ ctx.framework = FRAMEWORK_NAME
501
+ ctx.node_type = "workflow"
502
+ ctx.set_model(self._model, self._provider)
503
+
504
+ if collector is not None:
505
+ ctx.set_tokens(collector.total_prompt_tokens, collector.total_completion_tokens)
506
+ if collector.collected_spans:
507
+ last_llm = next(
508
+ (s for s in reversed(collector.collected_spans) if s.get("model")),
509
+ None,
510
+ )
511
+ if last_llm:
512
+ ctx.set_model(last_llm["model"], self._provider)
513
+
514
+ input_str = ""
515
+ output_str = _safe_str(result)
516
+ if hasattr(result, "content"):
517
+ output_str = _safe_str(result.content)
518
+
519
+ ctx.set_io(input_str, output_str)
520
+
521
+ if collector is not None and collector.collected_spans:
522
+ emit_child_spans(
523
+ collector.collected_spans,
524
+ run_id=ctx.run_id,
525
+ parent_span_id=ctx.span_id,
526
+ workflow_name=self._name,
527
+ framework=FRAMEWORK_NAME,
528
+ )
529
+
530
+ async def _finalize_async(self, collector: Any, result: Any, duration_ms: int) -> None:
531
+ ctx = current_context()
532
+ if ctx is None:
533
+ return
534
+
535
+ ctx.framework = FRAMEWORK_NAME
536
+ ctx.node_type = "workflow"
537
+ ctx.set_model(self._model, self._provider)
538
+
539
+ if collector is not None:
540
+ ctx.set_tokens(collector.total_prompt_tokens, collector.total_completion_tokens)
541
+ if collector.collected_spans:
542
+ last_llm = next(
543
+ (s for s in reversed(collector.collected_spans) if s.get("model")),
544
+ None,
545
+ )
546
+ if last_llm:
547
+ ctx.set_model(last_llm["model"], self._provider)
548
+
549
+ output_str = _safe_str(result)
550
+ if hasattr(result, "content"):
551
+ output_str = _safe_str(result.content)
552
+
553
+ ctx.set_io("", output_str)
554
+
555
+ if collector is not None and collector.collected_spans:
556
+ await emit_child_spans_async(
557
+ collector.collected_spans,
558
+ run_id=ctx.run_id,
559
+ parent_span_id=ctx.span_id,
560
+ workflow_name=self._name,
561
+ framework=FRAMEWORK_NAME,
562
+ )
563
+
564
+ def invoke(self, input: Any, config: Optional[Dict[str, Any]] = None, **kwargs) -> Any:
565
+ @instrument(name=self._name, governance=self._governance, kpis=self._kpis, metadata=self._metadata)
566
+ def _run(input_data: Any, cfg: Optional[Dict[str, Any]] = None, **kw) -> Any:
567
+ collector = self._make_collector()
568
+ merged_cfg = self._merge_config(cfg, collector)
569
+ t0 = time.time()
570
+ result = self._runnable.invoke(input_data, config=merged_cfg, **kw)
571
+ dur = int((time.time() - t0) * 1000)
572
+
573
+ ctx = current_context()
574
+ if ctx:
575
+ input_str = _extract_io(input_data, _INPUT_KEYS)
576
+ ctx.set_io(input_str, "")
577
+
578
+ self._finalize(collector, result, dur)
579
+ return result
580
+
581
+ return _run(input, cfg=config, **kwargs)
582
+
583
+ async def ainvoke(self, input: Any, config: Optional[Dict[str, Any]] = None, **kwargs) -> Any:
584
+ @instrument(name=self._name, governance=self._governance, kpis=self._kpis, metadata=self._metadata)
585
+ async def _run(input_data: Any, cfg: Optional[Dict[str, Any]] = None, **kw) -> Any:
586
+ collector = self._make_collector()
587
+ merged_cfg = self._merge_config(cfg, collector)
588
+ t0 = time.time()
589
+ result = await self._runnable.ainvoke(input_data, config=merged_cfg, **kw)
590
+ dur = int((time.time() - t0) * 1000)
591
+
592
+ ctx = current_context()
593
+ if ctx:
594
+ input_str = _extract_io(input_data, _INPUT_KEYS)
595
+ ctx.set_io(input_str, "")
596
+
597
+ await self._finalize_async(collector, result, dur)
598
+ return result
599
+
600
+ return await _run(input, cfg=config, **kwargs)
601
+
602
+ def stream(self, input: Any, config: Optional[Dict[str, Any]] = None, **kwargs) -> Iterator[Any]:
603
+ @instrument(name=self._name, governance=self._governance, kpis=self._kpis, metadata=self._metadata)
604
+ def _run(input_data: Any, cfg: Optional[Dict[str, Any]] = None, **kw) -> Any:
605
+ collector = self._make_collector()
606
+ merged_cfg = self._merge_config(cfg, collector)
607
+ t0 = time.time()
608
+ chunks = []
609
+ for chunk in self._runnable.stream(input_data, config=merged_cfg, **kw):
610
+ chunks.append(chunk)
611
+
612
+ dur = int((time.time() - t0) * 1000)
613
+ final = chunks[-1] if chunks else None
614
+
615
+ ctx = current_context()
616
+ if ctx:
617
+ input_str = _extract_io(input_data, _INPUT_KEYS)
618
+ ctx.set_io(input_str, "")
619
+
620
+ self._finalize(collector, final, dur)
621
+ return chunks
622
+
623
+ result_chunks = _run(input, cfg=config, **kwargs)
624
+ yield from result_chunks
625
+
626
+ async def astream(self, input: Any, config: Optional[Dict[str, Any]] = None, **kwargs) -> AsyncIterator[Any]:
627
+ @instrument(name=self._name, governance=self._governance, kpis=self._kpis, metadata=self._metadata)
628
+ async def _run(input_data: Any, cfg: Optional[Dict[str, Any]] = None, **kw) -> Any:
629
+ collector = self._make_collector()
630
+ merged_cfg = self._merge_config(cfg, collector)
631
+ t0 = time.time()
632
+ chunks = []
633
+ async for chunk in self._runnable.astream(input_data, config=merged_cfg, **kw):
634
+ chunks.append(chunk)
635
+
636
+ dur = int((time.time() - t0) * 1000)
637
+ final = chunks[-1] if chunks else None
638
+
639
+ ctx = current_context()
640
+ if ctx:
641
+ input_str = _extract_io(input_data, _INPUT_KEYS)
642
+ ctx.set_io(input_str, "")
643
+
644
+ await self._finalize_async(collector, final, dur)
645
+ return chunks
646
+
647
+ result_chunks = await _run(input, cfg=config, **kwargs)
648
+ for chunk in result_chunks:
649
+ yield chunk
650
+
651
+
652
+ def instrument_langchain(
653
+ runnable: Any,
654
+ *,
655
+ name: str,
656
+ governance: Optional[GovernanceDirective] = None,
657
+ model: str = "gpt-4.1-mini",
658
+ provider: str = "openai",
659
+ kpis: Optional[Sequence] = None,
660
+ metadata: Optional[Dict[str, str]] = None,
661
+ ) -> _InstrumentedRunnable:
662
+ """Instrument a LangChain Runnable for TuringPulse observability.
663
+
664
+ Wraps any LangChain ``Runnable`` (chains, agents, ``RunnableSequence``,
665
+ ``AgentExecutor``, etc.) and returns an object with ``.invoke()``,
666
+ ``.ainvoke()``, ``.stream()``, and ``.astream()`` methods.
667
+
668
+ Internally injects a ``BaseCallbackHandler`` to capture per-step spans
669
+ including LLM calls, tool executions, retriever lookups, and sub-agent
670
+ invocations.
671
+
672
+ Args:
673
+ runnable: A LangChain ``Runnable`` instance.
674
+ name: Workflow display name for TuringPulse.
675
+ governance: Optional governance directive for policy enforcement.
676
+ model: Default LLM model name (auto-detected from callbacks when possible).
677
+ provider: LLM provider (default ``openai``).
678
+ kpis: Optional KPI configurations.
679
+ metadata: Optional metadata key-value pairs.
680
+
681
+ Returns:
682
+ An ``_InstrumentedRunnable`` with ``.invoke()`` / ``.ainvoke()`` /
683
+ ``.stream()`` / ``.astream()`` methods.
684
+ """
685
+ return _InstrumentedRunnable(
686
+ runnable,
687
+ name=name,
688
+ governance=governance,
689
+ model=model,
690
+ provider=provider,
691
+ kpis=kpis,
692
+ metadata=metadata,
693
+ )