lm-deluge 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lm-deluge might be problematic. Click here for more details.

@@ -0,0 +1,111 @@
1
+ import io
2
+ import json
3
+ from ..prompt import Conversation
4
+ import asyncio
5
+ from ..client import LLMClient
6
+ from typing import Optional, Any
7
+ from ..util.json import load_json
8
+
9
+ try:
10
+ from PIL import Image as PILImage
11
+ except ImportError:
12
+ PILImage = None
13
+
14
+
15
+ async def extract_async(
16
+ inputs: list[str | Any],
17
+ schema: Any,
18
+ client: LLMClient,
19
+ document_name: Optional[str] = None,
20
+ object_name: Optional[str] = None,
21
+ show_progress: bool = True,
22
+ return_prompts: bool = False,
23
+ ):
24
+ if hasattr(schema, "model_json_schema"):
25
+ schema_dict = schema.model_json_schema()
26
+ elif isinstance(schema, dict):
27
+ schema_dict = schema
28
+ else:
29
+ raise ValueError("schema must be a pydantic model or a dict.")
30
+
31
+ # warn if json_mode is not True
32
+ for sp in client.sampling_params:
33
+ if sp.json_mode is False:
34
+ print(
35
+ "Warning: json_mode is False for one or more sampling params. You may get invalid output."
36
+ )
37
+ break
38
+ # check_schema(schema_dict) -- figure out later
39
+ if document_name is None:
40
+ document_name = "text"
41
+ if object_name is None:
42
+ object_name = ""
43
+ else:
44
+ object_name += " "
45
+
46
+ text_only_prompt = (
47
+ f"Given the following {document_name}, extract the {object_name}information "
48
+ + "from it according to the following JSON schema:\n\n```json\n"
49
+ + json.dumps(schema_dict, indent=2)
50
+ + "```\n\nHere is the {document_name}:\n\n```\n{<<__REPLACE_WITH_TEXT__>>}\n```"
51
+ + "Return the extracted information as JSON, no explanation required. "
52
+ + f"If the {document_name} seems to be totally irrelevant to the schema (not just incomplete), you may return a JSON object "
53
+ + 'like `{"error": "The document is not relevant to the schema."}`.'
54
+ )
55
+
56
+ image_only_prompt = (
57
+ f"Given the attached {document_name} image, extract the {object_name}information "
58
+ + "from it according to the following JSON schema:\n\n```json\n"
59
+ + json.dumps(schema_dict, indent=2)
60
+ + "Return the extracted information as JSON, no explanation required. "
61
+ + f"If the {document_name} seems to be totally irrelevant to the schema (not just incomplete), you may return a JSON object "
62
+ + 'like `{"error": "The document is not relevant to the schema."}`.'
63
+ )
64
+
65
+ prompts = []
66
+ for input in inputs:
67
+ if isinstance(input, str):
68
+ prompts.append(
69
+ text_only_prompt.replace("{<<__REPLACE_WITH_TEXT__>>}", input)
70
+ )
71
+ elif PILImage is not None and isinstance(input, PILImage.Image):
72
+ buffer = io.BytesIO()
73
+ input.save(buffer, format="PNG")
74
+ prompts.append(
75
+ Conversation.user(text=image_only_prompt, image=buffer.getvalue())
76
+ )
77
+ else:
78
+ raise ValueError("inputs must be a list of strings or PIL images.")
79
+
80
+ if return_prompts:
81
+ return prompts
82
+
83
+ resps = await client.process_prompts_async(prompts, show_progress=show_progress)
84
+ completions = [
85
+ load_json(resp.completion) if (resp and resp.completion) else None
86
+ for resp in resps
87
+ ]
88
+
89
+ return completions
90
+
91
+
92
+ def extract(
93
+ inputs: list[str | Any],
94
+ schema: Any,
95
+ client: LLMClient,
96
+ document_name: Optional[str] = None,
97
+ object_name: Optional[str] = None,
98
+ show_progress: bool = True,
99
+ return_prompts: bool = False,
100
+ ):
101
+ return asyncio.run(
102
+ extract_async(
103
+ inputs,
104
+ schema,
105
+ client,
106
+ document_name,
107
+ object_name,
108
+ show_progress,
109
+ return_prompts,
110
+ )
111
+ )
@@ -0,0 +1,71 @@
1
+ from ..client import LLMClient, APIResponse
2
+ from ..util.logprobs import extract_prob
3
+
4
+ # def extract_prob_yes(logprobs: list[dict]):
5
+ # """
6
+ # Extract the log probability of the token "yes" from the logprobs object.
7
+ # Since we can't rely on "yes" and "no" both being in the top_logprobs,
8
+ # we do the following:
9
+ # - if token is "yes", return p(yes)
10
+ # - if token is "no", return 1 - p(no)
11
+ # - otherwise, return 0.5
12
+ # """
13
+ # # use regexp to keep only alpha characters
14
+ # top_token = logprobs[0]["token"].lower()
15
+ # top_token = re.sub(r"[^a-z]", "", top_token)
16
+ # if top_token == "yes":
17
+ # return np.exp(logprobs[0]["logprob"])
18
+ # elif top_token == "no":
19
+ # return 1 - np.exp(logprobs[0]["logprob"])
20
+ # else:
21
+ # return 0.5
22
+
23
+
24
+ def score_llm(
25
+ scoring_prompt_template: str,
26
+ inputs: list[tuple | list | dict], # to format the template
27
+ scoring_model: LLMClient,
28
+ return_probabilities: bool,
29
+ yes_token: str = "yes",
30
+ ) -> list[bool | None] | list[float | None]:
31
+ if return_probabilities:
32
+ if not hasattr(scoring_model, "logprobs") or not scoring_model.logprobs:
33
+ raise ValueError(
34
+ "return_probabilities=True requires scoring_model to have logprobs=True. "
35
+ "you may need to upgrade lm_deluge to have access to this option."
36
+ )
37
+
38
+ if scoring_prompt_template is None:
39
+ raise ValueError("scoring_prompt must be provided.")
40
+
41
+ scoring_prompts = []
42
+ for inp in inputs:
43
+ if isinstance(inp, dict):
44
+ scoring_prompt = scoring_prompt_template.format(**inp)
45
+ elif isinstance(inp, tuple) or isinstance(inp, list):
46
+ scoring_prompt = scoring_prompt_template.format(*inp)
47
+ else:
48
+ raise ValueError("inputs must be a list of tuples, lists, or dicts.")
49
+ scoring_prompts.append(scoring_prompt)
50
+
51
+ resps: list[APIResponse] = scoring_model.process_prompts_sync( # pyright: ignore
52
+ prompts=scoring_prompts,
53
+ show_progress=False,
54
+ )
55
+
56
+ if return_probabilities:
57
+ logprobs_list = [resp.logprobs for resp in resps]
58
+ scores = [
59
+ extract_prob(yes_token, logprobs, use_complement=True)
60
+ if logprobs is not None
61
+ else None
62
+ for logprobs in logprobs_list
63
+ ]
64
+ else:
65
+ completions = [resp.completion for resp in resps]
66
+ scores = [
67
+ yes_token.lower().strip() in c.lower() if c is not None else None
68
+ for c in completions
69
+ ]
70
+
71
+ return scores # p(yes) or bool yes/no
@@ -0,0 +1,44 @@
1
+ import asyncio
2
+ from ..client import LLMClient
3
+
4
+ translation_prompt = (
5
+ "Translate the following text (enclosed in ```) into English. "
6
+ "Reply with only the translation. Text:\n\n```\n{}\n\n\nYour translation:"
7
+ )
8
+
9
+
10
+ def is_english(text: str, low_memory: bool = True):
11
+ try:
12
+ from ftlangdetect import detect # pyright: ignore
13
+
14
+ lang = detect(text.replace("\n", " "), low_memory=low_memory)["lang"]
15
+ return lang == "en"
16
+ except ImportError:
17
+ print(
18
+ "fasttext-langdetect is recommended to use the translate tool, will assume everything is english"
19
+ )
20
+ return True
21
+
22
+
23
+ async def translate_async(texts: list[str], client: LLMClient, low_memory: bool = True):
24
+ to_translate_idxs = [
25
+ i for i, text in enumerate(texts) if not is_english(text, low_memory=low_memory)
26
+ ]
27
+ if len(to_translate_idxs) == 0:
28
+ return texts
29
+
30
+ prompts = [translation_prompt.format(texts[i]) for i in to_translate_idxs]
31
+ resps = await client.process_prompts_async(prompts)
32
+ translations = [
33
+ resp.completion.strip() if (resp and resp.completion is not None) else None
34
+ for resp in resps
35
+ ]
36
+ for i, translation in zip(to_translate_idxs, translations):
37
+ if translation:
38
+ texts[i] = translation
39
+
40
+ return texts
41
+
42
+
43
+ def translate(texts: list[str], client: LLMClient, low_memory: bool = True):
44
+ return asyncio.run(translate_async(texts, client, low_memory))