openrunner-sdk 2.3.0__tar.gz → 2.4.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/PKG-INFO +3 -1
  2. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/openrunner/__init__.py +1 -0
  3. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/openrunner/api_client.py +80 -0
  4. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/openrunner/run.py +31 -10
  5. openrunner_sdk-2.4.1/openrunner/wer.py +232 -0
  6. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/pyproject.toml +2 -1
  7. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/tests/test_alert.py +1 -1
  8. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/tests/test_integration_sklearn.py +7 -7
  9. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/tests/test_media.py +12 -4
  10. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/tests/test_plot.py +7 -0
  11. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/tests/test_system_metrics.py +2 -2
  12. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/.gitignore +0 -0
  13. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/=6.0 +0 -0
  14. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/=8.1 +0 -0
  15. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/README.md +0 -0
  16. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/openrunner/artifact.py +0 -0
  17. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/openrunner/buffer.py +0 -0
  18. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/openrunner/cache.py +0 -0
  19. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/openrunner/cli.py +0 -0
  20. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/openrunner/config.py +0 -0
  21. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/openrunner/cost.py +0 -0
  22. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/openrunner/dataset.py +0 -0
  23. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/openrunner/environment.py +0 -0
  24. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/openrunner/evaluation.py +0 -0
  25. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/openrunner/feedback.py +0 -0
  26. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/openrunner/git_info.py +0 -0
  27. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/openrunner/guardrails.py +0 -0
  28. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/openrunner/integration/__init__.py +0 -0
  29. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/openrunner/integration/accelerate.py +0 -0
  30. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/openrunner/integration/anthropic_tracer.py +0 -0
  31. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/openrunner/integration/catboost.py +0 -0
  32. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/openrunner/integration/diffusers.py +0 -0
  33. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/openrunner/integration/fastai.py +0 -0
  34. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/openrunner/integration/forced_alignment.py +0 -0
  35. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/openrunner/integration/gladia.py +0 -0
  36. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/openrunner/integration/gymnasium.py +0 -0
  37. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/openrunner/integration/huggingface.py +0 -0
  38. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/openrunner/integration/hydra.py +0 -0
  39. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/openrunner/integration/ignite.py +0 -0
  40. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/openrunner/integration/jax.py +0 -0
  41. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/openrunner/integration/keras.py +0 -0
  42. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/openrunner/integration/langchain.py +0 -0
  43. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/openrunner/integration/lightgbm.py +0 -0
  44. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/openrunner/integration/lightning.py +0 -0
  45. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/openrunner/integration/llamaindex.py +0 -0
  46. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/openrunner/integration/openai_finetune.py +0 -0
  47. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/openrunner/integration/openai_tracer.py +0 -0
  48. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/openrunner/integration/optuna.py +0 -0
  49. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/openrunner/integration/pytorch.py +0 -0
  50. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/openrunner/integration/sb3.py +0 -0
  51. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/openrunner/integration/sklearn.py +0 -0
  52. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/openrunner/integration/tensorflow.py +0 -0
  53. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/openrunner/integration/trl.py +0 -0
  54. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/openrunner/integration/tts.py +0 -0
  55. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/openrunner/integration/ultralytics.py +0 -0
  56. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/openrunner/integration/voice_agent.py +0 -0
  57. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/openrunner/integration/whisper.py +0 -0
  58. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/openrunner/integration/xgboost.py +0 -0
  59. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/openrunner/launch.py +0 -0
  60. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/openrunner/media.py +0 -0
  61. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/openrunner/migrate.py +0 -0
  62. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/openrunner/model.py +0 -0
  63. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/openrunner/offline.py +0 -0
  64. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/openrunner/pii.py +0 -0
  65. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/openrunner/plot.py +0 -0
  66. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/openrunner/prompt.py +0 -0
  67. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/openrunner/query_api.py +0 -0
  68. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/openrunner/scorers.py +0 -0
  69. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/openrunner/sender.py +0 -0
  70. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/openrunner/settings.py +0 -0
  71. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/openrunner/summary.py +0 -0
  72. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/openrunner/sweep.py +0 -0
  73. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/openrunner/system_metrics.py +0 -0
  74. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/openrunner/tensorboard.py +0 -0
  75. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/openrunner/trace.py +0 -0
  76. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/openrunner/transcript_formatter.py +0 -0
  77. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/openrunner/wal.py +0 -0
  78. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/openrunner/wandb_compat/__init__.py +0 -0
  79. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/openrunner/wandb_compat/_shim.py +0 -0
  80. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/tests/__init__.py +0 -0
  81. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/tests/conftest.py +0 -0
  82. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/tests/test_aliases.py +0 -0
  83. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/tests/test_api_client.py +0 -0
  84. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/tests/test_artifact.py +0 -0
  85. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/tests/test_buffer.py +0 -0
  86. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/tests/test_cache.py +0 -0
  87. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/tests/test_class_scorers.py +0 -0
  88. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/tests/test_cli.py +0 -0
  89. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/tests/test_config.py +0 -0
  90. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/tests/test_evaluation.py +0 -0
  91. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/tests/test_finish.py +0 -0
  92. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/tests/test_git_info.py +0 -0
  93. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/tests/test_init.py +0 -0
  94. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/tests/test_integration_fastai.py +0 -0
  95. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/tests/test_integration_huggingface.py +0 -0
  96. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/tests/test_integration_keras.py +0 -0
  97. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/tests/test_integration_langchain.py +0 -0
  98. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/tests/test_integration_lightning.py +0 -0
  99. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/tests/test_integration_pytorch.py +0 -0
  100. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/tests/test_integration_xgboost.py +0 -0
  101. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/tests/test_launch.py +0 -0
  102. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/tests/test_log.py +0 -0
  103. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/tests/test_log_code.py +0 -0
  104. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/tests/test_migrate.py +0 -0
  105. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/tests/test_offline.py +0 -0
  106. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/tests/test_offline_sync.py +0 -0
  107. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/tests/test_pii.py +0 -0
  108. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/tests/test_query_api.py +0 -0
  109. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/tests/test_resume.py +0 -0
  110. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/tests/test_sdk_features.py +0 -0
  111. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/tests/test_sender.py +0 -0
  112. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/tests/test_summary.py +0 -0
  113. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/tests/test_sweep.py +0 -0
  114. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/tests/test_trace.py +0 -0
  115. {openrunner_sdk-2.3.0 → openrunner_sdk-2.4.1}/tests/test_wandb_compat.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: openrunner-sdk
