bead 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. bead/__init__.py +11 -0
  2. bead/__main__.py +11 -0
  3. bead/active_learning/__init__.py +15 -0
  4. bead/active_learning/config.py +231 -0
  5. bead/active_learning/loop.py +566 -0
  6. bead/active_learning/models/__init__.py +24 -0
  7. bead/active_learning/models/base.py +852 -0
  8. bead/active_learning/models/binary.py +910 -0
  9. bead/active_learning/models/categorical.py +943 -0
  10. bead/active_learning/models/cloze.py +862 -0
  11. bead/active_learning/models/forced_choice.py +956 -0
  12. bead/active_learning/models/free_text.py +773 -0
  13. bead/active_learning/models/lora.py +365 -0
  14. bead/active_learning/models/magnitude.py +835 -0
  15. bead/active_learning/models/multi_select.py +795 -0
  16. bead/active_learning/models/ordinal_scale.py +811 -0
  17. bead/active_learning/models/peft_adapter.py +155 -0
  18. bead/active_learning/models/random_effects.py +639 -0
  19. bead/active_learning/selection.py +354 -0
  20. bead/active_learning/strategies.py +391 -0
  21. bead/active_learning/trainers/__init__.py +26 -0
  22. bead/active_learning/trainers/base.py +210 -0
  23. bead/active_learning/trainers/data_collator.py +172 -0
  24. bead/active_learning/trainers/dataset_utils.py +261 -0
  25. bead/active_learning/trainers/huggingface.py +304 -0
  26. bead/active_learning/trainers/lightning.py +324 -0
  27. bead/active_learning/trainers/metrics.py +424 -0
  28. bead/active_learning/trainers/mixed_effects.py +551 -0
  29. bead/active_learning/trainers/model_wrapper.py +509 -0
  30. bead/active_learning/trainers/registry.py +104 -0
  31. bead/adapters/__init__.py +11 -0
  32. bead/adapters/huggingface.py +61 -0
  33. bead/behavioral/__init__.py +116 -0
  34. bead/behavioral/analytics.py +646 -0
  35. bead/behavioral/extraction.py +343 -0
  36. bead/behavioral/merging.py +343 -0
  37. bead/cli/__init__.py +11 -0
  38. bead/cli/active_learning.py +513 -0
  39. bead/cli/active_learning_commands.py +779 -0
  40. bead/cli/completion.py +359 -0
  41. bead/cli/config.py +624 -0
  42. bead/cli/constraint_builders.py +286 -0
  43. bead/cli/deployment.py +859 -0
  44. bead/cli/deployment_trials.py +493 -0
  45. bead/cli/deployment_ui.py +332 -0
  46. bead/cli/display.py +378 -0
  47. bead/cli/items.py +960 -0
  48. bead/cli/items_factories.py +776 -0
  49. bead/cli/list_constraints.py +714 -0
  50. bead/cli/lists.py +490 -0
  51. bead/cli/main.py +430 -0
  52. bead/cli/models.py +877 -0
  53. bead/cli/resource_loaders.py +621 -0
  54. bead/cli/resources.py +1036 -0
  55. bead/cli/shell.py +356 -0
  56. bead/cli/simulate.py +840 -0
  57. bead/cli/templates.py +1158 -0
  58. bead/cli/training.py +1080 -0
  59. bead/cli/utils.py +614 -0
  60. bead/cli/workflow.py +1273 -0
  61. bead/config/__init__.py +68 -0
  62. bead/config/active_learning.py +1009 -0
  63. bead/config/config.py +192 -0
  64. bead/config/defaults.py +118 -0
  65. bead/config/deployment.py +217 -0
  66. bead/config/env.py +147 -0
  67. bead/config/item.py +45 -0
  68. bead/config/list.py +193 -0
  69. bead/config/loader.py +149 -0
  70. bead/config/logging.py +42 -0
  71. bead/config/model.py +49 -0
  72. bead/config/paths.py +46 -0
  73. bead/config/profiles.py +320 -0
  74. bead/config/resources.py +47 -0
  75. bead/config/serialization.py +210 -0
  76. bead/config/simulation.py +206 -0
  77. bead/config/template.py +238 -0
  78. bead/config/validation.py +267 -0
  79. bead/data/__init__.py +65 -0
  80. bead/data/base.py +87 -0
  81. bead/data/identifiers.py +97 -0
  82. bead/data/language_codes.py +61 -0
  83. bead/data/metadata.py +270 -0
  84. bead/data/range.py +123 -0
  85. bead/data/repository.py +358 -0
  86. bead/data/serialization.py +249 -0
  87. bead/data/timestamps.py +89 -0
  88. bead/data/validation.py +349 -0
  89. bead/data_collection/__init__.py +11 -0
  90. bead/data_collection/jatos.py +223 -0
  91. bead/data_collection/merger.py +154 -0
  92. bead/data_collection/prolific.py +198 -0
  93. bead/deployment/__init__.py +5 -0
  94. bead/deployment/distribution.py +402 -0
  95. bead/deployment/jatos/__init__.py +1 -0
  96. bead/deployment/jatos/api.py +200 -0
  97. bead/deployment/jatos/exporter.py +210 -0
  98. bead/deployment/jspsych/__init__.py +9 -0
  99. bead/deployment/jspsych/biome.json +44 -0
  100. bead/deployment/jspsych/config.py +411 -0
  101. bead/deployment/jspsych/generator.py +598 -0
  102. bead/deployment/jspsych/package.json +51 -0
  103. bead/deployment/jspsych/pnpm-lock.yaml +2141 -0
  104. bead/deployment/jspsych/randomizer.py +299 -0
  105. bead/deployment/jspsych/src/lib/list-distributor.test.ts +327 -0
  106. bead/deployment/jspsych/src/lib/list-distributor.ts +1282 -0
  107. bead/deployment/jspsych/src/lib/randomizer.test.ts +232 -0
  108. bead/deployment/jspsych/src/lib/randomizer.ts +367 -0
  109. bead/deployment/jspsych/src/plugins/cloze-dropdown.ts +252 -0
  110. bead/deployment/jspsych/src/plugins/forced-choice.ts +265 -0
  111. bead/deployment/jspsych/src/plugins/plugins.test.ts +141 -0
  112. bead/deployment/jspsych/src/plugins/rating.ts +248 -0
  113. bead/deployment/jspsych/src/slopit/index.ts +9 -0
  114. bead/deployment/jspsych/src/types/jatos.d.ts +256 -0
  115. bead/deployment/jspsych/src/types/jspsych.d.ts +228 -0
  116. bead/deployment/jspsych/templates/experiment.css +1 -0
  117. bead/deployment/jspsych/templates/experiment.js.template +289 -0
  118. bead/deployment/jspsych/templates/index.html +51 -0
  119. bead/deployment/jspsych/templates/randomizer.js +241 -0
  120. bead/deployment/jspsych/templates/randomizer.js.template +313 -0
  121. bead/deployment/jspsych/trials.py +723 -0
  122. bead/deployment/jspsych/tsconfig.json +23 -0
  123. bead/deployment/jspsych/tsup.config.ts +30 -0
  124. bead/deployment/jspsych/ui/__init__.py +1 -0
  125. bead/deployment/jspsych/ui/components.py +383 -0
  126. bead/deployment/jspsych/ui/styles.py +411 -0
  127. bead/dsl/__init__.py +80 -0
  128. bead/dsl/ast.py +168 -0
  129. bead/dsl/context.py +178 -0
  130. bead/dsl/errors.py +71 -0
  131. bead/dsl/evaluator.py +570 -0
  132. bead/dsl/grammar.lark +81 -0
  133. bead/dsl/parser.py +231 -0
  134. bead/dsl/stdlib.py +929 -0
  135. bead/evaluation/__init__.py +13 -0
  136. bead/evaluation/convergence.py +485 -0
  137. bead/evaluation/interannotator.py +398 -0
  138. bead/items/__init__.py +40 -0
  139. bead/items/adapters/__init__.py +70 -0
  140. bead/items/adapters/anthropic.py +224 -0
  141. bead/items/adapters/api_utils.py +167 -0
  142. bead/items/adapters/base.py +216 -0
  143. bead/items/adapters/google.py +259 -0
  144. bead/items/adapters/huggingface.py +1074 -0
  145. bead/items/adapters/openai.py +323 -0
  146. bead/items/adapters/registry.py +202 -0
  147. bead/items/adapters/sentence_transformers.py +224 -0
  148. bead/items/adapters/togetherai.py +309 -0
  149. bead/items/binary.py +515 -0
  150. bead/items/cache.py +558 -0
  151. bead/items/categorical.py +593 -0
  152. bead/items/cloze.py +757 -0
  153. bead/items/constructor.py +784 -0
  154. bead/items/forced_choice.py +413 -0
  155. bead/items/free_text.py +681 -0
  156. bead/items/generation.py +432 -0
  157. bead/items/item.py +396 -0
  158. bead/items/item_template.py +787 -0
  159. bead/items/magnitude.py +573 -0
  160. bead/items/multi_select.py +621 -0
  161. bead/items/ordinal_scale.py +569 -0
  162. bead/items/scoring.py +448 -0
  163. bead/items/validation.py +723 -0
  164. bead/lists/__init__.py +30 -0
  165. bead/lists/balancer.py +263 -0
  166. bead/lists/constraints.py +1067 -0
  167. bead/lists/experiment_list.py +286 -0
  168. bead/lists/list_collection.py +378 -0
  169. bead/lists/partitioner.py +1141 -0
  170. bead/lists/stratification.py +254 -0
  171. bead/participants/__init__.py +73 -0
  172. bead/participants/collection.py +699 -0
  173. bead/participants/merging.py +312 -0
  174. bead/participants/metadata_spec.py +491 -0
  175. bead/participants/models.py +276 -0
  176. bead/resources/__init__.py +29 -0
  177. bead/resources/adapters/__init__.py +19 -0
  178. bead/resources/adapters/base.py +104 -0
  179. bead/resources/adapters/cache.py +128 -0
  180. bead/resources/adapters/glazing.py +508 -0
  181. bead/resources/adapters/registry.py +117 -0
  182. bead/resources/adapters/unimorph.py +796 -0
  183. bead/resources/classification.py +856 -0
  184. bead/resources/constraint_builders.py +329 -0
  185. bead/resources/constraints.py +165 -0
  186. bead/resources/lexical_item.py +223 -0
  187. bead/resources/lexicon.py +744 -0
  188. bead/resources/loaders.py +209 -0
  189. bead/resources/template.py +441 -0
  190. bead/resources/template_collection.py +707 -0
  191. bead/resources/template_generation.py +349 -0
  192. bead/simulation/__init__.py +29 -0
  193. bead/simulation/annotators/__init__.py +15 -0
  194. bead/simulation/annotators/base.py +175 -0
  195. bead/simulation/annotators/distance_based.py +135 -0
  196. bead/simulation/annotators/lm_based.py +114 -0
  197. bead/simulation/annotators/oracle.py +182 -0
  198. bead/simulation/annotators/random.py +181 -0
  199. bead/simulation/dsl_extension/__init__.py +3 -0
  200. bead/simulation/noise_models/__init__.py +13 -0
  201. bead/simulation/noise_models/base.py +42 -0
  202. bead/simulation/noise_models/random_noise.py +82 -0
  203. bead/simulation/noise_models/systematic.py +132 -0
  204. bead/simulation/noise_models/temperature.py +86 -0
  205. bead/simulation/runner.py +144 -0
  206. bead/simulation/strategies/__init__.py +23 -0
  207. bead/simulation/strategies/base.py +123 -0
  208. bead/simulation/strategies/binary.py +103 -0
  209. bead/simulation/strategies/categorical.py +123 -0
  210. bead/simulation/strategies/cloze.py +224 -0
  211. bead/simulation/strategies/forced_choice.py +127 -0
  212. bead/simulation/strategies/free_text.py +105 -0
  213. bead/simulation/strategies/magnitude.py +116 -0
  214. bead/simulation/strategies/multi_select.py +129 -0
  215. bead/simulation/strategies/ordinal_scale.py +131 -0
  216. bead/templates/__init__.py +27 -0
  217. bead/templates/adapters/__init__.py +17 -0
  218. bead/templates/adapters/base.py +128 -0
  219. bead/templates/adapters/cache.py +178 -0
  220. bead/templates/adapters/huggingface.py +312 -0
  221. bead/templates/combinatorics.py +103 -0
  222. bead/templates/filler.py +605 -0
  223. bead/templates/renderers.py +177 -0
  224. bead/templates/resolver.py +178 -0
  225. bead/templates/strategies.py +1806 -0
  226. bead/templates/streaming.py +195 -0
  227. bead-0.1.0.dist-info/METADATA +212 -0
  228. bead-0.1.0.dist-info/RECORD +231 -0
  229. bead-0.1.0.dist-info/WHEEL +4 -0
  230. bead-0.1.0.dist-info/entry_points.txt +2 -0
  231. bead-0.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,349 @@
