kaggle 1.7.3b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. kaggle/LICENSE +201 -0
  2. kaggle/__init__.py +6 -0
  3. kaggle/api/__init__.py +0 -0
  4. kaggle/api/kaggle_api.py +614 -0
  5. kaggle/api/kaggle_api_extended.py +4657 -0
  6. kaggle/cli.py +1606 -0
  7. kaggle/configuration.py +206 -0
  8. kaggle/models/__init__.py +0 -0
  9. kaggle/models/api_blob_type.py +4 -0
  10. kaggle/models/dataset_column.py +228 -0
  11. kaggle/models/dataset_new_request.py +385 -0
  12. kaggle/models/dataset_new_version_request.py +287 -0
  13. kaggle/models/dataset_update_settings_request.py +310 -0
  14. kaggle/models/kaggle_models_extended.py +276 -0
  15. kaggle/models/kernel_push_request.py +556 -0
  16. kaggle/models/model_instance_new_version_request.py +145 -0
  17. kaggle/models/model_instance_update_request.py +351 -0
  18. kaggle/models/model_new_instance_request.py +417 -0
  19. kaggle/models/model_new_request.py +314 -0
  20. kaggle/models/model_update_request.py +282 -0
  21. kaggle/models/start_blob_upload_request.py +232 -0
  22. kaggle/models/start_blob_upload_response.py +137 -0
  23. kaggle/models/upload_file.py +169 -0
  24. kaggle/test/__init__.py +0 -0
  25. kaggle/test/test_authenticate.py +43 -0
  26. kaggle-1.7.3b1.dist-info/METADATA +348 -0
  27. kaggle-1.7.3b1.dist-info/RECORD +89 -0
  28. kaggle-1.7.3b1.dist-info/WHEEL +4 -0
  29. kaggle-1.7.3b1.dist-info/entry_points.txt +2 -0
  30. kaggle-1.7.3b1.dist-info/licenses/LICENSE.txt +201 -0
  31. kagglesdk/LICENSE +201 -0
  32. kagglesdk/__init__.py +2 -0
  33. kagglesdk/admin/__init__.py +0 -0
  34. kagglesdk/admin/services/__init__.py +0 -0
  35. kagglesdk/admin/services/inbox_file_service.py +22 -0
  36. kagglesdk/admin/types/__init__.py +0 -0
  37. kagglesdk/admin/types/inbox_file_service.py +74 -0
  38. kagglesdk/blobs/__init__.py +0 -0
  39. kagglesdk/blobs/services/__init__.py +0 -0
  40. kagglesdk/blobs/services/blob_api_service.py +25 -0
  41. kagglesdk/blobs/types/__init__.py +0 -0
  42. kagglesdk/blobs/types/blob_api_service.py +177 -0
  43. kagglesdk/common/__init__.py +0 -0
  44. kagglesdk/common/types/__init__.py +0 -0
  45. kagglesdk/common/types/file_download.py +102 -0
  46. kagglesdk/common/types/http_redirect.py +105 -0
  47. kagglesdk/competitions/__init__.py +0 -0
  48. kagglesdk/competitions/services/__init__.py +0 -0
  49. kagglesdk/competitions/services/competition_api_service.py +129 -0
  50. kagglesdk/competitions/types/__init__.py +0 -0
  51. kagglesdk/competitions/types/competition_api_service.py +1874 -0
  52. kagglesdk/competitions/types/competition_enums.py +53 -0
  53. kagglesdk/competitions/types/submission_status.py +9 -0
  54. kagglesdk/datasets/__init__.py +0 -0
  55. kagglesdk/datasets/services/__init__.py +0 -0
  56. kagglesdk/datasets/services/dataset_api_service.py +170 -0
  57. kagglesdk/datasets/types/__init__.py +0 -0
  58. kagglesdk/datasets/types/dataset_api_service.py +2777 -0
  59. kagglesdk/datasets/types/dataset_enums.py +82 -0
  60. kagglesdk/datasets/types/dataset_types.py +646 -0
  61. kagglesdk/education/__init__.py +0 -0
  62. kagglesdk/education/services/__init__.py +0 -0
  63. kagglesdk/education/services/education_api_service.py +19 -0
  64. kagglesdk/education/types/__init__.py +0 -0
  65. kagglesdk/education/types/education_api_service.py +248 -0
  66. kagglesdk/education/types/education_service.py +139 -0
  67. kagglesdk/kaggle_client.py +66 -0
  68. kagglesdk/kaggle_env.py +42 -0
  69. kagglesdk/kaggle_http_client.py +316 -0
  70. kagglesdk/kaggle_object.py +293 -0
  71. kagglesdk/kernels/__init__.py +0 -0
  72. kagglesdk/kernels/services/__init__.py +0 -0
  73. kagglesdk/kernels/services/kernels_api_service.py +109 -0
  74. kagglesdk/kernels/types/__init__.py +0 -0
  75. kagglesdk/kernels/types/kernels_api_service.py +1951 -0
  76. kagglesdk/kernels/types/kernels_enums.py +33 -0
  77. kagglesdk/models/__init__.py +0 -0
  78. kagglesdk/models/services/__init__.py +0 -0
  79. kagglesdk/models/services/model_api_service.py +255 -0
  80. kagglesdk/models/services/model_service.py +19 -0
  81. kagglesdk/models/types/__init__.py +0 -0
  82. kagglesdk/models/types/model_api_service.py +3719 -0
  83. kagglesdk/models/types/model_enums.py +60 -0
  84. kagglesdk/models/types/model_service.py +275 -0
  85. kagglesdk/models/types/model_types.py +286 -0
  86. kagglesdk/test/test_client.py +45 -0
  87. kagglesdk/users/__init__.py +0 -0
  88. kagglesdk/users/types/__init__.py +0 -0
  89. kagglesdk/users/types/users_enums.py +22 -0
