wherobots-python-sdk 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wherobots/client.py ADDED
@@ -0,0 +1,640 @@
1
+ """Wherobots Jobs API Client.
2
+
3
+ ``WherobotsJob`` is the primary user-facing class. It orchestrates
4
+ job submission, monitoring, log streaming, and cancellation by
5
+ delegating HTTP work to the ``RunsAPI`` layer.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import time
11
+ from collections.abc import Callable, Iterator
12
+ from typing import Any
13
+
14
+ from wherobots.api.runs import RunsAPI
15
+ from wherobots.config import WherobotsConfig
16
+ from wherobots.enums import DependencyType, JobStatus, Region, Runtime, is_terminal_status
17
+ from wherobots.exceptions import (
18
+ WherobotsAPIError,
19
+ WherobotsJobError,
20
+ WherobotsTimeoutError,
21
+ WherobotsValidationError,
22
+ )
23
+ from wherobots.models import (
24
+ CreateRunPayload,
25
+ LogsResponse,
26
+ RunEnvironment,
27
+ RunJarPayload,
28
+ RunListPage,
29
+ RunMetricsResponse,
30
+ RunPythonPayload,
31
+ RunView,
32
+ )
33
+ from wherobots.utils.logger import get_logger
34
+ from wherobots.utils.validation import validate_name
35
+
36
+ logger = get_logger("client")
37
+
38
+
39
+ def _status_str(status: JobStatus | str | None) -> str:
40
+ """Human-readable status string for logging, handling the
41
+ forward-compat case where *status* is a raw string the enum
42
+ doesn't know about."""
43
+ if status is None:
44
+ return "UNKNOWN"
45
+ if isinstance(status, JobStatus):
46
+ return status.value
47
+ return status
48
+
49
+
50
+ class WherobotsJob:
51
+ """Wherobots Job Run Manager.
52
+
53
+ Manages the lifecycle of Wherobots job runs including submission,
54
+ monitoring, log streaming, and cancellation.
55
+ """
56
+
57
+ def __init__(
58
+ self,
59
+ script: str,
60
+ name: str,
61
+ runtime: str | Runtime = "tiny",
62
+ region: str | Region | None = None,
63
+ api_key: str | None = None,
64
+ version: str | None = None,
65
+ timeout_seconds: int = 3600,
66
+ args: list[str] | None = None,
67
+ spark_configs: dict[str, str] | None = None,
68
+ dependencies: list[dict[str, Any]] | None = None,
69
+ spark_driver_disk_gb: int | None = None,
70
+ spark_executor_disk_gb: int | None = None,
71
+ s3_bucket: str | None = None,
72
+ s3_prefix: str | None = None,
73
+ jar_main_class: str | None = None,
74
+ auto_upload: bool = True,
75
+ base_url: str | None = None,
76
+ request_timeout_seconds: int | None = None,
77
+ config: WherobotsConfig | None = None,
78
+ ):
79
+ """Initialize a Wherobots Job.
80
+
81
+ Args:
82
+ script: Path to Python script (.py) or JAR file (.jar).
83
+ Can be local path or S3 URI (s3://bucket/key)
84
+ name: Job name (8-255 chars, alphanumeric, _, -, .)
85
+ runtime: Compute runtime size (default: "tiny")
86
+ region: AWS region (default: from config or env)
87
+ api_key: Wherobots API key (or set WHEROBOTS_API_KEY env var)
88
+ version: Wherobots version ("latest" or "preview")
89
+ timeout_seconds: Job timeout in seconds (default: 3600)
90
+ args: Optional list of arguments to pass to script
91
+ spark_configs: Optional Spark configuration dict
92
+ dependencies: Optional list of PyPI or file dependencies
93
+ spark_driver_disk_gb: Optional driver disk size
94
+ spark_executor_disk_gb: Optional executor disk size
95
+ s3_bucket: Deprecated. Ignored; presigned uploads are used
96
+ instead. Will be removed in a future release.
97
+ s3_prefix: Deprecated. Ignored; presigned uploads are used
98
+ instead. Will be removed in a future release.
99
+ jar_main_class: Main class for JAR files (required for JARs)
100
+ auto_upload: Automatically upload local scripts via presigned
101
+ URL (default: True)
102
+ base_url: Override the API base URL
103
+ request_timeout_seconds: HTTP request timeout in seconds
104
+ config: Optional WherobotsConfig to override defaults
105
+
106
+ Raises:
107
+ WherobotsValidationError: If script is None/empty, disk sizes
108
+ are negative, or jar_main_class is missing for JARs.
109
+ """
110
+ if not script:
111
+ raise WherobotsValidationError("script must not be None or empty")
112
+
113
+ if spark_driver_disk_gb is not None and spark_driver_disk_gb < 0:
114
+ raise WherobotsValidationError(
115
+ f"spark_driver_disk_gb must be non-negative, got {spark_driver_disk_gb}"
116
+ )
117
+ if spark_executor_disk_gb is not None and spark_executor_disk_gb < 0:
118
+ raise WherobotsValidationError(
119
+ f"spark_executor_disk_gb must be non-negative, got {spark_executor_disk_gb}"
120
+ )
121
+
122
+ self.script = script
123
+ self.name = validate_name(name)
124
+ self.runtime = runtime.value if isinstance(runtime, Runtime) else runtime
125
+ region_value = region.value if isinstance(region, Region) else region
126
+
127
+ # Deprecation warnings for s3_bucket / s3_prefix are emitted by
128
+ # ``WherobotsConfig.from_env`` so they fire exactly once per
129
+ # construction regardless of whether the user went through
130
+ # ``WherobotsJob`` or ``WherobotsConfig`` directly.
131
+ self._config = config or WherobotsConfig.from_env(
132
+ api_key=api_key,
133
+ region=region_value,
134
+ s3_bucket=s3_bucket,
135
+ s3_prefix=s3_prefix,
136
+ base_url=base_url,
137
+ version=version,
138
+ request_timeout_seconds=request_timeout_seconds,
139
+ )
140
+
141
+ self.region = region_value or self._config.region or "aws-us-west-2"
142
+ self.version = version or self._config.version
143
+ self.timeout_seconds = timeout_seconds
144
+ self.args = args or []
145
+ self.spark_configs = spark_configs or {}
146
+ self.dependencies = dependencies or []
147
+ self.spark_driver_disk_gb = spark_driver_disk_gb
148
+ self.spark_executor_disk_gb = spark_executor_disk_gb
149
+ self.s3_bucket = s3_bucket or self._config.s3_bucket
150
+ self.s3_prefix = s3_prefix or self._config.s3_prefix
151
+ self.jar_main_class = jar_main_class
152
+ self.auto_upload = auto_upload
153
+
154
+ self.run_id: str | None = None
155
+ # Status may be a string when the server returns a value the
156
+ # SDK's JobStatus enum doesn't recognize yet (forward-compat).
157
+ self.status: JobStatus | str | None = None
158
+ self._last_log_cursor: int | str = 0
159
+ self._script_uri: str | None = None
160
+
161
+ self.is_jar = script.lower().endswith(".jar")
162
+ if self.is_jar and not jar_main_class:
163
+ raise WherobotsValidationError("jar_main_class is required for JAR files")
164
+
165
+ # Build the API layer
166
+ self._api = RunsAPI.from_config(self._config)
167
+
168
+ # ------------------------------------------------------------------ #
169
+ # Upload helpers
170
+ # ------------------------------------------------------------------ #
171
+
172
+ @staticmethod
173
+ def _is_s3_uri(path: str) -> bool:
174
+ return path.startswith("s3://")
175
+
176
+ def _prepare_script_uri(self) -> str:
177
+ """Resolve the script to an S3 URI, uploading if necessary.
178
+
179
+ Decision tree:
180
+
181
+ 1. Already resolved → return cached URI.
182
+ 2. Script is already an ``s3://`` URI → use directly.
183
+ 3. Local file + ``auto_upload=True`` → upload via presigned URL
184
+ (``FilesAPI``). Only an API key is needed.
185
+ 4. ``auto_upload=False`` → error.
186
+ """
187
+ if self._script_uri:
188
+ return self._script_uri
189
+
190
+ if self._is_s3_uri(self.script):
191
+ self._script_uri = self.script
192
+ elif self.auto_upload:
193
+ self._script_uri = self._upload_via_presigned(self.script)
194
+ else:
195
+ raise WherobotsValidationError(
196
+ f"Script path '{self.script}' is not an S3 URI. "
197
+ "Set auto_upload=True or provide an S3 URI."
198
+ )
199
+
200
+ assert self._script_uri is not None # all branches above set or raise
201
+ return self._script_uri
202
+
203
+ def _upload_via_presigned(self, local_path: str) -> str:
204
+ """Upload a local file via presigned URL (FilesAPI).
205
+
206
+ Args:
207
+ local_path: Path to the local file to upload.
208
+
209
+ Returns:
210
+ The ``s3://`` URI of the uploaded file.
211
+
212
+ Raises:
213
+ WherobotsAPIError: If the presigned upload fails.
214
+ """
215
+ from wherobots.api.files import FilesAPI
216
+
217
+ files_api = FilesAPI.from_config(self._config)
218
+ try:
219
+ s3_uri = files_api.upload_file(local_path)
220
+ logger.info("Presigned upload succeeded: %s", s3_uri)
221
+ return s3_uri
222
+ finally:
223
+ files_api.close()
224
+
225
+ # ------------------------------------------------------------------ #
226
+ # Payload construction
227
+ # ------------------------------------------------------------------ #
228
+
229
+ def _build_payload(self) -> CreateRunPayload:
230
+ script_uri = self._prepare_script_uri()
231
+
232
+ run_python: RunPythonPayload | None = None
233
+ run_jar: RunJarPayload | None = None
234
+
235
+ if self.is_jar:
236
+ run_jar = RunJarPayload(
237
+ uri=script_uri,
238
+ main_class=self.jar_main_class or "",
239
+ args=self.args,
240
+ )
241
+ else:
242
+ run_python = RunPythonPayload(uri=script_uri, args=self.args)
243
+
244
+ environment: RunEnvironment | None = None
245
+ if (
246
+ self.spark_configs
247
+ or self.dependencies
248
+ or self.spark_driver_disk_gb is not None
249
+ or self.spark_executor_disk_gb is not None
250
+ ):
251
+ environment = RunEnvironment(
252
+ spark_configs=self.spark_configs or None,
253
+ dependencies=self.dependencies or None,
254
+ spark_driver_disk_gb=self.spark_driver_disk_gb,
255
+ spark_executor_disk_gb=self.spark_executor_disk_gb,
256
+ )
257
+
258
+ return CreateRunPayload(
259
+ runtime=self.runtime,
260
+ name=self.name,
261
+ version=self.version or "latest",
262
+ timeout_seconds=self.timeout_seconds,
263
+ run_python=run_python,
264
+ run_jar=run_jar,
265
+ environment=environment,
266
+ )
267
+
268
+ # ------------------------------------------------------------------ #
269
+ # Lifecycle
270
+ # ------------------------------------------------------------------ #
271
+
272
+ def close(self) -> None:
273
+ """Close the underlying API session.
274
+
275
+ Should be called when the job manager is no longer needed,
276
+ or use as a context manager instead.
277
+ """
278
+ self._api.close()
279
+
280
+ def __enter__(self) -> WherobotsJob:
281
+ return self
282
+
283
+ def __exit__(self, *exc: Any) -> None:
284
+ self.close()
285
+
286
+ # ------------------------------------------------------------------ #
287
+ # Public API
288
+ # ------------------------------------------------------------------ #
289
+
290
+ def submit(self) -> str:
291
+ """Submit the job run.
292
+
293
+ Returns:
294
+ Run ID
295
+ """
296
+ if self.run_id:
297
+ logger.warning("Job already submitted with run_id: %s", self.run_id)
298
+ return self.run_id
299
+
300
+ payload = self._build_payload()
301
+ logger.info("Submitting job '%s' to %s", self.name, self.region)
302
+
303
+ run_view = self._api.create(payload, region=self.region)
304
+ self.run_id = run_view.id
305
+ self.status = run_view.status
306
+
307
+ logger.info("Job submitted successfully. Run ID: %s", self.run_id)
308
+ if self.status:
309
+ logger.info("Status: %s", _status_str(self.status))
310
+
311
+ return self.run_id
312
+
313
+ def get_status(self) -> RunView:
314
+ """Get current job status and details.
315
+
316
+ Returns:
317
+ RunView with full job details
318
+ """
319
+ if not self.run_id:
320
+ raise WherobotsJobError("Job not submitted. Call submit() first.")
321
+
322
+ run_view = self._api.get(self.run_id)
323
+ self.status = run_view.status
324
+ return run_view
325
+
326
+ def get_logs(
327
+ self,
328
+ cursor: int | str = 0,
329
+ size: int = 100,
330
+ ) -> LogsResponse:
331
+ """Get job logs.
332
+
333
+ Args:
334
+ cursor: Position to start reading logs. Accepts integer byte
335
+ offsets (legacy) or string cursors returned in
336
+ ``LogsResponse.next_page``. Defaults to ``0``.
337
+ size: Maximum number of log entries to retrieve (default: 100)
338
+
339
+ Returns:
340
+ LogsResponse with items and pagination info
341
+ """
342
+ if not self.run_id:
343
+ raise WherobotsJobError("Job not submitted. Call submit() first.")
344
+
345
+ return self._api.get_logs(self.run_id, cursor=cursor, size=size)
346
+
347
+ def get_metrics(self) -> RunMetricsResponse:
348
+ """Get job metrics for the run.
349
+
350
+ Returns:
351
+ RunMetricsResponse with series and instant metrics
352
+ """
353
+ if not self.run_id:
354
+ raise WherobotsJobError("Job not submitted. Call submit() first.")
355
+
356
+ return self._api.get_metrics(self.run_id)
357
+
358
+ def iter_logs(
359
+ self,
360
+ cursor: int | str = 0,
361
+ size: int = 100,
362
+ max_pages: int = 10_000,
363
+ ) -> Iterator[dict[str, Any]]:
364
+ """Iterate over log entries until pagination ends.
365
+
366
+ Args:
367
+ cursor: Initial pagination cursor.
368
+ size: Page size passed to :meth:`get_logs`.
369
+ max_pages: Hard ceiling on pages fetched, to guard against a
370
+ buggy server that always returns a distinct
371
+ ``next_page`` cursor. Defaults to ``10_000``.
372
+
373
+ Raises:
374
+ WherobotsJobError: If the page ceiling is hit.
375
+ """
376
+ next_cursor: int | str = cursor
377
+ for _ in range(max_pages):
378
+ logs = self.get_logs(cursor=next_cursor, size=size)
379
+ for item in logs.items:
380
+ yield item.to_dict()
381
+ if logs.next_page is None or logs.next_page == next_cursor:
382
+ return
383
+ next_cursor = logs.next_page
384
+ raise WherobotsJobError(
385
+ f"iter_logs exceeded max_pages={max_pages} without reaching end of log stream"
386
+ )
387
+
388
+ def poll_for_logs(
389
+ self,
390
+ follow: bool = True,
391
+ interval: float = 2.0,
392
+ log_handler: Callable[[dict[str, Any]], None] | None = None,
393
+ max_errors: int = 10,
394
+ _deadline: float | None = None,
395
+ ) -> None:
396
+ """Poll and print job logs.
397
+
398
+ Args:
399
+ follow: If True, continue polling until job completes.
400
+ interval: Polling interval in seconds (default: 2.0).
401
+ log_handler: Optional handler for each log item dict.
402
+ max_errors: Max consecutive transient errors before giving up
403
+ (default: 10). Non-transient errors (4xx except 429)
404
+ are raised immediately.
405
+ _deadline: Internal. Monotonic clock deadline for timeout
406
+ enforcement. Callers should use
407
+ ``wait_for_completion(max_wait_seconds=...)`` instead.
408
+
409
+ Raises:
410
+ WherobotsJobError: If the job has not been submitted.
411
+ WherobotsTimeoutError: If *_deadline* is set and exceeded.
412
+ WherobotsAPIError: On non-transient HTTP errors.
413
+ """
414
+ if not self.run_id:
415
+ raise WherobotsJobError("Job not submitted. Call submit() first.")
416
+
417
+ def default_handler(item: dict[str, Any]) -> None:
418
+ print(item.get("raw", ""), flush=True)
419
+
420
+ handler = log_handler or default_handler
421
+ consecutive_errors = 0
422
+
423
+ while True:
424
+ # Enforce timeout when called from wait_for_completion
425
+ if _deadline is not None and time.monotonic() >= _deadline:
426
+ raise WherobotsTimeoutError(
427
+ f"Timed out waiting for job {self.run_id} "
428
+ f"(deadline exceeded during log streaming)"
429
+ )
430
+
431
+ try:
432
+ logs = self.get_logs(cursor=self._last_log_cursor)
433
+
434
+ for item in logs.items:
435
+ handler(item.to_dict())
436
+
437
+ if logs.next_page is not None:
438
+ self._last_log_cursor = logs.next_page
439
+
440
+ consecutive_errors = 0 # Reset on success
441
+
442
+ if not follow:
443
+ break
444
+
445
+ run_view = self.get_status()
446
+ if is_terminal_status(run_view.status):
447
+ # Drain remaining logs
448
+ final_logs = self.get_logs(cursor=self._last_log_cursor, size=1000)
449
+ for item in final_logs.items:
450
+ handler(item.to_dict())
451
+ break
452
+
453
+ time.sleep(interval)
454
+
455
+ except KeyboardInterrupt:
456
+ logger.info("Log streaming interrupted by user")
457
+ break
458
+ except WherobotsAPIError as exc:
459
+ # Non-transient HTTP errors should not be retried
460
+ if exc.status_code and 400 <= exc.status_code < 500 and exc.status_code != 429:
461
+ raise
462
+ consecutive_errors += 1
463
+ logger.error("Error polling logs (%d/%d): %s", consecutive_errors, max_errors, exc)
464
+ if consecutive_errors >= max_errors or not follow:
465
+ raise
466
+ time.sleep(interval)
467
+ except Exception as exc:
468
+ consecutive_errors += 1
469
+ logger.error("Error polling logs (%d/%d): %s", consecutive_errors, max_errors, exc)
470
+ if consecutive_errors >= max_errors or not follow:
471
+ raise
472
+ time.sleep(interval)
473
+
474
+ def cancel(self) -> bool:
475
+ """Cancel the job run.
476
+
477
+ Returns:
478
+ True if cancellation request was successful
479
+ """
480
+ if not self.run_id:
481
+ raise WherobotsJobError("Job not submitted. Call submit() first.")
482
+
483
+ self._api.cancel(self.run_id)
484
+
485
+ logger.info("Cancellation requested for run %s", self.run_id)
486
+ logger.info("It may take a minute for status to change to CANCELLED")
487
+
488
+ return True
489
+
490
+ def wait_for_completion(
491
+ self,
492
+ poll_interval: float = 5.0,
493
+ stream_logs: bool = True,
494
+ log_interval: float = 2.0,
495
+ max_wait_seconds: float | None = None,
496
+ ) -> JobStatus | str:
497
+ """Wait for job to complete, optionally streaming logs.
498
+
499
+ Args:
500
+ poll_interval: Status check interval in seconds (default: 5.0)
501
+ stream_logs: Stream logs while waiting (default: True)
502
+ log_interval: Log polling interval in seconds (default: 2.0)
503
+ max_wait_seconds: Maximum time to wait before raising
504
+ ``WherobotsTimeoutError``. ``None`` means wait
505
+ indefinitely (default: None).
506
+
507
+ Returns:
508
+ Final job status
509
+
510
+ Raises:
511
+ WherobotsTimeoutError: If max_wait_seconds is exceeded.
512
+ """
513
+ if not self.run_id:
514
+ raise WherobotsJobError("Job not submitted. Call submit() first.")
515
+
516
+ logger.info("Waiting for job %s to complete...", self.run_id)
517
+ start = time.monotonic()
518
+
519
+ if stream_logs:
520
+ deadline = start + max_wait_seconds if max_wait_seconds is not None else None
521
+ self.poll_for_logs(follow=True, interval=log_interval, _deadline=deadline)
522
+ else:
523
+ while True:
524
+ if max_wait_seconds is not None:
525
+ elapsed = time.monotonic() - start
526
+ if elapsed >= max_wait_seconds:
527
+ raise WherobotsTimeoutError(
528
+ f"Timed out waiting for job {self.run_id} "
529
+ f"after {elapsed:.0f}s (limit: {max_wait_seconds}s)"
530
+ )
531
+ run_view = self.get_status()
532
+ if is_terminal_status(run_view.status):
533
+ break
534
+ if run_view.status:
535
+ logger.info("Status: %s", _status_str(run_view.status))
536
+ time.sleep(poll_interval)
537
+
538
+ final_view = self.get_status()
539
+ if not final_view.status:
540
+ raise WherobotsJobError(f"Job {self.run_id} returned an unrecognized status")
541
+
542
+ self.status = final_view.status
543
+ logger.info("Job %s finished with status: %s", self.run_id, _status_str(self.status))
544
+
545
+ if self.status == JobStatus.FAILED:
546
+ logger.error("Job failed. Check logs for details.")
547
+ elif self.status == JobStatus.CANCELLED:
548
+ logger.warning("Job was cancelled.")
549
+
550
+ return self.status
551
+
552
+ # ------------------------------------------------------------------ #
553
+ # Class-level helpers (no instance required)
554
+ # ------------------------------------------------------------------ #
555
+
556
+ @classmethod
557
+ def list_runs(
558
+ cls,
559
+ api_key: str | None = None,
560
+ region: str | None = None,
561
+ name_pattern: str | None = None,
562
+ created_after: str | None = None,
563
+ status: list[str | JobStatus] | None = None,
564
+ cursor: str | None = None,
565
+ size: int = 50,
566
+ base_url: str | None = None,
567
+ request_timeout_seconds: int | None = None,
568
+ config: WherobotsConfig | None = None,
569
+ ) -> RunListPage:
570
+ """List job runs in the organization.
571
+
572
+ Args:
573
+ api_key: Wherobots API key (or use WHEROBOTS_API_KEY env var)
574
+ region: Filter by region
575
+ name_pattern: Filter by name pattern (supports * wildcard)
576
+ created_after: Filter runs created after this ISO timestamp
577
+ status: Filter by status (list of statuses)
578
+ cursor: Pagination cursor
579
+ size: Number of results per page (default: 50, max: 100)
580
+ base_url: Override API base URL
581
+ request_timeout_seconds: HTTP request timeout in seconds
582
+ config: Optional WherobotsConfig to override defaults
583
+
584
+ Returns:
585
+ RunListPage with items, total, and pagination cursor
586
+ """
587
+ cfg = config or WherobotsConfig.from_env(
588
+ api_key=api_key,
589
+ region=region,
590
+ base_url=base_url,
591
+ request_timeout_seconds=request_timeout_seconds,
592
+ )
593
+
594
+ with RunsAPI.from_config(cfg) as api:
595
+ return api.list(
596
+ region=region or cfg.region,
597
+ name_pattern=name_pattern,
598
+ created_after=created_after,
599
+ status=status,
600
+ cursor=cursor,
601
+ size=size,
602
+ )
603
+
604
+ @classmethod
605
+ def add_pypi_dependency(cls, library_name: str, library_version: str) -> dict[str, str]:
606
+ """Create a PyPI dependency object.
607
+
608
+ Args:
609
+ library_name: PyPI package name
610
+ library_version: Package version
611
+
612
+ Returns:
613
+ Dependency dictionary
614
+ """
615
+ return {
616
+ "sourceType": DependencyType.PYPI.value,
617
+ "libraryName": library_name,
618
+ "libraryVersion": library_version,
619
+ }
620
+
621
+ @classmethod
622
+ def add_file_dependency(cls, file_path: str) -> dict[str, str]:
623
+ """Create a file dependency object.
624
+
625
+ Args:
626
+ file_path: Path to dependency file (.jar, .whl, .zip, .json)
627
+
628
+ Returns:
629
+ Dependency dictionary
630
+ """
631
+ valid_extensions = [".jar", ".whl", ".zip", ".json"]
632
+ if not any(file_path.endswith(ext) for ext in valid_extensions):
633
+ raise WherobotsValidationError(
634
+ f"File must have one of these extensions: {valid_extensions}"
635
+ )
636
+
637
+ return {
638
+ "sourceType": DependencyType.FILE.value,
639
+ "filePath": file_path,
640
+ }