1
+ """Validation utilities for data integrity checks.
2
+
3
+ This module provides validation functions beyond Pydantic's built-in validation,
4
+ including file validation, reference validation, and provenance chain validation.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from pathlib import Path
10
+ from typing import get_type_hints
11
+ from uuid import UUID
12
+
13
+ from pydantic import BaseModel, Field, ValidationError
14
+
15
+ from bead.data.metadata import MetadataTracker
16
+
17
+
18
+ class ValidationReport(BaseModel):
19
+ """Report of validation results.
20
+
21
+ A lightweight model for collecting and reporting validation results,
22
+ including errors, warnings, and statistics about validated objects.
23
+
24
+ Attributes
25
+ ----------
26
+ valid : bool
27
+ Overall validation status (False if any errors)
28
+ errors : list[str]
29
+ List of error messages (default: empty list)
30
+ warnings : list[str]
31
+ List of warning messages (default: empty list)
32
+ object_count : int
33
+ Number of objects validated (default: 0)
34
+
35
+ Examples
36
+ --------
37
+ >>> report = ValidationReport(valid=True)
38
+ >>> report.add_error("Invalid field")
39
+ >>> report.valid
40
+ False
41
+ >>> report.has_errors()
42
+ True
43
+ >>> len(report.errors)
44
+ 1
45
+ """
46
+
47
+ valid: bool
48
+ errors: list[str] = Field(default_factory=list)
49
+ warnings: list[str] = Field(default_factory=list)
50
+ object_count: int = 0
51
+
52
+ def add_error(self, message: str) -> None:
53
+ """Add an error message and set valid to False.
54
+
55
+ Parameters
56
+ ----------
57
+ message
58
+ Error message to add.
59
+
60
+ Examples
61
+ --------
62
+ >>> report = ValidationReport(valid=True)
63
+ >>> report.add_error("Something went wrong")
64
+ >>> report.valid
65
+ False
66
+ >>> "Something went wrong" in report.errors
67
+ True
68
+ """
69
+ self.errors.append(message)
70
+ self.valid = False
71
+
72
+ def add_warning(self, message: str) -> None:
73
+ """Add a warning message.
74
+
75
+ Warnings do not affect the valid status.
76
+
77
+ Parameters
78
+ ----------
79
+ message
80
+ Warning message to add.
81
+
82
+ Examples
83
+ --------
84
+ >>> report = ValidationReport(valid=True)
85
+ >>> report.add_warning("This might be an issue")
86
+ >>> report.valid
87
+ True
88
+ >>> report.has_warnings()
89
+ True
90
+ """
91
+ self.warnings.append(message)
92
+
93
+ def has_errors(self) -> bool:
94
+ """Check if report has any errors.
95
+
96
+ Returns
97
+ -------
98
+ bool
99
+ True if errors list is non-empty
100
+
101
+ Examples
102
+ --------
103
+ >>> report = ValidationReport(valid=True)
104
+ >>> report.has_errors()
105
+ False
106
+ >>> report.add_error("error")
107
+ >>> report.has_errors()
108
+ True
109
+ """
110
+ return len(self.errors) > 0
111
+
112
+ def has_warnings(self) -> bool:
113
+ """Check if report has any warnings.
114
+
115
+ Returns
116
+ -------
117
+ bool
118
+ True if warnings list is non-empty
119
+
120
+ Examples
121
+ --------
122
+ >>> report = ValidationReport(valid=True)
123
+ >>> report.has_warnings()
124
+ False
125
+ >>> report.add_warning("warning")
126
+ >>> report.has_warnings()
127
+ True
128
+ """
129
+ return len(self.warnings) > 0
130
+
131
+
132
+ def validate_jsonlines_file(
133
+ path: Path, model_class: type[BaseModel], strict: bool = True
134
+ ) -> ValidationReport:
135
+ """Validate JSONLines file against Pydantic model schema.
136
+
137
+ Reads and validates each line in a JSONLines file against the provided
138
+ model class. Empty lines are skipped.
139
+
140
+ Parameters
141
+ ----------
142
+ path
143
+ Path to JSONLines file to validate.
144
+ model_class
145
+ Pydantic model class to validate against.
146
+ strict
147
+ If True, stop at first error. If False, collect all errors (default: True).
148
+
149
+ Returns
150
+ -------
151
+ ValidationReport
152
+ Validation report with results
153
+
154
+ Examples
155
+ --------
156
+ >>> from pathlib import Path
157
+ >>> from bead.data.base import BeadBaseModel
158
+ >>> class TestModel(BeadBaseModel):
159
+ ... name: str
160
+ >>> # validate file
161
+ >>> report = validate_jsonlines_file(
162
+ ... Path("data.jsonl"), TestModel
163
+ ... ) # doctest: +SKIP
164
+ >>> report.valid
165
+ True
166
+ """
167
+ report = ValidationReport(valid=True)
168
+
169
+ # check if file exists
170
+ if not path.exists():
171
+ report.add_error(f"File not found: {path}")
172
+ return report
173
+
174
+ try:
175
+ # try to read the file
176
+ with path.open("r", encoding="utf-8") as f:
177
+ for line_num, line in enumerate(f, start=1):
178
+ line = line.strip()
179
+ if not line: # skip empty lines
180
+ continue
181
+
182
+ try:
183
+ # try to parse and validate
184
+ model_class.model_validate_json(line)
185
+ report.object_count += 1
186
+ except ValidationError as e:
187
+ error_msg = f"Line {line_num}: Validation error - {e}"
188
+ report.add_error(error_msg)
189
+ if strict:
190
+ return report
191
+ except Exception as e:
192
+ error_msg = f"Line {line_num}: Parse error - {e}"
193
+ report.add_error(error_msg)
194
+ if strict:
195
+ return report
196
+
197
+ except OSError as e:
198
+ report.add_error(f"Failed to read file: {e}")
199
+
200
+ return report
201
+
202
+
203
+ def validate_uuid_references(
204
+ objects: list[BaseModel], reference_pool: dict[UUID, BaseModel]
205
+ ) -> ValidationReport:
206
+ """Validate that UUID references point to existing objects.
207
+
208
+ Checks all UUID fields in objects to ensure they reference valid objects
209
+ in the reference pool. Supports both single UUID fields and list[UUID] fields.
210
+
211
+ Parameters
212
+ ----------
213
+ objects
214
+ List of objects to validate.
215
+ reference_pool
216
+ Dictionary of valid UUIDs to objects.
217
+
218
+ Returns
219
+ -------
220
+ ValidationReport
221
+ Validation report with results
222
+
223
+ Examples
224
+ --------
225
+ >>> from uuid import uuid4
226
+ >>> from bead.data.base import BeadBaseModel
227
+ >>> class Item(BeadBaseModel):
228
+ ... name: str
229
+ >>> items = [Item(name="test")]
230
+ >>> pool = {items[0].id: items[0]}
231
+ >>> report = validate_uuid_references(items, pool)
232
+ >>> report.valid
233
+ True
234
+ """
235
+ report = ValidationReport(valid=True)
236
+ report.object_count = len(objects)
237
+
238
+ for obj in objects:
239
+ # get type hints for the object
240
+ try:
241
+ type_hints = get_type_hints(type(obj))
242
+ except Exception:
243
+ # if we can't get type hints, skip this object
244
+ continue
245
+
246
+ # check each field
247
+ for field_name, field_type in type_hints.items():
248
+ # skip 'id' field; it's the object's own ID, not a reference
249
+ if field_name == "id":
250
+ continue
251
+
252
+ # convert type to string for checking
253
+ type_str = str(field_type)
254
+
255
+ # check if field contains UUID
256
+ if "UUID" not in type_str:
257
+ continue
258
+
259
+ # get field value
260
+ try:
261
+ field_value = getattr(obj, field_name)
262
+ except AttributeError:
263
+ continue
264
+
265
+ # check if it's a list of UUIDs
266
+ if "list" in type_str.lower() or "List" in type_str:
267
+ if isinstance(field_value, list):
268
+ for item in field_value: # pyright: ignore[reportUnknownVariableType]
269
+ if not isinstance(item, UUID):
270
+ continue
271
+ if item not in reference_pool:
272
+ # get object ID for error message
273
+ obj_id = getattr(obj, "id", "unknown")
274
+ report.add_error(
275
+ f"Object {obj_id}: "
276
+ f"Field '{field_name}' references "
277
+ f"missing UUID {item}"
278
+ )
279
+ # single UUID field
280
+ elif isinstance(field_value, UUID):
281
+ if field_value not in reference_pool:
282
+ # get object ID for error message
283
+ obj_id = getattr(obj, "id", "unknown")
284
+ report.add_error(
285
+ f"Object {obj_id}: "
286
+ f"Field '{field_name}' references "
287
+ f"missing UUID {field_value}"
288
+ )
289
+
290
+ return report
291
+
292
+
293
+ def validate_provenance_chain(
294
+ metadata: MetadataTracker, repository: dict[UUID, BaseModel]
295
+ ) -> ValidationReport:
296
+ """Validate provenance chain references are valid.
297
+
298
+ Checks that all parent_id references in the provenance chain exist in the
299
+ repository and that parent_type matches the actual type.
300
+
301
+ Parameters
302
+ ----------
303
+ metadata
304
+ Metadata tracker with provenance chain to validate.
305
+ repository
306
+ Dictionary of valid UUIDs to objects.
307
+
308
+ Returns
309
+ -------
310
+ ValidationReport
311
+ Validation report with results
312
+
313
+ Examples
314
+ --------
315
+ >>> from uuid import uuid4
316
+ >>> from bead.data.base import BeadBaseModel
317
+ >>> from bead.data.metadata import MetadataTracker
318
+ >>> class Template(BeadBaseModel):
319
+ ... name: str
320
+ >>> template = Template(name="test")
321
+ >>> metadata = MetadataTracker()
322
+ >>> metadata.add_provenance(template.id, "Template", "filled_from")
323
+ >>> repo = {template.id: template}
324
+ >>> report = validate_provenance_chain(metadata, repo)
325
+ >>> report.valid
326
+ True
327
+ """
328
+ report = ValidationReport(valid=True)
329
+ report.object_count = len(metadata.provenance)
330
+
331
+ for record in metadata.provenance:
332
+ # check if parent exists
333
+ if record.parent_id not in repository:
334
+ report.add_error(
335
+ f"Provenance record references missing parent: {record.parent_id}"
336
+ )
337
+ continue
338
+
339
+ # check if parent_type matches actual type
340
+ parent_obj = repository[record.parent_id]
341
+ actual_type = type(parent_obj).__name__
342
+
343
+ if record.parent_type != actual_type:
344
+ report.add_error(
345
+ f"Provenance record for {record.parent_id}: "
346
+ f"Expected type '{record.parent_type}', got '{actual_type}'"
347
+ )
348
+
349
+ return report
@@ -0,0 +1,11 @@
1
+ """Data collection infrastructure for human experiments."""
2
+
3
+ from bead.data_collection.jatos import JATOSDataCollector
4
+ from bead.data_collection.merger import DataMerger
5
+ from bead.data_collection.prolific import ProlificDataCollector
6
+
7
+ __all__ = [
8
+ "JATOSDataCollector",
9
+ "ProlificDataCollector",
10
+ "DataMerger",
11
+ ]
@@ -0,0 +1,223 @@
1
+ """JATOS data collection for model training.
2
+
3
+ This module provides the JATOSDataCollector class for downloading experimental
4
+ results from JATOS servers. It wraps the existing JATOSClient and adds
5
+ functionality for:
6
+ - Downloading all results for a study
7
+ - Filtering by component and worker type
8
+ - Adding metadata (timestamps, etc.)
9
+ - Saving to JSONLines format
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import json
15
+ from pathlib import Path
16
+
17
+ from bead.data.base import JsonValue
18
+ from bead.data.timestamps import now_iso8601
19
+ from bead.deployment.jatos.api import JATOSClient
20
+
21
+
22
+ class JATOSDataCollector:
23
+ """Collects experimental data from JATOS API.
24
+
25
+ This class wraps the existing JATOSClient to provide data collection
26
+ functionality specifically for model training. It downloads results,
27
+ adds metadata, and saves in JSONLines format.
28
+
29
+ Parameters
30
+ ----------
31
+ base_url : str
32
+ JATOS instance URL (e.g., https://jatos.example.com).
33
+ api_token : str
34
+ API authentication token.
35
+ study_id : int
36
+ JATOS study ID to collect data from.
37
+
38
+ Attributes
39
+ ----------
40
+ study_id : int
41
+ JATOS study ID to collect data from.
42
+ client : JATOSClient
43
+ Underlying JATOS API client.
44
+
45
+ Examples
46
+ --------
47
+ Create a collector and download results::
48
+
49
+ collector = JATOSDataCollector(
50
+ base_url="https://jatos.example.com",
51
+ api_token="my-token",
52
+ study_id=123
53
+ )
54
+ results = collector.download_results(Path("results.jsonl"))
55
+ """
56
+
57
+ def __init__(
58
+ self,
59
+ base_url: str,
60
+ api_token: str,
61
+ study_id: int,
62
+ ) -> None:
63
+ self.study_id = study_id
64
+ self.client = JATOSClient(base_url, api_token)
65
+
66
+ def download_results(
67
+ self,
68
+ output_path: Path,
69
+ component_id: int | None = None,
70
+ worker_type: str | None = None,
71
+ ) -> list[dict[str, JsonValue]]:
72
+ """Download all results for the study.
73
+
74
+ Downloads results from JATOS, optionally filtering by component ID
75
+ and worker type. Each result is enriched with download timestamp
76
+ metadata and saved to a JSONLines file (one result per line).
77
+
78
+ Parameters
79
+ ----------
80
+ output_path : Path
81
+ Path to save results (JSONLines format).
82
+ component_id : int | None
83
+ Filter by component ID (optional).
84
+ worker_type : str | None
85
+ Filter by worker type (optional).
86
+
87
+ Returns
88
+ -------
89
+ list[dict[str, JsonValue]]
90
+ Downloaded results with metadata.
91
+
92
+ Raises
93
+ ------
94
+ requests.HTTPError
95
+ If the API request fails.
96
+
97
+ Examples
98
+ --------
99
+ Download all results::
100
+
101
+ results = collector.download_results(Path("results.jsonl"))
102
+
103
+ Download with filters::
104
+
105
+ results = collector.download_results(
106
+ Path("results.jsonl"),
107
+ component_id=1,
108
+ worker_type="Prolific"
109
+ )
110
+ """
111
+ # Get result IDs from JATOS API
112
+ result_ids = self.client.get_results(self.study_id)
113
+
114
+ results: list[dict[str, JsonValue]] = []
115
+
116
+ # Download each result with metadata
117
+ for result_id in result_ids:
118
+ result = self._download_single_result(result_id)
119
+
120
+ # Apply filters
121
+ if component_id is not None:
122
+ result_component_id = result.get("metadata", {}).get("componentId")
123
+ if result_component_id != component_id:
124
+ continue
125
+
126
+ if worker_type is not None:
127
+ result_worker_type = result.get("metadata", {}).get("workerType")
128
+ if result_worker_type != worker_type:
129
+ continue
130
+
131
+ results.append(result)
132
+
133
+ # Save to JSONLines file (one result per line)
134
+ output_path.parent.mkdir(parents=True, exist_ok=True)
135
+ with open(output_path, "w") as f:
136
+ for result in results:
137
+ f.write(json.dumps(result) + "\n")
138
+
139
+ return results
140
+
141
+ def _download_single_result(self, result_id: int) -> dict[str, JsonValue]:
142
+ """Download a single result with metadata.
143
+
144
+ Parameters
145
+ ----------
146
+ result_id : int
147
+ Result ID to download.
148
+
149
+ Returns
150
+ -------
151
+ dict[str, JsonValue]
152
+ Result data with metadata and download timestamp.
153
+
154
+ Raises
155
+ ------
156
+ requests.HTTPError
157
+ If the API request fails.
158
+ """
159
+ # Get result data
160
+ data_url = f"{self.client.base_url}/api/v1/results/{result_id}/data"
161
+ data_response = self.client.session.get(data_url)
162
+ data_response.raise_for_status()
163
+
164
+ # Get result metadata
165
+ meta_url = f"{self.client.base_url}/api/v1/results/{result_id}"
166
+ meta_response = self.client.session.get(meta_url)
167
+ meta_response.raise_for_status()
168
+
169
+ metadata = meta_response.json()
170
+
171
+ return {
172
+ "result_id": result_id,
173
+ "data": data_response.json(),
174
+ "metadata": metadata,
175
+ "download_timestamp": now_iso8601().isoformat(),
176
+ }
177
+
178
+ def get_study_info(self) -> dict[str, JsonValue]:
179
+ """Get study information.
180
+
181
+ Delegates to the underlying JATOSClient.
182
+
183
+ Returns
184
+ -------
185
+ dict[str, JsonValue]
186
+ Study details dictionary.
187
+
188
+ Raises
189
+ ------
190
+ requests.HTTPError
191
+ If the API request fails.
192
+
193
+ Examples
194
+ --------
195
+ ::
196
+
197
+ info = collector.get_study_info()
198
+ print(info["title"])
199
+ """
200
+ return self.client.get_study(self.study_id)
201
+
202
+ def get_result_count(self) -> int:
203
+ """Get count of results.
204
+
205
+ Returns
206
+ -------
207
+ int
208
+ Number of results available for the study.
209
+
210
+ Raises
211
+ ------
212
+ requests.HTTPError
213
+ If the API request fails.
214
+
215
+ Examples
216
+ --------
217
+ ::
218
+
219
+ count = collector.get_result_count()
220
+ print(f"Found {count} results")
221
+ """
222
+ result_ids = self.client.get_results(self.study_id)
223
+ return len(result_ids)