trismik 0.9.0__py3-none-any.whl → 0.9.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,617 @@
1
+ """
2
+ Trismik adaptive test runner.
3
+
4
+ This module provides both synchronous and asynchronous interfaces for running
5
+ Trismik tests. The async implementation is the core, with sync methods wrapping
6
+ the async ones.
7
+ """
8
+
9
+ import asyncio
10
+ from typing import Any, Callable, Dict, List, Literal, Optional, Union, overload
11
+
12
+ import nest_asyncio
13
+ from tqdm.auto import tqdm
14
+
15
+ from trismik.client_async import TrismikAsyncClient
16
+ from trismik.settings import evaluation_settings
17
+ from trismik.types import (
18
+ AdaptiveTestScore,
19
+ TrismikAdaptiveTestState,
20
+ TrismikClassicEvalRequest,
21
+ TrismikClassicEvalResponse,
22
+ TrismikDataset,
23
+ TrismikItem,
24
+ TrismikMeResponse,
25
+ TrismikReplayRequest,
26
+ TrismikReplayRequestItem,
27
+ TrismikRunMetadata,
28
+ TrismikRunResults,
29
+ )
30
+
31
+
32
+ class AdaptiveTest:
33
+ """
34
+ Trismik test runner with both sync and async interfaces.
35
+
36
+ This class provides both synchronous and asynchronous interfaces for
37
+ running Trismik tests. The async implementation is the core, with sync
38
+ methods wrapping the async ones.
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ item_processor: Callable[[TrismikItem], Any],
44
+ client: Optional[TrismikAsyncClient] = None,
45
+ api_key: Optional[str] = None,
46
+ max_items: int = evaluation_settings["max_iterations"],
47
+ ) -> None:
48
+ """
49
+ Initialize a new Trismik runner.
50
+
51
+ Args:
52
+ item_processor (Callable[[TrismikItem], Any]): Function to process
53
+ test items. For async usage, this should be an async function.
54
+ client (Optional[TrismikAsyncClient]): Trismik async client to use
55
+ for requests. If not provided, a new one will be created.
56
+ api_key (Optional[str]): API key to use if a new client is created.
57
+ max_items (int): Maximum number of items to process. Default is 60.
58
+
59
+ Raises:
60
+ ValueError: If both client and api_key are provided.
61
+ TrismikApiError: If API request fails.
62
+ """
63
+ if client and api_key:
64
+ raise ValueError(
65
+ "Either 'client' or 'api_key' should be provided, not both."
66
+ )
67
+ self._item_processor = item_processor
68
+ if client:
69
+ self._client = client
70
+ else:
71
+ self._client = TrismikAsyncClient(api_key=api_key)
72
+ self._max_items = max_items
73
+ self._loop = None
74
+
75
+ def _get_loop(self) -> asyncio.AbstractEventLoop:
76
+ """
77
+ Get or create an event loop, handling nested loops if needed.
78
+
79
+ Returns:
80
+ asyncio.AbstractEventLoop: The event loop to use.
81
+ """
82
+ try:
83
+ loop = asyncio.get_event_loop()
84
+ except RuntimeError:
85
+ # No event loop in this thread, create one
86
+ loop = asyncio.new_event_loop()
87
+ asyncio.set_event_loop(loop)
88
+
89
+ # Allow nested event loops (needed for Jupyter, etc)
90
+ nest_asyncio.apply(loop)
91
+ return loop
92
+
93
+ def list_datasets(self) -> List[TrismikDataset]:
94
+ """
95
+ Get a list of available datasets synchronously.
96
+
97
+ Returns:
98
+ List[TrismikDataset]: List of available datasets.
99
+
100
+ Raises:
101
+ TrismikApiError: If API request fails.
102
+ """
103
+ loop = self._get_loop()
104
+ return loop.run_until_complete(self.list_datasets_async())
105
+
106
+ async def list_datasets_async(self) -> List[TrismikDataset]:
107
+ """
108
+ Get a list of available datasets asynchronously.
109
+
110
+ Returns:
111
+ List[TrismikDataset]: List of available datasets.
112
+
113
+ Raises:
114
+ TrismikApiError: If API request fails.
115
+ """
116
+ return await self._client.list_datasets()
117
+
118
+ def me(self) -> TrismikMeResponse:
119
+ """
120
+ Get current user information synchronously.
121
+
122
+ Returns:
123
+ TrismikMeResponse: User information including validity and payload.
124
+
125
+ Raises:
126
+ TrismikApiError: If API request fails.
127
+ """
128
+ loop = self._get_loop()
129
+ return loop.run_until_complete(self.me_async())
130
+
131
+ async def me_async(self) -> TrismikMeResponse:
132
+ """
133
+ Get current user information asynchronously.
134
+
135
+ Returns:
136
+ TrismikMeResponse: User information including validity and payload.
137
+
138
+ Raises:
139
+ TrismikApiError: If API request fails.
140
+ """
141
+ return await self._client.me()
142
+
143
+ @overload
144
+ def run( # noqa: E704
145
+ self,
146
+ test_id: str,
147
+ project_id: str,
148
+ experiment: str,
149
+ run_metadata: TrismikRunMetadata,
150
+ return_dict: Literal[True],
151
+ with_responses: bool = False,
152
+ ) -> Dict[str, Any]: ...
153
+
154
+ @overload
155
+ def run( # noqa: E704
156
+ self,
157
+ test_id: str,
158
+ project_id: str,
159
+ experiment: str,
160
+ run_metadata: TrismikRunMetadata,
161
+ return_dict: Literal[False],
162
+ with_responses: bool = False,
163
+ ) -> TrismikRunResults: ...
164
+
165
+ def run(
166
+ self,
167
+ test_id: str,
168
+ project_id: str,
169
+ experiment: str,
170
+ run_metadata: TrismikRunMetadata,
171
+ return_dict: bool = True,
172
+ with_responses: bool = False,
173
+ ) -> Union[TrismikRunResults, Dict[str, Any]]:
174
+ """
175
+ Run a test synchronously.
176
+
177
+ Args:
178
+ test_id (str): ID of the test to run.
179
+ project_id (str): ID of the project.
180
+ experiment (str): Name of the experiment.
181
+ run_metadata (TrismikRunMetadata): Metadata for the
182
+ run.
183
+ return_dict (bool): If True, return results as a dictionary instead
184
+ of TrismikRunResults object. Defaults to True.
185
+ with_responses (bool): If True, responses will be included with
186
+ the results.
187
+
188
+ Returns:
189
+ Union[TrismikRunResults, Dict[str, Any]]: Either TrismikRunResults
190
+ object or dictionary representation based on return_dict
191
+ parameter.
192
+
193
+ Raises:
194
+ TrismikApiError: If API request fails.
195
+ NotImplementedError: If with_responses = True (not yet implemented).
196
+ """
197
+ loop = self._get_loop()
198
+ if return_dict:
199
+ return loop.run_until_complete(
200
+ self.run_async(
201
+ test_id,
202
+ project_id,
203
+ experiment,
204
+ run_metadata,
205
+ True,
206
+ with_responses,
207
+ )
208
+ )
209
+ else:
210
+ return loop.run_until_complete(
211
+ self.run_async(
212
+ test_id,
213
+ project_id,
214
+ experiment,
215
+ run_metadata,
216
+ False,
217
+ with_responses,
218
+ )
219
+ )
220
+
221
+ @overload
222
+ async def run_async( # noqa: E704
223
+ self,
224
+ test_id: str,
225
+ project_id: str,
226
+ experiment: str,
227
+ run_metadata: TrismikRunMetadata,
228
+ return_dict: Literal[True],
229
+ with_responses: bool = False,
230
+ ) -> Dict[str, Any]: ...
231
+
232
+ @overload
233
+ async def run_async( # noqa: E704
234
+ self,
235
+ test_id: str,
236
+ project_id: str,
237
+ experiment: str,
238
+ run_metadata: TrismikRunMetadata,
239
+ return_dict: Literal[False],
240
+ with_responses: bool = False,
241
+ ) -> TrismikRunResults: ...
242
+
243
+ async def run_async(
244
+ self,
245
+ test_id: str,
246
+ project_id: str,
247
+ experiment: str,
248
+ run_metadata: TrismikRunMetadata,
249
+ return_dict: bool = True,
250
+ with_responses: bool = False,
251
+ ) -> Union[TrismikRunResults, Dict[str, Any]]:
252
+ """
253
+ Run a test asynchronously.
254
+
255
+ Args:
256
+ test_id: ID of the test to run.
257
+ project_id: ID of the project.
258
+ experiment: Name of the experiment.
259
+ run_metadata: Metadata for the run.
260
+ return_dict: If True, return results as a dictionary instead
261
+ of TrismikRunResults object. Defaults to True.
262
+ with_responses: If True, responses will be included with
263
+ the results.
264
+
265
+ Returns:
266
+ Union[TrismikRunResults, Dict[str, Any]]: Either TrismikRunResults
267
+ object or dictionary representation based on return_dict
268
+ parameter.
269
+
270
+ Raises:
271
+ TrismikApiError: If API request fails.
272
+ NotImplementedError: If with_responses = True (not yet implemented).
273
+ """
274
+ if with_responses:
275
+ raise NotImplementedError(
276
+ "with_responses is not yet implemented for the new API flow"
277
+ )
278
+
279
+ # Start run and get first item
280
+ start_response = await self._client.start_run(
281
+ test_id, project_id, experiment, run_metadata
282
+ )
283
+
284
+ # Initialize state tracking
285
+ states: List[TrismikAdaptiveTestState] = []
286
+ run_id = start_response.run_info.id
287
+
288
+ # Add initial state
289
+ states.append(
290
+ TrismikAdaptiveTestState(
291
+ run_id=run_id,
292
+ state=start_response.state,
293
+ completed=start_response.completed,
294
+ )
295
+ )
296
+
297
+ # Run the test and get last state
298
+ last_state = await self._run_async(
299
+ run_id, start_response.next_item, states
300
+ )
301
+
302
+ if not last_state:
303
+ raise RuntimeError(
304
+ "Test run completed but no final state was captured"
305
+ )
306
+
307
+ score = AdaptiveTestScore(
308
+ theta=last_state.state.thetas[-1],
309
+ std_error=last_state.state.std_error_history[-1],
310
+ )
311
+
312
+ results = TrismikRunResults(run_id, score=score)
313
+
314
+ if return_dict:
315
+ return {
316
+ "run_id": results.run_id,
317
+ "score": (
318
+ {
319
+ "theta": results.score.theta,
320
+ "std_error": results.score.std_error,
321
+ }
322
+ if results.score
323
+ else None
324
+ ),
325
+ "responses": results.responses,
326
+ }
327
+ else:
328
+ return results
329
+
330
+ @overload
331
+ def run_replay( # noqa: E704
332
+ self,
333
+ previous_run_id: str,
334
+ run_metadata: TrismikRunMetadata,
335
+ return_dict: Literal[True],
336
+ with_responses: bool = False,
337
+ ) -> Dict[str, Any]: ...
338
+
339
+ @overload
340
+ def run_replay( # noqa: E704
341
+ self,
342
+ previous_run_id: str,
343
+ run_metadata: TrismikRunMetadata,
344
+ return_dict: Literal[False],
345
+ with_responses: bool = False,
346
+ ) -> TrismikRunResults: ...
347
+
348
+ def run_replay(
349
+ self,
350
+ previous_run_id: str,
351
+ run_metadata: TrismikRunMetadata,
352
+ return_dict: bool = True,
353
+ with_responses: bool = False,
354
+ ) -> Union[TrismikRunResults, Dict[str, Any]]:
355
+ """
356
+ Replay the exact sequence of questions from a previous run.
357
+
358
+ Wraps the run_replay_async method.
359
+
360
+ Args:
361
+ previous_run_id: ID of a previous run to replay.
362
+ run_metadata: Metadata for the replay run.
363
+ return_dict: If True, return results as a dictionary instead
364
+ of TrismikRunResults object. Defaults to True.
365
+ with_responses: If True, responses will be included
366
+ with the results.
367
+
368
+ Returns:
369
+ Union[TrismikRunResults, Dict[str, Any]]: Either TrismikRunResults
370
+ object or dictionary representation based on return_dict
371
+ parameter.
372
+
373
+ Raises:
374
+ TrismikApiError: If API request fails.
375
+ """
376
+ loop = self._get_loop()
377
+ if return_dict:
378
+ return loop.run_until_complete(
379
+ self.run_replay_async(
380
+ previous_run_id,
381
+ run_metadata,
382
+ True,
383
+ with_responses,
384
+ )
385
+ )
386
+ else:
387
+ return loop.run_until_complete(
388
+ self.run_replay_async(
389
+ previous_run_id,
390
+ run_metadata,
391
+ False,
392
+ with_responses,
393
+ )
394
+ )
395
+
396
+ @overload
397
+ async def run_replay_async( # noqa: E704
398
+ self,
399
+ previous_run_id: str,
400
+ run_metadata: TrismikRunMetadata,
401
+ return_dict: Literal[True],
402
+ with_responses: bool = False,
403
+ ) -> Dict[str, Any]: ...
404
+
405
+ @overload
406
+ async def run_replay_async( # noqa: E704
407
+ self,
408
+ previous_run_id: str,
409
+ run_metadata: TrismikRunMetadata,
410
+ return_dict: Literal[False],
411
+ with_responses: bool = False,
412
+ ) -> TrismikRunResults: ...
413
+
414
+ async def run_replay_async(
415
+ self,
416
+ previous_run_id: str,
417
+ run_metadata: TrismikRunMetadata,
418
+ return_dict: bool = True,
419
+ with_responses: bool = False,
420
+ ) -> Union[TrismikRunResults, Dict[str, Any]]:
421
+ """
422
+ Replay the exact sequence of questions from a previous run.
423
+
424
+ Args:
425
+ previous_run_id: ID of a previous run to replay.
426
+ run_metadata: Metadata for the run.
427
+ return_dict: If True, return results as a dictionary instead
428
+ of TrismikRunResults object. Defaults to True.
429
+ with_responses: If True, responses will be included
430
+ with the results.
431
+
432
+ Returns:
433
+ Union[TrismikRunResults, Dict[str, Any]]: Either TrismikRunResults
434
+ object or dictionary representation based on return_dict
435
+ parameter.
436
+
437
+ Raises:
438
+ TrismikApiError: If API request fails.
439
+ """
440
+ # Get the original run summary to access dataset and responses
441
+ original_summary = await self._client.run_summary(previous_run_id)
442
+
443
+ # Build replay request by processing each item in the original order
444
+ replay_items = []
445
+ with tqdm(
446
+ total=len(original_summary.dataset), desc="Running replay..."
447
+ ) as pbar:
448
+ for item in original_summary.dataset:
449
+ # Handle both sync and async item processors
450
+ if asyncio.iscoroutinefunction(self._item_processor):
451
+ response = await self._item_processor(item)
452
+ else:
453
+ response = self._item_processor(item)
454
+
455
+ # Create replay request item
456
+ replay_item = TrismikReplayRequestItem(
457
+ itemId=item.id, itemChoiceId=response
458
+ )
459
+ replay_items.append(replay_item)
460
+ pbar.update(1)
461
+
462
+ # Create replay request
463
+ replay_request = TrismikReplayRequest(responses=replay_items)
464
+
465
+ # Submit replay with metadata
466
+ replay_response = await self._client.submit_replay(
467
+ previous_run_id, replay_request, run_metadata
468
+ )
469
+
470
+ # Create score from replay response
471
+ score = AdaptiveTestScore(
472
+ theta=replay_response.state.thetas[-1],
473
+ std_error=replay_response.state.std_error_history[-1],
474
+ )
475
+
476
+ # Return results with optional responses
477
+ if with_responses:
478
+ results = TrismikRunResults(
479
+ run_id=replay_response.id,
480
+ score=score,
481
+ responses=replay_response.responses,
482
+ )
483
+ else:
484
+ results = TrismikRunResults(run_id=replay_response.id, score=score)
485
+
486
+ if return_dict:
487
+ return {
488
+ "run_id": results.run_id,
489
+ "score": (
490
+ {
491
+ "theta": results.score.theta,
492
+ "std_error": results.score.std_error,
493
+ }
494
+ if results.score
495
+ else None
496
+ ),
497
+ "responses": (
498
+ [
499
+ {
500
+ "dataset_item_id": resp.dataset_item_id,
501
+ "value": resp.value,
502
+ "correct": resp.correct,
503
+ }
504
+ for resp in results.responses
505
+ ]
506
+ if results.responses
507
+ else None
508
+ ),
509
+ }
510
+ else:
511
+ return results
512
+
513
+ async def _run_async(
514
+ self,
515
+ run_id: str,
516
+ first_item: Optional[TrismikItem],
517
+ states: List[TrismikAdaptiveTestState],
518
+ ) -> Optional[TrismikAdaptiveTestState]:
519
+ """
520
+ Run a test asynchronously.
521
+
522
+ Args:
523
+ run_id (str): ID of the run to execute.
524
+ first_item (Optional[TrismikItem]): First item from run start.
525
+ states (List[TrismikAdaptiveTestState]): List to accumulate states.
526
+
527
+ Returns:
528
+ Optional[TrismikAdaptiveTestState]: Last state of the run.
529
+
530
+ Raises:
531
+ TrismikApiError: If API request fails.
532
+ """
533
+ item = first_item
534
+ with tqdm(total=self._max_items, desc="Evaluating") as pbar:
535
+ while item is not None:
536
+ # Handle both sync and async item processors
537
+ if asyncio.iscoroutinefunction(self._item_processor):
538
+ response = await self._item_processor(item)
539
+ else:
540
+ response = self._item_processor(item)
541
+
542
+ # Continue run with response
543
+ continue_response = await self._client.continue_run(
544
+ run_id, response
545
+ )
546
+
547
+ # Update state tracking
548
+ states.append(
549
+ TrismikAdaptiveTestState(
550
+ run_id=run_id,
551
+ state=continue_response.state,
552
+ completed=continue_response.completed,
553
+ )
554
+ )
555
+
556
+ pbar.update(1)
557
+
558
+ if continue_response.completed:
559
+ pbar.total = pbar.n # Update total to current position
560
+ pbar.refresh()
561
+ break
562
+
563
+ item = continue_response.next_item
564
+
565
+ last_state = states[-1] if states else None
566
+
567
+ return last_state
568
+
569
+ def submit_classic_eval(
570
+ self, classic_eval_request: TrismikClassicEvalRequest
571
+ ) -> TrismikClassicEvalResponse:
572
+ """
573
+ Submit a classic evaluation run with pre-computed results synchronously.
574
+
575
+ Args:
576
+ classic_eval_request (TrismikClassicEvalRequest): Request containing
577
+ project info, dataset, model outputs, and metrics.
578
+
579
+ Returns:
580
+ TrismikClassicEvalResponse: Response from the classic evaluation
581
+ endpoint.
582
+
583
+ Raises:
584
+ TrismikPayloadTooLargeError: If the request payload exceeds the
585
+ server's size limit.
586
+ TrismikValidationError: If the request fails validation.
587
+ TrismikApiError: If API request fails.
588
+ """
589
+ loop = self._get_loop()
590
+ return loop.run_until_complete(
591
+ self.submit_classic_eval_async(classic_eval_request)
592
+ )
593
+
594
+ async def submit_classic_eval_async(
595
+ self, classic_eval_request: TrismikClassicEvalRequest
596
+ ) -> TrismikClassicEvalResponse:
597
+ """
598
+ Submit a classic evaluation run with pre-computed results async.
599
+
600
+ This method allows you to submit pre-computed model outputs and metrics
601
+ for evaluation without running an interactive test.
602
+
603
+ Args:
604
+ classic_eval_request (TrismikClassicEvalRequest): Request containing
605
+ project info, dataset, model outputs, and metrics.
606
+
607
+ Returns:
608
+ TrismikClassicEvalResponse: Response from the classic evaluation
609
+ endpoint.
610
+
611
+ Raises:
612
+ TrismikPayloadTooLargeError: If the request payload exceeds the
613
+ server's size limit.
614
+ TrismikValidationError: If the request fails validation.
615
+ TrismikApiError: If API request fails.
616
+ """
617
+ return await self._client.submit_classic_eval(classic_eval_request)