clarifai 11.7.5__py3-none-any.whl → 11.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -193,6 +193,7 @@ class Pipeline(Lister, BaseClient):
193
193
  """
194
194
  start_time = time.time()
195
195
  seen_logs = set()
196
+ current_page = 1 # Track current page for log pagination.
196
197
 
197
198
  while time.time() - start_time < timeout:
198
199
  # Get run status
@@ -217,8 +218,8 @@ class Pipeline(Lister, BaseClient):
217
218
  pipeline_run, preserving_proto_field_name=True
218
219
  )
219
220
 
220
- # Display new log entries
221
- self._display_new_logs(run_id, seen_logs)
221
+ # Display new log entries and update current page
222
+ current_page = self._display_new_logs(run_id, seen_logs, current_page)
222
223
 
223
224
  elapsed_time = time.time() - start_time
224
225
  logger.info(f"Pipeline run monitoring... (elapsed {elapsed_time:.1f}s)")
@@ -276,12 +277,16 @@ class Pipeline(Lister, BaseClient):
276
277
  logger.error(f"Pipeline run timed out after {timeout} seconds")
277
278
  return {"status": "timeout"}
278
279
 
279
- def _display_new_logs(self, run_id: str, seen_logs: set):
280
+ def _display_new_logs(self, run_id: str, seen_logs: set, current_page: int = 1) -> int:
280
281
  """Display new log entries for a pipeline version run.
281
282
 
282
283
  Args:
283
284
  run_id (str): The pipeline version run ID.
284
285
  seen_logs (set): Set of already seen log entry IDs.
