mantisdk 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mantisdk might be problematic. Click here for more details.

Files changed (190) hide show
  1. mantisdk/__init__.py +22 -0
  2. mantisdk/adapter/__init__.py +15 -0
  3. mantisdk/adapter/base.py +94 -0
  4. mantisdk/adapter/messages.py +270 -0
  5. mantisdk/adapter/triplet.py +1028 -0
  6. mantisdk/algorithm/__init__.py +39 -0
  7. mantisdk/algorithm/apo/__init__.py +5 -0
  8. mantisdk/algorithm/apo/apo.py +889 -0
  9. mantisdk/algorithm/apo/prompts/apply_edit_variant01.poml +22 -0
  10. mantisdk/algorithm/apo/prompts/apply_edit_variant02.poml +18 -0
  11. mantisdk/algorithm/apo/prompts/text_gradient_variant01.poml +18 -0
  12. mantisdk/algorithm/apo/prompts/text_gradient_variant02.poml +16 -0
  13. mantisdk/algorithm/apo/prompts/text_gradient_variant03.poml +107 -0
  14. mantisdk/algorithm/base.py +162 -0
  15. mantisdk/algorithm/decorator.py +264 -0
  16. mantisdk/algorithm/fast.py +250 -0
  17. mantisdk/algorithm/gepa/__init__.py +59 -0
  18. mantisdk/algorithm/gepa/adapter.py +459 -0
  19. mantisdk/algorithm/gepa/gepa.py +364 -0
  20. mantisdk/algorithm/gepa/lib/__init__.py +18 -0
  21. mantisdk/algorithm/gepa/lib/adapters/README.md +12 -0
  22. mantisdk/algorithm/gepa/lib/adapters/__init__.py +0 -0
  23. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/README.md +341 -0
  24. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/__init__.py +1 -0
  25. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/anymaths_adapter.py +174 -0
  26. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/requirements.txt +1 -0
  27. mantisdk/algorithm/gepa/lib/adapters/default_adapter/README.md +0 -0
  28. mantisdk/algorithm/gepa/lib/adapters/default_adapter/__init__.py +0 -0
  29. mantisdk/algorithm/gepa/lib/adapters/default_adapter/default_adapter.py +209 -0
  30. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/README.md +7 -0
  31. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/__init__.py +0 -0
  32. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/dspy_adapter.py +307 -0
  33. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/README.md +99 -0
  34. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/dspy_program_proposal_signature.py +137 -0
  35. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/full_program_adapter.py +266 -0
  36. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/GEPA_RAG.md +621 -0
  37. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/__init__.py +56 -0
  38. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/evaluation_metrics.py +226 -0
  39. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/generic_rag_adapter.py +496 -0
  40. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/rag_pipeline.py +238 -0
  41. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_store_interface.py +212 -0
  42. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/__init__.py +2 -0
  43. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/chroma_store.py +196 -0
  44. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/lancedb_store.py +422 -0
  45. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/milvus_store.py +409 -0
  46. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/qdrant_store.py +368 -0
  47. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/weaviate_store.py +418 -0
  48. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/README.md +552 -0
  49. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/__init__.py +37 -0
  50. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_adapter.py +705 -0
  51. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_client.py +364 -0
  52. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/README.md +9 -0
  53. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/__init__.py +0 -0
  54. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/terminal_bench_adapter.py +217 -0
  55. mantisdk/algorithm/gepa/lib/api.py +375 -0
  56. mantisdk/algorithm/gepa/lib/core/__init__.py +0 -0
  57. mantisdk/algorithm/gepa/lib/core/adapter.py +180 -0
  58. mantisdk/algorithm/gepa/lib/core/data_loader.py +74 -0
  59. mantisdk/algorithm/gepa/lib/core/engine.py +356 -0
  60. mantisdk/algorithm/gepa/lib/core/result.py +233 -0
  61. mantisdk/algorithm/gepa/lib/core/state.py +636 -0
  62. mantisdk/algorithm/gepa/lib/examples/__init__.py +0 -0
  63. mantisdk/algorithm/gepa/lib/examples/aime.py +24 -0
  64. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/eval_default.py +111 -0
  65. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/instruction_prompt.txt +9 -0
  66. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/optimal_prompt.txt +24 -0
  67. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/train_anymaths.py +177 -0
  68. mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/arc_agi.ipynb +25705 -0
  69. mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/example.ipynb +348 -0
  70. mantisdk/algorithm/gepa/lib/examples/mcp_adapter/__init__.py +4 -0
  71. mantisdk/algorithm/gepa/lib/examples/mcp_adapter/mcp_optimization_example.py +455 -0
  72. mantisdk/algorithm/gepa/lib/examples/rag_adapter/RAG_GUIDE.md +613 -0
  73. mantisdk/algorithm/gepa/lib/examples/rag_adapter/__init__.py +9 -0
  74. mantisdk/algorithm/gepa/lib/examples/rag_adapter/rag_optimization.py +824 -0
  75. mantisdk/algorithm/gepa/lib/examples/rag_adapter/requirements-rag.txt +29 -0
  76. mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/instruction_prompt.txt +16 -0
  77. mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/terminus.txt +9 -0
  78. mantisdk/algorithm/gepa/lib/examples/terminal-bench/train_terminus.py +161 -0
  79. mantisdk/algorithm/gepa/lib/gepa_utils.py +117 -0
  80. mantisdk/algorithm/gepa/lib/logging/__init__.py +0 -0
  81. mantisdk/algorithm/gepa/lib/logging/experiment_tracker.py +187 -0
  82. mantisdk/algorithm/gepa/lib/logging/logger.py +75 -0
  83. mantisdk/algorithm/gepa/lib/logging/utils.py +103 -0
  84. mantisdk/algorithm/gepa/lib/proposer/__init__.py +0 -0
  85. mantisdk/algorithm/gepa/lib/proposer/base.py +31 -0
  86. mantisdk/algorithm/gepa/lib/proposer/merge.py +357 -0
  87. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/__init__.py +0 -0
  88. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/base.py +49 -0
  89. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/reflective_mutation.py +176 -0
  90. mantisdk/algorithm/gepa/lib/py.typed +0 -0
  91. mantisdk/algorithm/gepa/lib/strategies/__init__.py +0 -0
  92. mantisdk/algorithm/gepa/lib/strategies/batch_sampler.py +77 -0
  93. mantisdk/algorithm/gepa/lib/strategies/candidate_selector.py +50 -0
  94. mantisdk/algorithm/gepa/lib/strategies/component_selector.py +36 -0
  95. mantisdk/algorithm/gepa/lib/strategies/eval_policy.py +64 -0
  96. mantisdk/algorithm/gepa/lib/strategies/instruction_proposal.py +127 -0
  97. mantisdk/algorithm/gepa/lib/utils/__init__.py +10 -0
  98. mantisdk/algorithm/gepa/lib/utils/stop_condition.py +196 -0
  99. mantisdk/algorithm/gepa/tracing.py +105 -0
  100. mantisdk/algorithm/utils.py +177 -0
  101. mantisdk/algorithm/verl/__init__.py +5 -0
  102. mantisdk/algorithm/verl/interface.py +202 -0
  103. mantisdk/cli/__init__.py +56 -0
  104. mantisdk/cli/prometheus.py +115 -0
  105. mantisdk/cli/store.py +131 -0
  106. mantisdk/cli/vllm.py +29 -0
  107. mantisdk/client.py +408 -0
  108. mantisdk/config.py +348 -0
  109. mantisdk/emitter/__init__.py +43 -0
  110. mantisdk/emitter/annotation.py +370 -0
  111. mantisdk/emitter/exception.py +54 -0
  112. mantisdk/emitter/message.py +61 -0
  113. mantisdk/emitter/object.py +117 -0
  114. mantisdk/emitter/reward.py +320 -0
  115. mantisdk/env_var.py +156 -0
  116. mantisdk/execution/__init__.py +15 -0
  117. mantisdk/execution/base.py +64 -0
  118. mantisdk/execution/client_server.py +443 -0
  119. mantisdk/execution/events.py +69 -0
  120. mantisdk/execution/inter_process.py +16 -0
  121. mantisdk/execution/shared_memory.py +282 -0
  122. mantisdk/instrumentation/__init__.py +119 -0
  123. mantisdk/instrumentation/agentops.py +314 -0
  124. mantisdk/instrumentation/agentops_langchain.py +45 -0
  125. mantisdk/instrumentation/litellm.py +83 -0
  126. mantisdk/instrumentation/vllm.py +81 -0
  127. mantisdk/instrumentation/weave.py +500 -0
  128. mantisdk/litagent/__init__.py +11 -0
  129. mantisdk/litagent/decorator.py +536 -0
  130. mantisdk/litagent/litagent.py +252 -0
  131. mantisdk/llm_proxy.py +1890 -0
  132. mantisdk/logging.py +370 -0
  133. mantisdk/reward.py +7 -0
  134. mantisdk/runner/__init__.py +11 -0
  135. mantisdk/runner/agent.py +845 -0
  136. mantisdk/runner/base.py +182 -0
  137. mantisdk/runner/legacy.py +309 -0
  138. mantisdk/semconv.py +170 -0
  139. mantisdk/server.py +401 -0
  140. mantisdk/store/__init__.py +23 -0
  141. mantisdk/store/base.py +897 -0
  142. mantisdk/store/client_server.py +2092 -0
  143. mantisdk/store/collection/__init__.py +30 -0
  144. mantisdk/store/collection/base.py +587 -0
  145. mantisdk/store/collection/memory.py +970 -0
  146. mantisdk/store/collection/mongo.py +1412 -0
  147. mantisdk/store/collection_based.py +1823 -0
  148. mantisdk/store/insight.py +648 -0
  149. mantisdk/store/listener.py +58 -0
  150. mantisdk/store/memory.py +396 -0
  151. mantisdk/store/mongo.py +165 -0
  152. mantisdk/store/sqlite.py +3 -0
  153. mantisdk/store/threading.py +357 -0
  154. mantisdk/store/utils.py +142 -0
  155. mantisdk/tracer/__init__.py +16 -0
  156. mantisdk/tracer/agentops.py +242 -0
  157. mantisdk/tracer/base.py +287 -0
  158. mantisdk/tracer/dummy.py +106 -0
  159. mantisdk/tracer/otel.py +555 -0
  160. mantisdk/tracer/weave.py +677 -0
  161. mantisdk/trainer/__init__.py +6 -0
  162. mantisdk/trainer/init_utils.py +263 -0
  163. mantisdk/trainer/legacy.py +367 -0
  164. mantisdk/trainer/registry.py +12 -0
  165. mantisdk/trainer/trainer.py +618 -0
  166. mantisdk/types/__init__.py +6 -0
  167. mantisdk/types/core.py +553 -0
  168. mantisdk/types/resources.py +204 -0
  169. mantisdk/types/tracer.py +515 -0
  170. mantisdk/types/tracing.py +218 -0
  171. mantisdk/utils/__init__.py +1 -0
  172. mantisdk/utils/id.py +18 -0
  173. mantisdk/utils/metrics.py +1025 -0
  174. mantisdk/utils/otel.py +578 -0
  175. mantisdk/utils/otlp.py +536 -0
  176. mantisdk/utils/server_launcher.py +1045 -0
  177. mantisdk/utils/system_snapshot.py +81 -0
  178. mantisdk/verl/__init__.py +8 -0
  179. mantisdk/verl/__main__.py +6 -0
  180. mantisdk/verl/async_server.py +46 -0
  181. mantisdk/verl/config.yaml +27 -0
  182. mantisdk/verl/daemon.py +1154 -0
  183. mantisdk/verl/dataset.py +44 -0
  184. mantisdk/verl/entrypoint.py +248 -0
  185. mantisdk/verl/trainer.py +549 -0
  186. mantisdk-0.1.0.dist-info/METADATA +119 -0
  187. mantisdk-0.1.0.dist-info/RECORD +190 -0
  188. mantisdk-0.1.0.dist-info/WHEEL +4 -0
  189. mantisdk-0.1.0.dist-info/entry_points.txt +2 -0
  190. mantisdk-0.1.0.dist-info/licenses/LICENSE +19 -0
