sglang 0.1.14__py3-none-any.whl → 0.1.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. sglang/__init__.py +57 -2
  2. sglang/api.py +8 -5
  3. sglang/backend/anthropic.py +18 -4
  4. sglang/backend/openai.py +2 -1
  5. sglang/backend/runtime_endpoint.py +18 -5
  6. sglang/backend/vertexai.py +1 -0
  7. sglang/global_config.py +5 -1
  8. sglang/lang/chat_template.py +83 -2
  9. sglang/lang/interpreter.py +92 -35
  10. sglang/lang/ir.py +12 -9
  11. sglang/lang/tracer.py +6 -4
  12. sglang/launch_server_llavavid.py +31 -0
  13. sglang/srt/constrained/fsm_cache.py +1 -0
  14. sglang/srt/constrained/jump_forward.py +1 -0
  15. sglang/srt/conversation.py +2 -2
  16. sglang/srt/flush_cache.py +16 -0
  17. sglang/srt/hf_transformers_utils.py +10 -2
  18. sglang/srt/layers/context_flashattention_nopad.py +1 -0
  19. sglang/srt/layers/extend_attention.py +1 -0
  20. sglang/srt/layers/logits_processor.py +114 -54
  21. sglang/srt/layers/radix_attention.py +2 -1
  22. sglang/srt/layers/token_attention.py +1 -0
  23. sglang/srt/managers/detokenizer_manager.py +5 -1
  24. sglang/srt/managers/io_struct.py +27 -3
  25. sglang/srt/managers/router/infer_batch.py +97 -48
  26. sglang/srt/managers/router/manager.py +11 -8
  27. sglang/srt/managers/router/model_rpc.py +169 -90
  28. sglang/srt/managers/router/model_runner.py +110 -166
  29. sglang/srt/managers/router/radix_cache.py +89 -51
  30. sglang/srt/managers/router/scheduler.py +17 -28
  31. sglang/srt/managers/tokenizer_manager.py +110 -33
  32. sglang/srt/memory_pool.py +5 -14
  33. sglang/srt/model_config.py +11 -0
  34. sglang/srt/models/commandr.py +372 -0
  35. sglang/srt/models/dbrx.py +412 -0
  36. sglang/srt/models/dbrx_config.py +281 -0
  37. sglang/srt/models/gemma.py +24 -25
  38. sglang/srt/models/llama2.py +25 -26
  39. sglang/srt/models/llava.py +8 -10
  40. sglang/srt/models/llavavid.py +307 -0
  41. sglang/srt/models/mixtral.py +29 -33
  42. sglang/srt/models/qwen.py +34 -25
  43. sglang/srt/models/qwen2.py +25 -26
  44. sglang/srt/models/stablelm.py +26 -26
  45. sglang/srt/models/yivl.py +3 -5
  46. sglang/srt/openai_api_adapter.py +356 -0
  47. sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +36 -20
  48. sglang/srt/sampling_params.py +2 -0
  49. sglang/srt/server.py +91 -456
  50. sglang/srt/server_args.py +79 -49
  51. sglang/srt/utils.py +212 -47
  52. sglang/srt/weight_utils.py +417 -0
  53. sglang/test/test_programs.py +8 -7
  54. sglang/test/test_utils.py +195 -7
  55. sglang/utils.py +77 -26
  56. {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/METADATA +20 -18
  57. sglang-0.1.16.dist-info/RECORD +72 -0
  58. sglang-0.1.14.dist-info/RECORD +0 -64
  59. {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/LICENSE +0 -0
  60. {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/WHEEL +0 -0
  61. {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/top_level.txt +0 -0
sglang/srt/server_args.py CHANGED
@@ -1,3 +1,5 @@
1
+ """The arguments of the server."""
2
+
1
3
  import argparse
2
4
  import dataclasses
3
5
  from typing import List, Optional, Union
@@ -5,34 +7,47 @@ from typing import List, Optional, Union
5
7
 
6
8
  @dataclasses.dataclass
7
9
  class ServerArgs:
10
+ # Model and tokenizer
8
11
  model_path: str
9
12
  tokenizer_path: Optional[str] = None
10
- host: str = "127.0.0.1"
11
- port: int = 30000
12
- additional_ports: Optional[Union[List[int], int]] = None
13
13
  load_format: str = "auto"
14
14
  tokenizer_mode: str = "auto"
15
15
  chat_template: Optional[str] = None
16
16
  trust_remote_code: bool = True
17
+ context_length: Optional[int] = None
18
+
19
+ # Port
20
+ host: str = "127.0.0.1"
21
+ port: int = 30000
22
+ additional_ports: Optional[Union[List[int], int]] = None
23
+
24
+ # Memory and scheduling
17
25
  mem_fraction_static: Optional[float] = None
18
26
  max_prefill_num_token: Optional[int] = None
19
- context_length: Optional[int] = None
20
- tp_size: int = 1
21
27
  schedule_heuristic: str = "lpm"
22
28
  schedule_conservativeness: float = 1.0
23
- attention_reduce_in_fp32: bool = False
24
- random_seed: int = 42
29
+
30
+ # Other runtime options
31
+ tp_size: int = 1
25
32
  stream_interval: int = 8
33
+ random_seed: int = 42
34
+
35
+ # Logging
36
+ log_level: str = "info"
37
+ log_requests: bool = False
26
38
  disable_log_stats: bool = False
27
39
  log_stats_interval: int = 10
28
- log_level: str = "info"
40
+ show_time_cost: bool = False
29
41
 
30
- # optional modes
31
- disable_radix_cache: bool = False
42
+ # Other
43
+ api_key: str = ""
44
+
45
+ # Optimization/debug options
32
46
  enable_flashinfer: bool = False
47
+ attention_reduce_in_fp32: bool = False
48
+ disable_radix_cache: bool = False
33
49
  disable_regex_jump_forward: bool = False
34
50
  disable_disk_cache: bool = False
35
- api_key: str = ""
36
51
 
37
52
  def __post_init__(self):
38
53
  if self.tokenizer_path is None:
@@ -65,15 +80,18 @@ class ServerArgs:
65
80
  default=ServerArgs.tokenizer_path,
66
81
  help="The path of the tokenizer.",
67
82
  )
68
- parser.add_argument("--host", type=str, default=ServerArgs.host)
69
- parser.add_argument("--port", type=int, default=ServerArgs.port)
70
- # we want to be able to pass a list of ports
83
+ parser.add_argument(
84
+ "--host", type=str, default=ServerArgs.host, help="The host of the server."
85
+ )
86
+ parser.add_argument(
87
+ "--port", type=int, default=ServerArgs.port, help="The port of the server."
88
+ )
71
89
  parser.add_argument(
72
90
  "--additional-ports",
73
91
  type=int,
74
92
  nargs="*",
75
93
  default=[],
76
- help="Additional ports specified for launching server.",
94
+ help="Additional ports specified for the server.",
77
95
  )
78
96
  parser.add_argument(
79
97
  "--load-format",
@@ -111,6 +129,12 @@ class ServerArgs:
111
129
  action="store_true",
112
130
  help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
113
131
  )
132
+ parser.add_argument(
133
+ "--context-length",
134
+ type=int,
135
+ default=ServerArgs.context_length,
136
+ help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).",
137
+ )
114
138
  parser.add_argument(
115
139
  "--mem-fraction-static",
116
140
  type=float,
@@ -123,23 +147,12 @@ class ServerArgs:
123
147
  default=ServerArgs.max_prefill_num_token,
124
148
  help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.",
125
149
  )
126
- parser.add_argument(
127
- "--context-length",
128
- type=int,
129
- default=ServerArgs.context_length,
130
- help="The model's maximum context length. Use this to reduce the context length to save memory. Defaults to None (will use the value from the model's config.json instead).",
131
- )
132
- parser.add_argument(
133
- "--tp-size",
134
- type=int,
135
- default=ServerArgs.tp_size,
136
- help="Tensor parallelism degree.",
137
- )
138
150
  parser.add_argument(
139
151
  "--schedule-heuristic",
140
152
  type=str,
141
153
  default=ServerArgs.schedule_heuristic,
142
- help="Schudule mode: [lpm, weight, random, fcfs]",
154
+ choices=["lpm", "random", "fcfs", "dfs-weight"],
155
+ help="Scheduling Heuristic.",
143
156
  )
144
157
  parser.add_argument(
145
158
  "--schedule-conservativeness",
@@ -148,15 +161,10 @@ class ServerArgs:
148
161
  help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
149
162
  )
150
163
  parser.add_argument(
151
- "--random-seed",
164
+ "--tp-size",
152
165
  type=int,
153
- default=ServerArgs.random_seed,
154
- help="Random seed.",
155
- )
156
- parser.add_argument(
157
- "--attention-reduce-in-fp32",
158
- action="store_true",
159
- help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16.",
166
+ default=ServerArgs.tp_size,
167
+ help="Tensor parallelism size.",
160
168
  )
161
169
  parser.add_argument(
162
170
  "--stream-interval",
@@ -164,11 +172,22 @@ class ServerArgs:
164
172
  default=ServerArgs.stream_interval,
165
173
  help="The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher",
166
174
  )
175
+ parser.add_argument(
176
+ "--random-seed",
177
+ type=int,
178
+ default=ServerArgs.random_seed,
179
+ help="Random seed.",
180
+ )
167
181
  parser.add_argument(
168
182
  "--log-level",
169
183
  type=str,
170
184
  default=ServerArgs.log_level,
171
- help="Log level",
185
+ help="Logging level",
186
+ )
187
+ parser.add_argument(
188
+ "--log-requests",
189
+ action="store_true",
190
+ help="Log all requests",
172
191
  )
173
192
  parser.add_argument(
174
193
  "--disable-log-stats",
@@ -181,17 +200,34 @@ class ServerArgs:
181
200
  default=ServerArgs.log_stats_interval,
182
201
  help="Log stats interval in second.",
183
202
  )
184
- # optional modes
185
203
  parser.add_argument(
186
- "--disable-radix-cache",
204
+ "--show-time-cost",
187
205
  action="store_true",
188
- help="Disable RadixAttention",
206
+ help="Show time cost of custom marks",
207
+ )
208
+ parser.add_argument(
209
+ "--api-key",
210
+ type=str,
211
+ default=ServerArgs.api_key,
212
+ help="Set API key of the server",
189
213
  )
214
+
215
+ # Optimization/debug options
190
216
  parser.add_argument(
191
217
  "--enable-flashinfer",
192
218
  action="store_true",
193
219
  help="Enable flashinfer inference kernels",
194
220
  )
221
+ parser.add_argument(
222
+ "--attention-reduce-in-fp32",
223
+ action="store_true",
224
+ help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16.",
225
+ )
226
+ parser.add_argument(
227
+ "--disable-radix-cache",
228
+ action="store_true",
229
+ help="Disable RadixAttention",
230
+ )
195
231
  parser.add_argument(
196
232
  "--disable-regex-jump-forward",
197
233
  action="store_true",
@@ -202,12 +238,6 @@ class ServerArgs:
202
238
  action="store_true",
203
239
  help="Disable disk cache to avoid possible crashes related to file system or high concurrency.",
204
240
  )
205
- parser.add_argument(
206
- "--api-key",
207
- type=str,
208
- default=ServerArgs.api_key,
209
- help="Set API Key",
210
- )
211
241
 
212
242
  @classmethod
213
243
  def from_cli_args(cls, args: argparse.Namespace):
@@ -217,13 +247,13 @@ class ServerArgs:
217
247
  def url(self):
218
248
  return f"http://{self.host}:{self.port}"
219
249
 
220
- def get_optional_modes_logging(self):
250
+ def print_mode_args(self):
221
251
  return (
222
- f"disable_radix_cache={self.disable_radix_cache}, "
223
252
  f"enable_flashinfer={self.enable_flashinfer}, "
253
+ f"attention_reduce_in_fp32={self.attention_reduce_in_fp32}, "
254
+ f"disable_radix_cache={self.disable_radix_cache}, "
224
255
  f"disable_regex_jump_forward={self.disable_regex_jump_forward}, "
225
256
  f"disable_disk_cache={self.disable_disk_cache}, "
226
- f"attention_reduce_in_fp32={self.attention_reduce_in_fp32}"
227
257
  )
228
258
 
229
259
 
sglang/srt/utils.py CHANGED
@@ -1,58 +1,74 @@
1
+ """Common utilities."""
2
+
1
3
  import base64
2
4
  import os
3
5
  import random
4
6
  import socket
5
- import sys
6
7
  import time
7
- import traceback
8
+ from importlib.metadata import PackageNotFoundError, version
8
9
  from io import BytesIO
9
10
  from typing import List, Optional
10
11
 
11
12
  import numpy as np
13
+ import pydantic
12
14
  import requests
13
15
  import torch
14
- import torch.distributed as dist
16
+ from fastapi.responses import JSONResponse
17
+ from packaging import version as pkg_version
18
+ from pydantic import BaseModel
19
+ from starlette.middleware.base import BaseHTTPMiddleware
15
20
 
16
- is_show_cost_time = False
21
+ from sglang.utils import get_exception_traceback
17
22
 
23
+ show_time_cost = False
24
+ time_infos = {}
18
25
 
19
- def mark_cost_time(func_name):
20
- def inner_func(func):
21
- def time_func(*args, **kwargs):
22
- if dist.get_rank() in [0, 1] and is_show_cost_time:
23
- torch.cuda.synchronize()
24
- start_time = time.time()
25
- ans = func(*args, **kwargs)
26
- torch.cuda.synchronize()
27
- print(func_name, "cost time:", (time.time() - start_time) * 1000)
28
- return ans
29
- else:
30
- torch.cuda.synchronize()
31
- ans = func(*args, **kwargs)
32
- torch.cuda.synchronize()
33
- return ans
34
26
 
35
- return time_func
27
+ def enable_show_time_cost():
28
+ global show_time_cost
29
+ show_time_cost = True
36
30
 
37
- return inner_func
38
31
 
32
+ class TimeInfo:
33
+ def __init__(self, name, interval=0.1, color=0, indent=0):
34
+ self.name = name
35
+ self.interval = interval
36
+ self.color = color
37
+ self.indent = indent
38
+
39
+ self.acc_time = 0
40
+ self.last_acc_time = 0
41
+
42
+ def check(self):
43
+ if self.acc_time - self.last_acc_time > self.interval:
44
+ self.last_acc_time = self.acc_time
45
+ return True
46
+ return False
39
47
 
40
- time_mark = {}
48
+ def pretty_print(self):
49
+ print(f"\x1b[{self.color}m", end="")
50
+ print("-" * self.indent * 2, end="")
51
+ print(f"{self.name}: {self.acc_time:.3f}s\x1b[0m")
41
52
 
42
53
 
43
- def mark_start(key):
54
+ def mark_start(name, interval=0.1, color=0, indent=0):
55
+ global time_infos, show_time_cost
56
+ if not show_time_cost:
57
+ return
44
58
  torch.cuda.synchronize()
45
- global time_mark
46
- time_mark[key] = time.time()
47
- return
59
+ if time_infos.get(name, None) is None:
60
+ time_infos[name] = TimeInfo(name, interval, color, indent)
61
+ time_infos[name].acc_time -= time.time()
48
62
 
49
63
 
50
- def mark_end(key, print_min_cost=0.0):
64
+ def mark_end(name):
65
+ global time_infos, show_time_cost
66
+ if not show_time_cost:
67
+ return
51
68
  torch.cuda.synchronize()
52
- global time_mark
53
- cost_time = (time.time() - time_mark[key]) * 1000
54
- if cost_time > print_min_cost:
55
- print(f"cost {key}:", cost_time)
69
+ time_infos[name].acc_time += time.time()
70
+ if time_infos[name].check():
71
+ time_infos[name].pretty_print()
56
72
 
57
73
 
58
74
  def calculate_time(show=False, min_cost_ms=0.0):
@@ -74,6 +90,32 @@ def calculate_time(show=False, min_cost_ms=0.0):
74
90
  return wrapper
75
91
 
76
92
 
93
+ def get_available_gpu_memory(gpu_id, distributed=True):
94
+ """
95
+ Get available memory for cuda:gpu_id device.
96
+ When distributed is True, the available memory is the minimum available memory of all GPUs.
97
+ """
98
+ num_gpus = torch.cuda.device_count()
99
+ assert gpu_id < num_gpus
100
+
101
+ if torch.cuda.current_device() != gpu_id:
102
+ print(
103
+ f"WARNING: current device is not {gpu_id}, but {torch.cuda.current_device()}, ",
104
+ "which may cause useless memory allocation for torch CUDA context.",
105
+ )
106
+
107
+ free_gpu_memory, _ = torch.cuda.mem_get_info(gpu_id)
108
+
109
+ if distributed:
110
+ tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to(
111
+ torch.device("cuda", gpu_id)
112
+ )
113
+ torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.MIN)
114
+ free_gpu_memory = tensor.item()
115
+
116
+ return free_gpu_memory / (1 << 30)
117
+
118
+
77
119
  def set_random_seed(seed: int) -> None:
78
120
  random.seed(seed)
79
121
 
@@ -89,11 +131,13 @@ def alloc_usable_network_port(num, used_list=()):
89
131
  continue
90
132
 
91
133
  with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
134
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
92
135
  try:
93
136
  s.bind(("", port))
137
+ s.listen(1) # Attempt to listen on the port
94
138
  port_list.append(port)
95
139
  except socket.error:
96
- pass
140
+ pass # If any error occurs, this port is not usable
97
141
 
98
142
  if len(port_list) == num:
99
143
  return port_list
@@ -110,7 +154,7 @@ def check_port(port):
110
154
  return False
111
155
 
112
156
 
113
- def handle_port_init(
157
+ def allocate_init_ports(
114
158
  port: Optional[int] = None,
115
159
  additional_ports: Optional[List[int]] = None,
116
160
  tp_size: int = 1,
@@ -142,15 +186,7 @@ def handle_port_init(
142
186
  return port, additional_ports
143
187
 
144
188
 
145
- def get_exception_traceback():
146
- etype, value, tb = sys.exc_info()
147
- err_str = "".join(traceback.format_exception(etype, value, tb))
148
- return err_str
149
-
150
-
151
189
  def get_int_token_logit_bias(tokenizer, vocab_size):
152
- from transformers import LlamaTokenizer, LlamaTokenizerFast
153
-
154
190
  # a bug when model's vocab size > tokenizer.vocab_size
155
191
  vocab_size = tokenizer.vocab_size
156
192
  logit_bias = np.zeros(vocab_size, dtype=np.float32)
@@ -231,20 +267,102 @@ def wrap_kernel_launcher(kernel):
231
267
 
232
268
 
233
269
  def is_multimodal_model(model):
234
- if isinstance(model, str):
235
- return "llava" in model or "yi-vl" in model
236
270
  from sglang.srt.model_config import ModelConfig
237
271
 
272
+ if isinstance(model, str):
273
+ model = model.lower()
274
+ return "llava" in model or "yi-vl" in model or "llava-next" in model
275
+
238
276
  if isinstance(model, ModelConfig):
239
277
  model_path = model.path.lower()
240
- return "llava" in model_path or "yi-vl" in model_path
241
- raise Exception("unrecognized type")
278
+ return "llava" in model_path or "yi-vl" in model_path or "llava-next" in model_path
279
+
280
+ raise ValueError("unrecognized type")
281
+
282
+
283
+ def decode_video_base64(video_base64):
284
+ from PIL import Image
285
+
286
+ # Decode the base64 string
287
+ video_bytes = base64.b64decode(video_base64)
288
+
289
+ # Placeholder for the start indices of each PNG image
290
+ img_starts = []
291
+
292
+ frame_format = "PNG" # str(os.getenv('FRAME_FORMAT', "JPEG"))
293
+
294
+ assert frame_format in [
295
+ "PNG",
296
+ "JPEG",
297
+ ], "FRAME_FORMAT must be either 'PNG' or 'JPEG'"
298
+
299
+ if frame_format == "PNG":
300
+ # Find each PNG start signature to isolate images
301
+ i = 0
302
+ while i < len(video_bytes) - 7: # Adjusted for the length of the PNG signature
303
+ # Check if we found the start of a PNG file
304
+ if (
305
+ video_bytes[i] == 0x89
306
+ and video_bytes[i + 1] == 0x50
307
+ and video_bytes[i + 2] == 0x4E
308
+ and video_bytes[i + 3] == 0x47
309
+ and video_bytes[i + 4] == 0x0D
310
+ and video_bytes[i + 5] == 0x0A
311
+ and video_bytes[i + 6] == 0x1A
312
+ and video_bytes[i + 7] == 0x0A
313
+ ):
314
+ img_starts.append(i)
315
+ i += 8 # Skip the PNG signature
316
+ else:
317
+ i += 1
318
+ else:
319
+ # Find each JPEG start (0xFFD8) to isolate images
320
+ i = 0
321
+ while (
322
+ i < len(video_bytes) - 1
323
+ ): # Adjusted for the length of the JPEG SOI signature
324
+ # Check if we found the start of a JPEG file
325
+ if video_bytes[i] == 0xFF and video_bytes[i + 1] == 0xD8:
326
+ img_starts.append(i)
327
+ # Move to the next byte to continue searching for the next image start
328
+ i += 2
329
+ else:
330
+ i += 1
331
+
332
+ frames = []
333
+ for start_idx in img_starts:
334
+ # Assuming each image is back-to-back, the end of one image is the start of another
335
+ # The last image goes until the end of the byte string
336
+ end_idx = (
337
+ img_starts[img_starts.index(start_idx) + 1]
338
+ if img_starts.index(start_idx) + 1 < len(img_starts)
339
+ else len(video_bytes)
340
+ )
341
+ img_bytes = video_bytes[start_idx:end_idx]
342
+
343
+ # Convert bytes to a PIL Image
344
+ img = Image.open(BytesIO(img_bytes))
345
+
346
+ # Convert PIL Image to a NumPy array
347
+ frame = np.array(img)
348
+
349
+ # Append the frame to the list of frames
350
+ frames.append(frame)
351
+
352
+ # Ensure there's at least one frame to avoid errors with np.stack
353
+ if frames:
354
+ return np.stack(frames, axis=0), img.size
355
+ else:
356
+ return np.array([]), (
357
+ 0,
358
+ 0,
359
+ ) # Return an empty array and size tuple if no frames were found
242
360
 
243
361
 
244
362
  def load_image(image_file):
245
363
  from PIL import Image
246
364
 
247
- image = None
365
+ image = image_size = None
248
366
 
249
367
  if image_file.startswith("http://") or image_file.startswith("https://"):
250
368
  timeout = int(os.getenv("REQUEST_TIMEOUT", "3"))
@@ -255,7 +373,54 @@ def load_image(image_file):
255
373
  elif image_file.startswith("data:"):
256
374
  image_file = image_file.split(",")[1]
257
375
  image = Image.open(BytesIO(base64.b64decode(image_file)))
376
+ elif image_file.startswith("video:"):
377
+ image_file = image_file.replace("video:", "")
378
+ image, image_size = decode_video_base64(image_file)
258
379
  else:
259
380
  image = Image.open(BytesIO(base64.b64decode(image_file)))
260
381
 
261
- return image
382
+ return image, image_size
383
+
384
+
385
+ def assert_pkg_version(pkg: str, min_version: str):
386
+ try:
387
+ installed_version = version(pkg)
388
+ if pkg_version.parse(installed_version) < pkg_version.parse(min_version):
389
+ raise Exception(
390
+ f"{pkg} is installed with version {installed_version} which "
391
+ f"is less than the minimum required version {min_version}"
392
+ )
393
+ except PackageNotFoundError:
394
+ raise Exception(
395
+ f"{pkg} with minimum required version {min_version} is not installed"
396
+ )
397
+
398
+
399
+ API_KEY_HEADER_NAME = "X-API-Key"
400
+
401
+
402
+ class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
403
+ def __init__(self, app, api_key: str):
404
+ super().__init__(app)
405
+ self.api_key = api_key
406
+
407
+ async def dispatch(self, request, call_next):
408
+ # extract API key from the request headers
409
+ api_key_header = request.headers.get(API_KEY_HEADER_NAME)
410
+ if not api_key_header or api_key_header != self.api_key:
411
+ return JSONResponse(
412
+ status_code=403,
413
+ content={"detail": "Invalid API Key"},
414
+ )
415
+ response = await call_next(request)
416
+ return response
417
+
418
+
419
+ # FIXME: Remove this once we drop support for pydantic 1.x
420
+ IS_PYDANTIC_1 = int(pydantic.VERSION.split(".")[0]) == 1
421
+
422
+
423
+ def jsonify_pydantic_model(obj: BaseModel):
424
+ if IS_PYDANTIC_1:
425
+ return obj.json(ensure_ascii=False)
426
+ return obj.model_dump_json()