yamlgraph 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of yamlgraph might be problematic. Click here for more details.

Files changed (111) hide show
  1. examples/__init__.py +1 -0
  2. examples/storyboard/__init__.py +1 -0
  3. examples/storyboard/generate_videos.py +335 -0
  4. examples/storyboard/nodes/__init__.py +10 -0
  5. examples/storyboard/nodes/animated_character_node.py +248 -0
  6. examples/storyboard/nodes/animated_image_node.py +138 -0
  7. examples/storyboard/nodes/character_node.py +162 -0
  8. examples/storyboard/nodes/image_node.py +118 -0
  9. examples/storyboard/nodes/replicate_tool.py +238 -0
  10. examples/storyboard/retry_images.py +118 -0
  11. tests/__init__.py +1 -0
  12. tests/conftest.py +178 -0
  13. tests/integration/__init__.py +1 -0
  14. tests/integration/test_animated_storyboard.py +63 -0
  15. tests/integration/test_cli_commands.py +242 -0
  16. tests/integration/test_map_demo.py +50 -0
  17. tests/integration/test_memory_demo.py +281 -0
  18. tests/integration/test_pipeline_flow.py +105 -0
  19. tests/integration/test_providers.py +163 -0
  20. tests/integration/test_resume.py +75 -0
  21. tests/unit/__init__.py +1 -0
  22. tests/unit/test_agent_nodes.py +200 -0
  23. tests/unit/test_checkpointer.py +212 -0
  24. tests/unit/test_cli.py +121 -0
  25. tests/unit/test_cli_package.py +81 -0
  26. tests/unit/test_compile_graph_map.py +132 -0
  27. tests/unit/test_conditions_routing.py +253 -0
  28. tests/unit/test_config.py +93 -0
  29. tests/unit/test_conversation_memory.py +270 -0
  30. tests/unit/test_database.py +145 -0
  31. tests/unit/test_deprecation.py +104 -0
  32. tests/unit/test_executor.py +60 -0
  33. tests/unit/test_executor_async.py +179 -0
  34. tests/unit/test_export.py +150 -0
  35. tests/unit/test_expressions.py +178 -0
  36. tests/unit/test_format_prompt.py +145 -0
  37. tests/unit/test_generic_report.py +200 -0
  38. tests/unit/test_graph_commands.py +327 -0
  39. tests/unit/test_graph_loader.py +299 -0
  40. tests/unit/test_graph_schema.py +193 -0
  41. tests/unit/test_inline_schema.py +151 -0
  42. tests/unit/test_issues.py +164 -0
  43. tests/unit/test_jinja2_prompts.py +85 -0
  44. tests/unit/test_langsmith.py +319 -0
  45. tests/unit/test_llm_factory.py +109 -0
  46. tests/unit/test_llm_factory_async.py +118 -0
  47. tests/unit/test_loops.py +403 -0
  48. tests/unit/test_map_node.py +144 -0
  49. tests/unit/test_no_backward_compat.py +56 -0
  50. tests/unit/test_node_factory.py +225 -0
  51. tests/unit/test_prompts.py +166 -0
  52. tests/unit/test_python_nodes.py +198 -0
  53. tests/unit/test_reliability.py +298 -0
  54. tests/unit/test_result_export.py +234 -0
  55. tests/unit/test_router.py +296 -0
  56. tests/unit/test_sanitize.py +99 -0
  57. tests/unit/test_schema_loader.py +295 -0
  58. tests/unit/test_shell_tools.py +229 -0
  59. tests/unit/test_state_builder.py +331 -0
  60. tests/unit/test_state_builder_map.py +104 -0
  61. tests/unit/test_state_config.py +197 -0
  62. tests/unit/test_template.py +190 -0
  63. tests/unit/test_tool_nodes.py +129 -0
  64. yamlgraph/__init__.py +35 -0
  65. yamlgraph/builder.py +110 -0
  66. yamlgraph/cli/__init__.py +139 -0
  67. yamlgraph/cli/__main__.py +6 -0
  68. yamlgraph/cli/commands.py +232 -0
  69. yamlgraph/cli/deprecation.py +92 -0
  70. yamlgraph/cli/graph_commands.py +382 -0
  71. yamlgraph/cli/validators.py +37 -0
  72. yamlgraph/config.py +67 -0
  73. yamlgraph/constants.py +66 -0
  74. yamlgraph/error_handlers.py +226 -0
  75. yamlgraph/executor.py +275 -0
  76. yamlgraph/executor_async.py +122 -0
  77. yamlgraph/graph_loader.py +337 -0
  78. yamlgraph/map_compiler.py +138 -0
  79. yamlgraph/models/__init__.py +36 -0
  80. yamlgraph/models/graph_schema.py +141 -0
  81. yamlgraph/models/schemas.py +124 -0
  82. yamlgraph/models/state_builder.py +236 -0
  83. yamlgraph/node_factory.py +240 -0
  84. yamlgraph/routing.py +87 -0
  85. yamlgraph/schema_loader.py +160 -0
  86. yamlgraph/storage/__init__.py +17 -0
  87. yamlgraph/storage/checkpointer.py +72 -0
  88. yamlgraph/storage/database.py +320 -0
  89. yamlgraph/storage/export.py +269 -0
  90. yamlgraph/tools/__init__.py +1 -0
  91. yamlgraph/tools/agent.py +235 -0
  92. yamlgraph/tools/nodes.py +124 -0
  93. yamlgraph/tools/python_tool.py +178 -0
  94. yamlgraph/tools/shell.py +205 -0
  95. yamlgraph/utils/__init__.py +47 -0
  96. yamlgraph/utils/conditions.py +157 -0
  97. yamlgraph/utils/expressions.py +111 -0
  98. yamlgraph/utils/langsmith.py +308 -0
  99. yamlgraph/utils/llm_factory.py +118 -0
  100. yamlgraph/utils/llm_factory_async.py +105 -0
  101. yamlgraph/utils/logging.py +127 -0
  102. yamlgraph/utils/prompts.py +116 -0
  103. yamlgraph/utils/sanitize.py +98 -0
  104. yamlgraph/utils/template.py +102 -0
  105. yamlgraph/utils/validators.py +181 -0
  106. yamlgraph-0.1.1.dist-info/METADATA +854 -0
  107. yamlgraph-0.1.1.dist-info/RECORD +111 -0
  108. yamlgraph-0.1.1.dist-info/WHEEL +5 -0
  109. yamlgraph-0.1.1.dist-info/entry_points.txt +2 -0
  110. yamlgraph-0.1.1.dist-info/licenses/LICENSE +21 -0
  111. yamlgraph-0.1.1.dist-info/top_level.txt +3 -0
