mantisdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mantisdk might be problematic. Click here for more details.
- mantisdk/__init__.py +22 -0
- mantisdk/adapter/__init__.py +15 -0
- mantisdk/adapter/base.py +94 -0
- mantisdk/adapter/messages.py +270 -0
- mantisdk/adapter/triplet.py +1028 -0
- mantisdk/algorithm/__init__.py +39 -0
- mantisdk/algorithm/apo/__init__.py +5 -0
- mantisdk/algorithm/apo/apo.py +889 -0
- mantisdk/algorithm/apo/prompts/apply_edit_variant01.poml +22 -0
- mantisdk/algorithm/apo/prompts/apply_edit_variant02.poml +18 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant01.poml +18 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant02.poml +16 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant03.poml +107 -0
- mantisdk/algorithm/base.py +162 -0
- mantisdk/algorithm/decorator.py +264 -0
- mantisdk/algorithm/fast.py +250 -0
- mantisdk/algorithm/gepa/__init__.py +59 -0
- mantisdk/algorithm/gepa/adapter.py +459 -0
- mantisdk/algorithm/gepa/gepa.py +364 -0
- mantisdk/algorithm/gepa/lib/__init__.py +18 -0
- mantisdk/algorithm/gepa/lib/adapters/README.md +12 -0
- mantisdk/algorithm/gepa/lib/adapters/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/README.md +341 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/__init__.py +1 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/anymaths_adapter.py +174 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/requirements.txt +1 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/README.md +0 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/default_adapter.py +209 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/README.md +7 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/dspy_adapter.py +307 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/README.md +99 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/dspy_program_proposal_signature.py +137 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/full_program_adapter.py +266 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/GEPA_RAG.md +621 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/__init__.py +56 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/evaluation_metrics.py +226 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/generic_rag_adapter.py +496 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/rag_pipeline.py +238 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_store_interface.py +212 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/__init__.py +2 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/chroma_store.py +196 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/lancedb_store.py +422 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/milvus_store.py +409 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/qdrant_store.py +368 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/weaviate_store.py +418 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/README.md +552 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/__init__.py +37 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_adapter.py +705 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_client.py +364 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/README.md +9 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/terminal_bench_adapter.py +217 -0
- mantisdk/algorithm/gepa/lib/api.py +375 -0
- mantisdk/algorithm/gepa/lib/core/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/core/adapter.py +180 -0
- mantisdk/algorithm/gepa/lib/core/data_loader.py +74 -0
- mantisdk/algorithm/gepa/lib/core/engine.py +356 -0
- mantisdk/algorithm/gepa/lib/core/result.py +233 -0
- mantisdk/algorithm/gepa/lib/core/state.py +636 -0
- mantisdk/algorithm/gepa/lib/examples/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/examples/aime.py +24 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/eval_default.py +111 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/instruction_prompt.txt +9 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/optimal_prompt.txt +24 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/train_anymaths.py +177 -0
- mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/arc_agi.ipynb +25705 -0
- mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/example.ipynb +348 -0
- mantisdk/algorithm/gepa/lib/examples/mcp_adapter/__init__.py +4 -0
- mantisdk/algorithm/gepa/lib/examples/mcp_adapter/mcp_optimization_example.py +455 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/RAG_GUIDE.md +613 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/__init__.py +9 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/rag_optimization.py +824 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/requirements-rag.txt +29 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/instruction_prompt.txt +16 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/terminus.txt +9 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/train_terminus.py +161 -0
- mantisdk/algorithm/gepa/lib/gepa_utils.py +117 -0
- mantisdk/algorithm/gepa/lib/logging/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/logging/experiment_tracker.py +187 -0
- mantisdk/algorithm/gepa/lib/logging/logger.py +75 -0
- mantisdk/algorithm/gepa/lib/logging/utils.py +103 -0
- mantisdk/algorithm/gepa/lib/proposer/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/proposer/base.py +31 -0
- mantisdk/algorithm/gepa/lib/proposer/merge.py +357 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/base.py +49 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/reflective_mutation.py +176 -0
- mantisdk/algorithm/gepa/lib/py.typed +0 -0
- mantisdk/algorithm/gepa/lib/strategies/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/strategies/batch_sampler.py +77 -0
- mantisdk/algorithm/gepa/lib/strategies/candidate_selector.py +50 -0
- mantisdk/algorithm/gepa/lib/strategies/component_selector.py +36 -0
- mantisdk/algorithm/gepa/lib/strategies/eval_policy.py +64 -0
- mantisdk/algorithm/gepa/lib/strategies/instruction_proposal.py +127 -0
- mantisdk/algorithm/gepa/lib/utils/__init__.py +10 -0
- mantisdk/algorithm/gepa/lib/utils/stop_condition.py +196 -0
- mantisdk/algorithm/gepa/tracing.py +105 -0
- mantisdk/algorithm/utils.py +177 -0
- mantisdk/algorithm/verl/__init__.py +5 -0
- mantisdk/algorithm/verl/interface.py +202 -0
- mantisdk/cli/__init__.py +56 -0
- mantisdk/cli/prometheus.py +115 -0
- mantisdk/cli/store.py +131 -0
- mantisdk/cli/vllm.py +29 -0
- mantisdk/client.py +408 -0
- mantisdk/config.py +348 -0
- mantisdk/emitter/__init__.py +43 -0
- mantisdk/emitter/annotation.py +370 -0
- mantisdk/emitter/exception.py +54 -0
- mantisdk/emitter/message.py +61 -0
- mantisdk/emitter/object.py +117 -0
- mantisdk/emitter/reward.py +320 -0
- mantisdk/env_var.py +156 -0
- mantisdk/execution/__init__.py +15 -0
- mantisdk/execution/base.py +64 -0
- mantisdk/execution/client_server.py +443 -0
- mantisdk/execution/events.py +69 -0
- mantisdk/execution/inter_process.py +16 -0
- mantisdk/execution/shared_memory.py +282 -0
- mantisdk/instrumentation/__init__.py +119 -0
- mantisdk/instrumentation/agentops.py +314 -0
- mantisdk/instrumentation/agentops_langchain.py +45 -0
- mantisdk/instrumentation/litellm.py +83 -0
- mantisdk/instrumentation/vllm.py +81 -0
- mantisdk/instrumentation/weave.py +500 -0
- mantisdk/litagent/__init__.py +11 -0
- mantisdk/litagent/decorator.py +536 -0
- mantisdk/litagent/litagent.py +252 -0
- mantisdk/llm_proxy.py +1890 -0
- mantisdk/logging.py +370 -0
- mantisdk/reward.py +7 -0
- mantisdk/runner/__init__.py +11 -0
- mantisdk/runner/agent.py +845 -0
- mantisdk/runner/base.py +182 -0
- mantisdk/runner/legacy.py +309 -0
- mantisdk/semconv.py +170 -0
- mantisdk/server.py +401 -0
- mantisdk/store/__init__.py +23 -0
- mantisdk/store/base.py +897 -0
- mantisdk/store/client_server.py +2092 -0
- mantisdk/store/collection/__init__.py +30 -0
- mantisdk/store/collection/base.py +587 -0
- mantisdk/store/collection/memory.py +970 -0
- mantisdk/store/collection/mongo.py +1412 -0
- mantisdk/store/collection_based.py +1823 -0
- mantisdk/store/insight.py +648 -0
- mantisdk/store/listener.py +58 -0
- mantisdk/store/memory.py +396 -0
- mantisdk/store/mongo.py +165 -0
- mantisdk/store/sqlite.py +3 -0
- mantisdk/store/threading.py +357 -0
- mantisdk/store/utils.py +142 -0
- mantisdk/tracer/__init__.py +16 -0
- mantisdk/tracer/agentops.py +242 -0
- mantisdk/tracer/base.py +287 -0
- mantisdk/tracer/dummy.py +106 -0
- mantisdk/tracer/otel.py +555 -0
- mantisdk/tracer/weave.py +677 -0
- mantisdk/trainer/__init__.py +6 -0
- mantisdk/trainer/init_utils.py +263 -0
- mantisdk/trainer/legacy.py +367 -0
- mantisdk/trainer/registry.py +12 -0
- mantisdk/trainer/trainer.py +618 -0
- mantisdk/types/__init__.py +6 -0
- mantisdk/types/core.py +553 -0
- mantisdk/types/resources.py +204 -0
- mantisdk/types/tracer.py +515 -0
- mantisdk/types/tracing.py +218 -0
- mantisdk/utils/__init__.py +1 -0
- mantisdk/utils/id.py +18 -0
- mantisdk/utils/metrics.py +1025 -0
- mantisdk/utils/otel.py +578 -0
- mantisdk/utils/otlp.py +536 -0
- mantisdk/utils/server_launcher.py +1045 -0
- mantisdk/utils/system_snapshot.py +81 -0
- mantisdk/verl/__init__.py +8 -0
- mantisdk/verl/__main__.py +6 -0
- mantisdk/verl/async_server.py +46 -0
- mantisdk/verl/config.yaml +27 -0
- mantisdk/verl/daemon.py +1154 -0
- mantisdk/verl/dataset.py +44 -0
- mantisdk/verl/entrypoint.py +248 -0
- mantisdk/verl/trainer.py +549 -0
- mantisdk-0.1.0.dist-info/METADATA +119 -0
- mantisdk-0.1.0.dist-info/RECORD +190 -0
- mantisdk-0.1.0.dist-info/WHEEL +4 -0
- mantisdk-0.1.0.dist-info/entry_points.txt +2 -0
- mantisdk-0.1.0.dist-info/licenses/LICENSE +19 -0
mantisdk/llm_proxy.py
ADDED
|
@@ -0,0 +1,1890 @@
|
|
|
1
|
+
# Copyright (c) Microsoft. All rights reserved.
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import ast
|
|
6
|
+
import asyncio
|
|
7
|
+
import json
|
|
8
|
+
import logging
|
|
9
|
+
import os
|
|
10
|
+
import re
|
|
11
|
+
import tempfile
|
|
12
|
+
import threading
|
|
13
|
+
import time
|
|
14
|
+
from contextlib import asynccontextmanager
|
|
15
|
+
from contextvars import ContextVar
|
|
16
|
+
from datetime import datetime
|
|
17
|
+
from typing import (
|
|
18
|
+
Any,
|
|
19
|
+
AsyncGenerator,
|
|
20
|
+
Awaitable,
|
|
21
|
+
Callable,
|
|
22
|
+
Dict,
|
|
23
|
+
Iterable,
|
|
24
|
+
List,
|
|
25
|
+
Literal,
|
|
26
|
+
Optional,
|
|
27
|
+
Sequence,
|
|
28
|
+
Tuple,
|
|
29
|
+
Type,
|
|
30
|
+
TypedDict,
|
|
31
|
+
Union,
|
|
32
|
+
cast,
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
import litellm
|
|
36
|
+
import opentelemetry.trace as trace_api
|
|
37
|
+
import yaml
|
|
38
|
+
from fastapi import Request, Response
|
|
39
|
+
from fastapi.responses import StreamingResponse
|
|
40
|
+
from litellm.integrations.custom_logger import CustomLogger
|
|
41
|
+
from litellm.integrations.opentelemetry import OpenTelemetry, OpenTelemetryConfig
|
|
42
|
+
from litellm.proxy.proxy_server import app, save_worker_config # pyright: ignore[reportUnknownVariableType]
|
|
43
|
+
from litellm.types.utils import CallTypes
|
|
44
|
+
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
|
|
45
|
+
from opentelemetry.sdk.resources import Resource
|
|
46
|
+
from opentelemetry.sdk.trace import ReadableSpan, SpanContext
|
|
47
|
+
from opentelemetry.sdk.trace.export import SpanExporter, SpanExportResult
|
|
48
|
+
from opentelemetry.trace import Link, Status
|
|
49
|
+
from opentelemetry.util.types import Attributes
|
|
50
|
+
from starlette.middleware.base import BaseHTTPMiddleware
|
|
51
|
+
from starlette.types import Scope
|
|
52
|
+
|
|
53
|
+
from mantisdk.semconv import LightningResourceAttributes
|
|
54
|
+
from mantisdk.types import LLM, ProxyLLM
|
|
55
|
+
from mantisdk.utils.server_launcher import (
|
|
56
|
+
LaunchMode,
|
|
57
|
+
PythonServerLauncher,
|
|
58
|
+
PythonServerLauncherArgs,
|
|
59
|
+
noop_context,
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
from .store.base import LightningStore
|
|
63
|
+
|
|
64
|
+
logger = logging.getLogger(__name__)
|
|
65
|
+
|
|
66
|
+
# Context variable to store HTTP request headers for LiteLLM callback access
|
|
67
|
+
_request_headers_context: ContextVar[Optional[Dict[str, str]]] = ContextVar("request_headers", default=None)
|
|
68
|
+
|
|
69
|
+
__all__ = [
|
|
70
|
+
"LLMProxy",
|
|
71
|
+
]
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class ModelConfig(TypedDict):
|
|
75
|
+
"""LiteLLM model registration entry.
|
|
76
|
+
|
|
77
|
+
This mirrors the items in LiteLLM's `model_list` section.
|
|
78
|
+
|
|
79
|
+
Attributes:
|
|
80
|
+
model_name: Logical model name exposed by the proxy.
|
|
81
|
+
litellm_params: Parameters passed to LiteLLM for this model
|
|
82
|
+
(e.g., backend model id, api_base, additional options).
|
|
83
|
+
""" # Google style kept concise.
|
|
84
|
+
|
|
85
|
+
model_name: str
|
|
86
|
+
litellm_params: Dict[str, Any]
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def _get_pre_call_data(args: Any, kwargs: Any) -> Dict[str, Any]:
|
|
90
|
+
"""Extract LiteLLM request payload from hook args.
|
|
91
|
+
|
|
92
|
+
The LiteLLM logger hooks receive `(*args, **kwargs)` whose third positional
|
|
93
|
+
argument or `data=` kwarg contains the request payload.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
args: Positional arguments from the hook.
|
|
97
|
+
kwargs: Keyword arguments from the hook.
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
The request payload dict.
|
|
101
|
+
|
|
102
|
+
Raises:
|
|
103
|
+
ValueError: If the payload cannot be located or is not a dict.
|
|
104
|
+
"""
|
|
105
|
+
if kwargs.get("data"):
|
|
106
|
+
data = kwargs["data"]
|
|
107
|
+
elif len(args) >= 3:
|
|
108
|
+
data = args[2]
|
|
109
|
+
else:
|
|
110
|
+
raise ValueError(f"Unable to get request data from args or kwargs: {args}, {kwargs}")
|
|
111
|
+
if not isinstance(data, dict):
|
|
112
|
+
raise ValueError(f"Request data is not a dictionary: {data}")
|
|
113
|
+
return cast(Dict[str, Any], data)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def _reset_litellm_logging_worker() -> None:
|
|
117
|
+
"""Reset LiteLLM's global logging worker to the current event loop.
|
|
118
|
+
|
|
119
|
+
LiteLLM keeps a module-level ``GLOBAL_LOGGING_WORKER`` singleton that owns an
|
|
120
|
+
``asyncio.Queue``. The queue is bound to the event loop where it was created.
|
|
121
|
+
When the proxy is restarted, Uvicorn spins up a brand new event loop in a new
|
|
122
|
+
thread. If the existing logging worker (and its queue) are reused, LiteLLM
|
|
123
|
+
raises ``RuntimeError: <Queue ...> is bound to a different event loop`` the
|
|
124
|
+
next time it tries to log. Recreating the worker ensures that LiteLLM will
|
|
125
|
+
lazily initialise a fresh queue on the new loop.
|
|
126
|
+
"""
|
|
127
|
+
|
|
128
|
+
# ``GLOBAL_LOGGING_WORKER`` is imported in a few LiteLLM modules at runtime.
|
|
129
|
+
# Update any already-imported references so future calls use the fresh worker.
|
|
130
|
+
try:
|
|
131
|
+
import litellm.utils as litellm_utils
|
|
132
|
+
from litellm.litellm_core_utils import logging_worker as litellm_logging_worker
|
|
133
|
+
|
|
134
|
+
litellm_logging_worker.GLOBAL_LOGGING_WORKER = litellm_logging_worker.LoggingWorker()
|
|
135
|
+
litellm_utils.GLOBAL_LOGGING_WORKER = litellm_logging_worker.GLOBAL_LOGGING_WORKER # type: ignore[reportAttributeAccessIssue]
|
|
136
|
+
except Exception: # pragma: no cover - best-effort hygiene
|
|
137
|
+
logger.warning("Unable to propagate LiteLLM logging worker reset.", exc_info=True)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def _reset_litellm_logging_callback_manager() -> None:
|
|
141
|
+
"""Reset LiteLLM's global callback manager.
|
|
142
|
+
|
|
143
|
+
To get rid of the warning message: "Cannot add callback - would exceed MAX_CALLBACKS limit of 30."
|
|
144
|
+
when litellm is restarted multiple times in the same process.
|
|
145
|
+
|
|
146
|
+
It does not respect existing input/output callbacks.
|
|
147
|
+
"""
|
|
148
|
+
|
|
149
|
+
try:
|
|
150
|
+
litellm.logging_callback_manager._reset_all_callbacks() # pyright: ignore[reportPrivateUsage]
|
|
151
|
+
except Exception: # pragma: no cover - best-effort hygiene
|
|
152
|
+
logger.warning("Unable to reset LiteLLM logging callback manager.", exc_info=True)
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
class AddReturnTokenIds(CustomLogger):
|
|
156
|
+
"""LiteLLM logger hook to request token ids from vLLM.
|
|
157
|
+
|
|
158
|
+
This mutates the outgoing request payload to include `return_token_ids=True`
|
|
159
|
+
for backends that support token id return (e.g., vLLM).
|
|
160
|
+
|
|
161
|
+
See also:
|
|
162
|
+
[vLLM PR #22587](https://github.com/vllm-project/vllm/pull/22587)
|
|
163
|
+
"""
|
|
164
|
+
|
|
165
|
+
async def async_pre_call_hook(self, *args: Any, **kwargs: Any) -> Optional[Union[Exception, str, Dict[str, Any]]]:
|
|
166
|
+
"""Async pre-call hook to adjust request payload.
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
args: Positional args from LiteLLM.
|
|
170
|
+
kwargs: Keyword args from LiteLLM.
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
Either an updated payload dict or an Exception to short-circuit.
|
|
174
|
+
"""
|
|
175
|
+
try:
|
|
176
|
+
data = _get_pre_call_data(args, kwargs)
|
|
177
|
+
except Exception as e:
|
|
178
|
+
return e
|
|
179
|
+
|
|
180
|
+
# Ensure token ids are requested from the backend when supported.
|
|
181
|
+
return {**data, "return_token_ids": True}
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
class AddLogprobs(CustomLogger):
|
|
185
|
+
"""LiteLLM logger hook to request logprobs from vLLM.
|
|
186
|
+
|
|
187
|
+
This mutates the outgoing request payload to include `logprobs=1`
|
|
188
|
+
for backends that support logprobs return (e.g., vLLM).
|
|
189
|
+
"""
|
|
190
|
+
|
|
191
|
+
async def async_pre_call_hook(self, *args: Any, **kwargs: Any) -> Optional[Union[Exception, str, Dict[str, Any]]]:
|
|
192
|
+
"""Async pre-call hook to adjust request payload."""
|
|
193
|
+
try:
|
|
194
|
+
data = _get_pre_call_data(args, kwargs)
|
|
195
|
+
except Exception as e:
|
|
196
|
+
return e
|
|
197
|
+
|
|
198
|
+
# Ensure logprobs are requested from the backend when supported.
|
|
199
|
+
return {**data, "logprobs": 1}
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
class SpanWithExtraAttributes(ReadableSpan):
|
|
203
|
+
"""Wrapper around ReadableSpan that adds extra span attributes.
|
|
204
|
+
|
|
205
|
+
Since ReadableSpan is immutable, this wrapper intercepts the attributes
|
|
206
|
+
property to include additional attributes for Langfuse integration
|
|
207
|
+
(environment, tags, etc.).
|
|
208
|
+
"""
|
|
209
|
+
|
|
210
|
+
def __init__(self, wrapped_span: ReadableSpan, extra_attributes: Dict[str, Any]):
|
|
211
|
+
"""Initialize wrapper with a span and extra attributes to inject.
|
|
212
|
+
|
|
213
|
+
Args:
|
|
214
|
+
wrapped_span: The original ReadableSpan to wrap.
|
|
215
|
+
extra_attributes: Dictionary of extra attributes to add.
|
|
216
|
+
These should use Langfuse conventions (e.g., "langfuse.environment").
|
|
217
|
+
"""
|
|
218
|
+
self._wrapped = wrapped_span
|
|
219
|
+
self._extra_attributes = extra_attributes
|
|
220
|
+
|
|
221
|
+
@property
|
|
222
|
+
def name(self) -> str:
|
|
223
|
+
return self._wrapped.name
|
|
224
|
+
|
|
225
|
+
@property
|
|
226
|
+
def context(self) -> Optional[SpanContext]:
|
|
227
|
+
return self._wrapped.context
|
|
228
|
+
|
|
229
|
+
def get_span_context(self) -> SpanContext:
|
|
230
|
+
return self._wrapped.get_span_context()
|
|
231
|
+
|
|
232
|
+
@property
|
|
233
|
+
def parent(self) -> Optional[SpanContext]:
|
|
234
|
+
return self._wrapped.parent
|
|
235
|
+
|
|
236
|
+
@property
|
|
237
|
+
def start_time(self) -> int:
|
|
238
|
+
return self._wrapped.start_time
|
|
239
|
+
|
|
240
|
+
@property
|
|
241
|
+
def end_time(self) -> int:
|
|
242
|
+
return self._wrapped.end_time
|
|
243
|
+
|
|
244
|
+
@property
|
|
245
|
+
def status(self) -> Status:
|
|
246
|
+
return self._wrapped.status
|
|
247
|
+
|
|
248
|
+
@property
|
|
249
|
+
def attributes(self) -> Attributes:
|
|
250
|
+
"""Return original attributes merged with extra attributes."""
|
|
251
|
+
original_attrs = self._wrapped.attributes or {}
|
|
252
|
+
# Create a merged dict with original attrs and extra attrs
|
|
253
|
+
merged = dict(original_attrs)
|
|
254
|
+
merged.update(self._extra_attributes)
|
|
255
|
+
return merged
|
|
256
|
+
|
|
257
|
+
@property
|
|
258
|
+
def events(self) -> tuple:
|
|
259
|
+
return self._wrapped.events
|
|
260
|
+
|
|
261
|
+
@property
|
|
262
|
+
def links(self) -> tuple:
|
|
263
|
+
return self._wrapped.links
|
|
264
|
+
|
|
265
|
+
@property
|
|
266
|
+
def resource(self) -> Resource:
|
|
267
|
+
return self._wrapped.resource
|
|
268
|
+
|
|
269
|
+
@property
|
|
270
|
+
def instrumentation_scope(self):
|
|
271
|
+
return self._wrapped.instrumentation_scope
|
|
272
|
+
|
|
273
|
+
# Also expose _resource for direct modification if needed
|
|
274
|
+
@property
|
|
275
|
+
def _resource(self) -> Resource:
|
|
276
|
+
return self._wrapped._resource # pyright: ignore[reportPrivateUsage]
|
|
277
|
+
|
|
278
|
+
@_resource.setter
|
|
279
|
+
def _resource(self, value: Resource):
|
|
280
|
+
self._wrapped._resource = value # pyright: ignore[reportPrivateUsage]
|
|
281
|
+
|
|
282
|
+
def __getattr__(self, name: str):
|
|
283
|
+
"""Delegate all other attribute access to the wrapped span.
|
|
284
|
+
|
|
285
|
+
This ensures OTLPSpanExporter can access all private attributes (_events, _links, etc.)
|
|
286
|
+
that it needs for serialization without us having to enumerate them all.
|
|
287
|
+
"""
|
|
288
|
+
return getattr(self._wrapped, name)
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
class LightningSpanExporter(SpanExporter):
|
|
292
|
+
"""Buffered OTEL span exporter with subtree flushing and training-store sink.
|
|
293
|
+
|
|
294
|
+
Design:
|
|
295
|
+
|
|
296
|
+
* Spans are buffered until a root span's entire subtree is available.
|
|
297
|
+
* A private event loop on a daemon thread runs async flush logic.
|
|
298
|
+
* Rollout/attempt/sequence metadata is reconstructed by merging headers
|
|
299
|
+
from any span within a subtree.
|
|
300
|
+
|
|
301
|
+
Thread-safety:
|
|
302
|
+
|
|
303
|
+
* Buffer access is protected by a re-entrant lock.
|
|
304
|
+
* Export is synchronous to the caller yet schedules an async flush on the
|
|
305
|
+
internal loop, then waits for completion.
|
|
306
|
+
"""
|
|
307
|
+
|
|
308
|
+
def __init__(
|
|
309
|
+
self,
|
|
310
|
+
_store: Optional[LightningStore] = None,
|
|
311
|
+
otlp_endpoint: Optional[str] = None,
|
|
312
|
+
otlp_headers: Optional[Dict[str, str]] = None,
|
|
313
|
+
):
|
|
314
|
+
self._store: Optional[LightningStore] = _store # this is only for testing purposes
|
|
315
|
+
self._otlp_endpoint: Optional[str] = otlp_endpoint # Direct OTLP export endpoint
|
|
316
|
+
self._buffer: List[ReadableSpan] = []
|
|
317
|
+
self._lock: Optional[threading.Lock] = None
|
|
318
|
+
self._loop_lock_pid: Optional[int] = None
|
|
319
|
+
|
|
320
|
+
# Single dedicated event loop running in a daemon thread.
|
|
321
|
+
# This decouples OTEL SDK threads from our async store I/O.
|
|
322
|
+
# Deferred creation until first use.
|
|
323
|
+
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
|
324
|
+
self._loop_thread: Optional[threading.Thread] = None
|
|
325
|
+
|
|
326
|
+
# Initialize OTLP exporter with custom endpoint and headers if provided
|
|
327
|
+
if otlp_endpoint:
|
|
328
|
+
self._otlp_exporter = OTLPSpanExporter(endpoint=otlp_endpoint, headers=otlp_headers or {})
|
|
329
|
+
else:
|
|
330
|
+
self._otlp_exporter = OTLPSpanExporter()
|
|
331
|
+
|
|
332
|
+
def _ensure_loop(self) -> asyncio.AbstractEventLoop:
|
|
333
|
+
"""Lazily initialize the event loop and thread on first use.
|
|
334
|
+
|
|
335
|
+
Returns:
|
|
336
|
+
asyncio.AbstractEventLoop: The initialized event loop.
|
|
337
|
+
"""
|
|
338
|
+
self._clear_loop_and_lock()
|
|
339
|
+
if self._loop is None:
|
|
340
|
+
self._loop = asyncio.new_event_loop()
|
|
341
|
+
self._loop_thread = threading.Thread(target=self._run_loop, name="LightningSpanExporterLoop", daemon=True)
|
|
342
|
+
self._loop_thread.start()
|
|
343
|
+
return self._loop
|
|
344
|
+
|
|
345
|
+
def _ensure_lock(self) -> threading.Lock:
|
|
346
|
+
"""Lazily initialize the lock on first use.
|
|
347
|
+
|
|
348
|
+
Returns:
|
|
349
|
+
threading.Lock: The initialized lock.
|
|
350
|
+
"""
|
|
351
|
+
self._clear_loop_and_lock()
|
|
352
|
+
if self._lock is None:
|
|
353
|
+
self._lock = threading.Lock()
|
|
354
|
+
return self._lock
|
|
355
|
+
|
|
356
|
+
def _clear_loop_and_lock(self) -> None:
|
|
357
|
+
"""Clear the loop and lock.
|
|
358
|
+
This happens if the exporter was used in a process then used in another process.
|
|
359
|
+
|
|
360
|
+
This should only happen in CI.
|
|
361
|
+
"""
|
|
362
|
+
if os.getpid() != self._loop_lock_pid:
|
|
363
|
+
logger.warning("Loop and lock are not owned by the current process. Clearing them.")
|
|
364
|
+
self._loop = None
|
|
365
|
+
self._loop_thread = None
|
|
366
|
+
self._lock = None
|
|
367
|
+
self._loop_lock_pid = os.getpid()
|
|
368
|
+
elif self._loop_lock_pid is None:
|
|
369
|
+
self._loop_lock_pid = os.getpid()
|
|
370
|
+
|
|
371
|
+
def _run_loop(self) -> None:
|
|
372
|
+
"""Run the private asyncio loop forever on the exporter thread."""
|
|
373
|
+
assert self._loop is not None, "Loop should be initialized before thread starts"
|
|
374
|
+
asyncio.set_event_loop(self._loop)
|
|
375
|
+
self._loop.run_forever()
|
|
376
|
+
|
|
377
|
+
def shutdown(self) -> None:
|
|
378
|
+
"""Shut down the exporter event loop.
|
|
379
|
+
|
|
380
|
+
Safe to call at process exit.
|
|
381
|
+
|
|
382
|
+
"""
|
|
383
|
+
if self._loop is None:
|
|
384
|
+
return
|
|
385
|
+
|
|
386
|
+
try:
|
|
387
|
+
|
|
388
|
+
def _stop():
|
|
389
|
+
assert self._loop is not None
|
|
390
|
+
self._loop.stop()
|
|
391
|
+
|
|
392
|
+
self._loop.call_soon_threadsafe(_stop)
|
|
393
|
+
if self._loop_thread is not None:
|
|
394
|
+
self._loop_thread.join(timeout=2.0)
|
|
395
|
+
self._loop.close()
|
|
396
|
+
except Exception:
|
|
397
|
+
logger.exception("Error during exporter shutdown")
|
|
398
|
+
|
|
399
|
+
def export(self, spans: Sequence[ReadableSpan]) -> SpanExportResult:
|
|
400
|
+
"""Export spans via buffered subtree flush.
|
|
401
|
+
|
|
402
|
+
Appends spans to the internal buffer, then triggers an async flush on the
|
|
403
|
+
private event loop. Blocks until that flush completes.
|
|
404
|
+
|
|
405
|
+
Args:
|
|
406
|
+
spans: Sequence of spans to export.
|
|
407
|
+
|
|
408
|
+
Returns:
|
|
409
|
+
SpanExportResult: SUCCESS on flush success, else FAILURE.
|
|
410
|
+
"""
|
|
411
|
+
# Buffer append under lock to protect against concurrent exporters.
|
|
412
|
+
with self._ensure_lock():
|
|
413
|
+
for span in spans:
|
|
414
|
+
self._buffer.append(span)
|
|
415
|
+
default_endpoint = self._otlp_exporter._endpoint # pyright: ignore[reportPrivateUsage]
|
|
416
|
+
try:
|
|
417
|
+
self._maybe_flush()
|
|
418
|
+
except Exception as e:
|
|
419
|
+
logger.exception("Export flush failed: %s", e)
|
|
420
|
+
return SpanExportResult.FAILURE
|
|
421
|
+
finally:
|
|
422
|
+
self._otlp_exporter._endpoint = default_endpoint # pyright: ignore[reportPrivateUsage]
|
|
423
|
+
|
|
424
|
+
return SpanExportResult.SUCCESS
|
|
425
|
+
|
|
426
|
+
def _get_job_id_from_store(self, store: Any) -> Optional[str]:
|
|
427
|
+
"""Get the job_id from the store's listeners (if any InsightTracker is attached)."""
|
|
428
|
+
if hasattr(store, "listeners"):
|
|
429
|
+
for listener in store.listeners:
|
|
430
|
+
if hasattr(listener, "job_id"):
|
|
431
|
+
return listener.job_id
|
|
432
|
+
return None
|
|
433
|
+
|
|
434
|
+
def _get_tracing_metadata_from_rollout(
|
|
435
|
+
self, store: Any, rollout_id: str
|
|
436
|
+
) -> Tuple[Optional[str], Optional[List[str]], Optional[str]]:
|
|
437
|
+
"""Fetch tracing metadata (environment, tags, session_id) from a rollout.
|
|
438
|
+
|
|
439
|
+
Args:
|
|
440
|
+
store: The LightningStore instance.
|
|
441
|
+
rollout_id: The rollout ID to fetch.
|
|
442
|
+
|
|
443
|
+
Returns:
|
|
444
|
+
Tuple of (environment, tags, session_id). All may be None if not set.
|
|
445
|
+
"""
|
|
446
|
+
print(f"[TracingMetadata] Fetching metadata for rollout {rollout_id}")
|
|
447
|
+
logger.info(f"[TracingMetadata] Fetching metadata for rollout {rollout_id}")
|
|
448
|
+
try:
|
|
449
|
+
loop = self._ensure_loop()
|
|
450
|
+
get_rollout_task = store.get_rollout_by_id(rollout_id)
|
|
451
|
+
fut = asyncio.run_coroutine_threadsafe(get_rollout_task, loop)
|
|
452
|
+
rollout = fut.result(timeout=5.0) # Short timeout for metadata fetch
|
|
453
|
+
|
|
454
|
+
if rollout is None:
|
|
455
|
+
logger.warning(f"[TracingMetadata] Rollout {rollout_id} not found in store")
|
|
456
|
+
return None, None, None
|
|
457
|
+
|
|
458
|
+
logger.info(f"[TracingMetadata] Rollout {rollout_id} found, metadata={rollout.metadata}")
|
|
459
|
+
|
|
460
|
+
if rollout.metadata is None:
|
|
461
|
+
logger.warning(f"[TracingMetadata] Rollout {rollout_id} has no metadata (None)")
|
|
462
|
+
return None, None, None
|
|
463
|
+
|
|
464
|
+
if not rollout.metadata:
|
|
465
|
+
logger.warning(f"[TracingMetadata] Rollout {rollout_id} has empty metadata dict")
|
|
466
|
+
return None, None, None
|
|
467
|
+
|
|
468
|
+
environment = rollout.metadata.get("environment")
|
|
469
|
+
tags = rollout.metadata.get("tags")
|
|
470
|
+
session_id = rollout.metadata.get("session_id")
|
|
471
|
+
|
|
472
|
+
logger.info(f"[TracingMetadata] Rollout {rollout_id}: environment={environment}, tags={tags}, session_id={session_id}")
|
|
473
|
+
|
|
474
|
+
return environment, tags, session_id
|
|
475
|
+
except Exception as e:
|
|
476
|
+
logger.warning(f"[TracingMetadata] Failed to fetch rollout metadata for {rollout_id}: {e}")
|
|
477
|
+
import traceback
|
|
478
|
+
traceback.print_exc()
|
|
479
|
+
return None, None, None
|
|
480
|
+
|
|
481
|
+
def _maybe_flush(self):
|
|
482
|
+
"""Flush ready subtrees from the buffer.
|
|
483
|
+
|
|
484
|
+
Strategy:
|
|
485
|
+
We consider a subtree "ready" if we can identify a root span. We
|
|
486
|
+
then take that root and all its descendants out of the buffer and
|
|
487
|
+
try to reconstruct rollout/attempt/sequence headers by merging any
|
|
488
|
+
span's `metadata.requester_custom_headers` within the subtree.
|
|
489
|
+
|
|
490
|
+
Span types:
|
|
491
|
+
- Rollout spans: Have rollout_id/attempt_id/sequence_id headers
|
|
492
|
+
- Job spans: No rollout context, tagged with job_id for experiment tracking
|
|
493
|
+
|
|
494
|
+
Direct OTLP mode:
|
|
495
|
+
When `otlp_endpoint` is configured, spans are exported directly to the
|
|
496
|
+
endpoint without requiring a store or header validation.
|
|
497
|
+
|
|
498
|
+
Raises:
|
|
499
|
+
None directly. Logs and skips malformed spans.
|
|
500
|
+
|
|
501
|
+
"""
|
|
502
|
+
# Iterate over current roots. Each iteration pops a whole subtree.
|
|
503
|
+
for root_span_id in self._get_root_span_ids():
|
|
504
|
+
subtree_spans = self._pop_subtrees(root_span_id)
|
|
505
|
+
if not subtree_spans:
|
|
506
|
+
continue
|
|
507
|
+
|
|
508
|
+
# Merge all custom headers found in the subtree.
|
|
509
|
+
# This must happen BEFORE the direct OTLP check so both paths can apply tags.
|
|
510
|
+
headers_merged: Dict[str, Any] = {}
|
|
511
|
+
|
|
512
|
+
for span in subtree_spans:
|
|
513
|
+
if span.attributes is None:
|
|
514
|
+
continue
|
|
515
|
+
headers_str = span.attributes.get("metadata.requester_custom_headers")
|
|
516
|
+
if headers_str is None:
|
|
517
|
+
continue
|
|
518
|
+
if not isinstance(headers_str, str):
|
|
519
|
+
logger.debug(f"metadata.requester_custom_headers is not a string: {headers_str}")
|
|
520
|
+
continue
|
|
521
|
+
if not headers_str.strip():
|
|
522
|
+
continue
|
|
523
|
+
try:
|
|
524
|
+
# Use literal_eval to parse the stringified dict safely.
|
|
525
|
+
headers = ast.literal_eval(headers_str)
|
|
526
|
+
except Exception as e:
|
|
527
|
+
logger.debug(f"Failed to parse metadata.requester_custom_headers: {e}")
|
|
528
|
+
continue
|
|
529
|
+
if isinstance(headers, dict):
|
|
530
|
+
headers_merged.update(cast(Dict[str, Any], headers))
|
|
531
|
+
|
|
532
|
+
# Extract rollout context if available
|
|
533
|
+
rollout_id = headers_merged.get("x-rollout-id")
|
|
534
|
+
attempt_id = headers_merged.get("x-attempt-id")
|
|
535
|
+
sequence_id = headers_merged.get("x-sequence-id")
|
|
536
|
+
|
|
537
|
+
# Determine if we're using OTLP export (either direct or store-based)
|
|
538
|
+
store = self._store or get_active_llm_proxy().get_store()
|
|
539
|
+
otlp_enabled = bool(self._otlp_endpoint) or (store and store.capabilities.get("otlp_traces", False))
|
|
540
|
+
|
|
541
|
+
has_rollout_context = (
|
|
542
|
+
rollout_id
|
|
543
|
+
and attempt_id
|
|
544
|
+
and sequence_id
|
|
545
|
+
and isinstance(rollout_id, str)
|
|
546
|
+
and isinstance(attempt_id, str)
|
|
547
|
+
and isinstance(sequence_id, str)
|
|
548
|
+
and sequence_id.isdigit()
|
|
549
|
+
)
|
|
550
|
+
|
|
551
|
+
if has_rollout_context:
|
|
552
|
+
# Rollout-scoped spans: tag with rollout/attempt/sequence
|
|
553
|
+
sequence_id_decimal = int(sequence_id)
|
|
554
|
+
print(f"[TracingMetadata] Processing rollout {rollout_id} with {len(subtree_spans)} spans, otlp_enabled={otlp_enabled}")
|
|
555
|
+
logger.info(f"[TracingMetadata] Processing rollout {rollout_id} with {len(subtree_spans)} spans, otlp_enabled={otlp_enabled}")
|
|
556
|
+
|
|
557
|
+
# Fetch tracing metadata (environment, tags, session_id) from the rollout
|
|
558
|
+
environment, tags, session_id = self._get_tracing_metadata_from_rollout(store, rollout_id)
|
|
559
|
+
logger.info(f"[TracingMetadata] Fetched: environment={environment}, tags={tags}, session_id={session_id}")
|
|
560
|
+
|
|
561
|
+
if otlp_enabled:
|
|
562
|
+
# Build resource attributes for Mantisdk metadata
|
|
563
|
+
resource_attrs: Dict[str, Any] = {
|
|
564
|
+
LightningResourceAttributes.ROLLOUT_ID.value: rollout_id,
|
|
565
|
+
LightningResourceAttributes.ATTEMPT_ID.value: attempt_id,
|
|
566
|
+
LightningResourceAttributes.SPAN_SEQUENCE_ID.value: sequence_id_decimal,
|
|
567
|
+
LightningResourceAttributes.SPAN_TYPE.value: "rollout",
|
|
568
|
+
}
|
|
569
|
+
|
|
570
|
+
# Build span attributes for Langfuse-expected metadata
|
|
571
|
+
# Per Langfuse docs, use langfuse.* namespace for environment and tags
|
|
572
|
+
span_extra_attrs: Dict[str, Any] = {}
|
|
573
|
+
if session_id:
|
|
574
|
+
span_extra_attrs["session.id"] = session_id
|
|
575
|
+
logger.info(f"[TracingMetadata] Setting session.id={session_id}")
|
|
576
|
+
if environment:
|
|
577
|
+
span_extra_attrs["langfuse.environment"] = environment
|
|
578
|
+
logger.info(f"[TracingMetadata] Setting langfuse.environment={environment}")
|
|
579
|
+
|
|
580
|
+
# Extract call type from headers (set by @gepa.judge, @gepa.agent decorators)
|
|
581
|
+
call_type = headers_merged.get("x-mantis-call-type")
|
|
582
|
+
|
|
583
|
+
# Build final tags list, including call_type if present
|
|
584
|
+
final_tags = list(tags) if tags else []
|
|
585
|
+
if call_type and call_type not in final_tags:
|
|
586
|
+
final_tags.append(call_type)
|
|
587
|
+
|
|
588
|
+
if final_tags:
|
|
589
|
+
# Insight's OTEL ingestion expects tags on the *resource* under `langfuse.trace.tags`.
|
|
590
|
+
# Span-level `langfuse.tags` is not reliably ingested into `traces.tags`.
|
|
591
|
+
resource_attrs["langfuse.trace.tags"] = final_tags
|
|
592
|
+
# Keep span-level tags too for backwards-compat/debuggability
|
|
593
|
+
span_extra_attrs["langfuse.tags"] = final_tags
|
|
594
|
+
logger.info(f"[TracingMetadata] Setting langfuse.trace.tags={final_tags}")
|
|
595
|
+
|
|
596
|
+
# Prepare spans for export
|
|
597
|
+
spans_to_export: List[ReadableSpan] = []
|
|
598
|
+
for span in subtree_spans:
|
|
599
|
+
# Add resource attributes
|
|
600
|
+
span._resource = span._resource.merge( # pyright: ignore[reportPrivateUsage]
|
|
601
|
+
Resource.create(resource_attrs)
|
|
602
|
+
)
|
|
603
|
+
# Wrap with extra span attributes if we have any
|
|
604
|
+
if span_extra_attrs:
|
|
605
|
+
wrapped_span = SpanWithExtraAttributes(span, span_extra_attrs)
|
|
606
|
+
spans_to_export.append(wrapped_span)
|
|
607
|
+
else:
|
|
608
|
+
spans_to_export.append(span)
|
|
609
|
+
|
|
610
|
+
export_result = self._otlp_exporter.export(spans_to_export)
|
|
611
|
+
if export_result != SpanExportResult.SUCCESS:
|
|
612
|
+
logger.error(f"Failed to export rollout spans via OTLP. Result: {export_result}")
|
|
613
|
+
else:
|
|
614
|
+
# The old way: store does not support OTLP endpoint
|
|
615
|
+
for span in subtree_spans:
|
|
616
|
+
loop = self._ensure_loop()
|
|
617
|
+
add_otel_span_task = store.add_otel_span(
|
|
618
|
+
rollout_id=rollout_id,
|
|
619
|
+
attempt_id=attempt_id,
|
|
620
|
+
sequence_id=sequence_id_decimal,
|
|
621
|
+
readable_span=span,
|
|
622
|
+
)
|
|
623
|
+
fut = asyncio.run_coroutine_threadsafe(add_otel_span_task, loop)
|
|
624
|
+
fut.result()
|
|
625
|
+
|
|
626
|
+
elif otlp_enabled:
|
|
627
|
+
# Job-scoped spans (no rollout context): tag with job_id for experiment tracking
|
|
628
|
+
job_id = self._get_job_id_from_store(store)
|
|
629
|
+
|
|
630
|
+
# Extract Mantis tracing metadata from x-mantis-* headers
|
|
631
|
+
mantis_session_id = headers_merged.get("x-mantis-session-id")
|
|
632
|
+
mantis_environment = headers_merged.get("x-mantis-environment")
|
|
633
|
+
mantis_tags_str = headers_merged.get("x-mantis-tags")
|
|
634
|
+
mantis_call_type = headers_merged.get("x-mantis-call-type")
|
|
635
|
+
mantis_tags = None
|
|
636
|
+
if mantis_tags_str:
|
|
637
|
+
try:
|
|
638
|
+
mantis_tags = ast.literal_eval(mantis_tags_str) if mantis_tags_str.startswith("[") else None
|
|
639
|
+
except Exception:
|
|
640
|
+
pass
|
|
641
|
+
|
|
642
|
+
# Add call_type to tags if present
|
|
643
|
+
if mantis_call_type:
|
|
644
|
+
if mantis_tags is None:
|
|
645
|
+
mantis_tags = []
|
|
646
|
+
if mantis_call_type not in mantis_tags:
|
|
647
|
+
mantis_tags.append(mantis_call_type)
|
|
648
|
+
|
|
649
|
+
# Build span attributes for Mantis/Langfuse metadata
|
|
650
|
+
span_extra_attrs: Dict[str, Any] = {}
|
|
651
|
+
job_resource_attrs: Dict[str, Any] = {}
|
|
652
|
+
|
|
653
|
+
if mantis_session_id:
|
|
654
|
+
span_extra_attrs["session.id"] = mantis_session_id
|
|
655
|
+
logger.info(f"[TracingMetadata] Job span: session.id={mantis_session_id}")
|
|
656
|
+
if mantis_environment:
|
|
657
|
+
span_extra_attrs["langfuse.environment"] = mantis_environment
|
|
658
|
+
logger.info(f"[TracingMetadata] Job span: langfuse.environment={mantis_environment}")
|
|
659
|
+
if mantis_tags:
|
|
660
|
+
# Set tags as RESOURCE attributes (required for Insight ingestion)
|
|
661
|
+
job_resource_attrs["langfuse.trace.tags"] = mantis_tags
|
|
662
|
+
# Also set as span attributes for backwards-compat
|
|
663
|
+
span_extra_attrs["langfuse.tags"] = mantis_tags
|
|
664
|
+
logger.info(f"[TracingMetadata] Job span: langfuse.trace.tags={mantis_tags}")
|
|
665
|
+
|
|
666
|
+
# Prepare spans for export
|
|
667
|
+
spans_to_export: List[ReadableSpan] = []
|
|
668
|
+
for span in subtree_spans:
|
|
669
|
+
if job_id:
|
|
670
|
+
span._resource = span._resource.merge( # pyright: ignore[reportPrivateUsage]
|
|
671
|
+
Resource.create(
|
|
672
|
+
{
|
|
673
|
+
LightningResourceAttributes.JOB_ID.value: job_id,
|
|
674
|
+
LightningResourceAttributes.SPAN_TYPE.value: "job",
|
|
675
|
+
}
|
|
676
|
+
)
|
|
677
|
+
)
|
|
678
|
+
# Merge job-level langfuse resource attrs (tags) if present
|
|
679
|
+
try:
|
|
680
|
+
if "job_resource_attrs" in locals() and job_resource_attrs:
|
|
681
|
+
span._resource = span._resource.merge( # pyright: ignore[reportPrivateUsage]
|
|
682
|
+
Resource.create(job_resource_attrs)
|
|
683
|
+
)
|
|
684
|
+
except Exception:
|
|
685
|
+
pass
|
|
686
|
+
# Wrap with extra span attributes if we have any
|
|
687
|
+
if span_extra_attrs:
|
|
688
|
+
wrapped_span = SpanWithExtraAttributes(span, span_extra_attrs)
|
|
689
|
+
spans_to_export.append(wrapped_span)
|
|
690
|
+
else:
|
|
691
|
+
spans_to_export.append(span)
|
|
692
|
+
|
|
693
|
+
export_result = self._otlp_exporter.export(spans_to_export)
|
|
694
|
+
if export_result != SpanExportResult.SUCCESS:
|
|
695
|
+
logger.error(f"Failed to export job spans via OTLP. Result: {export_result}")
|
|
696
|
+
else:
|
|
697
|
+
logger.debug(f"Exported {len(spans_to_export)} job-scoped spans (job_id={job_id}, has_mantis_metadata={bool(span_extra_attrs)})")
|
|
698
|
+
else:
|
|
699
|
+
# No OTLP and no rollout context - skip with warning
|
|
700
|
+
logger.debug(
|
|
701
|
+
f"Skipping {len(subtree_spans)} spans: no rollout context and OTLP not enabled"
|
|
702
|
+
)
|
|
703
|
+
|
|
704
|
+
def _get_root_span_ids(self) -> Iterable[int]:
|
|
705
|
+
"""Yield span_ids for root spans currently in the buffer.
|
|
706
|
+
|
|
707
|
+
A root span is defined as one with `parent is None`.
|
|
708
|
+
|
|
709
|
+
Yields:
|
|
710
|
+
int: Span id for each root span found.
|
|
711
|
+
"""
|
|
712
|
+
for span in self._buffer:
|
|
713
|
+
if span.parent is None:
|
|
714
|
+
span_context = span.get_span_context()
|
|
715
|
+
if span_context is not None:
|
|
716
|
+
yield span_context.span_id
|
|
717
|
+
|
|
718
|
+
def _get_subtrees(self, root_span_id: int) -> Iterable[int]:
|
|
719
|
+
"""Yield span_ids in the subtree rooted at `root_span_id`.
|
|
720
|
+
|
|
721
|
+
Depth-first traversal over the current buffer.
|
|
722
|
+
|
|
723
|
+
Args:
|
|
724
|
+
root_span_id: The span id of the root.
|
|
725
|
+
|
|
726
|
+
Yields:
|
|
727
|
+
int: Span ids including the root and all descendants found.
|
|
728
|
+
"""
|
|
729
|
+
# Yield the root span id first.
|
|
730
|
+
yield root_span_id
|
|
731
|
+
for span in self._buffer:
|
|
732
|
+
# Check whether the span's parent is the root_span_id.
|
|
733
|
+
if span.parent is not None and span.parent.span_id == root_span_id:
|
|
734
|
+
span_context = span.get_span_context()
|
|
735
|
+
if span_context is not None:
|
|
736
|
+
# Recursively get child spans.
|
|
737
|
+
yield from self._get_subtrees(span_context.span_id)
|
|
738
|
+
|
|
739
|
+
def _pop_subtrees(self, root_span_id: int) -> List[ReadableSpan]:
|
|
740
|
+
"""Remove and return the subtree for a particular root from the buffer.
|
|
741
|
+
|
|
742
|
+
Args:
|
|
743
|
+
root_span_id: Root span id identifying the subtree.
|
|
744
|
+
|
|
745
|
+
Returns:
|
|
746
|
+
list[ReadableSpan]: Spans that were part of the subtree. Order follows buffer order.
|
|
747
|
+
"""
|
|
748
|
+
subtree_span_ids = set(self._get_subtrees(root_span_id))
|
|
749
|
+
subtree_spans: List[ReadableSpan] = []
|
|
750
|
+
new_buffer: List[ReadableSpan] = []
|
|
751
|
+
for span in self._buffer:
|
|
752
|
+
span_context = span.get_span_context()
|
|
753
|
+
if span_context is not None and span_context.span_id in subtree_span_ids:
|
|
754
|
+
subtree_spans.append(span)
|
|
755
|
+
else:
|
|
756
|
+
new_buffer.append(span)
|
|
757
|
+
# Replace buffer with remaining spans to avoid re-processing.
|
|
758
|
+
self._buffer = new_buffer
|
|
759
|
+
return subtree_spans
|
|
760
|
+
|
|
761
|
+
|
|
762
|
+
class LightningOpenTelemetry(OpenTelemetry):
|
|
763
|
+
"""OpenTelemetry integration that exports spans to the Lightning store.
|
|
764
|
+
|
|
765
|
+
Responsibilities:
|
|
766
|
+
|
|
767
|
+
* Ensures each request is annotated with a per-attempt sequence id so spans
|
|
768
|
+
are ordered deterministically even with clock skew across nodes.
|
|
769
|
+
* Uses [`LightningSpanExporter`][mantisdk.llm_proxy.LightningSpanExporter] to persist spans for analytics and training.
|
|
770
|
+
* Adds Mantisdk-specific attributes (session_id, tags, environment) to spans.
|
|
771
|
+
|
|
772
|
+
Args:
|
|
773
|
+
otlp_endpoint: Optional OTLP endpoint URL for direct trace export to external
|
|
774
|
+
collectors (e.g., Langfuse/Insight). When set, spans are exported directly without
|
|
775
|
+
requiring a store or rollout/attempt headers.
|
|
776
|
+
otlp_headers: Optional dict of HTTP headers for OTLP authentication (e.g., Basic Auth).
|
|
777
|
+
"""
|
|
778
|
+
|
|
779
|
+
def __init__(self, otlp_endpoint: Optional[str] = None, otlp_headers: Optional[Dict[str, str]] = None):
|
|
780
|
+
exporter = LightningSpanExporter(otlp_endpoint=otlp_endpoint, otlp_headers=otlp_headers)
|
|
781
|
+
config = OpenTelemetryConfig(exporter=exporter)
|
|
782
|
+
|
|
783
|
+
# Check for tracer initialization
|
|
784
|
+
if _check_tracer_provider():
|
|
785
|
+
logger.error("Tracer is already initialized. OpenTelemetry may not work as expected.")
|
|
786
|
+
|
|
787
|
+
super().__init__(config=config) # pyright: ignore[reportUnknownMemberType]
|
|
788
|
+
|
|
789
|
+
# Store exporter reference for debugging
|
|
790
|
+
self._custom_exporter = exporter
|
|
791
|
+
|
|
792
|
+
def _init_tracing(self, tracer_provider):
|
|
793
|
+
"""Override to ensure our span processor is added even when reusing existing TracerProvider.
|
|
794
|
+
|
|
795
|
+
LiteLLM's parent _init_tracing reuses existing TracerProviders but doesn't add
|
|
796
|
+
our custom span processor to them. We override to force adding our processor.
|
|
797
|
+
"""
|
|
798
|
+
from opentelemetry import trace as otel_trace_api
|
|
799
|
+
from opentelemetry.sdk.trace import TracerProvider as TracerProviderSDK
|
|
800
|
+
from opentelemetry.trace import SpanKind
|
|
801
|
+
|
|
802
|
+
# Call parent to set up tracer
|
|
803
|
+
super()._init_tracing(tracer_provider) # pyright: ignore[reportUnknownMemberType]
|
|
804
|
+
|
|
805
|
+
# If an existing provider was reused, add our span processor to it
|
|
806
|
+
current_provider = otel_trace_api.get_tracer_provider()
|
|
807
|
+
if isinstance(current_provider, TracerProviderSDK):
|
|
808
|
+
# Check if our processor is already added
|
|
809
|
+
has_our_processor = False
|
|
810
|
+
if hasattr(current_provider, "_active_span_processor"):
|
|
811
|
+
active_processor = current_provider._active_span_processor # pyright: ignore[reportPrivateUsage]
|
|
812
|
+
if hasattr(active_processor, "_span_processors"):
|
|
813
|
+
for proc in active_processor._span_processors: # pyright: ignore[reportPrivateUsage]
|
|
814
|
+
if hasattr(proc, "_exporter") and isinstance(proc._exporter, LightningSpanExporter): # pyright: ignore[reportPrivateUsage]
|
|
815
|
+
has_our_processor = True
|
|
816
|
+
break
|
|
817
|
+
|
|
818
|
+
if not has_our_processor:
|
|
819
|
+
# Add our span processor
|
|
820
|
+
span_processor = self._get_span_processor()
|
|
821
|
+
current_provider.add_span_processor(span_processor)
|
|
822
|
+
|
|
823
|
+
self.span_kind = SpanKind
|
|
824
|
+
|
|
825
|
+
def _get_span_processor(self, dynamic_headers: Optional[dict] = None):
|
|
826
|
+
"""Override to ensure our custom exporter is used.
|
|
827
|
+
|
|
828
|
+
LiteLLM's parent class checks if OTEL_EXPORTER has export() method and wraps it
|
|
829
|
+
in SimpleSpanProcessor. We override to add logging and ensure this happens.
|
|
830
|
+
"""
|
|
831
|
+
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
|
|
832
|
+
|
|
833
|
+
# Use our custom exporter directly
|
|
834
|
+
if hasattr(self.OTEL_EXPORTER, "export"):
|
|
835
|
+
processor = SimpleSpanProcessor(self.OTEL_EXPORTER)
|
|
836
|
+
return processor
|
|
837
|
+
|
|
838
|
+
# Fallback to parent implementation
|
|
839
|
+
return super()._get_span_processor(dynamic_headers) # pyright: ignore[reportUnknownMemberType]
|
|
840
|
+
|
|
841
|
+
def set_attributes(self, span: Any, kwargs: Dict[str, Any], response_obj: Optional[Any]) -> None:
|
|
842
|
+
"""Override to add Mantisdk-specific attributes from metadata.
|
|
843
|
+
|
|
844
|
+
Extracts session_id, tags, and environment from kwargs["metadata"] and sets
|
|
845
|
+
them as OTEL span attributes for visibility in Insight/Langfuse.
|
|
846
|
+
|
|
847
|
+
Also extracts extra_headers from kwargs and sets them as metadata.requester_custom_headers
|
|
848
|
+
so the proxy can read them for tagging.
|
|
849
|
+
"""
|
|
850
|
+
# Call parent implementation first
|
|
851
|
+
super().set_attributes(span, kwargs, response_obj) # pyright: ignore[reportUnknownMemberType]
|
|
852
|
+
|
|
853
|
+
# Extract extra_headers from kwargs and set as metadata.requester_custom_headers
|
|
854
|
+
# This allows the proxy to read x-mantis-* headers for tagging
|
|
855
|
+
# Check multiple sources: extra_headers (from OpenAI SDK), kwargs["headers"], and context variable
|
|
856
|
+
extra_headers = kwargs.get("extra_headers") or kwargs.get("extraHeaders")
|
|
857
|
+
request_headers = kwargs.get("headers") or {}
|
|
858
|
+
context_headers = _request_headers_context.get() or {}
|
|
859
|
+
|
|
860
|
+
# Merge all sources: extra_headers, request_headers, and context headers
|
|
861
|
+
merged_headers = {}
|
|
862
|
+
if extra_headers and isinstance(extra_headers, dict):
|
|
863
|
+
merged_headers.update(extra_headers)
|
|
864
|
+
if request_headers and isinstance(request_headers, dict):
|
|
865
|
+
# Extract x-mantis-* headers from HTTP request headers passed by LiteLLM
|
|
866
|
+
for key, value in request_headers.items():
|
|
867
|
+
if isinstance(key, str) and key.lower().startswith("x-mantis-"):
|
|
868
|
+
merged_headers[key] = value
|
|
869
|
+
# Also check context variable (set by MantisHeadersMiddleware)
|
|
870
|
+
if context_headers:
|
|
871
|
+
merged_headers.update(context_headers)
|
|
872
|
+
|
|
873
|
+
if merged_headers:
|
|
874
|
+
# LiteLLM stores this as a stringified dict in metadata.requester_custom_headers
|
|
875
|
+
span.set_attribute("metadata.requester_custom_headers", str(merged_headers))
|
|
876
|
+
|
|
877
|
+
# Extract Mantisdk tracing metadata from kwargs
|
|
878
|
+
metadata = kwargs.get("metadata", {}) if kwargs else {}
|
|
879
|
+
if not metadata:
|
|
880
|
+
return
|
|
881
|
+
|
|
882
|
+
# Set session_id as span attribute (standard Langfuse/Insight attribute)
|
|
883
|
+
session_id = metadata.get("session_id")
|
|
884
|
+
if session_id:
|
|
885
|
+
span.set_attribute("session.id", session_id)
|
|
886
|
+
|
|
887
|
+
# Set tags as span attribute
|
|
888
|
+
tags = metadata.get("tags")
|
|
889
|
+
if tags and isinstance(tags, list):
|
|
890
|
+
# Set as JSON array for Langfuse compatibility
|
|
891
|
+
import json
|
|
892
|
+
span.set_attribute("tags", json.dumps(tags))
|
|
893
|
+
# Also set individual tag attributes for filtering
|
|
894
|
+
for i, tag in enumerate(tags):
|
|
895
|
+
span.set_attribute(f"tag.{i}", str(tag))
|
|
896
|
+
|
|
897
|
+
# Set environment as span attribute
|
|
898
|
+
environment = metadata.get("environment")
|
|
899
|
+
if environment:
|
|
900
|
+
span.set_attribute("environment", environment)
|
|
901
|
+
|
|
902
|
+
async def async_pre_call_deployment_hook(
|
|
903
|
+
self, kwargs: Dict[str, Any], call_type: Optional[CallTypes] = None
|
|
904
|
+
) -> Optional[Dict[str, Any]]:
|
|
905
|
+
"""The root span is sometimes missing (e.g., when Anthropic endpoint is used).
|
|
906
|
+
It is created in an auth module in LiteLLM. If it's missing, we create it here.
|
|
907
|
+
"""
|
|
908
|
+
if "metadata" not in kwargs or "litellm_parent_otel_span" not in kwargs["metadata"]:
|
|
909
|
+
parent_otel_span = self.create_litellm_proxy_request_started_span( # type: ignore
|
|
910
|
+
start_time=datetime.now(),
|
|
911
|
+
headers=kwargs.get("headers", {}),
|
|
912
|
+
)
|
|
913
|
+
updated_metadata = {**kwargs.get("metadata", {}), "litellm_parent_otel_span": parent_otel_span}
|
|
914
|
+
|
|
915
|
+
return {**kwargs, "metadata": updated_metadata}
|
|
916
|
+
else:
|
|
917
|
+
return kwargs
|
|
918
|
+
|
|
919
|
+
|
|
920
|
+
class RolloutAttemptMiddleware(BaseHTTPMiddleware):
|
|
921
|
+
"""
|
|
922
|
+
Rewrites /rollout/{rid}/attempt/{aid}/... -> /...
|
|
923
|
+
and injects x-rollout-id, x-attempt-id, x-sequence-id headers.
|
|
924
|
+
|
|
925
|
+
LLMProxy can update store later without rebuilding middleware.
|
|
926
|
+
"""
|
|
927
|
+
|
|
928
|
+
async def dispatch(self, request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response:
|
|
929
|
+
# Decode rollout and attempt from the URL prefix. Example:
|
|
930
|
+
# /rollout/r123/attempt/a456/v1/chat/completions
|
|
931
|
+
# becomes
|
|
932
|
+
# /v1/chat/completions
|
|
933
|
+
# while adding request-scoped headers for trace attribution.
|
|
934
|
+
path = request.url.path
|
|
935
|
+
|
|
936
|
+
match = re.match(r"^/rollout/([^/]+)/attempt/([^/]+)(/.*)?$", path)
|
|
937
|
+
if match:
|
|
938
|
+
rollout_id = match.group(1)
|
|
939
|
+
attempt_id = match.group(2)
|
|
940
|
+
new_path = match.group(3) if match.group(3) is not None else "/"
|
|
941
|
+
|
|
942
|
+
# Rewrite the ASGI scope path so downstream sees a clean OpenAI path.
|
|
943
|
+
request.scope["path"] = new_path
|
|
944
|
+
request.scope["raw_path"] = new_path.encode()
|
|
945
|
+
|
|
946
|
+
store = get_active_llm_proxy().get_store()
|
|
947
|
+
if store is not None:
|
|
948
|
+
# Allocate a monotonic sequence id per (rollout, attempt).
|
|
949
|
+
sequence_id = await store.get_next_span_sequence_id(rollout_id, attempt_id)
|
|
950
|
+
|
|
951
|
+
# Inject headers so downstream components and exporters can retrieve them.
|
|
952
|
+
request.scope["headers"] = list(request.scope["headers"]) + [
|
|
953
|
+
(b"x-rollout-id", rollout_id.encode()),
|
|
954
|
+
(b"x-attempt-id", attempt_id.encode()),
|
|
955
|
+
(b"x-sequence-id", str(sequence_id).encode()),
|
|
956
|
+
]
|
|
957
|
+
else:
|
|
958
|
+
logger.warning("Store is not set. Skipping sequence id allocation and header injection.")
|
|
959
|
+
|
|
960
|
+
response = await call_next(request)
|
|
961
|
+
return response
|
|
962
|
+
|
|
963
|
+
|
|
964
|
+
class MantisHeadersMiddleware(BaseHTTPMiddleware):
|
|
965
|
+
"""Middleware to intercept x-mantis-* HTTP headers and store them in context.
|
|
966
|
+
|
|
967
|
+
This allows LiteLLM callbacks to access custom headers (like x-mantis-call-type)
|
|
968
|
+
even if LiteLLM doesn't pass them through kwargs.
|
|
969
|
+
"""
|
|
970
|
+
|
|
971
|
+
async def dispatch(self, request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response:
|
|
972
|
+
# Extract x-mantis-* headers from HTTP request
|
|
973
|
+
mantis_headers: Dict[str, str] = {}
|
|
974
|
+
for header_name, header_value in request.headers.items():
|
|
975
|
+
if isinstance(header_name, str) and header_name.lower().startswith("x-mantis-"):
|
|
976
|
+
mantis_headers[header_name] = header_value
|
|
977
|
+
|
|
978
|
+
# Store in context variable for callback access
|
|
979
|
+
if mantis_headers:
|
|
980
|
+
_request_headers_context.set(mantis_headers)
|
|
981
|
+
|
|
982
|
+
response = await call_next(request)
|
|
983
|
+
return response
|
|
984
|
+
|
|
985
|
+
|
|
986
|
+
class MessageInspectionMiddleware(BaseHTTPMiddleware):
|
|
987
|
+
"""Middleware to inspect the request and response bodies.
|
|
988
|
+
|
|
989
|
+
It's for debugging purposes. Add it via "message_inspection" middleware alias.
|
|
990
|
+
"""
|
|
991
|
+
|
|
992
|
+
async def dispatch(self, request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response:
|
|
993
|
+
ti = time.time()
|
|
994
|
+
logger.info(f"Received request with scope: {request.scope}")
|
|
995
|
+
logger.info(f"Received request with body: {await request.body()}")
|
|
996
|
+
response = await call_next(request)
|
|
997
|
+
elapsed = time.time() - ti
|
|
998
|
+
logger.info(f"Response to request took {elapsed} seconds")
|
|
999
|
+
logger.info(f"Received response with status code: {response.status_code}")
|
|
1000
|
+
logger.info(f"Received response with body: {response.body}")
|
|
1001
|
+
return response
|
|
1002
|
+
|
|
1003
|
+
|
|
1004
|
+
class StreamConversionMiddleware(BaseHTTPMiddleware):
|
|
1005
|
+
"""Middleware to convert streaming responses to non-streaming responses.
|
|
1006
|
+
|
|
1007
|
+
Useful for backend that only supports non-streaming responses.
|
|
1008
|
+
|
|
1009
|
+
LiteLLM's OpenTelemetry is also buggy with streaming responses.
|
|
1010
|
+
The conversion will hopefully bypass the bug.
|
|
1011
|
+
"""
|
|
1012
|
+
|
|
1013
|
+
async def dispatch(self, request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response:
|
|
1014
|
+
# Only process POST requests to completion endpoints
|
|
1015
|
+
if request.method != "POST":
|
|
1016
|
+
return await call_next(request)
|
|
1017
|
+
|
|
1018
|
+
# Check if it's a chat completions or messages endpoint
|
|
1019
|
+
endpoint_format: Literal["openai", "anthropic", "unknown"] = "unknown"
|
|
1020
|
+
if request.url.path.endswith("/chat/completions") or "/chat/completions?" in request.url.path:
|
|
1021
|
+
endpoint_format = "openai"
|
|
1022
|
+
elif request.url.path.endswith("/messages") or "/messages?" in request.url.path:
|
|
1023
|
+
endpoint_format = "anthropic"
|
|
1024
|
+
else:
|
|
1025
|
+
endpoint_format = "unknown"
|
|
1026
|
+
|
|
1027
|
+
if endpoint_format == "unknown":
|
|
1028
|
+
# Directly bypass the middleware
|
|
1029
|
+
return await call_next(request)
|
|
1030
|
+
|
|
1031
|
+
# Read the request body
|
|
1032
|
+
try:
|
|
1033
|
+
json_body = await request.json()
|
|
1034
|
+
except json.JSONDecodeError:
|
|
1035
|
+
logger.warning(f"Request body is not valid JSON: {request.body}")
|
|
1036
|
+
return await call_next(request)
|
|
1037
|
+
|
|
1038
|
+
# Check if streaming is requested
|
|
1039
|
+
is_streaming = json_body.get("stream", False)
|
|
1040
|
+
|
|
1041
|
+
# Simple case: no streaming requested, just return the response
|
|
1042
|
+
if not is_streaming:
|
|
1043
|
+
return await call_next(request)
|
|
1044
|
+
|
|
1045
|
+
# Now the stream case
|
|
1046
|
+
return await self._handle_stream_case(request, json_body, endpoint_format, call_next)
|
|
1047
|
+
|
|
1048
|
+
async def _handle_stream_case(
|
|
1049
|
+
self,
|
|
1050
|
+
request: Request,
|
|
1051
|
+
json_body: Dict[str, Any],
|
|
1052
|
+
endpoint_format: Literal["openai", "anthropic"],
|
|
1053
|
+
call_next: Callable[[Request], Awaitable[Response]],
|
|
1054
|
+
) -> Response:
|
|
1055
|
+
# 1) Modify the request body to force stream=False
|
|
1056
|
+
modified_json = dict(json_body)
|
|
1057
|
+
modified_json["stream"] = False
|
|
1058
|
+
modified_body = json.dumps(modified_json).encode("utf-8")
|
|
1059
|
+
|
|
1060
|
+
# 2) Build a new scope + receive that yields our modified body
|
|
1061
|
+
scope: Scope = dict(request.scope)
|
|
1062
|
+
# rewrite headers for accept/content-length
|
|
1063
|
+
new_headers: List[Tuple[bytes, bytes]] = []
|
|
1064
|
+
saw_accept = False
|
|
1065
|
+
for k, v in scope["headers"]:
|
|
1066
|
+
kl = k.lower()
|
|
1067
|
+
if kl == b"accept":
|
|
1068
|
+
saw_accept = True
|
|
1069
|
+
new_headers.append((k, b"application/json"))
|
|
1070
|
+
elif kl == b"content-length":
|
|
1071
|
+
# replace with new length
|
|
1072
|
+
continue
|
|
1073
|
+
else:
|
|
1074
|
+
new_headers.append((k, v))
|
|
1075
|
+
if not saw_accept:
|
|
1076
|
+
new_headers.append((b"accept", b"application/json"))
|
|
1077
|
+
new_headers.append((b"content-length", str(len(modified_body)).encode("ascii")))
|
|
1078
|
+
scope["headers"] = new_headers
|
|
1079
|
+
|
|
1080
|
+
# Directly modify the request body
|
|
1081
|
+
# Creating a new request won't work because request is cached in the base class
|
|
1082
|
+
request._body = modified_body # type: ignore
|
|
1083
|
+
|
|
1084
|
+
response = await call_next(request)
|
|
1085
|
+
|
|
1086
|
+
buffered: Optional[bytes] = None
|
|
1087
|
+
# 4) If OK, buffer the response body (it should be JSON because we forced stream=False)
|
|
1088
|
+
if 200 <= response.status_code < 300:
|
|
1089
|
+
try:
|
|
1090
|
+
if hasattr(response, "body_iterator"):
|
|
1091
|
+
# Buffer body safely
|
|
1092
|
+
body_chunks: List[bytes] = []
|
|
1093
|
+
async for chunk in response.body_iterator: # type: ignore
|
|
1094
|
+
body_chunks.append(chunk) # type: ignore
|
|
1095
|
+
buffered = b"".join(body_chunks)
|
|
1096
|
+
else:
|
|
1097
|
+
buffered = response.body # type: ignore
|
|
1098
|
+
|
|
1099
|
+
data = json.loads(buffered or b"{}")
|
|
1100
|
+
|
|
1101
|
+
if endpoint_format == "anthropic":
|
|
1102
|
+
return StreamingResponse(
|
|
1103
|
+
self.anthropic_stream_generator(data),
|
|
1104
|
+
media_type="text/event-stream",
|
|
1105
|
+
headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
|
|
1106
|
+
)
|
|
1107
|
+
else:
|
|
1108
|
+
# openai format
|
|
1109
|
+
return StreamingResponse(
|
|
1110
|
+
self.openai_stream_generator(data),
|
|
1111
|
+
media_type="text/event-stream",
|
|
1112
|
+
headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
|
|
1113
|
+
)
|
|
1114
|
+
except Exception as e:
|
|
1115
|
+
# If anything goes wrong, fall back to non-streaming JSON
|
|
1116
|
+
logger.exception(f"Error converting to stream; returning non-stream response: {e}")
|
|
1117
|
+
# Rebuild the consumed response
|
|
1118
|
+
return Response(
|
|
1119
|
+
content=buffered if buffered is not None else b"",
|
|
1120
|
+
status_code=response.status_code,
|
|
1121
|
+
headers=dict(response.headers),
|
|
1122
|
+
media_type=response.media_type,
|
|
1123
|
+
background=response.background,
|
|
1124
|
+
)
|
|
1125
|
+
else:
|
|
1126
|
+
return response
|
|
1127
|
+
|
|
1128
|
+
async def anthropic_stream_generator(self, original_response: Dict[str, Any]):
|
|
1129
|
+
"""Generate Anthropic SSE-formatted chunks from complete content blocks
|
|
1130
|
+
|
|
1131
|
+
This is a dirty hack for Anthropic-style streaming from non-streaming response.
|
|
1132
|
+
The sse format is subject to change based on Anthropic's implementation.
|
|
1133
|
+
If so, try to use `MessageInspectionMiddleware` to inspect the update and fix accordingly.
|
|
1134
|
+
"""
|
|
1135
|
+
# Anthropic format - handle multiple content blocks (text + tool_use)
|
|
1136
|
+
content_blocks: List[Dict[str, Any]] = original_response.get("content", [])
|
|
1137
|
+
message_id = original_response.get("id", f"msg_{int(time.time() * 1000)}")
|
|
1138
|
+
model = original_response.get("model", "claude")
|
|
1139
|
+
|
|
1140
|
+
# Send message_start event
|
|
1141
|
+
message_start: Dict[str, Any] = {
|
|
1142
|
+
"type": "message_start",
|
|
1143
|
+
"message": {
|
|
1144
|
+
"id": message_id,
|
|
1145
|
+
"type": "message",
|
|
1146
|
+
"role": "assistant",
|
|
1147
|
+
"content": [],
|
|
1148
|
+
"model": model,
|
|
1149
|
+
"stop_reason": None,
|
|
1150
|
+
"stop_sequence": None,
|
|
1151
|
+
"usage": original_response.get("usage", {"input_tokens": 0, "output_tokens": 0}),
|
|
1152
|
+
},
|
|
1153
|
+
}
|
|
1154
|
+
yield f"event: message_start\ndata: {json.dumps(message_start)}\n\n"
|
|
1155
|
+
|
|
1156
|
+
# Send ping to keep connection alive
|
|
1157
|
+
ping = {"type": "ping"}
|
|
1158
|
+
yield f"event: ping\ndata: {json.dumps(ping)}\n\n"
|
|
1159
|
+
|
|
1160
|
+
# Process each content block
|
|
1161
|
+
for block_index, block in enumerate(content_blocks):
|
|
1162
|
+
block_type = block.get("type", "text")
|
|
1163
|
+
|
|
1164
|
+
if block_type == "text":
|
|
1165
|
+
# Handle text block
|
|
1166
|
+
content = block.get("text", "")
|
|
1167
|
+
|
|
1168
|
+
# Send content_block_start event
|
|
1169
|
+
content_block_start = {
|
|
1170
|
+
"type": "content_block_start",
|
|
1171
|
+
"index": block_index,
|
|
1172
|
+
"content_block": {"type": "text", "text": ""},
|
|
1173
|
+
}
|
|
1174
|
+
yield f"event: content_block_start\ndata: {json.dumps(content_block_start)}\n\n"
|
|
1175
|
+
|
|
1176
|
+
# Stream text content in chunks
|
|
1177
|
+
if content:
|
|
1178
|
+
words = content.split()
|
|
1179
|
+
chunk_size = 5
|
|
1180
|
+
|
|
1181
|
+
for i in range(0, len(words), chunk_size):
|
|
1182
|
+
chunk_words = words[i : i + chunk_size]
|
|
1183
|
+
text_chunk = " ".join(chunk_words)
|
|
1184
|
+
|
|
1185
|
+
# Add space after chunk unless it's the last one
|
|
1186
|
+
if i + chunk_size < len(words):
|
|
1187
|
+
text_chunk += " "
|
|
1188
|
+
|
|
1189
|
+
content_block_delta = {
|
|
1190
|
+
"type": "content_block_delta",
|
|
1191
|
+
"index": block_index,
|
|
1192
|
+
"delta": {"type": "text_delta", "text": text_chunk},
|
|
1193
|
+
}
|
|
1194
|
+
yield f"event: content_block_delta\ndata: {json.dumps(content_block_delta)}\n\n"
|
|
1195
|
+
await asyncio.sleep(0.02)
|
|
1196
|
+
|
|
1197
|
+
# Send content_block_stop event
|
|
1198
|
+
content_block_stop = {"type": "content_block_stop", "index": block_index}
|
|
1199
|
+
yield f"event: content_block_stop\ndata: {json.dumps(content_block_stop)}\n\n"
|
|
1200
|
+
|
|
1201
|
+
elif block_type == "tool_use":
|
|
1202
|
+
# Handle tool_use block
|
|
1203
|
+
tool_name = block.get("name", "")
|
|
1204
|
+
tool_input = block.get("input", {})
|
|
1205
|
+
tool_id = block.get("id", f"toolu_{int(time.time() * 1000)}")
|
|
1206
|
+
|
|
1207
|
+
# Send content_block_start event for tool use
|
|
1208
|
+
content_block_start: Dict[str, Any] = {
|
|
1209
|
+
"type": "content_block_start",
|
|
1210
|
+
"index": block_index,
|
|
1211
|
+
"content_block": {"type": "tool_use", "id": tool_id, "name": tool_name, "input": {}},
|
|
1212
|
+
}
|
|
1213
|
+
yield f"event: content_block_start\ndata: {json.dumps(content_block_start)}\n\n"
|
|
1214
|
+
|
|
1215
|
+
# Stream tool input as JSON string chunks
|
|
1216
|
+
input_json = json.dumps(tool_input)
|
|
1217
|
+
chunk_size = 20 # characters per chunk for JSON
|
|
1218
|
+
|
|
1219
|
+
for i in range(0, len(input_json), chunk_size):
|
|
1220
|
+
json_chunk = input_json[i : i + chunk_size]
|
|
1221
|
+
|
|
1222
|
+
content_block_delta = {
|
|
1223
|
+
"type": "content_block_delta",
|
|
1224
|
+
"index": block_index,
|
|
1225
|
+
"delta": {"type": "input_json_delta", "partial_json": json_chunk},
|
|
1226
|
+
}
|
|
1227
|
+
yield f"event: content_block_delta\ndata: {json.dumps(content_block_delta)}\n\n"
|
|
1228
|
+
await asyncio.sleep(0.01)
|
|
1229
|
+
|
|
1230
|
+
# Send content_block_stop event
|
|
1231
|
+
content_block_stop = {"type": "content_block_stop", "index": block_index}
|
|
1232
|
+
yield f"event: content_block_stop\ndata: {json.dumps(content_block_stop)}\n\n"
|
|
1233
|
+
|
|
1234
|
+
# Send message_delta event with stop reason
|
|
1235
|
+
message_delta = {
|
|
1236
|
+
"type": "message_delta",
|
|
1237
|
+
"delta": {"stop_reason": original_response.get("stop_reason", "end_turn"), "stop_sequence": None},
|
|
1238
|
+
"usage": {"output_tokens": original_response.get("usage", {}).get("output_tokens", 0)},
|
|
1239
|
+
}
|
|
1240
|
+
yield f"event: message_delta\ndata: {json.dumps(message_delta)}\n\n"
|
|
1241
|
+
|
|
1242
|
+
# Send message_stop event
|
|
1243
|
+
message_stop = {"type": "message_stop"}
|
|
1244
|
+
yield f"event: message_stop\ndata: {json.dumps(message_stop)}\n\n"
|
|
1245
|
+
|
|
1246
|
+
async def openai_stream_generator(self, response_json: Dict[str, Any]) -> AsyncGenerator[str, Any]:
|
|
1247
|
+
"""
|
|
1248
|
+
Convert a *complete* OpenAI chat.completions choice into a stream of
|
|
1249
|
+
OpenAI-compatible SSE chunks.
|
|
1250
|
+
|
|
1251
|
+
This emits:
|
|
1252
|
+
|
|
1253
|
+
- an initial delta with the role ("assistant"),
|
|
1254
|
+
- a sequence of deltas for message.content (split into small chunks),
|
|
1255
|
+
- deltas for any tool_calls (including id/name and chunked arguments),
|
|
1256
|
+
- a terminal chunk with finish_reason,
|
|
1257
|
+
- and finally the literal '[DONE]'.
|
|
1258
|
+
|
|
1259
|
+
Notes:
|
|
1260
|
+
|
|
1261
|
+
- We only handle a *single* choice (index 0 typically).
|
|
1262
|
+
- We purposefully don't attempt to stream logprobs.
|
|
1263
|
+
- Chunking strategy is simple and conservative to avoid splitting
|
|
1264
|
+
multi-byte characters: we slice on spaces where possible, then fall
|
|
1265
|
+
back to fixed-size substrings.
|
|
1266
|
+
"""
|
|
1267
|
+
choice = cast(Dict[str, Any], (response_json.get("choices") or [{}])[0])
|
|
1268
|
+
model = response_json.get("model", "unknown")
|
|
1269
|
+
created: int = int(time.time())
|
|
1270
|
+
index: int = choice.get("index", 0)
|
|
1271
|
+
|
|
1272
|
+
message: Dict[str, Any] = choice.get("message", {}) or {}
|
|
1273
|
+
role: str = message.get("role", "assistant")
|
|
1274
|
+
content: str = message.get("content") or ""
|
|
1275
|
+
tool_calls: List[Any] = message.get("tool_calls") or []
|
|
1276
|
+
finish_reason: Optional[str] = choice.get(
|
|
1277
|
+
"finish_reason"
|
|
1278
|
+
) # e.g., "stop", "length", "tool_calls", "content_filter"
|
|
1279
|
+
|
|
1280
|
+
def sse_chunk(obj: Dict[str, Any]) -> str:
|
|
1281
|
+
return f"data: {json.dumps(obj, ensure_ascii=False)}\n\n"
|
|
1282
|
+
|
|
1283
|
+
# 1) initial chunk with the role
|
|
1284
|
+
yield sse_chunk(
|
|
1285
|
+
{
|
|
1286
|
+
"id": f"chatcmpl-{created}",
|
|
1287
|
+
"object": "chat.completion.chunk",
|
|
1288
|
+
"created": created,
|
|
1289
|
+
"model": model,
|
|
1290
|
+
"choices": [{"index": index, "delta": {"role": role}, "finish_reason": None}],
|
|
1291
|
+
}
|
|
1292
|
+
)
|
|
1293
|
+
|
|
1294
|
+
# 2) stream textual content as small deltas
|
|
1295
|
+
async def stream_content(text: str):
|
|
1296
|
+
if not text:
|
|
1297
|
+
return
|
|
1298
|
+
# prefer splitting on spaces in ~20–40 char pieces
|
|
1299
|
+
approx = 28
|
|
1300
|
+
start = 0
|
|
1301
|
+
n = len(text)
|
|
1302
|
+
while start < n:
|
|
1303
|
+
end = min(start + approx, n)
|
|
1304
|
+
if end < n:
|
|
1305
|
+
# try to break on a space going forward
|
|
1306
|
+
space = text.rfind(" ", start, end)
|
|
1307
|
+
if space > start:
|
|
1308
|
+
end = space + 1
|
|
1309
|
+
delta_text = text[start:end]
|
|
1310
|
+
start = end
|
|
1311
|
+
if not delta_text:
|
|
1312
|
+
break
|
|
1313
|
+
yield sse_chunk(
|
|
1314
|
+
{
|
|
1315
|
+
"id": f"chatcmpl-{created}",
|
|
1316
|
+
"object": "chat.completion.chunk",
|
|
1317
|
+
"created": created,
|
|
1318
|
+
"model": model,
|
|
1319
|
+
"choices": [{"index": index, "delta": {"content": delta_text}, "finish_reason": None}],
|
|
1320
|
+
}
|
|
1321
|
+
)
|
|
1322
|
+
# tiny pause helps some UIs animate smoothly; keep very small
|
|
1323
|
+
await asyncio.sleep(0.0)
|
|
1324
|
+
|
|
1325
|
+
async for piece in stream_content(content): # type: ignore[misc]
|
|
1326
|
+
yield piece # pass through the produced chunks
|
|
1327
|
+
|
|
1328
|
+
# 3) stream tool_calls if present (id/name first, then arguments piecemeal)
|
|
1329
|
+
for tc_index, tc in enumerate(tool_calls):
|
|
1330
|
+
tc_type = tc.get("type", "function")
|
|
1331
|
+
tc_id = tc.get("id") or f"call_{created}_{tc_index}"
|
|
1332
|
+
fn: Dict[str, Any] = (tc.get("function") or {}) if tc_type == "function" else {}
|
|
1333
|
+
fn_name: str = fn.get("name", "")
|
|
1334
|
+
fn_args: str = fn.get("arguments", "") or ""
|
|
1335
|
+
|
|
1336
|
+
# (a) delta that announces the tool call id/type/name
|
|
1337
|
+
yield sse_chunk(
|
|
1338
|
+
{
|
|
1339
|
+
"id": f"chatcmpl-{created}",
|
|
1340
|
+
"object": "chat.completion.chunk",
|
|
1341
|
+
"created": created,
|
|
1342
|
+
"model": model,
|
|
1343
|
+
"choices": [
|
|
1344
|
+
{
|
|
1345
|
+
"index": index,
|
|
1346
|
+
"delta": {
|
|
1347
|
+
"tool_calls": [
|
|
1348
|
+
{"index": tc_index, "id": tc_id, "type": tc_type, "function": {"name": fn_name}}
|
|
1349
|
+
]
|
|
1350
|
+
},
|
|
1351
|
+
"finish_reason": None,
|
|
1352
|
+
}
|
|
1353
|
+
],
|
|
1354
|
+
}
|
|
1355
|
+
)
|
|
1356
|
+
|
|
1357
|
+
# (b) stream arguments in small substrings
|
|
1358
|
+
arg_chunk_size = 40
|
|
1359
|
+
for pos in range(0, len(fn_args), arg_chunk_size):
|
|
1360
|
+
partial = fn_args[pos : pos + arg_chunk_size]
|
|
1361
|
+
yield sse_chunk(
|
|
1362
|
+
{
|
|
1363
|
+
"id": f"chatcmpl-{created}",
|
|
1364
|
+
"object": "chat.completion.chunk",
|
|
1365
|
+
"created": created,
|
|
1366
|
+
"model": model,
|
|
1367
|
+
"choices": [
|
|
1368
|
+
{
|
|
1369
|
+
"index": index,
|
|
1370
|
+
"delta": {"tool_calls": [{"index": tc_index, "function": {"arguments": partial}}]},
|
|
1371
|
+
"finish_reason": None,
|
|
1372
|
+
}
|
|
1373
|
+
],
|
|
1374
|
+
}
|
|
1375
|
+
)
|
|
1376
|
+
await asyncio.sleep(0.0)
|
|
1377
|
+
|
|
1378
|
+
# 4) terminal chunk with finish_reason (default to "stop" if missing)
|
|
1379
|
+
yield sse_chunk(
|
|
1380
|
+
{
|
|
1381
|
+
"id": f"chatcmpl-{created}",
|
|
1382
|
+
"object": "chat.completion.chunk",
|
|
1383
|
+
"created": created,
|
|
1384
|
+
"model": model,
|
|
1385
|
+
"choices": [
|
|
1386
|
+
{
|
|
1387
|
+
"index": index,
|
|
1388
|
+
"delta": {},
|
|
1389
|
+
"finish_reason": finish_reason or ("tool_calls" if tool_calls else "stop"),
|
|
1390
|
+
}
|
|
1391
|
+
],
|
|
1392
|
+
}
|
|
1393
|
+
)
|
|
1394
|
+
|
|
1395
|
+
# 5) literal DONE sentinel
|
|
1396
|
+
yield "data: [DONE]\n\n"
|
|
1397
|
+
|
|
1398
|
+
|
|
1399
|
+
_MIDDLEWARE_REGISTRY: Dict[str, Type[BaseHTTPMiddleware]] = {
|
|
1400
|
+
"rollout_attempt": RolloutAttemptMiddleware,
|
|
1401
|
+
"stream_conversion": StreamConversionMiddleware,
|
|
1402
|
+
"message_inspection": MessageInspectionMiddleware,
|
|
1403
|
+
"mantis_headers": MantisHeadersMiddleware,
|
|
1404
|
+
}
|
|
1405
|
+
|
|
1406
|
+
|
|
1407
|
+
_CALLBACK_REGISTRY = {
|
|
1408
|
+
"return_token_ids": AddReturnTokenIds,
|
|
1409
|
+
"logprobs": AddLogprobs,
|
|
1410
|
+
"opentelemetry": LightningOpenTelemetry,
|
|
1411
|
+
}
|
|
1412
|
+
|
|
1413
|
+
|
|
1414
|
+
class LLMProxy:
|
|
1415
|
+
"""Host a LiteLLM OpenAI-compatible proxy bound to a LightningStore.
|
|
1416
|
+
|
|
1417
|
+
The proxy:
|
|
1418
|
+
|
|
1419
|
+
* Serves an OpenAI-compatible API via uvicorn.
|
|
1420
|
+
* Adds rollout/attempt routing and headers via middleware.
|
|
1421
|
+
* Registers OTEL export and token-id callbacks.
|
|
1422
|
+
* Writes a LiteLLM worker config file with `model_list` and settings.
|
|
1423
|
+
|
|
1424
|
+
Lifecycle:
|
|
1425
|
+
|
|
1426
|
+
* [`start()`][mantisdk.LLMProxy.start] writes config, starts uvicorn server in a thread, and waits until ready.
|
|
1427
|
+
* [`stop()`][mantisdk.LLMProxy.stop] tears down the server and removes the temp config file.
|
|
1428
|
+
* [`restart()`][mantisdk.LLMProxy.restart] convenience wrapper to stop then start.
|
|
1429
|
+
|
|
1430
|
+
!!! note
|
|
1431
|
+
|
|
1432
|
+
As the LLM Proxy sets up an OpenTelemetry tracer, it's recommended to run it in a different
|
|
1433
|
+
process from the main runner (i.e., tracer from agents). See `launch_mode` for how to change that.
|
|
1434
|
+
|
|
1435
|
+
!!! warning
|
|
1436
|
+
|
|
1437
|
+
By default (or when "stream_conversion" middleware is enabled), the LLM Proxy will convert OpenAI and Anthropic requests with `stream=True`
|
|
1438
|
+
to a non-streaming request before going through the LiteLLM proxy. This is because the OpenTelemetry tracer provided by
|
|
1439
|
+
LiteLLM is buggy with streaming responses. You can disable this by removing the "stream_conversion" middleware.
|
|
1440
|
+
In that case, you might lose some tracing information like token IDs.
|
|
1441
|
+
|
|
1442
|
+
!!! danger
|
|
1443
|
+
|
|
1444
|
+
Do not run LLM proxy in the same process as the main runner. It's easy to cause conflicts in the tracer provider
|
|
1445
|
+
with tracers like [`AgentOpsTracer`][mantisdk.AgentOpsTracer].
|
|
1446
|
+
|
|
1447
|
+
Args:
|
|
1448
|
+
port: TCP port to bind. Will bind to a random port if not provided.
|
|
1449
|
+
model_list: LiteLLM `model_list` entries.
|
|
1450
|
+
store: LightningStore used for span sequence and persistence.
|
|
1451
|
+
host: Publicly reachable host used in resource endpoints. See `host` of `launcher_args` for more details.
|
|
1452
|
+
litellm_config: Extra LiteLLM proxy config merged with `model_list`.
|
|
1453
|
+
num_retries: Default LiteLLM retry count injected into `litellm_settings`.
|
|
1454
|
+
num_workers: Number of workers to run in the server. Only applicable for "mp" launch mode. Ignored if launcher_args is provided.
|
|
1455
|
+
When `num_workers > 1`, the server will be run using [gunicorn](https://gunicorn.org/).
|
|
1456
|
+
launch_mode: Launch mode for the server. Defaults to "mp". Cannot be used together with launcher_args. Ignored if launcher_args is provided.
|
|
1457
|
+
It's recommended to use `launch_mode="mp"` to launch the proxy, which will launch the server in a separate process.
|
|
1458
|
+
`launch_mode="thread"` can also be used if used in caution. It will launch the server in a separate thread.
|
|
1459
|
+
`launch_mode="asyncio"` launches the server in the current thread as an asyncio task.
|
|
1460
|
+
It is NOT recommended because it often causes hanging requests. Only use it if you know what you are doing.
|
|
1461
|
+
launcher_args: Arguments for the server launcher. If this is provided, host, port, and launch_mode will be ignored. Cannot be used together with port, host, and launch_mode.
|
|
1462
|
+
middlewares: List of FastAPI middleware classes or strings to register. You can specify the class aliases or classes that have been imported.
|
|
1463
|
+
If not provided, the default middlewares (RolloutAttemptMiddleware and StreamConversionMiddleware) will be used.
|
|
1464
|
+
Available middleware aliases are: "rollout_attempt", "stream_conversion", "message_inspection".
|
|
1465
|
+
Middlewares are the **first layer** of request processing. They are applied to all requests before the LiteLLM proxy.
|
|
1466
|
+
callbacks: List of LiteLLM callback classes or strings to register. You can specify the class aliases or classes that have been imported.
|
|
1467
|
+
If not provided, the default callbacks (AddReturnTokenIds and LightningOpenTelemetry) will be used.
|
|
1468
|
+
Available callback aliases are: "return_token_ids", "opentelemetry", "logprobs".
|
|
1469
|
+
otlp_endpoint: Optional OTLP endpoint URL for direct trace export to external
|
|
1470
|
+
collectors (e.g., Langfuse/Insight). When set, spans are exported directly without
|
|
1471
|
+
requiring rollout/attempt headers. Format: "http://host:port/api/public/otel/v1/traces"
|
|
1472
|
+
otlp_headers: Optional dict of HTTP headers for OTLP authentication.
|
|
1473
|
+
For Langfuse/Insight, use Basic Auth: {"Authorization": "Basic base64(publicKey:secretKey)"}
|
|
1474
|
+
"""
|
|
1475
|
+
|
|
1476
|
+
def __init__(
|
|
1477
|
+
self,
|
|
1478
|
+
port: int | None = None,
|
|
1479
|
+
model_list: List[ModelConfig] | None = None,
|
|
1480
|
+
store: Optional[LightningStore] = None,
|
|
1481
|
+
host: str | None = None,
|
|
1482
|
+
litellm_config: Dict[str, Any] | None = None,
|
|
1483
|
+
num_retries: int = 0,
|
|
1484
|
+
num_workers: int = 1,
|
|
1485
|
+
launch_mode: LaunchMode = "mp",
|
|
1486
|
+
launcher_args: PythonServerLauncherArgs | None = None,
|
|
1487
|
+
middlewares: Sequence[Union[Type[BaseHTTPMiddleware], str]] | None = None,
|
|
1488
|
+
callbacks: Sequence[Union[Type[CustomLogger], str]] | None = None,
|
|
1489
|
+
otlp_endpoint: Optional[str] = None,
|
|
1490
|
+
otlp_headers: Optional[Dict[str, str]] = None,
|
|
1491
|
+
):
|
|
1492
|
+
self.store = store
|
|
1493
|
+
self._otlp_endpoint = otlp_endpoint
|
|
1494
|
+
self._otlp_headers = otlp_headers
|
|
1495
|
+
|
|
1496
|
+
# Log OTLP configuration for diagnostics
|
|
1497
|
+
if otlp_endpoint:
|
|
1498
|
+
logger.info(f"LLMProxy initialized with OTLP endpoint: {otlp_endpoint}")
|
|
1499
|
+
if otlp_headers:
|
|
1500
|
+
logger.info(f"LLMProxy OTLP headers configured: {list(otlp_headers.keys())}")
|
|
1501
|
+
else:
|
|
1502
|
+
logger.debug("LLMProxy initialized without OTLP endpoint (will use store-based export)")
|
|
1503
|
+
|
|
1504
|
+
if launcher_args is not None and (
|
|
1505
|
+
port is not None or host is not None or launch_mode != "mp" or num_workers != 1
|
|
1506
|
+
):
|
|
1507
|
+
raise ValueError("port, host, launch_mode, and num_workers cannot be set when launcher_args is provided.")
|
|
1508
|
+
|
|
1509
|
+
self.server_launcher_args = launcher_args or PythonServerLauncherArgs(
|
|
1510
|
+
port=port,
|
|
1511
|
+
host=host,
|
|
1512
|
+
launch_mode=launch_mode,
|
|
1513
|
+
n_workers=num_workers,
|
|
1514
|
+
# NOTE: This /health endpoint can be slow sometimes because it actually probes the backend LLM service.
|
|
1515
|
+
healthcheck_url="/health",
|
|
1516
|
+
startup_timeout=60.0,
|
|
1517
|
+
)
|
|
1518
|
+
|
|
1519
|
+
if self.server_launcher_args.healthcheck_url is None:
|
|
1520
|
+
logger.warning("healthcheck_url is not set. LLM Proxy will not be checked for healthiness after starting.")
|
|
1521
|
+
|
|
1522
|
+
self.model_list = model_list or []
|
|
1523
|
+
self.litellm_config = litellm_config or {}
|
|
1524
|
+
|
|
1525
|
+
# Ensure num_retries is present inside the litellm_settings block.
|
|
1526
|
+
self.litellm_config.setdefault("litellm_settings", {})
|
|
1527
|
+
self.litellm_config["litellm_settings"].setdefault("num_retries", num_retries)
|
|
1528
|
+
self.server_launcher = PythonServerLauncher(app, self.server_launcher_args, noop_context())
|
|
1529
|
+
|
|
1530
|
+
self._config_file = None
|
|
1531
|
+
|
|
1532
|
+
self.middlewares: List[Type[BaseHTTPMiddleware]] = []
|
|
1533
|
+
if middlewares is None:
|
|
1534
|
+
middlewares = ["mantis_headers", "rollout_attempt", "stream_conversion"]
|
|
1535
|
+
for middleware in middlewares:
|
|
1536
|
+
if isinstance(middleware, str):
|
|
1537
|
+
if middleware not in _MIDDLEWARE_REGISTRY:
|
|
1538
|
+
raise ValueError(
|
|
1539
|
+
f"Invalid middleware alias: {middleware}. Available aliases are: {list(_MIDDLEWARE_REGISTRY.keys())}"
|
|
1540
|
+
)
|
|
1541
|
+
middleware = _MIDDLEWARE_REGISTRY[middleware]
|
|
1542
|
+
self.middlewares.append(middleware)
|
|
1543
|
+
else:
|
|
1544
|
+
self.middlewares.append(middleware)
|
|
1545
|
+
|
|
1546
|
+
self.callbacks: List[Type[CustomLogger]] = []
|
|
1547
|
+
if callbacks is None:
|
|
1548
|
+
callbacks = ["return_token_ids", "opentelemetry"]
|
|
1549
|
+
for callback in callbacks:
|
|
1550
|
+
if isinstance(callback, str):
|
|
1551
|
+
if callback not in _CALLBACK_REGISTRY:
|
|
1552
|
+
raise ValueError(
|
|
1553
|
+
f"Invalid callback alias: {callback}. Available aliases are: {list(_CALLBACK_REGISTRY.keys())}"
|
|
1554
|
+
)
|
|
1555
|
+
callback = _CALLBACK_REGISTRY[callback]
|
|
1556
|
+
self.callbacks.append(callback)
|
|
1557
|
+
else:
|
|
1558
|
+
self.callbacks.append(callback)
|
|
1559
|
+
|
|
1560
|
+
def get_store(self) -> Optional[LightningStore]:
|
|
1561
|
+
"""Get the store used by the proxy.
|
|
1562
|
+
|
|
1563
|
+
Returns:
|
|
1564
|
+
The store used by the proxy.
|
|
1565
|
+
"""
|
|
1566
|
+
return self.store
|
|
1567
|
+
|
|
1568
|
+
def set_store(self, store: LightningStore) -> None:
|
|
1569
|
+
"""Set the store for the proxy.
|
|
1570
|
+
|
|
1571
|
+
Args:
|
|
1572
|
+
store: The store to use for the proxy.
|
|
1573
|
+
"""
|
|
1574
|
+
self.store = store
|
|
1575
|
+
|
|
1576
|
+
def update_model_list(self, model_list: List[ModelConfig]) -> None:
|
|
1577
|
+
"""Replace the in-memory model list.
|
|
1578
|
+
|
|
1579
|
+
Args:
|
|
1580
|
+
model_list: New list of model entries.
|
|
1581
|
+
"""
|
|
1582
|
+
self.model_list = model_list
|
|
1583
|
+
logger.info(f"Updating LLMProxy model list to: {model_list}")
|
|
1584
|
+
# Do nothing if the server is not running.
|
|
1585
|
+
|
|
1586
|
+
def initialize(self):
|
|
1587
|
+
"""Initialize global middleware and LiteLLM callbacks.
|
|
1588
|
+
|
|
1589
|
+
Installs:
|
|
1590
|
+
|
|
1591
|
+
* A FastAPI middleware that rewrites /rollout/{rid}/attempt/{aid}/... paths,
|
|
1592
|
+
injects rollout/attempt/sequence headers, and forwards downstream.
|
|
1593
|
+
* LiteLLM callbacks for token ids and OpenTelemetry export.
|
|
1594
|
+
|
|
1595
|
+
The middleware can only be installed once because once the FastAPI app has started,
|
|
1596
|
+
the middleware cannot be changed any more.
|
|
1597
|
+
|
|
1598
|
+
This function does not start any server. It only wires global hooks.
|
|
1599
|
+
"""
|
|
1600
|
+
if self.store is None:
|
|
1601
|
+
raise ValueError("Store is not set. Please set the store before initializing the LLMProxy.")
|
|
1602
|
+
|
|
1603
|
+
if _global_llm_proxy is not None:
|
|
1604
|
+
logger.warning("A global LLMProxy is already set. Overwriting it with the new instance.")
|
|
1605
|
+
|
|
1606
|
+
# Patch for LiteLLM v1.80.6+: https://github.com/BerriAI/litellm/issues/17243
|
|
1607
|
+
os.environ["USE_OTEL_LITELLM_REQUEST_SPAN"] = "true"
|
|
1608
|
+
|
|
1609
|
+
# Set the global LLMProxy reference for middleware/exporter access.
|
|
1610
|
+
set_active_llm_proxy(self)
|
|
1611
|
+
|
|
1612
|
+
# Install middleware if it's not already installed.
|
|
1613
|
+
installation_status: Dict[Any, bool] = {}
|
|
1614
|
+
for mw in app.user_middleware:
|
|
1615
|
+
installation_status[mw.cls] = True
|
|
1616
|
+
|
|
1617
|
+
for mw in self.middlewares:
|
|
1618
|
+
if mw not in installation_status:
|
|
1619
|
+
logger.info(f"Adding middleware {mw} to the FastAPI app.")
|
|
1620
|
+
app.add_middleware(mw)
|
|
1621
|
+
else:
|
|
1622
|
+
logger.info(f"Middleware {mw} is already installed. Will not install a new one.")
|
|
1623
|
+
|
|
1624
|
+
if not initialize_llm_callbacks(self.callbacks, otlp_endpoint=self._otlp_endpoint, otlp_headers=self._otlp_headers):
|
|
1625
|
+
# If it's not the first time to initialize the callbacks, also
|
|
1626
|
+
# reset LiteLLM's logging worker so its asyncio.Queue binds to the new loop.
|
|
1627
|
+
_reset_litellm_logging_worker()
|
|
1628
|
+
|
|
1629
|
+
@asynccontextmanager
|
|
1630
|
+
async def _serve_context(self) -> AsyncGenerator[None, None]:
|
|
1631
|
+
"""Context manager to serve the proxy server.
|
|
1632
|
+
|
|
1633
|
+
See [`start`][mantisdk.LLMProxy.start] and [`stop`][mantisdk.LLMProxy.stop] for more details.
|
|
1634
|
+
"""
|
|
1635
|
+
|
|
1636
|
+
if not self.store:
|
|
1637
|
+
raise ValueError("Store is not set. Please set the store before starting the LLMProxy.")
|
|
1638
|
+
|
|
1639
|
+
# Initialize global middleware and callbacks.
|
|
1640
|
+
self.initialize()
|
|
1641
|
+
|
|
1642
|
+
# Persist a temp worker config for LiteLLM and point the proxy at it.
|
|
1643
|
+
self._config_file = tempfile.NamedTemporaryFile(suffix=".yaml", delete=False).name
|
|
1644
|
+
with open(self._config_file, "w") as fp:
|
|
1645
|
+
yaml.safe_dump(
|
|
1646
|
+
{
|
|
1647
|
+
"model_list": self.model_list,
|
|
1648
|
+
**self.litellm_config,
|
|
1649
|
+
},
|
|
1650
|
+
fp,
|
|
1651
|
+
)
|
|
1652
|
+
|
|
1653
|
+
save_worker_config(config=self._config_file)
|
|
1654
|
+
|
|
1655
|
+
# NOTE: When running the _serve_context in current process, you might encounter the following problems:
|
|
1656
|
+
# Problem 1: in litellm worker, <Queue at 0x70f1d028cd90 maxsize=50000> is bound to a different event loop
|
|
1657
|
+
# Problem 2: Proxy has conflicted opentelemetry setup with the main process.
|
|
1658
|
+
|
|
1659
|
+
# Ready
|
|
1660
|
+
logger.info("LLMProxy preparation is done. Will start the server.")
|
|
1661
|
+
yield
|
|
1662
|
+
|
|
1663
|
+
# Clean up
|
|
1664
|
+
|
|
1665
|
+
logger.info("LLMProxy server is cleaning up.")
|
|
1666
|
+
|
|
1667
|
+
# Remove worker config to avoid stale references.
|
|
1668
|
+
if self._config_file and os.path.exists(self._config_file):
|
|
1669
|
+
os.unlink(self._config_file)
|
|
1670
|
+
|
|
1671
|
+
logger.info("LLMProxy server finishes.")
|
|
1672
|
+
|
|
1673
|
+
async def start(self):
|
|
1674
|
+
"""Start the proxy server thread and initialize global wiring.
|
|
1675
|
+
|
|
1676
|
+
Side effects:
|
|
1677
|
+
|
|
1678
|
+
* Sets the module-level global store for middleware/exporter access.
|
|
1679
|
+
* Calls `initialize()` once to register middleware and callbacks.
|
|
1680
|
+
* Writes a temporary YAML config consumed by LiteLLM worker.
|
|
1681
|
+
* Launches uvicorn in a daemon thread and waits for readiness.
|
|
1682
|
+
"""
|
|
1683
|
+
# Refresh the serve context
|
|
1684
|
+
self.server_launcher.serve_context = self._serve_context()
|
|
1685
|
+
|
|
1686
|
+
if self.store is None:
|
|
1687
|
+
raise ValueError("Store is not set. Please set the store before starting the LLMProxy.")
|
|
1688
|
+
|
|
1689
|
+
store_capabilities = self.store.capabilities
|
|
1690
|
+
if self.server_launcher.args.launch_mode == "mp" and not store_capabilities.get("zero_copy", False):
|
|
1691
|
+
raise RuntimeError(
|
|
1692
|
+
"The store does not support zero-copy. Please use another store, or use asyncio or thread mode to launch the server."
|
|
1693
|
+
)
|
|
1694
|
+
elif self.server_launcher.args.launch_mode == "thread" and not store_capabilities.get("thread_safe", False):
|
|
1695
|
+
raise RuntimeError(
|
|
1696
|
+
"The store is not thread-safe. Please use another store, or use asyncio mode to launch the server."
|
|
1697
|
+
)
|
|
1698
|
+
elif self.server_launcher.args.launch_mode == "asyncio" and not store_capabilities.get("async_safe", False):
|
|
1699
|
+
raise RuntimeError("The store is not async-safe. Please use another store.")
|
|
1700
|
+
|
|
1701
|
+
logger.info(
|
|
1702
|
+
f"Starting LLMProxy server in {self.server_launcher.args.launch_mode} mode with store capabilities: {store_capabilities}"
|
|
1703
|
+
)
|
|
1704
|
+
|
|
1705
|
+
await self.server_launcher.start()
|
|
1706
|
+
|
|
1707
|
+
async def stop(self):
|
|
1708
|
+
"""Stop the proxy server and clean up temporary artifacts.
|
|
1709
|
+
|
|
1710
|
+
This is a best-effort graceful shutdown with a bounded join timeout.
|
|
1711
|
+
"""
|
|
1712
|
+
if not self.is_running():
|
|
1713
|
+
logger.warning("LLMProxy is not running. Nothing to stop.")
|
|
1714
|
+
return
|
|
1715
|
+
|
|
1716
|
+
await self.server_launcher.stop()
|
|
1717
|
+
|
|
1718
|
+
async def restart(self, *, _port: int | None = None) -> None:
|
|
1719
|
+
"""Restart the proxy if running, else start it.
|
|
1720
|
+
|
|
1721
|
+
Convenience wrapper calling `stop()` followed by `start()`.
|
|
1722
|
+
"""
|
|
1723
|
+
logger.info("Restarting LLMProxy server...")
|
|
1724
|
+
if self.is_running():
|
|
1725
|
+
await self.stop()
|
|
1726
|
+
if _port is not None:
|
|
1727
|
+
self.server_launcher_args.port = _port
|
|
1728
|
+
await self.start()
|
|
1729
|
+
|
|
1730
|
+
def is_running(self) -> bool:
|
|
1731
|
+
"""Return whether the uvicorn server is active.
|
|
1732
|
+
|
|
1733
|
+
Returns:
|
|
1734
|
+
bool: True if server was started and did not signal exit.
|
|
1735
|
+
"""
|
|
1736
|
+
return self.server_launcher.is_running()
|
|
1737
|
+
|
|
1738
|
+
def as_resource(
|
|
1739
|
+
self,
|
|
1740
|
+
rollout_id: str | None = None,
|
|
1741
|
+
attempt_id: str | None = None,
|
|
1742
|
+
model: str | None = None,
|
|
1743
|
+
sampling_parameters: Dict[str, Any] | None = None,
|
|
1744
|
+
) -> LLM:
|
|
1745
|
+
"""Create an `LLM` resource pointing at this proxy with rollout context.
|
|
1746
|
+
|
|
1747
|
+
The returned endpoint is:
|
|
1748
|
+
`http://{host}:{port}/rollout/{rollout_id}/attempt/{attempt_id}`
|
|
1749
|
+
|
|
1750
|
+
Args:
|
|
1751
|
+
rollout_id: Rollout identifier used for span attribution. If None, will instantiate a ProxyLLM resource.
|
|
1752
|
+
attempt_id: Attempt identifier used for span attribution. If None, will instantiate a ProxyLLM resource.
|
|
1753
|
+
model: Logical model name to use. If omitted and exactly one model
|
|
1754
|
+
is configured or all models have the same name, that model is used.
|
|
1755
|
+
sampling_parameters: Optional default sampling parameters.
|
|
1756
|
+
|
|
1757
|
+
Returns:
|
|
1758
|
+
LLM: Configured resource ready for OpenAI-compatible calls.
|
|
1759
|
+
|
|
1760
|
+
Raises:
|
|
1761
|
+
ValueError: If `model` is omitted and zero or multiple models are configured.
|
|
1762
|
+
"""
|
|
1763
|
+
if model is None:
|
|
1764
|
+
if len(self.model_list) == 1:
|
|
1765
|
+
model = self.model_list[0]["model_name"]
|
|
1766
|
+
elif len(self.model_list) == 0:
|
|
1767
|
+
raise ValueError("No models found in model_list. Please specify the model.")
|
|
1768
|
+
else:
|
|
1769
|
+
first_model_name = self.model_list[0]["model_name"]
|
|
1770
|
+
if all(model_config["model_name"] == first_model_name for model_config in self.model_list):
|
|
1771
|
+
model = first_model_name
|
|
1772
|
+
else:
|
|
1773
|
+
raise ValueError(
|
|
1774
|
+
f"Multiple models found in model_list: {self.model_list}. Please specify the model."
|
|
1775
|
+
)
|
|
1776
|
+
|
|
1777
|
+
if rollout_id is None and attempt_id is None:
|
|
1778
|
+
return ProxyLLM(
|
|
1779
|
+
endpoint=self.server_launcher.access_endpoint,
|
|
1780
|
+
model=model,
|
|
1781
|
+
sampling_parameters=dict(sampling_parameters or {}),
|
|
1782
|
+
)
|
|
1783
|
+
elif rollout_id is not None and attempt_id is not None:
|
|
1784
|
+
return LLM(
|
|
1785
|
+
endpoint=f"{self.server_launcher.access_endpoint}/rollout/{rollout_id}/attempt/{attempt_id}",
|
|
1786
|
+
model=model,
|
|
1787
|
+
sampling_parameters=dict(sampling_parameters or {}),
|
|
1788
|
+
)
|
|
1789
|
+
else:
|
|
1790
|
+
raise ValueError("Either rollout_id and attempt_id must be provided, or neither.")
|
|
1791
|
+
|
|
1792
|
+
|
|
1793
|
+
_global_llm_proxy: Optional[LLMProxy] = None
|
|
1794
|
+
_callbacks_before_litellm_start: Optional[List[Any]] = None
|
|
1795
|
+
|
|
1796
|
+
|
|
1797
|
+
def get_active_llm_proxy() -> LLMProxy:
|
|
1798
|
+
"""Get the current global LLMProxy instance.
|
|
1799
|
+
|
|
1800
|
+
Returns:
|
|
1801
|
+
Optional[LLMProxy]: The current LLMProxy if set, else None.
|
|
1802
|
+
"""
|
|
1803
|
+
if _global_llm_proxy is None:
|
|
1804
|
+
raise ValueError("Global LLMProxy is not set. Please call llm_proxy.start() first.")
|
|
1805
|
+
return _global_llm_proxy
|
|
1806
|
+
|
|
1807
|
+
|
|
1808
|
+
def set_active_llm_proxy(proxy: LLMProxy) -> None:
|
|
1809
|
+
"""Set the current global LLMProxy instance.
|
|
1810
|
+
|
|
1811
|
+
Args:
|
|
1812
|
+
proxy: The LLMProxy instance to set as global.
|
|
1813
|
+
"""
|
|
1814
|
+
global _global_llm_proxy
|
|
1815
|
+
_global_llm_proxy = proxy
|
|
1816
|
+
|
|
1817
|
+
|
|
1818
|
+
def initialize_llm_callbacks(
|
|
1819
|
+
callback_classes: List[Type[CustomLogger]],
|
|
1820
|
+
otlp_endpoint: Optional[str] = None,
|
|
1821
|
+
otlp_headers: Optional[Dict[str, str]] = None,
|
|
1822
|
+
) -> bool:
|
|
1823
|
+
"""Restore `litellm.callbacks` to a state that is just initialized by mantisdk.
|
|
1824
|
+
|
|
1825
|
+
When litellm is restarted multiple times in the same process, more and more callbacks
|
|
1826
|
+
will be appended to `litellm.callbacks`, which may exceed the MAX_CALLBACKS limit.
|
|
1827
|
+
This function remembers the initial state of `litellm.callbacks` and always restore to that state.
|
|
1828
|
+
|
|
1829
|
+
Args:
|
|
1830
|
+
callback_classes: List of callback classes to register.
|
|
1831
|
+
otlp_endpoint: Optional OTLP endpoint URL for direct trace export.
|
|
1832
|
+
otlp_headers: Optional dict of HTTP headers for OTLP authentication.
|
|
1833
|
+
|
|
1834
|
+
Returns:
|
|
1835
|
+
Whether the callbacks are initialized for the first time.
|
|
1836
|
+
"""
|
|
1837
|
+
global _callbacks_before_litellm_start
|
|
1838
|
+
|
|
1839
|
+
def _instantiate_callback(cls: Type[CustomLogger]) -> CustomLogger:
|
|
1840
|
+
"""Instantiate callback with appropriate arguments."""
|
|
1841
|
+
if cls is LightningOpenTelemetry:
|
|
1842
|
+
return LightningOpenTelemetry(otlp_endpoint=otlp_endpoint, otlp_headers=otlp_headers)
|
|
1843
|
+
return cls()
|
|
1844
|
+
|
|
1845
|
+
if _callbacks_before_litellm_start is None:
|
|
1846
|
+
litellm.callbacks.extend([_instantiate_callback(cls) for cls in callback_classes]) # type: ignore
|
|
1847
|
+
_callbacks_before_litellm_start = [*litellm.callbacks] # type: ignore
|
|
1848
|
+
return True
|
|
1849
|
+
|
|
1850
|
+
else:
|
|
1851
|
+
# Put whatever is missing in the new callback classes to the existing callbacks.
|
|
1852
|
+
for cls in callback_classes:
|
|
1853
|
+
if not any(isinstance(cb, cls) for cb in _callbacks_before_litellm_start):
|
|
1854
|
+
logger.info(f"Adding missing callback {cls} to the existing callbacks.")
|
|
1855
|
+
_callbacks_before_litellm_start.append(_instantiate_callback(cls))
|
|
1856
|
+
|
|
1857
|
+
_reset_litellm_logging_callback_manager()
|
|
1858
|
+
|
|
1859
|
+
if LightningOpenTelemetry in callback_classes:
|
|
1860
|
+
# Check if tracer provider is malformed due to global tracer clear in tests.
|
|
1861
|
+
if not _check_tracer_provider():
|
|
1862
|
+
logger.warning(
|
|
1863
|
+
"Global tracer provider might have been cleared outside. Re-initializing OpenTelemetry callback."
|
|
1864
|
+
)
|
|
1865
|
+
_callbacks_before_litellm_start = [
|
|
1866
|
+
cb for cb in _callbacks_before_litellm_start if not isinstance(cb, LightningOpenTelemetry)
|
|
1867
|
+
] + [LightningOpenTelemetry(otlp_endpoint=otlp_endpoint, otlp_headers=otlp_headers)]
|
|
1868
|
+
else:
|
|
1869
|
+
logger.debug("Global tracer provider is valid. Reusing existing OpenTelemetry callback.")
|
|
1870
|
+
# Otherwise, we just skip the check for opentelemetry and use the existing callback.
|
|
1871
|
+
|
|
1872
|
+
litellm.callbacks.clear() # type: ignore
|
|
1873
|
+
litellm.callbacks.extend(_callbacks_before_litellm_start) # type: ignore
|
|
1874
|
+
return False
|
|
1875
|
+
|
|
1876
|
+
|
|
1877
|
+
def _check_tracer_provider() -> bool:
|
|
1878
|
+
"""Check if the global tracer provider is properly initialized.
|
|
1879
|
+
|
|
1880
|
+
We don't guarantee the tracer provider is our tracer provider.
|
|
1881
|
+
|
|
1882
|
+
Returns:
|
|
1883
|
+
bool: True if the tracer provider is valid, else False.
|
|
1884
|
+
"""
|
|
1885
|
+
if (
|
|
1886
|
+
hasattr(trace_api, "_TRACER_PROVIDER")
|
|
1887
|
+
and trace_api._TRACER_PROVIDER is not None # pyright: ignore[reportPrivateUsage]
|
|
1888
|
+
):
|
|
1889
|
+
return True
|
|
1890
|
+
return False
|