mantisdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mantisdk might be problematic. Click here for more details.
- mantisdk/__init__.py +22 -0
- mantisdk/adapter/__init__.py +15 -0
- mantisdk/adapter/base.py +94 -0
- mantisdk/adapter/messages.py +270 -0
- mantisdk/adapter/triplet.py +1028 -0
- mantisdk/algorithm/__init__.py +39 -0
- mantisdk/algorithm/apo/__init__.py +5 -0
- mantisdk/algorithm/apo/apo.py +889 -0
- mantisdk/algorithm/apo/prompts/apply_edit_variant01.poml +22 -0
- mantisdk/algorithm/apo/prompts/apply_edit_variant02.poml +18 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant01.poml +18 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant02.poml +16 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant03.poml +107 -0
- mantisdk/algorithm/base.py +162 -0
- mantisdk/algorithm/decorator.py +264 -0
- mantisdk/algorithm/fast.py +250 -0
- mantisdk/algorithm/gepa/__init__.py +59 -0
- mantisdk/algorithm/gepa/adapter.py +459 -0
- mantisdk/algorithm/gepa/gepa.py +364 -0
- mantisdk/algorithm/gepa/lib/__init__.py +18 -0
- mantisdk/algorithm/gepa/lib/adapters/README.md +12 -0
- mantisdk/algorithm/gepa/lib/adapters/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/README.md +341 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/__init__.py +1 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/anymaths_adapter.py +174 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/requirements.txt +1 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/README.md +0 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/default_adapter.py +209 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/README.md +7 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/dspy_adapter.py +307 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/README.md +99 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/dspy_program_proposal_signature.py +137 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/full_program_adapter.py +266 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/GEPA_RAG.md +621 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/__init__.py +56 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/evaluation_metrics.py +226 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/generic_rag_adapter.py +496 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/rag_pipeline.py +238 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_store_interface.py +212 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/__init__.py +2 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/chroma_store.py +196 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/lancedb_store.py +422 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/milvus_store.py +409 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/qdrant_store.py +368 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/weaviate_store.py +418 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/README.md +552 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/__init__.py +37 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_adapter.py +705 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_client.py +364 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/README.md +9 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/terminal_bench_adapter.py +217 -0
- mantisdk/algorithm/gepa/lib/api.py +375 -0
- mantisdk/algorithm/gepa/lib/core/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/core/adapter.py +180 -0
- mantisdk/algorithm/gepa/lib/core/data_loader.py +74 -0
- mantisdk/algorithm/gepa/lib/core/engine.py +356 -0
- mantisdk/algorithm/gepa/lib/core/result.py +233 -0
- mantisdk/algorithm/gepa/lib/core/state.py +636 -0
- mantisdk/algorithm/gepa/lib/examples/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/examples/aime.py +24 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/eval_default.py +111 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/instruction_prompt.txt +9 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/optimal_prompt.txt +24 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/train_anymaths.py +177 -0
- mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/arc_agi.ipynb +25705 -0
- mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/example.ipynb +348 -0
- mantisdk/algorithm/gepa/lib/examples/mcp_adapter/__init__.py +4 -0
- mantisdk/algorithm/gepa/lib/examples/mcp_adapter/mcp_optimization_example.py +455 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/RAG_GUIDE.md +613 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/__init__.py +9 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/rag_optimization.py +824 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/requirements-rag.txt +29 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/instruction_prompt.txt +16 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/terminus.txt +9 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/train_terminus.py +161 -0
- mantisdk/algorithm/gepa/lib/gepa_utils.py +117 -0
- mantisdk/algorithm/gepa/lib/logging/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/logging/experiment_tracker.py +187 -0
- mantisdk/algorithm/gepa/lib/logging/logger.py +75 -0
- mantisdk/algorithm/gepa/lib/logging/utils.py +103 -0
- mantisdk/algorithm/gepa/lib/proposer/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/proposer/base.py +31 -0
- mantisdk/algorithm/gepa/lib/proposer/merge.py +357 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/base.py +49 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/reflective_mutation.py +176 -0
- mantisdk/algorithm/gepa/lib/py.typed +0 -0
- mantisdk/algorithm/gepa/lib/strategies/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/strategies/batch_sampler.py +77 -0
- mantisdk/algorithm/gepa/lib/strategies/candidate_selector.py +50 -0
- mantisdk/algorithm/gepa/lib/strategies/component_selector.py +36 -0
- mantisdk/algorithm/gepa/lib/strategies/eval_policy.py +64 -0
- mantisdk/algorithm/gepa/lib/strategies/instruction_proposal.py +127 -0
- mantisdk/algorithm/gepa/lib/utils/__init__.py +10 -0
- mantisdk/algorithm/gepa/lib/utils/stop_condition.py +196 -0
- mantisdk/algorithm/gepa/tracing.py +105 -0
- mantisdk/algorithm/utils.py +177 -0
- mantisdk/algorithm/verl/__init__.py +5 -0
- mantisdk/algorithm/verl/interface.py +202 -0
- mantisdk/cli/__init__.py +56 -0
- mantisdk/cli/prometheus.py +115 -0
- mantisdk/cli/store.py +131 -0
- mantisdk/cli/vllm.py +29 -0
- mantisdk/client.py +408 -0
- mantisdk/config.py +348 -0
- mantisdk/emitter/__init__.py +43 -0
- mantisdk/emitter/annotation.py +370 -0
- mantisdk/emitter/exception.py +54 -0
- mantisdk/emitter/message.py +61 -0
- mantisdk/emitter/object.py +117 -0
- mantisdk/emitter/reward.py +320 -0
- mantisdk/env_var.py +156 -0
- mantisdk/execution/__init__.py +15 -0
- mantisdk/execution/base.py +64 -0
- mantisdk/execution/client_server.py +443 -0
- mantisdk/execution/events.py +69 -0
- mantisdk/execution/inter_process.py +16 -0
- mantisdk/execution/shared_memory.py +282 -0
- mantisdk/instrumentation/__init__.py +119 -0
- mantisdk/instrumentation/agentops.py +314 -0
- mantisdk/instrumentation/agentops_langchain.py +45 -0
- mantisdk/instrumentation/litellm.py +83 -0
- mantisdk/instrumentation/vllm.py +81 -0
- mantisdk/instrumentation/weave.py +500 -0
- mantisdk/litagent/__init__.py +11 -0
- mantisdk/litagent/decorator.py +536 -0
- mantisdk/litagent/litagent.py +252 -0
- mantisdk/llm_proxy.py +1890 -0
- mantisdk/logging.py +370 -0
- mantisdk/reward.py +7 -0
- mantisdk/runner/__init__.py +11 -0
- mantisdk/runner/agent.py +845 -0
- mantisdk/runner/base.py +182 -0
- mantisdk/runner/legacy.py +309 -0
- mantisdk/semconv.py +170 -0
- mantisdk/server.py +401 -0
- mantisdk/store/__init__.py +23 -0
- mantisdk/store/base.py +897 -0
- mantisdk/store/client_server.py +2092 -0
- mantisdk/store/collection/__init__.py +30 -0
- mantisdk/store/collection/base.py +587 -0
- mantisdk/store/collection/memory.py +970 -0
- mantisdk/store/collection/mongo.py +1412 -0
- mantisdk/store/collection_based.py +1823 -0
- mantisdk/store/insight.py +648 -0
- mantisdk/store/listener.py +58 -0
- mantisdk/store/memory.py +396 -0
- mantisdk/store/mongo.py +165 -0
- mantisdk/store/sqlite.py +3 -0
- mantisdk/store/threading.py +357 -0
- mantisdk/store/utils.py +142 -0
- mantisdk/tracer/__init__.py +16 -0
- mantisdk/tracer/agentops.py +242 -0
- mantisdk/tracer/base.py +287 -0
- mantisdk/tracer/dummy.py +106 -0
- mantisdk/tracer/otel.py +555 -0
- mantisdk/tracer/weave.py +677 -0
- mantisdk/trainer/__init__.py +6 -0
- mantisdk/trainer/init_utils.py +263 -0
- mantisdk/trainer/legacy.py +367 -0
- mantisdk/trainer/registry.py +12 -0
- mantisdk/trainer/trainer.py +618 -0
- mantisdk/types/__init__.py +6 -0
- mantisdk/types/core.py +553 -0
- mantisdk/types/resources.py +204 -0
- mantisdk/types/tracer.py +515 -0
- mantisdk/types/tracing.py +218 -0
- mantisdk/utils/__init__.py +1 -0
- mantisdk/utils/id.py +18 -0
- mantisdk/utils/metrics.py +1025 -0
- mantisdk/utils/otel.py +578 -0
- mantisdk/utils/otlp.py +536 -0
- mantisdk/utils/server_launcher.py +1045 -0
- mantisdk/utils/system_snapshot.py +81 -0
- mantisdk/verl/__init__.py +8 -0
- mantisdk/verl/__main__.py +6 -0
- mantisdk/verl/async_server.py +46 -0
- mantisdk/verl/config.yaml +27 -0
- mantisdk/verl/daemon.py +1154 -0
- mantisdk/verl/dataset.py +44 -0
- mantisdk/verl/entrypoint.py +248 -0
- mantisdk/verl/trainer.py +549 -0
- mantisdk-0.1.0.dist-info/METADATA +119 -0
- mantisdk-0.1.0.dist-info/RECORD +190 -0
- mantisdk-0.1.0.dist-info/WHEEL +4 -0
- mantisdk-0.1.0.dist-info/entry_points.txt +2 -0
- mantisdk-0.1.0.dist-info/licenses/LICENSE +19 -0
|
@@ -0,0 +1,263 @@
|
|
|
1
|
+
# Copyright (c) Microsoft. All rights reserved.
|
|
2
|
+
|
|
3
|
+
"""Utility helpers for dynamic component initialization within the trainer."""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import importlib
|
|
8
|
+
import inspect
|
|
9
|
+
from typing import Any, Callable, Dict, Optional, TypeVar, Union, cast, overload
|
|
10
|
+
|
|
11
|
+
OptionalDefaults = Dict[str, Callable[[], Any] | Any]
|
|
12
|
+
T = TypeVar("T")
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def load_class(path: str) -> type[Any]:
|
|
16
|
+
"""Load a class from its fully qualified import path."""
|
|
17
|
+
module_name, class_name = path.rsplit(".", 1)
|
|
18
|
+
module = importlib.import_module(module_name)
|
|
19
|
+
return getattr(module, class_name)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def instantiate_component(
|
|
23
|
+
cls: type[Any],
|
|
24
|
+
provided_kwargs: Optional[Dict[str, Any]] = None,
|
|
25
|
+
optional_defaults: Optional[OptionalDefaults] = None,
|
|
26
|
+
) -> Any:
|
|
27
|
+
"""Instantiate `cls`, filling optional kwargs when the constructor accepts them."""
|
|
28
|
+
kwargs = dict(provided_kwargs or {})
|
|
29
|
+
if optional_defaults:
|
|
30
|
+
signature = inspect.signature(cls.__init__)
|
|
31
|
+
for name, value in optional_defaults.items():
|
|
32
|
+
if name in kwargs or name not in signature.parameters:
|
|
33
|
+
continue
|
|
34
|
+
kwargs[name] = value() if callable(value) else value
|
|
35
|
+
return cls(**kwargs)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def instantiate_from_spec(
|
|
39
|
+
spec: Union[str, Dict[str, Any]],
|
|
40
|
+
*,
|
|
41
|
+
spec_name: str,
|
|
42
|
+
optional_defaults: Optional[OptionalDefaults] = None,
|
|
43
|
+
dict_requires_type: bool = True,
|
|
44
|
+
dict_default_cls: type[Any] | None = None,
|
|
45
|
+
registry: Optional[Dict[str, str]] = None,
|
|
46
|
+
) -> Any:
|
|
47
|
+
"""Instantiate a component from a string or dict spec."""
|
|
48
|
+
if isinstance(spec, str):
|
|
49
|
+
type_path = registry.get(spec, spec) if registry else spec
|
|
50
|
+
cls = load_class(type_path)
|
|
51
|
+
return instantiate_component(cls, optional_defaults=optional_defaults)
|
|
52
|
+
|
|
53
|
+
if isinstance(spec, dict): # pyright: ignore[reportUnnecessaryIsInstance]
|
|
54
|
+
spec_conf = dict(spec)
|
|
55
|
+
type_path = spec_conf.pop("type", None)
|
|
56
|
+
if type_path is None and registry and "name" in spec_conf:
|
|
57
|
+
type_path = registry.get(spec_conf.pop("name"))
|
|
58
|
+
elif registry and type_path is not None:
|
|
59
|
+
type_path = registry.get(type_path, type_path)
|
|
60
|
+
if type_path is None:
|
|
61
|
+
if dict_requires_type:
|
|
62
|
+
raise ValueError(f"{spec_name} dict must have a 'type' key with the class full name")
|
|
63
|
+
if dict_default_cls is None:
|
|
64
|
+
raise ValueError(f"{spec_name} dict missing 'type' and no default class provided")
|
|
65
|
+
cls = dict_default_cls
|
|
66
|
+
else:
|
|
67
|
+
cls = load_class(type_path)
|
|
68
|
+
return instantiate_component(cls, spec_conf, optional_defaults)
|
|
69
|
+
|
|
70
|
+
raise TypeError(f"{spec_name} spec must be a string or dict (got {type(spec)}).")
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _ensure_expected_type(
|
|
74
|
+
instance: Any,
|
|
75
|
+
expected_type: type[T],
|
|
76
|
+
spec_name: str,
|
|
77
|
+
type_error_fmt: str | None,
|
|
78
|
+
) -> T:
|
|
79
|
+
if not isinstance(instance, expected_type):
|
|
80
|
+
type_name = str(type(instance)) # type: ignore
|
|
81
|
+
if type_error_fmt:
|
|
82
|
+
raise TypeError(type_error_fmt.format(type_name=type_name, expected_type=expected_type.__name__))
|
|
83
|
+
raise TypeError(f"{spec_name} factory returned {type_name}, which is not a {expected_type.__name__} subclass.")
|
|
84
|
+
return instance
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
@overload
|
|
88
|
+
def build_component(
|
|
89
|
+
spec: Union[T, str, Dict[str, Any], type[T], Callable[[], T], None],
|
|
90
|
+
*,
|
|
91
|
+
expected_type: type[T],
|
|
92
|
+
spec_name: str,
|
|
93
|
+
default_factory: Callable[[], T],
|
|
94
|
+
allow_none: bool = ...,
|
|
95
|
+
optional_defaults: Optional[OptionalDefaults] = ...,
|
|
96
|
+
dict_requires_type: bool = ...,
|
|
97
|
+
dict_default_cls: type[T] | None = ...,
|
|
98
|
+
type_error_fmt: str | None = ...,
|
|
99
|
+
invalid_spec_error_fmt: str | None = ...,
|
|
100
|
+
registry: Optional[Dict[str, str]] = ...,
|
|
101
|
+
) -> T: ...
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
@overload
|
|
105
|
+
def build_component(
|
|
106
|
+
spec: Union[T, str, Dict[str, Any], type[T], Callable[[], T], None],
|
|
107
|
+
*,
|
|
108
|
+
expected_type: type[T],
|
|
109
|
+
spec_name: str,
|
|
110
|
+
default_factory: None = ...,
|
|
111
|
+
allow_none: bool,
|
|
112
|
+
optional_defaults: Optional[OptionalDefaults] = ...,
|
|
113
|
+
dict_requires_type: bool = ...,
|
|
114
|
+
dict_default_cls: type[T] | None = ...,
|
|
115
|
+
type_error_fmt: str | None = ...,
|
|
116
|
+
invalid_spec_error_fmt: str | None = ...,
|
|
117
|
+
registry: Optional[Dict[str, str]] = ...,
|
|
118
|
+
) -> T | None: ...
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
@overload
|
|
122
|
+
def build_component(
|
|
123
|
+
spec: Union[T, str, Dict[str, Any], type[T], Callable[[], T], None],
|
|
124
|
+
*,
|
|
125
|
+
expected_type: type[T],
|
|
126
|
+
spec_name: str,
|
|
127
|
+
default_factory: None = ...,
|
|
128
|
+
allow_none: bool = ...,
|
|
129
|
+
optional_defaults: Optional[OptionalDefaults] = ...,
|
|
130
|
+
dict_requires_type: bool = ...,
|
|
131
|
+
dict_default_cls: type[T] | None = ...,
|
|
132
|
+
type_error_fmt: str | None = ...,
|
|
133
|
+
invalid_spec_error_fmt: str | None = ...,
|
|
134
|
+
registry: Optional[Dict[str, str]] = ...,
|
|
135
|
+
) -> T | None: ...
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def build_component(
|
|
139
|
+
spec: Union[T, str, Dict[str, Any], type[T], Callable[[], T], None],
|
|
140
|
+
*,
|
|
141
|
+
expected_type: type[T],
|
|
142
|
+
spec_name: str,
|
|
143
|
+
default_factory: Callable[[], T] | None = None,
|
|
144
|
+
allow_none: bool = False,
|
|
145
|
+
optional_defaults: Optional[OptionalDefaults] = None,
|
|
146
|
+
dict_requires_type: bool = True,
|
|
147
|
+
dict_default_cls: type[T] | None = None,
|
|
148
|
+
type_error_fmt: str | None = None,
|
|
149
|
+
invalid_spec_error_fmt: str | None = None,
|
|
150
|
+
registry: Optional[Dict[str, str]] = None,
|
|
151
|
+
) -> T | None:
|
|
152
|
+
"""Build and return a component instance from a flexible specification.
|
|
153
|
+
|
|
154
|
+
This function provides a flexible way to create component instances from various
|
|
155
|
+
input formats including direct instances, class types, factory functions, import
|
|
156
|
+
paths, or configuration dictionaries.
|
|
157
|
+
|
|
158
|
+
Args:
|
|
159
|
+
spec: The component specification. Can be:
|
|
160
|
+
- An instance of expected_type (returned as-is)
|
|
161
|
+
- A string import path (e.g., 'module.Class') or registry key
|
|
162
|
+
- A dict with 'type' key (import path or registry key) and constructor kwargs
|
|
163
|
+
- A class type (will be instantiated)
|
|
164
|
+
- A factory function (will be called)
|
|
165
|
+
- None (uses default_factory or returns None if allow_none=True)
|
|
166
|
+
expected_type: The type that the resulting instance must be or inherit from.
|
|
167
|
+
spec_name: Descriptive name for the spec, used in error messages.
|
|
168
|
+
default_factory: Optional factory function called when spec is None.
|
|
169
|
+
allow_none: If True, allows None to be returned when spec is None and
|
|
170
|
+
no default_factory is provided.
|
|
171
|
+
optional_defaults: Dict mapping parameter names to default values or factory
|
|
172
|
+
functions that will be injected if the constructor accepts them.
|
|
173
|
+
dict_requires_type: If True, dict specs must include a 'type' key.
|
|
174
|
+
dict_default_cls: Default class to use for dict specs without a 'type' key
|
|
175
|
+
(only used when dict_requires_type=False).
|
|
176
|
+
type_error_fmt: Custom format string for type validation errors. Should include
|
|
177
|
+
{type_name} and {expected_type} placeholders.
|
|
178
|
+
invalid_spec_error_fmt: Custom format string for invalid spec type errors.
|
|
179
|
+
Should include {actual_type} and {expected_type} placeholders.
|
|
180
|
+
registry: Optional mapping of short names to fully qualified import paths.
|
|
181
|
+
When provided, string specs or dict 'type'/'name' entries are first
|
|
182
|
+
resolved through this registry before attempting to import.
|
|
183
|
+
|
|
184
|
+
Returns:
|
|
185
|
+
An instance of expected_type, or None if allow_none=True and spec is None
|
|
186
|
+
without a default_factory.
|
|
187
|
+
|
|
188
|
+
Raises:
|
|
189
|
+
TypeError: If the instantiated object is not an instance of expected_type.
|
|
190
|
+
ValueError: If spec is None and neither default_factory nor allow_none is set,
|
|
191
|
+
or if spec type is invalid, or if dict spec is invalid.
|
|
192
|
+
|
|
193
|
+
Examples:
|
|
194
|
+
>>> # Direct instance
|
|
195
|
+
>>> optimizer = build_component(AdamW(), expected_type=Optimizer, spec_name='optimizer')
|
|
196
|
+
>>>
|
|
197
|
+
>>> # String import path
|
|
198
|
+
>>> optimizer = build_component('torch.optim.AdamW', expected_type=Optimizer, spec_name='optimizer')
|
|
199
|
+
>>>
|
|
200
|
+
>>> # Dict with type and kwargs
|
|
201
|
+
>>> spec = {'type': 'torch.optim.AdamW', 'lr': 0.001}
|
|
202
|
+
>>> optimizer = build_component(spec, expected_type=Optimizer, spec_name='optimizer')
|
|
203
|
+
>>>
|
|
204
|
+
>>> # Class type
|
|
205
|
+
>>> optimizer = build_component(AdamW, expected_type=Optimizer, spec_name='optimizer')
|
|
206
|
+
>>>
|
|
207
|
+
>>> # Factory function
|
|
208
|
+
>>> optimizer = build_component(lambda: AdamW(lr=0.001), expected_type=Optimizer,
|
|
209
|
+
... spec_name='optimizer')
|
|
210
|
+
"""
|
|
211
|
+
if isinstance(spec, expected_type):
|
|
212
|
+
return cast(T, spec)
|
|
213
|
+
|
|
214
|
+
if spec is None:
|
|
215
|
+
if default_factory is not None:
|
|
216
|
+
instance = default_factory()
|
|
217
|
+
return _ensure_expected_type(instance, expected_type, spec_name, type_error_fmt)
|
|
218
|
+
if allow_none:
|
|
219
|
+
return None
|
|
220
|
+
raise ValueError(
|
|
221
|
+
invalid_spec_error_fmt.format(actual_type=type(spec), expected_type=expected_type.__name__)
|
|
222
|
+
if invalid_spec_error_fmt
|
|
223
|
+
else f"{spec_name} cannot be None."
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
if isinstance(spec, type) and issubclass(spec, expected_type):
|
|
227
|
+
instance = instantiate_component(spec, optional_defaults=optional_defaults)
|
|
228
|
+
return _ensure_expected_type(instance, expected_type, spec_name, type_error_fmt)
|
|
229
|
+
|
|
230
|
+
if callable(spec) and not isinstance(spec, type): # type: ignore
|
|
231
|
+
instance = spec()
|
|
232
|
+
return _ensure_expected_type(instance, expected_type, spec_name, type_error_fmt)
|
|
233
|
+
|
|
234
|
+
if isinstance(spec, str):
|
|
235
|
+
instance = instantiate_from_spec(
|
|
236
|
+
spec,
|
|
237
|
+
spec_name=spec_name,
|
|
238
|
+
optional_defaults=optional_defaults,
|
|
239
|
+
dict_requires_type=dict_requires_type,
|
|
240
|
+
dict_default_cls=dict_default_cls,
|
|
241
|
+
registry=registry,
|
|
242
|
+
)
|
|
243
|
+
return _ensure_expected_type(instance, expected_type, spec_name, type_error_fmt)
|
|
244
|
+
|
|
245
|
+
if isinstance(spec, dict):
|
|
246
|
+
instance = instantiate_from_spec(
|
|
247
|
+
spec, # type: ignore
|
|
248
|
+
spec_name=spec_name,
|
|
249
|
+
optional_defaults=optional_defaults,
|
|
250
|
+
dict_requires_type=dict_requires_type,
|
|
251
|
+
dict_default_cls=dict_default_cls,
|
|
252
|
+
registry=registry,
|
|
253
|
+
)
|
|
254
|
+
return _ensure_expected_type(instance, expected_type, spec_name, type_error_fmt)
|
|
255
|
+
|
|
256
|
+
if invalid_spec_error_fmt:
|
|
257
|
+
raise ValueError(invalid_spec_error_fmt.format(actual_type=type(spec), expected_type=expected_type.__name__)) # type: ignore
|
|
258
|
+
|
|
259
|
+
type_name = str(type(spec)) # type: ignore
|
|
260
|
+
raise ValueError(f"Invalid {spec_name} type: {type_name}. Expected {expected_type.__name__}, str, dict, or None.")
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
__all__ = ["OptionalDefaults", "build_component", "instantiate_component", "instantiate_from_spec", "load_class"]
|
|
@@ -0,0 +1,367 @@
|
|
|
1
|
+
# Copyright (c) Microsoft. All rights reserved.
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import logging
|
|
5
|
+
import multiprocessing
|
|
6
|
+
import signal
|
|
7
|
+
import time
|
|
8
|
+
import warnings
|
|
9
|
+
from typing import Any, List, Optional, TypeVar, Union
|
|
10
|
+
|
|
11
|
+
from mantisdk.adapter import TraceAdapter, TracerTraceToTriplet
|
|
12
|
+
from mantisdk.algorithm import Algorithm
|
|
13
|
+
from mantisdk.client import MantisdkClient
|
|
14
|
+
from mantisdk.litagent import LitAgent
|
|
15
|
+
from mantisdk.runner import LegacyAgentRunner
|
|
16
|
+
from mantisdk.tracer.base import Tracer
|
|
17
|
+
from mantisdk.types import Dataset, ParallelWorkerBase
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
T_co = TypeVar("T_co", covariant=True)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class TrainerLegacy(ParallelWorkerBase):
|
|
25
|
+
"""Trainer for legacy mode for v0.1 compatibility."""
|
|
26
|
+
|
|
27
|
+
def __init__(self, *args: Any, **kwargs: Any):
|
|
28
|
+
"""Initialize the TrainerLegacy.
|
|
29
|
+
|
|
30
|
+
This method is mainly to make type checker happy.
|
|
31
|
+
It won't be used in practice.
|
|
32
|
+
"""
|
|
33
|
+
self._dev = kwargs.pop("dev", False)
|
|
34
|
+
self.algorithm: Optional[Algorithm] = kwargs.pop("algorithm", None)
|
|
35
|
+
self.tracer: Tracer = kwargs.pop("tracer", None)
|
|
36
|
+
self.n_workers: int = kwargs.pop("n_workers", None)
|
|
37
|
+
self.max_tasks: Optional[int] = kwargs.pop("max_tasks", None)
|
|
38
|
+
self.daemon: bool = kwargs.pop("daemon", True)
|
|
39
|
+
self.triplet_exporter: TraceAdapter[Any] = kwargs.pop("triplet_exporter", None)
|
|
40
|
+
|
|
41
|
+
def _extract_client_from_data(
|
|
42
|
+
self, data: Union[str, MantisdkClient, Dataset[Any]]
|
|
43
|
+
) -> Optional[MantisdkClient]:
|
|
44
|
+
"""Extract client from data if it's a string URL or MantisdkClient."""
|
|
45
|
+
if isinstance(data, str):
|
|
46
|
+
if not data.startswith("http://") and not data.startswith("https://"):
|
|
47
|
+
raise ValueError("String data must be a valid URL starting with http:// or https://")
|
|
48
|
+
return MantisdkClient(endpoint=data)
|
|
49
|
+
elif isinstance(data, MantisdkClient):
|
|
50
|
+
return data
|
|
51
|
+
return None
|
|
52
|
+
|
|
53
|
+
def _extract_dataset_from_data(
|
|
54
|
+
self, data: Union[str, MantisdkClient, Dataset[Any]]
|
|
55
|
+
) -> Optional[Dataset[Any]]:
|
|
56
|
+
"""Extract dataset from data if it's a Dataset."""
|
|
57
|
+
if isinstance(data, str) or isinstance(data, MantisdkClient):
|
|
58
|
+
return None
|
|
59
|
+
return data
|
|
60
|
+
|
|
61
|
+
def _determine_backend(
|
|
62
|
+
self,
|
|
63
|
+
train_data: Union[str, MantisdkClient, Dataset[Any]],
|
|
64
|
+
dev_data: Union[str, MantisdkClient, Dataset[Any], None] = None,
|
|
65
|
+
) -> Union[str, MantisdkClient]:
|
|
66
|
+
"""Determine which backend to use for initialization."""
|
|
67
|
+
if self._dev:
|
|
68
|
+
if dev_data is None:
|
|
69
|
+
raise ValueError("dev_data must be provided when dev=True.")
|
|
70
|
+
client = self._extract_client_from_data(dev_data)
|
|
71
|
+
if client is None:
|
|
72
|
+
raise ValueError("dev_data must be a string URL or MantisdkClient when dev=True.")
|
|
73
|
+
return client
|
|
74
|
+
else:
|
|
75
|
+
client = self._extract_client_from_data(train_data)
|
|
76
|
+
if client is None and self.algorithm is None:
|
|
77
|
+
raise ValueError(
|
|
78
|
+
"train_data must be a string URL or MantisdkClient when no algorithm is provided."
|
|
79
|
+
)
|
|
80
|
+
elif client is None and self.algorithm is not None:
|
|
81
|
+
# Algorithm will be responsible for creating the client
|
|
82
|
+
client = self.algorithm.get_client()
|
|
83
|
+
logger.info(f"Algorithm created client: {client}")
|
|
84
|
+
return client
|
|
85
|
+
if client is None:
|
|
86
|
+
raise ValueError(
|
|
87
|
+
"train_data must be a string URL or MantisdkClient when no algorithm is provided."
|
|
88
|
+
)
|
|
89
|
+
return client
|
|
90
|
+
|
|
91
|
+
def init(self, backend: Union[str, MantisdkClient]) -> None:
|
|
92
|
+
logger.info(f"Initializing Trainer...")
|
|
93
|
+
|
|
94
|
+
self._init_client(backend)
|
|
95
|
+
|
|
96
|
+
self.tracer.init()
|
|
97
|
+
|
|
98
|
+
logger.info(f"Trainer main initialization complete.")
|
|
99
|
+
|
|
100
|
+
def teardown(self) -> None:
|
|
101
|
+
logger.info(f"Cleaning up Trainer...")
|
|
102
|
+
self.tracer.teardown()
|
|
103
|
+
|
|
104
|
+
self._client = None
|
|
105
|
+
logger.info(f"Trainer main cleanup complete.")
|
|
106
|
+
|
|
107
|
+
def client(self) -> MantisdkClient:
|
|
108
|
+
"""Returns the MantisdkClient instance."""
|
|
109
|
+
if self._client is None:
|
|
110
|
+
raise RuntimeError("MantisdkClient has not been initialized. Call `init` first.")
|
|
111
|
+
return self._client
|
|
112
|
+
|
|
113
|
+
def _init_client(self, backend: Union[str, MantisdkClient]) -> MantisdkClient:
|
|
114
|
+
if self._client is None:
|
|
115
|
+
if isinstance(backend, MantisdkClient):
|
|
116
|
+
logger.info("Using provided MantisdkClient instance.")
|
|
117
|
+
self._client = backend
|
|
118
|
+
else:
|
|
119
|
+
logger.info(f"Initializing MantisdkClient with endpoint: {backend}")
|
|
120
|
+
if not isinstance(backend, str): # type: ignore
|
|
121
|
+
raise ValueError("backend must be a string URL or an MantisdkClient instance.")
|
|
122
|
+
if not backend.startswith("http://") and not backend.startswith("https://"):
|
|
123
|
+
raise ValueError("backend must be a valid URL starting with http:// or https://")
|
|
124
|
+
# Initialize the client with the provided backend URL
|
|
125
|
+
self._client = MantisdkClient(endpoint=backend)
|
|
126
|
+
else:
|
|
127
|
+
logger.warning("MantisdkClient already initialized. Returning existing instance.")
|
|
128
|
+
return self._client
|
|
129
|
+
|
|
130
|
+
def _worker_main_loop(self, agent: LitAgent[Any], worker_id: int, is_async: bool):
|
|
131
|
+
"""The main function for each worker process.
|
|
132
|
+
|
|
133
|
+
This function initializes the client and the loop, then starts the
|
|
134
|
+
execution. It also configures process-specific settings like the
|
|
135
|
+
process title and signal handling.
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
agent: The `LitAgent` instance to run.
|
|
139
|
+
worker_id: The unique ID for this worker.
|
|
140
|
+
is_async: A boolean indicating if the async loop should be run.
|
|
141
|
+
"""
|
|
142
|
+
if self.n_workers > 1:
|
|
143
|
+
import setproctitle
|
|
144
|
+
|
|
145
|
+
# Ignore Ctrl+C in worker processes; the main process handles it
|
|
146
|
+
signal.signal(signal.SIGINT, signal.SIG_IGN)
|
|
147
|
+
setproctitle.setproctitle(multiprocessing.current_process().name)
|
|
148
|
+
|
|
149
|
+
# Now we are in child processes, so we can safely set up the environment.
|
|
150
|
+
agent.set_trainer(self) # type: ignore
|
|
151
|
+
if not isinstance(self.triplet_exporter, TracerTraceToTriplet): # type: ignore
|
|
152
|
+
raise ValueError("triplet_exporter must be a TracerTraceToTriplet for the legacy trainer.")
|
|
153
|
+
# TODO: this should be set elsewhere
|
|
154
|
+
if agent.trained_agents:
|
|
155
|
+
self.triplet_exporter.agent_match = agent.trained_agents
|
|
156
|
+
self._initialize_worker_env(worker_id)
|
|
157
|
+
|
|
158
|
+
mode = "Async" if is_async else "Sync"
|
|
159
|
+
logger.info(f"[Worker {worker_id}] {mode} worker process started.")
|
|
160
|
+
|
|
161
|
+
num_processed = 0
|
|
162
|
+
|
|
163
|
+
try:
|
|
164
|
+
client = self.client()
|
|
165
|
+
loop = LegacyAgentRunner(
|
|
166
|
+
agent=agent,
|
|
167
|
+
client=client,
|
|
168
|
+
tracer=self.tracer,
|
|
169
|
+
triplet_exporter=self.triplet_exporter,
|
|
170
|
+
max_tasks=self.max_tasks,
|
|
171
|
+
worker_id=worker_id,
|
|
172
|
+
)
|
|
173
|
+
loop.init_worker(worker_id) # type: ignore
|
|
174
|
+
if is_async:
|
|
175
|
+
num_processed = asyncio.run(loop.iter_async())
|
|
176
|
+
else:
|
|
177
|
+
num_processed = loop.iter()
|
|
178
|
+
except Exception:
|
|
179
|
+
logger.exception(f"[Worker {worker_id}] Unhandled exception in worker loop.")
|
|
180
|
+
finally:
|
|
181
|
+
self._teardown_worker_env(worker_id)
|
|
182
|
+
|
|
183
|
+
return num_processed
|
|
184
|
+
|
|
185
|
+
def _initialize_worker_env(self, worker_id: int):
|
|
186
|
+
logger.info(f"[Worker {worker_id}] Setting up trainer environment...") # worker_id included in process name
|
|
187
|
+
self.tracer.init_worker(worker_id)
|
|
188
|
+
|
|
189
|
+
def _teardown_worker_env(self, worker_id: int):
|
|
190
|
+
logger.info(f"[Worker {worker_id}] Cleaning up trainer environment...")
|
|
191
|
+
self.tracer.teardown_worker(worker_id)
|
|
192
|
+
logger.info(f"[Worker {worker_id}] Environment cleanup complete.")
|
|
193
|
+
|
|
194
|
+
@staticmethod
|
|
195
|
+
def kill_orphaned_processes() -> None:
|
|
196
|
+
"""
|
|
197
|
+
Kill any orphaned processes that may have been left behind by previous runs.
|
|
198
|
+
This is useful for cleaning up after crashes or unexpected exits.
|
|
199
|
+
"""
|
|
200
|
+
import psutil
|
|
201
|
+
|
|
202
|
+
for proc in psutil.process_iter(): # type: ignore
|
|
203
|
+
# check whether the process name matches
|
|
204
|
+
if proc.name().startswith("Mantisdk-"):
|
|
205
|
+
proc.kill()
|
|
206
|
+
|
|
207
|
+
def _terminate_processes(self, processes: List[multiprocessing.Process]) -> None:
|
|
208
|
+
if self.n_workers > 1 and len(processes) > 0:
|
|
209
|
+
for i, p in enumerate(processes):
|
|
210
|
+
if p.is_alive():
|
|
211
|
+
logger.info(f"Terminating worker {i} (name: {p.name}, PID: {p.pid})...")
|
|
212
|
+
p.terminate()
|
|
213
|
+
else:
|
|
214
|
+
logger.info(f"Worker {i} (name: {p.name}, PID: {p.pid}) is not alive or has already terminated.")
|
|
215
|
+
for i, p in enumerate(processes):
|
|
216
|
+
if p.is_alive():
|
|
217
|
+
p.join(timeout=10) # Give some time to terminate
|
|
218
|
+
if p.is_alive(): # If still alive, kill
|
|
219
|
+
logger.warning(
|
|
220
|
+
f"Worker {i} (name: {p.name}, PID: {p.pid}) did not terminate gracefully, killing..."
|
|
221
|
+
)
|
|
222
|
+
p.kill()
|
|
223
|
+
p.join(timeout=10) # Ensure it's reaped
|
|
224
|
+
|
|
225
|
+
def fit_v0(
|
|
226
|
+
self,
|
|
227
|
+
agent: LitAgent[T_co],
|
|
228
|
+
train_data: Union[str, MantisdkClient, Dataset[T_co]],
|
|
229
|
+
*,
|
|
230
|
+
val_data: Union[str, MantisdkClient, Dataset[T_co], None] = None,
|
|
231
|
+
dev_data: Union[str, MantisdkClient, Dataset[T_co], None] = None,
|
|
232
|
+
dev_backend: Union[str, MantisdkClient, None] = None,
|
|
233
|
+
):
|
|
234
|
+
"""Train the agent using the provided data.
|
|
235
|
+
|
|
236
|
+
Each data argument can be a string URL connecting to a mantisdk server,
|
|
237
|
+
or an MantisdkClient instance connecting to a server (or mock server), or a dataset.
|
|
238
|
+
If no algorithm is provided when instantiating the trainer, the data must be
|
|
239
|
+
provided to connecting a server. Otherwise, dataset is also allowed and will be
|
|
240
|
+
passed to the algorithm.
|
|
241
|
+
|
|
242
|
+
If the algorithm is instantiated and there is no URL/client provided,
|
|
243
|
+
the algorithm will be responsible for creating a client that will connect to itself.
|
|
244
|
+
It can also create a mock client if the algorithm does not require a server.
|
|
245
|
+
"""
|
|
246
|
+
|
|
247
|
+
if dev_backend is not None:
|
|
248
|
+
warnings.warn("dev_backend is deprecated. Use dev_data instead.")
|
|
249
|
+
if dev_data is not None:
|
|
250
|
+
raise ValueError("dev_data and dev_backend cannot be provided at the same time.")
|
|
251
|
+
dev_data = dev_backend
|
|
252
|
+
|
|
253
|
+
# Extract datasets for algorithm if available
|
|
254
|
+
train_dataset = self._extract_dataset_from_data(train_data)
|
|
255
|
+
val_dataset = self._extract_dataset_from_data(val_data) if val_data else None
|
|
256
|
+
|
|
257
|
+
# Initialize the algorithm with trainer if provided
|
|
258
|
+
if self.algorithm is not None:
|
|
259
|
+
self.algorithm.set_trainer(self) # type: ignore
|
|
260
|
+
# DO NOT RUN TRAINING HERE. Need to spawn the worker first.
|
|
261
|
+
|
|
262
|
+
# Determine the backend to use for client-server mode
|
|
263
|
+
backend = self._determine_backend(train_data, dev_data)
|
|
264
|
+
|
|
265
|
+
if self._dev:
|
|
266
|
+
logger.warning(f"Running in dev mode. Using dev backend: {backend}")
|
|
267
|
+
else:
|
|
268
|
+
logger.debug(f"Running in non-dev mode. Using backend: {backend}")
|
|
269
|
+
|
|
270
|
+
self.init(backend)
|
|
271
|
+
|
|
272
|
+
processes: List[multiprocessing.Process] = []
|
|
273
|
+
|
|
274
|
+
# Determine if the agent is asynchronous
|
|
275
|
+
|
|
276
|
+
mode = "asynchronous" if agent.is_async() else "synchronous"
|
|
277
|
+
|
|
278
|
+
try:
|
|
279
|
+
if self.n_workers == 1:
|
|
280
|
+
logger.info(f"Running with n_workers=1 ({mode} in main process).")
|
|
281
|
+
|
|
282
|
+
# Warn if algorithm is set with single worker mode
|
|
283
|
+
if self.algorithm is not None:
|
|
284
|
+
logger.warning(
|
|
285
|
+
"Algorithm is set but using single worker mode. Algorithm will never get the chance to run."
|
|
286
|
+
)
|
|
287
|
+
# Ideally the single worker should be run in a separate thread or process.
|
|
288
|
+
|
|
289
|
+
num_tasks = self._worker_main_loop(agent, 0, agent.is_async())
|
|
290
|
+
logger.info(f"Single worker mode finished. Tasks processed: {num_tasks}")
|
|
291
|
+
|
|
292
|
+
# If algorithm is provided and we have datasets, run algorithm after worker completes
|
|
293
|
+
if self.algorithm is not None and train_dataset is not None:
|
|
294
|
+
logger.info("Running algorithm training after worker completion.")
|
|
295
|
+
self.algorithm.run(
|
|
296
|
+
train_dataset=train_dataset,
|
|
297
|
+
val_dataset=val_dataset,
|
|
298
|
+
)
|
|
299
|
+
else:
|
|
300
|
+
logger.info(f"Running with n_workers={self.n_workers} ({mode} multiprocessing).")
|
|
301
|
+
for i in range(self.n_workers):
|
|
302
|
+
process_name = f"Mantisdk-Worker-{i}"
|
|
303
|
+
p = multiprocessing.Process(
|
|
304
|
+
target=self._worker_main_loop,
|
|
305
|
+
args=(agent, i, agent.is_async()),
|
|
306
|
+
daemon=self.daemon,
|
|
307
|
+
name=process_name,
|
|
308
|
+
)
|
|
309
|
+
processes.append(p)
|
|
310
|
+
logger.info(f"Starting worker process {i} (name: {process_name})...")
|
|
311
|
+
p.start()
|
|
312
|
+
|
|
313
|
+
if self.daemon:
|
|
314
|
+
# If algorithm is provided and we have datasets, pass them to the algorithm
|
|
315
|
+
if self.algorithm is not None:
|
|
316
|
+
logger.info("All workers have been spawned. Running algorithm training with provided datasets.")
|
|
317
|
+
self.algorithm.run(
|
|
318
|
+
train_dataset=train_dataset,
|
|
319
|
+
val_dataset=val_dataset,
|
|
320
|
+
)
|
|
321
|
+
logger.info("Algorithm exits. Killing the workers.")
|
|
322
|
+
self._terminate_processes(processes)
|
|
323
|
+
|
|
324
|
+
for i, p in enumerate(processes):
|
|
325
|
+
p.join() # Wait for the process to complete
|
|
326
|
+
logger.info(
|
|
327
|
+
f"Worker process {i} (name: {p.name}, PID: {p.pid}) joined with exit code {p.exitcode}."
|
|
328
|
+
)
|
|
329
|
+
if p.exitcode != 0:
|
|
330
|
+
logger.warning(
|
|
331
|
+
f"Worker process {i} (name: {p.name}, PID: {p.pid}) exited with non-zero code: {p.exitcode}."
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
logger.info(f"All {self.n_workers} worker processes have completed.")
|
|
335
|
+
else:
|
|
336
|
+
logger.info("All worker processes started. Main process will not wait.")
|
|
337
|
+
|
|
338
|
+
# A hack to stop the main process from waiting for child processes to finish.
|
|
339
|
+
time.sleep(1) # Give workers time to start
|
|
340
|
+
import multiprocessing.process as multiprocessing_process
|
|
341
|
+
|
|
342
|
+
multiprocessing_process._children.clear() # type: ignore
|
|
343
|
+
|
|
344
|
+
if self.algorithm is not None:
|
|
345
|
+
logger.info("Main process continues to run algorithm.")
|
|
346
|
+
self.algorithm.run(
|
|
347
|
+
train_dataset=train_dataset,
|
|
348
|
+
val_dataset=val_dataset,
|
|
349
|
+
)
|
|
350
|
+
logger.info("Algorithm exits. Killing the workers.")
|
|
351
|
+
self._terminate_processes(processes)
|
|
352
|
+
|
|
353
|
+
except KeyboardInterrupt:
|
|
354
|
+
logger.info("KeyboardInterrupt received. Killing the workers.")
|
|
355
|
+
self._terminate_processes(processes)
|
|
356
|
+
logger.info(f"Workers terminated or single worker interrupted.")
|
|
357
|
+
raise
|
|
358
|
+
except Exception:
|
|
359
|
+
logger.exception(f"Unhandled exception in fit method.")
|
|
360
|
+
self._terminate_processes(processes)
|
|
361
|
+
logger.info(f"Workers terminated or single worker interrupted.")
|
|
362
|
+
raise
|
|
363
|
+
finally:
|
|
364
|
+
if self.daemon:
|
|
365
|
+
self.teardown()
|
|
366
|
+
else:
|
|
367
|
+
logger.info("Main process exiting. Please use Trainer.kill_orphaned_processes() for cleanup.")
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
# Copyright (c) Microsoft. All rights reserved.
|
|
2
|
+
|
|
3
|
+
"""Put components in this file to make them available to the Trainer.
|
|
4
|
+
|
|
5
|
+
Currently only used for ExecutionStrategy.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
ExecutionStrategyRegistry = {
|
|
9
|
+
"shm": "mantisdk.execution.shared_memory.SharedMemoryExecutionStrategy",
|
|
10
|
+
# "ipc": "mantisdk.execution.inter_process.InterProcessExecutionStrategy",
|
|
11
|
+
"cs": "mantisdk.execution.client_server.ClientServerExecutionStrategy",
|
|
12
|
+
}
|