simplai-sdk 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. billing/__init__.py +6 -0
  2. billing/api.py +55 -0
  3. billing/client.py +14 -0
  4. billing/schema.py +15 -0
  5. constants/__init__.py +90 -0
  6. core/__init__.py +53 -0
  7. core/agents/__init__.py +42 -0
  8. core/agents/execution/__init__.py +49 -0
  9. core/agents/execution/api.py +283 -0
  10. core/agents/execution/client.py +1139 -0
  11. core/agents/models.py +99 -0
  12. core/workflows/WORKFLOW_ARCHITECTURE.md +417 -0
  13. core/workflows/__init__.py +31 -0
  14. core/workflows/bulk/__init__.py +14 -0
  15. core/workflows/bulk/api.py +202 -0
  16. core/workflows/bulk/client.py +115 -0
  17. core/workflows/bulk/schema.py +58 -0
  18. core/workflows/models.py +49 -0
  19. core/workflows/scheduling/__init__.py +9 -0
  20. core/workflows/scheduling/api.py +179 -0
  21. core/workflows/scheduling/client.py +128 -0
  22. core/workflows/scheduling/schema.py +74 -0
  23. core/workflows/tool_execution/__init__.py +16 -0
  24. core/workflows/tool_execution/api.py +172 -0
  25. core/workflows/tool_execution/client.py +195 -0
  26. core/workflows/tool_execution/schema.py +40 -0
  27. exceptions/__init__.py +21 -0
  28. simplai_sdk/__init__.py +7 -0
  29. simplai_sdk/simplai.py +239 -0
  30. simplai_sdk-0.1.0.dist-info/METADATA +728 -0
  31. simplai_sdk-0.1.0.dist-info/RECORD +42 -0
  32. simplai_sdk-0.1.0.dist-info/WHEEL +5 -0
  33. simplai_sdk-0.1.0.dist-info/licenses/LICENSE +21 -0
  34. simplai_sdk-0.1.0.dist-info/top_level.txt +7 -0
  35. traces/__init__.py +1 -0
  36. traces/agents/__init__.py +55 -0
  37. traces/agents/api.py +350 -0
  38. traces/agents/client.py +697 -0
  39. traces/agents/models.py +249 -0
  40. traces/workflows/__init__.py +0 -0
  41. utils/__init__.py +0 -0
  42. utils/config.py +117 -0
