bead 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. bead/__init__.py +11 -0
  2. bead/__main__.py +11 -0
  3. bead/active_learning/__init__.py +15 -0
  4. bead/active_learning/config.py +231 -0
  5. bead/active_learning/loop.py +566 -0
  6. bead/active_learning/models/__init__.py +24 -0
  7. bead/active_learning/models/base.py +852 -0
  8. bead/active_learning/models/binary.py +910 -0
  9. bead/active_learning/models/categorical.py +943 -0
  10. bead/active_learning/models/cloze.py +862 -0
  11. bead/active_learning/models/forced_choice.py +956 -0
  12. bead/active_learning/models/free_text.py +773 -0
  13. bead/active_learning/models/lora.py +365 -0
  14. bead/active_learning/models/magnitude.py +835 -0
  15. bead/active_learning/models/multi_select.py +795 -0
  16. bead/active_learning/models/ordinal_scale.py +811 -0
  17. bead/active_learning/models/peft_adapter.py +155 -0
  18. bead/active_learning/models/random_effects.py +639 -0
  19. bead/active_learning/selection.py +354 -0
  20. bead/active_learning/strategies.py +391 -0
  21. bead/active_learning/trainers/__init__.py +26 -0
  22. bead/active_learning/trainers/base.py +210 -0
  23. bead/active_learning/trainers/data_collator.py +172 -0
  24. bead/active_learning/trainers/dataset_utils.py +261 -0
  25. bead/active_learning/trainers/huggingface.py +304 -0
  26. bead/active_learning/trainers/lightning.py +324 -0
  27. bead/active_learning/trainers/metrics.py +424 -0
  28. bead/active_learning/trainers/mixed_effects.py +551 -0
  29. bead/active_learning/trainers/model_wrapper.py +509 -0
  30. bead/active_learning/trainers/registry.py +104 -0
  31. bead/adapters/__init__.py +11 -0
  32. bead/adapters/huggingface.py +61 -0
  33. bead/behavioral/__init__.py +116 -0
  34. bead/behavioral/analytics.py +646 -0
  35. bead/behavioral/extraction.py +343 -0
  36. bead/behavioral/merging.py +343 -0
  37. bead/cli/__init__.py +11 -0
  38. bead/cli/active_learning.py +513 -0
  39. bead/cli/active_learning_commands.py +779 -0
  40. bead/cli/completion.py +359 -0
  41. bead/cli/config.py +624 -0
  42. bead/cli/constraint_builders.py +286 -0
  43. bead/cli/deployment.py +859 -0
  44. bead/cli/deployment_trials.py +493 -0
  45. bead/cli/deployment_ui.py +332 -0
  46. bead/cli/display.py +378 -0
  47. bead/cli/items.py +960 -0
  48. bead/cli/items_factories.py +776 -0
  49. bead/cli/list_constraints.py +714 -0
  50. bead/cli/lists.py +490 -0
  51. bead/cli/main.py +430 -0
  52. bead/cli/models.py +877 -0
  53. bead/cli/resource_loaders.py +621 -0
  54. bead/cli/resources.py +1036 -0
  55. bead/cli/shell.py +356 -0
  56. bead/cli/simulate.py +840 -0
  57. bead/cli/templates.py +1158 -0
  58. bead/cli/training.py +1080 -0
  59. bead/cli/utils.py +614 -0
  60. bead/cli/workflow.py +1273 -0
  61. bead/config/__init__.py +68 -0
  62. bead/config/active_learning.py +1009 -0
  63. bead/config/config.py +192 -0
  64. bead/config/defaults.py +118 -0
  65. bead/config/deployment.py +217 -0
  66. bead/config/env.py +147 -0
  67. bead/config/item.py +45 -0
  68. bead/config/list.py +193 -0
  69. bead/config/loader.py +149 -0
  70. bead/config/logging.py +42 -0
  71. bead/config/model.py +49 -0
  72. bead/config/paths.py +46 -0
  73. bead/config/profiles.py +320 -0
  74. bead/config/resources.py +47 -0
  75. bead/config/serialization.py +210 -0
  76. bead/config/simulation.py +206 -0
  77. bead/config/template.py +238 -0
  78. bead/config/validation.py +267 -0
  79. bead/data/__init__.py +65 -0
  80. bead/data/base.py +87 -0
  81. bead/data/identifiers.py +97 -0
  82. bead/data/language_codes.py +61 -0
  83. bead/data/metadata.py +270 -0
  84. bead/data/range.py +123 -0
  85. bead/data/repository.py +358 -0
  86. bead/data/serialization.py +249 -0
  87. bead/data/timestamps.py +89 -0
  88. bead/data/validation.py +349 -0
  89. bead/data_collection/__init__.py +11 -0
  90. bead/data_collection/jatos.py +223 -0
  91. bead/data_collection/merger.py +154 -0
  92. bead/data_collection/prolific.py +198 -0
  93. bead/deployment/__init__.py +5 -0
  94. bead/deployment/distribution.py +402 -0
  95. bead/deployment/jatos/__init__.py +1 -0
  96. bead/deployment/jatos/api.py +200 -0
  97. bead/deployment/jatos/exporter.py +210 -0
  98. bead/deployment/jspsych/__init__.py +9 -0
  99. bead/deployment/jspsych/biome.json +44 -0
  100. bead/deployment/jspsych/config.py +411 -0
  101. bead/deployment/jspsych/generator.py +598 -0
  102. bead/deployment/jspsych/package.json +51 -0
  103. bead/deployment/jspsych/pnpm-lock.yaml +2141 -0
  104. bead/deployment/jspsych/randomizer.py +299 -0
  105. bead/deployment/jspsych/src/lib/list-distributor.test.ts +327 -0
  106. bead/deployment/jspsych/src/lib/list-distributor.ts +1282 -0
  107. bead/deployment/jspsych/src/lib/randomizer.test.ts +232 -0
  108. bead/deployment/jspsych/src/lib/randomizer.ts +367 -0
  109. bead/deployment/jspsych/src/plugins/cloze-dropdown.ts +252 -0
  110. bead/deployment/jspsych/src/plugins/forced-choice.ts +265 -0
  111. bead/deployment/jspsych/src/plugins/plugins.test.ts +141 -0
  112. bead/deployment/jspsych/src/plugins/rating.ts +248 -0
  113. bead/deployment/jspsych/src/slopit/index.ts +9 -0
  114. bead/deployment/jspsych/src/types/jatos.d.ts +256 -0
  115. bead/deployment/jspsych/src/types/jspsych.d.ts +228 -0
  116. bead/deployment/jspsych/templates/experiment.css +1 -0
  117. bead/deployment/jspsych/templates/experiment.js.template +289 -0
  118. bead/deployment/jspsych/templates/index.html +51 -0
  119. bead/deployment/jspsych/templates/randomizer.js +241 -0
  120. bead/deployment/jspsych/templates/randomizer.js.template +313 -0
  121. bead/deployment/jspsych/trials.py +723 -0
  122. bead/deployment/jspsych/tsconfig.json +23 -0
  123. bead/deployment/jspsych/tsup.config.ts +30 -0
  124. bead/deployment/jspsych/ui/__init__.py +1 -0
  125. bead/deployment/jspsych/ui/components.py +383 -0
  126. bead/deployment/jspsych/ui/styles.py +411 -0
  127. bead/dsl/__init__.py +80 -0
  128. bead/dsl/ast.py +168 -0
  129. bead/dsl/context.py +178 -0
  130. bead/dsl/errors.py +71 -0
  131. bead/dsl/evaluator.py +570 -0
  132. bead/dsl/grammar.lark +81 -0
  133. bead/dsl/parser.py +231 -0
  134. bead/dsl/stdlib.py +929 -0
  135. bead/evaluation/__init__.py +13 -0
  136. bead/evaluation/convergence.py +485 -0
  137. bead/evaluation/interannotator.py +398 -0
  138. bead/items/__init__.py +40 -0
  139. bead/items/adapters/__init__.py +70 -0
  140. bead/items/adapters/anthropic.py +224 -0
  141. bead/items/adapters/api_utils.py +167 -0
  142. bead/items/adapters/base.py +216 -0
  143. bead/items/adapters/google.py +259 -0
  144. bead/items/adapters/huggingface.py +1074 -0
  145. bead/items/adapters/openai.py +323 -0
  146. bead/items/adapters/registry.py +202 -0
  147. bead/items/adapters/sentence_transformers.py +224 -0
  148. bead/items/adapters/togetherai.py +309 -0
  149. bead/items/binary.py +515 -0
  150. bead/items/cache.py +558 -0
  151. bead/items/categorical.py +593 -0
  152. bead/items/cloze.py +757 -0
  153. bead/items/constructor.py +784 -0
  154. bead/items/forced_choice.py +413 -0
  155. bead/items/free_text.py +681 -0
  156. bead/items/generation.py +432 -0
  157. bead/items/item.py +396 -0
  158. bead/items/item_template.py +787 -0
  159. bead/items/magnitude.py +573 -0
  160. bead/items/multi_select.py +621 -0
  161. bead/items/ordinal_scale.py +569 -0
  162. bead/items/scoring.py +448 -0
  163. bead/items/validation.py +723 -0
  164. bead/lists/__init__.py +30 -0
  165. bead/lists/balancer.py +263 -0
  166. bead/lists/constraints.py +1067 -0
  167. bead/lists/experiment_list.py +286 -0
  168. bead/lists/list_collection.py +378 -0
  169. bead/lists/partitioner.py +1141 -0
  170. bead/lists/stratification.py +254 -0
  171. bead/participants/__init__.py +73 -0
  172. bead/participants/collection.py +699 -0
  173. bead/participants/merging.py +312 -0
  174. bead/participants/metadata_spec.py +491 -0
  175. bead/participants/models.py +276 -0
  176. bead/resources/__init__.py +29 -0
  177. bead/resources/adapters/__init__.py +19 -0
  178. bead/resources/adapters/base.py +104 -0
  179. bead/resources/adapters/cache.py +128 -0
  180. bead/resources/adapters/glazing.py +508 -0
  181. bead/resources/adapters/registry.py +117 -0
  182. bead/resources/adapters/unimorph.py +796 -0
  183. bead/resources/classification.py +856 -0
  184. bead/resources/constraint_builders.py +329 -0
  185. bead/resources/constraints.py +165 -0
  186. bead/resources/lexical_item.py +223 -0
  187. bead/resources/lexicon.py +744 -0
  188. bead/resources/loaders.py +209 -0
  189. bead/resources/template.py +441 -0
  190. bead/resources/template_collection.py +707 -0
  191. bead/resources/template_generation.py +349 -0
  192. bead/simulation/__init__.py +29 -0
  193. bead/simulation/annotators/__init__.py +15 -0
  194. bead/simulation/annotators/base.py +175 -0
  195. bead/simulation/annotators/distance_based.py +135 -0
  196. bead/simulation/annotators/lm_based.py +114 -0
  197. bead/simulation/annotators/oracle.py +182 -0
  198. bead/simulation/annotators/random.py +181 -0
  199. bead/simulation/dsl_extension/__init__.py +3 -0
  200. bead/simulation/noise_models/__init__.py +13 -0
  201. bead/simulation/noise_models/base.py +42 -0
  202. bead/simulation/noise_models/random_noise.py +82 -0
  203. bead/simulation/noise_models/systematic.py +132 -0
  204. bead/simulation/noise_models/temperature.py +86 -0
  205. bead/simulation/runner.py +144 -0
  206. bead/simulation/strategies/__init__.py +23 -0
  207. bead/simulation/strategies/base.py +123 -0
  208. bead/simulation/strategies/binary.py +103 -0
  209. bead/simulation/strategies/categorical.py +123 -0
  210. bead/simulation/strategies/cloze.py +224 -0
  211. bead/simulation/strategies/forced_choice.py +127 -0
  212. bead/simulation/strategies/free_text.py +105 -0
  213. bead/simulation/strategies/magnitude.py +116 -0
  214. bead/simulation/strategies/multi_select.py +129 -0
  215. bead/simulation/strategies/ordinal_scale.py +131 -0
  216. bead/templates/__init__.py +27 -0
  217. bead/templates/adapters/__init__.py +17 -0
  218. bead/templates/adapters/base.py +128 -0
  219. bead/templates/adapters/cache.py +178 -0
  220. bead/templates/adapters/huggingface.py +312 -0
  221. bead/templates/combinatorics.py +103 -0
  222. bead/templates/filler.py +605 -0
  223. bead/templates/renderers.py +177 -0
  224. bead/templates/resolver.py +178 -0
  225. bead/templates/strategies.py +1806 -0
  226. bead/templates/streaming.py +195 -0
  227. bead-0.1.0.dist-info/METADATA +212 -0
  228. bead-0.1.0.dist-info/RECORD +231 -0
  229. bead-0.1.0.dist-info/WHEEL +4 -0
  230. bead-0.1.0.dist-info/entry_points.txt +2 -0
  231. bead-0.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,224 @@
