bead 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. bead/__init__.py +11 -0
  2. bead/__main__.py +11 -0
  3. bead/active_learning/__init__.py +15 -0
  4. bead/active_learning/config.py +231 -0
  5. bead/active_learning/loop.py +566 -0
  6. bead/active_learning/models/__init__.py +24 -0
  7. bead/active_learning/models/base.py +852 -0
  8. bead/active_learning/models/binary.py +910 -0
  9. bead/active_learning/models/categorical.py +943 -0
  10. bead/active_learning/models/cloze.py +862 -0
  11. bead/active_learning/models/forced_choice.py +956 -0
  12. bead/active_learning/models/free_text.py +773 -0
  13. bead/active_learning/models/lora.py +365 -0
  14. bead/active_learning/models/magnitude.py +835 -0
  15. bead/active_learning/models/multi_select.py +795 -0
  16. bead/active_learning/models/ordinal_scale.py +811 -0
  17. bead/active_learning/models/peft_adapter.py +155 -0
  18. bead/active_learning/models/random_effects.py +639 -0
  19. bead/active_learning/selection.py +354 -0
  20. bead/active_learning/strategies.py +391 -0
  21. bead/active_learning/trainers/__init__.py +26 -0
  22. bead/active_learning/trainers/base.py +210 -0
  23. bead/active_learning/trainers/data_collator.py +172 -0
  24. bead/active_learning/trainers/dataset_utils.py +261 -0
  25. bead/active_learning/trainers/huggingface.py +304 -0
  26. bead/active_learning/trainers/lightning.py +324 -0
  27. bead/active_learning/trainers/metrics.py +424 -0
  28. bead/active_learning/trainers/mixed_effects.py +551 -0
  29. bead/active_learning/trainers/model_wrapper.py +509 -0
  30. bead/active_learning/trainers/registry.py +104 -0
  31. bead/adapters/__init__.py +11 -0
  32. bead/adapters/huggingface.py +61 -0
  33. bead/behavioral/__init__.py +116 -0
  34. bead/behavioral/analytics.py +646 -0
  35. bead/behavioral/extraction.py +343 -0
  36. bead/behavioral/merging.py +343 -0
  37. bead/cli/__init__.py +11 -0
  38. bead/cli/active_learning.py +513 -0
  39. bead/cli/active_learning_commands.py +779 -0
  40. bead/cli/completion.py +359 -0
  41. bead/cli/config.py +624 -0
  42. bead/cli/constraint_builders.py +286 -0
  43. bead/cli/deployment.py +859 -0
  44. bead/cli/deployment_trials.py +493 -0
  45. bead/cli/deployment_ui.py +332 -0
  46. bead/cli/display.py +378 -0
  47. bead/cli/items.py +960 -0
  48. bead/cli/items_factories.py +776 -0
  49. bead/cli/list_constraints.py +714 -0
  50. bead/cli/lists.py +490 -0
  51. bead/cli/main.py +430 -0
  52. bead/cli/models.py +877 -0
  53. bead/cli/resource_loaders.py +621 -0
  54. bead/cli/resources.py +1036 -0
  55. bead/cli/shell.py +356 -0
  56. bead/cli/simulate.py +840 -0
  57. bead/cli/templates.py +1158 -0
  58. bead/cli/training.py +1080 -0
  59. bead/cli/utils.py +614 -0
  60. bead/cli/workflow.py +1273 -0
  61. bead/config/__init__.py +68 -0
  62. bead/config/active_learning.py +1009 -0
  63. bead/config/config.py +192 -0
  64. bead/config/defaults.py +118 -0
  65. bead/config/deployment.py +217 -0
  66. bead/config/env.py +147 -0
  67. bead/config/item.py +45 -0
  68. bead/config/list.py +193 -0
  69. bead/config/loader.py +149 -0
  70. bead/config/logging.py +42 -0
  71. bead/config/model.py +49 -0
  72. bead/config/paths.py +46 -0
  73. bead/config/profiles.py +320 -0
  74. bead/config/resources.py +47 -0
  75. bead/config/serialization.py +210 -0
  76. bead/config/simulation.py +206 -0
  77. bead/config/template.py +238 -0
  78. bead/config/validation.py +267 -0
  79. bead/data/__init__.py +65 -0
  80. bead/data/base.py +87 -0
  81. bead/data/identifiers.py +97 -0
  82. bead/data/language_codes.py +61 -0
  83. bead/data/metadata.py +270 -0
  84. bead/data/range.py +123 -0
  85. bead/data/repository.py +358 -0
  86. bead/data/serialization.py +249 -0
  87. bead/data/timestamps.py +89 -0
  88. bead/data/validation.py +349 -0
  89. bead/data_collection/__init__.py +11 -0
  90. bead/data_collection/jatos.py +223 -0
  91. bead/data_collection/merger.py +154 -0
  92. bead/data_collection/prolific.py +198 -0
  93. bead/deployment/__init__.py +5 -0
  94. bead/deployment/distribution.py +402 -0
  95. bead/deployment/jatos/__init__.py +1 -0
  96. bead/deployment/jatos/api.py +200 -0
  97. bead/deployment/jatos/exporter.py +210 -0
  98. bead/deployment/jspsych/__init__.py +9 -0
  99. bead/deployment/jspsych/biome.json +44 -0
  100. bead/deployment/jspsych/config.py +411 -0
  101. bead/deployment/jspsych/generator.py +598 -0
  102. bead/deployment/jspsych/package.json +51 -0
  103. bead/deployment/jspsych/pnpm-lock.yaml +2141 -0
  104. bead/deployment/jspsych/randomizer.py +299 -0
  105. bead/deployment/jspsych/src/lib/list-distributor.test.ts +327 -0
  106. bead/deployment/jspsych/src/lib/list-distributor.ts +1282 -0
  107. bead/deployment/jspsych/src/lib/randomizer.test.ts +232 -0
  108. bead/deployment/jspsych/src/lib/randomizer.ts +367 -0
  109. bead/deployment/jspsych/src/plugins/cloze-dropdown.ts +252 -0
  110. bead/deployment/jspsych/src/plugins/forced-choice.ts +265 -0
  111. bead/deployment/jspsych/src/plugins/plugins.test.ts +141 -0
  112. bead/deployment/jspsych/src/plugins/rating.ts +248 -0
  113. bead/deployment/jspsych/src/slopit/index.ts +9 -0
  114. bead/deployment/jspsych/src/types/jatos.d.ts +256 -0
  115. bead/deployment/jspsych/src/types/jspsych.d.ts +228 -0
  116. bead/deployment/jspsych/templates/experiment.css +1 -0
  117. bead/deployment/jspsych/templates/experiment.js.template +289 -0
  118. bead/deployment/jspsych/templates/index.html +51 -0
  119. bead/deployment/jspsych/templates/randomizer.js +241 -0
  120. bead/deployment/jspsych/templates/randomizer.js.template +313 -0
  121. bead/deployment/jspsych/trials.py +723 -0
  122. bead/deployment/jspsych/tsconfig.json +23 -0
  123. bead/deployment/jspsych/tsup.config.ts +30 -0
  124. bead/deployment/jspsych/ui/__init__.py +1 -0
  125. bead/deployment/jspsych/ui/components.py +383 -0
  126. bead/deployment/jspsych/ui/styles.py +411 -0
  127. bead/dsl/__init__.py +80 -0
  128. bead/dsl/ast.py +168 -0
  129. bead/dsl/context.py +178 -0
  130. bead/dsl/errors.py +71 -0
  131. bead/dsl/evaluator.py +570 -0
  132. bead/dsl/grammar.lark +81 -0
  133. bead/dsl/parser.py +231 -0
  134. bead/dsl/stdlib.py +929 -0
  135. bead/evaluation/__init__.py +13 -0
  136. bead/evaluation/convergence.py +485 -0
  137. bead/evaluation/interannotator.py +398 -0
  138. bead/items/__init__.py +40 -0
  139. bead/items/adapters/__init__.py +70 -0
  140. bead/items/adapters/anthropic.py +224 -0
  141. bead/items/adapters/api_utils.py +167 -0
  142. bead/items/adapters/base.py +216 -0
  143. bead/items/adapters/google.py +259 -0
  144. bead/items/adapters/huggingface.py +1074 -0
  145. bead/items/adapters/openai.py +323 -0
  146. bead/items/adapters/registry.py +202 -0
  147. bead/items/adapters/sentence_transformers.py +224 -0
  148. bead/items/adapters/togetherai.py +309 -0
  149. bead/items/binary.py +515 -0
  150. bead/items/cache.py +558 -0
  151. bead/items/categorical.py +593 -0
  152. bead/items/cloze.py +757 -0
  153. bead/items/constructor.py +784 -0
  154. bead/items/forced_choice.py +413 -0
  155. bead/items/free_text.py +681 -0
  156. bead/items/generation.py +432 -0
  157. bead/items/item.py +396 -0
  158. bead/items/item_template.py +787 -0
  159. bead/items/magnitude.py +573 -0
  160. bead/items/multi_select.py +621 -0
  161. bead/items/ordinal_scale.py +569 -0
  162. bead/items/scoring.py +448 -0
  163. bead/items/validation.py +723 -0
  164. bead/lists/__init__.py +30 -0
  165. bead/lists/balancer.py +263 -0
  166. bead/lists/constraints.py +1067 -0
  167. bead/lists/experiment_list.py +286 -0
  168. bead/lists/list_collection.py +378 -0
  169. bead/lists/partitioner.py +1141 -0
  170. bead/lists/stratification.py +254 -0
  171. bead/participants/__init__.py +73 -0
  172. bead/participants/collection.py +699 -0
  173. bead/participants/merging.py +312 -0
  174. bead/participants/metadata_spec.py +491 -0
  175. bead/participants/models.py +276 -0
  176. bead/resources/__init__.py +29 -0
  177. bead/resources/adapters/__init__.py +19 -0
  178. bead/resources/adapters/base.py +104 -0
  179. bead/resources/adapters/cache.py +128 -0
  180. bead/resources/adapters/glazing.py +508 -0
  181. bead/resources/adapters/registry.py +117 -0
  182. bead/resources/adapters/unimorph.py +796 -0
  183. bead/resources/classification.py +856 -0
  184. bead/resources/constraint_builders.py +329 -0
  185. bead/resources/constraints.py +165 -0
  186. bead/resources/lexical_item.py +223 -0
  187. bead/resources/lexicon.py +744 -0
  188. bead/resources/loaders.py +209 -0
  189. bead/resources/template.py +441 -0
  190. bead/resources/template_collection.py +707 -0
  191. bead/resources/template_generation.py +349 -0
  192. bead/simulation/__init__.py +29 -0
  193. bead/simulation/annotators/__init__.py +15 -0
  194. bead/simulation/annotators/base.py +175 -0
  195. bead/simulation/annotators/distance_based.py +135 -0
  196. bead/simulation/annotators/lm_based.py +114 -0
  197. bead/simulation/annotators/oracle.py +182 -0
  198. bead/simulation/annotators/random.py +181 -0
  199. bead/simulation/dsl_extension/__init__.py +3 -0
  200. bead/simulation/noise_models/__init__.py +13 -0
  201. bead/simulation/noise_models/base.py +42 -0
  202. bead/simulation/noise_models/random_noise.py +82 -0
  203. bead/simulation/noise_models/systematic.py +132 -0
  204. bead/simulation/noise_models/temperature.py +86 -0
  205. bead/simulation/runner.py +144 -0
  206. bead/simulation/strategies/__init__.py +23 -0
  207. bead/simulation/strategies/base.py +123 -0
  208. bead/simulation/strategies/binary.py +103 -0
  209. bead/simulation/strategies/categorical.py +123 -0
  210. bead/simulation/strategies/cloze.py +224 -0
  211. bead/simulation/strategies/forced_choice.py +127 -0
  212. bead/simulation/strategies/free_text.py +105 -0
  213. bead/simulation/strategies/magnitude.py +116 -0
  214. bead/simulation/strategies/multi_select.py +129 -0
  215. bead/simulation/strategies/ordinal_scale.py +131 -0
  216. bead/templates/__init__.py +27 -0
  217. bead/templates/adapters/__init__.py +17 -0
  218. bead/templates/adapters/base.py +128 -0
  219. bead/templates/adapters/cache.py +178 -0
  220. bead/templates/adapters/huggingface.py +312 -0
  221. bead/templates/combinatorics.py +103 -0
  222. bead/templates/filler.py +605 -0
  223. bead/templates/renderers.py +177 -0
  224. bead/templates/resolver.py +178 -0
  225. bead/templates/strategies.py +1806 -0
  226. bead/templates/streaming.py +195 -0
  227. bead-0.1.0.dist-info/METADATA +212 -0
  228. bead-0.1.0.dist-info/RECORD +231 -0
  229. bead-0.1.0.dist-info/WHEEL +4 -0
  230. bead-0.1.0.dist-info/entry_points.txt +2 -0
  231. bead-0.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,224 @@