@@ -0,0 +1,697 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import time
5
+ from typing import Any, Dict, List, Optional
6
+
7
+ import httpx
8
+
9
+ from constants import (
10
+ DEFAULT_BASE_URL,
11
+ METRICS_V1_AGENT_PATH,
12
+ METRICS_V1_TOOL_PATH,
13
+ TRACE_V1_DETAILS_PATH,
14
+ TRACE_V1_FETCH_PATH,
15
+ TRACE_V2_AGGREGATE_OUTPUT_PATH,
16
+ TRACE_V2_DETAILS_PATH,
17
+ TRACE_V2_DOWNLOAD_PATH,
18
+ TRACE_V2_FETCH_PATH,
19
+ TRACE_V2_TREE_PATH,
20
+ )
21
+ from .models import (
22
+ FetchTraceFilters,
23
+ MetricsRequestDto,
24
+ PageableResponse,
25
+ RagTrace,
26
+ TraceError,
27
+ TraceFilters,
28
+ TraceNode,
29
+ )
30
+
31
+
32
+ class TraceClient:
33
+ """Low-level HTTP client for the Simplai tracer API.
34
+
35
+ This class is reusable and manages underlying HTTP clients for efficiency.
36
+
37
+ Args:
38
+ api_key: PIM-SID key used for authenticating with the Simplai edge service.
39
+ base_url: Base URL of the Simplai edge service.
40
+ timeout: Default request timeout in seconds.
41
+ max_retries: Number of retries for transient HTTP errors.
42
+ backoff_factor: Base factor (in seconds) used for exponential backoff.
43
+ user_id: Optional user ID for trace operations.
44
+ tenant_id: Optional tenant ID (defaults to "1").
45
+ project_id: Optional project ID.
46
+ """
47
+
48
+ def __init__(
49
+ self,
50
+ api_key: str,
51
+ *,
52
+ base_url: str = DEFAULT_BASE_URL,
53
+ timeout: float = 30.0,
54
+ max_retries: int = 3,
55
+ backoff_factor: float = 0.5,
56
+ user_id: Optional[str] = None,
57
+ tenant_id: str = "1",
58
+ project_id: Optional[int] = None,
59
+ seller_id: Optional[str] = None,
60
+ client_id: Optional[str] = None,
61
+ seller_profile_id: Optional[str] = None,
62
+ ) -> None:
63
+ self.api_key = api_key
64
+ self.base_url = base_url.rstrip("/")
65
+ self.timeout = timeout
66
+ self.max_retries = max_retries
67
+ self.backoff_factor = backoff_factor
68
+ self.user_id = user_id
69
+ self.tenant_id = tenant_id
70
+ self.project_id = project_id
71
+ self.seller_id = seller_id
72
+ self.client_id = client_id
73
+ self.seller_profile_id = seller_profile_id
74
+
75
+ self._sync_client: Optional[httpx.Client] = None
76
+ self._async_client: Optional[httpx.AsyncClient] = None
77
+
78
+ # ------------------------------------------------------------------
79
+ # Internal HTTP helpers
80
+ # ------------------------------------------------------------------
81
+
82
+ def _get_sync_client(self) -> httpx.Client:
83
+ if self._sync_client is None:
84
+ self._sync_client = httpx.Client(timeout=self.timeout)
85
+ return self._sync_client
86
+
87
+ def _get_async_client(self) -> httpx.AsyncClient:
88
+ if self._async_client is None:
89
+ self._async_client = httpx.AsyncClient(timeout=self.timeout)
90
+ return self._async_client
91
+
92
+ def _headers(self) -> Dict[str, str]:
93
+ headers = {
94
+ "PIM-SID": self.api_key,
95
+ "Content-Type": "application/json",
96
+ }
97
+ if self.tenant_id:
98
+ headers["X-TENANT-ID"] = str(self.tenant_id)
99
+ if self.user_id:
100
+ headers["X-USER-ID"] = str(self.user_id)
101
+ if self.project_id:
102
+ headers["X-PROJECT-ID"] = str(self.project_id)
103
+ if self.seller_id:
104
+ headers["X-SELLER-ID"] = str(self.seller_id)
105
+ if self.client_id:
106
+ headers["X-CLIENT-ID"] = str(self.client_id)
107
+ if self.seller_profile_id:
108
+ headers["X-SELLER-PROFILE-ID"] = str(self.seller_profile_id)
109
+ return headers
110
+
111
+ def _request_with_retries_sync(
112
+ self, method: str, url: str, **kwargs: Any
113
+ ) -> httpx.Response:
114
+ """Make HTTP request with retry logic."""
115
+ client = self._get_sync_client()
116
+ last_exception = None
117
+
118
+ for attempt in range(self.max_retries + 1):
119
+ try:
120
+ response = client.request(method, url, **kwargs)
121
+ # Don't retry on 4xx errors (except 429)
122
+ if response.status_code < 500 or response.status_code == 429:
123
+ return response
124
+ last_exception = None
125
+ except (httpx.NetworkError, httpx.TimeoutException) as e:
126
+ last_exception = e
127
+ if attempt < self.max_retries:
128
+ wait_time = self.backoff_factor * (2 ** attempt)
129
+ time.sleep(wait_time)
130
+ else:
131
+ raise TraceError(f"Request failed after {self.max_retries + 1} attempts: {e}")
132
+
133
+ if last_exception:
134
+ raise TraceError(f"Request failed: {last_exception}")
135
+
136
+ # If we get here, it's a 5xx error that we should retry
137
+ raise TraceError(f"Request failed with status {response.status_code}")
138
+
139
+ # ------------------------------------------------------------------
140
+ # Request-level traces (RAG traces) - V1 API
141
+ # ------------------------------------------------------------------
142
+
143
+ def fetch_traces(
144
+ self,
145
+ filters: Optional[FetchTraceFilters] = None,
146
+ *,
147
+ page: int = 0,
148
+ size: int = 20,
149
+ sort: str = "id",
150
+ direction: str = "DESC",
151
+ ) -> PageableResponse:
152
+ """Fetch paginated request-level traces (RAG traces).
153
+
154
+ Args:
155
+ filters: Optional filters for querying traces.
156
+ page: Page number (0-indexed).
157
+ size: Page size.
158
+ sort: Sort field.
159
+ direction: Sort direction (ASC or DESC).
160
+
161
+ Returns:
162
+ PageableResponse containing paginated RagTrace objects.
163
+ """
164
+ url = self.base_url + TRACE_V1_FETCH_PATH
165
+ headers = self._headers()
166
+
167
+ # Build query parameters
168
+ params: Dict[str, Any] = {
169
+ "page": page,
170
+ "size": size,
171
+ "sort": f"{sort},{direction}",
172
+ }
173
+
174
+ if filters:
175
+ filters.tenant_id = filters.tenant_id or int(self.tenant_id)
176
+ params.update(filters.to_query_params())
177
+
178
+ response = self._request_with_retries_sync("GET", url, headers=headers, params=params)
179
+
180
+ if response.status_code != 200:
181
+ raise TraceError(f"Failed to fetch traces: {response.status_code} - {response.text}")
182
+
183
+ data = response.json()
184
+ return self._parse_pageable_response(data, RagTrace)
185
+
186
+ def get_trace_details(
187
+ self, *, trace_id: Optional[str] = None, request_id: Optional[str] = None
188
+ ) -> RagTrace:
189
+ """Get details of a specific request-level trace.
190
+
191
+ Args:
192
+ trace_id: Trace ID (optional, but either trace_id or request_id required).
193
+ request_id: Request ID (optional, but either trace_id or request_id required).
194
+
195
+ Returns:
196
+ RagTrace object with full details.
197
+ """
198
+ if not trace_id and not request_id:
199
+ raise ValueError("Either trace_id or request_id must be provided")
200
+
201
+ url = self.base_url + TRACE_V1_DETAILS_PATH
202
+ headers = self._headers()
203
+
204
+ params: Dict[str, Any] = {}
205
+ if trace_id:
206
+ params["traceId"] = trace_id
207
+ if request_id:
208
+ params["requestId"] = request_id
209
+
210
+ response = self._request_with_retries_sync("GET", url, headers=headers, params=params)
211
+
212
+ if response.status_code != 200:
213
+ raise TraceError(f"Failed to get trace details: {response.status_code} - {response.text}")
214
+
215
+ data = response.json()
216
+ if "data" in data:
217
+ trace_data = data["data"]
218
+ else:
219
+ trace_data = data
220
+
221
+ return self._parse_rag_trace(trace_data)
222
+
223
+ # ------------------------------------------------------------------
224
+ # Tree-based traces - V2 API
225
+ # ------------------------------------------------------------------
226
+
227
+ def fetch_trace_nodes(
228
+ self,
229
+ filters: Optional[TraceFilters] = None,
230
+ *,
231
+ page: int = 0,
232
+ size: int = 20,
233
+ sort: str = "id",
234
+ direction: str = "DESC",
235
+ ) -> PageableResponse:
236
+ """Fetch paginated tree-based trace nodes (root traces).
237
+
238
+ Args:
239
+ filters: Optional filters for querying traces.
240
+ page: Page number (0-indexed).
241
+ size: Page size.
242
+ sort: Sort field.
243
+ direction: Sort direction (ASC or DESC).
244
+
245
+ Returns:
246
+ PageableResponse containing paginated TraceNode objects.
247
+ """
248
+ url = self.base_url + TRACE_V2_FETCH_PATH
249
+ headers = self._headers()
250
+
251
+ params: Dict[str, Any] = {
252
+ "page": page,
253
+ "size": size,
254
+ "sort": f"{sort},{direction}",
255
+ }
256
+
257
+ if filters:
258
+ filters.tenant_id = filters.tenant_id or int(self.tenant_id)
259
+ filters.project_id = filters.project_id or (int(self.project_id) if self.project_id else None)
260
+ params.update(filters.to_query_params())
261
+
262
+ response = self._request_with_retries_sync("GET", url, headers=headers, params=params)
263
+
264
+ if response.status_code != 200:
265
+ raise TraceError(f"Failed to fetch trace nodes: {response.status_code} - {response.text}")
266
+
267
+ data = response.json()
268
+ return self._parse_pageable_response(data, TraceNode)
269
+
270
+ def get_aggregate_output(
271
+ self,
272
+ filters: Optional[TraceFilters] = None,
273
+ *,
274
+ page: int = 0,
275
+ size: int = 20,
276
+ sort: str = "id",
277
+ direction: str = "DESC",
278
+ ) -> Dict[str, Any]:
279
+ """Get aggregate output traces (summary view).
280
+
281
+ Args:
282
+ filters: Optional filters for querying traces.
283
+ page: Page number (0-indexed).
284
+ size: Page size.
285
+ sort: Sort field.
286
+ direction: Sort direction (ASC or DESC).
287
+
288
+ Returns:
289
+ Dictionary containing aggregate trace data.
290
+ """
291
+ url = self.base_url + TRACE_V2_AGGREGATE_OUTPUT_PATH
292
+ headers = self._headers()
293
+
294
+ params: Dict[str, Any] = {
295
+ "page": page,
296
+ "size": size,
297
+ "sort": f"{sort},{direction}",
298
+ }
299
+
300
+ if filters:
301
+ filters.tenant_id = filters.tenant_id or int(self.tenant_id)
302
+ filters.project_id = filters.project_id or (int(self.project_id) if self.project_id else None)
303
+ params.update(filters.to_query_params())
304
+
305
+ response = self._request_with_retries_sync("GET", url, headers=headers, params=params)
306
+
307
+ if response.status_code != 200:
308
+ raise TraceError(f"Failed to get aggregate output: {response.status_code} - {response.text}")
309
+
310
+ data = response.json()
311
+ if "data" in data:
312
+ return data["data"]
313
+ return data
314
+
315
+ def get_trace_tree(
316
+ self, tree_id: str, node_id: str, *, max_depth: Optional[int] = None, raw_response: bool = False
317
+ ) -> Union[TraceNode, Dict[str, Any]]:
318
+ """Get a trace sub-tree starting from a specific node.
319
+
320
+ Args:
321
+ tree_id: Tree ID of the trace.
322
+ node_id: Node ID to start the tree from.
323
+ max_depth: Maximum depth to traverse (optional).
324
+ raw_response: If True, return raw JSON response instead of parsed TraceNode.
325
+
326
+ Returns:
327
+ TraceNode representing the root of the sub-tree, or raw JSON dict if raw_response=True.
328
+ """
329
+ url = self.base_url + TRACE_V2_TREE_PATH
330
+ headers = self._headers()
331
+
332
+ params: Dict[str, Any] = {
333
+ "treeId": tree_id,
334
+ "nodeId": node_id,
335
+ }
336
+ if max_depth:
337
+ params["maxDepth"] = max_depth
338
+
339
+ # Include empty data payload like the working cURL
340
+ data = {"query": "", "variables": {}}
341
+
342
+ response = self._request_with_retries_sync("GET", url, headers=headers, params=params, json=data)
343
+
344
+ if response.status_code != 200:
345
+ error_detail = response.text
346
+ try:
347
+ error_json = response.json()
348
+ if isinstance(error_json, dict):
349
+ error_detail = json.dumps(error_json, indent=2)
350
+ except:
351
+ pass
352
+ raise TraceError(
353
+ f"Failed to get trace tree: {response.status_code} - {error_detail}\n"
354
+ f"URL: {url}\n"
355
+ f"Params: treeId={tree_id}, nodeId={node_id}"
356
+ )
357
+
358
+ data = response.json()
359
+
360
+ if raw_response:
361
+ return data
362
+
363
+ if "result" in data:
364
+ node_data = data["result"]
365
+ elif "data" in data:
366
+ node_data = data["data"]
367
+ else:
368
+ node_data = data
369
+
370
+ return self._parse_trace_node(node_data)
371
+
372
+ def get_trace_node_details(self, tree_id: str, node_id: str) -> TraceNode:
373
+ """Get details of a specific trace node.
374
+
375
+ Args:
376
+ tree_id: Tree ID of the trace.
377
+ node_id: Node ID to get details for.
378
+
379
+ Returns:
380
+ TraceNode object with full details.
381
+ """
382
+ url = self.base_url + TRACE_V2_DETAILS_PATH
383
+ headers = self._headers()
384
+
385
+ params = {
386
+ "treeId": tree_id,
387
+ "nodeId": node_id,
388
+ }
389
+
390
+ response = self._request_with_retries_sync("GET", url, headers=headers, params=params)
391
+
392
+ if response.status_code != 200:
393
+ raise TraceError(f"Failed to get trace node details: {response.status_code} - {response.text}")
394
+
395
+ data = response.json()
396
+
397
+ if raw_response:
398
+ return data
399
+
400
+ if "result" in data:
401
+ node_data = data["result"]
402
+ elif "data" in data:
403
+ node_data = data["data"]
404
+ else:
405
+ node_data = data
406
+
407
+ return self._parse_trace_node(node_data)
408
+
409
+ def download_traces(
410
+ self, filters: Optional[TraceFilters] = None
411
+ ) -> bytes:
412
+ """Download traces as CSV.
413
+
414
+ Args:
415
+ filters: Optional filters for querying traces.
416
+
417
+ Returns:
418
+ Bytes containing CSV file content.
419
+ """
420
+ url = self.base_url + TRACE_V2_DOWNLOAD_PATH
421
+ headers = self._headers()
422
+
423
+ params: Dict[str, Any] = {}
424
+ if filters:
425
+ filters.tenant_id = filters.tenant_id or int(self.tenant_id)
426
+ filters.project_id = filters.project_id or (int(self.project_id) if self.project_id else None)
427
+ params.update(filters.to_query_params())
428
+
429
+ # POST request with filters in body
430
+ body = {}
431
+ if filters:
432
+ body = {k: v for k, v in filters.to_query_params().items() if v}
433
+
434
+ response = self._request_with_retries_sync("POST", url, headers=headers, json=body, params=params)
435
+
436
+ if response.status_code != 200:
437
+ raise TraceError(f"Failed to download traces: {response.status_code} - {response.text}")
438
+
439
+ return response.content
440
+
441
+ # ------------------------------------------------------------------
442
+ # Metrics API
443
+ # ------------------------------------------------------------------
444
+
445
+ def get_agent_metrics(
446
+ self, metrics_request: Optional[MetricsRequestDto] = None
447
+ ) -> Dict[str, Any]:
448
+ """Get agent-level metrics from traces.
449
+
450
+ Args:
451
+ metrics_request: Optional metrics request with filters.
452
+
453
+ Returns:
454
+ Dictionary containing aggregated agent metrics.
455
+ """
456
+ url = self.base_url + METRICS_V1_AGENT_PATH
457
+ headers = self._headers()
458
+
459
+ params: Dict[str, Any] = {}
460
+ if metrics_request:
461
+ metrics_request.tenant_id = metrics_request.tenant_id or int(self.tenant_id)
462
+ metrics_request.project_id = metrics_request.project_id or (int(self.project_id) if self.project_id else None)
463
+ params.update(metrics_request.to_query_params())
464
+
465
+ response = self._request_with_retries_sync("GET", url, headers=headers, params=params)
466
+
467
+ if response.status_code != 200:
468
+ raise TraceError(f"Failed to get agent metrics: {response.status_code} - {response.text}")
469
+
470
+ data = response.json()
471
+ if "data" in data:
472
+ return data["data"]
473
+ return data
474
+
475
+ def get_tool_metrics(
476
+ self, metrics_request: Optional[MetricsRequestDto] = None
477
+ ) -> Dict[str, Any]:
478
+ """Get tool-level metrics from traces.
479
+
480
+ Args:
481
+ metrics_request: Optional metrics request with filters.
482
+
483
+ Returns:
484
+ Dictionary containing aggregated tool metrics.
485
+ """
486
+ url = self.base_url + METRICS_V1_TOOL_PATH
487
+ headers = self._headers()
488
+
489
+ params: Dict[str, Any] = {}
490
+ if metrics_request:
491
+ metrics_request.tenant_id = metrics_request.tenant_id or int(self.tenant_id)
492
+ metrics_request.project_id = metrics_request.project_id or (int(self.project_id) if self.project_id else None)
493
+ params.update(metrics_request.to_query_params())
494
+
495
+ response = self._request_with_retries_sync("GET", url, headers=headers, params=params)
496
+
497
+ if response.status_code != 200:
498
+ raise TraceError(f"Failed to get tool metrics: {response.status_code} - {response.text}")
499
+
500
+ data = response.json()
501
+ if "data" in data:
502
+ return data["data"]
503
+ return data
504
+
505
+ # ------------------------------------------------------------------
506
+ # Response parsing helpers
507
+ # ------------------------------------------------------------------
508
+
509
+ def _parse_pageable_response(self, data: Dict[str, Any], item_class: type) -> PageableResponse:
510
+ """Parse paginated response from API."""
511
+ if "data" in data:
512
+ page_data = data["data"]
513
+ else:
514
+ page_data = data
515
+
516
+ # Extract pagination info
517
+ content = page_data.get("content", [])
518
+ total_elements = page_data.get("totalElements", len(content))
519
+ total_pages = page_data.get("totalPages", 1)
520
+ page_number = page_data.get("number", 0)
521
+ page_size = page_data.get("size", len(content))
522
+
523
+ # Parse items
524
+ parsed_items = []
525
+ for item in content:
526
+ if item_class == RagTrace:
527
+ parsed_items.append(self._parse_rag_trace(item))
528
+ elif item_class == TraceNode:
529
+ parsed_items.append(self._parse_trace_node(item))
530
+ else:
531
+ parsed_items.append(item)
532
+
533
+ return PageableResponse(
534
+ success=data.get("success", True),
535
+ status=data.get("status", "OK"),
536
+ data=parsed_items,
537
+ total_elements=total_elements,
538
+ total_pages=total_pages,
539
+ page_number=page_number,
540
+ page_size=page_size,
541
+ )
542
+
543
+ def _parse_rag_trace(self, data: Dict[str, Any]) -> RagTrace:
544
+ """Parse RagTrace from API response."""
545
+ # Handle datetime strings
546
+ from datetime import datetime
547
+
548
+ created_at = None
549
+ updated_at = None
550
+ if "created_at" in data and data["created_at"]:
551
+ try:
552
+ created_at = datetime.fromisoformat(data["created_at"].replace("Z", "+00:00"))
553
+ except:
554
+ pass
555
+
556
+ if "updated_at" in data and data["updated_at"]:
557
+ try:
558
+ updated_at = datetime.fromisoformat(data["updated_at"].replace("Z", "+00:00"))
559
+ except:
560
+ pass
561
+
562
+ # Parse eval_metrics if present
563
+ eval_metrics = None
564
+ if "eval_metrics" in data and data["eval_metrics"]:
565
+ from .models import EvalMetrics
566
+ eval_metrics = [EvalMetrics(**metric) if isinstance(metric, dict) else metric for metric in data["eval_metrics"]]
567
+
568
+ return RagTrace(
569
+ id=data.get("id"),
570
+ request_id=data.get("request_id"),
571
+ conversation_id=data.get("conversation_id"),
572
+ user_id=data.get("user_id"),
573
+ username=data.get("username"),
574
+ tenant_id=data.get("tenant_id"),
575
+ workflow_id=data.get("workflow_id"),
576
+ workflow_name=data.get("workflow_name"),
577
+ agent_id=data.get("agent_id"),
578
+ agent_name=data.get("agent_name"),
579
+ kb_id=data.get("kb_id"),
580
+ kb_name=data.get("kb_name"),
581
+ model_id=data.get("model_id"),
582
+ model_name=data.get("model_name"),
583
+ model_input_tokens=data.get("model_input_tokens"),
584
+ model_output_tokens=data.get("model_output_tokens"),
585
+ kb_query_tokens=data.get("kb_query_tokens"),
586
+ context_tokens=data.get("context_tokens"),
587
+ query=data.get("query"),
588
+ prompt=data.get("prompt"),
589
+ context=data.get("context"),
590
+ response=data.get("response"),
591
+ response_time=data.get("response_time"),
592
+ model_response_time=data.get("model_response_time"),
593
+ kb_response_time=data.get("kb_response_time"),
594
+ user_feedback=data.get("user_feedback"),
595
+ source=data.get("source", []),
596
+ created_at=created_at,
597
+ updated_at=updated_at,
598
+ evaluation_started_by=data.get("evaluation_started_by"),
599
+ evaluation_started_by_name=data.get("evaluation_started_by_name"),
600
+ eval_metrics=eval_metrics,
601
+ status=data.get("status"),
602
+ ground_truth=data.get("ground_truth"),
603
+ payload=data,
604
+ )
605
+
606
+ def _parse_trace_node(self, data: Dict[str, Any]) -> TraceNode:
607
+ """Parse TraceNode from API response."""
608
+ from datetime import datetime
609
+ from .models import TraceData, ChildInfo
610
+
611
+ # Handle datetime strings
612
+ created_at = None
613
+ execution_time = None
614
+ if "created_at" in data and data["created_at"]:
615
+ try:
616
+ created_at = datetime.fromisoformat(data["created_at"].replace("Z", "+00:00"))
617
+ except:
618
+ pass
619
+
620
+ if "execution_time" in data and data["execution_time"]:
621
+ try:
622
+ execution_time = datetime.fromisoformat(data["execution_time"].replace("Z", "+00:00"))
623
+ except:
624
+ pass
625
+
626
+ # Parse nested data (only pass known TraceData keys; API may return extra fields e.g. user_message)
627
+ trace_data = None
628
+ if "data" in data and data["data"]:
629
+ d = data["data"]
630
+ if isinstance(d, dict):
631
+ known_data_keys = {"input", "output", "error", "metadata"}
632
+ filtered = {k: d[k] for k in known_data_keys if k in d}
633
+ try:
634
+ trace_data = TraceData(**filtered)
635
+ except (TypeError, ValueError):
636
+ trace_data = None # keep raw in payload
637
+ else:
638
+ trace_data = d
639
+
640
+ # Parse children recursively
641
+ children = None
642
+ if "children" in data and data["children"]:
643
+ children = [self._parse_trace_node(child) for child in data["children"]]
644
+
645
+ # Parse descendants (only pass known ChildInfo keys)
646
+ descendants = None
647
+ if "descendants" in data and data["descendants"]:
648
+ result_descendants = []
649
+ for d in data["descendants"]:
650
+ if isinstance(d, dict):
651
+ known_child_keys = {"node_id", "name", "entity_type"}
652
+ filtered_child = {k: d[k] for k in known_child_keys if k in d}
653
+ try:
654
+ result_descendants.append(ChildInfo(**filtered_child))
655
+ except (TypeError, ValueError):
656
+ result_descendants.append(d)
657
+ else:
658
+ result_descendants.append(d)
659
+ descendants = result_descendants
660
+
661
+ return TraceNode(
662
+ id=data.get("id"),
663
+ tenant_id=data.get("tenant_id"),
664
+ project_id=data.get("project_id"),
665
+ user_id=data.get("user_id"),
666
+ source=data.get("source"),
667
+ is_bulk_run_parent=data.get("is_bulk_run_parent"),
668
+ node_id=data.get("node_id"),
669
+ name=data.get("name"),
670
+ entity_type=data.get("entity_type"),
671
+ entity_id=data.get("entity_id"),
672
+ tree_id=data.get("tree_id"),
673
+ parent_id=data.get("parent_id"),
674
+ parent_span_id=data.get("parent_span_id"),
675
+ children=children,
676
+ descendants=descendants,
677
+ path=data.get("path"),
678
+ data=trace_data,
679
+ has_children=data.get("has_children"),
680
+ created_at=created_at,
681
+ execution_time=execution_time,
682
+ profile_id=data.get("profile_id"),
683
+ profile=data.get("profile"),
684
+ is_root=data.get("is_root"),
685
+ sub_entity_id=data.get("sub_entity_id"),
686
+ completion_tokens=data.get("completion_tokens"),
687
+ prompt_tokens=data.get("prompt_tokens"),
688
+ total_tokens=data.get("total_tokens"),
689
+ latency=data.get("latency"),
690
+ context_window_utilization=data.get("context_window_utilization"),
691
+ status=data.get("status"),
692
+ origin=data.get("origin"),
693
+ app_id=data.get("app_id"),
694
+ child_count=data.get("child_count"),
695
+ payload=data,
696
+ )
697
+