mantisdk 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mantisdk might be problematic. Click here for more details.

Files changed (190) hide show
  1. mantisdk/__init__.py +22 -0
  2. mantisdk/adapter/__init__.py +15 -0
  3. mantisdk/adapter/base.py +94 -0
  4. mantisdk/adapter/messages.py +270 -0
  5. mantisdk/adapter/triplet.py +1028 -0
  6. mantisdk/algorithm/__init__.py +39 -0
  7. mantisdk/algorithm/apo/__init__.py +5 -0
  8. mantisdk/algorithm/apo/apo.py +889 -0
  9. mantisdk/algorithm/apo/prompts/apply_edit_variant01.poml +22 -0
  10. mantisdk/algorithm/apo/prompts/apply_edit_variant02.poml +18 -0
  11. mantisdk/algorithm/apo/prompts/text_gradient_variant01.poml +18 -0
  12. mantisdk/algorithm/apo/prompts/text_gradient_variant02.poml +16 -0
  13. mantisdk/algorithm/apo/prompts/text_gradient_variant03.poml +107 -0
  14. mantisdk/algorithm/base.py +162 -0
  15. mantisdk/algorithm/decorator.py +264 -0
  16. mantisdk/algorithm/fast.py +250 -0
  17. mantisdk/algorithm/gepa/__init__.py +59 -0
  18. mantisdk/algorithm/gepa/adapter.py +459 -0
  19. mantisdk/algorithm/gepa/gepa.py +364 -0
  20. mantisdk/algorithm/gepa/lib/__init__.py +18 -0
  21. mantisdk/algorithm/gepa/lib/adapters/README.md +12 -0
  22. mantisdk/algorithm/gepa/lib/adapters/__init__.py +0 -0
  23. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/README.md +341 -0
  24. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/__init__.py +1 -0
  25. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/anymaths_adapter.py +174 -0
  26. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/requirements.txt +1 -0
  27. mantisdk/algorithm/gepa/lib/adapters/default_adapter/README.md +0 -0
  28. mantisdk/algorithm/gepa/lib/adapters/default_adapter/__init__.py +0 -0
  29. mantisdk/algorithm/gepa/lib/adapters/default_adapter/default_adapter.py +209 -0
  30. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/README.md +7 -0
  31. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/__init__.py +0 -0
  32. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/dspy_adapter.py +307 -0
  33. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/README.md +99 -0
  34. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/dspy_program_proposal_signature.py +137 -0
  35. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/full_program_adapter.py +266 -0
  36. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/GEPA_RAG.md +621 -0
  37. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/__init__.py +56 -0
  38. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/evaluation_metrics.py +226 -0
  39. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/generic_rag_adapter.py +496 -0
  40. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/rag_pipeline.py +238 -0
  41. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_store_interface.py +212 -0
  42. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/__init__.py +2 -0
  43. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/chroma_store.py +196 -0
  44. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/lancedb_store.py +422 -0
  45. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/milvus_store.py +409 -0
  46. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/qdrant_store.py +368 -0
  47. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/weaviate_store.py +418 -0
  48. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/README.md +552 -0
  49. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/__init__.py +37 -0
  50. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_adapter.py +705 -0
  51. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_client.py +364 -0
  52. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/README.md +9 -0
  53. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/__init__.py +0 -0
  54. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/terminal_bench_adapter.py +217 -0
  55. mantisdk/algorithm/gepa/lib/api.py +375 -0
  56. mantisdk/algorithm/gepa/lib/core/__init__.py +0 -0
  57. mantisdk/algorithm/gepa/lib/core/adapter.py +180 -0
  58. mantisdk/algorithm/gepa/lib/core/data_loader.py +74 -0
  59. mantisdk/algorithm/gepa/lib/core/engine.py +356 -0
  60. mantisdk/algorithm/gepa/lib/core/result.py +233 -0
  61. mantisdk/algorithm/gepa/lib/core/state.py +636 -0
  62. mantisdk/algorithm/gepa/lib/examples/__init__.py +0 -0
  63. mantisdk/algorithm/gepa/lib/examples/aime.py +24 -0
  64. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/eval_default.py +111 -0
  65. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/instruction_prompt.txt +9 -0
  66. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/optimal_prompt.txt +24 -0
  67. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/train_anymaths.py +177 -0
  68. mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/arc_agi.ipynb +25705 -0
  69. mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/example.ipynb +348 -0
  70. mantisdk/algorithm/gepa/lib/examples/mcp_adapter/__init__.py +4 -0
  71. mantisdk/algorithm/gepa/lib/examples/mcp_adapter/mcp_optimization_example.py +455 -0
  72. mantisdk/algorithm/gepa/lib/examples/rag_adapter/RAG_GUIDE.md +613 -0
  73. mantisdk/algorithm/gepa/lib/examples/rag_adapter/__init__.py +9 -0
  74. mantisdk/algorithm/gepa/lib/examples/rag_adapter/rag_optimization.py +824 -0
  75. mantisdk/algorithm/gepa/lib/examples/rag_adapter/requirements-rag.txt +29 -0
  76. mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/instruction_prompt.txt +16 -0
  77. mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/terminus.txt +9 -0
  78. mantisdk/algorithm/gepa/lib/examples/terminal-bench/train_terminus.py +161 -0
  79. mantisdk/algorithm/gepa/lib/gepa_utils.py +117 -0
  80. mantisdk/algorithm/gepa/lib/logging/__init__.py +0 -0
  81. mantisdk/algorithm/gepa/lib/logging/experiment_tracker.py +187 -0
  82. mantisdk/algorithm/gepa/lib/logging/logger.py +75 -0
  83. mantisdk/algorithm/gepa/lib/logging/utils.py +103 -0
  84. mantisdk/algorithm/gepa/lib/proposer/__init__.py +0 -0
  85. mantisdk/algorithm/gepa/lib/proposer/base.py +31 -0
  86. mantisdk/algorithm/gepa/lib/proposer/merge.py +357 -0
  87. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/__init__.py +0 -0
  88. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/base.py +49 -0
  89. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/reflective_mutation.py +176 -0
  90. mantisdk/algorithm/gepa/lib/py.typed +0 -0
  91. mantisdk/algorithm/gepa/lib/strategies/__init__.py +0 -0
  92. mantisdk/algorithm/gepa/lib/strategies/batch_sampler.py +77 -0
  93. mantisdk/algorithm/gepa/lib/strategies/candidate_selector.py +50 -0
  94. mantisdk/algorithm/gepa/lib/strategies/component_selector.py +36 -0
  95. mantisdk/algorithm/gepa/lib/strategies/eval_policy.py +64 -0
  96. mantisdk/algorithm/gepa/lib/strategies/instruction_proposal.py +127 -0
  97. mantisdk/algorithm/gepa/lib/utils/__init__.py +10 -0
  98. mantisdk/algorithm/gepa/lib/utils/stop_condition.py +196 -0
  99. mantisdk/algorithm/gepa/tracing.py +105 -0
  100. mantisdk/algorithm/utils.py +177 -0
  101. mantisdk/algorithm/verl/__init__.py +5 -0
  102. mantisdk/algorithm/verl/interface.py +202 -0
  103. mantisdk/cli/__init__.py +56 -0
  104. mantisdk/cli/prometheus.py +115 -0
  105. mantisdk/cli/store.py +131 -0
  106. mantisdk/cli/vllm.py +29 -0
  107. mantisdk/client.py +408 -0
  108. mantisdk/config.py +348 -0
  109. mantisdk/emitter/__init__.py +43 -0
  110. mantisdk/emitter/annotation.py +370 -0
  111. mantisdk/emitter/exception.py +54 -0
  112. mantisdk/emitter/message.py +61 -0
  113. mantisdk/emitter/object.py +117 -0
  114. mantisdk/emitter/reward.py +320 -0
  115. mantisdk/env_var.py +156 -0
  116. mantisdk/execution/__init__.py +15 -0
  117. mantisdk/execution/base.py +64 -0
  118. mantisdk/execution/client_server.py +443 -0
  119. mantisdk/execution/events.py +69 -0
  120. mantisdk/execution/inter_process.py +16 -0
  121. mantisdk/execution/shared_memory.py +282 -0
  122. mantisdk/instrumentation/__init__.py +119 -0
  123. mantisdk/instrumentation/agentops.py +314 -0
  124. mantisdk/instrumentation/agentops_langchain.py +45 -0
  125. mantisdk/instrumentation/litellm.py +83 -0
  126. mantisdk/instrumentation/vllm.py +81 -0
  127. mantisdk/instrumentation/weave.py +500 -0
  128. mantisdk/litagent/__init__.py +11 -0
  129. mantisdk/litagent/decorator.py +536 -0
  130. mantisdk/litagent/litagent.py +252 -0
  131. mantisdk/llm_proxy.py +1890 -0
  132. mantisdk/logging.py +370 -0
  133. mantisdk/reward.py +7 -0
  134. mantisdk/runner/__init__.py +11 -0
  135. mantisdk/runner/agent.py +845 -0
  136. mantisdk/runner/base.py +182 -0
  137. mantisdk/runner/legacy.py +309 -0
  138. mantisdk/semconv.py +170 -0
  139. mantisdk/server.py +401 -0
  140. mantisdk/store/__init__.py +23 -0
  141. mantisdk/store/base.py +897 -0
  142. mantisdk/store/client_server.py +2092 -0
  143. mantisdk/store/collection/__init__.py +30 -0
  144. mantisdk/store/collection/base.py +587 -0
  145. mantisdk/store/collection/memory.py +970 -0
  146. mantisdk/store/collection/mongo.py +1412 -0
  147. mantisdk/store/collection_based.py +1823 -0
  148. mantisdk/store/insight.py +648 -0
  149. mantisdk/store/listener.py +58 -0
  150. mantisdk/store/memory.py +396 -0
  151. mantisdk/store/mongo.py +165 -0
  152. mantisdk/store/sqlite.py +3 -0
  153. mantisdk/store/threading.py +357 -0
  154. mantisdk/store/utils.py +142 -0
  155. mantisdk/tracer/__init__.py +16 -0
  156. mantisdk/tracer/agentops.py +242 -0
  157. mantisdk/tracer/base.py +287 -0
  158. mantisdk/tracer/dummy.py +106 -0
  159. mantisdk/tracer/otel.py +555 -0
  160. mantisdk/tracer/weave.py +677 -0
  161. mantisdk/trainer/__init__.py +6 -0
  162. mantisdk/trainer/init_utils.py +263 -0
  163. mantisdk/trainer/legacy.py +367 -0
  164. mantisdk/trainer/registry.py +12 -0
  165. mantisdk/trainer/trainer.py +618 -0
  166. mantisdk/types/__init__.py +6 -0
  167. mantisdk/types/core.py +553 -0
  168. mantisdk/types/resources.py +204 -0
  169. mantisdk/types/tracer.py +515 -0
  170. mantisdk/types/tracing.py +218 -0
  171. mantisdk/utils/__init__.py +1 -0
  172. mantisdk/utils/id.py +18 -0
  173. mantisdk/utils/metrics.py +1025 -0
  174. mantisdk/utils/otel.py +578 -0
  175. mantisdk/utils/otlp.py +536 -0
  176. mantisdk/utils/server_launcher.py +1045 -0
  177. mantisdk/utils/system_snapshot.py +81 -0
  178. mantisdk/verl/__init__.py +8 -0
  179. mantisdk/verl/__main__.py +6 -0
  180. mantisdk/verl/async_server.py +46 -0
  181. mantisdk/verl/config.yaml +27 -0
  182. mantisdk/verl/daemon.py +1154 -0
  183. mantisdk/verl/dataset.py +44 -0
  184. mantisdk/verl/entrypoint.py +248 -0
  185. mantisdk/verl/trainer.py +549 -0
  186. mantisdk-0.1.0.dist-info/METADATA +119 -0
  187. mantisdk-0.1.0.dist-info/RECORD +190 -0
  188. mantisdk-0.1.0.dist-info/WHEEL +4 -0
  189. mantisdk-0.1.0.dist-info/entry_points.txt +2 -0
  190. mantisdk-0.1.0.dist-info/licenses/LICENSE +19 -0
