fabricatio 0.2.3.dev3__cp312-cp312-win_amd64.whl → 0.2.4.dev1__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,19 +1,18 @@
1
1
  """This module contains classes that manage the usage of language models and tools in tasks."""
2
2
 
3
3
  from asyncio import gather
4
- from typing import Callable, Dict, Iterable, List, Optional, Self, Set, Union, Unpack, overload
4
+ from typing import Callable, Dict, Iterable, List, Optional, Self, Set, Type, Union, Unpack, overload
5
5
 
6
6
  import asyncstdlib
7
7
  import litellm
8
- import orjson
9
8
  from fabricatio._rust_instances import template_manager
10
9
  from fabricatio.config import configs
11
10
  from fabricatio.journal import logger
12
- from fabricatio.models.generic import Base, WithBriefing
11
+ from fabricatio.models.generic import ScopedConfig, WithBriefing
13
12
  from fabricatio.models.kwargs_types import ChooseKwargs, EmbeddingKwargs, GenerateKwargs, LLMKwargs
14
13
  from fabricatio.models.task import Task
15
14
  from fabricatio.models.tool import Tool, ToolBox
16
- from fabricatio.models.utils import Messages, MilvusData
15
+ from fabricatio.models.utils import Messages
17
16
  from fabricatio.parser import JsonCapture
18
17
  from litellm import stream_chunk_builder
