trismik 0.9.11__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
trismik/adaptive_test.py DELETED
@@ -1,671 +0,0 @@
1
- """
2
- Trismik adaptive test runner.
3
-
4
- This module provides both synchronous and asynchronous interfaces for running
5
- Trismik tests. The async implementation is the core, with sync methods wrapping
6
- the async ones.
7
- """
8
-
9
- import asyncio
10
- from typing import Any, Callable, Dict, List, Literal, Optional, Union, overload
11
-
12
- import nest_asyncio
13
- from tqdm.auto import tqdm
14
-
15
- from trismik.client_async import TrismikAsyncClient
16
- from trismik.settings import evaluation_settings
17
- from trismik.types import (
18
- AdaptiveTestScore,
19
- TrismikAdaptiveTestState,
20
- TrismikClassicEvalRequest,
21
- TrismikClassicEvalResponse,
22
- TrismikDataset,
23
- TrismikItem,
24
- TrismikMeResponse,
25
- TrismikProject,
26
- TrismikReplayRequest,
27
- TrismikReplayRequestItem,
28
- TrismikRunMetadata,
29
- TrismikRunResults,
30
- )
31
-
32
-
33
- class AdaptiveTest:
34
- """
35
- Trismik test runner with both sync and async interfaces.
36
-
37
- This class provides both synchronous and asynchronous interfaces for
38
- running Trismik tests. The async implementation is the core, with sync
39
- methods wrapping the async ones.
40
- """
41
-
42
- def __init__(
43
- self,
44
- item_processor: Callable[[TrismikItem], Any],
45
- client: Optional[TrismikAsyncClient] = None,
46
- api_key: Optional[str] = None,
47
- max_items: int = evaluation_settings["max_iterations"],
48
- ) -> None:
49
- """
50
- Initialize a new Trismik runner.
51
-
52
- Args:
53
- item_processor (Callable[[TrismikItem], Any]): Function to process
54
- test items. For async usage, this should be an async function.
55
- client (Optional[TrismikAsyncClient]): Trismik async client to use
56
- for requests. If not provided, a new one will be created.
57
- api_key (Optional[str]): API key to use if a new client is created.
58
- max_items (int): Maximum number of items to process. Default is 60.
59
-
60
- Raises:
61
- ValueError: If both client and api_key are provided.
62
- TrismikApiError: If API request fails.
63
- """
64
- if client and api_key:
65
- raise ValueError(
66
- "Either 'client' or 'api_key' should be provided, not both."
67
- )
68
- self._item_processor = item_processor
69
- if client:
70
- self._client = client
71
- else:
72
- self._client = TrismikAsyncClient(api_key=api_key)
73
- self._max_items = max_items
74
- self._loop = None
75
-
76
- def _get_loop(self) -> asyncio.AbstractEventLoop:
77
- """
78
- Get or create an event loop, handling nested loops if needed.
79
-
80
- Returns:
81
- asyncio.AbstractEventLoop: The event loop to use.
82
- """
83
- try:
84
- loop = asyncio.get_event_loop()
85
- except RuntimeError:
86
- # No event loop in this thread, create one
87
- loop = asyncio.new_event_loop()
88
- asyncio.set_event_loop(loop)
89
-
90
- # Allow nested event loops (needed for Jupyter, etc)
91
- nest_asyncio.apply(loop)
92
- return loop
93
-
94
- def list_datasets(self) -> List[TrismikDataset]:
95
- """
96
- Get a list of available datasets synchronously.
97
-
98
- Returns:
99
- List[TrismikDataset]: List of available datasets.
100
-
101
- Raises:
102
- TrismikApiError: If API request fails.
103
- """
104
- loop = self._get_loop()
105
- return loop.run_until_complete(self.list_datasets_async())
106
-
107
- async def list_datasets_async(self) -> List[TrismikDataset]:
108
- """
109
- Get a list of available datasets asynchronously.
110
-
111
- Returns:
112
- List[TrismikDataset]: List of available datasets.
113
-
114
- Raises:
115
- TrismikApiError: If API request fails.
116
- """
117
- return await self._client.list_datasets()
118
-
119
- def me(self) -> TrismikMeResponse:
120
- """
121
- Get current user information synchronously.
122
-
123
- Returns:
124
- TrismikMeResponse: User information including validity and payload.
125
-
126
- Raises:
127
- TrismikApiError: If API request fails.
128
- """
129
- loop = self._get_loop()
130
- return loop.run_until_complete(self.me_async())
131
-
132
- async def me_async(self) -> TrismikMeResponse:
133
- """
134
- Get current user information asynchronously.
135
-
136
- Returns:
137
- TrismikMeResponse: User information including validity and payload.
138
-
139
- Raises:
140
- TrismikApiError: If API request fails.
141
- """
142
- return await self._client.me()
143
-
144
- def create_project(
145
- self,
146
- name: str,
147
- organization_id: str,
148
- description: Optional[str] = None,
149
- ) -> TrismikProject:
150
- """
151
- Create a new project synchronously.
152
-
153
- Args:
154
- name (str): Name of the project.
155
- organization_id (str): ID of the organization to create the
156
- project in.
157
- description (Optional[str]): Optional description of the project.
158
-
159
- Returns:
160
- TrismikProject: Created project information.
161
-
162
- Raises:
163
- TrismikValidationError: If the request fails validation.
164
- TrismikApiError: If API request fails.
165
- """
166
- loop = self._get_loop()
167
- return loop.run_until_complete(
168
- self.create_project_async(name, organization_id, description)
169
- )
170
-
171
- async def create_project_async(
172
- self,
173
- name: str,
174
- organization_id: str,
175
- description: Optional[str] = None,
176
- ) -> TrismikProject:
177
- """
178
- Create a new project asynchronously.
179
-
180
- Args:
181
- name (str): Name of the project.
182
- organization_id (str): ID of the organization to create the
183
- project in.
184
- description (Optional[str]): Optional description of the project.
185
-
186
- Returns:
187
- TrismikProject: Created project information.
188
-
189
- Raises:
190
- TrismikValidationError: If the request fails validation.
191
- TrismikApiError: If API request fails.
192
- """
193
- return await self._client.create_project(
194
- name, organization_id, description
195
- )
196
-
197
- @overload
198
- def run( # noqa: E704
199
- self,
200
- test_id: str,
201
- project_id: str,
202
- experiment: str,
203
- run_metadata: TrismikRunMetadata,
204
- return_dict: Literal[True],
205
- with_responses: bool = False,
206
- ) -> Dict[str, Any]: ...
207
-
208
- @overload
209
- def run( # noqa: E704
210
- self,
211
- test_id: str,
212
- project_id: str,
213
- experiment: str,
214
- run_metadata: TrismikRunMetadata,
215
- return_dict: Literal[False],
216
- with_responses: bool = False,
217
- ) -> TrismikRunResults: ...
218
-
219
- def run(
220
- self,
221
- test_id: str,
222
- project_id: str,
223
- experiment: str,
224
- run_metadata: TrismikRunMetadata,
225
- return_dict: bool = True,
226
- with_responses: bool = False,
227
- ) -> Union[TrismikRunResults, Dict[str, Any]]:
228
- """
229
- Run a test synchronously.
230
-
231
- Args:
232
- test_id (str): ID of the test to run.
233
- project_id (str): ID of the project.
234
- experiment (str): Name of the experiment.
235
- run_metadata (TrismikRunMetadata): Metadata for the
236
- run.
237
- return_dict (bool): If True, return results as a dictionary instead
238
- of TrismikRunResults object. Defaults to True.
239
- with_responses (bool): If True, responses will be included with
240
- the results.
241
-
242
- Returns:
243
- Union[TrismikRunResults, Dict[str, Any]]: Either TrismikRunResults
244
- object or dictionary representation based on return_dict
245
- parameter.
246
-
247
- Raises:
248
- TrismikApiError: If API request fails.
249
- NotImplementedError: If with_responses = True (not yet implemented).
250
- """
251
- loop = self._get_loop()
252
- if return_dict:
253
- return loop.run_until_complete(
254
- self.run_async(
255
- test_id,
256
- project_id,
257
- experiment,
258
- run_metadata,
259
- True,
260
- with_responses,
261
- )
262
- )
263
- else:
264
- return loop.run_until_complete(
265
- self.run_async(
266
- test_id,
267
- project_id,
268
- experiment,
269
- run_metadata,
270
- False,
271
- with_responses,
272
- )
273
- )
274
-
275
- @overload
276
- async def run_async( # noqa: E704
277
- self,
278
- test_id: str,
279
- project_id: str,
280
- experiment: str,
281
- run_metadata: TrismikRunMetadata,
282
- return_dict: Literal[True],
283
- with_responses: bool = False,
284
- ) -> Dict[str, Any]: ...
285
-
286
- @overload
287
- async def run_async( # noqa: E704
288
- self,
289
- test_id: str,
290
- project_id: str,
291
- experiment: str,
292
- run_metadata: TrismikRunMetadata,
293
- return_dict: Literal[False],
294
- with_responses: bool = False,
295
- ) -> TrismikRunResults: ...
296
-
297
- async def run_async(
298
- self,
299
- test_id: str,
300
- project_id: str,
301
- experiment: str,
302
- run_metadata: TrismikRunMetadata,
303
- return_dict: bool = True,
304
- with_responses: bool = False,
305
- ) -> Union[TrismikRunResults, Dict[str, Any]]:
306
- """
307
- Run a test asynchronously.
308
-
309
- Args:
310
- test_id: ID of the test to run.
311
- project_id: ID of the project.
312
- experiment: Name of the experiment.
313
- run_metadata: Metadata for the run.
314
- return_dict: If True, return results as a dictionary instead
315
- of TrismikRunResults object. Defaults to True.
316
- with_responses: If True, responses will be included with
317
- the results.
318
-
319
- Returns:
320
- Union[TrismikRunResults, Dict[str, Any]]: Either TrismikRunResults
321
- object or dictionary representation based on return_dict
322
- parameter.
323
-
324
- Raises:
325
- TrismikApiError: If API request fails.
326
- NotImplementedError: If with_responses = True (not yet implemented).
327
- """
328
- if with_responses:
329
- raise NotImplementedError(
330
- "with_responses is not yet implemented for the new API flow"
331
- )
332
-
333
- # Start run and get first item
334
- start_response = await self._client.start_run(
335
- test_id, project_id, experiment, run_metadata
336
- )
337
-
338
- # Initialize state tracking
339
- states: List[TrismikAdaptiveTestState] = []
340
- run_id = start_response.run_info.id
341
-
342
- # Add initial state
343
- states.append(
344
- TrismikAdaptiveTestState(
345
- run_id=run_id,
346
- state=start_response.state,
347
- completed=start_response.completed,
348
- )
349
- )
350
-
351
- # Run the test and get last state
352
- last_state = await self._run_async(
353
- run_id, start_response.next_item, states
354
- )
355
-
356
- if not last_state:
357
- raise RuntimeError(
358
- "Test run completed but no final state was captured"
359
- )
360
-
361
- score = AdaptiveTestScore(
362
- theta=last_state.state.thetas[-1],
363
- std_error=last_state.state.std_error_history[-1],
364
- )
365
-
366
- results = TrismikRunResults(run_id, score=score)
367
-
368
- if return_dict:
369
- return {
370
- "run_id": results.run_id,
371
- "score": (
372
- {
373
- "theta": results.score.theta,
374
- "std_error": results.score.std_error,
375
- }
376
- if results.score
377
- else None
378
- ),
379
- "responses": results.responses,
380
- }
381
- else:
382
- return results
383
-
384
- @overload
385
- def run_replay( # noqa: E704
386
- self,
387
- previous_run_id: str,
388
- run_metadata: TrismikRunMetadata,
389
- return_dict: Literal[True],
390
- with_responses: bool = False,
391
- ) -> Dict[str, Any]: ...
392
-
393
- @overload
394
- def run_replay( # noqa: E704
395
- self,
396
- previous_run_id: str,
397
- run_metadata: TrismikRunMetadata,
398
- return_dict: Literal[False],
399
- with_responses: bool = False,
400
- ) -> TrismikRunResults: ...
401
-
402
- def run_replay(
403
- self,
404
- previous_run_id: str,
405
- run_metadata: TrismikRunMetadata,
406
- return_dict: bool = True,
407
- with_responses: bool = False,
408
- ) -> Union[TrismikRunResults, Dict[str, Any]]:
409
- """
410
- Replay the exact sequence of questions from a previous run.
411
-
412
- Wraps the run_replay_async method.
413
-
414
- Args:
415
- previous_run_id: ID of a previous run to replay.
416
- run_metadata: Metadata for the replay run.
417
- return_dict: If True, return results as a dictionary instead
418
- of TrismikRunResults object. Defaults to True.
419
- with_responses: If True, responses will be included
420
- with the results.
421
-
422
- Returns:
423
- Union[TrismikRunResults, Dict[str, Any]]: Either TrismikRunResults
424
- object or dictionary representation based on return_dict
425
- parameter.
426
-
427
- Raises:
428
- TrismikApiError: If API request fails.
429
- """
430
- loop = self._get_loop()
431
- if return_dict:
432
- return loop.run_until_complete(
433
- self.run_replay_async(
434
- previous_run_id,
435
- run_metadata,
436
- True,
437
- with_responses,
438
- )
439
- )
440
- else:
441
- return loop.run_until_complete(
442
- self.run_replay_async(
443
- previous_run_id,
444
- run_metadata,
445
- False,
446
- with_responses,
447
- )
448
- )
449
-
450
- @overload
451
- async def run_replay_async( # noqa: E704
452
- self,
453
- previous_run_id: str,
454
- run_metadata: TrismikRunMetadata,
455
- return_dict: Literal[True],
456
- with_responses: bool = False,
457
- ) -> Dict[str, Any]: ...
458
-
459
- @overload
460
- async def run_replay_async( # noqa: E704
461
- self,
462
- previous_run_id: str,
463
- run_metadata: TrismikRunMetadata,
464
- return_dict: Literal[False],
465
- with_responses: bool = False,
466
- ) -> TrismikRunResults: ...
467
-
468
- async def run_replay_async(
469
- self,
470
- previous_run_id: str,
471
- run_metadata: TrismikRunMetadata,
472
- return_dict: bool = True,
473
- with_responses: bool = False,
474
- ) -> Union[TrismikRunResults, Dict[str, Any]]:
475
- """
476
- Replay the exact sequence of questions from a previous run.
477
-
478
- Args:
479
- previous_run_id: ID of a previous run to replay.
480
- run_metadata: Metadata for the run.
481
- return_dict: If True, return results as a dictionary instead
482
- of TrismikRunResults object. Defaults to True.
483
- with_responses: If True, responses will be included
484
- with the results.
485
-
486
- Returns:
487
- Union[TrismikRunResults, Dict[str, Any]]: Either TrismikRunResults
488
- object or dictionary representation based on return_dict
489
- parameter.
490
-
491
- Raises:
492
- TrismikApiError: If API request fails.
493
- """
494
- # Get the original run summary to access dataset and responses
495
- original_summary = await self._client.run_summary(previous_run_id)
496
-
497
- # Build replay request by processing each item in the original order
498
- replay_items = []
499
- with tqdm(
500
- total=len(original_summary.dataset), desc="Running replay..."
501
- ) as pbar:
502
- for item in original_summary.dataset:
503
- # Handle both sync and async item processors
504
- if asyncio.iscoroutinefunction(self._item_processor):
505
- response = await self._item_processor(item)
506
- else:
507
- response = self._item_processor(item)
508
-
509
- # Create replay request item
510
- replay_item = TrismikReplayRequestItem(
511
- itemId=item.id, itemChoiceId=response
512
- )
513
- replay_items.append(replay_item)
514
- pbar.update(1)
515
-
516
- # Create replay request
517
- replay_request = TrismikReplayRequest(responses=replay_items)
518
-
519
- # Submit replay with metadata
520
- replay_response = await self._client.submit_replay(
521
- previous_run_id, replay_request, run_metadata
522
- )
523
-
524
- # Create score from replay response
525
- score = AdaptiveTestScore(
526
- theta=replay_response.state.thetas[-1],
527
- std_error=replay_response.state.std_error_history[-1],
528
- )
529
-
530
- # Return results with optional responses
531
- if with_responses:
532
- results = TrismikRunResults(
533
- run_id=replay_response.id,
534
- score=score,
535
- responses=replay_response.responses,
536
- )
537
- else:
538
- results = TrismikRunResults(run_id=replay_response.id, score=score)
539
-
540
- if return_dict:
541
- return {
542
- "run_id": results.run_id,
543
- "score": (
544
- {
545
- "theta": results.score.theta,
546
- "std_error": results.score.std_error,
547
- }
548
- if results.score
549
- else None
550
- ),
551
- "responses": (
552
- [
553
- {
554
- "dataset_item_id": resp.dataset_item_id,
555
- "value": resp.value,
556
- "correct": resp.correct,
557
- }
558
- for resp in results.responses
559
- ]
560
- if results.responses
561
- else None
562
- ),
563
- }
564
- else:
565
- return results
566
-
567
- async def _run_async(
568
- self,
569
- run_id: str,
570
- first_item: Optional[TrismikItem],
571
- states: List[TrismikAdaptiveTestState],
572
- ) -> Optional[TrismikAdaptiveTestState]:
573
- """
574
- Run a test asynchronously.
575
-
576
- Args:
577
- run_id (str): ID of the run to execute.
578
- first_item (Optional[TrismikItem]): First item from run start.
579
- states (List[TrismikAdaptiveTestState]): List to accumulate states.
580
-
581
- Returns:
582
- Optional[TrismikAdaptiveTestState]: Last state of the run.
583
-
584
- Raises:
585
- TrismikApiError: If API request fails.
586
- """
587
- item = first_item
588
- with tqdm(total=self._max_items, desc="Evaluating") as pbar:
589
- while item is not None:
590
- # Handle both sync and async item processors
591
- if asyncio.iscoroutinefunction(self._item_processor):
592
- response = await self._item_processor(item)
593
- else:
594
- response = self._item_processor(item)
595
-
596
- # Continue run with response
597
- continue_response = await self._client.continue_run(
598
- run_id, response
599
- )
600
-
601
- # Update state tracking
602
- states.append(
603
- TrismikAdaptiveTestState(
604
- run_id=run_id,
605
- state=continue_response.state,
606
- completed=continue_response.completed,
607
- )
608
- )
609
-
610
- pbar.update(1)
611
-
612
- if continue_response.completed:
613
- pbar.total = pbar.n # Update total to current position
614
- pbar.refresh()
615
- break
616
-
617
- item = continue_response.next_item
618
-
619
- last_state = states[-1] if states else None
620
-
621
- return last_state
622
-
623
- def submit_classic_eval(
624
- self, classic_eval_request: TrismikClassicEvalRequest
625
- ) -> TrismikClassicEvalResponse:
626
- """
627
- Submit a classic evaluation run with pre-computed results synchronously.
628
-
629
- Args:
630
- classic_eval_request (TrismikClassicEvalRequest): Request containing
631
- project info, dataset, model outputs, and metrics.
632
-
633
- Returns:
634
- TrismikClassicEvalResponse: Response from the classic evaluation
635
- endpoint.
636
-
637
- Raises:
638
- TrismikPayloadTooLargeError: If the request payload exceeds the
639
- server's size limit.
640
- TrismikValidationError: If the request fails validation.
641
- TrismikApiError: If API request fails.
642
- """
643
- loop = self._get_loop()
644
- return loop.run_until_complete(
645
- self.submit_classic_eval_async(classic_eval_request)
646
- )
647
-
648
- async def submit_classic_eval_async(
649
- self, classic_eval_request: TrismikClassicEvalRequest
650
- ) -> TrismikClassicEvalResponse:
651
- """
652
- Submit a classic evaluation run with pre-computed results async.
653
-
654
- This method allows you to submit pre-computed model outputs and metrics
655
- for evaluation without running an interactive test.
656
-
657
- Args:
658
- classic_eval_request (TrismikClassicEvalRequest): Request containing
659
- project info, dataset, model outputs, and metrics.
660
-
661
- Returns:
662
- TrismikClassicEvalResponse: Response from the classic evaluation
663
- endpoint.
664
-
665
- Raises:
666
- TrismikPayloadTooLargeError: If the request payload exceeds the
667
- server's size limit.
668
- TrismikValidationError: If the request fails validation.
669
- TrismikApiError: If API request fails.
670
- """
671
- return await self._client.submit_classic_eval(classic_eval_request)