speedy-utils 1.0.5__py3-none-any.whl → 1.0.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -18,7 +18,9 @@ from typing import (
18
18
  )
19
19
 
20
20
  from httpx import URL
21
+ from huggingface_hub import repo_info
21
22
  from loguru import logger
23
+ from numpy import isin
22
24
  from openai import OpenAI, AuthenticationError, RateLimitError
23
25
  from openai.pagination import SyncPage
24
26
  from openai.types.chat import (
@@ -42,6 +44,29 @@ LegacyMsgs = List[Dict[str, str]] # old “…role/content…” dicts
42
44
  RawMsgs = Union[Messages, LegacyMsgs] # what __call__ accepts
43
45
 
44
46
 
47
+ # --------------------------------------------------------------------------- #
48
+ # color formatting helpers
49
+ # --------------------------------------------------------------------------- #
50
+ def _red(text: str) -> str:
51
+ """Format text with red color."""
52
+ return f"\x1b[31m{text}\x1b[0m"
53
+
54
+
55
+ def _green(text: str) -> str:
56
+ """Format text with green color."""
57
+ return f"\x1b[32m{text}\x1b[0m"
58
+
59
+
60
+ def _blue(text: str) -> str:
61
+ """Format text with blue color."""
62
+ return f"\x1b[34m{text}\x1b[0m"
63
+
64
+
65
+ def _yellow(text: str) -> str:
66
+ """Format text with yellow color."""
67
+ return f"\x1b[33m{text}\x1b[0m"
68
+
69
+
45
70
  class LM:
46
71
  """
47
72
  Unified language-model wrapper.
@@ -60,7 +85,7 @@ class LM:
60
85
  temperature: float = 0.0,
61
86
  max_tokens: int = 2_000,
62
87
  host: str = "localhost",
63
- port: Optional[int] = None,
88
+ port: Optional[int | str] = None,
64
89
  base_url: Optional[str] = None,
65
90
  api_key: Optional[str] = None,
66
91
  cache: bool = True,
@@ -90,6 +115,7 @@ class LM:
90
115
  prompt: str | None = ...,
91
116
  messages: RawMsgs | None = ...,
92
117
  response_format: type[str] = str,
118
+ return_openai_response: bool = ...,
93
119
  **kwargs: Any,
94
120
  ) -> str: ...
95
121
 
@@ -100,6 +126,7 @@ class LM:
100
126
  prompt: str | None = ...,
101
127
  messages: RawMsgs | None = ...,
102
128
  response_format: Type[TModel],
129
+ return_openai_response: bool = ...,
103
130
  **kwargs: Any,
104
131
  ) -> TModel: ...
105
132
 
@@ -111,6 +138,7 @@ class LM:
111
138
  response_format: Union[type[str], Type[BaseModel]] = str,
112
139
  cache: Optional[bool] = None,
113
140
  max_tokens: Optional[int] = None,
141
+ return_openai_response: bool = False,
114
142
  **kwargs: Any,
115
143
  ):
116
144
  # argument validation ------------------------------------------------
@@ -132,17 +160,117 @@ class LM:
132
160
  self.openai_kwargs,
133
161
  temperature=self.temperature,
134
162
  max_tokens=max_tokens or self.max_tokens,
135
- **kwargs,
136
163
  )
164
+ kw.update(kwargs)
137
165
  use_cache = self.do_cache if cache is None else cache
138
166
 
139
- raw = self._call_raw(
167
+ raw_response = self._call_raw(
140
168
  openai_msgs,
141
169
  response_format=response_format,
142
170
  use_cache=use_cache,
143
171
  **kw,
144
172
  )
145
- return self._parse_output(raw, response_format)
173
+
174
+ if return_openai_response:
175
+ response = raw_response
176
+ else:
177
+ response = self._parse_output(raw_response, response_format)
178
+
179
+ self.last_log = [prompt, messages, raw_response]
180
+ return response
181
+
182
+ def inspect_history(self) -> None:
183
+ if not hasattr(self, "last_log"):
184
+ raise ValueError("No history available. Please call the model first.")
185
+
186
+ prompt, messages, response = self.last_log
187
+ # Ensure response is a dictionary
188
+ if hasattr(response, "model_dump"):
189
+ response = response.model_dump()
190
+
191
+ if not messages:
192
+ messages = [{"role": "user", "content": prompt}]
193
+
194
+ print("\n\n")
195
+ print(_blue("[Conversation History]") + "\n")
196
+
197
+ # Print all messages in the conversation
198
+ for msg in messages:
199
+ role = msg["role"]
200
+ content = msg["content"]
201
+ print(_red(f"{role.capitalize()}:"))
202
+
203
+ if isinstance(content, str):
204
+ print(content.strip())
205
+ elif isinstance(content, list):
206
+ # Handle multimodal content
207
+ for item in content:
208
+ if item.get("type") == "text":
209
+ print(item["text"].strip())
210
+ elif item.get("type") == "image_url":
211
+ image_url = item["image_url"]["url"]
212
+ if "base64" in image_url:
213
+ len_base64 = len(image_url.split("base64,")[1])
214
+ print(_blue(f"<IMAGE BASE64 ENCODED({len_base64})>"))
215
+ else:
216
+ print(_blue(f"<image_url: {image_url}>"))
217
+ print("\n")
218
+
219
+ # Print the response - now always an OpenAI completion
220
+ print(_red("Response:"))
221
+
222
+ # Handle OpenAI response object
223
+ if isinstance(response, dict) and 'choices' in response and response['choices']:
224
+ message = response['choices'][0].get('message', {})
225
+
226
+ # Check for reasoning content (if available)
227
+ reasoning = message.get('reasoning_content')
228
+
229
+ # Check for parsed content (structured mode)
230
+ parsed = message.get('parsed')
231
+
232
+ # Get regular content
233
+ content = message.get('content')
234
+
235
+ # Display reasoning if available
236
+ if reasoning:
237
+ print(_yellow('<think>'))
238
+ print(reasoning.strip())
239
+ print(_yellow('</think>'))
240
+ print()
241
+
242
+ # Display parsed content for structured responses
243
+ if parsed:
244
+ # print(_green('<Parsed Structure>'))
245
+ if hasattr(parsed, 'model_dump'):
246
+ print(json.dumps(parsed.model_dump(), indent=2))
247
+ else:
248
+ print(json.dumps(parsed, indent=2))
249
+ # print(_green('</Parsed Structure>'))
250
+ print()
251
+
252
+ else:
253
+ if content:
254
+ # print(_green("<Content>"))
255
+ print(content.strip())
256
+ # print(_green("</Content>"))
257
+ else:
258
+ print(_green("[No content]"))
259
+
260
+ # Show if there were multiple completions
261
+ if len(response['choices']) > 1:
262
+ print(_blue(f"\n(Plus {len(response['choices']) - 1} other completions)"))
263
+ else:
264
+ # Fallback for non-standard response objects or cached responses
265
+ print(_yellow("Warning: Not a standard OpenAI response object"))
266
+ if isinstance(response, str):
267
+ print(_green(response.strip()))
268
+ elif isinstance(response, dict):
269
+ print(_green(json.dumps(response, indent=2)))
270
+ else:
271
+ print(_green(str(response)))
272
+
273
+ # print("\n\n")
146
274
 
147
275
  # --------------------------------------------------------------------- #
148
276
  # low-level OpenAI call
@@ -156,8 +284,11 @@ class LM:
156
284
  ):
157
285
  assert self.model is not None, "Model must be set before making a call."
158
286
  model: str = self.model
287
+
159
288
  cache_key = (
160
- self._cache_key(messages, kw, response_format) if use_cache else None
289
+ self._cache_key(messages, kw, response_format)
290
+ if use_cache
291
+ else None
161
292
  )
162
293
  if cache_key and (hit := self._load_cache(cache_key)) is not None:
163
294
  return hit
@@ -165,31 +296,28 @@ class LM:
165
296
  try:
166
297
  # structured mode
167
298
  if response_format is not str and issubclass(response_format, BaseModel):
168
- rsp: ParsedChatCompletion[BaseModel] = (
169
- self.client.beta.chat.completions.parse(
170
- model=model,
171
- messages=list(messages),
172
- response_format=response_format, # type: ignore[arg-type]
173
- **kw,
174
- )
299
+ openai_response = self.client.beta.chat.completions.parse(
300
+ model=model,
301
+ messages=list(messages),
302
+ response_format=response_format, # type: ignore[arg-type]
303
+ **kw,
175
304
  )
176
- result: Any = rsp.choices[0].message.parsed # already a model
177
305
  # plain-text mode
178
306
  else:
179
- rsp = self.client.chat.completions.create(
307
+ openai_response = self.client.chat.completions.create(
180
308
  model=model,
181
309
  messages=list(messages),
182
310
  **kw,
183
311
  )
184
- result = rsp.choices[0].message.content # str
312
+
185
313
  except (AuthenticationError, RateLimitError) as exc: # pragma: no cover
186
314
  logger.error(exc)
187
315
  raise
188
316
 
189
317
  if cache_key:
190
- self._dump_cache(cache_key, result)
318
+ self._dump_cache(cache_key, openai_response)
191
319
 
192
- return result
320
+ return openai_response
193
321
 
194
322
  # --------------------------------------------------------------------- #
195
323
  # legacy → typed messages
@@ -232,36 +360,68 @@ class LM:
232
360
  # --------------------------------------------------------------------- #
233
361
  @staticmethod
234
362
  def _parse_output(
235
- raw: Any,
363
+ raw_response: Any,
236
364
  response_format: Union[type[str], Type[BaseModel]],
237
365
  ) -> str | BaseModel:
366
+ # Convert any object to dict if needed
367
+ if hasattr(raw_response, 'model_dump'):
368
+ raw_response = raw_response.model_dump()
369
+
238
370
  if response_format is str:
239
- return cast(str, raw)
240
-
371
+ # Extract the content from OpenAI response dict
372
+ if isinstance(raw_response, dict) and 'choices' in raw_response:
373
+ message = raw_response['choices'][0]['message']
374
+ return message.get('content', '') or ''
375
+ return cast(str, raw_response)
376
+
241
377
  # For the type-checker: we *know* it's a BaseModel subclass here.
242
378
  model_cls = cast(Type[BaseModel], response_format)
243
379
 
244
- if isinstance(raw, model_cls):
245
- return raw
246
- if isinstance(raw, dict):
247
- return model_cls.model_validate(raw)
380
+ # Handle structured response
381
+ if isinstance(raw_response, dict) and 'choices' in raw_response:
382
+ message = raw_response['choices'][0]['message']
383
+
384
+ # Check if already parsed by OpenAI client
385
+ if 'parsed' in message:
386
+ return model_cls.model_validate(message['parsed'])
387
+
388
+ # Need to parse the content
389
+ content = message.get('content')
390
+ if content is None:
391
+ raise ValueError("Model returned empty content")
392
+
393
+ try:
394
+ data = json.loads(content)
395
+ return model_cls.model_validate(data)
396
+ except Exception as exc:
397
+ raise ValueError(f"Failed to parse model output as JSON:\n{content}") from exc
398
+
399
+ # Handle cached response or other formats
400
+ if isinstance(raw_response, model_cls):
401
+ return raw_response
402
+ if isinstance(raw_response, dict):
403
+ return model_cls.model_validate(raw_response)
404
+
405
+ # Try parsing as JSON string
248
406
  try:
249
- data = json.loads(raw)
250
- except Exception as exc: # noqa: BLE001
251
- raise ValueError(f"Model did not return JSON:\n---\n{raw}") from exc
252
- return model_cls.model_validate(data)
407
+ data = json.loads(raw_response)
408
+ return model_cls.model_validate(data)
409
+ except Exception as exc:
410
+ raise ValueError(f"Model did not return valid JSON:\n---\n{raw_response}") from exc
253
411
 
254
412
  # --------------------------------------------------------------------- #
255
413
  # tiny disk cache
256
414
  # --------------------------------------------------------------------- #
257
415
  @staticmethod
258
416
  def _cache_key(
259
- messages: Any, kw: Any, response_format: Union[type[str], Type[BaseModel]]
417
+ messages: Any,
418
+ kw: Any,
419
+ response_format: Union[type[str], Type[BaseModel]],
260
420
  ) -> str:
261
421
  tag = response_format.__name__ if response_format is not str else "text"
262
422
  blob = json.dumps([messages, kw, tag], sort_keys=True).encode()
263
423
  return base64.urlsafe_b64encode(hashlib.sha256(blob).digest()).decode()[:22]
264
-
424
+
265
425
  @staticmethod
266
426
  def _cache_path(key: str) -> str:
267
427
  return os.path.expanduser(f"~/.cache/lm/{key}.json")
@@ -289,12 +449,12 @@ class LM:
289
449
  return None
290
450
 
291
451
  @staticmethod
292
- def list_models(port=None) -> List[str]:
452
+ def list_models(port=None, host="localhost") -> List[str]:
293
453
  """
294
454
  List available models.
295
455
  """
296
456
  try:
297
- client: OpenAI = LM(port=port).client
457
+ client: OpenAI = LM(port=port, host=host).client
298
458
  base_url: URL = client.base_url
299
459
  logger.debug(f"Base URL: {base_url}")
300
460
  models: SyncPage[Model] = client.models.list()