1
+ """Sentence transformer adapter for semantic embeddings.
2
+
3
+ This module provides an adapter for sentence-transformers models,
4
+ which are optimized for generating sentence embeddings for semantic
5
+ similarity tasks.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import logging
11
+ from typing import TYPE_CHECKING
12
+
13
+ import numpy as np
14
+
15
+ from bead.items.adapters.base import ModelAdapter
16
+ from bead.items.cache import ModelOutputCache
17
+
18
+ if TYPE_CHECKING:
19
+ from sentence_transformers import SentenceTransformer
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ class HuggingFaceSentenceTransformer(ModelAdapter):
25
+ """Adapter for sentence-transformers models.
26
+
27
+ Supports sentence-transformers models like "all-MiniLM-L6-v2",
28
+ "all-mpnet-base-v2", etc. These models are optimized for generating
29
+ sentence embeddings for semantic similarity tasks.
30
+
31
+ Parameters
32
+ ----------
33
+ model_name : str
34
+ Sentence transformer model identifier.
35
+ cache : ModelOutputCache
36
+ Cache instance for storing model outputs.
37
+ device : str | None
38
+ Device to run model on. If None, uses sentence-transformers default.
39
+ model_version : str
40
+ Version string for cache tracking.
41
+ normalize_embeddings : bool
42
+ Whether to normalize embeddings to unit length.
43
+
44
+ Examples
45
+ --------
46
+ >>> from pathlib import Path
47
+ >>> from bead.items.cache import ModelOutputCache
48
+ >>> cache = ModelOutputCache(cache_dir=Path(".cache"))
49
+ >>> model = HuggingFaceSentenceTransformer("all-MiniLM-L6-v2", cache)
50
+ >>> embedding = model.get_embedding("The cat sat on the mat.")
51
+ >>> similarity = model.compute_similarity("The cat sat.", "The dog stood.")
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ model_name: str,
57
+ cache: ModelOutputCache,
58
+ device: str | None = None,
59
+ model_version: str = "unknown",
60
+ normalize_embeddings: bool = True,
61
+ ) -> None:
62
+ super().__init__(model_name, cache, model_version)
63
+ self.device = device
64
+ self.normalize_embeddings = normalize_embeddings
65
+ self._model: SentenceTransformer | None = None
66
+
67
+ def _load_model(self) -> None:
68
+ """Load model lazily on first use."""
69
+ if self._model is None:
70
+ from sentence_transformers import SentenceTransformer # noqa: PLC0415
71
+
72
+ logger.info(f"Loading sentence transformer: {self.model_name}")
73
+ self._model = SentenceTransformer(self.model_name, device=self.device)
74
+
75
+ @property
76
+ def model(self) -> SentenceTransformer:
77
+ """Get the model, loading if necessary."""
78
+ self._load_model()
79
+ assert self._model is not None
80
+ return self._model
81
+
82
+ def compute_log_probability(self, text: str) -> float:
83
+ """Compute log probability of text.
84
+
85
+ Not supported for sentence transformer models.
86
+
87
+ Raises
88
+ ------
89
+ NotImplementedError
90
+ Always raised, as sentence transformers don't provide log probabilities.
91
+ """
92
+ raise NotImplementedError(
93
+ f"Log probability is not supported for sentence transformer "
94
+ f"{self.model_name}. Use HuggingFaceLanguageModel or "
95
+ "HuggingFaceMaskedLanguageModel instead."
96
+ )
97
+
98
+ def compute_perplexity(self, text: str) -> float:
99
+ """Compute perplexity of text.
100
+
101
+ Not supported for sentence transformer models.
102
+
103
+ Raises
104
+ ------
105
+ NotImplementedError
106
+ Always raised, as sentence transformers don't provide perplexity.
107
+ """
108
+ raise NotImplementedError(
109
+ f"Perplexity is not supported for sentence transformer {self.model_name}. "
110
+ "Use HuggingFaceLanguageModel or HuggingFaceMaskedLanguageModel instead."
111
+ )
112
+
113
+ def get_embedding(self, text: str) -> np.ndarray:
114
+ """Get embedding vector for text.
115
+
116
+ Uses sentence-transformers encode() method to generate
117
+ optimized sentence embeddings.
118
+
119
+ Parameters
120
+ ----------
121
+ text : str
122
+ Text to embed.
123
+
124
+ Returns
125
+ -------
126
+ np.ndarray
127
+ Embedding vector for the text.
128
+ """
129
+ # Check cache
130
+ cached = self.cache.get(self.model_name, "embedding", text=text)
131
+ if cached is not None:
132
+ return cached
133
+
134
+ # Encode text
135
+ embedding = self.model.encode(
136
+ text,
137
+ convert_to_numpy=True,
138
+ normalize_embeddings=self.normalize_embeddings,
139
+ )
140
+
141
+ # Ensure it's a numpy array
142
+ if not isinstance(embedding, np.ndarray):
143
+ embedding = np.array(embedding)
144
+
145
+ # Cache result
146
+ self.cache.set(
147
+ self.model_name,
148
+ "embedding",
149
+ embedding,
150
+ model_version=self.model_version,
151
+ text=text,
152
+ )
153
+
154
+ return embedding
155
+
156
+ def compute_nli(self, premise: str, hypothesis: str) -> dict[str, float]:
157
+ """Compute natural language inference scores.
158
+
159
+ Not supported for sentence transformer models.
160
+
161
+ Raises
162
+ ------
163
+ NotImplementedError
164
+ Always raised, as sentence transformers don't support NLI directly.
165
+ """
166
+ raise NotImplementedError(
167
+ f"NLI is not supported for sentence transformer {self.model_name}. "
168
+ "Use HuggingFaceNLI adapter with an NLI-trained model instead."
169
+ )
170
+
171
+ def compute_similarity(self, text1: str, text2: str) -> float:
172
+ """Compute similarity between two texts.
173
+
174
+ Uses cosine similarity of embeddings. For sentence transformers,
175
+ this is optimized as embeddings are already normalized (if
176
+ normalize_embeddings=True).
177
+
178
+ Parameters
179
+ ----------
180
+ text1 : str
181
+ First text.
182
+ text2 : str
183
+ Second text.
184
+
185
+ Returns
186
+ -------
187
+ float
188
+ Similarity score in [-1, 1] (cosine similarity).
189
+ """
190
+ # Check cache
191
+ cached = self.cache.get(self.model_name, "similarity", text1=text1, text2=text2)
192
+ if cached is not None:
193
+ return cached
194
+
195
+ # Get embeddings
196
+ emb1 = self.get_embedding(text1)
197
+ emb2 = self.get_embedding(text2)
198
+
199
+ # Compute cosine similarity
200
+ if self.normalize_embeddings:
201
+ # Embeddings are already normalized, just dot product
202
+ similarity = float(np.dot(emb1, emb2))
203
+ else:
204
+ # Need to normalize
205
+ dot_product = np.dot(emb1, emb2)
206
+ norm1 = np.linalg.norm(emb1)
207
+ norm2 = np.linalg.norm(emb2)
208
+
209
+ if norm1 == 0 or norm2 == 0:
210
+ similarity = 0.0
211
+ else:
212
+ similarity = float(dot_product / (norm1 * norm2))
213
+
214
+ # Cache result
215
+ self.cache.set(
216
+ self.model_name,
217
+ "similarity",
218
+ similarity,
219
+ model_version=self.model_version,
220
+ text1=text1,
221
+ text2=text2,
222
+ )
223
+
224
+ return similarity
@@ -0,0 +1,309 @@
1
+ """Together AI adapter for item construction.
2
+
3
+ This module provides a ModelAdapter implementation for Together AI's API,
4
+ which provides access to various open-source models. Together AI uses an
5
+ OpenAI-compatible API, so we use the OpenAI client with a custom base URL.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import os
11
+
12
+ import numpy as np
13
+
14
+ try:
15
+ import openai
16
+ except ImportError as e:
17
+ raise ImportError(
18
+ "openai package is required for Together AI adapter. "
19
+ "Install it with: pip install openai"
20
+ ) from e
21
+
22
+ from bead.items.adapters.api_utils import rate_limit, retry_with_backoff
23
+ from bead.items.adapters.base import ModelAdapter
24
+ from bead.items.cache import ModelOutputCache
25
+
26
+
27
+ class TogetherAIAdapter(ModelAdapter):
28
+ """Adapter for Together AI models.
29
+
30
+ Together AI provides access to various open-source models through an
31
+ OpenAI-compatible API. This adapter uses the OpenAI client with a
32
+ custom base URL.
33
+
34
+ Parameters
35
+ ----------
36
+ model_name : str
37
+ Together AI model identifier
38
+ (default: "meta-llama/Llama-3-70b-chat-hf").
39
+ api_key : str | None
40
+ Together AI API key. If None, uses TOGETHER_API_KEY environment variable.
41
+ cache : ModelOutputCache | None
42
+ Cache for model outputs. If None, creates in-memory cache.
43
+ model_version : str
44
+ Model version for cache tracking (default: "latest").
45
+
46
+ Attributes
47
+ ----------
48
+ model_name : str
49
+ Together AI model identifier (e.g., "meta-llama/Llama-3-70b-chat-hf").
50
+ client : openai.OpenAI
51
+ OpenAI-compatible client configured for Together AI.
52
+
53
+ Raises
54
+ ------
55
+ ValueError
56
+ If no API key is provided and TOGETHER_API_KEY is not set.
57
+ """
58
+
59
+ def __init__(
60
+ self,
61
+ model_name: str = "meta-llama/Llama-3-70b-chat-hf",
62
+ api_key: str | None = None,
63
+ cache: ModelOutputCache | None = None,
64
+ model_version: str = "latest",
65
+ ) -> None:
66
+ if cache is None:
67
+ cache = ModelOutputCache(backend="memory")
68
+
69
+ super().__init__(
70
+ model_name=model_name, cache=cache, model_version=model_version
71
+ )
72
+
73
+ # Get API key from parameter or environment
74
+ if api_key is None:
75
+ api_key = os.environ.get("TOGETHER_API_KEY")
76
+ if api_key is None:
77
+ raise ValueError(
78
+ "Together AI API key must be provided via api_key parameter "
79
+ "or TOGETHER_API_KEY environment variable"
80
+ )
81
+
82
+ # Together AI uses OpenAI-compatible API
83
+ self.client = openai.OpenAI(
84
+ api_key=api_key, base_url="https://api.together.xyz/v1"
85
+ )
86
+
87
+ @retry_with_backoff(
88
+ max_retries=3,
89
+ initial_delay=1.0,
90
+ backoff_factor=2.0,
91
+ exceptions=(openai.APIError, openai.APIConnectionError, openai.RateLimitError),
92
+ )
93
+ @rate_limit(calls_per_minute=60)
94
+ def compute_log_probability(self, text: str) -> float:
95
+ """Compute log probability of text using Together AI API.
96
+
97
+ Uses the completions API with logprobs to get token-level log probabilities
98
+ and sums them to get the total log probability.
99
+
100
+ Parameters
101
+ ----------
102
+ text : str
103
+ Text to compute log probability for.
104
+
105
+ Returns
106
+ -------
107
+ float
108
+ Log probability of the text (sum of token log probabilities).
109
+ """
110
+ # Check cache
111
+ cached = self.cache.get(
112
+ model_name=self.model_name, operation="log_probability", text=text
113
+ )
114
+ if cached is not None:
115
+ return float(cached)
116
+
117
+ # Call API
118
+ try:
119
+ response = self.client.completions.create(
120
+ model=self.model_name,
121
+ prompt=text,
122
+ max_tokens=0,
123
+ echo=True,
124
+ logprobs=1,
125
+ )
126
+
127
+ # Sum token log probabilities
128
+ logprobs = response.choices[0].logprobs
129
+ if logprobs is None or logprobs.token_logprobs is None:
130
+ raise ValueError("API response did not include logprobs")
131
+
132
+ # Filter out None values (first token may have None)
133
+ token_logprobs = [lp for lp in logprobs.token_logprobs if lp is not None]
134
+ total_log_prob = sum(token_logprobs)
135
+
136
+ except (openai.BadRequestError, AttributeError) as e:
137
+ # Some models may not support completions API, fall back to chat
138
+ raise NotImplementedError(
139
+ f"Log probability computation is not supported for model "
140
+ f"{self.model_name}. This model may not support the "
141
+ "completions API with logprobs."
142
+ ) from e
143
+
144
+ # Cache result
145
+ self.cache.set(
146
+ model_name=self.model_name,
147
+ operation="log_probability",
148
+ result=total_log_prob,
149
+ model_version=self.model_version,
150
+ text=text,
151
+ )
152
+
153
+ return float(total_log_prob)
154
+
155
+ def compute_perplexity(self, text: str) -> float:
156
+ """Compute perplexity of text.
157
+
158
+ Perplexity is computed as exp(-log_prob / num_tokens).
159
+
160
+ Parameters
161
+ ----------
162
+ text : str
163
+ Text to compute perplexity for.
164
+
165
+ Returns
166
+ -------
167
+ float
168
+ Perplexity of the text (must be positive).
169
+
170
+ Raises
171
+ ------
172
+ NotImplementedError
173
+ If log probability computation is not supported.
174
+ """
175
+ # Check cache
176
+ cached = self.cache.get(
177
+ model_name=self.model_name, operation="perplexity", text=text
178
+ )
179
+ if cached is not None:
180
+ return float(cached)
181
+
182
+ # Get log probability
183
+ log_prob = self.compute_log_probability(text)
184
+
185
+ # Estimate number of tokens (rough approximation: 1 token ~ 4 chars)
186
+ num_tokens = max(1, len(text) // 4)
187
+
188
+ # Compute perplexity: exp(-log_prob / num_tokens)
189
+ perplexity = np.exp(-log_prob / num_tokens)
190
+
191
+ # Cache result
192
+ self.cache.set(
193
+ model_name=self.model_name,
194
+ operation="perplexity",
195
+ result=float(perplexity),
196
+ model_version=self.model_version,
197
+ text=text,
198
+ )
199
+
200
+ return float(perplexity)
201
+
202
+ def get_embedding(self, text: str) -> np.ndarray:
203
+ """Get embedding vector for text.
204
+
205
+ Not supported by Together AI (no embedding-specific models).
206
+
207
+ Raises
208
+ ------
209
+ NotImplementedError
210
+ Always raised - Together AI does not provide embeddings.
211
+ """
212
+ raise NotImplementedError(
213
+ "Embedding computation is not supported by Together AI. "
214
+ "Together AI focuses on text generation models. "
215
+ "Consider using OpenAI's text-embedding models or sentence transformers."
216
+ )
217
+
218
+ @retry_with_backoff(
219
+ max_retries=3,
220
+ initial_delay=1.0,
221
+ backoff_factor=2.0,
222
+ exceptions=(openai.APIError, openai.APIConnectionError, openai.RateLimitError),
223
+ )
224
+ @rate_limit(calls_per_minute=60)
225
+ def compute_nli(self, premise: str, hypothesis: str) -> dict[str, float]:
226
+ """Compute natural language inference scores via prompting.
227
+
228
+ Uses chat completions API with a prompt to classify the relationship
229
+ between premise and hypothesis.
230
+
231
+ Parameters
232
+ ----------
233
+ premise : str
234
+ Premise text.
235
+ hypothesis : str
236
+ Hypothesis text.
237
+
238
+ Returns
239
+ -------
240
+ dict[str, float]
241
+ Dictionary with keys "entailment", "neutral", "contradiction"
242
+ mapping to probability scores.
243
+ """
244
+ # Check cache
245
+ cached = self.cache.get(
246
+ model_name=self.model_name,
247
+ operation="nli",
248
+ premise=premise,
249
+ hypothesis=hypothesis,
250
+ )
251
+ if cached is not None:
252
+ return dict(cached)
253
+
254
+ # Construct prompt
255
+ prompt = (
256
+ "Given the following premise and hypothesis, "
257
+ "determine the relationship between them.\n\n"
258
+ f"Premise: {premise}\n"
259
+ f"Hypothesis: {hypothesis}\n\n"
260
+ "Choose one of the following:\n"
261
+ "- entailment: The hypothesis is definitely true given the premise\n"
262
+ "- neutral: The hypothesis might be true given the premise\n"
263
+ "- contradiction: The hypothesis is definitely false given the premise\n\n"
264
+ "Respond with only one word: entailment, neutral, or contradiction."
265
+ )
266
+
267
+ # Call API
268
+ response = self.client.chat.completions.create(
269
+ model=self.model_name,
270
+ messages=[{"role": "user", "content": prompt}],
271
+ temperature=0.0,
272
+ max_tokens=10,
273
+ )
274
+
275
+ # Parse response
276
+ answer = response.choices[0].message.content
277
+ if answer is None:
278
+ raise ValueError("API response did not include content")
279
+
280
+ answer = answer.strip().lower()
281
+
282
+ # Map to scores
283
+ scores: dict[str, float] = {
284
+ "entailment": 0.0,
285
+ "neutral": 0.0,
286
+ "contradiction": 0.0,
287
+ }
288
+
289
+ if "entailment" in answer:
290
+ scores["entailment"] = 1.0
291
+ elif "neutral" in answer:
292
+ scores["neutral"] = 1.0
293
+ elif "contradiction" in answer:
294
+ scores["contradiction"] = 1.0
295
+ else:
296
+ # Default to neutral if unclear
297
+ scores["neutral"] = 1.0
298
+
299
+ # Cache result
300
+ self.cache.set(
301
+ model_name=self.model_name,
302
+ operation="nli",
303
+ result=scores,
304
+ model_version=self.model_version,
305
+ premise=premise,
306
+ hypothesis=hypothesis,
307
+ )
308
+
309
+ return scores