3
- Version: 2.3.0
3
+ Version: 2.4.1
4
4
  Summary: OpenRunner SDK - W&B-compatible ML experiment tracking client
5
5
  Project-URL: Homepage, https://github.com/jqueguiner/openrunner
6
6
  Project-URL: Repository, https://github.com/jqueguiner/openrunner
@@ -82,6 +82,8 @@ Requires-Dist: numpy>=1.24; extra == 'tts'
82
82
  Provides-Extra: ultralytics
83
83
  Requires-Dist: ultralytics>=8.0; extra == 'ultralytics'
84
84
  Provides-Extra: voice-agent
85
+ Provides-Extra: wer
86
+ Requires-Dist: num2words2>=0.1; extra == 'wer'
85
87
  Provides-Extra: whisper
86
88
  Requires-Dist: openai-whisper>=20231117; extra == 'whisper'
87
89
  Provides-Extra: xgboost
@@ -97,6 +97,7 @@ from openrunner.settings import SDKSettings
97
97
  from openrunner.summary import Summary
98
98
  from openrunner.sweep import agent, sweep
99
99
  from openrunner.evaluation import EvaluationLogger, Scorer, evaluate, scorer
100
+ from openrunner.wer import WERScorer, compute_wer, compute_wer_batch
100
101
  from openrunner.guardrails import (
101
102
  GuardrailCheckResult,
102
103
  GuardrailResult,
@@ -490,6 +490,86 @@ class APIClient:
490
490
  logger.warning("use_artifact failed: %s", e)
491
491
  return None
492
492
 
493
+ def download_artifact(
494
+ self,
495
+ run_id: str,
496
+ artifact_name: str,
497
+ dest_dir: str = ".",
498
+ version: int | None = None,
499
+ alias: str | None = None,
500
+ ) -> str | None:
501
+ """Download all files from an artifact version to a local directory.
502
+
503
+ Args:
504
+ run_id: Run ID that used/created the artifact
505
+ artifact_name: Artifact name (or "name:alias")
506
+ dest_dir: Local directory to save files (created if needed)
507
+ version: Specific version number (optional)
508
+ alias: Alias name like "latest", "best" (optional)
509
+
510
+ Returns:
511
+ Path to the download directory, or None on failure.
512
+
513
+ Example:
514
+ path = client.download_artifact(run_id, "model-checkpoint", "./models")
515
+ # Files saved to ./models/model-checkpoint/v3/...
516
+ """
517
+ import os
518
+ from pathlib import Path
519
+
520
+ info = self.use_artifact(run_id, artifact_name, version=version, alias=alias)
521
+ if not info:
522
+ logger.warning("download_artifact: use_artifact returned no data")
523
+ return None
524
+
525
+ # Extract version info and files
526
+ ver = info.get("version", version or 1)
527
+ files = info.get("files", [])
528
+ if not files:
529
+ logger.warning("download_artifact: no files in artifact")
530
+ return None
531
+
532
+ # Create destination directory
533
+ name_clean = artifact_name.split(":")[0].replace("/", "_")
534
+ out_dir = Path(dest_dir) / name_clean / f"v{ver}"
535
+ out_dir.mkdir(parents=True, exist_ok=True)
536
+
537
+ downloaded = 0
538
+ for f in files:
539
+ url = f.get("presigned_url") or f.get("url") or f.get("download_url")
540
+ fname = f.get("name") or f.get("path") or f"file_{downloaded}"
541
+ if not url:
542
+ continue
543
+
544
+ # Try presigned URL first, fall back to proxy
545
+ data = self.download_file_from_presigned_url(url)
546
+ if data is None and "localhost" in url:
547
+ # Try proxy
548
+ key = f.get("storage_key") or f.get("key")
549
+ if key:
550
+ try:
551
+ resp = self._request("GET", f"/storage/download?key={key}")
552
+ if resp.status_code == 200:
553
+ data = resp.content
554
+ except Exception:
555
+ pass
556
+
557
+ if data:
558
+ file_path = out_dir / fname
559
+ file_path.parent.mkdir(parents=True, exist_ok=True)
560
+ file_path.write_bytes(data)
561
+ downloaded += 1
562
+ logger.info("downloaded: %s (%d bytes)", fname, len(data))
563
+ else:
564
+ logger.warning("failed to download: %s", fname)
565
+
566
+ if downloaded == 0:
567
+ logger.warning("download_artifact: no files downloaded")
568
+ return None
569
+
570
+ logger.info("artifact downloaded: %d files → %s", downloaded, out_dir)
571
+ return str(out_dir)
572
+
493
573
  def set_alias(
494
574
  self,
495
575
  artifact_id: str,
@@ -940,10 +940,11 @@ class Run:
940
940
  artifact_dir.mkdir(parents=True, exist_ok=True)
941
941
 
942
942
  for file_info in result.get("files", []):
943
- content_hash = file_info.get("content_hash", "")
944
- cached = self._artifact_cache.get(content_hash)
943
+ content_hash = file_info.get("content_hash") or ""
944
+ file_path = file_info.get("name") or file_info.get("path") or "file"
945
+ cached = self._artifact_cache.get(content_hash) if content_hash else None
945
946
 
946
- dest = artifact_dir / file_info["path"]
947
+ dest = artifact_dir / file_path
947
948
  dest.parent.mkdir(parents=True, exist_ok=True)
948
949
 
949
950
  if cached:
@@ -954,11 +955,21 @@ class Run:
954
955
  shutil.copy2(str(cached), str(dest))
955
956
  else:
956
957
  # Download and cache
957
- data = self._client.download_file_from_presigned_url(
958
- file_info["presigned_url"]
959
- )
958
+ url = file_info.get("download_url") or file_info.get("presigned_url", "")
959
+ data = None
960
+ if url.startswith("/"):
961
+ # Relative proxy URL — use authenticated client
962
+ try:
963
+ resp = self._client._request("GET", url)
964
+ if resp.status_code == 200:
965
+ data = resp.content
966
+ except Exception:
967
+ pass
968
+ else:
969
+ data = self._client.download_file_from_presigned_url(url)
960
970
  if data:
961
- self._artifact_cache.put(content_hash, data)
971
+ if content_hash:
972
+ self._artifact_cache.put(content_hash, data)
962
973
  dest.write_bytes(data)
963
974
 
964
975
  return artifact_dir
@@ -1154,18 +1165,28 @@ class Run:
1154
1165
  name: str,
1155
1166
  version: int | None = None,
1156
1167
  alias: str | None = None,
1157
- ):
1158
- """Convenience: download a model artifact.
1168
+ dest_dir: str = "./artifacts",
1169
+ ) -> str | None:
1170
+ """Download a model artifact to disk.
1159
1171
 
1160
1172
  Args:
1161
1173
  name: Model artifact name (supports "name:alias" syntax).
1162
1174
  version: Specific version number, or None for latest.
1163
1175
  alias: Alias name to resolve.
1176
+ dest_dir: Local directory for downloaded files.
1164
1177
 
1165
1178
  Returns:
1166
1179
  Path to the local artifact directory, or None on failure.
1167
1180
  """
1168
- return self.use_artifact(name, version=version, alias=alias)
1181
+ if not self._client:
1182
+ return None
1183
+ return self._client.download_artifact(
1184
+ run_id=self._run_id,
1185
+ artifact_name=name,
1186
+ dest_dir=dest_dir,
1187
+ version=version,
1188
+ alias=alias,
1189
+ )
1169
1190
 
1170
1191
  def link_model(
1171
1192
  self,
@@ -0,0 +1,232 @@
1
+ """Word Error Rate (WER) computation with num2words2 normalization.
2
+
3
+ Normalizes numbers to words before comparison so "50" vs "fifty" aren't
4
+ counted as substitution errors. Uses num2words2 (modern fork optimized
5
+ for LLM/AI/speech applications).
6
+
7
+ Install: pip install num2words2
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import re
13
+ from typing import Any
14
+
15
+
16
+ def _normalize_text(text: str, language: str = "en") -> str:
17
+ """Normalize text for WER: lowercase, expand numbers, strip punctuation."""
18
+ text = text.lower().strip()
19
+
20
+ # Expand numbers to words using num2words2
21
+ try:
22
+ from num2words2 import num2words
23
+
24
+ def _expand_number(match: re.Match) -> str:
25
+ num_str = match.group(0)
26
+ try:
27
+ # Handle decimals
28
+ if "." in num_str:
29
+ return num2words(float(num_str), lang=language)
30
+ # Handle integers
31
+ return num2words(int(num_str), lang=language)
32
+ except (ValueError, OverflowError):
33
+ return num_str
34
+
35
+ # Match numbers (integers, decimals, negatives)
36
+ text = re.sub(r"-?\d+\.?\d*", _expand_number, text)
37
+
38
+ except ImportError:
39
+ pass # num2words2 not installed — skip normalization
40
+
41
+ # Expand common currency symbols
42
+ text = re.sub(r"\$\s*", "dollars ", text)
43
+ text = re.sub(r"€\s*", "euros ", text)
44
+ text = re.sub(r"£\s*", "pounds ", text)
45
+ text = re.sub(r"%", " percent", text)
46
+
47
+ # Strip punctuation (keep hyphens inside words for compound words)
48
+ text = re.sub(r"[^\w\s-]", " ", text)
49
+ # Collapse whitespace
50
+ text = re.sub(r"\s+", " ", text).strip()
51
+
52
+ return text
53
+
54
+
55
+ def compute_wer(
56
+ reference: str,
57
+ hypothesis: str,
58
+ normalize: bool = True,
59
+ language: str = "en",
60
+ ) -> dict[str, Any]:
61
+ """Compute Word Error Rate between reference and hypothesis.
62
+
63
+ Args:
64
+ reference: Ground truth transcription
65
+ hypothesis: Model prediction
66
+ normalize: If True, expand numbers with num2words2 before comparing
67
+ language: Language code for num2words2 (en, es, fr, de, etc.)
68
+
69
+ Returns:
70
+ Dict with: wer, substitutions, insertions, deletions, ref_words, hyp_words
71
+ """
72
+ if normalize:
73
+ ref = _normalize_text(reference, language)
74
+ hyp = _normalize_text(hypothesis, language)
75
+ else:
76
+ ref = reference.lower().strip()
77
+ hyp = hypothesis.lower().strip()
78
+
79
+ ref_words = ref.split()
80
+ hyp_words = hyp.split()
81
+
82
+ # Edit distance DP
83
+ m, n = len(ref_words), len(hyp_words)
84
+ dp = [[0] * (n + 1) for _ in range(m + 1)]
85
+ for i in range(m + 1):
86
+ dp[i][0] = i
87
+ for j in range(n + 1):
88
+ dp[0][j] = j
89
+ for i in range(1, m + 1):
90
+ for j in range(1, n + 1):
91
+ if ref_words[i - 1] == hyp_words[j - 1]:
92
+ dp[i][j] = dp[i - 1][j - 1]
93
+ else:
94
+ dp[i][j] = 1 + min(
95
+ dp[i - 1][j - 1], # substitution
96
+ dp[i - 1][j], # deletion
97
+ dp[i][j - 1], # insertion
98
+ )
99
+
100
+ # Backtrace for error counts
101
+ subs, dels, ins = 0, 0, 0
102
+ i, j = m, n
103
+ while i > 0 or j > 0:
104
+ if i > 0 and j > 0 and ref_words[i - 1] == hyp_words[j - 1]:
105
+ i -= 1
106
+ j -= 1
107
+ elif i > 0 and j > 0 and dp[i][j] == dp[i - 1][j - 1] + 1:
108
+ subs += 1
109
+ i -= 1
110
+ j -= 1
111
+ elif i > 0 and dp[i][j] == dp[i - 1][j] + 1:
112
+ dels += 1
113
+ i -= 1
114
+ else:
115
+ ins += 1
116
+ j -= 1
117
+
118
+ total_errors = subs + dels + ins
119
+ wer_score = total_errors / max(len(ref_words), 1)
120
+
121
+ return {
122
+ "wer": round(wer_score, 4),
123
+ "substitutions": subs,
124
+ "insertions": ins,
125
+ "deletions": dels,
126
+ "errors": total_errors,
127
+ "ref_words": len(ref_words),
128
+ "hyp_words": len(hyp_words),
129
+ }
130
+
131
+
132
+ def compute_wer_batch(
133
+ references: list[str],
134
+ hypotheses: list[str],
135
+ normalize: bool = True,
136
+ language: str = "en",
137
+ ) -> dict[str, Any]:
138
+ """Compute WER across a batch of reference/hypothesis pairs.
139
+
140
+ Returns aggregate metrics + per-example breakdown.
141
+ """
142
+ total_errors = 0
143
+ total_ref_words = 0
144
+ total_subs = 0
145
+ total_ins = 0
146
+ total_dels = 0
147
+ examples = []
148
+
149
+ for ref, hyp in zip(references, hypotheses):
150
+ result = compute_wer(ref, hyp, normalize=normalize, language=language)
151
+ total_errors += result["errors"]
152
+ total_ref_words += result["ref_words"]
153
+ total_subs += result["substitutions"]
154
+ total_ins += result["insertions"]
155
+ total_dels += result["deletions"]
156
+ examples.append(result)
157
+
158
+ wer_score = total_errors / max(total_ref_words, 1)
159
+
160
+ return {
161
+ "wer": round(wer_score, 4),
162
+ "substitutions": total_subs,
163
+ "insertions": total_ins,
164
+ "deletions": total_dels,
165
+ "errors": total_errors,
166
+ "ref_words": total_ref_words,
167
+ "n_examples": len(references),
168
+ "examples": examples,
169
+ }
170
+
171
+
172
+ # ---------------------------------------------------------------------------
173
+ # Scorer integration for openrunner.evaluate()
174
+ # ---------------------------------------------------------------------------
175
+
176
+ from openrunner.evaluation import Scorer
177
+
178
+
179
+ class WERScorer(Scorer):
180
+ """Word Error Rate scorer for evaluation framework.
181
+
182
+ Uses num2words2 to normalize numbers before comparison.
183
+
184
+ Args:
185
+ normalize: Expand numbers to words (default True)
186
+ language: Language for number expansion (default "en")
187
+ ref_key: Key in the example dict for ground truth (default "expected")
188
+ hyp_key: Key in the output for hypothesis (default: uses output directly)
189
+
190
+ Example:
191
+ results = openrunner.evaluate(
192
+ model_fn=my_asr_model,
193
+ dataset=[{"input": audio, "expected": "fifty dollars"}],
194
+ scorers=[WERScorer(language="en")],
195
+ )
196
+ """
197
+
198
+ def __init__(
199
+ self,
200
+ normalize: bool = True,
201
+ language: str = "en",
202
+ ):
203
+ self.normalize = normalize
204
+ self.language = language
205
+
206
+ def score(self, output: Any, expected: Any, **kwargs) -> dict:
207
+ ref = str(expected) if expected else ""
208
+ hyp = str(output) if output else ""
209
+ result = compute_wer(ref, hyp, normalize=self.normalize, language=self.language)
210
+ return {
211
+ "wer": result["wer"],
212
+ "substitutions": result["substitutions"],
213
+ "insertions": result["insertions"],
214
+ "deletions": result["deletions"],
215
+ }
216
+
217
+ def summarize(self, scores: list[dict]) -> dict:
218
+ """Aggregate WER across all examples (corpus-level)."""
219
+ total_errors = sum(s["substitutions"] + s["insertions"] + s["deletions"] for s in scores)
220
+ # Approximate ref_words from individual WERs
221
+ total_ref = sum(
222
+ round((s["substitutions"] + s["insertions"] + s["deletions"]) / max(s["wer"], 1e-9))
223
+ if s["wer"] > 0 else 10
224
+ for s in scores
225
+ )
226
+ return {
227
+ "wer": round(total_errors / max(total_ref, 1), 4),
228
+ "mean_wer": round(sum(s["wer"] for s in scores) / max(len(scores), 1), 4),
229
+ "total_substitutions": sum(s["substitutions"] for s in scores),
230
+ "total_insertions": sum(s["insertions"] for s in scores),
231
+ "total_deletions": sum(s["deletions"] for s in scores),
232
+ }
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "openrunner-sdk"
3
- version = "2.3.0"
3
+ version = "2.4.1"
4
4
  description = "OpenRunner SDK - W&B-compatible ML experiment tracking client"
5
5
  readme = "README.md"
6
6
  license = {text = "MIT"}
@@ -30,6 +30,7 @@ Issues = "https://github.com/jqueguiner/openrunner/issues"
30
30
  openrunner = "openrunner.cli:main"
31
31
 
32
32
  [project.optional-dependencies]
33
+ wer = ["num2words2>=0.1"]
33
34
  gpu = ["nvidia-ml-py>=12.0"]
34
35
  pytorch = ["torch>=2.0"]
35
36
  huggingface = ["transformers>=4.30"]
@@ -192,7 +192,7 @@ class TestModuleLevelAlert:
192
192
  result = openrunner.alert("Title", text="Body", level="WARN")
193
193
  assert result == {"id": "x"}
194
194
  mock_run.alert.assert_called_once_with(
195
- title="Title", text="Body", level="WARN"
195
+ title="Title", text="Body", level="WARN", wait_duration=None
196
196
  )
197
197
  finally:
198
198
  openrunner._active_run = original
@@ -70,15 +70,15 @@ class TestLogModel:
70
70
  from openrunner.integration.sklearn import log_model
71
71
  import openrunner
72
72
 
73
- mock_log = MagicMock()
74
- monkeypatch.setattr("openrunner.log", mock_log)
73
+ mock_config = MagicMock()
74
+ monkeypatch.setattr("openrunner.config", mock_config)
75
75
  monkeypatch.setattr("openrunner._active_run", MagicMock())
76
76
 
77
77
  model = _make_model(params={"n_estimators": 100, "max_depth": 5, "random_state": 42})
78
78
  log_model(model)
79
79
 
80
- mock_log.assert_called_once()
81
- logged = mock_log.call_args[0][0]
80
+ mock_config.update.assert_called_once()
81
+ logged = mock_config.update.call_args[0][0]
82
82
  assert logged["model/n_estimators"] == 100
83
83
  assert logged["model/max_depth"] == 5
84
84
  assert logged["model/random_state"] == 42
@@ -88,12 +88,12 @@ class TestLogModel:
88
88
  from openrunner.integration.sklearn import log_model
89
89
  import openrunner
90
90
 
91
- mock_log = MagicMock()
92
- monkeypatch.setattr("openrunner.log", mock_log)
91
+ mock_config = MagicMock()
92
+ monkeypatch.setattr("openrunner.config", mock_config)
93
93
  monkeypatch.setattr("openrunner._active_run", None)
94
94
 
95
95
  log_model(_make_model())
96
- mock_log.assert_not_called()
96
+ mock_config.update.assert_not_called()
97
97
 
98
98
 
99
99
  class TestLogClassificationReport:
@@ -2,6 +2,7 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
+ import pytest
5
6
  import io
6
7
  import json
7
8
  import struct
@@ -586,16 +587,18 @@ class TestHtml:
586
587
  """Tests for the Html class."""
587
588
 
588
589
  def test_html_basic(self) -> None:
589
- """Html with raw string serializes correctly."""
590
+ """Html with raw string serializes correctly (inject=True prepends style)."""
590
591
  html = Html("<div>Hello World</div>")
591
592
  result = html._serialize()
592
- assert result == {"html": "<div>Hello World</div>"}
593
+ assert result["html"].endswith("<div>Hello World</div>")
594
+ assert "<style>" in result["html"]
593
595
 
594
596
  def test_html_with_caption(self) -> None:
595
597
  """Html with caption includes it in serialized output."""
596
598
  html = Html("<p>Report</p>", caption="Training Report")
597
599
  result = html._serialize()
598
- assert result == {"html": "<p>Report</p>", "caption": "Training Report"}
600
+ assert result["html"].endswith("<p>Report</p>")
601
+ assert result["caption"] == "Training Report"
599
602
  assert html.caption == "Training Report"
600
603
 
601
604
  def test_html_no_caption(self) -> None:
@@ -616,7 +619,7 @@ class TestHtml:
616
619
  )
