genblaze-assemblyai 0.3.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,76 @@
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+
7
+ # Distribution / packaging
8
+ dist/
9
+ build/
10
+ *.egg-info/
11
+ *.egg
12
+ wheels/
13
+ /MANIFEST
14
+
15
+ # Virtual environments
16
+ .venv/
17
+ venv/
18
+ env/
19
+ ENV/
20
+
21
+ # IDE / editors
22
+ .vscode/
23
+ .idea/
24
+ *.swp
25
+ *.swo
26
+ *~
27
+ .project
28
+ .settings/
29
+
30
+ # Testing / coverage
31
+ .pytest_cache/
32
+ .coverage
33
+ .coverage.*
34
+ htmlcov/
35
+ coverage.xml
36
+ *.cover
37
+
38
+ # Type checking
39
+ .mypy_cache/
40
+ .pytype/
41
+ .pyre/
42
+
43
+ # Linting
44
+ .ruff_cache/
45
+
46
+ # OS files
47
+ .DS_Store
48
+ Thumbs.db
49
+ ehthumbs.db
50
+
51
+ # Environment / secrets
52
+ .env
53
+ .env.*
54
+ !.env.example
55
+ *.pem
56
+ *.key
57
+ credentials.json
58
+
59
+ # Jupyter
60
+ .ipynb_checkpoints/
61
+
62
+ # MkDocs build output
63
+ site/
64
+
65
+ # Claude Code — local/ephemeral
66
+ .claude/settings.local.json
67
+ .claude/worktrees/
68
+ CLAUDE.local.md
69
+
70
+ # Hypothesis test cache
71
+ .hypothesis/
72
+
73
+ # Temp files
74
+ *.tmp
75
+ *.bak
76
+ *.log
@@ -0,0 +1,117 @@
1
+ Metadata-Version: 2.4
2
+ Name: genblaze-assemblyai
3
+ Version: 0.3.0
4
+ Summary: AssemblyAI (speech-to-text / transcription) provider adapter for genblaze
5
+ Project-URL: Homepage, https://github.com/backblaze-labs/genblaze
6
+ Project-URL: Documentation, https://github.com/backblaze-labs/genblaze
7
+ Project-URL: Repository, https://github.com/backblaze-labs/genblaze
8
+ Project-URL: Issues, https://github.com/backblaze-labs/genblaze/issues
9
+ Author-email: Eduardo Pavez <epavez@backblaze.com>
10
+ License-Expression: MIT
11
+ Keywords: ai,assemblyai,audio,c2pa-ready,genai,genblaze,manifest,media,pipeline,provenance,speech-to-text,transcription
12
+ Classifier: Development Status :: 3 - Alpha
13
+ Classifier: License :: OSI Approved :: MIT License
14
+ Classifier: Programming Language :: Python :: 3.11
15
+ Classifier: Programming Language :: Python :: 3.12
16
+ Classifier: Programming Language :: Python :: 3.13
17
+ Classifier: Topic :: Multimedia
18
+ Classifier: Topic :: Software Development :: Libraries
19
+ Classifier: Typing :: Typed
20
+ Requires-Python: >=3.11
21
+ Requires-Dist: assemblyai<1,>=0.45
22
+ Requires-Dist: genblaze-core<0.4,>=0.3.0
23
+ Provides-Extra: dev
24
+ Requires-Dist: pytest>=7.0; extra == 'dev'
25
+ Description-Content-Type: text/markdown
26
+
27
+ <!-- last_verified: 2026-06-23 -->
28
+ # genblaze-assemblyai
29
+
30
+ [AssemblyAI](https://www.assemblyai.com/) **speech-to-text / transcription**
31
+ provider adapter for [genblaze](https://github.com/backblaze-labs/genblaze).
32
+
33
+ This connector is the inverse of every other genblaze adapter: instead of
34
+ *generating* media it *consumes* an audio URL and *produces* a text
35
+ transcript with word-level timings. The transcript lands as a hash-verified
36
+ **TEXT asset** with a provenance manifest, so it composes into pipelines
37
+ (generate audio → transcribe, caption a generated video) and persists to
38
+ Backblaze B2 / S3 like any other genblaze step.
39
+
40
+ ## Install
41
+
42
+ ```bash
43
+ pip install genblaze-assemblyai # or: pip install "genblaze[assemblyai]"
44
+ export ASSEMBLYAI_API_KEY="..." # from https://www.assemblyai.com/app/api-keys
45
+ ```
46
+
47
+ ## Usage
48
+
49
+ ```python
50
+ from genblaze_core import Modality, Pipeline
51
+ from genblaze_assemblyai import AssemblyAIProvider
52
+
53
+ run, manifest = (
54
+ Pipeline("transcribe")
55
+ .step(
56
+ AssemblyAIProvider(),
57
+ model="universal-3-pro", # speech_models: universal-3-pro | universal-2
58
+ prompt="https://example.com/podcast-episode.mp3", # or params={"audio_url": ...}, or a chained audio input
59
+ modality=Modality.TEXT,
60
+ speaker_labels=True, # any TranscriptionConfig flag passes through
61
+ )
62
+ .run()
63
+ )
64
+
65
+ asset = run.steps[0].assets[0]
66
+ print(asset.metadata["text"]) # the transcript
67
+ print(asset.audio.word_timings[:3]) # [WordTiming(word=..., start=..., end=...), ...] (seconds)
68
+ print(manifest.canonical_hash)
69
+ ```
70
+
71
+ The audio URL is resolved from (in priority order) `step.inputs[0].url` →
72
+ `step.params["audio_url"]` → `step.prompt`, then SSRF-validated (https:// or
73
+ file:// only) before submission. `step.model` selects the AssemblyAI speech
74
+ model (sent on the SDK's plural `speech_models` field — the live API retired
75
+ the singular `speech_model` field and the `best` / `nano` aliases; current
76
+ values are `universal-3-pro` / `universal-2`); every other
77
+ `TranscriptionConfig` flag (`speaker_labels`, `language_code`,
78
+ audio-intelligence options, …) passes through `step.params`.
79
+
80
+ ## Pricing
81
+
82
+ The SDK ships no hardcoded prices. AssemblyAI bills per minute of *input*
83
+ audio; the connector captures the input duration in
84
+ `step.provider_payload["audio_duration"]` (seconds) during `fetch_output`.
85
+ Register a recipe at runtime — see
86
+ [`docs/reference/pricing-recipes.md`](https://github.com/backblaze-labs/genblaze/blob/main/docs/reference/pricing-recipes.md)
87
+ ("AssemblyAI" section):
88
+
89
+ ```python
90
+ from genblaze_core.providers import per_response_metric
91
+
92
+ RATE_PER_MINUTE = RATE # replace RATE with the per-minute USD rate from assemblyai.com/pricing
93
+
94
+
95
+ def per_minute(ctx):
96
+ seconds = ctx.provider_payload.get("audio_duration")
97
+ return (seconds / 60.0) * RATE_PER_MINUTE if seconds is not None else None
98
+
99
+
100
+ provider = AssemblyAIProvider(api_key="...")
101
+ # Speech-model slugs match the connector's family, so register against the
102
+ # concrete slug(s) you use (register_pricing layers onto the family spec).
103
+ for slug in ("universal-3-pro", "universal-2"):
104
+ provider.models.register_pricing(slug, per_response_metric(per_minute))
105
+ ```
106
+
107
+ ## Notes
108
+
109
+ - **Out of scope for v1:** real-time/streaming transcription, LeMUR
110
+ (LLM-over-transcript), and first-class SRT/VTT subtitle outputs. Audio-
111
+ intelligence flags pass through to the API and land in `metadata`.
112
+
113
+ ## Docs
114
+
115
+ - [AssemblyAI docs](https://www.assemblyai.com/docs)
116
+ - [Transcript API reference](https://www.assemblyai.com/docs/api-reference/transcripts/get)
117
+ - genblaze [new-provider guide](https://github.com/backblaze-labs/genblaze/blob/main/docs/guides/new-provider.md)
@@ -0,0 +1,91 @@
1
+ <!-- last_verified: 2026-06-23 -->
2
+ # genblaze-assemblyai
3
+
4
+ [AssemblyAI](https://www.assemblyai.com/) **speech-to-text / transcription**
5
+ provider adapter for [genblaze](https://github.com/backblaze-labs/genblaze).
6
+
7
+ This connector is the inverse of every other genblaze adapter: instead of
8
+ *generating* media it *consumes* an audio URL and *produces* a text
9
+ transcript with word-level timings. The transcript lands as a hash-verified
10
+ **TEXT asset** with a provenance manifest, so it composes into pipelines
11
+ (generate audio → transcribe, caption a generated video) and persists to
12
+ Backblaze B2 / S3 like any other genblaze step.
13
+
14
+ ## Install
15
+
16
+ ```bash
17
+ pip install genblaze-assemblyai # or: pip install "genblaze[assemblyai]"
18
+ export ASSEMBLYAI_API_KEY="..." # from https://www.assemblyai.com/app/api-keys
19
+ ```
20
+
21
+ ## Usage
22
+
23
+ ```python
24
+ from genblaze_core import Modality, Pipeline
25
+ from genblaze_assemblyai import AssemblyAIProvider
26
+
27
+ run, manifest = (
28
+ Pipeline("transcribe")
29
+ .step(
30
+ AssemblyAIProvider(),
31
+ model="universal-3-pro", # speech_models: universal-3-pro | universal-2
32
+ prompt="https://example.com/podcast-episode.mp3", # or params={"audio_url": ...}, or a chained audio input
33
+ modality=Modality.TEXT,
34
+ speaker_labels=True, # any TranscriptionConfig flag passes through
35
+ )
36
+ .run()
37
+ )
38
+
39
+ asset = run.steps[0].assets[0]
40
+ print(asset.metadata["text"]) # the transcript
41
+ print(asset.audio.word_timings[:3]) # [WordTiming(word=..., start=..., end=...), ...] (seconds)
42
+ print(manifest.canonical_hash)
43
+ ```
44
+
45
+ The audio URL is resolved from (in priority order) `step.inputs[0].url` →
46
+ `step.params["audio_url"]` → `step.prompt`, then SSRF-validated (https:// or
47
+ file:// only) before submission. `step.model` selects the AssemblyAI speech
48
+ model (sent on the SDK's plural `speech_models` field — the live API retired
49
+ the singular `speech_model` field and the `best` / `nano` aliases; current
50
+ values are `universal-3-pro` / `universal-2`); every other
51
+ `TranscriptionConfig` flag (`speaker_labels`, `language_code`,
52
+ audio-intelligence options, …) passes through `step.params`.
53
+
54
+ ## Pricing
55
+
56
+ The SDK ships no hardcoded prices. AssemblyAI bills per minute of *input*
57
+ audio; the connector captures the input duration in
58
+ `step.provider_payload["audio_duration"]` (seconds) during `fetch_output`.
59
+ Register a recipe at runtime — see
60
+ [`docs/reference/pricing-recipes.md`](https://github.com/backblaze-labs/genblaze/blob/main/docs/reference/pricing-recipes.md)
61
+ ("AssemblyAI" section):
62
+
63
+ ```python
64
+ from genblaze_core.providers import per_response_metric
65
+
66
+ RATE_PER_MINUTE = RATE # replace RATE with the per-minute USD rate from assemblyai.com/pricing
67
+
68
+
69
+ def per_minute(ctx):
70
+ seconds = ctx.provider_payload.get("audio_duration")
71
+ return (seconds / 60.0) * RATE_PER_MINUTE if seconds is not None else None
72
+
73
+
74
+ provider = AssemblyAIProvider(api_key="...")
75
+ # Speech-model slugs match the connector's family, so register against the
76
+ # concrete slug(s) you use (register_pricing layers onto the family spec).
77
+ for slug in ("universal-3-pro", "universal-2"):
78
+ provider.models.register_pricing(slug, per_response_metric(per_minute))
79
+ ```
80
+
81
+ ## Notes
82
+
83
+ - **Out of scope for v1:** real-time/streaming transcription, LeMUR
84
+ (LLM-over-transcript), and first-class SRT/VTT subtitle outputs. Audio-
85
+ intelligence flags pass through to the API and land in `metadata`.
86
+
87
+ ## Docs
88
+
89
+ - [AssemblyAI docs](https://www.assemblyai.com/docs)
90
+ - [Transcript API reference](https://www.assemblyai.com/docs/api-reference/transcripts/get)
91
+ - genblaze [new-provider guide](https://github.com/backblaze-labs/genblaze/blob/main/docs/guides/new-provider.md)
@@ -0,0 +1,7 @@
1
+ """AssemblyAI (speech-to-text / transcription) provider adapter for genblaze."""
2
+
3
+ from genblaze_assemblyai.provider import AssemblyAIProvider
4
+
5
+ from ._version import __version__ # noqa: F401 — re-exported
6
+
7
+ __all__ = ["AssemblyAIProvider"]
@@ -0,0 +1,29 @@
1
+ """Shared AssemblyAI error mapping — used by provider.py.
2
+
3
+ The AssemblyAI Python SDK raises exceptions that may carry an HTTP
4
+ ``status_code`` attribute (e.g. ``aai.types.TranscriptError`` and the
5
+ underlying transport errors). We classify off the status code when present,
6
+ then fall back to the shared string-based classifier. The same function also
7
+ classifies the plain ``transcript.error`` string returned when a transcript
8
+ finishes with ``status == "error"``.
9
+ """
10
+
11
+ from genblaze_core.models.enums import ProviderErrorCode
12
+ from genblaze_core.providers.base import classify_api_error
13
+
14
+
15
+ def map_assemblyai_error(exc: Exception | str) -> ProviderErrorCode:
16
+ """Map an AssemblyAI API exception (or error string) to a ProviderErrorCode."""
17
+ status = getattr(exc, "status_code", None)
18
+ if isinstance(status, int):
19
+ if status == 429:
20
+ return ProviderErrorCode.RATE_LIMIT
21
+ if status in (401, 403):
22
+ return ProviderErrorCode.AUTH_FAILURE
23
+ if status in (400, 422):
24
+ return ProviderErrorCode.INVALID_INPUT
25
+ if status >= 500:
26
+ return ProviderErrorCode.SERVER_ERROR
27
+ # Fall back to shared string-based classifier (also handles the plain
28
+ # ``transcript.error`` string from a failed transcript).
29
+ return classify_api_error(exc)
@@ -0,0 +1,14 @@
1
+ """``genblaze-assemblyai`` package version — single source of truth via importlib.metadata.
2
+
3
+ Reading from ``importlib.metadata`` keeps the constant equal to whatever
4
+ wheel is installed; no manual edits per release.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from importlib.metadata import PackageNotFoundError, version
10
+
11
+ try:
12
+ __version__: str = version("genblaze-assemblyai")
13
+ except PackageNotFoundError: # pragma: no cover — editable dev installs
14
+ __version__ = "0.0.0+unknown"
@@ -0,0 +1,409 @@
1
+ """AssemblyAIProvider — adapter for the AssemblyAI speech-to-text API.
2
+
3
+ AssemblyAI is the inverse of every other genblaze connector: it *consumes* an
4
+ audio URL and *produces* a text transcript (plus word-level timing, speaker
5
+ labels, and optional audio-intelligence). It fits genblaze's primitives with
6
+ zero core changes — the transcript lands as a **TEXT ``Asset``** following the
7
+ ``NvidiaChatProvider`` precedent (``url="text:{sha256}"``,
8
+ ``media_type="text/plain"``, payload in ``metadata["text"]``, sha256 over the
9
+ text bytes), and word timings populate ``AudioMetadata.word_timings``.
10
+
11
+ **API style:** genuinely async — the SDK's ``Transcriber().submit()`` is
12
+ non-blocking (returns a queued ``Transcript`` immediately) and
13
+ ``aai.Transcript.get_by_id(id)`` polls/fetches by id. So this is a
14
+ ``BaseProvider`` (submit / poll / fetch_output), which gets adaptive polling,
15
+ progress events, ``resume()`` crash-recovery, and poll caching for free.
16
+
17
+ **Catalog architecture (genblaze-core 0.3.0):** AssemblyAI exposes no live
18
+ ``GET /models`` catalog; the model set is small and stable. This connector
19
+ therefore declares ``DiscoverySupport.NONE`` and ships a single pattern-keyed
20
+ ``ModelFamily`` for the current ``universal-*`` speech models
21
+ (``universal-3-pro`` / ``universal-2``) plus a permissive TEXT fallback so any
22
+ slug passes through. ``step.model`` is the AssemblyAI speech model and is sent
23
+ on the SDK's plural ``speech_models`` field — the live API has deprecated the
24
+ singular ``speech_model`` field (and the legacy ``best`` / ``nano`` aliases).
25
+
26
+ **Pricing:** AssemblyAI bills per minute of *input* audio. The SDK ships zero
27
+ hardcoded prices — register a recipe at runtime that reads
28
+ ``step.provider_payload["audio_duration"]`` (seconds, captured during
29
+ ``fetch_output``); see ``docs/reference/pricing-recipes.md`` ("AssemblyAI").
30
+
31
+ Docs: https://www.assemblyai.com/docs
32
+ """
33
+
34
+ from __future__ import annotations
35
+
36
+ import hashlib
37
+ import logging
38
+ import os
39
+ import re
40
+ from typing import Any
41
+ from urllib.parse import urlparse
42
+ from urllib.request import url2pathname
43
+
44
+ from genblaze_core.exceptions import ProviderError
45
+ from genblaze_core.models.asset import Asset, AudioMetadata, WordTiming
46
+ from genblaze_core.models.enums import Modality, ProviderErrorCode
47
+ from genblaze_core.models.step import Step
48
+ from genblaze_core.providers import (
49
+ BaseProvider,
50
+ DiscoverySupport,
51
+ ModelFamily,
52
+ ModelRegistry,
53
+ ModelSpec,
54
+ ProviderCapabilities,
55
+ RetryPolicy,
56
+ validate_chain_input_url,
57
+ )
58
+ from genblaze_core.providers.retry import retry_after_from_response
59
+ from genblaze_core.runnable.config import RunnableConfig
60
+
61
+ from ._errors import map_assemblyai_error
62
+
63
+ logger = logging.getLogger("genblaze.assemblyai")
64
+
65
+ # Terminal transcript statuses. AssemblyAI's lifecycle is
66
+ # queued → processing → completed | error.
67
+ _TERMINAL_STATUSES = frozenset({"completed", "error"})
68
+
69
+ # The AssemblyAI speech-model family covers the current ``universal-*`` line
70
+ # (``universal-3-pro``, ``universal-2``) — the only values the live API accepts
71
+ # on the ``speech_models`` field as of 2026-06. The legacy ``speech_model``
72
+ # tier aliases (``best`` / ``nano``) and the bare ``universal`` slug are now
73
+ # rejected server-side; they fall through to the permissive fallback (which
74
+ # still lets the request through — failures surface at submit, per NONE
75
+ # discovery). Output is always a TEXT transcript, so the spec modality is TEXT.
76
+ # spec_template.pricing MUST be None (pricing is user-registered).
77
+ _ASSEMBLYAI_SPEECH_FAMILY = ModelFamily(
78
+ name="assemblyai-speech",
79
+ pattern=re.compile(r"^universal"),
80
+ spec_template=ModelSpec(model_id="*", modality=Modality.TEXT),
81
+ description="AssemblyAI speech models — universal-3-pro / universal-2.",
82
+ example_slugs=("universal-3-pro", "universal-2"),
83
+ )
84
+
85
+ _ASSEMBLYAI_FALLBACK = ModelSpec(model_id="*", modality=Modality.TEXT)
86
+
87
+
88
+ def _ms_to_s(value: Any) -> float | None:
89
+ """Convert an AssemblyAI millisecond timestamp to seconds, or None."""
90
+ if value is None:
91
+ return None
92
+ try:
93
+ return float(value) / 1000.0
94
+ except (TypeError, ValueError):
95
+ return None
96
+
97
+
98
+ def _status_str(transcript: Any) -> str:
99
+ """Return the transcript status as a lowercase string.
100
+
101
+ ``transcript.status`` is an ``aai.TranscriptStatus`` (a ``str`` enum) on
102
+ the real SDK; tests may use a plain string. Normalize both to the bare
103
+ value (``"completed"`` / ``"error"`` / …).
104
+ """
105
+ status = getattr(transcript, "status", None)
106
+ value = getattr(status, "value", None)
107
+ return str(value if value is not None else status).lower()
108
+
109
+
110
+ def _audio_ref_for_sdk(audio_url: str) -> str:
111
+ """Shape a resolved audio URL into the form ``Transcriber.submit()`` wants.
112
+
113
+ The AssemblyAI SDK treats any non-HTTP string as a *local file path* and
114
+ opens it with ``open(ref, "rb")``. A ``file://`` URI — the form chained
115
+ SyncProvider outputs take, percent-encoded via ``quote()`` — would be
116
+ opened literally as the filename ``file:///tmp/a.wav`` and fail. Convert
117
+ validated ``file://`` URIs to a real filesystem path (``url2pathname``
118
+ handles percent-decoding and platform path conversion); pass ``https://``
119
+ URLs through untouched so the SDK fetches them remotely.
120
+
121
+ Assumes the URL already passed ``validate_chain_input_url`` (so the scheme
122
+ is ``https`` or ``file`` and any ``file://`` netloc is empty/``localhost``).
123
+ """
124
+ parsed = urlparse(audio_url)
125
+ if parsed.scheme == "file":
126
+ return url2pathname(parsed.path)
127
+ return audio_url
128
+
129
+
130
+ def _serialize_utterances(utterances: Any) -> list[dict[str, Any]]:
131
+ """Flatten SDK utterance objects to canonical-JSON-safe dicts.
132
+
133
+ Times are converted ms → seconds to match ``word_timings``. Kept minimal
134
+ (speaker / text / start / end / confidence) — utterances are pass-through
135
+ context in ``metadata``, not a first-class shape in v1.
136
+ """
137
+ out: list[dict[str, Any]] = []
138
+ for u in utterances or []:
139
+ out.append(
140
+ {
141
+ "speaker": getattr(u, "speaker", None),
142
+ "text": getattr(u, "text", None),
143
+ "start": _ms_to_s(getattr(u, "start", None)),
144
+ "end": _ms_to_s(getattr(u, "end", None)),
145
+ "confidence": getattr(u, "confidence", None),
146
+ }
147
+ )
148
+ return out
149
+
150
+
151
+ class AssemblyAIProvider(BaseProvider):
152
+ """Provider adapter for AssemblyAI speech-to-text transcription.
153
+
154
+ Transcribes an audio URL into a hash-verified TEXT asset with word-level
155
+ timings. The audio URL is resolved from (in priority order)
156
+ ``step.inputs[0].url`` → ``step.params["audio_url"]`` → ``step.prompt`` and
157
+ SSRF-validated via ``validate_chain_input_url`` before submission, so the
158
+ same provider works both standalone and chained into a pipeline (e.g.
159
+ generate audio → transcribe).
160
+
161
+ ``step.model`` is the AssemblyAI speech model (``universal-3-pro`` /
162
+ ``universal-2``) and is sent on the SDK's plural ``speech_models`` field;
163
+ any other ``TranscriptionConfig`` kwarg (``speaker_labels``,
164
+ ``language_code``, audio-intelligence flags, …) passes through
165
+ ``step.params``.
166
+
167
+ Args:
168
+ api_key: AssemblyAI API key. Falls back to ``ASSEMBLYAI_API_KEY``.
169
+ poll_interval: Base seconds between polls (the base class applies
170
+ adaptive backoff on top).
171
+ models: Optional custom ``ModelRegistry`` — overrides the class default.
172
+ retry_policy: Optional retry policy override.
173
+ probe_cache_ttl: Per-instance probe-cache TTL (no-op for NONE discovery
174
+ but accepted for API uniformity with other providers).
175
+ probe_cache_max_entries: Per-instance probe-cache size cap.
176
+ """
177
+
178
+ name = "assemblyai"
179
+ discovery_support = DiscoverySupport.NONE
180
+ """AssemblyAI exposes no ``GET /models`` catalog — the model set is small
181
+ and stable. Family-matched ``universal-*`` slugs (``universal-3-pro`` /
182
+ ``universal-2``) preflight as ``OK_PROVISIONAL``; everything else resolves
183
+ through the permissive fallback as ``UNKNOWN_PERMISSIVE``."""
184
+
185
+ @classmethod
186
+ def create_registry(cls) -> ModelRegistry:
187
+ return ModelRegistry(
188
+ provider_families=(_ASSEMBLYAI_SPEECH_FAMILY,),
189
+ fallback=_ASSEMBLYAI_FALLBACK,
190
+ )
191
+
192
+ def get_capabilities(self) -> ProviderCapabilities:
193
+ """AssemblyAI: audio in, TEXT transcript out."""
194
+ return ProviderCapabilities(
195
+ supported_modalities=[Modality.TEXT],
196
+ supported_inputs=["audio"],
197
+ accepts_chain_input=True,
198
+ models=self._models.known(),
199
+ output_formats=["text/plain"],
200
+ )
201
+
202
+ def __init__(
203
+ self,
204
+ api_key: str | None = None,
205
+ *,
206
+ poll_interval: float = 3.0,
207
+ models: ModelRegistry | None = None,
208
+ retry_policy: RetryPolicy | None = None,
209
+ probe_cache_ttl: float | None = None,
210
+ probe_cache_max_entries: int | None = None,
211
+ ):
212
+ super().__init__(
213
+ models=models,
214
+ retry_policy=retry_policy,
215
+ probe_cache_ttl=probe_cache_ttl,
216
+ probe_cache_max_entries=probe_cache_max_entries,
217
+ )
218
+ self.poll_interval = poll_interval
219
+ self._api_key = api_key
220
+ self._client: Any = None
221
+
222
+ def _get_client(self) -> Any:
223
+ """Return the ``assemblyai`` module, configured with the API key.
224
+
225
+ AssemblyAI's SDK is module-scoped: transcription goes through
226
+ ``aai.Transcriber()`` / ``aai.Transcript.get_by_id()`` with the key set
227
+ on ``aai.settings.api_key``. So the "client" is the module itself. The
228
+ key is resolved here (the SDK does not auto-read ``ASSEMBLYAI_API_KEY``)
229
+ and a missing key fails fast with ``AUTH_FAILURE`` rather than a
230
+ deferred opaque 401.
231
+ """
232
+ if self._client is None:
233
+ key = self._api_key or os.getenv("ASSEMBLYAI_API_KEY")
234
+ if not key:
235
+ raise ProviderError(
236
+ "No AssemblyAI API key found. Pass api_key=... or set ASSEMBLYAI_API_KEY.",
237
+ error_code=ProviderErrorCode.AUTH_FAILURE,
238
+ )
239
+ try:
240
+ import assemblyai as aai
241
+ except ImportError as exc:
242
+ raise ProviderError(
243
+ "assemblyai package not installed. Run: pip install assemblyai"
244
+ ) from exc
245
+ aai.settings.api_key = key
246
+ self._client = aai
247
+ return self._client
248
+
249
+ def normalize_params(
250
+ self, params: dict[str, Any], modality: Modality | None = None
251
+ ) -> dict[str, Any]:
252
+ """Map standard names to AssemblyAI's native ``TranscriptionConfig`` keys.
253
+
254
+ ``language`` → ``language_code`` (AssemblyAI native). Everything else
255
+ passes through untouched. Idempotent via the ``if x in p and native
256
+ not in p`` guard.
257
+ """
258
+ p = dict(params)
259
+ if "language" in p and "language_code" not in p:
260
+ p["language_code"] = p.pop("language")
261
+ return p
262
+
263
+ def _resolve_audio_url(self, step: Step) -> str:
264
+ """Resolve + SSRF-validate the audio URL to transcribe.
265
+
266
+ Priority: ``step.inputs[0].url`` (chained pipeline output) →
267
+ ``step.params["audio_url"]`` → ``step.prompt`` (standalone use). The
268
+ chosen URL is validated with ``validate_chain_input_url`` (https:// or
269
+ file:// only) before it leaves the process.
270
+ """
271
+ if step.inputs and step.inputs[0].url:
272
+ url = step.inputs[0].url
273
+ elif step.params.get("audio_url"):
274
+ url = str(step.params["audio_url"])
275
+ elif step.prompt:
276
+ url = step.prompt
277
+ else:
278
+ raise ProviderError(
279
+ "AssemblyAI requires an audio URL via step.inputs[0], "
280
+ "step.params['audio_url'], or step.prompt.",
281
+ error_code=ProviderErrorCode.INVALID_INPUT,
282
+ )
283
+ validate_chain_input_url(url)
284
+ return url
285
+
286
+ def submit(self, step: Step, config: RunnableConfig | None = None) -> Any:
287
+ """Submit a transcription job (non-blocking) and return its id."""
288
+ aai = self._get_client()
289
+ # Resolve + SSRF-validate, then shape for the SDK: a validated file://
290
+ # chain input must reach Transcriber.submit() as a local filesystem
291
+ # path, since the SDK open()s any non-HTTP string literally.
292
+ audio_ref = _audio_ref_for_sdk(self._resolve_audio_url(step))
293
+ try:
294
+ cfg_kwargs = self.normalize_params(dict(step.params), step.modality)
295
+ # audio_url is the submit() argument, not a TranscriptionConfig field.
296
+ cfg_kwargs.pop("audio_url", None)
297
+ # step.model is the speech model. The live API has deprecated the
298
+ # singular ``speech_model`` field (and the legacy best/nano tier
299
+ # aliases); the slug is sent on the plural ``speech_models`` list,
300
+ # which currently accepts ``universal-3-pro`` / ``universal-2``.
301
+ if step.model:
302
+ cfg_kwargs["speech_models"] = [step.model]
303
+ transcription_config = aai.TranscriptionConfig(**cfg_kwargs)
304
+ transcriber = aai.Transcriber()
305
+ transcript = transcriber.submit(audio_ref, config=transcription_config)
306
+ return transcript.id
307
+ except ProviderError:
308
+ raise
309
+ except Exception as exc:
310
+ raise ProviderError(
311
+ f"AssemblyAI submit failed: {exc}",
312
+ error_code=map_assemblyai_error(exc),
313
+ retry_after=retry_after_from_response(exc),
314
+ ) from exc
315
+
316
+ def poll(self, prediction_id: Any, config: RunnableConfig | None = None) -> bool:
317
+ """Return True once the transcript reaches a terminal status."""
318
+ aai = self._get_client()
319
+ try:
320
+ transcript = aai.Transcript.get_by_id(prediction_id)
321
+ if _status_str(transcript) in _TERMINAL_STATUSES:
322
+ self._cache_poll_result(prediction_id, transcript)
323
+ return True
324
+ return False
325
+ except ProviderError:
326
+ raise
327
+ except Exception as exc:
328
+ raise ProviderError(
329
+ f"AssemblyAI poll failed: {exc}",
330
+ error_code=map_assemblyai_error(exc),
331
+ retry_after=retry_after_from_response(exc),
332
+ ) from exc
333
+
334
+ def fetch_output(self, prediction_id: Any, step: Step) -> Step:
335
+ """Build the TEXT transcript asset from the completed transcript."""
336
+ aai = self._get_client()
337
+ try:
338
+ transcript = self._get_cached_poll_result(prediction_id)
339
+ if transcript is None:
340
+ transcript = aai.Transcript.get_by_id(prediction_id)
341
+
342
+ status = _status_str(transcript)
343
+ if status == "error":
344
+ err = getattr(transcript, "error", None) or "AssemblyAI transcription failed"
345
+ raise ProviderError(str(err), error_code=map_assemblyai_error(err))
346
+ if status != "completed":
347
+ # fetch_output is only reached after poll() confirms a terminal
348
+ # status, so this guards the off-happy-path / resume case rather
349
+ # than silently emitting an empty transcript for a running job.
350
+ raise ProviderError(
351
+ f"AssemblyAI transcript {prediction_id} is not complete (status={status!r}).",
352
+ error_code=ProviderErrorCode.SERVER_ERROR,
353
+ )
354
+
355
+ text = getattr(transcript, "text", None) or ""
356
+ text_bytes = text.encode("utf-8")
357
+ digest = hashlib.sha256(text_bytes).hexdigest()
358
+
359
+ asset = Asset(
360
+ url=f"text:{digest}", # synthetic TEXT asset (NvidiaChatProvider precedent)
361
+ media_type="text/plain",
362
+ sha256=digest,
363
+ size_bytes=len(text_bytes),
364
+ )
365
+
366
+ words = getattr(transcript, "words", None)
367
+ if words:
368
+ # AssemblyAI word start/end are in MILLISECONDS; WordTiming
369
+ # expects seconds — divide by 1000.
370
+ timings = [
371
+ WordTiming(
372
+ word=getattr(w, "text", "") or "",
373
+ start=_ms_to_s(getattr(w, "start", None)) or 0.0,
374
+ end=_ms_to_s(getattr(w, "end", None)) or 0.0,
375
+ confidence=getattr(w, "confidence", None),
376
+ )
377
+ for w in words
378
+ ]
379
+ asset.audio = AudioMetadata(word_timings=timings)
380
+
381
+ asset.metadata["text"] = text
382
+ language = getattr(transcript, "language_code", None)
383
+ if language is not None:
384
+ asset.metadata["language"] = language
385
+ confidence = getattr(transcript, "confidence", None)
386
+ if confidence is not None:
387
+ asset.metadata["confidence"] = confidence
388
+ utterances = getattr(transcript, "utterances", None)
389
+ if utterances:
390
+ asset.metadata["utterances"] = _serialize_utterances(utterances)
391
+
392
+ step.assets = [asset]
393
+
394
+ # Seconds of input audio; the pricing recipe reads this.
395
+ audio_duration = getattr(transcript, "audio_duration", None)
396
+ if audio_duration is not None:
397
+ step.provider_payload["audio_duration"] = audio_duration
398
+ asset.metadata["audio_duration"] = audio_duration
399
+
400
+ self._apply_registry_pricing(step)
401
+ return step
402
+ except ProviderError:
403
+ raise
404
+ except Exception as exc:
405
+ raise ProviderError(
406
+ f"AssemblyAI fetch_output failed: {exc}",
407
+ error_code=map_assemblyai_error(exc),
408
+ retry_after=retry_after_from_response(exc),
409
+ ) from exc
File without changes
@@ -0,0 +1,61 @@
1
+ [build-system]
2
+ requires = ["hatchling"]
3
+ build-backend = "hatchling.build"
4
+
5
+ [project]
6
+ name = "genblaze-assemblyai"
7
+ version = "0.3.0"
8
+ description = "AssemblyAI (speech-to-text / transcription) provider adapter for genblaze"
9
+ authors = [{name = "Eduardo Pavez", email = "epavez@backblaze.com"}]
10
+ readme = "README.md"
11
+ requires-python = ">=3.11"
12
+ license = "MIT"
13
+ classifiers = [
14
+ "Development Status :: 3 - Alpha",
15
+ "License :: OSI Approved :: MIT License",
16
+ "Typing :: Typed",
17
+ "Programming Language :: Python :: 3.11",
18
+ "Programming Language :: Python :: 3.12",
19
+ "Programming Language :: Python :: 3.13",
20
+ "Topic :: Multimedia",
21
+ "Topic :: Software Development :: Libraries",
22
+ ]
23
+ keywords = ["genblaze", "ai", "media", "manifest", "provenance", "c2pa-ready", "genai", "pipeline", "assemblyai", "speech-to-text", "transcription", "audio"]
24
+
25
+ dependencies = [
26
+ "genblaze-core>=0.3.0,<0.4",
27
+ # The provider uses Transcriber().submit() (non-blocking) and
28
+ # aai.Transcript.get_by_id(id) for polling; both have been stable across
29
+ # the 0.x line. The floor is 0.45, not 0.40: submit() always sends the
30
+ # speech model on TranscriptionConfig's plural ``speech_models`` field,
31
+ # which the SDK only added in 0.45.0 — 0.40.0–0.44.3 raise
32
+ # ``TypeError: ... unexpected keyword argument 'speech_models'`` before the
33
+ # request ever leaves the process. Bounded below 1.0 so a future major
34
+ # can't silently reshape these call sites.
35
+ "assemblyai>=0.45,<1",
36
+ ]
37
+
38
+ [project.urls]
39
+ Homepage = "https://github.com/backblaze-labs/genblaze"
40
+ Documentation = "https://github.com/backblaze-labs/genblaze"
41
+ Repository = "https://github.com/backblaze-labs/genblaze"
42
+ Issues = "https://github.com/backblaze-labs/genblaze/issues"
43
+
44
+ [project.optional-dependencies]
45
+ dev = [
46
+ "pytest>=7.0",
47
+ ]
48
+
49
+ [project.entry-points."genblaze.providers"]
50
+ assemblyai = "genblaze_assemblyai:AssemblyAIProvider"
51
+
52
+ [tool.hatch.build.targets.wheel]
53
+ packages = ["genblaze_assemblyai"]
54
+
55
+ [tool.pytest.ini_options]
56
+ testpaths = ["tests"]
57
+
58
+ [tool.deptry]
59
+ # Treat the `dev` extra as dev-only tooling so deptry does not flag pytest
60
+ # (and other test-only deps) as unused declared deps (DEP002).
61
+ optional_dependencies_dev_groups = ["dev"]
File without changes
@@ -0,0 +1,527 @@
1
+ """Tests for AssemblyAIProvider (mocked — no real API calls).
2
+
3
+ The provider's "client" is the ``assemblyai`` module itself (transcription
4
+ goes through ``aai.Transcriber()`` / ``aai.Transcript.get_by_id()`` with the
5
+ key on ``aai.settings.api_key``). Tests inject a fake module via
6
+ ``provider._client`` — the lazy ``import assemblyai`` in ``_get_client`` is
7
+ never reached on the happy path, so the real SDK is not required.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import hashlib
13
+ import sys
14
+ from types import SimpleNamespace
15
+ from unittest.mock import patch
16
+
17
+ import pytest
18
+ from genblaze_core.exceptions import ProviderError
19
+ from genblaze_core.models.asset import Asset
20
+ from genblaze_core.models.enums import Modality, ProviderErrorCode, StepStatus
21
+ from genblaze_core.models.step import Step
22
+ from genblaze_core.testing import ProviderComplianceTests
23
+
24
+ AUDIO_URL = "https://example.com/audio.mp3"
25
+
26
+
27
+ # --- fakes ----------------------------------------------------------------
28
+
29
+
30
+ class _FakeWord:
31
+ def __init__(self, text: str, start: int, end: int, confidence: float = 0.99):
32
+ self.text = text
33
+ self.start = start # milliseconds (AssemblyAI native)
34
+ self.end = end # milliseconds
35
+ self.confidence = confidence
36
+
37
+
38
+ def _fake_transcript(
39
+ *,
40
+ status: str = "completed",
41
+ text: str = "hello world",
42
+ with_words: bool = True,
43
+ audio_duration: float | None = 12.5,
44
+ error: str | None = None,
45
+ language_code: str | None = "en_us",
46
+ confidence: float | None = 0.97,
47
+ utterances=None,
48
+ ):
49
+ words = (
50
+ [_FakeWord("hello", 0, 500, 0.99), _FakeWord("world", 500, 1000, 0.98)]
51
+ if with_words
52
+ else None
53
+ )
54
+ return SimpleNamespace(
55
+ id="transcript-abc123",
56
+ status=status,
57
+ text=text,
58
+ words=words,
59
+ audio_duration=audio_duration,
60
+ language_code=language_code,
61
+ confidence=confidence,
62
+ utterances=utterances,
63
+ error=error,
64
+ )
65
+
66
+
67
+ class _FakeTranscriber:
68
+ """Records the submitted (audio_url, config) and returns a queued id."""
69
+
70
+ last_submit: dict = {}
71
+
72
+ def submit(self, audio_url, config=None):
73
+ _FakeTranscriber.last_submit = {"audio_url": audio_url, "config": config}
74
+ return SimpleNamespace(id="transcript-abc123", status="queued")
75
+
76
+
77
+ def _make_fake_aai(transcript=None):
78
+ """Build a fake ``assemblyai`` module exposing the surface the provider uses."""
79
+ transcript = transcript if transcript is not None else _fake_transcript()
80
+ fake = SimpleNamespace()
81
+ fake.settings = SimpleNamespace(api_key=None)
82
+ fake.Transcriber = _FakeTranscriber
83
+ fake.Transcript = SimpleNamespace(get_by_id=lambda tid: transcript)
84
+ fake.TranscriptionConfig = lambda **kwargs: SimpleNamespace(**kwargs)
85
+ return fake
86
+
87
+
88
+ def _make_provider(transcript=None):
89
+ from genblaze_assemblyai import AssemblyAIProvider
90
+
91
+ provider = AssemblyAIProvider(api_key="test-key", poll_interval=0.0)
92
+ provider._client = _make_fake_aai(transcript)
93
+ return provider
94
+
95
+
96
+ def _make_step(**kwargs) -> Step:
97
+ kwargs.setdefault("provider", "assemblyai")
98
+ kwargs.setdefault("model", "universal-3-pro")
99
+ kwargs.setdefault("modality", Modality.TEXT)
100
+ kwargs.setdefault("prompt", AUDIO_URL)
101
+ return Step(**kwargs)
102
+
103
+
104
+ @pytest.fixture(autouse=True)
105
+ def _reset_registry_cache():
106
+ """Isolate the class-level model registry between tests.
107
+
108
+ ``BaseProvider`` caches one ``ModelRegistry`` per provider class
109
+ (``models_default()``), shared across instances. A test that calls
110
+ ``register_pricing`` would otherwise leak the user-registered spec — and
111
+ flip a family-matched slug's validation outcome to AUTHORITATIVE — into
112
+ later tests. Resetting the cache gives each test a fresh registry built
113
+ from the (unmutated) module-level family/fallback specs.
114
+ """
115
+ from genblaze_assemblyai import AssemblyAIProvider
116
+
117
+ AssemblyAIProvider._models_cache = None
118
+ yield
119
+ AssemblyAIProvider._models_cache = None
120
+
121
+
122
+ # --- submit ---------------------------------------------------------------
123
+
124
+
125
+ def test_submit_returns_id_and_forwards_audio_url():
126
+ provider = _make_provider()
127
+ step = _make_step()
128
+ pred_id = provider.submit(step)
129
+ assert pred_id == "transcript-abc123"
130
+ assert _FakeTranscriber.last_submit["audio_url"] == AUDIO_URL
131
+
132
+
133
+ def test_submit_passes_speech_model_to_config():
134
+ provider = _make_provider()
135
+ step = _make_step(model="universal-3-pro")
136
+ provider.submit(step)
137
+ cfg = _FakeTranscriber.last_submit["config"]
138
+ # step.model is sent on the plural ``speech_models`` field — the live API
139
+ # deprecated the singular ``speech_model`` field and the best/nano aliases.
140
+ assert cfg.speech_models == ["universal-3-pro"]
141
+
142
+
143
+ def test_submit_strips_audio_url_from_config():
144
+ provider = _make_provider()
145
+ step = _make_step(prompt=None, params={"audio_url": AUDIO_URL, "speaker_labels": True})
146
+ provider.submit(step)
147
+ cfg = _FakeTranscriber.last_submit["config"]
148
+ assert not hasattr(cfg, "audio_url")
149
+ assert cfg.speaker_labels is True
150
+
151
+
152
+ def test_submit_converts_file_uri_to_local_path():
153
+ # A validated file:// chain input (e.g. a generate-audio step's output)
154
+ # must reach the SDK as a local filesystem PATH — the SDK open()s any
155
+ # non-HTTP string literally, so forwarding "file:///tmp/clip.wav" would try
156
+ # to open that exact filename and fail.
157
+ provider = _make_provider()
158
+ step = _make_step(prompt="file:///srv/audio/clip.wav")
159
+ provider.submit(step)
160
+ assert _FakeTranscriber.last_submit["audio_url"] == "/srv/audio/clip.wav"
161
+
162
+
163
+ def test_submit_decodes_percent_encoded_file_uri():
164
+ # Sibling connectors emit file:// URLs via quote(), so a path with spaces
165
+ # arrives percent-encoded; it must be decoded back to the real path.
166
+ provider = _make_provider()
167
+ step = _make_step(prompt="file:///srv/audio/my%20clip.wav")
168
+ provider.submit(step)
169
+ assert _FakeTranscriber.last_submit["audio_url"] == "/srv/audio/my clip.wav"
170
+
171
+
172
+ def test_submit_file_uri_localhost_netloc_to_local_path():
173
+ # RFC 8089 'file://localhost/...' (allowed by validate_chain_input_url)
174
+ # drops the localhost host and resolves to the bare path.
175
+ provider = _make_provider()
176
+ step = _make_step(prompt="file://localhost/srv/audio/clip.wav")
177
+ provider.submit(step)
178
+ assert _FakeTranscriber.last_submit["audio_url"] == "/srv/audio/clip.wav"
179
+
180
+
181
+ def test_submit_https_url_passes_through_unchanged():
182
+ # https:// inputs are uploaded/fetched remotely by the SDK — not opened
183
+ # locally — so they must reach submit() byte-for-byte unchanged.
184
+ provider = _make_provider()
185
+ step = _make_step(prompt=AUDIO_URL)
186
+ provider.submit(step)
187
+ assert _FakeTranscriber.last_submit["audio_url"] == AUDIO_URL
188
+
189
+
190
+ # --- full lifecycle -------------------------------------------------------
191
+
192
+
193
+ def test_full_lifecycle_via_invoke():
194
+ provider = _make_provider()
195
+ step = _make_step()
196
+ result = provider.invoke(step)
197
+ assert result.status == StepStatus.SUCCEEDED
198
+ assert len(result.assets) == 1
199
+ asset = result.assets[0]
200
+ assert asset.media_type == "text/plain"
201
+ assert asset.url.startswith("text:")
202
+ assert asset.metadata["text"] == "hello world"
203
+
204
+
205
+ def test_text_asset_hash_matches_content():
206
+ provider = _make_provider()
207
+ step = _make_step()
208
+ pred_id = provider.submit(step)
209
+ result = provider.fetch_output(pred_id, step)
210
+ asset = result.assets[0]
211
+ expected = hashlib.sha256(b"hello world").hexdigest()
212
+ assert asset.sha256 == expected
213
+ assert asset.url == f"text:{expected}"
214
+ assert asset.size_bytes == len(b"hello world")
215
+
216
+
217
+ # --- word timings (ms -> seconds) ----------------------------------------
218
+
219
+
220
+ def test_word_timings_converted_ms_to_seconds():
221
+ provider = _make_provider()
222
+ step = _make_step()
223
+ pred_id = provider.submit(step)
224
+ result = provider.fetch_output(pred_id, step)
225
+ timings = result.assets[0].audio.word_timings
226
+ assert [t.word for t in timings] == ["hello", "world"]
227
+ # 0 ms -> 0.0 s, 500 ms -> 0.5 s, 1000 ms -> 1.0 s
228
+ assert timings[0].start == 0.0
229
+ assert timings[0].end == 0.5
230
+ assert timings[1].start == 0.5
231
+ assert timings[1].end == 1.0
232
+ assert timings[1].confidence == 0.98
233
+
234
+
235
+ def test_no_words_leaves_audio_metadata_unset():
236
+ provider = _make_provider(_fake_transcript(with_words=False))
237
+ step = _make_step()
238
+ pred_id = provider.submit(step)
239
+ result = provider.fetch_output(pred_id, step)
240
+ assert result.assets[0].audio is None
241
+
242
+
243
+ # --- audio_duration / pricing payload ------------------------------------
244
+
245
+
246
+ def test_audio_duration_captured_in_provider_payload():
247
+ provider = _make_provider(_fake_transcript(audio_duration=42.0))
248
+ step = _make_step()
249
+ pred_id = provider.submit(step)
250
+ result = provider.fetch_output(pred_id, step)
251
+ assert result.provider_payload["audio_duration"] == 42.0
252
+ assert result.assets[0].metadata["audio_duration"] == 42.0
253
+
254
+
255
+ def test_user_registered_per_minute_pricing():
256
+ from genblaze_core.providers import per_response_metric
257
+
258
+ provider = _make_provider(_fake_transcript(audio_duration=120.0)) # 2 minutes
259
+
260
+ def per_minute(ctx):
261
+ dur = ctx.provider_payload.get("audio_duration")
262
+ return (dur / 60.0) * 0.12 if dur is not None else None
263
+
264
+ # "universal-3-pro" matches the speech family, so register against the
265
+ # concrete slug (register_pricing falls through to the family spec and
266
+ # layers pricing).
267
+ provider.models.register_pricing("universal-3-pro", per_response_metric(per_minute))
268
+ step = _make_step(model="universal-3-pro")
269
+ result = provider.invoke(step)
270
+ assert result.cost_usd == pytest.approx(2 * 0.12)
271
+
272
+
273
+ # --- audio-url resolution precedence -------------------------------------
274
+
275
+
276
+ def test_audio_url_precedence_inputs_first():
277
+ provider = _make_provider()
278
+ step = _make_step(
279
+ prompt="https://prompt.example/p.mp3",
280
+ params={"audio_url": "https://params.example/p.mp3"},
281
+ inputs=[Asset(url="https://inputs.example/in.mp3", media_type="audio/mpeg")],
282
+ )
283
+ assert provider._resolve_audio_url(step) == "https://inputs.example/in.mp3"
284
+
285
+
286
+ def test_audio_url_precedence_params_over_prompt():
287
+ provider = _make_provider()
288
+ step = _make_step(
289
+ prompt="https://prompt.example/p.mp3",
290
+ params={"audio_url": "https://params.example/p.mp3"},
291
+ )
292
+ assert provider._resolve_audio_url(step) == "https://params.example/p.mp3"
293
+
294
+
295
+ def test_audio_url_falls_back_to_prompt():
296
+ provider = _make_provider()
297
+ step = _make_step(prompt="https://prompt.example/p.mp3")
298
+ assert provider._resolve_audio_url(step) == "https://prompt.example/p.mp3"
299
+
300
+
301
+ def test_empty_input_url_falls_through_to_params():
302
+ # An input asset with an empty url must not short-circuit the precedence
303
+ # chain — it degrades to params["audio_url"] rather than failing the SSRF
304
+ # validator on "".
305
+ provider = _make_provider()
306
+ step = _make_step(
307
+ prompt=None,
308
+ params={"audio_url": "https://params.example/p.mp3"},
309
+ inputs=[Asset(url="", media_type="audio/mpeg")],
310
+ )
311
+ assert provider._resolve_audio_url(step) == "https://params.example/p.mp3"
312
+
313
+
314
+ def test_missing_audio_url_raises_invalid_input():
315
+ provider = _make_provider()
316
+ step = _make_step(prompt=None)
317
+ with pytest.raises(ProviderError) as ei:
318
+ provider.submit(step)
319
+ assert ei.value.error_code == ProviderErrorCode.INVALID_INPUT
320
+
321
+
322
+ def test_unsafe_audio_url_rejected():
323
+ provider = _make_provider()
324
+ step = _make_step(prompt="http://insecure.example/a.mp3")
325
+ with pytest.raises(ProviderError):
326
+ provider.submit(step)
327
+
328
+
329
+ # --- error status ---------------------------------------------------------
330
+
331
+
332
+ def test_error_status_raises_provider_error():
333
+ provider = _make_provider(_fake_transcript(status="error", error="bad audio file"))
334
+ step = _make_step()
335
+ pred_id = provider.submit(step)
336
+ with pytest.raises(ProviderError, match="bad audio file"):
337
+ provider.fetch_output(pred_id, step)
338
+
339
+
340
+ def test_poll_false_until_terminal():
341
+ provider = _make_provider(_fake_transcript(status="processing"))
342
+ step = _make_step()
343
+ pred_id = provider.submit(step)
344
+ assert provider.poll(pred_id) is False
345
+
346
+
347
+ def test_fetch_output_non_terminal_raises():
348
+ # Defensive guard: fetch_output on a still-running transcript must raise
349
+ # rather than emit a silently-empty (text:"") asset and report SUCCEEDED.
350
+ provider = _make_provider(_fake_transcript(status="processing", text=None))
351
+ step = _make_step()
352
+ pred_id = provider.submit(step)
353
+ with pytest.raises(ProviderError, match="not complete") as ei:
354
+ provider.fetch_output(pred_id, step)
355
+ assert ei.value.error_code == ProviderErrorCode.SERVER_ERROR
356
+
357
+
358
+ def test_submit_api_error_wrapped_with_code():
359
+ from genblaze_assemblyai import AssemblyAIProvider
360
+
361
+ class _RaisingTranscriber:
362
+ def submit(self, audio_url, config=None):
363
+ err = RuntimeError("boom")
364
+ err.status_code = 429 # type: ignore[attr-defined]
365
+ raise err
366
+
367
+ fake = _make_fake_aai()
368
+ fake.Transcriber = _RaisingTranscriber
369
+ provider = AssemblyAIProvider(api_key="test-key", poll_interval=0.0)
370
+ provider._client = fake
371
+ step = _make_step()
372
+ with pytest.raises(ProviderError, match="AssemblyAI submit failed") as ei:
373
+ provider.submit(step)
374
+ assert ei.value.error_code == ProviderErrorCode.RATE_LIMIT
375
+
376
+
377
+ # --- normalize_params -----------------------------------------------------
378
+
379
+
380
+ def test_normalize_params_language_alias():
381
+ provider = _make_provider()
382
+ out = provider.normalize_params({"language": "es", "speaker_labels": True})
383
+ assert out["language_code"] == "es"
384
+ assert "language" not in out
385
+ assert out["speaker_labels"] is True
386
+
387
+
388
+ def test_normalize_params_idempotent():
389
+ provider = _make_provider()
390
+ params = {"language": "es", "speaker_labels": True}
391
+ once = provider.normalize_params(params)
392
+ twice = provider.normalize_params(once)
393
+ assert once == twice
394
+
395
+
396
+ def test_normalize_params_respects_existing_language_code():
397
+ provider = _make_provider()
398
+ out = provider.normalize_params({"language": "es", "language_code": "en_us"})
399
+ # Existing native key wins; the alias is left intact (idempotency guard).
400
+ assert out["language_code"] == "en_us"
401
+
402
+
403
+ # --- credentials ----------------------------------------------------------
404
+
405
+
406
+ def test_missing_api_key_raises_auth_failure(monkeypatch):
407
+ monkeypatch.delenv("ASSEMBLYAI_API_KEY", raising=False)
408
+ from genblaze_assemblyai import AssemblyAIProvider
409
+
410
+ provider = AssemblyAIProvider(api_key=None) # no injected client
411
+ step = _make_step()
412
+ with pytest.raises(ProviderError, match="No AssemblyAI API key") as ei:
413
+ provider.submit(step)
414
+ assert ei.value.error_code == ProviderErrorCode.AUTH_FAILURE
415
+
416
+
417
+ # --- error mapping --------------------------------------------------------
418
+
419
+
420
+ @pytest.mark.parametrize(
421
+ "status,expected",
422
+ [
423
+ (429, ProviderErrorCode.RATE_LIMIT),
424
+ (401, ProviderErrorCode.AUTH_FAILURE),
425
+ (403, ProviderErrorCode.AUTH_FAILURE),
426
+ (400, ProviderErrorCode.INVALID_INPUT),
427
+ (422, ProviderErrorCode.INVALID_INPUT),
428
+ (500, ProviderErrorCode.SERVER_ERROR),
429
+ (503, ProviderErrorCode.SERVER_ERROR),
430
+ ],
431
+ )
432
+ def test_map_assemblyai_error_status_codes(status, expected):
433
+ from genblaze_assemblyai._errors import map_assemblyai_error
434
+
435
+ exc = RuntimeError("err")
436
+ exc.status_code = status # type: ignore[attr-defined]
437
+ assert map_assemblyai_error(exc) == expected
438
+
439
+
440
+ def test_map_assemblyai_error_no_status_falls_back():
441
+ from genblaze_assemblyai._errors import map_assemblyai_error
442
+
443
+ assert map_assemblyai_error(RuntimeError("totally opaque")) == ProviderErrorCode.UNKNOWN
444
+
445
+
446
+ def test_map_assemblyai_error_accepts_error_string():
447
+ from genblaze_assemblyai._errors import map_assemblyai_error
448
+
449
+ # transcript.error is a plain string; the classifier handles it.
450
+ assert map_assemblyai_error("rate limit exceeded") == ProviderErrorCode.RATE_LIMIT
451
+
452
+
453
+ # --- catalog decoupling ---------------------------------------------------
454
+
455
+
456
+ def test_declares_discovery_support_none():
457
+ from genblaze_assemblyai import AssemblyAIProvider
458
+ from genblaze_core.providers import DiscoverySupport
459
+
460
+ assert AssemblyAIProvider.discovery_support is DiscoverySupport.NONE
461
+
462
+
463
+ @pytest.mark.parametrize("slug", ["universal", "universal-2", "universal-3-pro"])
464
+ def test_speech_slug_matches_family(slug):
465
+ from genblaze_core.providers import ValidationOutcome
466
+
467
+ provider = _make_provider()
468
+ assert provider.validate_model(slug).outcome is ValidationOutcome.OK_PROVISIONAL
469
+
470
+
471
+ def test_non_speech_slug_unknown_permissive():
472
+ from genblaze_core.providers import ValidationOutcome
473
+
474
+ provider = _make_provider()
475
+ assert provider.validate_model("some-other-model").outcome is (
476
+ ValidationOutcome.UNKNOWN_PERMISSIVE
477
+ )
478
+
479
+
480
+ def test_family_and_fallback_carry_no_pricing():
481
+ provider = _make_provider()
482
+ assert provider._models.get("universal-3-pro").pricing is None # family-matched
483
+ assert provider._models.get("anything-else").pricing is None # permissive fallback
484
+
485
+
486
+ def test_capabilities_text_only():
487
+ provider = _make_provider()
488
+ caps = provider.get_capabilities()
489
+ assert caps.supported_modalities == [Modality.TEXT]
490
+ assert caps.accepts_chain_input is True
491
+ assert caps.output_formats == ["text/plain"]
492
+
493
+
494
+ # --- compliance harness ---------------------------------------------------
495
+
496
+
497
+ class TestAssemblyAICompliance(ProviderComplianceTests):
498
+ """Verify AssemblyAIProvider satisfies the genblaze provider contract."""
499
+
500
+ # AssemblyAI ships zero hardcoded prices (per-minute-of-input-audio is
501
+ # user-registered; see docs/reference/pricing-recipes.md). cost_usd stays
502
+ # None unless the user registers a strategy — same posture as Hume.
503
+ expects_cost = False
504
+
505
+ @pytest.fixture(autouse=True)
506
+ def _patch_sdk(self):
507
+ # Safety net so the lazy ``import assemblyai`` resolves to the fake if
508
+ # ever reached; make_provider also injects ``_client`` directly.
509
+ with patch.dict(sys.modules, {"assemblyai": _make_fake_aai()}):
510
+ yield
511
+
512
+ def make_provider(self):
513
+ return _make_provider()
514
+
515
+ def make_step(self):
516
+ return _make_step()
517
+
518
+ def test_assets_have_valid_urls(self) -> None:
519
+ """Transcripts emit a synthetic ``text:{sha256}`` asset (the
520
+ NvidiaChatProvider TEXT-asset precedent), not an https:// / file://
521
+ URL — so we assert that scheme instead of the harness default."""
522
+ provider = self.make_provider()
523
+ step = self.make_step()
524
+ result = provider.invoke(step)
525
+ assert result.assets
526
+ for asset in result.assets:
527
+ assert asset.url.startswith("text:"), f"Expected text: asset URL, got {asset.url}"