arize-phoenix 9.5.0__py3-none-any.whl → 9.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-9.5.0.dist-info → arize_phoenix-9.6.1.dist-info}/METADATA +1 -1
- {arize_phoenix-9.5.0.dist-info → arize_phoenix-9.6.1.dist-info}/RECORD +18 -14
- phoenix/server/api/helpers/playground_clients.py +1 -1
- phoenix/server/api/routers/v1/datasets.py +11 -3
- phoenix/server/api/types/Span.py +49 -0
- phoenix/server/api/types/TokenCountPromptDetails.py +10 -0
- phoenix/server/cost_tracking/__init__.py +0 -0
- phoenix/server/cost_tracking/cost_lookup.py +255 -0
- phoenix/server/cost_tracking/model_cost_manifest.json +830 -0
- phoenix/server/static/.vite/manifest.json +9 -9
- phoenix/server/static/assets/{components-DpK7N6zE.js → components-CDvTuTqd.js} +3 -3
- phoenix/server/static/assets/{index-BXA0RjaV.js → index-DpcxdHu4.js} +1 -1
- phoenix/server/static/assets/{pages-jHwPRLA2.js → pages-Bcs41-Zv.js} +440 -396
- phoenix/version.py +1 -1
- {arize_phoenix-9.5.0.dist-info → arize_phoenix-9.6.1.dist-info}/WHEEL +0 -0
- {arize_phoenix-9.5.0.dist-info → arize_phoenix-9.6.1.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-9.5.0.dist-info → arize_phoenix-9.6.1.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-9.5.0.dist-info → arize_phoenix-9.6.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -6,7 +6,7 @@ phoenix/exceptions.py,sha256=n2L2KKuecrdflB9MsCdAYCiSEvGJptIsfRkXMoJle7A,169
|
|
|
6
6
|
phoenix/py.typed,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
|
7
7
|
phoenix/services.py,sha256=ngkyKGVatX3cO2WJdo2hKdaVKP-xJCMvqthvga6kJss,5196
|
|
8
8
|
phoenix/settings.py,sha256=x87BX7hWGQQZbrW_vrYqFR_izCGfO9gFc--JXUG4Tdk,754
|
|
9
|
-
phoenix/version.py,sha256=
|
|
9
|
+
phoenix/version.py,sha256=gkjhVoAFhlcpBLzRIiqhKP7hOWAIqvxp25eVs0y914g,22
|
|
10
10
|
phoenix/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
11
11
|
phoenix/core/embedding_dimension.py,sha256=zKGbcvwOXgLf-yrJBpQyKtd-LEOPRKHnUToyAU8Owis,87
|
|
12
12
|
phoenix/core/model.py,sha256=qBFraOtmwCCnWJltKNP18DDG0mULXigytlFsa6YOz6k,4837
|
|
@@ -157,7 +157,7 @@ phoenix/server/api/helpers/__init__.py,sha256=m2-xaSPqUiSs91k62JaRDjFNfl-1byxBfY
|
|
|
157
157
|
phoenix/server/api/helpers/annotations.py,sha256=9gMXKpMTfWEChoSCnvdWYuyB0hlSnNOp-qUdar9Vono,262
|
|
158
158
|
phoenix/server/api/helpers/dataset_helpers.py,sha256=DoMBTg-qXTnC_K4Evx1WKpCCYgRbITpVqyY-8efJRf0,8984
|
|
159
159
|
phoenix/server/api/helpers/experiment_run_filters.py,sha256=DOnVwrmn39eAkk2mwuZP8kIcAnR5jrOgllEwWSjsw94,29893
|
|
160
|
-
phoenix/server/api/helpers/playground_clients.py,sha256=
|
|
160
|
+
phoenix/server/api/helpers/playground_clients.py,sha256=C-GPq4wklcnGXiW5-7-ipx5wjowDuwSKzqbGHta2QEc,41888
|
|
161
161
|
phoenix/server/api/helpers/playground_registry.py,sha256=CPLMziFB2wmr-dfbx7VbzO2f8YIG_k5RftzvGXYGQ1w,2570
|
|
162
162
|
phoenix/server/api/helpers/playground_spans.py,sha256=ObAhvV_yNwEQDkjzgU5G73wfIisc8q4cpB0OFH5cd24,16974
|
|
163
163
|
phoenix/server/api/helpers/prompts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
@@ -231,7 +231,7 @@ phoenix/server/api/routers/utils.py,sha256=M41BoH-fl37izhRuN2aX7lWm7jOC20A_3uClv
|
|
|
231
231
|
phoenix/server/api/routers/v1/__init__.py,sha256=0oOcsKJkQtBXAjZAo3AMtfjyW3OGCU4MI4TGW5nV6lo,2614
|
|
232
232
|
phoenix/server/api/routers/v1/annotation_configs.py,sha256=rZ3yJm7m75BVegSjSHqsdqf7n26roGg7vYYiiKfWA3A,15898
|
|
233
233
|
phoenix/server/api/routers/v1/annotations.py,sha256=oeafR2tCLu-uIwM9J72gN3MX5WDhrOMU3Jqd1uIiFqg,5921
|
|
234
|
-
phoenix/server/api/routers/v1/datasets.py,sha256=
|
|
234
|
+
phoenix/server/api/routers/v1/datasets.py,sha256=Wqiy6ZKqn4BZSFyn93gzuhWx3mGn7kOkNncHzCWuBq8,37325
|
|
235
235
|
phoenix/server/api/routers/v1/evaluations.py,sha256=GFTo42aIEX0Htn0EjjoE1JZDYlvryeZ_CK9kowhwzGw,12830
|
|
236
236
|
phoenix/server/api/routers/v1/experiment_evaluations.py,sha256=xSs004jNYsOl3eg-6Zjo2tt9TefTd7WR3twWYrsNQNk,4828
|
|
237
237
|
phoenix/server/api/routers/v1/experiment_runs.py,sha256=jqpquCygtUYNNN7lgSvGvOlXCE7KTleDRFjxJ7bbDfM,6400
|
|
@@ -305,11 +305,12 @@ phoenix/server/api/types/Retrieval.py,sha256=OhMK2ncjoyp5h1yjKhjlKpoTbQrMHuxmgSF
|
|
|
305
305
|
phoenix/server/api/types/ScalarDriftMetricEnum.py,sha256=IUAcRPpgL41WdoIgK6cNk2Te38SspXGyEs-S1fY23_A,232
|
|
306
306
|
phoenix/server/api/types/Segments.py,sha256=vT2v0efoa5cuBKxLtxTnsUP5YJJCZfTloM71Spu0tMI,2915
|
|
307
307
|
phoenix/server/api/types/SortDir.py,sha256=OUpXhlCzCxPoXSDkJJygEs9Rw9pMymfaZUG5zPTrw4Y,152
|
|
308
|
-
phoenix/server/api/types/Span.py,sha256=
|
|
308
|
+
phoenix/server/api/types/Span.py,sha256=ZaDUBOPk4YE9nV0379Yc1NEZKnItWmIdCbJVgXaBgAU,30482
|
|
309
309
|
phoenix/server/api/types/SpanAnnotation.py,sha256=uPWu7Z8rmpfKhaaxbged4_o00pPCR3nkn7Gji9vB8jY,1959
|
|
310
310
|
phoenix/server/api/types/SpanIOValue.py,sha256=c5TWdZZN3v0gHI5xWeY7gjD-sE9ugWlGGAio-gDS-Uo,1653
|
|
311
311
|
phoenix/server/api/types/SystemApiKey.py,sha256=2ym8EgsTBIvxx1l9xZ-2YMovz58ZwYb_MaHBTJ9NH2E,166
|
|
312
312
|
phoenix/server/api/types/TimeSeries.py,sha256=nuuZtfHmOhTjeB8_SvZ5PUQexAkTcPScwYeFC5RUlRU,5491
|
|
313
|
+
phoenix/server/api/types/TokenCountPromptDetails.py,sha256=CWDWLrYoufrR1ePWfbq0-AgAkdjmGnJQt4_wNIt6bOQ,183
|
|
313
314
|
phoenix/server/api/types/TokenUsage.py,sha256=g-PjAGVigpchQgkXAuC5sc53fn2YwAgfeXkGmFPi_TE,201
|
|
314
315
|
phoenix/server/api/types/ToolDefinition.py,sha256=T6UH2vcbuPBDy7jKYOqMth2NdqxMPgDBf11Tpbt5Yb8,187
|
|
315
316
|
phoenix/server/api/types/Trace.py,sha256=fx1ozxiFMu-9AUyJ9LyMr6QtMqxzGEkucu7eE_dDZBM,8195
|
|
@@ -323,6 +324,9 @@ phoenix/server/api/types/VectorDriftMetricEnum.py,sha256=etiJM5ZjQuD-oE7sY-FbdIK
|
|
|
323
324
|
phoenix/server/api/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
324
325
|
phoenix/server/api/types/node.py,sha256=BLl_IOFr0zrqUxaAtGLGui5aeM5VNVXFTzGeAKrztr0,822
|
|
325
326
|
phoenix/server/api/types/pagination.py,sha256=BXm46gXZfrBS4hpiLvVSEdsbb29ctUMVJYjKXlOLxUA,9064
|
|
327
|
+
phoenix/server/cost_tracking/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
328
|
+
phoenix/server/cost_tracking/cost_lookup.py,sha256=c9COURDSW-LFAeuX1k2PX-kKpy8WZeIiwwjJr_YZOqY,9416
|
|
329
|
+
phoenix/server/cost_tracking/model_cost_manifest.json,sha256=tlOYj69-K0ru53ql3UtX-ynRU_J3C_g5BUGZR6aSirM,19270
|
|
326
330
|
phoenix/server/email/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
327
331
|
phoenix/server/email/sender.py,sha256=eC6RcLANVJH0mh20mGZ2qr-bU-OWo9po2e5og2tMzJw,4127
|
|
328
332
|
phoenix/server/email/types.py,sha256=IO2bTtCh-1cve-xiM4MWnunCCVNOQ3Z2cqTqF7vH-do,466
|
|
@@ -342,10 +346,10 @@ phoenix/server/static/apple-touch-icon-76x76.png,sha256=CT_xT12I0u2i0WU8JzBZBuOQ
|
|
|
342
346
|
phoenix/server/static/apple-touch-icon.png,sha256=fOfpjqGpWYbJ0eAurKsyoZP1EAs6ZVooBJ_SGk2ZkDs,3801
|
|
343
347
|
phoenix/server/static/favicon.ico,sha256=bY0vvCKRftemZfPShwZtE93DiiQdaYaozkPGwNFr6H8,34494
|
|
344
348
|
phoenix/server/static/modernizr.js,sha256=mvK-XtkNqjOral-QvzoqsyOMECXIMu5BQwSVN_wcU9c,2564
|
|
345
|
-
phoenix/server/static/.vite/manifest.json,sha256=
|
|
346
|
-
phoenix/server/static/assets/components-
|
|
347
|
-
phoenix/server/static/assets/index-
|
|
348
|
-
phoenix/server/static/assets/pages-
|
|
349
|
+
phoenix/server/static/.vite/manifest.json,sha256=hW3yshzfVwBhZmcRVFOv6lgVC5qy7v5U59K207nsiVI,2165
|
|
350
|
+
phoenix/server/static/assets/components-CDvTuTqd.js,sha256=lJUX_imM4QeN2DzlgfWXoC-tJ3eci8aeS-YaU6Bgy1Y,535701
|
|
351
|
+
phoenix/server/static/assets/index-DpcxdHu4.js,sha256=qvTZErnPG4_mjpaNZbTi53rL_9s9k25udpb0AElehBM,60240
|
|
352
|
+
phoenix/server/static/assets/pages-Bcs41-Zv.js,sha256=o5FR83BYGFyox-6dXabSgsHCGxVGblMzqnlKVpCsRjY,1038496
|
|
349
353
|
phoenix/server/static/assets/vendor-CToBXdDM.js,sha256=q_UwZrhCRrNhrvFyv3OO6bW52jM1TDiYk3aTj-NgdLU,2744392
|
|
350
354
|
phoenix/server/static/assets/vendor-WIZid84E.css,sha256=spZD2r7XL5GfLO13ln-IuXfnjAref8l6g_n_AvxxOlI,5517
|
|
351
355
|
phoenix/server/static/assets/vendor-arizeai-BhbMHqQs.js,sha256=l3G1o-P_IYcqQWOHBcSpT5RextOH2myGl58ZSN7NvcQ,193248
|
|
@@ -392,9 +396,9 @@ phoenix/utilities/project.py,sha256=auVpARXkDb-JgeX5f2aStyFIkeKvGwN9l7qrFeJMVxI,
|
|
|
392
396
|
phoenix/utilities/re.py,sha256=6YyUWIkv0zc2SigsxfOWIHzdpjKA_TZo2iqKq7zJKvw,2081
|
|
393
397
|
phoenix/utilities/span_store.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
394
398
|
phoenix/utilities/template_formatters.py,sha256=gh9PJD6WEGw7TEYXfSst1UR4pWWwmjxMLrDVQ_CkpkQ,2779
|
|
395
|
-
arize_phoenix-9.
|
|
396
|
-
arize_phoenix-9.
|
|
397
|
-
arize_phoenix-9.
|
|
398
|
-
arize_phoenix-9.
|
|
399
|
-
arize_phoenix-9.
|
|
400
|
-
arize_phoenix-9.
|
|
399
|
+
arize_phoenix-9.6.1.dist-info/METADATA,sha256=WERys-C4em-Qhasa3VK1FGQoe33vfqXhySCYQ3mvX5A,25590
|
|
400
|
+
arize_phoenix-9.6.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
401
|
+
arize_phoenix-9.6.1.dist-info/entry_points.txt,sha256=Pgpn8Upxx9P8z8joPXZWl2LlnAlGc3gcQoVchb06X1Q,94
|
|
402
|
+
arize_phoenix-9.6.1.dist-info/licenses/IP_NOTICE,sha256=JBqyyCYYxGDfzQ0TtsQgjts41IJoa-hiwDrBjCb9gHM,469
|
|
403
|
+
arize_phoenix-9.6.1.dist-info/licenses/LICENSE,sha256=HFkW9REuMOkvKRACuwLPT0hRydHb3zNg-fdFt94td18,3794
|
|
404
|
+
arize_phoenix-9.6.1.dist-info/RECORD,,
|
|
@@ -701,7 +701,7 @@ class AzureOpenAIStreamingClient(OpenAIBaseStreamingClient):
|
|
|
701
701
|
provider_key=GenerativeProviderKey.ANTHROPIC,
|
|
702
702
|
model_names=[
|
|
703
703
|
PROVIDER_DEFAULT,
|
|
704
|
-
"claude-3-7-latest",
|
|
704
|
+
"claude-3-7-sonnet-latest",
|
|
705
705
|
"claude-3-7-sonnet-20250219",
|
|
706
706
|
"claude-3-5-sonnet-latest",
|
|
707
707
|
"claude-3-5-haiku-latest",
|
|
@@ -3,6 +3,7 @@ import gzip
|
|
|
3
3
|
import io
|
|
4
4
|
import json
|
|
5
5
|
import logging
|
|
6
|
+
import urllib
|
|
6
7
|
import zlib
|
|
7
8
|
from asyncio import QueueFull
|
|
8
9
|
from collections import Counter
|
|
@@ -817,10 +818,11 @@ async def get_dataset_csv(
|
|
|
817
818
|
except ValueError as e:
|
|
818
819
|
raise HTTPException(detail=str(e), status_code=HTTP_422_UNPROCESSABLE_ENTITY)
|
|
819
820
|
content = await run_in_threadpool(_get_content_csv, examples)
|
|
821
|
+
encoded_dataset_name = urllib.parse.quote(dataset_name)
|
|
820
822
|
return Response(
|
|
821
823
|
content=content,
|
|
822
824
|
headers={
|
|
823
|
-
"content-disposition": f
|
|
825
|
+
"content-disposition": f"attachment; filename*=UTF-8''{encoded_dataset_name}.csv",
|
|
824
826
|
"content-type": "text/csv",
|
|
825
827
|
},
|
|
826
828
|
)
|
|
@@ -859,7 +861,10 @@ async def get_dataset_jsonl_openai_ft(
|
|
|
859
861
|
except ValueError as e:
|
|
860
862
|
raise HTTPException(detail=str(e), status_code=HTTP_422_UNPROCESSABLE_ENTITY)
|
|
861
863
|
content = await run_in_threadpool(_get_content_jsonl_openai_ft, examples)
|
|
862
|
-
|
|
864
|
+
encoded_dataset_name = urllib.parse.quote(dataset_name)
|
|
865
|
+
response.headers["content-disposition"] = (
|
|
866
|
+
f"attachment; filename*=UTF-8''{encoded_dataset_name}.jsonl"
|
|
867
|
+
)
|
|
863
868
|
return content
|
|
864
869
|
|
|
865
870
|
|
|
@@ -896,7 +901,10 @@ async def get_dataset_jsonl_openai_evals(
|
|
|
896
901
|
except ValueError as e:
|
|
897
902
|
raise HTTPException(detail=str(e), status_code=HTTP_422_UNPROCESSABLE_ENTITY)
|
|
898
903
|
content = await run_in_threadpool(_get_content_jsonl_openai_evals, examples)
|
|
899
|
-
|
|
904
|
+
encoded_dataset_name = urllib.parse.quote(dataset_name)
|
|
905
|
+
response.headers["content-disposition"] = (
|
|
906
|
+
f"attachment; filename*=UTF-8''{encoded_dataset_name}.jsonl"
|
|
907
|
+
)
|
|
900
908
|
return content
|
|
901
909
|
|
|
902
910
|
|
phoenix/server/api/types/Span.py
CHANGED
|
@@ -44,6 +44,8 @@ from phoenix.server.api.types.SpanAnnotation import SpanAnnotation, to_gql_span_
|
|
|
44
44
|
from phoenix.server.api.types.SpanIOValue import SpanIOValue, truncate_value
|
|
45
45
|
from phoenix.trace.attributes import get_attribute_value
|
|
46
46
|
|
|
47
|
+
from .TokenCountPromptDetails import TokenCountPromptDetails
|
|
48
|
+
|
|
47
49
|
if TYPE_CHECKING:
|
|
48
50
|
from phoenix.server.api.types.Project import Project
|
|
49
51
|
from phoenix.server.api.types.Trace import Trace
|
|
@@ -351,6 +353,48 @@ class Span(Node):
|
|
|
351
353
|
)
|
|
352
354
|
return cast(Optional[int], value)
|
|
353
355
|
|
|
356
|
+
@strawberry.field
|
|
357
|
+
async def token_prompt_details(
|
|
358
|
+
self,
|
|
359
|
+
info: Info[Context, None],
|
|
360
|
+
) -> TokenCountPromptDetails:
|
|
361
|
+
if self.db_span:
|
|
362
|
+
attributes = self.db_span.attributes
|
|
363
|
+
else:
|
|
364
|
+
attributes = await info.context.data_loaders.span_fields.load(
|
|
365
|
+
(self.span_rowid, models.Span.attributes),
|
|
366
|
+
)
|
|
367
|
+
|
|
368
|
+
cache_read: Optional[int] = None
|
|
369
|
+
raw_cache_read = get_attribute_value(
|
|
370
|
+
attributes=attributes,
|
|
371
|
+
key=LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_READ,
|
|
372
|
+
)
|
|
373
|
+
if isinstance(raw_cache_read, int):
|
|
374
|
+
cache_read = raw_cache_read
|
|
375
|
+
|
|
376
|
+
cache_write: Optional[int] = None
|
|
377
|
+
raw_cache_write = get_attribute_value(
|
|
378
|
+
attributes=attributes,
|
|
379
|
+
key=LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_WRITE,
|
|
380
|
+
)
|
|
381
|
+
if isinstance(raw_cache_write, int):
|
|
382
|
+
cache_write = raw_cache_write
|
|
383
|
+
|
|
384
|
+
audio: Optional[int] = None
|
|
385
|
+
raw_audio = get_attribute_value(
|
|
386
|
+
attributes=attributes,
|
|
387
|
+
key=LLM_TOKEN_COUNT_PROMPT_DETAILS_AUDIO,
|
|
388
|
+
)
|
|
389
|
+
if isinstance(raw_audio, int):
|
|
390
|
+
audio = raw_audio
|
|
391
|
+
|
|
392
|
+
return TokenCountPromptDetails(
|
|
393
|
+
cache_read=cache_read,
|
|
394
|
+
cache_write=cache_write,
|
|
395
|
+
audio=audio,
|
|
396
|
+
)
|
|
397
|
+
|
|
354
398
|
@strawberry.field
|
|
355
399
|
async def input(
|
|
356
400
|
self,
|
|
@@ -800,6 +844,11 @@ def _convert_metadata_to_string(metadata: Any) -> Optional[str]:
|
|
|
800
844
|
|
|
801
845
|
INPUT_MIME_TYPE = SpanAttributes.INPUT_MIME_TYPE
|
|
802
846
|
INPUT_VALUE = SpanAttributes.INPUT_VALUE
|
|
847
|
+
LLM_TOKEN_COUNT_PROMPT_DETAILS_AUDIO = SpanAttributes.LLM_TOKEN_COUNT_PROMPT_DETAILS_AUDIO
|
|
848
|
+
LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_READ = SpanAttributes.LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_READ
|
|
849
|
+
LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_WRITE = (
|
|
850
|
+
SpanAttributes.LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_WRITE
|
|
851
|
+
)
|
|
803
852
|
METADATA = SpanAttributes.METADATA
|
|
804
853
|
OUTPUT_MIME_TYPE = SpanAttributes.OUTPUT_MIME_TYPE
|
|
805
854
|
OUTPUT_VALUE = SpanAttributes.OUTPUT_VALUE
|
|
File without changes
|
|
@@ -0,0 +1,255 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
import re
|
|
4
|
+
from collections import defaultdict
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any, Iterator, Optional, Union
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass
|
|
11
|
+
class ModelTokenCost:
|
|
12
|
+
# Cost in USD
|
|
13
|
+
input: Optional[float] = None
|
|
14
|
+
output: Optional[float] = None
|
|
15
|
+
cache_write: Optional[float] = None
|
|
16
|
+
cache_read: Optional[float] = None
|
|
17
|
+
audio: Optional[float] = None
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class RegexDict:
|
|
21
|
+
__slots__ = ("_entries",)
|
|
22
|
+
|
|
23
|
+
def __init__(self) -> None:
|
|
24
|
+
self._entries: list[tuple[re.Pattern[str], Any]] = []
|
|
25
|
+
|
|
26
|
+
def __setitem__(self, pattern: Union[str, re.Pattern[str]], value: Any) -> None:
|
|
27
|
+
if isinstance(pattern, str):
|
|
28
|
+
compiled = re.compile(pattern)
|
|
29
|
+
elif isinstance(pattern, re.Pattern):
|
|
30
|
+
compiled = pattern
|
|
31
|
+
else:
|
|
32
|
+
raise TypeError("RegexDict key must be a str or re.Pattern")
|
|
33
|
+
|
|
34
|
+
for idx, (existing_pat, _) in enumerate(self._entries):
|
|
35
|
+
if existing_pat.pattern == compiled.pattern and existing_pat.flags == compiled.flags:
|
|
36
|
+
self._entries[idx] = (compiled, value)
|
|
37
|
+
return
|
|
38
|
+
self._entries.append((compiled, value))
|
|
39
|
+
|
|
40
|
+
def __delitem__(self, pattern: Union[str, re.Pattern[str]]) -> None:
|
|
41
|
+
if isinstance(pattern, str):
|
|
42
|
+
target = pattern
|
|
43
|
+
elif isinstance(pattern, re.Pattern):
|
|
44
|
+
target = pattern.pattern
|
|
45
|
+
else:
|
|
46
|
+
raise TypeError("RegexDict key must be a str or re.Pattern")
|
|
47
|
+
|
|
48
|
+
for idx, (existing_pat, _) in enumerate(self._entries):
|
|
49
|
+
if existing_pat.pattern == target:
|
|
50
|
+
del self._entries[idx]
|
|
51
|
+
return
|
|
52
|
+
raise KeyError(pattern)
|
|
53
|
+
|
|
54
|
+
def __getitem__(self, key: str) -> Any:
|
|
55
|
+
for pattern, value in self._entries:
|
|
56
|
+
if pattern.fullmatch(key):
|
|
57
|
+
return value
|
|
58
|
+
raise KeyError(key)
|
|
59
|
+
|
|
60
|
+
def __contains__(self, key: str) -> bool:
|
|
61
|
+
try:
|
|
62
|
+
_ = self[key]
|
|
63
|
+
return True
|
|
64
|
+
except KeyError:
|
|
65
|
+
return False
|
|
66
|
+
|
|
67
|
+
def __iter__(self) -> Iterator[tuple[str, Any]]:
|
|
68
|
+
for pattern, value in self._entries:
|
|
69
|
+
yield pattern.pattern, value
|
|
70
|
+
|
|
71
|
+
def __len__(self) -> int:
|
|
72
|
+
return len(self._entries)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class ModelCostLookup:
|
|
76
|
+
__slots__ = ("_provider_model_map", "_model_map", "_overrides", "_cache", "_max_cache_size")
|
|
77
|
+
|
|
78
|
+
def __init__(self) -> None:
|
|
79
|
+
# Each provider maps to a *RegexDict* of (pattern -> cost).
|
|
80
|
+
self._provider_model_map: defaultdict[Optional[str], RegexDict] = defaultdict(RegexDict)
|
|
81
|
+
# Map from *pattern string* to a set of providers that have that pattern.
|
|
82
|
+
self._model_map: defaultdict[re.Pattern[str], set[Optional[str]]] = defaultdict(set)
|
|
83
|
+
# A prioritized list of cost overrides (later overrides have higher priority).
|
|
84
|
+
self._overrides: list[tuple[Optional[str], re.Pattern[str], ModelTokenCost]] = []
|
|
85
|
+
# Cache for computed costs keyed by (provider, model_name).
|
|
86
|
+
self._cache: dict[tuple[Optional[str], str], list[tuple[str, ModelTokenCost]]] = {}
|
|
87
|
+
self._max_cache_size = 100
|
|
88
|
+
|
|
89
|
+
def add_pattern(
|
|
90
|
+
self, provider: Optional[str], pattern: re.Pattern[str], cost: ModelTokenCost
|
|
91
|
+
) -> None:
|
|
92
|
+
"""Register a model pattern with its cost."""
|
|
93
|
+
|
|
94
|
+
assert isinstance(pattern, re.Pattern), "pattern must be a compiled regex"
|
|
95
|
+
self._provider_model_map[provider][pattern] = cost
|
|
96
|
+
self._model_map[pattern].add(provider)
|
|
97
|
+
self._cache.clear()
|
|
98
|
+
|
|
99
|
+
def remove_pattern(self, provider: Optional[str], pattern: re.Pattern[str]) -> None:
|
|
100
|
+
"""Remove a previously-registered model pattern."""
|
|
101
|
+
|
|
102
|
+
assert isinstance(pattern, re.Pattern), "pattern must be a compiled regex"
|
|
103
|
+
if provider not in self._provider_model_map:
|
|
104
|
+
return
|
|
105
|
+
del self._provider_model_map[provider][pattern]
|
|
106
|
+
self._model_map[pattern].discard(provider)
|
|
107
|
+
if not self._provider_model_map[provider]:
|
|
108
|
+
del self._provider_model_map[provider]
|
|
109
|
+
if not self._model_map[pattern]:
|
|
110
|
+
del self._model_map[pattern]
|
|
111
|
+
self._cache.clear()
|
|
112
|
+
|
|
113
|
+
def get_cost(
|
|
114
|
+
self, provider: Optional[str], model_name: str
|
|
115
|
+
) -> list[tuple[str, ModelTokenCost]]:
|
|
116
|
+
key = (provider, model_name)
|
|
117
|
+
if key in self._cache:
|
|
118
|
+
value = self._cache.pop(key)
|
|
119
|
+
self._cache[key] = value
|
|
120
|
+
return value
|
|
121
|
+
|
|
122
|
+
result = self._lookup_cost(provider, model_name)
|
|
123
|
+
|
|
124
|
+
if len(self._cache) >= self._max_cache_size:
|
|
125
|
+
self._cache.pop(next(iter(self._cache)))
|
|
126
|
+
|
|
127
|
+
self._cache[key] = result
|
|
128
|
+
return result
|
|
129
|
+
|
|
130
|
+
def has_model(self, provider: Optional[str], model_name: str) -> bool:
|
|
131
|
+
"""Return ``True`` if a cost (either base or overridden) exists for the model."""
|
|
132
|
+
|
|
133
|
+
return self._contains(provider, model_name)
|
|
134
|
+
|
|
135
|
+
def pattern_count(self) -> int:
|
|
136
|
+
"""Return the number of registered *base* patterns (overrides not counted)."""
|
|
137
|
+
|
|
138
|
+
return sum(len(regex_dict) for regex_dict in self._provider_model_map.values())
|
|
139
|
+
|
|
140
|
+
def _lookup_cost(
|
|
141
|
+
self, provider: Optional[str], model_name: str
|
|
142
|
+
) -> list[tuple[str, ModelTokenCost]]:
|
|
143
|
+
assert isinstance(model_name, str), "Lookup key must be a str"
|
|
144
|
+
# 1) Provider-specific lookup
|
|
145
|
+
if provider is not None:
|
|
146
|
+
override_cost = self._lookup_override(provider, model_name)
|
|
147
|
+
if override_cost is not None:
|
|
148
|
+
return [(provider, override_cost)]
|
|
149
|
+
|
|
150
|
+
regex_dict = self._provider_model_map.get(provider)
|
|
151
|
+
if regex_dict is None:
|
|
152
|
+
raise KeyError(provider)
|
|
153
|
+
return [(provider, regex_dict[model_name])]
|
|
154
|
+
|
|
155
|
+
# 2) provider-agnostic lookup
|
|
156
|
+
provider_cost_map: dict[str, ModelTokenCost] = {}
|
|
157
|
+
for p, regex_dict in self._provider_model_map.items():
|
|
158
|
+
try:
|
|
159
|
+
provider_cost_map[p] = regex_dict[model_name] # type: ignore
|
|
160
|
+
except KeyError:
|
|
161
|
+
continue
|
|
162
|
+
|
|
163
|
+
for override_provider, override_pattern, override_cost in self._overrides:
|
|
164
|
+
if override_pattern.fullmatch(model_name):
|
|
165
|
+
if override_provider is None:
|
|
166
|
+
for p in list(provider_cost_map):
|
|
167
|
+
provider_cost_map[p] = override_cost
|
|
168
|
+
else:
|
|
169
|
+
provider_cost_map[override_provider] = override_cost
|
|
170
|
+
|
|
171
|
+
if not provider_cost_map:
|
|
172
|
+
raise KeyError(model_name)
|
|
173
|
+
return list(provider_cost_map.items())
|
|
174
|
+
|
|
175
|
+
def _contains(self, provider: Optional[str], model_name: str) -> bool:
|
|
176
|
+
if provider is None:
|
|
177
|
+
if any(pat.fullmatch(model_name) for _, pat, _ in self._overrides):
|
|
178
|
+
return True
|
|
179
|
+
return any(model_name in regex_dict for regex_dict in self._provider_model_map.values())
|
|
180
|
+
|
|
181
|
+
if self._lookup_override(provider, model_name) is not None:
|
|
182
|
+
return True
|
|
183
|
+
|
|
184
|
+
regex_dict = self._provider_model_map.get(provider)
|
|
185
|
+
if not regex_dict:
|
|
186
|
+
return False
|
|
187
|
+
return model_name in regex_dict
|
|
188
|
+
|
|
189
|
+
def add_override(
|
|
190
|
+
self, provider: Optional[str], pattern: re.Pattern[str], cost: ModelTokenCost
|
|
191
|
+
) -> None:
|
|
192
|
+
"""Register a *prioritized* cost override.
|
|
193
|
+
|
|
194
|
+
Overrides are evaluated in the order in which they are added (LIFO).
|
|
195
|
+
"""
|
|
196
|
+
|
|
197
|
+
if not isinstance(pattern, re.Pattern):
|
|
198
|
+
raise TypeError("pattern must be a compiled regex")
|
|
199
|
+
self._overrides.append((provider, pattern, cost))
|
|
200
|
+
self._cache.clear()
|
|
201
|
+
|
|
202
|
+
def _lookup_override(
|
|
203
|
+
self, provider: Optional[str], model_name: str
|
|
204
|
+
) -> Optional[ModelTokenCost]:
|
|
205
|
+
"""Return the cost from the highest-priority override that matches, or *None*."""
|
|
206
|
+
|
|
207
|
+
for override_provider, override_pattern, override_cost in reversed(self._overrides):
|
|
208
|
+
provider_matches = override_provider is None or override_provider == provider
|
|
209
|
+
if provider_matches and override_pattern.fullmatch(model_name):
|
|
210
|
+
return override_cost
|
|
211
|
+
return None
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
def create_cost_table(
|
|
215
|
+
manifest_path: Optional[Union[str, "os.PathLike[str]"]] = None,
|
|
216
|
+
) -> "ModelCostLookup":
|
|
217
|
+
if manifest_path is None:
|
|
218
|
+
manifest_path = Path(__file__).with_name("model_cost_manifest.json")
|
|
219
|
+
|
|
220
|
+
manifest_path = Path(manifest_path)
|
|
221
|
+
|
|
222
|
+
if not manifest_path.exists():
|
|
223
|
+
raise FileNotFoundError(f"Model cost manifest not found: {manifest_path}")
|
|
224
|
+
|
|
225
|
+
with manifest_path.open("r", encoding="utf-8") as fp:
|
|
226
|
+
try:
|
|
227
|
+
manifest_entries: list[dict[str, Any]] = json.load(fp)
|
|
228
|
+
except json.JSONDecodeError as exc:
|
|
229
|
+
raise ValueError(f"Failed to parse manifest JSON: {manifest_path}") from exc
|
|
230
|
+
|
|
231
|
+
lookup = ModelCostLookup()
|
|
232
|
+
|
|
233
|
+
for entry in manifest_entries:
|
|
234
|
+
provider: Optional[str] = entry.get("provider")
|
|
235
|
+
|
|
236
|
+
try:
|
|
237
|
+
pattern = re.compile(entry["regex"])
|
|
238
|
+
except re.error as exc:
|
|
239
|
+
raise ValueError(
|
|
240
|
+
f"Invalid regex in manifest for model {entry.get('model')}: {entry['regex']}"
|
|
241
|
+
) from exc
|
|
242
|
+
|
|
243
|
+
cost = ModelTokenCost(
|
|
244
|
+
input=entry.get("input"),
|
|
245
|
+
output=entry.get("output"),
|
|
246
|
+
cache_write=entry.get("cache_write"),
|
|
247
|
+
cache_read=entry.get("cache_read"),
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
lookup.add_pattern(provider, pattern, cost)
|
|
251
|
+
|
|
252
|
+
return lookup
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
COST_TABLE = create_cost_table()
|