trismik 0.9.12__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- trismik/__init__.py +28 -0
- trismik/_async/__init__.py +1 -0
- trismik/_async/_test_transform.py +58 -0
- trismik/_async/client.py +731 -0
- trismik/_async/helpers.py +23 -0
- trismik/_mapper.py +9 -29
- trismik/_sync/__init__.py +1 -0
- trismik/_sync/_test_transform.py +58 -0
- trismik/_sync/client.py +731 -0
- trismik/_sync/helpers.py +27 -0
- trismik/_utils.py +1 -3
- trismik-1.0.0.dist-info/METADATA +258 -0
- trismik-1.0.0.dist-info/RECORD +18 -0
- trismik/adaptive_test.py +0 -669
- trismik/client_async.py +0 -405
- trismik-0.9.12.dist-info/METADATA +0 -177
- trismik-0.9.12.dist-info/RECORD +0 -12
- {trismik-0.9.12.dist-info → trismik-1.0.0.dist-info}/WHEEL +0 -0
- {trismik-0.9.12.dist-info → trismik-1.0.0.dist-info}/licenses/LICENSE +0 -0
trismik/_sync/client.py
ADDED
|
@@ -0,0 +1,731 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Trismik async client for interacting with the Trismik API.
|
|
3
|
+
|
|
4
|
+
This module provides an asynchronous client for interacting with the Trismik
|
|
5
|
+
API. It uses httpx for making HTTP requests.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import Any, Callable, Dict, List, Literal, Optional, Union, overload
|
|
9
|
+
|
|
10
|
+
import httpx
|
|
11
|
+
|
|
12
|
+
from trismik._sync.helpers import process_item
|
|
13
|
+
from trismik._mapper import TrismikResponseMapper
|
|
14
|
+
from trismik._utils import TrismikUtils
|
|
15
|
+
from trismik.exceptions import TrismikApiError, TrismikPayloadTooLargeError, TrismikValidationError
|
|
16
|
+
from trismik.settings import client_settings, environment_settings, evaluation_settings
|
|
17
|
+
from trismik.types import (
|
|
18
|
+
AdaptiveTestScore,
|
|
19
|
+
TrismikAdaptiveTestState,
|
|
20
|
+
TrismikClassicEvalRequest,
|
|
21
|
+
TrismikClassicEvalResponse,
|
|
22
|
+
TrismikDataset,
|
|
23
|
+
TrismikItem,
|
|
24
|
+
TrismikMeResponse,
|
|
25
|
+
TrismikProject,
|
|
26
|
+
TrismikReplayRequest,
|
|
27
|
+
TrismikReplayRequestItem,
|
|
28
|
+
TrismikReplayResponse,
|
|
29
|
+
TrismikRunMetadata,
|
|
30
|
+
TrismikRunResponse,
|
|
31
|
+
TrismikRunResults,
|
|
32
|
+
TrismikRunSummary,
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class TrismikClient:
|
|
37
|
+
"""
|
|
38
|
+
Client for the Trismik API.
|
|
39
|
+
|
|
40
|
+
Provides methods to interact with the Trismik API, including
|
|
41
|
+
dataset management, test runs, and response handling.
|
|
42
|
+
|
|
43
|
+
Supports context manager protocol for automatic resource cleanup.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def __init__(
|
|
47
|
+
self,
|
|
48
|
+
service_url: Optional[str] = None,
|
|
49
|
+
api_key: Optional[str] = None,
|
|
50
|
+
http_client: Optional[httpx.Client] = None,
|
|
51
|
+
max_items: int = evaluation_settings["max_iterations"],
|
|
52
|
+
) -> None:
|
|
53
|
+
"""
|
|
54
|
+
Initialize the Trismik client.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
service_url: URL of the Trismik service. If not provided, uses
|
|
58
|
+
the default endpoint or TRISMIK_SERVICE_URL environment
|
|
59
|
+
variable.
|
|
60
|
+
api_key: API key for authentication. If not provided, reads from
|
|
61
|
+
the TRISMIK_API_KEY environment variable.
|
|
62
|
+
http_client: Custom HTTP client to use for requests. If not provided,
|
|
63
|
+
a new client will be created automatically and managed by this
|
|
64
|
+
instance.
|
|
65
|
+
max_items: Maximum number of items to process in adaptive tests.
|
|
66
|
+
Defaults to evaluation_settings["max_iterations"] (150).
|
|
67
|
+
|
|
68
|
+
Raises:
|
|
69
|
+
TrismikError: If api_key is not provided and not found in the
|
|
70
|
+
environment.
|
|
71
|
+
TrismikApiError: If API request fails.
|
|
72
|
+
"""
|
|
73
|
+
self._service_url = TrismikUtils.option(
|
|
74
|
+
service_url,
|
|
75
|
+
client_settings["endpoint"],
|
|
76
|
+
environment_settings["trismik_service_url"],
|
|
77
|
+
)
|
|
78
|
+
self._api_key = TrismikUtils.required_option(
|
|
79
|
+
api_key, "api_key", environment_settings["trismik_api_key"]
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
# Set default headers with API key
|
|
83
|
+
default_headers = {"x-api-key": self._api_key}
|
|
84
|
+
|
|
85
|
+
# Track whether we own the client (created it vs user-provided)
|
|
86
|
+
self._owns_client = http_client is None
|
|
87
|
+
self._http_client = http_client or httpx.Client(
|
|
88
|
+
base_url=self._service_url, headers=default_headers, timeout=30.0
|
|
89
|
+
)
|
|
90
|
+
self._max_items = max_items
|
|
91
|
+
|
|
92
|
+
def __enter__(self) -> "TrismikClient":
|
|
93
|
+
"""
|
|
94
|
+
Enter context manager.
|
|
95
|
+
|
|
96
|
+
Returns the client instance for use in with-statement.
|
|
97
|
+
"""
|
|
98
|
+
return self
|
|
99
|
+
|
|
100
|
+
def __exit__(
|
|
101
|
+
self,
|
|
102
|
+
exc_type: Optional[type[BaseException]],
|
|
103
|
+
exc_val: Optional[BaseException],
|
|
104
|
+
exc_tb: Optional[object],
|
|
105
|
+
) -> None:
|
|
106
|
+
"""
|
|
107
|
+
Exit context manager and close client if owned.
|
|
108
|
+
|
|
109
|
+
Automatically closes the HTTP client if it was created by this
|
|
110
|
+
instance (not user-provided). Ensures proper resource cleanup.
|
|
111
|
+
"""
|
|
112
|
+
self.close()
|
|
113
|
+
|
|
114
|
+
def close(self) -> None:
|
|
115
|
+
"""
|
|
116
|
+
Explicitly close the HTTP client if owned.
|
|
117
|
+
|
|
118
|
+
Call this method when you're done with the client to ensure
|
|
119
|
+
proper cleanup of resources. Only closes the client if it was
|
|
120
|
+
created by this instance (not user-provided).
|
|
121
|
+
|
|
122
|
+
If you use the client as a context manager, this is called
|
|
123
|
+
automatically on exit.
|
|
124
|
+
"""
|
|
125
|
+
if self._owns_client:
|
|
126
|
+
self._http_client.close()
|
|
127
|
+
|
|
128
|
+
def _handle_http_error(self, e: httpx.HTTPStatusError) -> Exception:
|
|
129
|
+
"""
|
|
130
|
+
Handle HTTP errors and return appropriate Trismik exceptions.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
e (httpx.HTTPStatusError): The HTTP status error to handle.
|
|
134
|
+
|
|
135
|
+
Returns:
|
|
136
|
+
Exception: The appropriate Trismik exception to raise.
|
|
137
|
+
"""
|
|
138
|
+
if e.response.status_code == 413:
|
|
139
|
+
# Handle payload too large error specifically
|
|
140
|
+
try:
|
|
141
|
+
backend_message = e.response.json().get("detail", "Payload too large.")
|
|
142
|
+
except Exception:
|
|
143
|
+
backend_message = "Payload too large."
|
|
144
|
+
return TrismikPayloadTooLargeError(backend_message)
|
|
145
|
+
elif e.response.status_code == 422:
|
|
146
|
+
# Handle validation error specifically
|
|
147
|
+
try:
|
|
148
|
+
backend_message = e.response.json().get("detail", "Validation failed.")
|
|
149
|
+
except Exception:
|
|
150
|
+
backend_message = "Validation failed."
|
|
151
|
+
return TrismikValidationError(backend_message)
|
|
152
|
+
else:
|
|
153
|
+
return TrismikApiError(TrismikUtils.get_error_message(e.response))
|
|
154
|
+
|
|
155
|
+
def list_datasets(self) -> List[TrismikDataset]:
|
|
156
|
+
"""
|
|
157
|
+
Get a list of available datasets.
|
|
158
|
+
|
|
159
|
+
Returns:
|
|
160
|
+
List[TrismikDataset]: List of available datasets.
|
|
161
|
+
|
|
162
|
+
Raises:
|
|
163
|
+
TrismikApiError: If API request fails.
|
|
164
|
+
"""
|
|
165
|
+
try:
|
|
166
|
+
url = "/datasets"
|
|
167
|
+
response = self._http_client.get(url)
|
|
168
|
+
response.raise_for_status()
|
|
169
|
+
json = response.json()
|
|
170
|
+
return TrismikResponseMapper.to_datasets(json)
|
|
171
|
+
except httpx.HTTPStatusError as e:
|
|
172
|
+
raise TrismikApiError(TrismikUtils.get_error_message(e.response)) from e
|
|
173
|
+
except httpx.HTTPError as e:
|
|
174
|
+
raise TrismikApiError(str(e)) from e
|
|
175
|
+
|
|
176
|
+
def start_run(
|
|
177
|
+
self,
|
|
178
|
+
dataset_id: str,
|
|
179
|
+
project_id: str,
|
|
180
|
+
experiment: str,
|
|
181
|
+
metadata: Optional[TrismikRunMetadata] = None,
|
|
182
|
+
) -> TrismikRunResponse:
|
|
183
|
+
"""
|
|
184
|
+
Start a new run for a dataset and get the first item.
|
|
185
|
+
|
|
186
|
+
Args:
|
|
187
|
+
dataset_id (str): ID of the dataset.
|
|
188
|
+
project_id (str): ID of the project.
|
|
189
|
+
experiment (str): Name of the experiment.
|
|
190
|
+
metadata (Optional[TrismikRunMetadata]): Run metadata.
|
|
191
|
+
|
|
192
|
+
Returns:
|
|
193
|
+
TrismikRunResponse: Run response.
|
|
194
|
+
|
|
195
|
+
Raises:
|
|
196
|
+
TrismikPayloadTooLargeError: If the request payload exceeds the
|
|
197
|
+
server's size limit.
|
|
198
|
+
TrismikApiError: If API request fails.
|
|
199
|
+
"""
|
|
200
|
+
try:
|
|
201
|
+
url = "/runs/start"
|
|
202
|
+
body = {
|
|
203
|
+
"datasetId": dataset_id,
|
|
204
|
+
"projectId": project_id,
|
|
205
|
+
"experiment": experiment,
|
|
206
|
+
"metadata": metadata.toDict() if metadata else {},
|
|
207
|
+
}
|
|
208
|
+
response = self._http_client.post(url, json=body)
|
|
209
|
+
response.raise_for_status()
|
|
210
|
+
json = response.json()
|
|
211
|
+
return TrismikResponseMapper.to_run_response(json)
|
|
212
|
+
except httpx.HTTPStatusError as e:
|
|
213
|
+
raise self._handle_http_error(e) from e
|
|
214
|
+
except httpx.HTTPError as e:
|
|
215
|
+
raise TrismikApiError(str(e)) from e
|
|
216
|
+
|
|
217
|
+
def continue_run(self, run_id: str, item_choice_id: str) -> TrismikRunResponse:
|
|
218
|
+
"""
|
|
219
|
+
Continue a run: respond to the current item and get the next one.
|
|
220
|
+
|
|
221
|
+
Args:
|
|
222
|
+
run_id (str): ID of the run.
|
|
223
|
+
item_choice_id (str): ID of the chosen item response.
|
|
224
|
+
|
|
225
|
+
Returns:
|
|
226
|
+
TrismikRunResponse: Run response.
|
|
227
|
+
|
|
228
|
+
Raises:
|
|
229
|
+
TrismikApiError: If API request fails.
|
|
230
|
+
"""
|
|
231
|
+
try:
|
|
232
|
+
url = "/runs/continue"
|
|
233
|
+
body = {"itemChoiceId": item_choice_id, "runId": run_id}
|
|
234
|
+
response = self._http_client.post(url, json=body)
|
|
235
|
+
response.raise_for_status()
|
|
236
|
+
json = response.json()
|
|
237
|
+
return TrismikResponseMapper.to_run_response(json)
|
|
238
|
+
except httpx.HTTPStatusError as e:
|
|
239
|
+
raise TrismikApiError(TrismikUtils.get_error_message(e.response)) from e
|
|
240
|
+
except httpx.HTTPError as e:
|
|
241
|
+
raise TrismikApiError(str(e)) from e
|
|
242
|
+
|
|
243
|
+
def run_summary(self, run_id: str) -> TrismikRunSummary:
|
|
244
|
+
"""
|
|
245
|
+
Get run summary including responses, dataset, and state.
|
|
246
|
+
|
|
247
|
+
Args:
|
|
248
|
+
run_id (str): ID of the run.
|
|
249
|
+
|
|
250
|
+
Returns:
|
|
251
|
+
TrismikRunSummary: Complete run summary with responses,
|
|
252
|
+
dataset, state, and metadata.
|
|
253
|
+
|
|
254
|
+
Raises:
|
|
255
|
+
TrismikApiError: If API request fails.
|
|
256
|
+
"""
|
|
257
|
+
try:
|
|
258
|
+
url = f"/runs/adaptive/{run_id}"
|
|
259
|
+
response = self._http_client.get(url)
|
|
260
|
+
response.raise_for_status()
|
|
261
|
+
json = response.json()
|
|
262
|
+
return TrismikResponseMapper.to_run_summary(json)
|
|
263
|
+
except httpx.HTTPStatusError as e:
|
|
264
|
+
raise TrismikApiError(TrismikUtils.get_error_message(e.response)) from e
|
|
265
|
+
except httpx.HTTPError as e:
|
|
266
|
+
raise TrismikApiError(str(e)) from e
|
|
267
|
+
|
|
268
|
+
def submit_replay(
|
|
269
|
+
self,
|
|
270
|
+
run_id: str,
|
|
271
|
+
replay_request: TrismikReplayRequest,
|
|
272
|
+
metadata: Optional[TrismikRunMetadata] = None,
|
|
273
|
+
) -> TrismikReplayResponse:
|
|
274
|
+
"""
|
|
275
|
+
Submit a replay of a run with specific responses.
|
|
276
|
+
|
|
277
|
+
Args:
|
|
278
|
+
run_id (str): ID of the run to replay.
|
|
279
|
+
replay_request (TrismikReplayRequest): Request containing responses
|
|
280
|
+
to submit.
|
|
281
|
+
metadata (Optional[TrismikRunMetadata]): Run metadata.
|
|
282
|
+
|
|
283
|
+
Returns:
|
|
284
|
+
TrismikReplayResponse: Response from the replay endpoint.
|
|
285
|
+
|
|
286
|
+
Raises:
|
|
287
|
+
TrismikPayloadTooLargeError: If the request payload exceeds the
|
|
288
|
+
server's size limit.
|
|
289
|
+
TrismikValidationError: If the request fails validation (e.g.,
|
|
290
|
+
duplicate item IDs, unknown item IDs).
|
|
291
|
+
TrismikApiError: If API request fails.
|
|
292
|
+
"""
|
|
293
|
+
try:
|
|
294
|
+
url = f"runs/{run_id}/replay"
|
|
295
|
+
|
|
296
|
+
# Convert TrismikReplayRequestItem objects to dictionaries
|
|
297
|
+
responses_dict = [
|
|
298
|
+
{"itemId": item.itemId, "itemChoiceId": item.itemChoiceId}
|
|
299
|
+
for item in replay_request.responses
|
|
300
|
+
]
|
|
301
|
+
|
|
302
|
+
body = {
|
|
303
|
+
"responses": responses_dict,
|
|
304
|
+
"metadata": metadata.toDict() if metadata else {},
|
|
305
|
+
}
|
|
306
|
+
response = self._http_client.post(url, json=body)
|
|
307
|
+
response.raise_for_status()
|
|
308
|
+
json = response.json()
|
|
309
|
+
return TrismikResponseMapper.to_replay_response(json)
|
|
310
|
+
except httpx.HTTPStatusError as e:
|
|
311
|
+
raise self._handle_http_error(e) from e
|
|
312
|
+
except httpx.HTTPError as e:
|
|
313
|
+
raise TrismikApiError(str(e)) from e
|
|
314
|
+
|
|
315
|
+
def me(self) -> TrismikMeResponse:
|
|
316
|
+
"""
|
|
317
|
+
Get current user information.
|
|
318
|
+
|
|
319
|
+
Returns:
|
|
320
|
+
TrismikMeResponse: User information including validity and payload.
|
|
321
|
+
|
|
322
|
+
Raises:
|
|
323
|
+
TrismikApiError: If API request fails.
|
|
324
|
+
"""
|
|
325
|
+
try:
|
|
326
|
+
url = "../admin/api-keys/me"
|
|
327
|
+
response = self._http_client.get(url)
|
|
328
|
+
response.raise_for_status()
|
|
329
|
+
json = response.json()
|
|
330
|
+
return TrismikResponseMapper.to_me_response(json)
|
|
331
|
+
except httpx.HTTPStatusError as e:
|
|
332
|
+
raise TrismikApiError(TrismikUtils.get_error_message(e.response)) from e
|
|
333
|
+
except httpx.HTTPError as e:
|
|
334
|
+
raise TrismikApiError(str(e)) from e
|
|
335
|
+
|
|
336
|
+
def submit_classic_eval(
|
|
337
|
+
self, classic_eval_request: TrismikClassicEvalRequest
|
|
338
|
+
) -> TrismikClassicEvalResponse:
|
|
339
|
+
"""
|
|
340
|
+
Submit a classic evaluation run with pre-computed results.
|
|
341
|
+
|
|
342
|
+
Args:
|
|
343
|
+
classic_eval_request (TrismikClassicEvalRequest): Request containing
|
|
344
|
+
project info, dataset, model outputs, and metrics.
|
|
345
|
+
|
|
346
|
+
Returns:
|
|
347
|
+
TrismikClassicEvalResponse: Response from the classic evaluation
|
|
348
|
+
endpoint.
|
|
349
|
+
|
|
350
|
+
Raises:
|
|
351
|
+
TrismikPayloadTooLargeError: If the request payload exceeds the
|
|
352
|
+
server's size limit.
|
|
353
|
+
TrismikValidationError: If the request fails validation.
|
|
354
|
+
TrismikApiError: If API request fails.
|
|
355
|
+
"""
|
|
356
|
+
try:
|
|
357
|
+
url = "/runs/classic"
|
|
358
|
+
|
|
359
|
+
# Convert request object to dictionary
|
|
360
|
+
items_dict = [
|
|
361
|
+
{
|
|
362
|
+
"datasetItemId": item.datasetItemId,
|
|
363
|
+
"modelInput": item.modelInput,
|
|
364
|
+
"modelOutput": item.modelOutput,
|
|
365
|
+
"goldOutput": item.goldOutput,
|
|
366
|
+
"metrics": item.metrics,
|
|
367
|
+
}
|
|
368
|
+
for item in classic_eval_request.items
|
|
369
|
+
]
|
|
370
|
+
|
|
371
|
+
metrics_dict = [
|
|
372
|
+
{
|
|
373
|
+
"metricId": metric.metricId,
|
|
374
|
+
"valueType": TrismikUtils.metric_value_to_type(metric.value),
|
|
375
|
+
"value": metric.value,
|
|
376
|
+
}
|
|
377
|
+
for metric in classic_eval_request.metrics
|
|
378
|
+
]
|
|
379
|
+
|
|
380
|
+
body = {
|
|
381
|
+
"projectId": classic_eval_request.projectId,
|
|
382
|
+
"experimentName": classic_eval_request.experimentName,
|
|
383
|
+
"datasetId": classic_eval_request.datasetId,
|
|
384
|
+
"modelName": classic_eval_request.modelName,
|
|
385
|
+
"hyperparameters": classic_eval_request.hyperparameters,
|
|
386
|
+
"items": items_dict,
|
|
387
|
+
"metrics": metrics_dict,
|
|
388
|
+
}
|
|
389
|
+
|
|
390
|
+
response = self._http_client.post(url, json=body)
|
|
391
|
+
response.raise_for_status()
|
|
392
|
+
json = response.json()
|
|
393
|
+
return TrismikResponseMapper.to_classic_eval_response(json)
|
|
394
|
+
except httpx.HTTPStatusError as e:
|
|
395
|
+
raise self._handle_http_error(e) from e
|
|
396
|
+
except httpx.HTTPError as e:
|
|
397
|
+
raise TrismikApiError(str(e)) from e
|
|
398
|
+
|
|
399
|
+
def create_project(
|
|
400
|
+
self,
|
|
401
|
+
name: str,
|
|
402
|
+
team_id: Optional[str] = None,
|
|
403
|
+
description: Optional[str] = None,
|
|
404
|
+
) -> TrismikProject:
|
|
405
|
+
"""
|
|
406
|
+
Create a new project.
|
|
407
|
+
|
|
408
|
+
Args:
|
|
409
|
+
name (str): Name of the project.
|
|
410
|
+
team_id (Optional[str]): ID of the team to create the
|
|
411
|
+
project in.
|
|
412
|
+
description (Optional[str]): Optional description of the project.
|
|
413
|
+
|
|
414
|
+
Returns:
|
|
415
|
+
TrismikProject: Created project information.
|
|
416
|
+
|
|
417
|
+
Raises:
|
|
418
|
+
TrismikValidationError: If the request fails validation.
|
|
419
|
+
TrismikApiError: If API request fails.
|
|
420
|
+
"""
|
|
421
|
+
try:
|
|
422
|
+
url = "../admin/public/projects"
|
|
423
|
+
|
|
424
|
+
body = {
|
|
425
|
+
"name": name,
|
|
426
|
+
}
|
|
427
|
+
if team_id is not None:
|
|
428
|
+
body["teamId"] = team_id
|
|
429
|
+
if description is not None:
|
|
430
|
+
body["description"] = description
|
|
431
|
+
|
|
432
|
+
response = self._http_client.post(url, json=body)
|
|
433
|
+
response.raise_for_status()
|
|
434
|
+
json = response.json()
|
|
435
|
+
return TrismikResponseMapper.to_project(json)
|
|
436
|
+
except httpx.HTTPStatusError as e:
|
|
437
|
+
raise self._handle_http_error(e) from e
|
|
438
|
+
except httpx.HTTPError as e:
|
|
439
|
+
raise TrismikApiError(str(e)) from e
|
|
440
|
+
|
|
441
|
+
# ===== Test Orchestration Methods =====
|
|
442
|
+
|
|
443
|
+
@overload
|
|
444
|
+
def run( # noqa: E704
|
|
445
|
+
self,
|
|
446
|
+
test_id: str,
|
|
447
|
+
project_id: str,
|
|
448
|
+
experiment: str,
|
|
449
|
+
run_metadata: TrismikRunMetadata,
|
|
450
|
+
item_processor: Callable[[TrismikItem], Any],
|
|
451
|
+
on_progress: Optional[Callable[[int, int], None]] = None,
|
|
452
|
+
return_dict: Literal[True] = True,
|
|
453
|
+
with_responses: bool = False,
|
|
454
|
+
) -> Dict[str, Any]: ...
|
|
455
|
+
|
|
456
|
+
@overload
|
|
457
|
+
def run( # noqa: E704
|
|
458
|
+
self,
|
|
459
|
+
test_id: str,
|
|
460
|
+
project_id: str,
|
|
461
|
+
experiment: str,
|
|
462
|
+
run_metadata: TrismikRunMetadata,
|
|
463
|
+
item_processor: Callable[[TrismikItem], Any],
|
|
464
|
+
on_progress: Optional[Callable[[int, int], None]] = None,
|
|
465
|
+
return_dict: Literal[False] = False,
|
|
466
|
+
with_responses: bool = False,
|
|
467
|
+
) -> TrismikRunResults: ...
|
|
468
|
+
|
|
469
|
+
def run(
|
|
470
|
+
self,
|
|
471
|
+
test_id: str,
|
|
472
|
+
project_id: str,
|
|
473
|
+
experiment: str,
|
|
474
|
+
run_metadata: TrismikRunMetadata,
|
|
475
|
+
item_processor: Callable[[TrismikItem], Any],
|
|
476
|
+
on_progress: Optional[Callable[[int, int], None]] = None,
|
|
477
|
+
return_dict: bool = True,
|
|
478
|
+
with_responses: bool = False,
|
|
479
|
+
) -> Union[TrismikRunResults, Dict[str, Any]]:
|
|
480
|
+
"""
|
|
481
|
+
Run an adaptive test.
|
|
482
|
+
|
|
483
|
+
Args:
|
|
484
|
+
test_id: ID of the test to run.
|
|
485
|
+
project_id: ID of the project.
|
|
486
|
+
experiment: Name of the experiment.
|
|
487
|
+
run_metadata: Metadata for the run.
|
|
488
|
+
item_processor: Function to process test items (can be sync or async).
|
|
489
|
+
on_progress: Optional callback for progress updates (current, total).
|
|
490
|
+
return_dict: If True, return dict instead of TrismikRunResults.
|
|
491
|
+
Defaults to True.
|
|
492
|
+
with_responses: If True, include responses in results.
|
|
493
|
+
|
|
494
|
+
Returns:
|
|
495
|
+
Test results as TrismikRunResults or dict.
|
|
496
|
+
|
|
497
|
+
Raises:
|
|
498
|
+
TrismikApiError: If API request fails.
|
|
499
|
+
NotImplementedError: If with_responses=True (not yet implemented).
|
|
500
|
+
"""
|
|
501
|
+
if with_responses:
|
|
502
|
+
raise NotImplementedError("with_responses is not yet implemented for the new API flow")
|
|
503
|
+
|
|
504
|
+
# Start run and get first item
|
|
505
|
+
start_response = self.start_run(test_id, project_id, experiment, run_metadata)
|
|
506
|
+
|
|
507
|
+
# Initialize state tracking
|
|
508
|
+
states: List[TrismikAdaptiveTestState] = []
|
|
509
|
+
run_id = start_response.run_info.id
|
|
510
|
+
states.append(
|
|
511
|
+
TrismikAdaptiveTestState(
|
|
512
|
+
run_id=run_id,
|
|
513
|
+
state=start_response.state,
|
|
514
|
+
completed=start_response.completed,
|
|
515
|
+
)
|
|
516
|
+
)
|
|
517
|
+
|
|
518
|
+
# Run the test and get last state
|
|
519
|
+
last_state = self.run_test_loop(
|
|
520
|
+
run_id,
|
|
521
|
+
start_response.next_item,
|
|
522
|
+
states,
|
|
523
|
+
item_processor,
|
|
524
|
+
on_progress,
|
|
525
|
+
)
|
|
526
|
+
|
|
527
|
+
if not last_state:
|
|
528
|
+
raise RuntimeError("Test run completed but no final state was captured")
|
|
529
|
+
|
|
530
|
+
score = AdaptiveTestScore(
|
|
531
|
+
theta=last_state.state.thetas[-1],
|
|
532
|
+
std_error=last_state.state.std_error_history[-1],
|
|
533
|
+
)
|
|
534
|
+
|
|
535
|
+
results = TrismikRunResults(run_id, score=score)
|
|
536
|
+
|
|
537
|
+
if return_dict:
|
|
538
|
+
return {
|
|
539
|
+
"run_id": results.run_id,
|
|
540
|
+
"score": (
|
|
541
|
+
{
|
|
542
|
+
"theta": results.score.theta,
|
|
543
|
+
"std_error": results.score.std_error,
|
|
544
|
+
}
|
|
545
|
+
if results.score
|
|
546
|
+
else None
|
|
547
|
+
),
|
|
548
|
+
"responses": results.responses,
|
|
549
|
+
}
|
|
550
|
+
else:
|
|
551
|
+
return results
|
|
552
|
+
|
|
553
|
+
def run_test_loop(
|
|
554
|
+
self,
|
|
555
|
+
run_id: str,
|
|
556
|
+
first_item: Optional[TrismikItem],
|
|
557
|
+
states: List[TrismikAdaptiveTestState],
|
|
558
|
+
item_processor: Callable[[TrismikItem], Any],
|
|
559
|
+
on_progress: Optional[Callable[[int, int], None]] = None,
|
|
560
|
+
) -> Optional[TrismikAdaptiveTestState]:
|
|
561
|
+
"""
|
|
562
|
+
Core test execution loop.
|
|
563
|
+
|
|
564
|
+
This method contains the main test orchestration logic.
|
|
565
|
+
The sync version is auto-generated by unasync.
|
|
566
|
+
|
|
567
|
+
Args:
|
|
568
|
+
run_id: ID of the run to execute.
|
|
569
|
+
first_item: First item from run start.
|
|
570
|
+
states: List to accumulate states.
|
|
571
|
+
item_processor: Function to process test items.
|
|
572
|
+
on_progress: Optional callback for progress updates.
|
|
573
|
+
|
|
574
|
+
Returns:
|
|
575
|
+
Last state of the run.
|
|
576
|
+
|
|
577
|
+
Raises:
|
|
578
|
+
TrismikApiError: If API request fails.
|
|
579
|
+
"""
|
|
580
|
+
item = first_item
|
|
581
|
+
current = 0
|
|
582
|
+
|
|
583
|
+
while item is not None:
|
|
584
|
+
# Report progress
|
|
585
|
+
if on_progress:
|
|
586
|
+
on_progress(current, self._max_items)
|
|
587
|
+
|
|
588
|
+
# Process item with helper (handles both sync and async processors)
|
|
589
|
+
response = process_item(item_processor, item)
|
|
590
|
+
|
|
591
|
+
# Continue run with response
|
|
592
|
+
continue_response = self.continue_run(run_id, response)
|
|
593
|
+
|
|
594
|
+
# Update state tracking
|
|
595
|
+
states.append(
|
|
596
|
+
TrismikAdaptiveTestState(
|
|
597
|
+
run_id=run_id,
|
|
598
|
+
state=continue_response.state,
|
|
599
|
+
completed=continue_response.completed,
|
|
600
|
+
)
|
|
601
|
+
)
|
|
602
|
+
|
|
603
|
+
current += 1
|
|
604
|
+
|
|
605
|
+
if continue_response.completed:
|
|
606
|
+
# Final progress update
|
|
607
|
+
if on_progress:
|
|
608
|
+
on_progress(current, current)
|
|
609
|
+
break
|
|
610
|
+
|
|
611
|
+
item = continue_response.next_item
|
|
612
|
+
|
|
613
|
+
return states[-1] if states else None
|
|
614
|
+
|
|
615
|
+
@overload
|
|
616
|
+
def run_replay( # noqa: E704
|
|
617
|
+
self,
|
|
618
|
+
previous_run_id: str,
|
|
619
|
+
run_metadata: TrismikRunMetadata,
|
|
620
|
+
item_processor: Callable[[TrismikItem], Any],
|
|
621
|
+
on_progress: Optional[Callable[[int, int], None]] = None,
|
|
622
|
+
return_dict: Literal[True] = True,
|
|
623
|
+
with_responses: bool = False,
|
|
624
|
+
) -> Dict[str, Any]: ...
|
|
625
|
+
|
|
626
|
+
@overload
|
|
627
|
+
def run_replay( # noqa: E704
|
|
628
|
+
self,
|
|
629
|
+
previous_run_id: str,
|
|
630
|
+
run_metadata: TrismikRunMetadata,
|
|
631
|
+
item_processor: Callable[[TrismikItem], Any],
|
|
632
|
+
on_progress: Optional[Callable[[int, int], None]] = None,
|
|
633
|
+
return_dict: Literal[False] = False,
|
|
634
|
+
with_responses: bool = False,
|
|
635
|
+
) -> TrismikRunResults: ...
|
|
636
|
+
|
|
637
|
+
def run_replay(
|
|
638
|
+
self,
|
|
639
|
+
previous_run_id: str,
|
|
640
|
+
run_metadata: TrismikRunMetadata,
|
|
641
|
+
item_processor: Callable[[TrismikItem], Any],
|
|
642
|
+
on_progress: Optional[Callable[[int, int], None]] = None,
|
|
643
|
+
return_dict: bool = True,
|
|
644
|
+
with_responses: bool = False,
|
|
645
|
+
) -> Union[TrismikRunResults, Dict[str, Any]]:
|
|
646
|
+
"""
|
|
647
|
+
Replay the exact sequence of questions from a previous run.
|
|
648
|
+
|
|
649
|
+
Args:
|
|
650
|
+
previous_run_id: ID of a previous run to replay.
|
|
651
|
+
run_metadata: Metadata for the replay run.
|
|
652
|
+
item_processor: Function to process test items (can be sync or async).
|
|
653
|
+
on_progress: Optional callback for progress updates (current, total).
|
|
654
|
+
return_dict: If True, return dict instead of TrismikRunResults.
|
|
655
|
+
with_responses: If True, include responses in results.
|
|
656
|
+
|
|
657
|
+
Returns:
|
|
658
|
+
Test results as TrismikRunResults or dict.
|
|
659
|
+
|
|
660
|
+
Raises:
|
|
661
|
+
TrismikApiError: If API request fails.
|
|
662
|
+
"""
|
|
663
|
+
# Get the original run summary
|
|
664
|
+
original_summary = self.run_summary(previous_run_id)
|
|
665
|
+
|
|
666
|
+
# Build replay request by processing each item
|
|
667
|
+
replay_items = []
|
|
668
|
+
total = len(original_summary.dataset)
|
|
669
|
+
|
|
670
|
+
for idx, item in enumerate(original_summary.dataset):
|
|
671
|
+
# Report progress
|
|
672
|
+
if on_progress:
|
|
673
|
+
on_progress(idx, total)
|
|
674
|
+
|
|
675
|
+
# Process item with helper (handles both sync and async processors)
|
|
676
|
+
response = process_item(item_processor, item)
|
|
677
|
+
|
|
678
|
+
# Create replay request item
|
|
679
|
+
replay_item = TrismikReplayRequestItem(itemId=item.id, itemChoiceId=response)
|
|
680
|
+
replay_items.append(replay_item)
|
|
681
|
+
|
|
682
|
+
# Final progress update
|
|
683
|
+
if on_progress:
|
|
684
|
+
on_progress(total, total)
|
|
685
|
+
|
|
686
|
+
# Create and submit replay request
|
|
687
|
+
replay_request = TrismikReplayRequest(responses=replay_items)
|
|
688
|
+
replay_response = self.submit_replay(previous_run_id, replay_request, run_metadata)
|
|
689
|
+
|
|
690
|
+
# Create score from replay response
|
|
691
|
+
score = AdaptiveTestScore(
|
|
692
|
+
theta=replay_response.state.thetas[-1],
|
|
693
|
+
std_error=replay_response.state.std_error_history[-1],
|
|
694
|
+
)
|
|
695
|
+
|
|
696
|
+
# Return results with optional responses
|
|
697
|
+
if with_responses:
|
|
698
|
+
results = TrismikRunResults(
|
|
699
|
+
run_id=replay_response.id,
|
|
700
|
+
score=score,
|
|
701
|
+
responses=replay_response.responses,
|
|
702
|
+
)
|
|
703
|
+
else:
|
|
704
|
+
results = TrismikRunResults(run_id=replay_response.id, score=score)
|
|
705
|
+
|
|
706
|
+
if return_dict:
|
|
707
|
+
return {
|
|
708
|
+
"run_id": results.run_id,
|
|
709
|
+
"score": (
|
|
710
|
+
{
|
|
711
|
+
"theta": results.score.theta,
|
|
712
|
+
"std_error": results.score.std_error,
|
|
713
|
+
}
|
|
714
|
+
if results.score
|
|
715
|
+
else None
|
|
716
|
+
),
|
|
717
|
+
"responses": (
|
|
718
|
+
[
|
|
719
|
+
{
|
|
720
|
+
"dataset_item_id": resp.dataset_item_id,
|
|
721
|
+
"value": resp.value,
|
|
722
|
+
"correct": resp.correct,
|
|
723
|
+
}
|
|
724
|
+
for resp in results.responses
|
|
725
|
+
]
|
|
726
|
+
if results.responses
|
|
727
|
+
else None
|
|
728
|
+
),
|
|
729
|
+
}
|
|
730
|
+
else:
|
|
731
|
+
return results
|