nexaai 1.0.19rc16__cp310-cp310-macosx_13_0_x86_64.whl → 1.0.19rc18__cp310-cp310-macosx_13_0_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nexaai might be problematic. Click here for more details.

Binary file
nexaai/_version.py CHANGED
@@ -1,4 +1,4 @@
1
1
  # This file is generated by CMake from _version.py.in
2
2
  # Do not modify this file manually - it will be overwritten
3
3
 
4
- __version__ = "1.0.19-rc16"
4
+ __version__ = "1.0.19-rc18"
Binary file
@@ -43,17 +43,156 @@ def _ensure_list(x: Union[str, List[str], None]) -> Optional[List[str]]:
43
43
  return x if isinstance(x, list) else [x]
44
44
 
45
45
 
46
+ def get_model_configs(model_name: str):
47
+ """Get model configurations based on model name"""
48
+
49
+ # 4B model configs (default)
50
+ if model_name in ["qwen3vl", "qwen3vl-4b", "qwen3vl-4b-thinking"]:
51
+ vision_config = VisionConfig(
52
+ hidden_size=1024,
53
+ intermediate_size=4096,
54
+ num_heads=16,
55
+ num_hidden_layers=24,
56
+ patch_size=16,
57
+ temporal_patch_size=2,
58
+ in_channels=3,
59
+ hidden_act="gelu",
60
+ spatial_merge_size=2,
61
+ out_hidden_size=2560,
62
+ num_position_embeddings=2304,
63
+ deepstack_visual_indexes=[5, 11, 17],
64
+ )
65
+
66
+ text_config = TextConfig(
67
+ model_type="qwen3vl",
68
+ hidden_size=2560,
69
+ num_hidden_layers=36,
70
+ intermediate_size=9728,
71
+ num_attention_heads=32,
72
+ num_key_value_heads=8,
73
+ rms_norm_eps=1e-6,
74
+ vocab_size=151936,
75
+ max_position_embeddings=32768,
76
+ rope_theta=5000000.0,
77
+ head_dim=128,
78
+ tie_word_embeddings=True,
79
+ attention_bias=False,
80
+ attention_dropout=0.0,
81
+ rope_scaling={"mrope_section": [24, 20, 20],
82
+ "rope_type": "default", "type": "default"},
83
+ )
84
+
85
+ # 8B model configs
86
+ elif model_name in ["qwen3vl-8b", "qwen3vl-8b-thinking"]:
87
+ vision_config = VisionConfig(
88
+ hidden_size=1152,
89
+ intermediate_size=4304,
90
+ num_heads=16,
91
+ num_hidden_layers=27,
92
+ patch_size=16,
93
+ temporal_patch_size=2,
94
+ in_channels=3,
95
+ hidden_act="gelu",
96
+ spatial_merge_size=2,
97
+ out_hidden_size=4096,
98
+ num_position_embeddings=2304,
99
+ deepstack_visual_indexes=[8, 16, 24],
100
+ )
101
+
102
+ text_config = TextConfig(
103
+ model_type="qwen3vl",
104
+ hidden_size=4096,
105
+ num_hidden_layers=36,
106
+ intermediate_size=12288,
107
+ num_attention_heads=32,
108
+ num_key_value_heads=8,
109
+ rms_norm_eps=1e-6,
110
+ vocab_size=151936,
111
+ max_position_embeddings=262144,
112
+ rope_theta=5000000,
113
+ head_dim=128,
114
+ tie_word_embeddings=False,
115
+ attention_bias=False,
116
+ attention_dropout=0.0,
117
+ rope_scaling={"mrope_section": [24, 20, 20], "rope_type": "default", "mrope_interleaved": True},
118
+ )
119
+ else:
120
+ # Fallback to 4B config
121
+ return get_model_configs("qwen3vl-4b")
122
+
123
+ return vision_config, text_config
124
+
125
+ def get_weight_filenames(model_name: str, model_path: Path):
126
+ """Get appropriate weight filenames based on model name and available files"""
127
+
128
+ # Determine model size and type based on the actual file structure
129
+ if "4b" in model_name:
130
+ size_prefix = "4b"
131
+ elif "8b" in model_name:
132
+ size_prefix = "8b"
133
+ else:
134
+ size_prefix = "4b"
135
+
136
+ # Determine model type
137
+ if "thinking" in model_name:
138
+ model_type = f"{size_prefix}_thinking"
139
+ else:
140
+ model_type = f"{size_prefix}_instruct"
141
+
142
+ # Try different weight file patterns matching the actual file structure
143
+ llm_patterns = [
144
+ # New naming convention matching actual files
145
+ f"qwen3vl-llm-{model_type}-q4_0.safetensors",
146
+ f"qwen3vl-llm-{model_type}-q8_0.safetensors",
147
+ f"qwen3vl-llm-{model_type}-f16.safetensors",
148
+ # Legacy naming convention
149
+ f"qwen3vl-llm-{size_prefix.upper()}-q4_0.safetensors",
150
+ f"qwen3vl-llm-{size_prefix.upper()}-q8_0.safetensors",
151
+ f"qwen3vl-llm-{size_prefix.upper()}-f16.safetensors",
152
+ f"qwen3vl-llm-{size_prefix.upper()}-f32.safetensors",
153
+ ]
154
+
155
+ vision_patterns = [
156
+ f"qwen3vl-vision-{model_type}-f16.safetensors",
157
+ f"qwen3vl-vision-{size_prefix.upper()}-f16.safetensors",
158
+ ]
159
+
160
+ # Find LLM weights
161
+ llm_weights_path = None
162
+ quantization_bits = None
163
+
164
+ for pattern in llm_patterns:
165
+ candidate_path = model_path / pattern
166
+ if candidate_path.exists():
167
+ llm_weights_path = candidate_path
168
+ if "q4_0" in pattern:
169
+ quantization_bits = 4
170
+ elif "q8_0" in pattern:
171
+ quantization_bits = 8
172
+ else:
173
+ quantization_bits = 16
174
+ break
175
+
176
+ # Find vision weights
177
+ vision_weights_path = None
178
+ for pattern in vision_patterns:
179
+ candidate_path = model_path / pattern
180
+ if candidate_path.exists():
181
+ vision_weights_path = candidate_path
182
+ break
183
+
184
+ return llm_weights_path, vision_weights_path, quantization_bits
185
+
186
+ # Update the load_qwen3_vl function signature and implementation:
46
187
  def load_qwen3_vl(
47
188
  path_or_repo: str,
48
189
  adapter_path: Optional[str] = None,
49
190
  lazy: bool = False,
50
191
  revision: Optional[str] = None,
192
+ model_name: Optional[str] = None,
51
193
  **kwargs,
52
194
  ) -> Tuple[Qwen3VLBundledModel, Qwen3VLProcessor]:
53
- """Load Qwen3-VL quantized models and processor.
54
-
55
- Parameters are aligned with .generate.load for compatibility.
56
- """
195
+ """Load Qwen3-VL quantized models and processor with support for different model sizes."""
57
196
 
58
197
  model_path = Path(path_or_repo)
59
198
  if not model_path.exists():
@@ -67,70 +206,22 @@ def load_qwen3_vl(
67
206
  if not model_path.exists():
68
207
  model_path = curr_dir / "modelfiles"
69
208
 
70
- # Model configs (kept identical to main)
71
- vision_config = VisionConfig(
72
- hidden_size=1024,
73
- intermediate_size=4096,
74
- num_heads=16,
75
- num_hidden_layers=24,
76
- patch_size=16,
77
- temporal_patch_size=2,
78
- in_channels=3,
79
- hidden_act="gelu",
80
- spatial_merge_size=2,
81
- out_hidden_size=2560,
82
- num_position_embeddings=2304,
83
- deepstack_visual_indexes=[5, 11, 17],
84
- )
85
-
86
- text_config = TextConfig(
87
- model_type="qwen3vl",
88
- hidden_size=2560,
89
- num_hidden_layers=36,
90
- intermediate_size=9728,
91
- num_attention_heads=32,
92
- num_key_value_heads=8,
93
- rms_norm_eps=1e-6,
94
- vocab_size=151936,
95
- max_position_embeddings=32768,
96
- rope_theta=5000000.0,
97
- head_dim=128,
98
- tie_word_embeddings=True,
99
- attention_bias=False,
100
- attention_dropout=0.0,
101
- rope_scaling={"mrope_section": [24, 20, 20],
102
- "rope_type": "default", "type": "default"},
103
- )
209
+ # Get model configurations based on model name
210
+ if model_name:
211
+ vision_config, text_config = get_model_configs(model_name)
212
+ else:
213
+ # Default to 4B config
214
+ vision_config, text_config = get_model_configs("qwen3vl-4b")
104
215
 
105
216
  vision_model = VEGModel(vision_config)
106
217
  llm_model = LLMModel(text_config)
107
218
 
108
- # Try to load LLM model from available files in order of preference
109
- preferred_order = [
110
- ("qwen3vl-llm-4B-q4_0.safetensors", 4),
111
- ("qwen3vl-llm-4B-q8_0.safetensors", 8),
112
- ("qwen3vl-llm-4B-f32.safetensors", 32)
113
- ]
114
-
115
- llm_weights_path = None
116
- quantization_bits = None
117
-
118
- # Try loading in order of preference
119
- for filename, bits in preferred_order:
120
- candidate_path = model_path / filename
121
- if candidate_path.exists():
122
- llm_weights_path = candidate_path
123
- quantization_bits = bits
124
- break
125
-
126
- if llm_weights_path is None:
127
- # Fallback to original hardcoded path for backward compatibility
128
- llm_weights_path = model_path / "qwen3vl-llm-4B-q4_0.safetensors"
129
- quantization_bits = 4
130
-
131
- vision_weights_path = model_path / "qwen3vl-vision-4B-f16.safetensors"
219
+ # Get appropriate weight filenames
220
+ llm_weights_path, vision_weights_path, quantization_bits = get_weight_filenames(
221
+ model_name or "qwen3vl-4b", model_path
222
+ )
132
223
 
133
- if not vision_weights_path.exists() or not llm_weights_path.exists():
224
+ if not vision_weights_path or not llm_weights_path:
134
225
  raise FileNotFoundError(
135
226
  f"Missing safetensors. Vision: {vision_weights_path}, LLM: {llm_weights_path}"
136
227
  )
@@ -146,8 +237,14 @@ def load_qwen3_vl(
146
237
 
147
238
  llm_model.load_weights(str(llm_weights_path), strict=True)
148
239
 
149
- # Tokenizer and processor
150
- tokenizer = AutoTokenizer.from_pretrained(path_or_repo)
240
+ try:
241
+ tokenizer = AutoTokenizer.from_pretrained(str(model_path))
242
+ except Exception:
243
+ try:
244
+ tokenizer = AutoTokenizer.from_pretrained(path_or_repo)
245
+ except Exception:
246
+ raise Exception("Failed to load tokenizer from the same path where model weights are loaded and original path_or_repo.")
247
+
151
248
  processor = Qwen3VLProcessor(tokenizer=tokenizer)
152
249
 
153
250
  return Qwen3VLBundledModel(vision_model=vision_model, llm_model=llm_model), processor
@@ -81,12 +81,13 @@ class VLM(ProfilingMixin):
81
81
 
82
82
  if model_name == "qwen3vl-moe":
83
83
  load_impl = load_qwen3_vl_moe
84
- elif model_name == "qwen3vl":
84
+ elif model_name in ["qwen3vl", "qwen3vl-4b", "qwen3vl-4b-thinking", "qwen3vl-8b", "qwen3vl-8b-thinking"]:
85
85
  load_impl = load_qwen3_vl
86
86
  else:
87
87
  load_impl = load
88
88
 
89
- self.model, self.processor = load_impl(str(model_path))
89
+ # Pass model_name to the loader for proper configuration
90
+ self.model, self.processor = load_impl(str(model_path), model_name=model_name)
90
91
 
91
92
  # Init deafutl sampler config with defualt.
92
93
  self.sampler_config = SamplerConfig()
@@ -94,6 +95,19 @@ class VLM(ProfilingMixin):
94
95
  # Track global character position for incremental processing
95
96
  self.global_n_past_chars = 0
96
97
 
98
+ # Add conversation state tracking to VLM class
99
+ if model_name in ["qwen3vl", "qwen3vl-4b", "qwen3vl-4b-thinking", "qwen3vl-8b", "qwen3vl-8b-thinking"]:
100
+ # Import here to avoid circular imports
101
+ from .modeling.models.qwen3_vl.llm_common.cache import make_prompt_cache
102
+ import mlx.core as mx
103
+
104
+ # Initialize conversation state
105
+ self.rope_deltas_total = mx.zeros((1, 1), dtype=mx.int32)
106
+ self.prompt_cache = make_prompt_cache(self.model.llm_model, max_kv_size=4096)
107
+ else:
108
+ self.rope_deltas_total = None
109
+ self.prompt_cache = None
110
+
97
111
  def destroy(self) -> None:
98
112
  """Destroy the model and free resources."""
99
113
  self.model = None
@@ -103,6 +117,14 @@ class VLM(ProfilingMixin):
103
117
  """Reset the model state."""
104
118
  self._reset_cache()
105
119
  self.global_n_past_chars = 0
120
+
121
+ # Reset conversation state for qwen3vl models
122
+ if self.model_name in ["qwen3vl", "qwen3vl-4b", "qwen3vl-4b-thinking", "qwen3vl-8b", "qwen3vl-8b-thinking"]:
123
+ import mlx.core as mx
124
+ from .modeling.models.qwen3_vl.llm_common.cache import make_prompt_cache
125
+
126
+ self.rope_deltas_total = mx.zeros((1, 1), dtype=mx.int32)
127
+ self.prompt_cache = make_prompt_cache(self.model.llm_model, max_kv_size=4096)
106
128
 
107
129
  def _reset_cache(self) -> None:
108
130
  """Reset the KV cache."""
@@ -280,7 +302,7 @@ class VLM(ProfilingMixin):
280
302
 
281
303
  # Apply incremental processing only for non-qwen3vl models
282
304
  # qwen3vl requires complete JSON conversation structure
283
- if self.model_name != "qwen3vl":
305
+ if self.model_name not in ["qwen3vl", "qwen3vl-4b", "qwen3vl-4b-thinking", "qwen3vl-8b", "qwen3vl-8b-thinking", "qwen3vl-moe"]:
284
306
  if self.global_n_past_chars < full_prompt_len:
285
307
  incremental_prompt = prompt[self.global_n_past_chars:]
286
308
  else:
@@ -297,7 +319,7 @@ class VLM(ProfilingMixin):
297
319
 
298
320
  if self.model_name == "qwen3vl-moe":
299
321
  stream_generate_impl = stream_generate_qwen3_vl_moe
300
- elif self.model_name == "qwen3vl":
322
+ elif self.model_name in ["qwen3vl", "qwen3vl-4b", "qwen3vl-4b-thinking", "qwen3vl-8b", "qwen3vl-8b-thinking"]:
301
323
  stream_generate_impl = stream_generate_qwen3_vl
302
324
  else:
303
325
  stream_generate_impl = stream_generate
@@ -305,28 +327,59 @@ class VLM(ProfilingMixin):
305
327
  try:
306
328
  token_count = 0
307
329
 
308
- for result in stream_generate_impl(
309
- self.model,
310
- self.processor,
311
- incremental_prompt, # Use incremental prompt instead of full prompt
312
- image=image_list,
313
- audio=audio_list,
314
- **gen_kwargs,
315
- ):
316
- token_count += 1
317
-
318
- # Record TTFT on first token
319
- if first_token:
320
- self._record_ttft()
321
- first_token = False
322
-
323
- # Call the token callback if provided
324
- if on_token is not None:
325
- if not on_token(result.text):
326
- self._set_stop_reason(StopReason.ML_STOP_REASON_USER)
327
- break
328
- text += result.text
329
- last_result = result
330
+ # Pass conversation state for qwen3vl models
331
+ if self.model_name in ["qwen3vl", "qwen3vl-4b", "qwen3vl-4b-thinking", "qwen3vl-8b", "qwen3vl-8b-thinking"]:
332
+ for result in stream_generate_impl(
333
+ self.model,
334
+ self.processor,
335
+ incremental_prompt,
336
+ image=image_list,
337
+ audio=audio_list,
338
+ rope_deltas_total=self.rope_deltas_total, # Pass conversation state
339
+ prompt_cache=self.prompt_cache, # Pass KV cache
340
+ **gen_kwargs,
341
+ ):
342
+ token_count += 1
343
+
344
+ # Record TTFT on first token
345
+ if first_token:
346
+ self._record_ttft()
347
+ first_token = False
348
+
349
+ # Call the token callback if provided
350
+ if on_token is not None:
351
+ if not on_token(result.text):
352
+ self._set_stop_reason(StopReason.ML_STOP_REASON_USER)
353
+ break
354
+ text += result.text
355
+ last_result = result
356
+
357
+ # Update conversation state after each token
358
+ # Note: rope_deltas_total is updated inside stream_generate_qwen3_vl
359
+
360
+ else:
361
+ for result in stream_generate_impl(
362
+ self.model,
363
+ self.processor,
364
+ incremental_prompt,
365
+ image=image_list,
366
+ audio=audio_list,
367
+ **gen_kwargs,
368
+ ):
369
+ token_count += 1
370
+
371
+ # Record TTFT on first token
372
+ if first_token:
373
+ self._record_ttft()
374
+ first_token = False
375
+
376
+ # Call the token callback if provided
377
+ if on_token is not None:
378
+ if not on_token(result.text):
379
+ self._set_stop_reason(StopReason.ML_STOP_REASON_USER)
380
+ break
381
+ text += result.text
382
+ last_result = result
330
383
 
331
384
 
332
385
  # Set stop reason if not user stop
@@ -339,7 +392,7 @@ class VLM(ProfilingMixin):
339
392
  self._update_generated_tokens(last_result.generation_tokens)
340
393
 
341
394
  # Update global character position (not needed for qwen3vl JSON processing)
342
- if self.model_name != "qwen3vl":
395
+ if self.model_name not in ["qwen3vl", "qwen3vl-4b", "qwen3vl-4b-thinking", "qwen3vl-8b", "qwen3vl-8b-thinking", "qwen3vl-moe"]:
343
396
  old_pos = self.global_n_past_chars
344
397
  self.global_n_past_chars = full_prompt_len + len(text)
345
398
 
@@ -444,11 +497,10 @@ class VLM(ProfilingMixin):
444
497
 
445
498
  def apply_chat_template_with_media(self, messages: Sequence[ChatMessage], num_images: int = 0, num_audios: int = 0, tools: Optional[str] = None, enable_thinking: bool = True) -> str:
446
499
  """Apply chat template to messages with proper image/audio token insertion and optional tools support."""
447
- if self.model_name == "qwen3vl":
500
+ if self.model_name in ["qwen3vl", "qwen3vl-4b", "qwen3vl-4b-thinking", "qwen3vl-8b", "qwen3vl-8b-thinking"]:
448
501
  return apply_chat_template_qwen3_vl(messages, num_images=num_images, num_audios=num_audios, tools=tools, enable_thinking=enable_thinking)
449
502
  if self.model_name == "qwen3vl-moe":
450
503
  return apply_chat_template_qwen3_vl_moe(messages, num_images=num_images, num_audios=num_audios, tools=tools, enable_thinking=enable_thinking)
451
- # Convert ChatMessage objects to dictionaries for the processor
452
504
  messages_dict = [{"role": msg.role, "content": msg.content} for msg in messages]
453
505
 
454
506
  parsed_tools = None
@@ -40,6 +40,57 @@ def parse_media_from_input(user_input):
40
40
 
41
41
  return prompt, image_paths if image_paths else None, audio_paths if audio_paths else None
42
42
 
43
+ def detect_model_name_and_repo(model_path):
44
+ """Detect model name and corresponding HuggingFace repo based on model path or name"""
45
+ model_path_lower = model_path.lower()
46
+
47
+ # Handle HuggingFace repo format
48
+ if "/" in model_path:
49
+ repo_name = model_path.split("/")[-1] if model_path.endswith("/") else model_path.split("/")[-1]
50
+ repo_name_lower = repo_name.lower()
51
+ else:
52
+ repo_name_lower = model_path_lower
53
+
54
+ # Model name mapping based on the provided examples
55
+ model_mappings = {
56
+ # 4B models
57
+ "qwen3vl-4b-4bit-mlx": ("qwen3vl-4b", "NexaAI/qwen3vl-4B-4bit-mlx"),
58
+ "qwen3vl-4b-fp16-mlx": ("qwen3vl-4b", "NexaAI/qwen3vl-4B-fp16-mlx"),
59
+ "qwen3vl-4b-thinking-4bit-mlx": ("qwen3vl-4b-thinking", "NexaAI/qwen3vl-4B-thinking-4bit-mlx"),
60
+ "qwen3vl-4b-thinking-fp16-mlx": ("qwen3vl-4b-thinking", "NexaAI/qwen3vl-4B-thinking-fp16-mlx"),
61
+
62
+ # 8B models
63
+ "qwen3vl-8b-4bit-mlx": ("qwen3vl-8b", "NexaAI/qwen3vl-8B-4bit-mlx"),
64
+ "qwen3vl-8b-fp16-mlx": ("qwen3vl-8b", "NexaAI/qwen3vl-8B-fp16-mlx"),
65
+ "qwen3vl-8b-thinking-4bit-mlx": ("qwen3vl-8b-thinking", "NexaAI/qwen3vl-8B-thinking-4bit-mlx"),
66
+ "qwen3vl-8b-thinking-fp16-mlx": ("qwen3vl-8b-thinking", "NexaAI/qwen3vl-8B-thinking-fp16-mlx"),
67
+ }
68
+
69
+ # Check exact matches first
70
+ for key, (model_name, repo) in model_mappings.items():
71
+ if key in repo_name_lower:
72
+ return model_name, repo if "/" not in model_path else model_path
73
+
74
+ # Fallback detection based on patterns
75
+ if "qwen3vl" in repo_name_lower:
76
+ if "8b" in repo_name_lower:
77
+ if "thinking" in repo_name_lower:
78
+ return "qwen3vl-8b-thinking", model_path
79
+ else:
80
+ return "qwen3vl-8b", model_path
81
+ elif "4b" in repo_name_lower:
82
+ if "thinking" in repo_name_lower:
83
+ return "qwen3vl-4b-thinking", model_path
84
+ else:
85
+ return "qwen3vl-4b", model_path
86
+ else:
87
+ # Default to 4B if size not specified
88
+ return "qwen3vl-4b", model_path
89
+ elif "gemma" in repo_name_lower:
90
+ return "gemma3", model_path
91
+
92
+ return "", model_path
93
+
43
94
  def parse_arguments():
44
95
  """Parse command line arguments for the VLM main function."""
45
96
  parser = argparse.ArgumentParser(
@@ -48,14 +99,14 @@ def parse_arguments():
48
99
  parser.add_argument(
49
100
  "--model_path",
50
101
  type=str,
51
- default="mlx-community/gemma-3-4b-it-8bit",
102
+ default="NexaAI/qwen3vl-4B-4bit-mlx",
52
103
  help="The path to the local model directory or Hugging Face repo."
53
104
  )
54
105
  parser.add_argument(
55
106
  "--model_name",
56
107
  type=str,
57
108
  default="",
58
- help="Specific model name/type (e.g., 'qwen3vl', 'qwen3vl-moe', 'gemma3'). If empty, auto-detect from model_path."
109
+ help="Specific model name/type (e.g., 'qwen3vl-4b', 'qwen3vl-4b-thinking', 'qwen3vl-8b', 'qwen3vl-8b-thinking'). If empty, auto-detect from model_path."
59
110
  )
60
111
  parser.add_argument(
61
112
  "--context_length",
@@ -89,22 +140,16 @@ def main():
89
140
 
90
141
  # Auto-detect model name if not provided
91
142
  model_name = args.model_name
92
-
93
- # TODO: avoid such hardcoded model name detection
143
+ model_path = args.model_path
144
+
94
145
  if not model_name:
95
- if "qwen3vl-30B" in args.model_path.lower():
96
- model_name = "qwen3vl-moe"
97
- elif "qwen3" in args.model_path.lower():
98
- model_name = "qwen3vl"
99
- elif "gemma" in args.model_path.lower():
100
- model_name = "gemma3"
101
- else:
102
- model_name = ""
146
+ model_name, model_path = detect_model_name_and_repo(args.model_path)
147
+ print(f"Auto-detected model: {model_name} from path: {model_path}")
103
148
 
104
149
  # Load the VLM instance
105
150
  vlm = VLM(
106
151
  model_name=model_name,
107
- model_path=args.model_path,
152
+ model_path=model_path,
108
153
  mmproj_path=None, # Not needed for this model
109
154
  context_length=args.context_length,
110
155
  device=None