diffusion-prompt-embedder 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,17 @@
1
+ """
2
+ prompt_parser: A library for parsing and processing text prompts with attention weights.
3
+
4
+ This package provides tools for parsing text prompts with attention weights syntax,
5
+ tokenizing prompts, and generating embeddings for use with Stable Diffusion models.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from diffusion_prompt_embedder.core.embedding import get_embeddings_sd15, get_embeddings_sd_15_batch
11
+ from diffusion_prompt_embedder.core.parser import parse_prompt_attention
12
+
13
+ __all__ = [
14
+ "get_embeddings_sd15",
15
+ "get_embeddings_sd_15_batch",
16
+ "parse_prompt_attention",
17
+ ]
@@ -0,0 +1,13 @@
1
+ """
2
+ CLIP model functionality for embedding generation.
3
+ """
4
+
5
+ from diffusion_prompt_embedder.clip.tokenization import (
6
+ get_prompts_tokens_with_weights,
7
+ group_tokens_and_weights,
8
+ )
9
+
10
+ __all__ = [
11
+ "get_prompts_tokens_with_weights",
12
+ "group_tokens_and_weights",
13
+ ]
@@ -0,0 +1,123 @@
1
+ from transformers import CLIPTokenizer
2
+
3
+ from diffusion_prompt_embedder.core.parser import parse_prompt_attention
4
+
5
+
6
+ def group_tokens_and_weights(
7
+ token_ids: list[int],
8
+ weights: list[float],
9
+ *,
10
+ pad_last_block: bool = True,
11
+ ) -> tuple[list[list[int]], list[list[float]]]:
12
+ """
13
+ Group tokenized IDs and weights into CLIP-compatible chunks of 77 tokens.
14
+
15
+ This function takes tokenized IDs and their corresponding weights, then groups them
16
+ into chunks of 77 tokens (75 content tokens + BOS and EOS tokens). The last block
17
+ can be padded with EOS tokens based on the pad_last_block parameter.
18
+
19
+ Args:
20
+ token_ids (list): Token IDs generated from the CLIP tokenizer
21
+ weights (list): Corresponding weights for each token
22
+ pad_last_block (bool): Whether to pad the last block to 75 tokens with EOS tokens
23
+
24
+ Returns:
25
+ tuple: A tuple containing:
26
+ - list[list[int]]: Grouped token IDs with each sublist containing 77 tokens
27
+ - list[list[float]]: Grouped weights matching the token IDs structure
28
+
29
+ Example:
30
+ token_groups, weight_groups = group_tokens_and_weights(
31
+ token_ids=token_id_list,
32
+ weights=token_weight_list
33
+ )
34
+ """
35
+ # Define beginning-of-sequence and end-of-sequence token IDs
36
+ bos, eos = 49406, 49407
37
+
38
+ # Initialize empty lists for storing grouped tokens and weights
39
+ new_token_ids = []
40
+ new_weights = []
41
+
42
+ # Process complete blocks of 75 tokens
43
+ while len(token_ids) >= 75:
44
+ # Extract the first 75 tokens and their weights
45
+ head_75_tokens = [token_ids.pop(0) for _ in range(75)]
46
+ head_75_weights = [weights.pop(0) for _ in range(75)]
47
+
48
+ # Create a complete block with BOS and EOS tokens
49
+ temp_77_token_ids = [bos, *head_75_tokens, eos]
50
+ temp_77_weights = [1.0, *head_75_weights, 1.0]
51
+
52
+ # Add the completed block to our result lists
53
+ new_token_ids.append(temp_77_token_ids)
54
+ new_weights.append(temp_77_weights)
55
+
56
+ # Process remaining tokens if any exist
57
+ if len(token_ids) > 0:
58
+ # Calculate padding length if pad_last_block is True
59
+ padding_len = 75 - len(token_ids) if pad_last_block else 0
60
+
61
+ # Create the final block with appropriate padding
62
+ temp_77_token_ids = [bos] + token_ids + [eos] * padding_len + [eos]
63
+ new_token_ids.append(temp_77_token_ids)
64
+
65
+ temp_77_weights = [1.0] + weights + [1.0] * padding_len + [1.0]
66
+ new_weights.append(temp_77_weights)
67
+
68
+ return new_token_ids, new_weights
69
+
70
+
71
+ def get_prompts_tokens_with_weights(
72
+ clip_tokenizer: CLIPTokenizer,
73
+ prompt: str | None,
74
+ ) -> tuple[list[int], list[float]]:
75
+ """
76
+ Tokenize a prompt with attention weights into token IDs and their corresponding weights.
77
+
78
+ This function processes prompts with weighted terms (like "a (cat:1.2) in the garden")
79
+ and returns both the token IDs and their respective weights. Works for both positive
80
+ and negative prompts in Stable Diffusion.
81
+
82
+ Args:
83
+ clip_tokenizer (CLIPTokenizer): The CLIP tokenizer instance
84
+ prompt (str | None): A prompt string with optional weights in parentheses
85
+ If None or empty, defaults to "empty"
86
+
87
+ Returns:
88
+ tuple: A tuple containing:
89
+ - list[int]: List of token IDs
90
+ - list[float]: List of weights corresponding to each token
91
+
92
+ Example:
93
+ token_id_list, token_weight_list = get_prompts_tokens_with_weights(
94
+ clip_tokenizer=clip_tokenizer,
95
+ prompt="a (red:1.5) cat"
96
+ )
97
+ """
98
+ # Use "empty" as default if prompt is None or empty
99
+ if (prompt is None) or (len(prompt) < 1):
100
+ prompt = "empty"
101
+
102
+ # Parse the prompt to get text chunks and their weights
103
+ texts_and_weights = parse_prompt_attention(prompt)
104
+ text_tokens: list[int] = []
105
+ text_weights: list[float] = []
106
+
107
+ for word, weight in texts_and_weights:
108
+ # Tokenize the text chunk, removing BOS/EOS tokens (positions 0 and -1)
109
+ token = clip_tokenizer(
110
+ word,
111
+ truncation=False, # Allow processing prompts of any length
112
+ ).input_ids[1:-1]
113
+
114
+ # Append new tokens to the full token list
115
+ text_tokens = [*text_tokens, *token]
116
+
117
+ # Apply the same weight to all tokens in this text chunk
118
+ chunk_weights = [weight] * len(token)
119
+
120
+ # Append weights to the full weights list
121
+ text_weights = [*text_weights, *chunk_weights]
122
+
123
+ return text_tokens, text_weights
@@ -0,0 +1,23 @@
1
+ """
2
+ Core prompt parsing functionality.
3
+ """
4
+
5
+ from diffusion_prompt_embedder.core.embedding import (
6
+ get_embeddings_sd15,
7
+ get_embeddings_sd_15_batch,
8
+ )
9
+ from diffusion_prompt_embedder.core.parser import (
10
+ apply_multiplier_to_range,
11
+ merge_identical_weights,
12
+ parse_prompt_attention,
13
+ process_text_token,
14
+ )
15
+
16
+ __all__ = [
17
+ "apply_multiplier_to_range",
18
+ "get_embeddings_sd15",
19
+ "get_embeddings_sd_15_batch",
20
+ "merge_identical_weights",
21
+ "parse_prompt_attention",
22
+ "process_text_token",
23
+ ]
@@ -0,0 +1,309 @@
1
+ import torch
2
+ from transformers import CLIPTextModel, CLIPTokenizer
3
+
4
+ from diffusion_prompt_embedder.clip.tokenization import get_prompts_tokens_with_weights, group_tokens_and_weights
5
+
6
+
7
+ def _encode_tokens_with_weights(
8
+ text_encoder: CLIPTextModel,
9
+ token_groups: list[list[int]],
10
+ weight_groups: list[list[float]],
11
+ device: torch.device,
12
+ dtype: torch.dtype,
13
+ ) -> list[torch.Tensor]:
14
+ """
15
+ Internal helper function to encode token groups and apply weights.
16
+
17
+ Args:
18
+ text_encoder: The CLIP text encoder model
19
+ token_groups: Grouped token IDs, each group has 77 tokens
20
+ weight_groups: Grouped weights matching the token IDs
21
+ device: Device to run encoding on
22
+ dtype: Data type for tensors
23
+
24
+ Returns:
25
+ list[torch.Tensor]: List of encoded embeddings for each token group
26
+ """
27
+ embeds = []
28
+
29
+ # Process each token group through the text encoder
30
+ for i in range(len(token_groups)):
31
+ # Process tokens
32
+ token_tensor = torch.tensor(
33
+ [token_groups[i]],
34
+ dtype=torch.long,
35
+ device=device,
36
+ )
37
+ weight_tensor = torch.tensor(
38
+ weight_groups[i],
39
+ dtype=dtype,
40
+ device=device,
41
+ )
42
+
43
+ # Get embeddings from text encoder
44
+ token_embedding = text_encoder(token_tensor)[0].squeeze(0)
45
+
46
+ # Apply attention weights to token embeddings
47
+ for j in range(len(weight_tensor)):
48
+ token_embedding[j] = token_embedding[j] * weight_tensor[j]
49
+
50
+ # Add batch dimension back and append to results
51
+ token_embedding = token_embedding.unsqueeze(0)
52
+ embeds.append(token_embedding)
53
+
54
+ return embeds
55
+
56
+
57
+ def _setup_clip_for_embedding(
58
+ text_encoder: CLIPTextModel,
59
+ clip_skip: int = 0,
60
+ ) -> tuple[torch.device, torch.dtype, object | None, int]:
61
+ """
62
+ Setup CLIP model for embedding generation and return common parameters.
63
+
64
+ Args:
65
+ text_encoder: The CLIP text encoder model
66
+ clip_skip: Number of layers to skip in CLIP model
67
+
68
+ Returns:
69
+ tuple: (device, dtype, original_clip_layers, clip_skip_applied)
70
+ """
71
+ # Get the device and dtype from the text encoder
72
+ device = text_encoder.device
73
+ dtype = text_encoder.dtype
74
+
75
+ # Store original layers for clip skip feature
76
+ original_clip_layers = None
77
+ if clip_skip > 0 and hasattr(text_encoder, "text_model"):
78
+ original_clip_layers = text_encoder.text_model.encoder.layers
79
+ text_encoder.text_model.encoder.layers = original_clip_layers[:-clip_skip]
80
+
81
+ return device, dtype, original_clip_layers, clip_skip
82
+
83
+
84
+ def get_embeddings_sd15( # noqa: PLR0913
85
+ tokenizer: CLIPTokenizer,
86
+ text_encoder: CLIPTextModel,
87
+ *,
88
+ prompt: str = "",
89
+ neg_prompt: str = "",
90
+ pad_last_block: bool = False,
91
+ clip_skip: int = 0,
92
+ ) -> tuple[torch.Tensor, torch.Tensor]:
93
+ """
94
+ Generate weighted text embeddings for Stable Diffusion 1.5 models.
95
+
96
+ This function processes both positive and negative prompts with weights and
97
+ generates CLIP text embeddings for use in Stable Diffusion inference. It can
98
+ handle arbitrarily long prompts by processing them in chunks and supports
99
+ clip-skip for style control.
100
+
101
+ Args:
102
+ tokenizer (CLIPTokenizer): The CLIP tokenizer instance
103
+ text_encoder (CLIPTextModel): The CLIP text encoder model
104
+ prompt (str): The positive prompt with optional weights in parentheses
105
+ neg_prompt (str): The negative prompt with optional weights in parentheses
106
+ pad_last_block (bool): Whether to pad the last token block to full length
107
+ clip_skip (int): Number of layers to skip in CLIP model for style control
108
+
109
+ Returns:
110
+ tuple[torch.Tensor, torch.Tensor]: A tuple containing:
111
+ - prompt_embeds: Tensor of positive prompt embeddings
112
+ - neg_prompt_embeds: Tensor of negative prompt embeddings
113
+
114
+ Example:
115
+ from transformers import CLIPTokenizer, CLIPTextModel
116
+
117
+ tokenizer = CLIPTokenizer.from_pretrained(
118
+ "openai/clip-vit-large-patch14",
119
+ )
120
+ text_encoder = CLIPTextModel.from_pretrained(
121
+ "openai/clip-vit-large-patch14",
122
+ torch_dtype=torch.float16
123
+ ).to("cuda")
124
+
125
+ prompt_embeds, neg_prompt_embeds = get_weighted_text_embeddings_sd15(
126
+ tokenizer=tokenizer,
127
+ text_encoder=text_encoder,
128
+ prompt="a (white:1.2) cat",
129
+ neg_prompt="blur, bad quality",
130
+ )
131
+ """
132
+ # Setup CLIP model and get common parameters
133
+ device, dtype, original_clip_layers, _ = _setup_clip_for_embedding(
134
+ text_encoder,
135
+ clip_skip,
136
+ )
137
+
138
+ # Get the eos token id from tokenizer
139
+ eos = tokenizer.eos_token_id
140
+
141
+ # Tokenize prompts with weights
142
+ prompt_tokens, prompt_weights = get_prompts_tokens_with_weights(
143
+ tokenizer,
144
+ prompt,
145
+ )
146
+ neg_prompt_tokens, neg_prompt_weights = get_prompts_tokens_with_weights(
147
+ tokenizer,
148
+ neg_prompt,
149
+ )
150
+
151
+ # Pad the shorter prompt to match the longer one for consistent batch processing
152
+ prompt_token_len = len(prompt_tokens)
153
+ neg_prompt_token_len = len(neg_prompt_tokens)
154
+ if prompt_token_len > neg_prompt_token_len:
155
+ # Pad negative prompt with EOS tokens to match positive prompt length
156
+ neg_prompt_tokens = neg_prompt_tokens + [eos] * abs(prompt_token_len - neg_prompt_token_len)
157
+ neg_prompt_weights = neg_prompt_weights + [1.0] * abs(prompt_token_len - neg_prompt_token_len)
158
+ else:
159
+ # Pad positive prompt with EOS tokens to match negative prompt length
160
+ prompt_tokens = prompt_tokens + [eos] * abs(prompt_token_len - neg_prompt_token_len)
161
+ prompt_weights = prompt_weights + [1.0] * abs(prompt_token_len - neg_prompt_token_len)
162
+
163
+ # Group tokens for processing in CLIP-compatible chunks (77 tokens per chunk)
164
+ prompt_token_groups, prompt_weight_groups = group_tokens_and_weights(
165
+ prompt_tokens.copy(),
166
+ prompt_weights.copy(),
167
+ pad_last_block=pad_last_block,
168
+ )
169
+ neg_prompt_token_groups, neg_prompt_weight_groups = group_tokens_and_weights(
170
+ neg_prompt_tokens.copy(),
171
+ neg_prompt_weights.copy(),
172
+ pad_last_block=pad_last_block,
173
+ )
174
+
175
+ # Process token groups through the shared encoder function
176
+ embeds = _encode_tokens_with_weights(
177
+ text_encoder,
178
+ prompt_token_groups,
179
+ prompt_weight_groups,
180
+ device,
181
+ dtype,
182
+ )
183
+
184
+ neg_embeds = _encode_tokens_with_weights(
185
+ text_encoder,
186
+ neg_prompt_token_groups,
187
+ neg_prompt_weight_groups,
188
+ device,
189
+ dtype,
190
+ )
191
+
192
+ # Concatenate all token group embeddings
193
+ prompt_embeds = torch.cat(embeds, dim=1)
194
+ neg_prompt_embeds = torch.cat(neg_embeds, dim=1)
195
+
196
+ # Restore original CLIP layers if clip_skip was used
197
+ if clip_skip > 0 and original_clip_layers is not None:
198
+ text_encoder.text_model.encoder.layers = original_clip_layers
199
+
200
+ return prompt_embeds, neg_prompt_embeds
201
+
202
+
203
+ def get_embeddings_sd_15_batch(
204
+ tokenizer: CLIPTokenizer,
205
+ text_encoder: CLIPTextModel,
206
+ *,
207
+ prompts: list[str],
208
+ pad_last_block: bool = True,
209
+ clip_skip: int = 0,
210
+ ) -> torch.Tensor:
211
+ """
212
+ Generate weighted text embeddings for multiple prompts in a batch.
213
+
214
+ This function processes a list of prompts with weights and generates CLIP text
215
+ embeddings for use in batch inference. It handles arbitrarily long prompts
216
+ by processing them in chunks, pads all prompts to the same length, and supports
217
+ clip-skip for style control.
218
+
219
+ Args:
220
+ tokenizer (CLIPTokenizer): The CLIP tokenizer instance
221
+ text_encoder (CLIPTextModel): The CLIP text encoder model
222
+ prompts (list[str]): List of prompts, each with optional weights in parentheses
223
+ pad_last_block (bool): Whether to pad the last token block to full length
224
+ clip_skip (int): Number of layers to skip in CLIP model for style control
225
+
226
+ Returns:
227
+ torch.Tensor: Tensor of embeddings for all prompts, shape [batch_size, seq_len, hidden_size]
228
+
229
+ Example:
230
+ from transformers import CLIPTokenizer, CLIPTextModel
231
+
232
+ tokenizer = CLIPTokenizer.from_pretrained(
233
+ "openai/clip-vit-large-patch14",
234
+ )
235
+ text_encoder = CLIPTextModel.from_pretrained(
236
+ "openai/clip-vit-large-patch14",
237
+ torch_dtype=torch.float16
238
+ ).to("cuda")
239
+
240
+ prompt_embeds = get_weighted_text_embeddings_batch(
241
+ tokenizer=tokenizer,
242
+ text_encoder=text_encoder,
243
+ prompts=["a (white:1.2) cat", "a (blue:1.4) dog", "a red bird"],
244
+ )
245
+ """
246
+ # Setup CLIP model and get common parameters
247
+ device, dtype, original_clip_layers, _ = _setup_clip_for_embedding(
248
+ text_encoder,
249
+ clip_skip,
250
+ )
251
+
252
+ # Get the eos token id from tokenizer
253
+ eos = tokenizer.eos_token_id
254
+
255
+ # Tokenize all prompts with weights
256
+ all_prompt_tokens: list[list[int]] = []
257
+ all_prompt_weights: list[list[float]] = []
258
+ max_token_len: int = 0
259
+
260
+ for prompt in prompts:
261
+ prompt_tokens, prompt_weights = get_prompts_tokens_with_weights(
262
+ tokenizer,
263
+ prompt,
264
+ )
265
+ all_prompt_tokens.append(prompt_tokens)
266
+ all_prompt_weights.append(prompt_weights)
267
+ max_token_len = max(max_token_len, len(prompt_tokens))
268
+
269
+ # Pad all prompts to the same length
270
+ for i in range(len(all_prompt_tokens)):
271
+ token_len = len(all_prompt_tokens[i])
272
+ if token_len < max_token_len:
273
+ padding_len = max_token_len - token_len
274
+ all_prompt_tokens[i] = all_prompt_tokens[i] + [eos] * padding_len
275
+ all_prompt_weights[i] = all_prompt_weights[i] + [1.0] * padding_len
276
+
277
+ # Initialize list to hold embeddings for each prompt
278
+ all_embeds = []
279
+
280
+ # Process each prompt separately
281
+ for prompt_idx in range(len(prompts)):
282
+ # Group tokens for processing in CLIP-compatible chunks (77 tokens per chunk)
283
+ prompt_token_groups, prompt_weight_groups = group_tokens_and_weights(
284
+ all_prompt_tokens[prompt_idx].copy(),
285
+ all_prompt_weights[prompt_idx].copy(),
286
+ pad_last_block=pad_last_block,
287
+ )
288
+
289
+ # Process token groups through the shared encoder function
290
+ embeds = _encode_tokens_with_weights(
291
+ text_encoder,
292
+ prompt_token_groups,
293
+ prompt_weight_groups,
294
+ device,
295
+ dtype,
296
+ )
297
+
298
+ # Concatenate all token group embeddings for this prompt
299
+ prompt_embeds = torch.cat(embeds, dim=1)
300
+ all_embeds.append(prompt_embeds)
301
+
302
+ # Stack all prompt embeddings into a batch
303
+ batched_embeds = torch.cat(all_embeds, dim=0)
304
+
305
+ # Restore original CLIP layers if clip_skip was used
306
+ if clip_skip > 0 and original_clip_layers is not None:
307
+ text_encoder.text_model.encoder.layers = original_clip_layers
308
+
309
+ return batched_embeds
@@ -0,0 +1,178 @@
1
+ import re
2
+
3
+ # Regular expressions for prompt processing
4
+ # Matches the "AND" keyword (used to split prompts)
5
+ re_and = re.compile(r"\bAND\b")
6
+ # Matches weight format: "text:1.5", captures text and weight value
7
+ re_weight = re.compile(r"^((?:\s|.)*?)(?:\s*:\s*([-+]?(?:\d+\.?|\d*\.\d+)))?\s*$")
8
+ # Matches the "BREAK" keyword (used to insert separators in prompts)
9
+ re_break = re.compile(r"\s*\bBREAK\b\s*", re.DOTALL)
10
+
11
+ # Complex regular expression for parsing attention markers
12
+ # This regex identifies various brackets and weight markers used to enhance or reduce specific parts of prompts
13
+ re_attention = re.compile(
14
+ r"""
15
+ \\\(| # Escaped left parenthesis \(
16
+ \\\)| # Escaped right parenthesis \)
17
+ \\\[| # Escaped left bracket \[
18
+ \\]| # Escaped right bracket \]
19
+ \\\\| # Escaped backslash \\
20
+ \\| # Single backslash (escape character)
21
+ \(| # Left parenthesis - starts an enhanced attention area
22
+ \[| # Left bracket - starts a reduced attention area
23
+ :\s*([+-]?[.\d]+)\s*\)| # Colon followed by number and right parenthesis - custom weight value
24
+ \)| # Right parenthesis - ends enhanced attention area
25
+ ]| # Right bracket - ends reduced attention area
26
+ [^\\()\[\]:]+| # Regular text (any text not containing special characters)
27
+ : # Single colon
28
+ """,
29
+ re.VERBOSE, # Enables verbose mode, allowing comments and whitespace in regex
30
+ )
31
+
32
+
33
+ def apply_multiplier_to_range(
34
+ tokens: list[list[str | float]],
35
+ start_position: int,
36
+ multiplier: float,
37
+ ) -> None:
38
+ """
39
+ Applies a weight multiplier to a range of tokens starting from a specified position.
40
+
41
+ This function is used to process weight adjustments for text within brackets,
42
+ such as weight changes in (text) or [text].
43
+
44
+ Args:
45
+ tokens: List of [text, weight] pairs to modify
46
+ start_position: Position to start applying the multiplier
47
+ multiplier: Weight multiplier to apply
48
+ """
49
+ for p in range(start_position, len(tokens)):
50
+ tokens[p][1] *= multiplier
51
+
52
+
53
+ def process_text_token(text: str) -> list[list[str | float]]:
54
+ """
55
+ Processes text tokens, specifically handling BREAK markers in the text.
56
+
57
+ BREAK markers are used to insert special separators in prompts,
58
+ typically used to divide different concepts or regions.
59
+
60
+ Args:
61
+ text: Text to process
62
+
63
+ Returns:
64
+ List of [text, weight] pairs
65
+ """
66
+ result = []
67
+ # Split text by BREAK keyword
68
+ parts = re.split(re_break, text)
69
+ for i, part in enumerate(parts):
70
+ if i > 0:
71
+ # Add a special marker after each BREAK with weight -1
72
+ result.append(["BREAK", -1])
73
+ # Add regular text with default weight 1.0
74
+ result.append([part, 1.0])
75
+ return result
76
+
77
+
78
+ def merge_identical_weights(tokens: list[list[str | float]]) -> list[list[str | float]]:
79
+ """
80
+ Merges consecutive tokens with identical weights.
81
+
82
+ When multiple consecutive text fragments have the same weight, this function
83
+ combines them into one to simplify output and improve efficiency.
84
+
85
+ Args:
86
+ tokens: List of [text, weight] pairs
87
+
88
+ Returns:
89
+ List of merged tokens
90
+ """
91
+ if not tokens:
92
+ return [["", 1.0]] # Return a default value if list is empty
93
+
94
+ i = 0
95
+ while i + 1 < len(tokens):
96
+ if tokens[i][1] == tokens[i + 1][1]:
97
+ # When two consecutive tokens have the same weight, merge their text
98
+ tokens[i][0] += tokens[i + 1][0]
99
+ tokens.pop(i + 1) # Remove the merged token
100
+ else:
101
+ i += 1
102
+
103
+ return tokens
104
+
105
+
106
+ def parse_prompt_attention(text: str) -> list[list[str | float]]:
107
+ """
108
+ Parses a string with attention markers and returns a list of text and associated weight pairs.
109
+
110
+ This function is the core of prompt parsing, handling various attention control symbols
111
+ like parentheses and brackets used to adjust focus on different parts of the prompt during generation.
112
+
113
+ Supported markers:
114
+ (abc) - increases attention to abc by a multiplier of 1.1
115
+ (abc:3.12) - increases attention to abc by a multiplier of 3.12
116
+ [abc] - decreases attention to abc by a multiplier of 1.1
117
+ \\( - literal character '('
118
+ \\[ - literal character '['
119
+ \\) - literal character ')'
120
+ \\] - literal character ']'
121
+ \\ - literal character '\'
122
+ anything else - just text
123
+
124
+ Args:
125
+ text: Prompt text to parse
126
+
127
+ Returns:
128
+ List of [text, weight] pairs representing the parsed prompt parts and their weights
129
+ """
130
+ res: list[list[str | float]] = [] # Result list storing [text, weight] pairs
131
+ round_brackets: list[int] = [] # Stack for parentheses, stores opening position
132
+ square_brackets: list[int] = [] # Stack for brackets, stores opening position
133
+
134
+ # Define weight multiplier constants
135
+ round_bracket_multiplier = 1.1 # Default enhancement factor for parentheses
136
+ square_bracket_multiplier = 1 / 1.1 # Default reduction factor for brackets (reciprocal)
137
+
138
+ # Parse each token in the text using regex
139
+ for m in re_attention.finditer(text):
140
+ token_text = m.group(0) # Current matched text
141
+ weight = m.group(1) # Possible weight value (if any)
142
+
143
+ if token_text.startswith("\\"):
144
+ # Handle escape characters - remove backslash, preserve original character
145
+ res.append([token_text[1:], 1.0])
146
+ elif token_text == "(":
147
+ # Left parenthesis - push current position to stack, mark start of enhancement area
148
+ round_brackets.append(len(res))
149
+ elif token_text == "[":
150
+ # Left bracket - push current position to stack, mark start of reduction area
151
+ square_brackets.append(len(res))
152
+ elif weight is not None and round_brackets:
153
+ # Right parenthesis with custom weight - adjust area with specified weight
154
+ apply_multiplier_to_range(res, round_brackets.pop(), float(weight))
155
+ elif token_text == ")" and round_brackets:
156
+ # Regular right parenthesis - enhance area with default multiplier
157
+ apply_multiplier_to_range(res, round_brackets.pop(), round_bracket_multiplier)
158
+ elif token_text == "]" and square_brackets:
159
+ # Right bracket - reduce area with default multiplier
160
+ apply_multiplier_to_range(res, square_brackets.pop(), square_bracket_multiplier)
161
+ else:
162
+ # Process regular text or unmatched brackets
163
+ res.extend(process_text_token(token_text))
164
+
165
+ # Handle unclosed brackets (ensure all opening brackets have corresponding closing brackets)
166
+ for pos in round_brackets:
167
+ # Apply default enhancement for unclosed parentheses
168
+ apply_multiplier_to_range(res, pos, round_bracket_multiplier)
169
+
170
+ for pos in square_brackets:
171
+ # Apply default reduction for unclosed brackets
172
+ apply_multiplier_to_range(res, pos, square_bracket_multiplier)
173
+
174
+ # Merge consecutive tokens with identical weights
175
+ res = merge_identical_weights(res)
176
+
177
+ # Ensure all elements in the returned list have the correct types
178
+ return [[str(text), float(weight)] for text, weight in res]
File without changes
@@ -0,0 +1,152 @@
1
+ Metadata-Version: 2.4
2
+ Name: diffusion-prompt-embedder
3
+ Version: 0.1.0
4
+ Summary: A Python library for parsing and processing prompts with support for embedding and tokenization
5
+ Project-URL: Homepage, https://github.com/jannchie/diffusion-prompt-embedder
6
+ Project-URL: Bug Tracker, https://github.com/jannchie/diffusion-prompt-embedder/issues
7
+ Project-URL: Documentation, https://github.com/jannchie/diffusion-prompt-embedder#readme
8
+ Author-email: Jianqi Pan <jannchie@gmail.com>
9
+ License: MIT
10
+ Keywords: ai,embedding,nlp,prompt,tokenization
11
+ Classifier: Development Status :: 4 - Beta
12
+ Classifier: Intended Audience :: Developers
13
+ Classifier: License :: OSI Approved :: MIT License
14
+ Classifier: Programming Language :: Python :: 3
15
+ Classifier: Programming Language :: Python :: 3.10
16
+ Classifier: Programming Language :: Python :: 3.11
17
+ Classifier: Programming Language :: Python :: 3.12
18
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
19
+ Requires-Python: >=3.10
20
+ Provides-Extra: all
21
+ Requires-Dist: torch>=2.0.0; extra == 'all'
22
+ Requires-Dist: transformers>=4.51.3; extra == 'all'
23
+ Provides-Extra: dev
24
+ Requires-Dist: pytest-cov>=6.1.1; extra == 'dev'
25
+ Requires-Dist: pytest>=8.3.5; extra == 'dev'
26
+ Requires-Dist: torch>=2.0.0; extra == 'dev'
27
+ Requires-Dist: transformers>=4.51.3; extra == 'dev'
28
+ Provides-Extra: torch
29
+ Requires-Dist: torch>=2.0.0; extra == 'torch'
30
+ Provides-Extra: transformers
31
+ Requires-Dist: transformers>=4.51.3; extra == 'transformers'
32
+ Description-Content-Type: text/markdown
33
+
34
+ # Diffusion Prompt Embedder
35
+
36
+ [![PyPI version](https://img.shields.io/pypi/v/diffusion-prompt-embedder.svg)](https://pypi.org/project/diffusion-prompt-embedder/)
37
+ [![Python Version](https://img.shields.io/pypi/pyversions/diffusion-prompt-embedder.svg)](https://pypi.org/project/diffusion-prompt-embedder/)
38
+ [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
39
+ [![Code Coverage](https://img.shields.io/badge/coverage-100%25-brightgreen.svg)](https://github.com/jannchie/diffusion-prompt-embedder)
40
+
41
+ A Python library specialized for parsing and processing weighted prompt text, supporting embedding generation and tokenization to enhance text processing for AI models like Stable Diffusion. It's compatible with SD Web UI's weighted prompts but doesn't include scheduling.
42
+
43
+ ## Features
44
+
45
+ - 💬 **Prompt Parsing**: Parse text prompts with weight markers (e.g., `a (cat:1.5) in the garden`)
46
+ - 🔢 **Weight Management**: Support for positive weight `(text)` and negative weight `[text]` syntax
47
+ - 📚 **CLIP Integration**: Seamless integration with CLIP text models for embedding generation
48
+ - 🔄 **Batch Processing**: Efficiently process batches of multiple prompts
49
+ - 🪄 **Long Text Support**: Handle prompts that exceed standard CLIP context length
50
+
51
+ ## Installation
52
+
53
+ Install the base library using pip:
54
+
55
+ ```bash
56
+ pip install diffusion-prompt-embedder
57
+ ```
58
+
59
+ ## Usage Examples
60
+
61
+ ### Parse Weighted Prompts
62
+
63
+ ```python
64
+ from diffusion_prompt_embedder import parse_prompt_attention
65
+
66
+ # Basic parsing
67
+ result = parse_prompt_attention("a (cat:1.5) in the garden")
68
+ print(result) # [['a ', 1.0], ['cat', 1.5], [' in the garden', 1.0]]
69
+
70
+ # Using brackets to lower weight
71
+ result = parse_prompt_attention("a [cat] in the garden")
72
+ print(result) # [['a ', 1.0], ['cat', 0.9090909090909091], [' in the garden', 1.0]]
73
+
74
+ # Complex prompt example
75
+ result = parse_prompt_attention("a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).")
76
+ print(result)
77
+ ```
78
+
79
+ ### Generate CLIP Embeddings
80
+
81
+ ```python
82
+ import torch
83
+ from transformers import CLIPTokenizer, CLIPTextModel
84
+ from prompt_parser import get_embeddings_sd15
85
+
86
+ # Initialize CLIP model
87
+ tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
88
+ text_encoder = CLIPTextModel.from_pretrained(
89
+ "openai/clip-vit-large-patch14",
90
+ torch_dtype=torch.float16
91
+ ).to("cuda")
92
+
93
+ # Generate embeddings
94
+ prompt_embeds, neg_prompt_embeds = get_embeddings_sd15(
95
+ tokenizer=tokenizer,
96
+ text_encoder=text_encoder,
97
+ prompt="a (white:1.2) cat",
98
+ neg_prompt="blur, bad quality",
99
+ clip_skip=1 # Optional: skip layers in CLIP model
100
+ )
101
+
102
+ # Batch processing multiple prompts
103
+ from prompt_parser import get_embeddings_sd_15_batch
104
+
105
+ batch_embeds = get_embeddings_sd_15_batch(
106
+ tokenizer=tokenizer,
107
+ text_encoder=text_encoder,
108
+ prompts=["a (white:1.2) cat", "a (blue:1.4) dog", "a red bird"]
109
+ )
110
+ ```
111
+
112
+ ## Prompt Syntax
113
+
114
+ ### Basic Weight Syntax
115
+
116
+ - `(text)` - Increases the prompt weight by 1.1x
117
+ - `(text:1.5)` - Sets the prompt weight to 1.5
118
+ - `[text]` - Decreases the prompt weight to 1/1.1 of original
119
+ - `\( \[ \) \]` - Use backslash to escape bracket characters
120
+
121
+ ### BREAK Syntax
122
+
123
+ Use the `BREAK` keyword to create breakpoints in prompts:
124
+
125
+ ```python
126
+ result = parse_prompt_attention("text1 BREAK text2")
127
+ # Result: [["text1", 1.0], ["BREAK", -1], ["text2", 1.0]]
128
+ ```
129
+
130
+ ## Development
131
+
132
+ Clone the repository and install development dependencies:
133
+
134
+ ```bash
135
+ git clone https://github.com/jannchie/diffusion-prompt-parser.git
136
+ cd diffusion-prompt-parser
137
+ pip install -e ".[dev]"
138
+ ```
139
+
140
+ Run tests:
141
+
142
+ ```bash
143
+ pytest
144
+ ```
145
+
146
+ ## License
147
+
148
+ [MIT](https://opensource.org/licenses/MIT)
149
+
150
+ ## Author
151
+
152
+ - Jianqi Pan ([@jannchie](https://github.com/jannchie))
@@ -0,0 +1,10 @@
1
+ diffusion_prompt_embedder/__init__.py,sha256=60W03g8iWpUgaMmZTJSkMRsjWwZ1mAGpNQbD_O0QZ70,583
2
+ diffusion_prompt_embedder/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
+ diffusion_prompt_embedder/clip/__init__.py,sha256=yikfkQ1fqg70OWMzO3i4G3yEmaXudV0XJHB-NMStXcA,286
4
+ diffusion_prompt_embedder/clip/tokenization.py,sha256=63wd3-Gib7ZvUwmvVNfi74n0ntF79zjHc3KHvqwzAGs,4658
5
+ diffusion_prompt_embedder/core/__init__.py,sha256=KUtvybDsDid_NCBlo9AcpLu5bbFNy24NDn8mqPQPWpc,543
6
+ diffusion_prompt_embedder/core/embedding.py,sha256=KfgncrrYshYevD_RfVahymw7qoNrAB2S3hbRC2CZEG0,11278
7
+ diffusion_prompt_embedder/core/parser.py,sha256=kp4Xr5XNl0JPzfouHMSZoyv_y6a7P1NEzlqhg669ubo,7383
8
+ diffusion_prompt_embedder-0.1.0.dist-info/METADATA,sha256=dj8TG5hqOtIsHtNLfSNpA2qbUujHxv8gbH4riTytbQY,5204
9
+ diffusion_prompt_embedder-0.1.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
10
+ diffusion_prompt_embedder-0.1.0.dist-info/RECORD,,
@@ -0,0 +1,4 @@
1
+ Wheel-Version: 1.0
2
+ Generator: hatchling 1.27.0
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any