bead 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bead/__init__.py +11 -0
- bead/__main__.py +11 -0
- bead/active_learning/__init__.py +15 -0
- bead/active_learning/config.py +231 -0
- bead/active_learning/loop.py +566 -0
- bead/active_learning/models/__init__.py +24 -0
- bead/active_learning/models/base.py +852 -0
- bead/active_learning/models/binary.py +910 -0
- bead/active_learning/models/categorical.py +943 -0
- bead/active_learning/models/cloze.py +862 -0
- bead/active_learning/models/forced_choice.py +956 -0
- bead/active_learning/models/free_text.py +773 -0
- bead/active_learning/models/lora.py +365 -0
- bead/active_learning/models/magnitude.py +835 -0
- bead/active_learning/models/multi_select.py +795 -0
- bead/active_learning/models/ordinal_scale.py +811 -0
- bead/active_learning/models/peft_adapter.py +155 -0
- bead/active_learning/models/random_effects.py +639 -0
- bead/active_learning/selection.py +354 -0
- bead/active_learning/strategies.py +391 -0
- bead/active_learning/trainers/__init__.py +26 -0
- bead/active_learning/trainers/base.py +210 -0
- bead/active_learning/trainers/data_collator.py +172 -0
- bead/active_learning/trainers/dataset_utils.py +261 -0
- bead/active_learning/trainers/huggingface.py +304 -0
- bead/active_learning/trainers/lightning.py +324 -0
- bead/active_learning/trainers/metrics.py +424 -0
- bead/active_learning/trainers/mixed_effects.py +551 -0
- bead/active_learning/trainers/model_wrapper.py +509 -0
- bead/active_learning/trainers/registry.py +104 -0
- bead/adapters/__init__.py +11 -0
- bead/adapters/huggingface.py +61 -0
- bead/behavioral/__init__.py +116 -0
- bead/behavioral/analytics.py +646 -0
- bead/behavioral/extraction.py +343 -0
- bead/behavioral/merging.py +343 -0
- bead/cli/__init__.py +11 -0
- bead/cli/active_learning.py +513 -0
- bead/cli/active_learning_commands.py +779 -0
- bead/cli/completion.py +359 -0
- bead/cli/config.py +624 -0
- bead/cli/constraint_builders.py +286 -0
- bead/cli/deployment.py +859 -0
- bead/cli/deployment_trials.py +493 -0
- bead/cli/deployment_ui.py +332 -0
- bead/cli/display.py +378 -0
- bead/cli/items.py +960 -0
- bead/cli/items_factories.py +776 -0
- bead/cli/list_constraints.py +714 -0
- bead/cli/lists.py +490 -0
- bead/cli/main.py +430 -0
- bead/cli/models.py +877 -0
- bead/cli/resource_loaders.py +621 -0
- bead/cli/resources.py +1036 -0
- bead/cli/shell.py +356 -0
- bead/cli/simulate.py +840 -0
- bead/cli/templates.py +1158 -0
- bead/cli/training.py +1080 -0
- bead/cli/utils.py +614 -0
- bead/cli/workflow.py +1273 -0
- bead/config/__init__.py +68 -0
- bead/config/active_learning.py +1009 -0
- bead/config/config.py +192 -0
- bead/config/defaults.py +118 -0
- bead/config/deployment.py +217 -0
- bead/config/env.py +147 -0
- bead/config/item.py +45 -0
- bead/config/list.py +193 -0
- bead/config/loader.py +149 -0
- bead/config/logging.py +42 -0
- bead/config/model.py +49 -0
- bead/config/paths.py +46 -0
- bead/config/profiles.py +320 -0
- bead/config/resources.py +47 -0
- bead/config/serialization.py +210 -0
- bead/config/simulation.py +206 -0
- bead/config/template.py +238 -0
- bead/config/validation.py +267 -0
- bead/data/__init__.py +65 -0
- bead/data/base.py +87 -0
- bead/data/identifiers.py +97 -0
- bead/data/language_codes.py +61 -0
- bead/data/metadata.py +270 -0
- bead/data/range.py +123 -0
- bead/data/repository.py +358 -0
- bead/data/serialization.py +249 -0
- bead/data/timestamps.py +89 -0
- bead/data/validation.py +349 -0
- bead/data_collection/__init__.py +11 -0
- bead/data_collection/jatos.py +223 -0
- bead/data_collection/merger.py +154 -0
- bead/data_collection/prolific.py +198 -0
- bead/deployment/__init__.py +5 -0
- bead/deployment/distribution.py +402 -0
- bead/deployment/jatos/__init__.py +1 -0
- bead/deployment/jatos/api.py +200 -0
- bead/deployment/jatos/exporter.py +210 -0
- bead/deployment/jspsych/__init__.py +9 -0
- bead/deployment/jspsych/biome.json +44 -0
- bead/deployment/jspsych/config.py +411 -0
- bead/deployment/jspsych/generator.py +598 -0
- bead/deployment/jspsych/package.json +51 -0
- bead/deployment/jspsych/pnpm-lock.yaml +2141 -0
- bead/deployment/jspsych/randomizer.py +299 -0
- bead/deployment/jspsych/src/lib/list-distributor.test.ts +327 -0
- bead/deployment/jspsych/src/lib/list-distributor.ts +1282 -0
- bead/deployment/jspsych/src/lib/randomizer.test.ts +232 -0
- bead/deployment/jspsych/src/lib/randomizer.ts +367 -0
- bead/deployment/jspsych/src/plugins/cloze-dropdown.ts +252 -0
- bead/deployment/jspsych/src/plugins/forced-choice.ts +265 -0
- bead/deployment/jspsych/src/plugins/plugins.test.ts +141 -0
- bead/deployment/jspsych/src/plugins/rating.ts +248 -0
- bead/deployment/jspsych/src/slopit/index.ts +9 -0
- bead/deployment/jspsych/src/types/jatos.d.ts +256 -0
- bead/deployment/jspsych/src/types/jspsych.d.ts +228 -0
- bead/deployment/jspsych/templates/experiment.css +1 -0
- bead/deployment/jspsych/templates/experiment.js.template +289 -0
- bead/deployment/jspsych/templates/index.html +51 -0
- bead/deployment/jspsych/templates/randomizer.js +241 -0
- bead/deployment/jspsych/templates/randomizer.js.template +313 -0
- bead/deployment/jspsych/trials.py +723 -0
- bead/deployment/jspsych/tsconfig.json +23 -0
- bead/deployment/jspsych/tsup.config.ts +30 -0
- bead/deployment/jspsych/ui/__init__.py +1 -0
- bead/deployment/jspsych/ui/components.py +383 -0
- bead/deployment/jspsych/ui/styles.py +411 -0
- bead/dsl/__init__.py +80 -0
- bead/dsl/ast.py +168 -0
- bead/dsl/context.py +178 -0
- bead/dsl/errors.py +71 -0
- bead/dsl/evaluator.py +570 -0
- bead/dsl/grammar.lark +81 -0
- bead/dsl/parser.py +231 -0
- bead/dsl/stdlib.py +929 -0
- bead/evaluation/__init__.py +13 -0
- bead/evaluation/convergence.py +485 -0
- bead/evaluation/interannotator.py +398 -0
- bead/items/__init__.py +40 -0
- bead/items/adapters/__init__.py +70 -0
- bead/items/adapters/anthropic.py +224 -0
- bead/items/adapters/api_utils.py +167 -0
- bead/items/adapters/base.py +216 -0
- bead/items/adapters/google.py +259 -0
- bead/items/adapters/huggingface.py +1074 -0
- bead/items/adapters/openai.py +323 -0
- bead/items/adapters/registry.py +202 -0
- bead/items/adapters/sentence_transformers.py +224 -0
- bead/items/adapters/togetherai.py +309 -0
- bead/items/binary.py +515 -0
- bead/items/cache.py +558 -0
- bead/items/categorical.py +593 -0
- bead/items/cloze.py +757 -0
- bead/items/constructor.py +784 -0
- bead/items/forced_choice.py +413 -0
- bead/items/free_text.py +681 -0
- bead/items/generation.py +432 -0
- bead/items/item.py +396 -0
- bead/items/item_template.py +787 -0
- bead/items/magnitude.py +573 -0
- bead/items/multi_select.py +621 -0
- bead/items/ordinal_scale.py +569 -0
- bead/items/scoring.py +448 -0
- bead/items/validation.py +723 -0
- bead/lists/__init__.py +30 -0
- bead/lists/balancer.py +263 -0
- bead/lists/constraints.py +1067 -0
- bead/lists/experiment_list.py +286 -0
- bead/lists/list_collection.py +378 -0
- bead/lists/partitioner.py +1141 -0
- bead/lists/stratification.py +254 -0
- bead/participants/__init__.py +73 -0
- bead/participants/collection.py +699 -0
- bead/participants/merging.py +312 -0
- bead/participants/metadata_spec.py +491 -0
- bead/participants/models.py +276 -0
- bead/resources/__init__.py +29 -0
- bead/resources/adapters/__init__.py +19 -0
- bead/resources/adapters/base.py +104 -0
- bead/resources/adapters/cache.py +128 -0
- bead/resources/adapters/glazing.py +508 -0
- bead/resources/adapters/registry.py +117 -0
- bead/resources/adapters/unimorph.py +796 -0
- bead/resources/classification.py +856 -0
- bead/resources/constraint_builders.py +329 -0
- bead/resources/constraints.py +165 -0
- bead/resources/lexical_item.py +223 -0
- bead/resources/lexicon.py +744 -0
- bead/resources/loaders.py +209 -0
- bead/resources/template.py +441 -0
- bead/resources/template_collection.py +707 -0
- bead/resources/template_generation.py +349 -0
- bead/simulation/__init__.py +29 -0
- bead/simulation/annotators/__init__.py +15 -0
- bead/simulation/annotators/base.py +175 -0
- bead/simulation/annotators/distance_based.py +135 -0
- bead/simulation/annotators/lm_based.py +114 -0
- bead/simulation/annotators/oracle.py +182 -0
- bead/simulation/annotators/random.py +181 -0
- bead/simulation/dsl_extension/__init__.py +3 -0
- bead/simulation/noise_models/__init__.py +13 -0
- bead/simulation/noise_models/base.py +42 -0
- bead/simulation/noise_models/random_noise.py +82 -0
- bead/simulation/noise_models/systematic.py +132 -0
- bead/simulation/noise_models/temperature.py +86 -0
- bead/simulation/runner.py +144 -0
- bead/simulation/strategies/__init__.py +23 -0
- bead/simulation/strategies/base.py +123 -0
- bead/simulation/strategies/binary.py +103 -0
- bead/simulation/strategies/categorical.py +123 -0
- bead/simulation/strategies/cloze.py +224 -0
- bead/simulation/strategies/forced_choice.py +127 -0
- bead/simulation/strategies/free_text.py +105 -0
- bead/simulation/strategies/magnitude.py +116 -0
- bead/simulation/strategies/multi_select.py +129 -0
- bead/simulation/strategies/ordinal_scale.py +131 -0
- bead/templates/__init__.py +27 -0
- bead/templates/adapters/__init__.py +17 -0
- bead/templates/adapters/base.py +128 -0
- bead/templates/adapters/cache.py +178 -0
- bead/templates/adapters/huggingface.py +312 -0
- bead/templates/combinatorics.py +103 -0
- bead/templates/filler.py +605 -0
- bead/templates/renderers.py +177 -0
- bead/templates/resolver.py +178 -0
- bead/templates/strategies.py +1806 -0
- bead/templates/streaming.py +195 -0
- bead-0.1.0.dist-info/METADATA +212 -0
- bead-0.1.0.dist-info/RECORD +231 -0
- bead-0.1.0.dist-info/WHEEL +4 -0
- bead-0.1.0.dist-info/entry_points.txt +2 -0
- bead-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,224 @@
|
|
|
1
|
+
"""Anthropic API adapter for item construction.
|
|
2
|
+
|
|
3
|
+
This module provides a ModelAdapter implementation for Anthropic's Claude API,
|
|
4
|
+
supporting natural language inference via prompting. Note that Claude API does
|
|
5
|
+
not provide direct access to log probabilities or embeddings.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import os
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
|
|
14
|
+
try:
|
|
15
|
+
import anthropic
|
|
16
|
+
except ImportError as e:
|
|
17
|
+
raise ImportError(
|
|
18
|
+
"anthropic package is required for Anthropic adapter. "
|
|
19
|
+
"Install it with: pip install anthropic"
|
|
20
|
+
) from e
|
|
21
|
+
|
|
22
|
+
from bead.items.adapters.api_utils import rate_limit, retry_with_backoff
|
|
23
|
+
from bead.items.adapters.base import ModelAdapter
|
|
24
|
+
from bead.items.cache import ModelOutputCache
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class AnthropicAdapter(ModelAdapter):
|
|
28
|
+
"""Adapter for Anthropic Claude API models.
|
|
29
|
+
|
|
30
|
+
Provides access to Claude models for prompted natural language inference.
|
|
31
|
+
Note that Claude API does not support log probability computation or
|
|
32
|
+
embeddings, so those methods will raise NotImplementedError.
|
|
33
|
+
|
|
34
|
+
Parameters
|
|
35
|
+
----------
|
|
36
|
+
model_name : str
|
|
37
|
+
Claude model identifier (default: "claude-3-5-sonnet-20241022").
|
|
38
|
+
api_key : str | None
|
|
39
|
+
Anthropic API key. If None, uses ANTHROPIC_API_KEY environment variable.
|
|
40
|
+
cache : ModelOutputCache | None
|
|
41
|
+
Cache for model outputs. If None, creates in-memory cache.
|
|
42
|
+
model_version : str
|
|
43
|
+
Model version for cache tracking (default: "latest").
|
|
44
|
+
|
|
45
|
+
Attributes
|
|
46
|
+
----------
|
|
47
|
+
model_name : str
|
|
48
|
+
Claude model identifier (e.g., "claude-3-5-sonnet-20241022").
|
|
49
|
+
client : anthropic.Anthropic
|
|
50
|
+
Anthropic API client.
|
|
51
|
+
|
|
52
|
+
Raises
|
|
53
|
+
------
|
|
54
|
+
ValueError
|
|
55
|
+
If no API key is provided and ANTHROPIC_API_KEY is not set.
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
def __init__(
|
|
59
|
+
self,
|
|
60
|
+
model_name: str = "claude-3-5-sonnet-20241022",
|
|
61
|
+
api_key: str | None = None,
|
|
62
|
+
cache: ModelOutputCache | None = None,
|
|
63
|
+
model_version: str = "latest",
|
|
64
|
+
) -> None:
|
|
65
|
+
if cache is None:
|
|
66
|
+
cache = ModelOutputCache(backend="memory")
|
|
67
|
+
|
|
68
|
+
super().__init__(
|
|
69
|
+
model_name=model_name, cache=cache, model_version=model_version
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
# Get API key from parameter or environment
|
|
73
|
+
if api_key is None:
|
|
74
|
+
api_key = os.environ.get("ANTHROPIC_API_KEY")
|
|
75
|
+
if api_key is None:
|
|
76
|
+
raise ValueError(
|
|
77
|
+
"Anthropic API key must be provided via api_key parameter "
|
|
78
|
+
"or ANTHROPIC_API_KEY environment variable"
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
self.client = anthropic.Anthropic(api_key=api_key)
|
|
82
|
+
|
|
83
|
+
def compute_log_probability(self, text: str) -> float:
|
|
84
|
+
"""Compute log probability of text.
|
|
85
|
+
|
|
86
|
+
Not supported by Anthropic API.
|
|
87
|
+
|
|
88
|
+
Raises
|
|
89
|
+
------
|
|
90
|
+
NotImplementedError
|
|
91
|
+
Always raised - Claude API does not provide log probabilities.
|
|
92
|
+
"""
|
|
93
|
+
raise NotImplementedError(
|
|
94
|
+
"Log probability computation is not supported by Anthropic Claude API. "
|
|
95
|
+
"Claude does not provide access to token-level probabilities."
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
def compute_perplexity(self, text: str) -> float:
|
|
99
|
+
"""Compute perplexity of text.
|
|
100
|
+
|
|
101
|
+
Not supported by Anthropic API (requires log probabilities).
|
|
102
|
+
|
|
103
|
+
Raises
|
|
104
|
+
------
|
|
105
|
+
NotImplementedError
|
|
106
|
+
Always raised - requires log probability support.
|
|
107
|
+
"""
|
|
108
|
+
raise NotImplementedError(
|
|
109
|
+
"Perplexity computation is not supported by Anthropic Claude API. "
|
|
110
|
+
"This operation requires log probabilities, which Claude does not provide."
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
def get_embedding(self, text: str) -> np.ndarray:
|
|
114
|
+
"""Get embedding vector for text.
|
|
115
|
+
|
|
116
|
+
Not supported by Anthropic API.
|
|
117
|
+
|
|
118
|
+
Raises
|
|
119
|
+
------
|
|
120
|
+
NotImplementedError
|
|
121
|
+
Always raised - Claude API does not provide embeddings.
|
|
122
|
+
"""
|
|
123
|
+
raise NotImplementedError(
|
|
124
|
+
"Embedding computation is not supported by Anthropic Claude API. "
|
|
125
|
+
"Claude does not provide embedding vectors. "
|
|
126
|
+
"Consider using OpenAI's text-embedding models or sentence transformers."
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
@retry_with_backoff(
|
|
130
|
+
max_retries=3,
|
|
131
|
+
initial_delay=1.0,
|
|
132
|
+
backoff_factor=2.0,
|
|
133
|
+
exceptions=(
|
|
134
|
+
anthropic.APIError,
|
|
135
|
+
anthropic.APIConnectionError,
|
|
136
|
+
anthropic.RateLimitError,
|
|
137
|
+
),
|
|
138
|
+
)
|
|
139
|
+
@rate_limit(calls_per_minute=60)
|
|
140
|
+
def compute_nli(self, premise: str, hypothesis: str) -> dict[str, float]:
|
|
141
|
+
"""Compute natural language inference scores via prompting.
|
|
142
|
+
|
|
143
|
+
Uses Claude's messages API with a prompt to classify the relationship
|
|
144
|
+
between premise and hypothesis.
|
|
145
|
+
|
|
146
|
+
Parameters
|
|
147
|
+
----------
|
|
148
|
+
premise : str
|
|
149
|
+
Premise text.
|
|
150
|
+
hypothesis : str
|
|
151
|
+
Hypothesis text.
|
|
152
|
+
|
|
153
|
+
Returns
|
|
154
|
+
-------
|
|
155
|
+
dict[str, float]
|
|
156
|
+
Dictionary with keys "entailment", "neutral", "contradiction"
|
|
157
|
+
mapping to probability scores.
|
|
158
|
+
"""
|
|
159
|
+
# Check cache
|
|
160
|
+
cached = self.cache.get(
|
|
161
|
+
model_name=self.model_name,
|
|
162
|
+
operation="nli",
|
|
163
|
+
premise=premise,
|
|
164
|
+
hypothesis=hypothesis,
|
|
165
|
+
)
|
|
166
|
+
if cached is not None:
|
|
167
|
+
return dict(cached)
|
|
168
|
+
|
|
169
|
+
# Construct prompt
|
|
170
|
+
prompt = (
|
|
171
|
+
"Given the following premise and hypothesis, "
|
|
172
|
+
"determine the relationship between them.\n\n"
|
|
173
|
+
f"Premise: {premise}\n"
|
|
174
|
+
f"Hypothesis: {hypothesis}\n\n"
|
|
175
|
+
"Choose one of the following:\n"
|
|
176
|
+
"- entailment: The hypothesis is definitely true given the premise\n"
|
|
177
|
+
"- neutral: The hypothesis might be true given the premise\n"
|
|
178
|
+
"- contradiction: The hypothesis is definitely false given the premise\n\n"
|
|
179
|
+
"Respond with only one word: entailment, neutral, or contradiction."
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
# Call API
|
|
183
|
+
response = self.client.messages.create(
|
|
184
|
+
model=self.model_name,
|
|
185
|
+
max_tokens=10,
|
|
186
|
+
temperature=0.0,
|
|
187
|
+
messages=[{"role": "user", "content": prompt}],
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
# Parse response
|
|
191
|
+
if not response.content or len(response.content) == 0:
|
|
192
|
+
raise ValueError("API response did not include content")
|
|
193
|
+
|
|
194
|
+
# Get text from first content block
|
|
195
|
+
answer = response.content[0].text.strip().lower()
|
|
196
|
+
|
|
197
|
+
# Map to scores
|
|
198
|
+
scores: dict[str, float] = {
|
|
199
|
+
"entailment": 0.0,
|
|
200
|
+
"neutral": 0.0,
|
|
201
|
+
"contradiction": 0.0,
|
|
202
|
+
}
|
|
203
|
+
|
|
204
|
+
if "entailment" in answer:
|
|
205
|
+
scores["entailment"] = 1.0
|
|
206
|
+
elif "neutral" in answer:
|
|
207
|
+
scores["neutral"] = 1.0
|
|
208
|
+
elif "contradiction" in answer:
|
|
209
|
+
scores["contradiction"] = 1.0
|
|
210
|
+
else:
|
|
211
|
+
# Default to neutral if unclear
|
|
212
|
+
scores["neutral"] = 1.0
|
|
213
|
+
|
|
214
|
+
# Cache result
|
|
215
|
+
self.cache.set(
|
|
216
|
+
model_name=self.model_name,
|
|
217
|
+
operation="nli",
|
|
218
|
+
result=scores,
|
|
219
|
+
model_version=self.model_version,
|
|
220
|
+
premise=premise,
|
|
221
|
+
hypothesis=hypothesis,
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
return scores
|
|
@@ -0,0 +1,167 @@
|
|
|
1
|
+
"""Utilities for API-based model adapters.
|
|
2
|
+
|
|
3
|
+
This module provides shared utilities for API-based model adapters,
|
|
4
|
+
including retry logic with exponential backoff and rate limiting.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import time
|
|
10
|
+
from collections.abc import Callable
|
|
11
|
+
from functools import wraps
|
|
12
|
+
from typing import ParamSpec, TypeVar
|
|
13
|
+
|
|
14
|
+
P = ParamSpec("P")
|
|
15
|
+
|
|
16
|
+
T = TypeVar("T")
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def retry_with_backoff(
|
|
20
|
+
max_retries: int = 3,
|
|
21
|
+
initial_delay: float = 1.0,
|
|
22
|
+
backoff_factor: float = 2.0,
|
|
23
|
+
exceptions: tuple[type[Exception], ...] = (Exception,),
|
|
24
|
+
) -> Callable[[Callable[..., T]], Callable[..., T]]:
|
|
25
|
+
"""Decorate function with retry logic and exponential backoff.
|
|
26
|
+
|
|
27
|
+
Retries a function call on specified exceptions with exponential backoff
|
|
28
|
+
between attempts. The delay between retries grows exponentially:
|
|
29
|
+
delay = initial_delay * (backoff_factor ** attempt).
|
|
30
|
+
|
|
31
|
+
Parameters
|
|
32
|
+
----------
|
|
33
|
+
max_retries : int
|
|
34
|
+
Maximum number of retry attempts (default: 3).
|
|
35
|
+
initial_delay : float
|
|
36
|
+
Initial delay in seconds before first retry (default: 1.0).
|
|
37
|
+
backoff_factor : float
|
|
38
|
+
Multiplicative factor for delay between retries (default: 2.0).
|
|
39
|
+
exceptions : tuple[type[Exception], ...]
|
|
40
|
+
Tuple of exception types to catch and retry on (default: (Exception,)).
|
|
41
|
+
|
|
42
|
+
Returns
|
|
43
|
+
-------
|
|
44
|
+
Callable
|
|
45
|
+
Decorated function with retry logic.
|
|
46
|
+
|
|
47
|
+
Examples
|
|
48
|
+
--------
|
|
49
|
+
>>> @retry_with_backoff(max_retries=3, initial_delay=1.0)
|
|
50
|
+
... def call_api():
|
|
51
|
+
... # May raise transient errors
|
|
52
|
+
... return api.get_data()
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
def decorator(func: Callable[P, T]) -> Callable[P, T]:
|
|
56
|
+
@wraps(func)
|
|
57
|
+
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
|
58
|
+
last_exception: Exception | None = None
|
|
59
|
+
|
|
60
|
+
for attempt in range(max_retries + 1):
|
|
61
|
+
try:
|
|
62
|
+
return func(*args, **kwargs)
|
|
63
|
+
except exceptions as e:
|
|
64
|
+
last_exception = e
|
|
65
|
+
if attempt < max_retries:
|
|
66
|
+
delay = initial_delay * (backoff_factor**attempt)
|
|
67
|
+
time.sleep(delay)
|
|
68
|
+
else:
|
|
69
|
+
# last attempt failed, re-raise
|
|
70
|
+
raise
|
|
71
|
+
|
|
72
|
+
# should never reach here, but for type checker
|
|
73
|
+
if last_exception is not None:
|
|
74
|
+
raise last_exception
|
|
75
|
+
raise RuntimeError("Unexpected state in retry_with_backoff")
|
|
76
|
+
|
|
77
|
+
return wrapper
|
|
78
|
+
|
|
79
|
+
return decorator
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class RateLimiter:
|
|
83
|
+
"""Rate limiter for API calls.
|
|
84
|
+
|
|
85
|
+
Tracks call timestamps and enforces a maximum rate of calls per minute.
|
|
86
|
+
Uses a sliding window algorithm to ensure the rate limit is respected.
|
|
87
|
+
|
|
88
|
+
Parameters
|
|
89
|
+
----------
|
|
90
|
+
calls_per_minute : int
|
|
91
|
+
Maximum number of calls allowed per minute (default: 60).
|
|
92
|
+
|
|
93
|
+
Attributes
|
|
94
|
+
----------
|
|
95
|
+
calls_per_minute : int
|
|
96
|
+
Maximum number of calls allowed per minute.
|
|
97
|
+
call_times : list[float]
|
|
98
|
+
Timestamps of recent API calls.
|
|
99
|
+
"""
|
|
100
|
+
|
|
101
|
+
def __init__(self, calls_per_minute: int = 60) -> None:
|
|
102
|
+
self.calls_per_minute = calls_per_minute
|
|
103
|
+
self.call_times: list[float] = []
|
|
104
|
+
|
|
105
|
+
def wait_if_needed(self) -> None:
|
|
106
|
+
"""Wait if rate limit would be exceeded.
|
|
107
|
+
|
|
108
|
+
Checks if making a call now would exceed the rate limit.
|
|
109
|
+
If so, sleeps until enough time has passed.
|
|
110
|
+
"""
|
|
111
|
+
now = time.time()
|
|
112
|
+
|
|
113
|
+
# remove calls older than 1 minute
|
|
114
|
+
cutoff_time = now - 60.0
|
|
115
|
+
self.call_times = [t for t in self.call_times if t > cutoff_time]
|
|
116
|
+
|
|
117
|
+
# if at rate limit, wait until oldest call expires
|
|
118
|
+
if len(self.call_times) >= self.calls_per_minute:
|
|
119
|
+
oldest_call = self.call_times[0]
|
|
120
|
+
wait_time = 60.0 - (now - oldest_call)
|
|
121
|
+
if wait_time > 0:
|
|
122
|
+
time.sleep(wait_time)
|
|
123
|
+
# clean up again after waiting
|
|
124
|
+
now = time.time()
|
|
125
|
+
cutoff_time = now - 60.0
|
|
126
|
+
self.call_times = [t for t in self.call_times if t > cutoff_time]
|
|
127
|
+
|
|
128
|
+
# record this call
|
|
129
|
+
self.call_times.append(time.time())
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def rate_limit(
|
|
133
|
+
calls_per_minute: int = 60,
|
|
134
|
+
) -> Callable[[Callable[P, T]], Callable[P, T]]:
|
|
135
|
+
"""Decorate function with rate limiting for API calls.
|
|
136
|
+
|
|
137
|
+
Enforces a maximum rate of API calls per minute using a shared
|
|
138
|
+
RateLimiter instance. Calls that would exceed the rate limit
|
|
139
|
+
will block until the limit resets.
|
|
140
|
+
|
|
141
|
+
Parameters
|
|
142
|
+
----------
|
|
143
|
+
calls_per_minute : int
|
|
144
|
+
Maximum number of calls allowed per minute (default: 60).
|
|
145
|
+
|
|
146
|
+
Returns
|
|
147
|
+
-------
|
|
148
|
+
Callable
|
|
149
|
+
Decorated function with rate limiting.
|
|
150
|
+
|
|
151
|
+
Examples
|
|
152
|
+
--------
|
|
153
|
+
>>> @rate_limit(calls_per_minute=30)
|
|
154
|
+
... def call_api():
|
|
155
|
+
... return api.get_data()
|
|
156
|
+
"""
|
|
157
|
+
limiter = RateLimiter(calls_per_minute=calls_per_minute)
|
|
158
|
+
|
|
159
|
+
def decorator(func: Callable[P, T]) -> Callable[P, T]:
|
|
160
|
+
@wraps(func)
|
|
161
|
+
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
|
162
|
+
limiter.wait_if_needed()
|
|
163
|
+
return func(*args, **kwargs)
|
|
164
|
+
|
|
165
|
+
return wrapper
|
|
166
|
+
|
|
167
|
+
return decorator
|
|
@@ -0,0 +1,216 @@
|
|
|
1
|
+
"""Base class for model adapters used in item construction.
|
|
2
|
+
|
|
3
|
+
This module defines the abstract ModelAdapter interface that all model adapters
|
|
4
|
+
must implement to support judgment prediction operations during Stage 3
|
|
5
|
+
(Item Construction).
|
|
6
|
+
|
|
7
|
+
This is SEPARATE from template filling model adapters
|
|
8
|
+
(bead.templates.models.adapter), which are used in Stage 2.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
from abc import ABC, abstractmethod
|
|
14
|
+
|
|
15
|
+
import numpy as np
|
|
16
|
+
|
|
17
|
+
from bead.items.cache import ModelOutputCache
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class ModelAdapter(ABC):
|
|
21
|
+
"""Base class for model adapters used in item construction.
|
|
22
|
+
|
|
23
|
+
All model adapters must implement this interface to support
|
|
24
|
+
judgment prediction operations during Stage 3 (Item Construction).
|
|
25
|
+
|
|
26
|
+
This is SEPARATE from template filling model adapters
|
|
27
|
+
(bead.templates.models.adapter), which are used in Stage 2.
|
|
28
|
+
|
|
29
|
+
Parameters
|
|
30
|
+
----------
|
|
31
|
+
model_name : str
|
|
32
|
+
Model identifier (e.g., "gpt2", "roberta-large-mnli").
|
|
33
|
+
cache : ModelOutputCache
|
|
34
|
+
Cache instance for storing model outputs.
|
|
35
|
+
model_version : str
|
|
36
|
+
Version of the model for cache tracking.
|
|
37
|
+
|
|
38
|
+
Attributes
|
|
39
|
+
----------
|
|
40
|
+
model_name : str
|
|
41
|
+
Model identifier (e.g., "gpt2", "roberta-large-mnli").
|
|
42
|
+
model_version : str
|
|
43
|
+
Version of the model.
|
|
44
|
+
cache : ModelOutputCache
|
|
45
|
+
Cache for model outputs.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
def __init__(
|
|
49
|
+
self, model_name: str, cache: ModelOutputCache, model_version: str = "unknown"
|
|
50
|
+
) -> None:
|
|
51
|
+
self.model_name = model_name
|
|
52
|
+
self.model_version = model_version
|
|
53
|
+
self.cache = cache
|
|
54
|
+
|
|
55
|
+
@abstractmethod
|
|
56
|
+
def compute_log_probability(self, text: str) -> float:
|
|
57
|
+
"""Compute log probability of text under language model.
|
|
58
|
+
|
|
59
|
+
Required for language model constraints. Should raise NotImplementedError
|
|
60
|
+
if not supported by model type.
|
|
61
|
+
|
|
62
|
+
Parameters
|
|
63
|
+
----------
|
|
64
|
+
text : str
|
|
65
|
+
Text to compute log probability for.
|
|
66
|
+
|
|
67
|
+
Returns
|
|
68
|
+
-------
|
|
69
|
+
float
|
|
70
|
+
Log probability of the text.
|
|
71
|
+
|
|
72
|
+
Raises
|
|
73
|
+
------
|
|
74
|
+
NotImplementedError
|
|
75
|
+
If this operation is not supported by the model type.
|
|
76
|
+
"""
|
|
77
|
+
pass
|
|
78
|
+
|
|
79
|
+
@abstractmethod
|
|
80
|
+
def compute_perplexity(self, text: str) -> float:
|
|
81
|
+
"""Compute perplexity of text.
|
|
82
|
+
|
|
83
|
+
Required for complexity-based filtering. Should raise NotImplementedError
|
|
84
|
+
if not supported by model type.
|
|
85
|
+
|
|
86
|
+
Parameters
|
|
87
|
+
----------
|
|
88
|
+
text : str
|
|
89
|
+
Text to compute perplexity for.
|
|
90
|
+
|
|
91
|
+
Returns
|
|
92
|
+
-------
|
|
93
|
+
float
|
|
94
|
+
Perplexity of the text (must be positive).
|
|
95
|
+
|
|
96
|
+
Raises
|
|
97
|
+
------
|
|
98
|
+
NotImplementedError
|
|
99
|
+
If this operation is not supported by the model type.
|
|
100
|
+
"""
|
|
101
|
+
pass
|
|
102
|
+
|
|
103
|
+
@abstractmethod
|
|
104
|
+
def get_embedding(
|
|
105
|
+
self, text: str
|
|
106
|
+
) -> np.ndarray[tuple[int, ...], np.dtype[np.float64]]:
|
|
107
|
+
"""Get embedding vector for text.
|
|
108
|
+
|
|
109
|
+
Required for similarity computations and semantic clustering.
|
|
110
|
+
Should raise NotImplementedError if not supported by model type.
|
|
111
|
+
|
|
112
|
+
Parameters
|
|
113
|
+
----------
|
|
114
|
+
text : str
|
|
115
|
+
Text to embed.
|
|
116
|
+
|
|
117
|
+
Returns
|
|
118
|
+
-------
|
|
119
|
+
np.ndarray
|
|
120
|
+
Embedding vector for the text.
|
|
121
|
+
|
|
122
|
+
Raises
|
|
123
|
+
------
|
|
124
|
+
NotImplementedError
|
|
125
|
+
If this operation is not supported by the model type.
|
|
126
|
+
"""
|
|
127
|
+
pass
|
|
128
|
+
|
|
129
|
+
@abstractmethod
|
|
130
|
+
def compute_nli(self, premise: str, hypothesis: str) -> dict[str, float]:
|
|
131
|
+
"""Compute natural language inference scores.
|
|
132
|
+
|
|
133
|
+
Must return dict with keys: "entailment", "neutral", "contradiction".
|
|
134
|
+
Required for inference-based constraints. Should raise NotImplementedError
|
|
135
|
+
if not supported by model type.
|
|
136
|
+
|
|
137
|
+
Parameters
|
|
138
|
+
----------
|
|
139
|
+
premise : str
|
|
140
|
+
Premise text.
|
|
141
|
+
hypothesis : str
|
|
142
|
+
Hypothesis text.
|
|
143
|
+
|
|
144
|
+
Returns
|
|
145
|
+
-------
|
|
146
|
+
dict[str, float]
|
|
147
|
+
Dictionary with keys "entailment", "neutral", "contradiction"
|
|
148
|
+
mapping to probability scores that sum to ~1.0.
|
|
149
|
+
|
|
150
|
+
Raises
|
|
151
|
+
------
|
|
152
|
+
NotImplementedError
|
|
153
|
+
If this operation is not supported by the model type.
|
|
154
|
+
"""
|
|
155
|
+
pass
|
|
156
|
+
|
|
157
|
+
def compute_similarity(self, text1: str, text2: str) -> float:
|
|
158
|
+
"""Compute similarity between two texts.
|
|
159
|
+
|
|
160
|
+
Default implementation using cosine similarity of embeddings.
|
|
161
|
+
Can be overridden for specialized similarity computation.
|
|
162
|
+
|
|
163
|
+
Parameters
|
|
164
|
+
----------
|
|
165
|
+
text1 : str
|
|
166
|
+
First text.
|
|
167
|
+
text2 : str
|
|
168
|
+
Second text.
|
|
169
|
+
|
|
170
|
+
Returns
|
|
171
|
+
-------
|
|
172
|
+
float
|
|
173
|
+
Similarity score in [-1, 1] (cosine similarity).
|
|
174
|
+
|
|
175
|
+
Raises
|
|
176
|
+
------
|
|
177
|
+
NotImplementedError
|
|
178
|
+
If embeddings are not supported by the model type.
|
|
179
|
+
"""
|
|
180
|
+
emb1 = self.get_embedding(text1)
|
|
181
|
+
emb2 = self.get_embedding(text2)
|
|
182
|
+
|
|
183
|
+
# Cosine similarity
|
|
184
|
+
dot_product = np.dot(emb1, emb2)
|
|
185
|
+
norm1 = np.linalg.norm(emb1)
|
|
186
|
+
norm2 = np.linalg.norm(emb2)
|
|
187
|
+
|
|
188
|
+
if norm1 == 0 or norm2 == 0:
|
|
189
|
+
return 0.0
|
|
190
|
+
|
|
191
|
+
return float(dot_product / (norm1 * norm2))
|
|
192
|
+
|
|
193
|
+
def get_nli_label(self, premise: str, hypothesis: str) -> str:
|
|
194
|
+
"""Get predicted NLI label (max score).
|
|
195
|
+
|
|
196
|
+
Default implementation using argmax over compute_nli() scores.
|
|
197
|
+
|
|
198
|
+
Parameters
|
|
199
|
+
----------
|
|
200
|
+
premise : str
|
|
201
|
+
Premise text.
|
|
202
|
+
hypothesis : str
|
|
203
|
+
Hypothesis text.
|
|
204
|
+
|
|
205
|
+
Returns
|
|
206
|
+
-------
|
|
207
|
+
str
|
|
208
|
+
Predicted label: "entailment", "neutral", or "contradiction".
|
|
209
|
+
|
|
210
|
+
Raises
|
|
211
|
+
------
|
|
212
|
+
NotImplementedError
|
|
213
|
+
If NLI is not supported by the model type.
|
|
214
|
+
"""
|
|
215
|
+
scores = self.compute_nli(premise, hypothesis)
|
|
216
|
+
return max(scores, key=scores.get) # type: ignore[arg-type, return-value]
|