bead 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. bead/__init__.py +11 -0
  2. bead/__main__.py +11 -0
  3. bead/active_learning/__init__.py +15 -0
  4. bead/active_learning/config.py +231 -0
  5. bead/active_learning/loop.py +566 -0
  6. bead/active_learning/models/__init__.py +24 -0
  7. bead/active_learning/models/base.py +852 -0
  8. bead/active_learning/models/binary.py +910 -0
  9. bead/active_learning/models/categorical.py +943 -0
  10. bead/active_learning/models/cloze.py +862 -0
  11. bead/active_learning/models/forced_choice.py +956 -0
  12. bead/active_learning/models/free_text.py +773 -0
  13. bead/active_learning/models/lora.py +365 -0
  14. bead/active_learning/models/magnitude.py +835 -0
  15. bead/active_learning/models/multi_select.py +795 -0
  16. bead/active_learning/models/ordinal_scale.py +811 -0
  17. bead/active_learning/models/peft_adapter.py +155 -0
  18. bead/active_learning/models/random_effects.py +639 -0
  19. bead/active_learning/selection.py +354 -0
  20. bead/active_learning/strategies.py +391 -0
  21. bead/active_learning/trainers/__init__.py +26 -0
  22. bead/active_learning/trainers/base.py +210 -0
  23. bead/active_learning/trainers/data_collator.py +172 -0
  24. bead/active_learning/trainers/dataset_utils.py +261 -0
  25. bead/active_learning/trainers/huggingface.py +304 -0
  26. bead/active_learning/trainers/lightning.py +324 -0
  27. bead/active_learning/trainers/metrics.py +424 -0
  28. bead/active_learning/trainers/mixed_effects.py +551 -0
  29. bead/active_learning/trainers/model_wrapper.py +509 -0
  30. bead/active_learning/trainers/registry.py +104 -0
  31. bead/adapters/__init__.py +11 -0
  32. bead/adapters/huggingface.py +61 -0
  33. bead/behavioral/__init__.py +116 -0
  34. bead/behavioral/analytics.py +646 -0
  35. bead/behavioral/extraction.py +343 -0
  36. bead/behavioral/merging.py +343 -0
  37. bead/cli/__init__.py +11 -0
  38. bead/cli/active_learning.py +513 -0
  39. bead/cli/active_learning_commands.py +779 -0
  40. bead/cli/completion.py +359 -0
  41. bead/cli/config.py +624 -0
  42. bead/cli/constraint_builders.py +286 -0
  43. bead/cli/deployment.py +859 -0
  44. bead/cli/deployment_trials.py +493 -0
  45. bead/cli/deployment_ui.py +332 -0
  46. bead/cli/display.py +378 -0
  47. bead/cli/items.py +960 -0
  48. bead/cli/items_factories.py +776 -0
  49. bead/cli/list_constraints.py +714 -0
  50. bead/cli/lists.py +490 -0
  51. bead/cli/main.py +430 -0
  52. bead/cli/models.py +877 -0
  53. bead/cli/resource_loaders.py +621 -0
  54. bead/cli/resources.py +1036 -0
  55. bead/cli/shell.py +356 -0
  56. bead/cli/simulate.py +840 -0
  57. bead/cli/templates.py +1158 -0
  58. bead/cli/training.py +1080 -0
  59. bead/cli/utils.py +614 -0
  60. bead/cli/workflow.py +1273 -0
  61. bead/config/__init__.py +68 -0
  62. bead/config/active_learning.py +1009 -0
  63. bead/config/config.py +192 -0
  64. bead/config/defaults.py +118 -0
  65. bead/config/deployment.py +217 -0
  66. bead/config/env.py +147 -0
  67. bead/config/item.py +45 -0
  68. bead/config/list.py +193 -0
  69. bead/config/loader.py +149 -0
  70. bead/config/logging.py +42 -0
  71. bead/config/model.py +49 -0
  72. bead/config/paths.py +46 -0
  73. bead/config/profiles.py +320 -0
  74. bead/config/resources.py +47 -0
  75. bead/config/serialization.py +210 -0
  76. bead/config/simulation.py +206 -0
  77. bead/config/template.py +238 -0
  78. bead/config/validation.py +267 -0
  79. bead/data/__init__.py +65 -0
  80. bead/data/base.py +87 -0
  81. bead/data/identifiers.py +97 -0
  82. bead/data/language_codes.py +61 -0
  83. bead/data/metadata.py +270 -0
  84. bead/data/range.py +123 -0
  85. bead/data/repository.py +358 -0
  86. bead/data/serialization.py +249 -0
  87. bead/data/timestamps.py +89 -0
  88. bead/data/validation.py +349 -0
  89. bead/data_collection/__init__.py +11 -0
  90. bead/data_collection/jatos.py +223 -0
  91. bead/data_collection/merger.py +154 -0
  92. bead/data_collection/prolific.py +198 -0
  93. bead/deployment/__init__.py +5 -0
  94. bead/deployment/distribution.py +402 -0
  95. bead/deployment/jatos/__init__.py +1 -0
  96. bead/deployment/jatos/api.py +200 -0
  97. bead/deployment/jatos/exporter.py +210 -0
  98. bead/deployment/jspsych/__init__.py +9 -0
  99. bead/deployment/jspsych/biome.json +44 -0
  100. bead/deployment/jspsych/config.py +411 -0
  101. bead/deployment/jspsych/generator.py +598 -0
  102. bead/deployment/jspsych/package.json +51 -0
  103. bead/deployment/jspsych/pnpm-lock.yaml +2141 -0
  104. bead/deployment/jspsych/randomizer.py +299 -0
  105. bead/deployment/jspsych/src/lib/list-distributor.test.ts +327 -0
  106. bead/deployment/jspsych/src/lib/list-distributor.ts +1282 -0
  107. bead/deployment/jspsych/src/lib/randomizer.test.ts +232 -0
  108. bead/deployment/jspsych/src/lib/randomizer.ts +367 -0
  109. bead/deployment/jspsych/src/plugins/cloze-dropdown.ts +252 -0
  110. bead/deployment/jspsych/src/plugins/forced-choice.ts +265 -0
  111. bead/deployment/jspsych/src/plugins/plugins.test.ts +141 -0
  112. bead/deployment/jspsych/src/plugins/rating.ts +248 -0
  113. bead/deployment/jspsych/src/slopit/index.ts +9 -0
  114. bead/deployment/jspsych/src/types/jatos.d.ts +256 -0
  115. bead/deployment/jspsych/src/types/jspsych.d.ts +228 -0
  116. bead/deployment/jspsych/templates/experiment.css +1 -0
  117. bead/deployment/jspsych/templates/experiment.js.template +289 -0
  118. bead/deployment/jspsych/templates/index.html +51 -0
  119. bead/deployment/jspsych/templates/randomizer.js +241 -0
  120. bead/deployment/jspsych/templates/randomizer.js.template +313 -0
  121. bead/deployment/jspsych/trials.py +723 -0
  122. bead/deployment/jspsych/tsconfig.json +23 -0
  123. bead/deployment/jspsych/tsup.config.ts +30 -0
  124. bead/deployment/jspsych/ui/__init__.py +1 -0
  125. bead/deployment/jspsych/ui/components.py +383 -0
  126. bead/deployment/jspsych/ui/styles.py +411 -0
  127. bead/dsl/__init__.py +80 -0
  128. bead/dsl/ast.py +168 -0
  129. bead/dsl/context.py +178 -0
  130. bead/dsl/errors.py +71 -0
  131. bead/dsl/evaluator.py +570 -0
  132. bead/dsl/grammar.lark +81 -0
  133. bead/dsl/parser.py +231 -0
  134. bead/dsl/stdlib.py +929 -0
  135. bead/evaluation/__init__.py +13 -0
  136. bead/evaluation/convergence.py +485 -0
  137. bead/evaluation/interannotator.py +398 -0
  138. bead/items/__init__.py +40 -0
  139. bead/items/adapters/__init__.py +70 -0
  140. bead/items/adapters/anthropic.py +224 -0
  141. bead/items/adapters/api_utils.py +167 -0
  142. bead/items/adapters/base.py +216 -0
  143. bead/items/adapters/google.py +259 -0
  144. bead/items/adapters/huggingface.py +1074 -0
  145. bead/items/adapters/openai.py +323 -0
  146. bead/items/adapters/registry.py +202 -0
  147. bead/items/adapters/sentence_transformers.py +224 -0
  148. bead/items/adapters/togetherai.py +309 -0
  149. bead/items/binary.py +515 -0
  150. bead/items/cache.py +558 -0
  151. bead/items/categorical.py +593 -0
  152. bead/items/cloze.py +757 -0
  153. bead/items/constructor.py +784 -0
  154. bead/items/forced_choice.py +413 -0
  155. bead/items/free_text.py +681 -0
  156. bead/items/generation.py +432 -0
  157. bead/items/item.py +396 -0
  158. bead/items/item_template.py +787 -0
  159. bead/items/magnitude.py +573 -0
  160. bead/items/multi_select.py +621 -0
  161. bead/items/ordinal_scale.py +569 -0
  162. bead/items/scoring.py +448 -0
  163. bead/items/validation.py +723 -0
  164. bead/lists/__init__.py +30 -0
  165. bead/lists/balancer.py +263 -0
  166. bead/lists/constraints.py +1067 -0
  167. bead/lists/experiment_list.py +286 -0
  168. bead/lists/list_collection.py +378 -0
  169. bead/lists/partitioner.py +1141 -0
  170. bead/lists/stratification.py +254 -0
  171. bead/participants/__init__.py +73 -0
  172. bead/participants/collection.py +699 -0
  173. bead/participants/merging.py +312 -0
  174. bead/participants/metadata_spec.py +491 -0
  175. bead/participants/models.py +276 -0
  176. bead/resources/__init__.py +29 -0
  177. bead/resources/adapters/__init__.py +19 -0
  178. bead/resources/adapters/base.py +104 -0
  179. bead/resources/adapters/cache.py +128 -0
  180. bead/resources/adapters/glazing.py +508 -0
  181. bead/resources/adapters/registry.py +117 -0
  182. bead/resources/adapters/unimorph.py +796 -0
  183. bead/resources/classification.py +856 -0
  184. bead/resources/constraint_builders.py +329 -0
  185. bead/resources/constraints.py +165 -0
  186. bead/resources/lexical_item.py +223 -0
  187. bead/resources/lexicon.py +744 -0
  188. bead/resources/loaders.py +209 -0
  189. bead/resources/template.py +441 -0
  190. bead/resources/template_collection.py +707 -0
  191. bead/resources/template_generation.py +349 -0
  192. bead/simulation/__init__.py +29 -0
  193. bead/simulation/annotators/__init__.py +15 -0
  194. bead/simulation/annotators/base.py +175 -0
  195. bead/simulation/annotators/distance_based.py +135 -0
  196. bead/simulation/annotators/lm_based.py +114 -0
  197. bead/simulation/annotators/oracle.py +182 -0
  198. bead/simulation/annotators/random.py +181 -0
  199. bead/simulation/dsl_extension/__init__.py +3 -0
  200. bead/simulation/noise_models/__init__.py +13 -0
  201. bead/simulation/noise_models/base.py +42 -0
  202. bead/simulation/noise_models/random_noise.py +82 -0
  203. bead/simulation/noise_models/systematic.py +132 -0
  204. bead/simulation/noise_models/temperature.py +86 -0
  205. bead/simulation/runner.py +144 -0
  206. bead/simulation/strategies/__init__.py +23 -0
  207. bead/simulation/strategies/base.py +123 -0
  208. bead/simulation/strategies/binary.py +103 -0
  209. bead/simulation/strategies/categorical.py +123 -0
  210. bead/simulation/strategies/cloze.py +224 -0
  211. bead/simulation/strategies/forced_choice.py +127 -0
  212. bead/simulation/strategies/free_text.py +105 -0
  213. bead/simulation/strategies/magnitude.py +116 -0
  214. bead/simulation/strategies/multi_select.py +129 -0
  215. bead/simulation/strategies/ordinal_scale.py +131 -0
  216. bead/templates/__init__.py +27 -0
  217. bead/templates/adapters/__init__.py +17 -0
  218. bead/templates/adapters/base.py +128 -0
  219. bead/templates/adapters/cache.py +178 -0
  220. bead/templates/adapters/huggingface.py +312 -0
  221. bead/templates/combinatorics.py +103 -0
  222. bead/templates/filler.py +605 -0
  223. bead/templates/renderers.py +177 -0
  224. bead/templates/resolver.py +178 -0
  225. bead/templates/strategies.py +1806 -0
  226. bead/templates/streaming.py +195 -0
  227. bead-0.1.0.dist-info/METADATA +212 -0
  228. bead-0.1.0.dist-info/RECORD +231 -0
  229. bead-0.1.0.dist-info/WHEEL +4 -0
  230. bead-0.1.0.dist-info/entry_points.txt +2 -0
  231. bead-0.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,323 @@
