arize-phoenix 10.15.0__py3-none-any.whl → 11.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (79) hide show
  1. {arize_phoenix-10.15.0.dist-info → arize_phoenix-11.0.0.dist-info}/METADATA +2 -2
  2. {arize_phoenix-10.15.0.dist-info → arize_phoenix-11.0.0.dist-info}/RECORD +77 -46
  3. phoenix/config.py +5 -2
  4. phoenix/datetime_utils.py +8 -1
  5. phoenix/db/bulk_inserter.py +40 -1
  6. phoenix/db/facilitator.py +263 -4
  7. phoenix/db/insertion/helpers.py +15 -0
  8. phoenix/db/insertion/span.py +3 -1
  9. phoenix/db/migrations/versions/a20694b15f82_cost.py +196 -0
  10. phoenix/db/models.py +267 -9
  11. phoenix/db/types/token_price_customization.py +29 -0
  12. phoenix/server/api/context.py +38 -4
  13. phoenix/server/api/dataloaders/__init__.py +41 -5
  14. phoenix/server/api/dataloaders/last_used_times_by_generative_model_id.py +35 -0
  15. phoenix/server/api/dataloaders/span_cost_by_span.py +24 -0
  16. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_generative_model.py +56 -0
  17. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_project_session.py +57 -0
  18. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_span.py +43 -0
  19. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_trace.py +56 -0
  20. phoenix/server/api/dataloaders/span_cost_details_by_span_cost.py +27 -0
  21. phoenix/server/api/dataloaders/span_cost_summary_by_experiment.py +58 -0
  22. phoenix/server/api/dataloaders/span_cost_summary_by_experiment_run.py +58 -0
  23. phoenix/server/api/dataloaders/span_cost_summary_by_generative_model.py +55 -0
  24. phoenix/server/api/dataloaders/span_cost_summary_by_project.py +140 -0
  25. phoenix/server/api/dataloaders/span_cost_summary_by_project_session.py +56 -0
  26. phoenix/server/api/dataloaders/span_cost_summary_by_trace.py +55 -0
  27. phoenix/server/api/dataloaders/span_costs.py +35 -0
  28. phoenix/server/api/dataloaders/types.py +29 -0
  29. phoenix/server/api/helpers/playground_clients.py +103 -12
  30. phoenix/server/api/input_types/ProjectSessionSort.py +3 -0
  31. phoenix/server/api/input_types/SpanSort.py +17 -0
  32. phoenix/server/api/mutations/__init__.py +2 -0
  33. phoenix/server/api/mutations/chat_mutations.py +17 -0
  34. phoenix/server/api/mutations/model_mutations.py +208 -0
  35. phoenix/server/api/queries.py +82 -41
  36. phoenix/server/api/routers/v1/traces.py +11 -4
  37. phoenix/server/api/subscriptions.py +36 -2
  38. phoenix/server/api/types/CostBreakdown.py +15 -0
  39. phoenix/server/api/types/Experiment.py +59 -1
  40. phoenix/server/api/types/ExperimentRun.py +58 -4
  41. phoenix/server/api/types/GenerativeModel.py +143 -2
  42. phoenix/server/api/types/{Model.py → InferenceModel.py} +1 -1
  43. phoenix/server/api/types/ModelInterface.py +11 -0
  44. phoenix/server/api/types/PlaygroundModel.py +10 -0
  45. phoenix/server/api/types/Project.py +42 -0
  46. phoenix/server/api/types/ProjectSession.py +44 -0
  47. phoenix/server/api/types/Span.py +137 -0
  48. phoenix/server/api/types/SpanCostDetailSummaryEntry.py +10 -0
  49. phoenix/server/api/types/SpanCostSummary.py +10 -0
  50. phoenix/server/api/types/TokenPrice.py +16 -0
  51. phoenix/server/api/types/TokenUsage.py +3 -3
  52. phoenix/server/api/types/Trace.py +41 -0
  53. phoenix/server/app.py +59 -0
  54. phoenix/server/cost_tracking/cost_details_calculator.py +190 -0
  55. phoenix/server/cost_tracking/cost_model_lookup.py +151 -0
  56. phoenix/server/cost_tracking/helpers.py +68 -0
  57. phoenix/server/cost_tracking/model_cost_manifest.json +59 -329
  58. phoenix/server/cost_tracking/regex_specificity.py +397 -0
  59. phoenix/server/cost_tracking/token_cost_calculator.py +57 -0
  60. phoenix/server/daemons/__init__.py +0 -0
  61. phoenix/server/daemons/generative_model_store.py +51 -0
  62. phoenix/server/daemons/span_cost_calculator.py +103 -0
  63. phoenix/server/dml_event_handler.py +1 -0
  64. phoenix/server/static/.vite/manifest.json +36 -36
  65. phoenix/server/static/assets/components-BnK9kodr.js +5055 -0
  66. phoenix/server/static/assets/{index-DIlhmbjB.js → index-S3YKLmbo.js} +13 -13
  67. phoenix/server/static/assets/{pages-YX47cEoQ.js → pages-BW6PBHZb.js} +811 -419
  68. phoenix/server/static/assets/{vendor-DCZoBorz.js → vendor-DqQvHbPa.js} +147 -147
  69. phoenix/server/static/assets/{vendor-arizeai-Ckci3irT.js → vendor-arizeai-CLX44PFA.js} +1 -1
  70. phoenix/server/static/assets/{vendor-codemirror-BODM513D.js → vendor-codemirror-Du3XyJnB.js} +1 -1
  71. phoenix/server/static/assets/{vendor-recharts-C9O2a-N3.js → vendor-recharts-B2PJDrnX.js} +25 -25
  72. phoenix/server/static/assets/{vendor-shiki-Dq54rRC7.js → vendor-shiki-CNbrFjf9.js} +1 -1
  73. phoenix/version.py +1 -1
  74. phoenix/server/cost_tracking/cost_lookup.py +0 -255
  75. phoenix/server/static/assets/components-SpUMF1qV.js +0 -4509
  76. {arize_phoenix-10.15.0.dist-info → arize_phoenix-11.0.0.dist-info}/WHEEL +0 -0
  77. {arize_phoenix-10.15.0.dist-info → arize_phoenix-11.0.0.dist-info}/entry_points.txt +0 -0
  78. {arize_phoenix-10.15.0.dist-info → arize_phoenix-11.0.0.dist-info}/licenses/IP_NOTICE +0 -0
  79. {arize_phoenix-10.15.0.dist-info → arize_phoenix-11.0.0.dist-info}/licenses/LICENSE +0 -0
