sglang 0.1.14__py3-none-any.whl → 0.1.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. sglang/__init__.py +59 -2
  2. sglang/api.py +40 -11
  3. sglang/backend/anthropic.py +17 -3
  4. sglang/backend/litellm.py +90 -0
  5. sglang/backend/openai.py +160 -12
  6. sglang/backend/runtime_endpoint.py +62 -27
  7. sglang/backend/vertexai.py +1 -0
  8. sglang/bench_latency.py +320 -0
  9. sglang/global_config.py +24 -3
  10. sglang/lang/chat_template.py +122 -6
  11. sglang/lang/compiler.py +2 -2
  12. sglang/lang/interpreter.py +206 -98
  13. sglang/lang/ir.py +98 -34
  14. sglang/lang/tracer.py +6 -4
  15. sglang/launch_server.py +4 -1
  16. sglang/launch_server_llavavid.py +32 -0
  17. sglang/srt/constrained/__init__.py +14 -6
  18. sglang/srt/constrained/fsm_cache.py +9 -2
  19. sglang/srt/constrained/jump_forward.py +113 -24
  20. sglang/srt/conversation.py +4 -2
  21. sglang/srt/flush_cache.py +18 -0
  22. sglang/srt/hf_transformers_utils.py +144 -3
  23. sglang/srt/layers/context_flashattention_nopad.py +1 -0
  24. sglang/srt/layers/extend_attention.py +20 -1
  25. sglang/srt/layers/fused_moe.py +596 -0
  26. sglang/srt/layers/logits_processor.py +190 -61
  27. sglang/srt/layers/radix_attention.py +62 -53
  28. sglang/srt/layers/token_attention.py +21 -9
  29. sglang/srt/managers/controller/cuda_graph_runner.py +196 -0
  30. sglang/srt/managers/controller/dp_worker.py +113 -0
  31. sglang/srt/managers/controller/infer_batch.py +908 -0
  32. sglang/srt/managers/controller/manager_multi.py +195 -0
  33. sglang/srt/managers/controller/manager_single.py +177 -0
  34. sglang/srt/managers/controller/model_runner.py +359 -0
  35. sglang/srt/managers/{router → controller}/radix_cache.py +102 -53
  36. sglang/srt/managers/controller/schedule_heuristic.py +65 -0
  37. sglang/srt/managers/controller/tp_worker.py +813 -0
  38. sglang/srt/managers/detokenizer_manager.py +42 -40
  39. sglang/srt/managers/io_struct.py +44 -10
  40. sglang/srt/managers/tokenizer_manager.py +224 -82
  41. sglang/srt/memory_pool.py +52 -59
  42. sglang/srt/model_config.py +97 -2
  43. sglang/srt/models/chatglm.py +399 -0
  44. sglang/srt/models/commandr.py +369 -0
  45. sglang/srt/models/dbrx.py +406 -0
  46. sglang/srt/models/gemma.py +34 -38
  47. sglang/srt/models/gemma2.py +436 -0
  48. sglang/srt/models/grok.py +738 -0
  49. sglang/srt/models/llama2.py +47 -37
  50. sglang/srt/models/llama_classification.py +107 -0
  51. sglang/srt/models/llava.py +92 -27
  52. sglang/srt/models/llavavid.py +298 -0
  53. sglang/srt/models/minicpm.py +366 -0
  54. sglang/srt/models/mixtral.py +302 -127
  55. sglang/srt/models/mixtral_quant.py +372 -0
  56. sglang/srt/models/qwen.py +40 -35
  57. sglang/srt/models/qwen2.py +33 -36
  58. sglang/srt/models/qwen2_moe.py +473 -0
  59. sglang/srt/models/stablelm.py +33 -39
  60. sglang/srt/models/yivl.py +19 -26
  61. sglang/srt/openai_api_adapter.py +411 -0
  62. sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +44 -19
  63. sglang/srt/sampling_params.py +2 -0
  64. sglang/srt/server.py +197 -481
  65. sglang/srt/server_args.py +190 -74
  66. sglang/srt/utils.py +460 -95
  67. sglang/test/test_programs.py +73 -10
  68. sglang/test/test_utils.py +226 -7
  69. sglang/utils.py +97 -27
  70. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/METADATA +74 -45
  71. sglang-0.1.21.dist-info/RECORD +82 -0
  72. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/WHEEL +1 -1
  73. sglang/srt/backend_config.py +0 -13
  74. sglang/srt/managers/router/infer_batch.py +0 -503
  75. sglang/srt/managers/router/manager.py +0 -79
  76. sglang/srt/managers/router/model_rpc.py +0 -686
  77. sglang/srt/managers/router/model_runner.py +0 -514
  78. sglang/srt/managers/router/scheduler.py +0 -70
  79. sglang-0.1.14.dist-info/RECORD +0 -64
  80. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/LICENSE +0 -0
  81. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,4 @@
