arize 8.0.0b2__py3-none-any.whl → 8.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. arize/__init__.py +8 -1
  2. arize/_exporter/client.py +18 -17
  3. arize/_exporter/parsers/tracing_data_parser.py +9 -4
  4. arize/_exporter/validation.py +1 -1
  5. arize/_flight/client.py +33 -13
  6. arize/_lazy.py +37 -2
  7. arize/client.py +61 -35
  8. arize/config.py +168 -14
  9. arize/constants/config.py +1 -0
  10. arize/datasets/client.py +32 -19
  11. arize/embeddings/auto_generator.py +14 -7
  12. arize/embeddings/base_generators.py +15 -9
  13. arize/embeddings/cv_generators.py +2 -2
  14. arize/embeddings/nlp_generators.py +8 -8
  15. arize/embeddings/tabular_generators.py +5 -5
  16. arize/exceptions/config.py +22 -0
  17. arize/exceptions/parameters.py +1 -1
  18. arize/exceptions/values.py +8 -5
  19. arize/experiments/__init__.py +4 -0
  20. arize/experiments/client.py +17 -11
  21. arize/experiments/evaluators/base.py +6 -3
  22. arize/experiments/evaluators/executors.py +6 -4
  23. arize/experiments/evaluators/rate_limiters.py +3 -1
  24. arize/experiments/evaluators/types.py +7 -5
  25. arize/experiments/evaluators/utils.py +7 -5
  26. arize/experiments/functions.py +111 -48
  27. arize/experiments/tracing.py +4 -1
  28. arize/experiments/types.py +31 -26
  29. arize/logging.py +53 -32
  30. arize/ml/batch_validation/validator.py +82 -70
  31. arize/ml/bounded_executor.py +25 -6
  32. arize/ml/casting.py +45 -27
  33. arize/ml/client.py +35 -28
  34. arize/ml/proto.py +16 -17
  35. arize/ml/stream_validation.py +63 -25
  36. arize/ml/surrogate_explainer/mimic.py +15 -7
  37. arize/ml/types.py +26 -12
  38. arize/pre_releases.py +7 -6
  39. arize/py.typed +0 -0
  40. arize/regions.py +10 -10
  41. arize/spans/client.py +113 -21
  42. arize/spans/conversion.py +7 -5
  43. arize/spans/validation/annotations/dataframe_form_validation.py +1 -1
  44. arize/spans/validation/annotations/value_validation.py +11 -14
  45. arize/spans/validation/common/dataframe_form_validation.py +1 -1
  46. arize/spans/validation/common/value_validation.py +10 -13
  47. arize/spans/validation/evals/value_validation.py +1 -1
  48. arize/spans/validation/metadata/argument_validation.py +1 -1
  49. arize/spans/validation/metadata/dataframe_form_validation.py +1 -1
  50. arize/spans/validation/metadata/value_validation.py +23 -1
  51. arize/utils/arrow.py +37 -1
  52. arize/utils/online_tasks/dataframe_preprocessor.py +8 -4
  53. arize/utils/proto.py +0 -1
  54. arize/utils/types.py +6 -6
  55. arize/version.py +1 -1
  56. {arize-8.0.0b2.dist-info → arize-8.0.1.dist-info}/METADATA +18 -3
  57. {arize-8.0.0b2.dist-info → arize-8.0.1.dist-info}/RECORD +60 -58
  58. {arize-8.0.0b2.dist-info → arize-8.0.1.dist-info}/WHEEL +0 -0
  59. {arize-8.0.0b2.dist-info → arize-8.0.1.dist-info}/licenses/LICENSE +0 -0
  60. {arize-8.0.0b2.dist-info → arize-8.0.1.dist-info}/licenses/NOTICE +0 -0
@@ -7,7 +7,14 @@ import json
7
7
  import logging
8
8
  import traceback
9
9
  from binascii import hexlify
10
- from collections.abc import Awaitable, Callable, Mapping, Sequence
10
+ from collections.abc import (
11
+ Awaitable,
12
+ Callable,
13
+ Coroutine,
14
+ Iterable,
15
+ Mapping,
16
+ Sequence,
17
+ )
11
18
  from contextlib import ExitStack