@@ -0,0 +1,263 @@
1
+ # Copyright (c) Microsoft. All rights reserved.
2
+
3
+ """Utility helpers for dynamic component initialization within the trainer."""
4
+
5
+ from __future__ import annotations
6
+
7
+ import importlib
8
+ import inspect
9
+ from typing import Any, Callable, Dict, Optional, TypeVar, Union, cast, overload
10
+
11
+ OptionalDefaults = Dict[str, Callable[[], Any] | Any]
12
+ T = TypeVar("T")
13
+
14
+
15
+ def load_class(path: str) -> type[Any]:
16
+ """Load a class from its fully qualified import path."""
17
+ module_name, class_name = path.rsplit(".", 1)
18
+ module = importlib.import_module(module_name)
19
+ return getattr(module, class_name)
20
+
21
+
22
+ def instantiate_component(
23
+ cls: type[Any],
24
+ provided_kwargs: Optional[Dict[str, Any]] = None,
25
+ optional_defaults: Optional[OptionalDefaults] = None,
26
+ ) -> Any:
27
+ """Instantiate `cls`, filling optional kwargs when the constructor accepts them."""
28
+ kwargs = dict(provided_kwargs or {})
29
+ if optional_defaults:
30
+ signature = inspect.signature(cls.__init__)
31
+ for name, value in optional_defaults.items():
32
+ if name in kwargs or name not in signature.parameters:
33
+ continue
34
+ kwargs[name] = value() if callable(value) else value
35
+ return cls(**kwargs)
36
+
37
+
38
+ def instantiate_from_spec(
39
+ spec: Union[str, Dict[str, Any]],
40
+ *,
41
+ spec_name: str,
42
+ optional_defaults: Optional[OptionalDefaults] = None,
43
+ dict_requires_type: bool = True,
44
+ dict_default_cls: type[Any] | None = None,
45
+ registry: Optional[Dict[str, str]] = None,
46
+ ) -> Any:
47
+ """Instantiate a component from a string or dict spec."""
48
+ if isinstance(spec, str):
49
+ type_path = registry.get(spec, spec) if registry else spec
50
+ cls = load_class(type_path)
51
+ return instantiate_component(cls, optional_defaults=optional_defaults)
52
+
53
+ if isinstance(spec, dict): # pyright: ignore[reportUnnecessaryIsInstance]
54
+ spec_conf = dict(spec)
55
+ type_path = spec_conf.pop("type", None)
56
+ if type_path is None and registry and "name" in spec_conf:
57
+ type_path = registry.get(spec_conf.pop("name"))
58
+ elif registry and type_path is not None:
59
+ type_path = registry.get(type_path, type_path)
60
+ if type_path is None:
61
+ if dict_requires_type:
62
+ raise ValueError(f"{spec_name} dict must have a 'type' key with the class full name")
63
+ if dict_default_cls is None:
64
+ raise ValueError(f"{spec_name} dict missing 'type' and no default class provided")
65
+ cls = dict_default_cls
66
+ else:
67
+ cls = load_class(type_path)
68
+ return instantiate_component(cls, spec_conf, optional_defaults)
69
+
70
+ raise TypeError(f"{spec_name} spec must be a string or dict (got {type(spec)}).")
71
+
72
+
73
+ def _ensure_expected_type(
74
+ instance: Any,
75
+ expected_type: type[T],
76
+ spec_name: str,
77
+ type_error_fmt: str | None,
78
+ ) -> T:
79
+ if not isinstance(instance, expected_type):
80
+ type_name = str(type(instance)) # type: ignore
81
+ if type_error_fmt:
82
+ raise TypeError(type_error_fmt.format(type_name=type_name, expected_type=expected_type.__name__))
83
+ raise TypeError(f"{spec_name} factory returned {type_name}, which is not a {expected_type.__name__} subclass.")
84
+ return instance
85
+
86
+
87
+ @overload
88
+ def build_component(
89
+ spec: Union[T, str, Dict[str, Any], type[T], Callable[[], T], None],
90
+ *,
91
+ expected_type: type[T],
92
+ spec_name: str,
93
+ default_factory: Callable[[], T],
94
+ allow_none: bool = ...,
95
+ optional_defaults: Optional[OptionalDefaults] = ...,
96
+ dict_requires_type: bool = ...,
97
+ dict_default_cls: type[T] | None = ...,
98
+ type_error_fmt: str | None = ...,
99
+ invalid_spec_error_fmt: str | None = ...,
100
+ registry: Optional[Dict[str, str]] = ...,
101
+ ) -> T: ...
102
+
103
+
104
+ @overload
105
+ def build_component(
106
+ spec: Union[T, str, Dict[str, Any], type[T], Callable[[], T], None],
107
+ *,
108
+ expected_type: type[T],
109
+ spec_name: str,
110
+ default_factory: None = ...,
111
+ allow_none: bool,
112
+ optional_defaults: Optional[OptionalDefaults] = ...,
113
+ dict_requires_type: bool = ...,
114
+ dict_default_cls: type[T] | None = ...,
115
+ type_error_fmt: str | None = ...,
116
+ invalid_spec_error_fmt: str | None = ...,
117
+ registry: Optional[Dict[str, str]] = ...,
118
+ ) -> T | None: ...
119
+
120
+
121
+ @overload
122
+ def build_component(
123
+ spec: Union[T, str, Dict[str, Any], type[T], Callable[[], T], None],
124
+ *,
125
+ expected_type: type[T],
126
+ spec_name: str,
127
+ default_factory: None = ...,
128
+ allow_none: bool = ...,
129
+ optional_defaults: Optional[OptionalDefaults] = ...,
130
+ dict_requires_type: bool = ...,
131
+ dict_default_cls: type[T] | None = ...,
132
+ type_error_fmt: str | None = ...,
133
+ invalid_spec_error_fmt: str | None = ...,
134
+ registry: Optional[Dict[str, str]] = ...,
135
+ ) -> T | None: ...
136
+
137
+
138
+ def build_component(
139
+ spec: Union[T, str, Dict[str, Any], type[T], Callable[[], T], None],
140
+ *,
141
+ expected_type: type[T],
142
+ spec_name: str,
143
+ default_factory: Callable[[], T] | None = None,
144
+ allow_none: bool = False,
145
+ optional_defaults: Optional[OptionalDefaults] = None,
146
+ dict_requires_type: bool = True,
147
+ dict_default_cls: type[T] | None = None,
148
+ type_error_fmt: str | None = None,
149
+ invalid_spec_error_fmt: str | None = None,
150
+ registry: Optional[Dict[str, str]] = None,
151
+ ) -> T | None:
152
+ """Build and return a component instance from a flexible specification.
153
+
154
+ This function provides a flexible way to create component instances from various
155
+ input formats including direct instances, class types, factory functions, import
156
+ paths, or configuration dictionaries.
157
+
158
+ Args:
159
+ spec: The component specification. Can be:
160
+ - An instance of expected_type (returned as-is)
161
+ - A string import path (e.g., 'module.Class') or registry key
162
+ - A dict with 'type' key (import path or registry key) and constructor kwargs
163
+ - A class type (will be instantiated)
164
+ - A factory function (will be called)
165
+ - None (uses default_factory or returns None if allow_none=True)
166
+ expected_type: The type that the resulting instance must be or inherit from.
167
+ spec_name: Descriptive name for the spec, used in error messages.
168
+ default_factory: Optional factory function called when spec is None.
169
+ allow_none: If True, allows None to be returned when spec is None and
170
+ no default_factory is provided.
171
+ optional_defaults: Dict mapping parameter names to default values or factory
172
+ functions that will be injected if the constructor accepts them.
173
+ dict_requires_type: If True, dict specs must include a 'type' key.
174
+ dict_default_cls: Default class to use for dict specs without a 'type' key
175
+ (only used when dict_requires_type=False).
176
+ type_error_fmt: Custom format string for type validation errors. Should include
177
+ {type_name} and {expected_type} placeholders.
178
+ invalid_spec_error_fmt: Custom format string for invalid spec type errors.
179
+ Should include {actual_type} and {expected_type} placeholders.
180
+ registry: Optional mapping of short names to fully qualified import paths.
181
+ When provided, string specs or dict 'type'/'name' entries are first
182
+ resolved through this registry before attempting to import.
183
+
184
+ Returns:
185
+ An instance of expected_type, or None if allow_none=True and spec is None
186
+ without a default_factory.
187
+
188
+ Raises:
189
+ TypeError: If the instantiated object is not an instance of expected_type.
190
+ ValueError: If spec is None and neither default_factory nor allow_none is set,
191
+ or if spec type is invalid, or if dict spec is invalid.
192
+
193
+ Examples:
194
+ >>> # Direct instance
195
+ >>> optimizer = build_component(AdamW(), expected_type=Optimizer, spec_name='optimizer')
196
+ >>>
197
+ >>> # String import path
198
+ >>> optimizer = build_component('torch.optim.AdamW', expected_type=Optimizer, spec_name='optimizer')
199
+ >>>
200
+ >>> # Dict with type and kwargs
201
+ >>> spec = {'type': 'torch.optim.AdamW', 'lr': 0.001}
202
+ >>> optimizer = build_component(spec, expected_type=Optimizer, spec_name='optimizer')
203
+ >>>
204
+ >>> # Class type
205
+ >>> optimizer = build_component(AdamW, expected_type=Optimizer, spec_name='optimizer')
206
+ >>>
207
+ >>> # Factory function
208
+ >>> optimizer = build_component(lambda: AdamW(lr=0.001), expected_type=Optimizer,
209
+ ... spec_name='optimizer')
210
+ """
211
+ if isinstance(spec, expected_type):
212
+ return cast(T, spec)
213
+
214
+ if spec is None:
215
+ if default_factory is not None:
216
+ instance = default_factory()
217
+ return _ensure_expected_type(instance, expected_type, spec_name, type_error_fmt)
218
+ if allow_none:
219
+ return None
220
+ raise ValueError(
221
+ invalid_spec_error_fmt.format(actual_type=type(spec), expected_type=expected_type.__name__)
222
+ if invalid_spec_error_fmt
223
+ else f"{spec_name} cannot be None."
224
+ )
225
+
226
+ if isinstance(spec, type) and issubclass(spec, expected_type):
227
+ instance = instantiate_component(spec, optional_defaults=optional_defaults)
228
+ return _ensure_expected_type(instance, expected_type, spec_name, type_error_fmt)
229
+
230
+ if callable(spec) and not isinstance(spec, type): # type: ignore
231
+ instance = spec()
232
+ return _ensure_expected_type(instance, expected_type, spec_name, type_error_fmt)
233
+
234
+ if isinstance(spec, str):
235
+ instance = instantiate_from_spec(
236
+ spec,
237
+ spec_name=spec_name,
238
+ optional_defaults=optional_defaults,
239
+ dict_requires_type=dict_requires_type,
240
+ dict_default_cls=dict_default_cls,
241
+ registry=registry,
242
+ )
243
+ return _ensure_expected_type(instance, expected_type, spec_name, type_error_fmt)
244
+
245
+ if isinstance(spec, dict):
246
+ instance = instantiate_from_spec(
247
+ spec, # type: ignore
248
+ spec_name=spec_name,
249
+ optional_defaults=optional_defaults,
250
+ dict_requires_type=dict_requires_type,
251
+ dict_default_cls=dict_default_cls,
252
+ registry=registry,
253
+ )
254
+ return _ensure_expected_type(instance, expected_type, spec_name, type_error_fmt)
255
+
256
+ if invalid_spec_error_fmt:
257
+ raise ValueError(invalid_spec_error_fmt.format(actual_type=type(spec), expected_type=expected_type.__name__)) # type: ignore
258
+
259
+ type_name = str(type(spec)) # type: ignore
260
+ raise ValueError(f"Invalid {spec_name} type: {type_name}. Expected {expected_type.__name__}, str, dict, or None.")
261
+
262
+
263
+ __all__ = ["OptionalDefaults", "build_component", "instantiate_component", "instantiate_from_spec", "load_class"]
@@ -0,0 +1,367 @@
1
+ # Copyright (c) Microsoft. All rights reserved.
2
+
3
+ import asyncio
4
+ import logging
5
+ import multiprocessing
6
+ import signal
7
+ import time
8
+ import warnings
9
+ from typing import Any, List, Optional, TypeVar, Union
10
+
11
+ from mantisdk.adapter import TraceAdapter, TracerTraceToTriplet
12
+ from mantisdk.algorithm import Algorithm
13
+ from mantisdk.client import MantisdkClient
14
+ from mantisdk.litagent import LitAgent
15
+ from mantisdk.runner import LegacyAgentRunner
16
+ from mantisdk.tracer.base import Tracer
17
+ from mantisdk.types import Dataset, ParallelWorkerBase
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+ T_co = TypeVar("T_co", covariant=True)
22
+
23
+
24
+ class TrainerLegacy(ParallelWorkerBase):
25
+ """Trainer for legacy mode for v0.1 compatibility."""
26
+
27
+ def __init__(self, *args: Any, **kwargs: Any):
28
+ """Initialize the TrainerLegacy.
29
+
30
+ This method is mainly to make type checker happy.
31
+ It won't be used in practice.
32
+ """
33
+ self._dev = kwargs.pop("dev", False)
34
+ self.algorithm: Optional[Algorithm] = kwargs.pop("algorithm", None)
35
+ self.tracer: Tracer = kwargs.pop("tracer", None)
36
+ self.n_workers: int = kwargs.pop("n_workers", None)
37
+ self.max_tasks: Optional[int] = kwargs.pop("max_tasks", None)
38
+ self.daemon: bool = kwargs.pop("daemon", True)
39
+ self.triplet_exporter: TraceAdapter[Any] = kwargs.pop("triplet_exporter", None)
40
+
41
+ def _extract_client_from_data(
42
+ self, data: Union[str, MantisdkClient, Dataset[Any]]
43
+ ) -> Optional[MantisdkClient]:
44
+ """Extract client from data if it's a string URL or MantisdkClient."""
45
+ if isinstance(data, str):
46
+ if not data.startswith("http://") and not data.startswith("https://"):
47
+ raise ValueError("String data must be a valid URL starting with http:// or https://")
48
+ return MantisdkClient(endpoint=data)
49
+ elif isinstance(data, MantisdkClient):
50
+ return data
51
+ return None
52
+
53
+ def _extract_dataset_from_data(
54
+ self, data: Union[str, MantisdkClient, Dataset[Any]]
55
+ ) -> Optional[Dataset[Any]]:
56
+ """Extract dataset from data if it's a Dataset."""
57
+ if isinstance(data, str) or isinstance(data, MantisdkClient):
58
+ return None
59
+ return data
60
+
61
+ def _determine_backend(
62
+ self,
63
+ train_data: Union[str, MantisdkClient, Dataset[Any]],
64
+ dev_data: Union[str, MantisdkClient, Dataset[Any], None] = None,
65
+ ) -> Union[str, MantisdkClient]:
66
+ """Determine which backend to use for initialization."""
67
+ if self._dev:
68
+ if dev_data is None:
69
+ raise ValueError("dev_data must be provided when dev=True.")
70
+ client = self._extract_client_from_data(dev_data)
71
+ if client is None:
72
+ raise ValueError("dev_data must be a string URL or MantisdkClient when dev=True.")
73
+ return client
74
+ else:
75
+ client = self._extract_client_from_data(train_data)
76
+ if client is None and self.algorithm is None:
77
+ raise ValueError(
78
+ "train_data must be a string URL or MantisdkClient when no algorithm is provided."
79
+ )
80
+ elif client is None and self.algorithm is not None:
81
+ # Algorithm will be responsible for creating the client
82
+ client = self.algorithm.get_client()
83
+ logger.info(f"Algorithm created client: {client}")
84
+ return client
85
+ if client is None:
86
+ raise ValueError(
87
+ "train_data must be a string URL or MantisdkClient when no algorithm is provided."
88
+ )
89
+ return client
90
+
91
+ def init(self, backend: Union[str, MantisdkClient]) -> None:
92
+ logger.info(f"Initializing Trainer...")
93
+
94
+ self._init_client(backend)
95
+
96
+ self.tracer.init()
97
+
98
+ logger.info(f"Trainer main initialization complete.")
99
+
100
+ def teardown(self) -> None:
101
+ logger.info(f"Cleaning up Trainer...")
102
+ self.tracer.teardown()
103
+
104
+ self._client = None
105
+ logger.info(f"Trainer main cleanup complete.")
106
+
107
+ def client(self) -> MantisdkClient:
108
+ """Returns the MantisdkClient instance."""
109
+ if self._client is None:
110
+ raise RuntimeError("MantisdkClient has not been initialized. Call `init` first.")
111
+ return self._client
112
+
113
+ def _init_client(self, backend: Union[str, MantisdkClient]) -> MantisdkClient:
114
+ if self._client is None:
115
+ if isinstance(backend, MantisdkClient):
116
+ logger.info("Using provided MantisdkClient instance.")
117
+ self._client = backend
118
+ else:
119
+ logger.info(f"Initializing MantisdkClient with endpoint: {backend}")
120
+ if not isinstance(backend, str): # type: ignore
121
+ raise ValueError("backend must be a string URL or an MantisdkClient instance.")
122
+ if not backend.startswith("http://") and not backend.startswith("https://"):
123
+ raise ValueError("backend must be a valid URL starting with http:// or https://")
124
+ # Initialize the client with the provided backend URL
125
+ self._client = MantisdkClient(endpoint=backend)
126
+ else:
127
+ logger.warning("MantisdkClient already initialized. Returning existing instance.")
128
+ return self._client
129
+
130
+ def _worker_main_loop(self, agent: LitAgent[Any], worker_id: int, is_async: bool):
131
+ """The main function for each worker process.
132
+
133
+ This function initializes the client and the loop, then starts the
134
+ execution. It also configures process-specific settings like the
135
+ process title and signal handling.
136
+
137
+ Args:
138
+ agent: The `LitAgent` instance to run.
139
+ worker_id: The unique ID for this worker.
140
+ is_async: A boolean indicating if the async loop should be run.
141
+ """
142
+ if self.n_workers > 1:
143
+ import setproctitle
144
+
145
+ # Ignore Ctrl+C in worker processes; the main process handles it
146
+ signal.signal(signal.SIGINT, signal.SIG_IGN)
147
+ setproctitle.setproctitle(multiprocessing.current_process().name)
148
+
149
+ # Now we are in child processes, so we can safely set up the environment.
150
+ agent.set_trainer(self) # type: ignore
151
+ if not isinstance(self.triplet_exporter, TracerTraceToTriplet): # type: ignore
152
+ raise ValueError("triplet_exporter must be a TracerTraceToTriplet for the legacy trainer.")
153
+ # TODO: this should be set elsewhere
154
+ if agent.trained_agents:
155
+ self.triplet_exporter.agent_match = agent.trained_agents
156
+ self._initialize_worker_env(worker_id)
157
+
158
+ mode = "Async" if is_async else "Sync"
159
+ logger.info(f"[Worker {worker_id}] {mode} worker process started.")
160
+
161
+ num_processed = 0
162
+
163
+ try:
164
+ client = self.client()
165
+ loop = LegacyAgentRunner(
166
+ agent=agent,
167
+ client=client,
168
+ tracer=self.tracer,
169
+ triplet_exporter=self.triplet_exporter,
170
+ max_tasks=self.max_tasks,
171
+ worker_id=worker_id,
172
+ )
173
+ loop.init_worker(worker_id) # type: ignore
174
+ if is_async:
175
+ num_processed = asyncio.run(loop.iter_async())
176
+ else:
177
+ num_processed = loop.iter()
178
+ except Exception:
179
+ logger.exception(f"[Worker {worker_id}] Unhandled exception in worker loop.")
180
+ finally:
181
+ self._teardown_worker_env(worker_id)
182
+
183
+ return num_processed
184
+
185
+ def _initialize_worker_env(self, worker_id: int):
186
+ logger.info(f"[Worker {worker_id}] Setting up trainer environment...") # worker_id included in process name
187
+ self.tracer.init_worker(worker_id)
188
+
189
+ def _teardown_worker_env(self, worker_id: int):
190
+ logger.info(f"[Worker {worker_id}] Cleaning up trainer environment...")
191
+ self.tracer.teardown_worker(worker_id)
192
+ logger.info(f"[Worker {worker_id}] Environment cleanup complete.")
193
+
194
+ @staticmethod
195
+ def kill_orphaned_processes() -> None:
196
+ """
197
+ Kill any orphaned processes that may have been left behind by previous runs.
198
+ This is useful for cleaning up after crashes or unexpected exits.
199
+ """
200
+ import psutil
201
+
202
+ for proc in psutil.process_iter(): # type: ignore
203
+ # check whether the process name matches
204
+ if proc.name().startswith("Mantisdk-"):
205
+ proc.kill()
206
+
207
+ def _terminate_processes(self, processes: List[multiprocessing.Process]) -> None:
208
+ if self.n_workers > 1 and len(processes) > 0:
209
+ for i, p in enumerate(processes):
210
+ if p.is_alive():
211
+ logger.info(f"Terminating worker {i} (name: {p.name}, PID: {p.pid})...")
212
+ p.terminate()
213
+ else:
214
+ logger.info(f"Worker {i} (name: {p.name}, PID: {p.pid}) is not alive or has already terminated.")
215
+ for i, p in enumerate(processes):
216
+ if p.is_alive():
217
+ p.join(timeout=10) # Give some time to terminate
218
+ if p.is_alive(): # If still alive, kill
219
+ logger.warning(
220
+ f"Worker {i} (name: {p.name}, PID: {p.pid}) did not terminate gracefully, killing..."
221
+ )
222
+ p.kill()
223
+ p.join(timeout=10) # Ensure it's reaped
224
+
225
+ def fit_v0(
226
+ self,
227
+ agent: LitAgent[T_co],
228
+ train_data: Union[str, MantisdkClient, Dataset[T_co]],
229
+ *,
230
+ val_data: Union[str, MantisdkClient, Dataset[T_co], None] = None,
231
+ dev_data: Union[str, MantisdkClient, Dataset[T_co], None] = None,
232
+ dev_backend: Union[str, MantisdkClient, None] = None,
233
+ ):
234
+ """Train the agent using the provided data.
235
+
236
+ Each data argument can be a string URL connecting to a mantisdk server,
237
+ or an MantisdkClient instance connecting to a server (or mock server), or a dataset.
238
+ If no algorithm is provided when instantiating the trainer, the data must be
239
+ provided to connecting a server. Otherwise, dataset is also allowed and will be
240
+ passed to the algorithm.
241
+
242
+ If the algorithm is instantiated and there is no URL/client provided,
243
+ the algorithm will be responsible for creating a client that will connect to itself.
244
+ It can also create a mock client if the algorithm does not require a server.
245
+ """
246
+
247
+ if dev_backend is not None:
248
+ warnings.warn("dev_backend is deprecated. Use dev_data instead.")
249
+ if dev_data is not None:
250
+ raise ValueError("dev_data and dev_backend cannot be provided at the same time.")
251
+ dev_data = dev_backend
252
+
253
+ # Extract datasets for algorithm if available
254
+ train_dataset = self._extract_dataset_from_data(train_data)
255
+ val_dataset = self._extract_dataset_from_data(val_data) if val_data else None
256
+
257
+ # Initialize the algorithm with trainer if provided
258
+ if self.algorithm is not None:
259
+ self.algorithm.set_trainer(self) # type: ignore
260
+ # DO NOT RUN TRAINING HERE. Need to spawn the worker first.
261
+
262
+ # Determine the backend to use for client-server mode
263
+ backend = self._determine_backend(train_data, dev_data)
264
+
265
+ if self._dev:
266
+ logger.warning(f"Running in dev mode. Using dev backend: {backend}")
267
+ else:
268
+ logger.debug(f"Running in non-dev mode. Using backend: {backend}")
269
+
270
+ self.init(backend)
271
+
272
+ processes: List[multiprocessing.Process] = []
273
+
274
+ # Determine if the agent is asynchronous
275
+
276
+ mode = "asynchronous" if agent.is_async() else "synchronous"
277
+
278
+ try:
279
+ if self.n_workers == 1:
280
+ logger.info(f"Running with n_workers=1 ({mode} in main process).")
281
+
282
+ # Warn if algorithm is set with single worker mode
283
+ if self.algorithm is not None:
284
+ logger.warning(
285
+ "Algorithm is set but using single worker mode. Algorithm will never get the chance to run."
286
+ )
287
+ # Ideally the single worker should be run in a separate thread or process.
288
+
289
+ num_tasks = self._worker_main_loop(agent, 0, agent.is_async())
290
+ logger.info(f"Single worker mode finished. Tasks processed: {num_tasks}")
291
+
292
+ # If algorithm is provided and we have datasets, run algorithm after worker completes
293
+ if self.algorithm is not None and train_dataset is not None:
294
+ logger.info("Running algorithm training after worker completion.")
295
+ self.algorithm.run(
296
+ train_dataset=train_dataset,
297
+ val_dataset=val_dataset,
298
+ )
299
+ else:
300
+ logger.info(f"Running with n_workers={self.n_workers} ({mode} multiprocessing).")
301
+ for i in range(self.n_workers):
302
+ process_name = f"Mantisdk-Worker-{i}"
303
+ p = multiprocessing.Process(
304
+ target=self._worker_main_loop,
305
+ args=(agent, i, agent.is_async()),
306
+ daemon=self.daemon,
307
+ name=process_name,
308
+ )
309
+ processes.append(p)
310
+ logger.info(f"Starting worker process {i} (name: {process_name})...")
311
+ p.start()
312
+
313
+ if self.daemon:
314
+ # If algorithm is provided and we have datasets, pass them to the algorithm
315
+ if self.algorithm is not None:
316
+ logger.info("All workers have been spawned. Running algorithm training with provided datasets.")
317
+ self.algorithm.run(
318
+ train_dataset=train_dataset,
319
+ val_dataset=val_dataset,
320
+ )
321
+ logger.info("Algorithm exits. Killing the workers.")
322
+ self._terminate_processes(processes)
323
+
324
+ for i, p in enumerate(processes):
325
+ p.join() # Wait for the process to complete
326
+ logger.info(
327
+ f"Worker process {i} (name: {p.name}, PID: {p.pid}) joined with exit code {p.exitcode}."
328
+ )
329
+ if p.exitcode != 0:
330
+ logger.warning(
331
+ f"Worker process {i} (name: {p.name}, PID: {p.pid}) exited with non-zero code: {p.exitcode}."
332
+ )
333
+
334
+ logger.info(f"All {self.n_workers} worker processes have completed.")
335
+ else:
336
+ logger.info("All worker processes started. Main process will not wait.")
337
+
338
+ # A hack to stop the main process from waiting for child processes to finish.
339
+ time.sleep(1) # Give workers time to start
340
+ import multiprocessing.process as multiprocessing_process
341
+
342
+ multiprocessing_process._children.clear() # type: ignore
343
+
344
+ if self.algorithm is not None:
345
+ logger.info("Main process continues to run algorithm.")
346
+ self.algorithm.run(
347
+ train_dataset=train_dataset,
348
+ val_dataset=val_dataset,
349
+ )
350
+ logger.info("Algorithm exits. Killing the workers.")
351
+ self._terminate_processes(processes)
352
+
353
+ except KeyboardInterrupt:
354
+ logger.info("KeyboardInterrupt received. Killing the workers.")
355
+ self._terminate_processes(processes)
356
+ logger.info(f"Workers terminated or single worker interrupted.")
357
+ raise
358
+ except Exception:
359
+ logger.exception(f"Unhandled exception in fit method.")
360
+ self._terminate_processes(processes)
361
+ logger.info(f"Workers terminated or single worker interrupted.")
362
+ raise
363
+ finally:
364
+ if self.daemon:
365
+ self.teardown()
366
+ else:
367
+ logger.info("Main process exiting. Please use Trainer.kill_orphaned_processes() for cleanup.")
@@ -0,0 +1,12 @@
1
+ # Copyright (c) Microsoft. All rights reserved.
2
+
3
+ """Put components in this file to make them available to the Trainer.
4
+
5
+ Currently only used for ExecutionStrategy.
6
+ """
7
+
8
+ ExecutionStrategyRegistry = {
9
+ "shm": "mantisdk.execution.shared_memory.SharedMemoryExecutionStrategy",
10
+ # "ipc": "mantisdk.execution.inter_process.InterProcessExecutionStrategy",
11
+ "cs": "mantisdk.execution.client_server.ClientServerExecutionStrategy",
12
+ }