@@ -0,0 +1,4657 @@
1
+ #!/usr/bin/python
2
+ #
3
+ # Copyright 2024 Kaggle Inc
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ # coding=utf-8
18
+ from __future__ import print_function
19
+
20
+ import csv
21
+ import io
22
+ import os
23
+ import shutil
24
+ import sys
25
+ import tarfile
26
+ import tempfile
27
+ import time
28
+ import zipfile
29
+ from os.path import expanduser
30
+ from random import random
31
+
32
+ import bleach
33
+ import requests
34
+ import urllib3.exceptions as urllib3_exceptions
35
+ from requests import RequestException
36
+
37
+ from kaggle.models.kaggle_models_extended import ResumableUploadResult, File
38
+
39
+ from requests.adapters import HTTPAdapter
40
+ from slugify import slugify
41
+ from tqdm import tqdm
42
+ from urllib3.util.retry import Retry
43
+ from google.protobuf import field_mask_pb2
44
+
45
+ from kaggle.configuration import Configuration
46
+ from kagglesdk import KaggleClient, KaggleEnv
47
+ from kagglesdk.admin.types.inbox_file_service import CreateInboxFileRequest
48
+ from kagglesdk.blobs.types.blob_api_service import ApiStartBlobUploadRequest, \
49
+ ApiStartBlobUploadResponse, ApiBlobType
50
+ from kagglesdk.competitions.types.competition_api_service import *
51
+ from kagglesdk.datasets.types.dataset_api_service import ApiListDatasetsRequest, \
52
+ ApiListDatasetFilesRequest, \
53
+ ApiGetDatasetStatusRequest, ApiDownloadDatasetRequest, \
54
+ ApiCreateDatasetRequest, ApiCreateDatasetVersionRequestBody, \
55
+ ApiCreateDatasetVersionByIdRequest, ApiCreateDatasetVersionRequest, \
56
+ ApiDatasetNewFile, ApiUpdateDatasetMetadataRequest, \
57
+ ApiGetDatasetMetadataRequest, ApiListDatasetFilesResponse, ApiDatasetFile
58
+ from kagglesdk.datasets.types.dataset_enums import DatasetSelectionGroup, \
59
+ DatasetSortBy, DatasetFileTypeGroup, DatasetLicenseGroup
60
+ from kagglesdk.datasets.types.dataset_types import DatasetSettings, \
61
+ SettingsLicense, DatasetCollaborator
62
+ from kagglesdk.kernels.types.kernels_api_service import ApiListKernelsRequest, \
63
+ ApiListKernelFilesRequest, ApiSaveKernelRequest, ApiGetKernelRequest, \
64
+ ApiListKernelSessionOutputRequest, ApiGetKernelSessionStatusRequest
65
+ from kagglesdk.kernels.types.kernels_enums import KernelsListSortType, \
66
+ KernelsListViewType
67
+ from kagglesdk.models.types.model_api_service import ApiListModelsRequest, \
68
+ ApiCreateModelRequest, ApiGetModelRequest, ApiDeleteModelRequest, \
69
+ ApiUpdateModelRequest, ApiGetModelInstanceRequest, \
70
+ ApiCreateModelInstanceRequest, ApiCreateModelInstanceRequestBody, \
71
+ ApiListModelInstanceVersionFilesRequest, ApiUpdateModelInstanceRequest, \
72
+ ApiDeleteModelInstanceRequest, ApiCreateModelInstanceVersionRequest, \
73
+ ApiCreateModelInstanceVersionRequestBody, \
74
+ ApiDownloadModelInstanceVersionRequest, ApiDeleteModelInstanceVersionRequest
75
+ from kagglesdk.models.types.model_enums import ListModelsOrderBy, \
76
+ ModelInstanceType, ModelFramework
77
+ from ..models.dataset_column import DatasetColumn
78
+ from ..models.upload_file import UploadFile
79
+
80
+
81
+ class DirectoryArchive(object):
82
+
83
+ def __init__(self, fullpath, format):
84
+ self._fullpath = fullpath
85
+ self._format = format
86
+ self.name = None
87
+ self.path = None
88
+
89
+ def __enter__(self):
90
+ self._temp_dir = tempfile.mkdtemp()
91
+ _, dir_name = os.path.split(self._fullpath)
92
+ self.path = shutil.make_archive(
93
+ os.path.join(self._temp_dir, dir_name), self._format, self._fullpath)
94
+ _, self.name = os.path.split(self.path)
95
+ return self
96
+
97
+ def __exit__(self, *args):
98
+ shutil.rmtree(self._temp_dir)
99
+
100
+
101
+ class ResumableUploadContext(object):
102
+
103
+ def __init__(self, no_resume=False):
104
+ self.no_resume = no_resume
105
+ self._temp_dir = os.path.join(tempfile.gettempdir(), '.kaggle/uploads')
106
+ self._file_uploads = []
107
+
108
+ def __enter__(self):
109
+ if self.no_resume:
110
+ return
111
+ self._create_temp_dir()
112
+ return self
113
+
114
+ def __exit__(self, exc_type, exc_value, exc_traceback):
115
+ if self.no_resume:
116
+ return
117
+ if exc_type is not None:
118
+ # Don't delete the upload file info when there is an error
119
+ # to give it a chance to retry/resume on the next invocation.
120
+ return
121
+ for file_upload in self._file_uploads:
122
+ file_upload.cleanup()
123
+
124
+ def get_upload_info_file_path(self, path):
125
+ return os.path.join(
126
+ self._temp_dir,
127
+ '%s.json' % path.replace(os.path.sep, '_').replace(':', '_'))
128
+
129
+ def new_resumable_file_upload(self, path, start_blob_upload_request):
130
+ file_upload = ResumableFileUpload(path, start_blob_upload_request, self)
131
+ self._file_uploads.append(file_upload)
132
+ file_upload.load()
133
+ return file_upload
134
+
135
+ def _create_temp_dir(self):
136
+ try:
137
+ os.makedirs(self._temp_dir)
138
+ except FileExistsError:
139
+ pass
140
+
141
+
142
+ class ResumableFileUpload(object):
143
+ # Reference: https://cloud.google.com/storage/docs/resumable-uploads
144
+ # A resumable upload must be completed within a week of being initiated
145
+ RESUMABLE_UPLOAD_EXPIRY_SECONDS = 6 * 24 * 3600
146
+
147
+ def __init__(self, path, start_blob_upload_request, context):
148
+ self.path = path
149
+ self.start_blob_upload_request = start_blob_upload_request
150
+ self.context = context
151
+ self.timestamp = int(time.time())
152
+ self.start_blob_upload_response = None
153
+ self.can_resume = False
154
+ self.upload_complete = False
155
+ if self.context.no_resume:
156
+ return
157
+ self._upload_info_file_path = self.context.get_upload_info_file_path(path)
158
+
159
+ def get_token(self):
160
+ if self.upload_complete:
161
+ return self.start_blob_upload_response.token
162
+ return None
163
+
164
+ def load(self):
165
+ if self.context.no_resume:
166
+ return
167
+ self._load_previous_if_any()
168
+
169
+ def _load_previous_if_any(self):
170
+ if not os.path.exists(self._upload_info_file_path):
171
+ return False
172
+
173
+ try:
174
+ with io.open(self._upload_info_file_path, 'r') as f:
175
+ previous = ResumableFileUpload.from_dict(json.load(f), self.context)
176
+ if self._is_previous_valid(previous):
177
+ self.start_blob_upload_response = previous.start_blob_upload_response
178
+ self.timestamp = previous.timestamp
179
+ self.can_resume = True
180
+ except Exception as e:
181
+ print('Error while trying to load upload info:', e)
182
+
183
+ def _is_previous_valid(self, previous):
184
+ return previous.path == self.path and \
185
+ previous.start_blob_upload_request == self.start_blob_upload_request and \
186
+ previous.timestamp > time.time() - ResumableFileUpload.RESUMABLE_UPLOAD_EXPIRY_SECONDS
187
+
188
+ def upload_initiated(self, start_blob_upload_response):
189
+ if self.context.no_resume:
190
+ return
191
+
192
+ self.start_blob_upload_response = start_blob_upload_response
193
+ with io.open(self._upload_info_file_path, 'w') as f:
194
+ json.dump(self.to_dict(), f, indent=True)
195
+
196
+ def upload_completed(self):
197
+ if self.context.no_resume:
198
+ return
199
+
200
+ self.upload_complete = True
201
+ self._save()
202
+
203
+ def _save(self):
204
+ with io.open(self._upload_info_file_path, 'w') as f:
205
+ json.dump(self.to_dict(), f, indent=True)
206
+
207
+ def cleanup(self):
208
+ if self.context.no_resume:
209
+ return
210
+
211
+ try:
212
+ os.remove(self._upload_info_file_path)
213
+ except OSError:
214
+ pass
215
+
216
+ def to_dict(self):
217
+ return {
218
+ 'path':
219
+ self.path,
220
+ 'start_blob_upload_request':
221
+ self.start_blob_upload_request.to_dict(),
222
+ 'timestamp':
223
+ self.timestamp,
224
+ 'start_blob_upload_response':
225
+ self.start_blob_upload_response.to_dict()
226
+ if self.start_blob_upload_response is not None else None,
227
+ 'upload_complete':
228
+ self.upload_complete,
229
+ }
230
+
231
+ def from_dict(other, context):
232
+ req = ApiStartBlobUploadRequest()
233
+ req.from_dict(other['start_blob_upload_request'])
234
+ new = ResumableFileUpload(
235
+ other['path'],
236
+ ApiStartBlobUploadRequest(**other['start_blob_upload_request']),
237
+ context)
238
+ new.timestamp = other.get('timestamp')
239
+ start_blob_upload_response = other.get('start_blob_upload_response')
240
+ if start_blob_upload_response is not None:
241
+ new.start_blob_upload_response = ApiStartBlobUploadResponse(
242
+ **start_blob_upload_response)
243
+ new.upload_complete = other.get('upload_complete') or False
244
+ return new
245
+
246
+ def to_str(self):
247
+ return str(self.to_dict())
248
+
249
+ def __repr__(self):
250
+ return self.to_str()
251
+
252
+
253
+ class KaggleApi:
254
+ __version__ = '1.7.3b1'
255
+
256
+ CONFIG_NAME_PROXY = 'proxy'
257
+ CONFIG_NAME_COMPETITION = 'competition'
258
+ CONFIG_NAME_PATH = 'path'
259
+ CONFIG_NAME_USER = 'username'
260
+ CONFIG_NAME_KEY = 'key'
261
+ CONFIG_NAME_SSL_CA_CERT = 'ssl_ca_cert'
262
+
263
+ HEADER_API_VERSION = 'X-Kaggle-ApiVersion'
264
+ DATASET_METADATA_FILE = 'dataset-metadata.json'
265
+ OLD_DATASET_METADATA_FILE = 'datapackage.json'
266
+ KERNEL_METADATA_FILE = 'kernel-metadata.json'
267
+ MODEL_METADATA_FILE = 'model-metadata.json'
268
+ MODEL_INSTANCE_METADATA_FILE = 'model-instance-metadata.json'
269
+ MAX_NUM_INBOX_FILES_TO_UPLOAD = 1000
270
+ MAX_UPLOAD_RESUME_ATTEMPTS = 10
271
+
272
+ config_dir = os.environ.get('KAGGLE_CONFIG_DIR')
273
+
274
+ if not config_dir:
275
+ config_dir = os.path.join(expanduser('~'), '.kaggle')
276
+ # Use ~/.kaggle if it already exists for backwards compatibility,
277
+ # otherwise follow XDG base directory specification
278
+ if sys.platform.startswith('linux') and not os.path.exists(config_dir):
279
+ config_dir = os.path.join((os.environ.get('XDG_CONFIG_HOME') or
280
+ os.path.join(expanduser('~'), '.config')),
281
+ 'kaggle')
282
+
283
+ if not os.path.exists(config_dir):
284
+ os.makedirs(config_dir)
285
+
286
+ config_file = 'kaggle.json'
287
+ config = os.path.join(config_dir, config_file)
288
+ config_values = {}
289
+ already_printed_version_warning = False
290
+
291
+ args = {} # DEBUG Add --local to use localhost
292
+ if os.environ.get('KAGGLE_API_ENVIRONMENT') == 'LOCALHOST':
293
+ args = {'--local'}
294
+
295
+ # Kernels valid types
296
+ valid_push_kernel_types = ['script', 'notebook']
297
+ valid_push_language_types = ['python', 'r', 'rmarkdown']
298
+ valid_push_pinning_types = ['original', 'latest']
299
+ valid_list_languages = ['all', 'python', 'r', 'sqlite', 'julia']
300
+ valid_list_kernel_types = ['all', 'script', 'notebook']
301
+ valid_list_output_types = ['all', 'visualization', 'data']
302
+ valid_list_sort_by = [
303
+ 'hotness', 'commentCount', 'dateCreated', 'dateRun', 'relevance',
304
+ 'scoreAscending', 'scoreDescending', 'viewCount', 'voteCount'
305
+ ]
306
+
307
+ # Competitions valid types
308
+ valid_competition_groups = [
309
+ 'general', 'entered', 'community', 'hosted', 'unlaunched',
310
+ 'unlaunched_community'
311
+ ]
312
+ valid_competition_categories = [
313
+ 'all', 'featured', 'research', 'recruitment', 'gettingStarted', 'masters',
314
+ 'playground'
315
+ ]
316
+ valid_competition_sort_by = [
317
+ 'grouped', 'best', 'prize', 'earliestDeadline', 'latestDeadline',
318
+ 'numberOfTeams', 'relevance', 'recentlyCreated'
319
+ ]
320
+
321
+ # Datasets valid types
322
+ valid_dataset_file_types = ['all', 'csv', 'sqlite', 'json', 'bigQuery']
323
+ valid_dataset_license_names = ['all', 'cc', 'gpl', 'odb', 'other']
324
+ valid_dataset_sort_bys = [
325
+ 'hottest', 'votes', 'updated', 'active', 'published'
326
+ ]
327
+
328
+ # Models valid types
329
+ valid_model_sort_bys = [
330
+ 'hotness', 'downloadCount', 'voteCount', 'notebookCount', 'createTime'
331
+ ]
332
+
333
+ # Command prefixes that are valid without authentication.
334
+ command_prefixes_allowing_anonymous_access = ('datasets download',
335
+ 'datasets files')
336
+
337
+ # Attributes
338
+ competition_fields = [
339
+ 'ref', 'deadline', 'category', 'reward', 'teamCount', 'userHasEntered'
340
+ ]
341
+ submission_fields = [
342
+ 'fileName', 'date', 'description', 'status', 'publicScore', 'privateScore'
343
+ ]
344
+ competition_file_fields = ['name', 'totalBytes', 'creationDate']
345
+ competition_file_labels = ['name', 'size', 'creationDate']
346
+ competition_leaderboard_fields = [
347
+ 'teamId', 'teamName', 'submissionDate', 'score'
348
+ ]
349
+ dataset_fields = [
350
+ 'ref', 'title', 'totalBytes', 'lastUpdated', 'downloadCount', 'voteCount',
351
+ 'usabilityRating'
352
+ ]
353
+ dataset_labels = [
354
+ 'ref', 'title', 'size', 'lastUpdated', 'downloadCount', 'voteCount',
355
+ 'usabilityRating'
356
+ ]
357
+ dataset_file_fields = ['name', 'total_bytes', 'creationDate']
358
+ model_fields = ['id', 'ref', 'title', 'subtitle', 'author']
359
+ model_all_fields = [
360
+ 'id', 'ref', 'author', 'slug', 'title', 'subtitle', 'isPrivate',
361
+ 'description', 'publishTime'
362
+ ]
363
+ model_file_fields = ['name', 'size', 'creationDate']
364
+
365
+ def _is_retriable(self, e):
366
+ return issubclass(type(e), ConnectionError) or \
367
+ issubclass(type(e), urllib3_exceptions.ConnectionError) or \
368
+ issubclass(type(e), urllib3_exceptions.ConnectTimeoutError) or \
369
+ issubclass(type(e), urllib3_exceptions.ProtocolError) or \
370
+ issubclass(type(e), requests.exceptions.ConnectionError) or \
371
+ issubclass(type(e), requests.exceptions.ConnectTimeout)
372
+
373
+ def _calculate_backoff_delay(self, attempt, initial_delay_millis,
374
+ retry_multiplier, randomness_factor):
375
+ delay_ms = initial_delay_millis * (retry_multiplier**attempt)
376
+ random_wait_ms = int(random() - 0.5) * 2 * delay_ms * randomness_factor
377
+ total_delay = (delay_ms + random_wait_ms) / 1000.0
378
+ return total_delay
379
+
380
+ def with_retry(self,
381
+ func,
382
+ max_retries=10,
383
+ initial_delay_millis=500,
384
+ retry_multiplier=1.7,
385
+ randomness_factor=0.5):
386
+
387
+ def retriable_func(*args):
388
+ for i in range(1, max_retries + 1):
389
+ try:
390
+ return func(*args)
391
+ except Exception as e:
392
+ if self._is_retriable(e) and i < max_retries:
393
+ total_delay = self._calculate_backoff_delay(i, initial_delay_millis,
394
+ retry_multiplier,
395
+ randomness_factor)
396
+ print('Request failed: %s. Will retry in %2.1f seconds' %
397
+ (e, total_delay))
398
+ time.sleep(total_delay)
399
+ continue
400
+ raise
401
+
402
+ return retriable_func
403
+
404
+ ## Authentication
405
+
406
+ def authenticate(self):
407
+ """authenticate the user with the Kaggle API. This method will generate
408
+ a configuration, first checking the environment for credential
409
+ variables, and falling back to looking for the .kaggle/kaggle.json
410
+ configuration file.
411
+ """
412
+
413
+ config_data = {}
414
+ # Ex: 'datasets list', 'competitions files', 'models instances get', etc.
415
+ api_command = ' '.join(sys.argv[1:])
416
+
417
+ # Step 1: try getting username/password from environment
418
+ config_data = self.read_config_environment(config_data)
419
+
420
+ # Step 2: if credentials were not in env read in configuration file
421
+ if self.CONFIG_NAME_USER not in config_data \
422
+ or self.CONFIG_NAME_KEY not in config_data:
423
+ if os.path.exists(self.config):
424
+ config_data = self.read_config_file(config_data)
425
+ elif self._is_help_or_version_command(api_command) or (len(
426
+ sys.argv) > 2 and api_command.startswith(
427
+ self.command_prefixes_allowing_anonymous_access)):
428
+ # Some API commands should be allowed without authentication.
429
+ return
430
+ else:
431
+ raise IOError('Could not find {}. Make sure it\'s located in'
432
+ ' {}. Or use the environment method. See setup'
433
+ ' instructions at'
434
+ ' https://github.com/Kaggle/kaggle-api/'.format(
435
+ self.config_file, self.config_dir))
436
+
437
+ # Step 3: load into configuration!
438
+ self._load_config(config_data)
439
+
440
+ def _is_help_or_version_command(self, api_command):
441
+ """determines if the string command passed in is for a help or version
442
+ command.
443
+ Parameters
444
+ ==========
445
+ api_command: a string, 'datasets list', 'competitions files',
446
+ 'models instances get', etc.
447
+ """
448
+ return api_command.endswith(('-h', '--help', '-v', '--version'))
449
+
450
+ def read_config_environment(self, config_data=None, quiet=False):
451
+ """read_config_environment is the second effort to get a username
452
+ and key to authenticate to the Kaggle API. The environment keys
453
+ are equivalent to the kaggle.json file, but with "KAGGLE_" prefix
454
+ to define a unique namespace.
455
+
456
+ Parameters
457
+ ==========
458
+ config_data: a partially loaded configuration dictionary (optional)
459
+ quiet: suppress verbose print of output (default is False)
460
+ """
461
+
462
+ # Add all variables that start with KAGGLE_ to config data
463
+
464
+ if config_data is None:
465
+ config_data = {}
466
+ for key, val in os.environ.items():
467
+ if key.startswith('KAGGLE_'):
468
+ config_key = key.replace('KAGGLE_', '', 1).lower()
469
+ config_data[config_key] = val
470
+
471
+ return config_data
472
+
473
+ ## Configuration
474
+
475
+ def _load_config(self, config_data):
476
+ """the final step of the authenticate steps, where we load the values
477
+ from config_data into the Configuration object.
478
+
479
+ Parameters
480
+ ==========
481
+ config_data: a dictionary with configuration values (keys) to read
482
+ into self.config_values
483
+
484
+ """
485
+ # Username and password are required.
486
+
487
+ for item in [self.CONFIG_NAME_USER, self.CONFIG_NAME_KEY]:
488
+ if item not in config_data:
489
+ raise ValueError('Error: Missing %s in configuration.' % item)
490
+
491
+ configuration = Configuration()
492
+
493
+ # Add to the final configuration (required)
494
+
495
+ configuration.username = config_data[self.CONFIG_NAME_USER]
496
+ configuration.password = config_data[self.CONFIG_NAME_KEY]
497
+
498
+ # Proxy
499
+
500
+ if self.CONFIG_NAME_PROXY in config_data:
501
+ configuration.proxy = config_data[self.CONFIG_NAME_PROXY]
502
+
503
+ # Cert File
504
+
505
+ if self.CONFIG_NAME_SSL_CA_CERT in config_data:
506
+ configuration.ssl_ca_cert = config_data[self.CONFIG_NAME_SSL_CA_CERT]
507
+
508
+ # Keep config values with class instance, and load api client!
509
+
510
+ self.config_values = config_data
511
+
512
+ def read_config_file(self, config_data=None, quiet=False):
513
+ """read_config_file is the first effort to get a username
514
+ and key to authenticate to the Kaggle API. Since we can get the
515
+ username and password from the environment, it's not required.
516
+
517
+ Parameters
518
+ ==========
519
+ config_data: the Configuration object to save a username and
520
+ password, if defined
521
+ quiet: suppress verbose print of output (default is False)
522
+ """
523
+ if config_data is None:
524
+ config_data = {}
525
+
526
+ if os.path.exists(self.config):
527
+
528
+ try:
529
+ if os.name != 'nt':
530
+ permissions = os.stat(self.config).st_mode
531
+ if (permissions & 4) or (permissions & 32):
532
+ print('Warning: Your Kaggle API key is readable by other '
533
+ 'users on this system! To fix this, you can run ' +
534
+ '\'chmod 600 {}\''.format(self.config))
535
+
536
+ with open(self.config) as f:
537
+ config_data = json.load(f)
538
+ except:
539
+ pass
540
+
541
+ else:
542
+
543
+ # Warn the user that configuration will be reliant on environment
544
+ if not quiet:
545
+ print('No Kaggle API config file found, will use environment.')
546
+
547
+ return config_data
548
+
549
+ def _read_config_file(self):
550
+ """read in the configuration file, a json file defined at self.config"""
551
+
552
+ try:
553
+ with open(self.config, 'r') as f:
554
+ config_data = json.load(f)
555
+ except FileNotFoundError:
556
+ config_data = {}
557
+
558
+ return config_data
559
+
560
+ def _write_config_file(self, config_data, indent=2):
561
+ """write config data to file.
562
+
563
+ Parameters
564
+ ==========
565
+ config_data: the Configuration object to save a username and
566
+ password, if defined
567
+ indent: number of tab indentations to use when writing json
568
+ """
569
+ with open(self.config, 'w') as f:
570
+ json.dump(config_data, f, indent=indent)
571
+
572
+ def set_config_value(self, name, value, quiet=False):
573
+ """a client helper function to set a configuration value, meaning
574
+ reading in the configuration file (if it exists), saving a new
575
+ config value, and then writing back
576
+
577
+ Parameters
578
+ ==========
579
+ name: the name of the value to set (key in dictionary)
580
+ value: the value to set at the key
581
+ quiet: disable verbose output if True (default is False)
582
+ """
583
+
584
+ config_data = self._read_config_file()
585
+
586
+ if value is not None:
587
+
588
+ # Update the config file with the value
589
+ config_data[name] = value
590
+
591
+ # Update the instance with the value
592
+ self.config_values[name] = value
593
+
594
+ # If defined by client, set and save!
595
+ self._write_config_file(config_data)
596
+
597
+ if not quiet:
598
+ self.print_config_value(name, separator=' is now set to: ')
599
+
600
+ def unset_config_value(self, name, quiet=False):
601
+ """unset a configuration value
602
+ Parameters
603
+ ==========
604
+ name: the name of the value to unset (remove key in dictionary)
605
+ quiet: disable verbose output if True (default is False)
606
+ """
607
+
608
+ config_data = self._read_config_file()
609
+
610
+ if name in config_data:
611
+
612
+ del config_data[name]
613
+
614
+ self._write_config_file(config_data)
615
+
616
+ if not quiet:
617
+ self.print_config_value(name, separator=' is now set to: ')
618
+
619
+ def get_config_value(self, name):
620
+ """ return a config value (with key name) if it's in the config_values,
621
+ otherwise return None
622
+
623
+ Parameters
624
+ ==========
625
+ name: the config value key to get
626
+
627
+ """
628
+ if name in self.config_values:
629
+ return self.config_values[name]
630
+
631
+ def get_default_download_dir(self, *subdirs):
632
+ """ Get the download path for a file. If not defined, return default
633
+ from config.
634
+
635
+ Parameters
636
+ ==========
637
+ subdirs: a single (or list of) subfolders under the basepath
638
+ """
639
+ # Look up value for key "path" in the config
640
+ path = self.get_config_value(self.CONFIG_NAME_PATH)
641
+
642
+ # If not set in config, default to present working directory
643
+ if path is None:
644
+ return os.getcwd()
645
+
646
+ return os.path.join(path, *subdirs)
647
+
648
+ def print_config_value(self, name, prefix='- ', separator=': '):
649
+ """print a single configuration value, based on a prefix and separator
650
+
651
+ Parameters
652
+ ==========
653
+ name: the key of the config valur in self.config_values to print
654
+ prefix: the prefix to print
655
+ separator: the separator to use (default is : )
656
+ """
657
+
658
+ value_out = 'None'
659
+ if name in self.config_values and self.config_values[name] is not None:
660
+ value_out = self.config_values[name]
661
+ print(prefix + name + separator + value_out)
662
+
663
+ def print_config_values(self, prefix='- '):
664
+ """a wrapper to print_config_value to print all configuration values
665
+ Parameters
666
+ ==========
667
+ prefix: the character prefix to put before the printed config value
668
+ defaults to "- "
669
+ """
670
+ print('Configuration values from ' + self.config_dir)
671
+ self.print_config_value(self.CONFIG_NAME_USER, prefix=prefix)
672
+ self.print_config_value(self.CONFIG_NAME_PATH, prefix=prefix)
673
+ self.print_config_value(self.CONFIG_NAME_PROXY, prefix=prefix)
674
+ self.print_config_value(self.CONFIG_NAME_COMPETITION, prefix=prefix)
675
+
676
+ def build_kaggle_client(self):
677
+ env = KaggleEnv.STAGING if '--staging' in self.args \
678
+ else KaggleEnv.ADMIN if '--admin' in self.args \
679
+ else KaggleEnv.LOCAL if '--local' in self.args \
680
+ else KaggleEnv.PROD
681
+ verbose = '--verbose' in self.args or '-v' in self.args
682
+ # config = self.api_client.configuration
683
+ return KaggleClient(
684
+ env=env,
685
+ verbose=verbose,
686
+ username=self.config_values['username'],
687
+ password=self.config_values['key'])
688
+
689
+ def camel_to_snake(self, name):
690
+ """
691
+ :param name: field in camel case
692
+ :return: field in snake case
693
+ """
694
+ name = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
695
+ return re.sub('([a-z0-9])([A-Z])', r'\1_\2', name).lower()
696
+
697
+ def lookup_enum(self, enum_class, item_name):
698
+ item = self.camel_to_snake(item_name).upper()
699
+ try:
700
+ return enum_class[item]
701
+ except KeyError:
702
+ prefix = self.camel_to_snake(enum_class.__name__).upper()
703
+ return enum_class[f'{prefix}_{self.camel_to_snake(item_name).upper()}']
704
+
705
+ def short_enum_name(self, value):
706
+ full_name = str(value)
707
+ names = full_name.split('.')
708
+ prefix_len = len(self.camel_to_snake(names[0])) + 1 # underscore
709
+ return names[1][prefix_len:].lower()
710
+
711
+ ## Competitions
712
+
713
+ def competitions_list(self,
714
+ group=None,
715
+ category=None,
716
+ sort_by=None,
717
+ page=1,
718
+ search=None):
719
+ """ Make a call to list competitions, format the response, and return
720
+ a list of ApiCompetition instances
721
+
722
+ Parameters
723
+ ==========
724
+
725
+ page: the page to return (default is 1)
726
+ search: a search term to use (default is empty string)
727
+ sort_by: how to sort the result, see valid_competition_sort_by for options
728
+ category: category to filter result to; use 'all' to get closed competitions
729
+ group: group to filter result to
730
+ """
731
+ if group:
732
+ if group not in self.valid_competition_groups:
733
+ raise ValueError('Invalid group specified. Valid options are ' +
734
+ str(self.valid_competition_groups))
735
+ if group == 'all':
736
+ group = CompetitionListTab.COMPETITION_LIST_TAB_EVERYTHING
737
+ else:
738
+ group = self.lookup_enum(CompetitionListTab, group)
739
+
740
+ if category:
741
+ if category not in self.valid_competition_categories:
742
+ raise ValueError('Invalid category specified. Valid options are ' +
743
+ str(self.valid_competition_categories))
744
+ category = self.lookup_enum(HostSegment, category)
745
+
746
+ if sort_by:
747
+ if sort_by not in self.valid_competition_sort_by:
748
+ raise ValueError('Invalid sort_by specified. Valid options are ' +
749
+ str(self.valid_competition_sort_by))
750
+ sort_by = self.lookup_enum(CompetitionSortBy, sort_by)
751
+
752
+ with self.build_kaggle_client() as kaggle:
753
+ request = ApiListCompetitionsRequest()
754
+ request.group = group
755
+ request.page = page
756
+ request.category = category
757
+ request.search = search
758
+ request.sort_by = sort_by
759
+ response = kaggle.competitions.competition_api_client.list_competitions(
760
+ request)
761
+ return response.competitions
762
+
763
+ def competitions_list_cli(self,
764
+ group=None,
765
+ category=None,
766
+ sort_by=None,
767
+ page=1,
768
+ search=None,
769
+ csv_display=False):
770
+ """ A wrapper for competitions_list for the client.
771
+
772
+ Parameters
773
+ ==========
774
+ group: group to filter result to
775
+ category: category to filter result to
776
+ sort_by: how to sort the result, see valid_sort_by for options
777
+ page: the page to return (default is 1)
778
+ search: a search term to use (default is empty string)
779
+ csv_display: if True, print comma separated values
780
+ """
781
+ competitions = self.competitions_list(
782
+ group=group,
783
+ category=category,
784
+ sort_by=sort_by,
785
+ page=page,
786
+ search=search)
787
+ if competitions:
788
+ if csv_display:
789
+ self.print_csv(competitions, self.competition_fields)
790
+ else:
791
+ self.print_table(competitions, self.competition_fields)
792
+ else:
793
+ print('No competitions found')
794
+
795
+ def competition_submit(self, file_name, message, competition, quiet=False):
796
+ """ Submit a competition.
797
+
798
+ Parameters
799
+ ==========
800
+ file_name: the competition metadata file
801
+ message: the submission description
802
+ competition: the competition name; if not given use the 'competition' config value
803
+ quiet: suppress verbose output (default is False)
804
+ """
805
+ if competition is None:
806
+ competition = self.get_config_value(self.CONFIG_NAME_COMPETITION)
807
+ if competition is not None and not quiet:
808
+ print('Using competition: ' + competition)
809
+
810
+ if competition is None:
811
+ raise ValueError('No competition specified')
812
+ else:
813
+ with self.build_kaggle_client() as kaggle:
814
+ request = ApiStartSubmissionUploadRequest()
815
+ request.competition_name = competition
816
+ request.file_name = os.path.basename(file_name)
817
+ request.content_length = os.path.getsize(file_name)
818
+ request.last_modified_epoch_seconds = int(os.path.getmtime(file_name))
819
+ response = kaggle.competitions.competition_api_client.start_submission_upload(
820
+ request)
821
+ upload_status = self.upload_complete(file_name, response.create_url,
822
+ quiet)
823
+ if upload_status != ResumableUploadResult.COMPLETE:
824
+ # Actual error is printed during upload_complete. Not
825
+ # ideal but changing would not be backwards compatible
826
+ return "Could not submit to competition"
827
+
828
+ submit_request = ApiCreateSubmissionRequest()
829
+ submit_request.competition_name = competition
830
+ submit_request.blob_file_tokens = response.token
831
+ submit_request.submission_description = message
832
+ submit_response = kaggle.competitions.competition_api_client.create_submission(
833
+ submit_request)
834
+ return submit_response
835
+
836
+ def competition_submit_cli(self,
837
+ file_name,
838
+ message,
839
+ competition,
840
+ competition_opt=None,
841
+ quiet=False):
842
+ """ Submit a competition using the client. Arguments are same as for
843
+ competition_submit, except for extra arguments provided here.
844
+
845
+ Parameters
846
+ ==========
847
+ file_name: the competition metadata file
848
+ message: the submission description
849
+ competition: the competition name; if not given use the 'competition' config value
850
+ quiet: suppress verbose output (default is False)
851
+ competition_opt: an alternative competition option provided by cli
852
+ """
853
+ competition = competition or competition_opt
854
+ try:
855
+ submit_result = self.competition_submit(file_name, message, competition,
856
+ quiet)
857
+ except RequestException as e:
858
+ if e.response and e.response.status_code == 404:
859
+ print('Could not find competition - please verify that you '
860
+ 'entered the correct competition ID and that the '
861
+ 'competition is still accepting submissions.')
862
+ return None
863
+ else:
864
+ raise e
865
+ return submit_result.message
866
+
867
+ def competition_submissions(self,
868
+ competition,
869
+ group=None,
870
+ sort=None,
871
+ page_token=0,
872
+ page_size=20):
873
+ """ Get the list of Submission for a particular competition.
874
+
875
+ Parameters
876
+ ==========
877
+ competition: the name of the competition
878
+ group: the submission group
879
+ sort: the sort-by option
880
+ page_token: token for pagination
881
+ page_size: the number of items per page
882
+ """
883
+ with self.build_kaggle_client() as kaggle:
884
+ request = ApiListSubmissionsRequest()
885
+ request.competition_name = competition
886
+ request.page = page_token
887
+ request.group = group
888
+ request.sort_by = sort
889
+ response = kaggle.competitions.competition_api_client.list_submissions(
890
+ request)
891
+ return response.submissions
892
+
893
+ def competition_submissions_cli(self,
894
+ competition=None,
895
+ competition_opt=None,
896
+ csv_display=False,
897
+ page_token=None,
898
+ page_size=20,
899
+ quiet=False):
900
+ """ A wrapper to competition_submission, will return either json or csv
901
+ to the user. Additional parameters are listed below, see
902
+ competition_submissions for rest.
903
+
904
+ Parameters
905
+ ==========
906
+ competition: the name of the competition. If None, look to config
907
+ competition_opt: an alternative competition option provided by cli
908
+ csv_display: if True, print comma separated values
909
+ page_token: token for pagination
910
+ page_size: the number of items per page
911
+ quiet: suppress verbose output (default is False)
912
+ """
913
+ competition = competition or competition_opt
914
+ if competition is None:
915
+ competition = self.get_config_value(self.CONFIG_NAME_COMPETITION)
916
+ if competition is not None and not quiet:
917
+ print('Using competition: ' + competition)
918
+
919
+ if competition is None:
920
+ raise ValueError('No competition specified')
921
+ else:
922
+ submissions = self.competition_submissions(
923
+ competition, page_token=page_token, page_size=page_size)
924
+ if submissions:
925
+ if csv_display:
926
+ self.print_csv(submissions, self.submission_fields)
927
+ else:
928
+ self.print_table(submissions, self.submission_fields)
929
+ else:
930
+ print('No submissions found')
931
+
932
+ def competition_list_files(self, competition, page_token=None, page_size=20):
933
+ """ List files for a competition.
934
+ Parameters
935
+ ==========
936
+ competition: the name of the competition
937
+ page_token: the page token for pagination
938
+ page_size: the number of items per page
939
+ """
940
+ with self.build_kaggle_client() as kaggle:
941
+ request = ApiListDataFilesRequest()
942
+ request.competition_name = competition
943
+ request.page_token = page_token
944
+ request.page_size = page_size
945
+ response = kaggle.competitions.competition_api_client.list_data_files(
946
+ request)
947
+ return response
948
+
949
+ def competition_list_files_cli(self,
950
+ competition,
951
+ competition_opt=None,
952
+ csv_display=False,
953
+ page_token=None,
954
+ page_size=20,
955
+ quiet=False):
956
+ """ List files for a competition, if it exists.
957
+
958
+ Parameters
959
+ ==========
960
+ competition: the name of the competition. If None, look to config
961
+ competition_opt: an alternative competition option provided by cli
962
+ csv_display: if True, print comma separated values
963
+ page_token: the page token for pagination
964
+ page_size: the number of items per page
965
+ quiet: suppress verbose output (default is False)
966
+ """
967
+ competition = competition or competition_opt
968
+ if competition is None:
969
+ competition = self.get_config_value(self.CONFIG_NAME_COMPETITION)
970
+ if competition is not None and not quiet:
971
+ print('Using competition: ' + competition)
972
+
973
+ if competition is None:
974
+ raise ValueError('No competition specified')
975
+ else:
976
+ result = self.competition_list_files(competition, page_token, page_size)
977
+ next_page_token = result.next_page_token
978
+ if next_page_token:
979
+ print('Next Page Token = {}'.format(next_page_token))
980
+ if result:
981
+ if csv_display:
982
+ self.print_csv(result.files, self.competition_file_fields,
983
+ self.competition_file_labels)
984
+ else:
985
+ self.print_table(result.files, self.competition_file_fields,
986
+ self.competition_file_labels)
987
+ else:
988
+ print('No files found')
989
+
990
+ def competition_download_file(self,
991
+ competition,
992
+ file_name,
993
+ path=None,
994
+ force=False,
995
+ quiet=False):
996
+ """ Download a competition file to a designated location, or use
997
+ a default location.
998
+
999
+ Parameters
1000
+ =========
1001
+ competition: the name of the competition
1002
+ file_name: the configuration file name
1003
+ path: a path to download the file to
1004
+ force: force the download if the file already exists (default False)
1005
+ quiet: suppress verbose output (default is False)
1006
+ """
1007
+ if path is None:
1008
+ effective_path = self.get_default_download_dir('competitions',
1009
+ competition)
1010
+ else:
1011
+ effective_path = path
1012
+
1013
+ with self.build_kaggle_client() as kaggle:
1014
+ request = ApiDownloadDataFileRequest()
1015
+ request.competition_name = competition
1016
+ request.file_name = file_name
1017
+ response = kaggle.competitions.competition_api_client.download_data_file(
1018
+ request)
1019
+ url = response.history[0].url
1020
+ outfile = os.path.join(effective_path, url.split('?')[0].split('/')[-1])
1021
+
1022
+ if force or self.download_needed(response, outfile, quiet):
1023
+ self.download_file(response, outfile, kaggle.http_client(), quiet,
1024
+ not force)
1025
+
1026
+ def competition_download_files(self,
1027
+ competition,
1028
+ path=None,
1029
+ force=False,
1030
+ quiet=True):
1031
+ """ Download all competition files.
1032
+
1033
+ Parameters
1034
+ =========
1035
+ competition: the name of the competition
1036
+ path: a path to download the file to
1037
+ force: force the download if the file already exists (default False)
1038
+ quiet: suppress verbose output (default is True)
1039
+ """
1040
+ if path is None:
1041
+ effective_path = self.get_default_download_dir('competitions',
1042
+ competition)
1043
+ else:
1044
+ effective_path = path
1045
+
1046
+ with self.build_kaggle_client() as kaggle:
1047
+ request = ApiDownloadDataFilesRequest()
1048
+ request.competition_name = competition
1049
+ response = kaggle.competitions.competition_api_client.download_data_files(
1050
+ request)
1051
+ url = response.url.split('?')[0]
1052
+ outfile = os.path.join(effective_path,
1053
+ competition + '.' + url.split('.')[-1])
1054
+
1055
+ if force or self.download_needed(response, outfile, quiet):
1056
+ self.download_file(response, outfile, quiet, not force)
1057
+
1058
+ def competition_download_cli(self,
1059
+ competition,
1060
+ competition_opt=None,
1061
+ file_name=None,
1062
+ path=None,
1063
+ force=False,
1064
+ quiet=False):
1065
+ """ A wrapper to competition_download_files, but first will parse input
1066
+ from API client. Additional parameters are listed here, see
1067
+ competition_download for remaining.
1068
+
1069
+ Parameters
1070
+ =========
1071
+ competition: the name of the competition
1072
+ competition_opt: an alternative competition option provided by cli
1073
+ file_name: the configuration file name
1074
+ path: a path to download the file to
1075
+ force: force the download if the file already exists (default False)
1076
+ quiet: suppress verbose output (default is False)
1077
+ """
1078
+ competition = competition or competition_opt
1079
+ if competition is None:
1080
+ competition = self.get_config_value(self.CONFIG_NAME_COMPETITION)
1081
+ if competition is not None and not quiet:
1082
+ print('Using competition: ' + competition)
1083
+
1084
+ if competition is None:
1085
+ raise ValueError('No competition specified')
1086
+ else:
1087
+ if file_name is None:
1088
+ self.competition_download_files(competition, path, force, quiet)
1089
+ else:
1090
+ self.competition_download_file(competition, file_name, path, force,
1091
+ quiet)
1092
+
1093
+ def competition_leaderboard_download(self, competition, path, quiet=True):
1094
+ """ Download a competition leaderboard.
1095
+
1096
+ Parameters
1097
+ =========
1098
+ competition: the name of the competition
1099
+ path: a path to download the file to
1100
+ quiet: suppress verbose output (default is True)
1101
+ """
1102
+ with self.build_kaggle_client() as kaggle:
1103
+ request = ApiDownloadLeaderboardRequest()
1104
+ request.competition_name = competition
1105
+ response = kaggle.competitions.competition_api_client.download_leaderboard(
1106
+ request)
1107
+ if path is None:
1108
+ effective_path = self.get_default_download_dir('competitions',
1109
+ competition)
1110
+ else:
1111
+ effective_path = path
1112
+
1113
+ file_name = competition + '.zip'
1114
+ outfile = os.path.join(effective_path, file_name)
1115
+ self.download_file(response, outfile, quiet)
1116
+
1117
+ def competition_leaderboard_view(self, competition):
1118
+ """ View a leaderboard based on a competition name.
1119
+
1120
+ Parameters
1121
+ ==========
1122
+ competition: the competition name to view leadboard for
1123
+ """
1124
+ with self.build_kaggle_client() as kaggle:
1125
+ request = ApiGetLeaderboardRequest()
1126
+ request.competition_name = competition
1127
+ response = kaggle.competitions.competition_api_client.get_leaderboard(
1128
+ request)
1129
+ return response.submissions
1130
+
1131
+ def competition_leaderboard_cli(self,
1132
+ competition,
1133
+ competition_opt=None,
1134
+ path=None,
1135
+ view=False,
1136
+ download=False,
1137
+ csv_display=False,
1138
+ quiet=False):
1139
+ """ A wrapper for competition_leaderbord_view that will print the
1140
+ results as a table or comma separated values
1141
+
1142
+ Parameters
1143
+ ==========
1144
+ competition: the competition name to view leadboard for
1145
+ competition_opt: an alternative competition option provided by cli
1146
+ path: a path to download to, if download is True
1147
+ view: if True, show the results in the terminal as csv or table
1148
+ download: if True, download the entire leaderboard
1149
+ csv_display: if True, print comma separated values instead of table
1150
+ quiet: suppress verbose output (default is False)
1151
+ """
1152
+ competition = competition or competition_opt
1153
+ if not view and not download:
1154
+ raise ValueError('Either --show or --download must be specified')
1155
+
1156
+ if competition is None:
1157
+ competition = self.get_config_value(self.CONFIG_NAME_COMPETITION)
1158
+ if competition is not None and not quiet:
1159
+ print('Using competition: ' + competition)
1160
+
1161
+ if competition is None:
1162
+ raise ValueError('No competition specified')
1163
+
1164
+ if download:
1165
+ self.competition_leaderboard_download(competition, path, quiet)
1166
+
1167
+ if view:
1168
+ results = self.competition_leaderboard_view(competition)
1169
+ if results:
1170
+ if csv_display:
1171
+ self.print_csv(results, self.competition_leaderboard_fields)
1172
+ else:
1173
+ self.print_table(results, self.competition_leaderboard_fields)
1174
+ else:
1175
+ print('No results found')
1176
+
1177
+ def dataset_list(self,
1178
+ sort_by=None,
1179
+ size=None,
1180
+ file_type=None,
1181
+ license_name=None,
1182
+ tag_ids=None,
1183
+ search=None,
1184
+ user=None,
1185
+ mine=False,
1186
+ page=1,
1187
+ max_size=None,
1188
+ min_size=None):
1189
+ """ Return a list of datasets.
1190
+
1191
+ Parameters
1192
+ ==========
1193
+ sort_by: how to sort the result, see valid_dataset_sort_bys for options
1194
+ size: Deprecated
1195
+ file_type: the format, see valid_dataset_file_types for string options
1196
+ license_name: string descriptor for license, see valid_dataset_license_names
1197
+ tag_ids: tag identifiers to filter the search
1198
+ search: a search term to use (default is empty string)
1199
+ user: username to filter the search to
1200
+ mine: boolean if True, group is changed to "my" to return personal
1201
+ page: the page to return (default is 1)
1202
+ max_size: the maximum size of the dataset to return (bytes)
1203
+ min_size: the minimum size of the dataset to return (bytes)
1204
+ """
1205
+ if sort_by:
1206
+ if sort_by not in self.valid_dataset_sort_bys:
1207
+ raise ValueError('Invalid sort by specified. Valid options are ' +
1208
+ str(self.valid_dataset_sort_bys))
1209
+ else:
1210
+ sort_by = self.lookup_enum(DatasetSortBy, sort_by)
1211
+
1212
+ if size:
1213
+ raise ValueError(
1214
+ 'The --size parameter has been deprecated. ' +
1215
+ 'Please use --max-size and --min-size to filter dataset sizes.')
1216
+
1217
+ if file_type:
1218
+ if file_type not in self.valid_dataset_file_types:
1219
+ raise ValueError('Invalid file type specified. Valid options are ' +
1220
+ str(self.valid_dataset_file_types))
1221
+ else:
1222
+ file_type = self.lookup_enum(DatasetFileTypeGroup, file_type)
1223
+
1224
+ if license_name:
1225
+ if license_name not in self.valid_dataset_license_names:
1226
+ raise ValueError('Invalid license specified. Valid options are ' +
1227
+ str(self.valid_dataset_license_names))
1228
+ else:
1229
+ license_name = self.lookup_enum(DatasetLicenseGroup, license_name)
1230
+
1231
+ if int(page) <= 0:
1232
+ raise ValueError('Page number must be >= 1')
1233
+
1234
+ if max_size and min_size:
1235
+ if int(max_size) < int(min_size):
1236
+ raise ValueError('Max Size must be max_size >= min_size')
1237
+ if max_size and int(max_size) <= 0:
1238
+ raise ValueError('Max Size must be > 0')
1239
+ elif min_size and int(min_size) < 0:
1240
+ raise ValueError('Min Size must be >= 0')
1241
+
1242
+ group = DatasetSelectionGroup.DATASET_SELECTION_GROUP_PUBLIC
1243
+ if mine:
1244
+ group = DatasetSelectionGroup.DATASET_SELECTION_GROUP_MY
1245
+ if user:
1246
+ raise ValueError('Cannot specify both mine and a user')
1247
+ if user:
1248
+ group = DatasetSelectionGroup.DATASET_SELECTION_GROUP_USER
1249
+
1250
+ with self.build_kaggle_client() as kaggle:
1251
+ request = ApiListDatasetsRequest()
1252
+ request.group = group
1253
+ request.sort_by = sort_by
1254
+ request.file_type = file_type
1255
+ request.license = license_name
1256
+ request.tag_ids = tag_ids
1257
+ request.search = search
1258
+ request.user = user
1259
+ request.page = page
1260
+ request.max_size = max_size
1261
+ request.min_size = min_size
1262
+ response = kaggle.datasets.dataset_api_client.list_datasets(request)
1263
+ return response.datasets
1264
+
1265
+ def dataset_list_cli(self,
1266
+ sort_by=None,
1267
+ size=None,
1268
+ file_type=None,
1269
+ license_name=None,
1270
+ tag_ids=None,
1271
+ search=None,
1272
+ user=None,
1273
+ mine=False,
1274
+ page=1,
1275
+ csv_display=False,
1276
+ max_size=None,
1277
+ min_size=None):
1278
+ """ A wrapper to dataset_list for the client. Additional parameters
1279
+ are described here, see dataset_list for others.
1280
+
1281
+ Parameters
1282
+ ==========
1283
+ sort_by: how to sort the result, see valid_dataset_sort_bys for options
1284
+ size: DEPRECATED
1285
+ file_type: the format, see valid_dataset_file_types for string options
1286
+ license_name: string descriptor for license, see valid_dataset_license_names
1287
+ tag_ids: tag identifiers to filter the search
1288
+ search: a search term to use (default is empty string)
1289
+ user: username to filter the search to
1290
+ mine: boolean if True, group is changed to "my" to return personal
1291
+ page: the page to return (default is 1)
1292
+ csv_display: if True, print comma separated values instead of table
1293
+ max_size: the maximum size of the dataset to return (bytes)
1294
+ min_size: the minimum size of the dataset to return (bytes)
1295
+ """
1296
+ datasets = self.dataset_list(sort_by, size, file_type, license_name,
1297
+ tag_ids, search, user, mine, page, max_size,
1298
+ min_size)
1299
+ if datasets:
1300
+ if csv_display:
1301
+ self.print_csv(datasets, self.dataset_fields, self.dataset_labels)
1302
+ else:
1303
+ self.print_table(datasets, self.dataset_fields, self.dataset_labels)
1304
+ else:
1305
+ print('No datasets found')
1306
+
1307
+ def dataset_metadata_prep(self, dataset, path):
1308
+ if dataset is None:
1309
+ raise ValueError('A dataset must be specified')
1310
+ if '/' in dataset:
1311
+ self.validate_dataset_string(dataset)
1312
+ dataset_urls = dataset.split('/')
1313
+ owner_slug = dataset_urls[0]
1314
+ dataset_slug = dataset_urls[1]
1315
+ else:
1316
+ owner_slug = self.get_config_value(self.CONFIG_NAME_USER)
1317
+ dataset_slug = dataset
1318
+
1319
+ if path is None:
1320
+ effective_path = self.get_default_download_dir('datasets', owner_slug,
1321
+ dataset_slug)
1322
+ else:
1323
+ effective_path = path
1324
+
1325
+ return (owner_slug, dataset_slug, effective_path)
1326
+
1327
+ def dataset_metadata_update(self, dataset, path):
1328
+ (owner_slug, dataset_slug,
1329
+ effective_path) = self.dataset_metadata_prep(dataset, path)
1330
+ meta_file = self.get_dataset_metadata_file(effective_path)
1331
+ with open(meta_file, 'r') as f:
1332
+ s = json.load(f)
1333
+ metadata = json.loads(s)
1334
+ update_settings = DatasetSettings()
1335
+ update_settings.title = metadata.get('title') or ''
1336
+ update_settings.subtitle = metadata.get('subtitle') or ''
1337
+ update_settings.description = metadata.get('description') or ''
1338
+ update_settings.is_private = metadata.get('isPrivate') or False
1339
+ update_settings.licenses = [
1340
+ self._new_license(l['name']) for l in metadata['licenses']
1341
+ ] if metadata.get('licenses') else []
1342
+ update_settings.keywords = metadata.get('keywords')
1343
+ update_settings.collaborators = [
1344
+ self._new_collaborator(c['username'], c['role'])
1345
+ for c in metadata['collaborators']
1346
+ ] if metadata.get('collaborators') else []
1347
+ update_settings.data = metadata.get('data')
1348
+ request = ApiUpdateDatasetMetadataRequest()
1349
+ request.owner_slug = owner_slug
1350
+ request.dataset_slug = dataset_slug
1351
+ request.settings = update_settings
1352
+ with self.build_kaggle_client() as kaggle:
1353
+ response = kaggle.datasets.dataset_api_client.update_dataset_metadata(
1354
+ request)
1355
+ if len(response.errors) > 0:
1356
+ [print(e['message']) for e in response.errors]
1357
+ exit(1)
1358
+
1359
+ @staticmethod
1360
+ def _new_license(name):
1361
+ l = SettingsLicense()
1362
+ l.name = name
1363
+ return l
1364
+
1365
+ @staticmethod
1366
+ def _new_collaborator(name, role):
1367
+ u = DatasetCollaborator()
1368
+ u.username = name
1369
+ u.role = role
1370
+ return u
1371
+
1372
+ def dataset_metadata(self, dataset, path):
1373
+ (owner_slug, dataset_slug,
1374
+ effective_path) = self.dataset_metadata_prep(dataset, path)
1375
+
1376
+ if not os.path.exists(effective_path):
1377
+ os.makedirs(effective_path)
1378
+
1379
+ with self.build_kaggle_client() as kaggle:
1380
+ request = ApiGetDatasetMetadataRequest()
1381
+ request.owner_slug = owner_slug
1382
+ request.dataset_slug = dataset_slug
1383
+ response = kaggle.datasets.dataset_api_client.get_dataset_metadata(
1384
+ request)
1385
+ if response.error_message:
1386
+ raise Exception(response.error_message)
1387
+
1388
+ meta_file = os.path.join(effective_path, self.DATASET_METADATA_FILE)
1389
+ with open(meta_file, 'w') as f:
1390
+ json.dump(
1391
+ response.to_json(response.info),
1392
+ f,
1393
+ indent=2,
1394
+ default=lambda o: o.__dict__)
1395
+
1396
+ return meta_file
1397
+
1398
+ def dataset_metadata_cli(self, dataset, path, update, dataset_opt=None):
1399
+ dataset = dataset or dataset_opt
1400
+ if (update):
1401
+ print('updating dataset metadata')
1402
+ self.dataset_metadata_update(dataset, path)
1403
+ print('successfully updated dataset metadata')
1404
+ else:
1405
+ meta_file = self.dataset_metadata(dataset, path)
1406
+ print('Downloaded metadata to ' + meta_file)
1407
+
1408
+ def dataset_list_files(self, dataset, page_token=None, page_size=20):
1409
+ """ List files for a dataset.
1410
+
1411
+ Parameters
1412
+ ==========
1413
+ dataset: the string identified of the dataset
1414
+ should be in format [owner]/[dataset-name]
1415
+ page_token: the page token for pagination
1416
+ page_size: the number of items per page
1417
+ """
1418
+ if dataset is None:
1419
+ raise ValueError('A dataset must be specified')
1420
+ owner_slug, dataset_slug, dataset_version_number = self.split_dataset_string(
1421
+ dataset)
1422
+
1423
+ with self.build_kaggle_client() as kaggle:
1424
+ request = ApiListDatasetFilesRequest()
1425
+ request.owner_slug = owner_slug
1426
+ request.dataset_slug = dataset_slug
1427
+ request.dataset_version_number = dataset_version_number
1428
+ request.page_token = page_token
1429
+ request.page_size = page_size
1430
+ response = kaggle.datasets.dataset_api_client.list_dataset_files(request)
1431
+ return response
1432
+
1433
+ def dataset_list_files_cli(self,
1434
+ dataset,
1435
+ dataset_opt=None,
1436
+ csv_display=False,
1437
+ page_token=None,
1438
+ page_size=20):
1439
+ """ A wrapper to dataset_list_files for the client
1440
+ (list files for a dataset).
1441
+ Parameters
1442
+ ==========
1443
+ dataset: the string identified of the dataset
1444
+ should be in format [owner]/[dataset-name]
1445
+ dataset_opt: an alternative option to providing a dataset
1446
+ csv_display: if True, print comma separated values instead of table
1447
+ page_token: the page token for pagination
1448
+ page_size: the number of items per page
1449
+ """
1450
+ dataset = dataset or dataset_opt
1451
+ result = self.dataset_list_files(dataset, page_token, page_size)
1452
+
1453
+ if result:
1454
+ if result.error_message:
1455
+ print(result.error_message)
1456
+ else:
1457
+ next_page_token = result.next_page_token
1458
+ if next_page_token:
1459
+ print('Next Page Token = {}'.format(next_page_token))
1460
+ fields = ['name', 'size', 'creationDate']
1461
+ ApiDatasetFile.size = ApiDatasetFile.total_bytes
1462
+ if csv_display:
1463
+ self.print_csv(result.files, fields)
1464
+ else:
1465
+ self.print_table(result.files, fields)
1466
+ else:
1467
+ print('No files found')
1468
+
1469
+ def dataset_status(self, dataset):
1470
+ """ Call to get the status of a dataset from the API.
1471
+ Parameters
1472
+ ==========
1473
+ dataset: the string identifier of the dataset
1474
+ should be in format [owner]/[dataset-name]
1475
+ """
1476
+ if dataset is None:
1477
+ raise ValueError('A dataset must be specified')
1478
+ if '/' in dataset:
1479
+ self.validate_dataset_string(dataset)
1480
+ dataset_urls = dataset.split('/')
1481
+ owner_slug = dataset_urls[0]
1482
+ dataset_slug = dataset_urls[1]
1483
+ else:
1484
+ owner_slug = self.get_config_value(self.CONFIG_NAME_USER)
1485
+ dataset_slug = dataset
1486
+
1487
+ with self.build_kaggle_client() as kaggle:
1488
+ request = ApiGetDatasetStatusRequest()
1489
+ request.owner_slug = owner_slug
1490
+ request.dataset_slug = dataset_slug
1491
+ response = kaggle.datasets.dataset_api_client.get_dataset_status(request)
1492
+ return response.status.name.lower()
1493
+
1494
+ def dataset_status_cli(self, dataset, dataset_opt=None):
1495
+ """ A wrapper for client for dataset_status, with additional
1496
+ dataset_opt to get the status of a dataset from the API.
1497
+ Parameters
1498
+ ==========
1499
+ dataset_opt: an alternative to dataset
1500
+ """
1501
+ dataset = dataset or dataset_opt
1502
+ return self.dataset_status(dataset)
1503
+
1504
+ def dataset_download_file(self,
1505
+ dataset,
1506
+ file_name,
1507
+ path=None,
1508
+ force=False,
1509
+ quiet=True,
1510
+ licenses=[]):
1511
+ """ Download a single file for a dataset.
1512
+
1513
+ Parameters
1514
+ ==========
1515
+ dataset: the string identified of the dataset
1516
+ should be in format [owner]/[dataset-name]
1517
+ file_name: the dataset configuration file
1518
+ path: if defined, download to this location
1519
+ force: force the download if the file already exists (default False)
1520
+ quiet: suppress verbose output (default is True)
1521
+ licenses: a list of license names, e.g. ['CC0-1.0']
1522
+ """
1523
+ if '/' in dataset:
1524
+ self.validate_dataset_string(dataset)
1525
+ owner_slug, dataset_slug, dataset_version_number = self.split_dataset_string(
1526
+ dataset)
1527
+ else:
1528
+ owner_slug = self.get_config_value(self.CONFIG_NAME_USER)
1529
+ dataset_slug = dataset
1530
+ dataset_version_number = None
1531
+
1532
+ if path is None:
1533
+ effective_path = self.get_default_download_dir('datasets', owner_slug,
1534
+ dataset_slug)
1535
+ else:
1536
+ effective_path = path
1537
+
1538
+ self._print_dataset_url_and_license(owner_slug, dataset_slug,
1539
+ dataset_version_number, licenses)
1540
+
1541
+ with self.build_kaggle_client() as kaggle:
1542
+ request = ApiDownloadDatasetRequest()
1543
+ request.owner_slug = owner_slug
1544
+ request.dataset_slug = dataset_slug
1545
+ request.dataset_version_number = dataset_version_number
1546
+ request.file_name = file_name
1547
+ response = kaggle.datasets.dataset_api_client.download_dataset(request)
1548
+ url = response.history[0].url
1549
+ outfile = os.path.join(effective_path, url.split('?')[0].split('/')[-1])
1550
+
1551
+ if force or self.download_needed(response, outfile, quiet):
1552
+ self.download_file(response, outfile, quiet, not force)
1553
+ return True
1554
+ else:
1555
+ return False
1556
+
1557
+ def dataset_download_files(self,
1558
+ dataset,
1559
+ path=None,
1560
+ force=False,
1561
+ quiet=True,
1562
+ unzip=False,
1563
+ licenses=[]):
1564
+ """ Download all files for a dataset.
1565
+
1566
+ Parameters
1567
+ ==========
1568
+ dataset: the string identified of the dataset
1569
+ should be in format [owner]/[dataset-name]
1570
+ path: the path to download the dataset to
1571
+ force: force the download if the file already exists (default False)
1572
+ quiet: suppress verbose output (default is True)
1573
+ unzip: if True, unzip files upon download (default is False)
1574
+ licenses: a list of license names, e.g. ['CC0-1.0']
1575
+ """
1576
+ if dataset is None:
1577
+ raise ValueError('A dataset must be specified')
1578
+ owner_slug, dataset_slug, dataset_version_number = self.split_dataset_string(
1579
+ dataset)
1580
+ if path is None:
1581
+ effective_path = self.get_default_download_dir('datasets', owner_slug,
1582
+ dataset_slug)
1583
+ else:
1584
+ effective_path = path
1585
+
1586
+ self._print_dataset_url_and_license(owner_slug, dataset_slug,
1587
+ dataset_version_number, licenses)
1588
+
1589
+ with self.build_kaggle_client() as kaggle:
1590
+ request = ApiDownloadDatasetRequest()
1591
+ request.owner_slug = owner_slug
1592
+ request.dataset_slug = dataset_slug
1593
+ request.dataset_version_number = dataset_version_number
1594
+ response = kaggle.datasets.dataset_api_client.download_dataset(request)
1595
+
1596
+ outfile = os.path.join(effective_path, dataset_slug + '.zip')
1597
+ if force or self.download_needed(response, outfile, quiet):
1598
+ self.download_file(response, outfile, quiet, not force)
1599
+ downloaded = True
1600
+ else:
1601
+ downloaded = False
1602
+
1603
+ if downloaded:
1604
+ outfile = os.path.join(effective_path, dataset_slug + '.zip')
1605
+ if unzip:
1606
+ try:
1607
+ with zipfile.ZipFile(outfile) as z:
1608
+ z.extractall(effective_path)
1609
+ except zipfile.BadZipFile as e:
1610
+ raise ValueError(
1611
+ f"The file {outfile} is corrupted or not a valid zip file. "
1612
+ "Please report this issue at https://www.github.com/kaggle/kaggle-api"
1613
+ )
1614
+ except FileNotFoundError:
1615
+ raise FileNotFoundError(
1616
+ f"The file {outfile} was not found. "
1617
+ "Please report this issue at https://www.github.com/kaggle/kaggle-api"
1618
+ )
1619
+ except Exception as e:
1620
+ raise RuntimeError(
1621
+ f"An unexpected error occurred: {e}. "
1622
+ "Please report this issue at https://www.github.com/kaggle/kaggle-api"
1623
+ )
1624
+
1625
+ try:
1626
+ os.remove(outfile)
1627
+ except OSError as e:
1628
+ print('Could not delete zip file, got %s' % e)
1629
+
1630
+ def _print_dataset_url_and_license(self, owner_slug, dataset_slug,
1631
+ dataset_version_number, licenses):
1632
+ if dataset_version_number is None:
1633
+ print('Dataset URL: https://www.kaggle.com/datasets/%s/%s' %
1634
+ (owner_slug, dataset_slug))
1635
+ else:
1636
+ print('Dataset URL: https://www.kaggle.com/datasets/%s/%s/versions/%s' %
1637
+ (owner_slug, dataset_slug, dataset_version_number))
1638
+
1639
+ if len(licenses) > 0:
1640
+ print('License(s): %s' % (','.join(licenses)))
1641
+
1642
+ def dataset_download_cli(self,
1643
+ dataset,
1644
+ dataset_opt=None,
1645
+ file_name=None,
1646
+ path=None,
1647
+ unzip=False,
1648
+ force=False,
1649
+ quiet=False):
1650
+ """ Client wrapper for dataset_download_files and download dataset file,
1651
+ either for a specific file (when file_name is provided),
1652
+ or all files for a dataset (plural).
1653
+
1654
+ Parameters
1655
+ ==========
1656
+ dataset: the string identified of the dataset
1657
+ should be in format [owner]/[dataset-name]
1658
+ dataset_opt: an alternative option to providing a dataset
1659
+ file_name: the dataset configuration file
1660
+ path: the path to download the dataset to
1661
+ force: force the download if the file already exists (default False)
1662
+ quiet: suppress verbose output (default is False)
1663
+ unzip: if True, unzip files upon download (default is False)
1664
+ """
1665
+ dataset = dataset or dataset_opt
1666
+
1667
+ owner_slug, dataset_slug, _ = self.split_dataset_string(dataset)
1668
+ metadata = self.process_response(
1669
+ self.metadata_get_with_http_info(owner_slug, dataset_slug))
1670
+
1671
+ if 'info' in metadata and 'licenses' in metadata['info']:
1672
+ # license_objs format is like: [{ 'name': 'CC0-1.0' }]
1673
+ license_objs = metadata['info']['licenses']
1674
+ licenses = [
1675
+ license_obj['name']
1676
+ for license_obj in license_objs
1677
+ if 'name' in license_obj
1678
+ ]
1679
+ else:
1680
+ licenses = [
1681
+ 'Error retrieving license. Please visit the Dataset URL to view license information.'
1682
+ ]
1683
+
1684
+ if file_name is None:
1685
+ self.dataset_download_files(
1686
+ dataset,
1687
+ path=path,
1688
+ unzip=unzip,
1689
+ force=force,
1690
+ quiet=quiet,
1691
+ licenses=licenses)
1692
+ else:
1693
+ self.dataset_download_file(
1694
+ dataset,
1695
+ file_name,
1696
+ path=path,
1697
+ force=force,
1698
+ quiet=quiet,
1699
+ licenses=licenses)
1700
+
1701
+ def _upload_blob(self, path, quiet, blob_type, upload_context):
1702
+ """ Upload a file.
1703
+
1704
+ Parameters
1705
+ ==========
1706
+ path: the complete path to upload
1707
+ quiet: suppress verbose output (default is False)
1708
+ blob_type (ApiBlobType): To which entity the file/blob refers
1709
+ upload_context (ResumableUploadContext): Context for resumable uploads
1710
+ """
1711
+ file_name = os.path.basename(path)
1712
+ content_length = os.path.getsize(path)
1713
+ last_modified_epoch_seconds = int(os.path.getmtime(path))
1714
+
1715
+ start_blob_upload_request = ApiStartBlobUploadRequest()
1716
+ start_blob_upload_request.type = blob_type
1717
+ start_blob_upload_request.name = file_name
1718
+ start_blob_upload_request.content_length = content_length
1719
+ start_blob_upload_request.last_modified_epoch_seconds = last_modified_epoch_seconds
1720
+
1721
+ file_upload = upload_context.new_resumable_file_upload(
1722
+ path, start_blob_upload_request)
1723
+
1724
+ for i in range(0, self.MAX_UPLOAD_RESUME_ATTEMPTS):
1725
+ if file_upload.upload_complete:
1726
+ return file_upload
1727
+
1728
+ if not file_upload.can_resume:
1729
+ # Initiate upload on Kaggle backend to get the url and token.
1730
+ with self.build_kaggle_client() as kaggle:
1731
+ method = kaggle.blobs.blob_api_client.start_blob_upload
1732
+ start_blob_upload_response = self.with_retry(method)(
1733
+ file_upload.start_blob_upload_request)
1734
+ file_upload.upload_initiated(start_blob_upload_response)
1735
+
1736
+ upload_result = self.upload_complete(
1737
+ path,
1738
+ file_upload.start_blob_upload_response.create_url,
1739
+ quiet,
1740
+ resume=file_upload.can_resume)
1741
+ if upload_result == ResumableUploadResult.INCOMPLETE:
1742
+ continue # Continue (i.e., retry/resume) only if the upload is incomplete.
1743
+
1744
+ if upload_result == ResumableUploadResult.COMPLETE:
1745
+ file_upload.upload_completed()
1746
+ break
1747
+
1748
+ return file_upload.get_token()
1749
+
1750
+ def dataset_create_version(self,
1751
+ folder,
1752
+ version_notes,
1753
+ quiet=False,
1754
+ convert_to_csv=True,
1755
+ delete_old_versions=False,
1756
+ dir_mode='skip'):
1757
+ """ Create a version of a dataset.
1758
+
1759
+ Parameters
1760
+ ==========
1761
+ folder: the folder with the dataset configuration / data files
1762
+ version_notes: notes to add for the version
1763
+ quiet: suppress verbose output (default is False)
1764
+ convert_to_csv: on upload, if data should be converted to csv
1765
+ delete_old_versions: if True, do that (default False)
1766
+ dir_mode: What to do with directories: "skip" - ignore; "zip" - compress and upload
1767
+ """
1768
+ if not os.path.isdir(folder):
1769
+ raise ValueError('Invalid folder: ' + folder)
1770
+
1771
+ meta_file = self.get_dataset_metadata_file(folder)
1772
+
1773
+ # read json
1774
+ with open(meta_file) as f:
1775
+ meta_data = json.load(f)
1776
+ ref = self.get_or_default(meta_data, 'id', None)
1777
+ id_no = self.get_or_default(meta_data, 'id_no', None)
1778
+ if not ref and not id_no:
1779
+ raise ValueError('ID or slug must be specified in the metadata')
1780
+ elif ref and ref == self.config_values[
1781
+ self.CONFIG_NAME_USER] + '/INSERT_SLUG_HERE':
1782
+ raise ValueError(
1783
+ 'Default slug detected, please change values before uploading')
1784
+
1785
+ subtitle = meta_data.get('subtitle')
1786
+ if subtitle and (len(subtitle) < 20 or len(subtitle) > 80):
1787
+ raise ValueError('Subtitle length must be between 20 and 80 characters')
1788
+ resources = meta_data.get('resources')
1789
+ if resources:
1790
+ self.validate_resources(folder, resources)
1791
+
1792
+ description = meta_data.get('description')
1793
+ keywords = self.get_or_default(meta_data, 'keywords', [])
1794
+
1795
+ body = ApiCreateDatasetVersionRequestBody()
1796
+ body.version_notes = version_notes
1797
+ body.subtitle = subtitle
1798
+ body.description = description
1799
+ body.files = []
1800
+ body.category_ids = keywords
1801
+ body.delete_old_versions = delete_old_versions
1802
+
1803
+ with self.build_kaggle_client() as kaggle:
1804
+ if id_no:
1805
+ request = ApiCreateDatasetVersionByIdRequest()
1806
+ request.id = id_no
1807
+ message = kaggle.datasets.dataset_api_client.create_dataset_version_by_id
1808
+ else:
1809
+ self.validate_dataset_string(ref)
1810
+ ref_list = ref.split('/')
1811
+ owner_slug = ref_list[0]
1812
+ dataset_slug = ref_list[1]
1813
+ request = ApiCreateDatasetVersionRequest()
1814
+ request.owner_slug = owner_slug
1815
+ request.dataset_slug = dataset_slug
1816
+ message = kaggle.datasets.dataset_api_client.create_dataset_version
1817
+ request.body = body
1818
+ with ResumableUploadContext() as upload_context:
1819
+ self.upload_files(body, resources, folder, ApiBlobType.DATASET,
1820
+ upload_context, quiet, dir_mode)
1821
+ request.body.files = [
1822
+ self._api_dataset_new_file(file) for file in request.body.files
1823
+ ]
1824
+ response = self.with_retry(message)(request)
1825
+ return response
1826
+
1827
+ def _api_dataset_new_file(self, file):
1828
+ # TODO Eliminate the need for this conversion
1829
+ f = ApiDatasetNewFile()
1830
+ f.token = file.token
1831
+ return f
1832
+
1833
+ def dataset_create_version_cli(self,
1834
+ folder,
1835
+ version_notes,
1836
+ quiet=False,
1837
+ convert_to_csv=True,
1838
+ delete_old_versions=False,
1839
+ dir_mode='skip'):
1840
+ """ client wrapper for creating a version of a dataset
1841
+ Parameters
1842
+ ==========
1843
+ folder: the folder with the dataset configuration / data files
1844
+ version_notes: notes to add for the version
1845
+ quiet: suppress verbose output (default is False)
1846
+ convert_to_csv: on upload, if data should be converted to csv
1847
+ delete_old_versions: if True, do that (default False)
1848
+ dir_mode: What to do with directories: "skip" - ignore; "zip" - compress and upload
1849
+ """
1850
+ folder = folder or os.getcwd()
1851
+ result = self.dataset_create_version(
1852
+ folder,
1853
+ version_notes,
1854
+ quiet=quiet,
1855
+ convert_to_csv=convert_to_csv,
1856
+ delete_old_versions=delete_old_versions,
1857
+ dir_mode=dir_mode)
1858
+
1859
+ if result is None:
1860
+ print('Dataset version creation error: See previous output')
1861
+ elif result.invalidTags:
1862
+ print(('The following are not valid tags and could not be added to '
1863
+ 'the dataset: ') + str(result.invalidTags))
1864
+ elif result.status.lower() == 'ok':
1865
+ print('Dataset version is being created. Please check progress at ' +
1866
+ result.url)
1867
+ else:
1868
+ print('Dataset version creation error: ' + result.error)
1869
+
1870
+ def dataset_initialize(self, folder):
1871
+ """ initialize a folder with a a dataset configuration (metadata) file
1872
+
1873
+ Parameters
1874
+ ==========
1875
+ folder: the folder to initialize the metadata file in
1876
+ """
1877
+ if not os.path.isdir(folder):
1878
+ raise ValueError('Invalid folder: ' + folder)
1879
+
1880
+ ref = self.config_values[self.CONFIG_NAME_USER] + '/INSERT_SLUG_HERE'
1881
+ licenses = []
1882
+ default_license = {'name': 'CC0-1.0'}
1883
+ licenses.append(default_license)
1884
+
1885
+ meta_data = {'title': 'INSERT_TITLE_HERE', 'id': ref, 'licenses': licenses}
1886
+ meta_file = os.path.join(folder, self.DATASET_METADATA_FILE)
1887
+ with open(meta_file, 'w') as f:
1888
+ json.dump(meta_data, f, indent=2)
1889
+
1890
+ print('Data package template written to: ' + meta_file)
1891
+ return meta_file
1892
+
1893
+ def dataset_initialize_cli(self, folder=None):
1894
+ folder = folder or os.getcwd()
1895
+ self.dataset_initialize(folder)
1896
+
1897
+ def dataset_create_new(self,
1898
+ folder,
1899
+ public=False,
1900
+ quiet=False,
1901
+ convert_to_csv=True,
1902
+ dir_mode='skip'):
1903
+ """ Create a new dataset, meaning the same as creating a version but
1904
+ with extra metadata like license and user/owner.
1905
+
1906
+ Parameters
1907
+ ==========
1908
+ folder: the folder to get the metadata file from
1909
+ public: should the dataset be public?
1910
+ quiet: suppress verbose output (default is False)
1911
+ convert_to_csv: if True, convert data to comma separated value
1912
+ dir_mode: What to do with directories: "skip" - ignore; "zip" - compress and upload
1913
+ """
1914
+ if not os.path.isdir(folder):
1915
+ raise ValueError('Invalid folder: ' + folder)
1916
+
1917
+ meta_file = self.get_dataset_metadata_file(folder)
1918
+
1919
+ # read json
1920
+ with open(meta_file) as f:
1921
+ meta_data = json.load(f)
1922
+ ref = self.get_or_fail(meta_data, 'id')
1923
+ title = self.get_or_fail(meta_data, 'title')
1924
+ licenses = self.get_or_fail(meta_data, 'licenses')
1925
+ ref_list = ref.split('/')
1926
+ owner_slug = ref_list[0]
1927
+ dataset_slug = ref_list[1]
1928
+
1929
+ # validations
1930
+ if ref == self.config_values[self.CONFIG_NAME_USER] + '/INSERT_SLUG_HERE':
1931
+ raise ValueError(
1932
+ 'Default slug detected, please change values before uploading')
1933
+ if title == 'INSERT_TITLE_HERE':
1934
+ raise ValueError(
1935
+ 'Default title detected, please change values before uploading')
1936
+ if len(licenses) != 1:
1937
+ raise ValueError('Please specify exactly one license')
1938
+ if len(dataset_slug) < 6 or len(dataset_slug) > 50:
1939
+ raise ValueError('The dataset slug must be between 6 and 50 characters')
1940
+ if len(title) < 6 or len(title) > 50:
1941
+ raise ValueError('The dataset title must be between 6 and 50 characters')
1942
+ resources = meta_data.get('resources')
1943
+ if resources:
1944
+ self.validate_resources(folder, resources)
1945
+
1946
+ license_name = self.get_or_fail(licenses[0], 'name')
1947
+ description = meta_data.get('description')
1948
+ keywords = self.get_or_default(meta_data, 'keywords', [])
1949
+
1950
+ subtitle = meta_data.get('subtitle')
1951
+ if subtitle and (len(subtitle) < 20 or len(subtitle) > 80):
1952
+ raise ValueError('Subtitle length must be between 20 and 80 characters')
1953
+
1954
+ request = ApiCreateDatasetRequest()
1955
+ request.title = title
1956
+ request.slug = dataset_slug
1957
+ request.owner_slug = owner_slug
1958
+ request.license_name = license_name
1959
+ request.subtitle = subtitle
1960
+ request.description = description
1961
+ request.files = []
1962
+ request.is_private = not public
1963
+ # request.convert_to_csv=convert_to_csv
1964
+ request.category_ids = keywords
1965
+
1966
+ with ResumableUploadContext() as upload_context:
1967
+ self.upload_files(request, resources, folder, ApiBlobType.DATASET,
1968
+ upload_context, quiet, dir_mode)
1969
+
1970
+ with self.build_kaggle_client() as kaggle:
1971
+ retry_request = ApiCreateDatasetRequest()
1972
+ retry_request.title = title
1973
+ retry_request.slug = dataset_slug
1974
+ retry_request.owner_slug = owner_slug
1975
+ retry_request.license_name = license_name
1976
+ retry_request.subtitle = subtitle
1977
+ retry_request.description = description
1978
+ retry_request.files = [
1979
+ self._api_dataset_new_file(file) for file in request.files
1980
+ ]
1981
+ retry_request.is_private = not public
1982
+ retry_request.category_ids = keywords
1983
+ response = self.with_retry(
1984
+ kaggle.datasets.dataset_api_client.create_dataset)(
1985
+ retry_request)
1986
+ return response
1987
+
1988
+ def dataset_create_new_cli(self,
1989
+ folder=None,
1990
+ public=False,
1991
+ quiet=False,
1992
+ convert_to_csv=True,
1993
+ dir_mode='skip'):
1994
+ """ client wrapper for creating a new dataset
1995
+ Parameters
1996
+ ==========
1997
+ folder: the folder to get the metadata file from
1998
+ public: should the dataset be public?
1999
+ quiet: suppress verbose output (default is False)
2000
+ convert_to_csv: if True, convert data to comma separated value
2001
+ dir_mode: What to do with directories: "skip" - ignore; "zip" - compress and upload
2002
+ """
2003
+ folder = folder or os.getcwd()
2004
+ result = self.dataset_create_new(folder, public, quiet, convert_to_csv,
2005
+ dir_mode)
2006
+ if result.invalidTags:
2007
+ print('The following are not valid tags and could not be added to '
2008
+ 'the dataset: ' + str(result.invalidTags))
2009
+ if result.status.lower() == 'ok':
2010
+ if public:
2011
+ print('Your public Dataset is being created. Please check '
2012
+ 'progress at ' + result.url)
2013
+ else:
2014
+ print('Your private Dataset is being created. Please check '
2015
+ 'progress at ' + result.url)
2016
+ else:
2017
+ print('Dataset creation error: ' + result.error)
2018
+
2019
+ def download_file(self,
2020
+ response,
2021
+ outfile,
2022
+ http_client,
2023
+ quiet=True,
2024
+ resume=False,
2025
+ chunk_size=1048576):
2026
+ """ download a file to an output file based on a chunk size
2027
+
2028
+ Parameters
2029
+ ==========
2030
+ response: the response to download
2031
+ outfile: the output file to download to
2032
+ http_client: the Kaggle http client to use
2033
+ quiet: suppress verbose output (default is True)
2034
+ chunk_size: the size of the chunk to stream
2035
+ resume: whether to resume an existing download
2036
+ """
2037
+
2038
+ outpath = os.path.dirname(outfile)
2039
+ if not os.path.exists(outpath):
2040
+ os.makedirs(outpath)
2041
+ size = int(response.headers['Content-Length'])
2042
+ size_read = 0
2043
+ open_mode = 'wb'
2044
+ last_modified = response.headers.get('Last-Modified')
2045
+ if last_modified is None:
2046
+ remote_date = datetime.now()
2047
+ else:
2048
+ remote_date = datetime.strptime(response.headers['Last-Modified'],
2049
+ '%a, %d %b %Y %H:%M:%S %Z')
2050
+ remote_date_timestamp = time.mktime(remote_date.timetuple())
2051
+
2052
+ if not quiet:
2053
+ print('Downloading ' + os.path.basename(outfile) + ' to ' + outpath)
2054
+
2055
+ file_exists = os.path.isfile(outfile)
2056
+ resumable = 'Accept-Ranges' in response.headers and response.headers[
2057
+ 'Accept-Ranges'] == 'bytes'
2058
+
2059
+ if resume and resumable and file_exists:
2060
+ size_read = os.path.getsize(outfile)
2061
+ open_mode = 'ab'
2062
+
2063
+ if not quiet:
2064
+ print("... resuming from %d bytes (%d bytes left) ..." % (
2065
+ size_read,
2066
+ size - size_read,
2067
+ ))
2068
+
2069
+ request_history = response.history[0]
2070
+ response = http_client.call(
2071
+ request_history.request.method,
2072
+ request_history.headers['location'],
2073
+ headers={'Range': 'bytes=%d-' % (size_read,)},
2074
+ _preload_content=False)
2075
+
2076
+ with tqdm(
2077
+ total=size,
2078
+ initial=size_read,
2079
+ unit='B',
2080
+ unit_scale=True,
2081
+ unit_divisor=1024,
2082
+ disable=quiet) as pbar:
2083
+ with open(outfile, open_mode) as out:
2084
+ # TODO: Delete this test after all API methods are converted.
2085
+ if type(response).__name__ == 'HTTPResponse':
2086
+ while True:
2087
+ data = response.read(chunk_size)
2088
+ if not data:
2089
+ break
2090
+ out.write(data)
2091
+ os.utime(
2092
+ outfile,
2093
+ times=(remote_date_timestamp - 1, remote_date_timestamp - 1))
2094
+ size_read = min(size, size_read + chunk_size)
2095
+ pbar.update(len(data))
2096
+ else:
2097
+ for data in response.iter_content(chunk_size):
2098
+ if not data:
2099
+ break
2100
+ out.write(data)
2101
+ os.utime(
2102
+ outfile,
2103
+ times=(remote_date_timestamp - 1, remote_date_timestamp - 1))
2104
+ size_read = min(size, size_read + chunk_size)
2105
+ pbar.update(len(data))
2106
+ if not quiet:
2107
+ print('\n', end='')
2108
+
2109
+ os.utime(outfile, times=(remote_date_timestamp, remote_date_timestamp))
2110
+
2111
+ def kernels_list(self,
2112
+ page=1,
2113
+ page_size=20,
2114
+ dataset=None,
2115
+ competition=None,
2116
+ parent_kernel=None,
2117
+ search=None,
2118
+ mine=False,
2119
+ user=None,
2120
+ language=None,
2121
+ kernel_type=None,
2122
+ output_type=None,
2123
+ sort_by=None):
2124
+ """ List kernels based on a set of search criteria.
2125
+
2126
+ Parameters
2127
+ ==========
2128
+ page: the page of results to return (default is 1)
2129
+ page_size: results per page (default is 20)
2130
+ dataset: if defined, filter to this dataset (default None)
2131
+ competition: if defined, filter to this competition (default None)
2132
+ parent_kernel: if defined, filter to those with specified parent
2133
+ search: a custom search string to pass to the list query
2134
+ mine: if true, group is specified as "my" to return personal kernels
2135
+ user: filter results to a specific user
2136
+ language: the programming language of the kernel
2137
+ kernel_type: the type of kernel, one of valid_list_kernel_types (str)
2138
+ output_type: the output type, one of valid_list_output_types (str)
2139
+ sort_by: if defined, sort results by this string (valid_list_sort_by)
2140
+ """
2141
+ if int(page) <= 0:
2142
+ raise ValueError('Page number must be >= 1')
2143
+
2144
+ page_size = int(page_size)
2145
+ if page_size <= 0:
2146
+ raise ValueError('Page size must be >= 1')
2147
+ if page_size > 100:
2148
+ page_size = 100
2149
+
2150
+ if language and language not in self.valid_list_languages:
2151
+ raise ValueError('Invalid language specified. Valid options are ' +
2152
+ str(self.valid_list_languages))
2153
+
2154
+ if kernel_type and kernel_type not in self.valid_list_kernel_types:
2155
+ raise ValueError('Invalid kernel type specified. Valid options are ' +
2156
+ str(self.valid_list_kernel_types))
2157
+
2158
+ if output_type and output_type not in self.valid_list_output_types:
2159
+ raise ValueError('Invalid output type specified. Valid options are ' +
2160
+ str(self.valid_list_output_types))
2161
+
2162
+ if sort_by:
2163
+ if sort_by not in self.valid_list_sort_by:
2164
+ raise ValueError('Invalid sort by type specified. Valid options are ' +
2165
+ str(self.valid_list_sort_by))
2166
+ if sort_by == 'relevance' and search == '':
2167
+ raise ValueError('Cannot sort by relevance without a search term.')
2168
+ sort_by = self.lookup_enum(KernelsListSortType, sort_by)
2169
+ else:
2170
+ sort_by = KernelsListSortType.HOTNESS
2171
+
2172
+ self.validate_dataset_string(dataset)
2173
+ self.validate_kernel_string(parent_kernel)
2174
+
2175
+ group = 'everyone'
2176
+ if mine:
2177
+ group = 'profile'
2178
+ group = self.lookup_enum(KernelsListViewType, group)
2179
+
2180
+ with self.build_kaggle_client() as kaggle:
2181
+ request = ApiListKernelsRequest()
2182
+ request.page = page
2183
+ request.page_size = page_size
2184
+ request.group = group
2185
+ request.user = user or ''
2186
+ request.language = language or 'all'
2187
+ request.kernel_type = kernel_type or 'all'
2188
+ request.output_type = output_type or 'all'
2189
+ request.sort_by = sort_by
2190
+ request.dataset = dataset or ''
2191
+ request.competition = competition or ''
2192
+ request.parent_kernel = parent_kernel or ''
2193
+ request.search = search or ''
2194
+ return kaggle.kernels.kernels_api_client.list_kernels(request).kernels
2195
+
2196
+ kernels_list_result = self.process_response(
2197
+ self.kernels_list_with_http_info(
2198
+ page=page,
2199
+ page_size=page_size,
2200
+ group=group,
2201
+ user=user or '',
2202
+ language=language or 'all',
2203
+ kernel_type=kernel_type or 'all',
2204
+ output_type=output_type or 'all',
2205
+ sort_by=sort_by or 'hotness',
2206
+ dataset=dataset or '',
2207
+ competition=competition or '',
2208
+ parent_kernel=parent_kernel or '',
2209
+ search=search or ''))
2210
+ return [Kernel(k) for k in kernels_list_result]
2211
+
2212
+ def kernels_list_cli(self,
2213
+ mine=False,
2214
+ page=1,
2215
+ page_size=20,
2216
+ search=None,
2217
+ csv_display=False,
2218
+ parent=None,
2219
+ competition=None,
2220
+ dataset=None,
2221
+ user=None,
2222
+ language=None,
2223
+ kernel_type=None,
2224
+ output_type=None,
2225
+ sort_by=None):
2226
+ """ Client wrapper for kernels_list, see this function for arguments.
2227
+ Additional arguments are provided here.
2228
+ Parameters
2229
+ ==========
2230
+ csv_display: if True, print comma separated values instead of table
2231
+ """
2232
+ kernels = self.kernels_list(
2233
+ page=page,
2234
+ page_size=page_size,
2235
+ search=search,
2236
+ mine=mine,
2237
+ dataset=dataset,
2238
+ competition=competition,
2239
+ parent_kernel=parent,
2240
+ user=user,
2241
+ language=language,
2242
+ kernel_type=kernel_type,
2243
+ output_type=output_type,
2244
+ sort_by=sort_by)
2245
+ fields = ['ref', 'title', 'author', 'lastRunTime', 'totalVotes']
2246
+ if kernels:
2247
+ if csv_display:
2248
+ self.print_csv(kernels, fields)
2249
+ else:
2250
+ self.print_table(kernels, fields)
2251
+ else:
2252
+ print('Not found')
2253
+
2254
+ def kernels_list_files(self, kernel, page_token=None, page_size=20):
2255
+ """ list files for a kernel
2256
+ Parameters
2257
+ ==========
2258
+ kernel: the string identifier of the kernel
2259
+ should be in format [owner]/[kernel-name]
2260
+ page_token: the page token for pagination
2261
+ page_size: the number of items per page
2262
+ """
2263
+ if kernel is None:
2264
+ raise ValueError('A kernel must be specified')
2265
+ user_name, kernel_slug, kernel_version_number = self.split_dataset_string(
2266
+ kernel)
2267
+
2268
+ with self.build_kaggle_client() as kaggle:
2269
+ request = ApiListKernelFilesRequest()
2270
+ request.kernel_slug = kernel_slug
2271
+ request.user_name = user_name
2272
+ request.page_token = page_token
2273
+ request.page_size = page_size
2274
+ return kaggle.kernels.kernels_api_client.list_kernel_files(request)
2275
+
2276
+ def kernels_list_files_cli(self,
2277
+ kernel,
2278
+ kernel_opt=None,
2279
+ csv_display=False,
2280
+ page_token=None,
2281
+ page_size=20):
2282
+ """ A wrapper to kernel_list_files for the client.
2283
+ (list files for a kernel)
2284
+ Parameters
2285
+ ==========
2286
+ kernel: the string identifier of the kernel
2287
+ should be in format [owner]/[kernel-name]
2288
+ kernel_opt: an alternative option to providing a kernel
2289
+ csv_display: if True, print comma separated values instead of table
2290
+ page_token: the page token for pagination
2291
+ page_size: the number of items per page
2292
+ """
2293
+ kernel = kernel or kernel_opt
2294
+ result = self.kernels_list_files(kernel, page_token, page_size)
2295
+
2296
+ if result is None:
2297
+ print('No files found')
2298
+ return
2299
+
2300
+ if result.error_message:
2301
+ print(result.error_message)
2302
+ return
2303
+
2304
+ next_page_token = result.nextPageToken
2305
+ if next_page_token:
2306
+ print('Next Page Token = {}'.format(next_page_token))
2307
+ fields = ['name', 'size', 'creationDate']
2308
+ if csv_display:
2309
+ self.print_csv(result.files, fields)
2310
+ else:
2311
+ self.print_table(result.files, fields)
2312
+
2313
+ def kernels_initialize(self, folder):
2314
+ """ Create a new kernel in a specified folder from a template, including
2315
+ json metadata that grabs values from the configuration.
2316
+ Parameters
2317
+ ==========
2318
+ folder: the path of the folder
2319
+ """
2320
+ if not os.path.isdir(folder):
2321
+ raise ValueError('Invalid folder: ' + folder)
2322
+
2323
+ resources = []
2324
+ resource = {'path': 'INSERT_SCRIPT_PATH_HERE'}
2325
+ resources.append(resource)
2326
+
2327
+ username = self.get_config_value(self.CONFIG_NAME_USER)
2328
+ meta_data = {
2329
+ 'id':
2330
+ username + '/INSERT_KERNEL_SLUG_HERE',
2331
+ 'title':
2332
+ 'INSERT_TITLE_HERE',
2333
+ 'code_file':
2334
+ 'INSERT_CODE_FILE_PATH_HERE',
2335
+ 'language':
2336
+ 'Pick one of: {' +
2337
+ ','.join(x for x in self.valid_push_language_types) + '}',
2338
+ 'kernel_type':
2339
+ 'Pick one of: {' +
2340
+ ','.join(x for x in self.valid_push_kernel_types) + '}',
2341
+ 'is_private':
2342
+ 'true',
2343
+ 'enable_gpu':
2344
+ 'false',
2345
+ 'enable_tpu':
2346
+ 'false',
2347
+ 'enable_internet':
2348
+ 'true',
2349
+ 'dataset_sources': [],
2350
+ 'competition_sources': [],
2351
+ 'kernel_sources': [],
2352
+ 'model_sources': [],
2353
+ }
2354
+ meta_file = os.path.join(folder, self.KERNEL_METADATA_FILE)
2355
+ with open(meta_file, 'w') as f:
2356
+ json.dump(meta_data, f, indent=2)
2357
+
2358
+ return meta_file
2359
+
2360
+ def kernels_initialize_cli(self, folder=None):
2361
+ """ A client wrapper for kernels_initialize. It takes same arguments but
2362
+ sets default folder to be None. If None, defaults to present
2363
+ working directory.
2364
+ Parameters
2365
+ ==========
2366
+ folder: the path of the folder (None defaults to ${PWD})
2367
+ """
2368
+ folder = folder or os.getcwd()
2369
+ meta_file = self.kernels_initialize(folder)
2370
+ print('Kernel metadata template written to: ' + meta_file)
2371
+
2372
+ def kernels_push(self, folder, timeout):
2373
+ """ Read the metadata file and kernel files from a notebook, validate
2374
+ both, and use the Kernel API to push to Kaggle if all is valid.
2375
+ Parameters
2376
+ ==========
2377
+ folder: the path of the folder
2378
+ """
2379
+ if not os.path.isdir(folder):
2380
+ raise ValueError('Invalid folder: ' + folder)
2381
+
2382
+ meta_file = os.path.join(folder, self.KERNEL_METADATA_FILE)
2383
+ if not os.path.isfile(meta_file):
2384
+ raise ValueError('Metadata file not found: ' + str(meta_file))
2385
+
2386
+ with open(meta_file) as f:
2387
+ meta_data = json.load(f)
2388
+
2389
+ title = self.get_or_default(meta_data, 'title', None)
2390
+ if title and len(title) < 5:
2391
+ raise ValueError('Title must be at least five characters')
2392
+
2393
+ code_path = self.get_or_default(meta_data, 'code_file', '')
2394
+ if not code_path:
2395
+ raise ValueError('A source file must be specified in the metadata')
2396
+
2397
+ code_file = os.path.join(folder, code_path)
2398
+ if not os.path.isfile(code_file):
2399
+ raise ValueError('Source file not found: ' + str(code_file))
2400
+
2401
+ slug = meta_data.get('id')
2402
+ id_no = meta_data.get('id_no')
2403
+ if not slug and not id_no:
2404
+ raise ValueError('ID or slug must be specified in the metadata')
2405
+ if slug:
2406
+ self.validate_kernel_string(slug)
2407
+ if '/' in slug:
2408
+ kernel_slug = slug.split('/')[1]
2409
+ else:
2410
+ kernel_slug = slug
2411
+ if title:
2412
+ as_slug = slugify(title)
2413
+ if kernel_slug.lower() != as_slug:
2414
+ print('Your kernel title does not resolve to the specified '
2415
+ 'id. This may result in surprising behavior. We '
2416
+ 'suggest making your title something that resolves to '
2417
+ 'the specified id. See %s for more information on '
2418
+ 'how slugs are determined.' %
2419
+ 'https://en.wikipedia.org/wiki/Clean_URL#Slug')
2420
+
2421
+ language = self.get_or_default(meta_data, 'language', '')
2422
+ if language not in self.valid_push_language_types:
2423
+ raise ValueError(
2424
+ 'A valid language must be specified in the metadata. Valid '
2425
+ 'options are ' + str(self.valid_push_language_types))
2426
+
2427
+ kernel_type = self.get_or_default(meta_data, 'kernel_type', '')
2428
+ if kernel_type not in self.valid_push_kernel_types:
2429
+ raise ValueError(
2430
+ 'A valid kernel type must be specified in the metadata. Valid '
2431
+ 'options are ' + str(self.valid_push_kernel_types))
2432
+
2433
+ if kernel_type == 'notebook' and language == 'rmarkdown':
2434
+ language = 'r'
2435
+
2436
+ dataset_sources = self.get_or_default(meta_data, 'dataset_sources', [])
2437
+ for source in dataset_sources:
2438
+ self.validate_dataset_string(source)
2439
+
2440
+ kernel_sources = self.get_or_default(meta_data, 'kernel_sources', [])
2441
+ for source in kernel_sources:
2442
+ self.validate_kernel_string(source)
2443
+
2444
+ model_sources = self.get_or_default(meta_data, 'model_sources', [])
2445
+ for source in model_sources:
2446
+ self.validate_model_string(source)
2447
+
2448
+ docker_pinning_type = self.get_or_default(meta_data,
2449
+ 'docker_image_pinning_type', None)
2450
+ if (docker_pinning_type is not None and
2451
+ docker_pinning_type not in self.valid_push_pinning_types):
2452
+ raise ValueError('If specified, the docker_image_pinning_type must be '
2453
+ 'one of ' + str(self.valid_push_pinning_types))
2454
+
2455
+ with open(code_file) as f:
2456
+ script_body = f.read()
2457
+
2458
+ if kernel_type == 'notebook':
2459
+ json_body = json.loads(script_body)
2460
+ if 'cells' in json_body:
2461
+ for cell in json_body['cells']:
2462
+ if 'outputs' in cell and cell['cell_type'] == 'code':
2463
+ cell['outputs'] = []
2464
+ # The spec allows a list of strings,
2465
+ # but the server expects just one
2466
+ if 'source' in cell and isinstance(cell['source'], list):
2467
+ cell['source'] = ''.join(cell['source'])
2468
+ script_body = json.dumps(json_body).replace("'", "\\'")
2469
+
2470
+ with self.build_kaggle_client() as kaggle:
2471
+ request = ApiSaveKernelRequest()
2472
+ request.id = id_no
2473
+ request.slug = slug
2474
+ request.new_title = self.get_or_default(meta_data, 'title', None)
2475
+ request.text = script_body
2476
+ request.language = language
2477
+ request.kernel_type = kernel_type
2478
+ request.is_private = self.get_bool(meta_data, 'is_private', True)
2479
+ request.enable_gpu = self.get_bool(meta_data, 'enable_gpu', False)
2480
+ request.enable_tpu = self.get_bool(meta_data, 'enable_tpu', False)
2481
+ request.enable_internet = self.get_bool(meta_data, 'enable_internet',
2482
+ True)
2483
+ request.dataset_data_sources = dataset_sources
2484
+ request.competition_data_sources = self.get_or_default(
2485
+ meta_data, 'competition_sources', [])
2486
+ request.kernel_data_sources = kernel_sources
2487
+ request.model_data_sources = model_sources
2488
+ request.category_ids = self.get_or_default(meta_data, 'keywords', [])
2489
+ request.docker_image_pinning_type = docker_pinning_type
2490
+ if timeout:
2491
+ request.session_timeout_seconds = int(timeout)
2492
+ return kaggle.kernels.kernels_api_client.save_kernel(request)
2493
+
2494
+ def kernels_push_cli(self, folder, timeout):
2495
+ """ Client wrapper for kernels_push.
2496
+ Parameters
2497
+ ==========
2498
+ folder: the path of the folder
2499
+ """
2500
+ folder = folder or os.getcwd()
2501
+ result = self.kernels_push(folder, timeout)
2502
+
2503
+ if result is None:
2504
+ print('Kernel push error: see previous output')
2505
+ elif not result.error:
2506
+ if result.invalidTags:
2507
+ print('The following are not valid tags and could not be added '
2508
+ 'to the kernel: ' + str(result.invalidTags))
2509
+ if result.invalidDatasetSources:
2510
+ print('The following are not valid dataset sources and could not '
2511
+ 'be added to the kernel: ' + str(result.invalidDatasetSources))
2512
+ if result.invalidCompetitionSources:
2513
+ print('The following are not valid competition sources and could '
2514
+ 'not be added to the kernel: ' +
2515
+ str(result.invalidCompetitionSources))
2516
+ if result.invalidKernelSources:
2517
+ print('The following are not valid kernel sources and could not '
2518
+ 'be added to the kernel: ' + str(result.invalidKernelSources))
2519
+
2520
+ if result.versionNumber:
2521
+ print('Kernel version %s successfully pushed. Please check '
2522
+ 'progress at %s' % (result.versionNumber, result.url))
2523
+ else:
2524
+ # Shouldn't happen but didn't test exhaustively
2525
+ print('Kernel version successfully pushed. Please check '
2526
+ 'progress at %s' % result.url)
2527
+ else:
2528
+ print('Kernel push error: ' + result.error)
2529
+
2530
+ def kernels_pull(self, kernel, path, metadata=False, quiet=True):
2531
+ """ Pull a kernel, including a metadata file (if metadata is True)
2532
+ and associated files to a specified path.
2533
+ Parameters
2534
+ ==========
2535
+ kernel: the kernel to pull
2536
+ path: the path to pull files to on the filesystem
2537
+ metadata: if True, also pull metadata
2538
+ quiet: suppress verbosity (default is True)
2539
+ """
2540
+ existing_metadata = None
2541
+ if kernel is None:
2542
+ if path is None:
2543
+ existing_metadata_path = os.path.join(os.getcwd(),
2544
+ self.KERNEL_METADATA_FILE)
2545
+ else:
2546
+ existing_metadata_path = os.path.join(path, self.KERNEL_METADATA_FILE)
2547
+ if os.path.exists(existing_metadata_path):
2548
+ with open(existing_metadata_path) as f:
2549
+ existing_metadata = json.load(f)
2550
+ kernel = existing_metadata['id']
2551
+ if 'INSERT_KERNEL_SLUG_HERE' in kernel:
2552
+ raise ValueError('A kernel must be specified')
2553
+ else:
2554
+ print('Using kernel ' + kernel)
2555
+
2556
+ if '/' in kernel:
2557
+ self.validate_kernel_string(kernel)
2558
+ kernel_url_list = kernel.split('/')
2559
+ owner_slug = kernel_url_list[0]
2560
+ kernel_slug = kernel_url_list[1]
2561
+ else:
2562
+ owner_slug = self.get_config_value(self.CONFIG_NAME_USER)
2563
+ kernel_slug = kernel
2564
+
2565
+ if path is None:
2566
+ effective_path = self.get_default_download_dir('kernels', owner_slug,
2567
+ kernel_slug)
2568
+ else:
2569
+ effective_path = path
2570
+
2571
+ if not os.path.exists(effective_path):
2572
+ os.makedirs(effective_path)
2573
+
2574
+ with self.build_kaggle_client() as kaggle:
2575
+ request = ApiGetKernelRequest()
2576
+ request.user_name = owner_slug
2577
+ request.kernel_slug = kernel_slug
2578
+ response = kaggle.kernels.kernels_api_client.get_kernel(request)
2579
+
2580
+ blob = response.blob
2581
+
2582
+ if os.path.isfile(effective_path):
2583
+ effective_dir = os.path.dirname(effective_path)
2584
+ else:
2585
+ effective_dir = effective_path
2586
+ metadata_path = os.path.join(effective_dir, self.KERNEL_METADATA_FILE)
2587
+
2588
+ if not os.path.isfile(effective_path):
2589
+ language = blob.language.lower()
2590
+ kernel_type = blob.kernel_type.lower()
2591
+
2592
+ file_name = None
2593
+ if existing_metadata:
2594
+ file_name = existing_metadata['code_file']
2595
+ elif os.path.isfile(metadata_path):
2596
+ with open(metadata_path) as f:
2597
+ file_name = json.load(f)['code_file']
2598
+
2599
+ if not file_name or file_name == "INSERT_CODE_FILE_PATH_HERE":
2600
+ extension = None
2601
+ if kernel_type == 'script':
2602
+ if language == 'python':
2603
+ extension = '.py'
2604
+ elif language == 'r':
2605
+ extension = '.R'
2606
+ elif language == 'rmarkdown':
2607
+ extension = '.Rmd'
2608
+ elif language == 'sqlite':
2609
+ extension = '.sql'
2610
+ elif language == 'julia':
2611
+ extension = '.jl'
2612
+ elif kernel_type == 'notebook':
2613
+ if language == 'python':
2614
+ extension = '.ipynb'
2615
+ elif language == 'r':
2616
+ extension = '.irnb'
2617
+ elif language == 'julia':
2618
+ extension = '.ijlnb'
2619
+ file_name = blob.slug + extension
2620
+
2621
+ if file_name is None:
2622
+ print('Unknown language %s + kernel type %s - please report this '
2623
+ 'on the kaggle-api github issues' % (language, kernel_type))
2624
+ print('Saving as a python file, even though this may not be the '
2625
+ 'correct language')
2626
+ file_name = 'script.py'
2627
+ script_path = os.path.join(effective_path, file_name)
2628
+ else:
2629
+ script_path = effective_path
2630
+ file_name = os.path.basename(effective_path)
2631
+
2632
+ with open(script_path, 'w', encoding="utf-8") as f:
2633
+ f.write(blob.source)
2634
+
2635
+ if metadata:
2636
+ data = {}
2637
+ server_metadata = response.metadata
2638
+ data['id'] = server_metadata.ref
2639
+ data['id_no'] = server_metadata.id
2640
+ data['title'] = server_metadata.title
2641
+ data['code_file'] = file_name
2642
+ data['language'] = server_metadata.language
2643
+ data['kernel_type'] = server_metadata.kernel_type
2644
+ data['is_private'] = server_metadata.is_private
2645
+ data['enable_gpu'] = server_metadata.enable_gpu
2646
+ data['enable_tpu'] = server_metadata.enable_tpu
2647
+ data['enable_internet'] = server_metadata.enable_internet
2648
+ data['keywords'] = server_metadata.category_ids
2649
+ data['dataset_sources'] = server_metadata.dataset_data_sources
2650
+ data['kernel_sources'] = server_metadata.kernel_data_sources
2651
+ data['competition_sources'] = server_metadata.competition_data_sources
2652
+ data['model_sources'] = server_metadata.model_data_sources
2653
+ with open(metadata_path, 'w') as f:
2654
+ json.dump(data, f, indent=2)
2655
+
2656
+ return effective_dir
2657
+ else:
2658
+ return script_path
2659
+
2660
+ def kernels_pull_cli(self,
2661
+ kernel,
2662
+ kernel_opt=None,
2663
+ path=None,
2664
+ metadata=False):
2665
+ """ Client wrapper for kernels_pull.
2666
+ """
2667
+ kernel = kernel or kernel_opt
2668
+ effective_path = self.kernels_pull(
2669
+ kernel, path=path, metadata=metadata, quiet=False)
2670
+ if metadata:
2671
+ print('Source code and metadata downloaded to ' + effective_path)
2672
+ else:
2673
+ print('Source code downloaded to ' + effective_path)
2674
+
2675
+ def kernels_output(self, kernel, path, force=False, quiet=True):
2676
+ """ Retrieve the output for a specified kernel.
2677
+ Parameters
2678
+ ==========
2679
+ kernel: the kernel to output
2680
+ path: the path to pull files to on the filesystem
2681
+ force: if output already exists, force overwrite (default False)
2682
+ quiet: suppress verbosity (default is True)
2683
+ """
2684
+ if kernel is None:
2685
+ raise ValueError('A kernel must be specified')
2686
+ if '/' in kernel:
2687
+ self.validate_kernel_string(kernel)
2688
+ kernel_url_list = kernel.split('/')
2689
+ owner_slug = kernel_url_list[0]
2690
+ kernel_slug = kernel_url_list[1]
2691
+ else:
2692
+ owner_slug = self.get_config_value(self.CONFIG_NAME_USER)
2693
+ kernel_slug = kernel
2694
+
2695
+ if path is None:
2696
+ target_dir = self.get_default_download_dir('kernels', owner_slug,
2697
+ kernel_slug, 'output')
2698
+ else:
2699
+ target_dir = path
2700
+
2701
+ if not os.path.exists(target_dir):
2702
+ os.makedirs(target_dir)
2703
+
2704
+ if not os.path.isdir(target_dir):
2705
+ raise ValueError('You must specify a directory for the kernels output')
2706
+
2707
+ token = None
2708
+ with self.build_kaggle_client() as kaggle:
2709
+ request = ApiListKernelSessionOutputRequest()
2710
+ request.user_name = owner_slug
2711
+ request.kernel_slug = kernel_slug
2712
+ response = kaggle.kernels.kernels_api_client.list_kernel_session_output(
2713
+ request)
2714
+ token = response.next_page_token
2715
+
2716
+ outfiles = []
2717
+ for item in response.files:
2718
+ outfile = os.path.join(target_dir, item['fileName'])
2719
+ outfiles.append(outfile)
2720
+ download_response = requests.get(item['url'], stream=True)
2721
+ if force or self.download_needed(download_response, outfile, quiet):
2722
+ os.makedirs(os.path.split(outfile)[0], exist_ok=True)
2723
+ with open(outfile, 'wb') as out:
2724
+ out.write(download_response.content)
2725
+ if not quiet:
2726
+ print('Output file downloaded to %s' % outfile)
2727
+
2728
+ log = response.log
2729
+ if log:
2730
+ outfile = os.path.join(target_dir, kernel_slug + '.log')
2731
+ outfiles.append(outfile)
2732
+ with open(outfile, 'w') as out:
2733
+ out.write(log)
2734
+ if not quiet:
2735
+ print('Kernel log downloaded to %s ' % outfile)
2736
+
2737
+ return outfiles, token # Breaking change, we need to get the token to the UI
2738
+
2739
+ def kernels_output_cli(self,
2740
+ kernel,
2741
+ kernel_opt=None,
2742
+ path=None,
2743
+ force=False,
2744
+ quiet=False):
2745
+ """ Client wrapper for kernels_output, with same arguments. Extra
2746
+ arguments are described below, and see kernels_output for others.
2747
+ Parameters
2748
+ ==========
2749
+ kernel_opt: option from client instead of kernel, if not defined
2750
+ """
2751
+ kernel = kernel or kernel_opt
2752
+ (_, token) = self.kernels_output(kernel, path, force, quiet)
2753
+ if token:
2754
+ print(f"Next page token: {token}")
2755
+
2756
+ def kernels_status(self, kernel):
2757
+ """ Call to the api to get the status of a kernel.
2758
+ Parameters
2759
+ ==========
2760
+ kernel: the kernel to get the status for
2761
+ """
2762
+ if kernel is None:
2763
+ raise ValueError('A kernel must be specified')
2764
+ if '/' in kernel:
2765
+ self.validate_kernel_string(kernel)
2766
+ kernel_url_list = kernel.split('/')
2767
+ owner_slug = kernel_url_list[0]
2768
+ kernel_slug = kernel_url_list[1]
2769
+ else:
2770
+ owner_slug = self.get_config_value(self.CONFIG_NAME_USER)
2771
+ kernel_slug = kernel
2772
+ with self.build_kaggle_client() as kaggle:
2773
+ request = ApiGetKernelSessionStatusRequest()
2774
+ request.user_name = owner_slug
2775
+ request.kernel_slug = kernel_slug
2776
+ return kaggle.kernels.kernels_api_client.get_kernel_session_status(
2777
+ request)
2778
+
2779
+ def kernels_status_cli(self, kernel, kernel_opt=None):
2780
+ """ Client wrapper for kernel_status.
2781
+ Parameters
2782
+ ==========
2783
+ kernel_opt: additional option from the client, if kernel not defined
2784
+ """
2785
+ kernel = kernel or kernel_opt
2786
+ response = self.kernels_status(kernel)
2787
+ status = response.status
2788
+ message = response.failure_message
2789
+ if message:
2790
+ print('%s has status "%s"' % (kernel, status))
2791
+ print('Failure message: "%s"' % message)
2792
+ else:
2793
+ print('%s has status "%s"' % (kernel, status))
2794
+
2795
+ def model_get(self, model):
2796
+ """ Get a model.
2797
+ Parameters
2798
+ ==========
2799
+ model: the string identifier of the model
2800
+ should be in format [owner]/[model-name]
2801
+ """
2802
+ owner_slug, model_slug = self.split_model_string(model)
2803
+
2804
+ with self.build_kaggle_client() as kaggle:
2805
+ request = ApiGetModelRequest()
2806
+ request.owner_slug = owner_slug
2807
+ request.model_slug = model_slug
2808
+ return kaggle.models.model_api_client.get_model(request)
2809
+
2810
+ def model_get_cli(self, model, folder=None):
2811
+ """ Clent wrapper for model_get, with additional
2812
+ model_opt to get a model from the API.
2813
+ Parameters
2814
+ ==========
2815
+ model: the string identifier of the model
2816
+ should be in format [owner]/[model-name]
2817
+ folder: the folder to download the model metadata file
2818
+ """
2819
+ model = self.model_get(model)
2820
+ if folder is None:
2821
+ self.print_obj(model)
2822
+ else:
2823
+ meta_file = os.path.join(folder, self.MODEL_METADATA_FILE)
2824
+
2825
+ data = {}
2826
+ data['id'] = model.id
2827
+ model_ref_split = model.ref.split('/')
2828
+ data['ownerSlug'] = model_ref_split[0]
2829
+ data['slug'] = model_ref_split[1]
2830
+ data['title'] = model.title
2831
+ data['subtitle'] = model.subtitle
2832
+ data['isPrivate'] = model.isPrivate # TODO Test to ensure True default
2833
+ data['description'] = model.description
2834
+ data['publishTime'] = model.publishTime
2835
+
2836
+ with open(meta_file, 'w') as f:
2837
+ json.dump(data, f, indent=2)
2838
+ print('Metadata file written to {}'.format(meta_file))
2839
+
2840
+ def model_list(self,
2841
+ sort_by=None,
2842
+ search=None,
2843
+ owner=None,
2844
+ page_size=20,
2845
+ page_token=None):
2846
+ """ Return a list of models.
2847
+
2848
+ Parameters
2849
+ ==========
2850
+ sort_by: how to sort the result, see valid_model_sort_bys for options
2851
+ search: a search term to use (default is empty string)
2852
+ owner: username or organization slug to filter the search to
2853
+ page_size: the page size to return (default is 20)
2854
+ page_token: the page token for pagination
2855
+ """
2856
+ if sort_by:
2857
+ if sort_by not in self.valid_model_sort_bys:
2858
+ raise ValueError('Invalid sort by specified. Valid options are ' +
2859
+ str(self.valid_model_sort_bys))
2860
+ sort_by = self.lookup_enum(ListModelsOrderBy, sort_by)
2861
+
2862
+ if int(page_size) <= 0:
2863
+ raise ValueError('Page size must be >= 1')
2864
+
2865
+ with self.build_kaggle_client() as kaggle:
2866
+ request = ApiListModelsRequest()
2867
+ request.sort_by = sort_by or ListModelsOrderBy.LIST_MODELS_ORDER_BY_HOTNESS
2868
+ request.search = search or ''
2869
+ request.owner = owner or ''
2870
+ request.page_size = page_size
2871
+ request.page_token = page_token
2872
+ response = kaggle.models.model_api_client.list_models(request)
2873
+ if response.next_page_token:
2874
+ print('Next Page Token = {}'.format(response.next_page_token))
2875
+ return response.models
2876
+
2877
+ def model_list_cli(self,
2878
+ sort_by=None,
2879
+ search=None,
2880
+ owner=None,
2881
+ page_size=20,
2882
+ page_token=None,
2883
+ csv_display=False):
2884
+ """ Client wrapper for model_list. Additional parameters
2885
+ are described here, see model_list for others.
2886
+
2887
+ Parameters
2888
+ ==========
2889
+ sort_by: how to sort the result, see valid_model_sort_bys for options
2890
+ search: a search term to use (default is empty string)
2891
+ owner: username or organization slug to filter the search to
2892
+ page_size: the page size to return (default is 20)
2893
+ page_token: the page token for pagination
2894
+ csv_display: if True, print comma separated values instead of table
2895
+ """
2896
+ models = self.model_list(sort_by, search, owner, page_size, page_token)
2897
+ fields = ['id', 'ref', 'title', 'subtitle', 'author']
2898
+ if models:
2899
+ if csv_display:
2900
+ self.print_csv(models, fields)
2901
+ else:
2902
+ self.print_table(models, fields)
2903
+ else:
2904
+ print('No models found')
2905
+
2906
+ def model_initialize(self, folder):
2907
+ """ Initialize a folder with a model configuration (metadata) file.
2908
+ Parameters
2909
+ ==========
2910
+ folder: the folder to initialize the metadata file in
2911
+ """
2912
+ if not os.path.isdir(folder):
2913
+ raise ValueError('Invalid folder: ' + folder)
2914
+
2915
+ meta_data = {
2916
+ 'ownerSlug':
2917
+ 'INSERT_OWNER_SLUG_HERE',
2918
+ 'title':
2919
+ 'INSERT_TITLE_HERE',
2920
+ 'slug':
2921
+ 'INSERT_SLUG_HERE',
2922
+ 'subtitle':
2923
+ '',
2924
+ 'isPrivate':
2925
+ True,
2926
+ 'description':
2927
+ '''# Model Summary
2928
+
2929
+ # Model Characteristics
2930
+
2931
+ # Data Overview
2932
+
2933
+ # Evaluation Results
2934
+ ''',
2935
+ 'publishTime':
2936
+ '',
2937
+ 'provenanceSources':
2938
+ ''
2939
+ }
2940
+ meta_file = os.path.join(folder, self.MODEL_METADATA_FILE)
2941
+ with open(meta_file, 'w') as f:
2942
+ json.dump(meta_data, f, indent=2)
2943
+
2944
+ print('Model template written to: ' + meta_file)
2945
+ return meta_file
2946
+
2947
+ def model_initialize_cli(self, folder=None):
2948
+ folder = folder or os.getcwd()
2949
+ self.model_initialize(folder)
2950
+
2951
+ def model_create_new(self, folder):
2952
+ """ Create a new model.
2953
+ Parameters
2954
+ ==========
2955
+ folder: the folder to get the metadata file from
2956
+ """
2957
+ if not os.path.isdir(folder):
2958
+ raise ValueError('Invalid folder: ' + folder)
2959
+
2960
+ meta_file = self.get_model_metadata_file(folder)
2961
+
2962
+ # read json
2963
+ with open(meta_file) as f:
2964
+ meta_data = json.load(f)
2965
+ owner_slug = self.get_or_fail(meta_data, 'ownerSlug')
2966
+ slug = self.get_or_fail(meta_data, 'slug')
2967
+ title = self.get_or_fail(meta_data, 'title')
2968
+ subtitle = meta_data.get('subtitle')
2969
+ is_private = self.get_or_fail(meta_data, 'isPrivate')
2970
+ description = self.sanitize_markdown(
2971
+ self.get_or_fail(meta_data, 'description'))
2972
+ publish_time = meta_data.get('publishTime')
2973
+ provenance_sources = meta_data.get('provenanceSources')
2974
+
2975
+ # validations
2976
+ if owner_slug == 'INSERT_OWNER_SLUG_HERE':
2977
+ raise ValueError(
2978
+ 'Default ownerSlug detected, please change values before uploading')
2979
+ if title == 'INSERT_TITLE_HERE':
2980
+ raise ValueError(
2981
+ 'Default title detected, please change values before uploading')
2982
+ if slug == 'INSERT_SLUG_HERE':
2983
+ raise ValueError(
2984
+ 'Default slug detected, please change values before uploading')
2985
+ if not isinstance(is_private, bool):
2986
+ raise ValueError('model.isPrivate must be a boolean')
2987
+ if publish_time:
2988
+ self.validate_date(publish_time)
2989
+ else:
2990
+ publish_time = None
2991
+
2992
+ with self.build_kaggle_client() as kaggle:
2993
+ request = ApiCreateModelRequest()
2994
+ request.owner_slug = owner_slug
2995
+ request.slug = slug
2996
+ request.title = title
2997
+ request.subtitle = subtitle
2998
+ request.is_private = is_private
2999
+ request.description = description
3000
+ request.publish_time = publish_time
3001
+ request.provenance_sources = provenance_sources
3002
+ return kaggle.models.model_api_client.create_model(request)
3003
+
3004
+ def model_create_new_cli(self, folder=None):
3005
+ """ Client wrapper for creating a new model.
3006
+ Parameters
3007
+ ==========
3008
+ folder: the folder to get the metadata file from
3009
+ """
3010
+ folder = folder or os.getcwd()
3011
+ result = self.model_create_new(folder)
3012
+
3013
+ if result.hasId:
3014
+ print('Your model was created. Id={}. Url={}'.format(
3015
+ result.id, result.url))
3016
+ else:
3017
+ print('Model creation error: ' + result.error)
3018
+
3019
+ def model_delete(self, model, yes):
3020
+ """ Delete a modeL.
3021
+ Parameters
3022
+ ==========
3023
+ model: the string identifier of the model
3024
+ should be in format [owner]/[model-name]
3025
+ yes: automatic confirmation
3026
+ """
3027
+ owner_slug, model_slug = self.split_model_string(model)
3028
+
3029
+ if not yes:
3030
+ if not self.confirmation():
3031
+ print('Deletion cancelled')
3032
+ exit(0)
3033
+
3034
+ with self.build_kaggle_client() as kaggle:
3035
+ request = ApiDeleteModelRequest()
3036
+ request.owner_slug = owner_slug
3037
+ request.model_slug = model_slug
3038
+ return kaggle.models.model_api_client.delete_model(request)
3039
+
3040
+ def model_delete_cli(self, model, yes):
3041
+ """ Client wrapper for deleting a model.
3042
+ Parameters
3043
+ ==========
3044
+ model: the string identified of the model
3045
+ should be in format [owner]/[model-name]
3046
+ yes: automatic confirmation
3047
+ """
3048
+ result = self.model_delete(model, yes)
3049
+
3050
+ if result.error:
3051
+ print('Model deletion error: ' + result.error)
3052
+ else:
3053
+ print('The model was deleted.')
3054
+
3055
+ def model_update(self, folder):
3056
+ """ Update a model.
3057
+ Parameters
3058
+ ==========
3059
+ folder: the folder to get the metadata file from
3060
+ """
3061
+ if not os.path.isdir(folder):
3062
+ raise ValueError('Invalid folder: ' + folder)
3063
+
3064
+ meta_file = self.get_model_metadata_file(folder)
3065
+
3066
+ # read json
3067
+ with open(meta_file) as f:
3068
+ meta_data = json.load(f)
3069
+ owner_slug = self.get_or_fail(meta_data, 'ownerSlug')
3070
+ slug = self.get_or_fail(meta_data, 'slug')
3071
+ title = self.get_or_default(meta_data, 'title', None)
3072
+ subtitle = self.get_or_default(meta_data, 'subtitle', None)
3073
+ is_private = self.get_or_default(meta_data, 'isPrivate', None)
3074
+ description = self.get_or_default(meta_data, 'description', None)
3075
+ publish_time = self.get_or_default(meta_data, 'publishTime', None)
3076
+ provenance_sources = self.get_or_default(meta_data, 'provenanceSources',
3077
+ None)
3078
+
3079
+ # validations
3080
+ if owner_slug == 'INSERT_OWNER_SLUG_HERE':
3081
+ raise ValueError(
3082
+ 'Default ownerSlug detected, please change values before uploading')
3083
+ if slug == 'INSERT_SLUG_HERE':
3084
+ raise ValueError(
3085
+ 'Default slug detected, please change values before uploading')
3086
+ if is_private != None and not isinstance(is_private, bool):
3087
+ raise ValueError('model.isPrivate must be a boolean')
3088
+ if publish_time:
3089
+ self.validate_date(publish_time)
3090
+
3091
+ # mask
3092
+ update_mask = {'paths': []}
3093
+ if title != None:
3094
+ update_mask['paths'].append('title')
3095
+ if subtitle != None:
3096
+ update_mask['paths'].append('subtitle')
3097
+ if is_private != None:
3098
+ update_mask['paths'].append('isPrivate') # is_private
3099
+ else:
3100
+ is_private = True # default value, not updated
3101
+ if description != None:
3102
+ description = self.sanitize_markdown(description)
3103
+ update_mask['paths'].append('description')
3104
+ if publish_time != None and len(publish_time) > 0:
3105
+ update_mask['paths'].append('publish_time')
3106
+ else:
3107
+ publish_time = None
3108
+ if provenance_sources != None and len(provenance_sources) > 0:
3109
+ update_mask['paths'].append('provenance_sources')
3110
+ else:
3111
+ provenance_sources = None
3112
+
3113
+ with self.build_kaggle_client() as kaggle:
3114
+ fm = field_mask_pb2.FieldMask(paths=update_mask['paths'])
3115
+ fm = fm.FromJsonString(json.dumps(update_mask))
3116
+ request = ApiUpdateModelRequest()
3117
+ request.owner_slug = owner_slug
3118
+ request.model_slug = slug
3119
+ request.title = title
3120
+ request.subtitle = subtitle
3121
+ request.is_private = is_private
3122
+ request.description = description
3123
+ request.publish_time = publish_time
3124
+ request.provenance_sources = provenance_sources
3125
+ request.update_mask = fm if len(update_mask['paths']) > 0 else None
3126
+ return kaggle.models.model_api_client.update_model(request)
3127
+
3128
+ def model_update_cli(self, folder=None):
3129
+ """ Client wrapper for updating a model.
3130
+ Parameters
3131
+ ==========
3132
+ folder: the folder to get the metadata file from
3133
+ """
3134
+ folder = folder or os.getcwd()
3135
+ result = self.model_update(folder)
3136
+
3137
+ if result.hasId:
3138
+ print('Your model was updated. Id={}. Url={}'.format(
3139
+ result.id, result.url))
3140
+ else:
3141
+ print('Model update error: ' + result.error)
3142
+
3143
+ def model_instance_get(self, model_instance):
3144
+ """ Get a model instance.
3145
+ Parameters
3146
+ ==========
3147
+ model_instance: the string identifier of the model instance
3148
+ should be in format [owner]/[model-name]/[framework]/[instance-slug]
3149
+ """
3150
+ if model_instance is None:
3151
+ raise ValueError('A model instance must be specified')
3152
+ owner_slug, model_slug, framework, instance_slug = self.split_model_instance_string(
3153
+ model_instance)
3154
+
3155
+ with self.build_kaggle_client() as kaggle:
3156
+ request = ApiGetModelInstanceRequest()
3157
+ request.owner_slug = owner_slug
3158
+ request.model_slug = model_slug
3159
+ request.framework = self.lookup_enum(ModelFramework, framework)
3160
+ request.instance_slug = instance_slug
3161
+ return kaggle.models.model_api_client.get_model_instance(request)
3162
+
3163
+ def model_instance_get_cli(self, model_instance, folder=None):
3164
+ """ Client wrapper for model_instance_get.
3165
+ Parameters
3166
+ ==========
3167
+ model_instance: the string identifier of the model instance
3168
+ should be in format [owner]/[model-name]/[framework]/[instance-slug]
3169
+ folder: the folder to download the model metadata file
3170
+ """
3171
+ mi = self.model_instance_get(model_instance)
3172
+ if folder is None:
3173
+ self.print_obj(mi)
3174
+ else:
3175
+ meta_file = os.path.join(folder, self.MODEL_INSTANCE_METADATA_FILE)
3176
+
3177
+ owner_slug, model_slug, framework, instance_slug = self.split_model_instance_string(
3178
+ model_instance)
3179
+
3180
+ data = {
3181
+ 'id': mi.id,
3182
+ 'ownerSlug': owner_slug,
3183
+ 'modelSlug': model_slug,
3184
+ 'instanceSlug': mi.slug,
3185
+ 'framework': self.short_enum_name(mi.framework),
3186
+ 'overview': mi.overview,
3187
+ 'usage': mi.usage,
3188
+ 'licenseName': mi.license_name,
3189
+ 'fineTunable': mi.fine_tunable,
3190
+ 'trainingData': mi.training_data,
3191
+ 'versionId': mi.version_id,
3192
+ 'versionNumber': mi.version_number,
3193
+ 'modelInstanceType': self.short_enum_name(mi.model_instance_type)
3194
+ }
3195
+ if mi.base_model_instance_information is not None:
3196
+ # TODO Test this.
3197
+ data['baseModelInstance'] = '{}/{}/{}/{}'.format(
3198
+ mi.base_model_instance_information['owner']['slug'],
3199
+ mi.base_model_instance_information['modelSlug'],
3200
+ mi.base_model_instance_information['framework'],
3201
+ mi.base_model_instance_information['instanceSlug'])
3202
+ data['externalBaseModelUrl'] = mi.external_base_model_url
3203
+
3204
+ with open(meta_file, 'w') as f:
3205
+ json.dump(data, f, indent=2)
3206
+ print('Metadata file written to {}'.format(meta_file))
3207
+
3208
+ def model_instance_initialize(self, folder):
3209
+ """ Initialize a folder with a model instance configuration (metadata) file.
3210
+ Parameters
3211
+ ==========
3212
+ folder: the folder to initialize the metadata file in
3213
+ """
3214
+ if not os.path.isdir(folder):
3215
+ raise ValueError('Invalid folder: ' + folder)
3216
+
3217
+ meta_data = {
3218
+ 'ownerSlug':
3219
+ 'INSERT_OWNER_SLUG_HERE',
3220
+ 'modelSlug':
3221
+ 'INSERT_EXISTING_MODEL_SLUG_HERE',
3222
+ 'instanceSlug':
3223
+ 'INSERT_INSTANCE_SLUG_HERE',
3224
+ 'framework':
3225
+ 'INSERT_FRAMEWORK_HERE',
3226
+ 'overview':
3227
+ '',
3228
+ 'usage':
3229
+ '''# Model Format
3230
+
3231
+ # Training Data
3232
+
3233
+ # Model Inputs
3234
+
3235
+ # Model Outputs
3236
+
3237
+ # Model Usage
3238
+
3239
+ # Fine-tuning
3240
+
3241
+ # Changelog
3242
+ ''',
3243
+ 'licenseName':
3244
+ 'Apache 2.0',
3245
+ 'fineTunable':
3246
+ False,
3247
+ 'trainingData': [],
3248
+ 'modelInstanceType':
3249
+ 'Unspecified',
3250
+ 'baseModelInstanceId':
3251
+ 0,
3252
+ 'externalBaseModelUrl':
3253
+ ''
3254
+ }
3255
+ meta_file = os.path.join(folder, self.MODEL_INSTANCE_METADATA_FILE)
3256
+ with open(meta_file, 'w') as f:
3257
+ json.dump(meta_data, f, indent=2)
3258
+
3259
+ print('Model Instance template written to: ' + meta_file)
3260
+ return meta_file
3261
+
3262
+ def model_instance_initialize_cli(self, folder):
3263
+ folder = folder or os.getcwd()
3264
+ self.model_instance_initialize(folder)
3265
+
3266
+ def model_instance_create(self, folder, quiet=False, dir_mode='skip'):
3267
+ """ Create a new model instance.
3268
+ Parameters
3269
+ ==========
3270
+ folder: the folder to get the metadata file from
3271
+ quiet: suppress verbose output (default is False)
3272
+ dir_mode: what to do with directories: "skip" - ignore; "zip" - compress and upload
3273
+ """
3274
+ if not os.path.isdir(folder):
3275
+ raise ValueError('Invalid folder: ' + folder)
3276
+
3277
+ meta_file = self.get_model_instance_metadata_file(folder)
3278
+
3279
+ # read json
3280
+ with open(meta_file) as f:
3281
+ meta_data = json.load(f)
3282
+ owner_slug = self.get_or_fail(meta_data, 'ownerSlug')
3283
+ model_slug = self.get_or_fail(meta_data, 'modelSlug')
3284
+ instance_slug = self.get_or_fail(meta_data, 'instanceSlug')
3285
+ framework = self.get_or_fail(meta_data, 'framework')
3286
+ overview = self.sanitize_markdown(
3287
+ self.get_or_default(meta_data, 'overview', ''))
3288
+ usage = self.sanitize_markdown(self.get_or_default(meta_data, 'usage', ''))
3289
+ license_name = self.get_or_fail(meta_data, 'licenseName')
3290
+ fine_tunable = self.get_or_default(meta_data, 'fineTunable', False)
3291
+ training_data = self.get_or_default(meta_data, 'trainingData', [])
3292
+ model_instance_type = self.get_or_default(meta_data, 'modelInstanceType',
3293
+ 'Unspecified')
3294
+ base_model_instance = self.get_or_default(meta_data, 'baseModelInstance',
3295
+ '')
3296
+ external_base_model_url = self.get_or_default(meta_data,
3297
+ 'externalBaseModelUrl', '')
3298
+
3299
+ # validations
3300
+ if owner_slug == 'INSERT_OWNER_SLUG_HERE':
3301
+ raise ValueError(
3302
+ 'Default ownerSlug detected, please change values before uploading')
3303
+ if model_slug == 'INSERT_EXISTING_MODEL_SLUG_HERE':
3304
+ raise ValueError(
3305
+ 'Default modelSlug detected, please change values before uploading')
3306
+ if instance_slug == 'INSERT_INSTANCE_SLUG_HERE':
3307
+ raise ValueError(
3308
+ 'Default instanceSlug detected, please change values before uploading'
3309
+ )
3310
+ if framework == 'INSERT_FRAMEWORK_HERE':
3311
+ raise ValueError(
3312
+ 'Default framework detected, please change values before uploading')
3313
+ if license_name == '':
3314
+ raise ValueError('Please specify a license')
3315
+ if not isinstance(fine_tunable, bool):
3316
+ raise ValueError('modelInstance.fineTunable must be a boolean')
3317
+ if not isinstance(training_data, list):
3318
+ raise ValueError('modelInstance.trainingData must be a list')
3319
+
3320
+ body = ApiCreateModelInstanceRequestBody()
3321
+ body.framework = self.lookup_enum(ModelFramework, framework)
3322
+ body.instance_slug = instance_slug
3323
+ body.overview = overview
3324
+ body.usage = usage
3325
+ body.license_name = license_name
3326
+ body.fine_tunable = fine_tunable
3327
+ body.training_data = training_data
3328
+ body.model_instance_type = self.lookup_enum(ModelInstanceType,
3329
+ model_instance_type)
3330
+ body.base_model_instance = base_model_instance
3331
+ body.external_base_model_url = external_base_model_url
3332
+ body.files = []
3333
+
3334
+ with self.build_kaggle_client() as kaggle:
3335
+ request = ApiCreateModelInstanceRequest()
3336
+ request.owner_slug = owner_slug
3337
+ request.model_slug = model_slug
3338
+ request.body = body
3339
+ message = kaggle.models.model_api_client.create_model_instance
3340
+ with ResumableUploadContext() as upload_context:
3341
+ self.upload_files(body, None, folder, ApiBlobType.MODEL, upload_context,
3342
+ quiet, dir_mode)
3343
+ request.body.files = [
3344
+ self._api_dataset_new_file(file) for file in request.body.files
3345
+ ]
3346
+ response = self.with_retry(message)(request)
3347
+ return response
3348
+
3349
+ def model_instance_create_cli(self, folder, quiet=False, dir_mode='skip'):
3350
+ """ Client wrapper for creating a new model instance.
3351
+ Parameters
3352
+ ==========
3353
+ folder: the folder to get the metadata file from
3354
+ quiet: suppress verbose output (default is False)
3355
+ dir_mode: what to do with directories: "skip" - ignore; "zip" - compress and upload
3356
+ """
3357
+ folder = folder or os.getcwd()
3358
+ result = self.model_instance_create(folder, quiet, dir_mode)
3359
+
3360
+ if result.hasId:
3361
+ print('Your model instance was created. Id={}. Url={}'.format(
3362
+ result.id, result.url))
3363
+ else:
3364
+ print('Model instance creation error: ' + result.error)
3365
+
3366
+ def model_instance_delete(self, model_instance, yes):
3367
+ """ Delete a model instance.
3368
+ Parameters
3369
+ ==========
3370
+ model_instance: the string identified of the model instance
3371
+ should be in format [owner]/[model-name]/[framework]/[instance-slug]
3372
+ yes: automatic confirmation
3373
+ """
3374
+ if model_instance is None:
3375
+ raise ValueError('A model instance must be specified')
3376
+ owner_slug, model_slug, framework, instance_slug = self.split_model_instance_string(
3377
+ model_instance)
3378
+
3379
+ if not yes:
3380
+ if not self.confirmation():
3381
+ print('Deletion cancelled')
3382
+ exit(0)
3383
+
3384
+ with self.build_kaggle_client() as kaggle:
3385
+ request = ApiDeleteModelInstanceRequest()
3386
+ request.owner_slug = owner_slug
3387
+ request.model_slug = model_slug
3388
+ request.framework = self.lookup_enum(ModelFramework, framework)
3389
+ request.instance_slug = instance_slug
3390
+ return kaggle.models.model_api_client.delete_model_instance(request)
3391
+ return res
3392
+
3393
+ def model_instance_delete_cli(self, model_instance, yes):
3394
+ """ Client wrapper for model_instance_delete.
3395
+ Parameters
3396
+ ==========
3397
+ model_instance: the string identified of the model instance
3398
+ should be in format [owner]/[model-name]/[framework]/[instance-slug]
3399
+ yes: automatic confirmation
3400
+ """
3401
+ result = self.model_instance_delete(model_instance, yes)
3402
+
3403
+ if len(result.error) > 0:
3404
+ print('Model instance deletion error: ' + result.error)
3405
+ else:
3406
+ print('The model instance was deleted.')
3407
+
3408
+ def model_instance_files(self,
3409
+ model_instance,
3410
+ page_token=None,
3411
+ page_size=20,
3412
+ csv_display=False):
3413
+ """ List files for the current version of a model instance.
3414
+
3415
+ Parameters
3416
+ ==========
3417
+ model_instance: the string identifier of the model instance
3418
+ should be in format [owner]/[model-name]/[framework]/[instance-slug]
3419
+ page_token: token for pagination
3420
+ page_size: the number of items per page
3421
+ csv_display: if True, print comma separated values instead of table
3422
+ """
3423
+ if model_instance is None:
3424
+ raise ValueError('A model_instance must be specified')
3425
+
3426
+ self.validate_model_instance_string(model_instance)
3427
+ urls = model_instance.split('/')
3428
+ [owner_slug, model_slug, framework, instance_slug] = urls
3429
+
3430
+ with self.build_kaggle_client() as kaggle:
3431
+ request = ApiListModelInstanceVersionFilesRequest()
3432
+ request.owner_slug = owner_slug
3433
+ request.model_slug = model_slug
3434
+ request.framework = self.lookup_enum(ModelFramework, framework)
3435
+ request.instance_slug = instance_slug
3436
+ request.page_size = page_size
3437
+ request.page_token = page_token
3438
+ response = kaggle.models.model_api_client.list_model_instance_version_files(
3439
+ request)
3440
+
3441
+ if response:
3442
+ next_page_token = response.next_page_token
3443
+ if next_page_token:
3444
+ print('Next Page Token = {}'.format(next_page_token))
3445
+ return response
3446
+ else:
3447
+ print('No files found')
3448
+ return FileList({})
3449
+
3450
+ def model_instance_files_cli(self,
3451
+ model_instance,
3452
+ page_token=None,
3453
+ page_size=20,
3454
+ csv_display=False):
3455
+ """ Client wrapper for model_instance_files.
3456
+
3457
+ Parameters
3458
+ ==========
3459
+ model_instance: the string identified of the model instance version
3460
+ should be in format [owner]/[model-name]/[framework]/[instance-slug]
3461
+ page_token: token for pagination
3462
+ page_size: the number of items per page
3463
+ csv_display: if True, print comma separated values instead of table
3464
+ """
3465
+ result = self.model_instance_files(
3466
+ model_instance,
3467
+ page_token=page_token,
3468
+ page_size=page_size,
3469
+ csv_display=csv_display)
3470
+ if result and result.files is not None:
3471
+ fields = self.dataset_file_fields
3472
+ if csv_display:
3473
+ self.print_csv(result.files, fields)
3474
+ else:
3475
+ self.print_table(result.files, fields)
3476
+
3477
+ def model_instance_update(self, folder):
3478
+ """ Update a model instance.
3479
+ Parameters
3480
+ ==========
3481
+ folder: the folder to get the metadata file from
3482
+ """
3483
+ if not os.path.isdir(folder):
3484
+ raise ValueError('Invalid folder: ' + folder)
3485
+
3486
+ meta_file = self.get_model_instance_metadata_file(folder)
3487
+
3488
+ # read json
3489
+ with open(meta_file) as f:
3490
+ meta_data = json.load(f)
3491
+ owner_slug = self.get_or_fail(meta_data, 'ownerSlug')
3492
+ model_slug = self.get_or_fail(meta_data, 'modelSlug')
3493
+ framework = self.get_or_fail(meta_data, 'framework')
3494
+ instance_slug = self.get_or_fail(meta_data, 'instanceSlug')
3495
+ overview = self.get_or_default(meta_data, 'overview', '')
3496
+ usage = self.get_or_default(meta_data, 'usage', '')
3497
+ license_name = self.get_or_default(meta_data, 'licenseName', None)
3498
+ fine_tunable = self.get_or_default(meta_data, 'fineTunable', None)
3499
+ training_data = self.get_or_default(meta_data, 'trainingData', None)
3500
+ model_instance_type = self.get_or_default(meta_data, 'modelInstanceType',
3501
+ None)
3502
+ base_model_instance = self.get_or_default(meta_data, 'baseModelInstance',
3503
+ None)
3504
+ external_base_model_url = self.get_or_default(meta_data,
3505
+ 'externalBaseModelUrl', None)
3506
+
3507
+ # validations
3508
+ if owner_slug == 'INSERT_OWNER_SLUG_HERE':
3509
+ raise ValueError(
3510
+ 'Default ownerSlug detected, please change values before uploading')
3511
+ if model_slug == 'INSERT_SLUG_HERE':
3512
+ raise ValueError(
3513
+ 'Default model slug detected, please change values before uploading')
3514
+ if instance_slug == 'INSERT_INSTANCE_SLUG_HERE':
3515
+ raise ValueError(
3516
+ 'Default instance slug detected, please change values before uploading'
3517
+ )
3518
+ if framework == 'INSERT_FRAMEWORK_HERE':
3519
+ raise ValueError(
3520
+ 'Default framework detected, please change values before uploading')
3521
+ if fine_tunable != None and not isinstance(fine_tunable, bool):
3522
+ raise ValueError('modelInstance.fineTunable must be a boolean')
3523
+ if training_data != None and not isinstance(training_data, list):
3524
+ raise ValueError('modelInstance.trainingData must be a list')
3525
+ if model_instance_type:
3526
+ model_instance_type = self.lookup_enum(ModelInstanceType,
3527
+ model_instance_type)
3528
+
3529
+ # mask
3530
+ update_mask = {'paths': []}
3531
+ if overview != None:
3532
+ overview = self.sanitize_markdown(overview)
3533
+ update_mask['paths'].append('overview')
3534
+ if usage != None:
3535
+ usage = self.sanitize_markdown(usage)
3536
+ update_mask['paths'].append('usage')
3537
+ if license_name != None:
3538
+ update_mask['paths'].append('licenseName')
3539
+ else:
3540
+ license_name = "Apache 2.0" # default value even if not updated
3541
+ if fine_tunable != None:
3542
+ update_mask['paths'].append('fineTunable')
3543
+ if training_data != None:
3544
+ update_mask['paths'].append('trainingData')
3545
+ if model_instance_type != None:
3546
+ update_mask['paths'].append('modelInstanceType')
3547
+ if base_model_instance != None:
3548
+ update_mask['paths'].append('baseModelInstance')
3549
+ if external_base_model_url != None:
3550
+ update_mask['paths'].append('externalBaseModelUrl')
3551
+
3552
+ with self.build_kaggle_client() as kaggle:
3553
+ fm = field_mask_pb2.FieldMask(paths=update_mask['paths'])
3554
+ fm = fm.FromJsonString(json.dumps(update_mask))
3555
+ request = ApiUpdateModelInstanceRequest()
3556
+ request.owner_slug = owner_slug
3557
+ request.model_slug = model_slug
3558
+ request.framework = self.lookup_enum(ModelFramework, framework)
3559
+ request.instance_slug = instance_slug
3560
+ request.overview = overview
3561
+ request.usage = usage
3562
+ request.license_name = license_name
3563
+ request.fine_tunable = fine_tunable
3564
+ request.training_data = training_data
3565
+ request.model_instance_type = model_instance_type
3566
+ request.base_model_instance = base_model_instance
3567
+ request.external_base_model_url = external_base_model_url
3568
+ request.update_mask = fm
3569
+ request.update_mask = fm if len(update_mask['paths']) > 0 else None
3570
+ return kaggle.models.model_api_client.update_model_instance(request)
3571
+
3572
+ def model_instance_update_cli(self, folder=None):
3573
+ """ Client wrapper for updating a model instance.
3574
+ Parameters
3575
+ ==========
3576
+ folder: the folder to get the metadata file from
3577
+ """
3578
+ folder = folder or os.getcwd()
3579
+ result = self.model_instance_update(folder)
3580
+
3581
+ if len(result.error) == 0:
3582
+ print('Your model instance was updated. Id={}. Url={}'.format(
3583
+ result.id, result.url))
3584
+ else:
3585
+ print('Model update error: ' + result.error)
3586
+
3587
+ def model_instance_version_create(self,
3588
+ model_instance,
3589
+ folder,
3590
+ version_notes='',
3591
+ quiet=False,
3592
+ dir_mode='skip'):
3593
+ """ Create a new model instance version.
3594
+ Parameters
3595
+ ==========
3596
+ model_instance: the string identified of the model instance
3597
+ should be in format [owner]/[model-name]/[framework]/[instance-slug]
3598
+ folder: the folder to get the metadata file from
3599
+ version_notes: the version notes to record for this new version
3600
+ quiet: suppress verbose output (default is False)
3601
+ dir_mode: what to do with directories: "skip" - ignore; "zip" - compress and upload
3602
+ """
3603
+ owner_slug, model_slug, framework, instance_slug = self.split_model_instance_string(
3604
+ model_instance)
3605
+
3606
+ request = ApiCreateModelInstanceVersionRequest()
3607
+ request.owner_slug = owner_slug
3608
+ request.model_slug = model_slug
3609
+ request.framework = self.lookup_enum(ModelFramework, framework)
3610
+ request.instance_slug = instance_slug
3611
+ body = ApiCreateModelInstanceVersionRequestBody()
3612
+ body.version_notes = version_notes
3613
+ request.body = body
3614
+ with self.build_kaggle_client() as kaggle:
3615
+ message = kaggle.models.model_api_client.create_model_instance_version
3616
+ with ResumableUploadContext() as upload_context:
3617
+ self.upload_files(body, None, folder, ApiBlobType.MODEL, upload_context,
3618
+ quiet, dir_mode)
3619
+ request.body.files = [
3620
+ self._api_dataset_new_file(file) for file in request.body.files
3621
+ ]
3622
+ response = self.with_retry(message)(request)
3623
+ return response
3624
+
3625
+ def model_instance_version_create_cli(self,
3626
+ model_instance,
3627
+ folder,
3628
+ version_notes='',
3629
+ quiet=False,
3630
+ dir_mode='skip'):
3631
+ """ Client wrapper for creating a new version of a model instance.
3632
+ Parameters
3633
+ ==========
3634
+ model_instance: the string identifier of the model instance
3635
+ should be in format [owner]/[model-name]/[framework]/[instance-slug]
3636
+ folder: the folder to get the metadata file from
3637
+ version_notes: the version notes to record for this new version
3638
+ quiet: suppress verbose output (default is False)
3639
+ dir_mode: what to do with directories: "skip" - ignore; "zip" - compress and upload
3640
+ """
3641
+ result = self.model_instance_version_create(model_instance, folder,
3642
+ version_notes, quiet, dir_mode)
3643
+
3644
+ if result.id != 0:
3645
+ print('Your model instance version was created. Url={}'.format(
3646
+ result.url))
3647
+ else:
3648
+ print('Model instance version creation error: ' + result.error)
3649
+
3650
+ def model_instance_version_download(self,
3651
+ model_instance_version,
3652
+ path=None,
3653
+ force=False,
3654
+ quiet=True,
3655
+ untar=False):
3656
+ """ Download all files for a model instance version.
3657
+
3658
+ Parameters
3659
+ ==========
3660
+ model_instance_version: the string identifier of the model instance version
3661
+ should be in format [owner]/[model-name]/[framework]/[instance-slug]/[version-number]
3662
+ path: the path to download the model instance version to
3663
+ force: force the download if the file already exists (default False)
3664
+ quiet: suppress verbose output (default is True)
3665
+ untar: if True, untar files upon download (default is False)
3666
+ """
3667
+ if model_instance_version is None:
3668
+ raise ValueError('A model_instance_version must be specified')
3669
+
3670
+ self.validate_model_instance_version_string(model_instance_version)
3671
+ urls = model_instance_version.split('/')
3672
+ owner_slug = urls[0]
3673
+ model_slug = urls[1]
3674
+ framework = urls[2]
3675
+ instance_slug = urls[3]
3676
+ version_number = urls[4]
3677
+
3678
+ if path is None:
3679
+ effective_path = self.get_default_download_dir('models', owner_slug,
3680
+ model_slug, framework,
3681
+ instance_slug,
3682
+ version_number)
3683
+ else:
3684
+ effective_path = path
3685
+
3686
+ request = ApiDownloadModelInstanceVersionRequest()
3687
+ request.owner_slug = owner_slug
3688
+ request.model_slug = model_slug
3689
+ request.framework = self.lookup_enum(ModelFramework, framework)
3690
+ request.instance_slug = instance_slug
3691
+ request.version_number = int(version_number)
3692
+ with self.build_kaggle_client() as kaggle:
3693
+ response = kaggle.models.model_api_client.download_model_instance_version(
3694
+ request)
3695
+
3696
+ outfile = os.path.join(effective_path, model_slug + '.tar.gz')
3697
+ if force or self.download_needed(response, outfile, quiet):
3698
+ self.download_file(response, outfile, quiet, not force)
3699
+ downloaded = True
3700
+ else:
3701
+ downloaded = False
3702
+
3703
+ if downloaded:
3704
+ if untar:
3705
+ try:
3706
+ with tarfile.open(outfile, mode='r:gz') as t:
3707
+ t.extractall(effective_path)
3708
+ except Exception as e:
3709
+ raise ValueError(
3710
+ 'Error extracting the tar.gz file, please report on '
3711
+ 'www.github.com/kaggle/kaggle-api', e)
3712
+
3713
+ try:
3714
+ os.remove(outfile)
3715
+ except OSError as e:
3716
+ print('Could not delete tar file, got %s' % e)
3717
+ return outfile
3718
+
3719
+ def model_instance_version_download_cli(self,
3720
+ model_instance_version,
3721
+ path=None,
3722
+ untar=False,
3723
+ force=False,
3724
+ quiet=False):
3725
+ """ Client wrapper for model_instance_version_download.
3726
+
3727
+ Parameters
3728
+ ==========
3729
+ model_instance_version: the string identifier of the model instance version
3730
+ should be in format [owner]/[model-name]/[framework]/[instance-slug]/[version-number]
3731
+ path: the path to download the model instance version to
3732
+ force: force the download if the file already exists (default False)
3733
+ quiet: suppress verbose output (default is False)
3734
+ untar: if True, untar files upon download (default is False)
3735
+ """
3736
+ return self.model_instance_version_download(
3737
+ model_instance_version,
3738
+ path=path,
3739
+ untar=untar,
3740
+ force=force,
3741
+ quiet=quiet)
3742
+
3743
+ def model_instance_version_files(self,
3744
+ model_instance_version,
3745
+ page_token=None,
3746
+ page_size=20,
3747
+ csv_display=False):
3748
+ """ List all files for a model instance version.
3749
+
3750
+ Parameters
3751
+ ==========
3752
+ model_instance_version: the string identifier of the model instance version
3753
+ should be in format [owner]/[model-name]/[framework]/[instance-slug]/[version-number]
3754
+ page_token: token for pagination
3755
+ page_size: the number of items per page
3756
+ csv_display: if True, print comma separated values instead of table
3757
+ """
3758
+ if model_instance_version is None:
3759
+ raise ValueError('A model_instance_version must be specified')
3760
+
3761
+ self.validate_model_instance_version_string(model_instance_version)
3762
+ urls = model_instance_version.split('/')
3763
+ [owner_slug, model_slug, framework, instance_slug, version_number] = urls
3764
+
3765
+ request = ApiListModelInstanceVersionFilesRequest()
3766
+ request.owner_slug = owner_slug
3767
+ request.model_slug = model_slug
3768
+ request.framework = self.lookup_enum(ModelFramework, framework)
3769
+ request.instance_slug = instance_slug
3770
+ request.version_number = int(version_number)
3771
+ request.page_size = page_size
3772
+ request.page_token = page_token
3773
+ with self.build_kaggle_client() as kaggle:
3774
+ response = kaggle.models.model_api_client.list_model_instance_version_files(
3775
+ request)
3776
+
3777
+ if response:
3778
+ next_page_token = response.next_page_token
3779
+ if next_page_token:
3780
+ print('Next Page Token = {}'.format(next_page_token))
3781
+ return response
3782
+ else:
3783
+ print('No files found')
3784
+
3785
+ def model_instance_version_files_cli(self,
3786
+ model_instance_version,
3787
+ page_token=None,
3788
+ page_size=20,
3789
+ csv_display=False):
3790
+ """ Client wrapper for model_instance_version_files.
3791
+
3792
+ Parameters
3793
+ ==========
3794
+ model_instance_version: the string identifier of the model instance version
3795
+ should be in format [owner]/[model-name]/[framework]/[instance-slug]/[version-number]
3796
+ page_token: token for pagination
3797
+ page_size: the number of items per page
3798
+ csv_display: if True, print comma separated values instead of table
3799
+ """
3800
+ result = self.model_instance_version_files(
3801
+ model_instance_version,
3802
+ page_token=page_token,
3803
+ page_size=page_size,
3804
+ csv_display=csv_display)
3805
+ if result and result.files is not None:
3806
+ fields = ['name', 'size', 'creation_date']
3807
+ labels = ['name', 'size', 'creationDate']
3808
+ if csv_display:
3809
+ self.print_csv(result.files, fields, labels)
3810
+ else:
3811
+ self.print_table(result.files, fields, labels)
3812
+
3813
+ def model_instance_version_delete(self, model_instance_version, yes):
3814
+ """ Delete a model instance version.
3815
+ Parameters
3816
+ ==========
3817
+ model_instance_version: the string identifier of the model instance version
3818
+ should be in format [owner]/[model-name]/[framework]/[instance-slug]/[version-number]
3819
+ yes: automatic confirmation
3820
+ """
3821
+ if model_instance_version is None:
3822
+ raise ValueError('A model instance version must be specified')
3823
+
3824
+ self.validate_model_instance_version_string(model_instance_version)
3825
+ urls = model_instance_version.split('/')
3826
+ owner_slug = urls[0]
3827
+ model_slug = urls[1]
3828
+ framework = urls[2]
3829
+ instance_slug = urls[3]
3830
+ version_number = urls[4]
3831
+
3832
+ if not yes:
3833
+ if not self.confirmation():
3834
+ print('Deletion cancelled')
3835
+ exit(0)
3836
+
3837
+ request = ApiDeleteModelInstanceVersionRequest()
3838
+ request.owner_slug = owner_slug
3839
+ request.model_slug = model_slug
3840
+ request.framework = self.lookup_enum(ModelFramework, framework)
3841
+ request.instance_slug = instance_slug
3842
+ request.version_number = int(version_number)
3843
+ with self.build_kaggle_client() as kaggle:
3844
+ response = kaggle.models.model_api_client.delete_model_instance_version(
3845
+ request)
3846
+ return response
3847
+
3848
+ def model_instance_version_delete_cli(self, model_instance_version, yes):
3849
+ """ Client wrapper for model_instance_version_delete
3850
+ Parameters
3851
+ ==========
3852
+ model_instance_version: the string identified of the model instance version
3853
+ should be in format [owner]/[model-name]/[framework]/[instance-slug]/[version-number]
3854
+ yes: automatic confirmation
3855
+ """
3856
+ result = self.model_instance_version_delete(model_instance_version, yes)
3857
+
3858
+ if len(result.error) > 0:
3859
+ print('Model instance version deletion error: ' + result.error)
3860
+ else:
3861
+ print('The model instance version was deleted.')
3862
+
3863
+ def files_upload_cli(self, local_paths, inbox_path, no_resume, no_compress):
3864
+ if len(local_paths) > self.MAX_NUM_INBOX_FILES_TO_UPLOAD:
3865
+ print('Cannot upload more than %d files!' %
3866
+ self.MAX_NUM_INBOX_FILES_TO_UPLOAD)
3867
+ return
3868
+
3869
+ files_to_create = []
3870
+ with ResumableUploadContext(no_resume) as upload_context:
3871
+ for local_path in local_paths:
3872
+ (upload_file, file_name) = self.file_upload_cli(local_path, inbox_path,
3873
+ no_compress,
3874
+ upload_context)
3875
+ if upload_file is None:
3876
+ continue
3877
+
3878
+ create_inbox_file_request = CreateInboxFileRequest()
3879
+ create_inbox_file_request.virtual_directory = inbox_path
3880
+ create_inbox_file_request.blob_file_token = upload_file.token
3881
+ files_to_create.append((create_inbox_file_request, file_name))
3882
+
3883
+ with self.build_kaggle_client() as kaggle:
3884
+ create_inbox_file = kaggle.admin.inbox_file_client.create_inbox_file
3885
+ for (create_inbox_file_request, file_name) in files_to_create:
3886
+ self.with_retry(create_inbox_file)(create_inbox_file_request)
3887
+ print('Inbox file created:', file_name)
3888
+
3889
+ def file_upload_cli(self, local_path, inbox_path, no_compress,
3890
+ upload_context):
3891
+ full_path = os.path.abspath(local_path)
3892
+ parent_path = os.path.dirname(full_path)
3893
+ file_or_folder_name = os.path.basename(full_path)
3894
+ dir_mode = 'tar' if no_compress else 'zip'
3895
+
3896
+ upload_file = self._upload_file_or_folder(parent_path, file_or_folder_name,
3897
+ ApiBlobType.INBOX, upload_context,
3898
+ dir_mode)
3899
+ return (upload_file, file_or_folder_name)
3900
+
3901
+ def print_obj(self, obj, indent=2):
3902
+ pretty = json.dumps(obj, indent=indent)
3903
+ print(pretty)
3904
+
3905
+ def download_needed(self, response, outfile, quiet=True):
3906
+ """ determine if a download is needed based on timestamp. Return True
3907
+ if needed (remote is newer) or False if local is newest.
3908
+ Parameters
3909
+ ==========
3910
+ response: the response from the API
3911
+ outfile: the output file to write to
3912
+ quiet: suppress verbose output (default is True)
3913
+ """
3914
+ try:
3915
+ last_modified = response.headers.get('Last-Modified')
3916
+ if last_modified is None:
3917
+ remote_date = datetime.now()
3918
+ else:
3919
+ remote_date = datetime.strptime(response.headers['Last-Modified'],
3920
+ '%a, %d %b %Y %H:%M:%S %Z')
3921
+ file_exists = os.path.isfile(outfile)
3922
+ if file_exists:
3923
+ local_date = datetime.fromtimestamp(os.path.getmtime(outfile))
3924
+ remote_size = int(response.headers['Content-Length'])
3925
+ local_size = os.path.getsize(outfile)
3926
+ if local_size < remote_size:
3927
+ return True
3928
+ if remote_date <= local_date:
3929
+ if not quiet:
3930
+ print(
3931
+ os.path.basename(outfile) +
3932
+ ': Skipping, found more recently modified local '
3933
+ 'copy (use --force to force download)')
3934
+ return False
3935
+ except:
3936
+ pass
3937
+ return True
3938
+
3939
+ def print_table(self, items, fields, labels=None):
3940
+ """ print a table of items, for a set of fields defined
3941
+
3942
+ Parameters
3943
+ ==========
3944
+ items: a list of items to print
3945
+ fields: a list of fields to select from items
3946
+ labels: labels for the fields, defaults to fields
3947
+ """
3948
+ if labels is None:
3949
+ labels = fields
3950
+ formats = []
3951
+ borders = []
3952
+ if len(items) == 0:
3953
+ return
3954
+ for f in fields:
3955
+ length = max(
3956
+ len(f),
3957
+ max([
3958
+ len(self.string(getattr(i, self.camel_to_snake(f))))
3959
+ for i in items
3960
+ ]))
3961
+ justify = '>' if isinstance(
3962
+ getattr(items[0], self.camel_to_snake(f)),
3963
+ int) or f == 'size' or f == 'reward' else '<'
3964
+ formats.append('{:' + justify + self.string(length + 2) + '}')
3965
+ borders.append('-' * length + ' ')
3966
+ row_format = u''.join(formats)
3967
+ headers = [f + ' ' for f in labels]
3968
+ print(row_format.format(*headers))
3969
+ print(row_format.format(*borders))
3970
+ for i in items:
3971
+ i_fields = [
3972
+ self.string(getattr(i, self.camel_to_snake(f))) + ' ' for f in fields
3973
+ ]
3974
+ try:
3975
+ print(row_format.format(*i_fields))
3976
+ except UnicodeEncodeError:
3977
+ print(row_format.format(*i_fields).encode('utf-8'))
3978
+
3979
+ def print_csv(self, items, fields, labels=None):
3980
+ """ print a set of fields in a set of items using a csv.writer
3981
+
3982
+ Parameters
3983
+ ==========
3984
+ items: a list of items to print
3985
+ fields: a list of fields to select from items
3986
+ labels: labels for the fields, defaults to fields
3987
+ """
3988
+ if labels is None:
3989
+ labels = fields
3990
+ writer = csv.writer(sys.stdout)
3991
+ writer.writerow(labels)
3992
+ for i in items:
3993
+ i_fields = [
3994
+ self.string(getattr(i, self.camel_to_snake(f))) for f in fields
3995
+ ]
3996
+ writer.writerow(i_fields)
3997
+
3998
+ def string(self, item):
3999
+ return item if isinstance(item, str) else str(item)
4000
+
4001
+ def get_or_fail(self, data, key):
4002
+ if key in data:
4003
+ return data[key]
4004
+ raise ValueError('Key ' + key + ' not found in data')
4005
+
4006
+ def get_or_default(self, data, key, default):
4007
+ if key in data:
4008
+ return data[key]
4009
+ return default
4010
+
4011
+ def get_bool(self, data, key, default):
4012
+ if key in data:
4013
+ val = data[key]
4014
+ if isinstance(val, str):
4015
+ val = val.lower()
4016
+ if val == 'true':
4017
+ return True
4018
+ elif val == 'false':
4019
+ return False
4020
+ else:
4021
+ raise ValueError('Invalid boolean value: ' + val)
4022
+ if isinstance(val, bool):
4023
+ return val
4024
+ raise ValueError('Invalid boolean value: ' + val)
4025
+ return default
4026
+
4027
+ def set_if_present(self, data, key, output, output_key):
4028
+ if key in data:
4029
+ output[output_key] = data[key]
4030
+
4031
+ def get_dataset_metadata_file(self, folder):
4032
+ meta_file = os.path.join(folder, self.DATASET_METADATA_FILE)
4033
+ if not os.path.isfile(meta_file):
4034
+ meta_file = os.path.join(folder, self.OLD_DATASET_METADATA_FILE)
4035
+ if not os.path.isfile(meta_file):
4036
+ raise ValueError('Metadata file not found: ' +
4037
+ self.DATASET_METADATA_FILE)
4038
+ return meta_file
4039
+
4040
+ def get_model_metadata_file(self, folder):
4041
+ meta_file = os.path.join(folder, self.MODEL_METADATA_FILE)
4042
+ if not os.path.isfile(meta_file):
4043
+ raise ValueError('Metadata file not found: ' + self.MODEL_METADATA_FILE)
4044
+ return meta_file
4045
+
4046
+ def get_model_instance_metadata_file(self, folder):
4047
+ meta_file = os.path.join(folder, self.MODEL_INSTANCE_METADATA_FILE)
4048
+ if not os.path.isfile(meta_file):
4049
+ raise ValueError('Metadata file not found: ' +
4050
+ self.MODEL_INSTANCE_METADATA_FILE)
4051
+ return meta_file
4052
+
4053
+ def process_response(self, result):
4054
+ """ process a response from the API. We check the API version against
4055
+ the client's to see if it's old, and give them a warning (once)
4056
+
4057
+ Parameters
4058
+ ==========
4059
+ result: the result from the API
4060
+ """
4061
+ if len(result) == 3:
4062
+ data = result[0]
4063
+ headers = result[2]
4064
+ if self.HEADER_API_VERSION in headers:
4065
+ api_version = headers[self.HEADER_API_VERSION]
4066
+ if (not self.already_printed_version_warning and
4067
+ not self.is_up_to_date(api_version)):
4068
+ print(f'Warning: Looks like you\'re using an outdated `kaggle`` '
4069
+ 'version (installed: {self.__version__}, please consider '
4070
+ 'upgrading to the latest version ({api_version})')
4071
+ self.already_printed_version_warning = True
4072
+ if isinstance(data, dict) and 'code' in data and data['code'] != 200:
4073
+ raise Exception(data['message'])
4074
+ return data
4075
+ return result
4076
+
4077
+ def is_up_to_date(self, server_version):
4078
+ """ determine if a client (on the local user's machine) is up to date
4079
+ with the version provided on the server. Return a boolean with True
4080
+ or False
4081
+ Parameters
4082
+ ==========
4083
+ server_version: the server version string to compare to the host
4084
+ """
4085
+ client_split = self.__version__.split('.')
4086
+ client_len = len(client_split)
4087
+ server_split = server_version.split('.')
4088
+ server_len = len(server_split)
4089
+
4090
+ # Make both lists the same length
4091
+ for i in range(client_len, server_len):
4092
+ client_split.append('0')
4093
+ for i in range(server_len, client_len):
4094
+ server_split.append('0')
4095
+
4096
+ for i in range(0, client_len):
4097
+ if 'a' in client_split[i] or 'b' in client_split[i]:
4098
+ # Using a alpha/beta version, don't check
4099
+ return True
4100
+ client = int(client_split[i])
4101
+ server = int(server_split[i])
4102
+ if client < server:
4103
+ return False
4104
+ elif server < client:
4105
+ return True
4106
+
4107
+ return True
4108
+
4109
+ def upload_files(self,
4110
+ request,
4111
+ resources,
4112
+ folder,
4113
+ blob_type,
4114
+ upload_context,
4115
+ quiet=False,
4116
+ dir_mode='skip'):
4117
+ """ upload files in a folder
4118
+ Parameters
4119
+ ==========
4120
+ request: the prepared request
4121
+ resources: the files to upload
4122
+ folder: the folder to upload from
4123
+ blob_type (ApiBlobType): To which entity the file/blob refers
4124
+ upload_context (ResumableUploadContext): Context for resumable uploads
4125
+ quiet: suppress verbose output (default is False)
4126
+ """
4127
+ for file_name in os.listdir(folder):
4128
+ if (file_name in [
4129
+ self.DATASET_METADATA_FILE, self.OLD_DATASET_METADATA_FILE,
4130
+ self.KERNEL_METADATA_FILE, self.MODEL_METADATA_FILE,
4131
+ self.MODEL_INSTANCE_METADATA_FILE
4132
+ ]):
4133
+ continue
4134
+ upload_file = self._upload_file_or_folder(folder, file_name, blob_type,
4135
+ upload_context, dir_mode, quiet,
4136
+ resources)
4137
+ if upload_file is not None:
4138
+ request.files.append(upload_file)
4139
+
4140
+ def _upload_file_or_folder(self,
4141
+ parent_path,
4142
+ file_or_folder_name,
4143
+ blob_type,
4144
+ upload_context,
4145
+ dir_mode,
4146
+ quiet=False,
4147
+ resources=None):
4148
+ full_path = os.path.join(parent_path, file_or_folder_name)
4149
+ if os.path.isfile(full_path):
4150
+ return self._upload_file(file_or_folder_name, full_path, blob_type,
4151
+ upload_context, quiet, resources)
4152
+
4153
+ elif os.path.isdir(full_path):
4154
+ if dir_mode in ['zip', 'tar']:
4155
+ with DirectoryArchive(full_path, dir_mode) as archive:
4156
+ return self._upload_file(archive.name, archive.path, blob_type,
4157
+ upload_context, quiet, resources)
4158
+ elif not quiet:
4159
+ print("Skipping folder: " + file_or_folder_name +
4160
+ "; use '--dir-mode' to upload folders")
4161
+ else:
4162
+ if not quiet:
4163
+ print('Skipping: ' + file_or_folder_name)
4164
+ return None
4165
+
4166
+ def _upload_file(self, file_name, full_path, blob_type, upload_context, quiet,
4167
+ resources):
4168
+ """ Helper function to upload a single file
4169
+ Parameters
4170
+ ==========
4171
+ file_name: name of the file to upload
4172
+ full_path: path to the file to upload
4173
+ blob_type (ApiBlobType): To which entity the file/blob refers
4174
+ upload_context (ResumableUploadContext): Context for resumable uploads
4175
+ quiet: suppress verbose output
4176
+ resources: optional file metadata
4177
+ :return: None - upload unsuccessful; instance of UploadFile - upload successful
4178
+ """
4179
+
4180
+ if not quiet:
4181
+ print('Starting upload for file ' + file_name)
4182
+
4183
+ content_length = os.path.getsize(full_path)
4184
+ token = self._upload_blob(full_path, quiet, blob_type, upload_context)
4185
+ if token is None:
4186
+ if not quiet:
4187
+ print('Upload unsuccessful: ' + file_name)
4188
+ return None
4189
+ if not quiet:
4190
+ print('Upload successful: ' + file_name + ' (' +
4191
+ File.get_size(content_length) + ')')
4192
+ upload_file = UploadFile()
4193
+ upload_file.token = token
4194
+ if resources:
4195
+ for item in resources:
4196
+ if file_name == item.get('path'):
4197
+ upload_file.description = item.get('description')
4198
+ if 'schema' in item:
4199
+ fields = self.get_or_default(item['schema'], 'fields', [])
4200
+ processed = []
4201
+ count = 0
4202
+ for field in fields:
4203
+ processed.append(self.process_column(field))
4204
+ processed[count].order = count
4205
+ count += 1
4206
+ upload_file.columns = processed
4207
+ return upload_file
4208
+
4209
+ def process_column(self, column):
4210
+ """ process a column, check for the type, and return the processed
4211
+ column
4212
+ Parameters
4213
+ ==========
4214
+ column: a list of values in a column to be processed
4215
+ """
4216
+ processed_column = DatasetColumn(
4217
+ name=self.get_or_fail(column, 'name'),
4218
+ description=self.get_or_default(column, 'description', ''))
4219
+ if 'type' in column:
4220
+ original_type = column['type'].lower()
4221
+ processed_column.original_type = original_type
4222
+ if (original_type == 'string' or original_type == 'date' or
4223
+ original_type == 'time' or original_type == 'yearmonth' or
4224
+ original_type == 'duration' or original_type == 'geopoint' or
4225
+ original_type == 'geojson'):
4226
+ processed_column.type = 'string'
4227
+ elif (original_type == 'numeric' or original_type == 'number' or
4228
+ original_type == 'year'):
4229
+ processed_column.type = 'numeric'
4230
+ elif original_type == 'boolean':
4231
+ processed_column.type = 'boolean'
4232
+ elif original_type == 'datetime':
4233
+ processed_column.type = 'datetime'
4234
+ else:
4235
+ # Possibly extended data type - not going to try to track those
4236
+ # here. Will set the type and let the server handle it.
4237
+ processed_column.type = original_type
4238
+ return processed_column
4239
+
4240
+ def upload_complete(self, path, url, quiet, resume=False):
4241
+ """ function to complete an upload to retrieve a path from a url
4242
+ Parameters
4243
+ ==========
4244
+ path: the path for the upload that is read in
4245
+ url: the url to send the POST to
4246
+ quiet: suppress verbose output (default is False)
4247
+ """
4248
+ file_size = os.path.getsize(path)
4249
+ resumable_upload_result = ResumableUploadResult.Incomplete()
4250
+
4251
+ try:
4252
+ if resume:
4253
+ resumable_upload_result = self._resume_upload(path, url, file_size,
4254
+ quiet)
4255
+ if resumable_upload_result.result != ResumableUploadResult.INCOMPLETE:
4256
+ return resumable_upload_result.result
4257
+
4258
+ start_at = resumable_upload_result.start_at
4259
+ upload_size = file_size - start_at
4260
+
4261
+ with tqdm(
4262
+ total=upload_size,
4263
+ unit='B',
4264
+ unit_scale=True,
4265
+ unit_divisor=1024,
4266
+ disable=quiet) as progress_bar:
4267
+ with io.open(path, 'rb', buffering=0) as fp:
4268
+ session = requests.Session()
4269
+ if start_at > 0:
4270
+ fp.seek(start_at)
4271
+ session.headers.update({
4272
+ 'Content-Length':
4273
+ '%d' % upload_size,
4274
+ 'Content-Range':
4275
+ 'bytes %d-%d/%d' % (start_at, file_size - 1, file_size)
4276
+ })
4277
+ reader = TqdmBufferedReader(fp, progress_bar)
4278
+ retries = Retry(total=10, backoff_factor=0.5)
4279
+ adapter = HTTPAdapter(max_retries=retries)
4280
+ session.mount('http://', adapter)
4281
+ session.mount('https://', adapter)
4282
+ response = session.put(url, data=reader)
4283
+ if self._is_upload_successful(response):
4284
+ return ResumableUploadResult.COMPLETE
4285
+ if response.status_code == 503:
4286
+ return ResumableUploadResult.INCOMPLETE
4287
+ # Server returned a non-resumable error so give up.
4288
+ return ResumableUploadResult.FAILED
4289
+ except Exception as error:
4290
+ print(error)
4291
+ # There is probably some weird bug in our code so try to resume the upload
4292
+ # in case it works on the next try.
4293
+ return ResumableUploadResult.INCOMPLETE
4294
+
4295
+ def _resume_upload(self, path, url, content_length, quiet):
4296
+ # Documentation: https://developers.google.com/drive/api/guides/manage-uploads#resume-upload
4297
+ session = requests.Session()
4298
+ session.headers.update({
4299
+ 'Content-Length': '0',
4300
+ 'Content-Range': 'bytes */%d' % content_length,
4301
+ })
4302
+
4303
+ response = session.put(url)
4304
+
4305
+ if self._is_upload_successful(response):
4306
+ return ResumableUploadResult.Complete()
4307
+ if response.status_code == 404:
4308
+ # Upload expired so need to start from scratch.
4309
+ if not quiet:
4310
+ print('Upload of %s expired. Please try again.' % path)
4311
+ return ResumableUploadResult.Failed()
4312
+ if response.status_code == 308: # Resume Incomplete
4313
+ bytes_uploaded = self._get_bytes_already_uploaded(response, quiet)
4314
+ if bytes_uploaded is None:
4315
+ # There is an error with the Range header so need to start from scratch.
4316
+ return ResumableUploadResult.Failed()
4317
+ result = ResumableUploadResult.Incomplete(bytes_uploaded)
4318
+ if not quiet:
4319
+ print('Already uploaded %d bytes. Will resume upload at %d.' %
4320
+ (result.bytes_uploaded, result.start_at))
4321
+ return result
4322
+ else:
4323
+ if not quiet:
4324
+ print('Server returned %d. Please try again.' % response.status_code)
4325
+ return ResumableUploadResult.Failed()
4326
+
4327
+ def _is_upload_successful(self, response):
4328
+ return response.status_code == 200 or response.status_code == 201
4329
+
4330
+ def _get_bytes_already_uploaded(self, response, quiet):
4331
+ range_val = response.headers.get('Range')
4332
+ if range_val is None:
4333
+ return 0 # This means server hasn't received anything before.
4334
+ items = range_val.split('-') # Example: bytes=0-1000 => ['0', '1000']
4335
+ if len(items) != 2:
4336
+ if not quiet:
4337
+ print('Invalid Range header format: %s. Will try again.' % range_val)
4338
+ return None # Shouldn't happen, something's wrong with Range header format.
4339
+ bytes_uploaded_str = items[-1] # Example: ['0', '1000'] => '1000'
4340
+ try:
4341
+ return int(bytes_uploaded_str) # Example: '1000' => 1000
4342
+ except ValueError:
4343
+ if not quiet:
4344
+ print('Invalid Range header format: %s. Will try again.' % range_val)
4345
+ return None # Shouldn't happen, something's wrong with Range header format.
4346
+
4347
+ def validate_dataset_string(self, dataset):
4348
+ """ determine if a dataset string is valid, meaning it is in the format
4349
+ of {username}/{dataset-slug} or {username}/{dataset-slug}/{version-number}.
4350
+ Parameters
4351
+ ==========
4352
+ dataset: the dataset name to validate
4353
+ """
4354
+ if dataset:
4355
+ if '/' not in dataset:
4356
+ raise ValueError('Dataset must be specified in the form of '
4357
+ '\'{username}/{dataset-slug}\'')
4358
+
4359
+ split = dataset.split('/')
4360
+ if not split[0] or not split[1] or len(split) > 3:
4361
+ raise ValueError('Invalid dataset specification ' + dataset)
4362
+
4363
+ def split_dataset_string(self, dataset):
4364
+ """ split a dataset string into owner_slug, dataset_slug,
4365
+ and optional version_number
4366
+ Parameters
4367
+ ==========
4368
+ dataset: the dataset name to split
4369
+ """
4370
+ if '/' in dataset:
4371
+ self.validate_dataset_string(dataset)
4372
+ urls = dataset.split('/')
4373
+ if len(urls) == 3:
4374
+ return urls[0], urls[1], urls[2]
4375
+ else:
4376
+ return urls[0], urls[1], None
4377
+ else:
4378
+ return self.get_config_value(self.CONFIG_NAME_USER), dataset, None
4379
+
4380
+ def validate_model_string(self, model):
4381
+ """ determine if a model string is valid, meaning it is in the format
4382
+ of {owner}/{model-slug}.
4383
+ Parameters
4384
+ ==========
4385
+ model: the model name to validate
4386
+ """
4387
+ if model:
4388
+ if model.count('/') != 1:
4389
+ raise ValueError('Model must be specified in the form of '
4390
+ '\'{owner}/{model-slug}\'')
4391
+
4392
+ split = model.split('/')
4393
+ if not split[0] or not split[1]:
4394
+ raise ValueError('Invalid model specification ' + model)
4395
+
4396
+ def split_model_string(self, model):
4397
+ """ split a model string into owner_slug, model_slug
4398
+ Parameters
4399
+ ==========
4400
+ model: the model name to split
4401
+ """
4402
+ if '/' in model:
4403
+ self.validate_model_string(model)
4404
+ model_urls = model.split('/')
4405
+ return model_urls[0], model_urls[1]
4406
+ else:
4407
+ return self.get_config_value(self.CONFIG_NAME_USER), model
4408
+
4409
+ def validate_model_instance_string(self, model_instance):
4410
+ """ determine if a model instance string is valid, meaning it is in the format
4411
+ of {owner}/{model-slug}/{framework}/{instance-slug}.
4412
+ Parameters
4413
+ ==========
4414
+ model_instance: the model instance name to validate
4415
+ """
4416
+ if model_instance:
4417
+ if model_instance.count('/') != 3:
4418
+ raise ValueError('Model instance must be specified in the form of '
4419
+ '\'{owner}/{model-slug}/{framework}/{instance-slug}\'')
4420
+
4421
+ split = model_instance.split('/')
4422
+ if not split[0] or not split[1] or not split[2] or not split[3]:
4423
+ raise ValueError('Invalid model instance specification ' +
4424
+ model_instance)
4425
+
4426
+ def split_model_instance_string(self, model_instance):
4427
+ """ split a model instance string into owner_slug, model_slug,
4428
+ framework, instance_slug
4429
+ Parameters
4430
+ ==========
4431
+ model_instance: the model instance name to validate
4432
+ """
4433
+ self.validate_model_instance_string(model_instance)
4434
+ urls = model_instance.split('/')
4435
+ return urls[0], urls[1], urls[2], urls[3]
4436
+
4437
+ def validate_model_instance_version_string(self, model_instance_version):
4438
+ """ determine if a model instance version string is valid, meaning it is in the format
4439
+ of {owner}/{model-slug}/{framework}/{instance-slug}/{version-number}.
4440
+ Parameters
4441
+ ==========
4442
+ model_instance_version: the model instance version name to validate
4443
+ """
4444
+ if model_instance_version:
4445
+ if model_instance_version.count('/') != 4:
4446
+ raise ValueError(
4447
+ 'Model instance version must be specified in the form of '
4448
+ '\'{owner}/{model-slug}/{framework}/{instance-slug}/{version-number}\''
4449
+ )
4450
+
4451
+ split = model_instance_version.split('/')
4452
+ if not split[0] or not split[1] or not split[2] or not split[
4453
+ 3] or not split[4]:
4454
+ raise ValueError('Invalid model instance version specification ' +
4455
+ model_instance_version)
4456
+
4457
+ try:
4458
+ version_number = int(split[4])
4459
+ except:
4460
+ raise ValueError(
4461
+ 'Model instance version\'s version-number must be an integer')
4462
+
4463
+ def validate_kernel_string(self, kernel):
4464
+ """ determine if a kernel string is valid, meaning it is in the format
4465
+ of {username}/{kernel-slug}.
4466
+ Parameters
4467
+ ==========
4468
+ kernel: the kernel name to validate
4469
+ """
4470
+ if kernel:
4471
+ if '/' not in kernel:
4472
+ raise ValueError('Kernel must be specified in the form of '
4473
+ '\'{username}/{kernel-slug}\'')
4474
+
4475
+ split = kernel.split('/')
4476
+ if not split[0] or not split[1]:
4477
+ raise ValueError('Kernel must be specified in the form of '
4478
+ '\'{username}/{kernel-slug}\'')
4479
+
4480
+ if len(split[1]) < 5:
4481
+ raise ValueError('Kernel slug must be at least five characters')
4482
+
4483
+ def validate_model_string(self, model):
4484
+ """ determine if a model string is valid, meaning it is in the format
4485
+ of {username}/{model-slug}/{framework}/{variation-slug}/{version-number}.
4486
+ Parameters
4487
+ ==========
4488
+ model: the model name to validate
4489
+ """
4490
+ if model:
4491
+ if '/' not in model:
4492
+ raise ValueError(
4493
+ 'Model must be specified in the form of '
4494
+ '\'{username}/{model-slug}/{framework}/{variation-slug}/{version-number}\''
4495
+ )
4496
+
4497
+ split = model.split('/')
4498
+ if not split[0] or not split[1]:
4499
+ raise ValueError('Invalid model specification ' + model)
4500
+
4501
+ def validate_resources(self, folder, resources):
4502
+ """ validate resources is a wrapper to validate the existence of files
4503
+ and that there are no duplicates for a folder and set of resources.
4504
+
4505
+ Parameters
4506
+ ==========
4507
+ folder: the folder to validate
4508
+ resources: one or more resources to validate within the folder
4509
+ """
4510
+ self.validate_files_exist(folder, resources)
4511
+ self.validate_no_duplicate_paths(resources)
4512
+
4513
+ def validate_files_exist(self, folder, resources):
4514
+ """ ensure that one or more resource files exist in a folder
4515
+
4516
+ Parameters
4517
+ ==========
4518
+ folder: the folder to validate
4519
+ resources: one or more resources to validate within the folder
4520
+ """
4521
+ for item in resources:
4522
+ file_name = item.get('path')
4523
+ full_path = os.path.join(folder, file_name)
4524
+ if not os.path.isfile(full_path):
4525
+ raise ValueError('%s does not exist' % full_path)
4526
+
4527
+ def validate_no_duplicate_paths(self, resources):
4528
+ """ ensure that the user has not provided duplicate paths in
4529
+ a list of resources.
4530
+
4531
+ Parameters
4532
+ ==========
4533
+ resources: one or more resources to validate not duplicated
4534
+ """
4535
+ paths = set()
4536
+ for item in resources:
4537
+ file_name = item.get('path')
4538
+ if file_name in paths:
4539
+ raise ValueError(
4540
+ '%s path was specified more than once in the metadata' % file_name)
4541
+ paths.add(file_name)
4542
+
4543
+ def convert_to_dataset_file_metadata(self, file_data, path):
4544
+ """ convert a set of file_data to a metadata file at path
4545
+
4546
+ Parameters
4547
+ ==========
4548
+ file_data: a dictionary of file data to write to file
4549
+ path: the path to write the metadata to
4550
+ """
4551
+ as_metadata = {
4552
+ 'path': os.path.join(path, file_data['name']),
4553
+ 'description': file_data['description']
4554
+ }
4555
+
4556
+ schema = {}
4557
+ fields = []
4558
+ for column in file_data['columns']:
4559
+ field = {
4560
+ 'name': column['name'],
4561
+ 'title': column['description'],
4562
+ 'type': column['type']
4563
+ }
4564
+ fields.append(field)
4565
+ schema['fields'] = fields
4566
+ as_metadata['schema'] = schema
4567
+
4568
+ return as_metadata
4569
+
4570
+ def validate_date(self, date):
4571
+ datetime.strptime(date, "%Y-%m-%d")
4572
+
4573
+ def sanitize_markdown(self, markdown):
4574
+ return bleach.clean(markdown)
4575
+
4576
+ def confirmation(self):
4577
+ question = "Are you sure?"
4578
+ prompt = "[yes/no]"
4579
+ options = {"yes": True, "no": False}
4580
+ while True:
4581
+ sys.stdout.write('{} {} '.format(question, prompt))
4582
+ choice = input().lower()
4583
+ if choice in options:
4584
+ return options[choice]
4585
+ else:
4586
+ sys.stdout.write("Please respond with 'yes' or 'no'.\n")
4587
+ return False
4588
+
4589
+
4590
+ class TqdmBufferedReader(io.BufferedReader):
4591
+
4592
+ def __init__(self, raw, progress_bar):
4593
+ """ helper class to implement an io.BufferedReader
4594
+ Parameters
4595
+ ==========
4596
+ raw: bytes data to pass to the buffered reader
4597
+ progress_bar: a progress bar to initialize the reader
4598
+ """
4599
+ io.BufferedReader.__init__(self, raw)
4600
+ self.progress_bar = progress_bar
4601
+
4602
+ def read(self, *args, **kwargs):
4603
+ """ read the buffer, passing named and non named arguments to the
4604
+ io.BufferedReader function.
4605
+ """
4606
+ buf = io.BufferedReader.read(self, *args, **kwargs)
4607
+ self.increment(len(buf))
4608
+ return buf
4609
+
4610
+ def increment(self, length):
4611
+ """ increment the reader by some length
4612
+
4613
+ Parameters
4614
+ ==========
4615
+ length: bytes to increment the reader by
4616
+ """
4617
+ self.progress_bar.update(length)
4618
+
4619
+
4620
+ class FileList(object):
4621
+
4622
+ def __init__(self, init_dict):
4623
+ self.error_message = ''
4624
+ files = init_dict['files']
4625
+ if files:
4626
+ for f in files:
4627
+ if 'size' in f:
4628
+ f['totalBytes'] = f['size']
4629
+ self.files = [File(f) for f in files]
4630
+ else:
4631
+ self.files = []
4632
+ token = init_dict['nextPageToken']
4633
+ if token:
4634
+ self.nextPageToken = token
4635
+ else:
4636
+ self.nextPageToken = ""
4637
+
4638
+ def __repr__(self):
4639
+ return ''
4640
+
4641
+ # This defines print_attributes(), which is very handy for inspecting
4642
+ # objects returned by the Kaggle API.
4643
+
4644
+ from pprint import pprint
4645
+ from inspect import getmembers
4646
+ from types import FunctionType
4647
+
4648
+ def attributes(obj):
4649
+ disallowed_names = {
4650
+ name for name, value in getmembers(type(obj))
4651
+ if isinstance(value, FunctionType)}
4652
+ return {
4653
+ name: getattr(obj, name) for name in dir(obj)
4654
+ if name[0] != '_' and name not in disallowed_names and hasattr(obj, name)}
4655
+
4656
+ def print_attributes(obj):
4657
+ pprint(attributes(obj))