19
18
  from litellm.types.utils import (
@@ -23,135 +22,16 @@ from litellm.types.utils import (
23
22
  StreamingChoices,
24
23
  )
25
24
  from litellm.utils import CustomStreamWrapper
26
- from pydantic import Field, HttpUrl, NonNegativeFloat, NonNegativeInt, PositiveInt, SecretStr
25
+ from more_itertools import duplicates_everseen
26
+ from pydantic import Field, NonNegativeInt, PositiveInt
27
27
 
28
28
 
29
- class LLMUsage(Base):
29
+ class LLMUsage(ScopedConfig):
30
30
  """Class that manages LLM (Large Language Model) usage parameters and methods."""
31
31
 
32
- llm_api_endpoint: Optional[HttpUrl] = None
33
- """The OpenAI API endpoint."""
34
-
35
- llm_api_key: Optional[SecretStr] = None
36
- """The OpenAI API key."""
37
-
38
- llm_timeout: Optional[PositiveInt] = None
39
- """The timeout of the LLM model."""
40
-
41
- llm_max_retries: Optional[PositiveInt] = None
42
- """The maximum number of retries."""
43
-
44
- llm_model: Optional[str] = None
45
- """The LLM model name."""
46
-
47
- llm_temperature: Optional[NonNegativeFloat] = None
48
- """The temperature of the LLM model."""
49
-
50
- llm_stop_sign: Optional[str | List[str]] = None
51
- """The stop sign of the LLM model."""
52
-
53
- llm_top_p: Optional[NonNegativeFloat] = None
54
- """The top p of the LLM model."""
55
-
56
- llm_generation_count: Optional[PositiveInt] = None
57
- """The number of generations to generate."""
58
-
59
- llm_stream: Optional[bool] = None
60
- """Whether to stream the LLM model's response."""
61
-
62
- llm_max_tokens: Optional[PositiveInt] = None
63
- """The maximum number of tokens to generate."""
64
-
65
- async def aembedding(
66
- self,
67
- input_text: List[str],
68
- model: Optional[str] = None,
69
- dimensions: Optional[int] = None,
70
- timeout: Optional[PositiveInt] = None,
71
- caching: Optional[bool] = False,
72
- ) -> EmbeddingResponse:
73
- """Asynchronously generates embeddings for the given input text.
74
-
75
- Args:
76
- input_text (List[str]): A list of strings to generate embeddings for.
77
- model (Optional[str]): The model to use for embedding. Defaults to the instance's `llm_model` or the global configuration.
78
- dimensions (Optional[int]): The dimensions of the embedding. Defaults to None.
79
- timeout (Optional[PositiveInt]): The timeout for the embedding request. Defaults to the instance's `llm_timeout` or the global configuration.
80
- caching (Optional[bool]): Whether to cache the embedding result. Defaults to False.
81
-
82
-
83
- Returns:
84
- EmbeddingResponse: The response containing the embeddings.
85
- """
86
- return await litellm.aembedding(
87
- input=input_text,
88
- caching=caching,
89
- dimensions=dimensions,
90
- model=model or self.llm_model or configs.llm.model,
91
- timeout=timeout or self.llm_timeout or configs.llm.timeout,
92
- api_key=self.llm_api_key.get_secret_value() if self.llm_api_key else configs.llm.api_key.get_secret_value(),
93
- api_base=self.llm_api_endpoint.unicode_string().rstrip(
94
- "/"
95
- ) # seems embedding function takes no base_url end with a slash
96
- if self.llm_api_endpoint
97
- else configs.llm.api_endpoint.unicode_string().rstrip("/"),
98
- )
99
-
100
- @overload
101
- async def vectorize(self, input_text: List[str], **kwargs: Unpack[EmbeddingKwargs]) -> List[List[float]]: ...
102
- @overload
103
- async def vectorize(self, input_text: str, **kwargs: Unpack[EmbeddingKwargs]) -> List[float]: ...
104
-
105
- async def vectorize(
106
- self, input_text: List[str] | str, **kwargs: Unpack[EmbeddingKwargs]
107
- ) -> List[List[float]] | List[float]:
108
- """Asynchronously generates vector embeddings for the given input text.
109
-
110
- Args:
111
- input_text (List[str] | str): A string or list of strings to generate embeddings for.
112
- **kwargs (Unpack[EmbeddingKwargs]): Additional keyword arguments for embedding.
113
-
114
- Returns:
115
- List[List[float]] | List[float]: The generated embeddings.
116
- """
117
- if isinstance(input_text, str):
118
- return (await self.aembedding([input_text], **kwargs)).data[0].get("embedding")
119
-
120
- return [o.get("embedding") for o in (await self.aembedding(input_text, **kwargs)).data]
121
-
122
- @overload
123
- async def pack(
124
- self, input_text: List[str], subject: Optional[str] = None, **kwargs: Unpack[EmbeddingKwargs]
125
- ) -> List[MilvusData]: ...
126
- @overload
127
- async def pack(
128
- self, input_text: str, subject: Optional[str] = None, **kwargs: Unpack[EmbeddingKwargs]
129
- ) -> MilvusData: ...
130
-
131
- async def pack(
132
- self, input_text: List[str] | str, subject: Optional[str] = None, **kwargs: Unpack[EmbeddingKwargs]
133
- ) -> List[MilvusData] | MilvusData:
134
- """Asynchronously generates MilvusData objects for the given input text.
135
-
136
- Args:
137
- input_text (List[str] | str): A string or list of strings to generate embeddings for.
138
- subject (Optional[str]): The subject of the input text. Defaults to None.
139
- **kwargs (Unpack[EmbeddingKwargs]): Additional keyword arguments for embedding.
140
-
141
- Returns:
142
- List[MilvusData] | MilvusData: The generated MilvusData objects.
143
- """
144
- if isinstance(input_text, str):
145
- return MilvusData(vector=await self.vectorize(input_text, **kwargs), text=input_text, subject=subject)
146
- vecs = await self.vectorize(input_text, **kwargs)
147
- return [
148
- MilvusData(
149
- vector=vec,
150
- text=text,
151
- subject=subject,
152
- )
153
- for text, vec in zip(input_text, vecs, strict=True)
154
- ]
32
+ @classmethod
33
+ def _scoped_model(cls) -> Type["LLMUsage"]:
34
+ return LLMUsage
155
35
 
156
36
  async def aquery(
157
37
  self,
@@ -181,10 +61,8 @@ class LLMUsage(Base):
181
61
  stream=kwargs.get("stream") or self.llm_stream or configs.llm.stream,
182
62
  timeout=kwargs.get("timeout") or self.llm_timeout or configs.llm.timeout,
183
63
  max_retries=kwargs.get("max_retries") or self.llm_max_retries or configs.llm.max_retries,
184
- api_key=self.llm_api_key.get_secret_value() if self.llm_api_key else configs.llm.api_key.get_secret_value(),
185
- base_url=self.llm_api_endpoint.unicode_string()
186
- if self.llm_api_endpoint
187
- else configs.llm.api_endpoint.unicode_string(),
64
+ api_key=(self.llm_api_key or configs.llm.api_key).get_secret_value(),
65
+ base_url=(self.llm_api_endpoint or configs.llm.api_endpoint).unicode_string(),
188
66
  )
189
67
 
190
68
  async def ainvoke(
@@ -213,13 +91,13 @@ class LLMUsage(Base):
213
91
  if isinstance(resp, ModelResponse):
214
92
  return resp.choices
215
93
  if isinstance(resp, CustomStreamWrapper):
216
- if configs.debug.streaming_visible:
217
- chunks = []
218
- async for chunk in resp:
219
- chunks.append(chunk)
220
- print(chunk.choices[0].delta.content or "", end="") # noqa: T201
221
- return stream_chunk_builder(chunks).choices
222
- return stream_chunk_builder(await asyncstdlib.list()).choices
94
+ if not configs.debug.streaming_visible:
95
+ return stream_chunk_builder(await asyncstdlib.list()).choices
96
+ chunks = []
97
+ async for chunk in resp:
98
+ chunks.append(chunk)
99
+ print(chunk.choices[0].delta.content or "", end="") # noqa: T201
100
+ return stream_chunk_builder(chunks).choices
223
101
  logger.critical(err := f"Unexpected response type: {type(resp)}")
224
102
  raise ValueError(err)
225
103
 
@@ -334,11 +212,10 @@ class LLMUsage(Base):
334
212
  **kwargs,
335
213
  )
336
214
  ) and (validated := validator(response)):