@@ -0,0 +1,138 @@
1
+ """Animated storyboard node for generating frame images.
2
+
3
+ Generates 3 images per panel: first_frame, original, last_frame.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ import json
9
+ import logging
10
+ from datetime import datetime
11
+ from pathlib import Path
12
+ from typing import Any
13
+
14
+ from .replicate_tool import generate_image
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ GraphState = dict[str, Any]
19
+
20
+
21
+ def generate_animated_images_node(state: GraphState) -> dict:
22
+ """Generate images for all animated panel frames.
23
+
24
+ Reads animated_panels from state and generates 3 images per panel.
25
+ Saves to outputs/storyboard/{thread_id}/animated/ directory.
26
+
27
+ Args:
28
+ state: Graph state with 'animated_panels' list of {first_frame, original, last_frame}
29
+
30
+ Returns:
31
+ State update with 'images' list organized by panel
32
+ """
33
+ animated_panels = state.get("animated_panels", [])
34
+ if not animated_panels:
35
+ logger.error("No animated_panels in state")
36
+ return {
37
+ "current_step": "generate_animated_images",
38
+ "images": [],
39
+ "error": "No animated panels to generate",
40
+ }
41
+
42
+ # Sort by _map_index if present to maintain order
43
+ if animated_panels and isinstance(animated_panels[0], dict):
44
+ animated_panels = sorted(
45
+ animated_panels,
46
+ key=lambda x: x.get("_map_index", 0) if isinstance(x, dict) else 0,
47
+ )
48
+
49
+ # Create output directory
50
+ thread_id = state.get("thread_id", datetime.now().strftime("%Y%m%d_%H%M%S"))
51
+ output_dir = Path("outputs/storyboard") / thread_id / "animated"
52
+ output_dir.mkdir(parents=True, exist_ok=True)
53
+
54
+ total_images = len(animated_panels) * 3
55
+ logger.info(
56
+ f"🎬 Generating {total_images} images ({len(animated_panels)} panels × 3 frames)"
57
+ )
58
+
59
+ # Get model selection from state (default: z-image)
60
+ model_name = state.get("model", "z-image")
61
+ logger.info(f"🖼️ Using model: {model_name}")
62
+
63
+ # Generate images for each panel
64
+ all_results: list[dict] = []
65
+ frame_keys = ["first_frame", "original", "last_frame"]
66
+
67
+ for panel_idx, panel in enumerate(animated_panels, 1):
68
+ # Handle Pydantic model or dict
69
+ if hasattr(panel, "model_dump"):
70
+ panel_dict = panel.model_dump()
71
+ elif isinstance(panel, dict):
72
+ panel_dict = panel
73
+ else:
74
+ logger.warning(f"Panel {panel_idx} has unexpected type: {type(panel)}")
75
+ continue
76
+
77
+ panel_result = {"panel": panel_idx, "frames": {}}
78
+
79
+ for frame_key in frame_keys:
80
+ prompt = panel_dict.get(frame_key, "")
81
+ if not prompt:
82
+ logger.warning(f"Panel {panel_idx} missing {frame_key}")
83
+ continue
84
+
85
+ output_path = output_dir / f"panel_{panel_idx}_{frame_key}.png"
86
+ logger.info(f"📸 Panel {panel_idx} {frame_key}: {prompt[:50]}...")
87
+
88
+ result = generate_image(prompt, output_path, model_name=model_name)
89
+
90
+ if result.success and result.path:
91
+ panel_result["frames"][frame_key] = result.path
92
+ else:
93
+ logger.error(f"Panel {panel_idx} {frame_key} failed: {result.error}")
94
+ panel_result["frames"][frame_key] = None
95
+
96
+ all_results.append(panel_result)
97
+
98
+ # Save metadata
99
+ story = state.get("story", {})
100
+ if hasattr(story, "model_dump"):
101
+ story_dict = story.model_dump()
102
+ elif isinstance(story, dict):
103
+ story_dict = story
104
+ else:
105
+ story_dict = {}
106
+
107
+ metadata_path = output_dir / "animated_story.json"
108
+ metadata = {
109
+ "concept": state.get("concept", ""),
110
+ "title": story_dict.get("title", ""),
111
+ "narrative": story_dict.get("narrative", ""),
112
+ "panels": [
113
+ {
114
+ "index": r["panel"],
115
+ "frames": r["frames"],
116
+ "prompts": {
117
+ k: animated_panels[r["panel"] - 1].get(k, "")
118
+ if isinstance(animated_panels[r["panel"] - 1], dict)
119
+ else ""
120
+ for k in frame_keys
121
+ },
122
+ }
123
+ for r in all_results
124
+ ],
125
+ "generated_at": datetime.now().isoformat(),
126
+ }
127
+ metadata_path.write_text(json.dumps(metadata, indent=2))
128
+ logger.info(f"📝 Metadata saved: {metadata_path}")
129
+
130
+ # Count successes
131
+ success_count = sum(1 for r in all_results for path in r["frames"].values() if path)
132
+ logger.info(f"✅ Generated {success_count}/{total_images} images")
133
+
134
+ return {
135
+ "current_step": "generate_animated_images",
136
+ "images": all_results,
137
+ "output_dir": str(output_dir),
138
+ }
@@ -0,0 +1,162 @@
1
+ """Character-consistent storyboard node.
2
+
3
+ This node:
4
+ 1. Generates a character image from description (step 0)
5
+ 2. Uses image-to-image editing to place character in each panel scene
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import json
11
+ import logging
12
+ from datetime import datetime
13
+ from pathlib import Path
14
+ from typing import Any
15
+
16
+ from .replicate_tool import ImageResult, edit_image, generate_image
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ # Type alias for state
21
+ GraphState = dict[str, Any]
22
+
23
+
24
+ def generate_character_storyboard(state: GraphState) -> dict:
25
+ """Generate character-consistent storyboard images.
26
+
27
+ Step 0: Generate base character image from character_prompt
28
+ Panels 1-3: Use image-to-image to place character in each scene
29
+
30
+ Args:
31
+ state: Graph state with 'story' containing character and panel prompts
32
+
33
+ Returns:
34
+ State update with 'images' list and metadata
35
+ """
36
+ story = state.get("story")
37
+ if not story:
38
+ logger.error("No story in state")
39
+ return {
40
+ "current_step": "generate_character_storyboard",
41
+ "images": [],
42
+ "error": "No story in state",
43
+ }
44
+
45
+ # Handle Pydantic model or dict
46
+ if hasattr(story, "model_dump"):
47
+ story_dict = story.model_dump()
48
+ elif isinstance(story, dict):
49
+ story_dict = story
50
+ else:
51
+ story_dict = {}
52
+
53
+ # Extract prompts
54
+ character_prompt = story_dict.get("character_prompt", "")
55
+ panels = story_dict.get("panels", [])
56
+
57
+ if not character_prompt:
58
+ logger.error("No character_prompt in story")
59
+ return {
60
+ "current_step": "generate_character_storyboard",
61
+ "images": [],
62
+ "error": "No character_prompt provided",
63
+ }
64
+
65
+ if not panels:
66
+ logger.error("No panels in story")
67
+ return {
68
+ "current_step": "generate_character_storyboard",
69
+ "images": [],
70
+ "error": "No panel prompts provided",
71
+ }
72
+
73
+ # Create output directory
74
+ thread_id = state.get("thread_id", datetime.now().strftime("%Y%m%d_%H%M%S"))
75
+ output_dir = Path("outputs/storyboard") / thread_id
76
+ output_dir.mkdir(parents=True, exist_ok=True)
77
+
78
+ logger.info(f"🎬 Generating character-consistent storyboard in {output_dir}")
79
+
80
+ # Get model selection from state (default: z-image for character)
81
+ model_name = state.get("model", "z-image")
82
+ logger.info(f"🖼️ Using model for character: {model_name}")
83
+
84
+ image_paths: list[str] = []
85
+ results: list[ImageResult] = []
86
+
87
+ # Step 0: Generate base character image
88
+ character_path = output_dir / "character.png"
89
+ logger.info(f"👤 Step 0 - Creating character: {character_prompt[:60]}...")
90
+
91
+ character_result = generate_image(
92
+ prompt=character_prompt,
93
+ output_path=character_path,
94
+ model_name=model_name,
95
+ )
96
+ results.append(character_result)
97
+
98
+ if not character_result.success:
99
+ logger.error(f"Character generation failed: {character_result.error}")
100
+ return {
101
+ "current_step": "generate_character_storyboard",
102
+ "images": [],
103
+ "error": f"Character generation failed: {character_result.error}",
104
+ }
105
+
106
+ image_paths.append(str(character_path))
107
+ logger.info(f"✓ Character created: {character_path}")
108
+
109
+ # Panels 1-3: Image-to-image editing with character as base
110
+ for i, panel_prompt in enumerate(panels[:3], 1): # Max 3 panels
111
+ if not panel_prompt:
112
+ logger.warning(f"Panel {i} has no prompt, skipping")
113
+ continue
114
+
115
+ panel_path = output_dir / f"panel_{i}.png"
116
+ logger.info(f"📸 Panel {i}: {panel_prompt[:60]}...")
117
+
118
+ panel_result = edit_image(
119
+ input_image=character_path,
120
+ prompt=panel_prompt,
121
+ output_path=panel_path,
122
+ aspect_ratio="16:9",
123
+ )
124
+ results.append(panel_result)
125
+
126
+ if panel_result.success and panel_result.path:
127
+ image_paths.append(panel_result.path)
128
+ logger.info(f"✓ Panel {i} created")
129
+ else:
130
+ logger.error(f"Panel {i} failed: {panel_result.error}")
131
+
132
+ # Save metadata
133
+ metadata_path = output_dir / "story.json"
134
+ metadata = {
135
+ "concept": state.get("concept", ""),
136
+ "title": story_dict.get("title", ""),
137
+ "narrative": story_dict.get("narrative", ""),
138
+ "character_prompt": character_prompt,
139
+ "character_image": str(character_path),
140
+ "panels": [
141
+ {
142
+ "prompt": panels[i] if i < len(panels) else "",
143
+ "image": image_paths[i + 1] if i + 1 < len(image_paths) else None,
144
+ }
145
+ for i in range(len(panels[:3]))
146
+ ],
147
+ "generated_at": datetime.now().isoformat(),
148
+ }
149
+ metadata_path.write_text(json.dumps(metadata, indent=2))
150
+ logger.info(f"📝 Metadata saved: {metadata_path}")
151
+
152
+ success_count = sum(1 for r in results if r.success)
153
+ logger.info(
154
+ f"✅ Generated {success_count}/{len(results)} images (1 character + {len(panels[:3])} panels)"
155
+ )
156
+
157
+ return {
158
+ "current_step": "generate_character_storyboard",
159
+ "images": image_paths,
160
+ "character_image": str(character_path),
161
+ "output_dir": str(output_dir),
162
+ }
@@ -0,0 +1,118 @@
1
+ """Storyboard node for generating panel images.
2
+
3
+ This node takes story panels from the LLM and generates images via Replicate.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ import json
9
+ import logging
10
+ from datetime import datetime
11
+ from pathlib import Path
12
+ from typing import Any
13
+
14
+ from .replicate_tool import ImageResult, generate_image
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ # Type alias for state
19
+ GraphState = dict[str, Any]
20
+
21
+
22
+ def generate_images_node(state: GraphState) -> dict:
23
+ """Generate images for each story panel.
24
+
25
+ Reads panel prompts from state.story and generates images.
26
+ Saves to outputs/storyboard/{thread_id}/ directory.
27
+
28
+ Args:
29
+ state: Graph state with 'story' containing panel prompts
30
+
31
+ Returns:
32
+ State update with 'images' list and metadata
33
+ """
34
+ story = state.get("story")
35
+ if not story:
36
+ logger.error("No story in state")
37
+ return {
38
+ "current_step": "generate_images",
39
+ "images": [],
40
+ "error": "No story panels to generate",
41
+ }
42
+
43
+ # Handle Pydantic model or dict
44
+ if hasattr(story, "model_dump"):
45
+ story_dict = story.model_dump()
46
+ elif isinstance(story, dict):
47
+ story_dict = story
48
+ else:
49
+ story_dict = {"panels": [str(story)]}
50
+
51
+ # Extract panel prompts (supports dynamic list)
52
+ panels = story_dict.get("panels", [])
53
+ if not panels:
54
+ # Fallback for legacy panel_1/2/3 format
55
+ panels = [
56
+ story_dict.get("panel_1", ""),
57
+ story_dict.get("panel_2", ""),
58
+ story_dict.get("panel_3", ""),
59
+ ]
60
+ panels = [p for p in panels if p] # Remove empty
61
+
62
+ # Create output directory
63
+ thread_id = state.get("thread_id", datetime.now().strftime("%Y%m%d_%H%M%S"))
64
+ output_dir = Path("outputs/storyboard") / thread_id
65
+ output_dir.mkdir(parents=True, exist_ok=True)
66
+
67
+ logger.info(f"🎬 Generating {len(panels)}-panel storyboard in {output_dir}")
68
+
69
+ # Get model selection from state (default: z-image)
70
+ model_name = state.get("model", "z-image")
71
+ logger.info(f"\ud83d\uddbc\ufe0f Using model: {model_name}")
72
+
73
+ # Generate each panel image
74
+ results: list[ImageResult] = []
75
+ image_paths: list[str] = []
76
+
77
+ for i, prompt in enumerate(panels, 1):
78
+ if not prompt:
79
+ logger.warning(f"Panel {i} has no prompt, skipping")
80
+ continue
81
+
82
+ output_path = output_dir / f"panel_{i}.png"
83
+ logger.info(f"\ud83d\udcf8 Panel {i}: {prompt[:60]}...")
84
+
85
+ result = generate_image(prompt, output_path, model_name=model_name)
86
+ results.append(result)
87
+
88
+ if result.success and result.path:
89
+ image_paths.append(result.path)
90
+ else:
91
+ logger.error(f"Panel {i} failed: {result.error}")
92
+
93
+ # Save story metadata
94
+ metadata_path = output_dir / "story.json"
95
+ metadata = {
96
+ "concept": state.get("concept", ""),
97
+ "title": story_dict.get("title", ""),
98
+ "narrative": story_dict.get("narrative", ""),
99
+ "panels": [
100
+ {
101
+ "prompt": panels[i] if i < len(panels) else "",
102
+ "image": image_paths[i] if i < len(image_paths) else None,
103
+ }
104
+ for i in range(max(len(panels), len(image_paths)))
105
+ ],
106
+ "generated_at": datetime.now().isoformat(),
107
+ }
108
+ metadata_path.write_text(json.dumps(metadata, indent=2))
109
+ logger.info(f"📝 Metadata saved: {metadata_path}")
110
+
111
+ success_count = sum(1 for r in results if r.success)
112
+ logger.info(f"✅ Generated {success_count}/{len(panels)} images")
113
+
114
+ return {
115
+ "current_step": "generate_images",
116
+ "images": image_paths,
117
+ "output_dir": str(output_dir),
118
+ }
@@ -0,0 +1,238 @@
1
+ """Replicate image generation tool for storyboard workflow.
2
+
3
+ Supports multiple models:
4
+ - z-image: Fast, good for realistic/photographic (default)
5
+ - hidream: Better for cartoons, illustrations, stylized art
6
+ - p-image-edit: Image-to-image editing for character consistency
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import logging
12
+ import os
13
+ from dataclasses import dataclass
14
+ from pathlib import Path
15
+
16
+ import httpx
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ # Model configurations
21
+ MODELS = {
22
+ "z-image": {
23
+ "id": "prunaai/z-image-turbo",
24
+ "width": 1344,
25
+ "height": 768,
26
+ "params": {
27
+ "guidance_scale": 0,
28
+ "num_inference_steps": 8,
29
+ },
30
+ },
31
+ "hidream": {
32
+ "id": "prunaai/hidream-l1-fast:f67f0ec7ef9fe91b74e8a68d34efaa9145bec28675cb190cbff8a70f0490256e",
33
+ "resolution": "1360 \u00d7 768 (Landscape)",
34
+ "params": {
35
+ "model_type": "fast",
36
+ "speed_mode": "Juiced \U0001f525 (more speed)",
37
+ },
38
+ },
39
+ }
40
+
41
+ DEFAULT_MODEL = "z-image"
42
+
43
+ # Check if replicate is available
44
+ try:
45
+ import replicate
46
+
47
+ REPLICATE_AVAILABLE = True
48
+ except ImportError:
49
+ REPLICATE_AVAILABLE = False
50
+ logger.warning("replicate package not installed. Run: pip install replicate")
51
+
52
+
53
+ @dataclass
54
+ class ImageResult:
55
+ """Result from image generation."""
56
+
57
+ success: bool
58
+ path: str | None = None
59
+ error: str | None = None
60
+
61
+
62
+ def generate_image(
63
+ prompt: str,
64
+ output_path: str | Path,
65
+ model_name: str = DEFAULT_MODEL,
66
+ ) -> ImageResult:
67
+ """Generate an image using Replicate API.
68
+
69
+ Args:
70
+ prompt: Text prompt for image generation
71
+ output_path: Path to save the generated image
72
+ model_name: Model to use ('z-image' or 'hidream')
73
+
74
+ Returns:
75
+ ImageResult with success status and path or error
76
+ """
77
+ if not REPLICATE_AVAILABLE:
78
+ return ImageResult(success=False, error="replicate package not installed")
79
+
80
+ api_token = os.environ.get("REPLICATE_API_TOKEN")
81
+ if not api_token:
82
+ return ImageResult(
83
+ success=False, error="REPLICATE_API_TOKEN not set in environment"
84
+ )
85
+
86
+ # Get model config
87
+ model_config = MODELS.get(model_name, MODELS[DEFAULT_MODEL])
88
+ model_id = model_config["id"]
89
+
90
+ output_path = Path(output_path)
91
+ output_path.parent.mkdir(parents=True, exist_ok=True)
92
+
93
+ try:
94
+ logger.info(f"🎨 Generating image with {model_name}: {prompt[:50]}...")
95
+
96
+ # Build input params based on model
97
+ if model_name == "hidream":
98
+ input_params = {
99
+ "prompt": prompt,
100
+ "seed": -1,
101
+ "resolution": model_config["resolution"],
102
+ "output_format": "png",
103
+ "output_quality": 80,
104
+ "disable_safety_checker": True,
105
+ **model_config["params"],
106
+ }
107
+ else:
108
+ # z-image and default
109
+ input_params = {
110
+ "prompt": prompt,
111
+ "width": model_config.get("width", 1344),
112
+ "height": model_config.get("height", 768),
113
+ "output_format": "png",
114
+ "output_quality": 80,
115
+ "disable_safety_checker": True,
116
+ **model_config.get("params", {}),
117
+ }
118
+
119
+ # Run the model
120
+ client = replicate.Client(api_token=api_token)
121
+ output = client.run(model_id, input=input_params)
122
+
123
+ # output is typically a URL or file-like object
124
+ image_url = output if isinstance(output, str) else str(output)
125
+
126
+ # Download the image
127
+ logger.info(f"📥 Downloading image to {output_path}")
128
+ response = httpx.get(image_url, timeout=60.0)
129
+ response.raise_for_status()
130
+
131
+ output_path.write_bytes(response.content)
132
+ logger.info(f"✓ Image saved: {output_path}")
133
+
134
+ return ImageResult(success=True, path=str(output_path))
135
+
136
+ except Exception as e:
137
+ logger.error(f"Image generation failed: {e}")
138
+ return ImageResult(success=False, error=str(e))
139
+
140
+
141
+ def edit_image(
142
+ input_image: str | Path,
143
+ prompt: str,
144
+ output_path: str | Path,
145
+ aspect_ratio: str = "16:9",
146
+ turbo: bool = True,
147
+ magic: float | None = None,
148
+ ) -> ImageResult:
149
+ """Edit an image using Replicate p-image-edit model.
150
+
151
+ Uses the input image as base and applies the prompt as modifications.
152
+ Great for maintaining character consistency across panels.
153
+
154
+ Args:
155
+ input_image: Path to the source image
156
+ prompt: Edit instructions (what to change/add)
157
+ output_path: Path to save the edited image
158
+ aspect_ratio: Output aspect ratio (default 16:9)
159
+ turbo: Use turbo mode for faster generation
160
+ magic: Prompt strength 0-1 (lower = more original, higher = more prompt)
161
+
162
+ Returns:
163
+ ImageResult with success status and path or error
164
+ """
165
+ if not REPLICATE_AVAILABLE:
166
+ return ImageResult(success=False, error="replicate package not installed")
167
+
168
+ api_token = os.environ.get("REPLICATE_API_TOKEN")
169
+ if not api_token:
170
+ return ImageResult(
171
+ success=False, error="REPLICATE_API_TOKEN not set in environment"
172
+ )
173
+
174
+ input_image = Path(input_image)
175
+ if not input_image.exists():
176
+ return ImageResult(success=False, error=f"Input image not found: {input_image}")
177
+
178
+ output_path = Path(output_path)
179
+ output_path.parent.mkdir(parents=True, exist_ok=True)
180
+
181
+ try:
182
+ logger.info(f"✏️ Editing image: {prompt[:50]}...")
183
+
184
+ client = replicate.Client(api_token=api_token)
185
+
186
+ with open(input_image, "rb") as f:
187
+ input_params = {
188
+ "turbo": turbo,
189
+ "images": [f],
190
+ "prompt": prompt,
191
+ "aspect_ratio": aspect_ratio,
192
+ "disable_safety_checker": True,
193
+ }
194
+ if magic is not None:
195
+ input_params["magic"] = magic
196
+
197
+ output = client.run(
198
+ "prunaai/p-image-edit",
199
+ input=input_params,
200
+ )
201
+
202
+ # Save the output
203
+ with open(output_path, "wb") as out:
204
+ out.write(output.read())
205
+
206
+ logger.info(f"✓ Edited image saved: {output_path}")
207
+ return ImageResult(success=True, path=str(output_path))
208
+
209
+ except Exception as e:
210
+ logger.error(f"Image editing failed: {e}")
211
+ return ImageResult(success=False, error=str(e))
212
+
213
+
214
+ def generate_storyboard_images(
215
+ panel_prompts: list[str],
216
+ output_dir: str | Path,
217
+ prefix: str = "panel",
218
+ ) -> list[ImageResult]:
219
+ """Generate multiple images for a storyboard.
220
+
221
+ Args:
222
+ panel_prompts: List of prompts for each panel
223
+ output_dir: Directory to save images
224
+ prefix: Filename prefix
225
+
226
+ Returns:
227
+ List of ImageResult for each panel
228
+ """
229
+ output_dir = Path(output_dir)
230
+ output_dir.mkdir(parents=True, exist_ok=True)
231
+
232
+ results = []
233
+ for i, prompt in enumerate(panel_prompts, 1):
234
+ output_path = output_dir / f"{prefix}_{i}.png"
235
+ result = generate_image(prompt, output_path)
236
+ results.append(result)
237
+
238
+ return results