pipecat-replicate 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,5 @@
1
+ # Replicate API token — https://replicate.com/account/api-tokens
2
+ REPLICATE_API_TOKEN=
3
+
4
+ # Required by the example (any OpenAI-compatible LLM key)
5
+ OPENAI_API_KEY=
@@ -0,0 +1,17 @@
1
+ name: Publish to PyPI
2
+
3
+ on:
4
+ release:
5
+ types: [published]
6
+
7
+ jobs:
8
+ publish:
9
+ runs-on: ubuntu-latest
10
+ environment: pypi
11
+ permissions:
12
+ id-token: write
13
+ steps:
14
+ - uses: actions/checkout@v4
15
+ - uses: astral-sh/setup-uv@v6
16
+ - run: uv build
17
+ - uses: pypa/gh-action-pypi-publish@release/v1
@@ -0,0 +1,22 @@
1
+ name: Tests
2
+
3
+ on:
4
+ push:
5
+ branches: [main]
6
+ pull_request:
7
+ branches: [main]
8
+
9
+ jobs:
10
+ test:
11
+ runs-on: ubuntu-latest
12
+ strategy:
13
+ matrix:
14
+ python-version: ["3.10", "3.12", "3.13"]
15
+ steps:
16
+ - uses: actions/checkout@v4
17
+ - uses: astral-sh/setup-uv@v6
18
+ - run: uv python install ${{ matrix.python-version }}
19
+ - run: uv sync --group dev --python ${{ matrix.python-version }}
20
+ - run: uv run ruff check
21
+ - run: uv run ruff format --check
22
+ - run: uv run pytest tests/ -v
@@ -0,0 +1,16 @@
1
+ __pycache__/
2
+ *.py[cod]
3
+ *$py.class
4
+ *.egg-info/
5
+ dist/
6
+ build/
7
+ .eggs/
8
+ *.egg
9
+ .env
10
+ .venv/
11
+ venv/
12
+ .python-version
13
+ uv.lock
14
+ .ruff_cache/
15
+ .mypy_cache/
16
+ .pytest_cache/
@@ -0,0 +1,10 @@
1
+ # Changelog
2
+
3
+ ## 0.1.0 (2026-04-15)
4
+
5
+ - Initial release
6
+ - Support for official Replicate models (`owner/name`) and versioned models (`owner/name:version`)
7
+ - Sync prediction requests with polling fallback
8
+ - Returns `URLImageRawFrame` on success
9
+ - Metrics support via `start_ttfb_metrics` / `stop_ttfb_metrics`
10
+ - Tested with Pipecat v0.0.108
@@ -0,0 +1,24 @@
1
+ BSD 2-Clause License
2
+
3
+ Copyright (c) 2026, Borislav Novikov
4
+
5
+ Redistribution and use in source and binary forms, with or without
6
+ modification, are permitted provided that the following conditions are met:
7
+
8
+ 1. Redistributions of source code must retain the above copyright notice, this
9
+ list of conditions and the following disclaimer.
10
+
11
+ 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ this list of conditions and the following disclaimer in the documentation
13
+ and/or other materials provided with the distribution.
14
+
15
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
16
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
19
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
20
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
21
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
22
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
23
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
24
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
@@ -0,0 +1,138 @@
1
+ Metadata-Version: 2.4
2
+ Name: pipecat-replicate
3
+ Version: 0.1.0
4
+ Summary: Replicate image generation integration for Pipecat
5
+ Project-URL: Repository, https://github.com/bnovik0v/pipecat-replicate
6
+ Project-URL: Issues, https://github.com/bnovik0v/pipecat-replicate/issues
7
+ Author-email: Borislav Novikov <borislav@polaro.com>
8
+ License-Expression: BSD-2-Clause
9
+ License-File: LICENSE
10
+ Requires-Python: <3.14,>=3.10
11
+ Requires-Dist: aiohttp>=3.9
12
+ Requires-Dist: loguru>=0.7
13
+ Requires-Dist: pillow>=10.0
14
+ Requires-Dist: pipecat-ai>=0.0.108
15
+ Description-Content-Type: text/markdown
16
+
17
+ # pipecat-replicate
18
+
19
+ [Replicate](https://replicate.com/) image generation integration for [Pipecat](https://github.com/pipecat-ai/pipecat) — a framework for building voice and multimodal conversational AI applications.
20
+
21
+ ## Pipecat Compatibility
22
+
23
+ **Tested with Pipecat v0.0.108**
24
+
25
+ ## Features
26
+
27
+ - Text-to-image generation using any Replicate-hosted model
28
+ - Official models (`owner/name`) and versioned community models (`owner/name:version`)
29
+ - Sync prediction requests with automatic polling fallback
30
+ - Returns `URLImageRawFrame` for direct use in Pipecat pipelines
31
+ - Configurable via the standard Pipecat `Settings` dataclass pattern
32
+
33
+ ## Installation
34
+
35
+ ### Using pip
36
+
37
+ ```bash
38
+ pip install pipecat-replicate
39
+ ```
40
+
41
+ ### Using uv
42
+
43
+ ```bash
44
+ uv add pipecat-replicate
45
+ ```
46
+
47
+ ### From source
48
+
49
+ ```bash
50
+ git clone https://github.com/bnovik0v/pipecat-replicate.git
51
+ cd pipecat-replicate
52
+ pip install -e .
53
+ ```
54
+
55
+ ## Quick Start
56
+
57
+ 1. Get your Replicate API token at https://replicate.com/account/api-tokens
58
+
59
+ 2. Set the environment variable:
60
+
61
+ ```bash
62
+ export REPLICATE_API_TOKEN=r8_...
63
+ ```
64
+
65
+ 3. Use in a Pipecat pipeline:
66
+
67
+ ```python
68
+ import aiohttp
69
+ from pipecat.frames.frames import TextFrame
70
+ from pipecat.pipeline.pipeline import Pipeline
71
+ from pipecat.pipeline.runner import PipelineRunner
72
+ from pipecat.pipeline.task import PipelineTask
73
+ from pipecat_replicate import ReplicateImageGenService
74
+
75
+ async with aiohttp.ClientSession() as session:
76
+ imagegen = ReplicateImageGenService(
77
+ aiohttp_session=session,
78
+ settings=ReplicateImageGenService.Settings(
79
+ model="black-forest-labs/flux-schnell",
80
+ aspect_ratio="1:1",
81
+ ),
82
+ )
83
+
84
+ pipeline = Pipeline([imagegen, ...])
85
+ task = PipelineTask(pipeline)
86
+ await task.queue_frames([TextFrame("a cat in the style of a screenprint poster")])
87
+
88
+ runner = PipelineRunner()
89
+ await runner.run(task)
90
+ ```
91
+
92
+ ## Configuration
93
+
94
+ ### Settings
95
+
96
+ | Field | Type | Default | Description |
97
+ | ------------------------ | ------ | ---------------------------------- | ----------------------------------------- |
98
+ | `model` | `str` | `"black-forest-labs/flux-schnell"` | Replicate model identifier |
99
+ | `aspect_ratio` | `str` | `"1:1"` | Aspect ratio for generated images |
100
+ | `num_outputs` | `int` | `1` | Number of images to generate (1–4) |
101
+ | `num_inference_steps` | `int` | `4` | Number of denoising steps |
102
+ | `seed` | `int` | `None` | Random seed for reproducible generation |
103
+ | `output_format` | `str` | `"webp"` | Output image format |
104
+ | `output_quality` | `int` | `80` | Output quality (0–100) |
105
+ | `disable_safety_checker` | `bool` | `False` | Whether to disable the model safety check |
106
+ | `go_fast` | `bool` | `True` | Use the model's fast generation mode |
107
+ | `megapixels` | `str` | `"1"` | Approximate megapixel count |
108
+
109
+ ### Constructor Parameters
110
+
111
+ | Parameter | Type | Default | Description |
112
+ | -------------------- | ---------------------- | -------------------------------------- | ------------------------------------ |
113
+ | `aiohttp_session` | `aiohttp.ClientSession`| *(required)* | HTTP session for API requests |
114
+ | `api_token` | `str` | `$REPLICATE_API_TOKEN` | Replicate API token |
115
+ | `settings` | `Settings` | *(see defaults above)* | Generation settings |
116
+ | `base_url` | `str` | `"https://api.replicate.com/v1"` | API base URL |
117
+ | `wait_timeout_secs` | `int` | `60` | Sync wait timeout (Prefer header) |
118
+ | `poll_interval_secs` | `float` | `0.5` | Poll interval for async predictions |
119
+ | `max_poll_attempts` | `int` | `120` | Maximum number of polling attempts |
120
+
121
+ ## Examples
122
+
123
+ See [`examples/basic_image_gen.py`](examples/basic_image_gen.py) for a complete example that generates an image and displays it in a Tk window.
124
+
125
+ ```bash
126
+ REPLICATE_API_TOKEN=r8_... python examples/basic_image_gen.py
127
+ ```
128
+
129
+ ## Running Tests
130
+
131
+ ```bash
132
+ uv sync --group dev
133
+ uv run pytest
134
+ ```
135
+
136
+ ## License
137
+
138
+ BSD 2-Clause — see [LICENSE](LICENSE).
@@ -0,0 +1,122 @@
1
+ # pipecat-replicate
2
+
3
+ [Replicate](https://replicate.com/) image generation integration for [Pipecat](https://github.com/pipecat-ai/pipecat) — a framework for building voice and multimodal conversational AI applications.
4
+
5
+ ## Pipecat Compatibility
6
+
7
+ **Tested with Pipecat v0.0.108**
8
+
9
+ ## Features
10
+
11
+ - Text-to-image generation using any Replicate-hosted model
12
+ - Official models (`owner/name`) and versioned community models (`owner/name:version`)
13
+ - Sync prediction requests with automatic polling fallback
14
+ - Returns `URLImageRawFrame` for direct use in Pipecat pipelines
15
+ - Configurable via the standard Pipecat `Settings` dataclass pattern
16
+
17
+ ## Installation
18
+
19
+ ### Using pip
20
+
21
+ ```bash
22
+ pip install pipecat-replicate
23
+ ```
24
+
25
+ ### Using uv
26
+
27
+ ```bash
28
+ uv add pipecat-replicate
29
+ ```
30
+
31
+ ### From source
32
+
33
+ ```bash
34
+ git clone https://github.com/bnovik0v/pipecat-replicate.git
35
+ cd pipecat-replicate
36
+ pip install -e .
37
+ ```
38
+
39
+ ## Quick Start
40
+
41
+ 1. Get your Replicate API token at https://replicate.com/account/api-tokens
42
+
43
+ 2. Set the environment variable:
44
+
45
+ ```bash
46
+ export REPLICATE_API_TOKEN=r8_...
47
+ ```
48
+
49
+ 3. Use in a Pipecat pipeline:
50
+
51
+ ```python
52
+ import aiohttp
53
+ from pipecat.frames.frames import TextFrame
54
+ from pipecat.pipeline.pipeline import Pipeline
55
+ from pipecat.pipeline.runner import PipelineRunner
56
+ from pipecat.pipeline.task import PipelineTask
57
+ from pipecat_replicate import ReplicateImageGenService
58
+
59
+ async with aiohttp.ClientSession() as session:
60
+ imagegen = ReplicateImageGenService(
61
+ aiohttp_session=session,
62
+ settings=ReplicateImageGenService.Settings(
63
+ model="black-forest-labs/flux-schnell",
64
+ aspect_ratio="1:1",
65
+ ),
66
+ )
67
+
68
+ pipeline = Pipeline([imagegen, ...])
69
+ task = PipelineTask(pipeline)
70
+ await task.queue_frames([TextFrame("a cat in the style of a screenprint poster")])
71
+
72
+ runner = PipelineRunner()
73
+ await runner.run(task)
74
+ ```
75
+
76
+ ## Configuration
77
+
78
+ ### Settings
79
+
80
+ | Field | Type | Default | Description |
81
+ | ------------------------ | ------ | ---------------------------------- | ----------------------------------------- |
82
+ | `model` | `str` | `"black-forest-labs/flux-schnell"` | Replicate model identifier |
83
+ | `aspect_ratio` | `str` | `"1:1"` | Aspect ratio for generated images |
84
+ | `num_outputs` | `int` | `1` | Number of images to generate (1–4) |
85
+ | `num_inference_steps` | `int` | `4` | Number of denoising steps |
86
+ | `seed` | `int` | `None` | Random seed for reproducible generation |
87
+ | `output_format` | `str` | `"webp"` | Output image format |
88
+ | `output_quality` | `int` | `80` | Output quality (0–100) |
89
+ | `disable_safety_checker` | `bool` | `False` | Whether to disable the model safety check |
90
+ | `go_fast` | `bool` | `True` | Use the model's fast generation mode |
91
+ | `megapixels` | `str` | `"1"` | Approximate megapixel count |
92
+
93
+ ### Constructor Parameters
94
+
95
+ | Parameter | Type | Default | Description |
96
+ | -------------------- | ---------------------- | -------------------------------------- | ------------------------------------ |
97
+ | `aiohttp_session` | `aiohttp.ClientSession`| *(required)* | HTTP session for API requests |
98
+ | `api_token` | `str` | `$REPLICATE_API_TOKEN` | Replicate API token |
99
+ | `settings` | `Settings` | *(see defaults above)* | Generation settings |
100
+ | `base_url` | `str` | `"https://api.replicate.com/v1"` | API base URL |
101
+ | `wait_timeout_secs` | `int` | `60` | Sync wait timeout (Prefer header) |
102
+ | `poll_interval_secs` | `float` | `0.5` | Poll interval for async predictions |
103
+ | `max_poll_attempts` | `int` | `120` | Maximum number of polling attempts |
104
+
105
+ ## Examples
106
+
107
+ See [`examples/basic_image_gen.py`](examples/basic_image_gen.py) for a complete example that generates an image and displays it in a Tk window.
108
+
109
+ ```bash
110
+ REPLICATE_API_TOKEN=r8_... python examples/basic_image_gen.py
111
+ ```
112
+
113
+ ## Running Tests
114
+
115
+ ```bash
116
+ uv sync --group dev
117
+ uv run pytest
118
+ ```
119
+
120
+ ## License
121
+
122
+ BSD 2-Clause — see [LICENSE](LICENSE).
@@ -0,0 +1,70 @@
1
+ """Basic Replicate image generation example with Pipecat.
2
+
3
+ Generates an image using the Replicate API and displays it in a Tk window.
4
+
5
+ Requirements:
6
+ pip install pipecat-replicate pipecat-ai[local]
7
+
8
+ Environment variables:
9
+ REPLICATE_API_TOKEN — your Replicate API token
10
+ """
11
+
12
+ import asyncio
13
+ import os
14
+ import sys
15
+ import tkinter as tk
16
+
17
+ import aiohttp
18
+ from dotenv import load_dotenv
19
+ from loguru import logger
20
+
21
+ from pipecat.frames.frames import TextFrame
22
+ from pipecat.pipeline.pipeline import Pipeline
23
+ from pipecat.pipeline.runner import PipelineRunner
24
+ from pipecat.pipeline.task import PipelineTask
25
+ from pipecat.transports.local.tk import TkLocalTransport, TkTransportParams
26
+ from pipecat_replicate import ReplicateImageGenService
27
+
28
+ load_dotenv(override=True)
29
+
30
+ logger.remove(0)
31
+ logger.add(sys.stderr, level="DEBUG")
32
+
33
+
34
+ async def main():
35
+ async with aiohttp.ClientSession() as session:
36
+ tk_root = tk.Tk()
37
+ tk_root.title("Replicate Image Gen")
38
+
39
+ transport = TkLocalTransport(
40
+ tk_root,
41
+ TkTransportParams(video_out_enabled=True, video_out_width=1024, video_out_height=1024),
42
+ )
43
+
44
+ imagegen = ReplicateImageGenService(
45
+ settings=ReplicateImageGenService.Settings(
46
+ model="black-forest-labs/flux-schnell",
47
+ aspect_ratio="1:1",
48
+ ),
49
+ aiohttp_session=session,
50
+ api_token=os.getenv("REPLICATE_API_TOKEN"),
51
+ )
52
+
53
+ pipeline = Pipeline([imagegen, transport.output()])
54
+
55
+ task = PipelineTask(pipeline)
56
+ await task.queue_frames([TextFrame("a cat in the style of a screenprint poster")])
57
+
58
+ runner = PipelineRunner()
59
+
60
+ async def run_tk():
61
+ while not task.has_finished():
62
+ tk_root.update()
63
+ tk_root.update_idletasks()
64
+ await asyncio.sleep(0.1)
65
+
66
+ await asyncio.gather(runner.run(task), run_tk())
67
+
68
+
69
+ if __name__ == "__main__":
70
+ asyncio.run(main())
@@ -0,0 +1,39 @@
1
+ [build-system]
2
+ requires = ["hatchling"]
3
+ build-backend = "hatchling.build"
4
+
5
+ [project]
6
+ name = "pipecat-replicate"
7
+ version = "0.1.0"
8
+ description = "Replicate image generation integration for Pipecat"
9
+ readme = "README.md"
10
+ license = "BSD-2-Clause"
11
+ authors = [
12
+ { name = "Borislav Novikov", email = "borislav@polaro.com" },
13
+ ]
14
+ requires-python = ">=3.10,<3.14"
15
+ dependencies = [
16
+ "pipecat-ai>=0.0.108",
17
+ "aiohttp>=3.9",
18
+ "loguru>=0.7",
19
+ "Pillow>=10.0",
20
+ ]
21
+
22
+ [project.urls]
23
+ Repository = "https://github.com/bnovik0v/pipecat-replicate"
24
+ Issues = "https://github.com/bnovik0v/pipecat-replicate/issues"
25
+
26
+ [dependency-groups]
27
+ dev = [
28
+ "pytest>=8",
29
+ "pytest-asyncio>=0.23",
30
+ "pytest-aiohttp>=1.0",
31
+ "ruff>=0.4",
32
+ "websockets>=13.0",
33
+ ]
34
+
35
+ [tool.ruff]
36
+ line-length = 100
37
+
38
+ [tool.pytest.ini_options]
39
+ asyncio_mode = "auto"
@@ -0,0 +1,5 @@
1
+ """Replicate image generation integration for Pipecat."""
2
+
3
+ from pipecat_replicate.image import ReplicateImageGenService, ReplicateImageGenSettings
4
+
5
+ __all__ = ["ReplicateImageGenService", "ReplicateImageGenSettings"]
@@ -0,0 +1,356 @@
1
+ """Replicate image generation service implementation.
2
+
3
+ This module provides integration with Replicate-hosted image generation models
4
+ for creating images from text prompts.
5
+ """
6
+
7
+ import asyncio
8
+ import base64
9
+ import io
10
+ import os
11
+ from dataclasses import dataclass, field
12
+ from typing import Any, AsyncGenerator, Dict, Optional
13
+
14
+ import aiohttp
15
+ from loguru import logger
16
+ from PIL import Image
17
+ from pydantic import BaseModel, Field
18
+
19
+ from pipecat.frames.frames import ErrorFrame, Frame, URLImageRawFrame
20
+ from pipecat.services.image_service import ImageGenService
21
+ from pipecat.services.settings import NOT_GIVEN, ImageGenSettings, _NotGiven
22
+
23
+
24
+ @dataclass
25
+ class ReplicateImageGenSettings(ImageGenSettings):
26
+ """Settings for the Replicate image generation service.
27
+
28
+ Parameters:
29
+ model: Replicate model identifier. Use ``owner/name`` for official
30
+ models or ``owner/name:version`` for versioned community models.
31
+ aspect_ratio: Aspect ratio for generated images.
32
+ num_outputs: Number of images to generate.
33
+ num_inference_steps: Number of denoising steps for the model.
34
+ seed: Random seed for reproducible generation. ``None`` uses a random seed.
35
+ output_format: Image format requested from the model.
36
+ output_quality: Output quality value supported by the model.
37
+ disable_safety_checker: Whether to disable the model safety checker.
38
+ go_fast: Whether to use the model's faster generation mode.
39
+ megapixels: Approximate megapixel count for generated images.
40
+ """
41
+
42
+ aspect_ratio: str | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
43
+ num_outputs: int | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
44
+ num_inference_steps: int | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
45
+ seed: int | None | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
46
+ output_format: str | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
47
+ output_quality: int | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
48
+ disable_safety_checker: bool | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
49
+ go_fast: bool | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
50
+ megapixels: str | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
51
+
52
+ def to_api_input(self) -> Dict[str, Any]:
53
+ """Build the Replicate input payload from settings."""
54
+ payload: Dict[str, Any] = {
55
+ "aspect_ratio": self.aspect_ratio,
56
+ "num_outputs": self.num_outputs,
57
+ "num_inference_steps": self.num_inference_steps,
58
+ "output_format": self.output_format,
59
+ "output_quality": self.output_quality,
60
+ "disable_safety_checker": self.disable_safety_checker,
61
+ "go_fast": self.go_fast,
62
+ "megapixels": self.megapixels,
63
+ }
64
+ if self.seed is not None:
65
+ payload["seed"] = self.seed
66
+ payload.update(self.extra)
67
+ return payload
68
+
69
+
70
+ class ReplicateImageGenService(ImageGenService):
71
+ """Replicate image generation service.
72
+
73
+ Provides text-to-image generation using Replicate-hosted models. Official
74
+ models use an ``owner/name`` identifier. Versioned community models can be
75
+ addressed with ``owner/name:version``.
76
+ """
77
+
78
+ Settings = ReplicateImageGenSettings
79
+ _settings: Settings
80
+
81
+ class InputParams(BaseModel):
82
+ """Input parameters for Replicate image generation.
83
+
84
+ .. deprecated:: 0.1.0
85
+ Use ``settings=ReplicateImageGenService.Settings(...)`` instead.
86
+
87
+ Parameters:
88
+ aspect_ratio: Aspect ratio for generated images. Defaults to ``"1:1"``.
89
+ num_outputs: Number of images to generate. Defaults to ``1``.
90
+ num_inference_steps: Number of denoising steps. Defaults to ``4``.
91
+ seed: Random seed for reproducible generation. Defaults to ``None``.
92
+ output_format: Output image format. Defaults to ``"webp"``.
93
+ output_quality: Output quality value. Defaults to ``80``.
94
+ disable_safety_checker: Whether to disable the safety checker. Defaults to ``False``.
95
+ go_fast: Whether to use the fast model path. Defaults to ``True``.
96
+ megapixels: Approximate megapixel count. Defaults to ``"1"``.
97
+ """
98
+
99
+ aspect_ratio: str = "1:1"
100
+ num_outputs: int = Field(default=1, ge=1, le=4)
101
+ num_inference_steps: int = Field(default=4, ge=1)
102
+ seed: Optional[int] = None
103
+ output_format: str = "webp"
104
+ output_quality: int = Field(default=80, ge=0, le=100)
105
+ disable_safety_checker: bool = False
106
+ go_fast: bool = True
107
+ megapixels: str = "1"
108
+
109
+ _TERMINAL_ERROR_STATUSES = {"failed", "canceled", "cancelled"}
110
+
111
+ def __init__(
112
+ self,
113
+ *,
114
+ params: Optional[InputParams] = None,
115
+ aiohttp_session: aiohttp.ClientSession,
116
+ api_token: Optional[str] = None,
117
+ model: Optional[str] = None,
118
+ settings: Optional[Settings] = None,
119
+ base_url: str = "https://api.replicate.com/v1",
120
+ wait_timeout_secs: int = 60,
121
+ poll_interval_secs: float = 0.5,
122
+ max_poll_attempts: int = 120,
123
+ **kwargs,
124
+ ):
125
+ """Initialize the ReplicateImageGenService.
126
+
127
+ Args:
128
+ params: Input parameters for image generation configuration.
129
+
130
+ .. deprecated:: 0.1.0
131
+ Use ``settings=ReplicateImageGenService.Settings(...)`` instead.
132
+
133
+ aiohttp_session: HTTP client session for Replicate requests and image downloads.
134
+ api_token: Optional Replicate API token. If provided, sets the
135
+ ``REPLICATE_API_TOKEN`` environment variable.
136
+ model: Replicate model identifier. Defaults to
137
+ ``"black-forest-labs/flux-schnell"``.
138
+
139
+ .. deprecated:: 0.1.0
140
+ Use ``settings=ReplicateImageGenService.Settings(model=...)`` instead.
141
+
142
+ settings: Runtime-configurable generation settings. When provided
143
+ alongside deprecated parameters, ``settings`` values take precedence.
144
+ base_url: Base URL for the Replicate HTTP API.
145
+ wait_timeout_secs: Sync wait duration passed in the ``Prefer`` header.
146
+ poll_interval_secs: Poll interval used when the initial sync request
147
+ returns before output is available.
148
+ max_poll_attempts: Maximum number of follow-up prediction polls.
149
+ **kwargs: Additional arguments passed to parent ImageGenService.
150
+ """
151
+ default_settings = self.Settings(
152
+ model="black-forest-labs/flux-schnell",
153
+ aspect_ratio="1:1",
154
+ num_outputs=1,
155
+ num_inference_steps=4,
156
+ seed=None,
157
+ output_format="webp",
158
+ output_quality=80,
159
+ disable_safety_checker=False,
160
+ go_fast=True,
161
+ megapixels="1",
162
+ )
163
+
164
+ if model is not None:
165
+ self._warn_init_param_moved_to_settings("model", "model")
166
+ default_settings.model = model
167
+
168
+ if params is not None:
169
+ self._warn_init_param_moved_to_settings("params")
170
+ if not settings:
171
+ default_settings.aspect_ratio = params.aspect_ratio
172
+ default_settings.num_outputs = params.num_outputs
173
+ default_settings.num_inference_steps = params.num_inference_steps
174
+ default_settings.seed = params.seed
175
+ default_settings.output_format = params.output_format
176
+ default_settings.output_quality = params.output_quality
177
+ default_settings.disable_safety_checker = params.disable_safety_checker
178
+ default_settings.go_fast = params.go_fast
179
+ default_settings.megapixels = params.megapixels
180
+
181
+ if settings is not None:
182
+ default_settings.apply_update(settings)
183
+
184
+ super().__init__(settings=default_settings, **kwargs)
185
+ self._aiohttp_session = aiohttp_session
186
+ self._api_token = api_token or os.getenv("REPLICATE_API_TOKEN", "")
187
+ self._base_url = base_url.rstrip("/")
188
+ self._wait_timeout_secs = wait_timeout_secs
189
+ self._poll_interval_secs = poll_interval_secs
190
+ self._max_poll_attempts = max_poll_attempts
191
+ if api_token:
192
+ os.environ["REPLICATE_API_TOKEN"] = api_token
193
+
194
+ def _prediction_request(self, prompt: str) -> tuple[str, dict[str, str], dict[str, Any]]:
195
+ """Build the Replicate prediction URL, headers, and request body."""
196
+ model = self._settings.model or ""
197
+ if "/" not in model:
198
+ raise ValueError("Replicate model must use 'owner/name' or 'owner/name:version' format")
199
+
200
+ input_payload = {"prompt": prompt, **self._settings.to_api_input()}
201
+ headers = {
202
+ "Authorization": f"Bearer {self._api_token}",
203
+ "Content-Type": "application/json",
204
+ "Prefer": f"wait={self._wait_timeout_secs}",
205
+ }
206
+
207
+ if ":" in model:
208
+ _, version = model.rsplit(":", maxsplit=1)
209
+ if not version:
210
+ raise ValueError("Versioned Replicate models must use 'owner/name:version' format")
211
+ return (
212
+ f"{self._base_url}/predictions",
213
+ headers,
214
+ {"version": version, "input": input_payload},
215
+ )
216
+
217
+ owner, name = model.split("/", maxsplit=1)
218
+ return (
219
+ f"{self._base_url}/models/{owner}/{name}/predictions",
220
+ headers,
221
+ {"input": input_payload},
222
+ )
223
+
224
+ def _extract_output_urls(self, prediction: Dict[str, Any]) -> list[str]:
225
+ """Extract output URLs from a Replicate prediction response."""
226
+ output = prediction.get("output")
227
+ if output is None:
228
+ return []
229
+ if isinstance(output, str):
230
+ return [output]
231
+ if isinstance(output, dict):
232
+ url = output.get("url")
233
+ return [url] if isinstance(url, str) else []
234
+
235
+ urls: list[str] = []
236
+ if isinstance(output, list):
237
+ for item in output:
238
+ if isinstance(item, str):
239
+ urls.append(item)
240
+ elif isinstance(item, dict) and isinstance(item.get("url"), str):
241
+ urls.append(item["url"])
242
+ return urls
243
+
244
+ async def _poll_prediction(
245
+ self,
246
+ prediction: Dict[str, Any],
247
+ headers: Dict[str, str],
248
+ *,
249
+ fallback_url: str | None,
250
+ ) -> Dict[str, Any]:
251
+ """Poll the prediction endpoint until output is available or it fails."""
252
+ output_urls = self._extract_output_urls(prediction)
253
+ if output_urls:
254
+ return prediction
255
+
256
+ prediction_url = prediction.get("urls", {}).get("get") or fallback_url
257
+
258
+ for _ in range(self._max_poll_attempts):
259
+ status = str(prediction.get("status", "")).lower()
260
+ if status in self._TERMINAL_ERROR_STATUSES:
261
+ error = prediction.get("error") or "prediction failed"
262
+ raise RuntimeError(f"Replicate prediction failed: {error}")
263
+
264
+ if not prediction_url:
265
+ return prediction
266
+
267
+ await asyncio.sleep(self._poll_interval_secs)
268
+ async with self._aiohttp_session.get(prediction_url, headers=headers) as response:
269
+ if response.status != 200:
270
+ error_text = await response.text()
271
+ raise RuntimeError(f"Replicate polling error ({response.status}): {error_text}")
272
+ prediction = await response.json()
273
+ output_urls = self._extract_output_urls(prediction)
274
+ if output_urls:
275
+ return prediction
276
+
277
+ raise TimeoutError("Replicate image generation timed out while waiting for output")
278
+
279
+ @staticmethod
280
+ def _load_image_frame(image_url: str | None, encoded_image: bytes) -> URLImageRawFrame:
281
+ """Decode image bytes and build a Pipecat image frame."""
282
+ image = Image.open(io.BytesIO(encoded_image))
283
+ return URLImageRawFrame(
284
+ url=image_url,
285
+ image=image.tobytes(),
286
+ size=image.size,
287
+ format=image.format,
288
+ )
289
+
290
+ @staticmethod
291
+ def _decode_data_url(data_url: str) -> bytes:
292
+ """Decode a ``data:`` URL into raw bytes."""
293
+ _, _, data = data_url.partition(",")
294
+ if not data:
295
+ raise ValueError("Replicate returned an invalid data URL")
296
+ return base64.b64decode(data)
297
+
298
+ async def run_image_gen(self, prompt: str) -> AsyncGenerator[Frame, None]:
299
+ """Generate images from a text prompt using Replicate.
300
+
301
+ Args:
302
+ prompt: The text prompt to generate images from.
303
+
304
+ Yields:
305
+ URLImageRawFrame: Frame containing generated image data and metadata.
306
+ ErrorFrame: If image generation fails.
307
+ """
308
+ logger.debug(f"Generating image from prompt with Replicate: {prompt}")
309
+ await self.start_ttfb_metrics()
310
+
311
+ try:
312
+ url, headers, payload = self._prediction_request(prompt)
313
+ async with self._aiohttp_session.post(url, json=payload, headers=headers) as response:
314
+ if response.status not in {200, 201}:
315
+ error_text = await response.text()
316
+ await self.stop_ttfb_metrics()
317
+ yield ErrorFrame(error=f"Replicate API error ({response.status}): {error_text}")
318
+ return
319
+ prediction = await response.json()
320
+ fallback_url = response.headers.get("Location")
321
+
322
+ prediction = await self._poll_prediction(
323
+ prediction,
324
+ headers={"Authorization": headers["Authorization"]},
325
+ fallback_url=fallback_url,
326
+ )
327
+
328
+ output_urls = self._extract_output_urls(prediction)
329
+ if not output_urls:
330
+ await self.stop_ttfb_metrics()
331
+ yield ErrorFrame("Replicate image generation failed: no output returned")
332
+ return
333
+
334
+ await self.stop_ttfb_metrics()
335
+
336
+ for image_url in output_urls:
337
+ if image_url.startswith("data:"):
338
+ encoded_image = self._decode_data_url(image_url)
339
+ else:
340
+ async with self._aiohttp_session.get(image_url) as response:
341
+ if response.status != 200:
342
+ error_text = await response.text()
343
+ yield ErrorFrame(
344
+ error=(
345
+ f"Replicate image download error"
346
+ f" ({response.status}): {error_text}"
347
+ )
348
+ )
349
+ continue
350
+ encoded_image = await response.read()
351
+
352
+ frame = await asyncio.to_thread(self._load_image_frame, image_url, encoded_image)
353
+ yield frame
354
+
355
+ except Exception as e:
356
+ yield ErrorFrame(f"Replicate image generation error: {e}")
File without changes
@@ -0,0 +1,202 @@
1
+ """Tests for ReplicateImageGenService."""
2
+
3
+ import io
4
+
5
+ import aiohttp
6
+ import pytest
7
+ from aiohttp import web
8
+ from PIL import Image
9
+
10
+ from pipecat.frames.frames import ErrorFrame, TextFrame, URLImageRawFrame
11
+ from pipecat.tests.utils import run_test
12
+ from pipecat_replicate import ReplicateImageGenService
13
+
14
+
15
+ def _make_test_image_bytes(format: str = "PNG") -> bytes:
16
+ image = Image.new("RGB", (2, 2), color=(255, 0, 0))
17
+ buffer = io.BytesIO()
18
+ image.save(buffer, format=format)
19
+ return buffer.getvalue()
20
+
21
+
22
+ @pytest.mark.asyncio
23
+ async def test_run_replicate_image_success_official_model(aiohttp_client):
24
+ """Official Replicate models should return image frames from sync predictions."""
25
+
26
+ image_bytes = _make_test_image_bytes()
27
+
28
+ async def prediction_handler(request):
29
+ assert request.headers["Authorization"] == "Bearer test-token"
30
+ assert request.headers["Prefer"] == "wait=60"
31
+ payload = await request.json()
32
+ assert payload["input"]["prompt"] == "a red square"
33
+ assert payload["input"]["aspect_ratio"] == "1:1"
34
+ return web.json_response(
35
+ {
36
+ "status": "processing",
37
+ "output": [str(request.url.with_path("/image.png"))],
38
+ "urls": {"get": str(request.url.with_path("/prediction-status"))},
39
+ }
40
+ )
41
+
42
+ async def image_handler(_request):
43
+ return web.Response(body=image_bytes, content_type="image/png")
44
+
45
+ app = web.Application()
46
+ app.router.add_post("/v1/models/black-forest-labs/flux-schnell/predictions", prediction_handler)
47
+ app.router.add_get("/image.png", image_handler)
48
+ client = await aiohttp_client(app)
49
+ base_url = str(client.make_url("/v1")).rstrip("/")
50
+
51
+ async with aiohttp.ClientSession() as session:
52
+ image_gen = ReplicateImageGenService(
53
+ aiohttp_session=session,
54
+ api_token="test-token",
55
+ base_url=base_url,
56
+ )
57
+
58
+ down_frames, up_frames = await run_test(
59
+ image_gen, frames_to_send=[TextFrame("a red square")]
60
+ )
61
+
62
+ assert not up_frames
63
+ assert isinstance(down_frames[0], TextFrame)
64
+ assert isinstance(down_frames[1], URLImageRawFrame)
65
+ assert down_frames[1].size == (2, 2)
66
+ assert down_frames[1].format == "PNG"
67
+
68
+
69
+ @pytest.mark.asyncio
70
+ async def test_run_replicate_image_success_versioned_model(aiohttp_client):
71
+ """Versioned community models should use the generic predictions endpoint."""
72
+
73
+ image_bytes = _make_test_image_bytes()
74
+ version = "53d5d1586a229bd033e060941789bfb0c177cefd5ef638f34b3099658343a897"
75
+
76
+ async def prediction_handler(request):
77
+ payload = await request.json()
78
+ assert payload["version"] == version
79
+ assert payload["input"]["prompt"] == "a blue square"
80
+ return web.json_response(
81
+ {
82
+ "status": "successful",
83
+ "output": [str(request.url.with_path("/image-versioned.png"))],
84
+ }
85
+ )
86
+
87
+ async def image_handler(_request):
88
+ return web.Response(body=image_bytes, content_type="image/png")
89
+
90
+ app = web.Application()
91
+ app.router.add_post("/v1/predictions", prediction_handler)
92
+ app.router.add_get("/image-versioned.png", image_handler)
93
+ client = await aiohttp_client(app)
94
+ base_url = str(client.make_url("/v1")).rstrip("/")
95
+
96
+ async with aiohttp.ClientSession() as session:
97
+ image_gen = ReplicateImageGenService(
98
+ aiohttp_session=session,
99
+ api_token="test-token",
100
+ base_url=base_url,
101
+ settings=ReplicateImageGenService.Settings(
102
+ model=f"black-forest-labs/flux-schnell:{version}",
103
+ aspect_ratio="1:1",
104
+ num_outputs=1,
105
+ num_inference_steps=4,
106
+ seed=None,
107
+ output_format="webp",
108
+ output_quality=80,
109
+ disable_safety_checker=False,
110
+ go_fast=True,
111
+ megapixels="1",
112
+ ),
113
+ )
114
+
115
+ down_frames, up_frames = await run_test(
116
+ image_gen, frames_to_send=[TextFrame("a blue square")]
117
+ )
118
+
119
+ assert not up_frames
120
+ assert isinstance(down_frames[1], URLImageRawFrame)
121
+
122
+
123
+ @pytest.mark.asyncio
124
+ async def test_run_replicate_image_polls_when_sync_response_has_no_output(aiohttp_client):
125
+ """The service should poll the prediction URL if sync mode returns early."""
126
+
127
+ image_bytes = _make_test_image_bytes()
128
+ poll_count = 0
129
+
130
+ async def prediction_handler(request):
131
+ return web.json_response(
132
+ {
133
+ "status": "processing",
134
+ "output": None,
135
+ "urls": {"get": str(request.url.with_path("/v1/predictions/test-id"))},
136
+ }
137
+ )
138
+
139
+ async def prediction_status_handler(request):
140
+ nonlocal poll_count
141
+ poll_count += 1
142
+ return web.json_response(
143
+ {
144
+ "status": "processing",
145
+ "output": [str(request.url.with_path("/image-polled.png"))],
146
+ }
147
+ )
148
+
149
+ async def image_handler(_request):
150
+ return web.Response(body=image_bytes, content_type="image/png")
151
+
152
+ app = web.Application()
153
+ app.router.add_post("/v1/models/black-forest-labs/flux-schnell/predictions", prediction_handler)
154
+ app.router.add_get("/v1/predictions/test-id", prediction_status_handler)
155
+ app.router.add_get("/image-polled.png", image_handler)
156
+ client = await aiohttp_client(app)
157
+ base_url = str(client.make_url("/v1")).rstrip("/")
158
+
159
+ async with aiohttp.ClientSession() as session:
160
+ image_gen = ReplicateImageGenService(
161
+ aiohttp_session=session,
162
+ api_token="test-token",
163
+ base_url=base_url,
164
+ poll_interval_secs=0.001,
165
+ max_poll_attempts=2,
166
+ )
167
+
168
+ down_frames, up_frames = await run_test(image_gen, frames_to_send=[TextFrame("poll me")])
169
+
170
+ assert not up_frames
171
+ assert poll_count == 1
172
+ assert isinstance(down_frames[1], URLImageRawFrame)
173
+
174
+
175
+ @pytest.mark.asyncio
176
+ async def test_run_replicate_image_error(aiohttp_client):
177
+ """Non-success responses should propagate an ErrorFrame upstream."""
178
+
179
+ async def prediction_handler(_request):
180
+ return web.Response(status=401, text="unauthorized")
181
+
182
+ app = web.Application()
183
+ app.router.add_post("/v1/models/black-forest-labs/flux-schnell/predictions", prediction_handler)
184
+ client = await aiohttp_client(app)
185
+ base_url = str(client.make_url("/v1")).rstrip("/")
186
+
187
+ async with aiohttp.ClientSession() as session:
188
+ image_gen = ReplicateImageGenService(
189
+ aiohttp_session=session,
190
+ api_token="bad-token",
191
+ base_url=base_url,
192
+ )
193
+
194
+ down_frames, up_frames = await run_test(
195
+ image_gen,
196
+ frames_to_send=[TextFrame("this should fail")],
197
+ )
198
+
199
+ assert isinstance(down_frames[0], TextFrame)
200
+ assert len(up_frames) == 1
201
+ assert isinstance(up_frames[0], ErrorFrame)
202
+ assert "401" in up_frames[0].error