mantisdk 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mantisdk might be problematic. Click here for more details.

Files changed (190) hide show
  1. mantisdk/__init__.py +22 -0
  2. mantisdk/adapter/__init__.py +15 -0
  3. mantisdk/adapter/base.py +94 -0
  4. mantisdk/adapter/messages.py +270 -0
  5. mantisdk/adapter/triplet.py +1028 -0
  6. mantisdk/algorithm/__init__.py +39 -0
  7. mantisdk/algorithm/apo/__init__.py +5 -0
  8. mantisdk/algorithm/apo/apo.py +889 -0
  9. mantisdk/algorithm/apo/prompts/apply_edit_variant01.poml +22 -0
  10. mantisdk/algorithm/apo/prompts/apply_edit_variant02.poml +18 -0
  11. mantisdk/algorithm/apo/prompts/text_gradient_variant01.poml +18 -0
  12. mantisdk/algorithm/apo/prompts/text_gradient_variant02.poml +16 -0
  13. mantisdk/algorithm/apo/prompts/text_gradient_variant03.poml +107 -0
  14. mantisdk/algorithm/base.py +162 -0
  15. mantisdk/algorithm/decorator.py +264 -0
  16. mantisdk/algorithm/fast.py +250 -0
  17. mantisdk/algorithm/gepa/__init__.py +59 -0
  18. mantisdk/algorithm/gepa/adapter.py +459 -0
  19. mantisdk/algorithm/gepa/gepa.py +364 -0
  20. mantisdk/algorithm/gepa/lib/__init__.py +18 -0
  21. mantisdk/algorithm/gepa/lib/adapters/README.md +12 -0
  22. mantisdk/algorithm/gepa/lib/adapters/__init__.py +0 -0
  23. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/README.md +341 -0
  24. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/__init__.py +1 -0
  25. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/anymaths_adapter.py +174 -0
  26. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/requirements.txt +1 -0
  27. mantisdk/algorithm/gepa/lib/adapters/default_adapter/README.md +0 -0
  28. mantisdk/algorithm/gepa/lib/adapters/default_adapter/__init__.py +0 -0
  29. mantisdk/algorithm/gepa/lib/adapters/default_adapter/default_adapter.py +209 -0
  30. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/README.md +7 -0
  31. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/__init__.py +0 -0
  32. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/dspy_adapter.py +307 -0
  33. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/README.md +99 -0
  34. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/dspy_program_proposal_signature.py +137 -0
  35. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/full_program_adapter.py +266 -0
  36. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/GEPA_RAG.md +621 -0
  37. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/__init__.py +56 -0
  38. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/evaluation_metrics.py +226 -0
  39. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/generic_rag_adapter.py +496 -0
  40. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/rag_pipeline.py +238 -0
  41. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_store_interface.py +212 -0
  42. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/__init__.py +2 -0
  43. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/chroma_store.py +196 -0
  44. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/lancedb_store.py +422 -0
  45. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/milvus_store.py +409 -0
  46. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/qdrant_store.py +368 -0
  47. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/weaviate_store.py +418 -0
  48. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/README.md +552 -0
  49. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/__init__.py +37 -0
  50. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_adapter.py +705 -0
  51. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_client.py +364 -0
  52. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/README.md +9 -0
  53. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/__init__.py +0 -0
  54. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/terminal_bench_adapter.py +217 -0
  55. mantisdk/algorithm/gepa/lib/api.py +375 -0
  56. mantisdk/algorithm/gepa/lib/core/__init__.py +0 -0
  57. mantisdk/algorithm/gepa/lib/core/adapter.py +180 -0
  58. mantisdk/algorithm/gepa/lib/core/data_loader.py +74 -0
  59. mantisdk/algorithm/gepa/lib/core/engine.py +356 -0
  60. mantisdk/algorithm/gepa/lib/core/result.py +233 -0
  61. mantisdk/algorithm/gepa/lib/core/state.py +636 -0
  62. mantisdk/algorithm/gepa/lib/examples/__init__.py +0 -0
  63. mantisdk/algorithm/gepa/lib/examples/aime.py +24 -0
  64. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/eval_default.py +111 -0
  65. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/instruction_prompt.txt +9 -0
  66. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/optimal_prompt.txt +24 -0
  67. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/train_anymaths.py +177 -0
  68. mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/arc_agi.ipynb +25705 -0
  69. mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/example.ipynb +348 -0
  70. mantisdk/algorithm/gepa/lib/examples/mcp_adapter/__init__.py +4 -0
  71. mantisdk/algorithm/gepa/lib/examples/mcp_adapter/mcp_optimization_example.py +455 -0
  72. mantisdk/algorithm/gepa/lib/examples/rag_adapter/RAG_GUIDE.md +613 -0
  73. mantisdk/algorithm/gepa/lib/examples/rag_adapter/__init__.py +9 -0
  74. mantisdk/algorithm/gepa/lib/examples/rag_adapter/rag_optimization.py +824 -0
  75. mantisdk/algorithm/gepa/lib/examples/rag_adapter/requirements-rag.txt +29 -0
  76. mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/instruction_prompt.txt +16 -0
  77. mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/terminus.txt +9 -0
  78. mantisdk/algorithm/gepa/lib/examples/terminal-bench/train_terminus.py +161 -0
  79. mantisdk/algorithm/gepa/lib/gepa_utils.py +117 -0
  80. mantisdk/algorithm/gepa/lib/logging/__init__.py +0 -0
  81. mantisdk/algorithm/gepa/lib/logging/experiment_tracker.py +187 -0
  82. mantisdk/algorithm/gepa/lib/logging/logger.py +75 -0
  83. mantisdk/algorithm/gepa/lib/logging/utils.py +103 -0
  84. mantisdk/algorithm/gepa/lib/proposer/__init__.py +0 -0
  85. mantisdk/algorithm/gepa/lib/proposer/base.py +31 -0
  86. mantisdk/algorithm/gepa/lib/proposer/merge.py +357 -0
  87. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/__init__.py +0 -0
  88. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/base.py +49 -0
  89. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/reflective_mutation.py +176 -0
  90. mantisdk/algorithm/gepa/lib/py.typed +0 -0
  91. mantisdk/algorithm/gepa/lib/strategies/__init__.py +0 -0
  92. mantisdk/algorithm/gepa/lib/strategies/batch_sampler.py +77 -0
  93. mantisdk/algorithm/gepa/lib/strategies/candidate_selector.py +50 -0
  94. mantisdk/algorithm/gepa/lib/strategies/component_selector.py +36 -0
  95. mantisdk/algorithm/gepa/lib/strategies/eval_policy.py +64 -0
  96. mantisdk/algorithm/gepa/lib/strategies/instruction_proposal.py +127 -0
  97. mantisdk/algorithm/gepa/lib/utils/__init__.py +10 -0
  98. mantisdk/algorithm/gepa/lib/utils/stop_condition.py +196 -0
  99. mantisdk/algorithm/gepa/tracing.py +105 -0
  100. mantisdk/algorithm/utils.py +177 -0
  101. mantisdk/algorithm/verl/__init__.py +5 -0
  102. mantisdk/algorithm/verl/interface.py +202 -0
  103. mantisdk/cli/__init__.py +56 -0
  104. mantisdk/cli/prometheus.py +115 -0
  105. mantisdk/cli/store.py +131 -0
  106. mantisdk/cli/vllm.py +29 -0
  107. mantisdk/client.py +408 -0
  108. mantisdk/config.py +348 -0
  109. mantisdk/emitter/__init__.py +43 -0
  110. mantisdk/emitter/annotation.py +370 -0
  111. mantisdk/emitter/exception.py +54 -0
  112. mantisdk/emitter/message.py +61 -0
  113. mantisdk/emitter/object.py +117 -0
  114. mantisdk/emitter/reward.py +320 -0
  115. mantisdk/env_var.py +156 -0
  116. mantisdk/execution/__init__.py +15 -0
  117. mantisdk/execution/base.py +64 -0
  118. mantisdk/execution/client_server.py +443 -0
  119. mantisdk/execution/events.py +69 -0
  120. mantisdk/execution/inter_process.py +16 -0
  121. mantisdk/execution/shared_memory.py +282 -0
  122. mantisdk/instrumentation/__init__.py +119 -0
  123. mantisdk/instrumentation/agentops.py +314 -0
  124. mantisdk/instrumentation/agentops_langchain.py +45 -0
  125. mantisdk/instrumentation/litellm.py +83 -0
  126. mantisdk/instrumentation/vllm.py +81 -0
  127. mantisdk/instrumentation/weave.py +500 -0
  128. mantisdk/litagent/__init__.py +11 -0
  129. mantisdk/litagent/decorator.py +536 -0
  130. mantisdk/litagent/litagent.py +252 -0
  131. mantisdk/llm_proxy.py +1890 -0
  132. mantisdk/logging.py +370 -0
  133. mantisdk/reward.py +7 -0
  134. mantisdk/runner/__init__.py +11 -0
  135. mantisdk/runner/agent.py +845 -0
  136. mantisdk/runner/base.py +182 -0
  137. mantisdk/runner/legacy.py +309 -0
  138. mantisdk/semconv.py +170 -0
  139. mantisdk/server.py +401 -0
  140. mantisdk/store/__init__.py +23 -0
  141. mantisdk/store/base.py +897 -0
  142. mantisdk/store/client_server.py +2092 -0
  143. mantisdk/store/collection/__init__.py +30 -0
  144. mantisdk/store/collection/base.py +587 -0
  145. mantisdk/store/collection/memory.py +970 -0
  146. mantisdk/store/collection/mongo.py +1412 -0
  147. mantisdk/store/collection_based.py +1823 -0
  148. mantisdk/store/insight.py +648 -0
  149. mantisdk/store/listener.py +58 -0
  150. mantisdk/store/memory.py +396 -0
  151. mantisdk/store/mongo.py +165 -0
  152. mantisdk/store/sqlite.py +3 -0
  153. mantisdk/store/threading.py +357 -0
  154. mantisdk/store/utils.py +142 -0
  155. mantisdk/tracer/__init__.py +16 -0
  156. mantisdk/tracer/agentops.py +242 -0
  157. mantisdk/tracer/base.py +287 -0
  158. mantisdk/tracer/dummy.py +106 -0
  159. mantisdk/tracer/otel.py +555 -0
  160. mantisdk/tracer/weave.py +677 -0
  161. mantisdk/trainer/__init__.py +6 -0
  162. mantisdk/trainer/init_utils.py +263 -0
  163. mantisdk/trainer/legacy.py +367 -0
  164. mantisdk/trainer/registry.py +12 -0
  165. mantisdk/trainer/trainer.py +618 -0
  166. mantisdk/types/__init__.py +6 -0
  167. mantisdk/types/core.py +553 -0
  168. mantisdk/types/resources.py +204 -0
  169. mantisdk/types/tracer.py +515 -0
  170. mantisdk/types/tracing.py +218 -0
  171. mantisdk/utils/__init__.py +1 -0
  172. mantisdk/utils/id.py +18 -0
  173. mantisdk/utils/metrics.py +1025 -0
  174. mantisdk/utils/otel.py +578 -0
  175. mantisdk/utils/otlp.py +536 -0
  176. mantisdk/utils/server_launcher.py +1045 -0
  177. mantisdk/utils/system_snapshot.py +81 -0
  178. mantisdk/verl/__init__.py +8 -0
  179. mantisdk/verl/__main__.py +6 -0
  180. mantisdk/verl/async_server.py +46 -0
  181. mantisdk/verl/config.yaml +27 -0
  182. mantisdk/verl/daemon.py +1154 -0
  183. mantisdk/verl/dataset.py +44 -0
  184. mantisdk/verl/entrypoint.py +248 -0
  185. mantisdk/verl/trainer.py +549 -0
  186. mantisdk-0.1.0.dist-info/METADATA +119 -0
  187. mantisdk-0.1.0.dist-info/RECORD +190 -0
  188. mantisdk-0.1.0.dist-info/WHEEL +4 -0
  189. mantisdk-0.1.0.dist-info/entry_points.txt +2 -0
  190. mantisdk-0.1.0.dist-info/licenses/LICENSE +19 -0