337
- logger.debug(f"Successfully validated the response at {i}th attempt. response: \n{response}")
215
+ logger.debug(f"Successfully validated the response at {i}th attempt.")
338
216
  return validated
339
- logger.debug(f"Failed to validate the response at {i}th attempt. response: \n{response}")
340
- logger.error(f"Failed to validate the response after {max_validations} attempts.")
341
- raise ValueError("Failed to validate the response.")
217
+ logger.error(err := f"Failed to validate the response after {max_validations} attempts.")
218
+ raise ValueError(err)
342
219
 
343
220
  async def aask_validate_batch[T](
344
221
  self,
@@ -361,6 +238,26 @@ class LLMUsage(Base):
361
238
  """
362
239
  return await gather(*[self.aask_validate(question, validator, **kwargs) for question in questions])
363
240
 
241
+ async def aliststr(self, requirement: str, k: NonNegativeInt = 0, **kwargs: Unpack[GenerateKwargs]) -> List[str]:
242
+ """Asynchronously generates a list of strings based on a given requirement.
243
+
244
+ Args:
245
+ requirement (str): The requirement for the list of strings.
246
+ k (NonNegativeInt): The number of choices to select, 0 means infinite. Defaults to 0.
247
+ **kwargs (Unpack[GenerateKwargs]): Additional keyword arguments for the LLM usage.
248
+
249
+ Returns:
250
+ List[str]: The validated response as a list of strings.
251
+ """
252
+ return await self.aask_validate(
253
+ template_manager.render_template(
254
+ configs.templates.liststr_template,
255
+ {"requirement": requirement, "k": k},
256
+ ),
257
+ lambda resp: JsonCapture.validate_with(resp, target_type=list, elements_type=str, length=k),
258
+ **kwargs,
259
+ )
260
+
364
261
  async def achoose[T: WithBriefing](
365
262
  self,
366
263
  instruction: str,
@@ -384,28 +281,28 @@ class LLMUsage(Base):
384
281
  - Ensures response compliance through JSON parsing and format validation.
385
282
  - Relies on `aask_validate` to implement retry mechanisms with validation.
386
283
  """
284
+ if dup := duplicates_everseen(choices, key=lambda x: x.name):
285
+ logger.error(err := f"Redundant choices: {dup}")
286
+ raise ValueError(err)
387
287
  prompt = template_manager.render_template(
388
288
  configs.templates.make_choice_template,
389
289
  {
390
290
  "instruction": instruction,
391
- "options": [{"name": m.name, "briefing": m.briefing} for m in choices],
291
+ "options": [m.model_dump(include={"name", "briefing"}) for m in choices],
392
292
  "k": k,
393
293
  },
394
294
  )
395
295
  names = {c.name for c in choices}
296
+
396
297
  logger.debug(f"Start choosing between {names} with prompt: \n{prompt}")
397
298
 
398
299
  def _validate(response: str) -> List[T] | None:
399
- ret = JsonCapture.convert_with(response, orjson.loads)
400
-
401
- if not isinstance(ret, List) or (0 < k != len(ret)):
402
- logger.error(f"Incorrect Type or length of response: \n{ret}")
300
+ ret = JsonCapture.validate_with(response, target_type=List, elements_type=str, length=k)
301
+ if ret is None or set(ret) - names:
403
302
  return None
404
- if any(n not in names for n in ret):
405
- logger.error(f"Invalid choice in response: \n{ret}")
406
- return None
407
-
408
- return [next(toolbox for toolbox in choices if toolbox.name == toolbox_str) for toolbox_str in ret]
303
+ return [
304
+ next(candidate for candidate in choices if candidate.name == candidate_name) for candidate_name in ret
305
+ ]
409
306
 
410
307
  return await self.aask_validate(
411
308
  question=prompt,
@@ -459,55 +356,91 @@ class LLMUsage(Base):
459
356
  Returns:
460
357
  bool: The judgment result (True or False) based on the AI's response.
461
358
  """
462
-
463
- def _validate(response: str) -> bool | None:
464
- ret = JsonCapture.convert_with(response, orjson.loads)
465
- if not isinstance(ret, bool):
466
- return None
467
- return ret
468
-
469
359
  return await self.aask_validate(
470
360
  question=template_manager.render_template(
471
361
  configs.templates.make_judgment_template,
472
362
  {"prompt": prompt, "affirm_case": affirm_case, "deny_case": deny_case},
473
363
  ),
474
- validator=_validate,
364
+ validator=lambda resp: JsonCapture.validate_with(resp, target_type=bool),
475
365
  **kwargs,
476
366
  )
477
367
 
478
- def fallback_to(self, other: "LLMUsage") -> Self:
479
- """Fallback to another instance's attribute values if the current instance's attributes are None.
368
+
369
+ class EmbeddingUsage(LLMUsage):
370
+ """A class representing the embedding model."""
371
+
372
+ async def aembedding(
373
+ self,
374
+ input_text: List[str],
375
+ model: Optional[str] = None,
376
+ dimensions: Optional[int] = None,
377
+ timeout: Optional[PositiveInt] = None,
378
+ caching: Optional[bool] = False,
379
+ ) -> EmbeddingResponse:
380
+ """Asynchronously generates embeddings for the given input text.
480
381
 
481
382
  Args:
482
- other (LLMUsage): Another instance from which to copy attribute values.
383
+ input_text (List[str]): A list of strings to generate embeddings for.
384
+ model (Optional[str]): The model to use for embedding. Defaults to the instance's `llm_model` or the global configuration.
385
+ dimensions (Optional[int]): The dimensions of the embedding output should have, which is used to validate the result. Defaults to None.
386
+ timeout (Optional[PositiveInt]): The timeout for the embedding request. Defaults to the instance's `llm_timeout` or the global configuration.
387
+ caching (Optional[bool]): Whether to cache the embedding result. Defaults to False.
388
+
483
389
 
484
390
  Returns:
485
- Self: The current instance, allowing for method chaining.
391
+ EmbeddingResponse: The response containing the embeddings.
486
392
  """
487
- # Iterate over the attribute names and copy values from 'other' to 'self' where applicable
488
- # noinspection PydanticTypeChecker,PyTypeChecker
489
- for attr_name in LLMUsage.model_fields:
490
- # Copy the attribute value from 'other' to 'self' only if 'self' has None and 'other' has a non-None value
491
- if getattr(self, attr_name) is None and (attr := getattr(other, attr_name)) is not None:
492
- setattr(self, attr_name, attr)
493
-
494
- # Return the current instance to allow for method chaining
495
- return self
393
+ # check seq length
394
+ max_len = self.embedding_max_sequence_length or configs.embedding.max_sequence_length
395
+ if any(len(t) > max_len for t in input_text):
396
+ logger.error(err := f"Input text exceeds maximum sequence length {max_len}.")
397
+ raise ValueError(err)
398
+
399
+ return await litellm.aembedding(
400
+ input=input_text,
401
+ caching=caching or self.embedding_caching or configs.embedding.caching,
402
+ dimensions=dimensions or self.embedding_dimensions or configs.embedding.dimensions,
403
+ model=model or self.embedding_model or configs.embedding.model or self.llm_model or configs.llm.model,
404
+ timeout=timeout
405
+ or self.embedding_timeout
406
+ or configs.embedding.timeout
407
+ or self.llm_timeout
408
+ or configs.llm.timeout,
409
+ api_key=(
410
+ self.embedding_api_key or configs.embedding.api_key or self.llm_api_key or configs.llm.api_key
411
+ ).get_secret_value(),
412
+ api_base=(
413
+ self.embedding_api_endpoint
414
+ or configs.embedding.api_endpoint
415
+ or self.llm_api_endpoint
416
+ or configs.llm.api_endpoint
417
+ )
418
+ .unicode_string()
419
+ .rstrip("/"),
420
+ # seems embedding function takes no base_url end with a slash
421
+ )
422
+
423
+ @overload
424
+ async def vectorize(self, input_text: List[str], **kwargs: Unpack[EmbeddingKwargs]) -> List[List[float]]: ...
425
+ @overload
426
+ async def vectorize(self, input_text: str, **kwargs: Unpack[EmbeddingKwargs]) -> List[float]: ...
496
427
 
497
- def hold_to(self, others: Union["LLMUsage", Iterable["LLMUsage"]]) -> Self:
498
- """Hold to another instance's attribute values if the current instance's attributes are None.
428
+ async def vectorize(
429
+ self, input_text: List[str] | str, **kwargs: Unpack[EmbeddingKwargs]
430
+ ) -> List[List[float]] | List[float]:
431
+ """Asynchronously generates vector embeddings for the given input text.
499
432
 
500
433
  Args:
501
- others (LLMUsage | Iterable[LLMUsage]): Another instance or iterable of instances from which to copy attribute values.
434
+ input_text (List[str] | str): A string or list of strings to generate embeddings for.
435
+ **kwargs (Unpack[EmbeddingKwargs]): Additional keyword arguments for embedding.
502
436
 
503
437
  Returns:
504
- Self: The current instance, allowing for method chaining.
438
+ List[List[float]] | List[float]: The generated embeddings.
505
439
  """
506
- for other in others:
507
- # noinspection PyTypeChecker,PydanticTypeChecker
508
- for attr_name in LLMUsage.model_fields:
509
- if (attr := getattr(self, attr_name)) is not None and getattr(other, attr_name) is None:
510
- setattr(other, attr_name, attr)
440
+ if isinstance(input_text, str):
441
+ return (await self.aembedding([input_text], **kwargs)).data[0].get("embedding")
442
+
443
+ return [o.get("embedding") for o in (await self.aembedding(input_text, **kwargs)).data]
511
444
 
512
445
 
513
446
  class ToolBoxUsage(LLMUsage):
@@ -1,5 +1,6 @@
1
1
  """A module containing utility classes for the models."""
2
2
 
3
+ from enum import Enum
3
4
  from typing import Any, Dict, List, Literal, Optional, Self
4
5
 
5
6
  from pydantic import BaseModel, ConfigDict, Field
@@ -125,3 +126,21 @@ class MilvusData(BaseModel):
125
126
  """
126
127
  self.id = new_id
127
128
  return self
129
+
130
+
131
+ class TaskStatus(Enum):
132
+ """An enumeration representing the status of a task.
133
+
134
+ Attributes:
135
+ Pending: The task is pending.
136
+ Running: The task is currently running.
137
+ Finished: The task has been successfully completed.
138
+ Failed: The task has failed.
139
+ Cancelled: The task has been cancelled.
140
+ """
141
+
142
+ Pending = "pending"
143
+ Running = "running"
144
+ Finished = "finished"
145
+ Failed = "failed"
146
+ Cancelled = "cancelled"
fabricatio/parser.py CHANGED
@@ -1,9 +1,10 @@
1
1
  """A module to parse text using regular expressions."""
2
2
 
3
- from typing import Any, Callable, Self, Tuple
3
+ from typing import Any, Callable, Optional, Self, Tuple, Type
4
4
 
5
+ import orjson
5
6
  import regex
6
- from pydantic import BaseModel, ConfigDict, Field, PositiveInt, PrivateAttr
7
+ from pydantic import BaseModel, ConfigDict, Field, PositiveInt, PrivateAttr, ValidationError
7
8
  from regex import Pattern, compile
8
9
 
9
10
  from fabricatio.journal import logger
@@ -27,11 +28,7 @@ class Capture(BaseModel):
27
28
  _compiled: Pattern = PrivateAttr()
28
29
 
29
30
  def model_post_init(self, __context: Any) -> None:
30
- """Initialize the compiled regular expression pattern after the model is initialized.
31
-
32
- Args:
33
- __context (Any): The context in which the model is initialized.
34
- """
31
+ """Initialize the compiled pattern."""
35
32
  self._compiled = compile(self.pattern, self.flags)
36
33
 
37
34
  def capture(self, text: str) -> Tuple[str, ...] | str | None:
@@ -70,10 +67,40 @@ class Capture(BaseModel):
70
67
  return None
71
68
  try:
72
69
  return convertor(cap)
73
- except (ValueError, SyntaxError) as e:
70
+ except (ValueError, SyntaxError, ValidationError) as e:
74
71
  logger.error(f"Failed to convert text using {convertor.__name__} to convert.\nerror: {e}\n {cap}")
75
72
  return None
76
73
 
74
+ def validate_with[K, T, E](
75
+ self,
76
+ text: str,
77
+ target_type: Type[T],
78
+ elements_type: Optional[Type[E]] = None,
79
+ length: Optional[int] = None,
80
+ deserializer: Callable[[Tuple[str, ...]], K] | Callable[[str], K] = orjson.loads,
81
+ ) -> T | None:
82
+ """Validate the given text using the pattern.
83
+
84
+ Args:
85
+ text (str): The text to search the pattern in.
86
+ target_type (Type[T]): The expected type of the output, dict or list.
87
+ elements_type (Optional[Type[E]]): The expected type of the elements in the output dict keys or list elements.
88
+ length (Optional[int]): The expected length of the output, bool(length)==False means no length validation.
89
+ deserializer (Callable[[Tuple[str, ...]], K] | Callable[[str], K]): The function to deserialize the captured text.
90
+
91
+ Returns:
92
+ T | None: The validated text if the pattern is found and the output is of the expected type, otherwise None.
93
+ """
94
+ judges = [lambda output_obj: isinstance(output_obj, target_type)]
95
+ if elements_type:
96
+ judges.append(lambda output_obj: all(isinstance(e, elements_type) for e in output_obj))
97
+ if length:
98
+ judges.append(lambda output_obj: len(output_obj) == length)
99
+
100
+ if (out := self.convert_with(text, deserializer)) and all(j(out) for j in judges):
101
+ return out
102
+ return None
103
+
77
104
  @classmethod
78
105
  def capture_code_block(cls, language: str) -> Self:
79
106
  """Capture the first occurrence of a code block in the given text.