1
- """
2
- This file contains the SGL programs used for unit testing.
3
- """
1
+ """This file contains the SGL programs used for unit testing."""
4
2
 
5
3
  import json
6
4
  import re
@@ -226,7 +224,7 @@ Action 3: Finish [United States].\n
226
224
 
227
225
  def test_parallel_decoding():
228
226
  max_tokens = 64
229
- number = 5
227
+ fork_size = 5
230
228
 
231
229
  @sgl.function
232
230
  def parallel_decoding(s, topic):
@@ -234,17 +232,17 @@ def test_parallel_decoding():
234
232
  s += "USER: Give some tips for " + topic + ".\n"
235
233
  s += (
236
234
  "ASSISTANT: Okay. Here are "
237
- + str(number)
235
+ + str(fork_size)
238
236
  + " concise tips, each under 8 words:\n"
239
237
  )
240
238
 
241
239
  # Generate skeleton
242
- for i in range(1, 1 + number):
240
+ for i in range(1, 1 + fork_size):
243
241
  s += f"{i}." + sgl.gen(max_tokens=16, stop=[".", "\n"]) + ".\n"
244
242
 
245
243
  # Generate detailed tips
246
- forks = s.fork(number)
247
- for i in range(number):
244
+ forks = s.fork(fork_size)
245
+ for i in range(fork_size):
248
246
  forks[
249
247
  i
250
248
  ] += f"Now, I expand tip {i+1} into a detailed paragraph:\nTip {i+1}:"
@@ -253,7 +251,7 @@ def test_parallel_decoding():
253
251
 
254
252
  # Concatenate tips and summarize
255
253
  s += "Here are these tips with detailed explanation:\n"
256
- for i in range(number):
254
+ for i in range(fork_size):
257
255
  s += f"Tip {i+1}:" + forks[i]["detailed_tip"] + "\n"
258
256
 
259
257
  s += "\nIn summary," + sgl.gen("summary", max_tokens=512)
@@ -296,7 +294,7 @@ def test_parallel_encoding(check_answer=True):
296
294
  def test_image_qa():
297
295
  @sgl.function
298
296
  def image_qa(s, question):
299
- s += sgl.user(sgl.image("test_image.png") + question)
297
+ s += sgl.user(sgl.image("example_image.png") + question)
300
298
  s += sgl.assistant(sgl.gen("answer"))
301
299
 
302
300
  state = image_qa.run(
@@ -304,6 +302,7 @@ def test_image_qa():
304
302
  temperature=0,
305
303
  max_new_tokens=64,
306
304
  )
