lm-deluge 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lm-deluge might be problematic. Click here for more details.

@@ -0,0 +1,446 @@
1
+ import re
2
+ import numpy as np
3
+ from typing import TypedDict, Optional, Callable
4
+
5
+
6
+ class TopLogprob(TypedDict):
7
+ token: str
8
+ logprob: float
9
+ bytes: list[int]
10
+
11
+
12
+ class LogprobEntry(TypedDict):
13
+ token: str
14
+ logprob: float
15
+ bytes: list[int]
16
+ top_logprobs: list[TopLogprob]
17
+
18
+
19
+ Logprobs = list[LogprobEntry]
20
+
21
+ ## In our implementation of APIResponse, the 'logprobs' field contains
22
+ ## just the 'content' field from the respons.choices[0].logprobs object.
23
+ # {
24
+ # "id": "chatcmpl-A6izyp6wnlEv6SLAb0ehIwBqCDLyR",
25
+ # "object": "chat.completion",
26
+ # "created": 1726166306,
27
+ # "model": "gpt-4o-mini-2024-07-18",
28
+ # "choices": [
29
+ # {
30
+ # "index": 0,
31
+ # "message": {
32
+ # "role": "assistant",
33
+ # "content": "A loop within loops,",
34
+ # "refusal": null
35
+ # },
36
+ # "logprobs": {
37
+ # "content": [
38
+ # {
39
+ # "token": "A",
40
+ # "logprob": -1.0330456,
41
+ # "bytes": [
42
+ # 65
43
+ # ],
44
+ # "top_logprobs": [
45
+ # {
46
+ # "token": "A",
47
+ # "logprob": -1.0330456,
48
+ # "bytes": [
49
+ # 65
50
+ # ]
51
+ # },
52
+ # {
53
+ # "token": "In",
54
+ # "logprob": -2.0330458,
55
+ # "bytes": [
56
+ # 73,
57
+ # 110
58
+ # ]
59
+ # },
60
+ # {
61
+ # "token": "Nested",
62
+ # "logprob": -2.0330458,
63
+ # "bytes": [
64
+ # 78,
65
+ # 101,
66
+ # 115,
67
+ # 116,
68
+ # 101,
69
+ # 100
70
+ # ]
71
+ # },
72
+ # {
73
+ # "token": "Function",
74
+ # "logprob": -2.7830458,
75
+ # "bytes": [
76
+ # 70,
77
+ # 117,
78
+ # 110,
79
+ # 99,
80
+ # 116,
81
+ # 105,
82
+ # 111,
83
+ # 110
84
+ # ]
85
+ # },
86
+ # {
87
+ # "token": "Layers",
88
+ # "logprob": -3.1580458,
89
+ # "bytes": [
90
+ # 76,
91
+ # 97,
92
+ # 121,
93
+ # 101,
94
+ # 114,
95
+ # 115
96
+ # ]
97
+ # }
98
+ # ]
99
+ # },
100
+ # {
101
+ # "token": " loop",
102
+ # "logprob": -2.909274,
103
+ # "bytes": [
104
+ # 32,
105
+ # 108,
106
+ # 111,
107
+ # 111,
108
+ # 112
109
+ # ],
110
+ # "top_logprobs": [
111
+ # {
112
+ # "token": " function",
113
+ # "logprob": -0.9092741,
114
+ # "bytes": [
115
+ # 32,
116
+ # 102,
117
+ # 117,
118
+ # 110,
119
+ # 99,
120
+ # 116,
121
+ # 105,
122
+ # 111,
123
+ # 110
124
+ # ]
125
+ # },
126
+ # {
127
+ # "token": " call",
128
+ # "logprob": -1.0342741,
129
+ # "bytes": [
130
+ # 32,
131
+ # 99,
132
+ # 97,
133
+ # 108,
134
+ # 108
135
+ # ]
136
+ # },
137
+ # {
138
+ # "token": " task",
139
+ # "logprob": -2.409274,
140
+ # "bytes": [
141
+ # 32,
142
+ # 116,
143
+ # 97,
144
+ # 115,
145
+ # 107
146
+ # ]
147
+ # },
148
+ # {
149
+ # "token": " loop",
150
+ # "logprob": -2.909274,
151
+ # "bytes": [
152
+ # 32,
153
+ # 108,
154
+ # 111,
155
+ # 111,
156
+ # 112
157
+ # ]
158
+ # },
159
+ # {
160
+ # "token": " problem",
161
+ # "logprob": -4.034274,
162
+ # "bytes": [
163
+ # 32,
164
+ # 112,
165
+ # 114,
166
+ # 111,
167
+ # 98,
168
+ # 108,
169
+ # 101,
170
+ # 109
171
+ # ]
172
+ # }
173
+ # ]
174
+ # },
175
+ # {
176
+ # "token": " within",
177
+ # "logprob": -0.09628018,
178
+ # "bytes": [
179
+ # 32,
180
+ # 119,
181
+ # 105,
182
+ # 116,
183
+ # 104,
184
+ # 105,
185
+ # 110
186
+ # ],
187
+ # "top_logprobs": [
188
+ # {
189
+ # "token": " within",
190
+ # "logprob": -0.09628018,
191
+ # "bytes": [
192
+ # 32,
193
+ # 119,
194
+ # 105,
195
+ # 116,
196
+ # 104,
197
+ # 105,
198
+ # 110
199
+ # ]
200
+ # },
201
+ # {
202
+ # "token": " in",
203
+ # "logprob": -2.72128,
204
+ # "bytes": [
205
+ # 32,
206
+ # 105,
207
+ # 110
208
+ # ]
209
+ # },
210
+ # {
211
+ # "token": " of",
212
+ # "logprob": -4.47128,
213
+ # "bytes": [
214
+ # 32,
215
+ # 111,
216
+ # 102
217
+ # ]
218
+ # },
219
+ # {
220
+ # "token": " that",
221
+ # "logprob": -5.34628,
222
+ # "bytes": [
223
+ # 32,
224
+ # 116,
225
+ # 104,
226
+ # 97,
227
+ # 116
228
+ # ]
229
+ # },
230
+ # {
231
+ # "token": " inside",
232
+ # "logprob": -5.59628,
233
+ # "bytes": [
234
+ # 32,
235
+ # 105,
236
+ # 110,
237
+ # 115,
238
+ # 105,
239
+ # 100,
240
+ # 101
241
+ # ]
242
+ # }
243
+ # ]
244
+ # },
245
+ # {
246
+ # "token": " loops",
247
+ # "logprob": -0.12761699,
248
+ # "bytes": [
249
+ # 32,
250
+ # 108,
251
+ # 111,
252
+ # 111,
253
+ # 112,
254
+ # 115
255
+ # ],
256
+ # "top_logprobs": [
257
+ # {
258
+ # "token": " loops",
259
+ # "logprob": -0.12761699,
260
+ # "bytes": [
261
+ # 32,
262
+ # 108,
263
+ # 111,
264
+ # 111,
265
+ # 112,
266
+ # 115
267
+ # ]
268
+ # },
269
+ # {
270
+ # "token": " self",
271
+ # "logprob": -3.127617,
272
+ # "bytes": [
273
+ # 32,
274
+ # 115,
275
+ # 101,
276
+ # 108,
277
+ # 102
278
+ # ]
279
+ # },
280
+ # {
281
+ # "token": " loop",
282
+ # "logprob": -3.627617,
283
+ # "bytes": [
284
+ # 32,
285
+ # 108,
286
+ # 111,
287
+ # 111,
288
+ # 112
289
+ # ]
290
+ # },
291
+ # {
292
+ # "token": " calls",
293
+ # "logprob": -4.377617,
294
+ # "bytes": [
295
+ # 32,
296
+ # 99,
297
+ # 97,
298
+ # 108,
299
+ # 108,
300
+ # 115
301
+ # ]
302
+ # },
303
+ # {
304
+ # "token": " itself",
305
+ # "logprob": -4.877617,
306
+ # "bytes": [
307
+ # 32,
308
+ # 105,
309
+ # 116,
310
+ # 115,
311
+ # 101,
312
+ # 108,
313
+ # 102
314
+ # ]
315
+ # }
316
+ # ]
317
+ # },
318
+ # {
319
+ # "token": ",",
320
+ # "logprob": -1.7432603e-6,
321
+ # "bytes": [
322
+ # 44
323
+ # ],
324
+ # "top_logprobs": [
325
+ # {
326
+ # "token": ",",
327
+ # "logprob": -1.7432603e-6,
328
+ # "bytes": [
329
+ # 44
330
+ # ]
331
+ # },
332
+ # {
333
+ # "token": " \n",
334
+ # "logprob": -13.875002,
335
+ # "bytes": [
336
+ # 32,
337
+ # 32,
338
+ # 10
339
+ # ]
340
+ # },
341
+ # {
342
+ # "token": "—",
343
+ # "logprob": -14.750002,
344
+ # "bytes": [
345
+ # 226,
346
+ # 128,
347
+ # 148
348
+ # ]
349
+ # },
350
+ # {
351
+ # "token": ",\n",
352
+ # "logprob": -15.000002,
353
+ # "bytes": [
354
+ # 44,
355
+ # 10
356
+ # ]
357
+ # },
358
+ # {
359
+ # "token": ";",
360
+ # "logprob": -17.375002,
361
+ # "bytes": [
362
+ # 59
363
+ # ]
364
+ # }
365
+ # ]
366
+ # }
367
+ # ],
368
+ # "refusal": null
369
+ # },
370
+ # "finish_reason": "length"
371
+ # }
372
+ # ],
373
+ # "usage": {
374
+ # "prompt_tokens": 28,
375
+ # "completion_tokens": 5,
376
+ # "total_tokens": 33
377
+ # },
378
+ # "system_fingerprint": "fp_483d39d857"
379
+ # }
380
+
381
+
382
+ def normalize_token(token: str):
383
+ return re.sub(r"[^a-z]", "", token.lower())
384
+
385
+
386
+ def is_match(token1: str, token2: str):
387
+ token1 = normalize_token(token1)
388
+ token2 = normalize_token(token2)
389
+ if token1 == token2:
390
+ return True
391
+ elif token1.startswith(token2):
392
+ return True
393
+ elif token2.startswith(token1):
394
+ return True
395
+ else:
396
+ return False
397
+
398
+
399
+ def extract_prob(
400
+ token: str,
401
+ logprobs: Logprobs,
402
+ use_top_logprobs: bool = False,
403
+ normalize_top_logprobs: bool = True, # if using top_logprobs, normalize by all the present tokens so they add up to 1
404
+ use_complement: bool = False, # if True, assume there's 2 choices, and return 1 - p if the top token doesn't match
405
+ token_index: int = 0, # get from the first token of the completion by default
406
+ token_match_fn: Optional[Callable[[str, str], bool]] = is_match,
407
+ ):
408
+ """
409
+ Extract the probability of the token from the logprobs object of a single
410
+ completion.
411
+ """
412
+ # ensure the token_index is valid
413
+ if token_index >= len(logprobs):
414
+ raise ValueError("token_index must be less than the length of logprobs.")
415
+ entry: LogprobEntry = logprobs[token_index]
416
+ # if using top_logprobs, ensure that at least one top_logprob is present
417
+ if use_top_logprobs:
418
+ if entry.get("top_logprobs", None) is None or len(entry["top_logprobs"]) == 0:
419
+ raise ValueError(
420
+ "top_logprobs must be present in logprobs to use top_logprobs=True."
421
+ )
422
+ top_tokens = [t["token"] for t in entry["top_logprobs"]]
423
+ top_probs = [np.exp(t["logprob"]) for t in entry["top_logprobs"]]
424
+ combined_prob = sum(
425
+ [p for t, p in zip(top_tokens, top_probs) if is_match(t, token)]
426
+ )
427
+
428
+ if normalize_top_logprobs:
429
+ # no point in using complement if normalizing; it will always be 0 if not present
430
+ return combined_prob / sum(top_probs)
431
+ elif combined_prob > 0:
432
+ return combined_prob
433
+ elif use_complement:
434
+ return 1 - combined_prob
435
+ else:
436
+ return 0.0
437
+
438
+ else:
439
+ top_token = entry["token"]
440
+ top_prob = np.exp(entry["logprob"])
441
+ if is_match(top_token, token):
442
+ return top_prob
443
+ elif use_complement:
444
+ return 1 - top_prob
445
+ else:
446
+ return 0.0
lm_deluge/util/pdf.py ADDED
@@ -0,0 +1,45 @@
1
+ import io
2
+
3
+
4
+ def text_from_pdf(pdf: str | bytes | io.BytesIO):
5
+ """
6
+ Extract text from a PDF. Does NOT use OCR, extracts the literal text.
7
+ The source can be:
8
+ - A file path (str)
9
+ - Bytes of a PDF file
10
+ - A BytesIO object containing a PDF file
11
+ """
12
+ try:
13
+ import pymupdf # pyright: ignore
14
+ except ImportError:
15
+ raise ImportError(
16
+ "pymupdf is required to extract text from PDFs. Install lm_deluge[pdf] or lm_deluge[full]."
17
+ )
18
+ if isinstance(pdf, str):
19
+ # It's a file path
20
+ doc = pymupdf.open(pdf)
21
+ elif isinstance(pdf, (bytes, io.BytesIO)):
22
+ # It's bytes or a BytesIO object
23
+ if isinstance(pdf, bytes):
24
+ pdf = io.BytesIO(pdf)
25
+ doc = pymupdf.open(stream=pdf, filetype="pdf")
26
+ else:
27
+ raise ValueError("Unsupported pdf_source type. Must be str, bytes, or BytesIO.")
28
+
29
+ text_content = []
30
+ for page in doc:
31
+ blocks = page.get_text("blocks", sort=True)
32
+ for block in blocks:
33
+ # block[4] contains the text content
34
+ text_content.append(block[4].strip())
35
+ text_content.append("\n") # Add extra newlines between blocks
36
+
37
+ # Join all text content with newlines
38
+ full_text = "\n".join(text_content).strip()
39
+ # Replace multiple consecutive spaces with a single space
40
+ full_text = " ".join(full_text.split())
41
+ # Clean up any resulting double spaces or newlines
42
+ full_text = " ".join([x for x in full_text.split(" ") if x])
43
+ full_text = "\n".join([x for x in full_text.split("\n") if x])
44
+
45
+ return full_text
@@ -0,0 +1,46 @@
1
+ from pydantic import BaseModel, ValidationError
2
+ from .json import load_json
3
+ from .xml import get_tag, xml_to_object
4
+
5
+
6
+ def get_model_from_json(
7
+ json_string: str,
8
+ model_class: BaseModel,
9
+ ) -> BaseModel:
10
+ try:
11
+ model_dict = load_json(json_string)
12
+ return model_class(**model_dict) # pyright: ignore
13
+ except ValidationError as ve:
14
+ # Handle validation errors if necessary
15
+ raise ve
16
+
17
+
18
+ def get_model_from_xml(xml_string: str, model_class: BaseModel, shallow: bool = True):
19
+ """
20
+ Convert an XML string to a Pydantic model.
21
+ If shallow is True, we don't try to parse the whole XML tree
22
+ into a Python object, we just try to extract each key's tag
23
+ with regex and fill the model's fields in that way.
24
+ """
25
+ if shallow:
26
+ # iterate over the fields of the model
27
+ model_dict = {}
28
+ for field_name, field_info in model_class.__fields__.items():
29
+ val = get_tag(xml_string, field_name)
30
+ if val is not None:
31
+ # no nested models for 'shallow' mode
32
+ model_dict[field_name] = val
33
+
34
+ try:
35
+ return model_class(**model_dict) # pyright: ignore
36
+ except ValidationError as ve:
37
+ # Handle validation errors if necessary
38
+ raise ve
39
+ else:
40
+ # use helper to parse the whole tree
41
+ model_dict = xml_to_object(xml_string)
42
+ try:
43
+ return model_class(**model_dict) # pyright: ignore
44
+ except ValidationError as ve:
45
+ # Handle validation errors if necessary
46
+ raise ve