mantisdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mantisdk might be problematic. Click here for more details.
- mantisdk/__init__.py +22 -0
- mantisdk/adapter/__init__.py +15 -0
- mantisdk/adapter/base.py +94 -0
- mantisdk/adapter/messages.py +270 -0
- mantisdk/adapter/triplet.py +1028 -0
- mantisdk/algorithm/__init__.py +39 -0
- mantisdk/algorithm/apo/__init__.py +5 -0
- mantisdk/algorithm/apo/apo.py +889 -0
- mantisdk/algorithm/apo/prompts/apply_edit_variant01.poml +22 -0
- mantisdk/algorithm/apo/prompts/apply_edit_variant02.poml +18 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant01.poml +18 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant02.poml +16 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant03.poml +107 -0
- mantisdk/algorithm/base.py +162 -0
- mantisdk/algorithm/decorator.py +264 -0
- mantisdk/algorithm/fast.py +250 -0
- mantisdk/algorithm/gepa/__init__.py +59 -0
- mantisdk/algorithm/gepa/adapter.py +459 -0
- mantisdk/algorithm/gepa/gepa.py +364 -0
- mantisdk/algorithm/gepa/lib/__init__.py +18 -0
- mantisdk/algorithm/gepa/lib/adapters/README.md +12 -0
- mantisdk/algorithm/gepa/lib/adapters/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/README.md +341 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/__init__.py +1 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/anymaths_adapter.py +174 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/requirements.txt +1 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/README.md +0 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/default_adapter.py +209 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/README.md +7 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/dspy_adapter.py +307 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/README.md +99 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/dspy_program_proposal_signature.py +137 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/full_program_adapter.py +266 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/GEPA_RAG.md +621 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/__init__.py +56 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/evaluation_metrics.py +226 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/generic_rag_adapter.py +496 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/rag_pipeline.py +238 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_store_interface.py +212 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/__init__.py +2 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/chroma_store.py +196 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/lancedb_store.py +422 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/milvus_store.py +409 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/qdrant_store.py +368 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/weaviate_store.py +418 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/README.md +552 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/__init__.py +37 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_adapter.py +705 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_client.py +364 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/README.md +9 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/terminal_bench_adapter.py +217 -0
- mantisdk/algorithm/gepa/lib/api.py +375 -0
- mantisdk/algorithm/gepa/lib/core/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/core/adapter.py +180 -0
- mantisdk/algorithm/gepa/lib/core/data_loader.py +74 -0
- mantisdk/algorithm/gepa/lib/core/engine.py +356 -0
- mantisdk/algorithm/gepa/lib/core/result.py +233 -0
- mantisdk/algorithm/gepa/lib/core/state.py +636 -0
- mantisdk/algorithm/gepa/lib/examples/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/examples/aime.py +24 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/eval_default.py +111 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/instruction_prompt.txt +9 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/optimal_prompt.txt +24 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/train_anymaths.py +177 -0
- mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/arc_agi.ipynb +25705 -0
- mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/example.ipynb +348 -0
- mantisdk/algorithm/gepa/lib/examples/mcp_adapter/__init__.py +4 -0
- mantisdk/algorithm/gepa/lib/examples/mcp_adapter/mcp_optimization_example.py +455 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/RAG_GUIDE.md +613 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/__init__.py +9 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/rag_optimization.py +824 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/requirements-rag.txt +29 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/instruction_prompt.txt +16 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/terminus.txt +9 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/train_terminus.py +161 -0
- mantisdk/algorithm/gepa/lib/gepa_utils.py +117 -0
- mantisdk/algorithm/gepa/lib/logging/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/logging/experiment_tracker.py +187 -0
- mantisdk/algorithm/gepa/lib/logging/logger.py +75 -0
- mantisdk/algorithm/gepa/lib/logging/utils.py +103 -0
- mantisdk/algorithm/gepa/lib/proposer/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/proposer/base.py +31 -0
- mantisdk/algorithm/gepa/lib/proposer/merge.py +357 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/base.py +49 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/reflective_mutation.py +176 -0
- mantisdk/algorithm/gepa/lib/py.typed +0 -0
- mantisdk/algorithm/gepa/lib/strategies/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/strategies/batch_sampler.py +77 -0
- mantisdk/algorithm/gepa/lib/strategies/candidate_selector.py +50 -0
- mantisdk/algorithm/gepa/lib/strategies/component_selector.py +36 -0
- mantisdk/algorithm/gepa/lib/strategies/eval_policy.py +64 -0
- mantisdk/algorithm/gepa/lib/strategies/instruction_proposal.py +127 -0
- mantisdk/algorithm/gepa/lib/utils/__init__.py +10 -0
- mantisdk/algorithm/gepa/lib/utils/stop_condition.py +196 -0
- mantisdk/algorithm/gepa/tracing.py +105 -0
- mantisdk/algorithm/utils.py +177 -0
- mantisdk/algorithm/verl/__init__.py +5 -0
- mantisdk/algorithm/verl/interface.py +202 -0
- mantisdk/cli/__init__.py +56 -0
- mantisdk/cli/prometheus.py +115 -0
- mantisdk/cli/store.py +131 -0
- mantisdk/cli/vllm.py +29 -0
- mantisdk/client.py +408 -0
- mantisdk/config.py +348 -0
- mantisdk/emitter/__init__.py +43 -0
- mantisdk/emitter/annotation.py +370 -0
- mantisdk/emitter/exception.py +54 -0
- mantisdk/emitter/message.py +61 -0
- mantisdk/emitter/object.py +117 -0
- mantisdk/emitter/reward.py +320 -0
- mantisdk/env_var.py +156 -0
- mantisdk/execution/__init__.py +15 -0
- mantisdk/execution/base.py +64 -0
- mantisdk/execution/client_server.py +443 -0
- mantisdk/execution/events.py +69 -0
- mantisdk/execution/inter_process.py +16 -0
- mantisdk/execution/shared_memory.py +282 -0
- mantisdk/instrumentation/__init__.py +119 -0
- mantisdk/instrumentation/agentops.py +314 -0
- mantisdk/instrumentation/agentops_langchain.py +45 -0
- mantisdk/instrumentation/litellm.py +83 -0
- mantisdk/instrumentation/vllm.py +81 -0
- mantisdk/instrumentation/weave.py +500 -0
- mantisdk/litagent/__init__.py +11 -0
- mantisdk/litagent/decorator.py +536 -0
- mantisdk/litagent/litagent.py +252 -0
- mantisdk/llm_proxy.py +1890 -0
- mantisdk/logging.py +370 -0
- mantisdk/reward.py +7 -0
- mantisdk/runner/__init__.py +11 -0
- mantisdk/runner/agent.py +845 -0
- mantisdk/runner/base.py +182 -0
- mantisdk/runner/legacy.py +309 -0
- mantisdk/semconv.py +170 -0
- mantisdk/server.py +401 -0
- mantisdk/store/__init__.py +23 -0
- mantisdk/store/base.py +897 -0
- mantisdk/store/client_server.py +2092 -0
- mantisdk/store/collection/__init__.py +30 -0
- mantisdk/store/collection/base.py +587 -0
- mantisdk/store/collection/memory.py +970 -0
- mantisdk/store/collection/mongo.py +1412 -0
- mantisdk/store/collection_based.py +1823 -0
- mantisdk/store/insight.py +648 -0
- mantisdk/store/listener.py +58 -0
- mantisdk/store/memory.py +396 -0
- mantisdk/store/mongo.py +165 -0
- mantisdk/store/sqlite.py +3 -0
- mantisdk/store/threading.py +357 -0
- mantisdk/store/utils.py +142 -0
- mantisdk/tracer/__init__.py +16 -0
- mantisdk/tracer/agentops.py +242 -0
- mantisdk/tracer/base.py +287 -0
- mantisdk/tracer/dummy.py +106 -0
- mantisdk/tracer/otel.py +555 -0
- mantisdk/tracer/weave.py +677 -0
- mantisdk/trainer/__init__.py +6 -0
- mantisdk/trainer/init_utils.py +263 -0
- mantisdk/trainer/legacy.py +367 -0
- mantisdk/trainer/registry.py +12 -0
- mantisdk/trainer/trainer.py +618 -0
- mantisdk/types/__init__.py +6 -0
- mantisdk/types/core.py +553 -0
- mantisdk/types/resources.py +204 -0
- mantisdk/types/tracer.py +515 -0
- mantisdk/types/tracing.py +218 -0
- mantisdk/utils/__init__.py +1 -0
- mantisdk/utils/id.py +18 -0
- mantisdk/utils/metrics.py +1025 -0
- mantisdk/utils/otel.py +578 -0
- mantisdk/utils/otlp.py +536 -0
- mantisdk/utils/server_launcher.py +1045 -0
- mantisdk/utils/system_snapshot.py +81 -0
- mantisdk/verl/__init__.py +8 -0
- mantisdk/verl/__main__.py +6 -0
- mantisdk/verl/async_server.py +46 -0
- mantisdk/verl/config.yaml +27 -0
- mantisdk/verl/daemon.py +1154 -0
- mantisdk/verl/dataset.py +44 -0
- mantisdk/verl/entrypoint.py +248 -0
- mantisdk/verl/trainer.py +549 -0
- mantisdk-0.1.0.dist-info/METADATA +119 -0
- mantisdk-0.1.0.dist-info/RECORD +190 -0
- mantisdk-0.1.0.dist-info/WHEEL +4 -0
- mantisdk-0.1.0.dist-info/entry_points.txt +2 -0
- mantisdk-0.1.0.dist-info/licenses/LICENSE +19 -0
mantisdk/client.py
ADDED
|
@@ -0,0 +1,408 @@
|
|
|
1
|
+
# Copyright (c) Microsoft. All rights reserved.
|
|
2
|
+
|
|
3
|
+
"""Utilities for interacting with legacy Mantisdk servers.
|
|
4
|
+
|
|
5
|
+
This module contains compatibility shims that speak the deprecated HTTP
|
|
6
|
+
interface used by older Mantisdk deployments. Modern code should prefer
|
|
7
|
+
the store-based APIs exposed by `mantisdk.store`, but keeping these
|
|
8
|
+
clients available makes it easier to migrate existing workflows incrementally.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import asyncio
|
|
12
|
+
import logging
|
|
13
|
+
import time
|
|
14
|
+
import urllib.parse
|
|
15
|
+
import warnings
|
|
16
|
+
from typing import Any, Dict, List, Optional, Union
|
|
17
|
+
|
|
18
|
+
import aiohttp
|
|
19
|
+
import requests
|
|
20
|
+
|
|
21
|
+
from .types import NamedResources, ResourcesUpdate, RolloutLegacy, Task, TaskIfAny, TaskInput
|
|
22
|
+
|
|
23
|
+
logger = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class MantisdkClient:
|
|
27
|
+
"""Client wrapper for the legacy version-aware Mantisdk server.
|
|
28
|
+
|
|
29
|
+
The client exposes synchronous and asynchronous helpers for polling tasks,
|
|
30
|
+
retrieving resource bundles, and submitting rollouts. It also maintains a
|
|
31
|
+
simple in-memory cache keyed by the server-provided resource identifier to
|
|
32
|
+
avoid redundant network requests.
|
|
33
|
+
|
|
34
|
+
!!! warning "Deprecated"
|
|
35
|
+
[`MantisdkClient`][mantisdk.client.MantisdkClient] is part of
|
|
36
|
+
the legacy client/server stack. New code should rely on the store-based APIs
|
|
37
|
+
implemented in `mantisdk.store`.
|
|
38
|
+
|
|
39
|
+
Attributes:
|
|
40
|
+
endpoint: Base URL of the Mantisdk server.
|
|
41
|
+
poll_interval: Delay in seconds between polling attempts when no task is
|
|
42
|
+
available.
|
|
43
|
+
timeout: Timeout in seconds applied to HTTP requests.
|
|
44
|
+
task_count: Number of tasks claimed during the lifetime of this client.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
_next_task_uri = "/task"
|
|
48
|
+
_resources_uri = "/resources"
|
|
49
|
+
_latest_resources_uri = "/resources/latest"
|
|
50
|
+
_report_rollout_uri = "/rollout"
|
|
51
|
+
|
|
52
|
+
def __init__(self, endpoint: str, poll_interval: float = 5.0, timeout: float = 10.0):
|
|
53
|
+
"""Initialize the client.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
endpoint: Root URL of the Mantisdk server.
|
|
57
|
+
poll_interval: Seconds to wait between polling attempts.
|
|
58
|
+
timeout: Seconds before a request to the server is considered timed out.
|
|
59
|
+
"""
|
|
60
|
+
warnings.warn(
|
|
61
|
+
"MantisdkClient is deprecated. Please use LightningStoreClient instead.", DeprecationWarning
|
|
62
|
+
)
|
|
63
|
+
self.endpoint = endpoint
|
|
64
|
+
self.task_count = 0
|
|
65
|
+
self.poll_interval = poll_interval
|
|
66
|
+
self.timeout = timeout
|
|
67
|
+
self._resource_cache: Dict[str, ResourcesUpdate] = {} # TODO: mechanism to evict cache
|
|
68
|
+
self._default_headers = {"X-Mantisdk-Client": "true"}
|
|
69
|
+
|
|
70
|
+
async def _request_json_async(self, url: str) -> Optional[Dict[str, Any]]:
|
|
71
|
+
"""Perform an asynchronous ``GET`` request and parse the JSON payload.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
url: Fully qualified URL to query.
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
Parsed JSON body as a dictionary if the request succeeds; otherwise ``None``.
|
|
78
|
+
"""
|
|
79
|
+
timeout = aiohttp.ClientTimeout(total=self.timeout)
|
|
80
|
+
async with aiohttp.ClientSession(timeout=timeout) as session:
|
|
81
|
+
try:
|
|
82
|
+
async with session.get(url, headers=self._default_headers) as resp:
|
|
83
|
+
resp.raise_for_status()
|
|
84
|
+
return await resp.json()
|
|
85
|
+
except Exception as e:
|
|
86
|
+
logger.debug(f"Async GET request failed for {url}: {e}")
|
|
87
|
+
return None
|
|
88
|
+
|
|
89
|
+
async def _post_json_async(self, url: str, payload: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
|
90
|
+
"""Perform an asynchronous ``POST`` request with a JSON body.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
url: Fully qualified URL that accepts the payload.
|
|
94
|
+
payload: Dictionary that will be serialized and sent as JSON.
|
|
95
|
+
|
|
96
|
+
Returns:
|
|
97
|
+
Parsed JSON body as a dictionary if the request succeeds; otherwise ``None``.
|
|
98
|
+
"""
|
|
99
|
+
timeout = aiohttp.ClientTimeout(total=self.timeout)
|
|
100
|
+
async with aiohttp.ClientSession(timeout=timeout) as session:
|
|
101
|
+
try:
|
|
102
|
+
async with session.post(url, json=payload, headers=self._default_headers) as resp:
|
|
103
|
+
resp.raise_for_status()
|
|
104
|
+
return await resp.json()
|
|
105
|
+
except Exception as e:
|
|
106
|
+
logger.debug(f"Async POST request failed for {url}: {e}")
|
|
107
|
+
return None
|
|
108
|
+
|
|
109
|
+
async def poll_next_task_async(self) -> Optional[Task]:
|
|
110
|
+
"""Poll the server asynchronously until a task becomes available.
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
The next [`Task`][mantisdk.Task] exposed by the server,
|
|
114
|
+
or ``None`` if polling fails.
|
|
115
|
+
"""
|
|
116
|
+
url = urllib.parse.urljoin(self.endpoint, self._next_task_uri)
|
|
117
|
+
while True:
|
|
118
|
+
response = await self._request_json_async(url)
|
|
119
|
+
if response:
|
|
120
|
+
task_if_any = TaskIfAny.model_validate(response)
|
|
121
|
+
if task_if_any.is_available and task_if_any.task:
|
|
122
|
+
self.task_count += 1
|
|
123
|
+
logger.info(f"[Task {self.task_count} Received] ID: {task_if_any.task.rollout_id}")
|
|
124
|
+
return task_if_any.task
|
|
125
|
+
logger.debug(f"No task available yet. Retrying in {self.poll_interval} seconds...")
|
|
126
|
+
await asyncio.sleep(self.poll_interval)
|
|
127
|
+
|
|
128
|
+
async def get_resources_by_id_async(self, resource_id: str) -> Optional[ResourcesUpdate]:
|
|
129
|
+
"""Fetch a specific resource bundle by identifier.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
resource_id: Identifier sourced from the task metadata.
|
|
133
|
+
|
|
134
|
+
Returns:
|
|
135
|
+
Cached or freshly downloaded
|
|
136
|
+
[`ResourcesUpdate`][mantisdk.ResourcesUpdate], or
|
|
137
|
+
``None`` when the server returns an error.
|
|
138
|
+
"""
|
|
139
|
+
if resource_id in self._resource_cache:
|
|
140
|
+
logger.debug(f"Found resources '{resource_id}' in cache.")
|
|
141
|
+
return self._resource_cache[resource_id]
|
|
142
|
+
|
|
143
|
+
url = urllib.parse.urljoin(self.endpoint, f"{self._resources_uri}/{resource_id}")
|
|
144
|
+
response = await self._request_json_async(url)
|
|
145
|
+
if response:
|
|
146
|
+
resources_update = ResourcesUpdate.model_validate(response)
|
|
147
|
+
self._resource_cache[resource_id] = resources_update
|
|
148
|
+
logger.info(f"Fetched and cached resources for ID: {resource_id}")
|
|
149
|
+
return resources_update
|
|
150
|
+
return None
|
|
151
|
+
|
|
152
|
+
async def get_latest_resources_async(self) -> Optional[ResourcesUpdate]:
|
|
153
|
+
"""Fetch the most recent resource bundle advertised by the server.
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
[`ResourcesUpdate`][mantisdk.ResourcesUpdate] for the
|
|
157
|
+
newest version, or ``None`` when unavailable.
|
|
158
|
+
"""
|
|
159
|
+
url = urllib.parse.urljoin(self.endpoint, self._latest_resources_uri)
|
|
160
|
+
response = await self._request_json_async(url)
|
|
161
|
+
if response:
|
|
162
|
+
resources_update = ResourcesUpdate.model_validate(response)
|
|
163
|
+
# Cache this result as well
|
|
164
|
+
self._resource_cache[resources_update.resources_id] = resources_update
|
|
165
|
+
return resources_update
|
|
166
|
+
return None
|
|
167
|
+
|
|
168
|
+
async def post_rollout_async(self, rollout: RolloutLegacy) -> Optional[Dict[str, Any]]:
|
|
169
|
+
"""Submit a completed rollout back to the server.
|
|
170
|
+
|
|
171
|
+
Args:
|
|
172
|
+
rollout: Legacy rollout payload produced by the executor.
|
|
173
|
+
|
|
174
|
+
Returns:
|
|
175
|
+
Parsed JSON response returned by the server, or ``None`` when the request fails.
|
|
176
|
+
"""
|
|
177
|
+
url = urllib.parse.urljoin(self.endpoint, self._report_rollout_uri)
|
|
178
|
+
payload = rollout.model_dump(mode="json")
|
|
179
|
+
return await self._post_json_async(url, payload)
|
|
180
|
+
|
|
181
|
+
def _request_json(self, url: str) -> Optional[Dict[str, Any]]:
|
|
182
|
+
"""Perform a blocking ``GET`` request and parse the JSON payload.
|
|
183
|
+
|
|
184
|
+
Args:
|
|
185
|
+
url: Fully qualified URL to query.
|
|
186
|
+
|
|
187
|
+
Returns:
|
|
188
|
+
Parsed JSON body as a dictionary if the request succeeds; otherwise ``None``.
|
|
189
|
+
"""
|
|
190
|
+
try:
|
|
191
|
+
response = requests.get(url, timeout=self.timeout, headers=self._default_headers)
|
|
192
|
+
response.raise_for_status()
|
|
193
|
+
return response.json()
|
|
194
|
+
except requests.exceptions.RequestException as e:
|
|
195
|
+
logger.debug(f"Sync GET request failed for {url}: {e}")
|
|
196
|
+
return None
|
|
197
|
+
|
|
198
|
+
def _post_json(self, url: str, payload: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
|
199
|
+
"""Perform a blocking ``POST`` request with a JSON payload.
|
|
200
|
+
|
|
201
|
+
Args:
|
|
202
|
+
url: Fully qualified URL that accepts the payload.
|
|
203
|
+
payload: Dictionary that will be serialized and sent as JSON.
|
|
204
|
+
|
|
205
|
+
Returns:
|
|
206
|
+
Parsed JSON body as a dictionary if the request succeeds; otherwise ``None``.
|
|
207
|
+
"""
|
|
208
|
+
try:
|
|
209
|
+
response = requests.post(url, json=payload, timeout=self.timeout, headers=self._default_headers)
|
|
210
|
+
response.raise_for_status()
|
|
211
|
+
return response.json()
|
|
212
|
+
except requests.exceptions.RequestException as e:
|
|
213
|
+
logger.debug(f"Sync POST request failed for {url}: {e}")
|
|
214
|
+
return None
|
|
215
|
+
|
|
216
|
+
def poll_next_task(self) -> Optional[Task]:
|
|
217
|
+
"""Poll the server synchronously until a task becomes available.
|
|
218
|
+
|
|
219
|
+
Returns:
|
|
220
|
+
The next [`Task`][mantisdk.Task] available for execution, or
|
|
221
|
+
``None`` if polling fails.
|
|
222
|
+
"""
|
|
223
|
+
url = urllib.parse.urljoin(self.endpoint, self._next_task_uri)
|
|
224
|
+
while True:
|
|
225
|
+
response = self._request_json(url)
|
|
226
|
+
if response:
|
|
227
|
+
task_if_any = TaskIfAny.model_validate(response)
|
|
228
|
+
if task_if_any.is_available and task_if_any.task:
|
|
229
|
+
self.task_count += 1
|
|
230
|
+
logger.info(f"[Task {self.task_count} Received] ID: {task_if_any.task.rollout_id}")
|
|
231
|
+
return task_if_any.task
|
|
232
|
+
logger.debug(f"No task available yet. Retrying in {self.poll_interval} seconds...")
|
|
233
|
+
time.sleep(self.poll_interval)
|
|
234
|
+
|
|
235
|
+
def get_resources_by_id(self, resource_id: str) -> Optional[ResourcesUpdate]:
|
|
236
|
+
"""Fetch a specific resource bundle by identifier.
|
|
237
|
+
|
|
238
|
+
Args:
|
|
239
|
+
resource_id: Identifier sourced from the task metadata.
|
|
240
|
+
|
|
241
|
+
Returns:
|
|
242
|
+
Cached or freshly downloaded
|
|
243
|
+
[`ResourcesUpdate`][mantisdk.ResourcesUpdate], or
|
|
244
|
+
``None`` when the server returns an error.
|
|
245
|
+
"""
|
|
246
|
+
if resource_id in self._resource_cache:
|
|
247
|
+
logger.debug(f"Found resources '{resource_id}' in cache.")
|
|
248
|
+
return self._resource_cache[resource_id]
|
|
249
|
+
|
|
250
|
+
url = urllib.parse.urljoin(self.endpoint, f"{self._resources_uri}/{resource_id}")
|
|
251
|
+
response = self._request_json(url)
|
|
252
|
+
if response:
|
|
253
|
+
resources_update = ResourcesUpdate.model_validate(response)
|
|
254
|
+
self._resource_cache[resource_id] = resources_update
|
|
255
|
+
logger.info(f"Fetched and cached resources for ID: {resource_id}")
|
|
256
|
+
return resources_update
|
|
257
|
+
return None
|
|
258
|
+
|
|
259
|
+
def get_latest_resources(self) -> Optional[ResourcesUpdate]:
|
|
260
|
+
"""Fetch the most recent resource bundle advertised by the server.
|
|
261
|
+
|
|
262
|
+
Returns:
|
|
263
|
+
[`ResourcesUpdate`][mantisdk.ResourcesUpdate] for the
|
|
264
|
+
newest version, or ``None`` when unavailable.
|
|
265
|
+
"""
|
|
266
|
+
url = urllib.parse.urljoin(self.endpoint, self._latest_resources_uri)
|
|
267
|
+
response = self._request_json(url)
|
|
268
|
+
if response:
|
|
269
|
+
resources_update = ResourcesUpdate.model_validate(response)
|
|
270
|
+
self._resource_cache[resources_update.resources_id] = resources_update
|
|
271
|
+
return resources_update
|
|
272
|
+
return None
|
|
273
|
+
|
|
274
|
+
def post_rollout(self, rollout: RolloutLegacy) -> Optional[Dict[str, Any]]:
|
|
275
|
+
"""Submit a completed rollout back to the server.
|
|
276
|
+
|
|
277
|
+
Args:
|
|
278
|
+
rollout: Legacy rollout payload produced by the executor.
|
|
279
|
+
|
|
280
|
+
Returns:
|
|
281
|
+
Parsed JSON response returned by the server, or ``None`` when the request fails.
|
|
282
|
+
"""
|
|
283
|
+
url = urllib.parse.urljoin(self.endpoint, self._report_rollout_uri)
|
|
284
|
+
payload = rollout.model_dump(mode="json")
|
|
285
|
+
return self._post_json(url, payload)
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
class DevTaskLoader(MantisdkClient):
|
|
289
|
+
"""In-memory task loader used for development and integration tests.
|
|
290
|
+
|
|
291
|
+
The loader mimics the behavior of the legacy HTTP server by storing tasks and
|
|
292
|
+
resources locally. Polling methods simply iterate over the provided collection,
|
|
293
|
+
allowing rapid iteration without provisioning any external infrastructure.
|
|
294
|
+
|
|
295
|
+
!!! warning "Deprecated"
|
|
296
|
+
|
|
297
|
+
[`DevTaskLoader`][mantisdk.client.DevTaskLoader] is a compatibility shim.
|
|
298
|
+
Prefer [`Trainer.dev`][mantisdk.Trainer.dev] for new code.
|
|
299
|
+
"""
|
|
300
|
+
|
|
301
|
+
def __init__(
|
|
302
|
+
self,
|
|
303
|
+
tasks: Union[List[TaskInput], List[Task]],
|
|
304
|
+
resources: Union[NamedResources, ResourcesUpdate],
|
|
305
|
+
**kwargs: Any,
|
|
306
|
+
):
|
|
307
|
+
"""Initialize the loader with predefined tasks and resources.
|
|
308
|
+
|
|
309
|
+
Args:
|
|
310
|
+
tasks: Sequence of task inputs or preconstructed tasks that will be served in
|
|
311
|
+
order.
|
|
312
|
+
resources: Static resources returned for any `resources_id` query.
|
|
313
|
+
**kwargs: Additional keyword arguments forwarded to the parent client.
|
|
314
|
+
|
|
315
|
+
Raises:
|
|
316
|
+
ValueError: If no tasks are provided or both [`Task`][mantisdk.Task]
|
|
317
|
+
and [`TaskInput`][mantisdk.TaskInput] instances are mixed.
|
|
318
|
+
"""
|
|
319
|
+
warnings.warn("DevTaskLoader is deprecated. Please use Trainer.dev instead.", DeprecationWarning)
|
|
320
|
+
super().__init__(endpoint="local://", **kwargs)
|
|
321
|
+
self._tasks = tasks.copy()
|
|
322
|
+
if len(self._tasks) == 0:
|
|
323
|
+
raise ValueError("DevTaskLoader requires at least one task to be provided.")
|
|
324
|
+
|
|
325
|
+
# Check if tasks are mixture of TaskInput and Task
|
|
326
|
+
if any(isinstance(task, Task) for task in self._tasks):
|
|
327
|
+
if not all(isinstance(task, Task) for task in self._tasks):
|
|
328
|
+
raise ValueError("All tasks must be either Task or TaskInput objects.")
|
|
329
|
+
|
|
330
|
+
self._task_index = 0
|
|
331
|
+
|
|
332
|
+
if isinstance(resources, ResourcesUpdate):
|
|
333
|
+
self._resources_update = resources
|
|
334
|
+
else:
|
|
335
|
+
self._resources_update = ResourcesUpdate(
|
|
336
|
+
resources_id="local", resources=resources, create_time=time.time(), update_time=time.time(), version=1
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
# Store rollouts posted back to the loader for easy debugging of local runs
|
|
340
|
+
self._rollouts: List[RolloutLegacy] = []
|
|
341
|
+
|
|
342
|
+
@property
|
|
343
|
+
def rollouts(self) -> List[RolloutLegacy]:
|
|
344
|
+
"""Return the rollouts posted back to the loader during development runs."""
|
|
345
|
+
return self._rollouts
|
|
346
|
+
|
|
347
|
+
def poll_next_task(self) -> Optional[Task]:
|
|
348
|
+
"""Return the next task from the local queue.
|
|
349
|
+
|
|
350
|
+
If [`TaskInput`][mantisdk.TaskInput] instances were provided,
|
|
351
|
+
they are converted into [`Task`][mantisdk.Task] objects on the
|
|
352
|
+
fly. Otherwise, the preconstructed tasks are returned in sequence.
|
|
353
|
+
|
|
354
|
+
Returns:
|
|
355
|
+
Next task to execute.
|
|
356
|
+
"""
|
|
357
|
+
if self._task_index >= len(self._tasks):
|
|
358
|
+
self._task_index = 0
|
|
359
|
+
|
|
360
|
+
task_or_input = self._tasks[self._task_index]
|
|
361
|
+
|
|
362
|
+
if isinstance(task_or_input, Task):
|
|
363
|
+
task = task_or_input
|
|
364
|
+
else:
|
|
365
|
+
rollout_id = f"local_task_{self._task_index + 1:03d}"
|
|
366
|
+
task = Task(
|
|
367
|
+
rollout_id=rollout_id,
|
|
368
|
+
input=task_or_input,
|
|
369
|
+
resources_id=self._resources_update.resources_id,
|
|
370
|
+
create_time=time.time(),
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
self._task_index += 1
|
|
374
|
+
self.task_count += 1
|
|
375
|
+
logger.info(f"[Task {self.task_count} Received] Task ID: {task.rollout_id}")
|
|
376
|
+
return task
|
|
377
|
+
|
|
378
|
+
def get_resources_by_id(self, resource_id: str) -> Optional[ResourcesUpdate]:
|
|
379
|
+
logger.debug(f"DevTaskLoader checking resources for ID: {resource_id}")
|
|
380
|
+
if resource_id != self._resources_update.resources_id:
|
|
381
|
+
raise ValueError(
|
|
382
|
+
f"Resource ID '{resource_id}' not found. Only '{self._resources_update.resources_id}' is available."
|
|
383
|
+
)
|
|
384
|
+
return self._resources_update
|
|
385
|
+
|
|
386
|
+
def get_latest_resources(self) -> Optional[ResourcesUpdate]:
|
|
387
|
+
logger.debug("DevTaskLoader returning latest resources.")
|
|
388
|
+
return self._resources_update
|
|
389
|
+
|
|
390
|
+
def post_rollout(self, rollout: RolloutLegacy) -> Optional[Dict[str, Any]]:
|
|
391
|
+
logger.debug(f"DevTaskLoader received rollout for task: {rollout.rollout_id}")
|
|
392
|
+
self._rollouts.append(rollout)
|
|
393
|
+
return {"status": "received", "rollout_id": rollout.rollout_id}
|
|
394
|
+
|
|
395
|
+
async def poll_next_task_async(self) -> Optional[Task]:
|
|
396
|
+
return self.poll_next_task()
|
|
397
|
+
|
|
398
|
+
async def get_resources_by_id_async(self, resource_id: str) -> Optional[ResourcesUpdate]:
|
|
399
|
+
return self.get_resources_by_id(resource_id)
|
|
400
|
+
|
|
401
|
+
async def get_latest_resources_async(self) -> Optional[ResourcesUpdate]:
|
|
402
|
+
return self.get_latest_resources()
|
|
403
|
+
|
|
404
|
+
async def post_rollout_async(self, rollout: RolloutLegacy) -> Optional[Dict[str, Any]]:
|
|
405
|
+
return self.post_rollout(rollout)
|
|
406
|
+
|
|
407
|
+
def __repr__(self):
|
|
408
|
+
return f"DevTaskLoader(num_tasks={len(self._tasks)}, resources={self._resources_update.resources})"
|