bead 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bead/__init__.py +11 -0
- bead/__main__.py +11 -0
- bead/active_learning/__init__.py +15 -0
- bead/active_learning/config.py +231 -0
- bead/active_learning/loop.py +566 -0
- bead/active_learning/models/__init__.py +24 -0
- bead/active_learning/models/base.py +852 -0
- bead/active_learning/models/binary.py +910 -0
- bead/active_learning/models/categorical.py +943 -0
- bead/active_learning/models/cloze.py +862 -0
- bead/active_learning/models/forced_choice.py +956 -0
- bead/active_learning/models/free_text.py +773 -0
- bead/active_learning/models/lora.py +365 -0
- bead/active_learning/models/magnitude.py +835 -0
- bead/active_learning/models/multi_select.py +795 -0
- bead/active_learning/models/ordinal_scale.py +811 -0
- bead/active_learning/models/peft_adapter.py +155 -0
- bead/active_learning/models/random_effects.py +639 -0
- bead/active_learning/selection.py +354 -0
- bead/active_learning/strategies.py +391 -0
- bead/active_learning/trainers/__init__.py +26 -0
- bead/active_learning/trainers/base.py +210 -0
- bead/active_learning/trainers/data_collator.py +172 -0
- bead/active_learning/trainers/dataset_utils.py +261 -0
- bead/active_learning/trainers/huggingface.py +304 -0
- bead/active_learning/trainers/lightning.py +324 -0
- bead/active_learning/trainers/metrics.py +424 -0
- bead/active_learning/trainers/mixed_effects.py +551 -0
- bead/active_learning/trainers/model_wrapper.py +509 -0
- bead/active_learning/trainers/registry.py +104 -0
- bead/adapters/__init__.py +11 -0
- bead/adapters/huggingface.py +61 -0
- bead/behavioral/__init__.py +116 -0
- bead/behavioral/analytics.py +646 -0
- bead/behavioral/extraction.py +343 -0
- bead/behavioral/merging.py +343 -0
- bead/cli/__init__.py +11 -0
- bead/cli/active_learning.py +513 -0
- bead/cli/active_learning_commands.py +779 -0
- bead/cli/completion.py +359 -0
- bead/cli/config.py +624 -0
- bead/cli/constraint_builders.py +286 -0
- bead/cli/deployment.py +859 -0
- bead/cli/deployment_trials.py +493 -0
- bead/cli/deployment_ui.py +332 -0
- bead/cli/display.py +378 -0
- bead/cli/items.py +960 -0
- bead/cli/items_factories.py +776 -0
- bead/cli/list_constraints.py +714 -0
- bead/cli/lists.py +490 -0
- bead/cli/main.py +430 -0
- bead/cli/models.py +877 -0
- bead/cli/resource_loaders.py +621 -0
- bead/cli/resources.py +1036 -0
- bead/cli/shell.py +356 -0
- bead/cli/simulate.py +840 -0
- bead/cli/templates.py +1158 -0
- bead/cli/training.py +1080 -0
- bead/cli/utils.py +614 -0
- bead/cli/workflow.py +1273 -0
- bead/config/__init__.py +68 -0
- bead/config/active_learning.py +1009 -0
- bead/config/config.py +192 -0
- bead/config/defaults.py +118 -0
- bead/config/deployment.py +217 -0
- bead/config/env.py +147 -0
- bead/config/item.py +45 -0
- bead/config/list.py +193 -0
- bead/config/loader.py +149 -0
- bead/config/logging.py +42 -0
- bead/config/model.py +49 -0
- bead/config/paths.py +46 -0
- bead/config/profiles.py +320 -0
- bead/config/resources.py +47 -0
- bead/config/serialization.py +210 -0
- bead/config/simulation.py +206 -0
- bead/config/template.py +238 -0
- bead/config/validation.py +267 -0
- bead/data/__init__.py +65 -0
- bead/data/base.py +87 -0
- bead/data/identifiers.py +97 -0
- bead/data/language_codes.py +61 -0
- bead/data/metadata.py +270 -0
- bead/data/range.py +123 -0
- bead/data/repository.py +358 -0
- bead/data/serialization.py +249 -0
- bead/data/timestamps.py +89 -0
- bead/data/validation.py +349 -0
- bead/data_collection/__init__.py +11 -0
- bead/data_collection/jatos.py +223 -0
- bead/data_collection/merger.py +154 -0
- bead/data_collection/prolific.py +198 -0
- bead/deployment/__init__.py +5 -0
- bead/deployment/distribution.py +402 -0
- bead/deployment/jatos/__init__.py +1 -0
- bead/deployment/jatos/api.py +200 -0
- bead/deployment/jatos/exporter.py +210 -0
- bead/deployment/jspsych/__init__.py +9 -0
- bead/deployment/jspsych/biome.json +44 -0
- bead/deployment/jspsych/config.py +411 -0
- bead/deployment/jspsych/generator.py +598 -0
- bead/deployment/jspsych/package.json +51 -0
- bead/deployment/jspsych/pnpm-lock.yaml +2141 -0
- bead/deployment/jspsych/randomizer.py +299 -0
- bead/deployment/jspsych/src/lib/list-distributor.test.ts +327 -0
- bead/deployment/jspsych/src/lib/list-distributor.ts +1282 -0
- bead/deployment/jspsych/src/lib/randomizer.test.ts +232 -0
- bead/deployment/jspsych/src/lib/randomizer.ts +367 -0
- bead/deployment/jspsych/src/plugins/cloze-dropdown.ts +252 -0
- bead/deployment/jspsych/src/plugins/forced-choice.ts +265 -0
- bead/deployment/jspsych/src/plugins/plugins.test.ts +141 -0
- bead/deployment/jspsych/src/plugins/rating.ts +248 -0
- bead/deployment/jspsych/src/slopit/index.ts +9 -0
- bead/deployment/jspsych/src/types/jatos.d.ts +256 -0
- bead/deployment/jspsych/src/types/jspsych.d.ts +228 -0
- bead/deployment/jspsych/templates/experiment.css +1 -0
- bead/deployment/jspsych/templates/experiment.js.template +289 -0
- bead/deployment/jspsych/templates/index.html +51 -0
- bead/deployment/jspsych/templates/randomizer.js +241 -0
- bead/deployment/jspsych/templates/randomizer.js.template +313 -0
- bead/deployment/jspsych/trials.py +723 -0
- bead/deployment/jspsych/tsconfig.json +23 -0
- bead/deployment/jspsych/tsup.config.ts +30 -0
- bead/deployment/jspsych/ui/__init__.py +1 -0
- bead/deployment/jspsych/ui/components.py +383 -0
- bead/deployment/jspsych/ui/styles.py +411 -0
- bead/dsl/__init__.py +80 -0
- bead/dsl/ast.py +168 -0
- bead/dsl/context.py +178 -0
- bead/dsl/errors.py +71 -0
- bead/dsl/evaluator.py +570 -0
- bead/dsl/grammar.lark +81 -0
- bead/dsl/parser.py +231 -0
- bead/dsl/stdlib.py +929 -0
- bead/evaluation/__init__.py +13 -0
- bead/evaluation/convergence.py +485 -0
- bead/evaluation/interannotator.py +398 -0
- bead/items/__init__.py +40 -0
- bead/items/adapters/__init__.py +70 -0
- bead/items/adapters/anthropic.py +224 -0
- bead/items/adapters/api_utils.py +167 -0
- bead/items/adapters/base.py +216 -0
- bead/items/adapters/google.py +259 -0
- bead/items/adapters/huggingface.py +1074 -0
- bead/items/adapters/openai.py +323 -0
- bead/items/adapters/registry.py +202 -0
- bead/items/adapters/sentence_transformers.py +224 -0
- bead/items/adapters/togetherai.py +309 -0
- bead/items/binary.py +515 -0
- bead/items/cache.py +558 -0
- bead/items/categorical.py +593 -0
- bead/items/cloze.py +757 -0
- bead/items/constructor.py +784 -0
- bead/items/forced_choice.py +413 -0
- bead/items/free_text.py +681 -0
- bead/items/generation.py +432 -0
- bead/items/item.py +396 -0
- bead/items/item_template.py +787 -0
- bead/items/magnitude.py +573 -0
- bead/items/multi_select.py +621 -0
- bead/items/ordinal_scale.py +569 -0
- bead/items/scoring.py +448 -0
- bead/items/validation.py +723 -0
- bead/lists/__init__.py +30 -0
- bead/lists/balancer.py +263 -0
- bead/lists/constraints.py +1067 -0
- bead/lists/experiment_list.py +286 -0
- bead/lists/list_collection.py +378 -0
- bead/lists/partitioner.py +1141 -0
- bead/lists/stratification.py +254 -0
- bead/participants/__init__.py +73 -0
- bead/participants/collection.py +699 -0
- bead/participants/merging.py +312 -0
- bead/participants/metadata_spec.py +491 -0
- bead/participants/models.py +276 -0
- bead/resources/__init__.py +29 -0
- bead/resources/adapters/__init__.py +19 -0
- bead/resources/adapters/base.py +104 -0
- bead/resources/adapters/cache.py +128 -0
- bead/resources/adapters/glazing.py +508 -0
- bead/resources/adapters/registry.py +117 -0
- bead/resources/adapters/unimorph.py +796 -0
- bead/resources/classification.py +856 -0
- bead/resources/constraint_builders.py +329 -0
- bead/resources/constraints.py +165 -0
- bead/resources/lexical_item.py +223 -0
- bead/resources/lexicon.py +744 -0
- bead/resources/loaders.py +209 -0
- bead/resources/template.py +441 -0
- bead/resources/template_collection.py +707 -0
- bead/resources/template_generation.py +349 -0
- bead/simulation/__init__.py +29 -0
- bead/simulation/annotators/__init__.py +15 -0
- bead/simulation/annotators/base.py +175 -0
- bead/simulation/annotators/distance_based.py +135 -0
- bead/simulation/annotators/lm_based.py +114 -0
- bead/simulation/annotators/oracle.py +182 -0
- bead/simulation/annotators/random.py +181 -0
- bead/simulation/dsl_extension/__init__.py +3 -0
- bead/simulation/noise_models/__init__.py +13 -0
- bead/simulation/noise_models/base.py +42 -0
- bead/simulation/noise_models/random_noise.py +82 -0
- bead/simulation/noise_models/systematic.py +132 -0
- bead/simulation/noise_models/temperature.py +86 -0
- bead/simulation/runner.py +144 -0
- bead/simulation/strategies/__init__.py +23 -0
- bead/simulation/strategies/base.py +123 -0
- bead/simulation/strategies/binary.py +103 -0
- bead/simulation/strategies/categorical.py +123 -0
- bead/simulation/strategies/cloze.py +224 -0
- bead/simulation/strategies/forced_choice.py +127 -0
- bead/simulation/strategies/free_text.py +105 -0
- bead/simulation/strategies/magnitude.py +116 -0
- bead/simulation/strategies/multi_select.py +129 -0
- bead/simulation/strategies/ordinal_scale.py +131 -0
- bead/templates/__init__.py +27 -0
- bead/templates/adapters/__init__.py +17 -0
- bead/templates/adapters/base.py +128 -0
- bead/templates/adapters/cache.py +178 -0
- bead/templates/adapters/huggingface.py +312 -0
- bead/templates/combinatorics.py +103 -0
- bead/templates/filler.py +605 -0
- bead/templates/renderers.py +177 -0
- bead/templates/resolver.py +178 -0
- bead/templates/strategies.py +1806 -0
- bead/templates/streaming.py +195 -0
- bead-0.1.0.dist-info/METADATA +212 -0
- bead-0.1.0.dist-info/RECORD +231 -0
- bead-0.1.0.dist-info/WHEEL +4 -0
- bead-0.1.0.dist-info/entry_points.txt +2 -0
- bead-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,323 @@
|
|
|
1
|
+
"""OpenAI API adapter for item construction.
|
|
2
|
+
|
|
3
|
+
This module provides a ModelAdapter implementation for OpenAI's API,
|
|
4
|
+
supporting GPT models for various NLP tasks including log probability
|
|
5
|
+
computation, embeddings, and natural language inference via prompting.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import os
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
|
|
14
|
+
try:
|
|
15
|
+
import openai
|
|
16
|
+
except ImportError as e:
|
|
17
|
+
raise ImportError(
|
|
18
|
+
"openai package is required for OpenAI adapter. "
|
|
19
|
+
"Install it with: pip install openai"
|
|
20
|
+
) from e
|
|
21
|
+
|
|
22
|
+
from bead.items.adapters.api_utils import rate_limit, retry_with_backoff
|
|
23
|
+
from bead.items.adapters.base import ModelAdapter
|
|
24
|
+
from bead.items.cache import ModelOutputCache
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class OpenAIAdapter(ModelAdapter):
|
|
28
|
+
"""Adapter for OpenAI API models.
|
|
29
|
+
|
|
30
|
+
Provides access to OpenAI's GPT models for language model operations,
|
|
31
|
+
embeddings, and prompted natural language inference.
|
|
32
|
+
|
|
33
|
+
Parameters
|
|
34
|
+
----------
|
|
35
|
+
model_name : str
|
|
36
|
+
OpenAI model identifier (default: "gpt-3.5-turbo").
|
|
37
|
+
api_key : str | None
|
|
38
|
+
OpenAI API key. If None, uses OPENAI_API_KEY environment variable.
|
|
39
|
+
cache : ModelOutputCache | None
|
|
40
|
+
Cache for model outputs. If None, creates in-memory cache.
|
|
41
|
+
model_version : str
|
|
42
|
+
Model version for cache tracking (default: "latest").
|
|
43
|
+
embedding_model : str
|
|
44
|
+
Model to use for embeddings (default: "text-embedding-ada-002").
|
|
45
|
+
|
|
46
|
+
Attributes
|
|
47
|
+
----------
|
|
48
|
+
model_name : str
|
|
49
|
+
OpenAI model identifier (e.g., "gpt-3.5-turbo", "gpt-4").
|
|
50
|
+
client : openai.OpenAI
|
|
51
|
+
OpenAI API client.
|
|
52
|
+
embedding_model : str
|
|
53
|
+
Model to use for embeddings (default: "text-embedding-ada-002").
|
|
54
|
+
|
|
55
|
+
Raises
|
|
56
|
+
------
|
|
57
|
+
ValueError
|
|
58
|
+
If no API key is provided and OPENAI_API_KEY is not set.
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
def __init__(
|
|
62
|
+
self,
|
|
63
|
+
model_name: str = "gpt-3.5-turbo",
|
|
64
|
+
api_key: str | None = None,
|
|
65
|
+
cache: ModelOutputCache | None = None,
|
|
66
|
+
model_version: str = "latest",
|
|
67
|
+
embedding_model: str = "text-embedding-ada-002",
|
|
68
|
+
) -> None:
|
|
69
|
+
if cache is None:
|
|
70
|
+
cache = ModelOutputCache(backend="memory")
|
|
71
|
+
|
|
72
|
+
super().__init__(
|
|
73
|
+
model_name=model_name, cache=cache, model_version=model_version
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
# Get API key from parameter or environment
|
|
77
|
+
if api_key is None:
|
|
78
|
+
api_key = os.environ.get("OPENAI_API_KEY")
|
|
79
|
+
if api_key is None:
|
|
80
|
+
raise ValueError(
|
|
81
|
+
"OpenAI API key must be provided via api_key parameter "
|
|
82
|
+
"or OPENAI_API_KEY environment variable"
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
self.client = openai.OpenAI(api_key=api_key)
|
|
86
|
+
self.embedding_model = embedding_model
|
|
87
|
+
|
|
88
|
+
@retry_with_backoff(
|
|
89
|
+
max_retries=3,
|
|
90
|
+
initial_delay=1.0,
|
|
91
|
+
backoff_factor=2.0,
|
|
92
|
+
exceptions=(openai.APIError, openai.APIConnectionError, openai.RateLimitError),
|
|
93
|
+
)
|
|
94
|
+
@rate_limit(calls_per_minute=60)
|
|
95
|
+
def compute_log_probability(self, text: str) -> float:
|
|
96
|
+
"""Compute log probability of text using OpenAI completions API.
|
|
97
|
+
|
|
98
|
+
Uses the completions API with logprobs to get token-level log probabilities
|
|
99
|
+
and sums them to get the total log probability.
|
|
100
|
+
|
|
101
|
+
Parameters
|
|
102
|
+
----------
|
|
103
|
+
text : str
|
|
104
|
+
Text to compute log probability for.
|
|
105
|
+
|
|
106
|
+
Returns
|
|
107
|
+
-------
|
|
108
|
+
float
|
|
109
|
+
Log probability of the text (sum of token log probabilities).
|
|
110
|
+
"""
|
|
111
|
+
# Check cache
|
|
112
|
+
cached = self.cache.get(
|
|
113
|
+
model_name=self.model_name, operation="log_probability", text=text
|
|
114
|
+
)
|
|
115
|
+
if cached is not None:
|
|
116
|
+
return float(cached)
|
|
117
|
+
|
|
118
|
+
# Call API
|
|
119
|
+
response = self.client.completions.create(
|
|
120
|
+
model=self.model_name,
|
|
121
|
+
prompt=text,
|
|
122
|
+
max_tokens=0,
|
|
123
|
+
echo=True,
|
|
124
|
+
logprobs=1,
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
# Sum token log probabilities
|
|
128
|
+
logprobs = response.choices[0].logprobs
|
|
129
|
+
if logprobs is None or logprobs.token_logprobs is None:
|
|
130
|
+
raise ValueError("API response did not include logprobs")
|
|
131
|
+
|
|
132
|
+
# Filter out None values (first token may have None)
|
|
133
|
+
token_logprobs = [lp for lp in logprobs.token_logprobs if lp is not None]
|
|
134
|
+
total_log_prob = sum(token_logprobs)
|
|
135
|
+
|
|
136
|
+
# Cache result
|
|
137
|
+
self.cache.set(
|
|
138
|
+
model_name=self.model_name,
|
|
139
|
+
operation="log_probability",
|
|
140
|
+
result=total_log_prob,
|
|
141
|
+
model_version=self.model_version,
|
|
142
|
+
text=text,
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
return float(total_log_prob)
|
|
146
|
+
|
|
147
|
+
def compute_perplexity(self, text: str) -> float:
|
|
148
|
+
"""Compute perplexity of text.
|
|
149
|
+
|
|
150
|
+
Perplexity is computed as exp(-log_prob / num_tokens).
|
|
151
|
+
|
|
152
|
+
Parameters
|
|
153
|
+
----------
|
|
154
|
+
text : str
|
|
155
|
+
Text to compute perplexity for.
|
|
156
|
+
|
|
157
|
+
Returns
|
|
158
|
+
-------
|
|
159
|
+
float
|
|
160
|
+
Perplexity of the text (must be positive).
|
|
161
|
+
"""
|
|
162
|
+
# Check cache
|
|
163
|
+
cached = self.cache.get(
|
|
164
|
+
model_name=self.model_name, operation="perplexity", text=text
|
|
165
|
+
)
|
|
166
|
+
if cached is not None:
|
|
167
|
+
return float(cached)
|
|
168
|
+
|
|
169
|
+
# Get log probability
|
|
170
|
+
log_prob = self.compute_log_probability(text)
|
|
171
|
+
|
|
172
|
+
# Estimate number of tokens (rough approximation: 1 token ~ 4 chars)
|
|
173
|
+
num_tokens = max(1, len(text) // 4)
|
|
174
|
+
|
|
175
|
+
# Compute perplexity: exp(-log_prob / num_tokens)
|
|
176
|
+
perplexity = np.exp(-log_prob / num_tokens)
|
|
177
|
+
|
|
178
|
+
# Cache result
|
|
179
|
+
self.cache.set(
|
|
180
|
+
model_name=self.model_name,
|
|
181
|
+
operation="perplexity",
|
|
182
|
+
result=float(perplexity),
|
|
183
|
+
model_version=self.model_version,
|
|
184
|
+
text=text,
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
return float(perplexity)
|
|
188
|
+
|
|
189
|
+
@retry_with_backoff(
|
|
190
|
+
max_retries=3,
|
|
191
|
+
initial_delay=1.0,
|
|
192
|
+
backoff_factor=2.0,
|
|
193
|
+
exceptions=(openai.APIError, openai.APIConnectionError, openai.RateLimitError),
|
|
194
|
+
)
|
|
195
|
+
@rate_limit(calls_per_minute=60)
|
|
196
|
+
def get_embedding(self, text: str) -> np.ndarray:
|
|
197
|
+
"""Get embedding vector for text using OpenAI embeddings API.
|
|
198
|
+
|
|
199
|
+
Parameters
|
|
200
|
+
----------
|
|
201
|
+
text : str
|
|
202
|
+
Text to embed.
|
|
203
|
+
|
|
204
|
+
Returns
|
|
205
|
+
-------
|
|
206
|
+
np.ndarray
|
|
207
|
+
Embedding vector for the text.
|
|
208
|
+
"""
|
|
209
|
+
# Check cache
|
|
210
|
+
cached = self.cache.get(
|
|
211
|
+
model_name=self.embedding_model, operation="embedding", text=text
|
|
212
|
+
)
|
|
213
|
+
if cached is not None:
|
|
214
|
+
return np.array(cached)
|
|
215
|
+
|
|
216
|
+
# Call API
|
|
217
|
+
response = self.client.embeddings.create(model=self.embedding_model, input=text)
|
|
218
|
+
|
|
219
|
+
embedding = np.array(response.data[0].embedding)
|
|
220
|
+
|
|
221
|
+
# Cache result
|
|
222
|
+
self.cache.set(
|
|
223
|
+
model_name=self.embedding_model,
|
|
224
|
+
operation="embedding",
|
|
225
|
+
result=embedding.tolist(),
|
|
226
|
+
model_version=self.model_version,
|
|
227
|
+
text=text,
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
return embedding
|
|
231
|
+
|
|
232
|
+
@retry_with_backoff(
|
|
233
|
+
max_retries=3,
|
|
234
|
+
initial_delay=1.0,
|
|
235
|
+
backoff_factor=2.0,
|
|
236
|
+
exceptions=(openai.APIError, openai.APIConnectionError, openai.RateLimitError),
|
|
237
|
+
)
|
|
238
|
+
@rate_limit(calls_per_minute=60)
|
|
239
|
+
def compute_nli(self, premise: str, hypothesis: str) -> dict[str, float]:
|
|
240
|
+
"""Compute natural language inference scores via prompting.
|
|
241
|
+
|
|
242
|
+
Uses chat completions API with a prompt to classify the relationship
|
|
243
|
+
between premise and hypothesis.
|
|
244
|
+
|
|
245
|
+
Parameters
|
|
246
|
+
----------
|
|
247
|
+
premise : str
|
|
248
|
+
Premise text.
|
|
249
|
+
hypothesis : str
|
|
250
|
+
Hypothesis text.
|
|
251
|
+
|
|
252
|
+
Returns
|
|
253
|
+
-------
|
|
254
|
+
dict[str, float]
|
|
255
|
+
Dictionary with keys "entailment", "neutral", "contradiction"
|
|
256
|
+
mapping to probability scores.
|
|
257
|
+
"""
|
|
258
|
+
# Check cache
|
|
259
|
+
cached = self.cache.get(
|
|
260
|
+
model_name=self.model_name,
|
|
261
|
+
operation="nli",
|
|
262
|
+
premise=premise,
|
|
263
|
+
hypothesis=hypothesis,
|
|
264
|
+
)
|
|
265
|
+
if cached is not None:
|
|
266
|
+
return dict(cached)
|
|
267
|
+
|
|
268
|
+
# Construct prompt
|
|
269
|
+
prompt = (
|
|
270
|
+
"Given the following premise and hypothesis, "
|
|
271
|
+
"determine the relationship between them.\n\n"
|
|
272
|
+
f"Premise: {premise}\n"
|
|
273
|
+
f"Hypothesis: {hypothesis}\n\n"
|
|
274
|
+
"Choose one of the following:\n"
|
|
275
|
+
"- entailment: The hypothesis is definitely true given the premise\n"
|
|
276
|
+
"- neutral: The hypothesis might be true given the premise\n"
|
|
277
|
+
"- contradiction: The hypothesis is definitely false given the premise\n\n"
|
|
278
|
+
"Respond with only one word: entailment, neutral, or contradiction."
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
# Call API
|
|
282
|
+
response = self.client.chat.completions.create(
|
|
283
|
+
model=self.model_name,
|
|
284
|
+
messages=[{"role": "user", "content": prompt}],
|
|
285
|
+
temperature=0.0,
|
|
286
|
+
max_tokens=10,
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
# Parse response
|
|
290
|
+
answer = response.choices[0].message.content
|
|
291
|
+
if answer is None:
|
|
292
|
+
raise ValueError("API response did not include content")
|
|
293
|
+
|
|
294
|
+
answer = answer.strip().lower()
|
|
295
|
+
|
|
296
|
+
# Map to scores
|
|
297
|
+
scores: dict[str, float] = {
|
|
298
|
+
"entailment": 0.0,
|
|
299
|
+
"neutral": 0.0,
|
|
300
|
+
"contradiction": 0.0,
|
|
301
|
+
}
|
|
302
|
+
|
|
303
|
+
if "entailment" in answer:
|
|
304
|
+
scores["entailment"] = 1.0
|
|
305
|
+
elif "neutral" in answer:
|
|
306
|
+
scores["neutral"] = 1.0
|
|
307
|
+
elif "contradiction" in answer:
|
|
308
|
+
scores["contradiction"] = 1.0
|
|
309
|
+
else:
|
|
310
|
+
# Default to neutral if unclear
|
|
311
|
+
scores["neutral"] = 1.0
|
|
312
|
+
|
|
313
|
+
# Cache result
|
|
314
|
+
self.cache.set(
|
|
315
|
+
model_name=self.model_name,
|
|
316
|
+
operation="nli",
|
|
317
|
+
result=scores,
|
|
318
|
+
model_version=self.model_version,
|
|
319
|
+
premise=premise,
|
|
320
|
+
hypothesis=hypothesis,
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
return scores
|
|
@@ -0,0 +1,202 @@
|
|
|
1
|
+
"""Model adapter registry for centralized adapter management.
|
|
2
|
+
|
|
3
|
+
This module provides a registry for managing all model adapters,
|
|
4
|
+
both local (HuggingFace) and API-based (OpenAI, Anthropic, etc.).
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from typing import TYPE_CHECKING, Unpack
|
|
10
|
+
|
|
11
|
+
from typing_extensions import TypedDict
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from bead.items.cache import ModelOutputCache
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class AdapterKwargs(TypedDict, total=False):
|
|
18
|
+
"""Keyword arguments for adapter initialization."""
|
|
19
|
+
|
|
20
|
+
api_key: str
|
|
21
|
+
device: str
|
|
22
|
+
model_version: str
|
|
23
|
+
embedding_model: str
|
|
24
|
+
normalize_embeddings: bool
|
|
25
|
+
cache: ModelOutputCache
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
from bead.items.adapters.base import ModelAdapter # noqa: E402
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class ModelAdapterRegistry:
|
|
32
|
+
"""Registry for all model adapters (local and API-based).
|
|
33
|
+
|
|
34
|
+
Provides centralized management of adapter types and instances,
|
|
35
|
+
with automatic instance caching to avoid redundant initialization.
|
|
36
|
+
|
|
37
|
+
Attributes
|
|
38
|
+
----------
|
|
39
|
+
adapters : dict[str, type[ModelAdapter]]
|
|
40
|
+
Registered adapter classes keyed by adapter type name.
|
|
41
|
+
instances : dict[str, ModelAdapter]
|
|
42
|
+
Cached adapter instances keyed by unique identifier.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
def __init__(self) -> None:
|
|
46
|
+
self.adapters: dict[str, type[ModelAdapter]] = {}
|
|
47
|
+
self.instances: dict[str, ModelAdapter] = {}
|
|
48
|
+
|
|
49
|
+
def register(self, name: str, adapter_class: type[ModelAdapter]) -> None:
|
|
50
|
+
"""Register an adapter class.
|
|
51
|
+
|
|
52
|
+
Parameters
|
|
53
|
+
----------
|
|
54
|
+
name : str
|
|
55
|
+
Unique name for the adapter type (e.g., "openai", "huggingface_lm").
|
|
56
|
+
adapter_class : type[ModelAdapter]
|
|
57
|
+
Adapter class to register (must inherit from ModelAdapter).
|
|
58
|
+
|
|
59
|
+
Raises
|
|
60
|
+
------
|
|
61
|
+
ValueError
|
|
62
|
+
If adapter class does not inherit from ModelAdapter.
|
|
63
|
+
"""
|
|
64
|
+
if not issubclass(adapter_class, ModelAdapter): # type: ignore[misc]
|
|
65
|
+
raise ValueError(
|
|
66
|
+
f"Adapter class {adapter_class.__name__} must inherit from ModelAdapter"
|
|
67
|
+
)
|
|
68
|
+
self.adapters[name] = adapter_class
|
|
69
|
+
|
|
70
|
+
def get_adapter(
|
|
71
|
+
self, adapter_type: str, model_name: str, **kwargs: Unpack[AdapterKwargs]
|
|
72
|
+
) -> ModelAdapter:
|
|
73
|
+
"""Get or create adapter instance (with caching).
|
|
74
|
+
|
|
75
|
+
Creates a new adapter instance if not cached, otherwise returns
|
|
76
|
+
the cached instance. Instances are cached by adapter type and model name.
|
|
77
|
+
|
|
78
|
+
Parameters
|
|
79
|
+
----------
|
|
80
|
+
adapter_type
|
|
81
|
+
Type of adapter (must be registered).
|
|
82
|
+
model_name
|
|
83
|
+
Model identifier for the adapter.
|
|
84
|
+
**kwargs
|
|
85
|
+
Additional keyword arguments to pass to adapter constructor
|
|
86
|
+
(api_key, device, model_version, embedding_model, etc.).
|
|
87
|
+
|
|
88
|
+
Returns
|
|
89
|
+
-------
|
|
90
|
+
ModelAdapter
|
|
91
|
+
Adapter instance (cached or newly created).
|
|
92
|
+
|
|
93
|
+
Raises
|
|
94
|
+
------
|
|
95
|
+
ValueError
|
|
96
|
+
If adapter type is not registered.
|
|
97
|
+
|
|
98
|
+
Examples
|
|
99
|
+
--------
|
|
100
|
+
>>> registry = ModelAdapterRegistry()
|
|
101
|
+
>>> registry.register("openai", OpenAIAdapter)
|
|
102
|
+
>>> adapter = registry.get_adapter("openai", "gpt-4", api_key="...")
|
|
103
|
+
"""
|
|
104
|
+
if adapter_type not in self.adapters:
|
|
105
|
+
raise ValueError(
|
|
106
|
+
f"Unknown adapter type: {adapter_type}. "
|
|
107
|
+
f"Available types: {list(self.adapters.keys())}"
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
# create cache key from adapter type and model name
|
|
111
|
+
cache_key = f"{adapter_type}:{model_name}"
|
|
112
|
+
|
|
113
|
+
# return cached instance if available
|
|
114
|
+
if cache_key in self.instances:
|
|
115
|
+
return self.instances[cache_key]
|
|
116
|
+
|
|
117
|
+
# create new instance
|
|
118
|
+
adapter_class = self.adapters[adapter_type]
|
|
119
|
+
adapter = adapter_class(model_name=model_name, **kwargs) # type: ignore[misc]
|
|
120
|
+
|
|
121
|
+
# cache and return
|
|
122
|
+
self.instances[cache_key] = adapter
|
|
123
|
+
return adapter
|
|
124
|
+
|
|
125
|
+
def clear_cache(self) -> None:
|
|
126
|
+
"""Clear all cached adapter instances.
|
|
127
|
+
|
|
128
|
+
Useful for testing or when you want to force recreation of adapters
|
|
129
|
+
with different parameters.
|
|
130
|
+
"""
|
|
131
|
+
self.instances.clear()
|
|
132
|
+
|
|
133
|
+
def list_adapters(self) -> list[str]:
|
|
134
|
+
"""List all registered adapter types.
|
|
135
|
+
|
|
136
|
+
Returns
|
|
137
|
+
-------
|
|
138
|
+
list[str]
|
|
139
|
+
List of registered adapter type names.
|
|
140
|
+
"""
|
|
141
|
+
return list(self.adapters.keys())
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
# Create default registry with all built-in adapters
|
|
145
|
+
default_registry = ModelAdapterRegistry()
|
|
146
|
+
|
|
147
|
+
# Register HuggingFace adapters
|
|
148
|
+
try:
|
|
149
|
+
from bead.items.adapters.huggingface import (
|
|
150
|
+
HuggingFaceLanguageModel,
|
|
151
|
+
HuggingFaceMaskedLanguageModel,
|
|
152
|
+
HuggingFaceNLI,
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
default_registry.register("huggingface_lm", HuggingFaceLanguageModel)
|
|
156
|
+
default_registry.register("huggingface_mlm", HuggingFaceMaskedLanguageModel)
|
|
157
|
+
default_registry.register("huggingface_nli", HuggingFaceNLI)
|
|
158
|
+
except ImportError:
|
|
159
|
+
# HuggingFace adapters not available (missing dependencies)
|
|
160
|
+
pass
|
|
161
|
+
|
|
162
|
+
# Register sentence transformers
|
|
163
|
+
try:
|
|
164
|
+
from bead.items.adapters.sentence_transformers import HuggingFaceSentenceTransformer
|
|
165
|
+
|
|
166
|
+
default_registry.register("sentence_transformer", HuggingFaceSentenceTransformer)
|
|
167
|
+
except ImportError:
|
|
168
|
+
# Sentence transformers not available
|
|
169
|
+
pass
|
|
170
|
+
|
|
171
|
+
# Register API adapters (these are optional)
|
|
172
|
+
try:
|
|
173
|
+
from bead.items.adapters.openai import OpenAIAdapter
|
|
174
|
+
|
|
175
|
+
default_registry.register("openai", OpenAIAdapter)
|
|
176
|
+
except ImportError:
|
|
177
|
+
# OpenAI not available
|
|
178
|
+
pass
|
|
179
|
+
|
|
180
|
+
try:
|
|
181
|
+
from bead.items.adapters.anthropic import AnthropicAdapter
|
|
182
|
+
|
|
183
|
+
default_registry.register("anthropic", AnthropicAdapter)
|
|
184
|
+
except ImportError:
|
|
185
|
+
# Anthropic not available
|
|
186
|
+
pass
|
|
187
|
+
|
|
188
|
+
try:
|
|
189
|
+
from bead.items.adapters.google import GoogleAdapter
|
|
190
|
+
|
|
191
|
+
default_registry.register("google", GoogleAdapter)
|
|
192
|
+
except ImportError:
|
|
193
|
+
# Google not available
|
|
194
|
+
pass
|
|
195
|
+
|
|
196
|
+
try:
|
|
197
|
+
from bead.items.adapters.togetherai import TogetherAIAdapter
|
|
198
|
+
|
|
199
|
+
default_registry.register("togetherai", TogetherAIAdapter)
|
|
200
|
+
except ImportError:
|
|
201
|
+
# Together AI not available
|
|
202
|
+
pass
|