trismik 0.9.12__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
trismik/adaptive_test.py DELETED
@@ -1,669 +0,0 @@
1
- """
2
- Trismik adaptive test runner.
3
-
4
- This module provides both synchronous and asynchronous interfaces for running
5
- Trismik tests. The async implementation is the core, with sync methods wrapping
6
- the async ones.
7
- """
8
-
9
- import asyncio
10
- from typing import Any, Callable, Dict, List, Literal, Optional, Union, overload
11
-
12
- import nest_asyncio
13
- from tqdm.auto import tqdm
14
-
15
- from trismik.client_async import TrismikAsyncClient
16
- from trismik.settings import evaluation_settings
17
- from trismik.types import (
18
- AdaptiveTestScore,
19
- TrismikAdaptiveTestState,
20
- TrismikClassicEvalRequest,
21
- TrismikClassicEvalResponse,
22
- TrismikDataset,
23
- TrismikItem,
24
- TrismikMeResponse,
25
- TrismikProject,
26
- TrismikReplayRequest,
27
- TrismikReplayRequestItem,
28
- TrismikRunMetadata,
29
- TrismikRunResults,
30
- )
31
-
32
-
33
- class AdaptiveTest:
34
- """
35
- Trismik test runner with both sync and async interfaces.
36
-
37
- This class provides both synchronous and asynchronous interfaces for
38
- running Trismik tests. The async implementation is the core, with sync
39
- methods wrapping the async ones.
40
- """
41
-
42
- def __init__(
43
- self,
44
- item_processor: Callable[[TrismikItem], Any],
45
- client: Optional[TrismikAsyncClient] = None,
46
- api_key: Optional[str] = None,
47
- max_items: int = evaluation_settings["max_iterations"],
48
- ) -> None:
49
- """
50
- Initialize a new Trismik runner.
51
-
52
- Args:
53
- item_processor (Callable[[TrismikItem], Any]): Function to process
54
- test items. For async usage, this should be an async function.
55
- client (Optional[TrismikAsyncClient]): Trismik async client to use
56
- for requests. If not provided, a new one will be created.
57
- api_key (Optional[str]): API key to use if a new client is created.
58
- max_items (int): Maximum number of items to process. Default is 60.
59
-
60
- Raises:
61
- ValueError: If both client and api_key are provided.
62
- TrismikApiError: If API request fails.
63
- """
64
- if client and api_key:
65
- raise ValueError(
66
- "Either 'client' or 'api_key' should be provided, not both."
67
- )
68
- self._item_processor = item_processor
69
- if client:
70
- self._client = client
71
- else:
72
- self._client = TrismikAsyncClient(api_key=api_key)
73
- self._max_items = max_items
74
- self._loop = None
75
-
76
- def _get_loop(self) -> asyncio.AbstractEventLoop:
77
- """
78
- Get or create an event loop, handling nested loops if needed.
79
-
80
- Returns:
81
- asyncio.AbstractEventLoop: The event loop to use.
82
- """
83
- try:
84
- loop = asyncio.get_event_loop()
85
- except RuntimeError:
86
- # No event loop in this thread, create one
87
- loop = asyncio.new_event_loop()
88
- asyncio.set_event_loop(loop)
89
-
90
- # Allow nested event loops (needed for Jupyter, etc)
91
- nest_asyncio.apply(loop)
92
- return loop
93
-
94
- def list_datasets(self) -> List[TrismikDataset]:
95
- """
96
- Get a list of available datasets synchronously.
97
-
98
- Returns:
99
- List[TrismikDataset]: List of available datasets.
100
-
101
- Raises:
102
- TrismikApiError: If API request fails.
103
- """
104
- loop = self._get_loop()
105
- return loop.run_until_complete(self.list_datasets_async())
106
-
107
- async def list_datasets_async(self) -> List[TrismikDataset]:
108
- """
109
- Get a list of available datasets asynchronously.
110
-
111
- Returns:
112
- List[TrismikDataset]: List of available datasets.
113
-
114
- Raises:
115
- TrismikApiError: If API request fails.
116
- """
117
- return await self._client.list_datasets()
118
-
119
- def me(self) -> TrismikMeResponse:
120
- """
121
- Get current user information synchronously.
122
-
123
- Returns:
124
- TrismikMeResponse: User information including validity and payload.
125
-
126
- Raises:
127
- TrismikApiError: If API request fails.
128
- """
129
- loop = self._get_loop()
130
- return loop.run_until_complete(self.me_async())
131
-
132
- async def me_async(self) -> TrismikMeResponse:
133
- """
134
- Get current user information asynchronously.
135
-
136
- Returns:
137
- TrismikMeResponse: User information including validity and payload.
138
-
139
- Raises:
140
- TrismikApiError: If API request fails.
141
- """
142
- return await self._client.me()
143
-
144
- def create_project(
145
- self,
146
- name: str,
147
- team_id: Optional[str] = None,
148
- description: Optional[str] = None,
149
- ) -> TrismikProject:
150
- """
151
- Create a new project synchronously.
152
-
153
- Args:
154
- name (str): Name of the project.
155
- team_id (str): ID of the team to create the
156
- project in.
157
- description (Optional[str]): Optional description of the project.
158
-
159
- Returns:
160
- TrismikProject: Created project information.
161
-
162
- Raises:
163
- TrismikValidationError: If the request fails validation.
164
- TrismikApiError: If API request fails.
165
- """
166
- loop = self._get_loop()
167
- return loop.run_until_complete(
168
- self.create_project_async(name, team_id, description)
169
- )
170
-
171
- async def create_project_async(
172
- self,
173
- name: str,
174
- team_id: Optional[str] = None,
175
- description: Optional[str] = None,
176
- ) -> TrismikProject:
177
- """
178
- Create a new project asynchronously.
179
-
180
- Args:
181
- name (str): Name of the project.
182
- team_id (str): ID of the team to create the
183
- project in.
184
- description (Optional[str]): Optional description of the project.
185
-
186
- Returns:
187
- TrismikProject: Created project information.
188
-
189
- Raises:
190
- TrismikValidationError: If the request fails validation.
191
- TrismikApiError: If API request fails.
192
- """
193
- return await self._client.create_project(name, team_id, description)
194
-
195
- @overload
196
- def run( # noqa: E704
197
- self,
198
- test_id: str,
199
- project_id: str,
200
- experiment: str,
201
- run_metadata: TrismikRunMetadata,
202
- return_dict: Literal[True],
203
- with_responses: bool = False,
204
- ) -> Dict[str, Any]: ...
205
-
206
- @overload
207
- def run( # noqa: E704
208
- self,
209
- test_id: str,
210
- project_id: str,
211
- experiment: str,
212
- run_metadata: TrismikRunMetadata,
213
- return_dict: Literal[False],
214
- with_responses: bool = False,
215
- ) -> TrismikRunResults: ...
216
-
217
- def run(
218
- self,
219
- test_id: str,
220
- project_id: str,
221
- experiment: str,
222
- run_metadata: TrismikRunMetadata,
223
- return_dict: bool = True,
224
- with_responses: bool = False,
225
- ) -> Union[TrismikRunResults, Dict[str, Any]]:
226
- """
227
- Run a test synchronously.
228
-
229
- Args:
230
- test_id (str): ID of the test to run.
231
- project_id (str): ID of the project.
232
- experiment (str): Name of the experiment.
233
- run_metadata (TrismikRunMetadata): Metadata for the
234
- run.
235
- return_dict (bool): If True, return results as a dictionary instead
236
- of TrismikRunResults object. Defaults to True.
237
- with_responses (bool): If True, responses will be included with
238
- the results.
239
-
240
- Returns:
241
- Union[TrismikRunResults, Dict[str, Any]]: Either TrismikRunResults
242
- object or dictionary representation based on return_dict
243
- parameter.
244
-
245
- Raises:
246
- TrismikApiError: If API request fails.
247
- NotImplementedError: If with_responses = True (not yet implemented).
248
- """
249
- loop = self._get_loop()
250
- if return_dict:
251
- return loop.run_until_complete(
252
- self.run_async(
253
- test_id,
254
- project_id,
255
- experiment,
256
- run_metadata,
257
- True,
258
- with_responses,
259
- )
260
- )
261
- else:
262
- return loop.run_until_complete(
263
- self.run_async(
264
- test_id,
265
- project_id,
266
- experiment,
267
- run_metadata,
268
- False,
269
- with_responses,
270
- )
271
- )
272
-
273
- @overload
274
- async def run_async( # noqa: E704
275
- self,
276
- test_id: str,
277
- project_id: str,
278
- experiment: str,
279
- run_metadata: TrismikRunMetadata,
280
- return_dict: Literal[True],
281
- with_responses: bool = False,
282
- ) -> Dict[str, Any]: ...
283
-
284
- @overload
285
- async def run_async( # noqa: E704
286
- self,
287
- test_id: str,
288
- project_id: str,
289
- experiment: str,
290
- run_metadata: TrismikRunMetadata,
291
- return_dict: Literal[False],
292
- with_responses: bool = False,
293
- ) -> TrismikRunResults: ...
294
-
295
- async def run_async(
296
- self,
297
- test_id: str,
298
- project_id: str,
299
- experiment: str,
300
- run_metadata: TrismikRunMetadata,
301
- return_dict: bool = True,
302
- with_responses: bool = False,
303
- ) -> Union[TrismikRunResults, Dict[str, Any]]:
304
- """
305
- Run a test asynchronously.
306
-
307
- Args:
308
- test_id: ID of the test to run.
309
- project_id: ID of the project.
310
- experiment: Name of the experiment.
311
- run_metadata: Metadata for the run.
312
- return_dict: If True, return results as a dictionary instead
313
- of TrismikRunResults object. Defaults to True.
314
- with_responses: If True, responses will be included with
315
- the results.
316
-
317
- Returns:
318
- Union[TrismikRunResults, Dict[str, Any]]: Either TrismikRunResults
319
- object or dictionary representation based on return_dict
320
- parameter.
321
-
322
- Raises:
323
- TrismikApiError: If API request fails.
324
- NotImplementedError: If with_responses = True (not yet implemented).
325
- """
326
- if with_responses:
327
- raise NotImplementedError(
328
- "with_responses is not yet implemented for the new API flow"
329
- )
330
-
331
- # Start run and get first item
332
- start_response = await self._client.start_run(
333
- test_id, project_id, experiment, run_metadata
334
- )
335
-
336
- # Initialize state tracking
337
- states: List[TrismikAdaptiveTestState] = []
338
- run_id = start_response.run_info.id
339
-
340
- # Add initial state
341
- states.append(
342
- TrismikAdaptiveTestState(
343
- run_id=run_id,
344
- state=start_response.state,
345
- completed=start_response.completed,
346
- )
347
- )
348
-
349
- # Run the test and get last state
350
- last_state = await self._run_async(
351
- run_id, start_response.next_item, states
352
- )
353
-
354
- if not last_state:
355
- raise RuntimeError(
356
- "Test run completed but no final state was captured"
357
- )
358
-
359
- score = AdaptiveTestScore(
360
- theta=last_state.state.thetas[-1],
361
- std_error=last_state.state.std_error_history[-1],
362
- )
363
-
364
- results = TrismikRunResults(run_id, score=score)
365
-
366
- if return_dict:
367
- return {
368
- "run_id": results.run_id,
369
- "score": (
370
- {
371
- "theta": results.score.theta,
372
- "std_error": results.score.std_error,
373
- }
374
- if results.score
375
- else None
376
- ),
377
- "responses": results.responses,
378
- }
379
- else:
380
- return results
381
-
382
- @overload
383
- def run_replay( # noqa: E704
384
- self,
385
- previous_run_id: str,
386
- run_metadata: TrismikRunMetadata,
387
- return_dict: Literal[True],
388
- with_responses: bool = False,
389
- ) -> Dict[str, Any]: ...
390
-
391
- @overload
392
- def run_replay( # noqa: E704
393
- self,
394
- previous_run_id: str,
395
- run_metadata: TrismikRunMetadata,
396
- return_dict: Literal[False],
397
- with_responses: bool = False,
398
- ) -> TrismikRunResults: ...
399
-
400
- def run_replay(
401
- self,
402
- previous_run_id: str,
403
- run_metadata: TrismikRunMetadata,
404
- return_dict: bool = True,
405
- with_responses: bool = False,
406
- ) -> Union[TrismikRunResults, Dict[str, Any]]:
407
- """
408
- Replay the exact sequence of questions from a previous run.
409
-
410
- Wraps the run_replay_async method.
411
-
412
- Args:
413
- previous_run_id: ID of a previous run to replay.
414
- run_metadata: Metadata for the replay run.
415
- return_dict: If True, return results as a dictionary instead
416
- of TrismikRunResults object. Defaults to True.
417
- with_responses: If True, responses will be included
418
- with the results.
419
-
420
- Returns:
421
- Union[TrismikRunResults, Dict[str, Any]]: Either TrismikRunResults
422
- object or dictionary representation based on return_dict
423
- parameter.
424
-
425
- Raises:
426
- TrismikApiError: If API request fails.
427
- """
428
- loop = self._get_loop()
429
- if return_dict:
430
- return loop.run_until_complete(
431
- self.run_replay_async(
432
- previous_run_id,
433
- run_metadata,
434
- True,
435
- with_responses,
436
- )
437
- )
438
- else:
439
- return loop.run_until_complete(
440
- self.run_replay_async(
441
- previous_run_id,
442
- run_metadata,
443
- False,
444
- with_responses,
445
- )
446
- )
447
-
448
- @overload
449
- async def run_replay_async( # noqa: E704
450
- self,
451
- previous_run_id: str,
452
- run_metadata: TrismikRunMetadata,
453
- return_dict: Literal[True],
454
- with_responses: bool = False,
455
- ) -> Dict[str, Any]: ...
456
-
457
- @overload
458
- async def run_replay_async( # noqa: E704
459
- self,
460
- previous_run_id: str,
461
- run_metadata: TrismikRunMetadata,
462
- return_dict: Literal[False],
463
- with_responses: bool = False,
464
- ) -> TrismikRunResults: ...
465
-
466
- async def run_replay_async(
467
- self,
468
- previous_run_id: str,
469
- run_metadata: TrismikRunMetadata,
470
- return_dict: bool = True,
471
- with_responses: bool = False,
472
- ) -> Union[TrismikRunResults, Dict[str, Any]]:
473
- """
474
- Replay the exact sequence of questions from a previous run.
475
-
476
- Args:
477
- previous_run_id: ID of a previous run to replay.
478
- run_metadata: Metadata for the run.
479
- return_dict: If True, return results as a dictionary instead
480
- of TrismikRunResults object. Defaults to True.
481
- with_responses: If True, responses will be included
482
- with the results.
483
-
484
- Returns:
485
- Union[TrismikRunResults, Dict[str, Any]]: Either TrismikRunResults
486
- object or dictionary representation based on return_dict
487
- parameter.
488
-
489
- Raises:
490
- TrismikApiError: If API request fails.
491
- """
492
- # Get the original run summary to access dataset and responses
493
- original_summary = await self._client.run_summary(previous_run_id)
494
-
495
- # Build replay request by processing each item in the original order
496
- replay_items = []
497
- with tqdm(
498
- total=len(original_summary.dataset), desc="Running replay..."
499
- ) as pbar:
500
- for item in original_summary.dataset:
501
- # Handle both sync and async item processors
502
- if asyncio.iscoroutinefunction(self._item_processor):
503
- response = await self._item_processor(item)
504
- else:
505
- response = self._item_processor(item)
506
-
507
- # Create replay request item
508
- replay_item = TrismikReplayRequestItem(
509
- itemId=item.id, itemChoiceId=response
510
- )
511
- replay_items.append(replay_item)
512
- pbar.update(1)
513
-
514
- # Create replay request
515
- replay_request = TrismikReplayRequest(responses=replay_items)
516
-
517
- # Submit replay with metadata
518
- replay_response = await self._client.submit_replay(
519
- previous_run_id, replay_request, run_metadata
520
- )
521
-
522
- # Create score from replay response
523
- score = AdaptiveTestScore(
524
- theta=replay_response.state.thetas[-1],
525
- std_error=replay_response.state.std_error_history[-1],
526
- )
527
-
528
- # Return results with optional responses
529
- if with_responses:
530
- results = TrismikRunResults(
531
- run_id=replay_response.id,
532
- score=score,
533
- responses=replay_response.responses,
534
- )
535
- else:
536
- results = TrismikRunResults(run_id=replay_response.id, score=score)
537
-
538
- if return_dict:
539
- return {
540
- "run_id": results.run_id,
541
- "score": (
542
- {
543
- "theta": results.score.theta,
544
- "std_error": results.score.std_error,
545
- }
546
- if results.score
547
- else None
548
- ),
549
- "responses": (
550
- [
551
- {
552
- "dataset_item_id": resp.dataset_item_id,
553
- "value": resp.value,
554
- "correct": resp.correct,
555
- }
556
- for resp in results.responses
557
- ]
558
- if results.responses
559
- else None
560
- ),
561
- }
562
- else:
563
- return results
564
-
565
- async def _run_async(
566
- self,
567
- run_id: str,
568
- first_item: Optional[TrismikItem],
569
- states: List[TrismikAdaptiveTestState],
570
- ) -> Optional[TrismikAdaptiveTestState]:
571
- """
572
- Run a test asynchronously.
573
-
574
- Args:
575
- run_id (str): ID of the run to execute.
576
- first_item (Optional[TrismikItem]): First item from run start.
577
- states (List[TrismikAdaptiveTestState]): List to accumulate states.
578
-
579
- Returns:
580
- Optional[TrismikAdaptiveTestState]: Last state of the run.
581
-
582
- Raises:
583
- TrismikApiError: If API request fails.
584
- """
585
- item = first_item
586
- with tqdm(total=self._max_items, desc="Evaluating") as pbar:
587
- while item is not None:
588
- # Handle both sync and async item processors
589
- if asyncio.iscoroutinefunction(self._item_processor):
590
- response = await self._item_processor(item)
591
- else:
592
- response = self._item_processor(item)
593
-
594
- # Continue run with response
595
- continue_response = await self._client.continue_run(
596
- run_id, response
597
- )
598
-
599
- # Update state tracking
600
- states.append(
601
- TrismikAdaptiveTestState(
602
- run_id=run_id,
603
- state=continue_response.state,
604
- completed=continue_response.completed,
605
- )
606
- )
607
-
608
- pbar.update(1)
609
-
610
- if continue_response.completed:
611
- pbar.total = pbar.n # Update total to current position
612
- pbar.refresh()
613
- break
614
-
615
- item = continue_response.next_item
616
-
617
- last_state = states[-1] if states else None
618
-
619
- return last_state
620
-
621
- def submit_classic_eval(
622
- self, classic_eval_request: TrismikClassicEvalRequest
623
- ) -> TrismikClassicEvalResponse:
624
- """
625
- Submit a classic evaluation run with pre-computed results synchronously.
626
-
627
- Args:
628
- classic_eval_request (TrismikClassicEvalRequest): Request containing
629
- project info, dataset, model outputs, and metrics.
630
-
631
- Returns:
632
- TrismikClassicEvalResponse: Response from the classic evaluation
633
- endpoint.
634
-
635
- Raises:
636
- TrismikPayloadTooLargeError: If the request payload exceeds the
637
- server's size limit.
638
- TrismikValidationError: If the request fails validation.
639
- TrismikApiError: If API request fails.
640
- """
641
- loop = self._get_loop()
642
- return loop.run_until_complete(
643
- self.submit_classic_eval_async(classic_eval_request)
644
- )
645
-
646
- async def submit_classic_eval_async(
647
- self, classic_eval_request: TrismikClassicEvalRequest
648
- ) -> TrismikClassicEvalResponse:
649
- """
650
- Submit a classic evaluation run with pre-computed results async.
651
-
652
- This method allows you to submit pre-computed model outputs and metrics
653
- for evaluation without running an interactive test.
654
-
655
- Args:
656
- classic_eval_request (TrismikClassicEvalRequest): Request containing
657
- project info, dataset, model outputs, and metrics.
658
-
659
- Returns:
660
- TrismikClassicEvalResponse: Response from the classic evaluation
661
- endpoint.
662
-
663
- Raises:
664
- TrismikPayloadTooLargeError: If the request payload exceeds the
665
- server's size limit.
666
- TrismikValidationError: If the request fails validation.
667
- TrismikApiError: If API request fails.
668
- """
669
- return await self._client.submit_classic_eval(classic_eval_request)