phoenix/db/facilitator.py CHANGED
@@ -1,13 +1,20 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import asyncio
4
+ import json
4
5
  import logging
6
+ import re
5
7
  import secrets
6
8
  from asyncio import gather
9
+ from datetime import datetime, timedelta, timezone
7
10
  from functools import partial
8
- from typing import Optional
11
+ from pathlib import Path
12
+ from typing import NamedTuple, Optional, Union
9
13
 
10
14
  import sqlalchemy as sa
15
+ from sqlalchemy import select
16
+ from sqlalchemy.orm import InstrumentedAttribute, joinedload
17
+ from sqlalchemy.sql.dml import ReturningDelete
11
18
 
12
19
  from phoenix import config
13
20
  from phoenix.auth import (
@@ -26,7 +33,6 @@ from phoenix.config import (
26
33
  from phoenix.db import models
27
34
  from phoenix.db.constants import DEFAULT_PROJECT_TRACE_RETENTION_POLICY_ID
28
35
  from phoenix.db.enums import ENUM_COLUMNS
29
- from phoenix.db.models import UserRoleName
30
36
  from phoenix.db.types.trace_retention import (
31
37
  MaxDaysRule,
32
38
  TraceRetentionCronExpression,
@@ -62,6 +68,8 @@ class Facilitator:
62
68
  _get_system_user_id,
63
69
  partial(_ensure_admins, email_sender=self._email_sender),
64
70
  _ensure_default_project_trace_retention_policy,
71
+ _ensure_model_costs,
72
+ _delete_expired_childless_records,
65
73
  ):
66
74
  await fn(self._db)
67
75
 
@@ -92,13 +100,13 @@ async def _ensure_user_roles(db: DbSessionFactory) -> None:
92
100
  the email "admin@localhost".
93
101
  """
94
102
  async with db() as session:
95
- role_ids: dict[UserRoleName, int] = {
103
+ role_ids: dict[models.UserRoleName, int] = {
96
104
  name: id_
97
105
  async for name, id_ in await session.stream(
98
106
  sa.select(models.UserRole.name, models.UserRole.id)
99
107
  )
100
108
  }
101
- existing_roles: list[UserRoleName] = [
109
+ existing_roles: list[models.UserRoleName] = [
102
110
  name
103
111
  async for name in await session.stream_scalars(
104
112
  sa.select(sa.distinct(models.UserRole.name)).join_from(models.User, models.UserRole)
@@ -216,6 +224,83 @@ async def _ensure_admins(
216
224
  logger.error(f"Failed to send welcome email: {exc}")
217
225
 
218
226
 
227
+ _CHILDLESS_RECORD_DELETION_GRACE_PERIOD_DAYS = 1
228
+
229
+
230
+ def _stmt_to_delete_expired_childless_records(
231
+ table: type[models.Base],
232
+ foreign_key: Union[InstrumentedAttribute[int], InstrumentedAttribute[Optional[int]]],
233
+ ) -> ReturningDelete[tuple[int]]:
234
+ """
235
+ Creates a SQLAlchemy DELETE statement to permanently remove childless records.
236
+
237
+ Args:
238
+ table: The table model class that has a deleted_at column
239
+ foreign_key: The foreign key attribute to check for child relationships
240
+
241
+ Returns:
242
+ A DELETE statement that removes childless records marked for deletion more than
243
+ _CHILDLESS_RECORD_DELETION_GRACE_PERIOD_DAYS days ago
244
+ """ # noqa: E501
245
+ if not hasattr(table, "deleted_at"):
246
+ raise TypeError("Table must have a 'deleted_at' column")
247
+ cutoff_time = datetime.now(timezone.utc) - timedelta(
248
+ days=_CHILDLESS_RECORD_DELETION_GRACE_PERIOD_DAYS
249
+ )
250
+ return (
251
+ sa.delete(table)
252
+ .where(table.deleted_at.isnot(None))
253
+ .where(table.deleted_at < cutoff_time)
254
+ .where(~sa.exists().where(table.id == foreign_key))
255
+ .returning(table.id)
256
+ )
257
+
258
+
259
+ async def _delete_expired_childless_records_on_generative_models(
260
+ db: DbSessionFactory,
261
+ ) -> None:
262
+ """
263
+ Permanently deletes childless GenerativeModel records that have been marked for deletion.
264
+
265
+ This function removes GenerativeModel records that:
266
+ - Have been marked for deletion (deleted_at is not NULL)
267
+ - Were marked more than 1 day ago (grace period expired)
268
+ - Have no associated SpanCost records (childless)
269
+
270
+ This cleanup is necessary to remove orphaned records that may have been left behind
271
+ due to previous migrations or deletions.
272
+ """ # noqa: E501
273
+ stmt = _stmt_to_delete_expired_childless_records(
274
+ models.GenerativeModel,
275
+ models.SpanCost.model_id,
276
+ )
277
+ async with db() as session:
278
+ result = (await session.scalars(stmt)).all()
279
+ if result:
280
+ logger.info(f"Permanently deleted {len(result)} expired childless GenerativeModel records")
281
+ else:
282
+ logger.debug("No expired childless GenerativeModel records found for permanent deletion")
283
+
284
+
285
+ async def _delete_expired_childless_records(
286
+ db: DbSessionFactory,
287
+ ) -> None:
288
+ """
289
+ Permanently deletes childless records across all relevant tables.
290
+
291
+ This function runs the deletion process for all table types that support soft deletion,
292
+ handling any exceptions that occur during the process. Only records that have been
293
+ marked for deletion for more than the grace period (1 day) are permanently removed.
294
+ """ # noqa: E501
295
+ exceptions = await gather(
296
+ _delete_expired_childless_records_on_generative_models(db),
297
+ return_exceptions=True,
298
+ )
299
+ for exc in exceptions:
300
+ if isinstance(exc, Exception):
301
+ logger.error(f"Failed to delete childless records: {exc}")
302
+
303
+
219
304
  async def _ensure_default_project_trace_retention_policy(db: DbSessionFactory) -> None:
220
305
  """
221
306
  Ensures the default trace retention policy (id=1) exists in the database. Default policy
@@ -261,3 +346,177 @@ async def _ensure_default_project_trace_retention_policy(db: DbSessionFactory) -
261
346
  }
262
347
  ],