1
+ """Anthropic API adapter for item construction.
2
+
3
+ This module provides a ModelAdapter implementation for Anthropic's Claude API,
4
+ supporting natural language inference via prompting. Note that Claude API does
5
+ not provide direct access to log probabilities or embeddings.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import os
11
+
12
+ import numpy as np
13
+
14
+ try:
15
+ import anthropic
16
+ except ImportError as e:
17
+ raise ImportError(
18
+ "anthropic package is required for Anthropic adapter. "
19
+ "Install it with: pip install anthropic"
20
+ ) from e
21
+
22
+ from bead.items.adapters.api_utils import rate_limit, retry_with_backoff
23
+ from bead.items.adapters.base import ModelAdapter
24
+ from bead.items.cache import ModelOutputCache
25
+
26
+
27
+ class AnthropicAdapter(ModelAdapter):
28
+ """Adapter for Anthropic Claude API models.
29
+
30
+ Provides access to Claude models for prompted natural language inference.
31
+ Note that Claude API does not support log probability computation or
32
+ embeddings, so those methods will raise NotImplementedError.
33
+
34
+ Parameters
35
+ ----------
36
+ model_name : str
37
+ Claude model identifier (default: "claude-3-5-sonnet-20241022").
38
+ api_key : str | None
39
+ Anthropic API key. If None, uses ANTHROPIC_API_KEY environment variable.
40
+ cache : ModelOutputCache | None
41
+ Cache for model outputs. If None, creates in-memory cache.
42
+ model_version : str
43
+ Model version for cache tracking (default: "latest").
44
+
45
+ Attributes
46
+ ----------
47
+ model_name : str
48
+ Claude model identifier (e.g., "claude-3-5-sonnet-20241022").
49
+ client : anthropic.Anthropic
50
+ Anthropic API client.
51
+
52
+ Raises
53
+ ------
54
+ ValueError
55
+ If no API key is provided and ANTHROPIC_API_KEY is not set.
56
+ """
57
+
58
+ def __init__(
59
+ self,
60
+ model_name: str = "claude-3-5-sonnet-20241022",
61
+ api_key: str | None = None,
62
+ cache: ModelOutputCache | None = None,
63
+ model_version: str = "latest",
64
+ ) -> None:
65
+ if cache is None:
66
+ cache = ModelOutputCache(backend="memory")
67
+
68
+ super().__init__(
69
+ model_name=model_name, cache=cache, model_version=model_version
70
+ )
71
+
72
+ # Get API key from parameter or environment
73
+ if api_key is None:
74
+ api_key = os.environ.get("ANTHROPIC_API_KEY")
75
+ if api_key is None:
76
+ raise ValueError(
77
+ "Anthropic API key must be provided via api_key parameter "
78
+ "or ANTHROPIC_API_KEY environment variable"
79
+ )
80
+
81
+ self.client = anthropic.Anthropic(api_key=api_key)
82
+
83
+ def compute_log_probability(self, text: str) -> float:
84
+ """Compute log probability of text.
85
+
86
+ Not supported by Anthropic API.
87
+
88
+ Raises
89
+ ------
90
+ NotImplementedError
91
+ Always raised - Claude API does not provide log probabilities.
92
+ """
93
+ raise NotImplementedError(
94
+ "Log probability computation is not supported by Anthropic Claude API. "
95
+ "Claude does not provide access to token-level probabilities."
96
+ )
97
+
98
+ def compute_perplexity(self, text: str) -> float:
99
+ """Compute perplexity of text.
100
+
101
+ Not supported by Anthropic API (requires log probabilities).
102
+
103
+ Raises
104
+ ------
105
+ NotImplementedError
106
+ Always raised - requires log probability support.
107
+ """
108
+ raise NotImplementedError(
109
+ "Perplexity computation is not supported by Anthropic Claude API. "
110
+ "This operation requires log probabilities, which Claude does not provide."
111
+ )
112
+
113
+ def get_embedding(self, text: str) -> np.ndarray:
114
+ """Get embedding vector for text.
115
+
116
+ Not supported by Anthropic API.
117
+
118
+ Raises
119
+ ------
120
+ NotImplementedError
121
+ Always raised - Claude API does not provide embeddings.
122
+ """
123
+ raise NotImplementedError(
124
+ "Embedding computation is not supported by Anthropic Claude API. "
125
+ "Claude does not provide embedding vectors. "
126
+ "Consider using OpenAI's text-embedding models or sentence transformers."
127
+ )
128
+
129
+ @retry_with_backoff(
130
+ max_retries=3,
131
+ initial_delay=1.0,
132
+ backoff_factor=2.0,
133
+ exceptions=(
134
+ anthropic.APIError,
135
+ anthropic.APIConnectionError,
136
+ anthropic.RateLimitError,
137
+ ),
138
+ )
139
+ @rate_limit(calls_per_minute=60)
140
+ def compute_nli(self, premise: str, hypothesis: str) -> dict[str, float]:
141
+ """Compute natural language inference scores via prompting.
142
+
143
+ Uses Claude's messages API with a prompt to classify the relationship
144
+ between premise and hypothesis.
145
+
146
+ Parameters
147
+ ----------
148
+ premise : str
149
+ Premise text.
150
+ hypothesis : str
151
+ Hypothesis text.
152
+
153
+ Returns
154
+ -------
155
+ dict[str, float]
156
+ Dictionary with keys "entailment", "neutral", "contradiction"
157
+ mapping to probability scores.
158
+ """
159
+ # Check cache
160
+ cached = self.cache.get(
161
+ model_name=self.model_name,
162
+ operation="nli",
163
+ premise=premise,
164
+ hypothesis=hypothesis,
165
+ )
166
+ if cached is not None:
167
+ return dict(cached)
168
+
169
+ # Construct prompt
170
+ prompt = (
171
+ "Given the following premise and hypothesis, "
172
+ "determine the relationship between them.\n\n"
173
+ f"Premise: {premise}\n"
174
+ f"Hypothesis: {hypothesis}\n\n"
175
+ "Choose one of the following:\n"
176
+ "- entailment: The hypothesis is definitely true given the premise\n"
177
+ "- neutral: The hypothesis might be true given the premise\n"
178
+ "- contradiction: The hypothesis is definitely false given the premise\n\n"
179
+ "Respond with only one word: entailment, neutral, or contradiction."
180
+ )
181
+
182
+ # Call API
183
+ response = self.client.messages.create(
184
+ model=self.model_name,
185
+ max_tokens=10,
186
+ temperature=0.0,
187
+ messages=[{"role": "user", "content": prompt}],
188
+ )
189
+
190
+ # Parse response
191
+ if not response.content or len(response.content) == 0:
192
+ raise ValueError("API response did not include content")
193
+
194
+ # Get text from first content block
195
+ answer = response.content[0].text.strip().lower()
196
+
197
+ # Map to scores
198
+ scores: dict[str, float] = {
199
+ "entailment": 0.0,
200
+ "neutral": 0.0,
201
+ "contradiction": 0.0,
202
+ }
203
+
204
+ if "entailment" in answer:
205
+ scores["entailment"] = 1.0
206
+ elif "neutral" in answer:
207
+ scores["neutral"] = 1.0
208
+ elif "contradiction" in answer:
209
+ scores["contradiction"] = 1.0
210
+ else:
211
+ # Default to neutral if unclear
212
+ scores["neutral"] = 1.0
213
+
214
+ # Cache result
215
+ self.cache.set(
216
+ model_name=self.model_name,
217
+ operation="nli",
218
+ result=scores,
219
+ model_version=self.model_version,
220
+ premise=premise,
221
+ hypothesis=hypothesis,
222
+ )
223
+
224
+ return scores
@@ -0,0 +1,167 @@
1
+ """Utilities for API-based model adapters.
2
+
3
+ This module provides shared utilities for API-based model adapters,
4
+ including retry logic with exponential backoff and rate limiting.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import time
10
+ from collections.abc import Callable
11
+ from functools import wraps
12
+ from typing import ParamSpec, TypeVar
13
+
14
+ P = ParamSpec("P")
15
+
16
+ T = TypeVar("T")
17
+
18
+
19
+ def retry_with_backoff(
20
+ max_retries: int = 3,
21
+ initial_delay: float = 1.0,
22
+ backoff_factor: float = 2.0,
23
+ exceptions: tuple[type[Exception], ...] = (Exception,),
24
+ ) -> Callable[[Callable[..., T]], Callable[..., T]]:
25
+ """Decorate function with retry logic and exponential backoff.
26
+
27
+ Retries a function call on specified exceptions with exponential backoff
28
+ between attempts. The delay between retries grows exponentially:
29
+ delay = initial_delay * (backoff_factor ** attempt).
30
+
31
+ Parameters
32
+ ----------
33
+ max_retries : int
34
+ Maximum number of retry attempts (default: 3).
35
+ initial_delay : float
36
+ Initial delay in seconds before first retry (default: 1.0).
37
+ backoff_factor : float
38
+ Multiplicative factor for delay between retries (default: 2.0).
39
+ exceptions : tuple[type[Exception], ...]
40
+ Tuple of exception types to catch and retry on (default: (Exception,)).
41
+
42
+ Returns
43
+ -------
44
+ Callable
45
+ Decorated function with retry logic.
46
+
47
+ Examples
48
+ --------
49
+ >>> @retry_with_backoff(max_retries=3, initial_delay=1.0)
50
+ ... def call_api():
51
+ ... # May raise transient errors
52
+ ... return api.get_data()
53
+ """
54
+
55
+ def decorator(func: Callable[P, T]) -> Callable[P, T]:
56
+ @wraps(func)
57
+ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
58
+ last_exception: Exception | None = None
59
+
60
+ for attempt in range(max_retries + 1):
61
+ try:
62
+ return func(*args, **kwargs)
63
+ except exceptions as e:
64
+ last_exception = e
65
+ if attempt < max_retries:
66
+ delay = initial_delay * (backoff_factor**attempt)
67
+ time.sleep(delay)
68
+ else:
69
+ # last attempt failed, re-raise
70
+ raise
71
+
72
+ # should never reach here, but for type checker
73
+ if last_exception is not None:
74
+ raise last_exception
75
+ raise RuntimeError("Unexpected state in retry_with_backoff")
76
+
77
+ return wrapper
78
+
79
+ return decorator
80
+
81
+
82
+ class RateLimiter:
83
+ """Rate limiter for API calls.
84
+
85
+ Tracks call timestamps and enforces a maximum rate of calls per minute.
86
+ Uses a sliding window algorithm to ensure the rate limit is respected.
87
+
88
+ Parameters
89
+ ----------
90
+ calls_per_minute : int
91
+ Maximum number of calls allowed per minute (default: 60).
92
+
93
+ Attributes
94
+ ----------
95
+ calls_per_minute : int
96
+ Maximum number of calls allowed per minute.
97
+ call_times : list[float]
98
+ Timestamps of recent API calls.
99
+ """
100
+
101
+ def __init__(self, calls_per_minute: int = 60) -> None:
102
+ self.calls_per_minute = calls_per_minute
103
+ self.call_times: list[float] = []
104
+
105
+ def wait_if_needed(self) -> None:
106
+ """Wait if rate limit would be exceeded.
107
+
108
+ Checks if making a call now would exceed the rate limit.
109
+ If so, sleeps until enough time has passed.
110
+ """
111
+ now = time.time()
112
+
113
+ # remove calls older than 1 minute
114
+ cutoff_time = now - 60.0
115
+ self.call_times = [t for t in self.call_times if t > cutoff_time]
116
+
117
+ # if at rate limit, wait until oldest call expires
118
+ if len(self.call_times) >= self.calls_per_minute:
119
+ oldest_call = self.call_times[0]
120
+ wait_time = 60.0 - (now - oldest_call)
121
+ if wait_time > 0:
122
+ time.sleep(wait_time)
123
+ # clean up again after waiting
124
+ now = time.time()
125
+ cutoff_time = now - 60.0
126
+ self.call_times = [t for t in self.call_times if t > cutoff_time]
127
+
128
+ # record this call
129
+ self.call_times.append(time.time())
130
+
131
+
132
+ def rate_limit(
133
+ calls_per_minute: int = 60,
134
+ ) -> Callable[[Callable[P, T]], Callable[P, T]]:
135
+ """Decorate function with rate limiting for API calls.
136
+
137
+ Enforces a maximum rate of API calls per minute using a shared
138
+ RateLimiter instance. Calls that would exceed the rate limit
139
+ will block until the limit resets.
140
+
141
+ Parameters
142
+ ----------
143
+ calls_per_minute : int
144
+ Maximum number of calls allowed per minute (default: 60).
145
+
146
+ Returns
147
+ -------
148
+ Callable
149
+ Decorated function with rate limiting.
150
+
151
+ Examples
152
+ --------
153
+ >>> @rate_limit(calls_per_minute=30)
154
+ ... def call_api():
155
+ ... return api.get_data()
156
+ """
157
+ limiter = RateLimiter(calls_per_minute=calls_per_minute)
158
+
159
+ def decorator(func: Callable[P, T]) -> Callable[P, T]:
160
+ @wraps(func)
161
+ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
162
+ limiter.wait_if_needed()
163
+ return func(*args, **kwargs)
164
+
165
+ return wrapper
166
+
167
+ return decorator
@@ -0,0 +1,216 @@
1
+ """Base class for model adapters used in item construction.
2
+
3
+ This module defines the abstract ModelAdapter interface that all model adapters
4
+ must implement to support judgment prediction operations during Stage 3
5
+ (Item Construction).
6
+
7
+ This is SEPARATE from template filling model adapters
8
+ (bead.templates.models.adapter), which are used in Stage 2.
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ from abc import ABC, abstractmethod
14
+
15
+ import numpy as np
16
+
17
+ from bead.items.cache import ModelOutputCache
18
+
19
+
20
+ class ModelAdapter(ABC):
21
+ """Base class for model adapters used in item construction.
22
+
23
+ All model adapters must implement this interface to support
24
+ judgment prediction operations during Stage 3 (Item Construction).
25
+
26
+ This is SEPARATE from template filling model adapters
27
+ (bead.templates.models.adapter), which are used in Stage 2.
28
+
29
+ Parameters
30
+ ----------
31
+ model_name : str
32
+ Model identifier (e.g., "gpt2", "roberta-large-mnli").
33
+ cache : ModelOutputCache
34
+ Cache instance for storing model outputs.
35
+ model_version : str
36
+ Version of the model for cache tracking.
37
+
38
+ Attributes
39
+ ----------
40
+ model_name : str
41
+ Model identifier (e.g., "gpt2", "roberta-large-mnli").
42
+ model_version : str
43
+ Version of the model.
44
+ cache : ModelOutputCache
45
+ Cache for model outputs.
46
+ """
47
+
48
+ def __init__(
49
+ self, model_name: str, cache: ModelOutputCache, model_version: str = "unknown"
50
+ ) -> None:
51
+ self.model_name = model_name
52
+ self.model_version = model_version
53
+ self.cache = cache
54
+
55
+ @abstractmethod
56
+ def compute_log_probability(self, text: str) -> float:
57
+ """Compute log probability of text under language model.
58
+
59
+ Required for language model constraints. Should raise NotImplementedError
60
+ if not supported by model type.
61
+
62
+ Parameters
63
+ ----------
64
+ text : str
65
+ Text to compute log probability for.
66
+
67
+ Returns
68
+ -------
69
+ float
70
+ Log probability of the text.
71
+
72
+ Raises
73
+ ------
74
+ NotImplementedError
75
+ If this operation is not supported by the model type.
76
+ """
77
+ pass
78
+
79
+ @abstractmethod
80
+ def compute_perplexity(self, text: str) -> float:
81
+ """Compute perplexity of text.
82
+
83
+ Required for complexity-based filtering. Should raise NotImplementedError
84
+ if not supported by model type.
85
+
86
+ Parameters
87
+ ----------
88
+ text : str
89
+ Text to compute perplexity for.
90
+
91
+ Returns
92
+ -------
93
+ float
94
+ Perplexity of the text (must be positive).
95
+
96
+ Raises
97
+ ------
98
+ NotImplementedError
99
+ If this operation is not supported by the model type.
100
+ """
101
+ pass
102
+
103
+ @abstractmethod
104
+ def get_embedding(
105
+ self, text: str
106
+ ) -> np.ndarray[tuple[int, ...], np.dtype[np.float64]]:
107
+ """Get embedding vector for text.
108
+
109
+ Required for similarity computations and semantic clustering.
110
+ Should raise NotImplementedError if not supported by model type.
111
+
112
+ Parameters
113
+ ----------
114
+ text : str
115
+ Text to embed.
116
+
117
+ Returns
118
+ -------
119
+ np.ndarray
120
+ Embedding vector for the text.
121
+
122
+ Raises
123
+ ------
124
+ NotImplementedError
125
+ If this operation is not supported by the model type.
126
+ """
127
+ pass
128
+
129
+ @abstractmethod
130
+ def compute_nli(self, premise: str, hypothesis: str) -> dict[str, float]:
131
+ """Compute natural language inference scores.
132
+
133
+ Must return dict with keys: "entailment", "neutral", "contradiction".
134
+ Required for inference-based constraints. Should raise NotImplementedError
135
+ if not supported by model type.
136
+
137
+ Parameters
138
+ ----------
139
+ premise : str
140
+ Premise text.
141
+ hypothesis : str
142
+ Hypothesis text.
143
+
144
+ Returns
145
+ -------
146
+ dict[str, float]
147
+ Dictionary with keys "entailment", "neutral", "contradiction"
148
+ mapping to probability scores that sum to ~1.0.
149
+
150
+ Raises
151
+ ------
152
+ NotImplementedError
153
+ If this operation is not supported by the model type.
154
+ """
155
+ pass
156
+
157
+ def compute_similarity(self, text1: str, text2: str) -> float:
158
+ """Compute similarity between two texts.
159
+
160
+ Default implementation using cosine similarity of embeddings.
161
+ Can be overridden for specialized similarity computation.
162
+
163
+ Parameters
164
+ ----------
165
+ text1 : str
166
+ First text.
167
+ text2 : str
168
+ Second text.
169
+
170
+ Returns
171
+ -------
172
+ float
173
+ Similarity score in [-1, 1] (cosine similarity).
174
+
175
+ Raises
176
+ ------
177
+ NotImplementedError
178
+ If embeddings are not supported by the model type.
179
+ """
180
+ emb1 = self.get_embedding(text1)
181
+ emb2 = self.get_embedding(text2)
182
+
183
+ # Cosine similarity
184
+ dot_product = np.dot(emb1, emb2)
185
+ norm1 = np.linalg.norm(emb1)
186
+ norm2 = np.linalg.norm(emb2)
187
+
188
+ if norm1 == 0 or norm2 == 0:
189
+ return 0.0
190
+
191
+ return float(dot_product / (norm1 * norm2))
192
+
193
+ def get_nli_label(self, premise: str, hypothesis: str) -> str:
194
+ """Get predicted NLI label (max score).
195
+
196
+ Default implementation using argmax over compute_nli() scores.
197
+
198
+ Parameters
199
+ ----------
200
+ premise : str
201
+ Premise text.
202
+ hypothesis : str
203
+ Hypothesis text.
204
+
205
+ Returns
206
+ -------
207
+ str
208
+ Predicted label: "entailment", "neutral", or "contradiction".
209
+
210
+ Raises
211
+ ------
212
+ NotImplementedError
213
+ If NLI is not supported by the model type.
214
+ """
215
+ scores = self.compute_nli(premise, hypothesis)
216
+ return max(scores, key=scores.get) # type: ignore[arg-type, return-value]