305
+
307
306
  assert (
308
307
  "taxi" in state.messages()[-1]["content"]
309
308
  or "car" in state.messages()[-1]["content"]
@@ -313,6 +312,7 @@ def test_image_qa():
313
312
  def test_stream():
314
313
  @sgl.function
315
314
  def qa(s, question):
315
+ s += sgl.system("You are a helpful assistant.")
316
316
  s += sgl.user(question)
317
317
  s += sgl.assistant(sgl.gen("answer"))
318
318
 
@@ -348,3 +348,66 @@ def test_regex():
348
348
  state = regex_gen.run()
349
349
  answer = state["answer"]
350
350
  assert re.match(regex, answer)
351
+
352
+
353
+ def test_completion_speculative():
354
+ @sgl.function(num_api_spec_tokens=64)
355
+ def gen_character_spec(s):
356
+ s += "Construct a character within the following format:\n"
357
+ s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
358
+ s += "\nPlease generate new Name, Birthday and Job.\n"
359
+ s += (
360
+ "Name:"
361
+ + sgl.gen("name", stop="\n")
362
+ + "\nBirthday:"
363
+ + sgl.gen("birthday", stop="\n")
364
+ )
365
+ s += "\nJob:" + sgl.gen("job", stop="\n") + "\n"
366
+
367
+ @sgl.function
368
+ def gen_character_no_spec(s):
369
+ s += "Construct a character within the following format:\n"
370
+ s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
371
+ s += "\nPlease generate new Name, Birthday and Job.\n"
372
+ s += (
373
+ "Name:"
374
+ + sgl.gen("name", stop="\n")
375
+ + "\nBirthday:"
376
+ + sgl.gen("birthday", stop="\n")
377
+ )
378
+ s += "\nJob:" + sgl.gen("job", stop="\n") + "\n"
379
+
380
+ token_usage = sgl.global_config.default_backend.token_usage
381
+
382
+ token_usage.reset()
383
+ gen_character_spec().sync()
384
+ usage_with_spec = token_usage.prompt_tokens
385
+
386
+ token_usage.reset()
387
+ gen_character_no_spec().sync()
388
+ usage_with_no_spec = token_usage.prompt_tokens
389
+
390
+ assert (
391
+ usage_with_spec < usage_with_no_spec
392
+ ), f"{usage_with_spec} vs {usage_with_no_spec}"
393
+
394
+
395
+ def test_chat_completion_speculative():
396
+ @sgl.function(num_api_spec_tokens=256)
397
+ def gen_character_spec(s):
398
+ s += sgl.system("You are a helpful assistant.")
399
+ s += sgl.user("Construct a character within the following format:")
400
+ s += sgl.assistant(
401
+ "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
402
+ )
403
+ s += sgl.user("Please generate new Name, Birthday and Job.\n")
404
+ s += sgl.assistant(
405
+ "Name:"
406
+ + sgl.gen("name", stop="\n")
407
+ + "\nBirthday:"
408
+ + sgl.gen("birthday", stop="\n")
409
+ + "\nJob:"
410
+ + sgl.gen("job", stop="\n")
411
+ )
412
+
413
+ gen_character_spec().sync()
sglang/test/test_utils.py CHANGED
@@ -1,13 +1,20 @@
1
1
  """Common utilities for testing and benchmarking"""
2
2
 
3
+ import asyncio
4
+ from functools import partial
5
+
3
6
  import numpy as np
4
7
  import requests
8
+
5
9
  from sglang.backend.openai import OpenAI
6
10
  from sglang.backend.runtime_endpoint import RuntimeEndpoint
7
11
  from sglang.global_config import global_config
12
+ from sglang.utils import get_exception_traceback
8
13
 
9
14
 
10
- def call_generate_lightllm(prompt, temperature, max_tokens, stop, url):
15
+ def call_generate_lightllm(prompt, temperature, max_tokens, stop=None, url=None):
16
+ assert url is not None
17
+
11
18
  data = {
12
19
  "inputs": prompt,
13
20
  "parameters": {
@@ -22,7 +29,9 @@ def call_generate_lightllm(prompt, temperature, max_tokens, stop, url):
22
29
  return pred
23
30
 
24
31
 
25
- def call_generate_vllm(prompt, temperature, max_tokens, stop, url, n=1):
32
+ def call_generate_vllm(prompt, temperature, max_tokens, stop=None, n=1, url=None):
33
+ assert url is not None
34
+
26
35
  data = {
27
36
  "prompt": prompt,
28
37
  "temperature": temperature,
@@ -40,8 +49,10 @@ def call_generate_vllm(prompt, temperature, max_tokens, stop, url, n=1):
40
49
 
41
50
 
42
51
  def call_generate_outlines(
43
- prompt, temperature, max_tokens, url, stop=[], regex=None, n=1
52
+ prompt, temperature, max_tokens, stop=[], regex=None, n=1, url=None
44
53
  ):
54
+ assert url is not None
55
+
45
56
  data = {
46
57
  "prompt": prompt,
47
58
  "temperature": temperature,
@@ -59,7 +70,9 @@ def call_generate_outlines(
59
70
  return pred
60
71
 
61
72
 
62
- def call_generate_srt_raw(prompt, temperature, max_tokens, stop, url):
73
+ def call_generate_srt_raw(prompt, temperature, max_tokens, stop=None, url=None):
74
+ assert url is not None
75
+
63
76
  data = {
64
77
  "text": prompt,
65
78
  "sampling_params": {
@@ -75,7 +88,98 @@ def call_generate_srt_raw(prompt, temperature, max_tokens, stop, url):
75
88
  return pred
76
89
 
77
90
 
78
- def call_select_lightllm(context, choices, url):
91
+ def call_generate_ginfer(prompt, temperature, max_tokens, stop=None, url=None):
92
+ import grpc
93
+ from ginfer import sampler_pb2, sampler_pb2_grpc
94
+
95
+ sampler_channel = grpc.insecure_channel(url.replace("http://", ""))
96
+ sampler = sampler_pb2_grpc.SamplerStub(sampler_channel)
97
+
98
+ if stop is None:
99
+ stop_strings = None
100
+ else:
101
+ stop_strings = [stop]
102
+
103
+ sample_request = sampler_pb2.SampleTextRequest(
104
+ prompt=prompt,
105
+ settings=sampler_pb2.SampleSettings(
106
+ max_len=max_tokens,
107
+ rng_seed=0,
108
+ temperature=max(temperature, 1e-7),
109
+ nucleus_p=1,
110
+ stop_strings=stop_strings,
111
+ ),
112
+ )
113
+ stream = sampler.SampleText(sample_request)
114
+ response = "".join([x.text for x in stream])
115
+ return response
116
+
117
+
118
+ def call_generate_guidance(
119
+ prompt, temperature, max_tokens, stop=None, n=1, regex=None, model=None
120
+ ):
121
+ assert model is not None
122
+ from guidance import gen
123
+
124
+ rets = []
125
+ for _ in range(n):
126
+ out = (
127
+ model
128
+ + prompt
129
+ + gen(
130
+ name="answer",
131
+ max_tokens=max_tokens,
132
+ temperature=temperature,
133
+ stop=stop,
134
+ regex=regex,
135
+ )
136
+ )
137
+ rets.append(out["answer"])
138
+ return rets if n > 1 else rets[0]
139
+
140
+
141
+ async def call_generate_lmql(
142
+ prompt, temperature, max_tokens, stop=None, n=1, max_len=4096, model=None, **kwargs
143
+ ):
144
+ assert model is not None
145
+ import lmql
146
+
147
+ if stop != None:
148
+
149
+ @lmql.query(model=model)
150
+ async def program(question, max_tokens, stop):
151
+ '''lmql
152
+ """{question}[ANSWER]""" where len(TOKENS(ANSWER)) < max_tokens and STOPS_AT(ANSWER, stop)
153
+ return ANSWER
154
+ '''
155
+
156
+ else:
157
+
158
+ @lmql.query(model=model)
159
+ async def program(question, max_tokens):
160
+ '''lmql
161
+ """{question}[ANSWER]""" where len(TOKENS(ANSWER)) < max_tokens
162
+ return ANSWER
163
+ '''
164
+
165
+ tasks = [
166
+ program(
167
+ question=prompt,
168
+ temperature=temperature,
169
+ max_tokens=max_tokens,
170
+ stop=stop,
171
+ max_len=max_len,
172
+ **kwargs,
173
+ )
174
+ for _ in range(n)
175
+ ]
176
+ rets = await asyncio.gather(*tasks)
177
+ return rets if n > 1 else rets[0]
178
+
179
+
180
+ def call_select_lightllm(context, choices, url=None):
181
+ assert url is not None
182
+
79
183
  scores = []
80
184
  for i in range(len(choices)):
81
185
  data = {
@@ -90,7 +194,9 @@ def call_select_lightllm(context, choices, url):
90
194
  return np.argmax(scores)
91
195
 
92
196
 
93
- def call_select_vllm(context, choices, url):
197
+ def call_select_vllm(context, choices, url=None):
198
+ assert url is not None
199
+
94
200
  scores = []
95
201
  for i in range(len(choices)):
96
202
  data = {
@@ -112,6 +218,31 @@ def call_select_vllm(context, choices, url):
112
218
  """
113
219
 
114
220
 
221
+ def call_select_guidance(context, choices, model=None):
222
+ assert model is not None
223
+ from guidance import select
224
+
225
+ out = model + context + select(choices, name="answer")
226
+ return choices.index(out["answer"])
227
+
228
+
229
+ async def call_select_lmql(context, choices, temperature=0, max_len=4096, model=None):
230
+ assert model is not None
231
+ import lmql
232
+
233
+ @lmql.query(model=model)
234
+ async def program(ctx, choices):
235
+ '''lmql
236
+ """{ctx}[ANSWER]""" where ANSWER in set(choices)
237
+ return ANSWER
238
+ '''
239
+
240
+ answer = await program(
241
+ ctx=context, choices=choices, temperature=temperature, max_len=max_len
242
+ )
243
+ return choices.index(answer)
244
+
245
+
115
246
  def add_common_other_args_and_parse(parser):
116
247
  parser.add_argument("--parallel", type=int, default=64)
117
248
  parser.add_argument("--host", type=str, default="http://127.0.0.1")
@@ -120,8 +251,18 @@ def add_common_other_args_and_parse(parser):
120
251
  "--backend",
121
252
  type=str,
122
253
  required=True,
123
- choices=["vllm", "lightllm", "guidance", "lmql", "srt-raw", "llama.cpp"],
254
+ choices=[
255
+ "vllm",
256
+ "outlines",
257
+ "lightllm",
258
+ "ginfer",
259
+ "guidance",
260
+ "lmql",
261
+ "srt-raw",
262
+ "llama.cpp",
263
+ ],
124
264
  )
265
+ parser.add_argument("--n-ctx", type=int, default=4096)
125
266
  parser.add_argument(
126
267
  "--model-path", type=str, default="meta-llama/Llama-2-7b-chat-hf"
127
268
  )
@@ -131,9 +272,11 @@ def add_common_other_args_and_parse(parser):
131
272
  if args.port is None:
132
273
  default_port = {
133
274
  "vllm": 21000,
275
+ "outlines": 21000,
134
276
  "lightllm": 22000,
135
277
  "lmql": 23000,
136
278
  "srt-raw": 30000,
279
+ "ginfer": 9988,
137
280
  }
138
281
  args.port = default_port.get(args.backend, None)
139
282
  return args
@@ -160,3 +303,79 @@ def select_sglang_backend(args):
160
303
  else:
161
304
  raise ValueError(f"Invalid backend: {args.backend}")
162
305
  return backend
306
+
307
+
308
+ def _get_call_generate(args):
309
+ if args.backend == "lightllm":
310
+ return partial(call_generate_lightllm, url=f"{args.host}:{args.port}/generate")
311
+ elif args.backend == "vllm":
312
+ return partial(call_generate_vllm, url=f"{args.host}:{args.port}/generate")
313
+ elif args.backend == "srt-raw":
314
+ return partial(call_generate_srt_raw, url=f"{args.host}:{args.port}/generate")
315
+ elif args.backend == "ginfer":
316
+ return partial(call_generate_ginfer, url=f"{args.host}:{args.port}")
317
+ elif args.backend == "outlines":
318
+ return partial(call_generate_outlines, url=f"{args.host}:{args.port}/generate")
319
+ elif args.backend == "guidance":
320
+ from guidance import models
321
+
322
+ model = models.LlamaCpp(args.model_path, n_gpu_layers=-1, n_ctx=args.n_ctx)
323
+ call_generate = partial(call_generate_guidance, model=model)
324
+ call_generate("Hello,", 1.0, 8, ".")
325
+ return call_generate
326
+ elif args.backend == "lmql":
327
+ import lmql
328
+
329
+ model = lmql.model(args.model_path, endpoint=f"{args.host}:{args.port}")
330
+ return partial(call_generate_lmql, model=model)
331
+ else:
332
+ raise ValueError(f"Invalid backend: {args.backend}")
333
+
334
+
335
+ def _get_call_select(args):
336
+ if args.backend == "lightllm":
337
+ return partial(call_select_lightllm, url=f"{args.host}:{args.port}/generate")
338
+ elif args.backend == "vllm":
339
+ return partial(call_select_vllm, url=f"{args.host}:{args.port}/generate")
340
+ elif args.backend == "guidance":
341
+ from guidance import models
342
+
343
+ model = models.LlamaCpp(args.model_path, n_gpu_layers=-1, n_ctx=args.n_ctx)
344
+ call_select = partial(call_select_guidance, model=model)
345
+
346
+ call_select("Hello,", ["world", "earth"])
347
+ return call_select
348
+
349
+ elif args.backend == "lmql":
350
+ import lmql
351
+
352
+ model = lmql.model(args.model_path, endpoint=f"{args.host}:{args.port}")
353
+ return partial(call_select_lmql, model=model)
354
+ else:
355
+ raise ValueError(f"Invalid backend: {args.backend}")
356
+
357
+
358
+ def get_call_generate(args):
359
+ call_generate = _get_call_generate(args)
360
+
361
+ def func(*args, **kwargs):
362
+ try:
363
+ return call_generate(*args, **kwargs)
364
+ except Exception:
365
+ print("Exception in call_generate:\n" + get_exception_traceback())
366
+ raise
367
+
368
+ return func
369
+
370
+
371
+ def get_call_select(args):
372
+ call_select = _get_call_select(args)
373
+
374
+ def func(*args, **kwargs):
375
+ try:
376
+ return call_select(*args, **kwargs)
377
+ except Exception:
378
+ print("Exception in call_select:\n" + get_exception_traceback())
379
+ raise
380
+
381
+ return func
sglang/utils.py CHANGED
@@ -2,40 +2,26 @@
2
2
 
3
3
  import base64
4
4
  import json
5
+ import logging
6
+ import signal
7
+ import sys
5
8
  import threading
9
+ import traceback
6
10
  import urllib.request
11
+ from concurrent.futures import ThreadPoolExecutor
7
12
  from io import BytesIO
8
13
  from json import dumps
9
14
 
15
+ import numpy as np
10
16
  import requests
11
17
 
18
+ logger = logging.getLogger(__name__)
12
19
 
13
- def get_available_gpu_memory(gpu_id, distributed=True):
14
- """
15
- Get available memory for cuda:gpu_id device.
16
- When distributed is True, the available memory is the minimum available memory of all GPUs.
17
- """
18
- import torch
19
20
 
20
- num_gpus = torch.cuda.device_count()
21
- assert gpu_id < num_gpus
22
-
23
- if torch.cuda.current_device() != gpu_id:
24
- print(
25
- f"WARNING: current device is not {gpu_id}, but {torch.cuda.current_device()}, ",
26
- "which may cause useless memory allocation for torch CUDA context.",
27
- )
28
-
29
- free_gpu_memory, _ = torch.cuda.mem_get_info(gpu_id)
30
-
31
- if distributed:
32
- tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to(
33
- torch.device("cuda", gpu_id)
34
- )
35
- torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.MIN)
36
- free_gpu_memory = tensor.item()
37
-
38
- return free_gpu_memory / (1 << 30)
21
+ def get_exception_traceback():
22
+ etype, value, tb = sys.exc_info()
23
+ err_str = "".join(traceback.format_exception(etype, value, tb))
24
+ return err_str
39
25
 
40
26
 
41
27
  def is_same_type(values):
@@ -110,8 +96,12 @@ def http_request(
110
96
  data = None
111
97
  else:
112
98
  data = bytes(dumps(json), encoding="utf-8")
113
- resp = urllib.request.urlopen(req, data=data, cafile=verify)
114
- return HttpResponse(resp)
99
+
100
+ try:
101
+ resp = urllib.request.urlopen(req, data=data, cafile=verify)
102
+ return HttpResponse(resp)
103
+ except urllib.error.HTTPError as e:
104
+ return HttpResponse(e)
115
105
 
116
106
 
117
107
  def encode_image_base64(image_path):
@@ -130,6 +120,75 @@ def encode_image_base64(image_path):
130
120
  return base64.b64encode(buffered.getvalue()).decode("utf-8")
131
121
 
132
122
 
123
+ def encode_frame(frame):
124
+ import cv2 # pip install opencv-python-headless
125
+ from PIL import Image
126
+
127
+ # Convert the frame to RGB (OpenCV uses BGR by default)
128
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
129
+
130
+ # Convert the frame to PIL Image to easily convert to bytes
131
+ im_pil = Image.fromarray(frame)
132
+
133
+ # Convert to bytes
134
+ buffered = BytesIO()
135
+
136
+ # frame_format = str(os.getenv('FRAME_FORMAT', "JPEG"))
137
+
138
+ im_pil.save(buffered, format="PNG")
139
+
140
+ frame_bytes = buffered.getvalue()
141
+
142
+ # Return the bytes of the frame
143
+ return frame_bytes
144
+
145
+
146
+ def encode_video_base64(video_path, num_frames=16):
147
+ import cv2 # pip install opencv-python-headless
148
+
149
+ cap = cv2.VideoCapture(video_path)
150
+ if not cap.isOpened():
151
+ raise IOError(f"Could not open video file:{video_path}")
152
+
153
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
154
+ print(f"target_frames: {num_frames}")
155
+
156
+ frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
157
+
158
+ frames = []
159
+ for i in range(total_frames):
160
+ ret, frame = cap.read()
161
+ if ret:
162
+ frames.append(frame)
163
+ else:
164
+ # Handle the case where the frame could not be read
165
+ # print(f"Warning: Could not read frame at index {i}.")
166
+ pass
167
+
168
+ cap.release()
169
+
170
+ # Safely select frames based on frame_indices, avoiding IndexError
171
+ frames = [frames[i] for i in frame_indices if i < len(frames)]
172
+
173
+ # If there are not enough frames, duplicate the last frame until we reach the target
174
+ while len(frames) < num_frames:
175
+ frames.append(frames[-1])
176
+
177
+ # Use ThreadPoolExecutor to process and encode frames in parallel
178
+ with ThreadPoolExecutor() as executor:
179
+ encoded_frames = list(executor.map(encode_frame, frames))
180
+
181
+ # encoded_frames = list(map(encode_frame, frames))
182
+
183
+ # Concatenate all frames bytes
184
+ video_bytes = b"".join(encoded_frames)
185
+
186
+ # Encode the concatenated bytes to base64
187
+ video_base64 = "video:" + base64.b64encode(video_bytes).decode("utf-8")
188
+
189
+ return video_base64
190
+
191
+
133
192
  def _is_chinese_char(cp):
134
193
  """Checks whether CP is the codepoint of a CJK character."""
135
194
  # This defines a "chinese character" as anything in the CJK Unicode block:
@@ -191,3 +250,14 @@ def run_with_timeout(func, args=(), kwargs=None, timeout=None):
191
250
  raise RuntimeError()
192
251
 
193
252
  return ret_value[0]
253
+
254
+
255
+ def graceful_registry(sub_module_name):
256
+ def graceful_shutdown(signum, frame):
257
+ logger.info(
258
+ f"{sub_module_name} Received signal to shutdown. Performing graceful shutdown..."
259
+ )
260
+ if signum == signal.SIGTERM:
261
+ logger.info(f"{sub_module_name} recive sigterm")
262
+
263
+ signal.signal(signal.SIGTERM, graceful_shutdown)