263
348
  )
349
+
350
+
351
+ _COST_MODEL_MANIFEST: Path = (
352
+ Path(__file__).parent.parent / "server" / "cost_tracking" / "model_cost_manifest.json"
353
+ )
354
+
355
+
356
+ class _TokenTypeKey(NamedTuple):
357
+ """
358
+ Composite key for uniquely identifying token price configurations.
359
+
360
+ Token prices are differentiated by both their type (e.g., "input", "output", "audio")
361
+ and whether they represent prompt tokens (input to the model) or completion tokens
362
+ (output from the model). Some token types like "audio" can exist in both categories.
363
+
364
+ Attributes:
365
+ token_type: The category of token (e.g., "input", "output", "audio", "cache_write")
366
+ is_prompt: True if these are prompt/input tokens, False if completion/output tokens
367
+ """
368
+
369
+ token_type: str
370
+ is_prompt: bool
371
+
372
+
373
+ async def _ensure_model_costs(db: DbSessionFactory) -> None:
374
+ """
375
+ Ensures that built-in generative models and their token pricing information are up-to-date
376
+ in the database based on the model cost manifest file.
377
+
378
+ This function performs a comprehensive synchronization between the database and the manifest:
379
+
380
+ 1. **Model Management**: Creates new built-in models from the manifest or updates existing ones
381
+ 2. **Token Price Synchronization**: Ensures all token prices match the manifest rates
382
+ 3. **Cleanup**: Soft-deletes built-in models no longer present in the manifest
383
+
384
+ The function handles different token types including:
385
+ - Input tokens (prompt): Standard input tokens for generation
386
+ - Cache write tokens (prompt): Tokens written to cache systems
387
+ - Cache read tokens (prompt): Tokens read from cache systems
388
+ - Output tokens (non-prompt): Generated response tokens
389
+ - Audio tokens (both prompt and non-prompt): Audio processing tokens
390
+
391
+ Token prices are uniquely identified by (token_type, is_prompt) pairs to handle
392
+ cases like audio tokens that can be both prompt and non-prompt.
393
+
394
+ Args:
395
+ db (DbSessionFactory): Database session factory for database operations
396
+
397
+ Returns:
398
+ None
399
+
400
+ Raises:
401
+ FileNotFoundError: If the model cost manifest file is not found
402
+ json.JSONDecodeError: If the manifest file contains invalid JSON
403
+ ValueError: If manifest data is malformed or missing required fields
404
+ """
405
+ # Load the authoritative model cost data from the manifest file
406
+ with open(_COST_MODEL_MANIFEST) as f:
407
+ manifest = json.load(f)
408
+
409
+ # Define all supported token types with their prompt/non-prompt classification
410
+ # This determines how tokens are categorized for billing purposes
411
+ token_types: list[_TokenTypeKey] = [
412
+ _TokenTypeKey("input", True), # Standard input tokens
413
+ _TokenTypeKey("cache_write", True), # Tokens written to cache
414
+ _TokenTypeKey("cache_read", True), # Tokens read from cache
415
+ _TokenTypeKey("output", False), # Generated output tokens
416
+ _TokenTypeKey("audio", True), # Audio input tokens
417
+ _TokenTypeKey("audio", False), # Audio output tokens
418
+ ]
419
+
420
+ async with db() as session:
421
+ # Fetch all existing built-in models with their token prices eagerly loaded
422
+ # Using .unique() to deduplicate models when multiple token prices are joined
423
+ built_in_models = {
424
+ omodel.name: omodel
425
+ for omodel in (
426
+ await session.scalars(
427
+ select(models.GenerativeModel)
428
+ .where(models.GenerativeModel.deleted_at.is_(None))
429
+ .where(models.GenerativeModel.is_built_in.is_(True))
430
+ .options(joinedload(models.GenerativeModel.token_prices))
431
+ )
432
+ ).unique()
433
+ }
434
+
435
+ seen_names: set[str] = set()
436
+ seen_patterns: set[tuple[re.Pattern[str], str]] = set()
437
+
438
+ # Process each model in the manifest
439
+ for model_data in manifest:
440
+ name = str(model_data.get("model") or "").strip()
441
+ if not name:
442
+ logger.warning("Skipping model with empty name in manifest")
443
+ continue
444
+ if name in seen_names:
445
+ logger.warning(f"Skipping model '{name}' with duplicate name in manifest")
446
+ continue
447
+ seen_names.add(name)
448
+ regex = str(model_data.get("regex") or "").strip()
449
+ try:
450
+ pattern = re.compile(regex)
451
+ except re.error as e:
452
+ logger.warning(f"Skipping model '{name}' with invalid regex: {e}")
453
+ continue
454
+ provider = str(model_data.get("provider") or "").strip()
455
+ if (pattern, provider) in seen_patterns:
456
+ logger.warning(
457
+ f"Skipping model '{name}' with duplicate name_pattern/provider combination"
458
+ )
459
+ continue
460
+ seen_patterns.add((pattern, provider))
461
+ # Remove model from built_in_models dict (for cleanup tracking)
462
+ # or create new model if not found
463
+ model = built_in_models.pop(model_data["model"], None)
464
+ if model is None:
465
+ # Create new built-in model from manifest data
466
+ model = models.GenerativeModel(
467
+ name=name,
468
+ provider=provider,
469
+ name_pattern=pattern,
470
+ is_built_in=True,
471
+ )
472
+ session.add(model)
473
+ else:
474
+ # Update existing model's metadata from manifest
475
+ model.provider = provider
476
+ model.name_pattern = pattern
477
+
478
+ # Create lookup table for existing token prices by (token_type, is_prompt)
479
+ # Using pop() during iteration allows us to track which prices are no longer needed
480
+ token_prices = {
481
+ _TokenTypeKey(token_price.token_type, token_price.is_prompt): token_price
482
+ for token_price in model.token_prices
483
+ }
484
+
485
+ # Synchronize token prices for all supported token types
486
+ for token_type, is_prompt in token_types:
487
+ # Skip if this token type has no rate in the manifest
488
+ if not (base_rate := model_data.get(token_type)):
489
+ continue
490
+
491
+ key = _TokenTypeKey(token_type, is_prompt)
492
+ # Remove from tracking dict and get existing price (if any)
493
+ if not (token_price := token_prices.pop(key, None)):
494
+ # Create new token price if it doesn't exist
495
+ token_price = models.TokenPrice(
496
+ token_type=token_type,
497
+ is_prompt=is_prompt,
498
+ base_rate=base_rate,
499
+ )
500
+ model.token_prices.append(token_price)
501
+ elif token_price.base_rate != base_rate:
502
+ # Update existing price if rate has changed
503
+ token_price.base_rate = base_rate
504
+
505
+ # Remove any token prices that are no longer in the manifest
506
+ # These are prices that weren't popped from the token_prices dict above
507
+ for token_price in token_prices.values():
508
+ model.token_prices.remove(token_price)
509
+
510
+ # Clean up built-in models that are no longer in the manifest
511
+ # These are models that weren't popped from built_in_models dict above
512
+ remaining_models = list(built_in_models.values())
513
+ if not remaining_models:
514
+ return
515
+
516
+ # Soft delete obsolete built-in models
517
+ async with db() as session:
518
+ await session.execute(
519
+ sa.update(models.GenerativeModel)
520
+ .values(deleted_at=sa.func.now())
521
+ .where(models.GenerativeModel.id.in_([m.id for m in remaining_models]))
522
+ )
@@ -3,6 +3,7 @@ from collections.abc import Awaitable, Callable, Iterable, Iterator, Mapping, Se
3
3
  from enum import Enum, auto
