aeri-python 4.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (391) hide show
  1. aeri/__init__.py +72 -0
  2. aeri/_client/_validation.py +204 -0
  3. aeri/_client/attributes.py +188 -0
  4. aeri/_client/client.py +3761 -0
  5. aeri/_client/constants.py +65 -0
  6. aeri/_client/datasets.py +302 -0
  7. aeri/_client/environment_variables.py +158 -0
  8. aeri/_client/get_client.py +149 -0
  9. aeri/_client/observe.py +661 -0
  10. aeri/_client/propagation.py +475 -0
  11. aeri/_client/resource_manager.py +510 -0
  12. aeri/_client/span.py +1519 -0
  13. aeri/_client/span_filter.py +76 -0
  14. aeri/_client/span_processor.py +206 -0
  15. aeri/_client/utils.py +132 -0
  16. aeri/_task_manager/media_manager.py +331 -0
  17. aeri/_task_manager/media_upload_consumer.py +44 -0
  18. aeri/_task_manager/media_upload_queue.py +12 -0
  19. aeri/_task_manager/score_ingestion_consumer.py +208 -0
  20. aeri/_task_manager/task_manager.py +475 -0
  21. aeri/_utils/__init__.py +19 -0
  22. aeri/_utils/environment.py +34 -0
  23. aeri/_utils/error_logging.py +47 -0
  24. aeri/_utils/parse_error.py +99 -0
  25. aeri/_utils/prompt_cache.py +188 -0
  26. aeri/_utils/request.py +137 -0
  27. aeri/_utils/serializer.py +205 -0
  28. aeri/api/.fern/metadata.json +14 -0
  29. aeri/api/__init__.py +836 -0
  30. aeri/api/annotation_queues/__init__.py +82 -0
  31. aeri/api/annotation_queues/client.py +1111 -0
  32. aeri/api/annotation_queues/raw_client.py +2288 -0
  33. aeri/api/annotation_queues/types/__init__.py +84 -0
  34. aeri/api/annotation_queues/types/annotation_queue.py +28 -0
  35. aeri/api/annotation_queues/types/annotation_queue_assignment_request.py +16 -0
  36. aeri/api/annotation_queues/types/annotation_queue_item.py +34 -0
  37. aeri/api/annotation_queues/types/annotation_queue_object_type.py +26 -0
  38. aeri/api/annotation_queues/types/annotation_queue_status.py +22 -0
  39. aeri/api/annotation_queues/types/create_annotation_queue_assignment_response.py +18 -0
  40. aeri/api/annotation_queues/types/create_annotation_queue_item_request.py +25 -0
  41. aeri/api/annotation_queues/types/create_annotation_queue_request.py +20 -0
  42. aeri/api/annotation_queues/types/delete_annotation_queue_assignment_response.py +14 -0
  43. aeri/api/annotation_queues/types/delete_annotation_queue_item_response.py +15 -0
  44. aeri/api/annotation_queues/types/paginated_annotation_queue_items.py +17 -0
  45. aeri/api/annotation_queues/types/paginated_annotation_queues.py +17 -0
  46. aeri/api/annotation_queues/types/update_annotation_queue_item_request.py +15 -0
  47. aeri/api/blob_storage_integrations/__init__.py +73 -0
  48. aeri/api/blob_storage_integrations/client.py +550 -0
  49. aeri/api/blob_storage_integrations/raw_client.py +976 -0
  50. aeri/api/blob_storage_integrations/types/__init__.py +77 -0
  51. aeri/api/blob_storage_integrations/types/blob_storage_export_frequency.py +26 -0
  52. aeri/api/blob_storage_integrations/types/blob_storage_export_mode.py +26 -0
  53. aeri/api/blob_storage_integrations/types/blob_storage_integration_deletion_response.py +14 -0
  54. aeri/api/blob_storage_integrations/types/blob_storage_integration_file_type.py +26 -0
  55. aeri/api/blob_storage_integrations/types/blob_storage_integration_response.py +64 -0
  56. aeri/api/blob_storage_integrations/types/blob_storage_integration_status_response.py +50 -0
  57. aeri/api/blob_storage_integrations/types/blob_storage_integration_type.py +26 -0
  58. aeri/api/blob_storage_integrations/types/blob_storage_integrations_response.py +15 -0
  59. aeri/api/blob_storage_integrations/types/blob_storage_sync_status.py +47 -0
  60. aeri/api/blob_storage_integrations/types/create_blob_storage_integration_request.py +91 -0
  61. aeri/api/client.py +679 -0
  62. aeri/api/comments/__init__.py +44 -0
  63. aeri/api/comments/client.py +407 -0
  64. aeri/api/comments/raw_client.py +750 -0
  65. aeri/api/comments/types/__init__.py +46 -0
  66. aeri/api/comments/types/create_comment_request.py +47 -0
  67. aeri/api/comments/types/create_comment_response.py +17 -0
  68. aeri/api/comments/types/get_comments_response.py +17 -0
  69. aeri/api/commons/__init__.py +210 -0
  70. aeri/api/commons/errors/__init__.py +56 -0
  71. aeri/api/commons/errors/access_denied_error.py +12 -0
  72. aeri/api/commons/errors/error.py +12 -0
  73. aeri/api/commons/errors/method_not_allowed_error.py +12 -0
  74. aeri/api/commons/errors/not_found_error.py +12 -0
  75. aeri/api/commons/errors/unauthorized_error.py +12 -0
  76. aeri/api/commons/types/__init__.py +190 -0
  77. aeri/api/commons/types/base_score.py +90 -0
  78. aeri/api/commons/types/base_score_v1.py +70 -0
  79. aeri/api/commons/types/boolean_score.py +26 -0
  80. aeri/api/commons/types/boolean_score_v1.py +26 -0
  81. aeri/api/commons/types/categorical_score.py +26 -0
  82. aeri/api/commons/types/categorical_score_v1.py +26 -0
  83. aeri/api/commons/types/comment.py +36 -0
  84. aeri/api/commons/types/comment_object_type.py +30 -0
  85. aeri/api/commons/types/config_category.py +15 -0
  86. aeri/api/commons/types/correction_score.py +26 -0
  87. aeri/api/commons/types/create_score_value.py +5 -0
  88. aeri/api/commons/types/dataset.py +49 -0
  89. aeri/api/commons/types/dataset_item.py +58 -0
  90. aeri/api/commons/types/dataset_run.py +63 -0
  91. aeri/api/commons/types/dataset_run_item.py +40 -0
  92. aeri/api/commons/types/dataset_run_with_items.py +19 -0
  93. aeri/api/commons/types/dataset_status.py +22 -0
  94. aeri/api/commons/types/map_value.py +11 -0
  95. aeri/api/commons/types/model.py +125 -0
  96. aeri/api/commons/types/model_price.py +14 -0
  97. aeri/api/commons/types/model_usage_unit.py +42 -0
  98. aeri/api/commons/types/numeric_score.py +17 -0
  99. aeri/api/commons/types/numeric_score_v1.py +17 -0
  100. aeri/api/commons/types/observation.py +142 -0
  101. aeri/api/commons/types/observation_level.py +30 -0
  102. aeri/api/commons/types/observation_v2.py +235 -0
  103. aeri/api/commons/types/observations_view.py +89 -0
  104. aeri/api/commons/types/pricing_tier.py +91 -0
  105. aeri/api/commons/types/pricing_tier_condition.py +68 -0
  106. aeri/api/commons/types/pricing_tier_input.py +76 -0
  107. aeri/api/commons/types/pricing_tier_operator.py +42 -0
  108. aeri/api/commons/types/score.py +201 -0
  109. aeri/api/commons/types/score_config.py +66 -0
  110. aeri/api/commons/types/score_config_data_type.py +26 -0
  111. aeri/api/commons/types/score_data_type.py +30 -0
  112. aeri/api/commons/types/score_source.py +26 -0
  113. aeri/api/commons/types/score_v1.py +131 -0
  114. aeri/api/commons/types/session.py +25 -0
  115. aeri/api/commons/types/session_with_traces.py +15 -0
  116. aeri/api/commons/types/trace.py +84 -0
  117. aeri/api/commons/types/trace_with_details.py +43 -0
  118. aeri/api/commons/types/trace_with_full_details.py +45 -0
  119. aeri/api/commons/types/usage.py +59 -0
  120. aeri/api/core/__init__.py +111 -0
  121. aeri/api/core/api_error.py +23 -0
  122. aeri/api/core/client_wrapper.py +141 -0
  123. aeri/api/core/datetime_utils.py +30 -0
  124. aeri/api/core/enum.py +20 -0
  125. aeri/api/core/file.py +70 -0
  126. aeri/api/core/force_multipart.py +18 -0
  127. aeri/api/core/http_client.py +711 -0
  128. aeri/api/core/http_response.py +55 -0
  129. aeri/api/core/http_sse/__init__.py +48 -0
  130. aeri/api/core/http_sse/_api.py +114 -0
  131. aeri/api/core/http_sse/_decoders.py +66 -0
  132. aeri/api/core/http_sse/_exceptions.py +7 -0
  133. aeri/api/core/http_sse/_models.py +17 -0
  134. aeri/api/core/jsonable_encoder.py +102 -0
  135. aeri/api/core/pydantic_utilities.py +310 -0
  136. aeri/api/core/query_encoder.py +60 -0
  137. aeri/api/core/remove_none_from_dict.py +11 -0
  138. aeri/api/core/request_options.py +35 -0
  139. aeri/api/core/serialization.py +282 -0
  140. aeri/api/dataset_items/__init__.py +52 -0
  141. aeri/api/dataset_items/client.py +499 -0
  142. aeri/api/dataset_items/raw_client.py +973 -0
  143. aeri/api/dataset_items/types/__init__.py +50 -0
  144. aeri/api/dataset_items/types/create_dataset_item_request.py +37 -0
  145. aeri/api/dataset_items/types/delete_dataset_item_response.py +17 -0
  146. aeri/api/dataset_items/types/paginated_dataset_items.py +17 -0
  147. aeri/api/dataset_run_items/__init__.py +43 -0
  148. aeri/api/dataset_run_items/client.py +323 -0
  149. aeri/api/dataset_run_items/raw_client.py +547 -0
  150. aeri/api/dataset_run_items/types/__init__.py +44 -0
  151. aeri/api/dataset_run_items/types/create_dataset_run_item_request.py +51 -0
  152. aeri/api/dataset_run_items/types/paginated_dataset_run_items.py +17 -0
  153. aeri/api/datasets/__init__.py +55 -0
  154. aeri/api/datasets/client.py +661 -0
  155. aeri/api/datasets/raw_client.py +1368 -0
  156. aeri/api/datasets/types/__init__.py +53 -0
  157. aeri/api/datasets/types/create_dataset_request.py +31 -0
  158. aeri/api/datasets/types/delete_dataset_run_response.py +14 -0
  159. aeri/api/datasets/types/paginated_dataset_runs.py +17 -0
  160. aeri/api/datasets/types/paginated_datasets.py +17 -0
  161. aeri/api/health/__init__.py +44 -0
  162. aeri/api/health/client.py +112 -0
  163. aeri/api/health/errors/__init__.py +42 -0
  164. aeri/api/health/errors/service_unavailable_error.py +13 -0
  165. aeri/api/health/raw_client.py +227 -0
  166. aeri/api/health/types/__init__.py +40 -0
  167. aeri/api/health/types/health_response.py +30 -0
  168. aeri/api/ingestion/__init__.py +169 -0
  169. aeri/api/ingestion/client.py +221 -0
  170. aeri/api/ingestion/raw_client.py +293 -0
  171. aeri/api/ingestion/types/__init__.py +169 -0
  172. aeri/api/ingestion/types/base_event.py +27 -0
  173. aeri/api/ingestion/types/create_event_body.py +14 -0
  174. aeri/api/ingestion/types/create_event_event.py +15 -0
  175. aeri/api/ingestion/types/create_generation_body.py +40 -0
  176. aeri/api/ingestion/types/create_generation_event.py +15 -0
  177. aeri/api/ingestion/types/create_observation_event.py +15 -0
  178. aeri/api/ingestion/types/create_span_body.py +19 -0
  179. aeri/api/ingestion/types/create_span_event.py +15 -0
  180. aeri/api/ingestion/types/ingestion_error.py +17 -0
  181. aeri/api/ingestion/types/ingestion_event.py +155 -0
  182. aeri/api/ingestion/types/ingestion_response.py +17 -0
  183. aeri/api/ingestion/types/ingestion_success.py +15 -0
  184. aeri/api/ingestion/types/ingestion_usage.py +8 -0
  185. aeri/api/ingestion/types/observation_body.py +53 -0
  186. aeri/api/ingestion/types/observation_type.py +54 -0
  187. aeri/api/ingestion/types/open_ai_completion_usage_schema.py +26 -0
  188. aeri/api/ingestion/types/open_ai_response_usage_schema.py +24 -0
  189. aeri/api/ingestion/types/open_ai_usage.py +28 -0
  190. aeri/api/ingestion/types/optional_observation_body.py +36 -0
  191. aeri/api/ingestion/types/score_body.py +75 -0
  192. aeri/api/ingestion/types/score_event.py +15 -0
  193. aeri/api/ingestion/types/sdk_log_body.py +14 -0
  194. aeri/api/ingestion/types/sdk_log_event.py +15 -0
  195. aeri/api/ingestion/types/trace_body.py +36 -0
  196. aeri/api/ingestion/types/trace_event.py +15 -0
  197. aeri/api/ingestion/types/update_event_body.py +14 -0
  198. aeri/api/ingestion/types/update_generation_body.py +40 -0
  199. aeri/api/ingestion/types/update_generation_event.py +15 -0
  200. aeri/api/ingestion/types/update_observation_event.py +15 -0
  201. aeri/api/ingestion/types/update_span_body.py +19 -0
  202. aeri/api/ingestion/types/update_span_event.py +15 -0
  203. aeri/api/ingestion/types/usage_details.py +10 -0
  204. aeri/api/legacy/__init__.py +61 -0
  205. aeri/api/legacy/client.py +105 -0
  206. aeri/api/legacy/metrics_v1/__init__.py +40 -0
  207. aeri/api/legacy/metrics_v1/client.py +214 -0
  208. aeri/api/legacy/metrics_v1/raw_client.py +322 -0
  209. aeri/api/legacy/metrics_v1/types/__init__.py +40 -0
  210. aeri/api/legacy/metrics_v1/types/metrics_response.py +19 -0
  211. aeri/api/legacy/observations_v1/__init__.py +43 -0
  212. aeri/api/legacy/observations_v1/client.py +523 -0
  213. aeri/api/legacy/observations_v1/raw_client.py +759 -0
  214. aeri/api/legacy/observations_v1/types/__init__.py +44 -0
  215. aeri/api/legacy/observations_v1/types/observations.py +17 -0
  216. aeri/api/legacy/observations_v1/types/observations_views.py +17 -0
  217. aeri/api/legacy/raw_client.py +13 -0
  218. aeri/api/legacy/score_v1/__init__.py +43 -0
  219. aeri/api/legacy/score_v1/client.py +329 -0
  220. aeri/api/legacy/score_v1/raw_client.py +545 -0
  221. aeri/api/legacy/score_v1/types/__init__.py +44 -0
  222. aeri/api/legacy/score_v1/types/create_score_request.py +75 -0
  223. aeri/api/legacy/score_v1/types/create_score_response.py +17 -0
  224. aeri/api/llm_connections/__init__.py +55 -0
  225. aeri/api/llm_connections/client.py +311 -0
  226. aeri/api/llm_connections/raw_client.py +541 -0
  227. aeri/api/llm_connections/types/__init__.py +53 -0
  228. aeri/api/llm_connections/types/llm_adapter.py +38 -0
  229. aeri/api/llm_connections/types/llm_connection.py +77 -0
  230. aeri/api/llm_connections/types/paginated_llm_connections.py +17 -0
  231. aeri/api/llm_connections/types/upsert_llm_connection_request.py +69 -0
  232. aeri/api/media/__init__.py +58 -0
  233. aeri/api/media/client.py +427 -0
  234. aeri/api/media/raw_client.py +739 -0
  235. aeri/api/media/types/__init__.py +56 -0
  236. aeri/api/media/types/get_media_response.py +55 -0
  237. aeri/api/media/types/get_media_upload_url_request.py +51 -0
  238. aeri/api/media/types/get_media_upload_url_response.py +28 -0
  239. aeri/api/media/types/media_content_type.py +232 -0
  240. aeri/api/media/types/patch_media_body.py +43 -0
  241. aeri/api/metrics/__init__.py +40 -0
  242. aeri/api/metrics/client.py +422 -0
  243. aeri/api/metrics/raw_client.py +530 -0
  244. aeri/api/metrics/types/__init__.py +40 -0
  245. aeri/api/metrics/types/metrics_v2response.py +19 -0
  246. aeri/api/models/__init__.py +43 -0
  247. aeri/api/models/client.py +523 -0
  248. aeri/api/models/raw_client.py +993 -0
  249. aeri/api/models/types/__init__.py +44 -0
  250. aeri/api/models/types/create_model_request.py +103 -0
  251. aeri/api/models/types/paginated_models.py +17 -0
  252. aeri/api/observations/__init__.py +43 -0
  253. aeri/api/observations/client.py +522 -0
  254. aeri/api/observations/raw_client.py +641 -0
  255. aeri/api/observations/types/__init__.py +44 -0
  256. aeri/api/observations/types/observations_v2meta.py +21 -0
  257. aeri/api/observations/types/observations_v2response.py +28 -0
  258. aeri/api/opentelemetry/__init__.py +67 -0
  259. aeri/api/opentelemetry/client.py +276 -0
  260. aeri/api/opentelemetry/raw_client.py +291 -0
  261. aeri/api/opentelemetry/types/__init__.py +65 -0
  262. aeri/api/opentelemetry/types/otel_attribute.py +27 -0
  263. aeri/api/opentelemetry/types/otel_attribute_value.py +46 -0
  264. aeri/api/opentelemetry/types/otel_resource.py +24 -0
  265. aeri/api/opentelemetry/types/otel_resource_span.py +32 -0
  266. aeri/api/opentelemetry/types/otel_scope.py +34 -0
  267. aeri/api/opentelemetry/types/otel_scope_span.py +28 -0
  268. aeri/api/opentelemetry/types/otel_span.py +76 -0
  269. aeri/api/opentelemetry/types/otel_trace_response.py +16 -0
  270. aeri/api/organizations/__init__.py +73 -0
  271. aeri/api/organizations/client.py +756 -0
  272. aeri/api/organizations/raw_client.py +1707 -0
  273. aeri/api/organizations/types/__init__.py +71 -0
  274. aeri/api/organizations/types/delete_membership_request.py +16 -0
  275. aeri/api/organizations/types/membership_deletion_response.py +17 -0
  276. aeri/api/organizations/types/membership_request.py +18 -0
  277. aeri/api/organizations/types/membership_response.py +20 -0
  278. aeri/api/organizations/types/membership_role.py +30 -0
  279. aeri/api/organizations/types/memberships_response.py +15 -0
  280. aeri/api/organizations/types/organization_api_key.py +31 -0
  281. aeri/api/organizations/types/organization_api_keys_response.py +19 -0
  282. aeri/api/organizations/types/organization_project.py +25 -0
  283. aeri/api/organizations/types/organization_projects_response.py +15 -0
  284. aeri/api/projects/__init__.py +67 -0
  285. aeri/api/projects/client.py +760 -0
  286. aeri/api/projects/raw_client.py +1577 -0
  287. aeri/api/projects/types/__init__.py +65 -0
  288. aeri/api/projects/types/api_key_deletion_response.py +18 -0
  289. aeri/api/projects/types/api_key_list.py +23 -0
  290. aeri/api/projects/types/api_key_response.py +30 -0
  291. aeri/api/projects/types/api_key_summary.py +35 -0
  292. aeri/api/projects/types/organization.py +22 -0
  293. aeri/api/projects/types/project.py +34 -0
  294. aeri/api/projects/types/project_deletion_response.py +15 -0
  295. aeri/api/projects/types/projects.py +15 -0
  296. aeri/api/prompt_version/__init__.py +4 -0
  297. aeri/api/prompt_version/client.py +157 -0
  298. aeri/api/prompt_version/raw_client.py +264 -0
  299. aeri/api/prompts/__init__.py +100 -0
  300. aeri/api/prompts/client.py +550 -0
  301. aeri/api/prompts/raw_client.py +987 -0
  302. aeri/api/prompts/types/__init__.py +96 -0
  303. aeri/api/prompts/types/base_prompt.py +42 -0
  304. aeri/api/prompts/types/chat_message.py +17 -0
  305. aeri/api/prompts/types/chat_message_type.py +15 -0
  306. aeri/api/prompts/types/chat_message_with_placeholders.py +8 -0
  307. aeri/api/prompts/types/chat_prompt.py +15 -0
  308. aeri/api/prompts/types/create_chat_prompt_request.py +37 -0
  309. aeri/api/prompts/types/create_chat_prompt_type.py +15 -0
  310. aeri/api/prompts/types/create_prompt_request.py +8 -0
  311. aeri/api/prompts/types/create_text_prompt_request.py +36 -0
  312. aeri/api/prompts/types/create_text_prompt_type.py +15 -0
  313. aeri/api/prompts/types/placeholder_message.py +16 -0
  314. aeri/api/prompts/types/placeholder_message_type.py +15 -0
  315. aeri/api/prompts/types/prompt.py +58 -0
  316. aeri/api/prompts/types/prompt_meta.py +35 -0
  317. aeri/api/prompts/types/prompt_meta_list_response.py +17 -0
  318. aeri/api/prompts/types/prompt_type.py +20 -0
  319. aeri/api/prompts/types/text_prompt.py +14 -0
  320. aeri/api/scim/__init__.py +94 -0
  321. aeri/api/scim/client.py +686 -0
  322. aeri/api/scim/raw_client.py +1528 -0
  323. aeri/api/scim/types/__init__.py +92 -0
  324. aeri/api/scim/types/authentication_scheme.py +20 -0
  325. aeri/api/scim/types/bulk_config.py +22 -0
  326. aeri/api/scim/types/empty_response.py +16 -0
  327. aeri/api/scim/types/filter_config.py +17 -0
  328. aeri/api/scim/types/resource_meta.py +17 -0
  329. aeri/api/scim/types/resource_type.py +27 -0
  330. aeri/api/scim/types/resource_types_response.py +21 -0
  331. aeri/api/scim/types/schema_extension.py +17 -0
  332. aeri/api/scim/types/schema_resource.py +19 -0
  333. aeri/api/scim/types/schemas_response.py +21 -0
  334. aeri/api/scim/types/scim_email.py +16 -0
  335. aeri/api/scim/types/scim_feature_support.py +14 -0
  336. aeri/api/scim/types/scim_name.py +14 -0
  337. aeri/api/scim/types/scim_user.py +24 -0
  338. aeri/api/scim/types/scim_users_list_response.py +25 -0
  339. aeri/api/scim/types/service_provider_config.py +36 -0
  340. aeri/api/scim/types/user_meta.py +20 -0
  341. aeri/api/score_configs/__init__.py +44 -0
  342. aeri/api/score_configs/client.py +526 -0
  343. aeri/api/score_configs/raw_client.py +1012 -0
  344. aeri/api/score_configs/types/__init__.py +46 -0
  345. aeri/api/score_configs/types/create_score_config_request.py +46 -0
  346. aeri/api/score_configs/types/score_configs.py +17 -0
  347. aeri/api/score_configs/types/update_score_config_request.py +53 -0
  348. aeri/api/scores/__init__.py +76 -0
  349. aeri/api/scores/client.py +420 -0
  350. aeri/api/scores/raw_client.py +656 -0
  351. aeri/api/scores/types/__init__.py +76 -0
  352. aeri/api/scores/types/get_scores_response.py +17 -0
  353. aeri/api/scores/types/get_scores_response_data.py +211 -0
  354. aeri/api/scores/types/get_scores_response_data_boolean.py +15 -0
  355. aeri/api/scores/types/get_scores_response_data_categorical.py +15 -0
  356. aeri/api/scores/types/get_scores_response_data_correction.py +15 -0
  357. aeri/api/scores/types/get_scores_response_data_numeric.py +15 -0
  358. aeri/api/scores/types/get_scores_response_trace_data.py +38 -0
  359. aeri/api/sessions/__init__.py +40 -0
  360. aeri/api/sessions/client.py +262 -0
  361. aeri/api/sessions/raw_client.py +500 -0
  362. aeri/api/sessions/types/__init__.py +40 -0
  363. aeri/api/sessions/types/paginated_sessions.py +17 -0
  364. aeri/api/trace/__init__.py +44 -0
  365. aeri/api/trace/client.py +728 -0
  366. aeri/api/trace/raw_client.py +1208 -0
  367. aeri/api/trace/types/__init__.py +46 -0
  368. aeri/api/trace/types/delete_trace_response.py +14 -0
  369. aeri/api/trace/types/sort.py +14 -0
  370. aeri/api/trace/types/traces.py +17 -0
  371. aeri/api/utils/__init__.py +44 -0
  372. aeri/api/utils/pagination/__init__.py +40 -0
  373. aeri/api/utils/pagination/types/__init__.py +40 -0
  374. aeri/api/utils/pagination/types/meta_response.py +38 -0
  375. aeri/batch_evaluation.py +1643 -0
  376. aeri/experiment.py +1044 -0
  377. aeri/langchain/CallbackHandler.py +1377 -0
  378. aeri/langchain/__init__.py +5 -0
  379. aeri/langchain/utils.py +212 -0
  380. aeri/logger.py +28 -0
  381. aeri/media.py +352 -0
  382. aeri/model.py +477 -0
  383. aeri/openai.py +1124 -0
  384. aeri/py.typed +0 -0
  385. aeri/span_filter.py +17 -0
  386. aeri/types.py +79 -0
  387. aeri/version.py +3 -0
  388. aeri_python-4.0.0.dist-info/METADATA +51 -0
  389. aeri_python-4.0.0.dist-info/RECORD +391 -0
  390. aeri_python-4.0.0.dist-info/WHEEL +4 -0
  391. aeri_python-4.0.0.dist-info/licenses/LICENSE +21 -0
