mantisdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mantisdk might be problematic. Click here for more details.
- mantisdk/__init__.py +22 -0
- mantisdk/adapter/__init__.py +15 -0
- mantisdk/adapter/base.py +94 -0
- mantisdk/adapter/messages.py +270 -0
- mantisdk/adapter/triplet.py +1028 -0
- mantisdk/algorithm/__init__.py +39 -0
- mantisdk/algorithm/apo/__init__.py +5 -0
- mantisdk/algorithm/apo/apo.py +889 -0
- mantisdk/algorithm/apo/prompts/apply_edit_variant01.poml +22 -0
- mantisdk/algorithm/apo/prompts/apply_edit_variant02.poml +18 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant01.poml +18 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant02.poml +16 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant03.poml +107 -0
- mantisdk/algorithm/base.py +162 -0
- mantisdk/algorithm/decorator.py +264 -0
- mantisdk/algorithm/fast.py +250 -0
- mantisdk/algorithm/gepa/__init__.py +59 -0
- mantisdk/algorithm/gepa/adapter.py +459 -0
- mantisdk/algorithm/gepa/gepa.py +364 -0
- mantisdk/algorithm/gepa/lib/__init__.py +18 -0
- mantisdk/algorithm/gepa/lib/adapters/README.md +12 -0
- mantisdk/algorithm/gepa/lib/adapters/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/README.md +341 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/__init__.py +1 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/anymaths_adapter.py +174 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/requirements.txt +1 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/README.md +0 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/default_adapter.py +209 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/README.md +7 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/dspy_adapter.py +307 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/README.md +99 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/dspy_program_proposal_signature.py +137 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/full_program_adapter.py +266 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/GEPA_RAG.md +621 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/__init__.py +56 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/evaluation_metrics.py +226 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/generic_rag_adapter.py +496 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/rag_pipeline.py +238 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_store_interface.py +212 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/__init__.py +2 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/chroma_store.py +196 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/lancedb_store.py +422 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/milvus_store.py +409 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/qdrant_store.py +368 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/weaviate_store.py +418 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/README.md +552 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/__init__.py +37 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_adapter.py +705 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_client.py +364 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/README.md +9 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/terminal_bench_adapter.py +217 -0
- mantisdk/algorithm/gepa/lib/api.py +375 -0
- mantisdk/algorithm/gepa/lib/core/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/core/adapter.py +180 -0
- mantisdk/algorithm/gepa/lib/core/data_loader.py +74 -0
- mantisdk/algorithm/gepa/lib/core/engine.py +356 -0
- mantisdk/algorithm/gepa/lib/core/result.py +233 -0
- mantisdk/algorithm/gepa/lib/core/state.py +636 -0
- mantisdk/algorithm/gepa/lib/examples/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/examples/aime.py +24 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/eval_default.py +111 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/instruction_prompt.txt +9 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/optimal_prompt.txt +24 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/train_anymaths.py +177 -0
- mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/arc_agi.ipynb +25705 -0
- mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/example.ipynb +348 -0
- mantisdk/algorithm/gepa/lib/examples/mcp_adapter/__init__.py +4 -0
- mantisdk/algorithm/gepa/lib/examples/mcp_adapter/mcp_optimization_example.py +455 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/RAG_GUIDE.md +613 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/__init__.py +9 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/rag_optimization.py +824 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/requirements-rag.txt +29 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/instruction_prompt.txt +16 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/terminus.txt +9 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/train_terminus.py +161 -0
- mantisdk/algorithm/gepa/lib/gepa_utils.py +117 -0
- mantisdk/algorithm/gepa/lib/logging/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/logging/experiment_tracker.py +187 -0
- mantisdk/algorithm/gepa/lib/logging/logger.py +75 -0
- mantisdk/algorithm/gepa/lib/logging/utils.py +103 -0
- mantisdk/algorithm/gepa/lib/proposer/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/proposer/base.py +31 -0
- mantisdk/algorithm/gepa/lib/proposer/merge.py +357 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/base.py +49 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/reflective_mutation.py +176 -0
- mantisdk/algorithm/gepa/lib/py.typed +0 -0
- mantisdk/algorithm/gepa/lib/strategies/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/strategies/batch_sampler.py +77 -0
- mantisdk/algorithm/gepa/lib/strategies/candidate_selector.py +50 -0
- mantisdk/algorithm/gepa/lib/strategies/component_selector.py +36 -0
- mantisdk/algorithm/gepa/lib/strategies/eval_policy.py +64 -0
- mantisdk/algorithm/gepa/lib/strategies/instruction_proposal.py +127 -0
- mantisdk/algorithm/gepa/lib/utils/__init__.py +10 -0
- mantisdk/algorithm/gepa/lib/utils/stop_condition.py +196 -0
- mantisdk/algorithm/gepa/tracing.py +105 -0
- mantisdk/algorithm/utils.py +177 -0
- mantisdk/algorithm/verl/__init__.py +5 -0
- mantisdk/algorithm/verl/interface.py +202 -0
- mantisdk/cli/__init__.py +56 -0
- mantisdk/cli/prometheus.py +115 -0
- mantisdk/cli/store.py +131 -0
- mantisdk/cli/vllm.py +29 -0
- mantisdk/client.py +408 -0
- mantisdk/config.py +348 -0
- mantisdk/emitter/__init__.py +43 -0
- mantisdk/emitter/annotation.py +370 -0
- mantisdk/emitter/exception.py +54 -0
- mantisdk/emitter/message.py +61 -0
- mantisdk/emitter/object.py +117 -0
- mantisdk/emitter/reward.py +320 -0
- mantisdk/env_var.py +156 -0
- mantisdk/execution/__init__.py +15 -0
- mantisdk/execution/base.py +64 -0
- mantisdk/execution/client_server.py +443 -0
- mantisdk/execution/events.py +69 -0
- mantisdk/execution/inter_process.py +16 -0
- mantisdk/execution/shared_memory.py +282 -0
- mantisdk/instrumentation/__init__.py +119 -0
- mantisdk/instrumentation/agentops.py +314 -0
- mantisdk/instrumentation/agentops_langchain.py +45 -0
- mantisdk/instrumentation/litellm.py +83 -0
- mantisdk/instrumentation/vllm.py +81 -0
- mantisdk/instrumentation/weave.py +500 -0
- mantisdk/litagent/__init__.py +11 -0
- mantisdk/litagent/decorator.py +536 -0
- mantisdk/litagent/litagent.py +252 -0
- mantisdk/llm_proxy.py +1890 -0
- mantisdk/logging.py +370 -0
- mantisdk/reward.py +7 -0
- mantisdk/runner/__init__.py +11 -0
- mantisdk/runner/agent.py +845 -0
- mantisdk/runner/base.py +182 -0
- mantisdk/runner/legacy.py +309 -0
- mantisdk/semconv.py +170 -0
- mantisdk/server.py +401 -0
- mantisdk/store/__init__.py +23 -0
- mantisdk/store/base.py +897 -0
- mantisdk/store/client_server.py +2092 -0
- mantisdk/store/collection/__init__.py +30 -0
- mantisdk/store/collection/base.py +587 -0
- mantisdk/store/collection/memory.py +970 -0
- mantisdk/store/collection/mongo.py +1412 -0
- mantisdk/store/collection_based.py +1823 -0
- mantisdk/store/insight.py +648 -0
- mantisdk/store/listener.py +58 -0
- mantisdk/store/memory.py +396 -0
- mantisdk/store/mongo.py +165 -0
- mantisdk/store/sqlite.py +3 -0
- mantisdk/store/threading.py +357 -0
- mantisdk/store/utils.py +142 -0
- mantisdk/tracer/__init__.py +16 -0
- mantisdk/tracer/agentops.py +242 -0
- mantisdk/tracer/base.py +287 -0
- mantisdk/tracer/dummy.py +106 -0
- mantisdk/tracer/otel.py +555 -0
- mantisdk/tracer/weave.py +677 -0
- mantisdk/trainer/__init__.py +6 -0
- mantisdk/trainer/init_utils.py +263 -0
- mantisdk/trainer/legacy.py +367 -0
- mantisdk/trainer/registry.py +12 -0
- mantisdk/trainer/trainer.py +618 -0
- mantisdk/types/__init__.py +6 -0
- mantisdk/types/core.py +553 -0
- mantisdk/types/resources.py +204 -0
- mantisdk/types/tracer.py +515 -0
- mantisdk/types/tracing.py +218 -0
- mantisdk/utils/__init__.py +1 -0
- mantisdk/utils/id.py +18 -0
- mantisdk/utils/metrics.py +1025 -0
- mantisdk/utils/otel.py +578 -0
- mantisdk/utils/otlp.py +536 -0
- mantisdk/utils/server_launcher.py +1045 -0
- mantisdk/utils/system_snapshot.py +81 -0
- mantisdk/verl/__init__.py +8 -0
- mantisdk/verl/__main__.py +6 -0
- mantisdk/verl/async_server.py +46 -0
- mantisdk/verl/config.yaml +27 -0
- mantisdk/verl/daemon.py +1154 -0
- mantisdk/verl/dataset.py +44 -0
- mantisdk/verl/entrypoint.py +248 -0
- mantisdk/verl/trainer.py +549 -0
- mantisdk-0.1.0.dist-info/METADATA +119 -0
- mantisdk-0.1.0.dist-info/RECORD +190 -0
- mantisdk-0.1.0.dist-info/WHEEL +4 -0
- mantisdk-0.1.0.dist-info/entry_points.txt +2 -0
- mantisdk-0.1.0.dist-info/licenses/LICENSE +19 -0
|
@@ -0,0 +1,500 @@
|
|
|
1
|
+
# Copyright (c) Microsoft. All rights reserved.
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
import threading
|
|
7
|
+
import warnings
|
|
8
|
+
from datetime import datetime, timezone
|
|
9
|
+
from typing import Any, Callable, Dict, Iterator, List
|
|
10
|
+
|
|
11
|
+
import weave.trace.weave_init
|
|
12
|
+
from pydantic import validate_call
|
|
13
|
+
from weave.trace_server import trace_server_interface as tsi
|
|
14
|
+
from weave.trace_server.ids import generate_id
|
|
15
|
+
from weave.trace_server_bindings.client_interface import TraceServerClientInterface
|
|
16
|
+
from weave.trace_server_bindings.models import ServerInfoRes
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
__all__ = [
|
|
21
|
+
"instrument_weave",
|
|
22
|
+
"uninstrument_weave",
|
|
23
|
+
"InMemoryWeaveTraceServer",
|
|
24
|
+
]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class InMemoryWeaveTraceServer(TraceServerClientInterface):
|
|
28
|
+
"""A minimal in-memory implementation of the TraceServerInterface.
|
|
29
|
+
|
|
30
|
+
It stores calls and objects in local dictionaries and returns valid Pydantic
|
|
31
|
+
responses to satisfy the Weave client and FullTraceServerInterface protocol.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(self):
|
|
35
|
+
# Minimal storage to allow basic querying in tests
|
|
36
|
+
self.calls: Dict[str, tsi.CallSchema] = {}
|
|
37
|
+
self.partial_calls: Dict[str, Dict[str, Any]] = {}
|
|
38
|
+
self.objs: Dict[str, Any] = {}
|
|
39
|
+
self.files: Dict[str, bytes] = {}
|
|
40
|
+
self.feedback: List[tsi.FeedbackCreateReq] = []
|
|
41
|
+
|
|
42
|
+
self._call_threading_lock = threading.Lock()
|
|
43
|
+
|
|
44
|
+
@classmethod
|
|
45
|
+
def from_env(cls, *args: Any, **kwargs: Any) -> InMemoryWeaveTraceServer:
|
|
46
|
+
return cls()
|
|
47
|
+
|
|
48
|
+
def server_info(self) -> ServerInfoRes:
|
|
49
|
+
return ServerInfoRes(min_required_weave_python_version="0.52.22")
|
|
50
|
+
|
|
51
|
+
def ensure_project_exists(self, entity: str, project: str) -> tsi.EnsureProjectExistsRes:
|
|
52
|
+
return tsi.EnsureProjectExistsRes(project_name=project)
|
|
53
|
+
|
|
54
|
+
# --- Call API ---
|
|
55
|
+
|
|
56
|
+
@validate_call
|
|
57
|
+
def call_start(self, req: tsi.CallStartReq) -> tsi.CallStartRes:
|
|
58
|
+
# NOTE: It's not necessary that call_end must be called after call_start.
|
|
59
|
+
request_content = req.start.model_dump(exclude_none=True)
|
|
60
|
+
|
|
61
|
+
# If id needs to be generated here, it's very likely we won't be able to find the call later.
|
|
62
|
+
# This is just to make the type checker happy.
|
|
63
|
+
call_id = request_content.get("id") or generate_id()
|
|
64
|
+
trace_id = request_content.get("trace_id") or generate_id()
|
|
65
|
+
request_content["id"] = call_id
|
|
66
|
+
request_content["trace_id"] = trace_id
|
|
67
|
+
|
|
68
|
+
with self._call_threading_lock:
|
|
69
|
+
if call_id in self.partial_calls:
|
|
70
|
+
# call_end has already been called for this call.
|
|
71
|
+
kwargs = {**request_content, **self.partial_calls[call_id]}
|
|
72
|
+
self.calls[call_id] = tsi.CallSchema(**kwargs)
|
|
73
|
+
del self.partial_calls[call_id]
|
|
74
|
+
else:
|
|
75
|
+
self.partial_calls[call_id] = request_content
|
|
76
|
+
|
|
77
|
+
return tsi.CallStartRes(id=call_id, trace_id=trace_id)
|
|
78
|
+
|
|
79
|
+
@validate_call
|
|
80
|
+
def call_end(self, req: tsi.CallEndReq) -> tsi.CallEndRes:
|
|
81
|
+
request_content = req.end.model_dump(exclude_none=True)
|
|
82
|
+
call_id = req.end.id
|
|
83
|
+
|
|
84
|
+
with self._call_threading_lock:
|
|
85
|
+
if call_id in self.partial_calls:
|
|
86
|
+
# End request always override the start request content.
|
|
87
|
+
kwargs = {**self.partial_calls[call_id], **request_content}
|
|
88
|
+
self.calls[call_id] = tsi.CallSchema(**kwargs)
|
|
89
|
+
del self.partial_calls[call_id]
|
|
90
|
+
else:
|
|
91
|
+
self.partial_calls[call_id] = request_content
|
|
92
|
+
return tsi.CallEndRes()
|
|
93
|
+
|
|
94
|
+
@validate_call
|
|
95
|
+
def call_start_batch(self, req: tsi.CallCreateBatchReq) -> tsi.CallCreateBatchRes:
|
|
96
|
+
for item in req.batch:
|
|
97
|
+
if isinstance(item, tsi.CallStartReq):
|
|
98
|
+
self.call_start(item)
|
|
99
|
+
elif isinstance(item, tsi.CallEndReq):
|
|
100
|
+
self.call_end(item)
|
|
101
|
+
return tsi.CallCreateBatchRes(res=[])
|
|
102
|
+
|
|
103
|
+
@validate_call
|
|
104
|
+
def call_read(self, req: tsi.CallReadReq) -> tsi.CallReadRes:
|
|
105
|
+
call_data = self.calls.get(req.id)
|
|
106
|
+
return tsi.CallReadRes(call=call_data)
|
|
107
|
+
|
|
108
|
+
@validate_call
|
|
109
|
+
def calls_query(self, req: tsi.CallsQueryReq) -> tsi.CallsQueryRes:
|
|
110
|
+
return tsi.CallsQueryRes(calls=list(self.calls_query_stream(req)))
|
|
111
|
+
|
|
112
|
+
@validate_call
|
|
113
|
+
def calls_query_stream(self, req: tsi.CallsQueryReq) -> Iterator[tsi.CallSchema]:
|
|
114
|
+
yield from self.calls.values()
|
|
115
|
+
|
|
116
|
+
@validate_call
|
|
117
|
+
def calls_delete(self, req: tsi.CallsDeleteReq) -> tsi.CallsDeleteRes:
|
|
118
|
+
num_deleted = 0
|
|
119
|
+
for call_id in req.call_ids:
|
|
120
|
+
if call_id in self.calls:
|
|
121
|
+
del self.calls[call_id]
|
|
122
|
+
num_deleted += 1
|
|
123
|
+
return tsi.CallsDeleteRes(num_deleted=num_deleted)
|
|
124
|
+
|
|
125
|
+
@validate_call
|
|
126
|
+
def call_update(self, req: tsi.CallUpdateReq) -> tsi.CallUpdateRes:
|
|
127
|
+
return tsi.CallUpdateRes()
|
|
128
|
+
|
|
129
|
+
@validate_call
|
|
130
|
+
def calls_query_stats(self, req: tsi.CallsQueryStatsReq) -> tsi.CallsQueryStatsRes:
|
|
131
|
+
return tsi.CallsQueryStatsRes(count=len(self.calls))
|
|
132
|
+
|
|
133
|
+
# --- Cost API ---
|
|
134
|
+
|
|
135
|
+
@validate_call
|
|
136
|
+
def cost_create(self, req: tsi.CostCreateReq) -> tsi.CostCreateRes:
|
|
137
|
+
return tsi.CostCreateRes(ids=[(generate_id(), generate_id()) for _ in req.costs])
|
|
138
|
+
|
|
139
|
+
@validate_call
|
|
140
|
+
def cost_query(self, req: tsi.CostQueryReq) -> tsi.CostQueryRes:
|
|
141
|
+
return tsi.CostQueryRes(results=[])
|
|
142
|
+
|
|
143
|
+
@validate_call
|
|
144
|
+
def cost_purge(self, req: tsi.CostPurgeReq) -> tsi.CostPurgeRes:
|
|
145
|
+
return tsi.CostPurgeRes()
|
|
146
|
+
|
|
147
|
+
# --- Object API (Legacy V1) ---
|
|
148
|
+
|
|
149
|
+
@validate_call
|
|
150
|
+
def obj_create(self, req: tsi.ObjCreateReq) -> tsi.ObjCreateRes:
|
|
151
|
+
digest = generate_id()
|
|
152
|
+
self.objs[digest] = req.obj
|
|
153
|
+
return tsi.ObjCreateRes(digest=digest)
|
|
154
|
+
|
|
155
|
+
@validate_call
|
|
156
|
+
def obj_read(self, req: tsi.ObjReadReq) -> tsi.ObjReadRes:
|
|
157
|
+
return tsi.ObjReadRes(obj=self.objs.get(req.digest, {}))
|
|
158
|
+
|
|
159
|
+
@validate_call
|
|
160
|
+
def objs_query(self, req: tsi.ObjQueryReq) -> tsi.ObjQueryRes:
|
|
161
|
+
return tsi.ObjQueryRes(objs=[])
|
|
162
|
+
|
|
163
|
+
@validate_call
|
|
164
|
+
def obj_delete(self, req: tsi.ObjDeleteReq) -> tsi.ObjDeleteRes:
|
|
165
|
+
return tsi.ObjDeleteRes(num_deleted=0)
|
|
166
|
+
|
|
167
|
+
# --- Table API ---
|
|
168
|
+
|
|
169
|
+
@validate_call
|
|
170
|
+
def table_create(self, req: tsi.TableCreateReq) -> tsi.TableCreateRes:
|
|
171
|
+
return tsi.TableCreateRes(digest=generate_id(), row_digests=[])
|
|
172
|
+
|
|
173
|
+
@validate_call
|
|
174
|
+
def table_create_from_digests(self, req: tsi.TableCreateFromDigestsReq) -> tsi.TableCreateFromDigestsRes:
|
|
175
|
+
return tsi.TableCreateFromDigestsRes(digest=generate_id())
|
|
176
|
+
|
|
177
|
+
@validate_call
|
|
178
|
+
def table_update(self, req: tsi.TableUpdateReq) -> tsi.TableUpdateRes:
|
|
179
|
+
return tsi.TableUpdateRes(digest=generate_id(), updated_row_digests=[])
|
|
180
|
+
|
|
181
|
+
@validate_call
|
|
182
|
+
def table_query(self, req: tsi.TableQueryReq) -> tsi.TableQueryRes:
|
|
183
|
+
return tsi.TableQueryRes(rows=[])
|
|
184
|
+
|
|
185
|
+
@validate_call
|
|
186
|
+
def table_query_stream(self, req: tsi.TableQueryReq) -> Iterator[tsi.TableRowSchema]:
|
|
187
|
+
yield from []
|
|
188
|
+
|
|
189
|
+
@validate_call
|
|
190
|
+
def table_query_stats(self, req: tsi.TableQueryStatsReq) -> tsi.TableQueryStatsRes:
|
|
191
|
+
return tsi.TableQueryStatsRes(count=0)
|
|
192
|
+
|
|
193
|
+
@validate_call
|
|
194
|
+
def table_query_stats_batch(self, req: tsi.TableQueryStatsBatchReq) -> tsi.TableQueryStatsBatchRes:
|
|
195
|
+
return tsi.TableQueryStatsBatchRes(tables=[])
|
|
196
|
+
|
|
197
|
+
# --- Ref API ---
|
|
198
|
+
|
|
199
|
+
@validate_call
|
|
200
|
+
def refs_read_batch(self, req: tsi.RefsReadBatchReq) -> tsi.RefsReadBatchRes:
|
|
201
|
+
return tsi.RefsReadBatchRes(vals=[])
|
|
202
|
+
|
|
203
|
+
# --- File API ---
|
|
204
|
+
|
|
205
|
+
def file_create(self, req: tsi.FileCreateReq) -> tsi.FileCreateRes:
|
|
206
|
+
self.files[req.name] = req.content
|
|
207
|
+
return tsi.FileCreateRes(digest=generate_id())
|
|
208
|
+
|
|
209
|
+
def file_content_read(self, req: tsi.FileContentReadReq) -> tsi.FileContentReadRes:
|
|
210
|
+
return tsi.FileContentReadRes(content=self.files.get(req.digest, b"dummy_content"))
|
|
211
|
+
|
|
212
|
+
def files_stats(self, req: tsi.FilesStatsReq) -> tsi.FilesStatsRes:
|
|
213
|
+
total_size = sum(len(c) for c in self.files.values())
|
|
214
|
+
return tsi.FilesStatsRes(total_size_bytes=total_size)
|
|
215
|
+
|
|
216
|
+
# --- Feedback API ---
|
|
217
|
+
|
|
218
|
+
@validate_call
|
|
219
|
+
def feedback_create(self, req: tsi.FeedbackCreateReq) -> tsi.FeedbackCreateRes:
|
|
220
|
+
req.id = req.id or generate_id()
|
|
221
|
+
self.feedback.append(req)
|
|
222
|
+
return tsi.FeedbackCreateRes(
|
|
223
|
+
id=req.id,
|
|
224
|
+
created_at=datetime.now(timezone.utc),
|
|
225
|
+
wb_user_id="dummy_user",
|
|
226
|
+
payload=req.payload,
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
def feedback_create_batch(self, req: tsi.FeedbackCreateBatchReq) -> tsi.FeedbackCreateBatchRes:
|
|
230
|
+
results: List[tsi.FeedbackCreateRes] = []
|
|
231
|
+
for item in req.batch:
|
|
232
|
+
res = self.feedback_create(item)
|
|
233
|
+
results.append(res)
|
|
234
|
+
return tsi.FeedbackCreateBatchRes(res=results)
|
|
235
|
+
|
|
236
|
+
@validate_call
|
|
237
|
+
def feedback_query(self, req: tsi.FeedbackQueryReq) -> tsi.FeedbackQueryRes:
|
|
238
|
+
return tsi.FeedbackQueryRes(result=[])
|
|
239
|
+
|
|
240
|
+
@validate_call
|
|
241
|
+
def feedback_purge(self, req: tsi.FeedbackPurgeReq) -> tsi.FeedbackPurgeRes:
|
|
242
|
+
self.feedback.clear()
|
|
243
|
+
return tsi.FeedbackPurgeRes()
|
|
244
|
+
|
|
245
|
+
@validate_call
|
|
246
|
+
def feedback_replace(self, req: tsi.FeedbackReplaceReq) -> tsi.FeedbackReplaceRes:
|
|
247
|
+
return tsi.FeedbackReplaceRes(
|
|
248
|
+
id=req.id or generate_id(),
|
|
249
|
+
created_at=datetime.now(timezone.utc),
|
|
250
|
+
wb_user_id="dummy",
|
|
251
|
+
payload={},
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
# --- Action API ---
|
|
255
|
+
|
|
256
|
+
@validate_call
|
|
257
|
+
def actions_execute_batch(self, req: tsi.ActionsExecuteBatchReq) -> tsi.ActionsExecuteBatchRes:
|
|
258
|
+
return tsi.ActionsExecuteBatchRes()
|
|
259
|
+
|
|
260
|
+
# --- Execute LLM API ---
|
|
261
|
+
|
|
262
|
+
@validate_call
|
|
263
|
+
def completions_create(self, req: tsi.CompletionsCreateReq) -> tsi.CompletionsCreateRes:
|
|
264
|
+
return tsi.CompletionsCreateRes(response={"choices": [{"text": "dummy completion"}]})
|
|
265
|
+
|
|
266
|
+
@validate_call
|
|
267
|
+
def completions_create_stream(self, req: tsi.CompletionsCreateReq) -> Iterator[dict[str, Any]]:
|
|
268
|
+
yield {"choices": [{"text": "dummy "}]}
|
|
269
|
+
yield {"choices": [{"text": "stream"}]}
|
|
270
|
+
|
|
271
|
+
# --- Execute Image Generation API ---
|
|
272
|
+
|
|
273
|
+
@validate_call
|
|
274
|
+
def image_create(self, req: tsi.ImageGenerationCreateReq) -> tsi.ImageGenerationCreateRes:
|
|
275
|
+
return tsi.ImageGenerationCreateRes(response={})
|
|
276
|
+
|
|
277
|
+
# --- Project Statistics API ---
|
|
278
|
+
|
|
279
|
+
@validate_call
|
|
280
|
+
def project_stats(self, req: tsi.ProjectStatsReq) -> tsi.ProjectStatsRes:
|
|
281
|
+
return tsi.ProjectStatsRes(
|
|
282
|
+
trace_storage_size_bytes=0,
|
|
283
|
+
objects_storage_size_bytes=0,
|
|
284
|
+
tables_storage_size_bytes=0,
|
|
285
|
+
files_storage_size_bytes=0,
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
# --- Thread API ---
|
|
289
|
+
|
|
290
|
+
@validate_call
|
|
291
|
+
def threads_query_stream(self, req: tsi.ThreadsQueryReq) -> Iterator[tsi.ThreadSchema]:
|
|
292
|
+
yield from []
|
|
293
|
+
|
|
294
|
+
# --- Evaluation API (V1) ---
|
|
295
|
+
|
|
296
|
+
@validate_call
|
|
297
|
+
def evaluate_model(self, req: tsi.EvaluateModelReq) -> tsi.EvaluateModelRes:
|
|
298
|
+
return tsi.EvaluateModelRes(call_id=generate_id())
|
|
299
|
+
|
|
300
|
+
@validate_call
|
|
301
|
+
def evaluation_status(self, req: tsi.EvaluationStatusReq) -> tsi.EvaluationStatusRes:
|
|
302
|
+
return tsi.EvaluationStatusRes(status=tsi.EvaluationStatusNotFound())
|
|
303
|
+
|
|
304
|
+
# --- OTEL API ---
|
|
305
|
+
|
|
306
|
+
def otel_export(self, req: tsi.OtelExportReq) -> tsi.OtelExportRes:
|
|
307
|
+
return tsi.OtelExportRes()
|
|
308
|
+
|
|
309
|
+
# ==========================================
|
|
310
|
+
# Object Interface (V2 APIs)
|
|
311
|
+
# ==========================================
|
|
312
|
+
|
|
313
|
+
# --- Ops ---
|
|
314
|
+
def op_create(self, req: tsi.OpCreateReq) -> tsi.OpCreateRes:
|
|
315
|
+
return tsi.OpCreateRes(digest=generate_id(), object_id=generate_id(), version_index=0)
|
|
316
|
+
|
|
317
|
+
def op_read(self, req: tsi.OpReadReq) -> tsi.OpReadRes:
|
|
318
|
+
return tsi.OpReadRes(op=None) # type: ignore
|
|
319
|
+
|
|
320
|
+
def op_list(self, req: tsi.OpListReq) -> Iterator[tsi.OpReadRes]:
|
|
321
|
+
yield from []
|
|
322
|
+
|
|
323
|
+
def op_delete(self, req: tsi.OpDeleteReq) -> tsi.OpDeleteRes:
|
|
324
|
+
return tsi.OpDeleteRes(num_deleted=0)
|
|
325
|
+
|
|
326
|
+
# --- Datasets ---
|
|
327
|
+
def dataset_create(self, req: tsi.DatasetCreateReq) -> tsi.DatasetCreateRes:
|
|
328
|
+
return tsi.DatasetCreateRes(digest=generate_id(), object_id=generate_id(), version_index=0)
|
|
329
|
+
|
|
330
|
+
def dataset_read(self, req: tsi.DatasetReadReq) -> tsi.DatasetReadRes:
|
|
331
|
+
return tsi.DatasetReadRes(dataset=None) # type: ignore
|
|
332
|
+
|
|
333
|
+
def dataset_list(self, req: tsi.DatasetListReq) -> Iterator[tsi.DatasetReadRes]:
|
|
334
|
+
yield from []
|
|
335
|
+
|
|
336
|
+
def dataset_delete(self, req: tsi.DatasetDeleteReq) -> tsi.DatasetDeleteRes:
|
|
337
|
+
return tsi.DatasetDeleteRes(num_deleted=0)
|
|
338
|
+
|
|
339
|
+
# --- Scorers ---
|
|
340
|
+
def scorer_create(self, req: tsi.ScorerCreateReq) -> tsi.ScorerCreateRes:
|
|
341
|
+
return tsi.ScorerCreateRes(digest=generate_id(), object_id=generate_id(), version_index=0, scorer=generate_id())
|
|
342
|
+
|
|
343
|
+
def scorer_read(self, req: tsi.ScorerReadReq) -> tsi.ScorerReadRes:
|
|
344
|
+
return tsi.ScorerReadRes(scorer=None) # type: ignore
|
|
345
|
+
|
|
346
|
+
def scorer_list(self, req: tsi.ScorerListReq) -> Iterator[tsi.ScorerReadRes]:
|
|
347
|
+
yield from []
|
|
348
|
+
|
|
349
|
+
def scorer_delete(self, req: tsi.ScorerDeleteReq) -> tsi.ScorerDeleteRes:
|
|
350
|
+
return tsi.ScorerDeleteRes(num_deleted=0)
|
|
351
|
+
|
|
352
|
+
# --- Evaluations (V2) ---
|
|
353
|
+
def evaluation_create(self, req: tsi.EvaluationCreateReq) -> tsi.EvaluationCreateRes:
|
|
354
|
+
return tsi.EvaluationCreateRes(
|
|
355
|
+
digest=generate_id(), object_id=generate_id(), version_index=0, evaluation_ref=generate_id()
|
|
356
|
+
)
|
|
357
|
+
|
|
358
|
+
def evaluation_read(self, req: tsi.EvaluationReadReq) -> tsi.EvaluationReadRes:
|
|
359
|
+
return tsi.EvaluationReadRes(evaluation=None) # type: ignore
|
|
360
|
+
|
|
361
|
+
def evaluation_list(self, req: tsi.EvaluationListReq) -> Iterator[tsi.EvaluationReadRes]:
|
|
362
|
+
yield from []
|
|
363
|
+
|
|
364
|
+
def evaluation_delete(self, req: tsi.EvaluationDeleteReq) -> tsi.EvaluationDeleteRes:
|
|
365
|
+
return tsi.EvaluationDeleteRes(num_deleted=0)
|
|
366
|
+
|
|
367
|
+
# --- Models ---
|
|
368
|
+
def model_create(self, req: tsi.ModelCreateReq) -> tsi.ModelCreateRes:
|
|
369
|
+
return tsi.ModelCreateRes(
|
|
370
|
+
digest=generate_id(), object_id=generate_id(), version_index=0, model_ref=generate_id()
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
def model_read(self, req: tsi.ModelReadReq) -> tsi.ModelReadRes:
|
|
374
|
+
return tsi.ModelReadRes(model=None) # type: ignore
|
|
375
|
+
|
|
376
|
+
def model_list(self, req: tsi.ModelListReq) -> Iterator[tsi.ModelReadRes]:
|
|
377
|
+
yield from []
|
|
378
|
+
|
|
379
|
+
def model_delete(self, req: tsi.ModelDeleteReq) -> tsi.ModelDeleteRes:
|
|
380
|
+
return tsi.ModelDeleteRes(num_deleted=0)
|
|
381
|
+
|
|
382
|
+
# --- Evaluation Runs ---
|
|
383
|
+
def evaluation_run_create(self, req: tsi.EvaluationRunCreateReq) -> tsi.EvaluationRunCreateRes:
|
|
384
|
+
return tsi.EvaluationRunCreateRes(evaluation_run_id=generate_id())
|
|
385
|
+
|
|
386
|
+
def evaluation_run_read(self, req: tsi.EvaluationRunReadReq) -> tsi.EvaluationRunReadRes:
|
|
387
|
+
return tsi.EvaluationRunReadRes(evaluation_run=None) # type: ignore
|
|
388
|
+
|
|
389
|
+
def evaluation_run_list(self, req: tsi.EvaluationRunListReq) -> Iterator[tsi.EvaluationRunReadRes]:
|
|
390
|
+
yield from []
|
|
391
|
+
|
|
392
|
+
def evaluation_run_delete(self, req: tsi.EvaluationRunDeleteReq) -> tsi.EvaluationRunDeleteRes:
|
|
393
|
+
return tsi.EvaluationRunDeleteRes(num_deleted=0)
|
|
394
|
+
|
|
395
|
+
def evaluation_run_finish(self, req: tsi.EvaluationRunFinishReq) -> tsi.EvaluationRunFinishRes:
|
|
396
|
+
return tsi.EvaluationRunFinishRes(success=True)
|
|
397
|
+
|
|
398
|
+
# --- Predictions ---
|
|
399
|
+
def prediction_create(self, req: tsi.PredictionCreateReq) -> tsi.PredictionCreateRes:
|
|
400
|
+
return tsi.PredictionCreateRes(prediction_id=generate_id())
|
|
401
|
+
|
|
402
|
+
def prediction_read(self, req: tsi.PredictionReadReq) -> tsi.PredictionReadRes:
|
|
403
|
+
return tsi.PredictionReadRes(prediction=None) # type: ignore
|
|
404
|
+
|
|
405
|
+
def prediction_list(self, req: tsi.PredictionListReq) -> Iterator[tsi.PredictionReadRes]:
|
|
406
|
+
yield from []
|
|
407
|
+
|
|
408
|
+
def prediction_delete(self, req: tsi.PredictionDeleteReq) -> tsi.PredictionDeleteRes:
|
|
409
|
+
return tsi.PredictionDeleteRes(num_deleted=0)
|
|
410
|
+
|
|
411
|
+
def prediction_finish(self, req: tsi.PredictionFinishReq) -> tsi.PredictionFinishRes:
|
|
412
|
+
return tsi.PredictionFinishRes(success=True)
|
|
413
|
+
|
|
414
|
+
# --- Scores ---
|
|
415
|
+
def score_create(self, req: tsi.ScoreCreateReq) -> tsi.ScoreCreateRes:
|
|
416
|
+
return tsi.ScoreCreateRes(score_id=generate_id())
|
|
417
|
+
|
|
418
|
+
def score_read(self, req: tsi.ScoreReadReq) -> tsi.ScoreReadRes:
|
|
419
|
+
return tsi.ScoreReadRes(score=None) # type: ignore
|
|
420
|
+
|
|
421
|
+
def score_list(self, req: tsi.ScoreListReq) -> Iterator[tsi.ScoreReadRes]:
|
|
422
|
+
yield from []
|
|
423
|
+
|
|
424
|
+
def score_delete(self, req: tsi.ScoreDeleteReq) -> tsi.ScoreDeleteRes:
|
|
425
|
+
return tsi.ScoreDeleteRes(num_deleted=0)
|
|
426
|
+
|
|
427
|
+
|
|
428
|
+
# Module-level storage for originals
|
|
429
|
+
_original_init_weave_get_server: Callable[..., Any] | None = None
|
|
430
|
+
_original_get_entity_project_from_project_name: Callable[..., Any] | None = None
|
|
431
|
+
_original_get_username: Callable[..., Any] | None = None
|
|
432
|
+
|
|
433
|
+
|
|
434
|
+
def init_weave_get_server_factory(server: InMemoryWeaveTraceServer) -> Callable[..., Any]:
|
|
435
|
+
# Bypass the usage of Weave remote server
|
|
436
|
+
def init_weave_get_server(*args: Any, **kwargs: Any) -> InMemoryWeaveTraceServer:
|
|
437
|
+
return server
|
|
438
|
+
|
|
439
|
+
return init_weave_get_server
|
|
440
|
+
|
|
441
|
+
|
|
442
|
+
def get_entity_project_from_project_name_factory(entity_name: str) -> tuple[str, str]:
|
|
443
|
+
# Bypass the usage of API
|
|
444
|
+
try:
|
|
445
|
+
assert _original_get_entity_project_from_project_name is not None
|
|
446
|
+
if _original_get_entity_project_from_project_name is not get_entity_project_from_project_name_factory:
|
|
447
|
+
return _original_get_entity_project_from_project_name(entity_name)
|
|
448
|
+
else:
|
|
449
|
+
warnings.warn("W&B integration might have been repeatedly/recursively instrumented.")
|
|
450
|
+
return "msk", "weave"
|
|
451
|
+
except weave.trace.weave_init.WeaveWandbAuthenticationException:
|
|
452
|
+
# In case API is not available.
|
|
453
|
+
return "msk", "weave"
|
|
454
|
+
|
|
455
|
+
|
|
456
|
+
def get_username() -> str:
|
|
457
|
+
# Bypass the usage of API
|
|
458
|
+
try:
|
|
459
|
+
assert _original_get_username is not None
|
|
460
|
+
return _original_get_username()
|
|
461
|
+
except RuntimeError:
|
|
462
|
+
return "msk"
|
|
463
|
+
except Exception as exc:
|
|
464
|
+
warnings.warn(f"Unexpected error in get_username. Using default username. Error: {exc}")
|
|
465
|
+
return "msk"
|
|
466
|
+
|
|
467
|
+
|
|
468
|
+
def instrument_weave(server: InMemoryWeaveTraceServer):
|
|
469
|
+
"""Patch the Weave/W&B integration to bypass actual network calls for testing."""
|
|
470
|
+
|
|
471
|
+
global _original_init_weave_get_server, _original_get_entity_project_from_project_name, _original_get_username
|
|
472
|
+
_original_init_weave_get_server = weave.trace.weave_init.init_weave_get_server
|
|
473
|
+
_original_get_entity_project_from_project_name = weave.trace.weave_init.get_entity_project_from_project_name
|
|
474
|
+
_original_get_username = weave.trace.weave_init.get_username
|
|
475
|
+
weave.trace.weave_init.init_weave_get_server = init_weave_get_server_factory(server)
|
|
476
|
+
weave.trace.weave_init.get_entity_project_from_project_name = get_entity_project_from_project_name_factory
|
|
477
|
+
weave.trace.weave_init.get_username = get_username
|
|
478
|
+
|
|
479
|
+
|
|
480
|
+
def uninstrument_weave():
|
|
481
|
+
"""Restore the original Weave/W&B integration methods and HTTP requests."""
|
|
482
|
+
global _original_init_weave_get_server, _original_get_entity_project_from_project_name, _original_get_username
|
|
483
|
+
|
|
484
|
+
if _original_init_weave_get_server is not None:
|
|
485
|
+
weave.trace.weave_init.init_weave_get_server = _original_init_weave_get_server
|
|
486
|
+
_original_init_weave_get_server = None
|
|
487
|
+
else:
|
|
488
|
+
raise RuntimeError("Weave/W&B integration was not instrumented.")
|
|
489
|
+
|
|
490
|
+
if _original_get_entity_project_from_project_name is not None:
|
|
491
|
+
weave.trace.weave_init.get_entity_project_from_project_name = _original_get_entity_project_from_project_name
|
|
492
|
+
_original_get_entity_project_from_project_name = None
|
|
493
|
+
else:
|
|
494
|
+
raise RuntimeError("Weave/W&B integration was not instrumented.")
|
|
495
|
+
|
|
496
|
+
if _original_get_username is not None:
|
|
497
|
+
weave.trace.weave_init.get_username = _original_get_username
|
|
498
|
+
_original_get_username = None
|
|
499
|
+
else:
|
|
500
|
+
raise RuntimeError("Weave/W&B integration was not instrumented.")
|