4
4
  from typing import Any, Optional
5
5
 
6
+ from openinference.semconv.trace import OpenInferenceSpanKindValues, SpanAttributes
6
7
  from sqlalchemy import Insert
7
8
  from sqlalchemy.dialects.postgresql import insert as insert_postgresql
8
9
  from sqlalchemy.dialects.sqlite import insert as insert_sqlite
@@ -13,6 +14,7 @@ from typing_extensions import TypeAlias, assert_never
13
14
  from phoenix.db import models
14
15
  from phoenix.db.helpers import SupportedSQLDialect
15
16
  from phoenix.db.models import Base
17
+ from phoenix.trace.attributes import get_attribute_value
16
18
 
17
19
 
18
20
  class DataManipulationEvent(ABC):
@@ -97,3 +99,16 @@ def as_kv(obj: models.Base) -> Iterator[tuple[str, Any]]:
97
99
  # postgresql disallows None for primary key
98
100
  continue
99
101
  yield k, v
102
+
103
+
104
+ def should_calculate_span_cost(
105
+ attributes: Optional[Mapping[str, Any]],
106
+ ) -> bool:
107
+ return bool(
108
+ (span_kind := get_attribute_value(attributes, SpanAttributes.OPENINFERENCE_SPAN_KIND))
109
+ and isinstance(span_kind, str)
110
+ and span_kind == OpenInferenceSpanKindValues.LLM.value
111
+ and (llm_name := get_attribute_value(attributes, SpanAttributes.LLM_MODEL_NAME))
112
+ and isinstance(llm_name, str)
113
+ and llm_name.strip()
114
+ )
@@ -14,6 +14,8 @@ from phoenix.trace.schemas import Span, SpanStatusCode
14
14
 
