mantisdk 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mantisdk might be problematic. Click here for more details.

Files changed (190) hide show
  1. mantisdk/__init__.py +22 -0
  2. mantisdk/adapter/__init__.py +15 -0
  3. mantisdk/adapter/base.py +94 -0
  4. mantisdk/adapter/messages.py +270 -0
  5. mantisdk/adapter/triplet.py +1028 -0
  6. mantisdk/algorithm/__init__.py +39 -0
  7. mantisdk/algorithm/apo/__init__.py +5 -0
  8. mantisdk/algorithm/apo/apo.py +889 -0
  9. mantisdk/algorithm/apo/prompts/apply_edit_variant01.poml +22 -0
  10. mantisdk/algorithm/apo/prompts/apply_edit_variant02.poml +18 -0
  11. mantisdk/algorithm/apo/prompts/text_gradient_variant01.poml +18 -0
  12. mantisdk/algorithm/apo/prompts/text_gradient_variant02.poml +16 -0
  13. mantisdk/algorithm/apo/prompts/text_gradient_variant03.poml +107 -0
  14. mantisdk/algorithm/base.py +162 -0
  15. mantisdk/algorithm/decorator.py +264 -0
  16. mantisdk/algorithm/fast.py +250 -0
  17. mantisdk/algorithm/gepa/__init__.py +59 -0
  18. mantisdk/algorithm/gepa/adapter.py +459 -0
  19. mantisdk/algorithm/gepa/gepa.py +364 -0
  20. mantisdk/algorithm/gepa/lib/__init__.py +18 -0
  21. mantisdk/algorithm/gepa/lib/adapters/README.md +12 -0
  22. mantisdk/algorithm/gepa/lib/adapters/__init__.py +0 -0
  23. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/README.md +341 -0
  24. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/__init__.py +1 -0
  25. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/anymaths_adapter.py +174 -0
  26. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/requirements.txt +1 -0
  27. mantisdk/algorithm/gepa/lib/adapters/default_adapter/README.md +0 -0
  28. mantisdk/algorithm/gepa/lib/adapters/default_adapter/__init__.py +0 -0
  29. mantisdk/algorithm/gepa/lib/adapters/default_adapter/default_adapter.py +209 -0
  30. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/README.md +7 -0
  31. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/__init__.py +0 -0
  32. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/dspy_adapter.py +307 -0
  33. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/README.md +99 -0
  34. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/dspy_program_proposal_signature.py +137 -0
  35. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/full_program_adapter.py +266 -0
  36. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/GEPA_RAG.md +621 -0
  37. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/__init__.py +56 -0
  38. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/evaluation_metrics.py +226 -0
  39. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/generic_rag_adapter.py +496 -0
  40. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/rag_pipeline.py +238 -0
  41. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_store_interface.py +212 -0
  42. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/__init__.py +2 -0
  43. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/chroma_store.py +196 -0
  44. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/lancedb_store.py +422 -0
  45. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/milvus_store.py +409 -0
  46. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/qdrant_store.py +368 -0
  47. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/weaviate_store.py +418 -0
  48. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/README.md +552 -0
  49. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/__init__.py +37 -0
  50. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_adapter.py +705 -0
  51. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_client.py +364 -0
  52. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/README.md +9 -0
  53. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/__init__.py +0 -0
  54. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/terminal_bench_adapter.py +217 -0
  55. mantisdk/algorithm/gepa/lib/api.py +375 -0
  56. mantisdk/algorithm/gepa/lib/core/__init__.py +0 -0
  57. mantisdk/algorithm/gepa/lib/core/adapter.py +180 -0
  58. mantisdk/algorithm/gepa/lib/core/data_loader.py +74 -0
  59. mantisdk/algorithm/gepa/lib/core/engine.py +356 -0
  60. mantisdk/algorithm/gepa/lib/core/result.py +233 -0
  61. mantisdk/algorithm/gepa/lib/core/state.py +636 -0
  62. mantisdk/algorithm/gepa/lib/examples/__init__.py +0 -0
  63. mantisdk/algorithm/gepa/lib/examples/aime.py +24 -0
  64. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/eval_default.py +111 -0
  65. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/instruction_prompt.txt +9 -0
  66. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/optimal_prompt.txt +24 -0
  67. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/train_anymaths.py +177 -0
  68. mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/arc_agi.ipynb +25705 -0
  69. mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/example.ipynb +348 -0
  70. mantisdk/algorithm/gepa/lib/examples/mcp_adapter/__init__.py +4 -0
  71. mantisdk/algorithm/gepa/lib/examples/mcp_adapter/mcp_optimization_example.py +455 -0
  72. mantisdk/algorithm/gepa/lib/examples/rag_adapter/RAG_GUIDE.md +613 -0
  73. mantisdk/algorithm/gepa/lib/examples/rag_adapter/__init__.py +9 -0
  74. mantisdk/algorithm/gepa/lib/examples/rag_adapter/rag_optimization.py +824 -0
  75. mantisdk/algorithm/gepa/lib/examples/rag_adapter/requirements-rag.txt +29 -0
  76. mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/instruction_prompt.txt +16 -0
  77. mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/terminus.txt +9 -0
  78. mantisdk/algorithm/gepa/lib/examples/terminal-bench/train_terminus.py +161 -0
  79. mantisdk/algorithm/gepa/lib/gepa_utils.py +117 -0
  80. mantisdk/algorithm/gepa/lib/logging/__init__.py +0 -0
  81. mantisdk/algorithm/gepa/lib/logging/experiment_tracker.py +187 -0
  82. mantisdk/algorithm/gepa/lib/logging/logger.py +75 -0
  83. mantisdk/algorithm/gepa/lib/logging/utils.py +103 -0
  84. mantisdk/algorithm/gepa/lib/proposer/__init__.py +0 -0
  85. mantisdk/algorithm/gepa/lib/proposer/base.py +31 -0
  86. mantisdk/algorithm/gepa/lib/proposer/merge.py +357 -0
  87. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/__init__.py +0 -0
  88. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/base.py +49 -0
  89. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/reflective_mutation.py +176 -0
  90. mantisdk/algorithm/gepa/lib/py.typed +0 -0
  91. mantisdk/algorithm/gepa/lib/strategies/__init__.py +0 -0
  92. mantisdk/algorithm/gepa/lib/strategies/batch_sampler.py +77 -0
  93. mantisdk/algorithm/gepa/lib/strategies/candidate_selector.py +50 -0
  94. mantisdk/algorithm/gepa/lib/strategies/component_selector.py +36 -0
  95. mantisdk/algorithm/gepa/lib/strategies/eval_policy.py +64 -0
  96. mantisdk/algorithm/gepa/lib/strategies/instruction_proposal.py +127 -0
  97. mantisdk/algorithm/gepa/lib/utils/__init__.py +10 -0
  98. mantisdk/algorithm/gepa/lib/utils/stop_condition.py +196 -0
  99. mantisdk/algorithm/gepa/tracing.py +105 -0
  100. mantisdk/algorithm/utils.py +177 -0
  101. mantisdk/algorithm/verl/__init__.py +5 -0
  102. mantisdk/algorithm/verl/interface.py +202 -0
  103. mantisdk/cli/__init__.py +56 -0
  104. mantisdk/cli/prometheus.py +115 -0
  105. mantisdk/cli/store.py +131 -0
  106. mantisdk/cli/vllm.py +29 -0
  107. mantisdk/client.py +408 -0
  108. mantisdk/config.py +348 -0
  109. mantisdk/emitter/__init__.py +43 -0
  110. mantisdk/emitter/annotation.py +370 -0
  111. mantisdk/emitter/exception.py +54 -0
  112. mantisdk/emitter/message.py +61 -0
  113. mantisdk/emitter/object.py +117 -0
  114. mantisdk/emitter/reward.py +320 -0
  115. mantisdk/env_var.py +156 -0
  116. mantisdk/execution/__init__.py +15 -0
  117. mantisdk/execution/base.py +64 -0
  118. mantisdk/execution/client_server.py +443 -0
  119. mantisdk/execution/events.py +69 -0
  120. mantisdk/execution/inter_process.py +16 -0
  121. mantisdk/execution/shared_memory.py +282 -0
  122. mantisdk/instrumentation/__init__.py +119 -0
  123. mantisdk/instrumentation/agentops.py +314 -0
  124. mantisdk/instrumentation/agentops_langchain.py +45 -0
  125. mantisdk/instrumentation/litellm.py +83 -0
  126. mantisdk/instrumentation/vllm.py +81 -0
  127. mantisdk/instrumentation/weave.py +500 -0
  128. mantisdk/litagent/__init__.py +11 -0
  129. mantisdk/litagent/decorator.py +536 -0
  130. mantisdk/litagent/litagent.py +252 -0
  131. mantisdk/llm_proxy.py +1890 -0
  132. mantisdk/logging.py +370 -0
  133. mantisdk/reward.py +7 -0
  134. mantisdk/runner/__init__.py +11 -0
  135. mantisdk/runner/agent.py +845 -0
  136. mantisdk/runner/base.py +182 -0
  137. mantisdk/runner/legacy.py +309 -0
  138. mantisdk/semconv.py +170 -0
  139. mantisdk/server.py +401 -0
  140. mantisdk/store/__init__.py +23 -0
  141. mantisdk/store/base.py +897 -0
  142. mantisdk/store/client_server.py +2092 -0
  143. mantisdk/store/collection/__init__.py +30 -0
  144. mantisdk/store/collection/base.py +587 -0
  145. mantisdk/store/collection/memory.py +970 -0
  146. mantisdk/store/collection/mongo.py +1412 -0
  147. mantisdk/store/collection_based.py +1823 -0
  148. mantisdk/store/insight.py +648 -0
  149. mantisdk/store/listener.py +58 -0
  150. mantisdk/store/memory.py +396 -0
  151. mantisdk/store/mongo.py +165 -0
  152. mantisdk/store/sqlite.py +3 -0
  153. mantisdk/store/threading.py +357 -0
  154. mantisdk/store/utils.py +142 -0
  155. mantisdk/tracer/__init__.py +16 -0
  156. mantisdk/tracer/agentops.py +242 -0
  157. mantisdk/tracer/base.py +287 -0
  158. mantisdk/tracer/dummy.py +106 -0
  159. mantisdk/tracer/otel.py +555 -0
  160. mantisdk/tracer/weave.py +677 -0
  161. mantisdk/trainer/__init__.py +6 -0
  162. mantisdk/trainer/init_utils.py +263 -0
  163. mantisdk/trainer/legacy.py +367 -0
  164. mantisdk/trainer/registry.py +12 -0
  165. mantisdk/trainer/trainer.py +618 -0
  166. mantisdk/types/__init__.py +6 -0
  167. mantisdk/types/core.py +553 -0
  168. mantisdk/types/resources.py +204 -0
  169. mantisdk/types/tracer.py +515 -0
  170. mantisdk/types/tracing.py +218 -0
  171. mantisdk/utils/__init__.py +1 -0
  172. mantisdk/utils/id.py +18 -0
  173. mantisdk/utils/metrics.py +1025 -0
  174. mantisdk/utils/otel.py +578 -0
  175. mantisdk/utils/otlp.py +536 -0
  176. mantisdk/utils/server_launcher.py +1045 -0
  177. mantisdk/utils/system_snapshot.py +81 -0
  178. mantisdk/verl/__init__.py +8 -0
  179. mantisdk/verl/__main__.py +6 -0
  180. mantisdk/verl/async_server.py +46 -0
  181. mantisdk/verl/config.yaml +27 -0
  182. mantisdk/verl/daemon.py +1154 -0
  183. mantisdk/verl/dataset.py +44 -0
  184. mantisdk/verl/entrypoint.py +248 -0
  185. mantisdk/verl/trainer.py +549 -0
  186. mantisdk-0.1.0.dist-info/METADATA +119 -0
  187. mantisdk-0.1.0.dist-info/RECORD +190 -0
  188. mantisdk-0.1.0.dist-info/WHEEL +4 -0
  189. mantisdk-0.1.0.dist-info/entry_points.txt +2 -0
  190. mantisdk-0.1.0.dist-info/licenses/LICENSE +19 -0