aeri/_client/client.py ADDED
@@ -0,0 +1,3761 @@
1
+ """Aeri OpenTelemetry integration module.
2
+
3
+ This module implements Aeri's core observability functionality on top of the OpenTelemetry (OTel) standard.
4
+ """
5
+
6
+ import asyncio
7
+ import logging
8
+ import os
9
+ import re
10
+ import urllib.parse
11
+ import warnings
12
+ from datetime import datetime
13
+ from hashlib import sha256
14
+ from time import time_ns
15
+ from typing import (
16
+ Any,
17
+ Callable,
18
+ Dict,
19
+ List,
20
+ Literal,
21
+ Optional,
22
+ Type,
23
+ Union,
24
+ cast,
25
+ overload,
26
+ )
27
+
28
+ import backoff
29
+ import httpx
30
+ from opentelemetry import trace as otel_trace_api
31
+ from opentelemetry.sdk.trace import ReadableSpan, TracerProvider
32
+ from opentelemetry.sdk.trace.id_generator import RandomIdGenerator
33
+ from opentelemetry.util._decorator import (
34
+ _AgnosticContextManager,
35
+ _agnosticcontextmanager,
36
+ )
37
+ from packaging.version import Version
38
+ from typing_extensions import deprecated
39
+
40
+ from aeri._client.attributes import AeriOtelSpanAttributes, _serialize
41
+ from aeri._client.constants import (
42
+ AERI_SDK_EXPERIMENT_ENVIRONMENT,
43
+ ObservationTypeGenerationLike,
44
+ ObservationTypeLiteral,
45
+ ObservationTypeLiteralNoEvent,
46
+ ObservationTypeSpanLike,
47
+ get_observation_types_list,
48
+ )
49
+ from aeri._client.datasets import DatasetClient
50
+ from aeri._client.environment_variables import (
51
+ AERI_BASE_URL,
52
+ AERI_DEBUG,
53
+ AERI_HOST,
54
+ AERI_PUBLIC_KEY,
55
+ AERI_RELEASE,
56
+ AERI_SAMPLE_RATE,
57
+ AERI_SECRET_KEY,
58
+ AERI_TIMEOUT,
59
+ AERI_TRACING_ENABLED,
60
+ AERI_TRACING_ENVIRONMENT,
61
+ )
62
+ from aeri._client.propagation import (
63
+ PropagatedExperimentAttributes,
64
+ _propagate_attributes,
65
+ )
66
+ from aeri._client.resource_manager import AeriResourceManager
67
+ from aeri._client.span import (
68
+ AeriAgent,
69
+ AeriChain,
70
+ AeriEmbedding,
71
+ AeriEvaluator,
72
+ AeriEvent,
73
+ AeriGeneration,
74
+ AeriGuardrail,
75
+ AeriRetriever,
76
+ AeriSpan,
77
+ AeriTool,
78
+ )
79
+ from aeri._client._validation import AeriClientConfig, ObservationInput, ScoreInput, TraceContextInput
80
+ from aeri._client.utils import get_sha256_hash_hex, run_async_safely
81
+ from aeri._utils import _get_timestamp
82
+ from aeri._utils.environment import get_common_release_envs
83
+ from aeri._utils.parse_error import handle_fern_exception
84
+ from aeri._utils.prompt_cache import PromptCache
85
+ from aeri.api import (
86
+ CreateChatPromptRequest,
87
+ CreateChatPromptType,
88
+ CreateTextPromptRequest,
89
+ Dataset,
90
+ DatasetItem,
91
+ DatasetRunWithItems,
92
+ DatasetStatus,
93
+ DeleteDatasetRunResponse,
94
+ Error,
95
+ MapValue,
96
+ NotFoundError,
97
+ PaginatedDatasetRuns,
98
+ Prompt_Chat,
99
+ Prompt_Text,
100
+ ScoreBody,
101
+ TraceBody,
102
+ )
103
+ from aeri.batch_evaluation import (
104
+ BatchEvaluationResult,
105
+ BatchEvaluationResumeToken,
106
+ BatchEvaluationRunner,
107
+ CompositeEvaluatorFunction,
108
+ MapperFunction,
109
+ )
110
+ from aeri.experiment import (
111
+ Evaluation,
112
+ EvaluatorFunction,
113
+ ExperimentData,
114
+ ExperimentItem,
115
+ ExperimentItemResult,
116
+ ExperimentResult,
117
+ RunEvaluatorFunction,
118
+ TaskFunction,
119
+ _run_evaluator,
120
+ _run_task,
121
+ )
122
+ from aeri.logger import aeri_logger
123
+ from aeri.media import AeriMedia
124
+ from aeri.model import (
125
+ ChatMessageDict,
126
+ ChatMessageWithPlaceholdersDict,
127
+ ChatPromptClient,
128
+ PromptClient,
129
+ TextPromptClient,
130
+ )
131
+ from aeri.types import MaskFunction, ScoreDataType, SpanLevel, TraceContext
132
+
133
+
134
+ class Aeri:
135
+ """Main client for Aeri tracing and platform features.
136
+
137
+ This class provides an interface for creating and managing traces, spans,
138
+ and generations in Aeri as well as interacting with the Aeri API.
139
+
140
+ The client features a thread-safe singleton pattern for each unique public API key,
141
+ ensuring consistent trace context propagation across your application. It implements
142
+ efficient batching of spans with configurable flush settings and includes background
143
+ thread management for media uploads and score ingestion.
144
+
145
+ Configuration is flexible through either direct parameters or environment variables,
146
+ with graceful fallbacks and runtime configuration updates.
147
+
148
+ Attributes:
149
+ api: Synchronous API client for Aeri backend communication
150
+ async_api: Asynchronous API client for Aeri backend communication
151
+ _otel_tracer: Internal AeriTracer instance managing OpenTelemetry components
152
+
153
+ Parameters:
154
+ public_key (Optional[str]): Your Aeri public API key. Can also be set via AERI_PUBLIC_KEY environment variable.
155
+ secret_key (Optional[str]): Your Aeri secret API key. Can also be set via AERI_SECRET_KEY environment variable.
156
+ base_url (Optional[str]): The Aeri API base URL. Defaults to "https://cloud.aeri.com". Can also be set via AERI_BASE_URL environment variable.
157
+ host (Optional[str]): Deprecated. Use base_url instead. The Aeri API host URL. Defaults to "https://cloud.aeri.com".
158
+ timeout (Optional[int]): Timeout in seconds for API requests. Defaults to 5 seconds.
159
+ httpx_client (Optional[httpx.Client]): Custom httpx client for making non-tracing HTTP requests. If not provided, a default client will be created.
160
+ debug (bool): Enable debug logging. Defaults to False. Can also be set via AERI_DEBUG environment variable.
161
+ tracing_enabled (Optional[bool]): Enable or disable tracing. Defaults to True. Can also be set via AERI_TRACING_ENABLED environment variable.
162
+ flush_at (Optional[int]): Number of spans to batch before sending to the API. Defaults to 512. Can also be set via AERI_FLUSH_AT environment variable.
163
+ flush_interval (Optional[float]): Time in seconds between batch flushes. Defaults to 5 seconds. Can also be set via AERI_FLUSH_INTERVAL environment variable.
164
+ environment (Optional[str]): Environment name for tracing. Default is 'default'. Can also be set via AERI_TRACING_ENVIRONMENT environment variable. Can be any lowercase alphanumeric string with hyphens and underscores that does not start with 'aeri'.
165
+ release (Optional[str]): Release version/hash of your application. Used for grouping analytics by release.
166
+ media_upload_thread_count (Optional[int]): Number of background threads for handling media uploads. Defaults to 1. Can also be set via AERI_MEDIA_UPLOAD_THREAD_COUNT environment variable.
167
+ sample_rate (Optional[float]): Sampling rate for traces (0.0 to 1.0). Defaults to 1.0 (100% of traces are sampled). Can also be set via AERI_SAMPLE_RATE environment variable.
168
+ mask (Optional[MaskFunction]): Function to mask sensitive data in traces before sending to the API.
169
+ blocked_instrumentation_scopes (Optional[List[str]]): Deprecated. Use `should_export_span` instead. Equivalent behavior:
170
+ ```python
171
+ from aeri.span_filter import is_default_export_span
172
+ blocked = {"sqlite", "requests"}
173
+
174
+ should_export_span = lambda span: (
175
+ is_default_export_span(span)
176
+ and (
177
+ span.instrumentation_scope is None
178
+ or span.instrumentation_scope.name not in blocked
179
+ )
180
+ )
181
+ ```
182
+ should_export_span (Optional[Callable[[ReadableSpan], bool]]): Callback to decide whether to export a span. If omitted, Aeri uses the default filter (Aeri SDK spans, spans with `gen_ai.*` attributes, and known LLM instrumentation scopes).
183
+ additional_headers (Optional[Dict[str, str]]): Additional headers to include in all API requests and OTLPSpanExporter requests. These headers will be merged with default headers. Note: If httpx_client is provided, additional_headers must be set directly on your custom httpx_client as well.
184
+ tracer_provider(Optional[TracerProvider]): OpenTelemetry TracerProvider to use for Aeri. This can be useful to set to have disconnected tracing between Aeri and other OpenTelemetry-span emitting libraries. Note: To track active spans, the context is still shared between TracerProviders. This may lead to broken trace trees.
185
+
186
+ Example:
187
+ ```python
188
+ from aeri.otel import Aeri
189
+
190
+ # Initialize the client (reads from env vars if not provided)
191
+ aeri = Aeri(
192
+ public_key="your-public-key",
193
+ secret_key="your-secret-key",
194
+ host="https://cloud.aeri.com", # Optional, default shown
195
+ )
196
+
197
+ # Create a trace span
198
+ with aeri.start_as_current_observation(name="process-query") as span:
199
+ # Your application code here
200
+
201
+ # Create a nested generation span for an LLM call
202
+ with span.start_as_current_generation(
203
+ name="generate-response",
204
+ model="gpt-4",
205
+ input={"query": "Tell me about AI"},
206
+ model_parameters={"temperature": 0.7, "max_tokens": 500}
207
+ ) as generation:
208
+ # Generate response here
209
+ response = "AI is a field of computer science..."
210
+
211
+ generation.update(
212
+ output=response,
213
+ usage_details={"prompt_tokens": 10, "completion_tokens": 50},
214
+ cost_details={"total_cost": 0.0023}
215
+ )
216
+
217
+ # Score the generation (supports NUMERIC, BOOLEAN, CATEGORICAL)
218
+ generation.score(name="relevance", value=0.95, data_type="NUMERIC")
219
+ ```
220
+ """
221
+
222
+ _resources: Optional[AeriResourceManager] = None
223
+ _mask: Optional[MaskFunction] = None
224
+ _otel_tracer: otel_trace_api.Tracer
225
+
226
+ def __init__(
227
+ self,
228
+ *,
229
+ public_key: Optional[str] = None,
230
+ secret_key: Optional[str] = None,
231
+ base_url: Optional[str] = None,
232
+ host: Optional[str] = None,
233
+ timeout: Optional[int] = None,
234
+ httpx_client: Optional[httpx.Client] = None,
235
+ debug: bool = False,
236
+ tracing_enabled: Optional[bool] = True,
237
+ flush_at: Optional[int] = None,
238
+ flush_interval: Optional[float] = None,
239
+ environment: Optional[str] = None,
240
+ release: Optional[str] = None,
241
+ media_upload_thread_count: Optional[int] = None,
242
+ sample_rate: Optional[float] = None,
243
+ mask: Optional[MaskFunction] = None,
244
+ blocked_instrumentation_scopes: Optional[List[str]] = None,
245
+ should_export_span: Optional[Callable[[ReadableSpan], bool]] = None,
246
+ additional_headers: Optional[Dict[str, str]] = None,
247
+ tracer_provider: Optional[TracerProvider] = None,
248
+ ):
249
+ # ── Resolve raw values from kwargs + env vars ────────────────────────
250
+ resolved_base_url = (
251
+ base_url
252
+ or os.environ.get(AERI_BASE_URL)
253
+ or host
254
+ or os.environ.get(AERI_HOST, "https://api.aeri.com")
255
+ )
256
+ # ── Strict Pydantic V2 validation of constructor config ──────────────
257
+ _cfg = AeriClientConfig.model_validate(
258
+ {
259
+ "public_key": public_key or os.environ.get(AERI_PUBLIC_KEY),
260
+ "secret_key": secret_key or os.environ.get(AERI_SECRET_KEY),
261
+ "base_url": resolved_base_url,
262
+ "timeout": timeout or int(os.environ.get(AERI_TIMEOUT, 5)),
263
+ "debug": debug or (os.getenv(AERI_DEBUG, "false").lower() == "true"),
264
+ "tracing_enabled": (
265
+ tracing_enabled
266
+ and os.environ.get(AERI_TRACING_ENABLED, "true").lower() != "false"
267
+ ),
268
+ "flush_at": flush_at or 15,
269
+ "flush_interval": flush_interval or 1.0,
270
+ "environment": environment or os.environ.get(AERI_TRACING_ENVIRONMENT),
271
+ "release": release or os.environ.get(AERI_RELEASE),
272
+ "media_upload_thread_count": media_upload_thread_count or 1,
273
+ "sample_rate": sample_rate or float(os.environ.get(AERI_SAMPLE_RATE, 1.0)),
274
+ }
275
+ )
276
+
277
+ self._base_url = _cfg.base_url
278
+ self._environment = environment or cast(
279
+ str, os.environ.get(AERI_TRACING_ENVIRONMENT)
280
+ )
281
+ self._release = (
282
+ release
283
+ or os.environ.get(AERI_RELEASE, None)
284
+ or get_common_release_envs()
285
+ )
286
+ self._project_id: Optional[str] = None
287
+ sample_rate = sample_rate or float(os.environ.get(AERI_SAMPLE_RATE, 1.0))
288
+ if not 0.0 <= sample_rate <= 1.0:
289
+ raise ValueError(
290
+ f"Sample rate must be between 0.0 and 1.0, got {sample_rate}"
291
+ )
292
+
293
+ timeout = timeout or int(os.environ.get(AERI_TIMEOUT, 5))
294
+
295
+ self._tracing_enabled = (
296
+ tracing_enabled
297
+ and os.environ.get(AERI_TRACING_ENABLED, "true").lower() != "false"
298
+ )
299
+ if not self._tracing_enabled:
300
+ aeri_logger.info(
301
+ "Configuration: Aeri tracing is explicitly disabled. No data will be sent to the Aeri API."
302
+ )
303
+
304
+ debug = (
305
+ debug if debug else (os.getenv(AERI_DEBUG, "false").lower() == "true")
306
+ )
307
+ if debug:
308
+ logging.basicConfig(
309
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
310
+ )
311
+ aeri_logger.setLevel(logging.DEBUG)
312
+
313
+ public_key = public_key or os.environ.get(AERI_PUBLIC_KEY)
314
+ if public_key is None:
315
+ aeri_logger.warning(
316
+ "Authentication error: Aeri client initialized without public_key. Client will be disabled. "
317
+ "Provide a public_key parameter or set AERI_PUBLIC_KEY environment variable. "
318
+ )
319
+ self._otel_tracer = otel_trace_api.NoOpTracer()
320
+ return
321
+
322
+ secret_key = secret_key or os.environ.get(AERI_SECRET_KEY)
323
+ if secret_key is None:
324
+ aeri_logger.warning(
325
+ "Authentication error: Aeri client initialized without secret_key. Client will be disabled. "
326
+ "Provide a secret_key parameter or set AERI_SECRET_KEY environment variable. "
327
+ )
328
+ self._otel_tracer = otel_trace_api.NoOpTracer()
329
+ return
330
+
331
+ if os.environ.get("OTEL_SDK_DISABLED", "false").lower() == "true":
332
+ aeri_logger.warning(
333
+ "OTEL_SDK_DISABLED is set. Aeri tracing will be disabled and no traces will appear in the UI."
334
+ )
335
+
336
+ if blocked_instrumentation_scopes is not None:
337
+ warnings.warn(
338
+ "`blocked_instrumentation_scopes` is deprecated and will be removed in a future release. "
339
+ "Use `should_export_span` instead. Example: "
340
+ "from aeri.span_filter import is_default_export_span; "
341
+ 'blocked={"scope"}; should_export_span=lambda span: '
342
+ "is_default_export_span(span) and (span.instrumentation_scope is None or "
343
+ "span.instrumentation_scope.name not in blocked).",
344
+ DeprecationWarning,
345
+ stacklevel=2,
346
+ )
347
+
348
+ # Initialize api and tracer if requirements are met
349
+ self._resources = AeriResourceManager(
350
+ public_key=public_key,
351
+ secret_key=secret_key,
352
+ base_url=self._base_url,
353
+ timeout=timeout,
354
+ environment=self._environment,
355
+ release=release,
356
+ flush_at=flush_at,
357
+ flush_interval=flush_interval,
358
+ httpx_client=httpx_client,
359
+ media_upload_thread_count=media_upload_thread_count,
360
+ sample_rate=sample_rate,
361
+ mask=mask,
362
+ tracing_enabled=self._tracing_enabled,
363
+ blocked_instrumentation_scopes=blocked_instrumentation_scopes,
364
+ should_export_span=should_export_span,
365
+ additional_headers=additional_headers,
366
+ tracer_provider=tracer_provider,
367
+ )
368
+ self._mask = self._resources.mask
369
+
370
+ self._otel_tracer = (
371
+ self._resources.tracer
372
+ if self._tracing_enabled and self._resources.tracer is not None
373
+ else otel_trace_api.NoOpTracer()
374
+ )
375
+ self.api = self._resources.api
376
+ self.async_api = self._resources.async_api
377
+
378
+ @overload
379
+ def start_observation(
380
+ self,
381
+ *,
382
+ trace_context: Optional[TraceContext] = None,
383
+ name: str,
384
+ as_type: Literal["generation"],
385
+ input: Optional[Any] = None,
386
+ output: Optional[Any] = None,
387
+ metadata: Optional[Any] = None,
388
+ version: Optional[str] = None,
389
+ level: Optional[SpanLevel] = None,
390
+ status_message: Optional[str] = None,
391
+ completion_start_time: Optional[datetime] = None,
392
+ model: Optional[str] = None,
393
+ model_parameters: Optional[Dict[str, MapValue]] = None,
394
+ usage_details: Optional[Dict[str, int]] = None,
395
+ cost_details: Optional[Dict[str, float]] = None,
396
+ prompt: Optional[PromptClient] = None,
397
+ ) -> AeriGeneration: ...
398
+
399
+ @overload
400
+ def start_observation(
401
+ self,
402
+ *,
403
+ trace_context: Optional[TraceContext] = None,
404
+ name: str,
405
+ as_type: Literal["span"] = "span",
406
+ input: Optional[Any] = None,
407
+ output: Optional[Any] = None,
408
+ metadata: Optional[Any] = None,
409
+ version: Optional[str] = None,
410
+ level: Optional[SpanLevel] = None,
411
+ status_message: Optional[str] = None,
412
+ ) -> AeriSpan: ...
413
+
414
+ @overload
415
+ def start_observation(
416
+ self,
417
+ *,
418
+ trace_context: Optional[TraceContext] = None,
419
+ name: str,
420
+ as_type: Literal["agent"],
421
+ input: Optional[Any] = None,
422
+ output: Optional[Any] = None,
423
+ metadata: Optional[Any] = None,
424
+ version: Optional[str] = None,
425
+ level: Optional[SpanLevel] = None,
426
+ status_message: Optional[str] = None,
427
+ ) -> AeriAgent: ...
428
+
429
+ @overload
430
+ def start_observation(
431
+ self,
432
+ *,
433
+ trace_context: Optional[TraceContext] = None,
434
+ name: str,
435
+ as_type: Literal["tool"],
436
+ input: Optional[Any] = None,
437
+ output: Optional[Any] = None,
438
+ metadata: Optional[Any] = None,
439
+ version: Optional[str] = None,
440
+ level: Optional[SpanLevel] = None,
441
+ status_message: Optional[str] = None,
442
+ ) -> AeriTool: ...
443
+
444
+ @overload
445
+ def start_observation(
446
+ self,
447
+ *,
448
+ trace_context: Optional[TraceContext] = None,
449
+ name: str,
450
+ as_type: Literal["chain"],
451
+ input: Optional[Any] = None,
452
+ output: Optional[Any] = None,
453
+ metadata: Optional[Any] = None,
454
+ version: Optional[str] = None,
455
+ level: Optional[SpanLevel] = None,
456
+ status_message: Optional[str] = None,
457
+ ) -> AeriChain: ...
458
+
459
+ @overload
460
+ def start_observation(
461
+ self,
462
+ *,
463
+ trace_context: Optional[TraceContext] = None,
464
+ name: str,
465
+ as_type: Literal["retriever"],
466
+ input: Optional[Any] = None,
467
+ output: Optional[Any] = None,
468
+ metadata: Optional[Any] = None,
469
+ version: Optional[str] = None,
470
+ level: Optional[SpanLevel] = None,
471
+ status_message: Optional[str] = None,
472
+ ) -> AeriRetriever: ...
473
+
474
+ @overload
475
+ def start_observation(
476
+ self,
477
+ *,
478
+ trace_context: Optional[TraceContext] = None,
479
+ name: str,
480
+ as_type: Literal["evaluator"],
481
+ input: Optional[Any] = None,
482
+ output: Optional[Any] = None,
483
+ metadata: Optional[Any] = None,
484
+ version: Optional[str] = None,
485
+ level: Optional[SpanLevel] = None,
486
+ status_message: Optional[str] = None,
487
+ ) -> AeriEvaluator: ...
488
+
489
+ @overload
490
+ def start_observation(
491
+ self,
492
+ *,
493
+ trace_context: Optional[TraceContext] = None,
494
+ name: str,
495
+ as_type: Literal["embedding"],
496
+ input: Optional[Any] = None,
497
+ output: Optional[Any] = None,
498
+ metadata: Optional[Any] = None,
499
+ version: Optional[str] = None,
500
+ level: Optional[SpanLevel] = None,
501
+ status_message: Optional[str] = None,
502
+ completion_start_time: Optional[datetime] = None,
503
+ model: Optional[str] = None,
504
+ model_parameters: Optional[Dict[str, MapValue]] = None,
505
+ usage_details: Optional[Dict[str, int]] = None,
506
+ cost_details: Optional[Dict[str, float]] = None,
507
+ prompt: Optional[PromptClient] = None,
508
+ ) -> AeriEmbedding: ...
509
+
510
+ @overload
511
+ def start_observation(
512
+ self,
513
+ *,
514
+ trace_context: Optional[TraceContext] = None,
515
+ name: str,
516
+ as_type: Literal["guardrail"],
517
+ input: Optional[Any] = None,
518
+ output: Optional[Any] = None,
519
+ metadata: Optional[Any] = None,
520
+ version: Optional[str] = None,
521
+ level: Optional[SpanLevel] = None,
522
+ status_message: Optional[str] = None,
523
+ ) -> AeriGuardrail: ...
524
+
525
+ def start_observation(
526
+ self,
527
+ *,
528
+ trace_context: Optional[TraceContext] = None,
529
+ name: str,
530
+ as_type: ObservationTypeLiteralNoEvent = "span",
531
+ input: Optional[Any] = None,
532
+ output: Optional[Any] = None,
533
+ metadata: Optional[Any] = None,
534
+ version: Optional[str] = None,
535
+ level: Optional[SpanLevel] = None,
536
+ status_message: Optional[str] = None,
537
+ completion_start_time: Optional[datetime] = None,
538
+ model: Optional[str] = None,
539
+ model_parameters: Optional[Dict[str, MapValue]] = None,
540
+ usage_details: Optional[Dict[str, int]] = None,
541
+ cost_details: Optional[Dict[str, float]] = None,
542
+ prompt: Optional[PromptClient] = None,
543
+ ) -> Union[
544
+ AeriSpan,
545
+ AeriGeneration,
546
+ AeriAgent,
547
+ AeriTool,
548
+ AeriChain,
549
+ AeriRetriever,
550
+ AeriEvaluator,
551
+ AeriEmbedding,
552
+ AeriGuardrail,
553
+ ]:
554
+ """Create a new observation of the specified type.
555
+
556
+ This method creates a new observation but does not set it as the current span in the
557
+ context. To create and use an observation within a context, use start_as_current_observation().
558
+
559
+ Args:
560
+ trace_context: Optional context for connecting to an existing trace
561
+ name: Name of the observation
562
+ as_type: Type of observation to create (defaults to "span")
563
+ input: Input data for the operation
564
+ output: Output data from the operation
565
+ metadata: Additional metadata to associate with the observation
566
+ version: Version identifier for the code or component
567
+ level: Importance level of the observation
568
+ status_message: Optional status message for the observation
569
+ completion_start_time: When the model started generating (for generation types)
570
+ model: Name/identifier of the AI model used (for generation types)
571
+ model_parameters: Parameters used for the model (for generation types)
572
+ usage_details: Token usage information (for generation types)
573
+ cost_details: Cost information (for generation types)
574
+ prompt: Associated prompt template (for generation types)
575
+
576
+ Returns:
577
+ An observation object of the appropriate type that must be ended with .end()
578
+ """
579
+ if trace_context:
580
+ trace_id = trace_context.get("trace_id", None)
581
+ parent_span_id = trace_context.get("parent_span_id", None)
582
+
583
+ if trace_id:
584
+ remote_parent_span = self._create_remote_parent_span(
585
+ trace_id=trace_id, parent_span_id=parent_span_id
586
+ )
587
+
588
+ with otel_trace_api.use_span(
589
+ cast(otel_trace_api.Span, remote_parent_span)
590
+ ):
591
+ otel_span = self._otel_tracer.start_span(name=name)
592
+ otel_span.set_attribute(AeriOtelSpanAttributes.AS_ROOT, True)
593
+
594
+ return self._create_observation_from_otel_span(
595
+ otel_span=otel_span,
596
+ as_type=as_type,
597
+ input=input,
598
+ output=output,
599
+ metadata=metadata,
600
+ version=version,
601
+ level=level,
602
+ status_message=status_message,
603
+ completion_start_time=completion_start_time,
604
+ model=model,
605
+ model_parameters=model_parameters,
606
+ usage_details=usage_details,
607
+ cost_details=cost_details,
608
+ prompt=prompt,
609
+ )
610
+
611
+ otel_span = self._otel_tracer.start_span(name=name)
612
+
613
+ return self._create_observation_from_otel_span(
614
+ otel_span=otel_span,
615
+ as_type=as_type,
616
+ input=input,
617
+ output=output,
618
+ metadata=metadata,
619
+ version=version,
620
+ level=level,
621
+ status_message=status_message,
622
+ completion_start_time=completion_start_time,
623
+ model=model,
624
+ model_parameters=model_parameters,
625
+ usage_details=usage_details,
626
+ cost_details=cost_details,
627
+ prompt=prompt,
628
+ )
629
+
630
+ def _create_observation_from_otel_span(
631
+ self,
632
+ *,
633
+ otel_span: otel_trace_api.Span,
634
+ as_type: ObservationTypeLiteralNoEvent,
635
+ input: Optional[Any] = None,
636
+ output: Optional[Any] = None,
637
+ metadata: Optional[Any] = None,
638
+ version: Optional[str] = None,
639
+ level: Optional[SpanLevel] = None,
640
+ status_message: Optional[str] = None,
641
+ completion_start_time: Optional[datetime] = None,
642
+ model: Optional[str] = None,
643
+ model_parameters: Optional[Dict[str, MapValue]] = None,
644
+ usage_details: Optional[Dict[str, int]] = None,
645
+ cost_details: Optional[Dict[str, float]] = None,
646
+ prompt: Optional[PromptClient] = None,
647
+ ) -> Union[
648
+ AeriSpan,
649
+ AeriGeneration,
650
+ AeriAgent,
651
+ AeriTool,
652
+ AeriChain,
653
+ AeriRetriever,
654
+ AeriEvaluator,
655
+ AeriEmbedding,
656
+ AeriGuardrail,
657
+ ]:
658
+ """Create the appropriate observation type from an OTEL span."""
659
+ if as_type in get_observation_types_list(ObservationTypeGenerationLike):
660
+ observation_class = self._get_span_class(as_type)
661
+ # Type ignore to prevent overloads of internal _get_span_class function,
662
+ # issue is that AeriEvent could be returned and that classes have diff. args
663
+ return observation_class( # type: ignore[return-value,call-arg]
664
+ otel_span=otel_span,
665
+ aeri_client=self,
666
+ environment=self._environment,
667
+ release=self._release,
668
+ input=input,
669
+ output=output,
670
+ metadata=metadata,
671
+ version=version,
672
+ level=level,
673
+ status_message=status_message,
674
+ completion_start_time=completion_start_time,
675
+ model=model,
676
+ model_parameters=model_parameters,
677
+ usage_details=usage_details,
678
+ cost_details=cost_details,
679
+ prompt=prompt,
680
+ )
681
+ else:
682
+ # For other types (e.g. span, guardrail), create appropriate class without generation properties
683
+ observation_class = self._get_span_class(as_type)
684
+ # Type ignore to prevent overloads of internal _get_span_class function,
685
+ # issue is that AeriEvent could be returned and that classes have diff. args
686
+ return observation_class( # type: ignore[return-value,call-arg]
687
+ otel_span=otel_span,
688
+ aeri_client=self,
689
+ environment=self._environment,
690
+ release=self._release,
691
+ input=input,
692
+ output=output,
693
+ metadata=metadata,
694
+ version=version,
695
+ level=level,
696
+ status_message=status_message,
697
+ )
698
+ # span._observation_type = as_type
699
+ # span._otel_span.set_attribute("aeri.observation.type", as_type)
700
+ # return span
701
+
702
+ @overload
703
+ def start_as_current_observation(
704
+ self,
705
+ *,
706
+ trace_context: Optional[TraceContext] = None,
707
+ name: str,
708
+ as_type: Literal["generation"],
709
+ input: Optional[Any] = None,
710
+ output: Optional[Any] = None,
711
+ metadata: Optional[Any] = None,
712
+ version: Optional[str] = None,
713
+ level: Optional[SpanLevel] = None,
714
+ status_message: Optional[str] = None,
715
+ completion_start_time: Optional[datetime] = None,
716
+ model: Optional[str] = None,
717
+ model_parameters: Optional[Dict[str, MapValue]] = None,
718
+ usage_details: Optional[Dict[str, int]] = None,
719
+ cost_details: Optional[Dict[str, float]] = None,
720
+ prompt: Optional[PromptClient] = None,
721
+ end_on_exit: Optional[bool] = None,
722
+ ) -> _AgnosticContextManager[AeriGeneration]: ...
723
+
724
+ @overload
725
+ def start_as_current_observation(
726
+ self,
727
+ *,
728
+ trace_context: Optional[TraceContext] = None,
729
+ name: str,
730
+ as_type: Literal["span"] = "span",
731
+ input: Optional[Any] = None,
732
+ output: Optional[Any] = None,
733
+ metadata: Optional[Any] = None,
734
+ version: Optional[str] = None,
735
+ level: Optional[SpanLevel] = None,
736
+ status_message: Optional[str] = None,
737
+ end_on_exit: Optional[bool] = None,
738
+ ) -> _AgnosticContextManager[AeriSpan]: ...
739
+
740
+ @overload
741
+ def start_as_current_observation(
742
+ self,
743
+ *,
744
+ trace_context: Optional[TraceContext] = None,
745
+ name: str,
746
+ as_type: Literal["agent"],
747
+ input: Optional[Any] = None,
748
+ output: Optional[Any] = None,
749
+ metadata: Optional[Any] = None,
750
+ version: Optional[str] = None,
751
+ level: Optional[SpanLevel] = None,
752
+ status_message: Optional[str] = None,
753
+ end_on_exit: Optional[bool] = None,
754
+ ) -> _AgnosticContextManager[AeriAgent]: ...
755
+
756
+ @overload
757
+ def start_as_current_observation(
758
+ self,
759
+ *,
760
+ trace_context: Optional[TraceContext] = None,
761
+ name: str,
762
+ as_type: Literal["tool"],
763
+ input: Optional[Any] = None,
764
+ output: Optional[Any] = None,
765
+ metadata: Optional[Any] = None,
766
+ version: Optional[str] = None,
767
+ level: Optional[SpanLevel] = None,
768
+ status_message: Optional[str] = None,
769
+ end_on_exit: Optional[bool] = None,
770
+ ) -> _AgnosticContextManager[AeriTool]: ...
771
+
772
+ @overload
773
+ def start_as_current_observation(
774
+ self,
775
+ *,
776
+ trace_context: Optional[TraceContext] = None,
777
+ name: str,
778
+ as_type: Literal["chain"],
779
+ input: Optional[Any] = None,
780
+ output: Optional[Any] = None,
781
+ metadata: Optional[Any] = None,
782
+ version: Optional[str] = None,
783
+ level: Optional[SpanLevel] = None,
784
+ status_message: Optional[str] = None,
785
+ end_on_exit: Optional[bool] = None,
786
+ ) -> _AgnosticContextManager[AeriChain]: ...
787
+
788
+ @overload
789
+ def start_as_current_observation(
790
+ self,
791
+ *,
792
+ trace_context: Optional[TraceContext] = None,
793
+ name: str,
794
+ as_type: Literal["retriever"],
795
+ input: Optional[Any] = None,
796
+ output: Optional[Any] = None,
797
+ metadata: Optional[Any] = None,
798
+ version: Optional[str] = None,
799
+ level: Optional[SpanLevel] = None,
800
+ status_message: Optional[str] = None,
801
+ end_on_exit: Optional[bool] = None,
802
+ ) -> _AgnosticContextManager[AeriRetriever]: ...
803
+
804
+ @overload
805
+ def start_as_current_observation(
806
+ self,
807
+ *,
808
+ trace_context: Optional[TraceContext] = None,
809
+ name: str,
810
+ as_type: Literal["evaluator"],
811
+ input: Optional[Any] = None,
812
+ output: Optional[Any] = None,
813
+ metadata: Optional[Any] = None,
814
+ version: Optional[str] = None,
815
+ level: Optional[SpanLevel] = None,
816
+ status_message: Optional[str] = None,
817
+ end_on_exit: Optional[bool] = None,
818
+ ) -> _AgnosticContextManager[AeriEvaluator]: ...
819
+
820
+ @overload
821
+ def start_as_current_observation(
822
+ self,
823
+ *,
824
+ trace_context: Optional[TraceContext] = None,
825
+ name: str,
826
+ as_type: Literal["embedding"],
827
+ input: Optional[Any] = None,
828
+ output: Optional[Any] = None,
829
+ metadata: Optional[Any] = None,
830
+ version: Optional[str] = None,
831
+ level: Optional[SpanLevel] = None,
832
+ status_message: Optional[str] = None,
833
+ completion_start_time: Optional[datetime] = None,
834
+ model: Optional[str] = None,
835
+ model_parameters: Optional[Dict[str, MapValue]] = None,
836
+ usage_details: Optional[Dict[str, int]] = None,
837
+ cost_details: Optional[Dict[str, float]] = None,
838
+ prompt: Optional[PromptClient] = None,
839
+ end_on_exit: Optional[bool] = None,
840
+ ) -> _AgnosticContextManager[AeriEmbedding]: ...
841
+
842
+ @overload
843
+ def start_as_current_observation(
844
+ self,
845
+ *,
846
+ trace_context: Optional[TraceContext] = None,
847
+ name: str,
848
+ as_type: Literal["guardrail"],
849
+ input: Optional[Any] = None,
850
+ output: Optional[Any] = None,
851
+ metadata: Optional[Any] = None,
852
+ version: Optional[str] = None,
853
+ level: Optional[SpanLevel] = None,
854
+ status_message: Optional[str] = None,
855
+ end_on_exit: Optional[bool] = None,
856
+ ) -> _AgnosticContextManager[AeriGuardrail]: ...
857
+
858
+ def start_as_current_observation(
859
+ self,
860
+ *,
861
+ trace_context: Optional[TraceContext] = None,
862
+ name: str,
863
+ as_type: ObservationTypeLiteralNoEvent = "span",
864
+ input: Optional[Any] = None,
865
+ output: Optional[Any] = None,
866
+ metadata: Optional[Any] = None,
867
+ version: Optional[str] = None,
868
+ level: Optional[SpanLevel] = None,
869
+ status_message: Optional[str] = None,
870
+ completion_start_time: Optional[datetime] = None,
871
+ model: Optional[str] = None,
872
+ model_parameters: Optional[Dict[str, MapValue]] = None,
873
+ usage_details: Optional[Dict[str, int]] = None,
874
+ cost_details: Optional[Dict[str, float]] = None,
875
+ prompt: Optional[PromptClient] = None,
876
+ end_on_exit: Optional[bool] = None,
877
+ ) -> Union[
878
+ _AgnosticContextManager[AeriGeneration],
879
+ _AgnosticContextManager[AeriSpan],
880
+ _AgnosticContextManager[AeriAgent],
881
+ _AgnosticContextManager[AeriTool],
882
+ _AgnosticContextManager[AeriChain],
883
+ _AgnosticContextManager[AeriRetriever],
884
+ _AgnosticContextManager[AeriEvaluator],
885
+ _AgnosticContextManager[AeriEmbedding],
886
+ _AgnosticContextManager[AeriGuardrail],
887
+ ]:
888
+ """Create a new observation and set it as the current span in a context manager.
889
+
890
+ This method creates a new observation of the specified type and sets it as the
891
+ current span within a context manager. Use this method with a 'with' statement to
892
+ automatically handle the observation lifecycle within a code block.
893
+
894
+ The created observation will be the child of the current span in the context.
895
+
896
+ Args:
897
+ trace_context: Optional context for connecting to an existing trace
898
+ name: Name of the observation (e.g., function or operation name)
899
+ as_type: Type of observation to create (defaults to "span")
900
+ input: Input data for the operation (can be any JSON-serializable object)
901
+ output: Output data from the operation (can be any JSON-serializable object)
902
+ metadata: Additional metadata to associate with the observation
903
+ version: Version identifier for the code or component
904
+ level: Importance level of the observation (info, warning, error)
905
+ status_message: Optional status message for the observation
906
+ end_on_exit (default: True): Whether to end the span automatically when leaving the context manager. If False, the span must be manually ended to avoid memory leaks.
907
+
908
+ The following parameters are available when as_type is: "generation" or "embedding".
909
+ completion_start_time: When the model started generating the response
910
+ model: Name/identifier of the AI model used (e.g., "gpt-4")
911
+ model_parameters: Parameters used for the model (e.g., temperature, max_tokens)
912
+ usage_details: Token usage information (e.g., prompt_tokens, completion_tokens)
913
+ cost_details: Cost information for the model call
914
+ prompt: Associated prompt template from Aeri prompt management
915
+
916
+ Returns:
917
+ A context manager that yields the appropriate observation type based on as_type
918
+
919
+ Example:
920
+ ```python
921
+ # Create a span
922
+ with aeri.start_as_current_observation(name="process-query", as_type="span") as span:
923
+ # Do work
924
+ result = process_data()
925
+ span.update(output=result)
926
+
927
+ # Create a child span automatically
928
+ with span.start_as_current_observation(name="sub-operation") as child_span:
929
+ # Do sub-operation work
930
+ child_span.update(output="sub-result")
931
+
932
+ # Create a tool observation
933
+ with aeri.start_as_current_observation(name="web-search", as_type="tool") as tool:
934
+ # Do tool work
935
+ results = search_web(query)
936
+ tool.update(output=results)
937
+
938
+ # Create a generation observation
939
+ with aeri.start_as_current_observation(
940
+ name="answer-generation",
941
+ as_type="generation",
942
+ model="gpt-4"
943
+ ) as generation:
944
+ # Generate answer
945
+ response = llm.generate(...)
946
+ generation.update(output=response)
947
+ ```
948
+ """
949
+ if as_type in get_observation_types_list(ObservationTypeGenerationLike):
950
+ if trace_context:
951
+ trace_id = trace_context.get("trace_id", None)
952
+ parent_span_id = trace_context.get("parent_span_id", None)
953
+
954
+ if trace_id:
955
+ remote_parent_span = self._create_remote_parent_span(
956
+ trace_id=trace_id, parent_span_id=parent_span_id
957
+ )
958
+
959
+ return cast(
960
+ Union[
961
+ _AgnosticContextManager[AeriGeneration],
962
+ _AgnosticContextManager[AeriEmbedding],
963
+ ],
964
+ self._create_span_with_parent_context(
965
+ as_type=as_type,
966
+ name=name,
967
+ remote_parent_span=remote_parent_span,
968
+ parent=None,
969
+ end_on_exit=end_on_exit,
970
+ input=input,
971
+ output=output,
972
+ metadata=metadata,
973
+ version=version,
974
+ level=level,
975
+ status_message=status_message,
976
+ completion_start_time=completion_start_time,
977
+ model=model,
978
+ model_parameters=model_parameters,
979
+ usage_details=usage_details,
980
+ cost_details=cost_details,
981
+ prompt=prompt,
982
+ ),
983
+ )
984
+
985
+ return cast(
986
+ Union[
987
+ _AgnosticContextManager[AeriGeneration],
988
+ _AgnosticContextManager[AeriEmbedding],
989
+ ],
990
+ self._start_as_current_otel_span_with_processed_media(
991
+ as_type=as_type,
992
+ name=name,
993
+ end_on_exit=end_on_exit,
994
+ input=input,
995
+ output=output,
996
+ metadata=metadata,
997
+ version=version,
998
+ level=level,
999
+ status_message=status_message,
1000
+ completion_start_time=completion_start_time,
1001
+ model=model,
1002
+ model_parameters=model_parameters,
1003
+ usage_details=usage_details,
1004
+ cost_details=cost_details,
1005
+ prompt=prompt,
1006
+ ),
1007
+ )
1008
+
1009
+ if as_type in get_observation_types_list(ObservationTypeSpanLike):
1010
+ if trace_context:
1011
+ trace_id = trace_context.get("trace_id", None)
1012
+ parent_span_id = trace_context.get("parent_span_id", None)
1013
+
1014
+ if trace_id:
1015
+ remote_parent_span = self._create_remote_parent_span(
1016
+ trace_id=trace_id, parent_span_id=parent_span_id
1017
+ )
1018
+
1019
+ return cast(
1020
+ Union[
1021
+ _AgnosticContextManager[AeriSpan],
1022
+ _AgnosticContextManager[AeriAgent],
1023
+ _AgnosticContextManager[AeriTool],
1024
+ _AgnosticContextManager[AeriChain],
1025
+ _AgnosticContextManager[AeriRetriever],
1026
+ _AgnosticContextManager[AeriEvaluator],
1027
+ _AgnosticContextManager[AeriGuardrail],
1028
+ ],
1029
+ self._create_span_with_parent_context(
1030
+ as_type=as_type,
1031
+ name=name,
1032
+ remote_parent_span=remote_parent_span,
1033
+ parent=None,
1034
+ end_on_exit=end_on_exit,
1035
+ input=input,
1036
+ output=output,
1037
+ metadata=metadata,
1038
+ version=version,
1039
+ level=level,
1040
+ status_message=status_message,
1041
+ ),
1042
+ )
1043
+
1044
+ return cast(
1045
+ Union[
1046
+ _AgnosticContextManager[AeriSpan],
1047
+ _AgnosticContextManager[AeriAgent],
1048
+ _AgnosticContextManager[AeriTool],
1049
+ _AgnosticContextManager[AeriChain],
1050
+ _AgnosticContextManager[AeriRetriever],
1051
+ _AgnosticContextManager[AeriEvaluator],
1052
+ _AgnosticContextManager[AeriGuardrail],
1053
+ ],
1054
+ self._start_as_current_otel_span_with_processed_media(
1055
+ as_type=as_type,
1056
+ name=name,
1057
+ end_on_exit=end_on_exit,
1058
+ input=input,
1059
+ output=output,
1060
+ metadata=metadata,
1061
+ version=version,
1062
+ level=level,
1063
+ status_message=status_message,
1064
+ ),
1065
+ )
1066
+
1067
+ # This should never be reached since all valid types are handled above
1068
+ aeri_logger.warning(
1069
+ f"Unknown observation type: {as_type}, falling back to span"
1070
+ )
1071
+ return self._start_as_current_otel_span_with_processed_media(
1072
+ as_type="span",
1073
+ name=name,
1074
+ end_on_exit=end_on_exit,
1075
+ input=input,
1076
+ output=output,
1077
+ metadata=metadata,
1078
+ version=version,
1079
+ level=level,
1080
+ status_message=status_message,
1081
+ )
1082
+
1083
+ def _get_span_class(
1084
+ self,
1085
+ as_type: ObservationTypeLiteral,
1086
+ ) -> Union[
1087
+ Type[AeriAgent],
1088
+ Type[AeriTool],
1089
+ Type[AeriChain],
1090
+ Type[AeriRetriever],
1091
+ Type[AeriEvaluator],
1092
+ Type[AeriEmbedding],
1093
+ Type[AeriGuardrail],
1094
+ Type[AeriGeneration],
1095
+ Type[AeriEvent],
1096
+ Type[AeriSpan],
1097
+ ]:
1098
+ """Get the appropriate span class based on as_type."""
1099
+ normalized_type = as_type.lower()
1100
+
1101
+ if normalized_type == "agent":
1102
+ return AeriAgent
1103
+ elif normalized_type == "tool":
1104
+ return AeriTool
1105
+ elif normalized_type == "chain":
1106
+ return AeriChain
1107
+ elif normalized_type == "retriever":
1108
+ return AeriRetriever
1109
+ elif normalized_type == "evaluator":
1110
+ return AeriEvaluator
1111
+ elif normalized_type == "embedding":
1112
+ return AeriEmbedding
1113
+ elif normalized_type == "guardrail":
1114
+ return AeriGuardrail
1115
+ elif normalized_type == "generation":
1116
+ return AeriGeneration
1117
+ elif normalized_type == "event":
1118
+ return AeriEvent
1119
+ elif normalized_type == "span":
1120
+ return AeriSpan
1121
+ else:
1122
+ return AeriSpan
1123
+
1124
+ @_agnosticcontextmanager
1125
+ def _create_span_with_parent_context(
1126
+ self,
1127
+ *,
1128
+ name: str,
1129
+ parent: Optional[otel_trace_api.Span] = None,
1130
+ remote_parent_span: Optional[otel_trace_api.Span] = None,
1131
+ as_type: ObservationTypeLiteralNoEvent,
1132
+ end_on_exit: Optional[bool] = None,
1133
+ input: Optional[Any] = None,
1134
+ output: Optional[Any] = None,
1135
+ metadata: Optional[Any] = None,
1136
+ version: Optional[str] = None,
1137
+ level: Optional[SpanLevel] = None,
1138
+ status_message: Optional[str] = None,
1139
+ completion_start_time: Optional[datetime] = None,
1140
+ model: Optional[str] = None,
1141
+ model_parameters: Optional[Dict[str, MapValue]] = None,
1142
+ usage_details: Optional[Dict[str, int]] = None,
1143
+ cost_details: Optional[Dict[str, float]] = None,
1144
+ prompt: Optional[PromptClient] = None,
1145
+ ) -> Any:
1146
+ parent_span = parent or cast(otel_trace_api.Span, remote_parent_span)
1147
+
1148
+ with otel_trace_api.use_span(parent_span):
1149
+ with self._start_as_current_otel_span_with_processed_media(
1150
+ name=name,
1151
+ as_type=as_type,
1152
+ end_on_exit=end_on_exit,
1153
+ input=input,
1154
+ output=output,
1155
+ metadata=metadata,
1156
+ version=version,
1157
+ level=level,
1158
+ status_message=status_message,
1159
+ completion_start_time=completion_start_time,
1160
+ model=model,
1161
+ model_parameters=model_parameters,
1162
+ usage_details=usage_details,
1163
+ cost_details=cost_details,
1164
+ prompt=prompt,
1165
+ ) as aeri_span:
1166
+ if remote_parent_span is not None:
1167
+ aeri_span._otel_span.set_attribute(
1168
+ AeriOtelSpanAttributes.AS_ROOT, True
1169
+ )
1170
+
1171
+ yield aeri_span
1172
+
1173
+ @_agnosticcontextmanager
1174
+ def _start_as_current_otel_span_with_processed_media(
1175
+ self,
1176
+ *,
1177
+ name: str,
1178
+ as_type: Optional[ObservationTypeLiteralNoEvent] = None,
1179
+ end_on_exit: Optional[bool] = None,
1180
+ input: Optional[Any] = None,
1181
+ output: Optional[Any] = None,
1182
+ metadata: Optional[Any] = None,
1183
+ version: Optional[str] = None,
1184
+ level: Optional[SpanLevel] = None,
1185
+ status_message: Optional[str] = None,
1186
+ completion_start_time: Optional[datetime] = None,
1187
+ model: Optional[str] = None,
1188
+ model_parameters: Optional[Dict[str, MapValue]] = None,
1189
+ usage_details: Optional[Dict[str, int]] = None,
1190
+ cost_details: Optional[Dict[str, float]] = None,
1191
+ prompt: Optional[PromptClient] = None,
1192
+ ) -> Any:
1193
+ with self._otel_tracer.start_as_current_span(
1194
+ name=name,
1195
+ end_on_exit=end_on_exit if end_on_exit is not None else True,
1196
+ ) as otel_span:
1197
+ span_class = self._get_span_class(
1198
+ as_type or "generation"
1199
+ ) # default was "generation"
1200
+ common_args = {
1201
+ "otel_span": otel_span,
1202
+ "aeri_client": self,
1203
+ "environment": self._environment,
1204
+ "release": self._release,
1205
+ "input": input,
1206
+ "output": output,
1207
+ "metadata": metadata,
1208
+ "version": version,
1209
+ "level": level,
1210
+ "status_message": status_message,
1211
+ }
1212
+
1213
+ if span_class in [
1214
+ AeriGeneration,
1215
+ AeriEmbedding,
1216
+ ]:
1217
+ common_args.update(
1218
+ {
1219
+ "completion_start_time": completion_start_time,
1220
+ "model": model,
1221
+ "model_parameters": model_parameters,
1222
+ "usage_details": usage_details,
1223
+ "cost_details": cost_details,
1224
+ "prompt": prompt,
1225
+ }
1226
+ )
1227
+ # For span-like types (span, agent, tool, chain, retriever, evaluator, guardrail), no generation properties needed
1228
+
1229
+ yield span_class(**common_args) # type: ignore[arg-type]
1230
+
1231
+ def _get_current_otel_span(self) -> Optional[otel_trace_api.Span]:
1232
+ current_span = otel_trace_api.get_current_span()
1233
+
1234
+ if current_span is otel_trace_api.INVALID_SPAN:
1235
+ aeri_logger.warning(
1236
+ "Context error: No active span in current context. Operations that depend on an active span will be skipped. "
1237
+ "Ensure spans are created with start_as_current_observation() or that you're operating within an active span context."
1238
+ )
1239
+ return None
1240
+
1241
+ return current_span
1242
+
1243
+ def update_current_generation(
1244
+ self,
1245
+ *,
1246
+ name: Optional[str] = None,
1247
+ input: Optional[Any] = None,
1248
+ output: Optional[Any] = None,
1249
+ metadata: Optional[Any] = None,
1250
+ version: Optional[str] = None,
1251
+ level: Optional[SpanLevel] = None,
1252
+ status_message: Optional[str] = None,
1253
+ completion_start_time: Optional[datetime] = None,
1254
+ model: Optional[str] = None,
1255
+ model_parameters: Optional[Dict[str, MapValue]] = None,
1256
+ usage_details: Optional[Dict[str, int]] = None,
1257
+ cost_details: Optional[Dict[str, float]] = None,
1258
+ prompt: Optional[PromptClient] = None,
1259
+ ) -> None:
1260
+ """Update the current active generation span with new information.
1261
+
1262
+ This method updates the current generation span in the active context with
1263
+ additional information. It's useful for adding output, usage stats, or other
1264
+ details that become available during or after model generation.
1265
+
1266
+ Args:
1267
+ name: The generation name
1268
+ input: Updated input data for the model
1269
+ output: Output from the model (e.g., completions)
1270
+ metadata: Additional metadata to associate with the generation
1271
+ version: Version identifier for the model or component
1272
+ level: Importance level of the generation (info, warning, error)
1273
+ status_message: Optional status message for the generation
1274
+ completion_start_time: When the model started generating the response
1275
+ model: Name/identifier of the AI model used (e.g., "gpt-4")
1276
+ model_parameters: Parameters used for the model (e.g., temperature, max_tokens)
1277
+ usage_details: Token usage information (e.g., prompt_tokens, completion_tokens)
1278
+ cost_details: Cost information for the model call
1279
+ prompt: Associated prompt template from Aeri prompt management
1280
+
1281
+ Example:
1282
+ ```python
1283
+ with aeri.start_as_current_generation(name="answer-query") as generation:
1284
+ # Initial setup and API call
1285
+ response = llm.generate(...)
1286
+
1287
+ # Update with results that weren't available at creation time
1288
+ aeri.update_current_generation(
1289
+ output=response.text,
1290
+ usage_details={
1291
+ "prompt_tokens": response.usage.prompt_tokens,
1292
+ "completion_tokens": response.usage.completion_tokens
1293
+ }
1294
+ )
1295
+ ```
1296
+ """
1297
+ if not self._tracing_enabled:
1298
+ aeri_logger.debug(
1299
+ "Operation skipped: update_current_generation - Tracing is disabled or client is in no-op mode."
1300
+ )
1301
+ return
1302
+
1303
+ current_otel_span = self._get_current_otel_span()
1304
+
1305
+ if current_otel_span is not None:
1306
+ generation = AeriGeneration(
1307
+ otel_span=current_otel_span, aeri_client=self
1308
+ )
1309
+
1310
+ if name:
1311
+ current_otel_span.update_name(name)
1312
+
1313
+ generation.update(
1314
+ input=input,
1315
+ output=output,
1316
+ metadata=metadata,
1317
+ version=version,
1318
+ level=level,
1319
+ status_message=status_message,
1320
+ completion_start_time=completion_start_time,
1321
+ model=model,
1322
+ model_parameters=model_parameters,
1323
+ usage_details=usage_details,
1324
+ cost_details=cost_details,
1325
+ prompt=prompt,
1326
+ )
1327
+
1328
+ def update_current_span(
1329
+ self,
1330
+ *,
1331
+ name: Optional[str] = None,
1332
+ input: Optional[Any] = None,
1333
+ output: Optional[Any] = None,
1334
+ metadata: Optional[Any] = None,
1335
+ version: Optional[str] = None,
1336
+ level: Optional[SpanLevel] = None,
1337
+ status_message: Optional[str] = None,
1338
+ ) -> None:
1339
+ """Update the current active span with new information.
1340
+
1341
+ This method updates the current span in the active context with
1342
+ additional information. It's useful for adding outputs or metadata
1343
+ that become available during execution.
1344
+
1345
+ Args:
1346
+ name: The span name
1347
+ input: Updated input data for the operation
1348
+ output: Output data from the operation
1349
+ metadata: Additional metadata to associate with the span
1350
+ version: Version identifier for the code or component
1351
+ level: Importance level of the span (info, warning, error)
1352
+ status_message: Optional status message for the span
1353
+
1354
+ Example:
1355
+ ```python
1356
+ with aeri.start_as_current_observation(name="process-data") as span:
1357
+ # Initial processing
1358
+ result = process_first_part()
1359
+
1360
+ # Update with intermediate results
1361
+ aeri.update_current_span(metadata={"intermediate_result": result})
1362
+
1363
+ # Continue processing
1364
+ final_result = process_second_part(result)
1365
+
1366
+ # Final update
1367
+ aeri.update_current_span(output=final_result)
1368
+ ```
1369
+ """
1370
+ if not self._tracing_enabled:
1371
+ aeri_logger.debug(
1372
+ "Operation skipped: update_current_span - Tracing is disabled or client is in no-op mode."
1373
+ )
1374
+ return
1375
+
1376
+ current_otel_span = self._get_current_otel_span()
1377
+
1378
+ if current_otel_span is not None:
1379
+ span = AeriSpan(
1380
+ otel_span=current_otel_span,
1381
+ aeri_client=self,
1382
+ environment=self._environment,
1383
+ release=self._release,
1384
+ )
1385
+
1386
+ if name:
1387
+ current_otel_span.update_name(name)
1388
+
1389
+ span.update(
1390
+ input=input,
1391
+ output=output,
1392
+ metadata=metadata,
1393
+ version=version,
1394
+ level=level,
1395
+ status_message=status_message,
1396
+ )
1397
+
1398
+ @deprecated(
1399
+ "Trace-level input/output is deprecated. "
1400
+ "For trace attributes (user_id, session_id, tags, etc.), use propagate_attributes() instead. "
1401
+ "This method will be removed in a future major version."
1402
+ )
1403
+ def set_current_trace_io(
1404
+ self,
1405
+ *,
1406
+ input: Optional[Any] = None,
1407
+ output: Optional[Any] = None,
1408
+ ) -> None:
1409
+ """Set trace-level input and output for the current span's trace.
1410
+
1411
+ .. deprecated::
1412
+ This is a legacy method for backward compatibility with Aeri platform
1413
+ features that still rely on trace-level input/output (e.g., legacy LLM-as-a-judge
1414
+ evaluators). It will be removed in a future major version.
1415
+
1416
+ For setting other trace attributes (user_id, session_id, metadata, tags, version),
1417
+ use :meth:`propagate_attributes` instead.
1418
+
1419
+ Args:
1420
+ input: Input data to associate with the trace.
1421
+ output: Output data to associate with the trace.
1422
+ """
1423
+ if not self._tracing_enabled:
1424
+ aeri_logger.debug(
1425
+ "Operation skipped: set_current_trace_io - Tracing is disabled or client is in no-op mode."
1426
+ )
1427
+ return
1428
+
1429
+ current_otel_span = self._get_current_otel_span()
1430
+
1431
+ if current_otel_span is not None and current_otel_span.is_recording():
1432
+ existing_observation_type = current_otel_span.attributes.get( # type: ignore[attr-defined]
1433
+ AeriOtelSpanAttributes.OBSERVATION_TYPE, "span"
1434
+ )
1435
+ # We need to preserve the class to keep the correct observation type
1436
+ span_class = self._get_span_class(existing_observation_type)
1437
+ span = span_class(
1438
+ otel_span=current_otel_span,
1439
+ aeri_client=self,
1440
+ environment=self._environment,
1441
+ release=self._release,
1442
+ )
1443
+
1444
+ span.set_trace_io(
1445
+ input=input,
1446
+ output=output,
1447
+ )
1448
+
1449
+ def set_current_trace_as_public(self) -> None:
1450
+ """Make the current trace publicly accessible via its URL.
1451
+
1452
+ When a trace is published, anyone with the trace link can view the full trace
1453
+ without needing to be logged in to Aeri. This action cannot be undone
1454
+ programmatically - once published, the entire trace becomes public.
1455
+
1456
+ This is a convenience method that publishes the trace from the currently
1457
+ active span context. Use this when you want to make a trace public from
1458
+ within a traced function without needing direct access to the span object.
1459
+ """
1460
+ if not self._tracing_enabled:
1461
+ aeri_logger.debug(
1462
+ "Operation skipped: set_current_trace_as_public - Tracing is disabled or client is in no-op mode."
1463
+ )
1464
+ return
1465
+
1466
+ current_otel_span = self._get_current_otel_span()
1467
+
1468
+ if current_otel_span is not None and current_otel_span.is_recording():
1469
+ existing_observation_type = current_otel_span.attributes.get( # type: ignore[attr-defined]
1470
+ AeriOtelSpanAttributes.OBSERVATION_TYPE, "span"
1471
+ )
1472
+ # We need to preserve the class to keep the correct observation type
1473
+ span_class = self._get_span_class(existing_observation_type)
1474
+ span = span_class(
1475
+ otel_span=current_otel_span,
1476
+ aeri_client=self,
1477
+ environment=self._environment,
1478
+ )
1479
+
1480
+ span.set_trace_as_public()
1481
+
1482
+ def create_event(
1483
+ self,
1484
+ *,
1485
+ trace_context: Optional[TraceContext] = None,
1486
+ name: str,
1487
+ input: Optional[Any] = None,
1488
+ output: Optional[Any] = None,
1489
+ metadata: Optional[Any] = None,
1490
+ version: Optional[str] = None,
1491
+ level: Optional[SpanLevel] = None,
1492
+ status_message: Optional[str] = None,
1493
+ ) -> AeriEvent:
1494
+ """Create a new Aeri observation of type 'EVENT'.
1495
+
1496
+ The created Aeri Event observation will be the child of the current span in the context.
1497
+
1498
+ Args:
1499
+ trace_context: Optional context for connecting to an existing trace
1500
+ name: Name of the span (e.g., function or operation name)
1501
+ input: Input data for the operation (can be any JSON-serializable object)
1502
+ output: Output data from the operation (can be any JSON-serializable object)
1503
+ metadata: Additional metadata to associate with the span
1504
+ version: Version identifier for the code or component
1505
+ level: Importance level of the span (info, warning, error)
1506
+ status_message: Optional status message for the span
1507
+
1508
+ Returns:
1509
+ The Aeri Event object
1510
+
1511
+ Example:
1512
+ ```python
1513
+ event = aeri.create_event(name="process-event")
1514
+ ```
1515
+ """
1516
+ timestamp = time_ns()
1517
+
1518
+ if trace_context:
1519
+ trace_id = trace_context.get("trace_id", None)
1520
+ parent_span_id = trace_context.get("parent_span_id", None)
1521
+
1522
+ if trace_id:
1523
+ remote_parent_span = self._create_remote_parent_span(
1524
+ trace_id=trace_id, parent_span_id=parent_span_id
1525
+ )
1526
+
1527
+ with otel_trace_api.use_span(
1528
+ cast(otel_trace_api.Span, remote_parent_span)
1529
+ ):
1530
+ otel_span = self._otel_tracer.start_span(
1531
+ name=name, start_time=timestamp
1532
+ )
1533
+ otel_span.set_attribute(AeriOtelSpanAttributes.AS_ROOT, True)
1534
+
1535
+ return cast(
1536
+ AeriEvent,
1537
+ AeriEvent(
1538
+ otel_span=otel_span,
1539
+ aeri_client=self,
1540
+ environment=self._environment,
1541
+ release=self._release,
1542
+ input=input,
1543
+ output=output,
1544
+ metadata=metadata,
1545
+ version=version,
1546
+ level=level,
1547
+ status_message=status_message,
1548
+ ).end(end_time=timestamp),
1549
+ )
1550
+
1551
+ otel_span = self._otel_tracer.start_span(name=name, start_time=timestamp)
1552
+
1553
+ return cast(
1554
+ AeriEvent,
1555
+ AeriEvent(
1556
+ otel_span=otel_span,
1557
+ aeri_client=self,
1558
+ environment=self._environment,
1559
+ release=self._release,
1560
+ input=input,
1561
+ output=output,
1562
+ metadata=metadata,
1563
+ version=version,
1564
+ level=level,
1565
+ status_message=status_message,
1566
+ ).end(end_time=timestamp),
1567
+ )
1568
+
1569
+ def _create_remote_parent_span(
1570
+ self, *, trace_id: str, parent_span_id: Optional[str]
1571
+ ) -> Any:
1572
+ if not self._is_valid_trace_id(trace_id):
1573
+ aeri_logger.warning(
1574
+ f"Passed trace ID '{trace_id}' is not a valid 32 lowercase hex char Aeri trace id. Ignoring trace ID."
1575
+ )
1576
+
1577
+ if parent_span_id and not self._is_valid_span_id(parent_span_id):
1578
+ aeri_logger.warning(
1579
+ f"Passed span ID '{parent_span_id}' is not a valid 16 lowercase hex char Aeri span id. Ignoring parent span ID."
1580
+ )
1581
+
1582
+ int_trace_id = int(trace_id, 16)
1583
+ int_parent_span_id = (
1584
+ int(parent_span_id, 16)
1585
+ if parent_span_id
1586
+ else RandomIdGenerator().generate_span_id()
1587
+ )
1588
+
1589
+ span_context = otel_trace_api.SpanContext(
1590
+ trace_id=int_trace_id,
1591
+ span_id=int_parent_span_id,
1592
+ trace_flags=otel_trace_api.TraceFlags(0x01), # mark span as sampled
1593
+ is_remote=False,
1594
+ )
1595
+
1596
+ return otel_trace_api.NonRecordingSpan(span_context)
1597
+
1598
+ def _is_valid_trace_id(self, trace_id: str) -> bool:
1599
+ pattern = r"^[0-9a-f]{32}$"
1600
+
1601
+ return bool(re.match(pattern, trace_id))
1602
+
1603
+ def _is_valid_span_id(self, span_id: str) -> bool:
1604
+ pattern = r"^[0-9a-f]{16}$"
1605
+
1606
+ return bool(re.match(pattern, span_id))
1607
+
1608
+ def _create_observation_id(self, *, seed: Optional[str] = None) -> str:
1609
+ """Create a unique observation ID for use with Aeri.
1610
+
1611
+ This method generates a unique observation ID (span ID in OpenTelemetry terms)
1612
+ for use with various Aeri APIs. It can either generate a random ID or
1613
+ create a deterministic ID based on a seed string.
1614
+
1615
+ Observation IDs must be 16 lowercase hexadecimal characters, representing 8 bytes.
1616
+ This method ensures the generated ID meets this requirement. If you need to
1617
+ correlate an external ID with a Aeri observation ID, use the external ID as
1618
+ the seed to get a valid, deterministic observation ID.
1619
+
1620
+ Args:
1621
+ seed: Optional string to use as a seed for deterministic ID generation.
1622
+ If provided, the same seed will always produce the same ID.
1623
+ If not provided, a random ID will be generated.
1624
+
1625
+ Returns:
1626
+ A 16-character lowercase hexadecimal string representing the observation ID.
1627
+
1628
+ Example:
1629
+ ```python
1630
+ # Generate a random observation ID
1631
+ obs_id = aeri.create_observation_id()
1632
+
1633
+ # Generate a deterministic ID based on a seed
1634
+ user_obs_id = aeri.create_observation_id(seed="user-123-feedback")
1635
+
1636
+ # Correlate an external item ID with a Aeri observation ID
1637
+ item_id = "item-789012"
1638
+ correlated_obs_id = aeri.create_observation_id(seed=item_id)
1639
+
1640
+ # Use the ID with Aeri APIs
1641
+ aeri.create_score(
1642
+ name="relevance",
1643
+ value=0.95,
1644
+ trace_id=trace_id,
1645
+ observation_id=obs_id
1646
+ )
1647
+ ```
1648
+ """
1649
+ if not seed:
1650
+ span_id_int = RandomIdGenerator().generate_span_id()
1651
+
1652
+ return self._format_otel_span_id(span_id_int)
1653
+
1654
+ return sha256(seed.encode("utf-8")).digest()[:8].hex()
1655
+
1656
+ @staticmethod
1657
+ def create_trace_id(*, seed: Optional[str] = None) -> str:
1658
+ """Create a unique trace ID for use with Aeri.
1659
+
1660
+ This method generates a unique trace ID for use with various Aeri APIs.
1661
+ It can either generate a random ID or create a deterministic ID based on
1662
+ a seed string.
1663
+
1664
+ Trace IDs must be 32 lowercase hexadecimal characters, representing 16 bytes.
1665
+ This method ensures the generated ID meets this requirement. If you need to
1666
+ correlate an external ID with a Aeri trace ID, use the external ID as the
1667
+ seed to get a valid, deterministic Aeri trace ID.
1668
+
1669
+ Args:
1670
+ seed: Optional string to use as a seed for deterministic ID generation.
1671
+ If provided, the same seed will always produce the same ID.
1672
+ If not provided, a random ID will be generated.
1673
+
1674
+ Returns:
1675
+ A 32-character lowercase hexadecimal string representing the Aeri trace ID.
1676
+
1677
+ Example:
1678
+ ```python
1679
+ # Generate a random trace ID
1680
+ trace_id = aeri.create_trace_id()
1681
+
1682
+ # Generate a deterministic ID based on a seed
1683
+ session_trace_id = aeri.create_trace_id(seed="session-456")
1684
+
1685
+ # Correlate an external ID with a Aeri trace ID
1686
+ external_id = "external-system-123456"
1687
+ correlated_trace_id = aeri.create_trace_id(seed=external_id)
1688
+
1689
+ # Use the ID with trace context
1690
+ with aeri.start_as_current_observation(
1691
+ name="process-request",
1692
+ trace_context={"trace_id": trace_id}
1693
+ ) as span:
1694
+ # Operation will be part of the specific trace
1695
+ pass
1696
+ ```
1697
+ """
1698
+ if not seed:
1699
+ trace_id_int = RandomIdGenerator().generate_trace_id()
1700
+
1701
+ return Aeri._format_otel_trace_id(trace_id_int)
1702
+
1703
+ return sha256(seed.encode("utf-8")).digest()[:16].hex()
1704
+
1705
+ def _get_otel_trace_id(self, otel_span: otel_trace_api.Span) -> str:
1706
+ span_context = otel_span.get_span_context()
1707
+
1708
+ return self._format_otel_trace_id(span_context.trace_id)
1709
+
1710
+ def _get_otel_span_id(self, otel_span: otel_trace_api.Span) -> str:
1711
+ span_context = otel_span.get_span_context()
1712
+
1713
+ return self._format_otel_span_id(span_context.span_id)
1714
+
1715
+ @staticmethod
1716
+ def _format_otel_span_id(span_id_int: int) -> str:
1717
+ """Format an integer span ID to a 16-character lowercase hex string.
1718
+
1719
+ Internal method to convert an OpenTelemetry integer span ID to the standard
1720
+ W3C Trace Context format (16-character lowercase hex string).
1721
+
1722
+ Args:
1723
+ span_id_int: 64-bit integer representing a span ID
1724
+
1725
+ Returns:
1726
+ A 16-character lowercase hexadecimal string
1727
+ """
1728
+ return format(span_id_int, "016x")
1729
+
1730
+ @staticmethod
1731
+ def _format_otel_trace_id(trace_id_int: int) -> str:
1732
+ """Format an integer trace ID to a 32-character lowercase hex string.
1733
+
1734
+ Internal method to convert an OpenTelemetry integer trace ID to the standard
1735
+ W3C Trace Context format (32-character lowercase hex string).
1736
+
1737
+ Args:
1738
+ trace_id_int: 128-bit integer representing a trace ID
1739
+
1740
+ Returns:
1741
+ A 32-character lowercase hexadecimal string
1742
+ """
1743
+ return format(trace_id_int, "032x")
1744
+
1745
+ @overload
1746
+ def create_score(
1747
+ self,
1748
+ *,
1749
+ name: str,
1750
+ value: float,
1751
+ session_id: Optional[str] = None,
1752
+ dataset_run_id: Optional[str] = None,
1753
+ trace_id: Optional[str] = None,
1754
+ observation_id: Optional[str] = None,
1755
+ score_id: Optional[str] = None,
1756
+ data_type: Optional[Literal["NUMERIC", "BOOLEAN"]] = None,
1757
+ comment: Optional[str] = None,
1758
+ config_id: Optional[str] = None,
1759
+ metadata: Optional[Any] = None,
1760
+ timestamp: Optional[datetime] = None,
1761
+ ) -> None: ...
1762
+
1763
+ @overload
1764
+ def create_score(
1765
+ self,
1766
+ *,
1767
+ name: str,
1768
+ value: str,
1769
+ session_id: Optional[str] = None,
1770
+ dataset_run_id: Optional[str] = None,
1771
+ trace_id: Optional[str] = None,
1772
+ score_id: Optional[str] = None,
1773
+ observation_id: Optional[str] = None,
1774
+ data_type: Optional[Literal["CATEGORICAL"]] = "CATEGORICAL",
1775
+ comment: Optional[str] = None,
1776
+ config_id: Optional[str] = None,
1777
+ metadata: Optional[Any] = None,
1778
+ timestamp: Optional[datetime] = None,
1779
+ ) -> None: ...
1780
+
1781
+ def create_score(
1782
+ self,
1783
+ *,
1784
+ name: str,
1785
+ value: Union[float, str],
1786
+ session_id: Optional[str] = None,
1787
+ dataset_run_id: Optional[str] = None,
1788
+ trace_id: Optional[str] = None,
1789
+ observation_id: Optional[str] = None,
1790
+ score_id: Optional[str] = None,
1791
+ data_type: Optional[ScoreDataType] = None,
1792
+ comment: Optional[str] = None,
1793
+ config_id: Optional[str] = None,
1794
+ metadata: Optional[Any] = None,
1795
+ timestamp: Optional[datetime] = None,
1796
+ ) -> None:
1797
+ """Create a score for a specific trace or observation.
1798
+
1799
+ This method creates a score for evaluating a Aeri trace or observation. Scores can be
1800
+ used to track quality metrics, user feedback, or automated evaluations.
1801
+
1802
+ Args:
1803
+ name: Name of the score (e.g., "relevance", "accuracy")
1804
+ value: Score value (can be numeric for NUMERIC/BOOLEAN types or string for CATEGORICAL)
1805
+ session_id: ID of the Aeri session to associate the score with
1806
+ dataset_run_id: ID of the Aeri dataset run to associate the score with
1807
+ trace_id: ID of the Aeri trace to associate the score with
1808
+ observation_id: Optional ID of the specific observation to score. Trace ID must be provided too.
1809
+ score_id: Optional custom ID for the score (auto-generated if not provided)
1810
+ data_type: Type of score (NUMERIC, BOOLEAN, or CATEGORICAL)
1811
+ comment: Optional comment or explanation for the score
1812
+ config_id: Optional ID of a score config defined in Aeri
1813
+ metadata: Optional metadata to be attached to the score
1814
+ timestamp: Optional timestamp for the score (defaults to current UTC time)
1815
+
1816
+ Example:
1817
+ ```python
1818
+ # Create a numeric score for accuracy
1819
+ aeri.create_score(
1820
+ name="accuracy",
1821
+ value=0.92,
1822
+ trace_id="abcdef1234567890abcdef1234567890",
1823
+ data_type="NUMERIC",
1824
+ comment="High accuracy with minor irrelevant details"
1825
+ )
1826
+
1827
+ # Create a categorical score for sentiment
1828
+ aeri.create_score(
1829
+ name="sentiment",
1830
+ value="positive",
1831
+ trace_id="abcdef1234567890abcdef1234567890",
1832
+ observation_id="abcdef1234567890",
1833
+ data_type="CATEGORICAL"
1834
+ )
1835
+ ```
1836
+ """
1837
+ if not self._tracing_enabled:
1838
+ return
1839
+
1840
+ # ── Pydantic V2 strict validation before any further processing ──────
1841
+ try:
1842
+ _validated_score = ScoreInput.model_validate(
1843
+ {
1844
+ "name": name,
1845
+ "value": value,
1846
+ "data_type": data_type,
1847
+ "comment": comment,
1848
+ "config_id": config_id,
1849
+ "trace_id": trace_id,
1850
+ "observation_id": observation_id,
1851
+ }
1852
+ )
1853
+ except Exception as validation_exc:
1854
+ err_msg = f"Score validation failed for name={name!r}: {validation_exc}"
1855
+ aeri_logger.error(err_msg)
1856
+ return
1857
+
1858
+ score_id = score_id or self._create_observation_id()
1859
+
1860
+ try:
1861
+ new_body = ScoreBody(
1862
+ id=score_id,
1863
+ session_id=session_id,
1864
+ datasetRunId=dataset_run_id,
1865
+ traceId=_validated_score.trace_id,
1866
+ observationId=_validated_score.observation_id,
1867
+ name=_validated_score.name,
1868
+ value=_validated_score.value,
1869
+ dataType=_validated_score.data_type, # type: ignore
1870
+ comment=_validated_score.comment,
1871
+ configId=_validated_score.config_id,
1872
+ environment=self._environment,
1873
+ metadata=metadata,
1874
+ )
1875
+
1876
+ event = {
1877
+ "id": self.create_trace_id(),
1878
+ "type": "score-create",
1879
+ "timestamp": timestamp or _get_timestamp(),
1880
+ "body": new_body,
1881
+ }
1882
+
1883
+ if self._resources is not None:
1884
+ # Force the score to be in sample if it was for a legacy trace ID, i.e. non-32 hexchar
1885
+ force_sample = (
1886
+ not self._is_valid_trace_id(trace_id) if trace_id else True
1887
+ )
1888
+
1889
+ self._resources.add_score_task(
1890
+ event,
1891
+ force_sample=force_sample,
1892
+ )
1893
+
1894
+ except Exception as e:
1895
+ aeri_logger.exception(
1896
+ f"Error creating score: Failed to process score event for trace_id={trace_id}, name={name}. Error: {e}"
1897
+ )
1898
+
1899
+ def _create_trace_tags_via_ingestion(
1900
+ self,
1901
+ *,
1902
+ trace_id: str,
1903
+ tags: List[str],
1904
+ ) -> None:
1905
+ """Private helper to enqueue trace tag updates via ingestion API events."""
1906
+ if not self._tracing_enabled:
1907
+ return
1908
+
1909
+ if len(tags) == 0:
1910
+ return
1911
+
1912
+ try:
1913
+ new_body = TraceBody(
1914
+ id=trace_id,
1915
+ tags=tags,
1916
+ )
1917
+
1918
+ event = {
1919
+ "id": self.create_trace_id(),
1920
+ "type": "trace-create",
1921
+ "timestamp": _get_timestamp(),
1922
+ "body": new_body,
1923
+ }
1924
+
1925
+ if self._resources is not None:
1926
+ self._resources.add_trace_task(event)
1927
+ except Exception as e:
1928
+ aeri_logger.exception(
1929
+ f"Error updating trace tags: Failed to process trace update event for trace_id={trace_id}. Error: {e}"
1930
+ )
1931
+
1932
+ @overload
1933
+ def score_current_span(
1934
+ self,
1935
+ *,
1936
+ name: str,
1937
+ value: float,
1938
+ score_id: Optional[str] = None,
1939
+ data_type: Optional[Literal["NUMERIC", "BOOLEAN"]] = None,
1940
+ comment: Optional[str] = None,
1941
+ config_id: Optional[str] = None,
1942
+ metadata: Optional[Any] = None,
1943
+ ) -> None: ...
1944
+
1945
+ @overload
1946
+ def score_current_span(
1947
+ self,
1948
+ *,
1949
+ name: str,
1950
+ value: str,
1951
+ score_id: Optional[str] = None,
1952
+ data_type: Optional[Literal["CATEGORICAL"]] = "CATEGORICAL",
1953
+ comment: Optional[str] = None,
1954
+ config_id: Optional[str] = None,
1955
+ metadata: Optional[Any] = None,
1956
+ ) -> None: ...
1957
+
1958
+ def score_current_span(
1959
+ self,
1960
+ *,
1961
+ name: str,
1962
+ value: Union[float, str],
1963
+ score_id: Optional[str] = None,
1964
+ data_type: Optional[ScoreDataType] = None,
1965
+ comment: Optional[str] = None,
1966
+ config_id: Optional[str] = None,
1967
+ metadata: Optional[Any] = None,
1968
+ ) -> None:
1969
+ """Create a score for the current active span.
1970
+
1971
+ This method scores the currently active span in the context. It's a convenient
1972
+ way to score the current operation without needing to know its trace and span IDs.
1973
+
1974
+ Args:
1975
+ name: Name of the score (e.g., "relevance", "accuracy")
1976
+ value: Score value (can be numeric for NUMERIC/BOOLEAN types or string for CATEGORICAL)
1977
+ score_id: Optional custom ID for the score (auto-generated if not provided)
1978
+ data_type: Type of score (NUMERIC, BOOLEAN, or CATEGORICAL)
1979
+ comment: Optional comment or explanation for the score
1980
+ config_id: Optional ID of a score config defined in Aeri
1981
+ metadata: Optional metadata to be attached to the score
1982
+
1983
+ Example:
1984
+ ```python
1985
+ with aeri.start_as_current_generation(name="answer-query") as generation:
1986
+ # Generate answer
1987
+ response = generate_answer(...)
1988
+ generation.update(output=response)
1989
+
1990
+ # Score the generation
1991
+ aeri.score_current_span(
1992
+ name="relevance",
1993
+ value=0.85,
1994
+ data_type="NUMERIC",
1995
+ comment="Mostly relevant but contains some tangential information",
1996
+ metadata={"model": "gpt-4", "prompt_version": "v2"}
1997
+ )
1998
+ ```
1999
+ """
2000
+ current_span = self._get_current_otel_span()
2001
+
2002
+ if current_span is not None:
2003
+ trace_id = self._get_otel_trace_id(current_span)
2004
+ observation_id = self._get_otel_span_id(current_span)
2005
+
2006
+ aeri_logger.info(
2007
+ f"Score: Creating score name='{name}' value={value} for current span ({observation_id}) in trace {trace_id}"
2008
+ )
2009
+
2010
+ self.create_score(
2011
+ trace_id=trace_id,
2012
+ observation_id=observation_id,
2013
+ name=name,
2014
+ value=cast(str, value),
2015
+ score_id=score_id,
2016
+ data_type=cast(Literal["CATEGORICAL"], data_type),
2017
+ comment=comment,
2018
+ config_id=config_id,
2019
+ metadata=metadata,
2020
+ )
2021
+
2022
+ @overload
2023
+ def score_current_trace(
2024
+ self,
2025
+ *,
2026
+ name: str,
2027
+ value: float,
2028
+ score_id: Optional[str] = None,
2029
+ data_type: Optional[Literal["NUMERIC", "BOOLEAN"]] = None,
2030
+ comment: Optional[str] = None,
2031
+ config_id: Optional[str] = None,
2032
+ metadata: Optional[Any] = None,
2033
+ ) -> None: ...
2034
+
2035
+ @overload
2036
+ def score_current_trace(
2037
+ self,
2038
+ *,
2039
+ name: str,
2040
+ value: str,
2041
+ score_id: Optional[str] = None,
2042
+ data_type: Optional[Literal["CATEGORICAL"]] = "CATEGORICAL",
2043
+ comment: Optional[str] = None,
2044
+ config_id: Optional[str] = None,
2045
+ metadata: Optional[Any] = None,
2046
+ ) -> None: ...
2047
+
2048
+ def score_current_trace(
2049
+ self,
2050
+ *,
2051
+ name: str,
2052
+ value: Union[float, str],
2053
+ score_id: Optional[str] = None,
2054
+ data_type: Optional[ScoreDataType] = None,
2055
+ comment: Optional[str] = None,
2056
+ config_id: Optional[str] = None,
2057
+ metadata: Optional[Any] = None,
2058
+ ) -> None:
2059
+ """Create a score for the current trace.
2060
+
2061
+ This method scores the trace of the currently active span. Unlike score_current_span,
2062
+ this method associates the score with the entire trace rather than a specific span.
2063
+ It's useful for scoring overall performance or quality of the entire operation.
2064
+
2065
+ Args:
2066
+ name: Name of the score (e.g., "user_satisfaction", "overall_quality")
2067
+ value: Score value (can be numeric for NUMERIC/BOOLEAN types or string for CATEGORICAL)
2068
+ score_id: Optional custom ID for the score (auto-generated if not provided)
2069
+ data_type: Type of score (NUMERIC, BOOLEAN, or CATEGORICAL)
2070
+ comment: Optional comment or explanation for the score
2071
+ config_id: Optional ID of a score config defined in Aeri
2072
+ metadata: Optional metadata to be attached to the score
2073
+
2074
+ Example:
2075
+ ```python
2076
+ with aeri.start_as_current_observation(name="process-user-request") as span:
2077
+ # Process request
2078
+ result = process_complete_request()
2079
+ span.update(output=result)
2080
+
2081
+ # Score the overall trace
2082
+ aeri.score_current_trace(
2083
+ name="overall_quality",
2084
+ value=0.95,
2085
+ data_type="NUMERIC",
2086
+ comment="High quality end-to-end response",
2087
+ metadata={"evaluator": "gpt-4", "criteria": "comprehensive"}
2088
+ )
2089
+ ```
2090
+ """
2091
+ current_span = self._get_current_otel_span()
2092
+
2093
+ if current_span is not None:
2094
+ trace_id = self._get_otel_trace_id(current_span)
2095
+
2096
+ aeri_logger.info(
2097
+ f"Score: Creating score name='{name}' value={value} for entire trace {trace_id}"
2098
+ )
2099
+
2100
+ self.create_score(
2101
+ trace_id=trace_id,
2102
+ name=name,
2103
+ value=cast(str, value),
2104
+ score_id=score_id,
2105
+ data_type=cast(Literal["CATEGORICAL"], data_type),
2106
+ comment=comment,
2107
+ config_id=config_id,
2108
+ metadata=metadata,
2109
+ )
2110
+
2111
+ def flush(self) -> None:
2112
+ """Force flush all pending spans and events to the Aeri API.
2113
+
2114
+ This method manually flushes any pending spans, scores, and other events to the
2115
+ Aeri API. It's useful in scenarios where you want to ensure all data is sent
2116
+ before proceeding, without waiting for the automatic flush interval.
2117
+
2118
+ Example:
2119
+ ```python
2120
+ # Record some spans and scores
2121
+ with aeri.start_as_current_observation(name="operation") as span:
2122
+ # Do work...
2123
+ pass
2124
+
2125
+ # Ensure all data is sent to Aeri before proceeding
2126
+ aeri.flush()
2127
+
2128
+ # Continue with other work
2129
+ ```
2130
+ """
2131
+ if self._resources is not None:
2132
+ self._resources.flush()
2133
+
2134
+ def shutdown(self) -> None:
2135
+ """Shut down the Aeri client and flush all pending data.
2136
+
2137
+ This method cleanly shuts down the Aeri client, ensuring all pending data
2138
+ is flushed to the API and all background threads are properly terminated.
2139
+
2140
+ It's important to call this method when your application is shutting down to
2141
+ prevent data loss and resource leaks. For most applications, using the client
2142
+ as a context manager or relying on the automatic shutdown via atexit is sufficient.
2143
+
2144
+ Example:
2145
+ ```python
2146
+ # Initialize Aeri
2147
+ aeri = Aeri(public_key="...", secret_key="...")
2148
+
2149
+ # Use Aeri throughout your application
2150
+ # ...
2151
+
2152
+ # When application is shutting down
2153
+ aeri.shutdown()
2154
+ ```
2155
+ """
2156
+ if self._resources is not None:
2157
+ self._resources.shutdown()
2158
+
2159
+ def get_current_trace_id(self) -> Optional[str]:
2160
+ """Get the trace ID of the current active span.
2161
+
2162
+ This method retrieves the trace ID from the currently active span in the context.
2163
+ It can be used to get the trace ID for referencing in logs, external systems,
2164
+ or for creating related operations.
2165
+
2166
+ Returns:
2167
+ The current trace ID as a 32-character lowercase hexadecimal string,
2168
+ or None if there is no active span.
2169
+
2170
+ Example:
2171
+ ```python
2172
+ with aeri.start_as_current_observation(name="process-request") as span:
2173
+ # Get the current trace ID for reference
2174
+ trace_id = aeri.get_current_trace_id()
2175
+
2176
+ # Use it for external correlation
2177
+ log.info(f"Processing request with trace_id: {trace_id}")
2178
+
2179
+ # Or pass to another system
2180
+ external_system.process(data, trace_id=trace_id)
2181
+ ```
2182
+ """
2183
+ if not self._tracing_enabled:
2184
+ aeri_logger.debug(
2185
+ "Operation skipped: get_current_trace_id - Tracing is disabled or client is in no-op mode."
2186
+ )
2187
+ return None
2188
+
2189
+ current_otel_span = self._get_current_otel_span()
2190
+
2191
+ return self._get_otel_trace_id(current_otel_span) if current_otel_span else None
2192
+
2193
+ def get_current_observation_id(self) -> Optional[str]:
2194
+ """Get the observation ID (span ID) of the current active span.
2195
+
2196
+ This method retrieves the observation ID from the currently active span in the context.
2197
+ It can be used to get the observation ID for referencing in logs, external systems,
2198
+ or for creating scores or other related operations.
2199
+
2200
+ Returns:
2201
+ The current observation ID as a 16-character lowercase hexadecimal string,
2202
+ or None if there is no active span.
2203
+
2204
+ Example:
2205
+ ```python
2206
+ with aeri.start_as_current_observation(name="process-user-query") as span:
2207
+ # Get the current observation ID
2208
+ observation_id = aeri.get_current_observation_id()
2209
+
2210
+ # Store it for later reference
2211
+ cache.set(f"query_{query_id}_observation", observation_id)
2212
+
2213
+ # Process the query...
2214
+ ```
2215
+ """
2216
+ if not self._tracing_enabled:
2217
+ aeri_logger.debug(
2218
+ "Operation skipped: get_current_observation_id - Tracing is disabled or client is in no-op mode."
2219
+ )
2220
+ return None
2221
+
2222
+ current_otel_span = self._get_current_otel_span()
2223
+
2224
+ return self._get_otel_span_id(current_otel_span) if current_otel_span else None
2225
+
2226
+ def _get_project_id(self) -> Optional[str]:
2227
+ """Fetch and return the current project id. Persisted across requests. Returns None if no project id is found for api keys."""
2228
+ if not self._project_id:
2229
+ proj = self.api.projects.get()
2230
+ if not proj.data or not proj.data[0].id:
2231
+ return None
2232
+
2233
+ self._project_id = proj.data[0].id
2234
+
2235
+ return self._project_id
2236
+
2237
+ def get_trace_url(self, *, trace_id: Optional[str] = None) -> Optional[str]:
2238
+ """Get the URL to view a trace in the Aeri UI.
2239
+
2240
+ This method generates a URL that links directly to a trace in the Aeri UI.
2241
+ It's useful for providing links in logs, notifications, or debugging tools.
2242
+
2243
+ Args:
2244
+ trace_id: Optional trace ID to generate a URL for. If not provided,
2245
+ the trace ID of the current active span will be used.
2246
+
2247
+ Returns:
2248
+ A URL string pointing to the trace in the Aeri UI,
2249
+ or None if the project ID couldn't be retrieved or no trace ID is available.
2250
+
2251
+ Example:
2252
+ ```python
2253
+ # Get URL for the current trace
2254
+ with aeri.start_as_current_observation(name="process-request") as span:
2255
+ trace_url = aeri.get_trace_url()
2256
+ log.info(f"Processing trace: {trace_url}")
2257
+
2258
+ # Get URL for a specific trace
2259
+ specific_trace_url = aeri.get_trace_url(trace_id="1234567890abcdef1234567890abcdef")
2260
+ send_notification(f"Review needed for trace: {specific_trace_url}")
2261
+ ```
2262
+ """
2263
+ final_trace_id = trace_id or self.get_current_trace_id()
2264
+ if not final_trace_id:
2265
+ return None
2266
+
2267
+ project_id = self._get_project_id()
2268
+
2269
+ return (
2270
+ f"{self._base_url}/project/{project_id}/traces/{final_trace_id}"
2271
+ if project_id and final_trace_id
2272
+ else None
2273
+ )
2274
+
2275
+ def get_dataset(
2276
+ self,
2277
+ name: str,
2278
+ *,
2279
+ fetch_items_page_size: Optional[int] = 50,
2280
+ version: Optional[datetime] = None,
2281
+ ) -> "DatasetClient":
2282
+ """Fetch a dataset by its name.
2283
+
2284
+ Args:
2285
+ name (str): The name of the dataset to fetch.
2286
+ fetch_items_page_size (Optional[int]): All items of the dataset will be fetched in chunks of this size. Defaults to 50.
2287
+ version (Optional[datetime]): Retrieve dataset items as they existed at this specific point in time (UTC).
2288
+ If provided, returns the state of items at the specified UTC timestamp.
2289
+ If not provided, returns the latest version. Must be a timezone-aware datetime object in UTC.
2290
+
2291
+ Returns:
2292
+ DatasetClient: The dataset with the given name.
2293
+ """
2294
+ try:
2295
+ aeri_logger.debug(f"Getting datasets {name}")
2296
+ dataset = self.api.datasets.get(dataset_name=self._url_encode(name))
2297
+
2298
+ dataset_items = []
2299
+ page = 1
2300
+
2301
+ while True:
2302
+ new_items = self.api.dataset_items.list(
2303
+ dataset_name=self._url_encode(name, is_url_param=True),
2304
+ page=page,
2305
+ limit=fetch_items_page_size,
2306
+ version=version,
2307
+ )
2308
+ dataset_items.extend(new_items.data)
2309
+
2310
+ if new_items.meta.total_pages <= page:
2311
+ break
2312
+
2313
+ page += 1
2314
+
2315
+ return DatasetClient(
2316
+ dataset=dataset,
2317
+ items=dataset_items,
2318
+ version=version,
2319
+ aeri_client=self,
2320
+ )
2321
+
2322
+ except Error as e:
2323
+ handle_fern_exception(e)
2324
+ raise e
2325
+
2326
+ def get_dataset_run(
2327
+ self, *, dataset_name: str, run_name: str
2328
+ ) -> DatasetRunWithItems:
2329
+ """Fetch a dataset run by dataset name and run name.
2330
+
2331
+ Args:
2332
+ dataset_name (str): The name of the dataset.
2333
+ run_name (str): The name of the run.
2334
+
2335
+ Returns:
2336
+ DatasetRunWithItems: The dataset run with its items.
2337
+ """
2338
+ try:
2339
+ return cast(
2340
+ DatasetRunWithItems,
2341
+ self.api.datasets.get_run(
2342
+ dataset_name=self._url_encode(dataset_name),
2343
+ run_name=self._url_encode(run_name),
2344
+ request_options=None,
2345
+ ),
2346
+ )
2347
+ except Error as e:
2348
+ handle_fern_exception(e)
2349
+ raise e
2350
+
2351
+ def get_dataset_runs(
2352
+ self,
2353
+ *,
2354
+ dataset_name: str,
2355
+ page: Optional[int] = None,
2356
+ limit: Optional[int] = None,
2357
+ ) -> PaginatedDatasetRuns:
2358
+ """Fetch all runs for a dataset.
2359
+
2360
+ Args:
2361
+ dataset_name (str): The name of the dataset.
2362
+ page (Optional[int]): Page number, starts at 1.
2363
+ limit (Optional[int]): Limit of items per page.
2364
+
2365
+ Returns:
2366
+ PaginatedDatasetRuns: Paginated list of dataset runs.
2367
+ """
2368
+ try:
2369
+ return cast(
2370
+ PaginatedDatasetRuns,
2371
+ self.api.datasets.get_runs(
2372
+ dataset_name=self._url_encode(dataset_name),
2373
+ page=page,
2374
+ limit=limit,
2375
+ request_options=None,
2376
+ ),
2377
+ )
2378
+ except Error as e:
2379
+ handle_fern_exception(e)
2380
+ raise e
2381
+
2382
+ def delete_dataset_run(
2383
+ self, *, dataset_name: str, run_name: str
2384
+ ) -> DeleteDatasetRunResponse:
2385
+ """Delete a dataset run and all its run items. This action is irreversible.
2386
+
2387
+ Args:
2388
+ dataset_name (str): The name of the dataset.
2389
+ run_name (str): The name of the run.
2390
+
2391
+ Returns:
2392
+ DeleteDatasetRunResponse: Confirmation of deletion.
2393
+ """
2394
+ try:
2395
+ return cast(
2396
+ DeleteDatasetRunResponse,
2397
+ self.api.datasets.delete_run(
2398
+ dataset_name=self._url_encode(dataset_name),
2399
+ run_name=self._url_encode(run_name),
2400
+ request_options=None,
2401
+ ),
2402
+ )
2403
+ except Error as e:
2404
+ handle_fern_exception(e)
2405
+ raise e
2406
+
2407
+ def run_experiment(
2408
+ self,
2409
+ *,
2410
+ name: str,
2411
+ run_name: Optional[str] = None,
2412
+ description: Optional[str] = None,
2413
+ data: ExperimentData,
2414
+ task: TaskFunction,
2415
+ evaluators: List[EvaluatorFunction] = [],
2416
+ composite_evaluator: Optional[CompositeEvaluatorFunction] = None,
2417
+ run_evaluators: List[RunEvaluatorFunction] = [],
2418
+ max_concurrency: int = 50,
2419
+ metadata: Optional[Dict[str, str]] = None,
2420
+ _dataset_version: Optional[datetime] = None,
2421
+ ) -> ExperimentResult:
2422
+ """Run an experiment on a dataset with automatic tracing and evaluation.
2423
+
2424
+ This method executes a task function on each item in the provided dataset,
2425
+ automatically traces all executions with Aeri for observability, runs
2426
+ item-level and run-level evaluators on the outputs, and returns comprehensive
2427
+ results with evaluation metrics.
2428
+
2429
+ The experiment system provides:
2430
+ - Automatic tracing of all task executions
2431
+ - Concurrent processing with configurable limits
2432
+ - Comprehensive error handling that isolates failures
2433
+ - Integration with Aeri datasets for experiment tracking
2434
+ - Flexible evaluation framework supporting both sync and async evaluators
2435
+
2436
+ Args:
2437
+ name: Human-readable name for the experiment. Used for identification
2438
+ in the Aeri UI.
2439
+ run_name: Optional exact name for the experiment run. If provided, this will be
2440
+ used as the exact dataset run name if the `data` contains Aeri dataset items.
2441
+ If not provided, this will default to the experiment name appended with an ISO timestamp.
2442
+ description: Optional description explaining the experiment's purpose,
2443
+ methodology, or expected outcomes.
2444
+ data: Array of data items to process. Can be either:
2445
+ - List of dict-like items with 'input', 'expected_output', 'metadata' keys
2446
+ - List of Aeri DatasetItem objects from dataset.items
2447
+ task: Function that processes each data item and returns output.
2448
+ Must accept 'item' as keyword argument and can return sync or async results.
2449
+ The task function signature should be: task(*, item, **kwargs) -> Any
2450
+ evaluators: List of functions to evaluate each item's output individually.
2451
+ Each evaluator receives input, output, expected_output, and metadata.
2452
+ Can return single Evaluation dict or list of Evaluation dicts.
2453
+ composite_evaluator: Optional function that creates composite scores from item-level evaluations.
2454
+ Receives the same inputs as item-level evaluators (input, output, expected_output, metadata)
2455
+ plus the list of evaluations from item-level evaluators. Useful for weighted averages,
2456
+ pass/fail decisions based on multiple criteria, or custom scoring logic combining multiple metrics.
2457
+ run_evaluators: List of functions to evaluate the entire experiment run.
2458
+ Each run evaluator receives all item_results and can compute aggregate metrics.
2459
+ Useful for calculating averages, distributions, or cross-item comparisons.
2460
+ max_concurrency: Maximum number of concurrent task executions (default: 50).
2461
+ Controls the number of items processed simultaneously. Adjust based on
2462
+ API rate limits and system resources.
2463
+ metadata: Optional metadata dictionary to attach to all experiment traces.
2464
+ This metadata will be included in every trace created during the experiment.
2465
+ If `data` are Aeri dataset items, the metadata will be attached to the dataset run, too.
2466
+
2467
+ Returns:
2468
+ ExperimentResult containing:
2469
+ - run_name: The experiment run name. This is equal to the dataset run name if experiment was on Aeri dataset.
2470
+ - item_results: List of results for each processed item with outputs and evaluations
2471
+ - run_evaluations: List of aggregate evaluation results for the entire run
2472
+ - dataset_run_id: ID of the dataset run (if using Aeri datasets)
2473
+ - dataset_run_url: Direct URL to view results in Aeri UI (if applicable)
2474
+
2475
+ Raises:
2476
+ ValueError: If required parameters are missing or invalid
2477
+ Exception: If experiment setup fails (individual item failures are handled gracefully)
2478
+
2479
+ Examples:
2480
+ Basic experiment with local data:
2481
+ ```python
2482
+ def summarize_text(*, item, **kwargs):
2483
+ return f"Summary: {item['input'][:50]}..."
2484
+
2485
+ def length_evaluator(*, input, output, expected_output=None, **kwargs):
2486
+ return {
2487
+ "name": "output_length",
2488
+ "value": len(output),
2489
+ "comment": f"Output contains {len(output)} characters"
2490
+ }
2491
+
2492
+ result = aeri.run_experiment(
2493
+ name="Text Summarization Test",
2494
+ description="Evaluate summarization quality and length",
2495
+ data=[
2496
+ {"input": "Long article text...", "expected_output": "Expected summary"},
2497
+ {"input": "Another article...", "expected_output": "Another summary"}
2498
+ ],
2499
+ task=summarize_text,
2500
+ evaluators=[length_evaluator]
2501
+ )
2502
+
2503
+ print(f"Processed {len(result.item_results)} items")
2504
+ for item_result in result.item_results:
2505
+ print(f"Input: {item_result.item['input']}")
2506
+ print(f"Output: {item_result.output}")
2507
+ print(f"Evaluations: {item_result.evaluations}")
2508
+ ```
2509
+
2510
+ Advanced experiment with async task and multiple evaluators:
2511
+ ```python
2512
+ async def llm_task(*, item, **kwargs):
2513
+ # Simulate async LLM call
2514
+ response = await openai_client.chat.completions.create(
2515
+ model="gpt-4",
2516
+ messages=[{"role": "user", "content": item["input"]}]
2517
+ )
2518
+ return response.choices[0].message.content
2519
+
2520
+ def accuracy_evaluator(*, input, output, expected_output=None, **kwargs):
2521
+ if expected_output and expected_output.lower() in output.lower():
2522
+ return {"name": "accuracy", "value": 1.0, "comment": "Correct answer"}
2523
+ return {"name": "accuracy", "value": 0.0, "comment": "Incorrect answer"}
2524
+
2525
+ def toxicity_evaluator(*, input, output, expected_output=None, **kwargs):
2526
+ # Simulate toxicity check
2527
+ toxicity_score = check_toxicity(output) # Your toxicity checker
2528
+ return {
2529
+ "name": "toxicity",
2530
+ "value": toxicity_score,
2531
+ "comment": f"Toxicity level: {'high' if toxicity_score > 0.7 else 'low'}"
2532
+ }
2533
+
2534
+ def average_accuracy(*, item_results, **kwargs):
2535
+ accuracies = [
2536
+ eval.value for result in item_results
2537
+ for eval in result.evaluations
2538
+ if eval.name == "accuracy"
2539
+ ]
2540
+ return {
2541
+ "name": "average_accuracy",
2542
+ "value": sum(accuracies) / len(accuracies) if accuracies else 0,
2543
+ "comment": f"Average accuracy across {len(accuracies)} items"
2544
+ }
2545
+
2546
+ result = aeri.run_experiment(
2547
+ name="LLM Safety and Accuracy Test",
2548
+ description="Evaluate model accuracy and safety across diverse prompts",
2549
+ data=test_dataset, # Your dataset items
2550
+ task=llm_task,
2551
+ evaluators=[accuracy_evaluator, toxicity_evaluator],
2552
+ run_evaluators=[average_accuracy],
2553
+ max_concurrency=5, # Limit concurrent API calls
2554
+ metadata={"model": "gpt-4", "temperature": 0.7}
2555
+ )
2556
+ ```
2557
+
2558
+ Using with Aeri datasets:
2559
+ ```python
2560
+ # Get dataset from Aeri
2561
+ dataset = aeri.get_dataset("my-eval-dataset")
2562
+
2563
+ result = dataset.run_experiment(
2564
+ name="Production Model Evaluation",
2565
+ description="Monthly evaluation of production model performance",
2566
+ task=my_production_task,
2567
+ evaluators=[accuracy_evaluator, latency_evaluator]
2568
+ )
2569
+
2570
+ # Results automatically linked to dataset in Aeri UI
2571
+ print(f"View results: {result['dataset_run_url']}")
2572
+ ```
2573
+
2574
+ Note:
2575
+ - Task and evaluator functions can be either synchronous or asynchronous
2576
+ - Individual item failures are logged but don't stop the experiment
2577
+ - All executions are automatically traced and visible in Aeri UI
2578
+ - When using Aeri datasets, results are automatically linked for easy comparison
2579
+ - This method works in both sync and async contexts (Jupyter notebooks, web apps, etc.)
2580
+ - Async execution is handled automatically with smart event loop detection
2581
+ """
2582
+ return cast(
2583
+ ExperimentResult,
2584
+ run_async_safely(
2585
+ self._run_experiment_async(
2586
+ name=name,
2587
+ run_name=self._create_experiment_run_name(
2588
+ name=name, run_name=run_name
2589
+ ),
2590
+ description=description,
2591
+ data=data,
2592
+ task=task,
2593
+ evaluators=evaluators or [],
2594
+ composite_evaluator=composite_evaluator,
2595
+ run_evaluators=run_evaluators or [],
2596
+ max_concurrency=max_concurrency,
2597
+ metadata=metadata,
2598
+ dataset_version=_dataset_version,
2599
+ ),
2600
+ ),
2601
+ )
2602
+
2603
+ async def _run_experiment_async(
2604
+ self,
2605
+ *,
2606
+ name: str,
2607
+ run_name: str,
2608
+ description: Optional[str],
2609
+ data: ExperimentData,
2610
+ task: TaskFunction,
2611
+ evaluators: List[EvaluatorFunction],
2612
+ composite_evaluator: Optional[CompositeEvaluatorFunction],
2613
+ run_evaluators: List[RunEvaluatorFunction],
2614
+ max_concurrency: int,
2615
+ metadata: Optional[Dict[str, Any]] = None,
2616
+ dataset_version: Optional[datetime] = None,
2617
+ ) -> ExperimentResult:
2618
+ aeri_logger.debug(
2619
+ f"Starting experiment '{name}' run '{run_name}' with {len(data)} items"
2620
+ )
2621
+
2622
+ # Set up concurrency control
2623
+ semaphore = asyncio.Semaphore(max_concurrency)
2624
+
2625
+ # Process all items
2626
+ async def process_item(item: ExperimentItem) -> ExperimentItemResult:
2627
+ async with semaphore:
2628
+ return await self._process_experiment_item(
2629
+ item,
2630
+ task,
2631
+ evaluators,
2632
+ composite_evaluator,
2633
+ name,
2634
+ run_name,
2635
+ description,
2636
+ metadata,
2637
+ dataset_version,
2638
+ )
2639
+
2640
+ # Run all items concurrently
2641
+ tasks = [process_item(item) for item in data]
2642
+ item_results = await asyncio.gather(*tasks, return_exceptions=True)
2643
+
2644
+ # Filter out any exceptions and log errors
2645
+ valid_results: List[ExperimentItemResult] = []
2646
+ for i, result in enumerate(item_results):
2647
+ if isinstance(result, Exception):
2648
+ aeri_logger.error(f"Item {i} failed: {result}")
2649
+ elif isinstance(result, ExperimentItemResult):
2650
+ valid_results.append(result) # type: ignore
2651
+
2652
+ # Run experiment-level evaluators
2653
+ run_evaluations: List[Evaluation] = []
2654
+ for run_evaluator in run_evaluators:
2655
+ try:
2656
+ evaluations = await _run_evaluator(
2657
+ run_evaluator, item_results=valid_results
2658
+ )
2659
+ run_evaluations.extend(evaluations)
2660
+ except Exception as e:
2661
+ aeri_logger.error(f"Run evaluator failed: {e}")
2662
+
2663
+ # Generate dataset run URL if applicable
2664
+ dataset_run_id = valid_results[0].dataset_run_id if valid_results else None
2665
+ dataset_run_url = None
2666
+ if dataset_run_id and data:
2667
+ try:
2668
+ # Check if the first item has dataset_id (for DatasetItem objects)
2669
+ first_item = data[0]
2670
+ dataset_id = None
2671
+
2672
+ if hasattr(first_item, "dataset_id"):
2673
+ dataset_id = getattr(first_item, "dataset_id", None)
2674
+
2675
+ if dataset_id:
2676
+ project_id = self._get_project_id()
2677
+
2678
+ if project_id:
2679
+ dataset_run_url = f"{self._base_url}/project/{project_id}/datasets/{dataset_id}/runs/{dataset_run_id}"
2680
+
2681
+ except Exception:
2682
+ pass # URL generation is optional
2683
+
2684
+ # Store run-level evaluations as scores
2685
+ for evaluation in run_evaluations:
2686
+ try:
2687
+ if dataset_run_id:
2688
+ self.create_score(
2689
+ dataset_run_id=dataset_run_id,
2690
+ name=evaluation.name or "<unknown>",
2691
+ value=evaluation.value, # type: ignore
2692
+ comment=evaluation.comment,
2693
+ metadata=evaluation.metadata,
2694
+ data_type=evaluation.data_type, # type: ignore
2695
+ config_id=evaluation.config_id,
2696
+ )
2697
+
2698
+ except Exception as e:
2699
+ aeri_logger.error(f"Failed to store run evaluation: {e}")
2700
+
2701
+ # Flush scores and traces
2702
+ self.flush()
2703
+
2704
+ return ExperimentResult(
2705
+ name=name,
2706
+ run_name=run_name,
2707
+ description=description,
2708
+ item_results=valid_results,
2709
+ run_evaluations=run_evaluations,
2710
+ dataset_run_id=dataset_run_id,
2711
+ dataset_run_url=dataset_run_url,
2712
+ )
2713
+
2714
+ async def _process_experiment_item(
2715
+ self,
2716
+ item: ExperimentItem,
2717
+ task: Callable,
2718
+ evaluators: List[Callable],
2719
+ composite_evaluator: Optional[CompositeEvaluatorFunction],
2720
+ experiment_name: str,
2721
+ experiment_run_name: str,
2722
+ experiment_description: Optional[str],
2723
+ experiment_metadata: Optional[Dict[str, Any]] = None,
2724
+ dataset_version: Optional[datetime] = None,
2725
+ ) -> ExperimentItemResult:
2726
+ span_name = "experiment-item-run"
2727
+
2728
+ with self.start_as_current_observation(name=span_name) as span:
2729
+ try:
2730
+ input_data = (
2731
+ item.get("input")
2732
+ if isinstance(item, dict)
2733
+ else getattr(item, "input", None)
2734
+ )
2735
+
2736
+ if input_data is None:
2737
+ raise ValueError("Experiment Item is missing input. Skipping item.")
2738
+
2739
+ expected_output = (
2740
+ item.get("expected_output")
2741
+ if isinstance(item, dict)
2742
+ else getattr(item, "expected_output", None)
2743
+ )
2744
+
2745
+ item_metadata = (
2746
+ item.get("metadata")
2747
+ if isinstance(item, dict)
2748
+ else getattr(item, "metadata", None)
2749
+ )
2750
+
2751
+ final_observation_metadata = {
2752
+ "experiment_name": experiment_name,
2753
+ "experiment_run_name": experiment_run_name,
2754
+ **(experiment_metadata or {}),
2755
+ }
2756
+
2757
+ trace_id = span.trace_id
2758
+ dataset_id = None
2759
+ dataset_item_id = None
2760
+ dataset_run_id = None
2761
+
2762
+ # Link to dataset run if this is a dataset item
2763
+ if hasattr(item, "id") and hasattr(item, "dataset_id"):
2764
+ try:
2765
+ # Use sync API to avoid event loop issues when run_async_safely
2766
+ # creates multiple event loops across different threads
2767
+ dataset_run_item = await asyncio.to_thread(
2768
+ self.api.dataset_run_items.create,
2769
+ run_name=experiment_run_name,
2770
+ run_description=experiment_description,
2771
+ metadata=experiment_metadata,
2772
+ dataset_item_id=item.id, # type: ignore
2773
+ trace_id=trace_id,
2774
+ observation_id=span.id,
2775
+ dataset_version=dataset_version,
2776
+ )
2777
+
2778
+ dataset_run_id = dataset_run_item.dataset_run_id
2779
+
2780
+ except Exception as e:
2781
+ aeri_logger.error(f"Failed to create dataset run item: {e}")
2782
+
2783
+ if (
2784
+ not isinstance(item, dict)
2785
+ and hasattr(item, "dataset_id")
2786
+ and hasattr(item, "id")
2787
+ ):
2788
+ dataset_id = item.dataset_id
2789
+ dataset_item_id = item.id
2790
+
2791
+ final_observation_metadata.update(
2792
+ {"dataset_id": dataset_id, "dataset_item_id": dataset_item_id}
2793
+ )
2794
+
2795
+ if isinstance(item_metadata, dict):
2796
+ final_observation_metadata.update(item_metadata)
2797
+
2798
+ experiment_id = dataset_run_id or self._create_observation_id()
2799
+ experiment_item_id = (
2800
+ dataset_item_id or get_sha256_hash_hex(_serialize(input_data))[:16]
2801
+ )
2802
+ span._otel_span.set_attributes(
2803
+ {
2804
+ k: v
2805
+ for k, v in {
2806
+ AeriOtelSpanAttributes.ENVIRONMENT: AERI_SDK_EXPERIMENT_ENVIRONMENT,
2807
+ AeriOtelSpanAttributes.EXPERIMENT_DESCRIPTION: experiment_description,
2808
+ AeriOtelSpanAttributes.EXPERIMENT_ITEM_EXPECTED_OUTPUT: _serialize(
2809
+ expected_output
2810
+ ),
2811
+ }.items()
2812
+ if v is not None
2813
+ }
2814
+ )
2815
+
2816
+ propagated_experiment_attributes = PropagatedExperimentAttributes(
2817
+ experiment_id=experiment_id,
2818
+ experiment_name=experiment_run_name,
2819
+ experiment_metadata=_serialize(experiment_metadata),
2820
+ experiment_dataset_id=dataset_id,
2821
+ experiment_item_id=experiment_item_id,
2822
+ experiment_item_metadata=_serialize(item_metadata),
2823
+ experiment_item_root_observation_id=span.id,
2824
+ )
2825
+
2826
+ with _propagate_attributes(experiment=propagated_experiment_attributes):
2827
+ output = await _run_task(task, item)
2828
+
2829
+ span.update(
2830
+ input=input_data,
2831
+ output=output,
2832
+ metadata=final_observation_metadata,
2833
+ )
2834
+
2835
+ except Exception as e:
2836
+ span.update(
2837
+ output=f"Error: {str(e)}", level="ERROR", status_message=str(e)
2838
+ )
2839
+ raise e
2840
+
2841
+ # Run evaluators
2842
+ evaluations = []
2843
+
2844
+ for evaluator in evaluators:
2845
+ try:
2846
+ eval_metadata: Optional[Dict[str, Any]] = None
2847
+
2848
+ if isinstance(item, dict):
2849
+ eval_metadata = item.get("metadata")
2850
+ elif hasattr(item, "metadata"):
2851
+ eval_metadata = item.metadata
2852
+
2853
+ with _propagate_attributes(
2854
+ experiment=propagated_experiment_attributes
2855
+ ):
2856
+ eval_results = await _run_evaluator(
2857
+ evaluator,
2858
+ input=input_data,
2859
+ output=output,
2860
+ expected_output=expected_output,
2861
+ metadata=eval_metadata,
2862
+ )
2863
+ evaluations.extend(eval_results)
2864
+
2865
+ # Store evaluations as scores
2866
+ for evaluation in eval_results:
2867
+ self.create_score(
2868
+ trace_id=trace_id,
2869
+ observation_id=span.id,
2870
+ name=evaluation.name,
2871
+ value=evaluation.value, # type: ignore
2872
+ comment=evaluation.comment,
2873
+ metadata=evaluation.metadata,
2874
+ config_id=evaluation.config_id,
2875
+ data_type=evaluation.data_type, # type: ignore
2876
+ )
2877
+
2878
+ except Exception as e:
2879
+ aeri_logger.error(f"Evaluator failed: {e}")
2880
+
2881
+ # Run composite evaluator if provided and we have evaluations
2882
+ if composite_evaluator and evaluations:
2883
+ try:
2884
+ composite_eval_metadata: Optional[Dict[str, Any]] = None
2885
+ if isinstance(item, dict):
2886
+ composite_eval_metadata = item.get("metadata")
2887
+ elif hasattr(item, "metadata"):
2888
+ composite_eval_metadata = item.metadata
2889
+
2890
+ with _propagate_attributes(
2891
+ experiment=propagated_experiment_attributes
2892
+ ):
2893
+ result = composite_evaluator(
2894
+ input=input_data,
2895
+ output=output,
2896
+ expected_output=expected_output,
2897
+ metadata=composite_eval_metadata,
2898
+ evaluations=evaluations,
2899
+ )
2900
+
2901
+ # Handle async composite evaluators
2902
+ if asyncio.iscoroutine(result):
2903
+ result = await result
2904
+
2905
+ # Normalize to list
2906
+ composite_evals: List[Evaluation] = []
2907
+ if isinstance(result, (dict, Evaluation)):
2908
+ composite_evals = [result] # type: ignore
2909
+ elif isinstance(result, list):
2910
+ composite_evals = result # type: ignore
2911
+
2912
+ # Store composite evaluations as scores and add to evaluations list
2913
+ for composite_evaluation in composite_evals:
2914
+ self.create_score(
2915
+ trace_id=trace_id,
2916
+ observation_id=span.id,
2917
+ name=composite_evaluation.name,
2918
+ value=composite_evaluation.value, # type: ignore
2919
+ comment=composite_evaluation.comment,
2920
+ metadata=composite_evaluation.metadata,
2921
+ config_id=composite_evaluation.config_id,
2922
+ data_type=composite_evaluation.data_type, # type: ignore
2923
+ )
2924
+ evaluations.append(composite_evaluation)
2925
+
2926
+ except Exception as e:
2927
+ aeri_logger.error(f"Composite evaluator failed: {e}")
2928
+
2929
+ return ExperimentItemResult(
2930
+ item=item,
2931
+ output=output,
2932
+ evaluations=evaluations,
2933
+ trace_id=trace_id,
2934
+ dataset_run_id=dataset_run_id,
2935
+ )
2936
+
2937
+ def _create_experiment_run_name(
2938
+ self, *, name: Optional[str] = None, run_name: Optional[str] = None
2939
+ ) -> str:
2940
+ if run_name:
2941
+ return run_name
2942
+
2943
+ iso_timestamp = _get_timestamp().isoformat().replace("+00:00", "Z")
2944
+
2945
+ return f"{name} - {iso_timestamp}"
2946
+
2947
+ def run_batched_evaluation(
2948
+ self,
2949
+ *,
2950
+ scope: Literal["traces", "observations"],
2951
+ mapper: MapperFunction,
2952
+ filter: Optional[str] = None,
2953
+ fetch_batch_size: int = 50,
2954
+ fetch_trace_fields: Optional[str] = None,
2955
+ max_items: Optional[int] = None,
2956
+ max_retries: int = 3,
2957
+ evaluators: List[EvaluatorFunction],
2958
+ composite_evaluator: Optional[CompositeEvaluatorFunction] = None,
2959
+ max_concurrency: int = 5,
2960
+ metadata: Optional[Dict[str, Any]] = None,
2961
+ _add_observation_scores_to_trace: bool = False,
2962
+ _additional_trace_tags: Optional[List[str]] = None,
2963
+ resume_from: Optional[BatchEvaluationResumeToken] = None,
2964
+ verbose: bool = False,
2965
+ ) -> BatchEvaluationResult:
2966
+ """Fetch traces or observations and run evaluations on each item.
2967
+
2968
+ This method provides a powerful way to evaluate existing data in Aeri at scale.
2969
+ It fetches items based on filters, transforms them using a mapper function, runs
2970
+ evaluators on each item, and creates scores that are linked back to the original
2971
+ entities. This is ideal for:
2972
+
2973
+ - Running evaluations on production traces after deployment
2974
+ - Backtesting new evaluation metrics on historical data
2975
+ - Batch scoring of observations for quality monitoring
2976
+ - Periodic evaluation runs on recent data
2977
+
2978
+ The method uses a streaming/pipeline approach to process items in batches, making
2979
+ it memory-efficient for large datasets. It includes comprehensive error handling,
2980
+ retry logic, and resume capability for long-running evaluations.
2981
+
2982
+ Args:
2983
+ scope: The type of items to evaluate. Must be one of:
2984
+ - "traces": Evaluate complete traces with all their observations
2985
+ - "observations": Evaluate individual observations (spans, generations, events)
2986
+ mapper: Function that transforms API response objects into evaluator inputs.
2987
+ Receives a trace/observation object and returns an EvaluatorInputs
2988
+ instance with input, output, expected_output, and metadata fields.
2989
+ Can be sync or async.
2990
+ evaluators: List of evaluation functions to run on each item. Each evaluator
2991
+ receives the mapped inputs and returns Evaluation object(s). Evaluator
2992
+ failures are logged but don't stop the batch evaluation.
2993
+ filter: Optional JSON filter string for querying items (same format as Aeri API). Examples:
2994
+ - '{"tags": ["production"]}'
2995
+ - '{"user_id": "user123", "timestamp": {"operator": ">", "value": "2024-01-01"}}'
2996
+ Default: None (fetches all items).
2997
+ fetch_batch_size: Number of items to fetch per API call and hold in memory.
2998
+ Larger values may be faster but use more memory. Default: 50.
2999
+ fetch_trace_fields: Comma-separated list of fields to include when fetching traces. Available field groups: 'core' (always included), 'io' (input, output, metadata), 'scores', 'observations', 'metrics'. If not specified, all fields are returned. Example: 'core,scores,metrics'. Note: Excluded 'observations' or 'scores' fields return empty arrays; excluded 'metrics' returns -1 for 'totalCost' and 'latency'. Only relevant if scope is 'traces'.
3000
+ max_items: Maximum total number of items to process. If None, processes all
3001
+ items matching the filter. Useful for testing or limiting evaluation runs.
3002
+ Default: None (process all).
3003
+ max_concurrency: Maximum number of items to evaluate concurrently. Controls
3004
+ parallelism and resource usage. Default: 5.
3005
+ composite_evaluator: Optional function that creates a composite score from
3006
+ item-level evaluations. Receives the original item and its evaluations,
3007
+ returns a single Evaluation. Useful for weighted averages or combined metrics.
3008
+ Default: None.
3009
+ metadata: Optional metadata dict to add to all created scores. Useful for
3010
+ tracking evaluation runs, versions, or other context. Default: None.
3011
+ max_retries: Maximum number of retry attempts for failed batch fetches.
3012
+ Uses exponential backoff (1s, 2s, 4s). Default: 3.
3013
+ verbose: If True, logs progress information to console. Useful for monitoring
3014
+ long-running evaluations. Default: False.
3015
+ resume_from: Optional resume token from a previous incomplete run. Allows
3016
+ continuing evaluation after interruption or failure. Default: None.
3017
+
3018
+
3019
+ Returns:
3020
+ BatchEvaluationResult containing:
3021
+ - total_items_fetched: Number of items fetched from API
3022
+ - total_items_processed: Number of items successfully evaluated
3023
+ - total_items_failed: Number of items that failed evaluation
3024
+ - total_scores_created: Scores created by item-level evaluators
3025
+ - total_composite_scores_created: Scores created by composite evaluator
3026
+ - total_evaluations_failed: Individual evaluator failures
3027
+ - evaluator_stats: Per-evaluator statistics (success rate, scores created)
3028
+ - resume_token: Token for resuming if incomplete (None if completed)
3029
+ - completed: True if all items processed
3030
+ - duration_seconds: Total execution time
3031
+ - failed_item_ids: IDs of items that failed
3032
+ - error_summary: Error types and counts
3033
+ - has_more_items: True if max_items reached but more exist
3034
+
3035
+ Raises:
3036
+ ValueError: If invalid scope is provided.
3037
+
3038
+ Examples:
3039
+ Basic trace evaluation:
3040
+ ```python
3041
+ from aeri import Aeri, EvaluatorInputs, Evaluation
3042
+
3043
+ client = Aeri()
3044
+
3045
+ # Define mapper to extract fields from traces
3046
+ def trace_mapper(trace):
3047
+ return EvaluatorInputs(
3048
+ input=trace.input,
3049
+ output=trace.output,
3050
+ expected_output=None,
3051
+ metadata={"trace_id": trace.id}
3052
+ )
3053
+
3054
+ # Define evaluator
3055
+ def length_evaluator(*, input, output, expected_output, metadata):
3056
+ return Evaluation(
3057
+ name="output_length",
3058
+ value=len(output) if output else 0
3059
+ )
3060
+
3061
+ # Run batch evaluation
3062
+ result = client.run_batched_evaluation(
3063
+ scope="traces",
3064
+ mapper=trace_mapper,
3065
+ evaluators=[length_evaluator],
3066
+ filter='{"tags": ["production"]}',
3067
+ max_items=1000,
3068
+ verbose=True
3069
+ )
3070
+
3071
+ print(f"Processed {result.total_items_processed} traces")
3072
+ print(f"Created {result.total_scores_created} scores")
3073
+ ```
3074
+
3075
+ Evaluation with composite scorer:
3076
+ ```python
3077
+ def accuracy_evaluator(*, input, output, expected_output, metadata):
3078
+ # ... evaluation logic
3079
+ return Evaluation(name="accuracy", value=0.85)
3080
+
3081
+ def relevance_evaluator(*, input, output, expected_output, metadata):
3082
+ # ... evaluation logic
3083
+ return Evaluation(name="relevance", value=0.92)
3084
+
3085
+ def composite_evaluator(*, item, evaluations):
3086
+ # Weighted average of evaluations
3087
+ weights = {"accuracy": 0.6, "relevance": 0.4}
3088
+ total = sum(
3089
+ e.value * weights.get(e.name, 0)
3090
+ for e in evaluations
3091
+ if isinstance(e.value, (int, float))
3092
+ )
3093
+ return Evaluation(
3094
+ name="composite_score",
3095
+ value=total,
3096
+ comment=f"Weighted average of {len(evaluations)} metrics"
3097
+ )
3098
+
3099
+ result = client.run_batched_evaluation(
3100
+ scope="traces",
3101
+ mapper=trace_mapper,
3102
+ evaluators=[accuracy_evaluator, relevance_evaluator],
3103
+ composite_evaluator=composite_evaluator,
3104
+ filter='{"user_id": "important_user"}',
3105
+ verbose=True
3106
+ )
3107
+ ```
3108
+
3109
+ Handling incomplete runs with resume:
3110
+ ```python
3111
+ # Initial run that may fail or timeout
3112
+ result = client.run_batched_evaluation(
3113
+ scope="observations",
3114
+ mapper=obs_mapper,
3115
+ evaluators=[my_evaluator],
3116
+ max_items=10000,
3117
+ verbose=True
3118
+ )
3119
+
3120
+ # Check if incomplete
3121
+ if not result.completed and result.resume_token:
3122
+ print(f"Processed {result.resume_token.items_processed} items before interruption")
3123
+
3124
+ # Resume from where it left off
3125
+ result = client.run_batched_evaluation(
3126
+ scope="observations",
3127
+ mapper=obs_mapper,
3128
+ evaluators=[my_evaluator],
3129
+ resume_from=result.resume_token,
3130
+ verbose=True
3131
+ )
3132
+
3133
+ print(f"Total items processed: {result.total_items_processed}")
3134
+ ```
3135
+
3136
+ Monitoring evaluator performance:
3137
+ ```python
3138
+ result = client.run_batched_evaluation(...)
3139
+
3140
+ for stats in result.evaluator_stats:
3141
+ success_rate = stats.successful_runs / stats.total_runs
3142
+ print(f"{stats.name}:")
3143
+ print(f" Success rate: {success_rate:.1%}")
3144
+ print(f" Scores created: {stats.total_scores_created}")
3145
+
3146
+ if stats.failed_runs > 0:
3147
+ print(f" ⚠️ Failed {stats.failed_runs} times")
3148
+ ```
3149
+
3150
+ Note:
3151
+ - Evaluator failures are logged but don't stop the batch evaluation
3152
+ - Individual item failures are tracked but don't stop processing
3153
+ - Fetch failures are retried with exponential backoff
3154
+ - All scores are automatically flushed to Aeri at the end
3155
+ - The resume mechanism uses timestamp-based filtering to avoid duplicates
3156
+ """
3157
+ runner = BatchEvaluationRunner(self)
3158
+
3159
+ return cast(
3160
+ BatchEvaluationResult,
3161
+ run_async_safely(
3162
+ runner.run_async(
3163
+ scope=scope,
3164
+ mapper=mapper,
3165
+ evaluators=evaluators,
3166
+ filter=filter,
3167
+ fetch_batch_size=fetch_batch_size,
3168
+ fetch_trace_fields=fetch_trace_fields,
3169
+ max_items=max_items,
3170
+ max_concurrency=max_concurrency,
3171
+ composite_evaluator=composite_evaluator,
3172
+ metadata=metadata,
3173
+ _add_observation_scores_to_trace=_add_observation_scores_to_trace,
3174
+ _additional_trace_tags=_additional_trace_tags,
3175
+ max_retries=max_retries,
3176
+ verbose=verbose,
3177
+ resume_from=resume_from,
3178
+ )
3179
+ ),
3180
+ )
3181
+
3182
+ def auth_check(self) -> bool:
3183
+ """Check if the provided credentials (public and secret key) are valid.
3184
+
3185
+ Raises:
3186
+ Exception: If no projects were found for the provided credentials.
3187
+
3188
+ Note:
3189
+ This method is blocking. It is discouraged to use it in production code.
3190
+ """
3191
+ try:
3192
+ projects = self.api.projects.get()
3193
+ aeri_logger.debug(
3194
+ f"Auth check successful, found {len(projects.data)} projects"
3195
+ )
3196
+ if len(projects.data) == 0:
3197
+ raise Exception(
3198
+ "Auth check failed, no project found for the keys provided."
3199
+ )
3200
+ return True
3201
+
3202
+ except AttributeError as e:
3203
+ aeri_logger.warning(
3204
+ f"Auth check failed: Client not properly initialized. Error: {e}"
3205
+ )
3206
+ return False
3207
+
3208
+ except Error as e:
3209
+ handle_fern_exception(e)
3210
+ raise e
3211
+
3212
+ def create_dataset(
3213
+ self,
3214
+ *,
3215
+ name: str,
3216
+ description: Optional[str] = None,
3217
+ metadata: Optional[Any] = None,
3218
+ input_schema: Optional[Any] = None,
3219
+ expected_output_schema: Optional[Any] = None,
3220
+ ) -> Dataset:
3221
+ """Create a dataset with the given name on Aeri.
3222
+
3223
+ Args:
3224
+ name: Name of the dataset to create.
3225
+ description: Description of the dataset. Defaults to None.
3226
+ metadata: Additional metadata. Defaults to None.
3227
+ input_schema: JSON Schema for validating dataset item inputs. When set, all new items will be validated against this schema.
3228
+ expected_output_schema: JSON Schema for validating dataset item expected outputs. When set, all new items will be validated against this schema.
3229
+
3230
+ Returns:
3231
+ Dataset: The created dataset as returned by the Aeri API.
3232
+ """
3233
+ try:
3234
+ aeri_logger.debug(f"Creating datasets {name}")
3235
+
3236
+ result = self.api.datasets.create(
3237
+ name=name,
3238
+ description=description,
3239
+ metadata=metadata,
3240
+ input_schema=input_schema,
3241
+ expected_output_schema=expected_output_schema,
3242
+ )
3243
+
3244
+ return cast(Dataset, result)
3245
+
3246
+ except Error as e:
3247
+ handle_fern_exception(e)
3248
+ raise e
3249
+
3250
+ def create_dataset_item(
3251
+ self,
3252
+ *,
3253
+ dataset_name: str,
3254
+ input: Optional[Any] = None,
3255
+ expected_output: Optional[Any] = None,
3256
+ metadata: Optional[Any] = None,
3257
+ source_trace_id: Optional[str] = None,
3258
+ source_observation_id: Optional[str] = None,
3259
+ status: Optional[DatasetStatus] = None,
3260
+ id: Optional[str] = None,
3261
+ ) -> DatasetItem:
3262
+ """Create a dataset item.
3263
+
3264
+ Upserts if an item with id already exists.
3265
+
3266
+ Args:
3267
+ dataset_name: Name of the dataset in which the dataset item should be created.
3268
+ input: Input data. Defaults to None. Can contain any dict, list or scalar.
3269
+ expected_output: Expected output data. Defaults to None. Can contain any dict, list or scalar.
3270
+ metadata: Additional metadata. Defaults to None. Can contain any dict, list or scalar.
3271
+ source_trace_id: Id of the source trace. Defaults to None.
3272
+ source_observation_id: Id of the source observation. Defaults to None.
3273
+ status: Status of the dataset item. Defaults to ACTIVE for newly created items.
3274
+ id: Id of the dataset item. Defaults to None. Provide your own id if you want to dedupe dataset items. Id needs to be globally unique and cannot be reused across datasets.
3275
+
3276
+ Returns:
3277
+ DatasetItem: The created dataset item as returned by the Aeri API.
3278
+
3279
+ Example:
3280
+ ```python
3281
+ from aeri import Aeri
3282
+
3283
+ aeri = Aeri()
3284
+
3285
+ # Uploading items to the Aeri dataset named "capital_cities"
3286
+ aeri.create_dataset_item(
3287
+ dataset_name="capital_cities",
3288
+ input={"input": {"country": "Italy"}},
3289
+ expected_output={"expected_output": "Rome"},
3290
+ metadata={"foo": "bar"}
3291
+ )
3292
+ ```
3293
+ """
3294
+ try:
3295
+ aeri_logger.debug(f"Creating dataset item for dataset {dataset_name}")
3296
+
3297
+ result = self.api.dataset_items.create(
3298
+ dataset_name=dataset_name,
3299
+ input=input,
3300
+ expected_output=expected_output,
3301
+ metadata=metadata,
3302
+ source_trace_id=source_trace_id,
3303
+ source_observation_id=source_observation_id,
3304
+ status=status,
3305
+ id=id,
3306
+ )
3307
+
3308
+ return cast(DatasetItem, result)
3309
+ except Error as e:
3310
+ handle_fern_exception(e)
3311
+ raise e
3312
+
3313
+ def resolve_media_references(
3314
+ self,
3315
+ *,
3316
+ obj: Any,
3317
+ resolve_with: Literal["base64_data_uri"],
3318
+ max_depth: int = 10,
3319
+ content_fetch_timeout_seconds: int = 5,
3320
+ ) -> Any:
3321
+ """Replace media reference strings in an object with base64 data URIs.
3322
+
3323
+ This method recursively traverses an object (up to max_depth) looking for media reference strings
3324
+ in the format "@@@aeriMedia:...@@@". When found, it (synchronously) fetches the actual media content using
3325
+ the provided Aeri client and replaces the reference string with a base64 data URI.
3326
+
3327
+ If fetching media content fails for a reference string, a warning is logged and the reference
3328
+ string is left unchanged.
3329
+
3330
+ Args:
3331
+ obj: The object to process. Can be a primitive value, array, or nested object.
3332
+ If the object has a __dict__ attribute, a dict will be returned instead of the original object type.
3333
+ resolve_with: The representation of the media content to replace the media reference string with.
3334
+ Currently only "base64_data_uri" is supported.
3335
+ max_depth: int: The maximum depth to traverse the object. Default is 10.
3336
+ content_fetch_timeout_seconds: int: The timeout in seconds for fetching media content. Default is 5.
3337
+
3338
+ Returns:
3339
+ A deep copy of the input object with all media references replaced with base64 data URIs where possible.
3340
+ If the input object has a __dict__ attribute, a dict will be returned instead of the original object type.
3341
+
3342
+ Example:
3343
+ obj = {
3344
+ "image": "@@@aeriMedia:type=image/jpeg|id=123|source=bytes@@@",
3345
+ "nested": {
3346
+ "pdf": "@@@aeriMedia:type=application/pdf|id=456|source=bytes@@@"
3347
+ }
3348
+ }
3349
+
3350
+ result = await AeriMedia.resolve_media_references(obj, aeri_client)
3351
+
3352
+ # Result:
3353
+ # {
3354
+ # "image": "data:image/jpeg;base64,/9j/4AAQSkZJRg...",
3355
+ # "nested": {
3356
+ # "pdf": "data:application/pdf;base64,JVBERi0xLjcK..."
3357
+ # }
3358
+ # }
3359
+ """
3360
+ return AeriMedia.resolve_media_references(
3361
+ aeri_client=self,
3362
+ obj=obj,
3363
+ resolve_with=resolve_with,
3364
+ max_depth=max_depth,
3365
+ content_fetch_timeout_seconds=content_fetch_timeout_seconds,
3366
+ )
3367
+
3368
+ @overload
3369
+ def get_prompt(
3370
+ self,
3371
+ name: str,
3372
+ *,
3373
+ version: Optional[int] = None,
3374
+ label: Optional[str] = None,
3375
+ type: Literal["chat"],
3376
+ cache_ttl_seconds: Optional[int] = None,
3377
+ fallback: Optional[List[ChatMessageDict]] = None,
3378
+ max_retries: Optional[int] = None,
3379
+ fetch_timeout_seconds: Optional[int] = None,
3380
+ ) -> ChatPromptClient: ...
3381
+
3382
+ @overload
3383
+ def get_prompt(
3384
+ self,
3385
+ name: str,
3386
+ *,
3387
+ version: Optional[int] = None,
3388
+ label: Optional[str] = None,
3389
+ type: Literal["text"] = "text",
3390
+ cache_ttl_seconds: Optional[int] = None,
3391
+ fallback: Optional[str] = None,
3392
+ max_retries: Optional[int] = None,
3393
+ fetch_timeout_seconds: Optional[int] = None,
3394
+ ) -> TextPromptClient: ...
3395
+
3396
+ def get_prompt(
3397
+ self,
3398
+ name: str,
3399
+ *,
3400
+ version: Optional[int] = None,
3401
+ label: Optional[str] = None,
3402
+ type: Literal["chat", "text"] = "text",
3403
+ cache_ttl_seconds: Optional[int] = None,
3404
+ fallback: Union[Optional[List[ChatMessageDict]], Optional[str]] = None,
3405
+ max_retries: Optional[int] = None,
3406
+ fetch_timeout_seconds: Optional[int] = None,
3407
+ ) -> PromptClient:
3408
+ """Get a prompt.
3409
+
3410
+ This method attempts to fetch the requested prompt from the local cache. If the prompt is not found
3411
+ in the cache or if the cached prompt has expired, it will try to fetch the prompt from the server again
3412
+ and update the cache. If fetching the new prompt fails, and there is an expired prompt in the cache, it will
3413
+ return the expired prompt as a fallback.
3414
+
3415
+ Args:
3416
+ name (str): The name of the prompt to retrieve.
3417
+
3418
+ Keyword Args:
3419
+ version (Optional[int]): The version of the prompt to retrieve. If no label and version is specified, the `production` label is returned. Specify either version or label, not both.
3420
+ label: Optional[str]: The label of the prompt to retrieve. If no label and version is specified, the `production` label is returned. Specify either version or label, not both.
3421
+ cache_ttl_seconds: Optional[int]: Time-to-live in seconds for caching the prompt. Must be specified as a
3422
+ keyword argument. If not set, defaults to 60 seconds. Disables caching if set to 0.
3423
+ type: Literal["chat", "text"]: The type of the prompt to retrieve. Defaults to "text".
3424
+ fallback: Union[Optional[List[ChatMessageDict]], Optional[str]]: The prompt string to return if fetching the prompt fails. Important on the first call where no cached prompt is available. Follows Aeri prompt formatting with double curly braces for variables. Defaults to None.
3425
+ max_retries: Optional[int]: The maximum number of retries in case of API/network errors. Defaults to 2. The maximum value is 4. Retries have an exponential backoff with a maximum delay of 10 seconds.
3426
+ fetch_timeout_seconds: Optional[int]: The timeout in milliseconds for fetching the prompt. Defaults to the default timeout set on the SDK, which is 5 seconds per default.
3427
+
3428
+ Returns:
3429
+ The prompt object retrieved from the cache or directly fetched if not cached or expired of type
3430
+ - TextPromptClient, if type argument is 'text'.
3431
+ - ChatPromptClient, if type argument is 'chat'.
3432
+
3433
+ Raises:
3434
+ Exception: Propagates any exceptions raised during the fetching of a new prompt, unless there is an
3435
+ expired prompt in the cache, in which case it logs a warning and returns the expired prompt.
3436
+ """
3437
+ if self._resources is None:
3438
+ raise Error(
3439
+ "SDK is not correctly initialized. Check the init logs for more details."
3440
+ )
3441
+ if version is not None and label is not None:
3442
+ raise ValueError("Cannot specify both version and label at the same time.")
3443
+
3444
+ if not name:
3445
+ raise ValueError("Prompt name cannot be empty.")
3446
+
3447
+ cache_key = PromptCache.generate_cache_key(name, version=version, label=label)
3448
+ bounded_max_retries = self._get_bounded_max_retries(
3449
+ max_retries, default_max_retries=2, max_retries_upper_bound=4
3450
+ )
3451
+
3452
+ aeri_logger.debug(f"Getting prompt '{cache_key}'")
3453
+ cached_prompt = self._resources.prompt_cache.get(cache_key)
3454
+
3455
+ if cached_prompt is None or cache_ttl_seconds == 0:
3456
+ aeri_logger.debug(
3457
+ f"Prompt '{cache_key}' not found in cache or caching disabled."
3458
+ )
3459
+ try:
3460
+ return self._fetch_prompt_and_update_cache(
3461
+ name,
3462
+ version=version,
3463
+ label=label,
3464
+ ttl_seconds=cache_ttl_seconds,
3465
+ max_retries=bounded_max_retries,
3466
+ fetch_timeout_seconds=fetch_timeout_seconds,
3467
+ )
3468
+ except Exception as e:
3469
+ if fallback:
3470
+ aeri_logger.warning(
3471
+ f"Returning fallback prompt for '{cache_key}' due to fetch error: {e}"
3472
+ )
3473
+
3474
+ fallback_client_args: Dict[str, Any] = {
3475
+ "name": name,
3476
+ "prompt": fallback,
3477
+ "type": type,
3478
+ "version": version or 0,
3479
+ "config": {},
3480
+ "labels": [label] if label else [],
3481
+ "tags": [],
3482
+ }
3483
+
3484
+ if type == "text":
3485
+ return TextPromptClient(
3486
+ prompt=Prompt_Text(**fallback_client_args),
3487
+ is_fallback=True,
3488
+ )
3489
+
3490
+ if type == "chat":
3491
+ return ChatPromptClient(
3492
+ prompt=Prompt_Chat(**fallback_client_args),
3493
+ is_fallback=True,
3494
+ )
3495
+
3496
+ raise e
3497
+
3498
+ if cached_prompt.is_expired():
3499
+ aeri_logger.debug(f"Stale prompt '{cache_key}' found in cache.")
3500
+ try:
3501
+ # refresh prompt in background thread, refresh_prompt deduplicates tasks
3502
+ aeri_logger.debug(f"Refreshing prompt '{cache_key}' in background.")
3503
+
3504
+ def refresh_task() -> None:
3505
+ self._fetch_prompt_and_update_cache(
3506
+ name,
3507
+ version=version,
3508
+ label=label,
3509
+ ttl_seconds=cache_ttl_seconds,
3510
+ max_retries=bounded_max_retries,
3511
+ fetch_timeout_seconds=fetch_timeout_seconds,
3512
+ )
3513
+
3514
+ self._resources.prompt_cache.add_refresh_prompt_task(
3515
+ cache_key,
3516
+ refresh_task,
3517
+ )
3518
+ aeri_logger.debug(
3519
+ f"Returning stale prompt '{cache_key}' from cache."
3520
+ )
3521
+ # return stale prompt
3522
+ return cached_prompt.value
3523
+
3524
+ except Exception as e:
3525
+ aeri_logger.warning(
3526
+ f"Error when refreshing cached prompt '{cache_key}', returning cached version. Error: {e}"
3527
+ )
3528
+ # creation of refresh prompt task failed, return stale prompt
3529
+ return cached_prompt.value
3530
+
3531
+ return cached_prompt.value
3532
+
3533
+ def _fetch_prompt_and_update_cache(
3534
+ self,
3535
+ name: str,
3536
+ *,
3537
+ version: Optional[int] = None,
3538
+ label: Optional[str] = None,
3539
+ ttl_seconds: Optional[int] = None,
3540
+ max_retries: int,
3541
+ fetch_timeout_seconds: Optional[int],
3542
+ ) -> PromptClient:
3543
+ cache_key = PromptCache.generate_cache_key(name, version=version, label=label)
3544
+ aeri_logger.debug(f"Fetching prompt '{cache_key}' from server...")
3545
+
3546
+ try:
3547
+
3548
+ @backoff.on_exception(
3549
+ backoff.constant, Exception, max_tries=max_retries + 1, logger=None
3550
+ )
3551
+ def fetch_prompts() -> Any:
3552
+ return self.api.prompts.get(
3553
+ self._url_encode(name),
3554
+ version=version,
3555
+ label=label,
3556
+ request_options={
3557
+ "timeout_in_seconds": fetch_timeout_seconds,
3558
+ }
3559
+ if fetch_timeout_seconds is not None
3560
+ else None,
3561
+ )
3562
+
3563
+ prompt_response = fetch_prompts()
3564
+
3565
+ prompt: PromptClient
3566
+ if prompt_response.type == "chat":
3567
+ prompt = ChatPromptClient(prompt_response)
3568
+ else:
3569
+ prompt = TextPromptClient(prompt_response)
3570
+
3571
+ if self._resources is not None:
3572
+ self._resources.prompt_cache.set(cache_key, prompt, ttl_seconds)
3573
+
3574
+ return prompt
3575
+
3576
+ except NotFoundError as not_found_error:
3577
+ aeri_logger.warning(
3578
+ f"Prompt '{cache_key}' not found during refresh, evicting from cache."
3579
+ )
3580
+ if self._resources is not None:
3581
+ self._resources.prompt_cache.delete(cache_key)
3582
+ raise not_found_error
3583
+
3584
+ except Exception as e:
3585
+ aeri_logger.error(
3586
+ f"Error while fetching prompt '{cache_key}': {str(e)}"
3587
+ )
3588
+ raise e
3589
+
3590
+ def _get_bounded_max_retries(
3591
+ self,
3592
+ max_retries: Optional[int],
3593
+ *,
3594
+ default_max_retries: int = 2,
3595
+ max_retries_upper_bound: int = 4,
3596
+ ) -> int:
3597
+ if max_retries is None:
3598
+ return default_max_retries
3599
+
3600
+ bounded_max_retries = min(
3601
+ max(max_retries, 0),
3602
+ max_retries_upper_bound,
3603
+ )
3604
+
3605
+ return bounded_max_retries
3606
+
3607
+ @overload
3608
+ def create_prompt(
3609
+ self,
3610
+ *,
3611
+ name: str,
3612
+ prompt: List[Union[ChatMessageDict, ChatMessageWithPlaceholdersDict]],
3613
+ labels: List[str] = [],
3614
+ tags: Optional[List[str]] = None,
3615
+ type: Optional[Literal["chat"]],
3616
+ config: Optional[Any] = None,
3617
+ commit_message: Optional[str] = None,
3618
+ ) -> ChatPromptClient: ...
3619
+
3620
+ @overload
3621
+ def create_prompt(
3622
+ self,
3623
+ *,
3624
+ name: str,
3625
+ prompt: str,
3626
+ labels: List[str] = [],
3627
+ tags: Optional[List[str]] = None,
3628
+ type: Optional[Literal["text"]] = "text",
3629
+ config: Optional[Any] = None,
3630
+ commit_message: Optional[str] = None,
3631
+ ) -> TextPromptClient: ...
3632
+
3633
+ def create_prompt(
3634
+ self,
3635
+ *,
3636
+ name: str,
3637
+ prompt: Union[
3638
+ str, List[Union[ChatMessageDict, ChatMessageWithPlaceholdersDict]]
3639
+ ],
3640
+ labels: List[str] = [],
3641
+ tags: Optional[List[str]] = None,
3642
+ type: Optional[Literal["chat", "text"]] = "text",
3643
+ config: Optional[Any] = None,
3644
+ commit_message: Optional[str] = None,
3645
+ ) -> PromptClient:
3646
+ """Create a new prompt in Aeri.
3647
+
3648
+ Keyword Args:
3649
+ name : The name of the prompt to be created.
3650
+ prompt : The content of the prompt to be created.
3651
+ is_active [DEPRECATED] : A flag indicating whether the prompt is active or not. This is deprecated and will be removed in a future release. Please use the 'production' label instead.
3652
+ labels: The labels of the prompt. Defaults to None. To create a default-served prompt, add the 'production' label.
3653
+ tags: The tags of the prompt. Defaults to None. Will be applied to all versions of the prompt.
3654
+ config: Additional structured data to be saved with the prompt. Defaults to None.
3655
+ type: The type of the prompt to be created. "chat" vs. "text". Defaults to "text".
3656
+ commit_message: Optional string describing the change.
3657
+
3658
+ Returns:
3659
+ TextPromptClient: The prompt if type argument is 'text'.
3660
+ ChatPromptClient: The prompt if type argument is 'chat'.
3661
+ """
3662
+ try:
3663
+ aeri_logger.debug(f"Creating prompt {name=}, {labels=}")
3664
+
3665
+ if type == "chat":
3666
+ if not isinstance(prompt, list):
3667
+ raise ValueError(
3668
+ "For 'chat' type, 'prompt' must be a list of chat messages with role and content attributes."
3669
+ )
3670
+ request: Union[CreateChatPromptRequest, CreateTextPromptRequest] = (
3671
+ CreateChatPromptRequest(
3672
+ name=name,
3673
+ prompt=cast(Any, prompt),
3674
+ labels=labels,
3675
+ tags=tags,
3676
+ config=config or {},
3677
+ commit_message=commit_message,
3678
+ type=CreateChatPromptType.CHAT,
3679
+ )
3680
+ )
3681
+ server_prompt = self.api.prompts.create(request=request)
3682
+
3683
+ if self._resources is not None:
3684
+ self._resources.prompt_cache.invalidate(name)
3685
+
3686
+ return ChatPromptClient(prompt=cast(Prompt_Chat, server_prompt))
3687
+
3688
+ if not isinstance(prompt, str):
3689
+ raise ValueError("For 'text' type, 'prompt' must be a string.")
3690
+
3691
+ request = CreateTextPromptRequest(
3692
+ name=name,
3693
+ prompt=prompt,
3694
+ labels=labels,
3695
+ tags=tags,
3696
+ config=config or {},
3697
+ commit_message=commit_message,
3698
+ )
3699
+
3700
+ server_prompt = self.api.prompts.create(request=request)
3701
+
3702
+ if self._resources is not None:
3703
+ self._resources.prompt_cache.invalidate(name)
3704
+
3705
+ return TextPromptClient(prompt=cast(Prompt_Text, server_prompt))
3706
+
3707
+ except Error as e:
3708
+ handle_fern_exception(e)
3709
+ raise e
3710
+
3711
+ def update_prompt(
3712
+ self,
3713
+ *,
3714
+ name: str,
3715
+ version: int,
3716
+ new_labels: List[str] = [],
3717
+ ) -> Any:
3718
+ """Update an existing prompt version in Aeri. The Aeri SDK prompt cache is invalidated for all prompts witht he specified name.
3719
+
3720
+ Args:
3721
+ name (str): The name of the prompt to update.
3722
+ version (int): The version number of the prompt to update.
3723
+ new_labels (List[str], optional): New labels to assign to the prompt version. Labels are unique across versions. The "latest" label is reserved and managed by Aeri. Defaults to [].
3724
+
3725
+ Returns:
3726
+ Prompt: The updated prompt from the Aeri API.
3727
+
3728
+ """
3729
+ updated_prompt = self.api.prompt_version.update(
3730
+ name=self._url_encode(name),
3731
+ version=version,
3732
+ new_labels=new_labels,
3733
+ )
3734
+
3735
+ if self._resources is not None:
3736
+ self._resources.prompt_cache.invalidate(name)
3737
+
3738
+ return updated_prompt
3739
+
3740
+ def _url_encode(self, url: str, *, is_url_param: Optional[bool] = False) -> str:
3741
+ # httpx ≥ 0.28 does its own WHATWG-compliant quoting (eg. encodes bare
3742
+ # “%”, “?”, “#”, “|”, … in query/path parts). Re-quoting here would
3743
+ # double-encode, so we skip when the value is about to be sent straight
3744
+ # to httpx (`is_url_param=True`) and the installed version is ≥ 0.28.
3745
+ if is_url_param and Version(httpx.__version__) >= Version("0.28.0"):
3746
+ return url
3747
+
3748
+ # urllib.parse.quote does not escape slashes "/" by default; we need to add safe="" to force escaping
3749
+ # we need add safe="" to force escaping of slashes
3750
+ # This is necessary for prompts in prompt folders
3751
+ return urllib.parse.quote(url, safe="")
3752
+
3753
+ def clear_prompt_cache(self) -> None:
3754
+ """Clear the entire prompt cache, removing all cached prompts.
3755
+
3756
+ This method is useful when you want to force a complete refresh of all
3757
+ cached prompts, for example after major updates or when you need to
3758
+ ensure the latest versions are fetched from the server.
3759
+ """
3760
+ if self._resources is not None:
3761
+ self._resources.prompt_cache.clear()