themis-eval 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- themis/cli/__init__.py +5 -0
- themis/cli/__main__.py +6 -0
- themis/cli/commands/__init__.py +19 -0
- themis/cli/commands/benchmarks.py +221 -0
- themis/cli/commands/comparison.py +394 -0
- themis/cli/commands/config_commands.py +244 -0
- themis/cli/commands/cost.py +214 -0
- themis/cli/commands/demo.py +68 -0
- themis/cli/commands/info.py +90 -0
- themis/cli/commands/leaderboard.py +362 -0
- themis/cli/commands/math_benchmarks.py +318 -0
- themis/cli/commands/mcq_benchmarks.py +207 -0
- themis/cli/commands/sample_run.py +244 -0
- themis/cli/commands/visualize.py +299 -0
- themis/cli/main.py +93 -0
- themis/cli/new_project.py +33 -0
- themis/cli/utils.py +51 -0
- themis/config/__init__.py +19 -0
- themis/config/loader.py +27 -0
- themis/config/registry.py +34 -0
- themis/config/runtime.py +214 -0
- themis/config/schema.py +112 -0
- themis/core/__init__.py +5 -0
- themis/core/conversation.py +354 -0
- themis/core/entities.py +164 -0
- themis/core/serialization.py +231 -0
- themis/core/tools.py +393 -0
- themis/core/types.py +141 -0
- themis/datasets/__init__.py +273 -0
- themis/datasets/base.py +264 -0
- themis/datasets/commonsense_qa.py +174 -0
- themis/datasets/competition_math.py +265 -0
- themis/datasets/coqa.py +133 -0
- themis/datasets/gpqa.py +190 -0
- themis/datasets/gsm8k.py +123 -0
- themis/datasets/gsm_symbolic.py +124 -0
- themis/datasets/math500.py +122 -0
- themis/datasets/med_qa.py +179 -0
- themis/datasets/medmcqa.py +169 -0
- themis/datasets/mmlu_pro.py +262 -0
- themis/datasets/piqa.py +146 -0
- themis/datasets/registry.py +201 -0
- themis/datasets/schema.py +245 -0
- themis/datasets/sciq.py +150 -0
- themis/datasets/social_i_qa.py +151 -0
- themis/datasets/super_gpqa.py +263 -0
- themis/evaluation/__init__.py +1 -0
- themis/evaluation/conditional.py +410 -0
- themis/evaluation/extractors/__init__.py +19 -0
- themis/evaluation/extractors/error_taxonomy_extractor.py +80 -0
- themis/evaluation/extractors/exceptions.py +7 -0
- themis/evaluation/extractors/identity_extractor.py +29 -0
- themis/evaluation/extractors/json_field_extractor.py +45 -0
- themis/evaluation/extractors/math_verify_extractor.py +37 -0
- themis/evaluation/extractors/regex_extractor.py +43 -0
- themis/evaluation/math_verify_utils.py +87 -0
- themis/evaluation/metrics/__init__.py +21 -0
- themis/evaluation/metrics/composite_metric.py +47 -0
- themis/evaluation/metrics/consistency_metric.py +80 -0
- themis/evaluation/metrics/exact_match.py +51 -0
- themis/evaluation/metrics/length_difference_tolerance.py +33 -0
- themis/evaluation/metrics/math_verify_accuracy.py +40 -0
- themis/evaluation/metrics/pairwise_judge_metric.py +141 -0
- themis/evaluation/metrics/response_length.py +33 -0
- themis/evaluation/metrics/rubric_judge_metric.py +134 -0
- themis/evaluation/pipeline.py +49 -0
- themis/evaluation/pipelines/__init__.py +15 -0
- themis/evaluation/pipelines/composable_pipeline.py +357 -0
- themis/evaluation/pipelines/standard_pipeline.py +288 -0
- themis/evaluation/reports.py +293 -0
- themis/evaluation/statistics/__init__.py +53 -0
- themis/evaluation/statistics/bootstrap.py +79 -0
- themis/evaluation/statistics/confidence_intervals.py +121 -0
- themis/evaluation/statistics/distributions.py +207 -0
- themis/evaluation/statistics/effect_sizes.py +124 -0
- themis/evaluation/statistics/hypothesis_tests.py +305 -0
- themis/evaluation/statistics/types.py +139 -0
- themis/evaluation/strategies/__init__.py +13 -0
- themis/evaluation/strategies/attempt_aware_evaluation_strategy.py +51 -0
- themis/evaluation/strategies/default_evaluation_strategy.py +25 -0
- themis/evaluation/strategies/evaluation_strategy.py +24 -0
- themis/evaluation/strategies/judge_evaluation_strategy.py +64 -0
- themis/experiment/__init__.py +5 -0
- themis/experiment/builder.py +151 -0
- themis/experiment/cache_manager.py +129 -0
- themis/experiment/comparison.py +631 -0
- themis/experiment/cost.py +310 -0
- themis/experiment/definitions.py +62 -0
- themis/experiment/export.py +690 -0
- themis/experiment/export_csv.py +159 -0
- themis/experiment/integration_manager.py +104 -0
- themis/experiment/math.py +192 -0
- themis/experiment/mcq.py +169 -0
- themis/experiment/orchestrator.py +373 -0
- themis/experiment/pricing.py +317 -0
- themis/experiment/storage.py +255 -0
- themis/experiment/visualization.py +588 -0
- themis/generation/__init__.py +1 -0
- themis/generation/agentic_runner.py +420 -0
- themis/generation/batching.py +254 -0
- themis/generation/clients.py +143 -0
- themis/generation/conversation_runner.py +236 -0
- themis/generation/plan.py +456 -0
- themis/generation/providers/litellm_provider.py +221 -0
- themis/generation/providers/vllm_provider.py +135 -0
- themis/generation/router.py +34 -0
- themis/generation/runner.py +207 -0
- themis/generation/strategies.py +98 -0
- themis/generation/templates.py +71 -0
- themis/generation/turn_strategies.py +393 -0
- themis/generation/types.py +9 -0
- themis/integrations/__init__.py +0 -0
- themis/integrations/huggingface.py +61 -0
- themis/integrations/wandb.py +65 -0
- themis/interfaces/__init__.py +83 -0
- themis/project/__init__.py +20 -0
- themis/project/definitions.py +98 -0
- themis/project/patterns.py +230 -0
- themis/providers/__init__.py +5 -0
- themis/providers/registry.py +39 -0
- themis/utils/api_generator.py +379 -0
- themis/utils/cost_tracking.py +376 -0
- themis/utils/dashboard.py +452 -0
- themis/utils/logging_utils.py +41 -0
- themis/utils/progress.py +58 -0
- themis/utils/tracing.py +320 -0
- {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/METADATA +1 -1
- themis_eval-0.1.1.dist-info/RECORD +134 -0
- themis_eval-0.1.0.dist-info/RECORD +0 -8
- {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/WHEEL +0 -0
- {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/licenses/LICENSE +0 -0
- {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
"""Project-level definitions for grouping experiments."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from typing import Any, Mapping, Sequence
|
|
7
|
+
|
|
8
|
+
from themis.experiment.definitions import ExperimentDefinition
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass(frozen=True)
|
|
12
|
+
class ProjectExperiment:
|
|
13
|
+
"""Metadata wrapper that pairs a name with an experiment definition."""
|
|
14
|
+
|
|
15
|
+
name: str
|
|
16
|
+
definition: ExperimentDefinition
|
|
17
|
+
description: str | None = None
|
|
18
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
19
|
+
tags: tuple[str, ...] = field(default_factory=tuple)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass
|
|
23
|
+
class Project:
|
|
24
|
+
"""Container that organizes multiple experiments under a shared project."""
|
|
25
|
+
|
|
26
|
+
project_id: str
|
|
27
|
+
name: str
|
|
28
|
+
description: str | None = None
|
|
29
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
30
|
+
tags: tuple[str, ...] = field(default_factory=tuple)
|
|
31
|
+
experiments: Sequence[ProjectExperiment] = field(default_factory=tuple)
|
|
32
|
+
|
|
33
|
+
def __post_init__(self) -> None:
|
|
34
|
+
self._experiment_index: dict[str, ProjectExperiment] = {}
|
|
35
|
+
normalized: list[ProjectExperiment] = []
|
|
36
|
+
for experiment in self.experiments:
|
|
37
|
+
self._register_experiment(experiment)
|
|
38
|
+
normalized.append(experiment)
|
|
39
|
+
self.experiments = tuple(normalized)
|
|
40
|
+
|
|
41
|
+
def add_experiment(self, experiment: ProjectExperiment) -> ProjectExperiment:
|
|
42
|
+
"""Attach an experiment to the project, enforcing unique names."""
|
|
43
|
+
|
|
44
|
+
self._register_experiment(experiment)
|
|
45
|
+
self.experiments = tuple(list(self.experiments) + [experiment])
|
|
46
|
+
return experiment
|
|
47
|
+
|
|
48
|
+
def create_experiment(
|
|
49
|
+
self,
|
|
50
|
+
*,
|
|
51
|
+
name: str,
|
|
52
|
+
definition: ExperimentDefinition,
|
|
53
|
+
description: str | None = None,
|
|
54
|
+
metadata: Mapping[str, Any] | None = None,
|
|
55
|
+
tags: Sequence[str] | None = None,
|
|
56
|
+
) -> ProjectExperiment:
|
|
57
|
+
"""Convenience helper to register an experiment from raw components."""
|
|
58
|
+
|
|
59
|
+
experiment = ProjectExperiment(
|
|
60
|
+
name=name,
|
|
61
|
+
description=description,
|
|
62
|
+
definition=definition,
|
|
63
|
+
metadata=dict(metadata or {}),
|
|
64
|
+
tags=tuple(tags or ()),
|
|
65
|
+
)
|
|
66
|
+
return self.add_experiment(experiment)
|
|
67
|
+
|
|
68
|
+
def get_experiment(self, name: str) -> ProjectExperiment:
|
|
69
|
+
try:
|
|
70
|
+
return self._experiment_index[name]
|
|
71
|
+
except KeyError as exc: # pragma: no cover - defensive guard
|
|
72
|
+
raise KeyError(
|
|
73
|
+
f"Experiment '{name}' not registered in project '{self.project_id}'"
|
|
74
|
+
) from exc
|
|
75
|
+
|
|
76
|
+
def metadata_for_experiment(self, name: str) -> dict[str, Any]:
|
|
77
|
+
"""Merge project-level metadata with experiment-specific overrides."""
|
|
78
|
+
|
|
79
|
+
combined = dict(self.metadata)
|
|
80
|
+
combined.update(self.get_experiment(name).metadata)
|
|
81
|
+
return combined
|
|
82
|
+
|
|
83
|
+
def list_experiment_names(self) -> tuple[str, ...]:
|
|
84
|
+
return tuple(self._experiment_index.keys())
|
|
85
|
+
|
|
86
|
+
def _register_experiment(self, experiment: ProjectExperiment) -> None:
|
|
87
|
+
if experiment.name in self._experiment_index:
|
|
88
|
+
raise ValueError(
|
|
89
|
+
f"Experiment '{experiment.name}' already registered "
|
|
90
|
+
f"in project '{self.project_id}'"
|
|
91
|
+
)
|
|
92
|
+
self._experiment_index[experiment.name] = experiment
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
__all__ = [
|
|
96
|
+
"Project",
|
|
97
|
+
"ProjectExperiment",
|
|
98
|
+
]
|
|
@@ -0,0 +1,230 @@
|
|
|
1
|
+
"""Reusable experiment patterns for organizing projects."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import re
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
|
+
from typing import Any, Callable, Mapping, Sequence
|
|
8
|
+
|
|
9
|
+
from themis.experiment.definitions import ExperimentDefinition
|
|
10
|
+
from themis.experiment.orchestrator import ExperimentReport
|
|
11
|
+
from themis.project.definitions import Project, ProjectExperiment
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _slugify(value: str) -> str:
|
|
15
|
+
text = value.strip().lower()
|
|
16
|
+
text = re.sub(r"[^a-z0-9]+", "-", text)
|
|
17
|
+
text = text.strip("-")
|
|
18
|
+
return text or "variant"
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass
|
|
22
|
+
class AblationVariant:
|
|
23
|
+
value: Any
|
|
24
|
+
label: str | None = None
|
|
25
|
+
metadata: Mapping[str, Any] = field(default_factory=dict)
|
|
26
|
+
|
|
27
|
+
def display_label(self) -> str:
|
|
28
|
+
return self.label or str(self.value)
|
|
29
|
+
|
|
30
|
+
def slug(self) -> str:
|
|
31
|
+
return _slugify(self.display_label())
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@dataclass(frozen=True)
|
|
35
|
+
class AblationChartPoint:
|
|
36
|
+
x_value: Any
|
|
37
|
+
label: str
|
|
38
|
+
metric_value: float
|
|
39
|
+
metric_name: str
|
|
40
|
+
count: int
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@dataclass(frozen=True)
|
|
44
|
+
class AblationChart:
|
|
45
|
+
title: str
|
|
46
|
+
x_label: str
|
|
47
|
+
y_label: str
|
|
48
|
+
metric_name: str
|
|
49
|
+
points: tuple[AblationChartPoint, ...]
|
|
50
|
+
|
|
51
|
+
def as_dict(self) -> dict[str, Any]:
|
|
52
|
+
return {
|
|
53
|
+
"title": self.title,
|
|
54
|
+
"x_label": self.x_label,
|
|
55
|
+
"y_label": self.y_label,
|
|
56
|
+
"metric": self.metric_name,
|
|
57
|
+
"points": [
|
|
58
|
+
{
|
|
59
|
+
"label": point.label,
|
|
60
|
+
"x": point.x_value,
|
|
61
|
+
"value": point.metric_value,
|
|
62
|
+
"count": point.count,
|
|
63
|
+
}
|
|
64
|
+
for point in self.points
|
|
65
|
+
],
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
@dataclass(frozen=True)
|
|
70
|
+
class XAbationPatternApplication:
|
|
71
|
+
pattern_name: str
|
|
72
|
+
parameter_name: str
|
|
73
|
+
experiments: tuple[ProjectExperiment, ...]
|
|
74
|
+
variant_by_name: Mapping[str, AblationVariant]
|
|
75
|
+
_pattern: "XAblationPattern" = field(repr=False)
|
|
76
|
+
|
|
77
|
+
def build_chart(self, reports: Mapping[str, ExperimentReport]) -> AblationChart:
|
|
78
|
+
return self._pattern._build_chart(reports, self.variant_by_name)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class XAbationPattern:
|
|
82
|
+
"""Vary a single factor across values to compare performance."""
|
|
83
|
+
|
|
84
|
+
pattern_type = "x-ablation"
|
|
85
|
+
|
|
86
|
+
def __init__(
|
|
87
|
+
self,
|
|
88
|
+
*,
|
|
89
|
+
name: str,
|
|
90
|
+
parameter_name: str,
|
|
91
|
+
values: Sequence[AblationVariant | Any],
|
|
92
|
+
definition_builder: Callable[[AblationVariant], ExperimentDefinition],
|
|
93
|
+
metric_name: str,
|
|
94
|
+
x_axis_label: str | None = None,
|
|
95
|
+
y_axis_label: str | None = None,
|
|
96
|
+
title: str | None = None,
|
|
97
|
+
) -> None:
|
|
98
|
+
if not values:
|
|
99
|
+
raise ValueError("XAblationPattern requires at least one value")
|
|
100
|
+
self.name = name
|
|
101
|
+
self.parameter_name = parameter_name
|
|
102
|
+
self._variants = [self._normalize_variant(value) for value in values]
|
|
103
|
+
self._definition_builder = definition_builder
|
|
104
|
+
self.metric_name = metric_name
|
|
105
|
+
self.x_axis_label = x_axis_label or parameter_name
|
|
106
|
+
self.y_axis_label = y_axis_label or metric_name
|
|
107
|
+
self.title = title or f"{name} ({parameter_name} ablation)"
|
|
108
|
+
|
|
109
|
+
def materialize(
|
|
110
|
+
self,
|
|
111
|
+
project: Project,
|
|
112
|
+
*,
|
|
113
|
+
name_template: str | None = None,
|
|
114
|
+
description_template: str | None = None,
|
|
115
|
+
base_tags: Sequence[str] | None = None,
|
|
116
|
+
) -> XAbationPatternApplication:
|
|
117
|
+
template = name_template or "{pattern}-{value_slug}"
|
|
118
|
+
tags = tuple(base_tags or ()) + (self.pattern_type,)
|
|
119
|
+
experiments: list[ProjectExperiment] = []
|
|
120
|
+
variant_map: dict[str, AblationVariant] = {}
|
|
121
|
+
for index, variant in enumerate(self._variants):
|
|
122
|
+
experiment_name = template.format(
|
|
123
|
+
pattern=self.name,
|
|
124
|
+
parameter=self.parameter_name,
|
|
125
|
+
value=variant.value,
|
|
126
|
+
value_label=variant.display_label(),
|
|
127
|
+
value_slug=variant.slug(),
|
|
128
|
+
index=index,
|
|
129
|
+
)
|
|
130
|
+
description: str | None = None
|
|
131
|
+
if description_template is not None:
|
|
132
|
+
description = description_template.format(
|
|
133
|
+
pattern=self.name,
|
|
134
|
+
parameter=self.parameter_name,
|
|
135
|
+
value=variant.value,
|
|
136
|
+
value_label=variant.display_label(),
|
|
137
|
+
index=index,
|
|
138
|
+
)
|
|
139
|
+
metadata = {
|
|
140
|
+
"pattern": self.pattern_type,
|
|
141
|
+
"pattern_name": self.name,
|
|
142
|
+
"parameter_name": self.parameter_name,
|
|
143
|
+
"parameter_value": variant.value,
|
|
144
|
+
"parameter_label": variant.display_label(),
|
|
145
|
+
"pattern_index": index,
|
|
146
|
+
}
|
|
147
|
+
metadata.update(dict(variant.metadata))
|
|
148
|
+
definition = self._definition_builder(variant)
|
|
149
|
+
project_experiment = project.add_experiment(
|
|
150
|
+
ProjectExperiment(
|
|
151
|
+
name=experiment_name,
|
|
152
|
+
description=description,
|
|
153
|
+
definition=definition,
|
|
154
|
+
metadata=metadata,
|
|
155
|
+
tags=tuple(dict.fromkeys(tags)),
|
|
156
|
+
)
|
|
157
|
+
)
|
|
158
|
+
experiments.append(project_experiment)
|
|
159
|
+
variant_map[project_experiment.name] = variant
|
|
160
|
+
return XAbationPatternApplication(
|
|
161
|
+
pattern_name=self.name,
|
|
162
|
+
parameter_name=self.parameter_name,
|
|
163
|
+
experiments=tuple(experiments),
|
|
164
|
+
variant_by_name=variant_map,
|
|
165
|
+
_pattern=self,
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
def _build_chart(
|
|
169
|
+
self,
|
|
170
|
+
reports: Mapping[str, ExperimentReport],
|
|
171
|
+
variant_by_name: Mapping[str, AblationVariant],
|
|
172
|
+
) -> AblationChart:
|
|
173
|
+
points: list[AblationChartPoint] = []
|
|
174
|
+
for experiment in variant_by_name:
|
|
175
|
+
variant = variant_by_name[experiment]
|
|
176
|
+
report = reports.get(experiment)
|
|
177
|
+
if report is None:
|
|
178
|
+
raise KeyError(
|
|
179
|
+
f"Missing report for experiment '{experiment}' in pattern '{self.name}'"
|
|
180
|
+
)
|
|
181
|
+
metric = report.evaluation_report.metrics.get(self.metric_name)
|
|
182
|
+
if metric is None:
|
|
183
|
+
raise ValueError(
|
|
184
|
+
f"Metric '{self.metric_name}' not found for experiment '{experiment}'"
|
|
185
|
+
)
|
|
186
|
+
points.append(
|
|
187
|
+
AblationChartPoint(
|
|
188
|
+
x_value=variant.value,
|
|
189
|
+
label=variant.display_label(),
|
|
190
|
+
metric_value=metric.mean,
|
|
191
|
+
metric_name=metric.name,
|
|
192
|
+
count=metric.count,
|
|
193
|
+
)
|
|
194
|
+
)
|
|
195
|
+
ordered_points = self._order_points(points, variant_by_name)
|
|
196
|
+
return AblationChart(
|
|
197
|
+
title=self.title,
|
|
198
|
+
x_label=self.x_axis_label,
|
|
199
|
+
y_label=self.y_axis_label,
|
|
200
|
+
metric_name=self.metric_name,
|
|
201
|
+
points=tuple(ordered_points),
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
def _order_points(
|
|
205
|
+
self,
|
|
206
|
+
points: Sequence[AblationChartPoint],
|
|
207
|
+
variant_by_name: Mapping[str, AblationVariant],
|
|
208
|
+
) -> list[AblationChartPoint]:
|
|
209
|
+
order: dict[Any, int] = {
|
|
210
|
+
variant.value: index for index, variant in enumerate(self._variants)
|
|
211
|
+
}
|
|
212
|
+
return sorted(points, key=lambda point: order.get(point.x_value, 0))
|
|
213
|
+
|
|
214
|
+
def _normalize_variant(self, value: AblationVariant | Any) -> AblationVariant:
|
|
215
|
+
if isinstance(value, AblationVariant):
|
|
216
|
+
return AblationVariant(
|
|
217
|
+
value=value.value,
|
|
218
|
+
label=value.label,
|
|
219
|
+
metadata=dict(value.metadata),
|
|
220
|
+
)
|
|
221
|
+
return AblationVariant(value=value)
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
__all__ = [
|
|
225
|
+
"AblationChart",
|
|
226
|
+
"AblationChartPoint",
|
|
227
|
+
"AblationVariant",
|
|
228
|
+
"XAblationPattern",
|
|
229
|
+
"XAblationPatternApplication",
|
|
230
|
+
]
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
"""Simple registry for ModelProvider factories."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Callable, Dict
|
|
6
|
+
|
|
7
|
+
from themis.interfaces import ModelProvider
|
|
8
|
+
|
|
9
|
+
ProviderFactory = Callable[..., ModelProvider]
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class _ProviderRegistry:
|
|
13
|
+
def __init__(self) -> None:
|
|
14
|
+
self._factories: Dict[str, ProviderFactory] = {}
|
|
15
|
+
|
|
16
|
+
def register(self, name: str, factory: ProviderFactory) -> None:
|
|
17
|
+
key = name.lower()
|
|
18
|
+
self._factories[key] = factory
|
|
19
|
+
|
|
20
|
+
def create(self, name: str, **options) -> ModelProvider:
|
|
21
|
+
key = name.lower()
|
|
22
|
+
factory = self._factories.get(key)
|
|
23
|
+
if factory is None:
|
|
24
|
+
raise KeyError(f"No provider registered under name '{name}'")
|
|
25
|
+
return factory(**options)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
_REGISTRY = _ProviderRegistry()
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def register_provider(name: str, factory: ProviderFactory) -> None:
|
|
32
|
+
_REGISTRY.register(name, factory)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def create_provider(name: str, **options) -> ModelProvider:
|
|
36
|
+
return _REGISTRY.create(name, **options)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
__all__ = ["register_provider", "create_provider", "ProviderFactory"]
|