arize-phoenix 11.23.2__py3-none-any.whl → 11.24.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

@@ -1,14 +1,14 @@
1
1
  import re
2
2
  from collections import defaultdict
3
3
  from datetime import datetime
4
- from typing import Any, Iterable, Iterator, Optional, Union
4
+ from typing import Any, Iterable, Iterator, Literal, Optional, Union
5
5
  from typing import cast as type_cast
6
6
 
7
7
  import numpy as np
8
8
  import numpy.typing as npt
9
9
  import strawberry
10
10
  from sqlalchemy import ColumnElement, String, and_, case, cast, func, select, text
11
- from sqlalchemy.orm import aliased, joinedload, load_only
11
+ from sqlalchemy.orm import joinedload, load_only
12
12
  from starlette.authentication import UnauthenticatedUser
13
13
  from strawberry import ID, UNSET
14
14
  from strawberry.relay import Connection, GlobalID, Node
@@ -108,29 +108,52 @@ class DbTableStats:
108
108
 
109
109
 
110
110
  @strawberry.type
111
- class MetricCounts:
112
- num_increases: int
113
- num_decreases: int
114
- num_equal: int
115
-
111
+ class ExperimentRunMetricComparison:
112
+ num_runs_improved: int = strawberry.field(
113
+ description=(
114
+ "The number of runs in which the base experiment improved "
115
+ "on the best run in any compare experiment."
116
+ )
117
+ )
118
+ num_runs_regressed: int = strawberry.field(
119
+ description=(
120
+ "The number of runs in which the base experiment regressed "
121
+ "on the best run in any compare experiment."
122
+ )
123
+ )
124
+ num_runs_equal: int = strawberry.field(
125
+ description=(
126
+ "The number of runs in which the base experiment is equal to the best run "
127
+ "in any compare experiment."
128
+ )
129
+ )
130
+ num_total_runs: strawberry.Private[int]
116
131
 
117
- @strawberry.type
118
- class CompareExperimentRunMetricCounts:
119
- compare_experiment_id: GlobalID
120
- latency: MetricCounts
121
- prompt_token_count: MetricCounts
122
- completion_token_count: MetricCounts
123
- total_token_count: MetricCounts
124
- total_cost: MetricCounts
132
+ @strawberry.field(
133
+ description=(
134
+ "The number of runs in the base experiment that could not be compared, either because "
135
+ "the base experiment run was missing a value or because all compare experiment runs "
136
+ "were missing values."
137
+ )
138
+ ) # type: ignore[misc]
139
+ def num_runs_without_comparison(self) -> int:
140
+ return (
141
+ self.num_total_runs
142
+ - self.num_runs_improved
143
+ - self.num_runs_regressed
144
+ - self.num_runs_equal
145
+ )
125
146
 
126
147
 
127
148
  @strawberry.type
128
- class CompareExperimentRunAnnotationMetricCounts:
129
- annotation_name: str
130
- compare_experiment_id: GlobalID
131
- num_increases: int
132
- num_decreases: int
133
- num_equal: int
149
+ class ExperimentRunMetricComparisons:
150
+ latency: ExperimentRunMetricComparison
151
+ total_token_count: ExperimentRunMetricComparison
152
+ prompt_token_count: ExperimentRunMetricComparison
153
+ completion_token_count: ExperimentRunMetricComparison
154
+ total_cost: ExperimentRunMetricComparison
155
+ prompt_cost: ExperimentRunMetricComparison
156
+ completion_cost: ExperimentRunMetricComparison
134
157
 
135
158
 
136
159
  @strawberry.type
@@ -521,12 +544,12 @@ class Query:
521
544
  )
522
545
 
523
546
  @strawberry.field
524
- async def compare_experiment_run_metric_counts(
547
+ async def experiment_run_metric_comparisons(
525
548
  self,
526
549
  info: Info[Context, None],
527
550
  base_experiment_id: GlobalID,
528
551
  compare_experiment_ids: list[GlobalID],
529
- ) -> list[CompareExperimentRunMetricCounts]:
552
+ ) -> ExperimentRunMetricComparisons:
530
553
  if base_experiment_id in compare_experiment_ids:
531
554
  raise BadRequest("Compare experiment IDs cannot contain the base experiment ID")
532
555
  if not compare_experiment_ids:
@@ -553,375 +576,256 @@ class Query:
553
576
  raise BadRequest(f"Invalid compare experiment ID: {compare_experiment_id}")
554
577
 
555
578
  base_experiment_runs = (
556
- select(models.ExperimentRun)
579
+ select(
580
+ models.ExperimentRun.dataset_example_id,
581
+ func.min(models.ExperimentRun.start_time).label("start_time"),
582
+ func.min(models.ExperimentRun.end_time).label("end_time"),
583
+ func.sum(models.SpanCost.total_tokens).label("total_tokens"),
584
+ func.sum(models.SpanCost.prompt_tokens).label("prompt_tokens"),
585
+ func.sum(models.SpanCost.completion_tokens).label("completion_tokens"),
586
+ func.sum(models.SpanCost.total_cost).label("total_cost"),
587
+ func.sum(models.SpanCost.prompt_cost).label("prompt_cost"),
588
+ func.sum(models.SpanCost.completion_cost).label("completion_cost"),
589
+ )
590
+ .select_from(models.ExperimentRun)
591
+ .join(
592
+ models.Trace,
593
+ onclause=models.ExperimentRun.trace_id == models.Trace.trace_id,
594
+ isouter=True,
595
+ )
596
+ .join(
597
+ models.SpanCost,
598
+ onclause=models.Trace.id == models.SpanCost.trace_rowid,
599
+ isouter=True,
600
+ )
557
601
  .where(models.ExperimentRun.experiment_id == base_experiment_rowid)
602
+ .group_by(models.ExperimentRun.dataset_example_id)
558
603
  .subquery()
559
604
  .alias("base_experiment_runs")
560
605
  )