1
+ """OpenAI API adapter for item construction.
2
+
3
+ This module provides a ModelAdapter implementation for OpenAI's API,
4
+ supporting GPT models for various NLP tasks including log probability
5
+ computation, embeddings, and natural language inference via prompting.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import os
11
+
12
+ import numpy as np
13
+
14
+ try:
15
+ import openai
16
+ except ImportError as e:
17
+ raise ImportError(
18
+ "openai package is required for OpenAI adapter. "
19
+ "Install it with: pip install openai"
20
+ ) from e
21
+
22
+ from bead.items.adapters.api_utils import rate_limit, retry_with_backoff
23
+ from bead.items.adapters.base import ModelAdapter
24
+ from bead.items.cache import ModelOutputCache
25
+
26
+
27
+ class OpenAIAdapter(ModelAdapter):
28
+ """Adapter for OpenAI API models.
29
+
30
+ Provides access to OpenAI's GPT models for language model operations,
31
+ embeddings, and prompted natural language inference.
32
+
33
+ Parameters
34
+ ----------
35
+ model_name : str
36
+ OpenAI model identifier (default: "gpt-3.5-turbo").
37
+ api_key : str | None
38
+ OpenAI API key. If None, uses OPENAI_API_KEY environment variable.
39
+ cache : ModelOutputCache | None
40
+ Cache for model outputs. If None, creates in-memory cache.
41
+ model_version : str
42
+ Model version for cache tracking (default: "latest").
43
+ embedding_model : str
44
+ Model to use for embeddings (default: "text-embedding-ada-002").
45
+
46
+ Attributes
47
+ ----------
48
+ model_name : str
49
+ OpenAI model identifier (e.g., "gpt-3.5-turbo", "gpt-4").
50
+ client : openai.OpenAI
51
+ OpenAI API client.
52
+ embedding_model : str
53
+ Model to use for embeddings (default: "text-embedding-ada-002").
54
+
55
+ Raises
56
+ ------
57
+ ValueError
58
+ If no API key is provided and OPENAI_API_KEY is not set.
59
+ """
60
+
61
+ def __init__(
62
+ self,
63
+ model_name: str = "gpt-3.5-turbo",
64
+ api_key: str | None = None,
65
+ cache: ModelOutputCache | None = None,
66
+ model_version: str = "latest",
67
+ embedding_model: str = "text-embedding-ada-002",
68
+ ) -> None:
69
+ if cache is None:
70
+ cache = ModelOutputCache(backend="memory")
71
+
72
+ super().__init__(
73
+ model_name=model_name, cache=cache, model_version=model_version
74
+ )
75
+
76
+ # Get API key from parameter or environment
77
+ if api_key is None:
78
+ api_key = os.environ.get("OPENAI_API_KEY")
79
+ if api_key is None:
80
+ raise ValueError(
81
+ "OpenAI API key must be provided via api_key parameter "
82
+ "or OPENAI_API_KEY environment variable"
83
+ )
84
+
85
+ self.client = openai.OpenAI(api_key=api_key)
86
+ self.embedding_model = embedding_model
87
+
88
+ @retry_with_backoff(
89
+ max_retries=3,
90
+ initial_delay=1.0,
91
+ backoff_factor=2.0,
92
+ exceptions=(openai.APIError, openai.APIConnectionError, openai.RateLimitError),
93
+ )
94
+ @rate_limit(calls_per_minute=60)
95
+ def compute_log_probability(self, text: str) -> float:
96
+ """Compute log probability of text using OpenAI completions API.
97
+
98
+ Uses the completions API with logprobs to get token-level log probabilities
99
+ and sums them to get the total log probability.
100
+
101
+ Parameters
102
+ ----------
103
+ text : str
104
+ Text to compute log probability for.
105
+
106
+ Returns
107
+ -------
108
+ float
109
+ Log probability of the text (sum of token log probabilities).
110
+ """
111
+ # Check cache
112
+ cached = self.cache.get(
113
+ model_name=self.model_name, operation="log_probability", text=text
114
+ )
115
+ if cached is not None:
116
+ return float(cached)
117
+
118
+ # Call API
119
+ response = self.client.completions.create(
120
+ model=self.model_name,
121
+ prompt=text,
122
+ max_tokens=0,
123
+ echo=True,
124
+ logprobs=1,
125
+ )
126
+
127
+ # Sum token log probabilities
128
+ logprobs = response.choices[0].logprobs
129
+ if logprobs is None or logprobs.token_logprobs is None:
130
+ raise ValueError("API response did not include logprobs")
131
+
132
+ # Filter out None values (first token may have None)
133
+ token_logprobs = [lp for lp in logprobs.token_logprobs if lp is not None]
134
+ total_log_prob = sum(token_logprobs)
135
+
136
+ # Cache result
137
+ self.cache.set(
138
+ model_name=self.model_name,
139
+ operation="log_probability",
140
+ result=total_log_prob,
141
+ model_version=self.model_version,
142
+ text=text,
143
+ )
144
+
145
+ return float(total_log_prob)
146
+
147
+ def compute_perplexity(self, text: str) -> float:
148
+ """Compute perplexity of text.
149
+
150
+ Perplexity is computed as exp(-log_prob / num_tokens).
151
+
152
+ Parameters
153
+ ----------
154
+ text : str
155
+ Text to compute perplexity for.
156
+
157
+ Returns
158
+ -------
159
+ float
160
+ Perplexity of the text (must be positive).
161
+ """
162
+ # Check cache
163
+ cached = self.cache.get(
164
+ model_name=self.model_name, operation="perplexity", text=text
165
+ )
166
+ if cached is not None:
167
+ return float(cached)
168
+
169
+ # Get log probability
170
+ log_prob = self.compute_log_probability(text)
171
+
172
+ # Estimate number of tokens (rough approximation: 1 token ~ 4 chars)
173
+ num_tokens = max(1, len(text) // 4)
174
+
175
+ # Compute perplexity: exp(-log_prob / num_tokens)
176
+ perplexity = np.exp(-log_prob / num_tokens)
177
+
178
+ # Cache result
179
+ self.cache.set(
180
+ model_name=self.model_name,
181
+ operation="perplexity",
182
+ result=float(perplexity),
183
+ model_version=self.model_version,
184
+ text=text,
185
+ )
186
+
187
+ return float(perplexity)
188
+
189
+ @retry_with_backoff(
190
+ max_retries=3,
191
+ initial_delay=1.0,
192
+ backoff_factor=2.0,
193
+ exceptions=(openai.APIError, openai.APIConnectionError, openai.RateLimitError),
194
+ )
195
+ @rate_limit(calls_per_minute=60)
196
+ def get_embedding(self, text: str) -> np.ndarray:
197
+ """Get embedding vector for text using OpenAI embeddings API.
198
+
199
+ Parameters
200
+ ----------
201
+ text : str
202
+ Text to embed.
203
+
204
+ Returns
205
+ -------
206
+ np.ndarray
207
+ Embedding vector for the text.
208
+ """
209
+ # Check cache
210
+ cached = self.cache.get(
211
+ model_name=self.embedding_model, operation="embedding", text=text
212
+ )
213
+ if cached is not None:
214
+ return np.array(cached)
215
+
216
+ # Call API
217
+ response = self.client.embeddings.create(model=self.embedding_model, input=text)
218
+
219
+ embedding = np.array(response.data[0].embedding)
220
+
221
+ # Cache result
222
+ self.cache.set(
223
+ model_name=self.embedding_model,
224
+ operation="embedding",
225
+ result=embedding.tolist(),
226
+ model_version=self.model_version,
227
+ text=text,
228
+ )
229
+
230
+ return embedding
231
+
232
+ @retry_with_backoff(
233
+ max_retries=3,
234
+ initial_delay=1.0,
235
+ backoff_factor=2.0,
236
+ exceptions=(openai.APIError, openai.APIConnectionError, openai.RateLimitError),
237
+ )
238
+ @rate_limit(calls_per_minute=60)
239
+ def compute_nli(self, premise: str, hypothesis: str) -> dict[str, float]:
240
+ """Compute natural language inference scores via prompting.
241
+
242
+ Uses chat completions API with a prompt to classify the relationship
243
+ between premise and hypothesis.
244
+
245
+ Parameters
246
+ ----------
247
+ premise : str
248
+ Premise text.
249
+ hypothesis : str
250
+ Hypothesis text.
251
+
252
+ Returns
253
+ -------
254
+ dict[str, float]
255
+ Dictionary with keys "entailment", "neutral", "contradiction"
256
+ mapping to probability scores.
257
+ """
258
+ # Check cache
259
+ cached = self.cache.get(
260
+ model_name=self.model_name,
261
+ operation="nli",
262
+ premise=premise,
263
+ hypothesis=hypothesis,
264
+ )
265
+ if cached is not None:
266
+ return dict(cached)
267
+
268
+ # Construct prompt
269
+ prompt = (
270
+ "Given the following premise and hypothesis, "
271
+ "determine the relationship between them.\n\n"
272
+ f"Premise: {premise}\n"
273
+ f"Hypothesis: {hypothesis}\n\n"
274
+ "Choose one of the following:\n"
275
+ "- entailment: The hypothesis is definitely true given the premise\n"
276
+ "- neutral: The hypothesis might be true given the premise\n"
277
+ "- contradiction: The hypothesis is definitely false given the premise\n\n"
278
+ "Respond with only one word: entailment, neutral, or contradiction."
279
+ )
280
+
281
+ # Call API
282
+ response = self.client.chat.completions.create(
283
+ model=self.model_name,
284
+ messages=[{"role": "user", "content": prompt}],
285
+ temperature=0.0,
286
+ max_tokens=10,
287
+ )
288
+
289
+ # Parse response
290
+ answer = response.choices[0].message.content
291
+ if answer is None:
292
+ raise ValueError("API response did not include content")
293
+
294
+ answer = answer.strip().lower()
295
+
296
+ # Map to scores
297
+ scores: dict[str, float] = {
298
+ "entailment": 0.0,
299
+ "neutral": 0.0,
300
+ "contradiction": 0.0,
301
+ }
302
+
303
+ if "entailment" in answer:
304
+ scores["entailment"] = 1.0
305
+ elif "neutral" in answer:
306
+ scores["neutral"] = 1.0
307
+ elif "contradiction" in answer:
308
+ scores["contradiction"] = 1.0
309
+ else:
310
+ # Default to neutral if unclear
311
+ scores["neutral"] = 1.0
312
+
313
+ # Cache result
314
+ self.cache.set(
315
+ model_name=self.model_name,
316
+ operation="nli",
317
+ result=scores,
318
+ model_version=self.model_version,
319
+ premise=premise,
320
+ hypothesis=hypothesis,
321
+ )
322
+
323
+ return scores
@@ -0,0 +1,202 @@
1
+ """Model adapter registry for centralized adapter management.
2
+
3
+ This module provides a registry for managing all model adapters,
4
+ both local (HuggingFace) and API-based (OpenAI, Anthropic, etc.).
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from typing import TYPE_CHECKING, Unpack
10
+
11
+ from typing_extensions import TypedDict
12
+
13
+ if TYPE_CHECKING:
14
+ from bead.items.cache import ModelOutputCache
15
+
16
+
17
+ class AdapterKwargs(TypedDict, total=False):
18
+ """Keyword arguments for adapter initialization."""
19
+
20
+ api_key: str
21
+ device: str
22
+ model_version: str
23
+ embedding_model: str
24
+ normalize_embeddings: bool
25
+ cache: ModelOutputCache
26
+
27
+
28
+ from bead.items.adapters.base import ModelAdapter # noqa: E402
29
+
30
+
31
+ class ModelAdapterRegistry:
32
+ """Registry for all model adapters (local and API-based).
33
+
34
+ Provides centralized management of adapter types and instances,
35
+ with automatic instance caching to avoid redundant initialization.
36
+
37
+ Attributes
38
+ ----------
39
+ adapters : dict[str, type[ModelAdapter]]
40
+ Registered adapter classes keyed by adapter type name.
41
+ instances : dict[str, ModelAdapter]
42
+ Cached adapter instances keyed by unique identifier.
43
+ """
44
+
45
+ def __init__(self) -> None:
46
+ self.adapters: dict[str, type[ModelAdapter]] = {}
47
+ self.instances: dict[str, ModelAdapter] = {}
48
+
49
+ def register(self, name: str, adapter_class: type[ModelAdapter]) -> None:
50
+ """Register an adapter class.
51
+
52
+ Parameters
53
+ ----------
54
+ name : str
55
+ Unique name for the adapter type (e.g., "openai", "huggingface_lm").
56
+ adapter_class : type[ModelAdapter]
57
+ Adapter class to register (must inherit from ModelAdapter).
58
+
59
+ Raises
60
+ ------
61
+ ValueError
62
+ If adapter class does not inherit from ModelAdapter.
63
+ """
64
+ if not issubclass(adapter_class, ModelAdapter): # type: ignore[misc]
65
+ raise ValueError(
66
+ f"Adapter class {adapter_class.__name__} must inherit from ModelAdapter"
67
+ )
68
+ self.adapters[name] = adapter_class
69
+
70
+ def get_adapter(
71
+ self, adapter_type: str, model_name: str, **kwargs: Unpack[AdapterKwargs]
72
+ ) -> ModelAdapter:
73
+ """Get or create adapter instance (with caching).
74
+
75
+ Creates a new adapter instance if not cached, otherwise returns
76
+ the cached instance. Instances are cached by adapter type and model name.
77
+
78
+ Parameters
79
+ ----------
80
+ adapter_type
81
+ Type of adapter (must be registered).
82
+ model_name
83
+ Model identifier for the adapter.
84
+ **kwargs
85
+ Additional keyword arguments to pass to adapter constructor
86
+ (api_key, device, model_version, embedding_model, etc.).
87
+
88
+ Returns
89
+ -------
90
+ ModelAdapter
91
+ Adapter instance (cached or newly created).
92
+
93
+ Raises
94
+ ------
95
+ ValueError
96
+ If adapter type is not registered.
97
+
98
+ Examples
99
+ --------
100
+ >>> registry = ModelAdapterRegistry()
101
+ >>> registry.register("openai", OpenAIAdapter)
102
+ >>> adapter = registry.get_adapter("openai", "gpt-4", api_key="...")
103
+ """
104
+ if adapter_type not in self.adapters:
105
+ raise ValueError(
106
+ f"Unknown adapter type: {adapter_type}. "
107
+ f"Available types: {list(self.adapters.keys())}"
108
+ )
109
+
110
+ # create cache key from adapter type and model name
111
+ cache_key = f"{adapter_type}:{model_name}"
112
+
113
+ # return cached instance if available
114
+ if cache_key in self.instances:
115
+ return self.instances[cache_key]
116
+
117
+ # create new instance
118
+ adapter_class = self.adapters[adapter_type]
119
+ adapter = adapter_class(model_name=model_name, **kwargs) # type: ignore[misc]
120
+
121
+ # cache and return
122
+ self.instances[cache_key] = adapter
123
+ return adapter
124
+
125
+ def clear_cache(self) -> None:
126
+ """Clear all cached adapter instances.
127
+
128
+ Useful for testing or when you want to force recreation of adapters
129
+ with different parameters.
130
+ """
131
+ self.instances.clear()
132
+
133
+ def list_adapters(self) -> list[str]:
134
+ """List all registered adapter types.
135
+
136
+ Returns
137
+ -------
138
+ list[str]
139
+ List of registered adapter type names.
140
+ """
141
+ return list(self.adapters.keys())
142
+
143
+
144
+ # Create default registry with all built-in adapters
145
+ default_registry = ModelAdapterRegistry()
146
+
147
+ # Register HuggingFace adapters
148
+ try:
149
+ from bead.items.adapters.huggingface import (
150
+ HuggingFaceLanguageModel,
151
+ HuggingFaceMaskedLanguageModel,
152
+ HuggingFaceNLI,
153
+ )
154
+
155
+ default_registry.register("huggingface_lm", HuggingFaceLanguageModel)
156
+ default_registry.register("huggingface_mlm", HuggingFaceMaskedLanguageModel)
157
+ default_registry.register("huggingface_nli", HuggingFaceNLI)
158
+ except ImportError:
159
+ # HuggingFace adapters not available (missing dependencies)
160
+ pass
161
+
162
+ # Register sentence transformers
163
+ try:
164
+ from bead.items.adapters.sentence_transformers import HuggingFaceSentenceTransformer
165
+
166
+ default_registry.register("sentence_transformer", HuggingFaceSentenceTransformer)
167
+ except ImportError:
168
+ # Sentence transformers not available
169
+ pass
170
+
171
+ # Register API adapters (these are optional)
172
+ try:
173
+ from bead.items.adapters.openai import OpenAIAdapter
174
+
175
+ default_registry.register("openai", OpenAIAdapter)
176
+ except ImportError:
177
+ # OpenAI not available
178
+ pass
179
+
180
+ try:
181
+ from bead.items.adapters.anthropic import AnthropicAdapter
182
+
183
+ default_registry.register("anthropic", AnthropicAdapter)
184
+ except ImportError:
185
+ # Anthropic not available
186
+ pass
187
+
188
+ try:
189
+ from bead.items.adapters.google import GoogleAdapter
190
+
191
+ default_registry.register("google", GoogleAdapter)
192
+ except ImportError:
193
+ # Google not available
194
+ pass
195
+
196
+ try:
197
+ from bead.items.adapters.togetherai import TogetherAIAdapter
198
+
199
+ default_registry.register("togetherai", TogetherAIAdapter)
200
+ except ImportError:
201
+ # Together AI not available
202
+ pass