617
620
  html = Html(content, caption="complex")
618
621
  result = html._serialize()
619
- assert result["html"] == content
622
+ assert result["html"].endswith(content)
620
623
  assert result["caption"] == "complex"
621
624
 
622
625
 
@@ -627,6 +630,11 @@ class TestHtml:
627
630
  class TestMatplotlibFigure:
628
631
  """Tests for the MatplotlibFigure class."""
629
632
 
633
+ pytestmark = pytest.mark.skipif(
634
+ not __import__("importlib").util.find_spec("matplotlib"),
635
+ reason="matplotlib not installed",
636
+ )
637
+
630
638
  def test_matplotlib_explicit_figure(self) -> None:
631
639
  """MatplotlibFigure from explicit figure serializes to PNG bytes."""
632
640
  import matplotlib
@@ -208,6 +208,11 @@ class TestLineSeries:
208
208
  # pr_curve()
209
209
  # ---------------------------------------------------------------------------
210
210
 
211
+ import importlib.util
212
+ _has_sklearn = importlib.util.find_spec("sklearn") is not None
213
+
214
+
215
+ @pytest.mark.skipif(not _has_sklearn, reason="scikit-learn not installed")
211
216
  class TestPRCurve:
212
217
  """Tests for plot.pr_curve()."""
