bead 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bead/__init__.py +11 -0
- bead/__main__.py +11 -0
- bead/active_learning/__init__.py +15 -0
- bead/active_learning/config.py +231 -0
- bead/active_learning/loop.py +566 -0
- bead/active_learning/models/__init__.py +24 -0
- bead/active_learning/models/base.py +852 -0
- bead/active_learning/models/binary.py +910 -0
- bead/active_learning/models/categorical.py +943 -0
- bead/active_learning/models/cloze.py +862 -0
- bead/active_learning/models/forced_choice.py +956 -0
- bead/active_learning/models/free_text.py +773 -0
- bead/active_learning/models/lora.py +365 -0
- bead/active_learning/models/magnitude.py +835 -0
- bead/active_learning/models/multi_select.py +795 -0
- bead/active_learning/models/ordinal_scale.py +811 -0
- bead/active_learning/models/peft_adapter.py +155 -0
- bead/active_learning/models/random_effects.py +639 -0
- bead/active_learning/selection.py +354 -0
- bead/active_learning/strategies.py +391 -0
- bead/active_learning/trainers/__init__.py +26 -0
- bead/active_learning/trainers/base.py +210 -0
- bead/active_learning/trainers/data_collator.py +172 -0
- bead/active_learning/trainers/dataset_utils.py +261 -0
- bead/active_learning/trainers/huggingface.py +304 -0
- bead/active_learning/trainers/lightning.py +324 -0
- bead/active_learning/trainers/metrics.py +424 -0
- bead/active_learning/trainers/mixed_effects.py +551 -0
- bead/active_learning/trainers/model_wrapper.py +509 -0
- bead/active_learning/trainers/registry.py +104 -0
- bead/adapters/__init__.py +11 -0
- bead/adapters/huggingface.py +61 -0
- bead/behavioral/__init__.py +116 -0
- bead/behavioral/analytics.py +646 -0
- bead/behavioral/extraction.py +343 -0
- bead/behavioral/merging.py +343 -0
- bead/cli/__init__.py +11 -0
- bead/cli/active_learning.py +513 -0
- bead/cli/active_learning_commands.py +779 -0
- bead/cli/completion.py +359 -0
- bead/cli/config.py +624 -0
- bead/cli/constraint_builders.py +286 -0
- bead/cli/deployment.py +859 -0
- bead/cli/deployment_trials.py +493 -0
- bead/cli/deployment_ui.py +332 -0
- bead/cli/display.py +378 -0
- bead/cli/items.py +960 -0
- bead/cli/items_factories.py +776 -0
- bead/cli/list_constraints.py +714 -0
- bead/cli/lists.py +490 -0
- bead/cli/main.py +430 -0
- bead/cli/models.py +877 -0
- bead/cli/resource_loaders.py +621 -0
- bead/cli/resources.py +1036 -0
- bead/cli/shell.py +356 -0
- bead/cli/simulate.py +840 -0
- bead/cli/templates.py +1158 -0
- bead/cli/training.py +1080 -0
- bead/cli/utils.py +614 -0
- bead/cli/workflow.py +1273 -0
- bead/config/__init__.py +68 -0
- bead/config/active_learning.py +1009 -0
- bead/config/config.py +192 -0
- bead/config/defaults.py +118 -0
- bead/config/deployment.py +217 -0
- bead/config/env.py +147 -0
- bead/config/item.py +45 -0
- bead/config/list.py +193 -0
- bead/config/loader.py +149 -0
- bead/config/logging.py +42 -0
- bead/config/model.py +49 -0
- bead/config/paths.py +46 -0
- bead/config/profiles.py +320 -0
- bead/config/resources.py +47 -0
- bead/config/serialization.py +210 -0
- bead/config/simulation.py +206 -0
- bead/config/template.py +238 -0
- bead/config/validation.py +267 -0
- bead/data/__init__.py +65 -0
- bead/data/base.py +87 -0
- bead/data/identifiers.py +97 -0
- bead/data/language_codes.py +61 -0
- bead/data/metadata.py +270 -0
- bead/data/range.py +123 -0
- bead/data/repository.py +358 -0
- bead/data/serialization.py +249 -0
- bead/data/timestamps.py +89 -0
- bead/data/validation.py +349 -0
- bead/data_collection/__init__.py +11 -0
- bead/data_collection/jatos.py +223 -0
- bead/data_collection/merger.py +154 -0
- bead/data_collection/prolific.py +198 -0
- bead/deployment/__init__.py +5 -0
- bead/deployment/distribution.py +402 -0
- bead/deployment/jatos/__init__.py +1 -0
- bead/deployment/jatos/api.py +200 -0
- bead/deployment/jatos/exporter.py +210 -0
- bead/deployment/jspsych/__init__.py +9 -0
- bead/deployment/jspsych/biome.json +44 -0
- bead/deployment/jspsych/config.py +411 -0
- bead/deployment/jspsych/generator.py +598 -0
- bead/deployment/jspsych/package.json +51 -0
- bead/deployment/jspsych/pnpm-lock.yaml +2141 -0
- bead/deployment/jspsych/randomizer.py +299 -0
- bead/deployment/jspsych/src/lib/list-distributor.test.ts +327 -0
- bead/deployment/jspsych/src/lib/list-distributor.ts +1282 -0
- bead/deployment/jspsych/src/lib/randomizer.test.ts +232 -0
- bead/deployment/jspsych/src/lib/randomizer.ts +367 -0
- bead/deployment/jspsych/src/plugins/cloze-dropdown.ts +252 -0
- bead/deployment/jspsych/src/plugins/forced-choice.ts +265 -0
- bead/deployment/jspsych/src/plugins/plugins.test.ts +141 -0
- bead/deployment/jspsych/src/plugins/rating.ts +248 -0
- bead/deployment/jspsych/src/slopit/index.ts +9 -0
- bead/deployment/jspsych/src/types/jatos.d.ts +256 -0
- bead/deployment/jspsych/src/types/jspsych.d.ts +228 -0
- bead/deployment/jspsych/templates/experiment.css +1 -0
- bead/deployment/jspsych/templates/experiment.js.template +289 -0
- bead/deployment/jspsych/templates/index.html +51 -0
- bead/deployment/jspsych/templates/randomizer.js +241 -0
- bead/deployment/jspsych/templates/randomizer.js.template +313 -0
- bead/deployment/jspsych/trials.py +723 -0
- bead/deployment/jspsych/tsconfig.json +23 -0
- bead/deployment/jspsych/tsup.config.ts +30 -0
- bead/deployment/jspsych/ui/__init__.py +1 -0
- bead/deployment/jspsych/ui/components.py +383 -0
- bead/deployment/jspsych/ui/styles.py +411 -0
- bead/dsl/__init__.py +80 -0
- bead/dsl/ast.py +168 -0
- bead/dsl/context.py +178 -0
- bead/dsl/errors.py +71 -0
- bead/dsl/evaluator.py +570 -0
- bead/dsl/grammar.lark +81 -0
- bead/dsl/parser.py +231 -0
- bead/dsl/stdlib.py +929 -0
- bead/evaluation/__init__.py +13 -0
- bead/evaluation/convergence.py +485 -0
- bead/evaluation/interannotator.py +398 -0
- bead/items/__init__.py +40 -0
- bead/items/adapters/__init__.py +70 -0
- bead/items/adapters/anthropic.py +224 -0
- bead/items/adapters/api_utils.py +167 -0
- bead/items/adapters/base.py +216 -0
- bead/items/adapters/google.py +259 -0
- bead/items/adapters/huggingface.py +1074 -0
- bead/items/adapters/openai.py +323 -0
- bead/items/adapters/registry.py +202 -0
- bead/items/adapters/sentence_transformers.py +224 -0
- bead/items/adapters/togetherai.py +309 -0
- bead/items/binary.py +515 -0
- bead/items/cache.py +558 -0
- bead/items/categorical.py +593 -0
- bead/items/cloze.py +757 -0
- bead/items/constructor.py +784 -0
- bead/items/forced_choice.py +413 -0
- bead/items/free_text.py +681 -0
- bead/items/generation.py +432 -0
- bead/items/item.py +396 -0
- bead/items/item_template.py +787 -0
- bead/items/magnitude.py +573 -0
- bead/items/multi_select.py +621 -0
- bead/items/ordinal_scale.py +569 -0
- bead/items/scoring.py +448 -0
- bead/items/validation.py +723 -0
- bead/lists/__init__.py +30 -0
- bead/lists/balancer.py +263 -0
- bead/lists/constraints.py +1067 -0
- bead/lists/experiment_list.py +286 -0
- bead/lists/list_collection.py +378 -0
- bead/lists/partitioner.py +1141 -0
- bead/lists/stratification.py +254 -0
- bead/participants/__init__.py +73 -0
- bead/participants/collection.py +699 -0
- bead/participants/merging.py +312 -0
- bead/participants/metadata_spec.py +491 -0
- bead/participants/models.py +276 -0
- bead/resources/__init__.py +29 -0
- bead/resources/adapters/__init__.py +19 -0
- bead/resources/adapters/base.py +104 -0
- bead/resources/adapters/cache.py +128 -0
- bead/resources/adapters/glazing.py +508 -0
- bead/resources/adapters/registry.py +117 -0
- bead/resources/adapters/unimorph.py +796 -0
- bead/resources/classification.py +856 -0
- bead/resources/constraint_builders.py +329 -0
- bead/resources/constraints.py +165 -0
- bead/resources/lexical_item.py +223 -0
- bead/resources/lexicon.py +744 -0
- bead/resources/loaders.py +209 -0
- bead/resources/template.py +441 -0
- bead/resources/template_collection.py +707 -0
- bead/resources/template_generation.py +349 -0
- bead/simulation/__init__.py +29 -0
- bead/simulation/annotators/__init__.py +15 -0
- bead/simulation/annotators/base.py +175 -0
- bead/simulation/annotators/distance_based.py +135 -0
- bead/simulation/annotators/lm_based.py +114 -0
- bead/simulation/annotators/oracle.py +182 -0
- bead/simulation/annotators/random.py +181 -0
- bead/simulation/dsl_extension/__init__.py +3 -0
- bead/simulation/noise_models/__init__.py +13 -0
- bead/simulation/noise_models/base.py +42 -0
- bead/simulation/noise_models/random_noise.py +82 -0
- bead/simulation/noise_models/systematic.py +132 -0
- bead/simulation/noise_models/temperature.py +86 -0
- bead/simulation/runner.py +144 -0
- bead/simulation/strategies/__init__.py +23 -0
- bead/simulation/strategies/base.py +123 -0
- bead/simulation/strategies/binary.py +103 -0
- bead/simulation/strategies/categorical.py +123 -0
- bead/simulation/strategies/cloze.py +224 -0
- bead/simulation/strategies/forced_choice.py +127 -0
- bead/simulation/strategies/free_text.py +105 -0
- bead/simulation/strategies/magnitude.py +116 -0
- bead/simulation/strategies/multi_select.py +129 -0
- bead/simulation/strategies/ordinal_scale.py +131 -0
- bead/templates/__init__.py +27 -0
- bead/templates/adapters/__init__.py +17 -0
- bead/templates/adapters/base.py +128 -0
- bead/templates/adapters/cache.py +178 -0
- bead/templates/adapters/huggingface.py +312 -0
- bead/templates/combinatorics.py +103 -0
- bead/templates/filler.py +605 -0
- bead/templates/renderers.py +177 -0
- bead/templates/resolver.py +178 -0
- bead/templates/strategies.py +1806 -0
- bead/templates/streaming.py +195 -0
- bead-0.1.0.dist-info/METADATA +212 -0
- bead-0.1.0.dist-info/RECORD +231 -0
- bead-0.1.0.dist-info/WHEEL +4 -0
- bead-0.1.0.dist-info/entry_points.txt +2 -0
- bead-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,224 @@
|
|
|
1
|
+
"""Sentence transformer adapter for semantic embeddings.
|
|
2
|
+
|
|
3
|
+
This module provides an adapter for sentence-transformers models,
|
|
4
|
+
which are optimized for generating sentence embeddings for semantic
|
|
5
|
+
similarity tasks.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import logging
|
|
11
|
+
from typing import TYPE_CHECKING
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
|
|
15
|
+
from bead.items.adapters.base import ModelAdapter
|
|
16
|
+
from bead.items.cache import ModelOutputCache
|
|
17
|
+
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
from sentence_transformers import SentenceTransformer
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class HuggingFaceSentenceTransformer(ModelAdapter):
|
|
25
|
+
"""Adapter for sentence-transformers models.
|
|
26
|
+
|
|
27
|
+
Supports sentence-transformers models like "all-MiniLM-L6-v2",
|
|
28
|
+
"all-mpnet-base-v2", etc. These models are optimized for generating
|
|
29
|
+
sentence embeddings for semantic similarity tasks.
|
|
30
|
+
|
|
31
|
+
Parameters
|
|
32
|
+
----------
|
|
33
|
+
model_name : str
|
|
34
|
+
Sentence transformer model identifier.
|
|
35
|
+
cache : ModelOutputCache
|
|
36
|
+
Cache instance for storing model outputs.
|
|
37
|
+
device : str | None
|
|
38
|
+
Device to run model on. If None, uses sentence-transformers default.
|
|
39
|
+
model_version : str
|
|
40
|
+
Version string for cache tracking.
|
|
41
|
+
normalize_embeddings : bool
|
|
42
|
+
Whether to normalize embeddings to unit length.
|
|
43
|
+
|
|
44
|
+
Examples
|
|
45
|
+
--------
|
|
46
|
+
>>> from pathlib import Path
|
|
47
|
+
>>> from bead.items.cache import ModelOutputCache
|
|
48
|
+
>>> cache = ModelOutputCache(cache_dir=Path(".cache"))
|
|
49
|
+
>>> model = HuggingFaceSentenceTransformer("all-MiniLM-L6-v2", cache)
|
|
50
|
+
>>> embedding = model.get_embedding("The cat sat on the mat.")
|
|
51
|
+
>>> similarity = model.compute_similarity("The cat sat.", "The dog stood.")
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
def __init__(
|
|
55
|
+
self,
|
|
56
|
+
model_name: str,
|
|
57
|
+
cache: ModelOutputCache,
|
|
58
|
+
device: str | None = None,
|
|
59
|
+
model_version: str = "unknown",
|
|
60
|
+
normalize_embeddings: bool = True,
|
|
61
|
+
) -> None:
|
|
62
|
+
super().__init__(model_name, cache, model_version)
|
|
63
|
+
self.device = device
|
|
64
|
+
self.normalize_embeddings = normalize_embeddings
|
|
65
|
+
self._model: SentenceTransformer | None = None
|
|
66
|
+
|
|
67
|
+
def _load_model(self) -> None:
|
|
68
|
+
"""Load model lazily on first use."""
|
|
69
|
+
if self._model is None:
|
|
70
|
+
from sentence_transformers import SentenceTransformer # noqa: PLC0415
|
|
71
|
+
|
|
72
|
+
logger.info(f"Loading sentence transformer: {self.model_name}")
|
|
73
|
+
self._model = SentenceTransformer(self.model_name, device=self.device)
|
|
74
|
+
|
|
75
|
+
@property
|
|
76
|
+
def model(self) -> SentenceTransformer:
|
|
77
|
+
"""Get the model, loading if necessary."""
|
|
78
|
+
self._load_model()
|
|
79
|
+
assert self._model is not None
|
|
80
|
+
return self._model
|
|
81
|
+
|
|
82
|
+
def compute_log_probability(self, text: str) -> float:
|
|
83
|
+
"""Compute log probability of text.
|
|
84
|
+
|
|
85
|
+
Not supported for sentence transformer models.
|
|
86
|
+
|
|
87
|
+
Raises
|
|
88
|
+
------
|
|
89
|
+
NotImplementedError
|
|
90
|
+
Always raised, as sentence transformers don't provide log probabilities.
|
|
91
|
+
"""
|
|
92
|
+
raise NotImplementedError(
|
|
93
|
+
f"Log probability is not supported for sentence transformer "
|
|
94
|
+
f"{self.model_name}. Use HuggingFaceLanguageModel or "
|
|
95
|
+
"HuggingFaceMaskedLanguageModel instead."
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
def compute_perplexity(self, text: str) -> float:
|
|
99
|
+
"""Compute perplexity of text.
|
|
100
|
+
|
|
101
|
+
Not supported for sentence transformer models.
|
|
102
|
+
|
|
103
|
+
Raises
|
|
104
|
+
------
|
|
105
|
+
NotImplementedError
|
|
106
|
+
Always raised, as sentence transformers don't provide perplexity.
|
|
107
|
+
"""
|
|
108
|
+
raise NotImplementedError(
|
|
109
|
+
f"Perplexity is not supported for sentence transformer {self.model_name}. "
|
|
110
|
+
"Use HuggingFaceLanguageModel or HuggingFaceMaskedLanguageModel instead."
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
def get_embedding(self, text: str) -> np.ndarray:
|
|
114
|
+
"""Get embedding vector for text.
|
|
115
|
+
|
|
116
|
+
Uses sentence-transformers encode() method to generate
|
|
117
|
+
optimized sentence embeddings.
|
|
118
|
+
|
|
119
|
+
Parameters
|
|
120
|
+
----------
|
|
121
|
+
text : str
|
|
122
|
+
Text to embed.
|
|
123
|
+
|
|
124
|
+
Returns
|
|
125
|
+
-------
|
|
126
|
+
np.ndarray
|
|
127
|
+
Embedding vector for the text.
|
|
128
|
+
"""
|
|
129
|
+
# Check cache
|
|
130
|
+
cached = self.cache.get(self.model_name, "embedding", text=text)
|
|
131
|
+
if cached is not None:
|
|
132
|
+
return cached
|
|
133
|
+
|
|
134
|
+
# Encode text
|
|
135
|
+
embedding = self.model.encode(
|
|
136
|
+
text,
|
|
137
|
+
convert_to_numpy=True,
|
|
138
|
+
normalize_embeddings=self.normalize_embeddings,
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
# Ensure it's a numpy array
|
|
142
|
+
if not isinstance(embedding, np.ndarray):
|
|
143
|
+
embedding = np.array(embedding)
|
|
144
|
+
|
|
145
|
+
# Cache result
|
|
146
|
+
self.cache.set(
|
|
147
|
+
self.model_name,
|
|
148
|
+
"embedding",
|
|
149
|
+
embedding,
|
|
150
|
+
model_version=self.model_version,
|
|
151
|
+
text=text,
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
return embedding
|
|
155
|
+
|
|
156
|
+
def compute_nli(self, premise: str, hypothesis: str) -> dict[str, float]:
|
|
157
|
+
"""Compute natural language inference scores.
|
|
158
|
+
|
|
159
|
+
Not supported for sentence transformer models.
|
|
160
|
+
|
|
161
|
+
Raises
|
|
162
|
+
------
|
|
163
|
+
NotImplementedError
|
|
164
|
+
Always raised, as sentence transformers don't support NLI directly.
|
|
165
|
+
"""
|
|
166
|
+
raise NotImplementedError(
|
|
167
|
+
f"NLI is not supported for sentence transformer {self.model_name}. "
|
|
168
|
+
"Use HuggingFaceNLI adapter with an NLI-trained model instead."
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
def compute_similarity(self, text1: str, text2: str) -> float:
|
|
172
|
+
"""Compute similarity between two texts.
|
|
173
|
+
|
|
174
|
+
Uses cosine similarity of embeddings. For sentence transformers,
|
|
175
|
+
this is optimized as embeddings are already normalized (if
|
|
176
|
+
normalize_embeddings=True).
|
|
177
|
+
|
|
178
|
+
Parameters
|
|
179
|
+
----------
|
|
180
|
+
text1 : str
|
|
181
|
+
First text.
|
|
182
|
+
text2 : str
|
|
183
|
+
Second text.
|
|
184
|
+
|
|
185
|
+
Returns
|
|
186
|
+
-------
|
|
187
|
+
float
|
|
188
|
+
Similarity score in [-1, 1] (cosine similarity).
|
|
189
|
+
"""
|
|
190
|
+
# Check cache
|
|
191
|
+
cached = self.cache.get(self.model_name, "similarity", text1=text1, text2=text2)
|
|
192
|
+
if cached is not None:
|
|
193
|
+
return cached
|
|
194
|
+
|
|
195
|
+
# Get embeddings
|
|
196
|
+
emb1 = self.get_embedding(text1)
|
|
197
|
+
emb2 = self.get_embedding(text2)
|
|
198
|
+
|
|
199
|
+
# Compute cosine similarity
|
|
200
|
+
if self.normalize_embeddings:
|
|
201
|
+
# Embeddings are already normalized, just dot product
|
|
202
|
+
similarity = float(np.dot(emb1, emb2))
|
|
203
|
+
else:
|
|
204
|
+
# Need to normalize
|
|
205
|
+
dot_product = np.dot(emb1, emb2)
|
|
206
|
+
norm1 = np.linalg.norm(emb1)
|
|
207
|
+
norm2 = np.linalg.norm(emb2)
|
|
208
|
+
|
|
209
|
+
if norm1 == 0 or norm2 == 0:
|
|
210
|
+
similarity = 0.0
|
|
211
|
+
else:
|
|
212
|
+
similarity = float(dot_product / (norm1 * norm2))
|
|
213
|
+
|
|
214
|
+
# Cache result
|
|
215
|
+
self.cache.set(
|
|
216
|
+
self.model_name,
|
|
217
|
+
"similarity",
|
|
218
|
+
similarity,
|
|
219
|
+
model_version=self.model_version,
|
|
220
|
+
text1=text1,
|
|
221
|
+
text2=text2,
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
return similarity
|
|
@@ -0,0 +1,309 @@
|
|
|
1
|
+
"""Together AI adapter for item construction.
|
|
2
|
+
|
|
3
|
+
This module provides a ModelAdapter implementation for Together AI's API,
|
|
4
|
+
which provides access to various open-source models. Together AI uses an
|
|
5
|
+
OpenAI-compatible API, so we use the OpenAI client with a custom base URL.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import os
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
|
|
14
|
+
try:
|
|
15
|
+
import openai
|
|
16
|
+
except ImportError as e:
|
|
17
|
+
raise ImportError(
|
|
18
|
+
"openai package is required for Together AI adapter. "
|
|
19
|
+
"Install it with: pip install openai"
|
|
20
|
+
) from e
|
|
21
|
+
|
|
22
|
+
from bead.items.adapters.api_utils import rate_limit, retry_with_backoff
|
|
23
|
+
from bead.items.adapters.base import ModelAdapter
|
|
24
|
+
from bead.items.cache import ModelOutputCache
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class TogetherAIAdapter(ModelAdapter):
|
|
28
|
+
"""Adapter for Together AI models.
|
|
29
|
+
|
|
30
|
+
Together AI provides access to various open-source models through an
|
|
31
|
+
OpenAI-compatible API. This adapter uses the OpenAI client with a
|
|
32
|
+
custom base URL.
|
|
33
|
+
|
|
34
|
+
Parameters
|
|
35
|
+
----------
|
|
36
|
+
model_name : str
|
|
37
|
+
Together AI model identifier
|
|
38
|
+
(default: "meta-llama/Llama-3-70b-chat-hf").
|
|
39
|
+
api_key : str | None
|
|
40
|
+
Together AI API key. If None, uses TOGETHER_API_KEY environment variable.
|
|
41
|
+
cache : ModelOutputCache | None
|
|
42
|
+
Cache for model outputs. If None, creates in-memory cache.
|
|
43
|
+
model_version : str
|
|
44
|
+
Model version for cache tracking (default: "latest").
|
|
45
|
+
|
|
46
|
+
Attributes
|
|
47
|
+
----------
|
|
48
|
+
model_name : str
|
|
49
|
+
Together AI model identifier (e.g., "meta-llama/Llama-3-70b-chat-hf").
|
|
50
|
+
client : openai.OpenAI
|
|
51
|
+
OpenAI-compatible client configured for Together AI.
|
|
52
|
+
|
|
53
|
+
Raises
|
|
54
|
+
------
|
|
55
|
+
ValueError
|
|
56
|
+
If no API key is provided and TOGETHER_API_KEY is not set.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
def __init__(
|
|
60
|
+
self,
|
|
61
|
+
model_name: str = "meta-llama/Llama-3-70b-chat-hf",
|
|
62
|
+
api_key: str | None = None,
|
|
63
|
+
cache: ModelOutputCache | None = None,
|
|
64
|
+
model_version: str = "latest",
|
|
65
|
+
) -> None:
|
|
66
|
+
if cache is None:
|
|
67
|
+
cache = ModelOutputCache(backend="memory")
|
|
68
|
+
|
|
69
|
+
super().__init__(
|
|
70
|
+
model_name=model_name, cache=cache, model_version=model_version
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
# Get API key from parameter or environment
|
|
74
|
+
if api_key is None:
|
|
75
|
+
api_key = os.environ.get("TOGETHER_API_KEY")
|
|
76
|
+
if api_key is None:
|
|
77
|
+
raise ValueError(
|
|
78
|
+
"Together AI API key must be provided via api_key parameter "
|
|
79
|
+
"or TOGETHER_API_KEY environment variable"
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
# Together AI uses OpenAI-compatible API
|
|
83
|
+
self.client = openai.OpenAI(
|
|
84
|
+
api_key=api_key, base_url="https://api.together.xyz/v1"
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
@retry_with_backoff(
|
|
88
|
+
max_retries=3,
|
|
89
|
+
initial_delay=1.0,
|
|
90
|
+
backoff_factor=2.0,
|
|
91
|
+
exceptions=(openai.APIError, openai.APIConnectionError, openai.RateLimitError),
|
|
92
|
+
)
|
|
93
|
+
@rate_limit(calls_per_minute=60)
|
|
94
|
+
def compute_log_probability(self, text: str) -> float:
|
|
95
|
+
"""Compute log probability of text using Together AI API.
|
|
96
|
+
|
|
97
|
+
Uses the completions API with logprobs to get token-level log probabilities
|
|
98
|
+
and sums them to get the total log probability.
|
|
99
|
+
|
|
100
|
+
Parameters
|
|
101
|
+
----------
|
|
102
|
+
text : str
|
|
103
|
+
Text to compute log probability for.
|
|
104
|
+
|
|
105
|
+
Returns
|
|
106
|
+
-------
|
|
107
|
+
float
|
|
108
|
+
Log probability of the text (sum of token log probabilities).
|
|
109
|
+
"""
|
|
110
|
+
# Check cache
|
|
111
|
+
cached = self.cache.get(
|
|
112
|
+
model_name=self.model_name, operation="log_probability", text=text
|
|
113
|
+
)
|
|
114
|
+
if cached is not None:
|
|
115
|
+
return float(cached)
|
|
116
|
+
|
|
117
|
+
# Call API
|
|
118
|
+
try:
|
|
119
|
+
response = self.client.completions.create(
|
|
120
|
+
model=self.model_name,
|
|
121
|
+
prompt=text,
|
|
122
|
+
max_tokens=0,
|
|
123
|
+
echo=True,
|
|
124
|
+
logprobs=1,
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
# Sum token log probabilities
|
|
128
|
+
logprobs = response.choices[0].logprobs
|
|
129
|
+
if logprobs is None or logprobs.token_logprobs is None:
|
|
130
|
+
raise ValueError("API response did not include logprobs")
|
|
131
|
+
|
|
132
|
+
# Filter out None values (first token may have None)
|
|
133
|
+
token_logprobs = [lp for lp in logprobs.token_logprobs if lp is not None]
|
|
134
|
+
total_log_prob = sum(token_logprobs)
|
|
135
|
+
|
|
136
|
+
except (openai.BadRequestError, AttributeError) as e:
|
|
137
|
+
# Some models may not support completions API, fall back to chat
|
|
138
|
+
raise NotImplementedError(
|
|
139
|
+
f"Log probability computation is not supported for model "
|
|
140
|
+
f"{self.model_name}. This model may not support the "
|
|
141
|
+
"completions API with logprobs."
|
|
142
|
+
) from e
|
|
143
|
+
|
|
144
|
+
# Cache result
|
|
145
|
+
self.cache.set(
|
|
146
|
+
model_name=self.model_name,
|
|
147
|
+
operation="log_probability",
|
|
148
|
+
result=total_log_prob,
|
|
149
|
+
model_version=self.model_version,
|
|
150
|
+
text=text,
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
return float(total_log_prob)
|
|
154
|
+
|
|
155
|
+
def compute_perplexity(self, text: str) -> float:
|
|
156
|
+
"""Compute perplexity of text.
|
|
157
|
+
|
|
158
|
+
Perplexity is computed as exp(-log_prob / num_tokens).
|
|
159
|
+
|
|
160
|
+
Parameters
|
|
161
|
+
----------
|
|
162
|
+
text : str
|
|
163
|
+
Text to compute perplexity for.
|
|
164
|
+
|
|
165
|
+
Returns
|
|
166
|
+
-------
|
|
167
|
+
float
|
|
168
|
+
Perplexity of the text (must be positive).
|
|
169
|
+
|
|
170
|
+
Raises
|
|
171
|
+
------
|
|
172
|
+
NotImplementedError
|
|
173
|
+
If log probability computation is not supported.
|
|
174
|
+
"""
|
|
175
|
+
# Check cache
|
|
176
|
+
cached = self.cache.get(
|
|
177
|
+
model_name=self.model_name, operation="perplexity", text=text
|
|
178
|
+
)
|
|
179
|
+
if cached is not None:
|
|
180
|
+
return float(cached)
|
|
181
|
+
|
|
182
|
+
# Get log probability
|
|
183
|
+
log_prob = self.compute_log_probability(text)
|
|
184
|
+
|
|
185
|
+
# Estimate number of tokens (rough approximation: 1 token ~ 4 chars)
|
|
186
|
+
num_tokens = max(1, len(text) // 4)
|
|
187
|
+
|
|
188
|
+
# Compute perplexity: exp(-log_prob / num_tokens)
|
|
189
|
+
perplexity = np.exp(-log_prob / num_tokens)
|
|
190
|
+
|
|
191
|
+
# Cache result
|
|
192
|
+
self.cache.set(
|
|
193
|
+
model_name=self.model_name,
|
|
194
|
+
operation="perplexity",
|
|
195
|
+
result=float(perplexity),
|
|
196
|
+
model_version=self.model_version,
|
|
197
|
+
text=text,
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
return float(perplexity)
|
|
201
|
+
|
|
202
|
+
def get_embedding(self, text: str) -> np.ndarray:
|
|
203
|
+
"""Get embedding vector for text.
|
|
204
|
+
|
|
205
|
+
Not supported by Together AI (no embedding-specific models).
|
|
206
|
+
|
|
207
|
+
Raises
|
|
208
|
+
------
|
|
209
|
+
NotImplementedError
|
|
210
|
+
Always raised - Together AI does not provide embeddings.
|
|
211
|
+
"""
|
|
212
|
+
raise NotImplementedError(
|
|
213
|
+
"Embedding computation is not supported by Together AI. "
|
|
214
|
+
"Together AI focuses on text generation models. "
|
|
215
|
+
"Consider using OpenAI's text-embedding models or sentence transformers."
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
@retry_with_backoff(
|
|
219
|
+
max_retries=3,
|
|
220
|
+
initial_delay=1.0,
|
|
221
|
+
backoff_factor=2.0,
|
|
222
|
+
exceptions=(openai.APIError, openai.APIConnectionError, openai.RateLimitError),
|
|
223
|
+
)
|
|
224
|
+
@rate_limit(calls_per_minute=60)
|
|
225
|
+
def compute_nli(self, premise: str, hypothesis: str) -> dict[str, float]:
|
|
226
|
+
"""Compute natural language inference scores via prompting.
|
|
227
|
+
|
|
228
|
+
Uses chat completions API with a prompt to classify the relationship
|
|
229
|
+
between premise and hypothesis.
|
|
230
|
+
|
|
231
|
+
Parameters
|
|
232
|
+
----------
|
|
233
|
+
premise : str
|
|
234
|
+
Premise text.
|
|
235
|
+
hypothesis : str
|
|
236
|
+
Hypothesis text.
|
|
237
|
+
|
|
238
|
+
Returns
|
|
239
|
+
-------
|
|
240
|
+
dict[str, float]
|
|
241
|
+
Dictionary with keys "entailment", "neutral", "contradiction"
|
|
242
|
+
mapping to probability scores.
|
|
243
|
+
"""
|
|
244
|
+
# Check cache
|
|
245
|
+
cached = self.cache.get(
|
|
246
|
+
model_name=self.model_name,
|
|
247
|
+
operation="nli",
|
|
248
|
+
premise=premise,
|
|
249
|
+
hypothesis=hypothesis,
|
|
250
|
+
)
|
|
251
|
+
if cached is not None:
|
|
252
|
+
return dict(cached)
|
|
253
|
+
|
|
254
|
+
# Construct prompt
|
|
255
|
+
prompt = (
|
|
256
|
+
"Given the following premise and hypothesis, "
|
|
257
|
+
"determine the relationship between them.\n\n"
|
|
258
|
+
f"Premise: {premise}\n"
|
|
259
|
+
f"Hypothesis: {hypothesis}\n\n"
|
|
260
|
+
"Choose one of the following:\n"
|
|
261
|
+
"- entailment: The hypothesis is definitely true given the premise\n"
|
|
262
|
+
"- neutral: The hypothesis might be true given the premise\n"
|
|
263
|
+
"- contradiction: The hypothesis is definitely false given the premise\n\n"
|
|
264
|
+
"Respond with only one word: entailment, neutral, or contradiction."
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
# Call API
|
|
268
|
+
response = self.client.chat.completions.create(
|
|
269
|
+
model=self.model_name,
|
|
270
|
+
messages=[{"role": "user", "content": prompt}],
|
|
271
|
+
temperature=0.0,
|
|
272
|
+
max_tokens=10,
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
# Parse response
|
|
276
|
+
answer = response.choices[0].message.content
|
|
277
|
+
if answer is None:
|
|
278
|
+
raise ValueError("API response did not include content")
|
|
279
|
+
|
|
280
|
+
answer = answer.strip().lower()
|
|
281
|
+
|
|
282
|
+
# Map to scores
|
|
283
|
+
scores: dict[str, float] = {
|
|
284
|
+
"entailment": 0.0,
|
|
285
|
+
"neutral": 0.0,
|
|
286
|
+
"contradiction": 0.0,
|
|
287
|
+
}
|
|
288
|
+
|
|
289
|
+
if "entailment" in answer:
|
|
290
|
+
scores["entailment"] = 1.0
|
|
291
|
+
elif "neutral" in answer:
|
|
292
|
+
scores["neutral"] = 1.0
|
|
293
|
+
elif "contradiction" in answer:
|
|
294
|
+
scores["contradiction"] = 1.0
|
|
295
|
+
else:
|
|
296
|
+
# Default to neutral if unclear
|
|
297
|
+
scores["neutral"] = 1.0
|
|
298
|
+
|
|
299
|
+
# Cache result
|
|
300
|
+
self.cache.set(
|
|
301
|
+
model_name=self.model_name,
|
|
302
|
+
operation="nli",
|
|
303
|
+
result=scores,
|
|
304
|
+
model_version=self.model_version,
|
|
305
|
+
premise=premise,
|
|
306
|
+
hypothesis=hypothesis,
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
return scores
|