286
+ current_page (int): The current page to fetch logs from.
287
+
288
+ Returns:
289
+ int: The next page number to fetch from in subsequent calls.
285
290
  """
286
291
  try:
287
292
  logs_request = service_pb2.ListLogEntriesRequest()
@@ -290,7 +295,7 @@ class Pipeline(Lister, BaseClient):
290
295
  logs_request.pipeline_version_id = self.pipeline_version_id or ""
291
296
  logs_request.pipeline_version_run_id = run_id
292
297
  logs_request.log_type = "pipeline.version.run" # Set required log type
293
- logs_request.page = 1
298
+ logs_request.page = current_page
294
299
  logs_request.per_page = 50
295
300
 
296
301
  logs_response = self.STUB.ListLogEntries(
@@ -298,7 +303,9 @@ class Pipeline(Lister, BaseClient):
298
303
  )
299
304
 
300
305
  if logs_response.status.code == status_code_pb2.StatusCode.SUCCESS:
306
+ entries_count = 0
301
307
  for log_entry in logs_response.log_entries:
308
+ entries_count += 1
302
309
  # Use log entry URL or timestamp as unique identifier
303
310
  log_id = log_entry.url or f"{log_entry.created_at.seconds}_{log_entry.message}"
304
311
  if log_id not in seen_logs:
@@ -312,5 +319,14 @@ class Pipeline(Lister, BaseClient):
312
319
  else:
313
320
  logger.info(log_message)
314
321
 
322
+ # If we got a full page (50 entries), there might be more logs on the next page
323
+ # If we got fewer than 50 entries, we've reached the end and should stay on current page
324
+ if entries_count == 50:
325
+ return current_page + 1
326
+ else:
327
+ return current_page
328
+
315
329
  except Exception as e:
316
330
  logger.debug(f"Error fetching logs: {e}")
331
+ # Return current page on error to retry the same page next fetch
332
+ return current_page
@@ -10,7 +10,7 @@ import tarfile
10
10
  import time
11
11
  import webbrowser
12
12
  from string import Template
13
- from typing import Literal
13
+ from typing import Any, Dict, Literal, Optional
14
14
  from unittest.mock import MagicMock
15
15
 
16
16
  import yaml
@@ -1107,7 +1107,7 @@ class ModelBuilder:
1107
1107
  logger.error(f"Failed to download checkpoints for model {repo_id}")
1108
1108
  sys.exit(1)
1109
1109
  else:
1110
- logger.info(f"Downloaded checkpoints for model {repo_id}")
1110
+ logger.info(f"Downloaded checkpoints for model {repo_id} successfully to {path}")
1111
1111
  return path
1112
1112
 
1113
1113
  def _concepts_protos_from_concepts(self, concepts):
@@ -1140,7 +1140,109 @@ class ModelBuilder:
1140
1140
  concepts = config.get('concepts')
1141
1141
  logger.info(f"Updated config.yaml with {len(concepts)} concepts.")
1142
1142
 
1143
- def get_model_version_proto(self):
1143
+ def _get_git_info(self) -> Optional[Dict[str, Any]]:
1144
+ """
1145
+ Get git repository information for the model path.
1146
+
1147
+ Returns:
1148
+ Dict with git info (url, commit, branch) or None if not a git repository
1149
+ """
1150
+ try:
1151
+ # Check if the folder is within a git repository
1152
+ result = subprocess.run(
1153
+ ['git', 'rev-parse', '--git-dir'],
1154
+ cwd=self.folder,
1155
+ capture_output=True,
1156
+ text=True,
1157
+ check=True,
1158
+ )
1159
+
1160
+ # Get git remote URL
1161
+ remote_result = subprocess.run(
1162
+ ['git', 'config', '--get', 'remote.origin.url'],
1163
+ cwd=self.folder,
1164
+ capture_output=True,
1165
+ text=True,
1166
+ check=False,
1167
+ )
1168
+
1169
+ # Get current commit hash
1170
+ commit_result = subprocess.run(
1171
+ ['git', 'rev-parse', 'HEAD'],
1172
+ cwd=self.folder,
1173
+ capture_output=True,
1174
+ text=True,
1175
+ check=True,
1176
+ )
1177
+
1178
+ # Get current branch
1179
+ branch_result = subprocess.run(
1180
+ ['git', 'branch', '--show-current'],
1181
+ cwd=self.folder,
1182
+ capture_output=True,
1183
+ text=True,
1184
+ check=False,
1185
+ )
1186
+
1187
+ git_info = {
1188
+ 'commit': commit_result.stdout.strip(),
1189
+ 'branch': branch_result.stdout.strip()
1190
+ if branch_result.returncode == 0
1191
+ else 'HEAD',
1192
+ }
1193
+
1194
+ if remote_result.returncode == 0:
1195
+ git_info['url'] = remote_result.stdout.strip()
1196
+
1197
+ return git_info
1198
+
1199
+ except (subprocess.CalledProcessError, FileNotFoundError):
1200
+ # Not a git repository or git not available
1201
+ return None
1202
+
1203
+ def _check_git_status_and_prompt(self) -> bool:
1204
+ """
1205
+ Check for uncommitted changes in git repository within the model path and prompt user.
1206
+
1207
+ Returns:
1208
+ True if should continue with upload, False if should abort
1209
+ """
1210
+ try:
1211
+ # Check for uncommitted changes within the model path only
1212
+ status_result = subprocess.run(
1213
+ ['git', 'status', '--porcelain', '.'],
1214
+ cwd=self.folder,
1215
+ capture_output=True,
1216
+ text=True,
1217
+ check=True,
1218
+ )
1219
+
1220
+ if status_result.stdout.strip():
1221
+ logger.warning("Uncommitted changes detected in model path:")
1222
+ logger.warning(status_result.stdout)
1223
+
1224
+ response = input(
1225
+ "\nDo you want to continue upload with uncommitted changes? (y/N): "
1226
+ )
1227
+ return response.lower() in ['y', 'yes']
1228
+ else:
1229
+ logger.info("Model path has no uncommitted changes.")
1230
+ return True
1231
+
1232
+ except (subprocess.CalledProcessError, FileNotFoundError):
1233
+ # Error checking git status, but we already know it's a git repo
1234
+ logger.warning("Could not check git status, continuing with upload.")
1235
+ return True
1236
+
1237
+ def get_model_version_proto(self, git_info: Optional[Dict[str, Any]] = None):
1238
+ """
1239
+ Create a ModelVersion protobuf message for the model.
1240
+ Args:
1241
+ git_info (Optional[Dict[str, Any]]): Git repository information to include in metadata.
1242
+ Returns:
1243
+ resources_pb2.ModelVersion: The ModelVersion protobuf message.
1244
+ """
1245
+
1144
1246
  signatures = self.get_method_signatures()
1145
1247
  model_version_proto = resources_pb2.ModelVersion(
1146
1248
  pretrained_model_config=resources_pb2.PretrainedModelConfig(),
@@ -1148,6 +1250,14 @@ class ModelBuilder:
1148
1250
  method_signatures=signatures,
1149
1251
  )
1150
1252
 
1253
+ # Add git information to metadata if available
1254
+ if git_info:
1255
+ from google.protobuf.struct_pb2 import Struct
1256
+
1257
+ metadata_struct = Struct()
1258
+ metadata_struct.update({'git_registry': git_info})
1259
+ model_version_proto.metadata.CopyFrom(metadata_struct)
1260
+
1151
1261
  model_type_id = self.config.get('model').get('model_type_id')
1152
1262
  if model_type_id in CONCEPTS_REQUIRED_MODEL_TYPE:
1153
1263
  if 'concepts' in self.config:
@@ -1175,7 +1285,7 @@ class ModelBuilder:
1175
1285
  )
1176
1286
  return model_version_proto
1177
1287
 
1178
- def upload_model_version(self):
1288
+ def upload_model_version(self, git_info=None):
1179
1289
  file_path = f"{self.folder}.tar.gz"
1180
1290
  logger.debug(f"Will tar it into file: {file_path}")
1181
1291
 
@@ -1208,7 +1318,7 @@ class ModelBuilder:
1208
1318
  )
1209
1319
  return
1210
1320
 
1211
- model_version_proto = self.get_model_version_proto()
1321
+ model_version_proto = self.get_model_version_proto(git_info)
1212
1322
 
1213
1323
  def filter_func(tarinfo):
1214
1324
  name = tarinfo.name
@@ -1228,7 +1338,7 @@ class ModelBuilder:
1228
1338
  if when != "upload" and self.config.get("checkpoints"):
1229
1339
  # Get the checkpoint size to add to the storage request.
1230
1340
  # First check for the env variable, then try querying huggingface. If all else fails, use the default.
1231
- checkpoint_size = os.environ.get('CHECKPOINT_SIZE_BYTES', 0)
1341
+ checkpoint_size = int(os.environ.get('CHECKPOINT_SIZE_BYTES', 0))
1232
1342
  if not checkpoint_size:
1233
1343
  _, repo_id, _, _, _, _ = self._validate_config_checkpoints()
1234
1344
  checkpoint_size = HuggingFaceLoader.get_huggingface_checkpoint_total_size(repo_id)
@@ -1332,13 +1442,13 @@ XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
1332
1442
  )
1333
1443
  return result
1334
1444
 
1335
- def get_model_build_logs(self):
1445
+ def get_model_build_logs(self, current_page=1):
1336
1446
  logs_request = service_pb2.ListLogEntriesRequest(
1337
1447
  log_type="builder",
1338
1448
  user_app_id=self.client.user_app_id,
1339
1449
  model_id=self.model_proto.id,
1340
1450
  model_version_id=self.model_version_id,
1341
- page=1,
1451
+ page=current_page,
1342
1452
  per_page=50,
1343
1453
  )
1344
1454
  response = self.client.STUB.ListLogEntries(logs_request)
@@ -1347,6 +1457,7 @@ XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
1347
1457
  def monitor_model_build(self):
1348
1458
  st = time.time()
1349
1459
  seen_logs = set() # To avoid duplicate log messages
1460
+ current_page = 1 # Track current page for log pagination
1350
1461
  while True:
1351
1462
  resp = self.client.STUB.GetModelVersion(
1352
1463
  service_pb2.GetModelVersionRequest(
@@ -1357,8 +1468,10 @@ XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
1357
1468
  )
1358
1469
 
1359
1470
  status_code = resp.model_version.status.code
1360
- logs = self.get_model_build_logs()
1471
+ logs = self.get_model_build_logs(current_page)
1472
+ entries_count = 0
1361
1473
  for log_entry in logs.log_entries:
1474
+ entries_count += 1
1362
1475
  if log_entry.url not in seen_logs:
1363
1476
  seen_logs.add(log_entry.url)
1364
1477
  log_entry_msg = re.sub(
@@ -1367,6 +1480,12 @@ XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
1367
1480
  log_entry.message.strip(),
1368
1481
  )
1369
1482
  logger.info(log_entry_msg)
1483
+
1484
+ # If we got a full page (50 entries), there might be more logs on the next page
1485
+ # If we got fewer than 50 entries, we've reached the end and should stay on current page
1486
+ if entries_count == 50:
1487
+ current_page += 1
1488
+ # else: stay on current_page
1370
1489
  if status_code == status_code_pb2.MODEL_BUILDING:
1371
1490
  print(
1372
1491
  f"Model is building... (elapsed {time.time() - st:.1f}s)", end='\r', flush=True
@@ -1414,8 +1533,20 @@ def upload_model(folder, stage, skip_dockerfile, pat=None, base_url=None):
1414
1533
  f"New model will be created at {builder.model_ui_url} with it's first version."
1415
1534
  )
1416
1535
 
1536
+ # Check for git repository information
1537
+ git_info = builder._get_git_info()
1538
+ if git_info:
1539
+ logger.info(f"Detected git repository: {git_info.get('url', 'local repository')}")
1540
+ logger.info(f"Current commit: {git_info['commit']}")
1541
+ logger.info(f"Current branch: {git_info['branch']}")
1542
+
1543
+ # Check for uncommitted changes and prompt user
1544
+ if not builder._check_git_status_and_prompt():
1545
+ logger.info("Upload cancelled by user due to uncommitted changes.")
1546
+ return
1417
1547
  input("Press Enter to continue...")
1418
- model_version = builder.upload_model_version()
1548
+
1549
+ model_version = builder.upload_model_version(git_info)
1419
1550
 
1420
1551
  # Ask user if they want to deploy the model
1421
1552
  if model_version is not None: # if it comes back None then it failed.
@@ -9,7 +9,7 @@ from clarifai_protocol.utils.health import HealthProbeRequestHandler
9
9
  from clarifai.client.auth.helper import ClarifaiAuthHelper
10
10
  from clarifai.utils.constants import STATUS_FAIL, STATUS_MIXED, STATUS_OK, STATUS_UNKNOWN
11
11
  from clarifai.utils.logging import get_req_id_from_context, logger
12
- from clarifai.utils.secrets import inject_secrets
12
+ from clarifai.utils.secrets import inject_secrets, req_secrets_context
13
13
 
14
14
  from ..utils.url_fetcher import ensure_urls_downloaded
15
15
  from .model_class import ModelClass
@@ -126,7 +126,10 @@ class ModelRunner(BaseRunner, HealthProbeRequestHandler):
126
126
  if method_name == '_GET_SIGNATURES':
127
127
  logging = False
128
128
 
129
- resp = self.model.predict_wrapper(request)
129
+ # Use req_secrets_context to temporarily set request-type secrets as environment variables
130
+ with req_secrets_context(request):
131
+ resp = self.model.predict_wrapper(request)
132
+
130
133
  # if we have any non-successful code already it's an error we can return.
131
134
  if (
132
135
  resp.status.code != status_code_pb2.SUCCESS
@@ -185,45 +188,49 @@ class ModelRunner(BaseRunner, HealthProbeRequestHandler):
185
188
  status_str = STATUS_UNKNOWN
186
189
  endpoint = "POST /v2/.../outputs/generate"
187
190
 
188
- for resp in self.model.generate_wrapper(request):
189
- # if we have any non-successful code already it's an error we can return.
190
- if (
191
- resp.status.code != status_code_pb2.SUCCESS
192
- and resp.status.code != status_code_pb2.ZERO
193
- ):
194
- status_str = f"{resp.status.code} ERROR"
195
- duration_ms = (time.time() - start_time) * 1000
196
- logger.info(f"{endpoint} | {status_str} | {duration_ms:.2f}ms | req_id={req_id}")
197
- yield service_pb2.RunnerItemOutput(multi_output_response=resp)
198
- continue
199
- successes = []
200
- for output in resp.outputs:
201
- if not output.HasField('status') or not output.status.code:
202
- raise Exception(
203
- "Output must have a status code, please check the model implementation."
191
+ # Use req_secrets_context to temporarily set request-type secrets as environment variables
192
+ with req_secrets_context(request):
193
+ for resp in self.model.generate_wrapper(request):
194
+ # if we have any non-successful code already it's an error we can return.
195
+ if (
196
+ resp.status.code != status_code_pb2.SUCCESS
197
+ and resp.status.code != status_code_pb2.ZERO
198
+ ):
199
+ status_str = f"{resp.status.code} ERROR"
200
+ duration_ms = (time.time() - start_time) * 1000
201
+ logger.info(
202
+ f"{endpoint} | {status_str} | {duration_ms:.2f}ms | req_id={req_id}"
204
203
  )
205
- successes.append(output.status.code == status_code_pb2.SUCCESS)
206
- if all(successes):
207
- status = status_pb2.Status(
208
- code=status_code_pb2.SUCCESS,
209
- description="Success",
210
- )
211
- status_str = STATUS_OK
212
- elif any(successes):
213
- status = status_pb2.Status(
214
- code=status_code_pb2.MIXED_STATUS,
215
- description="Mixed Status",
216
- )
217
- status_str = STATUS_MIXED
218
- else:
219
- status = status_pb2.Status(
220
- code=status_code_pb2.FAILURE,
221
- description="Failed",
222
- )
223
- status_str = STATUS_FAIL
224
- resp.status.CopyFrom(status)
204
+ yield service_pb2.RunnerItemOutput(multi_output_response=resp)
205
+ continue
206
+ successes = []
207
+ for output in resp.outputs:
208
+ if not output.HasField('status') or not output.status.code:
209
+ raise Exception(
210
+ "Output must have a status code, please check the model implementation."
211
+ )
212
+ successes.append(output.status.code == status_code_pb2.SUCCESS)
213
+ if all(successes):
214
+ status = status_pb2.Status(
215
+ code=status_code_pb2.SUCCESS,
216
+ description="Success",
217
+ )
218
+ status_str = STATUS_OK
219
+ elif any(successes):
220
+ status = status_pb2.Status(
221
+ code=status_code_pb2.MIXED_STATUS,
222
+ description="Mixed Status",
223
+ )
224
+ status_str = STATUS_MIXED
225
+ else:
226
+ status = status_pb2.Status(
227
+ code=status_code_pb2.FAILURE,
228
+ description="Failed",
229
+ )
230
+ status_str = STATUS_FAIL
231
+ resp.status.CopyFrom(status)
225
232
 
226
- yield service_pb2.RunnerItemOutput(multi_output_response=resp)
233
+ yield service_pb2.RunnerItemOutput(multi_output_response=resp)
227
234
 
228
235
  duration_ms = (time.time() - start_time) * 1000
229
236
  logger.info(f"{endpoint} | {status_str} | {duration_ms:.2f}ms | req_id={req_id}")
@@ -237,45 +244,55 @@ class ModelRunner(BaseRunner, HealthProbeRequestHandler):
237
244
  status_str = STATUS_UNKNOWN
238
245
  endpoint = "POST /v2/.../outputs/stream "
239
246
 
240
- for resp in self.model.stream_wrapper(pmo_iterator(runner_item_iterator)):
241
- # if we have any non-successful code already it's an error we can return.
242
- if (
243
- resp.status.code != status_code_pb2.SUCCESS
244
- and resp.status.code != status_code_pb2.ZERO
245
- ):
246
- status_str = f"{resp.status.code} ERROR"
247
- duration_ms = (time.time() - start_time) * 1000
248
- logger.info(f"{endpoint} | {status_str} | {duration_ms:.2f}ms | req_id={req_id}")
249
- yield service_pb2.RunnerItemOutput(multi_output_response=resp)
250
- continue
251
- successes = []
252
- for output in resp.outputs:
253
- if not output.HasField('status') or not output.status.code:
254
- raise Exception(
255
- "Output must have a status code, please check the model implementation."
247
+ # Get the first request to establish secrets context
248
+ first_request = None
249
+ runner_items = list(runner_item_iterator) # Convert to list to avoid consuming iterator
250
+ if runner_items:
251
+ first_request = runner_items[0].post_model_outputs_request
252
+
253
+ # Use req_secrets_context based on the first request (secrets should be consistent across stream)
254
+ with req_secrets_context(first_request):
255
+ for resp in self.model.stream_wrapper(pmo_iterator(iter(runner_items))):
256
+ # if we have any non-successful code already it's an error we can return.
257
+ if (
258
+ resp.status.code != status_code_pb2.SUCCESS
259
+ and resp.status.code != status_code_pb2.ZERO
260
+ ):
261
+ status_str = f"{resp.status.code} ERROR"
262
+ duration_ms = (time.time() - start_time) * 1000
263
+ logger.info(
264
+ f"{endpoint} | {status_str} | {duration_ms:.2f}ms | req_id={req_id}"
256
265
  )
257
- successes.append(output.status.code == status_code_pb2.SUCCESS)
258
- if all(successes):
259
- status = status_pb2.Status(
260
- code=status_code_pb2.SUCCESS,
261
- description="Success",
262
- )
263
- status_str = STATUS_OK
264
- elif any(successes):
265
- status = status_pb2.Status(
266
- code=status_code_pb2.MIXED_STATUS,
267
- description="Mixed Status",
268
- )
269
- status_str = STATUS_MIXED
270
- else:
271
- status = status_pb2.Status(
272
- code=status_code_pb2.FAILURE,
273
- description="Failed",
274
- )
275
- status_str = STATUS_FAIL
276
- resp.status.CopyFrom(status)
266
+ yield service_pb2.RunnerItemOutput(multi_output_response=resp)
267
+ continue
268
+ successes = []
269
+ for output in resp.outputs:
270
+ if not output.HasField('status') or not output.status.code:
271
+ raise Exception(
272
+ "Output must have a status code, please check the model implementation."
273
+ )
274
+ successes.append(output.status.code == status_code_pb2.SUCCESS)
275
+ if all(successes):
276
+ status = status_pb2.Status(
277
+ code=status_code_pb2.SUCCESS,
278
+ description="Success",
279
+ )
280
+ status_str = STATUS_OK
281
+ elif any(successes):
282
+ status = status_pb2.Status(
283
+ code=status_code_pb2.MIXED_STATUS,
284
+ description="Mixed Status",
285
+ )
286
+ status_str = STATUS_MIXED
287
+ else:
288
+ status = status_pb2.Status(
289
+ code=status_code_pb2.FAILURE,
290
+ description="Failed",
291
+ )
292
+ status_str = STATUS_FAIL
293
+ resp.status.CopyFrom(status)
277
294
 
278
- yield service_pb2.RunnerItemOutput(multi_output_response=resp)
295
+ yield service_pb2.RunnerItemOutput(multi_output_response=resp)
279
296
 
280
297
  duration_ms = (time.time() - start_time) * 1000
281
298
  logger.info(f"{endpoint} | {status_str} | {duration_ms:.2f}ms | req_id={req_id}")
@@ -284,6 +301,27 @@ class ModelRunner(BaseRunner, HealthProbeRequestHandler):
284
301
  """Set the model for this runner."""
285
302
  self.model = model
286
303
 
304
+ def handle_liveness_probe(self):
305
+ # if the model has a handle_liveness_probe method, call it to determine liveness
306
+ # otherwise rely on HealthProbeRequestHandler.is_alive from the protocol
307
+ if hasattr(self.model, 'handle_liveness_probe'):
308
+ HealthProbeRequestHandler.is_alive = self.model.handle_liveness_probe()
309
+ return super().handle_liveness_probe()
310
+
311
+ def handle_readiness_probe(self):
312
+ # if the model has a handle_readiness_probe method, call it to determine readiness
313
+ # otherwise rely on HealthProbeRequestHandler.is_ready from the protocol
314
+ if hasattr(self.model, 'handle_readiness_probe'):
315
+ HealthProbeRequestHandler.is_ready = self.model.handle_readiness_probe()
316
+ return super().handle_readiness_probe()
317
+
318
+ def handle_startup_probe(self):
319
+ # if the model has a handle_startup_probe method, call it to determine startup
320
+ # otherwise rely on HealthProbeRequestHandler.is_startup from the protocol
321
+ if hasattr(self.model, 'handle_startup_probe'):
322
+ HealthProbeRequestHandler.is_startup = self.model.handle_startup_probe()
323
+ return super().handle_startup_probe()
324
+
287
325
 
288
326
  def pmo_iterator(runner_item_iterator, auth_helper=None):
289
327
  for runner_item in runner_item_iterator:
@@ -69,6 +69,24 @@ class OpenAIModelClass(ModelClass):
69
69
 
70
70
  return completion_args
71
71
 
72
+ def handle_liveness_probe(self) -> bool:
73
+ """Handle liveness probe by checking if the client can list models."""
74
+ try:
75
+ _ = self.client.models.list()
76
+ return True
77
+ except Exception as e:
78
+ logger.error(f"Liveness probe failed: {e}", exc_info=True)
79
+ return False
80
+
81
+ def handle_readiness_probe(self) -> bool:
82
+ """Handle readiness probe by checking if the client can list models."""
83
+ try:
84
+ _ = self.client.models.list()
85
+ return True
86
+ except Exception as e:
87
+ logger.error(f"Readiness probe failed: {e}", exc_info=True)
88
+ return False
89
+
72
90
  def _set_usage(self, resp):
73
91
  if resp.usage and resp.usage.prompt_tokens and resp.usage.completion_tokens:
74
92
  self.set_output_context(
@@ -408,6 +408,7 @@ COPY --link=true requirements.txt config.yaml /home/nonroot/main/
408
408
  """