213
218
 
@@ -272,6 +277,7 @@ class TestPRCurve:
272
277
  # roc_curve()
273
278
  # ---------------------------------------------------------------------------
274
279
 
280
+ @pytest.mark.skipif(not _has_sklearn, reason="scikit-learn not installed")
275
281
  class TestROCCurve:
276
282
  """Tests for plot.roc_curve()."""
277
283
 
@@ -336,6 +342,7 @@ class TestROCCurve:
336
342
  # confusion_matrix()
337
343
  # ---------------------------------------------------------------------------
338
344
 
345
+ @pytest.mark.skipif(not _has_sklearn, reason="scikit-learn not installed")
339
346
  class TestConfusionMatrix:
340
347
  """Tests for plot.confusion_matrix()."""
341
348
 
@@ -82,8 +82,8 @@ class TestGPUMetrics:
82
82
  assert result["system.gpu.0.gpu"] == 75.0
83
83
  assert "system.gpu.0.memory" in result
84
84
  assert result["system.gpu.0.memory"] == 50.0
85
- assert "system.gpu.0.memoryAllocatedBytes" in result
86
- assert result["system.gpu.0.memoryAllocatedBytes"] == 4_000_000_000.0
85
+ assert "system.gpu.0.memoryAllocatedMB" in result
86
+ assert result["system.gpu.0.memoryAllocatedMB"] == 4_000_000_000.0 / (1024 * 1024)
87
87
 
88
88
 
89
89
  # ---------------------------------------------------------------------------
File without changes
File without changes
File without changes