mantisdk 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mantisdk might be problematic. Click here for more details.

Files changed (190) hide show
  1. mantisdk/__init__.py +22 -0
  2. mantisdk/adapter/__init__.py +15 -0
  3. mantisdk/adapter/base.py +94 -0
  4. mantisdk/adapter/messages.py +270 -0
  5. mantisdk/adapter/triplet.py +1028 -0
  6. mantisdk/algorithm/__init__.py +39 -0
  7. mantisdk/algorithm/apo/__init__.py +5 -0
  8. mantisdk/algorithm/apo/apo.py +889 -0
  9. mantisdk/algorithm/apo/prompts/apply_edit_variant01.poml +22 -0
  10. mantisdk/algorithm/apo/prompts/apply_edit_variant02.poml +18 -0
  11. mantisdk/algorithm/apo/prompts/text_gradient_variant01.poml +18 -0
  12. mantisdk/algorithm/apo/prompts/text_gradient_variant02.poml +16 -0
  13. mantisdk/algorithm/apo/prompts/text_gradient_variant03.poml +107 -0
  14. mantisdk/algorithm/base.py +162 -0
  15. mantisdk/algorithm/decorator.py +264 -0
  16. mantisdk/algorithm/fast.py +250 -0
  17. mantisdk/algorithm/gepa/__init__.py +59 -0
  18. mantisdk/algorithm/gepa/adapter.py +459 -0
  19. mantisdk/algorithm/gepa/gepa.py +364 -0
  20. mantisdk/algorithm/gepa/lib/__init__.py +18 -0
  21. mantisdk/algorithm/gepa/lib/adapters/README.md +12 -0
  22. mantisdk/algorithm/gepa/lib/adapters/__init__.py +0 -0
  23. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/README.md +341 -0
  24. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/__init__.py +1 -0
  25. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/anymaths_adapter.py +174 -0
  26. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/requirements.txt +1 -0
  27. mantisdk/algorithm/gepa/lib/adapters/default_adapter/README.md +0 -0
  28. mantisdk/algorithm/gepa/lib/adapters/default_adapter/__init__.py +0 -0
  29. mantisdk/algorithm/gepa/lib/adapters/default_adapter/default_adapter.py +209 -0
  30. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/README.md +7 -0
  31. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/__init__.py +0 -0
  32. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/dspy_adapter.py +307 -0
  33. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/README.md +99 -0
  34. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/dspy_program_proposal_signature.py +137 -0
  35. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/full_program_adapter.py +266 -0
  36. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/GEPA_RAG.md +621 -0
  37. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/__init__.py +56 -0
  38. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/evaluation_metrics.py +226 -0
  39. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/generic_rag_adapter.py +496 -0
  40. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/rag_pipeline.py +238 -0
  41. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_store_interface.py +212 -0
  42. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/__init__.py +2 -0
  43. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/chroma_store.py +196 -0
  44. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/lancedb_store.py +422 -0
  45. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/milvus_store.py +409 -0
  46. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/qdrant_store.py +368 -0
  47. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/weaviate_store.py +418 -0
  48. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/README.md +552 -0
  49. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/__init__.py +37 -0
  50. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_adapter.py +705 -0
  51. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_client.py +364 -0
  52. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/README.md +9 -0
  53. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/__init__.py +0 -0
  54. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/terminal_bench_adapter.py +217 -0
  55. mantisdk/algorithm/gepa/lib/api.py +375 -0
  56. mantisdk/algorithm/gepa/lib/core/__init__.py +0 -0
  57. mantisdk/algorithm/gepa/lib/core/adapter.py +180 -0
  58. mantisdk/algorithm/gepa/lib/core/data_loader.py +74 -0
  59. mantisdk/algorithm/gepa/lib/core/engine.py +356 -0
  60. mantisdk/algorithm/gepa/lib/core/result.py +233 -0
  61. mantisdk/algorithm/gepa/lib/core/state.py +636 -0
  62. mantisdk/algorithm/gepa/lib/examples/__init__.py +0 -0
  63. mantisdk/algorithm/gepa/lib/examples/aime.py +24 -0
  64. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/eval_default.py +111 -0
  65. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/instruction_prompt.txt +9 -0
  66. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/optimal_prompt.txt +24 -0
  67. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/train_anymaths.py +177 -0
  68. mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/arc_agi.ipynb +25705 -0
  69. mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/example.ipynb +348 -0
  70. mantisdk/algorithm/gepa/lib/examples/mcp_adapter/__init__.py +4 -0
  71. mantisdk/algorithm/gepa/lib/examples/mcp_adapter/mcp_optimization_example.py +455 -0
  72. mantisdk/algorithm/gepa/lib/examples/rag_adapter/RAG_GUIDE.md +613 -0
  73. mantisdk/algorithm/gepa/lib/examples/rag_adapter/__init__.py +9 -0
  74. mantisdk/algorithm/gepa/lib/examples/rag_adapter/rag_optimization.py +824 -0
  75. mantisdk/algorithm/gepa/lib/examples/rag_adapter/requirements-rag.txt +29 -0
  76. mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/instruction_prompt.txt +16 -0
  77. mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/terminus.txt +9 -0
  78. mantisdk/algorithm/gepa/lib/examples/terminal-bench/train_terminus.py +161 -0
  79. mantisdk/algorithm/gepa/lib/gepa_utils.py +117 -0
  80. mantisdk/algorithm/gepa/lib/logging/__init__.py +0 -0
  81. mantisdk/algorithm/gepa/lib/logging/experiment_tracker.py +187 -0
  82. mantisdk/algorithm/gepa/lib/logging/logger.py +75 -0
  83. mantisdk/algorithm/gepa/lib/logging/utils.py +103 -0
  84. mantisdk/algorithm/gepa/lib/proposer/__init__.py +0 -0
  85. mantisdk/algorithm/gepa/lib/proposer/base.py +31 -0
  86. mantisdk/algorithm/gepa/lib/proposer/merge.py +357 -0
  87. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/__init__.py +0 -0
  88. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/base.py +49 -0
  89. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/reflective_mutation.py +176 -0
  90. mantisdk/algorithm/gepa/lib/py.typed +0 -0
  91. mantisdk/algorithm/gepa/lib/strategies/__init__.py +0 -0
  92. mantisdk/algorithm/gepa/lib/strategies/batch_sampler.py +77 -0
  93. mantisdk/algorithm/gepa/lib/strategies/candidate_selector.py +50 -0
  94. mantisdk/algorithm/gepa/lib/strategies/component_selector.py +36 -0
  95. mantisdk/algorithm/gepa/lib/strategies/eval_policy.py +64 -0
  96. mantisdk/algorithm/gepa/lib/strategies/instruction_proposal.py +127 -0
  97. mantisdk/algorithm/gepa/lib/utils/__init__.py +10 -0
  98. mantisdk/algorithm/gepa/lib/utils/stop_condition.py +196 -0
  99. mantisdk/algorithm/gepa/tracing.py +105 -0
  100. mantisdk/algorithm/utils.py +177 -0
  101. mantisdk/algorithm/verl/__init__.py +5 -0
  102. mantisdk/algorithm/verl/interface.py +202 -0
  103. mantisdk/cli/__init__.py +56 -0
  104. mantisdk/cli/prometheus.py +115 -0
  105. mantisdk/cli/store.py +131 -0
  106. mantisdk/cli/vllm.py +29 -0
  107. mantisdk/client.py +408 -0
  108. mantisdk/config.py +348 -0
  109. mantisdk/emitter/__init__.py +43 -0
  110. mantisdk/emitter/annotation.py +370 -0
  111. mantisdk/emitter/exception.py +54 -0
  112. mantisdk/emitter/message.py +61 -0
  113. mantisdk/emitter/object.py +117 -0
  114. mantisdk/emitter/reward.py +320 -0
  115. mantisdk/env_var.py +156 -0
  116. mantisdk/execution/__init__.py +15 -0
  117. mantisdk/execution/base.py +64 -0
  118. mantisdk/execution/client_server.py +443 -0
  119. mantisdk/execution/events.py +69 -0
  120. mantisdk/execution/inter_process.py +16 -0
  121. mantisdk/execution/shared_memory.py +282 -0
  122. mantisdk/instrumentation/__init__.py +119 -0
  123. mantisdk/instrumentation/agentops.py +314 -0
  124. mantisdk/instrumentation/agentops_langchain.py +45 -0
  125. mantisdk/instrumentation/litellm.py +83 -0
  126. mantisdk/instrumentation/vllm.py +81 -0
  127. mantisdk/instrumentation/weave.py +500 -0
  128. mantisdk/litagent/__init__.py +11 -0
  129. mantisdk/litagent/decorator.py +536 -0
  130. mantisdk/litagent/litagent.py +252 -0
  131. mantisdk/llm_proxy.py +1890 -0
  132. mantisdk/logging.py +370 -0
  133. mantisdk/reward.py +7 -0
  134. mantisdk/runner/__init__.py +11 -0
  135. mantisdk/runner/agent.py +845 -0
  136. mantisdk/runner/base.py +182 -0
  137. mantisdk/runner/legacy.py +309 -0
  138. mantisdk/semconv.py +170 -0
  139. mantisdk/server.py +401 -0
  140. mantisdk/store/__init__.py +23 -0
  141. mantisdk/store/base.py +897 -0
  142. mantisdk/store/client_server.py +2092 -0
  143. mantisdk/store/collection/__init__.py +30 -0
  144. mantisdk/store/collection/base.py +587 -0
  145. mantisdk/store/collection/memory.py +970 -0
  146. mantisdk/store/collection/mongo.py +1412 -0
  147. mantisdk/store/collection_based.py +1823 -0
  148. mantisdk/store/insight.py +648 -0
  149. mantisdk/store/listener.py +58 -0
  150. mantisdk/store/memory.py +396 -0
  151. mantisdk/store/mongo.py +165 -0
  152. mantisdk/store/sqlite.py +3 -0
  153. mantisdk/store/threading.py +357 -0
  154. mantisdk/store/utils.py +142 -0
  155. mantisdk/tracer/__init__.py +16 -0
  156. mantisdk/tracer/agentops.py +242 -0
  157. mantisdk/tracer/base.py +287 -0
  158. mantisdk/tracer/dummy.py +106 -0
  159. mantisdk/tracer/otel.py +555 -0
  160. mantisdk/tracer/weave.py +677 -0
  161. mantisdk/trainer/__init__.py +6 -0
  162. mantisdk/trainer/init_utils.py +263 -0
  163. mantisdk/trainer/legacy.py +367 -0
  164. mantisdk/trainer/registry.py +12 -0
  165. mantisdk/trainer/trainer.py +618 -0
  166. mantisdk/types/__init__.py +6 -0
  167. mantisdk/types/core.py +553 -0
  168. mantisdk/types/resources.py +204 -0
  169. mantisdk/types/tracer.py +515 -0
  170. mantisdk/types/tracing.py +218 -0
  171. mantisdk/utils/__init__.py +1 -0
  172. mantisdk/utils/id.py +18 -0
  173. mantisdk/utils/metrics.py +1025 -0
  174. mantisdk/utils/otel.py +578 -0
  175. mantisdk/utils/otlp.py +536 -0
  176. mantisdk/utils/server_launcher.py +1045 -0
  177. mantisdk/utils/system_snapshot.py +81 -0
  178. mantisdk/verl/__init__.py +8 -0
  179. mantisdk/verl/__main__.py +6 -0
  180. mantisdk/verl/async_server.py +46 -0
  181. mantisdk/verl/config.yaml +27 -0
  182. mantisdk/verl/daemon.py +1154 -0
  183. mantisdk/verl/dataset.py +44 -0
  184. mantisdk/verl/entrypoint.py +248 -0
  185. mantisdk/verl/trainer.py +549 -0
  186. mantisdk-0.1.0.dist-info/METADATA +119 -0
  187. mantisdk-0.1.0.dist-info/RECORD +190 -0
  188. mantisdk-0.1.0.dist-info/WHEEL +4 -0
  189. mantisdk-0.1.0.dist-info/entry_points.txt +2 -0
  190. mantisdk-0.1.0.dist-info/licenses/LICENSE +19 -0
