nous-genai 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. nous/__init__.py +3 -0
  2. nous/genai/__init__.py +56 -0
  3. nous/genai/__main__.py +3 -0
  4. nous/genai/_internal/__init__.py +1 -0
  5. nous/genai/_internal/capability_rules.py +476 -0
  6. nous/genai/_internal/config.py +102 -0
  7. nous/genai/_internal/errors.py +63 -0
  8. nous/genai/_internal/http.py +951 -0
  9. nous/genai/_internal/json_schema.py +54 -0
  10. nous/genai/cli.py +1316 -0
  11. nous/genai/client.py +719 -0
  12. nous/genai/mcp_cli.py +275 -0
  13. nous/genai/mcp_server.py +1080 -0
  14. nous/genai/providers/__init__.py +15 -0
  15. nous/genai/providers/aliyun.py +535 -0
  16. nous/genai/providers/anthropic.py +483 -0
  17. nous/genai/providers/gemini.py +1606 -0
  18. nous/genai/providers/openai.py +1909 -0
  19. nous/genai/providers/tuzi.py +1158 -0
  20. nous/genai/providers/volcengine.py +273 -0
  21. nous/genai/reference/__init__.py +17 -0
  22. nous/genai/reference/catalog.py +206 -0
  23. nous/genai/reference/mappings.py +467 -0
  24. nous/genai/reference/mode_overrides.py +26 -0
  25. nous/genai/reference/model_catalog.py +82 -0
  26. nous/genai/reference/model_catalog_data/__init__.py +1 -0
  27. nous/genai/reference/model_catalog_data/aliyun.py +98 -0
  28. nous/genai/reference/model_catalog_data/anthropic.py +10 -0
  29. nous/genai/reference/model_catalog_data/google.py +45 -0
  30. nous/genai/reference/model_catalog_data/openai.py +44 -0
  31. nous/genai/reference/model_catalog_data/tuzi_anthropic.py +21 -0
  32. nous/genai/reference/model_catalog_data/tuzi_google.py +19 -0
  33. nous/genai/reference/model_catalog_data/tuzi_openai.py +75 -0
  34. nous/genai/reference/model_catalog_data/tuzi_web.py +136 -0
  35. nous/genai/reference/model_catalog_data/volcengine.py +107 -0
  36. nous/genai/tools/__init__.py +13 -0
  37. nous/genai/tools/output_parser.py +119 -0
  38. nous/genai/types.py +416 -0
  39. nous/py.typed +1 -0
  40. nous_genai-0.1.0.dist-info/METADATA +200 -0
  41. nous_genai-0.1.0.dist-info/RECORD +45 -0
  42. nous_genai-0.1.0.dist-info/WHEEL +5 -0
  43. nous_genai-0.1.0.dist-info/entry_points.txt +4 -0
  44. nous_genai-0.1.0.dist-info/licenses/LICENSE +190 -0
  45. nous_genai-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,951 @@