@@ -0,0 +1,677 @@
1
+ # Copyright (c) Microsoft. All rights reserved.
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import concurrent.futures as futures
7
+ import logging
8
+ import os
9
+ import re
10
+ import weakref
11
+ from contextlib import asynccontextmanager, contextmanager
12
+ from datetime import datetime
13
+ from typing import (
14
+ Any,
15
+ AsyncIterator,
16
+ Callable,
17
+ Dict,
18
+ Iterator,
19
+ List,
20
+ Optional,
21
+ cast,
22
+ )
23
+
24
+ import weave
25
+ from opentelemetry.semconv.attributes import exception_attributes
26
+ from weave.trace.call import Call
27
+ from weave.trace.settings import UserSettings
28
+ from weave.trace.weave_client import WeaveClient
29
+ from weave.trace_server import trace_server_interface as tsi
30
+ from weave.wandb_interface.context import set_wandb_api_context
31
+
32
+ from mantisdk.instrumentation.weave import InMemoryWeaveTraceServer, instrument_weave, uninstrument_weave
33
+ from mantisdk.semconv import LightningResourceAttributes, LightningSpanAttributes
34
+ from mantisdk.store.base import LightningStore
35
+ from mantisdk.types import (
36
+ Attributes,
37
+ OtelResource,
38
+ Span,
39
+ SpanContext,
40
+ SpanCoreFields,
41
+ SpanRecordingContext,
42
+ StatusCode,
43
+ TraceStatus,
44
+ )
45
+ from mantisdk.utils.id import generate_id
46
+ from mantisdk.utils.otel import (
47
+ filter_and_unflatten_attributes,
48
+ flatten_attributes,
49
+ format_exception_attributes,
50
+ sanitize_attributes,
51
+ )
52
+
53
+ from .base import Tracer, with_active_tracer_context
54
+
55
+ logger = logging.getLogger(__name__)
56
+
57
+
58
+ def op_name_to_func_name(op_name: str) -> str:
59
+ """Convert a Weave operation name to a function name.
60
+
61
+ Weave operation names look like this: `weave:///xxx/mantisdk.tracer.weave/op/openai.chat.completions.create:019b10be-...-44d74272569c`
62
+ """
63
+ match = re.search(r"/([^/:]+):", op_name)
64
+ if match:
65
+ return match.group(1)
66
+ else:
67
+ return op_name
68
+
69
+
70
+ def random_project_name() -> str:
71
+ return "msk/weave-" + generate_id(12)
72
+
73
+
74
+ def get_timestamp_or_throw(date: Optional[datetime], field_name: str) -> float:
75
+ if date is None:
76
+ raise ValueError(f"{field_name} is required but not set")
77
+ return date.timestamp()
78
+
79
+
80
+ class WeaveSpanRecordingContext(SpanRecordingContext):
81
+ """Universal interface for recording operations on a Weave call."""
82
+
83
+ def __init__(self, call: Call) -> None:
84
+ self._call = call
85
+
86
+ def record_exception(self, exception: BaseException) -> None:
87
+ self._call.exception = str(exception)
88
+ self.record_status("ERROR", str(exception))
89
+ self.record_attributes(format_exception_attributes(exception))
90
+
91
+ def _get_input_from_attributes(self, attributes: Attributes) -> Dict[str, Any]:
92
+ if LightningSpanAttributes.OPERATION_INPUT.value in attributes:
93
+ # This can be a very rare case. If it happens, we can just let it throw.
94
+ return cast(Dict[str, Any], attributes[LightningSpanAttributes.OPERATION_INPUT.value])
95
+ else:
96
+ filtered_attributes = filter_and_unflatten_attributes(
97
+ attributes, LightningSpanAttributes.OPERATION_INPUT.value
98
+ )
99
+ if isinstance(filtered_attributes, list):
100
+ return {str(i): v for i, v in enumerate(filtered_attributes)}
101
+ else:
102
+ return filtered_attributes
103
+
104
+ def _get_output_from_attributes(self, attributes: Attributes) -> Any:
105
+ if LightningSpanAttributes.OPERATION_OUTPUT.value in attributes:
106
+ return attributes[LightningSpanAttributes.OPERATION_OUTPUT.value]
107
+ else:
108
+ return filter_and_unflatten_attributes(attributes, LightningSpanAttributes.OPERATION_OUTPUT.value)
109
+
110
+ def record_attributes(self, attributes: Attributes) -> None:
111
+ input_attributes = self._get_input_from_attributes(attributes)
112
+ if input_attributes:
113
+ self._call.inputs.update(input_attributes)
114
+
115
+ output_attributes = self._get_output_from_attributes(attributes)
116
+ if output_attributes:
117
+ if self._call.output is not None:
118
+ logger.warning(f"Output is already set. It will be overridden: {self._call.output}")
119
+ self._call.output = output_attributes
120
+
121
+ if LightningSpanAttributes.OPERATION_NAME.value in attributes:
122
+ logger.error(
123
+ f"Cannot record operation name as an attribute. It will be skipped: {attributes[LightningSpanAttributes.OPERATION_NAME.value]}"
124
+ )
125
+
126
+ # The rest of the attributes are recorded as summary.
127
+ for key, value in attributes.items():
128
+ if (
129
+ not key == LightningSpanAttributes.OPERATION_INPUT.value
130
+ and not key.startswith(LightningSpanAttributes.OPERATION_INPUT.value + ".")
131
+ and not key == LightningSpanAttributes.OPERATION_OUTPUT.value
132
+ and not key.startswith(LightningSpanAttributes.OPERATION_OUTPUT.value + ".")
133
+ and not key == LightningSpanAttributes.OPERATION_NAME.value
134
+ ):
135
+ if self._call.summary is None:
136
+ self._call.summary = {}
137
+ self._call.summary[key] = value
138
+
139
+ def record_status(self, status_code: StatusCode, description: Optional[str] = None) -> None:
140
+ if status_code == "ERROR":
141
+ if not description:
142
+ raise ValueError("Description is required when status code is ERROR")
143
+ self._call.exception = description
144
+ elif status_code == "OK":
145
+ self._call.exception = None
146
+ # Do nothing for other status codes.
147
+
148
+ def finalize(self) -> None:
149
+ # Do nothing
150
+ pass
151
+
152
+ def get_recorded_span(self) -> SpanCoreFields:
153
+ return SpanCoreFields(
154
+ name=self._call.op_name,
155
+ attributes=flatten_attributes(self._call.attributes or {}),
156
+ start_time=self._call.started_at.timestamp() if self._call.started_at else None,
157
+ end_time=self._call.ended_at.timestamp() if self._call.ended_at else None,
158
+ status=TraceStatus(
159
+ status_code="OK" if self._call.exception is None else "ERROR", description=self._call.exception
160
+ ),
161
+ )
162
+
163
+
164
+ class WeaveTracerManagedTraceServer(InMemoryWeaveTraceServer):
165
+ """A managed trace server for WeaveTracer."""
166
+
167
+ def __init__(
168
+ self,
169
+ partial_call_callback: Callable[[Dict[str, Any]], None],
170
+ complete_call_callback: Callable[[tsi.CallSchema], None],
171
+ ):
172
+ super().__init__()
173
+ self.partial_call_callback = partial_call_callback
174
+ self.complete_call_callback = complete_call_callback
175
+ self._calls_already_invoked: set[str] = set()
176
+
177
+ def trigger_callbacks(self, call_id: str) -> None:
178
+ with self._call_threading_lock:
179
+ if call_id in self.calls:
180
+ if call_id not in self._calls_already_invoked:
181
+ self._calls_already_invoked.add(call_id)
182
+ self.complete_call_callback(self.calls[call_id])
183
+ else:
184
+ logger.info(f"Call {call_id} has callback already invoked. Skipping.")
185
+ elif call_id in self.partial_calls:
186
+ self.partial_call_callback(self.partial_calls[call_id])
187
+ else:
188
+ logger.error(f"Call {call_id} not found in partial_calls or calls")
189
+
190
+ def call_start(self, req: tsi.CallStartReq) -> tsi.CallStartRes:
191
+ try:
192
+ ret = super().call_start(req)
193
+ self.trigger_callbacks(ret.id)
194
+ return ret
195
+ except Exception:
196
+ logger.exception(f"Error calling call_start: {req}", exc_info=True)
197
+ raise
198
+
199
+ def call_end(self, req: tsi.CallEndReq) -> tsi.CallEndRes:
200
+ try:
201
+ ret = super().call_end(req)
202
+ self.trigger_callbacks(req.end.id)
203
+ return ret
204
+ except Exception:
205
+ logger.exception(f"Error calling call_end: {req}", exc_info=True)
206
+ raise
207
+
208
+ def clear(self) -> None:
209
+ self._calls_already_invoked.clear()
210
+
211
+
212
+ class WeaveTracer(Tracer):
213
+ """Tracer implementation using Weave for telemetry and trace logging.
214
+
215
+ This replaces AgentOpsTracer with a Weave-based manual trace context. It tracks:
216
+
217
+ - Function/method calls
218
+ - Input/Output data
219
+ - Exceptions
220
+
221
+ and logs them to Weave Cloud (W&B backend) or optionally bypasses the network for testing.
222
+ """
223
+
224
+ def __init__(
225
+ self,
226
+ *,
227
+ project_name: str | None = None,
228
+ weave_user_settings: UserSettings | None = None,
229
+ instrument_managed: bool = True,
230
+ ):
231
+ """Initialize a WeaveTracer instance.
232
+
233
+ Args:
234
+ project_name: Optional project name for Weave; defaults to the current module name.
235
+ weave_user_settings: Optional UserSettings for Weave.
236
+ instrument_managed: Whether to patch the Weave/W&B integration to bypass actual network calls for testing.
237
+ """
238
+ super().__init__()
239
+ self.project_name = project_name
240
+ self.instrument_managed = instrument_managed
241
+ self.weave_user_settings = weave_user_settings or UserSettings(use_server_cache=False)
242
+
243
+ self._store: Optional[LightningStore] = None
244
+ self._server = WeaveTracerManagedTraceServer(
245
+ partial_call_callback=self.partial_call_callback, complete_call_callback=self.complete_call_callback
246
+ )
247
+
248
+ self._default_sequence_counter: int = 0
249
+ self._calls: Dict[str, tsi.CallSchema] = {} # call_id -> call
250
+ self._spans: List[Span] = [] # spans in the current trace
251
+ self._rollout_id: Optional[str] = None
252
+ self._attempt_id: Optional[str] = None
253
+ self._partial_call_futures: Dict[str, asyncio.Future[int] | futures.Future[int]] = {}
254
+ self._complete_call_futures: List[asyncio.Future[None] | futures.Future[None]] = []
255
+ self._loop: weakref.ReferenceType[asyncio.AbstractEventLoop] | None = None
256
+
257
+ def instrument(self, worker_id: int):
258
+ instrument_weave(self._server)
259
+
260
+ def uninstrument(self, worker_id: int):
261
+ uninstrument_weave()
262
+
263
+ def init_worker(self, worker_id: int, store: Optional[LightningStore] = None):
264
+ """
265
+ Initialize the tracer for a worker thread/process.
266
+
267
+ Args:
268
+ worker_id: Identifier of the worker.
269
+ store: Optional LightningStore for storing spans.
270
+ """
271
+ super().init_worker(worker_id, store)
272
+ logger.info(f"[Worker {worker_id}] Setting up Weave tracer...")
273
+ self._store = store
274
+
275
+ # Optionally patch network calls to bypass real Weave/W&B endpoints
276
+ if self.instrument_managed:
277
+ self.instrument(worker_id)
278
+
279
+ # If WANDB_API_KEY is not set, we need to initialize Weave with a hack
280
+ if not os.getenv("WANDB_API_KEY"):
281
+ logger.info("WANDB_API_KEY is not set. Initializing Weave a mock context.")
282
+ set_wandb_api_context("msk", api_key=None, headers=None, cookies=None)
283
+ else:
284
+ logger.debug("WANDB_API_KEY is set. Weave will be initialized automatically.")
285
+
286
+ weave_client = weave.get_client()
287
+ if self.project_name is None:
288
+ self.project_name = random_project_name()
289
+
290
+ if weave_client is not None:
291
+ logger.warning("Weave client was already initialized. Reentrant calls are at your own risk.")
292
+ if weave_client.project == self.project_name:
293
+ logger.error(
294
+ f"Weave client was already initialized for the same project '{self.project_name}'. It's very likely that weave won't work correctly."
295
+ )
296
+
297
+ # Init no matter what
298
+ try:
299
+ weave.init(project_name=self.project_name, settings=self.weave_user_settings)
300
+ logger.info(f"[Worker {worker_id}] Weave client initialized.")
301
+ except Exception as exc:
302
+ raise RuntimeError(f"Failed to initialize Weave for project '{self.project_name}'") from exc
303
+
304
+ def teardown_worker(self, worker_id: int):
305
+ """
306
+ Clean up tracer resources for the worker.
307
+
308
+ Args:
309
+ worker_id: Identifier of the worker.
310
+ """
311
+ super().teardown_worker(worker_id)
312
+
313
+ if self.instrument_managed:
314
+ self.uninstrument(worker_id)
315
+ logger.info(f"[Worker {worker_id}] Instrumentation removed.")
316
+
317
+ @with_active_tracer_context
318
+ @asynccontextmanager
319
+ async def trace_context(
320
+ self,
321
+ name: Optional[str] = None,
322
+ *,
323
+ rollout_id: Optional[str] = None,
324
+ attempt_id: Optional[str] = None,
325
+ **kwargs: Any,
326
+ ) -> AsyncIterator[Any]:
327
+ """Asynchronous implementation of the tracing context.
328
+
329
+ Args:
330
+ name: Optional operation name.
331
+ rollout_id: Optional rollout ID.
332
+ attempt_id: Optional attempt ID.
333
+
334
+ Raises:
335
+ ValueError: If store, rollout_id, and attempt_id are inconsistently provided.
336
+ RuntimeError: If Weave is not installed or client is uninitialized.
337
+ """
338
+
339
+ if rollout_id is not None and attempt_id is not None:
340
+ self._rollout_id = rollout_id
341
+ self._attempt_id = attempt_id
342
+ elif rollout_id is None and attempt_id is None:
343
+ logger.info("No rollout_id or attempt_id provided. Skipping writing to store.")
344
+ self._rollout_id = self._attempt_id = None
345
+ else:
346
+ raise ValueError("rollout_id and attempt_id must be either both provided or both None")
347
+
348
+ await self._init_trace_context()
349
+
350
+ weave_client = self._get_weave_client()
351
+
352
+ if weave_client.server is not self._server:
353
+ logger.error(
354
+ "Weave client is not using the correct trace server. You might have multiple WeaveTracer instances running in the same process. "
355
+ f"Expected {self._server}, got {weave_client.server}"
356
+ )
357
+
358
+ arg_op = name or weave_client.project
359
+ arg_inputs: dict[str, str] = {}
360
+ if rollout_id is not None:
361
+ arg_inputs[LightningResourceAttributes.ROLLOUT_ID.value] = rollout_id
362
+ if attempt_id is not None:
363
+ arg_inputs[LightningResourceAttributes.ATTEMPT_ID.value] = attempt_id
364
+
365
+ try:
366
+ # Create a new trace call object in Weave
367
+ trace_call = weave_client.create_call( # pyright: ignore[reportUnknownMemberType]
368
+ op=arg_op, inputs=arg_inputs
369
+ )
370
+
371
+ try:
372
+ yield trace_call
373
+ # Finish trace even if no exception
374
+ weave_client.finish_call(trace_call) # pyright: ignore[reportUnknownMemberType]
375
+ except Exception as exc:
376
+ # Finish trace and log any exception
377
+ weave_client.finish_call(trace_call, exception=exc) # pyright: ignore[reportUnknownMemberType]
378
+ logger.error(f"Trace failed for rollout_id={rollout_id}, attempt_id={attempt_id}, error={exc}")
379
+ raise
380
+
381
+ finally:
382
+ try:
383
+ weave_client.flush()
384
+ # It's possible that the call end futures are from a dedicated Weave thread pool,
385
+ await asyncio.gather(*[asyncio.wrap_future(future) for future in self._complete_call_futures])
386
+
387
+ finally:
388
+ # Mandatory cleanup
389
+ self._rollout_id = None
390
+ self._attempt_id = None
391
+ self._server.clear()
392
+
393
+ def create_span(
394
+ self,
395
+ name: str,
396
+ attributes: Optional[Attributes] = None,
397
+ timestamp: Optional[float] = None,
398
+ status: Optional[TraceStatus] = None,
399
+ ) -> SpanCoreFields:
400
+ if timestamp is not None:
401
+ logger.warning("Weave doesn't support customizing the start time of a call. Timestamp is ignored.")
402
+ weave_client = self._get_weave_client()
403
+ trace_call = weave_client.create_call( # pyright: ignore[reportUnknownMemberType]
404
+ op=name,
405
+ attributes=attributes,
406
+ inputs={},
407
+ )
408
+ # Immediately finish the call
409
+ weave_client.finish_call(trace_call) # pyright: ignore[reportUnknownMemberType]
410
+ # We don't wait for the call to be propagated to the server.
411
+ start_time = trace_call.started_at.timestamp() if trace_call.started_at else None
412
+ end_time = trace_call.ended_at.timestamp() if trace_call.ended_at else None
413
+ trace_status = (
414
+ TraceStatus(status_code="OK")
415
+ if trace_call.exception is None
416
+ else TraceStatus(status_code="ERROR", description=trace_call.exception)
417
+ )
418
+ return SpanCoreFields(
419
+ name=name,
420
+ attributes=flatten_attributes(trace_call.attributes or {}),
421
+ start_time=start_time,
422
+ end_time=end_time,
423
+ status=trace_status,
424
+ )
425
+
426
+ @contextmanager
427
+ def operation_context(
428
+ self,
429
+ name: str,
430
+ attributes: Optional[Attributes] = None,
431
+ start_time: Optional[float] = None,
432
+ end_time: Optional[float] = None,
433
+ ) -> Iterator[SpanRecordingContext]:
434
+ if start_time is not None:
435
+ logger.warning("Weave doesn't support customizing the start time of a call. Timestamp is ignored.")
436
+ if end_time is not None:
437
+ logger.warning("Weave doesn't support customizing the end time of a call. Timestamp is ignored.")
438
+ weave_client = self._get_weave_client()
439
+ trace_call = weave_client.create_call( # pyright: ignore[reportUnknownMemberType]
440
+ op=name,
441
+ attributes=attributes,
442
+ inputs={},
443
+ )
444
+ recording_context = WeaveSpanRecordingContext(trace_call)
445
+ try:
446
+ yield recording_context
447
+ except Exception as exc:
448
+ recording_context.record_exception(exc)
449
+ raise
450
+ finally:
451
+ weave_client.finish_call(trace_call) # pyright: ignore[reportUnknownMemberType]
452
+
453
+ async def _init_trace_context(self) -> None:
454
+ """Initialize the trace context."""
455
+ self._spans.clear()
456
+ self._calls.clear()
457
+ self._partial_call_futures.clear()
458
+ self._complete_call_futures.clear()
459
+ self._loop = weakref.ref(asyncio.get_running_loop())
460
+
461
+ def _get_weave_client(self) -> WeaveClient:
462
+ """Get the Weave client."""
463
+ weave_client = weave.get_client()
464
+ if not weave_client:
465
+ raise RuntimeError("Weave client is not initialized. Call init_worker() first.")
466
+ return weave_client
467
+
468
+ def _ensure_loop(self) -> tuple[asyncio.AbstractEventLoop, bool]:
469
+ """Returns a usable event loop and a boolean indicating whether it's the current running loop.
470
+
471
+ Prefer using the main loop if it's possible. Otherwise, use the current running loop.
472
+ """
473
+ # Get the current running loop
474
+ try:
475
+ running_loop = asyncio.get_running_loop()
476
+ except RuntimeError:
477
+ running_loop = None
478
+
479
+ # Get the main loop, which can be a different loop
480
+ if self._loop is not None:
481
+ main_loop = self._loop()
482
+ else:
483
+ main_loop = None
484
+
485
+ if main_loop is not None:
486
+ return main_loop, id(main_loop) == id(running_loop)
487
+ elif running_loop is not None:
488
+ return running_loop, True
489
+ else:
490
+ raise RuntimeError("No running event loop found. This should not happen.")
491
+
492
+ def get_last_trace(self) -> List[Span]:
493
+ return self._spans
494
+
495
+ def partial_call_callback(self, request_content: Dict[str, Any]) -> None:
496
+ call_id = request_content.get("id")
497
+ if call_id is None:
498
+ raise ValueError("Call ID is required even for partial calls")
499
+
500
+ if call_id in self._partial_call_futures:
501
+ raise ValueError(f"Call {call_id} already has a start future")
502
+
503
+ # The callback must possibly be called from a dedicated Weave thread pool,
504
+ # but it should be executed on the main event loop.
505
+ try:
506
+ loop, is_current_loop = self._ensure_loop()
507
+ if is_current_loop:
508
+ task = loop.create_task(self.partial_call_handler(request_content))
509
+ else:
510
+ # Schedule the task on the dedicated loop
511
+ task = asyncio.run_coroutine_threadsafe(self.partial_call_handler(request_content), loop)
512
+ self._partial_call_futures[call_id] = task
513
+ except Exception as exc:
514
+ logger.exception(f"Error creating call start task: {exc}", exc_info=True)
515
+
516
+ def complete_call_callback(self, call: tsi.CallSchema) -> None:
517
+ try:
518
+ loop, is_current_loop = self._ensure_loop()
519
+ if is_current_loop:
520
+ task = loop.create_task(self.complete_call_handler(call))
521
+ else:
522
+ # Schedule the task on the dedicated loop
523
+ task = asyncio.run_coroutine_threadsafe(self.complete_call_handler(call), loop)
524
+ self._complete_call_futures.append(task)
525
+ except Exception as exc:
526
+ logger.exception(f"Error creating call finish task: {exc}", exc_info=True)
527
+
528
+ async def _get_next_sequence_id(self) -> int:
529
+ """Get the next sequence ID for a span.
530
+
531
+ Use store to get the next sequence ID if available, otherwise use a default counter.
532
+ """
533
+ if self._rollout_id and self._attempt_id and self._store:
534
+ return await self._store.get_next_span_sequence_id(self._rollout_id, self._attempt_id)
535
+ else:
536
+ self._default_sequence_counter += 1
537
+ return self._default_sequence_counter
538
+
539
+ async def partial_call_handler(self, request_content: Dict[str, Any]) -> int:
540
+ """Handler called when a Weave Call starts.
541
+
542
+ Args:
543
+ request_content: The partial Weave Call object.
544
+
545
+ Returns:
546
+ The sequence ID for the call.
547
+ """
548
+ sequence_id = await self._get_next_sequence_id()
549
+ return sequence_id
550
+
551
+ async def complete_call_handler(self, call: tsi.CallSchema) -> None:
552
+ """Handler called when a Weave Call finishes.
553
+
554
+ Converts the call (including nested children) into spans and stores them in LightningStore.
555
+ """
556
+ # Make sure the corresponding call_start_future is complete
557
+ if call.id in self._partial_call_futures:
558
+ sequence_id = await asyncio.wrap_future(self._partial_call_futures[call.id])
559
+ del self._partial_call_futures[call.id]
560
+ else:
561
+ # Fetch a new sequence ID as the call_start is somehow missing
562
+ if call.id in self._calls:
563
+ logger.warning(
564
+ f"Call {call.id} is already in calls. The call is already completed. Overwriting the call."
565
+ )
566
+ else:
567
+ logger.warning(f"Call {call.id} has no start future. Fetching a new sequence ID.")
568
+ sequence_id = await self._get_next_sequence_id()
569
+
570
+ self._calls[call.id] = call
571
+
572
+ span = await self.convert_call_to_span(call, self._rollout_id, self._attempt_id, sequence_id)
573
+ self._spans.append(span)
574
+ if self._store and self._rollout_id and self._attempt_id:
575
+ try:
576
+ await self._store.add_span(span)
577
+ except Exception as exc:
578
+ logger.exception(f"Error adding span to store: {exc}")
579
+
580
+ async def convert_call_to_span(
581
+ self,
582
+ call: tsi.CallSchema,
583
+ rollout_id: Optional[str] = None,
584
+ attempt_id: Optional[str] = None,
585
+ sequence_id: Optional[int] = None,
586
+ ) -> Span:
587
+ """Convert a Weave Call (with nested children) into a Mantisdk Span.
588
+
589
+ `rollout_id` and `attempt_id` are required to attach the spans to the store.
590
+
591
+ Args:
592
+ call: The Weave Call object.
593
+ rollout_id: Optional rollout ID to attach to spans.
594
+ attempt_id: Optional attempt ID to attach to spans.
595
+ sequence_id: Optional sequence ID to attach to spans.
596
+
597
+ Returns:
598
+ List of converted spans.
599
+ """
600
+ rollout_id = rollout_id or "rollout-dummy"
601
+ attempt_id = attempt_id or "attempt-dummy"
602
+ sequence_id = sequence_id or 0
603
+
604
+ start_ts: float = call.started_at.timestamp()
605
+ end_ts: Optional[float] = call.ended_at.timestamp() if call.ended_at else None
606
+
607
+ if call.exception:
608
+ status = TraceStatus(status_code="ERROR", description=call.exception)
609
+ else:
610
+ status = TraceStatus(status_code="OK")
611
+
612
+ attributes: Dict[str, Any] = {
613
+ LightningSpanAttributes.OPERATION_NAME.value: call.op_name,
614
+ # op_name can be possibly overridden by the attributes.
615
+ **call.attributes,
616
+ }
617
+ if call.inputs:
618
+ attributes[LightningSpanAttributes.OPERATION_INPUT.value] = call.inputs
619
+ if call.output:
620
+ attributes[LightningSpanAttributes.OPERATION_OUTPUT.value] = call.output
621
+ if call.summary:
622
+ # attributes can be possibly overridden by the summary.
623
+ attributes.update(call.summary)
624
+ if call.exception:
625
+ attributes[exception_attributes.EXCEPTION_MESSAGE] = call.exception
626
+
627
+ sanitized_attributes = sanitize_attributes(flatten_attributes(attributes, expand_leaf_lists=False))
628
+
629
+ context = SpanContext(
630
+ trace_id=call.trace_id,
631
+ span_id=call.id,
632
+ is_remote=False,
633
+ trace_state={},
634
+ )
635
+
636
+ # Get context for parent
637
+ if call.parent_id:
638
+ parent_call = self._calls.get(call.parent_id)
639
+ if parent_call:
640
+ parent_context = SpanContext(
641
+ trace_id=parent_call.trace_id,
642
+ span_id=parent_call.id,
643
+ is_remote=False,
644
+ trace_state={},
645
+ )
646
+ else:
647
+ parent_context = None
648
+ else:
649
+ parent_context = None
650
+
651
+ # Build the Span object
652
+ return Span(
653
+ rollout_id=rollout_id,
654
+ attempt_id=attempt_id,
655
+ sequence_id=sequence_id,
656
+ trace_id=call.trace_id,
657
+ span_id=call.id,
658
+ parent_id=call.parent_id,
659
+ name=op_name_to_func_name(call.op_name),
660
+ status=status,
661
+ attributes=sanitized_attributes,
662
+ events=[], # Weave calls do not generate events
663
+ links=[], # Weave calls do not generate links
664
+ start_time=start_ts,
665
+ end_time=end_ts,
666
+ context=context,
667
+ parent=parent_context,
668
+ resource=OtelResource(
669
+ attributes={
670
+ LightningResourceAttributes.ROLLOUT_ID.value: rollout_id,
671
+ LightningResourceAttributes.ATTEMPT_ID.value: attempt_id,
672
+ LightningResourceAttributes.SPAN_SEQUENCE_ID.value: sequence_id,
673
+ LightningResourceAttributes.TRACER_NAME.value: "weave",
674
+ },
675
+ schema_url="",
676
+ ),
677
+ )
@@ -0,0 +1,6 @@
1
+ # Copyright (c) Microsoft. All rights reserved.
2
+
3
+ from .init_utils import build_component
4
+ from .trainer import Trainer
5
+
6
+ __all__ = ["Trainer", "build_component"]