@@ -0,0 +1,182 @@
1
+ # Copyright (c) Microsoft. All rights reserved.
2
+
3
+ """Abstract runner interface for executing agent tasks."""
4
+
5
+ from __future__ import annotations
6
+
7
+ import logging
8
+ from contextlib import contextmanager
9
+ from typing import TYPE_CHECKING, Any, Generic, Iterator, Optional, Sequence, TypeVar
10
+
11
+ from mantisdk.execution.events import ExecutionEvent
12
+ from mantisdk.litagent import LitAgent
13
+ from mantisdk.store.base import LightningStore
14
+ from mantisdk.types import Hook, NamedResources, ParallelWorkerBase, Rollout, RolloutMode
15
+
16
+ if TYPE_CHECKING:
17
+ from mantisdk.execution.events import ExecutionEvent
18
+
19
+
20
+ T_task = TypeVar("T_task")
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ class Runner(ParallelWorkerBase, Generic[T_task]):
26
+ """Abstract base class for long-running agent executors.
27
+
28
+ Runner implementations coordinate [`LitAgent`][mantisdk.LitAgent]
29
+ instances, acquire work from a [`LightningStore`][mantisdk.LightningStore],
30
+ and emit [`Rollout`][mantisdk.Rollout] objects. Subclasses decide how
31
+ to schedule work (polling, streaming, etc.) while this base class provides a
32
+ minimal lifecycle contract.
33
+ """
34
+
35
+ def init(self, agent: LitAgent[T_task], **kwargs: Any) -> None:
36
+ """Prepare the runner to execute tasks for `agent`.
37
+
38
+ This method is called only once during the setup for all workers, not for each worker.
39
+
40
+ Args:
41
+ agent: Agent instance providing task-specific logic.
42
+ **kwargs: Optional runner-specific configuration.
43
+
44
+ Raises:
45
+ NotImplementedError: Subclasses must supply the initialization
46
+ routine.
47
+ """
48
+ raise NotImplementedError()
49
+
50
+ def init_worker(self, worker_id: int, store: LightningStore, **kwargs: Any) -> None:
51
+ """Configure worker-local state before processing tasks.
52
+
53
+ This method is called for **each** worker during the setup.
54
+
55
+ Args:
56
+ worker_id: Unique identifier for this worker process or thread.
57
+ store: Shared [`LightningStore`][mantisdk.LightningStore]
58
+ backing task coordination.
59
+ **kwargs: Optional worker-specific configuration.
60
+
61
+ Raises:
62
+ NotImplementedError: Subclasses must prepare per-worker resources.
63
+ """
64
+ raise NotImplementedError()
65
+
66
+ def run(self, *args: Any, **kwargs: Any) -> None:
67
+ """Deprecated synchronous entry point.
68
+
69
+ Use [`iter()`][mantisdk.Runner.iter] or [`step()`][mantisdk.Runner.step] instead.
70
+
71
+ Raises:
72
+ RuntimeError: Always raised to direct callers to
73
+ [iter()][mantisdk.Runner.iter] or
74
+ [step()][mantisdk.Runner.step].
75
+ """
76
+ raise RuntimeError("The behavior of run() of Runner is undefined. Use iter() or step() instead.")
77
+
78
+ def teardown(self, *args: Any, **kwargs: Any) -> None:
79
+ """Release resources acquired during [`init()`][mantisdk.Runner.init].
80
+
81
+ Raises:
82
+ NotImplementedError: Subclasses must implement the shutdown routine.
83
+ """
84
+ raise NotImplementedError()
85
+
86
+ def teardown_worker(self, worker_id: int, *args: Any, **kwargs: Any) -> None:
87
+ """Release per-worker resources allocated by [`init_worker()`][mantisdk.Runner.init_worker].
88
+
89
+ Args:
90
+ worker_id: Identifier of the worker being torn down.
91
+
92
+ Raises:
93
+ NotImplementedError: Subclasses must implement the shutdown routine.
94
+ """
95
+ raise NotImplementedError()
96
+
97
+ @contextmanager
98
+ def run_context(
99
+ self,
100
+ *,
101
+ agent: LitAgent[T_task],
102
+ store: LightningStore,
103
+ hooks: Optional[Sequence[Hook]] = None,
104
+ worker_id: Optional[int] = None,
105
+ ) -> Iterator[Runner[T_task]]:
106
+ """Initialize and tear down a runner within a simple context manager.
107
+
108
+ The helper is primarily intended for debugging runner implementations
109
+ outside of a full [`Trainer`][mantisdk.Trainer] stack.
110
+
111
+ Args:
112
+ agent: Agent executed by this runner.
113
+ store: Backing [`LightningStore`][mantisdk.LightningStore].
114
+ If you don't have one, you can easily create one with
115
+ [`InMemoryLightningStore`][mantisdk.InMemoryLightningStore].
116
+ hooks: Optional sequence of hooks recognised by the runner.
117
+ Not all runners support hooks.
118
+ worker_id: Override the worker identifier used during setup. Defaults
119
+ to `0`.
120
+ """
121
+ _initialized: bool = False
122
+ _worker_initialized: bool = False
123
+ try:
124
+ self.init(agent=agent, hooks=hooks)
125
+ _initialized = True
126
+ self.init_worker(worker_id=0, store=store)
127
+ _worker_initialized = True
128
+ yield self
129
+ finally:
130
+ try:
131
+ if _worker_initialized:
132
+ self.teardown_worker(worker_id=worker_id if worker_id is not None else 0)
133
+ except Exception:
134
+ logger.error("Error during runner worker teardown", exc_info=True)
135
+
136
+ try:
137
+ if _initialized:
138
+ self.teardown()
139
+ except Exception:
140
+ logger.error("Error during runner teardown", exc_info=True)
141
+
142
+ async def iter(self, *, event: Optional[ExecutionEvent] = None) -> None:
143
+ """Run the runner, continuously iterating over tasks in the store.
144
+
145
+ This method runs in a loop, polling the store for new tasks and executing
146
+ them until interrupted by the event or when no more tasks are available.
147
+
148
+ Args:
149
+ event: Cooperative stop signal. When set, the runner should complete
150
+ the current unit of work and exit the loop.
151
+
152
+ Raises:
153
+ NotImplementedError: Subclasses provide the iteration behavior.
154
+ """
155
+ raise NotImplementedError()
156
+
157
+ async def step(
158
+ self,
159
+ input: T_task,
160
+ *,
161
+ resources: Optional[NamedResources] = None,
162
+ mode: Optional[RolloutMode] = None,
163
+ event: Optional[ExecutionEvent] = None,
164
+ ) -> Rollout:
165
+ """Execute a single task with the given input.
166
+
167
+ This method provides fine-grained control for executing individual tasks
168
+ directly, bypassing the store's task queue.
169
+
170
+ Args:
171
+ input: Task payload consumed by the agent.
172
+ resources: Optional named resources scoped to this invocation.
173
+ mode: Optional rollout mode such as `"train"` or `"eval"`.
174
+ event: Cooperative stop signal for long-running tasks.
175
+
176
+ Returns:
177
+ Completed rollout produced by the agent.
178
+
179
+ Raises:
180
+ NotImplementedError: Subclasses provide the execution behavior.
181
+ """
182
+ raise NotImplementedError()
@@ -0,0 +1,309 @@
1
+ # Copyright (c) Microsoft. All rights reserved.
2
+
3
+ import json
4
+ import logging
5
+ import time
6
+ from typing import Any, Dict, List, Optional, cast
7
+
8
+ from opentelemetry.sdk.trace import ReadableSpan
9
+
10
+ from mantisdk.adapter import TracerTraceToTriplet
11
+ from mantisdk.client import MantisdkClient
12
+ from mantisdk.litagent import LitAgent
13
+ from mantisdk.litagent.litagent import is_v0_1_rollout_api
14
+ from mantisdk.tracer.base import Tracer
15
+ from mantisdk.types import RolloutLegacy, RolloutRawResultLegacy, Span, SpanLike, Triplet
16
+
17
+ from .base import Runner
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+ __all__ = [
22
+ "LegacyAgentRunner",
23
+ ]
24
+
25
+
26
+ class LegacyAgentRunner(Runner[Any]):
27
+ """Manages the agent's execution loop and integrates with AgentOps.
28
+
29
+ This class orchestrates the interaction between the agent (`LitAgent`) and
30
+ the server (`MantisdkClient`). It handles polling for tasks, executing
31
+ the agent's logic, and reporting results back to the server. If enabled,
32
+ it will also automatically trace each rollout using AgentOps.
33
+
34
+ Attributes:
35
+ agent: The `LitAgent` instance containing the agent's logic.
36
+ client: The `MantisdkClient` for server communication.
37
+ tracer: The tracer instance for this runner/worker.
38
+ worker_id: An optional identifier for the worker process.
39
+ max_tasks: The maximum number of tasks to process before stopping.
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ agent: LitAgent[Any],
45
+ client: MantisdkClient,
46
+ tracer: Tracer,
47
+ triplet_exporter: TracerTraceToTriplet,
48
+ worker_id: Optional[int] = None,
49
+ max_tasks: Optional[int] = None,
50
+ ):
51
+ super().__init__()
52
+ self.agent = agent
53
+ self.client = client
54
+ self.tracer = tracer
55
+ self.triplet_exporter = triplet_exporter
56
+
57
+ # Worker-specific attributes
58
+ self.worker_id = worker_id
59
+ self.max_tasks = max_tasks
60
+
61
+ # These methods are overridden by Runner, getting them back to old behavior.
62
+ def init(self, *args: Any, **kwargs: Any) -> None:
63
+ pass
64
+
65
+ def init_worker(self, worker_id: int, *args: Any, **kwargs: Any) -> None:
66
+ self.worker_id = worker_id
67
+
68
+ def teardown_worker(self, worker_id: int, *args: Any, **kwargs: Any) -> None:
69
+ pass
70
+
71
+ def teardown(self, *args: Any, **kwargs: Any) -> None:
72
+ pass
73
+
74
+ def _log_prefix(self, rollout_id: Optional[str] = None) -> str:
75
+ """Generates a standardized log prefix for the current worker."""
76
+ if self.worker_id is not None:
77
+ if rollout_id:
78
+ return f"[Worker {self.worker_id} | RolloutLegacy {rollout_id}]"
79
+ else:
80
+ return f"[Worker {self.worker_id}]"
81
+ if rollout_id:
82
+ return f"[RolloutLegacy {rollout_id}]"
83
+ return "[Default Worker]"
84
+
85
+ def _to_rollout_object(
86
+ self,
87
+ result: RolloutRawResultLegacy,
88
+ rollout_id: str,
89
+ ) -> RolloutLegacy:
90
+ """Standardizes the agent's return value into a RolloutLegacy object.
91
+
92
+ Args:
93
+ result: The output from the agent's rollout method.
94
+ rollout_id: The unique identifier for the current task.
95
+
96
+ Returns:
97
+ A standardized `RolloutLegacy` object for reporting to the server.
98
+ """
99
+ trace: Any = None
100
+ final_reward: Optional[float] = None
101
+ triplets: Optional[List[Triplet]] = None
102
+ trace_spans: Optional[List[SpanLike]] = None
103
+
104
+ # Handle different types of results from the agent
105
+ # Case 1: result is a float (final reward)
106
+ if isinstance(result, float):
107
+ final_reward = result
108
+ # Case 2: result is a list of Triplets
109
+ if isinstance(result, list) and all(isinstance(t, Triplet) for t in result):
110
+ triplets = result # type: ignore
111
+ # Case 3.1: result is a list of ReadableSpan (OpenTelemetry spans)
112
+ if isinstance(result, list) and all(isinstance(t, (ReadableSpan)) for t in result):
113
+ trace_spans = result # type: ignore
114
+ trace = [json.loads(readable_span.to_json()) for readable_span in trace_spans] # type: ignore
115
+ # Case 3.2: result is a list of Span (Mantisdk spans)
116
+ if isinstance(result, list) and all(isinstance(t, Span) for t in result):
117
+ trace_spans = result # type: ignore
118
+ trace = [span.model_dump() for span in trace_spans] # type: ignore
119
+ # Case 4: result is a list of dict (trace JSON)
120
+ if isinstance(result, list) and all(isinstance(t, dict) for t in result):
121
+ trace = result
122
+ # Case 5: result is a RolloutLegacy object
123
+ if isinstance(result, RolloutLegacy):
124
+ final_reward = result.final_reward
125
+ triplets = result.triplets
126
+ trace = result.trace
127
+
128
+ # If the agent has tracing enabled, use the tracer's last trace if not already set
129
+ if self.tracer and (trace is None or trace_spans is None):
130
+ trace_spans = self.tracer.get_last_trace() # type: ignore
131
+ if trace_spans:
132
+ trace = [cast(Span, span).model_dump() for span in trace_spans]
133
+
134
+ # Always extract triplets from the trace using TracerTraceToTriplet
135
+ if trace_spans:
136
+ triplets = self.triplet_exporter(trace_spans) # type: ignore
137
+
138
+ # If the agent has triplets, use the last one for final reward if not set
139
+ if triplets and triplets[-1].reward is not None and final_reward is None:
140
+ final_reward = triplets[-1].reward
141
+
142
+ # Create the RolloutLegacy object with standardized fields
143
+ result_dict: Dict[str, Any] = {
144
+ "rollout_id": rollout_id,
145
+ }
146
+ if final_reward is not None:
147
+ result_dict["final_reward"] = final_reward
148
+ if triplets is not None:
149
+ result_dict["triplets"] = triplets
150
+ if trace is not None:
151
+ result_dict["trace"] = trace
152
+
153
+ if isinstance(result, RolloutLegacy):
154
+ return result.model_copy(update=result_dict)
155
+ return RolloutLegacy(**result_dict)
156
+
157
+ def run(self) -> bool: # type: ignore
158
+ """Poll the task and rollout once synchronously."""
159
+ self.agent.set_runner(self) # Ensure the agent has a reference to this runner
160
+
161
+ task = self.client.poll_next_task()
162
+ if task is None:
163
+ logger.info(f"{self._log_prefix()} Poll returned no task. Exiting.")
164
+ return False
165
+ rollout_id = task.rollout_id
166
+
167
+ resources_id = task.resources_id
168
+ resources_update = None
169
+ if resources_id:
170
+ resources_update = self.client.get_resources_by_id(resources_id)
171
+ else:
172
+ logger.debug(f"{self._log_prefix(rollout_id)} No 'resources_id'. Fetching latest resources.")
173
+ resources_update = self.client.get_latest_resources()
174
+ if not resources_update:
175
+ logger.error(f"{self._log_prefix(rollout_id)} Failed to fetch resources. Skipping.")
176
+ return False
177
+
178
+ rollout_obj = RolloutLegacy(rollout_id=task.rollout_id, task=task) # Default empty rollout
179
+
180
+ try:
181
+ try:
182
+ self.agent.on_rollout_start(task, self, self.tracer)
183
+ except Exception:
184
+ logger.exception(f"{self._log_prefix(rollout_id)} Exception during on_rollout_start hook.")
185
+
186
+ with self.tracer._trace_context_sync(name=f"rollout_{rollout_id}"): # pyright: ignore[reportPrivateUsage]
187
+ start_time = time.time()
188
+ rollout_method = self.agent.training_rollout if task.mode == "train" else self.agent.validation_rollout
189
+ # Pass the task input, not the whole task object
190
+ if is_v0_1_rollout_api(rollout_method):
191
+ result = cast(
192
+ RolloutRawResultLegacy,
193
+ rollout_method(
194
+ task.input, rollout_id=rollout_obj.rollout_id, resources=resources_update.resources # type: ignore
195
+ ),
196
+ ) # type: ignore
197
+ else:
198
+ result = rollout_method(task.input, resources=resources_update.resources, rollout=rollout_obj) # type: ignore
199
+ rollout_obj = self._to_rollout_object(result, task.rollout_id) # type: ignore
200
+ end_time = time.time()
201
+ logger.info(
202
+ f"{self._log_prefix(rollout_id)} Completed in "
203
+ f"{end_time - start_time:.2f}s. Triplet length: "
204
+ f"{len(rollout_obj.triplets) if rollout_obj.triplets is not None else 'N/A'}. "
205
+ f"Reward: {rollout_obj.final_reward}"
206
+ )
207
+
208
+ except Exception:
209
+ logger.exception(f"{self._log_prefix(rollout_id)} Exception during rollout.")
210
+ finally:
211
+ try:
212
+ self.agent.on_rollout_end(task, rollout_obj, self, self.tracer) # type: ignore
213
+ except Exception:
214
+ logger.exception(f"{self._log_prefix(rollout_id)} Exception during on_rollout_end hook.")
215
+ self.client.post_rollout(rollout_obj)
216
+
217
+ return True
218
+
219
+ def iter(self) -> int: # type: ignore
220
+ """Executes the synchronous polling and rollout loop."""
221
+ num_tasks_processed = 0
222
+ logger.info(f"{self._log_prefix()} Started sync rollouts (max: {self.max_tasks or 'unlimited'}).")
223
+
224
+ while self.max_tasks is None or num_tasks_processed < self.max_tasks:
225
+ if self.run():
226
+ num_tasks_processed += 1
227
+
228
+ if num_tasks_processed % 10 == 0 or num_tasks_processed == 1:
229
+ logger.info(f"{self._log_prefix()} Progress: {num_tasks_processed}/{self.max_tasks or 'unlimited'}")
230
+
231
+ logger.info(f"{self._log_prefix()} Finished sync rollouts. Processed {num_tasks_processed} tasks.")
232
+ return num_tasks_processed
233
+
234
+ async def run_async(self) -> bool:
235
+ """Poll the task and rollout once."""
236
+ self.agent.set_runner(self) # Ensure the agent has a reference to this runner
237
+
238
+ task = await self.client.poll_next_task_async()
239
+ if task is None:
240
+ logger.info(f"{self._log_prefix()} Poll returned no task. Exiting.")
241
+ return False
242
+ rollout_id = task.rollout_id
243
+
244
+ resources_id = task.resources_id
245
+ resources_update = None
246
+ if resources_id:
247
+ resources_update = await self.client.get_resources_by_id_async(resources_id)
248
+ else:
249
+ logger.debug(f"{self._log_prefix(rollout_id)} No 'resources_id'. Fetching latest resources.")
250
+ resources_update = await self.client.get_latest_resources_async()
251
+ if not resources_update:
252
+ logger.error(f"{self._log_prefix(rollout_id)} Failed to fetch resources. Skipping.")
253
+ return False
254
+
255
+ rollout_obj = RolloutLegacy(rollout_id=task.rollout_id, task=task) # Default empty rollout
256
+
257
+ try:
258
+ try:
259
+ self.agent.on_rollout_start(task, self, self.tracer)
260
+ except Exception:
261
+ logger.exception(f"{self._log_prefix(rollout_id)} Exception during on_rollout_start hook.")
262
+
263
+ async with self.tracer.trace_context(name=f"rollout_{rollout_id}"):
264
+ start_time = time.time()
265
+ rollout_method = (
266
+ self.agent.training_rollout_async if task.mode == "train" else self.agent.validation_rollout_async
267
+ )
268
+ # Pass the task input, not the whole task object
269
+ if is_v0_1_rollout_api(rollout_method):
270
+ result = cast(
271
+ RolloutRawResultLegacy,
272
+ await rollout_method(
273
+ task.input, rollout_id=rollout_obj.rollout_id, resources=resources_update.resources # type: ignore
274
+ ),
275
+ ) # type: ignore
276
+ else:
277
+ result = await rollout_method(task.input, resources=resources_update.resources, rollout=rollout_obj) # type: ignore
278
+ rollout_obj = self._to_rollout_object(result, task.rollout_id) # type: ignore
279
+ end_time = time.time()
280
+ logger.info(
281
+ f"{self._log_prefix(rollout_id)} Completed in "
282
+ f"{end_time - start_time:.2f}s. Triplet length: "
283
+ f"{len(rollout_obj.triplets) if rollout_obj.triplets is not None else 'N/A'}. "
284
+ f"Reward: {rollout_obj.final_reward}"
285
+ )
286
+ except Exception:
287
+ logger.exception(f"{self._log_prefix(rollout_id)} Exception during rollout.")
288
+ finally:
289
+ try:
290
+ self.agent.on_rollout_end(task, rollout_obj, self, self.tracer) # type: ignore
291
+ except Exception:
292
+ logger.exception(f"{self._log_prefix(rollout_id)} Exception during on_rollout_end hook.")
293
+ await self.client.post_rollout_async(rollout_obj)
294
+
295
+ return True
296
+
297
+ async def iter_async(self) -> int:
298
+ """Executes the asynchronous polling and rollout loop."""
299
+ num_tasks_processed = 0
300
+ logger.info(f"{self._log_prefix()} Started async rollouts (max: {self.max_tasks or 'unlimited'}).")
301
+
302
+ while self.max_tasks is None or num_tasks_processed < self.max_tasks:
303
+ if await self.run_async():
304
+ num_tasks_processed += 1
305
+
306
+ if num_tasks_processed % 10 == 0 or num_tasks_processed == 1:
307
+ logger.info(f"{self._log_prefix()} Progress: {num_tasks_processed}/{self.max_tasks or 'unlimited'}")
308
+ logger.info(f"{self._log_prefix()} Finished async rollouts. Processed {num_tasks_processed} tasks.")
309
+ return num_tasks_processed
mantisdk/semconv.py ADDED
@@ -0,0 +1,170 @@
1
+ # Copyright (c) Microsoft. All rights reserved.
2
+
3
+ """Semantic conventions for Mantisdk spans.
4
+
5
+ Conventions in this file are added on demand. We generally DO NOT add
6
+ new semantic conventions unless it's absolutely needed for certain algorithms or scenarios.
7
+ """
8
+
9
+ from enum import Enum
10
+
11
+ from pydantic import BaseModel
12
+
13
+ AGL_ANNOTATION = "mantisdk.annotation"
14
+ """Mantisdk's standard span name for annotations.
15
+
16
+ Annotations are minimal span units for rewards, tags, and metadatas.
17
+ They are used to "annotate" a specific event or a part of rollout.
18
+ """
19
+
20
+ AGL_MESSAGE = "mantisdk.message"
21
+ """Mantisdk's standard span name for messages and logs."""
22
+
23
+ AGL_OBJECT = "mantisdk.object"
24
+ """Mantisdk's standard span name for customized objects."""
25
+
26
+ AGL_EXCEPTION = "mantisdk.exception"
27
+ """Mantisdk's standard span name for exceptions.
28
+
29
+ Used by the exception emitter to record exception details.
30
+ """
31
+
32
+ AGL_OPERATION = "mantisdk.operation"
33
+ """Mantisdk's standard span name for functions.
34
+ Wrap function or code-blocks as operations.
35
+ """
36
+
37
+ AGL_REWARD = "mantisdk.reward"
38
+ """Mantisdk's standard span name for reward operations."""
39
+
40
+ AGL_VIRTUAL = "mantisdk.virtual"
41
+ """Mantisdk's standard span name for virtual operations.
42
+
43
+ Mostly used in adapter when needing to represent the root or intermediate operations.
44
+ """
45
+
46
+
47
+ class LightningResourceAttributes(Enum):
48
+ """Resource attribute names used in Mantisdk spans."""
49
+
50
+ ROLLOUT_ID = "mantisdk.rollout_id"
51
+ """Resource name for rollout ID in Mantisdk spans."""
52
+
53
+ ATTEMPT_ID = "mantisdk.attempt_id"
54
+ """Resource name for attempt ID in Mantisdk spans."""
55
+
56
+ SPAN_SEQUENCE_ID = "mantisdk.span_sequence_id"
57
+ """Resource name for span sequence ID in Mantisdk spans."""
58
+
59
+ TRACER_NAME = "mantisdk.tracer.name"
60
+ """Which tracer is used to create this span."""
61
+
62
+ JOB_ID = "mantisdk.job_id"
63
+ """Resource name for job ID in Mantisdk spans (for non-rollout spans)."""
64
+
65
+ SPAN_TYPE = "mantisdk.span_type"
66
+ """Type of span: 'rollout' for rollout-context spans, 'job' for job-level spans."""
67
+
68
+
69
+ class LightningSpanAttributes(Enum):
70
+ """Attribute names that commonly appear in Mantisdk spans.
71
+
72
+ Exception types can't be found here because they are defined in OpenTelemetry's official semantic conventions.
73
+ """
74
+
75
+ REWARD = "mantisdk.reward"
76
+ """Attribute prefix for rewards-related data in reward spans.
77
+
78
+ It should be used as a prefix. For example, "mantisdk.reward.0.value" can
79
+ be used to track a specific metric. See [RewardAttributes][mantisdk.semconv.RewardAttributes].
80
+ """
81
+
82
+ LINK = "mantisdk.link"
83
+ """Attribute name for linking the current span to another span or other objects like requests/responses."""
84
+
85
+ TAG = "mantisdk.tag"
86
+ """Attribute name for tagging spans with customized strings."""
87
+
88
+ MESSAGE_BODY = "mantisdk.message.body"
89
+ """Attribute name for message text in message spans."""
90
+
91
+ OBJECT_TYPE = "mantisdk.object.type"
92
+ """Attribute name for object type (full qualified name) in object spans.
93
+
94
+ I think builtin types like str, int, bool, list, dict are self-explanatory and
95
+ should also be qualified to use here.
96
+ """
97
+
98
+ OBJECT_LITERAL = "mantisdk.object.literal"
99
+ """Attribute name for object literal value in object spans (for str, int, bool, ...)."""
100
+
101
+ OBJECT_JSON = "mantisdk.object.json"
102
+ """Attribute name for object serialized value (JSON) in object spans."""
103
+
104
+ OPERATION_NAME = "mantisdk.operation.name"
105
+ """Attribute name for operation name in operation spans, normally the function name."""
106
+
107
+ OPERATION_INPUT = "mantisdk.operation.input"
108
+ """Attribute name for operation input in operation spans."""
109
+
110
+ OPERATION_OUTPUT = "mantisdk.operation.output"
111
+ """Attribute name for operation output in operation spans."""
112
+
113
+
114
+ class RewardAttributes(Enum):
115
+ """Multi-dimensional reward attributes will look like:
116
+
117
+ ```json
118
+ {"mantisdk.reward.0.name": "efficiency", "mantisdk.reward.0.value": 0.75}
119
+ ```
120
+
121
+ The first reward in the reward list will automatically be the primary reward.
122
+ If the reward list has greater than 1, it shall be a multi-dimensional case.
123
+ """
124
+
125
+ REWARD_NAME = "name"
126
+ """Key for each dimension in multi-dimensional reward spans."""
127
+
128
+ REWARD_VALUE = "value"
129
+ """Value for each dimension in multi-dimensional reward spans."""
130
+
131
+
132
+ class RewardPydanticModel(BaseModel):
133
+ """A stricter implementation of RewardAttributes used in otel helpers."""
134
+
135
+ name: str
136
+ """Name of the reward dimension."""
137
+
138
+ value: float
139
+ """Value of the reward dimension."""
140
+
141
+
142
+ class LinkAttributes(Enum):
143
+ """Standard link types used in Mantisdk spans.
144
+
145
+ The link is more powerful than [OpenTelemetry link](https://opentelemetry.io/docs/specs/otel/trace/api/#link)
146
+ in that it supports linking to a queryset of spans.
147
+ It can even link to span object that hasn't been emitted yet.
148
+ """
149
+
150
+ KEY_MATCH = "key_match"
151
+ """Linking to spans with matching attribute keys.
152
+
153
+ `trace_id` and `span_id` are reserved and will be used to link to specific spans directly.
154
+
155
+ For example, it can be `gen_ai.response.id` if intended to be link to a chat completion response span.
156
+ Or it can be `span_id` to link to a specific span by its ID.
157
+ """
158
+
159
+ VALUE_MATCH = "value_match"
160
+ """Linking to spans with corresponding attribute values on those keys."""
161
+
162
+
163
+ class LinkPydanticModel(BaseModel):
164
+ """A stricter implementation of LinkAttributes used in otel helpers."""
165
+
166
+ key_match: str
167
+ """The attribute key to match on the target spans."""
168
+
169
+ value_match: str
170
+ """The attribute value to match on the target spans."""