1
+ from __future__ import annotations
2
+
3
+ import errno
4
+ import http.client
5
+ import ipaddress
6
+ import json
7
+ import os
8
+ import socket
9
+ import ssl
10
+ import tempfile
11
+ import urllib.parse
12
+ from base64 import b64encode
13
+ from dataclasses import dataclass
14
+ from typing import Any, Iterable, Iterator
15
+ from uuid import uuid4
16
+
17
+ from .config import get_default_timeout_ms, get_prefixed_env
18
+ from .errors import (
19
+ auth_error,
20
+ invalid_request_error,
21
+ not_supported_error,
22
+ provider_error,
23
+ rate_limit_error,
24
+ timeout_error,
25
+ )
26
+
27
+
28
+ def _timeout_seconds(timeout_ms: int | None) -> float:
29
+ if timeout_ms is None:
30
+ timeout_ms = get_default_timeout_ms()
31
+ return max(0.001, timeout_ms / 1000.0)
32
+
33
+
34
+ def _env_truthy(name: str) -> bool:
35
+ return os.environ.get(name) in {"1", "true", "TRUE", "yes", "YES"}
36
+
37
+
38
+ def _default_url_download_max_bytes() -> int:
39
+ raw = get_prefixed_env("URL_DOWNLOAD_MAX_BYTES")
40
+ if raw is None:
41
+ return 128 * 1024 * 1024
42
+ try:
43
+ value = int(raw)
44
+ except ValueError:
45
+ return 128 * 1024 * 1024
46
+ return max(1, value)
47
+
48
+
49
+ def _is_private_ip(ip: ipaddress.IPv4Address | ipaddress.IPv6Address) -> bool:
50
+ return bool(
51
+ ip.is_private
52
+ or ip.is_loopback
53
+ or ip.is_link_local
54
+ or ip.is_multicast
55
+ or ip.is_reserved
56
+ or ip.is_unspecified
57
+ )
58
+
59
+
60
+ def _resolve_host_ips(host: str) -> list[ipaddress.IPv4Address | ipaddress.IPv6Address]:
61
+ out: list[ipaddress.IPv4Address | ipaddress.IPv6Address] = []
62
+ try:
63
+ infos = socket.getaddrinfo(host, None, proto=socket.IPPROTO_TCP)
64
+ except OSError:
65
+ return out
66
+ seen: set[ipaddress.IPv4Address | ipaddress.IPv6Address] = set()
67
+ for family, _, _, _, sockaddr in infos:
68
+ if family == socket.AF_INET:
69
+ addr = sockaddr[0]
70
+ elif family == socket.AF_INET6:
71
+ addr = sockaddr[0]
72
+ else:
73
+ continue
74
+ try:
75
+ ip = ipaddress.ip_address(addr)
76
+ except ValueError:
77
+ continue
78
+ if not isinstance(ip, (ipaddress.IPv4Address, ipaddress.IPv6Address)):
79
+ continue
80
+ if ip in seen:
81
+ continue
82
+ seen.add(ip)
83
+ out.append(ip)
84
+ return out
85
+
86
+
87
+ def _is_private_host(host: str) -> bool:
88
+ h = host.strip().lower().rstrip(".")
89
+ if h in {"localhost"} or h.endswith(".localhost"):
90
+ return True
91
+ try:
92
+ ip = ipaddress.ip_address(h)
93
+ except ValueError:
94
+ ip = None
95
+ if isinstance(ip, (ipaddress.IPv4Address, ipaddress.IPv6Address)):
96
+ return _is_private_ip(ip)
97
+ for resolved in _resolve_host_ips(h):
98
+ if _is_private_ip(resolved):
99
+ return True
100
+ return False
101
+
102
+
103
+ def _resolve_url_host_ips(
104
+ host: str,
105
+ ) -> tuple[list[ipaddress.IPv4Address | ipaddress.IPv6Address], bool]:
106
+ """
107
+ Resolve a URL host once and classify it as private/loopback.
108
+
109
+ Returns: (resolved_ips, is_private)
110
+ """
111
+ h = host.strip().lower().rstrip(".")
112
+ if h in {"localhost"} or h.endswith(".localhost"):
113
+ return [], True
114
+ try:
115
+ ip = ipaddress.ip_address(h)
116
+ except ValueError:
117
+ ip = None
118
+ if isinstance(ip, (ipaddress.IPv4Address, ipaddress.IPv6Address)):
119
+ return [ip], _is_private_ip(ip)
120
+ resolved = _resolve_host_ips(h)
121
+ return resolved, any(_is_private_ip(x) for x in resolved)
122
+
123
+
124
+ def download_to_file(
125
+ *,
126
+ url: str,
127
+ output_path: str,
128
+ timeout_ms: int | None = None,
129
+ max_bytes: int | None = None,
130
+ headers: dict[str, str] | None = None,
131
+ proxy_url: str | None = None,
132
+ ) -> None:
133
+ """
134
+ Download URL to a local file with timeout/proxy support and a hard size limit.
135
+
136
+ Security: by default rejects obvious private/loopback IP hosts unless `NOUS_GENAI_ALLOW_PRIVATE_URLS=1`.
137
+ """
138
+ effective_max = (
139
+ _default_url_download_max_bytes() if max_bytes is None else max_bytes
140
+ )
141
+ if effective_max <= 0:
142
+ raise invalid_request_error("max_bytes must be positive")
143
+
144
+ cur = url
145
+ initial_host: str | None = None
146
+ for _ in range(5):
147
+ parsed = urllib.parse.urlparse(cur)
148
+ if parsed.scheme.lower() not in {"http", "https"}:
149
+ raise invalid_request_error(f"unsupported url scheme: {parsed.scheme}")
150
+ if not parsed.hostname:
151
+ raise invalid_request_error(f"invalid url: {cur}")
152
+ if initial_host is None:
153
+ initial_host = parsed.hostname
154
+ resolved, is_private = _resolve_url_host_ips(parsed.hostname)
155
+ if is_private and not _env_truthy("NOUS_GENAI_ALLOW_PRIVATE_URLS"):
156
+ raise invalid_request_error(
157
+ "url host is private/loopback; set NOUS_GENAI_ALLOW_PRIVATE_URLS=1 to allow"
158
+ )
159
+ if not resolved:
160
+ raise provider_error(
161
+ f"dns resolution failed: {parsed.hostname}", retryable=True
162
+ )
163
+ connect_ip = str(resolved[0])
164
+
165
+ path = _path_with_query(parsed)
166
+ timeout_s = _timeout_seconds(timeout_ms)
167
+ conn = _make_connection(
168
+ parsed,
169
+ timeout_s,
170
+ proxy_url=proxy_url,
171
+ connect_host=connect_ip,
172
+ tls_server_hostname=parsed.hostname,
173
+ )
174
+ try:
175
+ req_headers: dict[str, str] = {"Accept": "*/*"}
176
+ if (
177
+ headers
178
+ and initial_host
179
+ and parsed.hostname.lower() == initial_host.lower()
180
+ ):
181
+ req_headers.update(headers)
182
+ if proxy_url:
183
+ target_port = parsed.port or (
184
+ 443 if parsed.scheme.lower() == "https" else 80
185
+ )
186
+ default_port = 443 if parsed.scheme.lower() == "https" else 80
187
+ req_headers["Host"] = (
188
+ parsed.hostname
189
+ if target_port == default_port
190
+ else f"{parsed.hostname}:{target_port}"
191
+ )
192
+ conn.request("GET", path, headers=req_headers)
193
+ resp = conn.getresponse()
194
+ if resp.status in {301, 302, 303, 307, 308}:
195
+ loc = resp.getheader("Location")
196
+ if not loc:
197
+ raise provider_error("redirect response missing Location header")
198
+ cur = urllib.parse.urljoin(cur, loc)
199
+ continue
200
+ if resp.status < 200 or resp.status >= 300:
201
+ raw = resp.read(64 * 1024 + 1)
202
+ _raise_for_status(resp.status, raw[: 64 * 1024])
203
+
204
+ raw_len = resp.getheader("Content-Length")
205
+ if raw_len:
206
+ try:
207
+ n = int(raw_len)
208
+ except ValueError:
209
+ n = -1
210
+ if n > effective_max:
211
+ raise not_supported_error(
212
+ f"url download too large ({n} > {effective_max}); set NOUS_GENAI_URL_DOWNLOAD_MAX_BYTES or use path/ref"
213
+ )
214
+
215
+ out_dir = os.path.dirname(os.path.abspath(output_path)) or "."
216
+ with tempfile.NamedTemporaryFile(
217
+ prefix="genaisdk-dl-", dir=out_dir, delete=False
218
+ ) as tmp:
219
+ tmp_path = tmp.name
220
+ total = 0
221
+ try:
222
+ with open(tmp_path, "wb") as f:
223
+ while True:
224
+ chunk = resp.read(64 * 1024)
225
+ if not chunk:
226
+ break
227
+ total += len(chunk)
228
+ if total > effective_max:
229
+ raise not_supported_error(
230
+ f"url download exceeded limit ({total} > {effective_max}); set NOUS_GENAI_URL_DOWNLOAD_MAX_BYTES or use path/ref"
231
+ )
232
+ f.write(chunk)
233
+ except Exception:
234
+ try:
235
+ os.unlink(tmp_path)
236
+ except OSError:
237
+ pass
238
+ raise
239
+ try:
240
+ os.replace(tmp_path, output_path)
241
+ except Exception:
242
+ try:
243
+ os.unlink(tmp_path)
244
+ except OSError:
245
+ pass
246
+ raise
247
+ return
248
+ except (socket.timeout, TimeoutError):
249
+ raise timeout_error("request timeout")
250
+ except (ssl.SSLError, http.client.HTTPException, OSError) as e:
251
+ raise provider_error(f"network error: {type(e).__name__}", retryable=True)
252
+ finally:
253
+ conn.close()
254
+
255
+ raise provider_error("too many redirects", retryable=False)
256
+
257
+
258
+ def download_to_tempfile(
259
+ *,
260
+ url: str,
261
+ timeout_ms: int | None = None,
262
+ max_bytes: int | None = None,
263
+ headers: dict[str, str] | None = None,
264
+ suffix: str = "",
265
+ proxy_url: str | None = None,
266
+ ) -> str:
267
+ with tempfile.NamedTemporaryFile(
268
+ prefix="genaisdk-", suffix=suffix, delete=False
269
+ ) as f:
270
+ tmp_path = f.name
271
+ try:
272
+ download_to_file(
273
+ url=url,
274
+ output_path=tmp_path,
275
+ timeout_ms=timeout_ms,
276
+ max_bytes=max_bytes,
277
+ headers=headers,
278
+ proxy_url=proxy_url,
279
+ )
280
+ except Exception:
281
+ try:
282
+ os.unlink(tmp_path)
283
+ except OSError:
284
+ pass
285
+ raise
286
+ return tmp_path
287
+
288
+
289
+ def _proxy_tunnel_headers(proxy: urllib.parse.ParseResult) -> dict[str, str] | None:
290
+ user = proxy.username
291
+ pw = proxy.password
292
+ if user is None and pw is None:
293
+ return None
294
+ user = "" if user is None else user
295
+ pw = "" if pw is None else pw
296
+ token = b64encode(f"{user}:{pw}".encode("utf-8")).decode("ascii")
297
+ return {"Proxy-Authorization": f"Basic {token}"}
298
+
299
+
300
+ class _PinnedHTTPConnection(http.client.HTTPConnection):
301
+ def __init__(
302
+ self,
303
+ host: str,
304
+ port: int,
305
+ *,
306
+ connect_host: str,
307
+ timeout: float,
308
+ ) -> None:
309
+ super().__init__(host, port, timeout=timeout)
310
+ self._connect_host = connect_host
311
+
312
+ def connect(self) -> None:
313
+ self.sock = socket.create_connection(
314
+ (self._connect_host, self.port), timeout=self.timeout
315
+ )
316
+ try:
317
+ self.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
318
+ except OSError as e:
319
+ if e.errno != errno.ENOPROTOOPT:
320
+ raise
321
+ if getattr(self, "_tunnel_host", None):
322
+ self._tunnel() # type: ignore[attr-defined]
323
+
324
+
325
+ class _PinnedHTTPSConnection(http.client.HTTPSConnection):
326
+ def __init__(
327
+ self,
328
+ host: str,
329
+ port: int,
330
+ *,
331
+ connect_host: str,
332
+ tls_server_hostname: str,
333
+ timeout: float,
334
+ context: ssl.SSLContext,
335
+ ) -> None:
336
+ super().__init__(host, port, timeout=timeout, context=context)
337
+ self._connect_host = connect_host
338
+ self._tls_server_hostname = tls_server_hostname
339
+ self._ssl_context = context
340
+
341
+ def connect(self) -> None:
342
+ self.sock = socket.create_connection(
343
+ (self._connect_host, self.port), timeout=self.timeout
344
+ )
345
+ try:
346
+ self.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
347
+ except OSError as e:
348
+ if e.errno != errno.ENOPROTOOPT:
349
+ raise
350
+ if getattr(self, "_tunnel_host", None):
351
+ self._tunnel() # type: ignore[attr-defined]
352
+ assert self.sock is not None
353
+ self.sock = self._ssl_context.wrap_socket(
354
+ self.sock,
355
+ server_hostname=self._tls_server_hostname,
356
+ )
357
+
358
+
359
+ def _make_connection(
360
+ parsed: urllib.parse.ParseResult,
361
+ timeout_s: float,
362
+ *,
363
+ proxy_url: str | None,
364
+ connect_host: str | None = None,
365
+ tls_server_hostname: str | None = None,
366
+ ) -> http.client.HTTPConnection:
367
+ scheme = parsed.scheme.lower()
368
+ target_host = parsed.hostname
369
+ if not target_host:
370
+ raise invalid_request_error(f"invalid url: {parsed.geturl()}")
371
+
372
+ target_port = parsed.port
373
+ is_https = scheme == "https"
374
+ if not is_https and scheme != "http":
375
+ raise invalid_request_error(f"unsupported url scheme: {scheme}")
376
+
377
+ target_connect_host = target_host if connect_host is None else connect_host
378
+ tls_hostname = target_host if tls_server_hostname is None else tls_server_hostname
379
+
380
+ if proxy_url:
381
+ p = urllib.parse.urlparse(proxy_url)
382
+ if not p.hostname:
383
+ raise invalid_request_error(f"invalid proxy url: {proxy_url}")
384
+ if p.scheme.lower() not in {"http", "https"}:
385
+ raise invalid_request_error(f"unsupported proxy url scheme: {p.scheme}")
386
+ proxy_port = p.port or (443 if p.scheme == "https" else 80)
387
+ effective_target_port = target_port or (443 if is_https else 80)
388
+ if is_https:
389
+ ctx = ssl.create_default_context()
390
+ conn: http.client.HTTPConnection = _PinnedHTTPSConnection(
391
+ p.hostname,
392
+ proxy_port,
393
+ connect_host=p.hostname,
394
+ tls_server_hostname=tls_hostname,
395
+ timeout=timeout_s,
396
+ context=ctx,
397
+ )
398
+ else:
399
+ conn = http.client.HTTPConnection(p.hostname, proxy_port, timeout=timeout_s)
400
+ conn.set_tunnel(
401
+ target_connect_host, effective_target_port, headers=_proxy_tunnel_headers(p)
402
+ )
403
+ return conn
404
+
405
+ if is_https:
406
+ ctx = ssl.create_default_context()
407
+ effective_port = target_port or 443
408
+ if target_connect_host != target_host or tls_hostname != target_host:
409
+ return _PinnedHTTPSConnection(
410
+ target_host,
411
+ effective_port,
412
+ connect_host=target_connect_host,
413
+ tls_server_hostname=tls_hostname,
414
+ timeout=timeout_s,
415
+ context=ctx,
416
+ )
417
+ return http.client.HTTPSConnection(
418
+ target_host, effective_port, timeout=timeout_s, context=ctx
419
+ )
420
+ effective_port = target_port or 80
421
+ if target_connect_host != target_host:
422
+ return _PinnedHTTPConnection(
423
+ target_host,
424
+ effective_port,
425
+ connect_host=target_connect_host,
426
+ timeout=timeout_s,
427
+ )
428
+ return http.client.HTTPConnection(target_host, effective_port, timeout=timeout_s)
429
+
430
+
431
+ def _path_with_query(parsed: urllib.parse.ParseResult) -> str:
432
+ path = parsed.path or "/"
433
+ if parsed.query:
434
+ return f"{path}?{parsed.query}"
435
+ return path
436
+
437
+
438
+ def _extract_error_message(body: bytes) -> tuple[str, str | None]:
439
+ if not body:
440
+ return "empty error body", None
441
+ try:
442
+ obj = json.loads(body)
443
+ except Exception:
444
+ text = body.decode("utf-8", errors="replace")
445
+ return text[:2_000], None
446
+
447
+ if isinstance(obj, dict):
448
+ if isinstance(obj.get("error"), dict):
449
+ err = obj["error"]
450
+ msg = err.get("message") or err.get("detail") or str(err)
451
+ code = err.get("code") or err.get("status") or None
452
+ return str(msg)[:2_000], str(code) if code is not None else None
453
+ msg = obj.get("message") or str(obj)
454
+ code = obj.get("code") or obj.get("status") or None
455
+ return str(msg)[:2_000], str(code) if code is not None else None
456
+
457
+ return str(obj)[:2_000], None
458
+
459
+
460
+ def _raise_for_status(status: int, body: bytes) -> None:
461
+ message, provider_code = _extract_error_message(body)
462
+ if status in (401, 403):
463
+ raise auth_error(message, provider_code=provider_code)
464
+ if status == 429:
465
+ raise rate_limit_error(message, provider_code=provider_code)
466
+ if status in (400, 404, 409, 415, 422):
467
+ raise invalid_request_error(message, provider_code=provider_code)
468
+ if status in (408, 504):
469
+ raise timeout_error(message)
470
+ if 500 <= status <= 599:
471
+ raise provider_error(message, provider_code=provider_code, retryable=True)
472
+ raise provider_error(message, provider_code=provider_code, retryable=False)
473
+
474
+
475
+ @dataclass(frozen=True, slots=True)
476
+ class StreamingBody:
477
+ content_type: str
478
+ content_length: int
479
+ chunks: Iterable[bytes]
480
+
481
+
482
+ def multipart_form_data_fields(*, fields: dict[str, str]) -> StreamingBody:
483
+ boundary = uuid4().hex
484
+ boundary_bytes = boundary.encode("ascii")
485
+ crlf = b"\r\n"
486
+ tail = b"--" + boundary_bytes + b"--" + crlf
487
+
488
+ parts = []
489
+ total = len(tail)
490
+ for k, v in fields.items():
491
+ part = (
492
+ b"--"
493
+ + boundary_bytes
494
+ + crlf
495
+ + f'Content-Disposition: form-data; name="{k}"'.encode("utf-8")
496
+ + crlf
497
+ + crlf
498
+ + v.encode("utf-8")
499
+ + crlf
500
+ )
501
+ parts.append(part)
502
+ total += len(part)
503
+
504
+ def _chunks() -> Iterator[bytes]:
505
+ yield from parts
506
+ yield tail
507
+
508
+ return StreamingBody(
509
+ content_type=f"multipart/form-data; boundary={boundary}",
510
+ content_length=total,
511
+ chunks=_chunks(),
512
+ )
513
+
514
+
515
+ def multipart_form_data(
516
+ *,
517
+ fields: dict[str, str],
518
+ file_field: str,
519
+ file_path: str,
520
+ filename: str,
521
+ file_mime_type: str,
522
+ chunk_size: int = 64 * 1024,
523
+ ) -> StreamingBody:
524
+ boundary = uuid4().hex
525
+ boundary_bytes = boundary.encode("ascii")
526
+ crlf = b"\r\n"
527
+
528
+ def _field_part(name: str, value: str) -> bytes:
529
+ header = (
530
+ b"--"
531
+ + boundary_bytes
532
+ + crlf
533
+ + f'Content-Disposition: form-data; name="{name}"'.encode("utf-8")
534
+ + crlf
535
+ + crlf
536
+ )
537
+ return header + value.encode("utf-8") + crlf
538
+
539
+ def _file_preamble() -> bytes:
540
+ return (
541
+ b"--"
542
+ + boundary_bytes
543
+ + crlf
544
+ + (
545
+ f'Content-Disposition: form-data; name="{file_field}"; filename="{filename}"'
546
+ ).encode("utf-8")
547
+ + crlf
548
+ + f"Content-Type: {file_mime_type}".encode("utf-8")
549
+ + crlf
550
+ + crlf
551
+ )
552
+
553
+ file_size = os.stat(file_path).st_size
554
+ tail = b"--" + boundary_bytes + b"--" + crlf
555
+ preamble = _file_preamble()
556
+
557
+ total = len(tail) + len(preamble) + file_size + len(crlf)
558
+ field_parts = []
559
+ for k, v in fields.items():
560
+ part = _field_part(k, v)
561
+ field_parts.append(part)
562
+ total += len(part)
563
+
564
+ def _chunks() -> Iterator[bytes]:
565
+ for part in field_parts:
566
+ yield part
567
+ yield preamble
568
+ with open(file_path, "rb") as f:
569
+ while True:
570
+ chunk = f.read(chunk_size)
571
+ if not chunk:
572
+ break
573
+ yield chunk
574
+ yield crlf
575
+ yield tail
576
+
577
+ return StreamingBody(
578
+ content_type=f"multipart/form-data; boundary={boundary}",
579
+ content_length=total,
580
+ chunks=_chunks(),
581
+ )
582
+
583
+
584
+ def multipart_form_data_json_and_file(
585
+ *,
586
+ metadata_field: str,
587
+ metadata: dict[str, Any],
588
+ file_field: str,
589
+ file_path: str,
590
+ filename: str,
591
+ file_mime_type: str,
592
+ chunk_size: int = 64 * 1024,
593
+ ) -> StreamingBody:
594
+ boundary = uuid4().hex
595
+ boundary_bytes = boundary.encode("ascii")
596
+ crlf = b"\r\n"
597
+
598
+ meta_bytes = json.dumps(metadata, separators=(",", ":")).encode("utf-8")
599
+ meta_part = (
600
+ b"--"
601
+ + boundary_bytes
602
+ + crlf
603
+ + f'Content-Disposition: form-data; name="{metadata_field}"'.encode("utf-8")
604
+ + crlf
605
+ + b"Content-Type: application/json; charset=utf-8"
606
+ + crlf
607
+ + crlf
608
+ + meta_bytes
609
+ + crlf
610
+ )
611
+ file_preamble = (
612
+ b"--"
613
+ + boundary_bytes
614
+ + crlf
615
+ + (
616
+ f'Content-Disposition: form-data; name="{file_field}"; filename="{filename}"'
617
+ ).encode("utf-8")
618
+ + crlf
619
+ + f"Content-Type: {file_mime_type}".encode("utf-8")
620
+ + crlf
621
+ + crlf
622
+ )
623
+ tail = b"--" + boundary_bytes + b"--" + crlf
624
+
625
+ file_size = os.stat(file_path).st_size
626
+ total = len(meta_part) + len(file_preamble) + file_size + len(crlf) + len(tail)
627
+
628
+ def _chunks() -> Iterator[bytes]:
629
+ yield meta_part
630
+ yield file_preamble
631
+ with open(file_path, "rb") as f:
632
+ while True:
633
+ chunk = f.read(chunk_size)
634
+ if not chunk:
635
+ break
636
+ yield chunk
637
+ yield crlf
638
+ yield tail
639
+
640
+ return StreamingBody(
641
+ content_type=f"multipart/form-data; boundary={boundary}",
642
+ content_length=total,
643
+ chunks=_chunks(),
644
+ )
645
+
646
+
647
+ def request_json(
648
+ *,
649
+ method: str,
650
+ url: str,
651
+ headers: dict[str, str] | None = None,
652
+ json_body: Any | None = None,
653
+ timeout_ms: int | None = None,
654
+ proxy_url: str | None = None,
655
+ ) -> dict[str, Any]:
656
+ body = (
657
+ None
658
+ if json_body is None
659
+ else json.dumps(json_body, separators=(",", ":")).encode("utf-8")
660
+ )
661
+ req_headers = {"Accept": "application/json"}
662
+ if body is not None:
663
+ req_headers["Content-Type"] = "application/json"
664
+ req_headers["Content-Length"] = str(len(body))
665
+ if headers:
666
+ req_headers.update(headers)
667
+
668
+ parsed = urllib.parse.urlparse(url)
669
+ path = _path_with_query(parsed)
670
+ timeout_s = _timeout_seconds(timeout_ms)
671
+ conn = _make_connection(parsed, timeout_s, proxy_url=proxy_url)
672
+ try:
673
+ conn.request(method.upper(), path, body=body, headers=req_headers)
674
+ resp = conn.getresponse()
675
+ raw = resp.read()
676
+ if resp.status < 200 or resp.status >= 300:
677
+ _raise_for_status(resp.status, raw)
678
+ if not raw:
679
+ return {}
680
+ try:
681
+ obj = json.loads(raw)
682
+ except Exception:
683
+ raise provider_error("invalid json response", retryable=True)
684
+ if not isinstance(obj, dict):
685
+ raise provider_error("invalid json response", retryable=True)
686
+ return obj
687
+ except (socket.timeout, TimeoutError):
688
+ raise timeout_error("request timeout")
689
+ except (ssl.SSLError, http.client.HTTPException, OSError) as e:
690
+ raise provider_error(f"network error: {type(e).__name__}", retryable=True)
691
+ finally:
692
+ conn.close()
693
+
694
+
695
+ def request_bytes(
696
+ *,
697
+ method: str,
698
+ url: str,
699
+ headers: dict[str, str] | None = None,
700
+ body: bytes | None = None,
701
+ timeout_ms: int | None = None,
702
+ proxy_url: str | None = None,
703
+ ) -> bytes:
704
+ req_headers: dict[str, str] = {}
705
+ if body is not None:
706
+ req_headers["Content-Length"] = str(len(body))
707
+ if headers:
708
+ req_headers.update(headers)
709
+
710
+ parsed = urllib.parse.urlparse(url)
711
+ path = _path_with_query(parsed)
712
+ timeout_s = _timeout_seconds(timeout_ms)
713
+ conn = _make_connection(parsed, timeout_s, proxy_url=proxy_url)
714
+ try:
715
+ conn.request(method.upper(), path, body=body, headers=req_headers)
716
+ resp = conn.getresponse()
717
+ raw = resp.read()
718
+ if resp.status < 200 or resp.status >= 300:
719
+ _raise_for_status(resp.status, raw)
720
+ return raw
721
+ except (socket.timeout, TimeoutError):
722
+ raise timeout_error("request timeout")
723
+ except (ssl.SSLError, http.client.HTTPException, OSError) as e:
724
+ raise provider_error(f"network error: {type(e).__name__}", retryable=True)
725
+ finally:
726
+ conn.close()
727
+
728
+
729
+ def request_stream_json_sse(
730
+ *,
731
+ method: str,
732
+ url: str,
733
+ headers: dict[str, str] | None = None,
734
+ json_body: Any | None = None,
735
+ timeout_ms: int | None = None,
736
+ proxy_url: str | None = None,
737
+ ) -> Iterator[dict[str, Any]]:
738
+ body = (
739
+ None
740
+ if json_body is None
741
+ else json.dumps(json_body, separators=(",", ":")).encode("utf-8")
742
+ )
743
+ req_headers = {"Accept": "text/event-stream"}
744
+ if body is not None:
745
+ req_headers["Content-Type"] = "application/json"
746
+ req_headers["Content-Length"] = str(len(body))
747
+ if headers:
748
+ req_headers.update(headers)
749
+
750
+ parsed = urllib.parse.urlparse(url)
751
+ path = _path_with_query(parsed)
752
+ timeout_s = _timeout_seconds(timeout_ms)
753
+ conn = _make_connection(parsed, timeout_s, proxy_url=proxy_url)
754
+
755
+ try:
756
+ conn.request(method.upper(), path, body=body, headers=req_headers)
757
+ resp = conn.getresponse()
758
+ if resp.status < 200 or resp.status >= 300:
759
+ raw = resp.read()
760
+ _raise_for_status(resp.status, raw)
761
+
762
+ def _iter() -> Iterator[dict[str, Any]]:
763
+ try:
764
+ for ev in _iter_sse_events(resp):
765
+ if not ev.data:
766
+ continue
767
+ if ev.data == "[DONE]":
768
+ return
769
+ try:
770
+ obj = json.loads(ev.data)
771
+ except Exception:
772
+ raise provider_error(f"invalid sse json: {ev.data[:200]}")
773
+ if isinstance(obj, dict):
774
+ yield obj
775
+ finally:
776
+ conn.close()
777
+
778
+ return _iter()
779
+ except (socket.timeout, TimeoutError):
780
+ conn.close()
781
+ raise timeout_error("request timeout")
782
+ except (ssl.SSLError, http.client.HTTPException, OSError) as e:
783
+ conn.close()
784
+ raise provider_error(f"network error: {type(e).__name__}", retryable=True)
785
+ except Exception:
786
+ conn.close()
787
+ raise
788
+
789
+
790
+ def request_streaming_body_json(
791
+ *,
792
+ method: str,
793
+ url: str,
794
+ headers: dict[str, str] | None,
795
+ body: StreamingBody,
796
+ timeout_ms: int | None = None,
797
+ proxy_url: str | None = None,
798
+ ) -> dict[str, Any]:
799
+ req_headers = {"Accept": "application/json"}
800
+ req_headers["Content-Type"] = body.content_type
801
+ req_headers["Content-Length"] = str(body.content_length)
802
+ if headers:
803
+ req_headers.update(headers)
804
+
805
+ parsed = urllib.parse.urlparse(url)
806
+ path = _path_with_query(parsed)
807
+ timeout_s = _timeout_seconds(timeout_ms)
808
+ conn = _make_connection(parsed, timeout_s, proxy_url=proxy_url)
809
+
810
+ try:
811
+ conn.putrequest(method.upper(), path)
812
+ for k, v in req_headers.items():
813
+ conn.putheader(k, v)
814
+ conn.endheaders()
815
+ for chunk in body.chunks:
816
+ if chunk:
817
+ conn.send(chunk)
818
+ resp = conn.getresponse()
819
+ raw = resp.read()
820
+ if resp.status < 200 or resp.status >= 300:
821
+ _raise_for_status(resp.status, raw)
822
+ if not raw:
823
+ return {}
824
+ try:
825
+ obj = json.loads(raw)
826
+ except Exception:
827
+ raise provider_error("invalid json response", retryable=True)
828
+ if not isinstance(obj, dict):
829
+ raise provider_error("invalid json response", retryable=True)
830
+ return obj
831
+ except (socket.timeout, TimeoutError):
832
+ raise timeout_error("request timeout")
833
+ except (ssl.SSLError, http.client.HTTPException, OSError) as e:
834
+ raise provider_error(f"network error: {type(e).__name__}", retryable=True)
835
+ finally:
836
+ conn.close()
837
+
838
+
839
+ @dataclass(frozen=True, slots=True)
840
+ class SSEEvent:
841
+ data: str
842
+ event: str | None = None
843
+ id: str | None = None
844
+ retry: int | None = None
845
+
846
+
847
+ def request_stream_sse(
848
+ *,
849
+ method: str,
850
+ url: str,
851
+ headers: dict[str, str] | None = None,
852
+ body: bytes | None = None,
853
+ timeout_ms: int | None = None,
854
+ proxy_url: str | None = None,
855
+ ) -> Iterator[SSEEvent]:
856
+ req_headers = {"Accept": "text/event-stream"}
857
+ if body is not None:
858
+ req_headers["Content-Length"] = str(len(body))
859
+ if headers:
860
+ req_headers.update(headers)
861
+
862
+ parsed = urllib.parse.urlparse(url)
863
+ path = _path_with_query(parsed)
864
+ timeout_s = _timeout_seconds(timeout_ms)
865
+ conn = _make_connection(parsed, timeout_s, proxy_url=proxy_url)
866
+
867
+ try:
868
+ conn.request(method.upper(), path, body=body, headers=req_headers)
869
+ resp = conn.getresponse()
870
+ if resp.status < 200 or resp.status >= 300:
871
+ raw = resp.read()
872
+ _raise_for_status(resp.status, raw)
873
+
874
+ def _iter() -> Iterator[SSEEvent]:
875
+ try:
876
+ yield from _iter_sse_events(resp)
877
+ finally:
878
+ conn.close()
879
+
880
+ return _iter()
881
+ except (socket.timeout, TimeoutError):
882
+ conn.close()
883
+ raise timeout_error("request timeout")
884
+ except (ssl.SSLError, http.client.HTTPException, OSError) as e:
885
+ conn.close()
886
+ raise provider_error(f"network error: {type(e).__name__}", retryable=True)
887
+ except Exception:
888
+ conn.close()
889
+ raise
890
+
891
+
892
+ def _iter_sse_events(resp: http.client.HTTPResponse) -> Iterator[SSEEvent]:
893
+ buffer: list[str] = []
894
+ event_type: str | None = None
895
+ event_id: str | None = None
896
+ retry_ms: int | None = None
897
+ while True:
898
+ line = resp.readline()
899
+ if not line:
900
+ break
901
+ text = line.decode("utf-8", errors="replace")
902
+ text = text.rstrip("\n")
903
+ if text.endswith("\r"):
904
+ text = text[:-1]
905
+ if not text:
906
+ if (
907
+ buffer
908
+ or event_type is not None
909
+ or event_id is not None
910
+ or retry_ms is not None
911
+ ):
912
+ yield SSEEvent(
913
+ data="\n".join(buffer),
914
+ event=event_type,
915
+ id=event_id,
916
+ retry=retry_ms,
917
+ )
918
+ buffer.clear()
919
+ event_type = None
920
+ event_id = None
921
+ retry_ms = None
922
+ continue
923
+ if text.startswith(":"):
924
+ continue
925
+ if ":" in text:
926
+ field, value = text.split(":", 1)
927
+ value = value[1:] if value.startswith(" ") else value
928
+ else:
929
+ field, value = text, ""
930
+ if field == "data":
931
+ buffer.append(value)
932
+ elif field == "event":
933
+ event_type = value or None
934
+ elif field == "id":
935
+ if "\x00" not in value:
936
+ event_id = value
937
+ elif field == "retry":
938
+ try:
939
+ retry_ms = int(value)
940
+ except ValueError:
941
+ continue
942
+ if buffer or event_type is not None or event_id is not None or retry_ms is not None:
943
+ yield SSEEvent(
944
+ data="\n".join(buffer), event=event_type, id=event_id, retry=retry_ms
945
+ )
946
+
947
+
948
+ def _iter_sse_data(resp: http.client.HTTPResponse) -> Iterator[str]:
949
+ for ev in _iter_sse_events(resp):
950
+ if ev.data:
951
+ yield ev.data