pipecat-replicate 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pipecat_replicate-0.1.0/.env.example +5 -0
- pipecat_replicate-0.1.0/.github/workflows/publish.yml +17 -0
- pipecat_replicate-0.1.0/.github/workflows/test.yml +22 -0
- pipecat_replicate-0.1.0/.gitignore +16 -0
- pipecat_replicate-0.1.0/CHANGELOG.md +10 -0
- pipecat_replicate-0.1.0/LICENSE +24 -0
- pipecat_replicate-0.1.0/PKG-INFO +138 -0
- pipecat_replicate-0.1.0/README.md +122 -0
- pipecat_replicate-0.1.0/examples/basic_image_gen.py +70 -0
- pipecat_replicate-0.1.0/pyproject.toml +39 -0
- pipecat_replicate-0.1.0/src/pipecat_replicate/__init__.py +5 -0
- pipecat_replicate-0.1.0/src/pipecat_replicate/image.py +356 -0
- pipecat_replicate-0.1.0/tests/__init__.py +0 -0
- pipecat_replicate-0.1.0/tests/test_replicate_image.py +202 -0
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
name: Publish to PyPI
|
|
2
|
+
|
|
3
|
+
on:
|
|
4
|
+
release:
|
|
5
|
+
types: [published]
|
|
6
|
+
|
|
7
|
+
jobs:
|
|
8
|
+
publish:
|
|
9
|
+
runs-on: ubuntu-latest
|
|
10
|
+
environment: pypi
|
|
11
|
+
permissions:
|
|
12
|
+
id-token: write
|
|
13
|
+
steps:
|
|
14
|
+
- uses: actions/checkout@v4
|
|
15
|
+
- uses: astral-sh/setup-uv@v6
|
|
16
|
+
- run: uv build
|
|
17
|
+
- uses: pypa/gh-action-pypi-publish@release/v1
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
name: Tests
|
|
2
|
+
|
|
3
|
+
on:
|
|
4
|
+
push:
|
|
5
|
+
branches: [main]
|
|
6
|
+
pull_request:
|
|
7
|
+
branches: [main]
|
|
8
|
+
|
|
9
|
+
jobs:
|
|
10
|
+
test:
|
|
11
|
+
runs-on: ubuntu-latest
|
|
12
|
+
strategy:
|
|
13
|
+
matrix:
|
|
14
|
+
python-version: ["3.10", "3.12", "3.13"]
|
|
15
|
+
steps:
|
|
16
|
+
- uses: actions/checkout@v4
|
|
17
|
+
- uses: astral-sh/setup-uv@v6
|
|
18
|
+
- run: uv python install ${{ matrix.python-version }}
|
|
19
|
+
- run: uv sync --group dev --python ${{ matrix.python-version }}
|
|
20
|
+
- run: uv run ruff check
|
|
21
|
+
- run: uv run ruff format --check
|
|
22
|
+
- run: uv run pytest tests/ -v
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
# Changelog
|
|
2
|
+
|
|
3
|
+
## 0.1.0 (2026-04-15)
|
|
4
|
+
|
|
5
|
+
- Initial release
|
|
6
|
+
- Support for official Replicate models (`owner/name`) and versioned models (`owner/name:version`)
|
|
7
|
+
- Sync prediction requests with polling fallback
|
|
8
|
+
- Returns `URLImageRawFrame` on success
|
|
9
|
+
- Metrics support via `start_ttfb_metrics` / `stop_ttfb_metrics`
|
|
10
|
+
- Tested with Pipecat v0.0.108
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
BSD 2-Clause License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026, Borislav Novikov
|
|
4
|
+
|
|
5
|
+
Redistribution and use in source and binary forms, with or without
|
|
6
|
+
modification, are permitted provided that the following conditions are met:
|
|
7
|
+
|
|
8
|
+
1. Redistributions of source code must retain the above copyright notice, this
|
|
9
|
+
list of conditions and the following disclaimer.
|
|
10
|
+
|
|
11
|
+
2. Redistributions in binary form must reproduce the above copyright notice,
|
|
12
|
+
this list of conditions and the following disclaimer in the documentation
|
|
13
|
+
and/or other materials provided with the distribution.
|
|
14
|
+
|
|
15
|
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
16
|
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
17
|
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
18
|
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
19
|
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
20
|
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
21
|
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
22
|
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
23
|
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
24
|
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: pipecat-replicate
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Replicate image generation integration for Pipecat
|
|
5
|
+
Project-URL: Repository, https://github.com/bnovik0v/pipecat-replicate
|
|
6
|
+
Project-URL: Issues, https://github.com/bnovik0v/pipecat-replicate/issues
|
|
7
|
+
Author-email: Borislav Novikov <borislav@polaro.com>
|
|
8
|
+
License-Expression: BSD-2-Clause
|
|
9
|
+
License-File: LICENSE
|
|
10
|
+
Requires-Python: <3.14,>=3.10
|
|
11
|
+
Requires-Dist: aiohttp>=3.9
|
|
12
|
+
Requires-Dist: loguru>=0.7
|
|
13
|
+
Requires-Dist: pillow>=10.0
|
|
14
|
+
Requires-Dist: pipecat-ai>=0.0.108
|
|
15
|
+
Description-Content-Type: text/markdown
|
|
16
|
+
|
|
17
|
+
# pipecat-replicate
|
|
18
|
+
|
|
19
|
+
[Replicate](https://replicate.com/) image generation integration for [Pipecat](https://github.com/pipecat-ai/pipecat) — a framework for building voice and multimodal conversational AI applications.
|
|
20
|
+
|
|
21
|
+
## Pipecat Compatibility
|
|
22
|
+
|
|
23
|
+
**Tested with Pipecat v0.0.108**
|
|
24
|
+
|
|
25
|
+
## Features
|
|
26
|
+
|
|
27
|
+
- Text-to-image generation using any Replicate-hosted model
|
|
28
|
+
- Official models (`owner/name`) and versioned community models (`owner/name:version`)
|
|
29
|
+
- Sync prediction requests with automatic polling fallback
|
|
30
|
+
- Returns `URLImageRawFrame` for direct use in Pipecat pipelines
|
|
31
|
+
- Configurable via the standard Pipecat `Settings` dataclass pattern
|
|
32
|
+
|
|
33
|
+
## Installation
|
|
34
|
+
|
|
35
|
+
### Using pip
|
|
36
|
+
|
|
37
|
+
```bash
|
|
38
|
+
pip install pipecat-replicate
|
|
39
|
+
```
|
|
40
|
+
|
|
41
|
+
### Using uv
|
|
42
|
+
|
|
43
|
+
```bash
|
|
44
|
+
uv add pipecat-replicate
|
|
45
|
+
```
|
|
46
|
+
|
|
47
|
+
### From source
|
|
48
|
+
|
|
49
|
+
```bash
|
|
50
|
+
git clone https://github.com/bnovik0v/pipecat-replicate.git
|
|
51
|
+
cd pipecat-replicate
|
|
52
|
+
pip install -e .
|
|
53
|
+
```
|
|
54
|
+
|
|
55
|
+
## Quick Start
|
|
56
|
+
|
|
57
|
+
1. Get your Replicate API token at https://replicate.com/account/api-tokens
|
|
58
|
+
|
|
59
|
+
2. Set the environment variable:
|
|
60
|
+
|
|
61
|
+
```bash
|
|
62
|
+
export REPLICATE_API_TOKEN=r8_...
|
|
63
|
+
```
|
|
64
|
+
|
|
65
|
+
3. Use in a Pipecat pipeline:
|
|
66
|
+
|
|
67
|
+
```python
|
|
68
|
+
import aiohttp
|
|
69
|
+
from pipecat.frames.frames import TextFrame
|
|
70
|
+
from pipecat.pipeline.pipeline import Pipeline
|
|
71
|
+
from pipecat.pipeline.runner import PipelineRunner
|
|
72
|
+
from pipecat.pipeline.task import PipelineTask
|
|
73
|
+
from pipecat_replicate import ReplicateImageGenService
|
|
74
|
+
|
|
75
|
+
async with aiohttp.ClientSession() as session:
|
|
76
|
+
imagegen = ReplicateImageGenService(
|
|
77
|
+
aiohttp_session=session,
|
|
78
|
+
settings=ReplicateImageGenService.Settings(
|
|
79
|
+
model="black-forest-labs/flux-schnell",
|
|
80
|
+
aspect_ratio="1:1",
|
|
81
|
+
),
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
pipeline = Pipeline([imagegen, ...])
|
|
85
|
+
task = PipelineTask(pipeline)
|
|
86
|
+
await task.queue_frames([TextFrame("a cat in the style of a screenprint poster")])
|
|
87
|
+
|
|
88
|
+
runner = PipelineRunner()
|
|
89
|
+
await runner.run(task)
|
|
90
|
+
```
|
|
91
|
+
|
|
92
|
+
## Configuration
|
|
93
|
+
|
|
94
|
+
### Settings
|
|
95
|
+
|
|
96
|
+
| Field | Type | Default | Description |
|
|
97
|
+
| ------------------------ | ------ | ---------------------------------- | ----------------------------------------- |
|
|
98
|
+
| `model` | `str` | `"black-forest-labs/flux-schnell"` | Replicate model identifier |
|
|
99
|
+
| `aspect_ratio` | `str` | `"1:1"` | Aspect ratio for generated images |
|
|
100
|
+
| `num_outputs` | `int` | `1` | Number of images to generate (1–4) |
|
|
101
|
+
| `num_inference_steps` | `int` | `4` | Number of denoising steps |
|
|
102
|
+
| `seed` | `int` | `None` | Random seed for reproducible generation |
|
|
103
|
+
| `output_format` | `str` | `"webp"` | Output image format |
|
|
104
|
+
| `output_quality` | `int` | `80` | Output quality (0–100) |
|
|
105
|
+
| `disable_safety_checker` | `bool` | `False` | Whether to disable the model safety check |
|
|
106
|
+
| `go_fast` | `bool` | `True` | Use the model's fast generation mode |
|
|
107
|
+
| `megapixels` | `str` | `"1"` | Approximate megapixel count |
|
|
108
|
+
|
|
109
|
+
### Constructor Parameters
|
|
110
|
+
|
|
111
|
+
| Parameter | Type | Default | Description |
|
|
112
|
+
| -------------------- | ---------------------- | -------------------------------------- | ------------------------------------ |
|
|
113
|
+
| `aiohttp_session` | `aiohttp.ClientSession`| *(required)* | HTTP session for API requests |
|
|
114
|
+
| `api_token` | `str` | `$REPLICATE_API_TOKEN` | Replicate API token |
|
|
115
|
+
| `settings` | `Settings` | *(see defaults above)* | Generation settings |
|
|
116
|
+
| `base_url` | `str` | `"https://api.replicate.com/v1"` | API base URL |
|
|
117
|
+
| `wait_timeout_secs` | `int` | `60` | Sync wait timeout (Prefer header) |
|
|
118
|
+
| `poll_interval_secs` | `float` | `0.5` | Poll interval for async predictions |
|
|
119
|
+
| `max_poll_attempts` | `int` | `120` | Maximum number of polling attempts |
|
|
120
|
+
|
|
121
|
+
## Examples
|
|
122
|
+
|
|
123
|
+
See [`examples/basic_image_gen.py`](examples/basic_image_gen.py) for a complete example that generates an image and displays it in a Tk window.
|
|
124
|
+
|
|
125
|
+
```bash
|
|
126
|
+
REPLICATE_API_TOKEN=r8_... python examples/basic_image_gen.py
|
|
127
|
+
```
|
|
128
|
+
|
|
129
|
+
## Running Tests
|
|
130
|
+
|
|
131
|
+
```bash
|
|
132
|
+
uv sync --group dev
|
|
133
|
+
uv run pytest
|
|
134
|
+
```
|
|
135
|
+
|
|
136
|
+
## License
|
|
137
|
+
|
|
138
|
+
BSD 2-Clause — see [LICENSE](LICENSE).
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
# pipecat-replicate
|
|
2
|
+
|
|
3
|
+
[Replicate](https://replicate.com/) image generation integration for [Pipecat](https://github.com/pipecat-ai/pipecat) — a framework for building voice and multimodal conversational AI applications.
|
|
4
|
+
|
|
5
|
+
## Pipecat Compatibility
|
|
6
|
+
|
|
7
|
+
**Tested with Pipecat v0.0.108**
|
|
8
|
+
|
|
9
|
+
## Features
|
|
10
|
+
|
|
11
|
+
- Text-to-image generation using any Replicate-hosted model
|
|
12
|
+
- Official models (`owner/name`) and versioned community models (`owner/name:version`)
|
|
13
|
+
- Sync prediction requests with automatic polling fallback
|
|
14
|
+
- Returns `URLImageRawFrame` for direct use in Pipecat pipelines
|
|
15
|
+
- Configurable via the standard Pipecat `Settings` dataclass pattern
|
|
16
|
+
|
|
17
|
+
## Installation
|
|
18
|
+
|
|
19
|
+
### Using pip
|
|
20
|
+
|
|
21
|
+
```bash
|
|
22
|
+
pip install pipecat-replicate
|
|
23
|
+
```
|
|
24
|
+
|
|
25
|
+
### Using uv
|
|
26
|
+
|
|
27
|
+
```bash
|
|
28
|
+
uv add pipecat-replicate
|
|
29
|
+
```
|
|
30
|
+
|
|
31
|
+
### From source
|
|
32
|
+
|
|
33
|
+
```bash
|
|
34
|
+
git clone https://github.com/bnovik0v/pipecat-replicate.git
|
|
35
|
+
cd pipecat-replicate
|
|
36
|
+
pip install -e .
|
|
37
|
+
```
|
|
38
|
+
|
|
39
|
+
## Quick Start
|
|
40
|
+
|
|
41
|
+
1. Get your Replicate API token at https://replicate.com/account/api-tokens
|
|
42
|
+
|
|
43
|
+
2. Set the environment variable:
|
|
44
|
+
|
|
45
|
+
```bash
|
|
46
|
+
export REPLICATE_API_TOKEN=r8_...
|
|
47
|
+
```
|
|
48
|
+
|
|
49
|
+
3. Use in a Pipecat pipeline:
|
|
50
|
+
|
|
51
|
+
```python
|
|
52
|
+
import aiohttp
|
|
53
|
+
from pipecat.frames.frames import TextFrame
|
|
54
|
+
from pipecat.pipeline.pipeline import Pipeline
|
|
55
|
+
from pipecat.pipeline.runner import PipelineRunner
|
|
56
|
+
from pipecat.pipeline.task import PipelineTask
|
|
57
|
+
from pipecat_replicate import ReplicateImageGenService
|
|
58
|
+
|
|
59
|
+
async with aiohttp.ClientSession() as session:
|
|
60
|
+
imagegen = ReplicateImageGenService(
|
|
61
|
+
aiohttp_session=session,
|
|
62
|
+
settings=ReplicateImageGenService.Settings(
|
|
63
|
+
model="black-forest-labs/flux-schnell",
|
|
64
|
+
aspect_ratio="1:1",
|
|
65
|
+
),
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
pipeline = Pipeline([imagegen, ...])
|
|
69
|
+
task = PipelineTask(pipeline)
|
|
70
|
+
await task.queue_frames([TextFrame("a cat in the style of a screenprint poster")])
|
|
71
|
+
|
|
72
|
+
runner = PipelineRunner()
|
|
73
|
+
await runner.run(task)
|
|
74
|
+
```
|
|
75
|
+
|
|
76
|
+
## Configuration
|
|
77
|
+
|
|
78
|
+
### Settings
|
|
79
|
+
|
|
80
|
+
| Field | Type | Default | Description |
|
|
81
|
+
| ------------------------ | ------ | ---------------------------------- | ----------------------------------------- |
|
|
82
|
+
| `model` | `str` | `"black-forest-labs/flux-schnell"` | Replicate model identifier |
|
|
83
|
+
| `aspect_ratio` | `str` | `"1:1"` | Aspect ratio for generated images |
|
|
84
|
+
| `num_outputs` | `int` | `1` | Number of images to generate (1–4) |
|
|
85
|
+
| `num_inference_steps` | `int` | `4` | Number of denoising steps |
|
|
86
|
+
| `seed` | `int` | `None` | Random seed for reproducible generation |
|
|
87
|
+
| `output_format` | `str` | `"webp"` | Output image format |
|
|
88
|
+
| `output_quality` | `int` | `80` | Output quality (0–100) |
|
|
89
|
+
| `disable_safety_checker` | `bool` | `False` | Whether to disable the model safety check |
|
|
90
|
+
| `go_fast` | `bool` | `True` | Use the model's fast generation mode |
|
|
91
|
+
| `megapixels` | `str` | `"1"` | Approximate megapixel count |
|
|
92
|
+
|
|
93
|
+
### Constructor Parameters
|
|
94
|
+
|
|
95
|
+
| Parameter | Type | Default | Description |
|
|
96
|
+
| -------------------- | ---------------------- | -------------------------------------- | ------------------------------------ |
|
|
97
|
+
| `aiohttp_session` | `aiohttp.ClientSession`| *(required)* | HTTP session for API requests |
|
|
98
|
+
| `api_token` | `str` | `$REPLICATE_API_TOKEN` | Replicate API token |
|
|
99
|
+
| `settings` | `Settings` | *(see defaults above)* | Generation settings |
|
|
100
|
+
| `base_url` | `str` | `"https://api.replicate.com/v1"` | API base URL |
|
|
101
|
+
| `wait_timeout_secs` | `int` | `60` | Sync wait timeout (Prefer header) |
|
|
102
|
+
| `poll_interval_secs` | `float` | `0.5` | Poll interval for async predictions |
|
|
103
|
+
| `max_poll_attempts` | `int` | `120` | Maximum number of polling attempts |
|
|
104
|
+
|
|
105
|
+
## Examples
|
|
106
|
+
|
|
107
|
+
See [`examples/basic_image_gen.py`](examples/basic_image_gen.py) for a complete example that generates an image and displays it in a Tk window.
|
|
108
|
+
|
|
109
|
+
```bash
|
|
110
|
+
REPLICATE_API_TOKEN=r8_... python examples/basic_image_gen.py
|
|
111
|
+
```
|
|
112
|
+
|
|
113
|
+
## Running Tests
|
|
114
|
+
|
|
115
|
+
```bash
|
|
116
|
+
uv sync --group dev
|
|
117
|
+
uv run pytest
|
|
118
|
+
```
|
|
119
|
+
|
|
120
|
+
## License
|
|
121
|
+
|
|
122
|
+
BSD 2-Clause — see [LICENSE](LICENSE).
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
"""Basic Replicate image generation example with Pipecat.
|
|
2
|
+
|
|
3
|
+
Generates an image using the Replicate API and displays it in a Tk window.
|
|
4
|
+
|
|
5
|
+
Requirements:
|
|
6
|
+
pip install pipecat-replicate pipecat-ai[local]
|
|
7
|
+
|
|
8
|
+
Environment variables:
|
|
9
|
+
REPLICATE_API_TOKEN — your Replicate API token
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import asyncio
|
|
13
|
+
import os
|
|
14
|
+
import sys
|
|
15
|
+
import tkinter as tk
|
|
16
|
+
|
|
17
|
+
import aiohttp
|
|
18
|
+
from dotenv import load_dotenv
|
|
19
|
+
from loguru import logger
|
|
20
|
+
|
|
21
|
+
from pipecat.frames.frames import TextFrame
|
|
22
|
+
from pipecat.pipeline.pipeline import Pipeline
|
|
23
|
+
from pipecat.pipeline.runner import PipelineRunner
|
|
24
|
+
from pipecat.pipeline.task import PipelineTask
|
|
25
|
+
from pipecat.transports.local.tk import TkLocalTransport, TkTransportParams
|
|
26
|
+
from pipecat_replicate import ReplicateImageGenService
|
|
27
|
+
|
|
28
|
+
load_dotenv(override=True)
|
|
29
|
+
|
|
30
|
+
logger.remove(0)
|
|
31
|
+
logger.add(sys.stderr, level="DEBUG")
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
async def main():
|
|
35
|
+
async with aiohttp.ClientSession() as session:
|
|
36
|
+
tk_root = tk.Tk()
|
|
37
|
+
tk_root.title("Replicate Image Gen")
|
|
38
|
+
|
|
39
|
+
transport = TkLocalTransport(
|
|
40
|
+
tk_root,
|
|
41
|
+
TkTransportParams(video_out_enabled=True, video_out_width=1024, video_out_height=1024),
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
imagegen = ReplicateImageGenService(
|
|
45
|
+
settings=ReplicateImageGenService.Settings(
|
|
46
|
+
model="black-forest-labs/flux-schnell",
|
|
47
|
+
aspect_ratio="1:1",
|
|
48
|
+
),
|
|
49
|
+
aiohttp_session=session,
|
|
50
|
+
api_token=os.getenv("REPLICATE_API_TOKEN"),
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
pipeline = Pipeline([imagegen, transport.output()])
|
|
54
|
+
|
|
55
|
+
task = PipelineTask(pipeline)
|
|
56
|
+
await task.queue_frames([TextFrame("a cat in the style of a screenprint poster")])
|
|
57
|
+
|
|
58
|
+
runner = PipelineRunner()
|
|
59
|
+
|
|
60
|
+
async def run_tk():
|
|
61
|
+
while not task.has_finished():
|
|
62
|
+
tk_root.update()
|
|
63
|
+
tk_root.update_idletasks()
|
|
64
|
+
await asyncio.sleep(0.1)
|
|
65
|
+
|
|
66
|
+
await asyncio.gather(runner.run(task), run_tk())
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
if __name__ == "__main__":
|
|
70
|
+
asyncio.run(main())
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["hatchling"]
|
|
3
|
+
build-backend = "hatchling.build"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "pipecat-replicate"
|
|
7
|
+
version = "0.1.0"
|
|
8
|
+
description = "Replicate image generation integration for Pipecat"
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
license = "BSD-2-Clause"
|
|
11
|
+
authors = [
|
|
12
|
+
{ name = "Borislav Novikov", email = "borislav@polaro.com" },
|
|
13
|
+
]
|
|
14
|
+
requires-python = ">=3.10,<3.14"
|
|
15
|
+
dependencies = [
|
|
16
|
+
"pipecat-ai>=0.0.108",
|
|
17
|
+
"aiohttp>=3.9",
|
|
18
|
+
"loguru>=0.7",
|
|
19
|
+
"Pillow>=10.0",
|
|
20
|
+
]
|
|
21
|
+
|
|
22
|
+
[project.urls]
|
|
23
|
+
Repository = "https://github.com/bnovik0v/pipecat-replicate"
|
|
24
|
+
Issues = "https://github.com/bnovik0v/pipecat-replicate/issues"
|
|
25
|
+
|
|
26
|
+
[dependency-groups]
|
|
27
|
+
dev = [
|
|
28
|
+
"pytest>=8",
|
|
29
|
+
"pytest-asyncio>=0.23",
|
|
30
|
+
"pytest-aiohttp>=1.0",
|
|
31
|
+
"ruff>=0.4",
|
|
32
|
+
"websockets>=13.0",
|
|
33
|
+
]
|
|
34
|
+
|
|
35
|
+
[tool.ruff]
|
|
36
|
+
line-length = 100
|
|
37
|
+
|
|
38
|
+
[tool.pytest.ini_options]
|
|
39
|
+
asyncio_mode = "auto"
|
|
@@ -0,0 +1,356 @@
|
|
|
1
|
+
"""Replicate image generation service implementation.
|
|
2
|
+
|
|
3
|
+
This module provides integration with Replicate-hosted image generation models
|
|
4
|
+
for creating images from text prompts.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import asyncio
|
|
8
|
+
import base64
|
|
9
|
+
import io
|
|
10
|
+
import os
|
|
11
|
+
from dataclasses import dataclass, field
|
|
12
|
+
from typing import Any, AsyncGenerator, Dict, Optional
|
|
13
|
+
|
|
14
|
+
import aiohttp
|
|
15
|
+
from loguru import logger
|
|
16
|
+
from PIL import Image
|
|
17
|
+
from pydantic import BaseModel, Field
|
|
18
|
+
|
|
19
|
+
from pipecat.frames.frames import ErrorFrame, Frame, URLImageRawFrame
|
|
20
|
+
from pipecat.services.image_service import ImageGenService
|
|
21
|
+
from pipecat.services.settings import NOT_GIVEN, ImageGenSettings, _NotGiven
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclass
|
|
25
|
+
class ReplicateImageGenSettings(ImageGenSettings):
|
|
26
|
+
"""Settings for the Replicate image generation service.
|
|
27
|
+
|
|
28
|
+
Parameters:
|
|
29
|
+
model: Replicate model identifier. Use ``owner/name`` for official
|
|
30
|
+
models or ``owner/name:version`` for versioned community models.
|
|
31
|
+
aspect_ratio: Aspect ratio for generated images.
|
|
32
|
+
num_outputs: Number of images to generate.
|
|
33
|
+
num_inference_steps: Number of denoising steps for the model.
|
|
34
|
+
seed: Random seed for reproducible generation. ``None`` uses a random seed.
|
|
35
|
+
output_format: Image format requested from the model.
|
|
36
|
+
output_quality: Output quality value supported by the model.
|
|
37
|
+
disable_safety_checker: Whether to disable the model safety checker.
|
|
38
|
+
go_fast: Whether to use the model's faster generation mode.
|
|
39
|
+
megapixels: Approximate megapixel count for generated images.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
aspect_ratio: str | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
|
|
43
|
+
num_outputs: int | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
|
|
44
|
+
num_inference_steps: int | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
|
|
45
|
+
seed: int | None | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
|
|
46
|
+
output_format: str | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
|
|
47
|
+
output_quality: int | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
|
|
48
|
+
disable_safety_checker: bool | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
|
|
49
|
+
go_fast: bool | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
|
|
50
|
+
megapixels: str | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
|
|
51
|
+
|
|
52
|
+
def to_api_input(self) -> Dict[str, Any]:
|
|
53
|
+
"""Build the Replicate input payload from settings."""
|
|
54
|
+
payload: Dict[str, Any] = {
|
|
55
|
+
"aspect_ratio": self.aspect_ratio,
|
|
56
|
+
"num_outputs": self.num_outputs,
|
|
57
|
+
"num_inference_steps": self.num_inference_steps,
|
|
58
|
+
"output_format": self.output_format,
|
|
59
|
+
"output_quality": self.output_quality,
|
|
60
|
+
"disable_safety_checker": self.disable_safety_checker,
|
|
61
|
+
"go_fast": self.go_fast,
|
|
62
|
+
"megapixels": self.megapixels,
|
|
63
|
+
}
|
|
64
|
+
if self.seed is not None:
|
|
65
|
+
payload["seed"] = self.seed
|
|
66
|
+
payload.update(self.extra)
|
|
67
|
+
return payload
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class ReplicateImageGenService(ImageGenService):
|
|
71
|
+
"""Replicate image generation service.
|
|
72
|
+
|
|
73
|
+
Provides text-to-image generation using Replicate-hosted models. Official
|
|
74
|
+
models use an ``owner/name`` identifier. Versioned community models can be
|
|
75
|
+
addressed with ``owner/name:version``.
|
|
76
|
+
"""
|
|
77
|
+
|
|
78
|
+
Settings = ReplicateImageGenSettings
|
|
79
|
+
_settings: Settings
|
|
80
|
+
|
|
81
|
+
class InputParams(BaseModel):
|
|
82
|
+
"""Input parameters for Replicate image generation.
|
|
83
|
+
|
|
84
|
+
.. deprecated:: 0.1.0
|
|
85
|
+
Use ``settings=ReplicateImageGenService.Settings(...)`` instead.
|
|
86
|
+
|
|
87
|
+
Parameters:
|
|
88
|
+
aspect_ratio: Aspect ratio for generated images. Defaults to ``"1:1"``.
|
|
89
|
+
num_outputs: Number of images to generate. Defaults to ``1``.
|
|
90
|
+
num_inference_steps: Number of denoising steps. Defaults to ``4``.
|
|
91
|
+
seed: Random seed for reproducible generation. Defaults to ``None``.
|
|
92
|
+
output_format: Output image format. Defaults to ``"webp"``.
|
|
93
|
+
output_quality: Output quality value. Defaults to ``80``.
|
|
94
|
+
disable_safety_checker: Whether to disable the safety checker. Defaults to ``False``.
|
|
95
|
+
go_fast: Whether to use the fast model path. Defaults to ``True``.
|
|
96
|
+
megapixels: Approximate megapixel count. Defaults to ``"1"``.
|
|
97
|
+
"""
|
|
98
|
+
|
|
99
|
+
aspect_ratio: str = "1:1"
|
|
100
|
+
num_outputs: int = Field(default=1, ge=1, le=4)
|
|
101
|
+
num_inference_steps: int = Field(default=4, ge=1)
|
|
102
|
+
seed: Optional[int] = None
|
|
103
|
+
output_format: str = "webp"
|
|
104
|
+
output_quality: int = Field(default=80, ge=0, le=100)
|
|
105
|
+
disable_safety_checker: bool = False
|
|
106
|
+
go_fast: bool = True
|
|
107
|
+
megapixels: str = "1"
|
|
108
|
+
|
|
109
|
+
_TERMINAL_ERROR_STATUSES = {"failed", "canceled", "cancelled"}
|
|
110
|
+
|
|
111
|
+
def __init__(
|
|
112
|
+
self,
|
|
113
|
+
*,
|
|
114
|
+
params: Optional[InputParams] = None,
|
|
115
|
+
aiohttp_session: aiohttp.ClientSession,
|
|
116
|
+
api_token: Optional[str] = None,
|
|
117
|
+
model: Optional[str] = None,
|
|
118
|
+
settings: Optional[Settings] = None,
|
|
119
|
+
base_url: str = "https://api.replicate.com/v1",
|
|
120
|
+
wait_timeout_secs: int = 60,
|
|
121
|
+
poll_interval_secs: float = 0.5,
|
|
122
|
+
max_poll_attempts: int = 120,
|
|
123
|
+
**kwargs,
|
|
124
|
+
):
|
|
125
|
+
"""Initialize the ReplicateImageGenService.
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
params: Input parameters for image generation configuration.
|
|
129
|
+
|
|
130
|
+
.. deprecated:: 0.1.0
|
|
131
|
+
Use ``settings=ReplicateImageGenService.Settings(...)`` instead.
|
|
132
|
+
|
|
133
|
+
aiohttp_session: HTTP client session for Replicate requests and image downloads.
|
|
134
|
+
api_token: Optional Replicate API token. If provided, sets the
|
|
135
|
+
``REPLICATE_API_TOKEN`` environment variable.
|
|
136
|
+
model: Replicate model identifier. Defaults to
|
|
137
|
+
``"black-forest-labs/flux-schnell"``.
|
|
138
|
+
|
|
139
|
+
.. deprecated:: 0.1.0
|
|
140
|
+
Use ``settings=ReplicateImageGenService.Settings(model=...)`` instead.
|
|
141
|
+
|
|
142
|
+
settings: Runtime-configurable generation settings. When provided
|
|
143
|
+
alongside deprecated parameters, ``settings`` values take precedence.
|
|
144
|
+
base_url: Base URL for the Replicate HTTP API.
|
|
145
|
+
wait_timeout_secs: Sync wait duration passed in the ``Prefer`` header.
|
|
146
|
+
poll_interval_secs: Poll interval used when the initial sync request
|
|
147
|
+
returns before output is available.
|
|
148
|
+
max_poll_attempts: Maximum number of follow-up prediction polls.
|
|
149
|
+
**kwargs: Additional arguments passed to parent ImageGenService.
|
|
150
|
+
"""
|
|
151
|
+
default_settings = self.Settings(
|
|
152
|
+
model="black-forest-labs/flux-schnell",
|
|
153
|
+
aspect_ratio="1:1",
|
|
154
|
+
num_outputs=1,
|
|
155
|
+
num_inference_steps=4,
|
|
156
|
+
seed=None,
|
|
157
|
+
output_format="webp",
|
|
158
|
+
output_quality=80,
|
|
159
|
+
disable_safety_checker=False,
|
|
160
|
+
go_fast=True,
|
|
161
|
+
megapixels="1",
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
if model is not None:
|
|
165
|
+
self._warn_init_param_moved_to_settings("model", "model")
|
|
166
|
+
default_settings.model = model
|
|
167
|
+
|
|
168
|
+
if params is not None:
|
|
169
|
+
self._warn_init_param_moved_to_settings("params")
|
|
170
|
+
if not settings:
|
|
171
|
+
default_settings.aspect_ratio = params.aspect_ratio
|
|
172
|
+
default_settings.num_outputs = params.num_outputs
|
|
173
|
+
default_settings.num_inference_steps = params.num_inference_steps
|
|
174
|
+
default_settings.seed = params.seed
|
|
175
|
+
default_settings.output_format = params.output_format
|
|
176
|
+
default_settings.output_quality = params.output_quality
|
|
177
|
+
default_settings.disable_safety_checker = params.disable_safety_checker
|
|
178
|
+
default_settings.go_fast = params.go_fast
|
|
179
|
+
default_settings.megapixels = params.megapixels
|
|
180
|
+
|
|
181
|
+
if settings is not None:
|
|
182
|
+
default_settings.apply_update(settings)
|
|
183
|
+
|
|
184
|
+
super().__init__(settings=default_settings, **kwargs)
|
|
185
|
+
self._aiohttp_session = aiohttp_session
|
|
186
|
+
self._api_token = api_token or os.getenv("REPLICATE_API_TOKEN", "")
|
|
187
|
+
self._base_url = base_url.rstrip("/")
|
|
188
|
+
self._wait_timeout_secs = wait_timeout_secs
|
|
189
|
+
self._poll_interval_secs = poll_interval_secs
|
|
190
|
+
self._max_poll_attempts = max_poll_attempts
|
|
191
|
+
if api_token:
|
|
192
|
+
os.environ["REPLICATE_API_TOKEN"] = api_token
|
|
193
|
+
|
|
194
|
+
def _prediction_request(self, prompt: str) -> tuple[str, dict[str, str], dict[str, Any]]:
|
|
195
|
+
"""Build the Replicate prediction URL, headers, and request body."""
|
|
196
|
+
model = self._settings.model or ""
|
|
197
|
+
if "/" not in model:
|
|
198
|
+
raise ValueError("Replicate model must use 'owner/name' or 'owner/name:version' format")
|
|
199
|
+
|
|
200
|
+
input_payload = {"prompt": prompt, **self._settings.to_api_input()}
|
|
201
|
+
headers = {
|
|
202
|
+
"Authorization": f"Bearer {self._api_token}",
|
|
203
|
+
"Content-Type": "application/json",
|
|
204
|
+
"Prefer": f"wait={self._wait_timeout_secs}",
|
|
205
|
+
}
|
|
206
|
+
|
|
207
|
+
if ":" in model:
|
|
208
|
+
_, version = model.rsplit(":", maxsplit=1)
|
|
209
|
+
if not version:
|
|
210
|
+
raise ValueError("Versioned Replicate models must use 'owner/name:version' format")
|
|
211
|
+
return (
|
|
212
|
+
f"{self._base_url}/predictions",
|
|
213
|
+
headers,
|
|
214
|
+
{"version": version, "input": input_payload},
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
owner, name = model.split("/", maxsplit=1)
|
|
218
|
+
return (
|
|
219
|
+
f"{self._base_url}/models/{owner}/{name}/predictions",
|
|
220
|
+
headers,
|
|
221
|
+
{"input": input_payload},
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
def _extract_output_urls(self, prediction: Dict[str, Any]) -> list[str]:
|
|
225
|
+
"""Extract output URLs from a Replicate prediction response."""
|
|
226
|
+
output = prediction.get("output")
|
|
227
|
+
if output is None:
|
|
228
|
+
return []
|
|
229
|
+
if isinstance(output, str):
|
|
230
|
+
return [output]
|
|
231
|
+
if isinstance(output, dict):
|
|
232
|
+
url = output.get("url")
|
|
233
|
+
return [url] if isinstance(url, str) else []
|
|
234
|
+
|
|
235
|
+
urls: list[str] = []
|
|
236
|
+
if isinstance(output, list):
|
|
237
|
+
for item in output:
|
|
238
|
+
if isinstance(item, str):
|
|
239
|
+
urls.append(item)
|
|
240
|
+
elif isinstance(item, dict) and isinstance(item.get("url"), str):
|
|
241
|
+
urls.append(item["url"])
|
|
242
|
+
return urls
|
|
243
|
+
|
|
244
|
+
async def _poll_prediction(
|
|
245
|
+
self,
|
|
246
|
+
prediction: Dict[str, Any],
|
|
247
|
+
headers: Dict[str, str],
|
|
248
|
+
*,
|
|
249
|
+
fallback_url: str | None,
|
|
250
|
+
) -> Dict[str, Any]:
|
|
251
|
+
"""Poll the prediction endpoint until output is available or it fails."""
|
|
252
|
+
output_urls = self._extract_output_urls(prediction)
|
|
253
|
+
if output_urls:
|
|
254
|
+
return prediction
|
|
255
|
+
|
|
256
|
+
prediction_url = prediction.get("urls", {}).get("get") or fallback_url
|
|
257
|
+
|
|
258
|
+
for _ in range(self._max_poll_attempts):
|
|
259
|
+
status = str(prediction.get("status", "")).lower()
|
|
260
|
+
if status in self._TERMINAL_ERROR_STATUSES:
|
|
261
|
+
error = prediction.get("error") or "prediction failed"
|
|
262
|
+
raise RuntimeError(f"Replicate prediction failed: {error}")
|
|
263
|
+
|
|
264
|
+
if not prediction_url:
|
|
265
|
+
return prediction
|
|
266
|
+
|
|
267
|
+
await asyncio.sleep(self._poll_interval_secs)
|
|
268
|
+
async with self._aiohttp_session.get(prediction_url, headers=headers) as response:
|
|
269
|
+
if response.status != 200:
|
|
270
|
+
error_text = await response.text()
|
|
271
|
+
raise RuntimeError(f"Replicate polling error ({response.status}): {error_text}")
|
|
272
|
+
prediction = await response.json()
|
|
273
|
+
output_urls = self._extract_output_urls(prediction)
|
|
274
|
+
if output_urls:
|
|
275
|
+
return prediction
|
|
276
|
+
|
|
277
|
+
raise TimeoutError("Replicate image generation timed out while waiting for output")
|
|
278
|
+
|
|
279
|
+
@staticmethod
|
|
280
|
+
def _load_image_frame(image_url: str | None, encoded_image: bytes) -> URLImageRawFrame:
|
|
281
|
+
"""Decode image bytes and build a Pipecat image frame."""
|
|
282
|
+
image = Image.open(io.BytesIO(encoded_image))
|
|
283
|
+
return URLImageRawFrame(
|
|
284
|
+
url=image_url,
|
|
285
|
+
image=image.tobytes(),
|
|
286
|
+
size=image.size,
|
|
287
|
+
format=image.format,
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
@staticmethod
|
|
291
|
+
def _decode_data_url(data_url: str) -> bytes:
|
|
292
|
+
"""Decode a ``data:`` URL into raw bytes."""
|
|
293
|
+
_, _, data = data_url.partition(",")
|
|
294
|
+
if not data:
|
|
295
|
+
raise ValueError("Replicate returned an invalid data URL")
|
|
296
|
+
return base64.b64decode(data)
|
|
297
|
+
|
|
298
|
+
async def run_image_gen(self, prompt: str) -> AsyncGenerator[Frame, None]:
|
|
299
|
+
"""Generate images from a text prompt using Replicate.
|
|
300
|
+
|
|
301
|
+
Args:
|
|
302
|
+
prompt: The text prompt to generate images from.
|
|
303
|
+
|
|
304
|
+
Yields:
|
|
305
|
+
URLImageRawFrame: Frame containing generated image data and metadata.
|
|
306
|
+
ErrorFrame: If image generation fails.
|
|
307
|
+
"""
|
|
308
|
+
logger.debug(f"Generating image from prompt with Replicate: {prompt}")
|
|
309
|
+
await self.start_ttfb_metrics()
|
|
310
|
+
|
|
311
|
+
try:
|
|
312
|
+
url, headers, payload = self._prediction_request(prompt)
|
|
313
|
+
async with self._aiohttp_session.post(url, json=payload, headers=headers) as response:
|
|
314
|
+
if response.status not in {200, 201}:
|
|
315
|
+
error_text = await response.text()
|
|
316
|
+
await self.stop_ttfb_metrics()
|
|
317
|
+
yield ErrorFrame(error=f"Replicate API error ({response.status}): {error_text}")
|
|
318
|
+
return
|
|
319
|
+
prediction = await response.json()
|
|
320
|
+
fallback_url = response.headers.get("Location")
|
|
321
|
+
|
|
322
|
+
prediction = await self._poll_prediction(
|
|
323
|
+
prediction,
|
|
324
|
+
headers={"Authorization": headers["Authorization"]},
|
|
325
|
+
fallback_url=fallback_url,
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
output_urls = self._extract_output_urls(prediction)
|
|
329
|
+
if not output_urls:
|
|
330
|
+
await self.stop_ttfb_metrics()
|
|
331
|
+
yield ErrorFrame("Replicate image generation failed: no output returned")
|
|
332
|
+
return
|
|
333
|
+
|
|
334
|
+
await self.stop_ttfb_metrics()
|
|
335
|
+
|
|
336
|
+
for image_url in output_urls:
|
|
337
|
+
if image_url.startswith("data:"):
|
|
338
|
+
encoded_image = self._decode_data_url(image_url)
|
|
339
|
+
else:
|
|
340
|
+
async with self._aiohttp_session.get(image_url) as response:
|
|
341
|
+
if response.status != 200:
|
|
342
|
+
error_text = await response.text()
|
|
343
|
+
yield ErrorFrame(
|
|
344
|
+
error=(
|
|
345
|
+
f"Replicate image download error"
|
|
346
|
+
f" ({response.status}): {error_text}"
|
|
347
|
+
)
|
|
348
|
+
)
|
|
349
|
+
continue
|
|
350
|
+
encoded_image = await response.read()
|
|
351
|
+
|
|
352
|
+
frame = await asyncio.to_thread(self._load_image_frame, image_url, encoded_image)
|
|
353
|
+
yield frame
|
|
354
|
+
|
|
355
|
+
except Exception as e:
|
|
356
|
+
yield ErrorFrame(f"Replicate image generation error: {e}")
|
|
File without changes
|
|
@@ -0,0 +1,202 @@
|
|
|
1
|
+
"""Tests for ReplicateImageGenService."""
|
|
2
|
+
|
|
3
|
+
import io
|
|
4
|
+
|
|
5
|
+
import aiohttp
|
|
6
|
+
import pytest
|
|
7
|
+
from aiohttp import web
|
|
8
|
+
from PIL import Image
|
|
9
|
+
|
|
10
|
+
from pipecat.frames.frames import ErrorFrame, TextFrame, URLImageRawFrame
|
|
11
|
+
from pipecat.tests.utils import run_test
|
|
12
|
+
from pipecat_replicate import ReplicateImageGenService
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _make_test_image_bytes(format: str = "PNG") -> bytes:
|
|
16
|
+
image = Image.new("RGB", (2, 2), color=(255, 0, 0))
|
|
17
|
+
buffer = io.BytesIO()
|
|
18
|
+
image.save(buffer, format=format)
|
|
19
|
+
return buffer.getvalue()
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@pytest.mark.asyncio
|
|
23
|
+
async def test_run_replicate_image_success_official_model(aiohttp_client):
|
|
24
|
+
"""Official Replicate models should return image frames from sync predictions."""
|
|
25
|
+
|
|
26
|
+
image_bytes = _make_test_image_bytes()
|
|
27
|
+
|
|
28
|
+
async def prediction_handler(request):
|
|
29
|
+
assert request.headers["Authorization"] == "Bearer test-token"
|
|
30
|
+
assert request.headers["Prefer"] == "wait=60"
|
|
31
|
+
payload = await request.json()
|
|
32
|
+
assert payload["input"]["prompt"] == "a red square"
|
|
33
|
+
assert payload["input"]["aspect_ratio"] == "1:1"
|
|
34
|
+
return web.json_response(
|
|
35
|
+
{
|
|
36
|
+
"status": "processing",
|
|
37
|
+
"output": [str(request.url.with_path("/image.png"))],
|
|
38
|
+
"urls": {"get": str(request.url.with_path("/prediction-status"))},
|
|
39
|
+
}
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
async def image_handler(_request):
|
|
43
|
+
return web.Response(body=image_bytes, content_type="image/png")
|
|
44
|
+
|
|
45
|
+
app = web.Application()
|
|
46
|
+
app.router.add_post("/v1/models/black-forest-labs/flux-schnell/predictions", prediction_handler)
|
|
47
|
+
app.router.add_get("/image.png", image_handler)
|
|
48
|
+
client = await aiohttp_client(app)
|
|
49
|
+
base_url = str(client.make_url("/v1")).rstrip("/")
|
|
50
|
+
|
|
51
|
+
async with aiohttp.ClientSession() as session:
|
|
52
|
+
image_gen = ReplicateImageGenService(
|
|
53
|
+
aiohttp_session=session,
|
|
54
|
+
api_token="test-token",
|
|
55
|
+
base_url=base_url,
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
down_frames, up_frames = await run_test(
|
|
59
|
+
image_gen, frames_to_send=[TextFrame("a red square")]
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
assert not up_frames
|
|
63
|
+
assert isinstance(down_frames[0], TextFrame)
|
|
64
|
+
assert isinstance(down_frames[1], URLImageRawFrame)
|
|
65
|
+
assert down_frames[1].size == (2, 2)
|
|
66
|
+
assert down_frames[1].format == "PNG"
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
@pytest.mark.asyncio
|
|
70
|
+
async def test_run_replicate_image_success_versioned_model(aiohttp_client):
|
|
71
|
+
"""Versioned community models should use the generic predictions endpoint."""
|
|
72
|
+
|
|
73
|
+
image_bytes = _make_test_image_bytes()
|
|
74
|
+
version = "53d5d1586a229bd033e060941789bfb0c177cefd5ef638f34b3099658343a897"
|
|
75
|
+
|
|
76
|
+
async def prediction_handler(request):
|
|
77
|
+
payload = await request.json()
|
|
78
|
+
assert payload["version"] == version
|
|
79
|
+
assert payload["input"]["prompt"] == "a blue square"
|
|
80
|
+
return web.json_response(
|
|
81
|
+
{
|
|
82
|
+
"status": "successful",
|
|
83
|
+
"output": [str(request.url.with_path("/image-versioned.png"))],
|
|
84
|
+
}
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
async def image_handler(_request):
|
|
88
|
+
return web.Response(body=image_bytes, content_type="image/png")
|
|
89
|
+
|
|
90
|
+
app = web.Application()
|
|
91
|
+
app.router.add_post("/v1/predictions", prediction_handler)
|
|
92
|
+
app.router.add_get("/image-versioned.png", image_handler)
|
|
93
|
+
client = await aiohttp_client(app)
|
|
94
|
+
base_url = str(client.make_url("/v1")).rstrip("/")
|
|
95
|
+
|
|
96
|
+
async with aiohttp.ClientSession() as session:
|
|
97
|
+
image_gen = ReplicateImageGenService(
|
|
98
|
+
aiohttp_session=session,
|
|
99
|
+
api_token="test-token",
|
|
100
|
+
base_url=base_url,
|
|
101
|
+
settings=ReplicateImageGenService.Settings(
|
|
102
|
+
model=f"black-forest-labs/flux-schnell:{version}",
|
|
103
|
+
aspect_ratio="1:1",
|
|
104
|
+
num_outputs=1,
|
|
105
|
+
num_inference_steps=4,
|
|
106
|
+
seed=None,
|
|
107
|
+
output_format="webp",
|
|
108
|
+
output_quality=80,
|
|
109
|
+
disable_safety_checker=False,
|
|
110
|
+
go_fast=True,
|
|
111
|
+
megapixels="1",
|
|
112
|
+
),
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
down_frames, up_frames = await run_test(
|
|
116
|
+
image_gen, frames_to_send=[TextFrame("a blue square")]
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
assert not up_frames
|
|
120
|
+
assert isinstance(down_frames[1], URLImageRawFrame)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
@pytest.mark.asyncio
|
|
124
|
+
async def test_run_replicate_image_polls_when_sync_response_has_no_output(aiohttp_client):
|
|
125
|
+
"""The service should poll the prediction URL if sync mode returns early."""
|
|
126
|
+
|
|
127
|
+
image_bytes = _make_test_image_bytes()
|
|
128
|
+
poll_count = 0
|
|
129
|
+
|
|
130
|
+
async def prediction_handler(request):
|
|
131
|
+
return web.json_response(
|
|
132
|
+
{
|
|
133
|
+
"status": "processing",
|
|
134
|
+
"output": None,
|
|
135
|
+
"urls": {"get": str(request.url.with_path("/v1/predictions/test-id"))},
|
|
136
|
+
}
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
async def prediction_status_handler(request):
|
|
140
|
+
nonlocal poll_count
|
|
141
|
+
poll_count += 1
|
|
142
|
+
return web.json_response(
|
|
143
|
+
{
|
|
144
|
+
"status": "processing",
|
|
145
|
+
"output": [str(request.url.with_path("/image-polled.png"))],
|
|
146
|
+
}
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
async def image_handler(_request):
|
|
150
|
+
return web.Response(body=image_bytes, content_type="image/png")
|
|
151
|
+
|
|
152
|
+
app = web.Application()
|
|
153
|
+
app.router.add_post("/v1/models/black-forest-labs/flux-schnell/predictions", prediction_handler)
|
|
154
|
+
app.router.add_get("/v1/predictions/test-id", prediction_status_handler)
|
|
155
|
+
app.router.add_get("/image-polled.png", image_handler)
|
|
156
|
+
client = await aiohttp_client(app)
|
|
157
|
+
base_url = str(client.make_url("/v1")).rstrip("/")
|
|
158
|
+
|
|
159
|
+
async with aiohttp.ClientSession() as session:
|
|
160
|
+
image_gen = ReplicateImageGenService(
|
|
161
|
+
aiohttp_session=session,
|
|
162
|
+
api_token="test-token",
|
|
163
|
+
base_url=base_url,
|
|
164
|
+
poll_interval_secs=0.001,
|
|
165
|
+
max_poll_attempts=2,
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
down_frames, up_frames = await run_test(image_gen, frames_to_send=[TextFrame("poll me")])
|
|
169
|
+
|
|
170
|
+
assert not up_frames
|
|
171
|
+
assert poll_count == 1
|
|
172
|
+
assert isinstance(down_frames[1], URLImageRawFrame)
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
@pytest.mark.asyncio
|
|
176
|
+
async def test_run_replicate_image_error(aiohttp_client):
|
|
177
|
+
"""Non-success responses should propagate an ErrorFrame upstream."""
|
|
178
|
+
|
|
179
|
+
async def prediction_handler(_request):
|
|
180
|
+
return web.Response(status=401, text="unauthorized")
|
|
181
|
+
|
|
182
|
+
app = web.Application()
|
|
183
|
+
app.router.add_post("/v1/models/black-forest-labs/flux-schnell/predictions", prediction_handler)
|
|
184
|
+
client = await aiohttp_client(app)
|
|
185
|
+
base_url = str(client.make_url("/v1")).rstrip("/")
|
|
186
|
+
|
|
187
|
+
async with aiohttp.ClientSession() as session:
|
|
188
|
+
image_gen = ReplicateImageGenService(
|
|
189
|
+
aiohttp_session=session,
|
|
190
|
+
api_token="bad-token",
|
|
191
|
+
base_url=base_url,
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
down_frames, up_frames = await run_test(
|
|
195
|
+
image_gen,
|
|
196
|
+
frames_to_send=[TextFrame("this should fail")],
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
assert isinstance(down_frames[0], TextFrame)
|
|
200
|
+
assert len(up_frames) == 1
|
|
201
|
+
assert isinstance(up_frames[0], ErrorFrame)
|
|
202
|
+
assert "401" in up_frames[0].error
|