12
19
  from copy import deepcopy
13
20
  from datetime import date, datetime, time, timedelta, timezone
@@ -34,7 +41,7 @@ from openinference.semconv.trace import (
34
41
  )
35
42
  from opentelemetry.context import Context
36
43
  from opentelemetry.sdk.resources import Resource
37
- from opentelemetry.trace import Status, StatusCode, Tracer
44
+ from opentelemetry.trace import NoOpTracer, Status, StatusCode, Tracer
38
45
 
39
46
  if TYPE_CHECKING:
40
47
  from opentelemetry.sdk.trace import Span
@@ -48,6 +55,7 @@ from arize.experiments.evaluators.types import (
48
55
  EvaluationResult,
49
56
  EvaluationResultFieldNames,
50
57
  EvaluatorName,
58
+ JSONSerializable,
51
59
  )
52
60
  from arize.experiments.evaluators.utils import create_evaluator
53
61
  from arize.experiments.tracing import capture_spans, flatten
@@ -64,6 +72,9 @@ RateLimitErrors: TypeAlias = type[BaseException] | Sequence[type[BaseException]]
64
72
 
65
73
  logger = logging.getLogger(__name__)
66
74
 
75
+ # Module-level singleton for no-op tracing
76
+ _NOOP_TRACER = NoOpTracer()
77
+
67
78
 