15
15
  class SpanInsertionEvent(NamedTuple):
16
16
  project_rowid: int
17
+ span_rowid: int
18
+ trace_rowid: int
17
19
 
18
20
 
19
21
  class ClearProjectSpansEvent(NamedTuple):
@@ -190,4 +192,4 @@ async def insert_span(
190
192
  + cumulative_llm_token_count_completion,
191
193
  )
192
194
  )
193
- return SpanInsertionEvent(project_rowid)
195
+ return SpanInsertionEvent(project_rowid, span_rowid, trace.id)
@@ -0,0 +1,196 @@
1
+ """Cost-related tables
2
+
3
+ Revision ID: a20694b15f82
4
+ Revises: migrations/versions/6a88424799fe_update_users_with_auth_method.py
5
+ Create Date: 2025-05-30 17:15:12.663565
6
+
7
+ """
8
+
9
+ from typing import Sequence, Union
10
+
11
+ import sqlalchemy as sa
12
+ from alembic import op
13
+
14
+ # revision identifiers, used by Alembic.
15
+ revision: str = "a20694b15f82"
16
+ down_revision: Union[str, None] = "6a88424799fe"
17
+ branch_labels: Union[str, Sequence[str], None] = None
18
+ depends_on: Union[str, Sequence[str], None] = None
19
+
20
+ _Integer = sa.Integer().with_variant(
21
+ sa.BigInteger(),
22
+ "postgresql",
23
+ )
24
+
25
+
26
+ def upgrade() -> None:
27
+ op.create_table(
28
+ "generative_models",
29
+ sa.Column(
30
+ "id",
31
+ _Integer,
32
+ primary_key=True,
33
+ ),
34
+ sa.Column(
35
+ "name",
36
+ sa.String,
37
+ nullable=False,
38
+ ),
39
+ sa.Column(
40
+ "provider",
41
+ sa.String,
42
+ nullable=False,
43
+ ),
44
+ sa.Column(
45
+ "name_pattern",
46
+ sa.String,
47
+ nullable=False,
48
+ ),
49
+ sa.Column(
50
+ "is_built_in",
51
+ sa.Boolean,
52
+ nullable=False,
53
+ ),
54
+ sa.Column(
55
+ "start_time",
56
+ sa.TIMESTAMP(timezone=True),
57
+ ),
58
+ sa.Column(
59
+ "created_at",
60
+ sa.TIMESTAMP(timezone=True),
61
+ nullable=False,
62
+ server_default=sa.func.now(),
63
+ ),
64
+ sa.Column(
65
+ "updated_at",
66
+ sa.TIMESTAMP(timezone=True),
67
+ nullable=False,
68
+ server_default=sa.func.now(),
69
+ onupdate=sa.func.now(),
70
+ ),
71
+ sa.Column(
72
+ "deleted_at",
73
+ sa.TIMESTAMP(timezone=True),
74
+ nullable=True,
75
+ ),
76
+ )
77
+ op.create_index(
78
+ "ix_generative_models_match_criteria",
79
+ "generative_models",
80
+ ["name_pattern", "provider", "is_built_in"],
81
+ unique=True,
82
+ postgresql_where=sa.text("deleted_at IS NULL"),
83
+ sqlite_where=sa.text("deleted_at IS NULL"),
84
+ )
85
+ op.create_index(
86
+ "ix_generative_models_name_is_built_in",
87
+ "generative_models",
88
+ ["name", "is_built_in"],
89
+ unique=True,
90
+ postgresql_where=sa.text("deleted_at IS NULL"),
91
+ sqlite_where=sa.text("deleted_at IS NULL"),
92
+ )
93
+ op.create_table(
94
+ "token_prices",
95
+ sa.Column(
96
+ "id",
97
+ _Integer,
98
+ primary_key=True,
99
+ ),
100
+ sa.Column(
101
+ "model_id",
102
+ _Integer,
103
+ sa.ForeignKey("generative_models.id", ondelete="CASCADE"),
104
+ nullable=False,
105
+ index=True,
106
+ ),
107
+ sa.Column("token_type", sa.String, nullable=False),
108
+ sa.Column("is_prompt", sa.Boolean, nullable=False),
109
+ sa.Column("base_rate", sa.Float, nullable=False),
110
+ sa.Column("customization", sa.JSON),
111
+ sa.UniqueConstraint(
112
+ "model_id",
113
+ "token_type",
114
+ "is_prompt",
115
+ ),
116
+ )
117
+ op.create_table(
118
+ "span_costs",
119
+ sa.Column(
120
+ "id",
121
+ _Integer,
122
+ primary_key=True,
123
+ ),
124
+ sa.Column(
125
+ "span_rowid",
126
+ _Integer,
127
+ sa.ForeignKey("spans.id", ondelete="CASCADE"),
128
+ nullable=False,
129
+ index=True,
130
+ ),
131
+ sa.Column(
132
+ "trace_rowid",
133
+ _Integer,
134
+ sa.ForeignKey("traces.id", ondelete="CASCADE"),
135
+ nullable=False,
136
+ index=True,
137
+ ),
138
+ sa.Column(
139
+ "model_id",
140
+ _Integer,
141
+ sa.ForeignKey(
142
+ "generative_models.id",
143
+ ondelete="RESTRICT",
144
+ ),
145
+ nullable=True,
146
+ ),
147
+ sa.Column(
148
+ "span_start_time",
149
+ sa.TIMESTAMP(timezone=True),
150
+ nullable=False,
151
+ index=True,
152
+ ),
153
+ sa.Column("total_cost", sa.Float),
154
+ sa.Column("total_tokens", sa.Float),
155
+ sa.Column("prompt_cost", sa.Float),
156
+ sa.Column("prompt_tokens", sa.Float),
157
+ sa.Column("completion_cost", sa.Float),
158
+ sa.Column("completion_tokens", sa.Float),
159
+ )
160
+ op.create_index(
161
+ "ix_span_costs_model_id_span_start_time",
162
+ "span_costs",
163
+ ["model_id", "span_start_time"],
164
+ )
165
+ op.create_table(
166
+ "span_cost_details",
167
+ sa.Column(
168
+ "id",
169
+ _Integer,
170
+ primary_key=True,
171
+ ),
172
+ sa.Column(
173
+ "span_cost_id",
174
+ _Integer,
175
+ sa.ForeignKey("span_costs.id", ondelete="CASCADE"),
176
+ nullable=False,
177
+ index=True,
178
+ ),
179
+ sa.Column("token_type", sa.String, nullable=False, index=True),
180
+ sa.Column("is_prompt", sa.Boolean, nullable=False),
181
+ sa.Column("cost", sa.Float),
182
+ sa.Column("tokens", sa.Float),
183
+ sa.Column("cost_per_token", sa.Float),
184
+ sa.UniqueConstraint(
185
+ "span_cost_id",
186
+ "token_type",
187
+ "is_prompt",
188
+ ),
189
+ )
190
+
191
+
192
+ def downgrade() -> None:
193
+ op.drop_table("span_cost_details")
194
+ op.drop_table("span_costs")
195
+ op.drop_table("token_prices")
196
+ op.drop_table("generative_models")