mantisdk 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mantisdk might be problematic. Click here for more details.

Files changed (190) hide show
  1. mantisdk/__init__.py +22 -0
  2. mantisdk/adapter/__init__.py +15 -0
  3. mantisdk/adapter/base.py +94 -0
  4. mantisdk/adapter/messages.py +270 -0
  5. mantisdk/adapter/triplet.py +1028 -0
  6. mantisdk/algorithm/__init__.py +39 -0
  7. mantisdk/algorithm/apo/__init__.py +5 -0
  8. mantisdk/algorithm/apo/apo.py +889 -0
  9. mantisdk/algorithm/apo/prompts/apply_edit_variant01.poml +22 -0
  10. mantisdk/algorithm/apo/prompts/apply_edit_variant02.poml +18 -0
  11. mantisdk/algorithm/apo/prompts/text_gradient_variant01.poml +18 -0
  12. mantisdk/algorithm/apo/prompts/text_gradient_variant02.poml +16 -0
  13. mantisdk/algorithm/apo/prompts/text_gradient_variant03.poml +107 -0
  14. mantisdk/algorithm/base.py +162 -0
  15. mantisdk/algorithm/decorator.py +264 -0
  16. mantisdk/algorithm/fast.py +250 -0
  17. mantisdk/algorithm/gepa/__init__.py +59 -0
  18. mantisdk/algorithm/gepa/adapter.py +459 -0
  19. mantisdk/algorithm/gepa/gepa.py +364 -0
  20. mantisdk/algorithm/gepa/lib/__init__.py +18 -0
  21. mantisdk/algorithm/gepa/lib/adapters/README.md +12 -0
  22. mantisdk/algorithm/gepa/lib/adapters/__init__.py +0 -0
  23. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/README.md +341 -0
  24. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/__init__.py +1 -0
  25. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/anymaths_adapter.py +174 -0
  26. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/requirements.txt +1 -0
  27. mantisdk/algorithm/gepa/lib/adapters/default_adapter/README.md +0 -0
  28. mantisdk/algorithm/gepa/lib/adapters/default_adapter/__init__.py +0 -0
  29. mantisdk/algorithm/gepa/lib/adapters/default_adapter/default_adapter.py +209 -0
  30. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/README.md +7 -0
  31. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/__init__.py +0 -0
  32. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/dspy_adapter.py +307 -0
  33. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/README.md +99 -0
  34. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/dspy_program_proposal_signature.py +137 -0
  35. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/full_program_adapter.py +266 -0
  36. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/GEPA_RAG.md +621 -0
  37. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/__init__.py +56 -0
  38. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/evaluation_metrics.py +226 -0
  39. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/generic_rag_adapter.py +496 -0
  40. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/rag_pipeline.py +238 -0
  41. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_store_interface.py +212 -0
  42. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/__init__.py +2 -0
  43. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/chroma_store.py +196 -0
  44. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/lancedb_store.py +422 -0
  45. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/milvus_store.py +409 -0
  46. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/qdrant_store.py +368 -0
  47. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/weaviate_store.py +418 -0
  48. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/README.md +552 -0
  49. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/__init__.py +37 -0
  50. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_adapter.py +705 -0
  51. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_client.py +364 -0
  52. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/README.md +9 -0
  53. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/__init__.py +0 -0
  54. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/terminal_bench_adapter.py +217 -0
  55. mantisdk/algorithm/gepa/lib/api.py +375 -0
  56. mantisdk/algorithm/gepa/lib/core/__init__.py +0 -0
  57. mantisdk/algorithm/gepa/lib/core/adapter.py +180 -0
  58. mantisdk/algorithm/gepa/lib/core/data_loader.py +74 -0
  59. mantisdk/algorithm/gepa/lib/core/engine.py +356 -0
  60. mantisdk/algorithm/gepa/lib/core/result.py +233 -0
  61. mantisdk/algorithm/gepa/lib/core/state.py +636 -0
  62. mantisdk/algorithm/gepa/lib/examples/__init__.py +0 -0
  63. mantisdk/algorithm/gepa/lib/examples/aime.py +24 -0
  64. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/eval_default.py +111 -0
  65. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/instruction_prompt.txt +9 -0
  66. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/optimal_prompt.txt +24 -0
  67. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/train_anymaths.py +177 -0
  68. mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/arc_agi.ipynb +25705 -0
  69. mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/example.ipynb +348 -0
  70. mantisdk/algorithm/gepa/lib/examples/mcp_adapter/__init__.py +4 -0
  71. mantisdk/algorithm/gepa/lib/examples/mcp_adapter/mcp_optimization_example.py +455 -0
  72. mantisdk/algorithm/gepa/lib/examples/rag_adapter/RAG_GUIDE.md +613 -0
  73. mantisdk/algorithm/gepa/lib/examples/rag_adapter/__init__.py +9 -0
  74. mantisdk/algorithm/gepa/lib/examples/rag_adapter/rag_optimization.py +824 -0
  75. mantisdk/algorithm/gepa/lib/examples/rag_adapter/requirements-rag.txt +29 -0
  76. mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/instruction_prompt.txt +16 -0
  77. mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/terminus.txt +9 -0
  78. mantisdk/algorithm/gepa/lib/examples/terminal-bench/train_terminus.py +161 -0
  79. mantisdk/algorithm/gepa/lib/gepa_utils.py +117 -0
  80. mantisdk/algorithm/gepa/lib/logging/__init__.py +0 -0
  81. mantisdk/algorithm/gepa/lib/logging/experiment_tracker.py +187 -0
  82. mantisdk/algorithm/gepa/lib/logging/logger.py +75 -0
  83. mantisdk/algorithm/gepa/lib/logging/utils.py +103 -0
  84. mantisdk/algorithm/gepa/lib/proposer/__init__.py +0 -0
  85. mantisdk/algorithm/gepa/lib/proposer/base.py +31 -0
  86. mantisdk/algorithm/gepa/lib/proposer/merge.py +357 -0
  87. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/__init__.py +0 -0
  88. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/base.py +49 -0
  89. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/reflective_mutation.py +176 -0
  90. mantisdk/algorithm/gepa/lib/py.typed +0 -0
  91. mantisdk/algorithm/gepa/lib/strategies/__init__.py +0 -0
  92. mantisdk/algorithm/gepa/lib/strategies/batch_sampler.py +77 -0
  93. mantisdk/algorithm/gepa/lib/strategies/candidate_selector.py +50 -0
  94. mantisdk/algorithm/gepa/lib/strategies/component_selector.py +36 -0
  95. mantisdk/algorithm/gepa/lib/strategies/eval_policy.py +64 -0
  96. mantisdk/algorithm/gepa/lib/strategies/instruction_proposal.py +127 -0
  97. mantisdk/algorithm/gepa/lib/utils/__init__.py +10 -0
  98. mantisdk/algorithm/gepa/lib/utils/stop_condition.py +196 -0
  99. mantisdk/algorithm/gepa/tracing.py +105 -0
  100. mantisdk/algorithm/utils.py +177 -0
  101. mantisdk/algorithm/verl/__init__.py +5 -0
  102. mantisdk/algorithm/verl/interface.py +202 -0
  103. mantisdk/cli/__init__.py +56 -0
  104. mantisdk/cli/prometheus.py +115 -0
  105. mantisdk/cli/store.py +131 -0
  106. mantisdk/cli/vllm.py +29 -0
  107. mantisdk/client.py +408 -0
  108. mantisdk/config.py +348 -0
  109. mantisdk/emitter/__init__.py +43 -0
  110. mantisdk/emitter/annotation.py +370 -0
  111. mantisdk/emitter/exception.py +54 -0
  112. mantisdk/emitter/message.py +61 -0
  113. mantisdk/emitter/object.py +117 -0
  114. mantisdk/emitter/reward.py +320 -0
  115. mantisdk/env_var.py +156 -0
  116. mantisdk/execution/__init__.py +15 -0
  117. mantisdk/execution/base.py +64 -0
  118. mantisdk/execution/client_server.py +443 -0
  119. mantisdk/execution/events.py +69 -0
  120. mantisdk/execution/inter_process.py +16 -0
  121. mantisdk/execution/shared_memory.py +282 -0
  122. mantisdk/instrumentation/__init__.py +119 -0
  123. mantisdk/instrumentation/agentops.py +314 -0
  124. mantisdk/instrumentation/agentops_langchain.py +45 -0
  125. mantisdk/instrumentation/litellm.py +83 -0
  126. mantisdk/instrumentation/vllm.py +81 -0
  127. mantisdk/instrumentation/weave.py +500 -0
  128. mantisdk/litagent/__init__.py +11 -0
  129. mantisdk/litagent/decorator.py +536 -0
  130. mantisdk/litagent/litagent.py +252 -0
  131. mantisdk/llm_proxy.py +1890 -0
  132. mantisdk/logging.py +370 -0
  133. mantisdk/reward.py +7 -0
  134. mantisdk/runner/__init__.py +11 -0
  135. mantisdk/runner/agent.py +845 -0
  136. mantisdk/runner/base.py +182 -0
  137. mantisdk/runner/legacy.py +309 -0
  138. mantisdk/semconv.py +170 -0
  139. mantisdk/server.py +401 -0
  140. mantisdk/store/__init__.py +23 -0
  141. mantisdk/store/base.py +897 -0
  142. mantisdk/store/client_server.py +2092 -0
  143. mantisdk/store/collection/__init__.py +30 -0
  144. mantisdk/store/collection/base.py +587 -0
  145. mantisdk/store/collection/memory.py +970 -0
  146. mantisdk/store/collection/mongo.py +1412 -0
  147. mantisdk/store/collection_based.py +1823 -0
  148. mantisdk/store/insight.py +648 -0
  149. mantisdk/store/listener.py +58 -0
  150. mantisdk/store/memory.py +396 -0
  151. mantisdk/store/mongo.py +165 -0
  152. mantisdk/store/sqlite.py +3 -0
  153. mantisdk/store/threading.py +357 -0
  154. mantisdk/store/utils.py +142 -0
  155. mantisdk/tracer/__init__.py +16 -0
  156. mantisdk/tracer/agentops.py +242 -0
  157. mantisdk/tracer/base.py +287 -0
  158. mantisdk/tracer/dummy.py +106 -0
  159. mantisdk/tracer/otel.py +555 -0
  160. mantisdk/tracer/weave.py +677 -0
  161. mantisdk/trainer/__init__.py +6 -0
  162. mantisdk/trainer/init_utils.py +263 -0
  163. mantisdk/trainer/legacy.py +367 -0
  164. mantisdk/trainer/registry.py +12 -0
  165. mantisdk/trainer/trainer.py +618 -0
  166. mantisdk/types/__init__.py +6 -0
  167. mantisdk/types/core.py +553 -0
  168. mantisdk/types/resources.py +204 -0
  169. mantisdk/types/tracer.py +515 -0
  170. mantisdk/types/tracing.py +218 -0
  171. mantisdk/utils/__init__.py +1 -0
  172. mantisdk/utils/id.py +18 -0
  173. mantisdk/utils/metrics.py +1025 -0
  174. mantisdk/utils/otel.py +578 -0
  175. mantisdk/utils/otlp.py +536 -0
  176. mantisdk/utils/server_launcher.py +1045 -0
  177. mantisdk/utils/system_snapshot.py +81 -0
  178. mantisdk/verl/__init__.py +8 -0
  179. mantisdk/verl/__main__.py +6 -0
  180. mantisdk/verl/async_server.py +46 -0
  181. mantisdk/verl/config.yaml +27 -0
  182. mantisdk/verl/daemon.py +1154 -0
  183. mantisdk/verl/dataset.py +44 -0
  184. mantisdk/verl/entrypoint.py +248 -0
  185. mantisdk/verl/trainer.py +549 -0
  186. mantisdk-0.1.0.dist-info/METADATA +119 -0
  187. mantisdk-0.1.0.dist-info/RECORD +190 -0
  188. mantisdk-0.1.0.dist-info/WHEEL +4 -0
  189. mantisdk-0.1.0.dist-info/entry_points.txt +2 -0
  190. mantisdk-0.1.0.dist-info/licenses/LICENSE +19 -0