68
79
  def run_experiment(
69
80
  experiment_name: str,
@@ -116,11 +127,12 @@ def run_experiment(
116
127
  error: BaseException | None = None
117
128
  status = Status(StatusCode.OK)
118
129
  with ExitStack() as stack:
119
- span: Span = stack.enter_context(
130
+ # Type ignore: OpenTelemetry interface vs implementation type mismatch
131
+ span: Span = stack.enter_context( # type: ignore[assignment]
120
132
  cm=tracer.start_as_current_span(
121
133
  name=root_span_name, context=Context()
122
134
  )
123
- ) # type:ignore
135
+ )
124
136
  stack.enter_context(capture_spans(resource))
125
137
  span.set_attribute(METADATA, json.dumps(md, ensure_ascii=False))
126
138
  try:
@@ -146,9 +158,12 @@ def run_experiment(
146
158
  raise TypeError(sync_error_message)
147
159
  output = _output
148
160
 
149
- output = jsonify(output)
161
+ # Type ignore: jsonify returns object but runtime result is JSONSerializable
162
+ output = jsonify(output) # type: ignore[assignment]
150
163
  if example.input:
151
- span.set_attribute(INPUT_VALUE, example.input) # type: ignore
164
+ # OpenTelemetry type hints are restrictive, but Arize's tracing layer
165
+ # accepts JSON-serializable structures which are auto-serialized
166
+ span.set_attribute(INPUT_VALUE, example.input) # type: ignore[arg-type]
152
167
  else:
153
168
  span.set_attribute(
154
169
  INPUT_VALUE,
@@ -187,9 +202,9 @@ def run_experiment(
187
202
  else datetime.now(tz=timezone.utc)
188
203
  ),
189
204
  dataset_example_id=example.id,
190
- output=output, # type:ignore
205
+ output=output,
191
206
  error=repr(error) if error else None,
192
- trace_id=_str_trace_id(span.get_span_context().trace_id), # type: ignore
207
+ trace_id=_str_trace_id(span.get_span_context().trace_id),
193
208
  )
194
209
 
195
210
  async def async_run_experiment(example: Example) -> ExperimentRun:
@@ -197,11 +212,12 @@ def run_experiment(
197
212
  error: BaseException | None = None
198
213
  status = Status(StatusCode.OK)
199
214
  with ExitStack() as stack:
200
- span: Span = stack.enter_context(
215
+ # Type ignore: OpenTelemetry interface vs implementation type mismatch
216
+ span: Span = stack.enter_context( # type: ignore[assignment]
201
217
  cm=tracer.start_as_current_span(
202
218
  name=root_span_name, context=Context()
203
219
  )
204
- ) # type:ignore
220
+ )
205
221
  stack.enter_context(capture_spans(resource))
206
222
  span.set_attribute(METADATA, json.dumps(md, ensure_ascii=False))
207
223
  try:
@@ -220,9 +236,12 @@ def run_experiment(
220
236
  )
221
237
  error = exc
222
238
  _print_experiment_error(exc, example_id=example.id, kind="task")
223
- output = jsonify(output)
239
+ # Type ignore: jsonify returns object but runtime result is JSONSerializable
240
+ output = jsonify(output) # type: ignore[assignment]
224
241
  if example.input:
225
- span.set_attribute(INPUT_VALUE, example.input) # type: ignore
242
+ # OpenTelemetry type hints are restrictive, but Arize's tracing layer
243
+ # accepts JSON-serializable structures which are auto-serialized
244
+ span.set_attribute(INPUT_VALUE, example.input) # type: ignore[arg-type]
226
245
  else:
227
246
  span.set_attribute(
228
247
  INPUT_VALUE,
@@ -261,14 +280,14 @@ def run_experiment(
261
280
  else datetime.now(tz=timezone.utc)
262
281
  ),
263
282
  dataset_example_id=example.id,
264
- output=output, # type: ignore
283
+ output=output,
265
284
  error=repr(error) if error else None,
266
- trace_id=_str_trace_id(span.get_span_context().trace_id), # type: ignore
285
+ trace_id=_str_trace_id(span.get_span_context().trace_id),
267
286
  )
268
287
 
269
288
  _errors: tuple[type[BaseException], ...]
270
289
  if not isinstance(rate_limit_errors, Sequence):
271
- _errors = (rate_limit_errors,) # type: ignore
290
+ _errors = (rate_limit_errors,) if rate_limit_errors is not None else ()
272
291
  else:
273
292
  _errors = tuple(filter(None, rate_limit_errors))
274
293
  rate_limiters = [RateLimiter(rate_limit_error=rle) for rle in _errors]
@@ -284,8 +303,13 @@ def run_experiment(
284
303
  )
285
304
 
286
305
  executor = get_executor_on_sync_context(
287
- sync_fn=rate_limited_sync_run_experiment,
288
- async_fn=rate_limited_async_run_experiment,
306
+ sync_fn=cast(
307
+ "Callable[[object], Any]", rate_limited_sync_run_experiment
308
+ ),
309
+ async_fn=cast(
310
+ "Callable[[object], Coroutine[Any, Any, Any]]",
311
+ rate_limited_async_run_experiment,
312
+ ),
289
313
  max_retries=0,
290
314
  exit_on_error=exit_on_error,
291
315
  fallback_return_value=None,
@@ -295,20 +319,27 @@ def run_experiment(
295
319
  )
296
320
 
297
321
  runs, _ = executor.run(examples)
298
- task_summary = _TaskSummary.from_task_runs(len(dataset), runs)
322
+ task_summary = _TaskSummary.from_task_runs(
323
+ len(dataset), cast("list[ExperimentRun | None]", runs)
324
+ )
299
325
 
300
326
  if exit_on_error and (None in runs):
301
327
  # When exit_on_error is True, the result of a failed task execution is None
302
328
  # If any task execution failed, raise an error to exit early
303
329
  raise RuntimeError("An error occurred during execution of tasks.")
304
330
 
331
+ # Filter out None values before accessing attributes
332
+ runs_filtered = [
333
+ r for r in cast("list[ExperimentRun | None]", runs) if r is not None
334
+ ]
335
+
305
336
  out_df = pd.DataFrame()
306
- out_df["id"] = [run.id for run in runs]
307
- out_df["example_id"] = [run.dataset_example_id for run in runs]
308
- out_df["result"] = [run.output for run in runs]
309
- out_df["result.trace.id"] = [run.trace_id for run in runs]
337
+ out_df["id"] = [run.id for run in runs_filtered]
338
+ out_df["example_id"] = [run.dataset_example_id for run in runs_filtered]
339
+ out_df["result"] = [run.output for run in runs_filtered] # type: ignore[assignment]
340
+ out_df["result.trace.id"] = [run.trace_id for run in runs_filtered]
310
341
  out_df["result.trace.timestamp"] = [
311
- int(run.start_time.timestamp() * 1e3) for run in runs
342
+ int(run.start_time.timestamp() * 1e3) for run in runs_filtered
312
343
  ]
313
344
  out_df.set_index("id", inplace=True, drop=False)
314
345
  logger.info(f"✅ Task runs completed.\n{task_summary}")
@@ -317,7 +348,7 @@ def run_experiment(
317
348
  eval_results = evaluate_experiment(
318
349
  experiment_name=experiment_name,
319
350
  examples=examples,
320
- experiment_results=runs,
351
+ experiment_results=cast("Sequence[ExperimentRun]", runs),
321
352
  evaluators=evaluators,
322
353
  rate_limit_errors=rate_limit_errors,
323
354
  concurrency=concurrency,
@@ -333,7 +364,7 @@ def run_experiment(
333
364
  )
334
365
 
335
366
  # group evaluation results by name
336
- eval_results_by_name = {}
367
+ eval_results_by_name: dict[str, list[ExperimentEvaluationRun]] = {}
337
368
  for r in eval_results:
338
369
  if r is None:
339
370
  continue
@@ -355,7 +386,8 @@ def run_experiment(
355
386
  }
356
387
 
357
388
  for attr, getter in eval_data.items():
358
- out_df[f"eval.{eval_name}.{attr}"] = out_df.index.map(
389
+ # Type ignore: pandas DataFrame column assignment type is overly restrictive
390
+ out_df[f"eval.{eval_name}.{attr}"] = out_df.index.map( # type: ignore[assignment]
359
391
  {r.experiment_run_id: getter(r) for r in eval_res}
360
392
  )
361
393
  out_df = _add_metadata_to_output_df(out_df, eval_res, eval_name)
@@ -372,7 +404,7 @@ def evaluate_experiment(
372
404
  evaluators: Evaluators | None = None,
373
405
  rate_limit_errors: RateLimitErrors | None = None,
374
406
  concurrency: int = 3,
375
- tracer: Tracer | None = None,
407
+ tracer: Tracer = _NOOP_TRACER,
376
408
  resource: Resource | None = None,
377
409
  exit_on_error: bool = False,
378
410
  timeout: int = 120,
@@ -386,7 +418,7 @@ def evaluate_experiment(
386
418
  evaluators (Evaluators): The evaluators to use for assessment.
387
419
  rate_limit_errors (RateLimitErrors | :obj:`None`): Optional rate limit errors.
388
420
  concurrency (int): The number of concurrent tasks to run. Default is 3.
389
- tracer (Tracer | :obj:`None`): Optional tracer for tracing the evaluation.
421
+ tracer (Tracer): Tracer for tracing the evaluation. Defaults to NoOpTracer().
390
422
  resource (Resource | :obj:`None`): Optional resource for the evaluation.
391
423
  exit_on_error (bool): Whether to exit on error. Default is False.
392
424
  timeout (int): The timeout for each evaluation in seconds. Default is 120.
@@ -425,12 +457,16 @@ def evaluate_experiment(
425
457
  status = Status(StatusCode.OK)
426
458
  root_span_name = f"Evaluation: {evaluator.name}"
427
459
  with ExitStack() as stack:
428
- span: Span = stack.enter_context(
429
- tracer.start_as_current_span( # type:ignore
430
- name=root_span_name, context=Context()
431
- )
460
+ span: Span = cast(
461
+ "Span",
462
+ stack.enter_context(
463
+ tracer.start_as_current_span(
464
+ name=root_span_name, context=Context()
465
+ )
466
+ ),
432
467
  )
433
- stack.enter_context(capture_spans(resource)) # type:ignore
468
+ if resource is not None:
469
+ stack.enter_context(capture_spans(resource))
434
470
  span.set_attribute(METADATA, json.dumps(md, ensure_ascii=False))
435
471
  try:
436
472
  result = evaluator.evaluate(
@@ -456,7 +492,15 @@ def evaluate_experiment(
456
492
  )
457
493
  if result:
458
494
  span.set_attributes(
459
- dict(flatten(jsonify(result), recurse_on_sequence=True))
495
+ dict(
496
+ flatten(
497
+ cast(
498
+ "Mapping[str, Any] | Iterable[Any]",
499
+ jsonify(result),
500
+ ),
501
+ recurse_on_sequence=True,
502
+ )
503
+ )
460
504
  )
461
505
  span.set_attribute(OPENINFERENCE_SPAN_KIND, root_span_kind)
462
506
  span.set_status(status)
@@ -473,7 +517,7 @@ def evaluate_experiment(
473
517
  annotator_kind=evaluator.kind,
474
518
  error=repr(error) if error else None,
475
519
  result=result,
476
- trace_id=_str_trace_id(span.get_span_context().trace_id), # type:ignore
520
+ trace_id=_str_trace_id(span.get_span_context().trace_id),
477
521
  )
478
522
 
479
523
  async def async_eval_run(
@@ -485,12 +529,16 @@ def evaluate_experiment(
485
529
  status = Status(StatusCode.OK)
486
530
  root_span_name = f"Evaluation: {evaluator.name}"
487
531
  with ExitStack() as stack:
488
- span: Span = stack.enter_context(
489
- tracer.start_as_current_span( # type:ignore
490
- name=root_span_name, context=Context()
491
- )
532
+ span: Span = cast(
533
+ "Span",
534
+ stack.enter_context(
535
+ tracer.start_as_current_span(
536
+ name=root_span_name, context=Context()
537
+ )
538
+ ),
492
539
  )
493
- stack.enter_context(capture_spans(resource)) # type:ignore
540
+ if resource is not None:
541
+ stack.enter_context(capture_spans(resource))
494
542
  span.set_attribute(METADATA, json.dumps(md, ensure_ascii=False))
495
543
  try:
496
544
  result = await evaluator.async_evaluate(
@@ -516,7 +564,15 @@ def evaluate_experiment(
516
564
  )
517
565
  if result:
518
566
  span.set_attributes(
519
- dict(flatten(jsonify(result), recurse_on_sequence=True))
567
+ dict(
568
+ flatten(
569
+ cast(
570
+ "Mapping[str, Any] | Iterable[Any]",
571
+ jsonify(result),
572
+ ),
573
+ recurse_on_sequence=True,
574
+ )
575
+ )
520
576
  )
521
577
  span.set_attribute(OPENINFERENCE_SPAN_KIND, root_span_kind)
522
578
  span.set_status(status)
@@ -532,7 +588,7 @@ def evaluate_experiment(
532
588
  annotator_kind=evaluator.kind,
533
589
  error=repr(error) if error else None,
534
590
  result=result,
535
- trace_id=_str_trace_id(span.get_span_context().trace_id), # type:ignore
591
+ trace_id=_str_trace_id(span.get_span_context().trace_id),
536
592
  )
537
593
 
538
594
  _errors: tuple[type[BaseException], ...]
@@ -553,8 +609,11 @@ def evaluate_experiment(
553
609
  )
554
610
 
555
611
  executor = get_executor_on_sync_context(
556
- rate_limited_sync_evaluate_run,
557
- rate_limited_async_evaluate_run,
612
+ cast("Callable[[object], Any]", rate_limited_sync_evaluate_run),
613
+ cast(
614
+ "Callable[[object], Coroutine[Any, Any, Any]]",
615
+ rate_limited_async_evaluate_run,
616
+ ),
558
617
  max_retries=0,
559
618
  exit_on_error=exit_on_error,
560
619
  fallback_return_value=None,
@@ -565,14 +624,15 @@ def evaluate_experiment(
565
624
  timeout=timeout,
566
625
  )
567
626
  eval_runs, _ = executor.run(evaluation_input)
568
- return eval_runs
627
+ # Cast: run returns list[Unset | object], but sync/async_eval_run guarantee ExperimentEvaluationRun
628
+ return cast("list[ExperimentEvaluationRun]", eval_runs)
569
629
 
570
630
 
571
631
  def _add_metadata_to_output_df(
572
632
  output_df: pd.DataFrame,
573
633
  eval_runs: list[ExperimentEvaluationRun],
574
634
  evaluator_name: str,
575
- ) -> object:
635
+ ) -> pd.DataFrame:
576
636
  for eval_run in eval_runs:
577
637
  if eval_run.result is None:
578
638
  continue
@@ -603,7 +663,9 @@ def _dataframe_to_examples(dataset: pd.DataFrame) -> list[Example]:
603
663
  examples = []
604
664
 
605
665
  for _, row in dataset.iterrows():
606
- example = Example(dataset_row=row.to_dict())
666
+ example = Example(
667
+ dataset_row=cast("Mapping[str, JSONSerializable]", row.to_dict())
668
+ )
607
669
  examples.append(example)
608
670
  return examples
609
671
 
@@ -770,7 +832,8 @@ def get_result_attr(r: object, attr: str, default: object = None) -> object:
770
832
  Returns:
771
833
  The attribute value if found, otherwise the default value.
772
834
  """
773
- return getattr(r.result, attr, default) if r.result else default
835
+ # Type ignore: r typed as object but expected to have result attribute at runtime
836
+ return getattr(r.result, attr, default) if r.result else default # type: ignore[attr-defined]
774
837
 
775
838
 
776
839
  def transform_to_experiment_format(
@@ -57,7 +57,10 @@ _ACTIVE_MODIFIER: ContextVar[SpanModifier | None] = ContextVar(
57
57
 
58
58
 
59
59
  def override_span(
60
- init: Callable[..., None], span: ReadableSpan, args: object, kwargs: object
60
+ init: Callable[..., None],
61
+ span: ReadableSpan,
62
+ args: tuple[object, ...],
63
+ kwargs: dict[str, object],
61
64
  ) -> None:
62
65
  """Override span initialization to apply active span modifiers.
63
66
 
@@ -11,6 +11,7 @@ from datetime import datetime, timezone
11
11
  from importlib.metadata import version
12
12
  from random import getrandbits
13
13
  from typing import (
14
+ NoReturn,
14
15
  cast,
15
16
  )
16
17
 
@@ -90,11 +91,13 @@ class Example:
90
91
  def from_dict(cls, obj: Mapping[str, object]) -> Example:
91
92
  """Create an Example instance from a dictionary."""
92
93
  return cls(
93
- id=obj["id"],
94
- input=obj["input"],
95
- output=obj["output"],
96
- metadata=obj.get("metadata") or {},
97
- updated_at=obj["updated_at"],
94
+ id=cast("str", obj["id"]),
95
+ input=cast("Mapping[str, JSONSerializable]", obj["input"]),
96
+ output=cast("Mapping[str, JSONSerializable]", obj["output"]),
97
+ metadata=cast(
98
+ "Mapping[str, JSONSerializable]", obj.get("metadata") or {}
99
+ ),
100
+ updated_at=cast("datetime", obj["updated_at"]),
98
101
  )
99
102
 
100
103
  def __repr__(self) -> str:
@@ -148,7 +151,7 @@ def _make_read_only(
148
151
  return obj
149
152
 
150
153
 
151
- class _ReadOnly(ObjectProxy): # type: ignore[misc]
154
+ class _ReadOnly(ObjectProxy):
152
155
  def __setitem__(self, *args: object, **kwargs: object) -> object:
153
156
  raise NotImplementedError
154
157
 
@@ -227,15 +230,15 @@ class ExperimentRun:
227
230
  def from_dict(cls, obj: Mapping[str, object]) -> ExperimentRun:
228
231
  """Create an ExperimentRun instance from a dictionary."""
229
232
  return cls(
230
- start_time=obj["start_time"],
231
- end_time=obj["end_time"],
232
- experiment_id=obj["experiment_id"],
233
- dataset_example_id=obj["dataset_example_id"],
234
- repetition_number=obj.get("repetition_number") or 1,
235
- output=_make_read_only(obj.get("output")),
236
- error=obj.get("error"),
237
- id=obj["id"],
238
- trace_id=obj.get("trace_id"),
233
+ start_time=cast("datetime", obj["start_time"]),
234
+ end_time=cast("datetime", obj["end_time"]),
235
+ experiment_id=cast("str", obj["experiment_id"]),
236
+ dataset_example_id=cast("str", obj["dataset_example_id"]),
237
+ repetition_number=cast("int", obj.get("repetition_number") or 1),
238
+ output=cast("JSONSerializable", _make_read_only(obj.get("output"))),
239
+ error=cast("str | None", obj.get("error")),
240
+ id=cast("str", obj["id"]),
241
+ trace_id=cast("str | None", obj.get("trace_id")),
239
242
  )
240
243
 
241
244
  def __post_init__(self) -> None:
@@ -280,15 +283,17 @@ class ExperimentEvaluationRun:
280
283
  def from_dict(cls, obj: Mapping[str, object]) -> ExperimentEvaluationRun:
281
284
  """Create an ExperimentEvaluationRun instance from a dictionary."""
282
285
  return cls(
283
- experiment_run_id=obj["experiment_run_id"],
284
- start_time=obj["start_time"],
285
- end_time=obj["end_time"],
286
- name=obj["name"],
287
- annotator_kind=obj["annotator_kind"],
288
- error=obj.get("error"),
289
- result=EvaluationResult.from_dict(obj.get("result")),
290
- id=obj["id"],
291
- trace_id=obj.get("trace_id"),
286
+ experiment_run_id=cast("str", obj["experiment_run_id"]),
287
+ start_time=cast("datetime", obj["start_time"]),
288
+ end_time=cast("datetime", obj["end_time"]),
289
+ name=cast("str", obj["name"]),
290
+ annotator_kind=cast("str", obj["annotator_kind"]),
291
+ error=cast("str | None", obj.get("error")),
292
+ result=EvaluationResult.from_dict(
293
+ cast("Mapping[str, object] | None", obj.get("result"))
294
+ ),
295
+ id=cast("str", obj["id"]),
296
+ trace_id=cast("str | None", obj.get("trace_id")),
292
297
  )
293
298
 
294
299
  def __post_init__(self) -> None:
@@ -334,7 +339,7 @@ class _HasStats:
334
339
  text = self.stats.__str__()
335
340
  else:
336
341
  text = self.stats.to_markdown(index=False)
337
- return f"{self.title}\n{'-' * len(self.title)}\n" + text # type: ignore
342
+ return f"{self.title}\n{'-' * len(self.title)}\n" + text
338
343
 
339
344
 
340
345
  @dataclass(frozen=True)
@@ -378,7 +383,7 @@ class _TaskSummary(_HasStats):
378
383
  return summary
379
384
 
380
385
  @classmethod
381
- def __new__(cls, *args: object, **kwargs: object) -> object:
386
+ def __new__(cls, *args: object, **kwargs: object) -> NoReturn:
382
387
  # Direct instantiation by users is discouraged.
383
388
  raise NotImplementedError
384
389
 
arize/logging.py CHANGED
@@ -6,8 +6,11 @@ import json
6
6
  import logging
7
7
  import os
8
8
  import sys
9
- from collections.abc import Iterable, Mapping
10
- from typing import Any, ClassVar
9
+ from collections.abc import Iterable, Mapping, MutableMapping
10
+ from typing import TYPE_CHECKING, Any, ClassVar
11
+
12
+ if TYPE_CHECKING:
13
+ import requests
11
14
 
12
15
  from arize.config import _parse_bool
13
16
  from arize.constants.config import (
@@ -34,9 +37,18 @@ class CtxAdapter(logging.LoggerAdapter):
34
37
  """LoggerAdapter that merges bound context with per-call extras safely."""
35
38
 
36
39
  def process(
37
- self, msg: object, kwargs: dict[str, object]
38
- ) -> tuple[object, dict[str, object]]:
39
- """Process the logging call by merging bound and call extras."""
40
+ self, msg: object, kwargs: MutableMapping[str, Any]
41
+ ) -> tuple[object, MutableMapping[str, Any]]:
42
+ """Process the logging call by merging bound and call extras.
43
+
44
+ Args:
45
+ msg: The log message to process.
46
+ kwargs: Keyword arguments from the logging call, may include 'extra' dict.
47
+
48
+ Returns:
49
+ tuple[object, dict[str, object]]: A tuple of (message, modified_kwargs) with
50
+ merged extra context.
51
+ """
40
52
  call_extra = _coerce_mapping(kwargs.pop("extra", None))
41
53
  bound_extra = _coerce_mapping(self.extra)
42
54
  merged = (
@@ -49,13 +61,24 @@ class CtxAdapter(logging.LoggerAdapter):
49
61
  return msg, kwargs
50
62
 
51
63
  def with_extra(self, **more: object) -> CtxAdapter:
52
- """Return a copy of this adapter with additional bound extras."""
64
+ """Return a copy of this adapter with additional bound extras.
65
+
66
+ Args:
67
+ **more: Additional key-value pairs to merge into the bound extras.
68
+
69
+ Returns:
70
+ CtxAdapter: A new adapter instance with merged extra context.
71
+ """
53
72
  base = _coerce_mapping(self.extra)
54
73
  base.update(_coerce_mapping(more))
55
74
  return type(self)(self.logger, base)
56
75
 
57
76
  def without_extra(self) -> CtxAdapter:
58
- """Return a copy of this adapter with *no* bound extras."""
77
+ """Return a copy of this adapter with *no* bound extras.
78
+
79
+ Returns:
80
+ CtxAdapter: A new adapter instance without any bound extra context.
81
+ """
59
82
  return type(self)(self.logger, None)
60
83
 
61
84
 
@@ -86,7 +109,14 @@ class CustomLogFormatter(logging.Formatter):
86
109
  super().__init__(fmt=fmt)
87
110
 
88
111
  def format(self, record: logging.LogRecord) -> str:
89
- """Format the log record with color based on log level."""
112
+ """Format the log record with color based on log level.
113
+
114
+ Args:
115
+ record: The log record to format.
116
+
117
+ Returns:
118
+ str: Formatted and color-coded log message with any extra fields appended.
119
+ """
90
120
  # Build the base message without any color.
91
121
  base = super().format(record)
92
122
 
@@ -113,31 +143,22 @@ class JsonFormatter(logging.Formatter):
113
143
  """Minimal JSON formatter (one JSON object per line)."""
114
144
 
115
145
  # fields to skip copying from record.__dict__
116
- _skip: ClassVar[set[str]] = {
117
- # "name",
118
- # "msg",
119
- # "args",
120
- # "levelname",
121
- # "levelno",
122
- # "pathname",
123
- # "filename",
124
- # "module",
125
- # "exc_info",
126
- # "exc_text",
127
- # "stack_info",
128
- # "lineno",
129
- # "funcName",
130
- # "created",
131
- # "msecs",
132
- # "relativeCreated",
133
- # "thread",
134
- # "threadName",
135
- # "processName",
136
- # "process",
137
- }
146
+ _skip: ClassVar[set[str]] = set()
147
+ # Potential fields to skip:
148
+ # "name", "msg", "args", "levelname", "levelno", "pathname",
149
+ # "filename", "module", "exc_info", "exc_text", "stack_info",
150
+ # "lineno", "funcName", "created", "msecs", "relativeCreated",
151
+ # "thread", "threadName", "processName", "process"
138
152
 
139
153
  def format(self, record: logging.LogRecord) -> str:
140
- """Format the log record as a JSON string."""
154
+ """Format the log record as a JSON string.
155
+
156
+ Args:
157
+ record: The log record to format.
158
+
159
+ Returns:
160
+ str: JSON-formatted log message as a single line with all fields and extras.
161
+ """
141
162
  payload: dict[str, object] = {
142
163
  # "time": self.formatTime(record, datefmt="%Y-%m-%dT%H:%M:%S%z"),
143
164
  # "logger": record.name,
@@ -277,7 +298,7 @@ def log_a_list(values: Iterable[Any] | None, join_word: str) -> str:
277
298
  )
278
299
 
279
300
 
280
- def get_arize_project_url(response: object) -> str:
301
+ def get_arize_project_url(response: requests.Response) -> str:
281
302
  """Extract the Arize project URL from an API response.
282
303
 
283
304
  Args: