kreuzberg 3.2.0__py3-none-any.whl → 3.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. kreuzberg/__init__.py +3 -0
  2. kreuzberg/__main__.py +8 -0
  3. kreuzberg/_api/__init__.py +0 -0
  4. kreuzberg/_api/main.py +87 -0
  5. kreuzberg/_cli_config.py +175 -0
  6. kreuzberg/_extractors/_image.py +39 -4
  7. kreuzberg/_extractors/_pandoc.py +158 -18
  8. kreuzberg/_extractors/_pdf.py +199 -19
  9. kreuzberg/_extractors/_presentation.py +1 -1
  10. kreuzberg/_extractors/_spread_sheet.py +65 -7
  11. kreuzberg/_gmft.py +222 -16
  12. kreuzberg/_mime_types.py +62 -16
  13. kreuzberg/_multiprocessing/__init__.py +6 -0
  14. kreuzberg/_multiprocessing/gmft_isolated.py +332 -0
  15. kreuzberg/_multiprocessing/process_manager.py +188 -0
  16. kreuzberg/_multiprocessing/sync_tesseract.py +261 -0
  17. kreuzberg/_multiprocessing/tesseract_pool.py +359 -0
  18. kreuzberg/_ocr/_easyocr.py +6 -12
  19. kreuzberg/_ocr/_paddleocr.py +15 -13
  20. kreuzberg/_ocr/_tesseract.py +136 -46
  21. kreuzberg/_playa.py +43 -0
  22. kreuzberg/_types.py +4 -0
  23. kreuzberg/_utils/_cache.py +372 -0
  24. kreuzberg/_utils/_device.py +10 -27
  25. kreuzberg/_utils/_document_cache.py +220 -0
  26. kreuzberg/_utils/_errors.py +232 -0
  27. kreuzberg/_utils/_pdf_lock.py +72 -0
  28. kreuzberg/_utils/_process_pool.py +100 -0
  29. kreuzberg/_utils/_serialization.py +82 -0
  30. kreuzberg/_utils/_string.py +1 -1
  31. kreuzberg/_utils/_sync.py +21 -0
  32. kreuzberg/cli.py +338 -0
  33. kreuzberg/extraction.py +247 -36
  34. kreuzberg-3.4.0.dist-info/METADATA +290 -0
  35. kreuzberg-3.4.0.dist-info/RECORD +50 -0
  36. {kreuzberg-3.2.0.dist-info → kreuzberg-3.4.0.dist-info}/WHEEL +1 -2
  37. kreuzberg-3.4.0.dist-info/entry_points.txt +2 -0
  38. kreuzberg-3.2.0.dist-info/METADATA +0 -166
  39. kreuzberg-3.2.0.dist-info/RECORD +0 -34
  40. kreuzberg-3.2.0.dist-info/top_level.txt +0 -1
  41. {kreuzberg-3.2.0.dist-info → kreuzberg-3.4.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,372 @@
1
+ """General-purpose file-based caching layer for Kreuzberg."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import hashlib
6
+ import os
7
+ import threading
8
+ import time
9
+ from contextlib import suppress
10
+ from pathlib import Path
11
+ from typing import Any, Generic, TypeVar
12
+
13
+ from anyio import Path as AsyncPath
14
+
15
+ from kreuzberg._types import ExtractionResult
16
+ from kreuzberg._utils._serialization import deserialize, serialize
17
+ from kreuzberg._utils._sync import run_sync
18
+
19
+ T = TypeVar("T")
20
+
21
+
22
+ class KreuzbergCache(Generic[T]):
23
+ """File-based cache for Kreuzberg operations.
24
+
25
+ Provides both sync and async interfaces for caching extraction results,
26
+ OCR results, table data, and other expensive operations to disk.
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ cache_type: str,
32
+ cache_dir: Path | str | None = None,
33
+ max_cache_size_mb: float = 500.0,
34
+ max_age_days: int = 30,
35
+ ) -> None:
36
+ """Initialize cache.
37
+
38
+ Args:
39
+ cache_type: Type of cache (e.g., 'ocr', 'tables', 'documents', 'mime')
40
+ cache_dir: Cache directory (defaults to .kreuzberg/{cache_type} in cwd)
41
+ max_cache_size_mb: Maximum cache size in MB (default: 500MB)
42
+ max_age_days: Maximum age of cached results in days (default: 30 days)
43
+ """
44
+ if cache_dir is None:
45
+ cache_dir = Path.cwd() / ".kreuzberg" / cache_type
46
+
47
+ self.cache_dir = Path(cache_dir)
48
+ self.cache_type = cache_type
49
+ self.max_cache_size_mb = max_cache_size_mb
50
+ self.max_age_days = max_age_days
51
+
52
+ self.cache_dir.mkdir(parents=True, exist_ok=True)
53
+
54
+ # In-memory tracking of processing state (session-scoped) # ~keep
55
+ self._processing: dict[str, threading.Event] = {}
56
+ self._lock = threading.Lock()
57
+
58
+ def _get_cache_key(self, **kwargs: Any) -> str:
59
+ """Generate cache key from kwargs.
60
+
61
+ Args:
62
+ **kwargs: Key-value pairs to generate cache key from
63
+
64
+ Returns:
65
+ Unique cache key string
66
+ """
67
+ # Sort for consistent hashing # ~keep
68
+ cache_str = str(sorted(kwargs.items()))
69
+ return hashlib.sha256(cache_str.encode()).hexdigest()[:16]
70
+
71
+ def _get_cache_path(self, cache_key: str) -> Path:
72
+ """Get cache file path for key."""
73
+ return self.cache_dir / f"{cache_key}.msgpack"
74
+
75
+ def _is_cache_valid(self, cache_path: Path) -> bool:
76
+ """Check if cached result is still valid."""
77
+ try:
78
+ if not cache_path.exists():
79
+ return False
80
+
81
+ mtime = cache_path.stat().st_mtime
82
+ age_days = (time.time() - mtime) / (24 * 3600)
83
+
84
+ return age_days <= self.max_age_days
85
+ except OSError:
86
+ return False
87
+
88
+ def _serialize_result(self, result: T) -> dict[str, Any]:
89
+ """Serialize result for caching with metadata."""
90
+ return {"type": type(result).__name__, "data": result, "cached_at": time.time()}
91
+
92
+ def _deserialize_result(self, cached_data: dict[str, Any]) -> T:
93
+ """Deserialize cached result."""
94
+ data = cached_data["data"]
95
+
96
+ if cached_data.get("type") == "ExtractionResult" and isinstance(data, dict):
97
+ from kreuzberg._types import ExtractionResult
98
+
99
+ return ExtractionResult(**data) # type: ignore[return-value]
100
+
101
+ return data # type: ignore[no-any-return]
102
+
103
+ def _cleanup_cache(self) -> None:
104
+ """Clean up old and oversized cache entries."""
105
+ try:
106
+ cache_files = list(self.cache_dir.glob("*.msgpack"))
107
+
108
+ cutoff_time = time.time() - (self.max_age_days * 24 * 3600)
109
+ for cache_file in cache_files[:]:
110
+ try:
111
+ if cache_file.stat().st_mtime < cutoff_time:
112
+ cache_file.unlink(missing_ok=True)
113
+ cache_files.remove(cache_file)
114
+ except OSError: # noqa: PERF203
115
+ continue
116
+
117
+ total_size = sum(cache_file.stat().st_size for cache_file in cache_files if cache_file.exists()) / (
118
+ 1024 * 1024
119
+ )
120
+
121
+ if total_size > self.max_cache_size_mb:
122
+ cache_files.sort(key=lambda f: f.stat().st_mtime if f.exists() else 0)
123
+
124
+ for cache_file in cache_files:
125
+ try:
126
+ size_mb = cache_file.stat().st_size / (1024 * 1024)
127
+ cache_file.unlink(missing_ok=True)
128
+ total_size -= size_mb
129
+
130
+ if total_size <= self.max_cache_size_mb * 0.8:
131
+ break
132
+ except OSError:
133
+ continue
134
+ except (OSError, ValueError, TypeError):
135
+ pass
136
+
137
+ def get(self, **kwargs: Any) -> T | None:
138
+ """Get cached result (sync).
139
+
140
+ Args:
141
+ **kwargs: Key-value pairs to generate cache key from
142
+
143
+ Returns:
144
+ Cached result if available, None otherwise
145
+ """
146
+ cache_key = self._get_cache_key(**kwargs)
147
+ cache_path = self._get_cache_path(cache_key)
148
+
149
+ if not self._is_cache_valid(cache_path):
150
+ return None
151
+
152
+ try:
153
+ content = cache_path.read_bytes()
154
+ cached_data = deserialize(content, dict)
155
+ return self._deserialize_result(cached_data)
156
+ except (OSError, ValueError, KeyError):
157
+ with suppress(OSError):
158
+ cache_path.unlink(missing_ok=True)
159
+ return None
160
+
161
+ def set(self, result: T, **kwargs: Any) -> None:
162
+ """Cache result (sync).
163
+
164
+ Args:
165
+ result: Result to cache
166
+ **kwargs: Key-value pairs to generate cache key from
167
+ """
168
+ cache_key = self._get_cache_key(**kwargs)
169
+ cache_path = self._get_cache_path(cache_key)
170
+
171
+ try:
172
+ serialized = self._serialize_result(result)
173
+ content = serialize(serialized)
174
+ cache_path.write_bytes(content)
175
+
176
+ if hash(cache_key) % 100 == 0:
177
+ self._cleanup_cache()
178
+ except (OSError, TypeError, ValueError):
179
+ pass
180
+
181
+ async def aget(self, **kwargs: Any) -> T | None:
182
+ """Get cached result (async).
183
+
184
+ Args:
185
+ **kwargs: Key-value pairs to generate cache key from
186
+
187
+ Returns:
188
+ Cached result if available, None otherwise
189
+ """
190
+ cache_key = self._get_cache_key(**kwargs)
191
+ cache_path = AsyncPath(self._get_cache_path(cache_key))
192
+
193
+ if not await run_sync(self._is_cache_valid, Path(cache_path)):
194
+ return None
195
+
196
+ try:
197
+ content = await cache_path.read_bytes()
198
+ cached_data = deserialize(content, dict)
199
+ return self._deserialize_result(cached_data)
200
+ except (OSError, ValueError, KeyError):
201
+ with suppress(Exception):
202
+ await cache_path.unlink(missing_ok=True)
203
+ return None
204
+
205
+ async def aset(self, result: T, **kwargs: Any) -> None:
206
+ """Cache result (async).
207
+
208
+ Args:
209
+ result: Result to cache
210
+ **kwargs: Key-value pairs to generate cache key from
211
+ """
212
+ cache_key = self._get_cache_key(**kwargs)
213
+ cache_path = AsyncPath(self._get_cache_path(cache_key))
214
+
215
+ try:
216
+ serialized = self._serialize_result(result)
217
+ content = serialize(serialized)
218
+ await cache_path.write_bytes(content)
219
+
220
+ if hash(cache_key) % 100 == 0:
221
+ await run_sync(self._cleanup_cache)
222
+ except (OSError, TypeError, ValueError):
223
+ pass
224
+
225
+ def is_processing(self, **kwargs: Any) -> bool:
226
+ """Check if operation is currently being processed."""
227
+ cache_key = self._get_cache_key(**kwargs)
228
+ with self._lock:
229
+ return cache_key in self._processing
230
+
231
+ def mark_processing(self, **kwargs: Any) -> threading.Event:
232
+ """Mark operation as being processed and return event to wait on."""
233
+ cache_key = self._get_cache_key(**kwargs)
234
+
235
+ with self._lock:
236
+ if cache_key not in self._processing:
237
+ self._processing[cache_key] = threading.Event()
238
+ return self._processing[cache_key]
239
+
240
+ def mark_complete(self, **kwargs: Any) -> None:
241
+ """Mark operation processing as complete."""
242
+ cache_key = self._get_cache_key(**kwargs)
243
+
244
+ with self._lock:
245
+ if cache_key in self._processing:
246
+ event = self._processing.pop(cache_key)
247
+ event.set()
248
+
249
+ def clear(self) -> None:
250
+ """Clear all cached results."""
251
+ try:
252
+ for cache_file in self.cache_dir.glob("*.msgpack"):
253
+ cache_file.unlink(missing_ok=True)
254
+ except OSError:
255
+ pass
256
+
257
+ with self._lock:
258
+ pass
259
+
260
+ def get_stats(self) -> dict[str, Any]:
261
+ """Get cache statistics."""
262
+ try:
263
+ cache_files = list(self.cache_dir.glob("*.msgpack"))
264
+ total_size = sum(cache_file.stat().st_size for cache_file in cache_files if cache_file.exists())
265
+
266
+ return {
267
+ "cache_type": self.cache_type,
268
+ "cached_results": len(cache_files),
269
+ "processing_results": len(self._processing),
270
+ "total_cache_size_mb": total_size / 1024 / 1024,
271
+ "avg_result_size_kb": (total_size / len(cache_files) / 1024) if cache_files else 0,
272
+ "cache_dir": str(self.cache_dir),
273
+ "max_cache_size_mb": self.max_cache_size_mb,
274
+ "max_age_days": self.max_age_days,
275
+ }
276
+ except OSError:
277
+ return {
278
+ "cache_type": self.cache_type,
279
+ "cached_results": 0,
280
+ "processing_results": len(self._processing),
281
+ "total_cache_size_mb": 0.0,
282
+ "avg_result_size_kb": 0.0,
283
+ "cache_dir": str(self.cache_dir),
284
+ "max_cache_size_mb": self.max_cache_size_mb,
285
+ "max_age_days": self.max_age_days,
286
+ }
287
+
288
+
289
+ _ocr_cache: KreuzbergCache[ExtractionResult] | None = None
290
+ _document_cache: KreuzbergCache[ExtractionResult] | None = None
291
+ _table_cache: KreuzbergCache[Any] | None = None
292
+ _mime_cache: KreuzbergCache[str] | None = None
293
+
294
+
295
+ def get_ocr_cache() -> KreuzbergCache[ExtractionResult]:
296
+ """Get the global OCR cache instance."""
297
+ global _ocr_cache
298
+ if _ocr_cache is None:
299
+ cache_dir_str = os.environ.get("KREUZBERG_CACHE_DIR")
300
+ cache_dir: Path | None = None
301
+ if cache_dir_str:
302
+ cache_dir = Path(cache_dir_str) / "ocr"
303
+
304
+ _ocr_cache = KreuzbergCache[ExtractionResult](
305
+ cache_type="ocr",
306
+ cache_dir=cache_dir,
307
+ max_cache_size_mb=float(os.environ.get("KREUZBERG_OCR_CACHE_SIZE_MB", "500")),
308
+ max_age_days=int(os.environ.get("KREUZBERG_OCR_CACHE_AGE_DAYS", "30")),
309
+ )
310
+ return _ocr_cache
311
+
312
+
313
+ def get_document_cache() -> KreuzbergCache[ExtractionResult]:
314
+ """Get the global document cache instance."""
315
+ global _document_cache
316
+ if _document_cache is None:
317
+ cache_dir_str = os.environ.get("KREUZBERG_CACHE_DIR")
318
+ cache_dir: Path | None = None
319
+ if cache_dir_str:
320
+ cache_dir = Path(cache_dir_str) / "documents"
321
+
322
+ _document_cache = KreuzbergCache[ExtractionResult](
323
+ cache_type="documents",
324
+ cache_dir=cache_dir,
325
+ max_cache_size_mb=float(os.environ.get("KREUZBERG_DOCUMENT_CACHE_SIZE_MB", "1000")),
326
+ max_age_days=int(os.environ.get("KREUZBERG_DOCUMENT_CACHE_AGE_DAYS", "7")),
327
+ )
328
+ return _document_cache
329
+
330
+
331
+ def get_table_cache() -> KreuzbergCache[Any]:
332
+ """Get the global table cache instance."""
333
+ global _table_cache
334
+ if _table_cache is None:
335
+ cache_dir_str = os.environ.get("KREUZBERG_CACHE_DIR")
336
+ cache_dir: Path | None = None
337
+ if cache_dir_str:
338
+ cache_dir = Path(cache_dir_str) / "tables"
339
+
340
+ _table_cache = KreuzbergCache[Any](
341
+ cache_type="tables",
342
+ cache_dir=cache_dir,
343
+ max_cache_size_mb=float(os.environ.get("KREUZBERG_TABLE_CACHE_SIZE_MB", "200")),
344
+ max_age_days=int(os.environ.get("KREUZBERG_TABLE_CACHE_AGE_DAYS", "30")),
345
+ )
346
+ return _table_cache
347
+
348
+
349
+ def get_mime_cache() -> KreuzbergCache[str]:
350
+ """Get the global MIME type cache instance."""
351
+ global _mime_cache
352
+ if _mime_cache is None:
353
+ cache_dir_str = os.environ.get("KREUZBERG_CACHE_DIR")
354
+ cache_dir: Path | None = None
355
+ if cache_dir_str:
356
+ cache_dir = Path(cache_dir_str) / "mime"
357
+
358
+ _mime_cache = KreuzbergCache[str](
359
+ cache_type="mime",
360
+ cache_dir=cache_dir,
361
+ max_cache_size_mb=float(os.environ.get("KREUZBERG_MIME_CACHE_SIZE_MB", "50")),
362
+ max_age_days=int(os.environ.get("KREUZBERG_MIME_CACHE_AGE_DAYS", "60")),
363
+ )
364
+ return _mime_cache
365
+
366
+
367
+ def clear_all_caches() -> None:
368
+ """Clear all caches."""
369
+ get_ocr_cache().clear()
370
+ get_document_cache().clear()
371
+ get_table_cache().clear()
372
+ get_mime_cache().clear()
@@ -1,5 +1,5 @@
1
1
  """Device detection and management utilities for GPU acceleration."""
2
- # ruff: noqa: BLE001
2
+ # ruff: noqa: BLE001 # ~keep
3
3
 
4
4
  from __future__ import annotations
5
5
 
@@ -36,7 +36,6 @@ def detect_available_devices() -> list[DeviceInfo]:
36
36
  """
37
37
  devices: list[DeviceInfo] = []
38
38
 
39
- # Always include CPU as fallback
40
39
  devices.append(
41
40
  DeviceInfo(
42
41
  device_type="cpu",
@@ -44,18 +43,15 @@ def detect_available_devices() -> list[DeviceInfo]:
44
43
  )
45
44
  )
46
45
 
47
- # Check for CUDA (NVIDIA GPUs)
48
46
  if _is_cuda_available():
49
47
  cuda_devices = _get_cuda_devices()
50
48
  devices.extend(cuda_devices)
51
49
 
52
- # Check for MPS (Apple Silicon)
53
50
  if _is_mps_available():
54
51
  mps_device = _get_mps_device()
55
52
  if mps_device:
56
53
  devices.append(mps_device)
57
54
 
58
- # Reorder to put GPU devices first
59
55
  gpu_devices = [d for d in devices if d.device_type != "cpu"]
60
56
  cpu_devices = [d for d in devices if d.device_type == "cpu"]
61
57
 
@@ -95,14 +91,12 @@ def validate_device_request(
95
91
  """
96
92
  available_devices = detect_available_devices()
97
93
 
98
- # Handle auto device selection
99
94
  if requested == "auto":
100
95
  device = get_optimal_device()
101
96
  if memory_limit is not None:
102
97
  _validate_memory_limit(device, memory_limit)
103
98
  return device
104
99
 
105
- # Find requested device
106
100
  matching_devices = [d for d in available_devices if d.device_type == requested]
107
101
 
108
102
  if not matching_devices:
@@ -125,10 +119,8 @@ def validate_device_request(
125
119
  },
126
120
  )
127
121
 
128
- # Use the first matching device (typically the best one)
129
122
  device = matching_devices[0]
130
123
 
131
- # Validate memory limit if specified
132
124
  if memory_limit is not None:
133
125
  _validate_memory_limit(device, memory_limit)
134
126
 
@@ -159,7 +151,7 @@ def get_device_memory_info(device: DeviceInfo) -> tuple[float | None, float | No
159
151
  def _is_cuda_available() -> bool:
160
152
  """Check if CUDA is available."""
161
153
  try:
162
- import torch
154
+ import torch # type: ignore[import-not-found,unused-ignore]
163
155
 
164
156
  return torch.cuda.is_available()
165
157
  except ImportError:
@@ -169,7 +161,7 @@ def _is_cuda_available() -> bool:
169
161
  def _is_mps_available() -> bool:
170
162
  """Check if MPS (Apple Silicon) is available."""
171
163
  try:
172
- import torch
164
+ import torch # type: ignore[import-not-found,unused-ignore]
173
165
 
174
166
  return torch.backends.mps.is_available()
175
167
  except ImportError:
@@ -188,17 +180,14 @@ def _get_cuda_devices() -> list[DeviceInfo]:
188
180
 
189
181
  for i in range(torch.cuda.device_count()):
190
182
  props = torch.cuda.get_device_properties(i)
191
- total_memory = props.total_memory / (1024**3) # Convert to GB
183
+ total_memory = props.total_memory / (1024**3)
192
184
 
193
- # Get available memory
194
185
  torch.cuda.set_device(i)
195
186
  available_memory = torch.cuda.get_device_properties(i).total_memory / (1024**3)
196
187
  try:
197
- # Try to get current memory usage
198
188
  allocated = torch.cuda.memory_allocated(i) / (1024**3)
199
189
  available_memory = total_memory - allocated
200
190
  except Exception:
201
- # Fallback to total memory if we can't get allocation info
202
191
  available_memory = total_memory
203
192
 
204
193
  devices.append(
@@ -225,7 +214,6 @@ def _get_mps_device() -> DeviceInfo | None:
225
214
  if not torch.backends.mps.is_available():
226
215
  return None
227
216
 
228
- # MPS doesn't provide detailed memory info
229
217
  return DeviceInfo(
230
218
  device_type="mps",
231
219
  name="Apple Silicon GPU (MPS)",
@@ -260,8 +248,6 @@ def _get_cuda_memory_info(device_id: int) -> tuple[float | None, float | None]:
260
248
 
261
249
  def _get_mps_memory_info() -> tuple[float | None, float | None]:
262
250
  """Get MPS memory information."""
263
- # MPS doesn't provide detailed memory info through PyTorch
264
- # We could potentially use system calls but that's platform-specific
265
251
  return None, None
266
252
 
267
253
 
@@ -276,7 +262,7 @@ def _validate_memory_limit(device: DeviceInfo, memory_limit: float) -> None:
276
262
  ValidationError: If the device doesn't have enough memory.
277
263
  """
278
264
  if device.device_type == "cpu":
279
- # CPU memory validation is complex and OS-dependent, skip for now
265
+ # CPU memory validation is complex and OS-dependent, skip for now # ~keep
280
266
  return
281
267
 
282
268
  total_memory, available_memory = get_device_memory_info(device)
@@ -311,7 +297,7 @@ def is_backend_gpu_compatible(backend: str) -> bool:
311
297
  Returns:
312
298
  True if the backend supports GPU acceleration.
313
299
  """
314
- # EasyOCR and PaddleOCR support GPU, Tesseract does not
300
+ # EasyOCR and PaddleOCR support GPU, Tesseract does not # ~keep
315
301
  return backend.lower() in ("easyocr", "paddleocr")
316
302
 
317
303
 
@@ -326,25 +312,22 @@ def get_recommended_batch_size(device: DeviceInfo, input_size_mb: float = 10.0)
326
312
  Recommended batch size.
327
313
  """
328
314
  if device.device_type == "cpu":
329
- # Conservative batch size for CPU
315
+ # Conservative batch size for CPU # ~keep
330
316
  return 1
331
317
 
332
- # For GPU devices, estimate based on available memory
333
318
  _, available_memory = get_device_memory_info(device)
334
319
 
335
320
  if available_memory is None:
336
- # Conservative default for unknown memory
337
321
  return 4
338
322
 
339
- # Reserve some memory for model and intermediate calculations
340
- # Use approximately 50% of available memory for batching
323
+ # Use approximately 50% of available memory for batching # ~keep
341
324
  usable_memory_gb = available_memory * 0.5
342
325
  usable_memory_mb = usable_memory_gb * 1024
343
326
 
344
- # Estimate batch size (conservative)
327
+ # Estimate batch size (conservative) # ~keep
345
328
  estimated_batch_size = max(1, int(usable_memory_mb / (input_size_mb * 4)))
346
329
 
347
- # Cap at reasonable limits
330
+ # Cap at reasonable limits # ~keep
348
331
  return min(estimated_batch_size, 32)
349
332
 
350
333