561
- base_experiment_traces = aliased(models.Trace, name="base_experiment_traces")
562
- base_experiment_span_costs = (
606
+ compare_experiment_runs = (
563
607
  select(
564
- models.SpanCost.trace_rowid,
565
- func.coalesce(func.sum(models.SpanCost.total_tokens), 0).label("total_tokens"),
566
- func.coalesce(func.sum(models.SpanCost.prompt_tokens), 0).label("prompt_tokens"),
567
- func.coalesce(func.sum(models.SpanCost.completion_tokens), 0).label(
568
- "completion_tokens"
569
- ),
570
- func.coalesce(func.sum(models.SpanCost.total_cost), 0).label("total_cost"),
571
- )
572
- .select_from(models.SpanCost)
573
- .group_by(
574
- models.SpanCost.trace_rowid,
608
+ models.ExperimentRun.dataset_example_id,
609
+ func.min(
610
+ LatencyMs(models.ExperimentRun.start_time, models.ExperimentRun.end_time)
611
+ ).label("min_latency_ms"),
612
+ func.min(models.SpanCost.total_tokens).label("min_total_tokens"),
613
+ func.min(models.SpanCost.prompt_tokens).label("min_prompt_tokens"),
614
+ func.min(models.SpanCost.completion_tokens).label("min_completion_tokens"),
615
+ func.min(models.SpanCost.total_cost).label("min_total_cost"),
616
+ func.min(models.SpanCost.prompt_cost).label("min_prompt_cost"),
617
+ func.min(models.SpanCost.completion_cost).label("min_completion_cost"),
575
618
  )
576
- .subquery()
577
- .alias("base_experiment_span_costs")
578
- )
579
-
580
- query = (
581
- select() # add selected columns below
582
- .select_from(base_experiment_runs)
619
+ .select_from(models.ExperimentRun)
583
620
  .join(
584
- base_experiment_traces,
585
- onclause=base_experiment_runs.c.trace_id == base_experiment_traces.trace_id,
621
+ models.Trace,
622
+ onclause=models.ExperimentRun.trace_id == models.Trace.trace_id,
586
623
  isouter=True,
587
624
  )
588
625
  .join(
589
- base_experiment_span_costs,
590
- onclause=base_experiment_traces.id == base_experiment_span_costs.c.trace_rowid,
626
+ models.SpanCost,
627
+ onclause=models.Trace.id == models.SpanCost.trace_rowid,
591
628
  isouter=True,
592
629
  )
630
+ .where(
631
+ models.ExperimentRun.experiment_id.in_(compare_experiment_rowids),
632
+ )
633
+ .group_by(models.ExperimentRun.dataset_example_id)
634
+ .subquery()
635
+ .alias("comp_exp_run_mins")
593
636
  )
594
637
 
595
638
  base_experiment_run_latency = LatencyMs(
596
639
  base_experiment_runs.c.start_time, base_experiment_runs.c.end_time
597
640
  ).label("base_experiment_run_latency_ms")
598
- base_experiment_run_prompt_token_count = base_experiment_span_costs.c.prompt_tokens
599
- base_experiment_run_completion_token_count = base_experiment_span_costs.c.completion_tokens
600
- base_experiment_run_total_token_count = base_experiment_span_costs.c.total_tokens
601
- base_experiment_run_total_cost = base_experiment_span_costs.c.total_cost
602
-
603
- for compare_experiment_index, compare_experiment_rowid in enumerate(
604
- compare_experiment_rowids
605
- ):
606
- compare_experiment_runs = (
607
- select(models.ExperimentRun)
608
- .where(models.ExperimentRun.experiment_id == compare_experiment_rowid)
609
- .subquery()
610
- .alias(f"comp_exp_{compare_experiment_index}_runs")
611
- )
612
- compare_experiment_traces = aliased(
613
- models.Trace, name=f"comp_exp_{compare_experiment_index}_traces"
614
- )
615
- compare_experiment_span_costs = (
616
- select(
617
- models.SpanCost.trace_rowid,
618
- func.coalesce(func.sum(models.SpanCost.total_tokens), 0).label("total_tokens"),
619
- func.coalesce(func.sum(models.SpanCost.prompt_tokens), 0).label(
620
- "prompt_tokens"
621
- ),
622
- func.coalesce(func.sum(models.SpanCost.completion_tokens), 0).label(
623
- "completion_tokens"
624
- ),
625
- func.coalesce(func.sum(models.SpanCost.total_cost), 0).label("total_cost"),
626
- )
627
- .select_from(models.SpanCost)
628
- .group_by(models.SpanCost.trace_rowid)
629
- .subquery()
630
- .alias(f"comp_exp_{compare_experiment_index}_span_costs")
631
- )
632
- compare_experiment_run_latency = LatencyMs(
633
- compare_experiment_runs.c.start_time, compare_experiment_runs.c.end_time
634
- ).label(f"comp_exp_{compare_experiment_index}_run_latency_ms")
635
- compare_experiment_run_prompt_token_count = (
636
- compare_experiment_span_costs.c.prompt_tokens
637
- )
638
- compare_experiment_run_completion_token_count = (
639
- compare_experiment_span_costs.c.completion_tokens
640
- )
641
- compare_experiment_run_total_token_count = compare_experiment_span_costs.c.total_tokens
642
- compare_experiment_run_total_cost = compare_experiment_span_costs.c.total_cost
643
-
644
- query = (
645
- query.add_columns(
646
- _count_rows(
647
- base_experiment_run_latency < compare_experiment_run_latency,
648
- ).label(f"comp_exp_{compare_experiment_index}_num_runs_increased_latency"),
649
- _count_rows(
650
- base_experiment_run_latency > compare_experiment_run_latency,
651
- ).label(f"comp_exp_{compare_experiment_index}_num_runs_decreased_latency"),
652
- _count_rows(
653
- base_experiment_run_latency == compare_experiment_run_latency,
654
- ).label(f"comp_exp_{compare_experiment_index}_num_runs_equal_latency"),
655
- _count_rows(
656
- base_experiment_run_prompt_token_count
657
- < compare_experiment_run_prompt_token_count,
658
- ).label(
659
- f"comp_exp_{compare_experiment_index}_num_runs_increased_prompt_token_count"
660
- ),
661
- _count_rows(
662
- base_experiment_run_prompt_token_count
663
- > compare_experiment_run_prompt_token_count,
664
- ).label(
665
- f"comp_exp_{compare_experiment_index}_num_runs_decreased_prompt_token_count"
666
- ),
667
- _count_rows(
668
- base_experiment_run_prompt_token_count
669
- == compare_experiment_run_prompt_token_count,
670
- ).label(
671
- f"comp_exp_{compare_experiment_index}_num_runs_equal_prompt_token_count"
672
- ),
673
- _count_rows(
674
- base_experiment_run_completion_token_count
675
- < compare_experiment_run_completion_token_count,
676
- ).label(
677
- f"comp_exp_{compare_experiment_index}_num_runs_increased_completion_token_count"
678
- ),
679
- _count_rows(
680
- base_experiment_run_completion_token_count
681
- > compare_experiment_run_completion_token_count,
682
- ).label(
683
- f"comp_exp_{compare_experiment_index}_num_runs_decreased_completion_token_count"
684
- ),
685
- _count_rows(
686
- base_experiment_run_completion_token_count
687
- == compare_experiment_run_completion_token_count,
688
- ).label(
689
- f"comp_exp_{compare_experiment_index}_num_runs_equal_completion_token_count"
690
- ),
691
- _count_rows(
692
- base_experiment_run_total_token_count
693
- < compare_experiment_run_total_token_count,
694
- ).label(
695
- f"comp_exp_{compare_experiment_index}_num_runs_increased_total_token_count"
696
- ),
697
- _count_rows(
698
- base_experiment_run_total_token_count
699
- > compare_experiment_run_total_token_count,
700
- ).label(
701
- f"comp_exp_{compare_experiment_index}_num_runs_decreased_total_token_count"
702
- ),
703
- _count_rows(
704
- base_experiment_run_total_token_count
705
- == compare_experiment_run_total_token_count,
706
- ).label(
707
- f"comp_exp_{compare_experiment_index}_num_runs_equal_total_token_count"
708
- ),
709
- _count_rows(
710
- base_experiment_run_total_cost < compare_experiment_run_total_cost,
711
- ).label(f"comp_exp_{compare_experiment_index}_num_runs_increased_total_cost"),
712
- _count_rows(
713
- base_experiment_run_total_cost > compare_experiment_run_total_cost,
714
- ).label(f"comp_exp_{compare_experiment_index}_num_runs_decreased_total_cost"),
715
- _count_rows(
716
- base_experiment_run_total_cost == compare_experiment_run_total_cost,
717
- ).label(f"comp_exp_{compare_experiment_index}_num_runs_equal_total_cost"),
718
- )
719
- .join(
720
- compare_experiment_runs,
721
- onclause=base_experiment_runs.c.dataset_example_id
722
- == compare_experiment_runs.c.dataset_example_id,
723
- isouter=True,
724
- )
725
- .join(
726
- compare_experiment_traces,
727
- onclause=compare_experiment_runs.c.trace_id
728
- == compare_experiment_traces.trace_id,
729
- isouter=True,
730
- )
731
- .join(
732
- compare_experiment_span_costs,
733
- onclause=compare_experiment_traces.id
734
- == compare_experiment_span_costs.c.trace_rowid,
735
- isouter=True,
736
- )
737
- )
738
-
739
- async with info.context.db() as session:
740
- result = (await session.execute(query)).first()
741
- assert result is not None
742
-
743
- num_columns_per_compare_experiment = len(query.columns) // len(compare_experiment_ids)
744
- counts = []
745
- for compare_experiment_index, compare_experiment_id in enumerate(compare_experiment_ids):
746
- start_index = compare_experiment_index * num_columns_per_compare_experiment
747
- end_index = start_index + num_columns_per_compare_experiment
748
- (
749
- num_runs_with_increased_latency,
750
- num_runs_with_decreased_latency,
751
- num_runs_with_equal_latency,
752
- num_runs_with_increased_prompt_token_count,
753
- num_runs_with_decreased_prompt_token_count,
754
- num_runs_with_equal_prompt_token_count,
755
- num_runs_with_increased_completion_token_count,
756
- num_runs_with_decreased_completion_token_count,
757
- num_runs_with_equal_completion_token_count,
758
- num_runs_with_increased_total_token_count,
759
- num_runs_with_decreased_total_token_count,
760
- num_runs_with_equal_total_token_count,
761
- num_runs_with_increased_total_cost,
762
- num_runs_with_decreased_total_cost,
763
- num_runs_with_equal_total_cost,
764
- ) = result[start_index:end_index]
765
- counts.append(
766
- CompareExperimentRunMetricCounts(
767
- compare_experiment_id=compare_experiment_id,
768
- latency=MetricCounts(
769
- num_increases=num_runs_with_increased_latency,
770
- num_decreases=num_runs_with_decreased_latency,
771
- num_equal=num_runs_with_equal_latency,
772
- ),
773
- prompt_token_count=MetricCounts(
774
- num_increases=num_runs_with_increased_prompt_token_count,
775
- num_decreases=num_runs_with_decreased_prompt_token_count,
776
- num_equal=num_runs_with_equal_prompt_token_count,
777
- ),
778
- completion_token_count=MetricCounts(
779
- num_increases=num_runs_with_increased_completion_token_count,
780
- num_decreases=num_runs_with_decreased_completion_token_count,
781
- num_equal=num_runs_with_equal_completion_token_count,
782
- ),
783
- total_token_count=MetricCounts(
784
- num_increases=num_runs_with_increased_total_token_count,
785
- num_decreases=num_runs_with_decreased_total_token_count,
786
- num_equal=num_runs_with_equal_total_token_count,
787
- ),
788
- total_cost=MetricCounts(
789
- num_increases=num_runs_with_increased_total_cost,
790
- num_decreases=num_runs_with_decreased_total_cost,
791
- num_equal=num_runs_with_equal_total_cost,
792
- ),
793
- )
794
- )
795
- return counts
796
-
797
- @strawberry.field
798
- async def compare_experiment_run_annotation_metric_counts(
799
- self,
800
- info: Info[Context, None],
801
- base_experiment_id: GlobalID,
802
- compare_experiment_ids: list[GlobalID],
803
- ) -> list[CompareExperimentRunAnnotationMetricCounts]:
804
- if base_experiment_id in compare_experiment_ids:
805
- raise BadRequest("Compare experiment IDs cannot contain the base experiment ID")
806
- if not compare_experiment_ids:
807
- raise BadRequest("At least one compare experiment ID must be provided")
808
- if len(set(compare_experiment_ids)) < len(compare_experiment_ids):
809
- raise BadRequest("Compare experiment IDs must be unique")
810
641
 
811
- try:
812
- base_experiment_rowid = from_global_id_with_expected_type(
813
- base_experiment_id, models.Experiment.__name__
814
- )
815
- except ValueError:
816
- raise BadRequest(f"Invalid base experiment ID: {base_experiment_id}")
817
-
818
- compare_experiment_rowids = []
819
- for compare_experiment_id in compare_experiment_ids:
820
- try:
821
- compare_experiment_rowids.append(
822
- from_global_id_with_expected_type(
823
- compare_experiment_id, models.Experiment.__name__
824
- )
825
- )
826
- except ValueError:
827
- raise BadRequest(f"Invalid compare experiment ID: {compare_experiment_id}")
828
-
829
- base_experiment_runs = (
830
- select(models.ExperimentRun)
831
- .where(
832
- models.ExperimentRun.experiment_id == base_experiment_rowid,
642
+ comparisons_query = (
643
+ select(
644
+ func.count().label("num_base_experiment_runs"),
645
+ _comparison_count_expression(
646
+ base_column=base_experiment_run_latency,
647
+ compare_column=compare_experiment_runs.c.min_latency_ms,
648
+ optimization_direction="minimize",
649
+ comparison_type="improvement",
650
+ ).label("num_latency_improved"),
651
+ _comparison_count_expression(
652
+ base_column=base_experiment_run_latency,
653
+ compare_column=compare_experiment_runs.c.min_latency_ms,
654
+ optimization_direction="minimize",
655
+ comparison_type="regression",
656
+ ).label("num_latency_regressed"),
657
+ _comparison_count_expression(
658
+ base_column=base_experiment_run_latency,
659
+ compare_column=compare_experiment_runs.c.min_latency_ms,
660
+ optimization_direction="minimize",
661
+ comparison_type="equality",
662
+ ).label("num_latency_is_equal"),
663
+ _comparison_count_expression(
664
+ base_column=base_experiment_runs.c.total_tokens,
665
+ compare_column=compare_experiment_runs.c.min_total_tokens,
666
+ optimization_direction="minimize",
667
+ comparison_type="improvement",
668
+ ).label("num_total_token_count_improved"),
669
+ _comparison_count_expression(
670
+ base_column=base_experiment_runs.c.total_tokens,
671
+ compare_column=compare_experiment_runs.c.min_total_tokens,
672
+ optimization_direction="minimize",
673
+ comparison_type="regression",
674
+ ).label("num_total_token_count_regressed"),
675
+ _comparison_count_expression(
676
+ base_column=base_experiment_runs.c.total_tokens,
677
+ compare_column=compare_experiment_runs.c.min_total_tokens,
678
+ optimization_direction="minimize",
679
+ comparison_type="equality",
680
+ ).label("num_total_token_count_is_equal"),
681
+ _comparison_count_expression(
682
+ base_column=base_experiment_runs.c.prompt_tokens,
683
+ compare_column=compare_experiment_runs.c.min_prompt_tokens,
684
+ optimization_direction="minimize",
685
+ comparison_type="improvement",
686
+ ).label("num_prompt_token_count_improved"),
687
+ _comparison_count_expression(
688
+ base_column=base_experiment_runs.c.prompt_tokens,
689
+ compare_column=compare_experiment_runs.c.min_prompt_tokens,
690
+ optimization_direction="minimize",
691
+ comparison_type="regression",
692
+ ).label("num_prompt_token_count_regressed"),
693
+ _comparison_count_expression(
694
+ base_column=base_experiment_runs.c.prompt_tokens,
695
+ compare_column=compare_experiment_runs.c.min_prompt_tokens,
696
+ optimization_direction="minimize",
697
+ comparison_type="equality",
698
+ ).label("num_prompt_token_count_is_equal"),
699
+ _comparison_count_expression(
700
+ base_column=base_experiment_runs.c.completion_tokens,
701
+ compare_column=compare_experiment_runs.c.min_completion_tokens,
702
+ optimization_direction="minimize",
703
+ comparison_type="improvement",
704
+ ).label("num_completion_token_count_improved"),
705
+ _comparison_count_expression(
706
+ base_column=base_experiment_runs.c.completion_tokens,
707
+ compare_column=compare_experiment_runs.c.min_completion_tokens,
708
+ optimization_direction="minimize",
709
+ comparison_type="regression",
710
+ ).label("num_completion_token_count_regressed"),
711
+ _comparison_count_expression(
712
+ base_column=base_experiment_runs.c.completion_tokens,
713
+ compare_column=compare_experiment_runs.c.min_completion_tokens,
714
+ optimization_direction="minimize",
715
+ comparison_type="equality",
716
+ ).label("num_completion_token_count_is_equal"),
717
+ _comparison_count_expression(
718
+ base_column=base_experiment_runs.c.total_cost,
719
+ compare_column=compare_experiment_runs.c.min_total_cost,
720
+ optimization_direction="minimize",
721
+ comparison_type="improvement",
722
+ ).label("num_total_cost_improved"),
723
+ _comparison_count_expression(
724
+ base_column=base_experiment_runs.c.total_cost,
725
+ compare_column=compare_experiment_runs.c.min_total_cost,
726
+ optimization_direction="minimize",
727
+ comparison_type="regression",
728
+ ).label("num_total_cost_regressed"),
729
+ _comparison_count_expression(
730
+ base_column=base_experiment_runs.c.total_cost,
731
+ compare_column=compare_experiment_runs.c.min_total_cost,
732
+ optimization_direction="minimize",
733
+ comparison_type="equality",
734
+ ).label("num_total_cost_is_equal"),
735
+ _comparison_count_expression(
736
+ base_column=base_experiment_runs.c.prompt_cost,
737
+ compare_column=compare_experiment_runs.c.min_prompt_cost,
738
+ optimization_direction="minimize",
739
+ comparison_type="improvement",
740
+ ).label("num_prompt_cost_improved"),
741
+ _comparison_count_expression(
742
+ base_column=base_experiment_runs.c.prompt_cost,
743
+ compare_column=compare_experiment_runs.c.min_prompt_cost,
744
+ optimization_direction="minimize",
745
+ comparison_type="regression",
746
+ ).label("num_prompt_cost_regressed"),
747
+ _comparison_count_expression(
748
+ base_column=base_experiment_runs.c.prompt_cost,
749
+ compare_column=compare_experiment_runs.c.min_prompt_cost,
750
+ optimization_direction="minimize",
751
+ comparison_type="equality",
752
+ ).label("num_prompt_cost_is_equal"),
753
+ _comparison_count_expression(
754
+ base_column=base_experiment_runs.c.completion_cost,
755
+ compare_column=compare_experiment_runs.c.min_completion_cost,
756
+ optimization_direction="minimize",
757
+ comparison_type="improvement",
758
+ ).label("num_completion_cost_improved"),
759
+ _comparison_count_expression(
760
+ base_column=base_experiment_runs.c.completion_cost,
761
+ compare_column=compare_experiment_runs.c.min_completion_cost,
762
+ optimization_direction="minimize",
763
+ comparison_type="regression",
764
+ ).label("num_completion_cost_regressed"),
765
+ _comparison_count_expression(
766
+ base_column=base_experiment_runs.c.completion_cost,
767
+ compare_column=compare_experiment_runs.c.min_completion_cost,
768
+ optimization_direction="minimize",
769
+ comparison_type="equality",
770
+ ).label("num_completion_cost_is_equal"),
833
771
  )
834
- .subquery()
835
- .alias("base_experiment_runs")
836
- )
837
- base_experiment_run_annotations = aliased(
838
- models.ExperimentRunAnnotation, name="base_experiment_run_annotations"
839
- )
840
- query = (
841
- select(base_experiment_run_annotations.name)
842
772
  .select_from(base_experiment_runs)
843
773
  .join(
844
- base_experiment_run_annotations,
845
- onclause=base_experiment_runs.c.id
846
- == base_experiment_run_annotations.experiment_run_id,
774
+ compare_experiment_runs,
775
+ onclause=base_experiment_runs.c.dataset_example_id
776
+ == compare_experiment_runs.c.dataset_example_id,
847
777
  isouter=True,
848
778
  )
849
- .group_by(base_experiment_run_annotations.name)
850
- .order_by(base_experiment_run_annotations.name)
851
779
  )
852
- for compare_experiment_index, compare_experiment_rowid in enumerate(
853
- compare_experiment_rowids
854
- ):
855
- compare_experiment_runs = (
856
- select(models.ExperimentRun)
857
- .where(
858
- models.ExperimentRun.experiment_id == compare_experiment_rowid,
859
- )
860
- .subquery()
861
- .alias(f"comp_exp_{compare_experiment_index}_runs")
862
- )
863
- compare_experiment_run_annotations = aliased(
864
- models.ExperimentRunAnnotation,
865
- name=f"comp_exp_{compare_experiment_index}_run_annotations",
866
- )
867
- query = (
868
- query.add_columns(
869
- _count_rows(
870
- base_experiment_run_annotations.score
871
- < compare_experiment_run_annotations.score,
872
- ).label(f"comp_exp_{compare_experiment_index}_num_runs_increased_score"),
873
- _count_rows(
874
- base_experiment_run_annotations.score
875
- > compare_experiment_run_annotations.score,
876
- ).label(f"comp_exp_{compare_experiment_index}_num_runs_decreased_score"),
877
- _count_rows(
878
- base_experiment_run_annotations.score
879
- == compare_experiment_run_annotations.score,
880
- ).label(f"comp_exp_{compare_experiment_index}_num_runs_equal_score"),
881
- )
882
- .join(
883
- compare_experiment_runs,
884
- onclause=base_experiment_runs.c.dataset_example_id
885
- == compare_experiment_runs.c.dataset_example_id,
886
- isouter=True,
887
- )
888
- .join(
889
- compare_experiment_run_annotations,
890
- onclause=compare_experiment_runs.c.id
891
- == compare_experiment_run_annotations.experiment_run_id,
892
- isouter=True,
893
- )
894
- .where(
895
- base_experiment_run_annotations.name == compare_experiment_run_annotations.name
896
- )
897
- )
780
+
898
781
  async with info.context.db() as session:
899
- result = (await session.execute(query)).all()
782
+ result = (await session.execute(comparisons_query)).first()
900
783
  assert result is not None
901
- num_columns_per_compare_experiment = (len(query.columns) - 1) // len(compare_experiment_ids)
902
- metric_counts = []
903
- for record in result:
904
- annotation_name, *counts = record
905
- for compare_experiment_index, compare_experiment_id in enumerate(
906
- compare_experiment_ids
907
- ):
908
- start_index = compare_experiment_index * num_columns_per_compare_experiment
909
- end_index = start_index + num_columns_per_compare_experiment
910
- (
911
- num_runs_with_increased_score,
912
- num_runs_with_decreased_score,
913
- num_runs_with_equal_score,
914
- ) = counts[start_index:end_index]
915
- metric_counts.append(
916
- CompareExperimentRunAnnotationMetricCounts(
917
- annotation_name=annotation_name,
918
- compare_experiment_id=compare_experiment_id,
919
- num_increases=num_runs_with_increased_score,
920
- num_decreases=num_runs_with_decreased_score,
921
- num_equal=num_runs_with_equal_score,
922
- )
923
- )
924
- return metric_counts
784
+
785
+ return ExperimentRunMetricComparisons(
786
+ latency=ExperimentRunMetricComparison(
787
+ num_runs_improved=result.num_latency_improved,
788
+ num_runs_regressed=result.num_latency_regressed,
789
+ num_runs_equal=result.num_latency_is_equal,
790
+ num_total_runs=result.num_base_experiment_runs,
791
+ ),
792
+ total_token_count=ExperimentRunMetricComparison(
793
+ num_runs_improved=result.num_total_token_count_improved,
794
+ num_runs_regressed=result.num_total_token_count_regressed,
795
+ num_runs_equal=result.num_total_token_count_is_equal,
796
+ num_total_runs=result.num_base_experiment_runs,
797
+ ),
798
+ prompt_token_count=ExperimentRunMetricComparison(
799
+ num_runs_improved=result.num_prompt_token_count_improved,
800
+ num_runs_regressed=result.num_prompt_token_count_regressed,
801
+ num_runs_equal=result.num_prompt_token_count_is_equal,
802
+ num_total_runs=result.num_base_experiment_runs,
803
+ ),
804
+ completion_token_count=ExperimentRunMetricComparison(
805
+ num_runs_improved=result.num_completion_token_count_improved,
806
+ num_runs_regressed=result.num_completion_token_count_regressed,
807
+ num_runs_equal=result.num_completion_token_count_is_equal,
808
+ num_total_runs=result.num_base_experiment_runs,
809
+ ),
810
+ total_cost=ExperimentRunMetricComparison(
811
+ num_runs_improved=result.num_total_cost_improved,
812
+ num_runs_regressed=result.num_total_cost_regressed,
813
+ num_runs_equal=result.num_total_cost_is_equal,
814
+ num_total_runs=result.num_base_experiment_runs,
815
+ ),
816
+ prompt_cost=ExperimentRunMetricComparison(
817
+ num_runs_improved=result.num_prompt_cost_improved,
818
+ num_runs_regressed=result.num_prompt_cost_regressed,
819
+ num_runs_equal=result.num_prompt_cost_is_equal,
820
+ num_total_runs=result.num_base_experiment_runs,
821
+ ),
822
+ completion_cost=ExperimentRunMetricComparison(
823
+ num_runs_improved=result.num_completion_cost_improved,
824
+ num_runs_regressed=result.num_completion_cost_regressed,
825
+ num_runs_equal=result.num_completion_cost_is_equal,
826
+ num_total_runs=result.num_base_experiment_runs,
827
+ ),
828
+ )
925
829
 
926
830
  @strawberry.field
927
831
  async def validate_experiment_run_filter_condition(
@@ -1550,16 +1454,36 @@ def _longest_matching_prefix(s: str, prefixes: Iterable[str]) -> str:
1550
1454
  return longest
1551
1455
 
1552
1456
 
1553
- def _count_rows(
1554
- condition: ColumnElement[Any],
1555
- ) -> ColumnElement[Any]:
1457
+ def _comparison_count_expression(
1458
+ *,
1459
+ base_column: ColumnElement[Any],
1460
+ compare_column: ColumnElement[Any],
1461
+ optimization_direction: Literal["maximize", "minimize"],
1462
+ comparison_type: Literal["improvement", "regression", "equality"],
1463
+ ) -> ColumnElement[int]:
1556
1464
  """
1557
- Returns an expression that counts the number of rows satisfying the condition.
1465
+ Given a base and compare column, returns an expression counting the number of
1466
+ improvements, regressions, or equalities given the optimization direction.
1558
1467
  """
1468
+ if optimization_direction == "maximize":
1469
+ raise NotImplementedError
1470
+
1471
+ if comparison_type == "improvement":
1472
+ condition = compare_column > base_column
1473
+ elif comparison_type == "regression":
1474
+ condition = compare_column < base_column
1475
+ elif comparison_type == "equality":
1476
+ condition = compare_column == base_column
1477
+ else:
1478
+ assert_never(comparison_type)
1479
+
1559
1480
  return func.coalesce(
1560
1481
  func.sum(
1561
1482
  case(
1562
- (condition, 1),
1483
+ (
1484
+ condition,
1485
+ 1,
1486
+ ),
1563
1487
  else_=0,
1564
1488
  )
1565
1489
  ),