409
409
  max_checks = timeout_sec // interval_sec
410
410
  seen_logs = set() # To avoid duplicate log messages
411
+ current_page = 1 # Track current page for log pagination
411
412
  st = time.time()
412
413
 
413
414
  for _ in range(max_checks):
@@ -434,14 +435,16 @@ COPY --link=true requirements.txt config.yaml /home/nonroot/main/
434
435
  user_app_id=self.client.user_app_id,
435
436
  pipeline_step_id=self.pipeline_step_id,
436
437
  pipeline_step_version_id=self.pipeline_step_version_id,
437
- page=1,
438
+ page=current_page,
438
439
  per_page=50,
439
440
  )
440
441
  logs = self.client.STUB.ListLogEntries(
441
442
  logs_request, metadata=self.client.auth_helper.metadata
442
443
  )
443
444
 
445
+ entries_count = 0
444
446
  for log_entry in logs.log_entries:
447
+ entries_count += 1
445
448
  if log_entry.url not in seen_logs:
446
449
  seen_logs.add(log_entry.url)
447
450
  log_entry_msg = re.sub(
@@ -451,6 +454,12 @@ COPY --link=true requirements.txt config.yaml /home/nonroot/main/
451
454
  )
452
455
  logger.info(log_entry_msg)
453
456
 
457
+ # If we got a full page (50 entries), there might be more logs on the next page
458
+ # If we got fewer than 50 entries, we've reached the end and should stay on current page
459
+ if entries_count == 50:
460
+ current_page += 1
461
+ # else: stay on current_page
462
+
454
463
  status = response.pipeline_step_version.status.code
455
464
  if status in {
456
465
  status_code_pb2.StatusCode.PIPELINE_STEP_READY,
@@ -301,7 +301,7 @@ class ModelServer:
301
301
  f"> Playground: To chat with your model, visit: {context.ui}/playground?model={context.model_id}__{context.model_version_id}&user_id={context.user_id}&app_id={context.app_id}\n"
302
302
  )
303
303
  logger.info(
304
- f"> API URL: To call your model via the API, use this model URL: {context.ui}/users/{context.user_id}/apps/{context.app_id}/models/{context.model_id}\n"
304
+ f"> API URL: To call your model via the API, use this model URL: {context.ui}/{context.user_id}/{context.app_id}/models/{context.model_id}\n"
305
305
  )
306
306
  logger.info("Press CTRL+C to stop the runner.\n")
307
307
  self._runner.start() # start the runner to fetch work from the API.