mantisdk/llm_proxy.py ADDED
@@ -0,0 +1,1890 @@
1
+ # Copyright (c) Microsoft. All rights reserved.
2
+
3
+ from __future__ import annotations
4
+
5
+ import ast
6
+ import asyncio
7
+ import json
8
+ import logging
9
+ import os
10
+ import re
11
+ import tempfile
12
+ import threading
13
+ import time
14
+ from contextlib import asynccontextmanager
15
+ from contextvars import ContextVar
16
+ from datetime import datetime
17
+ from typing import (
18
+ Any,
19
+ AsyncGenerator,
20
+ Awaitable,
21
+ Callable,
22
+ Dict,
23
+ Iterable,
24
+ List,
25
+ Literal,
26
+ Optional,
27
+ Sequence,
28
+ Tuple,
29
+ Type,
30
+ TypedDict,
31
+ Union,
32
+ cast,
33
+ )
34
+
35
+ import litellm
36
+ import opentelemetry.trace as trace_api
37
+ import yaml
38
+ from fastapi import Request, Response
39
+ from fastapi.responses import StreamingResponse
40
+ from litellm.integrations.custom_logger import CustomLogger
41
+ from litellm.integrations.opentelemetry import OpenTelemetry, OpenTelemetryConfig
42
+ from litellm.proxy.proxy_server import app, save_worker_config # pyright: ignore[reportUnknownVariableType]
43
+ from litellm.types.utils import CallTypes
44
+ from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
45
+ from opentelemetry.sdk.resources import Resource
46
+ from opentelemetry.sdk.trace import ReadableSpan, SpanContext
47
+ from opentelemetry.sdk.trace.export import SpanExporter, SpanExportResult
48
+ from opentelemetry.trace import Link, Status
49
+ from opentelemetry.util.types import Attributes
50
+ from starlette.middleware.base import BaseHTTPMiddleware
51
+ from starlette.types import Scope
52
+
53
+ from mantisdk.semconv import LightningResourceAttributes
54
+ from mantisdk.types import LLM, ProxyLLM
55
+ from mantisdk.utils.server_launcher import (
56
+ LaunchMode,
57
+ PythonServerLauncher,
58
+ PythonServerLauncherArgs,
59
+ noop_context,
60
+ )
61
+
62
+ from .store.base import LightningStore
63
+
64
+ logger = logging.getLogger(__name__)
65
+
66
+ # Context variable to store HTTP request headers for LiteLLM callback access
67
+ _request_headers_context: ContextVar[Optional[Dict[str, str]]] = ContextVar("request_headers", default=None)
68
+
69
+ __all__ = [
70
+ "LLMProxy",
71
+ ]
72
+
73
+
74
+ class ModelConfig(TypedDict):
75
+ """LiteLLM model registration entry.
76
+
77
+ This mirrors the items in LiteLLM's `model_list` section.
78
+
79
+ Attributes:
80
+ model_name: Logical model name exposed by the proxy.
81
+ litellm_params: Parameters passed to LiteLLM for this model
82
+ (e.g., backend model id, api_base, additional options).
83
+ """ # Google style kept concise.
84
+
85
+ model_name: str
86
+ litellm_params: Dict[str, Any]
87
+
88
+
89
+ def _get_pre_call_data(args: Any, kwargs: Any) -> Dict[str, Any]:
90
+ """Extract LiteLLM request payload from hook args.
91
+
92
+ The LiteLLM logger hooks receive `(*args, **kwargs)` whose third positional
93
+ argument or `data=` kwarg contains the request payload.
94
+
95
+ Args:
96
+ args: Positional arguments from the hook.
97
+ kwargs: Keyword arguments from the hook.
98
+
99
+ Returns:
100
+ The request payload dict.
101
+
102
+ Raises:
103
+ ValueError: If the payload cannot be located or is not a dict.
104
+ """
105
+ if kwargs.get("data"):
106
+ data = kwargs["data"]
107
+ elif len(args) >= 3:
108
+ data = args[2]
109
+ else:
110
+ raise ValueError(f"Unable to get request data from args or kwargs: {args}, {kwargs}")
111
+ if not isinstance(data, dict):
112
+ raise ValueError(f"Request data is not a dictionary: {data}")
113
+ return cast(Dict[str, Any], data)
114
+
115
+
116
+ def _reset_litellm_logging_worker() -> None:
117
+ """Reset LiteLLM's global logging worker to the current event loop.
118
+
119
+ LiteLLM keeps a module-level ``GLOBAL_LOGGING_WORKER`` singleton that owns an
120
+ ``asyncio.Queue``. The queue is bound to the event loop where it was created.
121
+ When the proxy is restarted, Uvicorn spins up a brand new event loop in a new
122
+ thread. If the existing logging worker (and its queue) are reused, LiteLLM
123
+ raises ``RuntimeError: <Queue ...> is bound to a different event loop`` the
124
+ next time it tries to log. Recreating the worker ensures that LiteLLM will
125
+ lazily initialise a fresh queue on the new loop.
126
+ """
127
+
128
+ # ``GLOBAL_LOGGING_WORKER`` is imported in a few LiteLLM modules at runtime.
129
+ # Update any already-imported references so future calls use the fresh worker.
130
+ try:
131
+ import litellm.utils as litellm_utils
132
+ from litellm.litellm_core_utils import logging_worker as litellm_logging_worker
133
+
134
+ litellm_logging_worker.GLOBAL_LOGGING_WORKER = litellm_logging_worker.LoggingWorker()
135
+ litellm_utils.GLOBAL_LOGGING_WORKER = litellm_logging_worker.GLOBAL_LOGGING_WORKER # type: ignore[reportAttributeAccessIssue]
136
+ except Exception: # pragma: no cover - best-effort hygiene
137
+ logger.warning("Unable to propagate LiteLLM logging worker reset.", exc_info=True)
138
+
139
+
140
+ def _reset_litellm_logging_callback_manager() -> None:
141
+ """Reset LiteLLM's global callback manager.
142
+
143
+ To get rid of the warning message: "Cannot add callback - would exceed MAX_CALLBACKS limit of 30."
144
+ when litellm is restarted multiple times in the same process.
145
+
146
+ It does not respect existing input/output callbacks.
147
+ """
148
+
149
+ try:
150
+ litellm.logging_callback_manager._reset_all_callbacks() # pyright: ignore[reportPrivateUsage]
151
+ except Exception: # pragma: no cover - best-effort hygiene
152
+ logger.warning("Unable to reset LiteLLM logging callback manager.", exc_info=True)
153
+
154
+
155
+ class AddReturnTokenIds(CustomLogger):
156
+ """LiteLLM logger hook to request token ids from vLLM.
157
+
158
+ This mutates the outgoing request payload to include `return_token_ids=True`
159
+ for backends that support token id return (e.g., vLLM).
160
+
161
+ See also:
162
+ [vLLM PR #22587](https://github.com/vllm-project/vllm/pull/22587)
163
+ """
164
+
165
+ async def async_pre_call_hook(self, *args: Any, **kwargs: Any) -> Optional[Union[Exception, str, Dict[str, Any]]]:
166
+ """Async pre-call hook to adjust request payload.
167
+
168
+ Args:
169
+ args: Positional args from LiteLLM.
170
+ kwargs: Keyword args from LiteLLM.
171
+
172
+ Returns:
173
+ Either an updated payload dict or an Exception to short-circuit.
174
+ """
175
+ try:
176
+ data = _get_pre_call_data(args, kwargs)
177
+ except Exception as e:
178
+ return e
179
+
180
+ # Ensure token ids are requested from the backend when supported.
181
+ return {**data, "return_token_ids": True}
182
+
183
+
184
+ class AddLogprobs(CustomLogger):
185
+ """LiteLLM logger hook to request logprobs from vLLM.
186
+
187
+ This mutates the outgoing request payload to include `logprobs=1`
188
+ for backends that support logprobs return (e.g., vLLM).
189
+ """
190
+
191
+ async def async_pre_call_hook(self, *args: Any, **kwargs: Any) -> Optional[Union[Exception, str, Dict[str, Any]]]:
192
+ """Async pre-call hook to adjust request payload."""
193
+ try:
194
+ data = _get_pre_call_data(args, kwargs)
195
+ except Exception as e:
196
+ return e
197
+
198
+ # Ensure logprobs are requested from the backend when supported.
199
+ return {**data, "logprobs": 1}
200
+
201
+
202
+ class SpanWithExtraAttributes(ReadableSpan):
203
+ """Wrapper around ReadableSpan that adds extra span attributes.
204
+
205
+ Since ReadableSpan is immutable, this wrapper intercepts the attributes
206
+ property to include additional attributes for Langfuse integration
207
+ (environment, tags, etc.).
208
+ """
209
+
210
+ def __init__(self, wrapped_span: ReadableSpan, extra_attributes: Dict[str, Any]):
211
+ """Initialize wrapper with a span and extra attributes to inject.
212
+
213
+ Args:
214
+ wrapped_span: The original ReadableSpan to wrap.
215
+ extra_attributes: Dictionary of extra attributes to add.
216
+ These should use Langfuse conventions (e.g., "langfuse.environment").
217
+ """
218
+ self._wrapped = wrapped_span
219
+ self._extra_attributes = extra_attributes
220
+
221
+ @property
222
+ def name(self) -> str:
223
+ return self._wrapped.name
224
+
225
+ @property
226
+ def context(self) -> Optional[SpanContext]:
227
+ return self._wrapped.context
228
+
229
+ def get_span_context(self) -> SpanContext:
230
+ return self._wrapped.get_span_context()
231
+
232
+ @property
233
+ def parent(self) -> Optional[SpanContext]:
234
+ return self._wrapped.parent
235
+
236
+ @property
237
+ def start_time(self) -> int:
238
+ return self._wrapped.start_time
239
+
240
+ @property
241
+ def end_time(self) -> int:
242
+ return self._wrapped.end_time
243
+
244
+ @property
245
+ def status(self) -> Status:
246
+ return self._wrapped.status
247
+
248
+ @property
249
+ def attributes(self) -> Attributes:
250
+ """Return original attributes merged with extra attributes."""
251
+ original_attrs = self._wrapped.attributes or {}
252
+ # Create a merged dict with original attrs and extra attrs
253
+ merged = dict(original_attrs)
254
+ merged.update(self._extra_attributes)
255
+ return merged
256
+
257
+ @property
258
+ def events(self) -> tuple:
259
+ return self._wrapped.events
260
+
261
+ @property
262
+ def links(self) -> tuple:
263
+ return self._wrapped.links
264
+
265
+ @property
266
+ def resource(self) -> Resource:
267
+ return self._wrapped.resource
268
+
269
+ @property
270
+ def instrumentation_scope(self):
271
+ return self._wrapped.instrumentation_scope
272
+
273
+ # Also expose _resource for direct modification if needed
274
+ @property
275
+ def _resource(self) -> Resource:
276
+ return self._wrapped._resource # pyright: ignore[reportPrivateUsage]
277
+
278
+ @_resource.setter
279
+ def _resource(self, value: Resource):
280
+ self._wrapped._resource = value # pyright: ignore[reportPrivateUsage]
281
+
282
+ def __getattr__(self, name: str):
283
+ """Delegate all other attribute access to the wrapped span.
284
+
285
+ This ensures OTLPSpanExporter can access all private attributes (_events, _links, etc.)
286
+ that it needs for serialization without us having to enumerate them all.
287
+ """
288
+ return getattr(self._wrapped, name)
289
+
290
+
291
+ class LightningSpanExporter(SpanExporter):
292
+ """Buffered OTEL span exporter with subtree flushing and training-store sink.
293
+
294
+ Design:
295
+
296
+ * Spans are buffered until a root span's entire subtree is available.
297
+ * A private event loop on a daemon thread runs async flush logic.
298
+ * Rollout/attempt/sequence metadata is reconstructed by merging headers
299
+ from any span within a subtree.
300
+
301
+ Thread-safety:
302
+
303
+ * Buffer access is protected by a re-entrant lock.
304
+ * Export is synchronous to the caller yet schedules an async flush on the
305
+ internal loop, then waits for completion.
306
+ """
307
+
308
+ def __init__(
309
+ self,
310
+ _store: Optional[LightningStore] = None,
311
+ otlp_endpoint: Optional[str] = None,
312
+ otlp_headers: Optional[Dict[str, str]] = None,
313
+ ):
314
+ self._store: Optional[LightningStore] = _store # this is only for testing purposes
315
+ self._otlp_endpoint: Optional[str] = otlp_endpoint # Direct OTLP export endpoint
316
+ self._buffer: List[ReadableSpan] = []
317
+ self._lock: Optional[threading.Lock] = None
318
+ self._loop_lock_pid: Optional[int] = None
319
+
320
+ # Single dedicated event loop running in a daemon thread.
321
+ # This decouples OTEL SDK threads from our async store I/O.
322
+ # Deferred creation until first use.
323
+ self._loop: Optional[asyncio.AbstractEventLoop] = None
324
+ self._loop_thread: Optional[threading.Thread] = None
325
+
326
+ # Initialize OTLP exporter with custom endpoint and headers if provided
327
+ if otlp_endpoint:
328
+ self._otlp_exporter = OTLPSpanExporter(endpoint=otlp_endpoint, headers=otlp_headers or {})
329
+ else:
330
+ self._otlp_exporter = OTLPSpanExporter()
331
+
332
+ def _ensure_loop(self) -> asyncio.AbstractEventLoop:
333
+ """Lazily initialize the event loop and thread on first use.
334
+
335
+ Returns:
336
+ asyncio.AbstractEventLoop: The initialized event loop.
337
+ """
338
+ self._clear_loop_and_lock()
339
+ if self._loop is None:
340
+ self._loop = asyncio.new_event_loop()
341
+ self._loop_thread = threading.Thread(target=self._run_loop, name="LightningSpanExporterLoop", daemon=True)
342
+ self._loop_thread.start()
343
+ return self._loop
344
+
345
+ def _ensure_lock(self) -> threading.Lock:
346
+ """Lazily initialize the lock on first use.
347
+
348
+ Returns:
349
+ threading.Lock: The initialized lock.
350
+ """
351
+ self._clear_loop_and_lock()
352
+ if self._lock is None:
353
+ self._lock = threading.Lock()
354
+ return self._lock
355
+
356
+ def _clear_loop_and_lock(self) -> None:
357
+ """Clear the loop and lock.
358
+ This happens if the exporter was used in a process then used in another process.
359
+
360
+ This should only happen in CI.
361
+ """
362
+ if os.getpid() != self._loop_lock_pid:
363
+ logger.warning("Loop and lock are not owned by the current process. Clearing them.")
364
+ self._loop = None
365
+ self._loop_thread = None
366
+ self._lock = None
367
+ self._loop_lock_pid = os.getpid()
368
+ elif self._loop_lock_pid is None:
369
+ self._loop_lock_pid = os.getpid()
370
+
371
+ def _run_loop(self) -> None:
372
+ """Run the private asyncio loop forever on the exporter thread."""
373
+ assert self._loop is not None, "Loop should be initialized before thread starts"
374
+ asyncio.set_event_loop(self._loop)
375
+ self._loop.run_forever()
376
+
377
+ def shutdown(self) -> None:
378
+ """Shut down the exporter event loop.
379
+
380
+ Safe to call at process exit.
381
+
382
+ """
383
+ if self._loop is None:
384
+ return
385
+
386
+ try:
387
+
388
+ def _stop():
389
+ assert self._loop is not None
390
+ self._loop.stop()
391
+
392
+ self._loop.call_soon_threadsafe(_stop)
393
+ if self._loop_thread is not None:
394
+ self._loop_thread.join(timeout=2.0)
395
+ self._loop.close()
396
+ except Exception:
397
+ logger.exception("Error during exporter shutdown")
398
+
399
+ def export(self, spans: Sequence[ReadableSpan]) -> SpanExportResult:
400
+ """Export spans via buffered subtree flush.
401
+
402
+ Appends spans to the internal buffer, then triggers an async flush on the
403
+ private event loop. Blocks until that flush completes.
404
+
405
+ Args:
406
+ spans: Sequence of spans to export.
407
+
408
+ Returns:
409
+ SpanExportResult: SUCCESS on flush success, else FAILURE.
410
+ """
411
+ # Buffer append under lock to protect against concurrent exporters.
412
+ with self._ensure_lock():
413
+ for span in spans:
414
+ self._buffer.append(span)
415
+ default_endpoint = self._otlp_exporter._endpoint # pyright: ignore[reportPrivateUsage]
416
+ try:
417
+ self._maybe_flush()
418
+ except Exception as e:
419
+ logger.exception("Export flush failed: %s", e)
420
+ return SpanExportResult.FAILURE
421
+ finally:
422
+ self._otlp_exporter._endpoint = default_endpoint # pyright: ignore[reportPrivateUsage]
423
+
424
+ return SpanExportResult.SUCCESS
425
+
426
+ def _get_job_id_from_store(self, store: Any) -> Optional[str]:
427
+ """Get the job_id from the store's listeners (if any InsightTracker is attached)."""
428
+ if hasattr(store, "listeners"):
429
+ for listener in store.listeners:
430
+ if hasattr(listener, "job_id"):
431
+ return listener.job_id
432
+ return None
433
+
434
+ def _get_tracing_metadata_from_rollout(
435
+ self, store: Any, rollout_id: str
436
+ ) -> Tuple[Optional[str], Optional[List[str]], Optional[str]]:
437
+ """Fetch tracing metadata (environment, tags, session_id) from a rollout.
438
+
439
+ Args:
440
+ store: The LightningStore instance.
441
+ rollout_id: The rollout ID to fetch.
442
+
443
+ Returns:
444
+ Tuple of (environment, tags, session_id). All may be None if not set.
445
+ """
446
+ print(f"[TracingMetadata] Fetching metadata for rollout {rollout_id}")
447
+ logger.info(f"[TracingMetadata] Fetching metadata for rollout {rollout_id}")
448
+ try:
449
+ loop = self._ensure_loop()
450
+ get_rollout_task = store.get_rollout_by_id(rollout_id)
451
+ fut = asyncio.run_coroutine_threadsafe(get_rollout_task, loop)
452
+ rollout = fut.result(timeout=5.0) # Short timeout for metadata fetch
453
+
454
+ if rollout is None:
455
+ logger.warning(f"[TracingMetadata] Rollout {rollout_id} not found in store")
456
+ return None, None, None
457
+
458
+ logger.info(f"[TracingMetadata] Rollout {rollout_id} found, metadata={rollout.metadata}")
459
+
460
+ if rollout.metadata is None:
461
+ logger.warning(f"[TracingMetadata] Rollout {rollout_id} has no metadata (None)")
462
+ return None, None, None
463
+
464
+ if not rollout.metadata:
465
+ logger.warning(f"[TracingMetadata] Rollout {rollout_id} has empty metadata dict")
466
+ return None, None, None
467
+
468
+ environment = rollout.metadata.get("environment")
469
+ tags = rollout.metadata.get("tags")
470
+ session_id = rollout.metadata.get("session_id")
471
+
472
+ logger.info(f"[TracingMetadata] Rollout {rollout_id}: environment={environment}, tags={tags}, session_id={session_id}")
473
+
474
+ return environment, tags, session_id
475
+ except Exception as e:
476
+ logger.warning(f"[TracingMetadata] Failed to fetch rollout metadata for {rollout_id}: {e}")
477
+ import traceback
478
+ traceback.print_exc()
479
+ return None, None, None
480
+
481
+ def _maybe_flush(self):
482
+ """Flush ready subtrees from the buffer.
483
+
484
+ Strategy:
485
+ We consider a subtree "ready" if we can identify a root span. We
486
+ then take that root and all its descendants out of the buffer and
487
+ try to reconstruct rollout/attempt/sequence headers by merging any
488
+ span's `metadata.requester_custom_headers` within the subtree.
489
+
490
+ Span types:
491
+ - Rollout spans: Have rollout_id/attempt_id/sequence_id headers
492
+ - Job spans: No rollout context, tagged with job_id for experiment tracking
493
+
494
+ Direct OTLP mode:
495
+ When `otlp_endpoint` is configured, spans are exported directly to the
496
+ endpoint without requiring a store or header validation.
497
+
498
+ Raises:
499
+ None directly. Logs and skips malformed spans.
500
+
501
+ """
502
+ # Iterate over current roots. Each iteration pops a whole subtree.
503
+ for root_span_id in self._get_root_span_ids():
504
+ subtree_spans = self._pop_subtrees(root_span_id)
505
+ if not subtree_spans:
506
+ continue
507
+
508
+ # Merge all custom headers found in the subtree.
509
+ # This must happen BEFORE the direct OTLP check so both paths can apply tags.
510
+ headers_merged: Dict[str, Any] = {}
511
+
512
+ for span in subtree_spans:
513
+ if span.attributes is None:
514
+ continue
515
+ headers_str = span.attributes.get("metadata.requester_custom_headers")
516
+ if headers_str is None:
517
+ continue
518
+ if not isinstance(headers_str, str):
519
+ logger.debug(f"metadata.requester_custom_headers is not a string: {headers_str}")
520
+ continue
521
+ if not headers_str.strip():
522
+ continue
523
+ try:
524
+ # Use literal_eval to parse the stringified dict safely.
525
+ headers = ast.literal_eval(headers_str)
526
+ except Exception as e:
527
+ logger.debug(f"Failed to parse metadata.requester_custom_headers: {e}")
528
+ continue
529
+ if isinstance(headers, dict):
530
+ headers_merged.update(cast(Dict[str, Any], headers))
531
+
532
+ # Extract rollout context if available
533
+ rollout_id = headers_merged.get("x-rollout-id")
534
+ attempt_id = headers_merged.get("x-attempt-id")
535
+ sequence_id = headers_merged.get("x-sequence-id")
536
+
537
+ # Determine if we're using OTLP export (either direct or store-based)
538
+ store = self._store or get_active_llm_proxy().get_store()
539
+ otlp_enabled = bool(self._otlp_endpoint) or (store and store.capabilities.get("otlp_traces", False))
540
+
541
+ has_rollout_context = (
542
+ rollout_id
543
+ and attempt_id
544
+ and sequence_id
545
+ and isinstance(rollout_id, str)
546
+ and isinstance(attempt_id, str)
547
+ and isinstance(sequence_id, str)
548
+ and sequence_id.isdigit()
549
+ )
550
+
551
+ if has_rollout_context:
552
+ # Rollout-scoped spans: tag with rollout/attempt/sequence
553
+ sequence_id_decimal = int(sequence_id)
554
+ print(f"[TracingMetadata] Processing rollout {rollout_id} with {len(subtree_spans)} spans, otlp_enabled={otlp_enabled}")
555
+ logger.info(f"[TracingMetadata] Processing rollout {rollout_id} with {len(subtree_spans)} spans, otlp_enabled={otlp_enabled}")
556
+
557
+ # Fetch tracing metadata (environment, tags, session_id) from the rollout
558
+ environment, tags, session_id = self._get_tracing_metadata_from_rollout(store, rollout_id)
559
+ logger.info(f"[TracingMetadata] Fetched: environment={environment}, tags={tags}, session_id={session_id}")
560
+
561
+ if otlp_enabled:
562
+ # Build resource attributes for Mantisdk metadata
563
+ resource_attrs: Dict[str, Any] = {
564
+ LightningResourceAttributes.ROLLOUT_ID.value: rollout_id,
565
+ LightningResourceAttributes.ATTEMPT_ID.value: attempt_id,
566
+ LightningResourceAttributes.SPAN_SEQUENCE_ID.value: sequence_id_decimal,
567
+ LightningResourceAttributes.SPAN_TYPE.value: "rollout",
568
+ }
569
+
570
+ # Build span attributes for Langfuse-expected metadata
571
+ # Per Langfuse docs, use langfuse.* namespace for environment and tags
572
+ span_extra_attrs: Dict[str, Any] = {}
573
+ if session_id:
574
+ span_extra_attrs["session.id"] = session_id
575
+ logger.info(f"[TracingMetadata] Setting session.id={session_id}")
576
+ if environment:
577
+ span_extra_attrs["langfuse.environment"] = environment
578
+ logger.info(f"[TracingMetadata] Setting langfuse.environment={environment}")
579
+
580
+ # Extract call type from headers (set by @gepa.judge, @gepa.agent decorators)
581
+ call_type = headers_merged.get("x-mantis-call-type")
582
+
583
+ # Build final tags list, including call_type if present
584
+ final_tags = list(tags) if tags else []
585
+ if call_type and call_type not in final_tags:
586
+ final_tags.append(call_type)
587
+
588
+ if final_tags:
589
+ # Insight's OTEL ingestion expects tags on the *resource* under `langfuse.trace.tags`.
590
+ # Span-level `langfuse.tags` is not reliably ingested into `traces.tags`.
591
+ resource_attrs["langfuse.trace.tags"] = final_tags
592
+ # Keep span-level tags too for backwards-compat/debuggability
593
+ span_extra_attrs["langfuse.tags"] = final_tags
594
+ logger.info(f"[TracingMetadata] Setting langfuse.trace.tags={final_tags}")
595
+
596
+ # Prepare spans for export
597
+ spans_to_export: List[ReadableSpan] = []
598
+ for span in subtree_spans:
599
+ # Add resource attributes
600
+ span._resource = span._resource.merge( # pyright: ignore[reportPrivateUsage]
601
+ Resource.create(resource_attrs)
602
+ )
603
+ # Wrap with extra span attributes if we have any
604
+ if span_extra_attrs:
605
+ wrapped_span = SpanWithExtraAttributes(span, span_extra_attrs)
606
+ spans_to_export.append(wrapped_span)
607
+ else:
608
+ spans_to_export.append(span)
609
+
610
+ export_result = self._otlp_exporter.export(spans_to_export)
611
+ if export_result != SpanExportResult.SUCCESS:
612
+ logger.error(f"Failed to export rollout spans via OTLP. Result: {export_result}")
613
+ else:
614
+ # The old way: store does not support OTLP endpoint
615
+ for span in subtree_spans:
616
+ loop = self._ensure_loop()
617
+ add_otel_span_task = store.add_otel_span(
618
+ rollout_id=rollout_id,
619
+ attempt_id=attempt_id,
620
+ sequence_id=sequence_id_decimal,
621
+ readable_span=span,
622
+ )
623
+ fut = asyncio.run_coroutine_threadsafe(add_otel_span_task, loop)
624
+ fut.result()
625
+
626
+ elif otlp_enabled:
627
+ # Job-scoped spans (no rollout context): tag with job_id for experiment tracking
628
+ job_id = self._get_job_id_from_store(store)
629
+
630
+ # Extract Mantis tracing metadata from x-mantis-* headers
631
+ mantis_session_id = headers_merged.get("x-mantis-session-id")
632
+ mantis_environment = headers_merged.get("x-mantis-environment")
633
+ mantis_tags_str = headers_merged.get("x-mantis-tags")
634
+ mantis_call_type = headers_merged.get("x-mantis-call-type")
635
+ mantis_tags = None
636
+ if mantis_tags_str:
637
+ try:
638
+ mantis_tags = ast.literal_eval(mantis_tags_str) if mantis_tags_str.startswith("[") else None
639
+ except Exception:
640
+ pass
641
+
642
+ # Add call_type to tags if present
643
+ if mantis_call_type:
644
+ if mantis_tags is None:
645
+ mantis_tags = []
646
+ if mantis_call_type not in mantis_tags:
647
+ mantis_tags.append(mantis_call_type)
648
+
649
+ # Build span attributes for Mantis/Langfuse metadata
650
+ span_extra_attrs: Dict[str, Any] = {}
651
+ job_resource_attrs: Dict[str, Any] = {}
652
+
653
+ if mantis_session_id:
654
+ span_extra_attrs["session.id"] = mantis_session_id
655
+ logger.info(f"[TracingMetadata] Job span: session.id={mantis_session_id}")
656
+ if mantis_environment:
657
+ span_extra_attrs["langfuse.environment"] = mantis_environment
658
+ logger.info(f"[TracingMetadata] Job span: langfuse.environment={mantis_environment}")
659
+ if mantis_tags:
660
+ # Set tags as RESOURCE attributes (required for Insight ingestion)
661
+ job_resource_attrs["langfuse.trace.tags"] = mantis_tags
662
+ # Also set as span attributes for backwards-compat
663
+ span_extra_attrs["langfuse.tags"] = mantis_tags
664
+ logger.info(f"[TracingMetadata] Job span: langfuse.trace.tags={mantis_tags}")
665
+
666
+ # Prepare spans for export
667
+ spans_to_export: List[ReadableSpan] = []
668
+ for span in subtree_spans:
669
+ if job_id:
670
+ span._resource = span._resource.merge( # pyright: ignore[reportPrivateUsage]
671
+ Resource.create(
672
+ {
673
+ LightningResourceAttributes.JOB_ID.value: job_id,
674
+ LightningResourceAttributes.SPAN_TYPE.value: "job",
675
+ }
676
+ )
677
+ )
678
+ # Merge job-level langfuse resource attrs (tags) if present
679
+ try:
680
+ if "job_resource_attrs" in locals() and job_resource_attrs:
681
+ span._resource = span._resource.merge( # pyright: ignore[reportPrivateUsage]
682
+ Resource.create(job_resource_attrs)
683
+ )
684
+ except Exception:
685
+ pass
686
+ # Wrap with extra span attributes if we have any
687
+ if span_extra_attrs:
688
+ wrapped_span = SpanWithExtraAttributes(span, span_extra_attrs)
689
+ spans_to_export.append(wrapped_span)
690
+ else:
691
+ spans_to_export.append(span)
692
+
693
+ export_result = self._otlp_exporter.export(spans_to_export)
694
+ if export_result != SpanExportResult.SUCCESS:
695
+ logger.error(f"Failed to export job spans via OTLP. Result: {export_result}")
696
+ else:
697
+ logger.debug(f"Exported {len(spans_to_export)} job-scoped spans (job_id={job_id}, has_mantis_metadata={bool(span_extra_attrs)})")
698
+ else:
699
+ # No OTLP and no rollout context - skip with warning
700
+ logger.debug(
701
+ f"Skipping {len(subtree_spans)} spans: no rollout context and OTLP not enabled"
702
+ )
703
+
704
+ def _get_root_span_ids(self) -> Iterable[int]:
705
+ """Yield span_ids for root spans currently in the buffer.
706
+
707
+ A root span is defined as one with `parent is None`.
708
+
709
+ Yields:
710
+ int: Span id for each root span found.
711
+ """
712
+ for span in self._buffer:
713
+ if span.parent is None:
714
+ span_context = span.get_span_context()
715
+ if span_context is not None:
716
+ yield span_context.span_id
717
+
718
+ def _get_subtrees(self, root_span_id: int) -> Iterable[int]:
719
+ """Yield span_ids in the subtree rooted at `root_span_id`.
720
+
721
+ Depth-first traversal over the current buffer.
722
+
723
+ Args:
724
+ root_span_id: The span id of the root.
725
+
726
+ Yields:
727
+ int: Span ids including the root and all descendants found.
728
+ """
729
+ # Yield the root span id first.
730
+ yield root_span_id
731
+ for span in self._buffer:
732
+ # Check whether the span's parent is the root_span_id.
733
+ if span.parent is not None and span.parent.span_id == root_span_id:
734
+ span_context = span.get_span_context()
735
+ if span_context is not None:
736
+ # Recursively get child spans.
737
+ yield from self._get_subtrees(span_context.span_id)
738
+
739
+ def _pop_subtrees(self, root_span_id: int) -> List[ReadableSpan]:
740
+ """Remove and return the subtree for a particular root from the buffer.
741
+
742
+ Args:
743
+ root_span_id: Root span id identifying the subtree.
744
+
745
+ Returns:
746
+ list[ReadableSpan]: Spans that were part of the subtree. Order follows buffer order.
747
+ """
748
+ subtree_span_ids = set(self._get_subtrees(root_span_id))
749
+ subtree_spans: List[ReadableSpan] = []
750
+ new_buffer: List[ReadableSpan] = []
751
+ for span in self._buffer:
752
+ span_context = span.get_span_context()
753
+ if span_context is not None and span_context.span_id in subtree_span_ids:
754
+ subtree_spans.append(span)
755
+ else:
756
+ new_buffer.append(span)
757
+ # Replace buffer with remaining spans to avoid re-processing.
758
+ self._buffer = new_buffer
759
+ return subtree_spans
760
+
761
+
762
+ class LightningOpenTelemetry(OpenTelemetry):
763
+ """OpenTelemetry integration that exports spans to the Lightning store.
764
+
765
+ Responsibilities:
766
+
767
+ * Ensures each request is annotated with a per-attempt sequence id so spans
768
+ are ordered deterministically even with clock skew across nodes.
769
+ * Uses [`LightningSpanExporter`][mantisdk.llm_proxy.LightningSpanExporter] to persist spans for analytics and training.
770
+ * Adds Mantisdk-specific attributes (session_id, tags, environment) to spans.
771
+
772
+ Args:
773
+ otlp_endpoint: Optional OTLP endpoint URL for direct trace export to external
774
+ collectors (e.g., Langfuse/Insight). When set, spans are exported directly without
775
+ requiring a store or rollout/attempt headers.
776
+ otlp_headers: Optional dict of HTTP headers for OTLP authentication (e.g., Basic Auth).
777
+ """
778
+
779
+ def __init__(self, otlp_endpoint: Optional[str] = None, otlp_headers: Optional[Dict[str, str]] = None):
780
+ exporter = LightningSpanExporter(otlp_endpoint=otlp_endpoint, otlp_headers=otlp_headers)
781
+ config = OpenTelemetryConfig(exporter=exporter)
782
+
783
+ # Check for tracer initialization
784
+ if _check_tracer_provider():
785
+ logger.error("Tracer is already initialized. OpenTelemetry may not work as expected.")
786
+
787
+ super().__init__(config=config) # pyright: ignore[reportUnknownMemberType]
788
+
789
+ # Store exporter reference for debugging
790
+ self._custom_exporter = exporter
791
+
792
+ def _init_tracing(self, tracer_provider):
793
+ """Override to ensure our span processor is added even when reusing existing TracerProvider.
794
+
795
+ LiteLLM's parent _init_tracing reuses existing TracerProviders but doesn't add
796
+ our custom span processor to them. We override to force adding our processor.
797
+ """
798
+ from opentelemetry import trace as otel_trace_api
799
+ from opentelemetry.sdk.trace import TracerProvider as TracerProviderSDK
800
+ from opentelemetry.trace import SpanKind
801
+
802
+ # Call parent to set up tracer
803
+ super()._init_tracing(tracer_provider) # pyright: ignore[reportUnknownMemberType]
804
+
805
+ # If an existing provider was reused, add our span processor to it
806
+ current_provider = otel_trace_api.get_tracer_provider()
807
+ if isinstance(current_provider, TracerProviderSDK):
808
+ # Check if our processor is already added
809
+ has_our_processor = False
810
+ if hasattr(current_provider, "_active_span_processor"):
811
+ active_processor = current_provider._active_span_processor # pyright: ignore[reportPrivateUsage]
812
+ if hasattr(active_processor, "_span_processors"):
813
+ for proc in active_processor._span_processors: # pyright: ignore[reportPrivateUsage]
814
+ if hasattr(proc, "_exporter") and isinstance(proc._exporter, LightningSpanExporter): # pyright: ignore[reportPrivateUsage]
815
+ has_our_processor = True
816
+ break
817
+
818
+ if not has_our_processor:
819
+ # Add our span processor
820
+ span_processor = self._get_span_processor()
821
+ current_provider.add_span_processor(span_processor)
822
+
823
+ self.span_kind = SpanKind
824
+
825
+ def _get_span_processor(self, dynamic_headers: Optional[dict] = None):
826
+ """Override to ensure our custom exporter is used.
827
+
828
+ LiteLLM's parent class checks if OTEL_EXPORTER has export() method and wraps it
829
+ in SimpleSpanProcessor. We override to add logging and ensure this happens.
830
+ """
831
+ from opentelemetry.sdk.trace.export import SimpleSpanProcessor
832
+
833
+ # Use our custom exporter directly
834
+ if hasattr(self.OTEL_EXPORTER, "export"):
835
+ processor = SimpleSpanProcessor(self.OTEL_EXPORTER)
836
+ return processor
837
+
838
+ # Fallback to parent implementation
839
+ return super()._get_span_processor(dynamic_headers) # pyright: ignore[reportUnknownMemberType]
840
+
841
+ def set_attributes(self, span: Any, kwargs: Dict[str, Any], response_obj: Optional[Any]) -> None:
842
+ """Override to add Mantisdk-specific attributes from metadata.
843
+
844
+ Extracts session_id, tags, and environment from kwargs["metadata"] and sets
845
+ them as OTEL span attributes for visibility in Insight/Langfuse.
846
+
847
+ Also extracts extra_headers from kwargs and sets them as metadata.requester_custom_headers
848
+ so the proxy can read them for tagging.
849
+ """
850
+ # Call parent implementation first
851
+ super().set_attributes(span, kwargs, response_obj) # pyright: ignore[reportUnknownMemberType]
852
+
853
+ # Extract extra_headers from kwargs and set as metadata.requester_custom_headers
854
+ # This allows the proxy to read x-mantis-* headers for tagging
855
+ # Check multiple sources: extra_headers (from OpenAI SDK), kwargs["headers"], and context variable
856
+ extra_headers = kwargs.get("extra_headers") or kwargs.get("extraHeaders")
857
+ request_headers = kwargs.get("headers") or {}
858
+ context_headers = _request_headers_context.get() or {}
859
+
860
+ # Merge all sources: extra_headers, request_headers, and context headers
861
+ merged_headers = {}
862
+ if extra_headers and isinstance(extra_headers, dict):
863
+ merged_headers.update(extra_headers)
864
+ if request_headers and isinstance(request_headers, dict):
865
+ # Extract x-mantis-* headers from HTTP request headers passed by LiteLLM
866
+ for key, value in request_headers.items():
867
+ if isinstance(key, str) and key.lower().startswith("x-mantis-"):
868
+ merged_headers[key] = value
869
+ # Also check context variable (set by MantisHeadersMiddleware)
870
+ if context_headers:
871
+ merged_headers.update(context_headers)
872
+
873
+ if merged_headers:
874
+ # LiteLLM stores this as a stringified dict in metadata.requester_custom_headers
875
+ span.set_attribute("metadata.requester_custom_headers", str(merged_headers))
876
+
877
+ # Extract Mantisdk tracing metadata from kwargs
878
+ metadata = kwargs.get("metadata", {}) if kwargs else {}
879
+ if not metadata:
880
+ return
881
+
882
+ # Set session_id as span attribute (standard Langfuse/Insight attribute)
883
+ session_id = metadata.get("session_id")
884
+ if session_id:
885
+ span.set_attribute("session.id", session_id)
886
+
887
+ # Set tags as span attribute
888
+ tags = metadata.get("tags")
889
+ if tags and isinstance(tags, list):
890
+ # Set as JSON array for Langfuse compatibility
891
+ import json
892
+ span.set_attribute("tags", json.dumps(tags))
893
+ # Also set individual tag attributes for filtering
894
+ for i, tag in enumerate(tags):
895
+ span.set_attribute(f"tag.{i}", str(tag))
896
+
897
+ # Set environment as span attribute
898
+ environment = metadata.get("environment")
899
+ if environment:
900
+ span.set_attribute("environment", environment)
901
+
902
+ async def async_pre_call_deployment_hook(
903
+ self, kwargs: Dict[str, Any], call_type: Optional[CallTypes] = None
904
+ ) -> Optional[Dict[str, Any]]:
905
+ """The root span is sometimes missing (e.g., when Anthropic endpoint is used).
906
+ It is created in an auth module in LiteLLM. If it's missing, we create it here.
907
+ """
908
+ if "metadata" not in kwargs or "litellm_parent_otel_span" not in kwargs["metadata"]:
909
+ parent_otel_span = self.create_litellm_proxy_request_started_span( # type: ignore
910
+ start_time=datetime.now(),
911
+ headers=kwargs.get("headers", {}),
912
+ )
913
+ updated_metadata = {**kwargs.get("metadata", {}), "litellm_parent_otel_span": parent_otel_span}
914
+
915
+ return {**kwargs, "metadata": updated_metadata}
916
+ else:
917
+ return kwargs
918
+
919
+
920
+ class RolloutAttemptMiddleware(BaseHTTPMiddleware):
921
+ """
922
+ Rewrites /rollout/{rid}/attempt/{aid}/... -> /...
923
+ and injects x-rollout-id, x-attempt-id, x-sequence-id headers.
924
+
925
+ LLMProxy can update store later without rebuilding middleware.
926
+ """
927
+
928
+ async def dispatch(self, request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response:
929
+ # Decode rollout and attempt from the URL prefix. Example:
930
+ # /rollout/r123/attempt/a456/v1/chat/completions
931
+ # becomes
932
+ # /v1/chat/completions
933
+ # while adding request-scoped headers for trace attribution.
934
+ path = request.url.path
935
+
936
+ match = re.match(r"^/rollout/([^/]+)/attempt/([^/]+)(/.*)?$", path)
937
+ if match:
938
+ rollout_id = match.group(1)
939
+ attempt_id = match.group(2)
940
+ new_path = match.group(3) if match.group(3) is not None else "/"
941
+
942
+ # Rewrite the ASGI scope path so downstream sees a clean OpenAI path.
943
+ request.scope["path"] = new_path
944
+ request.scope["raw_path"] = new_path.encode()
945
+
946
+ store = get_active_llm_proxy().get_store()
947
+ if store is not None:
948
+ # Allocate a monotonic sequence id per (rollout, attempt).
949
+ sequence_id = await store.get_next_span_sequence_id(rollout_id, attempt_id)
950
+
951
+ # Inject headers so downstream components and exporters can retrieve them.
952
+ request.scope["headers"] = list(request.scope["headers"]) + [
953
+ (b"x-rollout-id", rollout_id.encode()),
954
+ (b"x-attempt-id", attempt_id.encode()),
955
+ (b"x-sequence-id", str(sequence_id).encode()),
956
+ ]
957
+ else:
958
+ logger.warning("Store is not set. Skipping sequence id allocation and header injection.")
959
+
960
+ response = await call_next(request)
961
+ return response
962
+
963
+
964
+ class MantisHeadersMiddleware(BaseHTTPMiddleware):
965
+ """Middleware to intercept x-mantis-* HTTP headers and store them in context.
966
+
967
+ This allows LiteLLM callbacks to access custom headers (like x-mantis-call-type)
968
+ even if LiteLLM doesn't pass them through kwargs.
969
+ """
970
+
971
+ async def dispatch(self, request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response:
972
+ # Extract x-mantis-* headers from HTTP request
973
+ mantis_headers: Dict[str, str] = {}
974
+ for header_name, header_value in request.headers.items():
975
+ if isinstance(header_name, str) and header_name.lower().startswith("x-mantis-"):
976
+ mantis_headers[header_name] = header_value
977
+
978
+ # Store in context variable for callback access
979
+ if mantis_headers:
980
+ _request_headers_context.set(mantis_headers)
981
+
982
+ response = await call_next(request)
983
+ return response
984
+
985
+
986
+ class MessageInspectionMiddleware(BaseHTTPMiddleware):
987
+ """Middleware to inspect the request and response bodies.
988
+
989
+ It's for debugging purposes. Add it via "message_inspection" middleware alias.
990
+ """
991
+
992
+ async def dispatch(self, request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response:
993
+ ti = time.time()
994
+ logger.info(f"Received request with scope: {request.scope}")
995
+ logger.info(f"Received request with body: {await request.body()}")
996
+ response = await call_next(request)
997
+ elapsed = time.time() - ti
998
+ logger.info(f"Response to request took {elapsed} seconds")
999
+ logger.info(f"Received response with status code: {response.status_code}")
1000
+ logger.info(f"Received response with body: {response.body}")
1001
+ return response
1002
+
1003
+
1004
+ class StreamConversionMiddleware(BaseHTTPMiddleware):
1005
+ """Middleware to convert streaming responses to non-streaming responses.
1006
+
1007
+ Useful for backend that only supports non-streaming responses.
1008
+
1009
+ LiteLLM's OpenTelemetry is also buggy with streaming responses.
1010
+ The conversion will hopefully bypass the bug.
1011
+ """
1012
+
1013
+ async def dispatch(self, request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response:
1014
+ # Only process POST requests to completion endpoints
1015
+ if request.method != "POST":
1016
+ return await call_next(request)
1017
+
1018
+ # Check if it's a chat completions or messages endpoint
1019
+ endpoint_format: Literal["openai", "anthropic", "unknown"] = "unknown"
1020
+ if request.url.path.endswith("/chat/completions") or "/chat/completions?" in request.url.path:
1021
+ endpoint_format = "openai"
1022
+ elif request.url.path.endswith("/messages") or "/messages?" in request.url.path:
1023
+ endpoint_format = "anthropic"
1024
+ else:
1025
+ endpoint_format = "unknown"
1026
+
1027
+ if endpoint_format == "unknown":
1028
+ # Directly bypass the middleware
1029
+ return await call_next(request)
1030
+
1031
+ # Read the request body
1032
+ try:
1033
+ json_body = await request.json()
1034
+ except json.JSONDecodeError:
1035
+ logger.warning(f"Request body is not valid JSON: {request.body}")
1036
+ return await call_next(request)
1037
+
1038
+ # Check if streaming is requested
1039
+ is_streaming = json_body.get("stream", False)
1040
+
1041
+ # Simple case: no streaming requested, just return the response
1042
+ if not is_streaming:
1043
+ return await call_next(request)
1044
+
1045
+ # Now the stream case
1046
+ return await self._handle_stream_case(request, json_body, endpoint_format, call_next)
1047
+
1048
+ async def _handle_stream_case(
1049
+ self,
1050
+ request: Request,
1051
+ json_body: Dict[str, Any],
1052
+ endpoint_format: Literal["openai", "anthropic"],
1053
+ call_next: Callable[[Request], Awaitable[Response]],
1054
+ ) -> Response:
1055
+ # 1) Modify the request body to force stream=False
1056
+ modified_json = dict(json_body)
1057
+ modified_json["stream"] = False
1058
+ modified_body = json.dumps(modified_json).encode("utf-8")
1059
+
1060
+ # 2) Build a new scope + receive that yields our modified body
1061
+ scope: Scope = dict(request.scope)
1062
+ # rewrite headers for accept/content-length
1063
+ new_headers: List[Tuple[bytes, bytes]] = []
1064
+ saw_accept = False
1065
+ for k, v in scope["headers"]:
1066
+ kl = k.lower()
1067
+ if kl == b"accept":
1068
+ saw_accept = True
1069
+ new_headers.append((k, b"application/json"))
1070
+ elif kl == b"content-length":
1071
+ # replace with new length
1072
+ continue
1073
+ else:
1074
+ new_headers.append((k, v))
1075
+ if not saw_accept:
1076
+ new_headers.append((b"accept", b"application/json"))
1077
+ new_headers.append((b"content-length", str(len(modified_body)).encode("ascii")))
1078
+ scope["headers"] = new_headers
1079
+
1080
+ # Directly modify the request body
1081
+ # Creating a new request won't work because request is cached in the base class
1082
+ request._body = modified_body # type: ignore
1083
+
1084
+ response = await call_next(request)
1085
+
1086
+ buffered: Optional[bytes] = None
1087
+ # 4) If OK, buffer the response body (it should be JSON because we forced stream=False)
1088
+ if 200 <= response.status_code < 300:
1089
+ try:
1090
+ if hasattr(response, "body_iterator"):
1091
+ # Buffer body safely
1092
+ body_chunks: List[bytes] = []
1093
+ async for chunk in response.body_iterator: # type: ignore
1094
+ body_chunks.append(chunk) # type: ignore
1095
+ buffered = b"".join(body_chunks)
1096
+ else:
1097
+ buffered = response.body # type: ignore
1098
+
1099
+ data = json.loads(buffered or b"{}")
1100
+
1101
+ if endpoint_format == "anthropic":
1102
+ return StreamingResponse(
1103
+ self.anthropic_stream_generator(data),
1104
+ media_type="text/event-stream",
1105
+ headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
1106
+ )
1107
+ else:
1108
+ # openai format
1109
+ return StreamingResponse(
1110
+ self.openai_stream_generator(data),
1111
+ media_type="text/event-stream",
1112
+ headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
1113
+ )
1114
+ except Exception as e:
1115
+ # If anything goes wrong, fall back to non-streaming JSON
1116
+ logger.exception(f"Error converting to stream; returning non-stream response: {e}")
1117
+ # Rebuild the consumed response
1118
+ return Response(
1119
+ content=buffered if buffered is not None else b"",
1120
+ status_code=response.status_code,
1121
+ headers=dict(response.headers),
1122
+ media_type=response.media_type,
1123
+ background=response.background,
1124
+ )
1125
+ else:
1126
+ return response
1127
+
1128
+ async def anthropic_stream_generator(self, original_response: Dict[str, Any]):
1129
+ """Generate Anthropic SSE-formatted chunks from complete content blocks
1130
+
1131
+ This is a dirty hack for Anthropic-style streaming from non-streaming response.
1132
+ The sse format is subject to change based on Anthropic's implementation.
1133
+ If so, try to use `MessageInspectionMiddleware` to inspect the update and fix accordingly.
1134
+ """
1135
+ # Anthropic format - handle multiple content blocks (text + tool_use)
1136
+ content_blocks: List[Dict[str, Any]] = original_response.get("content", [])
1137
+ message_id = original_response.get("id", f"msg_{int(time.time() * 1000)}")
1138
+ model = original_response.get("model", "claude")
1139
+
1140
+ # Send message_start event
1141
+ message_start: Dict[str, Any] = {
1142
+ "type": "message_start",
1143
+ "message": {
1144
+ "id": message_id,
1145
+ "type": "message",
1146
+ "role": "assistant",
1147
+ "content": [],
1148
+ "model": model,
1149
+ "stop_reason": None,
1150
+ "stop_sequence": None,
1151
+ "usage": original_response.get("usage", {"input_tokens": 0, "output_tokens": 0}),
1152
+ },
1153
+ }
1154
+ yield f"event: message_start\ndata: {json.dumps(message_start)}\n\n"
1155
+
1156
+ # Send ping to keep connection alive
1157
+ ping = {"type": "ping"}
1158
+ yield f"event: ping\ndata: {json.dumps(ping)}\n\n"
1159
+
1160
+ # Process each content block
1161
+ for block_index, block in enumerate(content_blocks):
1162
+ block_type = block.get("type", "text")
1163
+
1164
+ if block_type == "text":
1165
+ # Handle text block
1166
+ content = block.get("text", "")
1167
+
1168
+ # Send content_block_start event
1169
+ content_block_start = {
1170
+ "type": "content_block_start",
1171
+ "index": block_index,
1172
+ "content_block": {"type": "text", "text": ""},
1173
+ }
1174
+ yield f"event: content_block_start\ndata: {json.dumps(content_block_start)}\n\n"
1175
+
1176
+ # Stream text content in chunks
1177
+ if content:
1178
+ words = content.split()
1179
+ chunk_size = 5
1180
+
1181
+ for i in range(0, len(words), chunk_size):
1182
+ chunk_words = words[i : i + chunk_size]
1183
+ text_chunk = " ".join(chunk_words)
1184
+
1185
+ # Add space after chunk unless it's the last one
1186
+ if i + chunk_size < len(words):
1187
+ text_chunk += " "
1188
+
1189
+ content_block_delta = {
1190
+ "type": "content_block_delta",
1191
+ "index": block_index,
1192
+ "delta": {"type": "text_delta", "text": text_chunk},
1193
+ }
1194
+ yield f"event: content_block_delta\ndata: {json.dumps(content_block_delta)}\n\n"
1195
+ await asyncio.sleep(0.02)
1196
+
1197
+ # Send content_block_stop event
1198
+ content_block_stop = {"type": "content_block_stop", "index": block_index}
1199
+ yield f"event: content_block_stop\ndata: {json.dumps(content_block_stop)}\n\n"
1200
+
1201
+ elif block_type == "tool_use":
1202
+ # Handle tool_use block
1203
+ tool_name = block.get("name", "")
1204
+ tool_input = block.get("input", {})
1205
+ tool_id = block.get("id", f"toolu_{int(time.time() * 1000)}")
1206
+
1207
+ # Send content_block_start event for tool use
1208
+ content_block_start: Dict[str, Any] = {
1209
+ "type": "content_block_start",
1210
+ "index": block_index,
1211
+ "content_block": {"type": "tool_use", "id": tool_id, "name": tool_name, "input": {}},
1212
+ }
1213
+ yield f"event: content_block_start\ndata: {json.dumps(content_block_start)}\n\n"
1214
+
1215
+ # Stream tool input as JSON string chunks
1216
+ input_json = json.dumps(tool_input)
1217
+ chunk_size = 20 # characters per chunk for JSON
1218
+
1219
+ for i in range(0, len(input_json), chunk_size):
1220
+ json_chunk = input_json[i : i + chunk_size]
1221
+
1222
+ content_block_delta = {
1223
+ "type": "content_block_delta",
1224
+ "index": block_index,
1225
+ "delta": {"type": "input_json_delta", "partial_json": json_chunk},
1226
+ }
1227
+ yield f"event: content_block_delta\ndata: {json.dumps(content_block_delta)}\n\n"
1228
+ await asyncio.sleep(0.01)
1229
+
1230
+ # Send content_block_stop event
1231
+ content_block_stop = {"type": "content_block_stop", "index": block_index}
1232
+ yield f"event: content_block_stop\ndata: {json.dumps(content_block_stop)}\n\n"
1233
+
1234
+ # Send message_delta event with stop reason
1235
+ message_delta = {
1236
+ "type": "message_delta",
1237
+ "delta": {"stop_reason": original_response.get("stop_reason", "end_turn"), "stop_sequence": None},
1238
+ "usage": {"output_tokens": original_response.get("usage", {}).get("output_tokens", 0)},
1239
+ }
1240
+ yield f"event: message_delta\ndata: {json.dumps(message_delta)}\n\n"
1241
+
1242
+ # Send message_stop event
1243
+ message_stop = {"type": "message_stop"}
1244
+ yield f"event: message_stop\ndata: {json.dumps(message_stop)}\n\n"
1245
+
1246
+ async def openai_stream_generator(self, response_json: Dict[str, Any]) -> AsyncGenerator[str, Any]:
1247
+ """
1248
+ Convert a *complete* OpenAI chat.completions choice into a stream of
1249
+ OpenAI-compatible SSE chunks.
1250
+
1251
+ This emits:
1252
+
1253
+ - an initial delta with the role ("assistant"),
1254
+ - a sequence of deltas for message.content (split into small chunks),
1255
+ - deltas for any tool_calls (including id/name and chunked arguments),
1256
+ - a terminal chunk with finish_reason,
1257
+ - and finally the literal '[DONE]'.
1258
+
1259
+ Notes:
1260
+
1261
+ - We only handle a *single* choice (index 0 typically).
1262
+ - We purposefully don't attempt to stream logprobs.
1263
+ - Chunking strategy is simple and conservative to avoid splitting
1264
+ multi-byte characters: we slice on spaces where possible, then fall
1265
+ back to fixed-size substrings.
1266
+ """
1267
+ choice = cast(Dict[str, Any], (response_json.get("choices") or [{}])[0])
1268
+ model = response_json.get("model", "unknown")
1269
+ created: int = int(time.time())
1270
+ index: int = choice.get("index", 0)
1271
+
1272
+ message: Dict[str, Any] = choice.get("message", {}) or {}
1273
+ role: str = message.get("role", "assistant")
1274
+ content: str = message.get("content") or ""
1275
+ tool_calls: List[Any] = message.get("tool_calls") or []
1276
+ finish_reason: Optional[str] = choice.get(
1277
+ "finish_reason"
1278
+ ) # e.g., "stop", "length", "tool_calls", "content_filter"
1279
+
1280
+ def sse_chunk(obj: Dict[str, Any]) -> str:
1281
+ return f"data: {json.dumps(obj, ensure_ascii=False)}\n\n"
1282
+
1283
+ # 1) initial chunk with the role
1284
+ yield sse_chunk(
1285
+ {
1286
+ "id": f"chatcmpl-{created}",
1287
+ "object": "chat.completion.chunk",
1288
+ "created": created,
1289
+ "model": model,
1290
+ "choices": [{"index": index, "delta": {"role": role}, "finish_reason": None}],
1291
+ }
1292
+ )
1293
+
1294
+ # 2) stream textual content as small deltas
1295
+ async def stream_content(text: str):
1296
+ if not text:
1297
+ return
1298
+ # prefer splitting on spaces in ~20–40 char pieces
1299
+ approx = 28
1300
+ start = 0
1301
+ n = len(text)
1302
+ while start < n:
1303
+ end = min(start + approx, n)
1304
+ if end < n:
1305
+ # try to break on a space going forward
1306
+ space = text.rfind(" ", start, end)
1307
+ if space > start:
1308
+ end = space + 1
1309
+ delta_text = text[start:end]
1310
+ start = end
1311
+ if not delta_text:
1312
+ break
1313
+ yield sse_chunk(
1314
+ {
1315
+ "id": f"chatcmpl-{created}",
1316
+ "object": "chat.completion.chunk",
1317
+ "created": created,
1318
+ "model": model,
1319
+ "choices": [{"index": index, "delta": {"content": delta_text}, "finish_reason": None}],
1320
+ }
1321
+ )
1322
+ # tiny pause helps some UIs animate smoothly; keep very small
1323
+ await asyncio.sleep(0.0)
1324
+
1325
+ async for piece in stream_content(content): # type: ignore[misc]
1326
+ yield piece # pass through the produced chunks
1327
+
1328
+ # 3) stream tool_calls if present (id/name first, then arguments piecemeal)
1329
+ for tc_index, tc in enumerate(tool_calls):
1330
+ tc_type = tc.get("type", "function")
1331
+ tc_id = tc.get("id") or f"call_{created}_{tc_index}"
1332
+ fn: Dict[str, Any] = (tc.get("function") or {}) if tc_type == "function" else {}
1333
+ fn_name: str = fn.get("name", "")
1334
+ fn_args: str = fn.get("arguments", "") or ""
1335
+
1336
+ # (a) delta that announces the tool call id/type/name
1337
+ yield sse_chunk(
1338
+ {
1339
+ "id": f"chatcmpl-{created}",
1340
+ "object": "chat.completion.chunk",
1341
+ "created": created,
1342
+ "model": model,
1343
+ "choices": [
1344
+ {
1345
+ "index": index,
1346
+ "delta": {
1347
+ "tool_calls": [
1348
+ {"index": tc_index, "id": tc_id, "type": tc_type, "function": {"name": fn_name}}
1349
+ ]
1350
+ },
1351
+ "finish_reason": None,
1352
+ }
1353
+ ],
1354
+ }
1355
+ )
1356
+
1357
+ # (b) stream arguments in small substrings
1358
+ arg_chunk_size = 40
1359
+ for pos in range(0, len(fn_args), arg_chunk_size):
1360
+ partial = fn_args[pos : pos + arg_chunk_size]
1361
+ yield sse_chunk(
1362
+ {
1363
+ "id": f"chatcmpl-{created}",
1364
+ "object": "chat.completion.chunk",
1365
+ "created": created,
1366
+ "model": model,
1367
+ "choices": [
1368
+ {
1369
+ "index": index,
1370
+ "delta": {"tool_calls": [{"index": tc_index, "function": {"arguments": partial}}]},
1371
+ "finish_reason": None,
1372
+ }
1373
+ ],
1374
+ }
1375
+ )
1376
+ await asyncio.sleep(0.0)
1377
+
1378
+ # 4) terminal chunk with finish_reason (default to "stop" if missing)
1379
+ yield sse_chunk(
1380
+ {
1381
+ "id": f"chatcmpl-{created}",
1382
+ "object": "chat.completion.chunk",
1383
+ "created": created,
1384
+ "model": model,
1385
+ "choices": [
1386
+ {
1387
+ "index": index,
1388
+ "delta": {},
1389
+ "finish_reason": finish_reason or ("tool_calls" if tool_calls else "stop"),
1390
+ }
1391
+ ],
1392
+ }
1393
+ )
1394
+
1395
+ # 5) literal DONE sentinel
1396
+ yield "data: [DONE]\n\n"
1397
+
1398
+
1399
+ _MIDDLEWARE_REGISTRY: Dict[str, Type[BaseHTTPMiddleware]] = {
1400
+ "rollout_attempt": RolloutAttemptMiddleware,
1401
+ "stream_conversion": StreamConversionMiddleware,
1402
+ "message_inspection": MessageInspectionMiddleware,
1403
+ "mantis_headers": MantisHeadersMiddleware,
1404
+ }
1405
+
1406
+
1407
+ _CALLBACK_REGISTRY = {
1408
+ "return_token_ids": AddReturnTokenIds,
1409
+ "logprobs": AddLogprobs,
1410
+ "opentelemetry": LightningOpenTelemetry,
1411
+ }
1412
+
1413
+
1414
+ class LLMProxy:
1415
+ """Host a LiteLLM OpenAI-compatible proxy bound to a LightningStore.
1416
+
1417
+ The proxy:
1418
+
1419
+ * Serves an OpenAI-compatible API via uvicorn.
1420
+ * Adds rollout/attempt routing and headers via middleware.
1421
+ * Registers OTEL export and token-id callbacks.
1422
+ * Writes a LiteLLM worker config file with `model_list` and settings.
1423
+
1424
+ Lifecycle:
1425
+
1426
+ * [`start()`][mantisdk.LLMProxy.start] writes config, starts uvicorn server in a thread, and waits until ready.
1427
+ * [`stop()`][mantisdk.LLMProxy.stop] tears down the server and removes the temp config file.
1428
+ * [`restart()`][mantisdk.LLMProxy.restart] convenience wrapper to stop then start.
1429
+
1430
+ !!! note
1431
+
1432
+ As the LLM Proxy sets up an OpenTelemetry tracer, it's recommended to run it in a different
1433
+ process from the main runner (i.e., tracer from agents). See `launch_mode` for how to change that.
1434
+
1435
+ !!! warning
1436
+
1437
+ By default (or when "stream_conversion" middleware is enabled), the LLM Proxy will convert OpenAI and Anthropic requests with `stream=True`
1438
+ to a non-streaming request before going through the LiteLLM proxy. This is because the OpenTelemetry tracer provided by
1439
+ LiteLLM is buggy with streaming responses. You can disable this by removing the "stream_conversion" middleware.
1440
+ In that case, you might lose some tracing information like token IDs.
1441
+
1442
+ !!! danger
1443
+
1444
+ Do not run LLM proxy in the same process as the main runner. It's easy to cause conflicts in the tracer provider
1445
+ with tracers like [`AgentOpsTracer`][mantisdk.AgentOpsTracer].
1446
+
1447
+ Args:
1448
+ port: TCP port to bind. Will bind to a random port if not provided.
1449
+ model_list: LiteLLM `model_list` entries.
1450
+ store: LightningStore used for span sequence and persistence.
1451
+ host: Publicly reachable host used in resource endpoints. See `host` of `launcher_args` for more details.
1452
+ litellm_config: Extra LiteLLM proxy config merged with `model_list`.
1453
+ num_retries: Default LiteLLM retry count injected into `litellm_settings`.
1454
+ num_workers: Number of workers to run in the server. Only applicable for "mp" launch mode. Ignored if launcher_args is provided.
1455
+ When `num_workers > 1`, the server will be run using [gunicorn](https://gunicorn.org/).
1456
+ launch_mode: Launch mode for the server. Defaults to "mp". Cannot be used together with launcher_args. Ignored if launcher_args is provided.
1457
+ It's recommended to use `launch_mode="mp"` to launch the proxy, which will launch the server in a separate process.
1458
+ `launch_mode="thread"` can also be used if used in caution. It will launch the server in a separate thread.
1459
+ `launch_mode="asyncio"` launches the server in the current thread as an asyncio task.
1460
+ It is NOT recommended because it often causes hanging requests. Only use it if you know what you are doing.
1461
+ launcher_args: Arguments for the server launcher. If this is provided, host, port, and launch_mode will be ignored. Cannot be used together with port, host, and launch_mode.
1462
+ middlewares: List of FastAPI middleware classes or strings to register. You can specify the class aliases or classes that have been imported.
1463
+ If not provided, the default middlewares (RolloutAttemptMiddleware and StreamConversionMiddleware) will be used.
1464
+ Available middleware aliases are: "rollout_attempt", "stream_conversion", "message_inspection".
1465
+ Middlewares are the **first layer** of request processing. They are applied to all requests before the LiteLLM proxy.
1466
+ callbacks: List of LiteLLM callback classes or strings to register. You can specify the class aliases or classes that have been imported.
1467
+ If not provided, the default callbacks (AddReturnTokenIds and LightningOpenTelemetry) will be used.
1468
+ Available callback aliases are: "return_token_ids", "opentelemetry", "logprobs".
1469
+ otlp_endpoint: Optional OTLP endpoint URL for direct trace export to external
1470
+ collectors (e.g., Langfuse/Insight). When set, spans are exported directly without
1471
+ requiring rollout/attempt headers. Format: "http://host:port/api/public/otel/v1/traces"
1472
+ otlp_headers: Optional dict of HTTP headers for OTLP authentication.
1473
+ For Langfuse/Insight, use Basic Auth: {"Authorization": "Basic base64(publicKey:secretKey)"}
1474
+ """
1475
+
1476
+ def __init__(
1477
+ self,
1478
+ port: int | None = None,
1479
+ model_list: List[ModelConfig] | None = None,
1480
+ store: Optional[LightningStore] = None,
1481
+ host: str | None = None,
1482
+ litellm_config: Dict[str, Any] | None = None,
1483
+ num_retries: int = 0,
1484
+ num_workers: int = 1,
1485
+ launch_mode: LaunchMode = "mp",
1486
+ launcher_args: PythonServerLauncherArgs | None = None,
1487
+ middlewares: Sequence[Union[Type[BaseHTTPMiddleware], str]] | None = None,
1488
+ callbacks: Sequence[Union[Type[CustomLogger], str]] | None = None,
1489
+ otlp_endpoint: Optional[str] = None,
1490
+ otlp_headers: Optional[Dict[str, str]] = None,
1491
+ ):
1492
+ self.store = store
1493
+ self._otlp_endpoint = otlp_endpoint
1494
+ self._otlp_headers = otlp_headers
1495
+
1496
+ # Log OTLP configuration for diagnostics
1497
+ if otlp_endpoint:
1498
+ logger.info(f"LLMProxy initialized with OTLP endpoint: {otlp_endpoint}")
1499
+ if otlp_headers:
1500
+ logger.info(f"LLMProxy OTLP headers configured: {list(otlp_headers.keys())}")
1501
+ else:
1502
+ logger.debug("LLMProxy initialized without OTLP endpoint (will use store-based export)")
1503
+
1504
+ if launcher_args is not None and (
1505
+ port is not None or host is not None or launch_mode != "mp" or num_workers != 1
1506
+ ):
1507
+ raise ValueError("port, host, launch_mode, and num_workers cannot be set when launcher_args is provided.")
1508
+
1509
+ self.server_launcher_args = launcher_args or PythonServerLauncherArgs(
1510
+ port=port,
1511
+ host=host,
1512
+ launch_mode=launch_mode,
1513
+ n_workers=num_workers,
1514
+ # NOTE: This /health endpoint can be slow sometimes because it actually probes the backend LLM service.
1515
+ healthcheck_url="/health",
1516
+ startup_timeout=60.0,
1517
+ )
1518
+
1519
+ if self.server_launcher_args.healthcheck_url is None:
1520
+ logger.warning("healthcheck_url is not set. LLM Proxy will not be checked for healthiness after starting.")
1521
+
1522
+ self.model_list = model_list or []
1523
+ self.litellm_config = litellm_config or {}
1524
+
1525
+ # Ensure num_retries is present inside the litellm_settings block.
1526
+ self.litellm_config.setdefault("litellm_settings", {})
1527
+ self.litellm_config["litellm_settings"].setdefault("num_retries", num_retries)
1528
+ self.server_launcher = PythonServerLauncher(app, self.server_launcher_args, noop_context())
1529
+
1530
+ self._config_file = None
1531
+
1532
+ self.middlewares: List[Type[BaseHTTPMiddleware]] = []
1533
+ if middlewares is None:
1534
+ middlewares = ["mantis_headers", "rollout_attempt", "stream_conversion"]
1535
+ for middleware in middlewares:
1536
+ if isinstance(middleware, str):
1537
+ if middleware not in _MIDDLEWARE_REGISTRY:
1538
+ raise ValueError(
1539
+ f"Invalid middleware alias: {middleware}. Available aliases are: {list(_MIDDLEWARE_REGISTRY.keys())}"
1540
+ )
1541
+ middleware = _MIDDLEWARE_REGISTRY[middleware]
1542
+ self.middlewares.append(middleware)
1543
+ else:
1544
+ self.middlewares.append(middleware)
1545
+
1546
+ self.callbacks: List[Type[CustomLogger]] = []
1547
+ if callbacks is None:
1548
+ callbacks = ["return_token_ids", "opentelemetry"]
1549
+ for callback in callbacks:
1550
+ if isinstance(callback, str):
1551
+ if callback not in _CALLBACK_REGISTRY:
1552
+ raise ValueError(
1553
+ f"Invalid callback alias: {callback}. Available aliases are: {list(_CALLBACK_REGISTRY.keys())}"
1554
+ )
1555
+ callback = _CALLBACK_REGISTRY[callback]
1556
+ self.callbacks.append(callback)
1557
+ else:
1558
+ self.callbacks.append(callback)
1559
+
1560
+ def get_store(self) -> Optional[LightningStore]:
1561
+ """Get the store used by the proxy.
1562
+
1563
+ Returns:
1564
+ The store used by the proxy.
1565
+ """
1566
+ return self.store
1567
+
1568
+ def set_store(self, store: LightningStore) -> None:
1569
+ """Set the store for the proxy.
1570
+
1571
+ Args:
1572
+ store: The store to use for the proxy.
1573
+ """
1574
+ self.store = store
1575
+
1576
+ def update_model_list(self, model_list: List[ModelConfig]) -> None:
1577
+ """Replace the in-memory model list.
1578
+
1579
+ Args:
1580
+ model_list: New list of model entries.
1581
+ """
1582
+ self.model_list = model_list
1583
+ logger.info(f"Updating LLMProxy model list to: {model_list}")
1584
+ # Do nothing if the server is not running.
1585
+
1586
+ def initialize(self):
1587
+ """Initialize global middleware and LiteLLM callbacks.
1588
+
1589
+ Installs:
1590
+
1591
+ * A FastAPI middleware that rewrites /rollout/{rid}/attempt/{aid}/... paths,
1592
+ injects rollout/attempt/sequence headers, and forwards downstream.
1593
+ * LiteLLM callbacks for token ids and OpenTelemetry export.
1594
+
1595
+ The middleware can only be installed once because once the FastAPI app has started,
1596
+ the middleware cannot be changed any more.
1597
+
1598
+ This function does not start any server. It only wires global hooks.
1599
+ """
1600
+ if self.store is None:
1601
+ raise ValueError("Store is not set. Please set the store before initializing the LLMProxy.")
1602
+
1603
+ if _global_llm_proxy is not None:
1604
+ logger.warning("A global LLMProxy is already set. Overwriting it with the new instance.")
1605
+
1606
+ # Patch for LiteLLM v1.80.6+: https://github.com/BerriAI/litellm/issues/17243
1607
+ os.environ["USE_OTEL_LITELLM_REQUEST_SPAN"] = "true"
1608
+
1609
+ # Set the global LLMProxy reference for middleware/exporter access.
1610
+ set_active_llm_proxy(self)
1611
+
1612
+ # Install middleware if it's not already installed.
1613
+ installation_status: Dict[Any, bool] = {}
1614
+ for mw in app.user_middleware:
1615
+ installation_status[mw.cls] = True
1616
+
1617
+ for mw in self.middlewares:
1618
+ if mw not in installation_status:
1619
+ logger.info(f"Adding middleware {mw} to the FastAPI app.")
1620
+ app.add_middleware(mw)
1621
+ else:
1622
+ logger.info(f"Middleware {mw} is already installed. Will not install a new one.")
1623
+
1624
+ if not initialize_llm_callbacks(self.callbacks, otlp_endpoint=self._otlp_endpoint, otlp_headers=self._otlp_headers):
1625
+ # If it's not the first time to initialize the callbacks, also
1626
+ # reset LiteLLM's logging worker so its asyncio.Queue binds to the new loop.
1627
+ _reset_litellm_logging_worker()
1628
+
1629
+ @asynccontextmanager
1630
+ async def _serve_context(self) -> AsyncGenerator[None, None]:
1631
+ """Context manager to serve the proxy server.
1632
+
1633
+ See [`start`][mantisdk.LLMProxy.start] and [`stop`][mantisdk.LLMProxy.stop] for more details.
1634
+ """
1635
+
1636
+ if not self.store:
1637
+ raise ValueError("Store is not set. Please set the store before starting the LLMProxy.")
1638
+
1639
+ # Initialize global middleware and callbacks.
1640
+ self.initialize()
1641
+
1642
+ # Persist a temp worker config for LiteLLM and point the proxy at it.
1643
+ self._config_file = tempfile.NamedTemporaryFile(suffix=".yaml", delete=False).name
1644
+ with open(self._config_file, "w") as fp:
1645
+ yaml.safe_dump(
1646
+ {
1647
+ "model_list": self.model_list,
1648
+ **self.litellm_config,
1649
+ },
1650
+ fp,
1651
+ )
1652
+
1653
+ save_worker_config(config=self._config_file)
1654
+
1655
+ # NOTE: When running the _serve_context in current process, you might encounter the following problems:
1656
+ # Problem 1: in litellm worker, <Queue at 0x70f1d028cd90 maxsize=50000> is bound to a different event loop
1657
+ # Problem 2: Proxy has conflicted opentelemetry setup with the main process.
1658
+
1659
+ # Ready
1660
+ logger.info("LLMProxy preparation is done. Will start the server.")
1661
+ yield
1662
+
1663
+ # Clean up
1664
+
1665
+ logger.info("LLMProxy server is cleaning up.")
1666
+
1667
+ # Remove worker config to avoid stale references.
1668
+ if self._config_file and os.path.exists(self._config_file):
1669
+ os.unlink(self._config_file)
1670
+
1671
+ logger.info("LLMProxy server finishes.")
1672
+
1673
+ async def start(self):
1674
+ """Start the proxy server thread and initialize global wiring.
1675
+
1676
+ Side effects:
1677
+
1678
+ * Sets the module-level global store for middleware/exporter access.
1679
+ * Calls `initialize()` once to register middleware and callbacks.
1680
+ * Writes a temporary YAML config consumed by LiteLLM worker.
1681
+ * Launches uvicorn in a daemon thread and waits for readiness.
1682
+ """
1683
+ # Refresh the serve context
1684
+ self.server_launcher.serve_context = self._serve_context()
1685
+
1686
+ if self.store is None:
1687
+ raise ValueError("Store is not set. Please set the store before starting the LLMProxy.")
1688
+
1689
+ store_capabilities = self.store.capabilities
1690
+ if self.server_launcher.args.launch_mode == "mp" and not store_capabilities.get("zero_copy", False):
1691
+ raise RuntimeError(
1692
+ "The store does not support zero-copy. Please use another store, or use asyncio or thread mode to launch the server."
1693
+ )
1694
+ elif self.server_launcher.args.launch_mode == "thread" and not store_capabilities.get("thread_safe", False):
1695
+ raise RuntimeError(
1696
+ "The store is not thread-safe. Please use another store, or use asyncio mode to launch the server."
1697
+ )
1698
+ elif self.server_launcher.args.launch_mode == "asyncio" and not store_capabilities.get("async_safe", False):
1699
+ raise RuntimeError("The store is not async-safe. Please use another store.")
1700
+
1701
+ logger.info(
1702
+ f"Starting LLMProxy server in {self.server_launcher.args.launch_mode} mode with store capabilities: {store_capabilities}"
1703
+ )
1704
+
1705
+ await self.server_launcher.start()
1706
+
1707
+ async def stop(self):
1708
+ """Stop the proxy server and clean up temporary artifacts.
1709
+
1710
+ This is a best-effort graceful shutdown with a bounded join timeout.
1711
+ """
1712
+ if not self.is_running():
1713
+ logger.warning("LLMProxy is not running. Nothing to stop.")
1714
+ return
1715
+
1716
+ await self.server_launcher.stop()
1717
+
1718
+ async def restart(self, *, _port: int | None = None) -> None:
1719
+ """Restart the proxy if running, else start it.
1720
+
1721
+ Convenience wrapper calling `stop()` followed by `start()`.
1722
+ """
1723
+ logger.info("Restarting LLMProxy server...")
1724
+ if self.is_running():
1725
+ await self.stop()
1726
+ if _port is not None:
1727
+ self.server_launcher_args.port = _port
1728
+ await self.start()
1729
+
1730
+ def is_running(self) -> bool:
1731
+ """Return whether the uvicorn server is active.
1732
+
1733
+ Returns:
1734
+ bool: True if server was started and did not signal exit.
1735
+ """
1736
+ return self.server_launcher.is_running()
1737
+
1738
+ def as_resource(
1739
+ self,
1740
+ rollout_id: str | None = None,
1741
+ attempt_id: str | None = None,
1742
+ model: str | None = None,
1743
+ sampling_parameters: Dict[str, Any] | None = None,
1744
+ ) -> LLM:
1745
+ """Create an `LLM` resource pointing at this proxy with rollout context.
1746
+
1747
+ The returned endpoint is:
1748
+ `http://{host}:{port}/rollout/{rollout_id}/attempt/{attempt_id}`
1749
+
1750
+ Args:
1751
+ rollout_id: Rollout identifier used for span attribution. If None, will instantiate a ProxyLLM resource.
1752
+ attempt_id: Attempt identifier used for span attribution. If None, will instantiate a ProxyLLM resource.
1753
+ model: Logical model name to use. If omitted and exactly one model
1754
+ is configured or all models have the same name, that model is used.
1755
+ sampling_parameters: Optional default sampling parameters.
1756
+
1757
+ Returns:
1758
+ LLM: Configured resource ready for OpenAI-compatible calls.
1759
+
1760
+ Raises:
1761
+ ValueError: If `model` is omitted and zero or multiple models are configured.
1762
+ """
1763
+ if model is None:
1764
+ if len(self.model_list) == 1:
1765
+ model = self.model_list[0]["model_name"]
1766
+ elif len(self.model_list) == 0:
1767
+ raise ValueError("No models found in model_list. Please specify the model.")
1768
+ else:
1769
+ first_model_name = self.model_list[0]["model_name"]
1770
+ if all(model_config["model_name"] == first_model_name for model_config in self.model_list):
1771
+ model = first_model_name
1772
+ else:
1773
+ raise ValueError(
1774
+ f"Multiple models found in model_list: {self.model_list}. Please specify the model."
1775
+ )
1776
+
1777
+ if rollout_id is None and attempt_id is None:
1778
+ return ProxyLLM(
1779
+ endpoint=self.server_launcher.access_endpoint,
1780
+ model=model,
1781
+ sampling_parameters=dict(sampling_parameters or {}),
1782
+ )
1783
+ elif rollout_id is not None and attempt_id is not None:
1784
+ return LLM(
1785
+ endpoint=f"{self.server_launcher.access_endpoint}/rollout/{rollout_id}/attempt/{attempt_id}",
1786
+ model=model,
1787
+ sampling_parameters=dict(sampling_parameters or {}),
1788
+ )
1789
+ else:
1790
+ raise ValueError("Either rollout_id and attempt_id must be provided, or neither.")
1791
+
1792
+
1793
+ _global_llm_proxy: Optional[LLMProxy] = None
1794
+ _callbacks_before_litellm_start: Optional[List[Any]] = None
1795
+
1796
+
1797
+ def get_active_llm_proxy() -> LLMProxy:
1798
+ """Get the current global LLMProxy instance.
1799
+
1800
+ Returns:
1801
+ Optional[LLMProxy]: The current LLMProxy if set, else None.
1802
+ """
1803
+ if _global_llm_proxy is None:
1804
+ raise ValueError("Global LLMProxy is not set. Please call llm_proxy.start() first.")
1805
+ return _global_llm_proxy
1806
+
1807
+
1808
+ def set_active_llm_proxy(proxy: LLMProxy) -> None:
1809
+ """Set the current global LLMProxy instance.
1810
+
1811
+ Args:
1812
+ proxy: The LLMProxy instance to set as global.
1813
+ """
1814
+ global _global_llm_proxy
1815
+ _global_llm_proxy = proxy
1816
+
1817
+
1818
+ def initialize_llm_callbacks(
1819
+ callback_classes: List[Type[CustomLogger]],
1820
+ otlp_endpoint: Optional[str] = None,
1821
+ otlp_headers: Optional[Dict[str, str]] = None,
1822
+ ) -> bool:
1823
+ """Restore `litellm.callbacks` to a state that is just initialized by mantisdk.
1824
+
1825
+ When litellm is restarted multiple times in the same process, more and more callbacks
1826
+ will be appended to `litellm.callbacks`, which may exceed the MAX_CALLBACKS limit.
1827
+ This function remembers the initial state of `litellm.callbacks` and always restore to that state.
1828
+
1829
+ Args:
1830
+ callback_classes: List of callback classes to register.
1831
+ otlp_endpoint: Optional OTLP endpoint URL for direct trace export.
1832
+ otlp_headers: Optional dict of HTTP headers for OTLP authentication.
1833
+
1834
+ Returns:
1835
+ Whether the callbacks are initialized for the first time.
1836
+ """
1837
+ global _callbacks_before_litellm_start
1838
+
1839
+ def _instantiate_callback(cls: Type[CustomLogger]) -> CustomLogger:
1840
+ """Instantiate callback with appropriate arguments."""
1841
+ if cls is LightningOpenTelemetry:
1842
+ return LightningOpenTelemetry(otlp_endpoint=otlp_endpoint, otlp_headers=otlp_headers)
1843
+ return cls()
1844
+
1845
+ if _callbacks_before_litellm_start is None:
1846
+ litellm.callbacks.extend([_instantiate_callback(cls) for cls in callback_classes]) # type: ignore
1847
+ _callbacks_before_litellm_start = [*litellm.callbacks] # type: ignore
1848
+ return True
1849
+
1850
+ else:
1851
+ # Put whatever is missing in the new callback classes to the existing callbacks.
1852
+ for cls in callback_classes:
1853
+ if not any(isinstance(cb, cls) for cb in _callbacks_before_litellm_start):
1854
+ logger.info(f"Adding missing callback {cls} to the existing callbacks.")
1855
+ _callbacks_before_litellm_start.append(_instantiate_callback(cls))
1856
+
1857
+ _reset_litellm_logging_callback_manager()
1858
+
1859
+ if LightningOpenTelemetry in callback_classes:
1860
+ # Check if tracer provider is malformed due to global tracer clear in tests.
1861
+ if not _check_tracer_provider():
1862
+ logger.warning(
1863
+ "Global tracer provider might have been cleared outside. Re-initializing OpenTelemetry callback."
1864
+ )
1865
+ _callbacks_before_litellm_start = [
1866
+ cb for cb in _callbacks_before_litellm_start if not isinstance(cb, LightningOpenTelemetry)
1867
+ ] + [LightningOpenTelemetry(otlp_endpoint=otlp_endpoint, otlp_headers=otlp_headers)]
1868
+ else:
1869
+ logger.debug("Global tracer provider is valid. Reusing existing OpenTelemetry callback.")
1870
+ # Otherwise, we just skip the check for opentelemetry and use the existing callback.
1871
+
1872
+ litellm.callbacks.clear() # type: ignore
1873
+ litellm.callbacks.extend(_callbacks_before_litellm_start) # type: ignore
1874
+ return False
1875
+
1876
+
1877
+ def _check_tracer_provider() -> bool:
1878
+ """Check if the global tracer provider is properly initialized.
1879
+
1880
+ We don't guarantee the tracer provider is our tracer provider.
1881
+
1882
+ Returns:
1883
+ bool: True if the tracer provider is valid, else False.
1884
+ """
1885
+ if (
1886
+ hasattr(trace_api, "_TRACER_PROVIDER")
1887
+ and trace_api._TRACER_PROVIDER is not None # pyright: ignore[reportPrivateUsage]
1888
+ ):
1889
+ return True
1890
+ return False