@@ -0,0 +1,375 @@
1
+ # Copyright (c) 2025 Lakshya A Agrawal and the GEPA contributors
2
+ # https://github.com/gepa-ai/gepa
3
+
4
+ import os
5
+ import random
6
+ from collections.abc import Sequence
7
+ from typing import Any, Literal, cast
8
+
9
+ from mantisdk.algorithm.gepa.lib.adapters.default_adapter.default_adapter import (
10
+ ChatCompletionCallable,
11
+ DefaultAdapter,
12
+ Evaluator,
13
+ )
14
+ from mantisdk.algorithm.gepa.lib.core.adapter import DataInst, GEPAAdapter, RolloutOutput, Trajectory
15
+ from mantisdk.algorithm.gepa.lib.core.data_loader import DataId, DataLoader, ensure_loader
16
+ from mantisdk.algorithm.gepa.lib.core.engine import GEPAEngine
17
+ from mantisdk.algorithm.gepa.lib.core.result import GEPAResult
18
+ from mantisdk.algorithm.gepa.lib.core.state import EvaluationCache, FrontierType
19
+ from mantisdk.algorithm.gepa.lib.logging.experiment_tracker import create_experiment_tracker
20
+ from mantisdk.algorithm.gepa.lib.logging.logger import LoggerProtocol, StdOutLogger
21
+ from mantisdk.algorithm.gepa.lib.proposer.merge import MergeProposer
22
+ from mantisdk.algorithm.gepa.lib.proposer.reflective_mutation.base import CandidateSelector, LanguageModel, ReflectionComponentSelector
23
+ from mantisdk.algorithm.gepa.lib.proposer.reflective_mutation.reflective_mutation import ReflectiveMutationProposer
24
+ from mantisdk.algorithm.gepa.lib.strategies.batch_sampler import BatchSampler, EpochShuffledBatchSampler
25
+ from mantisdk.algorithm.gepa.lib.strategies.candidate_selector import (
26
+ CurrentBestCandidateSelector,
27
+ EpsilonGreedyCandidateSelector,
28
+ ParetoCandidateSelector,
29
+ )
30
+ from mantisdk.algorithm.gepa.lib.strategies.component_selector import (
31
+ AllReflectionComponentSelector,
32
+ RoundRobinReflectionComponentSelector,
33
+ )
34
+ from mantisdk.algorithm.gepa.lib.strategies.eval_policy import EvaluationPolicy, FullEvaluationPolicy
35
+ from mantisdk.algorithm.gepa.lib.utils import FileStopper, StopperProtocol
36
+
37
+
38
+ def optimize(
39
+ seed_candidate: dict[str, str],
40
+ trainset: list[DataInst] | DataLoader[DataId, DataInst],
41
+ valset: list[DataInst] | DataLoader[DataId, DataInst] | None = None,
42
+ adapter: GEPAAdapter[DataInst, Trajectory, RolloutOutput] | None = None,
43
+ task_lm: str | ChatCompletionCallable | None = None,
44
+ evaluator: Evaluator | None = None,
45
+ # Reflection-based configuration
46
+ reflection_lm: LanguageModel | str | None = None,
47
+ candidate_selection_strategy: CandidateSelector | Literal["pareto", "current_best", "epsilon_greedy"] = "pareto",
48
+ frontier_type: FrontierType = "instance",
49
+ skip_perfect_score: bool = True,
50
+ batch_sampler: BatchSampler | Literal["epoch_shuffled"] = "epoch_shuffled",
51
+ reflection_minibatch_size: int | None = None,
52
+ perfect_score: float = 1.0,
53
+ reflection_prompt_template: str | None = None,
54
+ # Component selection configuration
55
+ module_selector: ReflectionComponentSelector | str = "round_robin",
56
+ # Merge-based configuration
57
+ use_merge: bool = False,
58
+ max_merge_invocations: int = 5,
59
+ merge_val_overlap_floor: int = 5,
60
+ # Budget and Stop Condition
61
+ max_metric_calls: int | None = None,
62
+ stop_callbacks: StopperProtocol | Sequence[StopperProtocol] | None = None,
63
+ # Logging
64
+ logger: LoggerProtocol | None = None,
65
+ run_dir: str | None = None,
66
+ use_wandb: bool = False,
67
+ wandb_api_key: str | None = None,
68
+ wandb_init_kwargs: dict[str, Any] | None = None,
69
+ use_mlflow: bool = False,
70
+ mlflow_tracking_uri: str | None = None,
71
+ mlflow_experiment_name: str | None = None,
72
+ track_best_outputs: bool = False,
73
+ display_progress_bar: bool = False,
74
+ use_cloudpickle: bool = False,
75
+ # Evaluation caching
76
+ cache_evaluation: bool = False,
77
+ # Reproducibility
78
+ seed: int = 0,
79
+ raise_on_exception: bool = True,
80
+ val_evaluation_policy: EvaluationPolicy[DataId, DataInst] | Literal["full_eval"] | None = None,
81
+ ) -> GEPAResult[RolloutOutput, DataId]:
82
+ """
83
+ GEPA is an evolutionary optimizer that evolves (multiple) text components of a complex system to optimize them towards a given metric.
84
+ GEPA can also leverage rich textual feedback obtained from the system's execution environment, evaluation,
85
+ and the system's own execution traces to iteratively improve the system's performance.
86
+
87
+ Concepts:
88
+ - System: A harness that uses text components to perform a task. Each text component of the system to be optimized is a named component of the system.
89
+ - Candidate: A mapping from component names to component text. A concrete instantiation of the system is realized by setting the text of each system component
90
+ to the text provided by the candidate mapping.
91
+ - `DataInst`: An (uninterpreted) data type over which the system operates.
92
+ - `RolloutOutput`: The output of the system on a `DataInst`.
93
+
94
+ Each execution of the system produces a `RolloutOutput`, which can be evaluated to produce a score. The execution of the system also produces a trajectory,
95
+ which consists of the operations performed by different components of the system, including the text of the components that were executed.
96
+
97
+ GEPA can be applied to optimize any system that uses text components (e.g., prompts in a AI system, code snippets/code files/functions/classes in a codebase, etc.).
98
+ In order for GEPA to plug into your system's environment, GEPA requires an adapter, `GEPAAdapter` to be implemented. The adapter is responsible for:
99
+ 1. Evaluating a proposed candidate on a batch of inputs.
100
+ - The adapter receives a candidate proposed by GEPA, along with a batch of inputs selected from the training/validation set.
101
+ - The adapter instantiates the system with the texts proposed in the candidate.
102
+ - The adapter then evaluates the candidate on the batch of inputs, and returns the scores.
103
+ - The adapter should also capture relevant information from the execution of the candidate, like system and evaluation traces.
104
+ 2. Identifying textual information relevant to a component of the candidate
105
+ - Given the trajectories captured during the execution of the candidate, GEPA selects a component of the candidate to update.
106
+ - The adapter receives the candidate, the batch of inputs, and the trajectories captured during the execution of the candidate.
107
+ - The adapter is responsible for identifying the textual information relevant to the component to update.
108
+ - This information is used by GEPA to reflect on the performnace of the component, and propose new component texts.
109
+
110
+ At each iteration, GEPA proposes a new candidate using one of the following strategies:
111
+ 1. Reflective mutation: GEPA proposes a new candidate by mutating the current candidate, leveraging rich textual feedback.
112
+ 2. Merge: GEPA proposes a new candidate by merging 2 candidates that are on the Pareto frontier.
113
+
114
+ GEPA also tracks the Pareto frontier of performance achieved by different candidates on the validation set. This way, it can leverage candidates that
115
+ work well on a subset of inputs to improve the system's performance on the entire validation set, by evolving from the Pareto frontier.
116
+
117
+ Parameters:
118
+ - seed_candidate: The initial candidate to start with.
119
+ - trainset: Training data supplied as an in-memory sequence or a `DataLoader` yielding batches for reflective updates.
120
+ - valset: Validation data source (sequence or `DataLoader`) used for tracking Pareto scores. If not provided, GEPA reuses the trainset.
121
+ - adapter: A `GEPAAdapter` instance that implements the adapter interface. This allows GEPA to plug into your system's environment. If not provided, GEPA will use a default adapter: `gepa.adapters.default_adapter.default_adapter.DefaultAdapter`, with model defined by `task_lm`.
122
+ - task_lm: Optional. The model to use for the task. This is only used if `adapter` is not provided, and is used to initialize the default adapter.
123
+ - evaluator: Optional. A custom evaluator to use for evaluating the candidate program. If not provided, GEPA will use the default evaluator: `gepa.adapters.default_adapter.default_adapter.ContainsAnswerEvaluator`. Only used if `adapter` is not provided.
124
+
125
+ # Reflection-based configuration
126
+ - reflection_lm: A `LanguageModel` instance that is used to reflect on the performance of the candidate program.
127
+ - candidate_selection_strategy: The strategy to use for selecting the candidate to update. Supported strategies: 'pareto', 'current_best', 'epsilon_greedy'. Defaults to 'pareto'.
128
+ - frontier_type: Strategy for tracking Pareto frontiers. 'instance' tracks per validation example, 'objective' tracks per objective metric, 'hybrid' combines both, 'cartesian' tracks per (example, objective) pair. Defaults to 'instance'.
129
+ - skip_perfect_score: Whether to skip updating the candidate if it achieves a perfect score on the minibatch.
130
+ - batch_sampler: Strategy for selecting training examples. Can be a [BatchSampler](src/gepa/strategies/batch_sampler.py) instance or a string for a predefined strategy from ['epoch_shuffled']. Defaults to 'epoch_shuffled', which creates an [EpochShuffledBatchSampler](src/gepa/strategies/batch_sampler.py).
131
+ - reflection_minibatch_size: The number of examples to use for reflection in each proposal step. Defaults to 3. Only valid when batch_sampler='epoch_shuffled' (default), and is ignored otherwise.
132
+ - perfect_score: The perfect score to achieve.
133
+ - reflection_prompt_template: The prompt template to use for reflection. If not provided, GEPA will use the default prompt template (see [InstructionProposalSignature](src/gepa/strategies/instruction_proposal.py)). The prompt template must contain the following placeholders, which will be replaced with actual values: `<curr_instructions>` (will be replaced by the instructions to evolve) and `<inputs_outputs_feedback>` (replaced with the inputs, outputs, and feedback generated with current instruction). This will be ignored if the adapter provides its own `propose_new_texts` method.
134
+
135
+ # Component selection configuration
136
+ - module_selector: Component selection strategy. Can be a ReflectionComponentSelector instance or a string ('round_robin', 'all'). Defaults to 'round_robin'. The 'round_robin' strategy cycles through components in order. The 'all' strategy selects all components for modification in every GEPA iteration.
137
+
138
+ # Merge-based configuration
139
+ - use_merge: Whether to use the merge strategy.
140
+ - max_merge_invocations: The maximum number of merge invocations to perform.
141
+ - merge_val_overlap_floor: Minimum number of shared validation ids required between parents before attempting a merge subsample. Only relevant when using `val_evaluation_policy` other than `full_eval`.
142
+
143
+ # Budget and Stop Condition
144
+ - max_metric_calls: Optional maximum number of metric calls to perform. If not provided, stop_callbacks must be provided.
145
+ - stop_callbacks: Optional stopper(s) that return True when optimization should stop. Can be a single StopperProtocol or a list or tuple of StopperProtocol instances. Examples: FileStopper, TimeoutStopCondition, SignalStopper, NoImprovementStopper, or custom stopping logic. If not provided, max_metric_calls must be provided.
146
+
147
+ # Logging
148
+ - logger: A `LoggerProtocol` instance that is used to log the progress of the optimization.
149
+ - run_dir: The directory to save the results to. Optimization state and results will be saved to this directory. If the directory already exists, GEPA will read the state from this directory and resume the optimization from the last saved state. If provided, a FileStopper is automatically created which checks for the presence of "gepa.stop" in this directory, allowing graceful stopping of the optimization process upon its presence.
150
+ - use_wandb: Whether to use Weights and Biases to log the progress of the optimization.
151
+ - wandb_api_key: The API key to use for Weights and Biases.
152
+ - wandb_init_kwargs: Additional keyword arguments to pass to the Weights and Biases initialization.
153
+ - use_mlflow: Whether to use MLflow to log the progress of the optimization.
154
+ Both wandb and mlflow can be used simultaneously if desired.
155
+ - mlflow_tracking_uri: The tracking URI to use for MLflow.
156
+ - mlflow_experiment_name: The experiment name to use for MLflow.
157
+ - track_best_outputs: Whether to track the best outputs on the validation set. If True, GEPAResult will contain the best outputs obtained for each task in the validation set.
158
+ - display_progress_bar: Show a tqdm progress bar over metric calls when enabled.
159
+ - use_cloudpickle: Use cloudpickle instead of pickle. This can be helpful when the serialized state contains dynamically generated DSPy signatures.
160
+
161
+ # Evaluation caching
162
+ - cache_evaluation: Whether to cache the (score, output, objective_scores) of (candidate, example) pairs. If True and a cache entry exists, GEPA will skip the fitness evaluation and use the cached results. This helps avoid redundant evaluations and saves metric calls. Defaults to False.
163
+
164
+ # Reproducibility
165
+ - seed: The seed to use for the random number generator.
166
+ - val_evaluation_policy: Strategy controlling which validation ids to score each iteration and which candidate is currently best. Supported strings: "full_eval" (evaluate every id each time) Passing None defaults to "full_eval".
167
+ - raise_on_exception: Whether to propagate proposer/evaluator exceptions instead of stopping gracefully.
168
+ """
169
+ active_adapter: GEPAAdapter[DataInst, Trajectory, RolloutOutput] | None = None
170
+ if adapter is None:
171
+ assert task_lm is not None, (
172
+ "Since no adapter is provided, GEPA requires a task LM to be provided. Please set the `task_lm` parameter."
173
+ )
174
+ active_adapter = cast(
175
+ GEPAAdapter[DataInst, Trajectory, RolloutOutput], DefaultAdapter(model=task_lm, evaluator=evaluator)
176
+ )
177
+ else:
178
+ assert task_lm is None, (
179
+ "Since an adapter is provided, GEPA does not require a task LM to be provided. Please set the `task_lm` parameter to None."
180
+ )
181
+ assert evaluator is None, (
182
+ "Since an adapter is provided, GEPA does not require an evaluator to be provided. Please set the `evaluator` parameter to None."
183
+ )
184
+ active_adapter = adapter
185
+
186
+ # Normalize datasets to DataLoader instances
187
+ train_loader = ensure_loader(trainset)
188
+ val_loader = ensure_loader(valset) if valset is not None else train_loader
189
+
190
+ # Comprehensive stop_callback logic
191
+ # Convert stop_callbacks to a list if it's not already
192
+ stop_callbacks_list: list[StopperProtocol] = []
193
+ if stop_callbacks is not None:
194
+ if isinstance(stop_callbacks, Sequence):
195
+ stop_callbacks_list.extend(stop_callbacks)
196
+ else:
197
+ stop_callbacks_list.append(stop_callbacks)
198
+
199
+ # Add file stopper if run_dir is provided
200
+ if run_dir is not None:
201
+ stop_file_path = os.path.join(run_dir, "gepa.stop")
202
+ file_stopper = FileStopper(stop_file_path)
203
+ stop_callbacks_list.append(file_stopper)
204
+
205
+ # Add max_metric_calls stopper if provided
206
+ if max_metric_calls is not None:
207
+ from mantisdk.algorithm.gepa.lib.utils import MaxMetricCallsStopper
208
+
209
+ max_calls_stopper = MaxMetricCallsStopper(max_metric_calls)
210
+ stop_callbacks_list.append(max_calls_stopper)
211
+
212
+ # Assert that at least one stopping condition is provided
213
+ if not stop_callbacks_list:
214
+ raise ValueError(
215
+ "The user must provide at least one of stop_callbacks or max_metric_calls to specify a stopping condition."
216
+ )
217
+
218
+ # Create composite stopper if multiple stoppers, or use single stopper
219
+ stop_callback: StopperProtocol
220
+ if len(stop_callbacks_list) == 1:
221
+ stop_callback = stop_callbacks_list[0]
222
+ else:
223
+ from mantisdk.algorithm.gepa.lib.utils import CompositeStopper
224
+
225
+ stop_callback = CompositeStopper(*stop_callbacks_list)
226
+
227
+ if not hasattr(active_adapter, "propose_new_texts"):
228
+ assert reflection_lm is not None, (
229
+ f"reflection_lm was not provided. The adapter used '{active_adapter!s}' does not provide a propose_new_texts method, "
230
+ + "and hence, GEPA will use the default proposer, which requires a reflection_lm to be specified."
231
+ )
232
+
233
+ if isinstance(reflection_lm, str):
234
+ import litellm
235
+
236
+ reflection_lm_name = reflection_lm
237
+
238
+ def _reflection_lm(prompt: str) -> str:
239
+ completion = litellm.completion(model=reflection_lm_name, messages=[{"role": "user", "content": prompt}])
240
+ return completion.choices[0].message.content # type: ignore
241
+
242
+ reflection_lm = _reflection_lm
243
+
244
+ if logger is None:
245
+ logger = StdOutLogger()
246
+
247
+ rng = random.Random(seed)
248
+
249
+ candidate_selector: CandidateSelector
250
+ if isinstance(candidate_selection_strategy, str):
251
+ factories = {
252
+ "pareto": lambda: ParetoCandidateSelector(rng=rng),
253
+ "current_best": lambda: CurrentBestCandidateSelector(),
254
+ "epsilon_greedy": lambda: EpsilonGreedyCandidateSelector(epsilon=0.1, rng=rng),
255
+ }
256
+
257
+ try:
258
+ candidate_selector = factories[candidate_selection_strategy]()
259
+ except KeyError as exc:
260
+ raise ValueError(
261
+ f"Unknown candidate_selector strategy: {candidate_selection_strategy}. "
262
+ "Supported strategies: 'pareto', 'current_best', 'epsilon_greedy'"
263
+ ) from exc
264
+ elif isinstance(candidate_selection_strategy, CandidateSelector):
265
+ candidate_selector = candidate_selection_strategy
266
+ else:
267
+ raise TypeError(
268
+ "candidate_selection_strategy must be a supported string strategy or an instance of CandidateSelector."
269
+ )
270
+
271
+ if val_evaluation_policy is None or val_evaluation_policy == "full_eval":
272
+ val_evaluation_policy = FullEvaluationPolicy()
273
+ elif not isinstance(val_evaluation_policy, EvaluationPolicy):
274
+ raise ValueError(
275
+ f"val_evaluation_policy should be one of 'full_eval' or an instance of EvaluationPolicy, but got {type(val_evaluation_policy)}"
276
+ )
277
+
278
+ if isinstance(module_selector, str):
279
+ module_selector_cls = {
280
+ "round_robin": RoundRobinReflectionComponentSelector,
281
+ "all": AllReflectionComponentSelector,
282
+ }.get(module_selector)
283
+
284
+ assert module_selector_cls is not None, (
285
+ f"Unknown module_selector strategy: {module_selector}. Supported strategies: 'round_robin', 'all'"
286
+ )
287
+
288
+ module_selector_instance: ReflectionComponentSelector = module_selector_cls()
289
+ else:
290
+ module_selector_instance = module_selector
291
+
292
+ if batch_sampler == "epoch_shuffled":
293
+ batch_sampler = EpochShuffledBatchSampler(minibatch_size=reflection_minibatch_size or 3, rng=rng)
294
+ else:
295
+ assert reflection_minibatch_size is None, (
296
+ "reflection_minibatch_size only accepted if batch_sampler is 'epoch_shuffled'"
297
+ )
298
+
299
+ experiment_tracker = create_experiment_tracker(
300
+ use_wandb=use_wandb,
301
+ wandb_api_key=wandb_api_key,
302
+ wandb_init_kwargs=wandb_init_kwargs,
303
+ use_mlflow=use_mlflow,
304
+ mlflow_tracking_uri=mlflow_tracking_uri,
305
+ mlflow_experiment_name=mlflow_experiment_name,
306
+ )
307
+
308
+ if reflection_prompt_template is not None:
309
+ assert not (adapter is not None and getattr(adapter, "propose_new_texts", None) is not None), (
310
+ f"Adapter {adapter!s} provides its own propose_new_texts method; reflection_prompt_template will be ignored. "
311
+ "Set reflection_prompt_template to None."
312
+ )
313
+
314
+ # Create evaluation cache if enabled
315
+ evaluation_cache: EvaluationCache[RolloutOutput, DataId] | None = None
316
+ if cache_evaluation:
317
+ evaluation_cache = EvaluationCache[RolloutOutput, DataId]()
318
+
319
+ reflective_proposer = ReflectiveMutationProposer(
320
+ logger=logger,
321
+ trainset=train_loader,
322
+ adapter=active_adapter,
323
+ candidate_selector=candidate_selector,
324
+ module_selector=module_selector_instance,
325
+ batch_sampler=batch_sampler,
326
+ perfect_score=perfect_score,
327
+ skip_perfect_score=skip_perfect_score,
328
+ experiment_tracker=experiment_tracker,
329
+ reflection_lm=reflection_lm,
330
+ reflection_prompt_template=reflection_prompt_template,
331
+ )
332
+
333
+ def evaluator_fn(
334
+ inputs: list[DataInst], prog: dict[str, str]
335
+ ) -> tuple[list[RolloutOutput], list[float], Sequence[dict[str, float]] | None]:
336
+ eval_out = active_adapter.evaluate(inputs, prog, capture_traces=False)
337
+ return eval_out.outputs, eval_out.scores, eval_out.objective_scores
338
+
339
+ merge_proposer: MergeProposer | None = None
340
+ if use_merge:
341
+ merge_proposer = MergeProposer(
342
+ logger=logger,
343
+ valset=val_loader,
344
+ evaluator=evaluator_fn,
345
+ use_merge=use_merge,
346
+ max_merge_invocations=max_merge_invocations,
347
+ rng=rng,
348
+ val_overlap_floor=merge_val_overlap_floor,
349
+ )
350
+
351
+ engine = GEPAEngine(
352
+ adapter=active_adapter,
353
+ run_dir=run_dir,
354
+ valset=val_loader,
355
+ seed_candidate=seed_candidate,
356
+ perfect_score=perfect_score,
357
+ seed=seed,
358
+ reflective_proposer=reflective_proposer,
359
+ merge_proposer=merge_proposer,
360
+ frontier_type=frontier_type,
361
+ logger=logger,
362
+ experiment_tracker=experiment_tracker,
363
+ track_best_outputs=track_best_outputs,
364
+ display_progress_bar=display_progress_bar,
365
+ raise_on_exception=raise_on_exception,
366
+ stop_callback=stop_callback,
367
+ val_evaluation_policy=val_evaluation_policy,
368
+ use_cloudpickle=use_cloudpickle,
369
+ evaluation_cache=evaluation_cache,
370
+ )
371
+
372
+ with experiment_tracker:
373
+ state = engine.run()
374
+
375
+ return GEPAResult.from_state(state, run_dir=run_dir, seed=seed)
File without changes
@@ -0,0 +1,180 @@
1
+ # Copyright (c) 2025 Lakshya A Agrawal and the GEPA contributors
2
+ # https://github.com/gepa-ai/gepa
3
+
4
+ from collections.abc import Mapping, Sequence
5
+ from dataclasses import dataclass
6
+ from typing import Any, Generic, Protocol, TypeVar
7
+
8
+ # Generic type aliases matching your original
9
+ RolloutOutput = TypeVar("RolloutOutput")
10
+ Trajectory = TypeVar("Trajectory")
11
+ DataInst = TypeVar("DataInst")
12
+ Candidate = dict[str, str]
13
+
14
+
15
+ @dataclass
16
+ class EvaluationBatch(Generic[Trajectory, RolloutOutput]):
17
+ """
18
+ Container for the result of evaluating a proposed candidate on a batch of data.
19
+
20
+ - outputs: raw per-example outputs from upon executing the candidate. GEPA does not interpret these;
21
+ they are forwarded to other parts of the user's code or logging as-is.
22
+ - scores: per-example numeric scores (floats). GEPA sums these for minibatch acceptance
23
+ and averages them over the full validation set for tracking/pareto fronts.
24
+ - trajectories: optional per-example traces used by make_reflective_dataset to build
25
+ a reflective dataset (See `GEPAAdapter.make_reflective_dataset`). If capture_traces=True is passed to `evaluate`, trajectories
26
+ should be provided and align one-to-one with `outputs` and `scores`.
27
+ - objective_scores: optional per-example maps of objective name -> score. Leave None when
28
+ the evaluator does not expose multi-objective metrics.
29
+ """
30
+
31
+ outputs: list[RolloutOutput]
32
+ scores: list[float]
33
+ trajectories: list[Trajectory] | None = None
34
+ objective_scores: list[dict[str, float]] | None = None
35
+
36
+
37
+ class ProposalFn(Protocol):
38
+ def __call__(
39
+ self,
40
+ candidate: dict[str, str],
41
+ reflective_dataset: Mapping[str, Sequence[Mapping[str, Any]]],
42
+ components_to_update: list[str],
43
+ ) -> dict[str, str]:
44
+ """
45
+ - Given the current `candidate`, a reflective dataset (as returned by
46
+ `GEPAAdapter.make_reflective_dataset`), and a list of component names to update,
47
+ return a mapping component_name -> new component text (str). This allows the user
48
+ to implement their own instruction proposal logic. For example, the user can use
49
+ a different LLM, implement DSPy signatures, etc. Another example can be situations
50
+ where 2 or more components need to be updated together (coupled updates).
51
+
52
+ Returns
53
+ - Dict[str, str] mapping component names to newly proposed component texts.
54
+ """
55
+ ...
56
+
57
+
58
+ class GEPAAdapter(Protocol[DataInst, Trajectory, RolloutOutput]):
59
+ """
60
+ GEPAAdapter is the single integration point between your system
61
+ and the GEPA optimization engine. Implementers provide three responsibilities:
62
+
63
+ The following are user-defined types that are not interpreted by GEPA but are used by the user's code
64
+ to define the adapter:
65
+ DataInst: User-defined type of input data to the program under optimization.
66
+ Trajectory: User-defined type of trajectory data, which typically captures the
67
+ different steps of the program candidate execution.
68
+ RolloutOutput: User-defined type of output data from the program candidate.
69
+
70
+ The following are the responsibilities of the adapter:
71
+ 1) Program construction and evaluation (evaluate):
72
+ Given a batch of DataInst and a "candidate" program (mapping from named components
73
+ -> component text), execute the program to produce per-example scores and
74
+ optionally rich trajectories (capturing intermediate states) needed for reflection.
75
+
76
+ 2) Reflective dataset construction (make_reflective_dataset):
77
+ Given the candidate, EvaluationBatch (trajectories, outputs, scores), and the list of components to update,
78
+ produce a small JSON-serializable dataset for each component that you want to update. This
79
+ dataset is fed to the teacher LM to propose improved component text.
80
+
81
+ 3) Optional instruction proposal (propose_new_texts):
82
+ GEPA provides a default implementation (instruction_proposal.py) that serializes the reflective dataset
83
+ to propose new component texts. However, users can implement their own proposal logic by implementing this method.
84
+ This method receives the current candidate, the reflective dataset, and the list of components to update,
85
+ and returns a mapping from component name to new component text.
86
+
87
+ Key concepts and contracts:
88
+ - candidate: Dict[str, str] mapping a named component of the system to its corresponding text.
89
+ - scores: higher is better. GEPA uses:
90
+ - minibatch: sum(scores) to compare old vs. new candidate (acceptance test),
91
+ - full valset: mean(scores) for tracking and Pareto-front selection.
92
+ Ensure your metric is calibrated accordingly or normalized to a consistent scale.
93
+ - trajectories: opaque to GEPA (the engine never inspects them). They must be
94
+ consumable by your own make_reflective_dataset implementation to extract the
95
+ minimal context needed to produce meaningful feedback for every component of
96
+ the system under optimization.
97
+ - error handling: Never raise for individual example failures. Instead:
98
+ - Return a valid `EvaluationBatch` with per-example failure scores (e.g., 0.0)
99
+ when formatting/parsing fails. Even better if the trajectories are also populated
100
+ with the failed example, including the error message, identifying the reason for the failure.
101
+ - Reserve exceptions for unrecoverable, systemic failures (e.g., missing model,
102
+ misconfigured program, schema mismatch).
103
+ - If an exception is raised, the engine will log the error and proceed to the next iteration.
104
+ """
105
+
106
+ def evaluate(
107
+ self,
108
+ batch: list[DataInst],
109
+ candidate: dict[str, str],
110
+ capture_traces: bool = False,
111
+ ) -> EvaluationBatch[Trajectory, RolloutOutput]:
112
+ """
113
+ Run the program defined by `candidate` on a batch of data.
114
+
115
+ Parameters
116
+ - batch: list of task-specific inputs (DataInst).
117
+ - candidate: mapping from component name -> component text. You must instantiate
118
+ your full system with the component text for each component, and execute it on the batch.
119
+ - capture_traces: when True, you must populate `EvaluationBatch.trajectories`
120
+ with a per-example trajectory object that your `make_reflective_dataset` can
121
+ later consume. When False, you may set trajectories=None to save time/memory.
122
+ capture_traces=True is used by the reflective mutation proposer to build a reflective dataset.
123
+
124
+ Returns
125
+ - EvaluationBatch with:
126
+ - outputs: raw per-example outputs (opaque to GEPA).
127
+ - scores: per-example floats, length == len(batch). Higher is better.
128
+ - trajectories:
129
+ - if capture_traces=True: list[Trajectory] with length == len(batch).
130
+ - if capture_traces=False: None.
131
+
132
+ Scoring semantics
133
+ - The engine uses sum(scores) on minibatches to decide whether to accept a
134
+ candidate mutation and average(scores) over the full valset for tracking.
135
+ - Prefer to return per-example scores, that can be aggregated via summation.
136
+ - If an example fails (e.g., parse error), use a fallback score (e.g., 0.0).
137
+
138
+ Correctness constraints
139
+ - len(outputs) == len(scores) == len(batch)
140
+ - If capture_traces=True: trajectories must be provided and len(trajectories) == len(batch)
141
+ - Do not mutate `batch` or `candidate` in-place. Construct a fresh program
142
+ instance or deep-copy as needed.
143
+ """
144
+ ...
145
+
146
+ def make_reflective_dataset(
147
+ self,
148
+ candidate: dict[str, str],
149
+ eval_batch: EvaluationBatch[Trajectory, RolloutOutput],
150
+ components_to_update: list[str],
151
+ ) -> Mapping[str, Sequence[Mapping[str, Any]]]:
152
+ """
153
+ Build a small, JSON-serializable dataset (per component) to drive instruction
154
+ refinement by a teacher LLM.
155
+
156
+ Parameters
157
+ - candidate: the same candidate evaluated in evaluate().
158
+ - eval_batch: The result of evaluate(..., capture_traces=True) on
159
+ the same batch. You should extract everything you need from eval_batch.trajectories
160
+ (and optionally outputs/scores) to assemble concise, high-signal examples.
161
+ - components_to_update: subset of component names for which the proposer has
162
+ requested updates. At a time, GEPA identifies a subset of components to update.
163
+
164
+ Returns
165
+ - A dict: component_name -> list of dict records (the "reflective dataset").
166
+ Each record should be JSON-serializable and is passed verbatim to the
167
+ instruction proposal prompt. A recommended schema is:
168
+ {
169
+ "Inputs": Dict[str, str], # Minimal, clean view of the inputs to the component
170
+ "Generated Outputs": Dict[str, str] | str, # Model outputs or raw text
171
+ "Feedback": str # Feedback on the component's performance, including correct answer, error messages, etc.
172
+ }
173
+ You may include additional keys (e.g., "score", "rationale", "trace_id") if useful.
174
+
175
+ Determinism
176
+ - If you subsample trace instances, use a seeded RNG to keep runs reproducible.
177
+ """
178
+ ...
179
+
180
+ propose_new_texts: ProposalFn | None = None
@@ -0,0 +1,74 @@
1
+ """Data loader protocols and concrete helpers."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any, Hashable, Protocol, Sequence, TypeVar, cast, runtime_checkable
6
+
7
+ from mantisdk.algorithm.gepa.lib.core.adapter import DataInst
8
+
9
+
10
+ class ComparableHashable(Hashable, Protocol):
11
+ """Protocol requiring hashing and rich comparison support."""
12
+
13
+ def __lt__(self, other: Any, /) -> bool: ...
14
+
15
+ def __gt__(self, other: Any, /) -> bool: ...
16
+
17
+ def __le__(self, other: Any, /) -> bool: ...
18
+
19
+ def __ge__(self, other: Any, /) -> bool: ...
20
+
21
+
22
+ DataId = TypeVar("DataId", bound=ComparableHashable)
23
+ """ Generic for the identifier for data examples """
24
+
25
+
26
+ @runtime_checkable
27
+ class DataLoader(Protocol[DataId, DataInst]):
28
+ """Minimal interface for retrieving validation examples keyed by opaque ids."""
29
+
30
+ def all_ids(self) -> Sequence[DataId]:
31
+ """Return the ordered universe of ids currently available. This may change over time."""
32
+ ...
33
+
34
+ def fetch(self, ids: Sequence[DataId]) -> list[DataInst]:
35
+ """Materialise the payloads corresponding to `ids`, preserving order."""
36
+ ...
37
+
38
+ def __len__(self) -> int:
39
+ """Return current number of items in the loader."""
40
+ ...
41
+
42
+
43
+ class MutableDataLoader(DataLoader[DataId, DataInst], Protocol):
44
+ """A data loader that can be mutated."""
45
+
46
+ def add_items(self, items: list[DataInst]) -> None:
47
+ """Add items to the loader."""
48
+
49
+
50
+ class ListDataLoader(MutableDataLoader[int, DataInst]):
51
+ """In-memory reference implementation backed by a list."""
52
+
53
+ def __init__(self, items: Sequence[DataInst]):
54
+ self.items = list(items)
55
+
56
+ def all_ids(self) -> Sequence[int]:
57
+ return list(range(len(self.items)))
58
+
59
+ def fetch(self, ids: Sequence[int]) -> list[DataInst]:
60
+ return [self.items[data_id] for data_id in ids]
61
+
62
+ def __len__(self) -> int:
63
+ return len(self.items)
64
+
65
+ def add_items(self, items: Sequence[DataInst]) -> None:
66
+ self.items.extend(items)
67
+
68
+
69
+ def ensure_loader(data_or_loader: Sequence[DataInst] | DataLoader[DataId, DataInst]) -> DataLoader[DataId, DataInst]:
70
+ if isinstance(data_or_loader, DataLoader):
71
+ return data_or_loader
72
+ if isinstance(data_or_loader, Sequence):
73
+ return cast(DataLoader[DataId, DataInst], ListDataLoader(data_or_loader))
74
+ raise TypeError(f"Unable to cast to a DataLoader type: {type(data_or_loader)}")