arize-phoenix 10.15.0__py3-none-any.whl → 11.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (79) hide show
  1. {arize_phoenix-10.15.0.dist-info → arize_phoenix-11.1.0.dist-info}/METADATA +2 -2
  2. {arize_phoenix-10.15.0.dist-info → arize_phoenix-11.1.0.dist-info}/RECORD +77 -46
  3. phoenix/config.py +5 -2
  4. phoenix/datetime_utils.py +8 -1
  5. phoenix/db/bulk_inserter.py +40 -1
  6. phoenix/db/facilitator.py +263 -4
  7. phoenix/db/insertion/helpers.py +15 -0
  8. phoenix/db/insertion/span.py +3 -1
  9. phoenix/db/migrations/versions/a20694b15f82_cost.py +196 -0
  10. phoenix/db/models.py +267 -9
  11. phoenix/db/types/token_price_customization.py +29 -0
  12. phoenix/server/api/context.py +38 -4
  13. phoenix/server/api/dataloaders/__init__.py +41 -5
  14. phoenix/server/api/dataloaders/last_used_times_by_generative_model_id.py +35 -0
  15. phoenix/server/api/dataloaders/span_cost_by_span.py +24 -0
  16. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_generative_model.py +56 -0
  17. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_project_session.py +57 -0
  18. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_span.py +43 -0
  19. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_trace.py +56 -0
  20. phoenix/server/api/dataloaders/span_cost_details_by_span_cost.py +27 -0
  21. phoenix/server/api/dataloaders/span_cost_summary_by_experiment.py +58 -0
  22. phoenix/server/api/dataloaders/span_cost_summary_by_experiment_run.py +58 -0
  23. phoenix/server/api/dataloaders/span_cost_summary_by_generative_model.py +55 -0
  24. phoenix/server/api/dataloaders/span_cost_summary_by_project.py +140 -0
  25. phoenix/server/api/dataloaders/span_cost_summary_by_project_session.py +56 -0
  26. phoenix/server/api/dataloaders/span_cost_summary_by_trace.py +55 -0
  27. phoenix/server/api/dataloaders/span_costs.py +35 -0
  28. phoenix/server/api/dataloaders/types.py +29 -0
  29. phoenix/server/api/helpers/playground_clients.py +103 -12
  30. phoenix/server/api/input_types/ProjectSessionSort.py +3 -0
  31. phoenix/server/api/input_types/SpanSort.py +17 -0
  32. phoenix/server/api/mutations/__init__.py +2 -0
  33. phoenix/server/api/mutations/chat_mutations.py +17 -0
  34. phoenix/server/api/mutations/model_mutations.py +208 -0
  35. phoenix/server/api/queries.py +82 -41
  36. phoenix/server/api/routers/v1/traces.py +11 -4
  37. phoenix/server/api/subscriptions.py +36 -2
  38. phoenix/server/api/types/CostBreakdown.py +15 -0
  39. phoenix/server/api/types/Experiment.py +59 -1
  40. phoenix/server/api/types/ExperimentRun.py +58 -4
  41. phoenix/server/api/types/GenerativeModel.py +143 -2
  42. phoenix/server/api/types/{Model.py → InferenceModel.py} +1 -1
  43. phoenix/server/api/types/ModelInterface.py +11 -0
  44. phoenix/server/api/types/PlaygroundModel.py +10 -0
  45. phoenix/server/api/types/Project.py +42 -0
  46. phoenix/server/api/types/ProjectSession.py +44 -0
  47. phoenix/server/api/types/Span.py +137 -0
  48. phoenix/server/api/types/SpanCostDetailSummaryEntry.py +10 -0
  49. phoenix/server/api/types/SpanCostSummary.py +10 -0
  50. phoenix/server/api/types/TokenPrice.py +16 -0
  51. phoenix/server/api/types/TokenUsage.py +3 -3
  52. phoenix/server/api/types/Trace.py +41 -0
  53. phoenix/server/app.py +59 -0
  54. phoenix/server/cost_tracking/cost_details_calculator.py +190 -0
  55. phoenix/server/cost_tracking/cost_model_lookup.py +151 -0
  56. phoenix/server/cost_tracking/helpers.py +68 -0
  57. phoenix/server/cost_tracking/model_cost_manifest.json +59 -329
  58. phoenix/server/cost_tracking/regex_specificity.py +397 -0
  59. phoenix/server/cost_tracking/token_cost_calculator.py +57 -0
  60. phoenix/server/daemons/__init__.py +0 -0
  61. phoenix/server/daemons/generative_model_store.py +51 -0
  62. phoenix/server/daemons/span_cost_calculator.py +103 -0
  63. phoenix/server/dml_event_handler.py +1 -0
  64. phoenix/server/static/.vite/manifest.json +36 -36
  65. phoenix/server/static/assets/components-BQWqzM6Z.js +5055 -0
  66. phoenix/server/static/assets/{index-DIlhmbjB.js → index-t6f0PRIo.js} +13 -13
  67. phoenix/server/static/assets/{pages-YX47cEoQ.js → pages-B8Uyb2qa.js} +818 -422
  68. phoenix/server/static/assets/{vendor-DCZoBorz.js → vendor-DqQvHbPa.js} +147 -147
  69. phoenix/server/static/assets/{vendor-arizeai-Ckci3irT.js → vendor-arizeai-CLX44PFA.js} +1 -1
  70. phoenix/server/static/assets/{vendor-codemirror-BODM513D.js → vendor-codemirror-Du3XyJnB.js} +1 -1
  71. phoenix/server/static/assets/{vendor-recharts-C9O2a-N3.js → vendor-recharts-B2PJDrnX.js} +25 -25
  72. phoenix/server/static/assets/{vendor-shiki-Dq54rRC7.js → vendor-shiki-CNbrFjf9.js} +1 -1
  73. phoenix/version.py +1 -1
  74. phoenix/server/cost_tracking/cost_lookup.py +0 -255
  75. phoenix/server/static/assets/components-SpUMF1qV.js +0 -4509
  76. {arize_phoenix-10.15.0.dist-info → arize_phoenix-11.1.0.dist-info}/WHEEL +0 -0
  77. {arize_phoenix-10.15.0.dist-info → arize_phoenix-11.1.0.dist-info}/entry_points.txt +0 -0
  78. {arize_phoenix-10.15.0.dist-info → arize_phoenix-11.1.0.dist-info}/licenses/IP_NOTICE +0 -0
  79. {arize_phoenix-10.15.0.dist-info → arize_phoenix-11.1.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,56 @@
1
+ from collections import defaultdict
2
+
3
+ from sqlalchemy import func, select
4
+ from sqlalchemy.sql.functions import coalesce
5
+ from strawberry.dataloader import DataLoader
6
+ from typing_extensions import TypeAlias
7
+
8
+ from phoenix.db import models
9
+ from phoenix.server.api.dataloaders.types import (
10
+ CostBreakdown,
11
+ SpanCostDetailSummaryEntry,
12
+ )
13
+ from phoenix.server.types import DbSessionFactory
14
+
15
+ GenerativeModelId: TypeAlias = int
16
+ Key: TypeAlias = GenerativeModelId
17
+ Result: TypeAlias = list[SpanCostDetailSummaryEntry]
18
+
19
+
20
+ class SpanCostDetailSummaryEntriesByGenerativeModelDataLoader(DataLoader[Key, Result]):
21
+ def __init__(self, db: DbSessionFactory) -> None:
22
+ super().__init__(load_fn=self._load_fn)
23
+ self._db = db
24
+
25
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
26
+ pk = models.SpanCost.model_id
27
+ stmt = (
28
+ select(
29
+ pk,
30
+ models.SpanCostDetail.token_type,
31
+ models.SpanCostDetail.is_prompt,
32
+ coalesce(func.sum(models.SpanCostDetail.cost), 0).label("cost"),
33
+ coalesce(func.sum(models.SpanCostDetail.tokens), 0).label("tokens"),
34
+ )
35
+ .select_from(models.SpanCostDetail)
36
+ .join(models.SpanCost, models.SpanCostDetail.span_cost_id == models.SpanCost.id)
37
+ .where(pk.in_(keys))
38
+ .group_by(pk, models.SpanCostDetail.token_type, models.SpanCostDetail.is_prompt)
39
+ )
40
+ results: defaultdict[Key, Result] = defaultdict(list)
41
+ async with self._db() as session:
42
+ data = await session.stream(stmt)
43
+ async for (
44
+ id_,
45
+ token_type,
46
+ is_prompt,
47
+ cost,
48
+ tokens,
49
+ ) in data:
50
+ entry = SpanCostDetailSummaryEntry(
51
+ token_type=token_type,
52
+ is_prompt=is_prompt,
53
+ value=CostBreakdown(tokens=tokens, cost=cost),
54
+ )
55
+ results[id_].append(entry)
56
+ return list(map(list, map(results.__getitem__, keys)))
@@ -0,0 +1,57 @@
1
+ from collections import defaultdict
2
+
3
+ from sqlalchemy import func, select
4
+ from sqlalchemy.sql.functions import coalesce
5
+ from strawberry.dataloader import DataLoader
6
+ from typing_extensions import TypeAlias
7
+
8
+ from phoenix.db import models
9
+ from phoenix.server.api.dataloaders.types import (
10
+ CostBreakdown,
11
+ SpanCostDetailSummaryEntry,
12
+ )
13
+ from phoenix.server.types import DbSessionFactory
14
+
15
+ ProjectSessionRowId: TypeAlias = int
16
+ Key: TypeAlias = ProjectSessionRowId
17
+ Result: TypeAlias = list[SpanCostDetailSummaryEntry]
18
+
19
+
20
+ class SpanCostDetailSummaryEntriesByProjectSessionDataLoader(DataLoader[Key, Result]):
21
+ def __init__(self, db: DbSessionFactory) -> None:
22
+ super().__init__(load_fn=self._load_fn)
23
+ self._db = db
24
+
25
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
26
+ pk = models.Trace.project_session_rowid
27
+ stmt = (
28
+ select(
29
+ pk,
30
+ models.SpanCostDetail.token_type,
31
+ models.SpanCostDetail.is_prompt,
32
+ coalesce(func.sum(models.SpanCostDetail.cost), 0).label("cost"),
33
+ coalesce(func.sum(models.SpanCostDetail.tokens), 0).label("tokens"),
34
+ )
35
+ .select_from(models.SpanCostDetail)
36
+ .join(models.SpanCost, models.SpanCostDetail.span_cost_id == models.SpanCost.id)
37
+ .join(models.Trace, models.SpanCost.trace_rowid == models.Trace.id)
38
+ .where(pk.in_(keys))
39
+ .group_by(pk, models.SpanCostDetail.token_type, models.SpanCostDetail.is_prompt)
40
+ )
41
+ results: defaultdict[Key, Result] = defaultdict(list)
42
+ async with self._db() as session:
43
+ data = await session.stream(stmt)
44
+ async for (
45
+ id_,
46
+ token_type,
47
+ is_prompt,
48
+ cost,
49
+ tokens,
50
+ ) in data:
51
+ entry = SpanCostDetailSummaryEntry(
52
+ token_type=token_type,
53
+ is_prompt=is_prompt,
54
+ value=CostBreakdown(tokens=tokens, cost=cost),
55
+ )
56
+ results[id_].append(entry)
57
+ return list(map(list, map(results.__getitem__, keys)))
@@ -0,0 +1,43 @@
1
+ from collections import defaultdict
2
+
3
+ from sqlalchemy import select
4
+ from sqlalchemy.orm import contains_eager
5
+ from strawberry.dataloader import DataLoader
6
+ from typing_extensions import TypeAlias
7
+
8
+ from phoenix.db import models
9
+ from phoenix.server.api.dataloaders.types import (
10
+ CostBreakdown,
11
+ SpanCostDetailSummaryEntry,
12
+ )
13
+ from phoenix.server.types import DbSessionFactory
14
+
15
+ SpanRowID: TypeAlias = int
16
+ Key: TypeAlias = SpanRowID
17
+ Result: TypeAlias = list[SpanCostDetailSummaryEntry]
18
+
19
+
20
+ class SpanCostDetailSummaryEntriesBySpanDataLoader(DataLoader[Key, Result]):
21
+ def __init__(self, db: DbSessionFactory) -> None:
22
+ super().__init__(load_fn=self._load_fn)
23
+ self._db = db
24
+
25
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
26
+ results: defaultdict[Key, Result] = defaultdict(list)
27
+ async with self._db() as session:
28
+ async for span_cost_detail in await session.stream_scalars(
29
+ select(models.SpanCostDetail)
30
+ .join(models.SpanCost, models.SpanCostDetail.span_cost_id == models.SpanCost.id)
31
+ .where(models.SpanCost.span_rowid.in_(keys))
32
+ .options(contains_eager(models.SpanCostDetail.span_cost))
33
+ ):
34
+ entry = SpanCostDetailSummaryEntry(
35
+ token_type=span_cost_detail.token_type,
36
+ is_prompt=span_cost_detail.is_prompt,
37
+ value=CostBreakdown(
38
+ tokens=span_cost_detail.tokens,
39
+ cost=span_cost_detail.cost,
40
+ ),
41
+ )
42
+ results[span_cost_detail.span_cost.span_rowid].append(entry)
43
+ return list(map(list, map(results.__getitem__, keys)))
@@ -0,0 +1,56 @@
1
+ from collections import defaultdict
2
+
3
+ from sqlalchemy import func, select
4
+ from sqlalchemy.sql.functions import coalesce
5
+ from strawberry.dataloader import DataLoader
6
+ from typing_extensions import TypeAlias
7
+
8
+ from phoenix.db import models
9
+ from phoenix.server.api.dataloaders.types import (
10
+ CostBreakdown,
11
+ SpanCostDetailSummaryEntry,
12
+ )
13
+ from phoenix.server.types import DbSessionFactory
14
+
15
+ TraceRowId: TypeAlias = int
16
+ Key: TypeAlias = TraceRowId
17
+ Result: TypeAlias = list[SpanCostDetailSummaryEntry]
18
+
19
+
20
+ class SpanCostDetailSummaryEntriesByTraceDataLoader(DataLoader[Key, Result]):
21
+ def __init__(self, db: DbSessionFactory) -> None:
22
+ super().__init__(load_fn=self._load_fn)
23
+ self._db = db
24
+
25
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
26
+ pk = models.SpanCost.trace_rowid
27
+ stmt = (
28
+ select(
29
+ pk,
30
+ models.SpanCostDetail.token_type,
31
+ models.SpanCostDetail.is_prompt,
32
+ coalesce(func.sum(models.SpanCostDetail.cost), 0).label("cost"),
33
+ coalesce(func.sum(models.SpanCostDetail.tokens), 0).label("tokens"),
34
+ )
35
+ .select_from(models.SpanCostDetail)
36
+ .join(models.SpanCost, models.SpanCostDetail.span_cost_id == models.SpanCost.id)
37
+ .where(pk.in_(keys))
38
+ .group_by(pk, models.SpanCostDetail.token_type, models.SpanCostDetail.is_prompt)
39
+ )
40
+ results: defaultdict[Key, Result] = defaultdict(list)
41
+ async with self._db() as session:
42
+ data = await session.stream(stmt)
43
+ async for (
44
+ id_,
45
+ token_type,
46
+ is_prompt,
47
+ cost,
48
+ tokens,
49
+ ) in data:
50
+ entry = SpanCostDetailSummaryEntry(
51
+ token_type=token_type,
52
+ is_prompt=is_prompt,
53
+ value=CostBreakdown(tokens=tokens, cost=cost),
54
+ )
55
+ results[id_].append(entry)
56
+ return list(map(list, map(results.__getitem__, keys)))
@@ -0,0 +1,27 @@
1
+ from collections import defaultdict
2
+
3
+ from sqlalchemy import select
4
+ from strawberry.dataloader import DataLoader
5
+ from typing_extensions import TypeAlias
6
+
7
+ from phoenix.db import models
8
+ from phoenix.server.types import DbSessionFactory
9
+
10
+ SpanCostId: TypeAlias = int
11
+ Key: TypeAlias = SpanCostId
12
+ Result: TypeAlias = list[models.SpanCostDetail]
13
+
14
+
15
+ class SpanCostDetailsBySpanCostDataLoader(DataLoader[Key, Result]):
16
+ def __init__(self, db: DbSessionFactory) -> None:
17
+ super().__init__(load_fn=self._load_fn)
18
+ self._db = db
19
+
20
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
21
+ entity = models.SpanCostDetail
22
+ stmt = select(entity).where(entity.span_cost_id.in_(keys))
23
+ result: defaultdict[Key, Result] = defaultdict(list)
24
+ async with self._db() as session:
25
+ async for obj in await session.stream_scalars(stmt):
26
+ result[obj.span_cost_id].append(obj)
27
+ return list(map(result.__getitem__, keys))
@@ -0,0 +1,58 @@
1
+ from collections import defaultdict
2
+
3
+ from sqlalchemy import func, select
4
+ from sqlalchemy.sql.functions import coalesce
5
+ from strawberry.dataloader import DataLoader
6
+ from typing_extensions import TypeAlias
7
+
8
+ from phoenix.db import models
9
+ from phoenix.server.api.dataloaders.types import CostBreakdown, SpanCostSummary
10
+ from phoenix.server.types import DbSessionFactory
11
+
12
+ ExperimentId: TypeAlias = int
13
+ Key: TypeAlias = ExperimentId
14
+ Result: TypeAlias = SpanCostSummary
15
+
16
+
17
+ class SpanCostSummaryByExperimentDataLoader(DataLoader[Key, Result]):
18
+ def __init__(self, db: DbSessionFactory) -> None:
19
+ super().__init__(load_fn=self._load_fn)
20
+ self._db = db
21
+
22
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
23
+ stmt = (
24
+ select(
25
+ models.ExperimentRun.experiment_id,
26
+ coalesce(func.sum(models.SpanCost.prompt_cost), 0).label("prompt_cost"),
27
+ coalesce(func.sum(models.SpanCost.completion_cost), 0).label("completion_cost"),
28
+ coalesce(func.sum(models.SpanCost.total_cost), 0).label("total_cost"),
29
+ coalesce(func.sum(models.SpanCost.prompt_tokens), 0).label("prompt_tokens"),
30
+ coalesce(func.sum(models.SpanCost.completion_tokens), 0).label("completion_tokens"),
31
+ coalesce(func.sum(models.SpanCost.total_tokens), 0).label("total_tokens"),
32
+ )
33
+ .select_from(models.ExperimentRun)
34
+ .join(models.Trace, models.ExperimentRun.trace_id == models.Trace.trace_id)
35
+ .join(models.SpanCost, models.SpanCost.trace_rowid == models.Trace.id)
36
+ .where(models.ExperimentRun.experiment_id.in_(keys))
37
+ .group_by(models.ExperimentRun.experiment_id)
38
+ )
39
+
40
+ results: defaultdict[Key, Result] = defaultdict(SpanCostSummary)
41
+ async with self._db() as session:
42
+ data = await session.stream(stmt)
43
+ async for (
44
+ experiment_id,
45
+ prompt_cost,
46
+ completion_cost,
47
+ total_cost,
48
+ prompt_tokens,
49
+ completion_tokens,
50
+ total_tokens,
51
+ ) in data:
52
+ summary = SpanCostSummary(
53
+ prompt=CostBreakdown(tokens=prompt_tokens, cost=prompt_cost),
54
+ completion=CostBreakdown(tokens=completion_tokens, cost=completion_cost),
55
+ total=CostBreakdown(tokens=total_tokens, cost=total_cost),
56
+ )
57
+ results[experiment_id] = summary
58
+ return list(map(results.__getitem__, keys))
@@ -0,0 +1,58 @@
1
+ from collections import defaultdict
2
+
3
+ from sqlalchemy import func, select
4
+ from sqlalchemy.sql.functions import coalesce
5
+ from strawberry.dataloader import DataLoader
6
+ from typing_extensions import TypeAlias
7
+
8
+ from phoenix.db import models
9
+ from phoenix.server.api.dataloaders.types import CostBreakdown, SpanCostSummary
10
+ from phoenix.server.types import DbSessionFactory
11
+
12
+ ExperimentRunId: TypeAlias = int
13
+ Key: TypeAlias = ExperimentRunId
14
+ Result: TypeAlias = SpanCostSummary
15
+
16
+
17
+ class SpanCostSummaryByExperimentRunDataLoader(DataLoader[Key, Result]):
18
+ def __init__(self, db: DbSessionFactory) -> None:
19
+ super().__init__(load_fn=self._load_fn)
20
+ self._db = db
21
+
22
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
23
+ stmt = (
24
+ select(
25
+ models.ExperimentRun.id,
26
+ coalesce(func.sum(models.SpanCost.prompt_cost), 0).label("prompt_cost"),
27
+ coalesce(func.sum(models.SpanCost.completion_cost), 0).label("completion_cost"),
28
+ coalesce(func.sum(models.SpanCost.total_cost), 0).label("total_cost"),
29
+ coalesce(func.sum(models.SpanCost.prompt_tokens), 0).label("prompt_tokens"),
30
+ coalesce(func.sum(models.SpanCost.completion_tokens), 0).label("completion_tokens"),
31
+ coalesce(func.sum(models.SpanCost.total_tokens), 0).label("total_tokens"),
32
+ )
33
+ .select_from(models.ExperimentRun)
34
+ .join(models.Trace, models.ExperimentRun.trace_id == models.Trace.trace_id)
35
+ .join(models.SpanCost, models.SpanCost.trace_rowid == models.Trace.id)
36
+ .where(models.ExperimentRun.id.in_(keys))
37
+ .group_by(models.ExperimentRun.id)
38
+ )
39
+
40
+ results: defaultdict[Key, Result] = defaultdict(SpanCostSummary)
41
+ async with self._db() as session:
42
+ data = await session.stream(stmt)
43
+ async for (
44
+ run_id,
45
+ prompt_cost,
46
+ completion_cost,
47
+ total_cost,
48
+ prompt_tokens,
49
+ completion_tokens,
50
+ total_tokens,
51
+ ) in data:
52
+ summary = SpanCostSummary(
53
+ prompt=CostBreakdown(tokens=prompt_tokens, cost=prompt_cost),
54
+ completion=CostBreakdown(tokens=completion_tokens, cost=completion_cost),
55
+ total=CostBreakdown(tokens=total_tokens, cost=total_cost),
56
+ )
57
+ results[run_id] = summary
58
+ return list(map(results.__getitem__, keys))
@@ -0,0 +1,55 @@
1
+ from collections import defaultdict
2
+
3
+ from sqlalchemy import func, select
4
+ from sqlalchemy.sql.functions import coalesce
5
+ from strawberry.dataloader import DataLoader
6
+ from typing_extensions import TypeAlias
7
+
8
+ from phoenix.db import models
9
+ from phoenix.server.api.dataloaders.types import CostBreakdown, SpanCostSummary
10
+ from phoenix.server.types import DbSessionFactory
11
+
12
+ GenerativeModelId: TypeAlias = int
13
+ Key: TypeAlias = GenerativeModelId
14
+ Result: TypeAlias = SpanCostSummary
15
+
16
+
17
+ class SpanCostSummaryByGenerativeModelDataLoader(DataLoader[Key, Result]):
18
+ def __init__(self, db: DbSessionFactory) -> None:
19
+ super().__init__(load_fn=self._load_fn)
20
+ self._db = db
21
+
22
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
23
+ pk = models.SpanCost.model_id
24
+ stmt = (
25
+ select(
26
+ pk,
27
+ coalesce(func.sum(models.SpanCost.prompt_cost), 0).label("prompt_cost"),
28
+ coalesce(func.sum(models.SpanCost.completion_cost), 0).label("completion_cost"),
29
+ coalesce(func.sum(models.SpanCost.total_cost), 0).label("total_cost"),
30
+ coalesce(func.sum(models.SpanCost.prompt_tokens), 0).label("prompt_tokens"),
31
+ coalesce(func.sum(models.SpanCost.completion_tokens), 0).label("completion_tokens"),
32
+ coalesce(func.sum(models.SpanCost.total_tokens), 0).label("total_tokens"),
33
+ )
34
+ .where(pk.in_(keys))
35
+ .group_by(pk)
36
+ )
37
+ results: defaultdict[Key, Result] = defaultdict(SpanCostSummary)
38
+ async with self._db() as session:
39
+ data = await session.stream(stmt)
40
+ async for (
41
+ id_,
42
+ prompt_cost,
43
+ completion_cost,
44
+ total_cost,
45
+ prompt_tokens,
46
+ completion_tokens,
47
+ total_tokens,
48
+ ) in data:
49
+ summary = SpanCostSummary(
50
+ prompt=CostBreakdown(tokens=prompt_tokens, cost=prompt_cost),
51
+ completion=CostBreakdown(tokens=completion_tokens, cost=completion_cost),
52
+ total=CostBreakdown(tokens=total_tokens, cost=total_cost),
53
+ )
54
+ results[id_] = summary
55
+ return list(map(results.__getitem__, keys))
@@ -0,0 +1,140 @@
1
+ from collections import defaultdict
2
+ from datetime import datetime
3
+ from typing import Any, Optional
4
+
5
+ from cachetools import LFUCache, TTLCache
6
+ from sqlalchemy import Select, func, select
7
+ from sqlalchemy.sql.functions import coalesce
8
+ from strawberry.dataloader import AbstractCache, DataLoader
9
+ from typing_extensions import TypeAlias
10
+
11
+ from phoenix.db import models
12
+ from phoenix.server.api.dataloaders.cache import TwoTierCache
13
+ from phoenix.server.api.dataloaders.types import CostBreakdown, SpanCostSummary
14
+ from phoenix.server.api.input_types.TimeRange import TimeRange
15
+ from phoenix.server.types import DbSessionFactory
16
+ from phoenix.trace.dsl import SpanFilter
17
+
18
+ ProjectRowId: TypeAlias = int
19
+ TimeInterval: TypeAlias = tuple[Optional[datetime], Optional[datetime]]
20
+ FilterCondition: TypeAlias = Optional[str]
21
+
22
+ Segment: TypeAlias = tuple[TimeInterval, FilterCondition]
23
+ Param: TypeAlias = ProjectRowId
24
+
25
+ Key: TypeAlias = tuple[ProjectRowId, Optional[TimeRange], FilterCondition]
26
+ Result: TypeAlias = SpanCostSummary
27
+ ResultPosition: TypeAlias = int
28
+ DEFAULT_VALUE: Result = SpanCostSummary()
29
+
30
+
31
+ def _cache_key_fn(key: Key) -> tuple[Segment, Param]:
32
+ project_rowid, time_range, filter_condition = key
33
+ interval = (
34
+ (time_range.start, time_range.end) if isinstance(time_range, TimeRange) else (None, None)
35
+ )
36
+ return (interval, filter_condition), project_rowid
37
+
38
+
39
+ _Section: TypeAlias = ProjectRowId
40
+ _SubKey: TypeAlias = tuple[TimeInterval, FilterCondition]
41
+
42
+
43
+ class SpanCostSummaryCache(
44
+ TwoTierCache[Key, Result, _Section, _SubKey],
45
+ ):
46
+ def __init__(self) -> None:
47
+ super().__init__(
48
+ # TTL=3600 (1-hour) because time intervals are always moving forward, but
49
+ # interval endpoints are rounded down to the hour by the UI, so anything
50
+ # older than an hour most likely won't be a cache-hit anyway.
51
+ main_cache=TTLCache(maxsize=64, ttl=3600),
52
+ sub_cache_factory=lambda: LFUCache(maxsize=2 * 2 * 3),
53
+ )
54
+
55
+ def _cache_key(self, key: Key) -> tuple[_Section, _SubKey]:
56
+ (interval, filter_condition), project_rowid = _cache_key_fn(key)
57
+ return project_rowid, (interval, filter_condition)
58
+
59
+
60
+ class SpanCostSummaryByProjectDataLoader(DataLoader[Key, Result]):
61
+ def __init__(
62
+ self,
63
+ db: DbSessionFactory,
64
+ cache_map: Optional[AbstractCache[Key, Result]] = None,
65
+ ) -> None:
66
+ super().__init__(
67
+ load_fn=self._load_fn,
68
+ cache_key_fn=_cache_key_fn,
69
+ cache_map=cache_map,
70
+ )
71
+ self._db = db
72
+
73
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
74
+ results: list[Result] = [DEFAULT_VALUE] * len(keys)
75
+ arguments: defaultdict[
76
+ Segment,
77
+ defaultdict[Param, list[ResultPosition]],
78
+ ] = defaultdict(lambda: defaultdict(list))
79
+ for position, key in enumerate(keys):
80
+ segment, param = _cache_key_fn(key)
81
+ arguments[segment][param].append(position)
82
+ async with self._db() as session:
83
+ for segment, params in arguments.items():
84
+ stmt = _get_stmt(segment, *params.keys())
85
+ data = await session.stream(stmt)
86
+ async for (
87
+ id_,
88
+ prompt_cost,
89
+ completion_cost,
90
+ total_cost,
91
+ prompt_tokens,
92
+ completion_tokens,
93
+ total_tokens,
94
+ ) in data:
95
+ summary = SpanCostSummary(
96
+ prompt=CostBreakdown(tokens=prompt_tokens, cost=prompt_cost),
97
+ completion=CostBreakdown(tokens=completion_tokens, cost=completion_cost),
98
+ total=CostBreakdown(tokens=total_tokens, cost=total_cost),
99
+ )
100
+ for position in params.get(id_, []):
101
+ results[position] = summary
102
+ return results
103
+
104
+
105
+ def _get_stmt(
106
+ segment: Segment,
107
+ *params: Param,
108
+ ) -> Select[Any]:
109
+ (start_time, end_time), filter_condition = segment
110
+ pid = models.Trace.project_rowid
111
+
112
+ stmt: Select[Any] = (
113
+ select(
114
+ pid,
115
+ coalesce(func.sum(models.SpanCost.prompt_cost), 0).label("prompt_cost"),
116
+ coalesce(func.sum(models.SpanCost.completion_cost), 0).label("completion_cost"),
117
+ coalesce(func.sum(models.SpanCost.total_cost), 0).label("total_cost"),
118
+ coalesce(func.sum(models.SpanCost.prompt_tokens), 0).label("prompt_tokens"),
119
+ coalesce(func.sum(models.SpanCost.completion_tokens), 0).label("completion_tokens"),
120
+ coalesce(func.sum(models.SpanCost.total_tokens), 0).label("total_tokens"),
121
+ )
122
+ .select_from(models.Trace)
123
+ .join(models.Span, models.Span.trace_rowid == models.Trace.id)
124
+ .join(models.SpanCost, models.Span.id == models.SpanCost.span_rowid)
125
+ .group_by(pid)
126
+ )
127
+
128
+ if start_time:
129
+ stmt = stmt.where(start_time <= models.Span.start_time)
130
+ if end_time:
131
+ stmt = stmt.where(models.Span.start_time < end_time)
132
+
133
+ if filter_condition:
134
+ sf = SpanFilter(filter_condition)
135
+ stmt = sf(stmt)
136
+
137
+ project_ids = [rowid for rowid in params]
138
+ stmt = stmt.where(pid.in_(project_ids))
139
+
140
+ return stmt
@@ -0,0 +1,56 @@
1
+ from collections import defaultdict
2
+
3
+ from sqlalchemy import func, select
4
+ from sqlalchemy.sql.functions import coalesce
5
+ from strawberry.dataloader import DataLoader
6
+ from typing_extensions import TypeAlias
7
+
8
+ from phoenix.db import models
9
+ from phoenix.server.api.dataloaders.types import CostBreakdown, SpanCostSummary
10
+ from phoenix.server.types import DbSessionFactory
11
+
12
+ ProjectSessionRowId: TypeAlias = int
13
+ Key: TypeAlias = ProjectSessionRowId
14
+ Result: TypeAlias = SpanCostSummary
15
+
16
+
17
+ class SpanCostSummaryByProjectSessionDataLoader(DataLoader[Key, Result]):
18
+ def __init__(self, db: DbSessionFactory) -> None:
19
+ super().__init__(load_fn=self._load_fn)
20
+ self._db = db
21
+
22
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
23
+ pk = models.Trace.project_session_rowid
24
+ stmt = (
25
+ select(
26
+ pk,
27
+ coalesce(func.sum(models.SpanCost.prompt_cost), 0).label("prompt_cost"),
28
+ coalesce(func.sum(models.SpanCost.completion_cost), 0).label("completion_cost"),
29
+ coalesce(func.sum(models.SpanCost.total_cost), 0).label("total_cost"),
30
+ coalesce(func.sum(models.SpanCost.prompt_tokens), 0).label("prompt_tokens"),
31
+ coalesce(func.sum(models.SpanCost.completion_tokens), 0).label("completion_tokens"),
32
+ coalesce(func.sum(models.SpanCost.total_tokens), 0).label("total_tokens"),
33
+ )
34
+ .join_from(models.SpanCost, models.Trace)
35
+ .where(pk.in_(keys))
36
+ .group_by(pk)
37
+ )
38
+ results: defaultdict[Key, Result] = defaultdict(SpanCostSummary)
39
+ async with self._db() as session:
40
+ data = await session.stream(stmt)
41
+ async for (
42
+ id_,
43
+ prompt_cost,
44
+ completion_cost,
45
+ total_cost,
46
+ prompt_tokens,
47
+ completion_tokens,
48
+ total_tokens,
49
+ ) in data:
50
+ summary = SpanCostSummary(
51
+ prompt=CostBreakdown(tokens=prompt_tokens, cost=prompt_cost),
52
+ completion=CostBreakdown(tokens=completion_tokens, cost=completion_cost),
53
+ total=CostBreakdown(tokens=total_tokens, cost=total_cost),
54
+ )
55
+ results[id_] = summary
56
+ return list(map(results.__getitem__, keys))
@@ -0,0 +1,55 @@
1
+ from collections import defaultdict
2
+
3
+ from sqlalchemy import func, select
4
+ from sqlalchemy.sql.functions import coalesce
5
+ from strawberry.dataloader import DataLoader
6
+ from typing_extensions import TypeAlias
7
+
8
+ from phoenix.db import models
9
+ from phoenix.server.api.dataloaders.types import CostBreakdown, SpanCostSummary
10
+ from phoenix.server.types import DbSessionFactory
11
+
12
+ TraceRowId: TypeAlias = int
13
+ Key: TypeAlias = TraceRowId
14
+ Result: TypeAlias = SpanCostSummary
15
+
16
+
17
+ class SpanCostSummaryByTraceDataLoader(DataLoader[Key, Result]):
18
+ def __init__(self, db: DbSessionFactory) -> None:
19
+ super().__init__(load_fn=self._load_fn)
20
+ self._db = db
21
+
22
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
23
+ pk = models.SpanCost.trace_rowid
24
+ stmt = (
25
+ select(
26
+ pk,
27
+ coalesce(func.sum(models.SpanCost.prompt_cost), 0).label("prompt_cost"),
28
+ coalesce(func.sum(models.SpanCost.completion_cost), 0).label("completion_cost"),
29
+ coalesce(func.sum(models.SpanCost.total_cost), 0).label("total_cost"),
30
+ coalesce(func.sum(models.SpanCost.prompt_tokens), 0).label("prompt_tokens"),
31
+ coalesce(func.sum(models.SpanCost.completion_tokens), 0).label("completion_tokens"),
32
+ coalesce(func.sum(models.SpanCost.total_tokens), 0).label("total_tokens"),
33
+ )
34
+ .where(pk.in_(keys))
35
+ .group_by(pk)
36
+ )
37
+ results: defaultdict[Key, Result] = defaultdict(SpanCostSummary)
38
+ async with self._db() as session:
39
+ data = await session.stream(stmt)
40
+ async for (
41
+ id_,
42
+ prompt_cost,
43
+ completion_cost,
44
+ total_cost,
45
+ prompt_tokens,
46
+ completion_tokens,
47
+ total_tokens,
48
+ ) in data:
49
+ summary = SpanCostSummary(
50
+ prompt=CostBreakdown(tokens=prompt_tokens, cost=prompt_cost),
51
+ completion=CostBreakdown(tokens=completion_tokens, cost=completion_cost),
52
+ total=CostBreakdown(tokens=total_tokens, cost=total_cost),
53
+ )
54
+ results[id_] = summary
55
+ return list(map(results.__getitem__, keys))