@@ -0,0 +1,500 @@
1
+ # Copyright (c) Microsoft. All rights reserved.
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ import threading
7
+ import warnings
8
+ from datetime import datetime, timezone
9
+ from typing import Any, Callable, Dict, Iterator, List
10
+
11
+ import weave.trace.weave_init
12
+ from pydantic import validate_call
13
+ from weave.trace_server import trace_server_interface as tsi
14
+ from weave.trace_server.ids import generate_id
15
+ from weave.trace_server_bindings.client_interface import TraceServerClientInterface
16
+ from weave.trace_server_bindings.models import ServerInfoRes
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ __all__ = [
21
+ "instrument_weave",
22
+ "uninstrument_weave",
23
+ "InMemoryWeaveTraceServer",
24
+ ]
25
+
26
+
27
+ class InMemoryWeaveTraceServer(TraceServerClientInterface):
28
+ """A minimal in-memory implementation of the TraceServerInterface.
29
+
30
+ It stores calls and objects in local dictionaries and returns valid Pydantic
31
+ responses to satisfy the Weave client and FullTraceServerInterface protocol.
32
+ """
33
+
34
+ def __init__(self):
35
+ # Minimal storage to allow basic querying in tests
36
+ self.calls: Dict[str, tsi.CallSchema] = {}
37
+ self.partial_calls: Dict[str, Dict[str, Any]] = {}
38
+ self.objs: Dict[str, Any] = {}
39
+ self.files: Dict[str, bytes] = {}
40
+ self.feedback: List[tsi.FeedbackCreateReq] = []
41
+
42
+ self._call_threading_lock = threading.Lock()
43
+
44
+ @classmethod
45
+ def from_env(cls, *args: Any, **kwargs: Any) -> InMemoryWeaveTraceServer:
46
+ return cls()
47
+
48
+ def server_info(self) -> ServerInfoRes:
49
+ return ServerInfoRes(min_required_weave_python_version="0.52.22")
50
+
51
+ def ensure_project_exists(self, entity: str, project: str) -> tsi.EnsureProjectExistsRes:
52
+ return tsi.EnsureProjectExistsRes(project_name=project)
53
+
54
+ # --- Call API ---
55
+
56
+ @validate_call
57
+ def call_start(self, req: tsi.CallStartReq) -> tsi.CallStartRes:
58
+ # NOTE: It's not necessary that call_end must be called after call_start.
59
+ request_content = req.start.model_dump(exclude_none=True)
60
+
61
+ # If id needs to be generated here, it's very likely we won't be able to find the call later.
62
+ # This is just to make the type checker happy.
63
+ call_id = request_content.get("id") or generate_id()
64
+ trace_id = request_content.get("trace_id") or generate_id()
65
+ request_content["id"] = call_id
66
+ request_content["trace_id"] = trace_id
67
+
68
+ with self._call_threading_lock:
69
+ if call_id in self.partial_calls:
70
+ # call_end has already been called for this call.
71
+ kwargs = {**request_content, **self.partial_calls[call_id]}
72
+ self.calls[call_id] = tsi.CallSchema(**kwargs)
73
+ del self.partial_calls[call_id]
74
+ else:
75
+ self.partial_calls[call_id] = request_content
76
+
77
+ return tsi.CallStartRes(id=call_id, trace_id=trace_id)
78
+
79
+ @validate_call
80
+ def call_end(self, req: tsi.CallEndReq) -> tsi.CallEndRes:
81
+ request_content = req.end.model_dump(exclude_none=True)
82
+ call_id = req.end.id
83
+
84
+ with self._call_threading_lock:
85
+ if call_id in self.partial_calls:
86
+ # End request always override the start request content.
87
+ kwargs = {**self.partial_calls[call_id], **request_content}
88
+ self.calls[call_id] = tsi.CallSchema(**kwargs)
89
+ del self.partial_calls[call_id]
90
+ else:
91
+ self.partial_calls[call_id] = request_content
92
+ return tsi.CallEndRes()
93
+
94
+ @validate_call
95
+ def call_start_batch(self, req: tsi.CallCreateBatchReq) -> tsi.CallCreateBatchRes:
96
+ for item in req.batch:
97
+ if isinstance(item, tsi.CallStartReq):
98
+ self.call_start(item)
99
+ elif isinstance(item, tsi.CallEndReq):
100
+ self.call_end(item)
101
+ return tsi.CallCreateBatchRes(res=[])
102
+
103
+ @validate_call
104
+ def call_read(self, req: tsi.CallReadReq) -> tsi.CallReadRes:
105
+ call_data = self.calls.get(req.id)
106
+ return tsi.CallReadRes(call=call_data)
107
+
108
+ @validate_call
109
+ def calls_query(self, req: tsi.CallsQueryReq) -> tsi.CallsQueryRes:
110
+ return tsi.CallsQueryRes(calls=list(self.calls_query_stream(req)))
111
+
112
+ @validate_call
113
+ def calls_query_stream(self, req: tsi.CallsQueryReq) -> Iterator[tsi.CallSchema]:
114
+ yield from self.calls.values()
115
+
116
+ @validate_call
117
+ def calls_delete(self, req: tsi.CallsDeleteReq) -> tsi.CallsDeleteRes:
118
+ num_deleted = 0
119
+ for call_id in req.call_ids:
120
+ if call_id in self.calls:
121
+ del self.calls[call_id]
122
+ num_deleted += 1
123
+ return tsi.CallsDeleteRes(num_deleted=num_deleted)
124
+
125
+ @validate_call
126
+ def call_update(self, req: tsi.CallUpdateReq) -> tsi.CallUpdateRes:
127
+ return tsi.CallUpdateRes()
128
+
129
+ @validate_call
130
+ def calls_query_stats(self, req: tsi.CallsQueryStatsReq) -> tsi.CallsQueryStatsRes:
131
+ return tsi.CallsQueryStatsRes(count=len(self.calls))
132
+
133
+ # --- Cost API ---
134
+
135
+ @validate_call
136
+ def cost_create(self, req: tsi.CostCreateReq) -> tsi.CostCreateRes:
137
+ return tsi.CostCreateRes(ids=[(generate_id(), generate_id()) for _ in req.costs])
138
+
139
+ @validate_call
140
+ def cost_query(self, req: tsi.CostQueryReq) -> tsi.CostQueryRes:
141
+ return tsi.CostQueryRes(results=[])
142
+
143
+ @validate_call
144
+ def cost_purge(self, req: tsi.CostPurgeReq) -> tsi.CostPurgeRes:
145
+ return tsi.CostPurgeRes()
146
+
147
+ # --- Object API (Legacy V1) ---
148
+
149
+ @validate_call
150
+ def obj_create(self, req: tsi.ObjCreateReq) -> tsi.ObjCreateRes:
151
+ digest = generate_id()
152
+ self.objs[digest] = req.obj
153
+ return tsi.ObjCreateRes(digest=digest)
154
+
155
+ @validate_call
156
+ def obj_read(self, req: tsi.ObjReadReq) -> tsi.ObjReadRes:
157
+ return tsi.ObjReadRes(obj=self.objs.get(req.digest, {}))
158
+
159
+ @validate_call
160
+ def objs_query(self, req: tsi.ObjQueryReq) -> tsi.ObjQueryRes:
161
+ return tsi.ObjQueryRes(objs=[])
162
+
163
+ @validate_call
164
+ def obj_delete(self, req: tsi.ObjDeleteReq) -> tsi.ObjDeleteRes:
165
+ return tsi.ObjDeleteRes(num_deleted=0)
166
+
167
+ # --- Table API ---
168
+
169
+ @validate_call
170
+ def table_create(self, req: tsi.TableCreateReq) -> tsi.TableCreateRes:
171
+ return tsi.TableCreateRes(digest=generate_id(), row_digests=[])
172
+
173
+ @validate_call
174
+ def table_create_from_digests(self, req: tsi.TableCreateFromDigestsReq) -> tsi.TableCreateFromDigestsRes:
175
+ return tsi.TableCreateFromDigestsRes(digest=generate_id())
176
+
177
+ @validate_call
178
+ def table_update(self, req: tsi.TableUpdateReq) -> tsi.TableUpdateRes:
179
+ return tsi.TableUpdateRes(digest=generate_id(), updated_row_digests=[])
180
+
181
+ @validate_call
182
+ def table_query(self, req: tsi.TableQueryReq) -> tsi.TableQueryRes:
183
+ return tsi.TableQueryRes(rows=[])
184
+
185
+ @validate_call
186
+ def table_query_stream(self, req: tsi.TableQueryReq) -> Iterator[tsi.TableRowSchema]:
187
+ yield from []
188
+
189
+ @validate_call
190
+ def table_query_stats(self, req: tsi.TableQueryStatsReq) -> tsi.TableQueryStatsRes:
191
+ return tsi.TableQueryStatsRes(count=0)
192
+
193
+ @validate_call
194
+ def table_query_stats_batch(self, req: tsi.TableQueryStatsBatchReq) -> tsi.TableQueryStatsBatchRes:
195
+ return tsi.TableQueryStatsBatchRes(tables=[])
196
+
197
+ # --- Ref API ---
198
+
199
+ @validate_call
200
+ def refs_read_batch(self, req: tsi.RefsReadBatchReq) -> tsi.RefsReadBatchRes:
201
+ return tsi.RefsReadBatchRes(vals=[])
202
+
203
+ # --- File API ---
204
+
205
+ def file_create(self, req: tsi.FileCreateReq) -> tsi.FileCreateRes:
206
+ self.files[req.name] = req.content
207
+ return tsi.FileCreateRes(digest=generate_id())
208
+
209
+ def file_content_read(self, req: tsi.FileContentReadReq) -> tsi.FileContentReadRes:
210
+ return tsi.FileContentReadRes(content=self.files.get(req.digest, b"dummy_content"))
211
+
212
+ def files_stats(self, req: tsi.FilesStatsReq) -> tsi.FilesStatsRes:
213
+ total_size = sum(len(c) for c in self.files.values())
214
+ return tsi.FilesStatsRes(total_size_bytes=total_size)
215
+
216
+ # --- Feedback API ---
217
+
218
+ @validate_call
219
+ def feedback_create(self, req: tsi.FeedbackCreateReq) -> tsi.FeedbackCreateRes:
220
+ req.id = req.id or generate_id()
221
+ self.feedback.append(req)
222
+ return tsi.FeedbackCreateRes(
223
+ id=req.id,
224
+ created_at=datetime.now(timezone.utc),
225
+ wb_user_id="dummy_user",
226
+ payload=req.payload,
227
+ )
228
+
229
+ def feedback_create_batch(self, req: tsi.FeedbackCreateBatchReq) -> tsi.FeedbackCreateBatchRes:
230
+ results: List[tsi.FeedbackCreateRes] = []
231
+ for item in req.batch:
232
+ res = self.feedback_create(item)
233
+ results.append(res)
234
+ return tsi.FeedbackCreateBatchRes(res=results)
235
+
236
+ @validate_call
237
+ def feedback_query(self, req: tsi.FeedbackQueryReq) -> tsi.FeedbackQueryRes:
238
+ return tsi.FeedbackQueryRes(result=[])
239
+
240
+ @validate_call
241
+ def feedback_purge(self, req: tsi.FeedbackPurgeReq) -> tsi.FeedbackPurgeRes:
242
+ self.feedback.clear()
243
+ return tsi.FeedbackPurgeRes()
244
+
245
+ @validate_call
246
+ def feedback_replace(self, req: tsi.FeedbackReplaceReq) -> tsi.FeedbackReplaceRes:
247
+ return tsi.FeedbackReplaceRes(
248
+ id=req.id or generate_id(),
249
+ created_at=datetime.now(timezone.utc),
250
+ wb_user_id="dummy",
251
+ payload={},
252
+ )
253
+
254
+ # --- Action API ---
255
+
256
+ @validate_call
257
+ def actions_execute_batch(self, req: tsi.ActionsExecuteBatchReq) -> tsi.ActionsExecuteBatchRes:
258
+ return tsi.ActionsExecuteBatchRes()
259
+
260
+ # --- Execute LLM API ---
261
+
262
+ @validate_call
263
+ def completions_create(self, req: tsi.CompletionsCreateReq) -> tsi.CompletionsCreateRes:
264
+ return tsi.CompletionsCreateRes(response={"choices": [{"text": "dummy completion"}]})
265
+
266
+ @validate_call
267
+ def completions_create_stream(self, req: tsi.CompletionsCreateReq) -> Iterator[dict[str, Any]]:
268
+ yield {"choices": [{"text": "dummy "}]}
269
+ yield {"choices": [{"text": "stream"}]}
270
+
271
+ # --- Execute Image Generation API ---
272
+
273
+ @validate_call
274
+ def image_create(self, req: tsi.ImageGenerationCreateReq) -> tsi.ImageGenerationCreateRes:
275
+ return tsi.ImageGenerationCreateRes(response={})
276
+
277
+ # --- Project Statistics API ---
278
+
279
+ @validate_call
280
+ def project_stats(self, req: tsi.ProjectStatsReq) -> tsi.ProjectStatsRes:
281
+ return tsi.ProjectStatsRes(
282
+ trace_storage_size_bytes=0,
283
+ objects_storage_size_bytes=0,
284
+ tables_storage_size_bytes=0,
285
+ files_storage_size_bytes=0,
286
+ )
287
+
288
+ # --- Thread API ---
289
+
290
+ @validate_call
291
+ def threads_query_stream(self, req: tsi.ThreadsQueryReq) -> Iterator[tsi.ThreadSchema]:
292
+ yield from []
293
+
294
+ # --- Evaluation API (V1) ---
295
+
296
+ @validate_call
297
+ def evaluate_model(self, req: tsi.EvaluateModelReq) -> tsi.EvaluateModelRes:
298
+ return tsi.EvaluateModelRes(call_id=generate_id())
299
+
300
+ @validate_call
301
+ def evaluation_status(self, req: tsi.EvaluationStatusReq) -> tsi.EvaluationStatusRes:
302
+ return tsi.EvaluationStatusRes(status=tsi.EvaluationStatusNotFound())
303
+
304
+ # --- OTEL API ---
305
+
306
+ def otel_export(self, req: tsi.OtelExportReq) -> tsi.OtelExportRes:
307
+ return tsi.OtelExportRes()
308
+
309
+ # ==========================================
310
+ # Object Interface (V2 APIs)
311
+ # ==========================================
312
+
313
+ # --- Ops ---
314
+ def op_create(self, req: tsi.OpCreateReq) -> tsi.OpCreateRes:
315
+ return tsi.OpCreateRes(digest=generate_id(), object_id=generate_id(), version_index=0)
316
+
317
+ def op_read(self, req: tsi.OpReadReq) -> tsi.OpReadRes:
318
+ return tsi.OpReadRes(op=None) # type: ignore
319
+
320
+ def op_list(self, req: tsi.OpListReq) -> Iterator[tsi.OpReadRes]:
321
+ yield from []
322
+
323
+ def op_delete(self, req: tsi.OpDeleteReq) -> tsi.OpDeleteRes:
324
+ return tsi.OpDeleteRes(num_deleted=0)
325
+
326
+ # --- Datasets ---
327
+ def dataset_create(self, req: tsi.DatasetCreateReq) -> tsi.DatasetCreateRes:
328
+ return tsi.DatasetCreateRes(digest=generate_id(), object_id=generate_id(), version_index=0)
329
+
330
+ def dataset_read(self, req: tsi.DatasetReadReq) -> tsi.DatasetReadRes:
331
+ return tsi.DatasetReadRes(dataset=None) # type: ignore
332
+
333
+ def dataset_list(self, req: tsi.DatasetListReq) -> Iterator[tsi.DatasetReadRes]:
334
+ yield from []
335
+
336
+ def dataset_delete(self, req: tsi.DatasetDeleteReq) -> tsi.DatasetDeleteRes:
337
+ return tsi.DatasetDeleteRes(num_deleted=0)
338
+
339
+ # --- Scorers ---
340
+ def scorer_create(self, req: tsi.ScorerCreateReq) -> tsi.ScorerCreateRes:
341
+ return tsi.ScorerCreateRes(digest=generate_id(), object_id=generate_id(), version_index=0, scorer=generate_id())
342
+
343
+ def scorer_read(self, req: tsi.ScorerReadReq) -> tsi.ScorerReadRes:
344
+ return tsi.ScorerReadRes(scorer=None) # type: ignore
345
+
346
+ def scorer_list(self, req: tsi.ScorerListReq) -> Iterator[tsi.ScorerReadRes]:
347
+ yield from []
348
+
349
+ def scorer_delete(self, req: tsi.ScorerDeleteReq) -> tsi.ScorerDeleteRes:
350
+ return tsi.ScorerDeleteRes(num_deleted=0)
351
+
352
+ # --- Evaluations (V2) ---
353
+ def evaluation_create(self, req: tsi.EvaluationCreateReq) -> tsi.EvaluationCreateRes:
354
+ return tsi.EvaluationCreateRes(
355
+ digest=generate_id(), object_id=generate_id(), version_index=0, evaluation_ref=generate_id()
356
+ )
357
+
358
+ def evaluation_read(self, req: tsi.EvaluationReadReq) -> tsi.EvaluationReadRes:
359
+ return tsi.EvaluationReadRes(evaluation=None) # type: ignore
360
+
361
+ def evaluation_list(self, req: tsi.EvaluationListReq) -> Iterator[tsi.EvaluationReadRes]:
362
+ yield from []
363
+
364
+ def evaluation_delete(self, req: tsi.EvaluationDeleteReq) -> tsi.EvaluationDeleteRes:
365
+ return tsi.EvaluationDeleteRes(num_deleted=0)
366
+
367
+ # --- Models ---
368
+ def model_create(self, req: tsi.ModelCreateReq) -> tsi.ModelCreateRes:
369
+ return tsi.ModelCreateRes(
370
+ digest=generate_id(), object_id=generate_id(), version_index=0, model_ref=generate_id()
371
+ )
372
+
373
+ def model_read(self, req: tsi.ModelReadReq) -> tsi.ModelReadRes:
374
+ return tsi.ModelReadRes(model=None) # type: ignore
375
+
376
+ def model_list(self, req: tsi.ModelListReq) -> Iterator[tsi.ModelReadRes]:
377
+ yield from []
378
+
379
+ def model_delete(self, req: tsi.ModelDeleteReq) -> tsi.ModelDeleteRes:
380
+ return tsi.ModelDeleteRes(num_deleted=0)
381
+
382
+ # --- Evaluation Runs ---
383
+ def evaluation_run_create(self, req: tsi.EvaluationRunCreateReq) -> tsi.EvaluationRunCreateRes:
384
+ return tsi.EvaluationRunCreateRes(evaluation_run_id=generate_id())
385
+
386
+ def evaluation_run_read(self, req: tsi.EvaluationRunReadReq) -> tsi.EvaluationRunReadRes:
387
+ return tsi.EvaluationRunReadRes(evaluation_run=None) # type: ignore
388
+
389
+ def evaluation_run_list(self, req: tsi.EvaluationRunListReq) -> Iterator[tsi.EvaluationRunReadRes]:
390
+ yield from []
391
+
392
+ def evaluation_run_delete(self, req: tsi.EvaluationRunDeleteReq) -> tsi.EvaluationRunDeleteRes:
393
+ return tsi.EvaluationRunDeleteRes(num_deleted=0)
394
+
395
+ def evaluation_run_finish(self, req: tsi.EvaluationRunFinishReq) -> tsi.EvaluationRunFinishRes:
396
+ return tsi.EvaluationRunFinishRes(success=True)
397
+
398
+ # --- Predictions ---
399
+ def prediction_create(self, req: tsi.PredictionCreateReq) -> tsi.PredictionCreateRes:
400
+ return tsi.PredictionCreateRes(prediction_id=generate_id())
401
+
402
+ def prediction_read(self, req: tsi.PredictionReadReq) -> tsi.PredictionReadRes:
403
+ return tsi.PredictionReadRes(prediction=None) # type: ignore
404
+
405
+ def prediction_list(self, req: tsi.PredictionListReq) -> Iterator[tsi.PredictionReadRes]:
406
+ yield from []
407
+
408
+ def prediction_delete(self, req: tsi.PredictionDeleteReq) -> tsi.PredictionDeleteRes:
409
+ return tsi.PredictionDeleteRes(num_deleted=0)
410
+
411
+ def prediction_finish(self, req: tsi.PredictionFinishReq) -> tsi.PredictionFinishRes:
412
+ return tsi.PredictionFinishRes(success=True)
413
+
414
+ # --- Scores ---
415
+ def score_create(self, req: tsi.ScoreCreateReq) -> tsi.ScoreCreateRes:
416
+ return tsi.ScoreCreateRes(score_id=generate_id())
417
+
418
+ def score_read(self, req: tsi.ScoreReadReq) -> tsi.ScoreReadRes:
419
+ return tsi.ScoreReadRes(score=None) # type: ignore
420
+
421
+ def score_list(self, req: tsi.ScoreListReq) -> Iterator[tsi.ScoreReadRes]:
422
+ yield from []
423
+
424
+ def score_delete(self, req: tsi.ScoreDeleteReq) -> tsi.ScoreDeleteRes:
425
+ return tsi.ScoreDeleteRes(num_deleted=0)
426
+
427
+
428
+ # Module-level storage for originals
429
+ _original_init_weave_get_server: Callable[..., Any] | None = None
430
+ _original_get_entity_project_from_project_name: Callable[..., Any] | None = None
431
+ _original_get_username: Callable[..., Any] | None = None
432
+
433
+
434
+ def init_weave_get_server_factory(server: InMemoryWeaveTraceServer) -> Callable[..., Any]:
435
+ # Bypass the usage of Weave remote server
436
+ def init_weave_get_server(*args: Any, **kwargs: Any) -> InMemoryWeaveTraceServer:
437
+ return server
438
+
439
+ return init_weave_get_server
440
+
441
+
442
+ def get_entity_project_from_project_name_factory(entity_name: str) -> tuple[str, str]:
443
+ # Bypass the usage of API
444
+ try:
445
+ assert _original_get_entity_project_from_project_name is not None
446
+ if _original_get_entity_project_from_project_name is not get_entity_project_from_project_name_factory:
447
+ return _original_get_entity_project_from_project_name(entity_name)
448
+ else:
449
+ warnings.warn("W&B integration might have been repeatedly/recursively instrumented.")
450
+ return "msk", "weave"
451
+ except weave.trace.weave_init.WeaveWandbAuthenticationException:
452
+ # In case API is not available.
453
+ return "msk", "weave"
454
+
455
+
456
+ def get_username() -> str:
457
+ # Bypass the usage of API
458
+ try:
459
+ assert _original_get_username is not None
460
+ return _original_get_username()
461
+ except RuntimeError:
462
+ return "msk"
463
+ except Exception as exc:
464
+ warnings.warn(f"Unexpected error in get_username. Using default username. Error: {exc}")
465
+ return "msk"
466
+
467
+
468
+ def instrument_weave(server: InMemoryWeaveTraceServer):
469
+ """Patch the Weave/W&B integration to bypass actual network calls for testing."""
470
+
471
+ global _original_init_weave_get_server, _original_get_entity_project_from_project_name, _original_get_username
472
+ _original_init_weave_get_server = weave.trace.weave_init.init_weave_get_server
473
+ _original_get_entity_project_from_project_name = weave.trace.weave_init.get_entity_project_from_project_name
474
+ _original_get_username = weave.trace.weave_init.get_username
475
+ weave.trace.weave_init.init_weave_get_server = init_weave_get_server_factory(server)
476
+ weave.trace.weave_init.get_entity_project_from_project_name = get_entity_project_from_project_name_factory
477
+ weave.trace.weave_init.get_username = get_username
478
+
479
+
480
+ def uninstrument_weave():
481
+ """Restore the original Weave/W&B integration methods and HTTP requests."""
482
+ global _original_init_weave_get_server, _original_get_entity_project_from_project_name, _original_get_username
483
+
484
+ if _original_init_weave_get_server is not None:
485
+ weave.trace.weave_init.init_weave_get_server = _original_init_weave_get_server
486
+ _original_init_weave_get_server = None
487
+ else:
488
+ raise RuntimeError("Weave/W&B integration was not instrumented.")
489
+
490
+ if _original_get_entity_project_from_project_name is not None:
491
+ weave.trace.weave_init.get_entity_project_from_project_name = _original_get_entity_project_from_project_name
492
+ _original_get_entity_project_from_project_name = None
493
+ else:
494
+ raise RuntimeError("Weave/W&B integration was not instrumented.")
495
+
496
+ if _original_get_username is not None:
497
+ weave.trace.weave_init.get_username = _original_get_username
498
+ _original_get_username = None
499
+ else:
500
+ raise RuntimeError("Weave/W&B integration was not instrumented.")
@@ -0,0 +1,11 @@
1
+ # Copyright (c) Microsoft. All rights reserved.
2
+
3
+ from .decorator import *
4
+ from .litagent import *
5
+
6
+ __all__ = [
7
+ "LitAgent",
8
+ "llm_rollout",
9
+ "prompt_rollout",
10
+ "rollout",
11
+ ]