codeshift 0.4.0__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -24,12 +24,151 @@ class PydanticV1ToV2Transformer(BaseTransformer):
24
24
  self._current_class: str | None = None
25
25
  # Track position info
26
26
  self._line_offset = 0
27
+ # Track Pydantic model classes defined in this file
28
+ self._pydantic_model_classes: set[str] = set()
29
+ # Track variables known to be Pydantic model instances
30
+ self._pydantic_instance_vars: set[str] = set()
31
+ # Track function parameters with Pydantic model type hints
32
+ self._pydantic_param_vars: set[str] = set()
33
+ # Track if BaseModel is imported from pydantic
34
+ self._has_basemodel_import = False
35
+
36
+ def visit_ImportFrom(self, node: cst.ImportFrom) -> bool:
37
+ """Track Pydantic imports to identify model base classes."""
38
+ if node.module is None:
39
+ return True
40
+
41
+ module_name = self._get_module_name(node.module)
42
+ if module_name == "pydantic" or module_name.startswith("pydantic."):
43
+ if isinstance(node.names, cst.ImportStar):
44
+ # With star import, assume BaseModel is available
45
+ self._has_basemodel_import = True
46
+ elif isinstance(node.names, tuple):
47
+ for name in node.names:
48
+ if isinstance(name, cst.ImportAlias):
49
+ imported_name = self._get_name_value(name.name)
50
+ if imported_name == "BaseModel":
51
+ self._has_basemodel_import = True
52
+ return True
27
53
 
28
54
  def visit_ClassDef(self, node: cst.ClassDef) -> bool:
29
- """Track the current class being visited."""
55
+ """Track the current class being visited and detect Pydantic models."""
30
56
  self._current_class = node.name.value
57
+
58
+ # Check if this class inherits from BaseModel or another known Pydantic model
59
+ for base in node.bases:
60
+ base_name = self._get_base_class_name(base.value)
61
+ if (
62
+ base_name in ("BaseModel", "pydantic.BaseModel")
63
+ or base_name in self._pydantic_model_classes
64
+ ):
65
+ self._pydantic_model_classes.add(node.name.value)
66
+ break
67
+ return True
68
+
69
+ def _get_base_class_name(self, node: cst.BaseExpression) -> str:
70
+ """Get the name of a base class from its AST node."""
71
+ if isinstance(node, cst.Name):
72
+ return node.value
73
+ if isinstance(node, cst.Attribute):
74
+ return f"{self._get_base_class_name(node.value)}.{node.attr.value}"
75
+ if isinstance(node, cst.Subscript):
76
+ # Handle Generic[T] style - get the base
77
+ return self._get_base_class_name(node.value)
78
+ return ""
79
+
80
+ def visit_Assign(self, node: cst.Assign) -> bool:
81
+ """Track assignments of Pydantic model instances to variables."""
82
+ # Check if the value is a call to a Pydantic model class
83
+ if isinstance(node.value, cst.Call):
84
+ class_name = self._get_call_func_name(node.value.func)
85
+ if class_name in self._pydantic_model_classes:
86
+ # Track all assigned variable names
87
+ for target in node.targets:
88
+ if isinstance(target.target, cst.Name):
89
+ self._pydantic_instance_vars.add(target.target.value)
90
+ return True
91
+
92
+ def visit_AnnAssign(self, node: cst.AnnAssign) -> bool:
93
+ """Track annotated assignments with Pydantic model type hints."""
94
+ if isinstance(node.target, cst.Name):
95
+ type_name = self._get_annotation_name(node.annotation.annotation)
96
+ if type_name in self._pydantic_model_classes:
97
+ self._pydantic_instance_vars.add(node.target.value)
98
+ return True
99
+
100
+ def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
101
+ """Track function parameters with Pydantic model type annotations."""
102
+ for param in node.params.params:
103
+ if param.annotation is not None:
104
+ type_name = self._get_annotation_name(param.annotation.annotation)
105
+ if type_name in self._pydantic_model_classes:
106
+ self._pydantic_param_vars.add(param.name.value)
31
107
  return True
32
108
 
109
+ def leave_FunctionDef_params(self, node: cst.FunctionDef) -> None:
110
+ """Clear function-scoped parameter tracking when leaving function."""
111
+ # Note: This is a simplified approach - ideally we'd use proper scope analysis
112
+ pass
113
+
114
+ def _get_call_func_name(self, node: cst.BaseExpression) -> str:
115
+ """Get the function/class name from a Call's func attribute."""
116
+ if isinstance(node, cst.Name):
117
+ return node.value
118
+ if isinstance(node, cst.Attribute):
119
+ return node.attr.value # Return just the class name part
120
+ return ""
121
+
122
+ def _get_annotation_name(self, node: cst.BaseExpression) -> str:
123
+ """Extract the type name from a type annotation."""
124
+ if isinstance(node, cst.Name):
125
+ return node.value
126
+ if isinstance(node, cst.Attribute):
127
+ return node.attr.value # Return just the class name part
128
+ if isinstance(node, cst.Subscript):
129
+ # Handle Optional[Model], List[Model], etc.
130
+ return self._get_annotation_name(node.value)
131
+ return ""
132
+
133
+ def _is_pydantic_instance(self, node: cst.BaseExpression) -> bool:
134
+ """Check if an expression is known to be a Pydantic model instance.
135
+
136
+ Returns True if we can confirm it's a Pydantic instance.
137
+ Returns False if we cannot confirm (either unknown or definitely not Pydantic).
138
+ """
139
+ if isinstance(node, cst.Name):
140
+ var_name = node.value
141
+ # Check if it's a known Pydantic instance variable
142
+ if var_name in self._pydantic_instance_vars:
143
+ return True
144
+ # Check if it's a function parameter with Pydantic type hint
145
+ if var_name in self._pydantic_param_vars:
146
+ return True
147
+ # Heuristic: variable name matches a model class name (case-insensitive)
148
+ for model_class in self._pydantic_model_classes:
149
+ if var_name.lower() == model_class.lower():
150
+ return True
151
+ return False
152
+ if isinstance(node, cst.Call):
153
+ # Direct call like Model().json() - check if the function is a Pydantic class
154
+ func_name = self._get_call_func_name(node.func)
155
+ return func_name in self._pydantic_model_classes
156
+ if isinstance(node, cst.Attribute):
157
+ # Could be accessing an attribute that returns a Pydantic model
158
+ # This is harder to determine without full type analysis
159
+ return False
160
+ return False
161
+
162
+ def _is_class_method_call(self, node: cst.BaseExpression) -> bool:
163
+ """Check if this is a call on a class rather than an instance (e.g., Model.parse_obj).
164
+
165
+ Class methods like parse_obj, schema, etc. are called on the class itself.
166
+ """
167
+ if isinstance(node, cst.Name):
168
+ # Check if the name is a known Pydantic model class
169
+ return node.value in self._pydantic_model_classes
170
+ return False
171
+
33
172
  def leave_ClassDef(
34
173
  self, original_node: cst.ClassDef, updated_node: cst.ClassDef
35
174
  ) -> cst.ClassDef:
@@ -372,11 +511,17 @@ class PydanticV1ToV2Transformer(BaseTransformer):
372
511
  # Handle method calls on objects
373
512
  if isinstance(updated_node.func, cst.Attribute):
374
513
  method_name = updated_node.func.attr.value
514
+ obj = updated_node.func.value
375
515
 
376
- method_mappings = {
516
+ # Methods that can only be called on instances
517
+ instance_method_mappings = {
377
518
  "dict": "model_dump",
378
519
  "json": "model_dump_json",
379
520
  "copy": "model_copy",
521
+ }
522
+
523
+ # Methods that are typically called on the class (class methods)
524
+ class_method_mappings = {
380
525
  "parse_obj": "model_validate",
381
526
  "parse_raw": "model_validate_json",
382
527
  "schema": "model_json_schema",
@@ -384,19 +529,40 @@ class PydanticV1ToV2Transformer(BaseTransformer):
384
529
  "update_forward_refs": "model_rebuild",
385
530
  }
386
531
 
387
- if method_name in method_mappings:
388
- new_method = method_mappings[method_name]
389
- new_attr = updated_node.func.with_changes(attr=cst.Name(new_method))
532
+ # Handle instance methods - need to verify the object is a Pydantic instance
533
+ if method_name in instance_method_mappings:
534
+ # Only transform if we can confirm this is a Pydantic model instance
535
+ if self._is_pydantic_instance(obj):
536
+ new_method = instance_method_mappings[method_name]
537
+ new_attr = updated_node.func.with_changes(attr=cst.Name(new_method))
390
538
 
391
- self.record_change(
392
- description=f"Convert .{method_name}() to .{new_method}()",
393
- line_number=1,
394
- original=f".{method_name}()",
395
- replacement=f".{new_method}()",
396
- transform_name=f"{method_name}_to_{new_method}",
397
- )
539
+ self.record_change(
540
+ description=f"Convert .{method_name}() to .{new_method}()",
541
+ line_number=1,
542
+ original=f".{method_name}()",
543
+ replacement=f".{new_method}()",
544
+ transform_name=f"{method_name}_to_{new_method}",
545
+ )
546
+
547
+ return updated_node.with_changes(func=new_attr)
548
+ # If we can't confirm it's a Pydantic instance, skip transformation
549
+ # This prevents false positives like response.json() on requests.Response
398
550
 
399
- return updated_node.with_changes(func=new_attr)
551
+ # Handle class methods - verify the object is a Pydantic model class
552
+ if method_name in class_method_mappings:
553
+ if self._is_class_method_call(obj):
554
+ new_method = class_method_mappings[method_name]
555
+ new_attr = updated_node.func.with_changes(attr=cst.Name(new_method))
556
+
557
+ self.record_change(
558
+ description=f"Convert .{method_name}() to .{new_method}()",
559
+ line_number=1,
560
+ original=f".{method_name}()",
561
+ replacement=f".{new_method}()",
562
+ transform_name=f"{method_name}_to_{new_method}",
563
+ )
564
+
565
+ return updated_node.with_changes(func=new_attr)
400
566
 
401
567
  # Handle Field(regex=...) -> Field(pattern=...)
402
568
  if isinstance(updated_node.func, cst.Name) and updated_node.func.value == "Field":
@@ -461,17 +627,20 @@ class PydanticV1ToV2Transformer(BaseTransformer):
461
627
  }
462
628
 
463
629
  if attr_name in attr_mappings:
464
- new_attr = attr_mappings[attr_name]
630
+ # Only transform if the object is a known Pydantic model class
631
+ obj = updated_node.value
632
+ if self._is_class_method_call(obj):
633
+ new_attr = attr_mappings[attr_name]
465
634
 
466
- self.record_change(
467
- description=f"Convert {attr_name} to {new_attr}",
468
- line_number=1,
469
- original=attr_name,
470
- replacement=new_attr,
471
- transform_name=f"{attr_name}_rename",
472
- )
635
+ self.record_change(
636
+ description=f"Convert {attr_name} to {new_attr}",
637
+ line_number=1,
638
+ original=attr_name,
639
+ replacement=new_attr,
640
+ transform_name=f"{attr_name}_rename",
641
+ )
473
642
 
474
- return updated_node.with_changes(attr=cst.Name(new_attr))
643
+ return updated_node.with_changes(attr=cst.Name(new_attr))
475
644
 
476
645
  return updated_node
477
646
 
@@ -6,6 +6,17 @@ from pathlib import Path
6
6
  import libcst as cst
7
7
  from libcst.metadata import MetadataWrapper, PositionProvider
8
8
 
9
+ # Mapping of package names to their actual import names
10
+ # Some packages have different import names than their package names
11
+ PACKAGE_IMPORT_ALIASES: dict[str, list[str]] = {
12
+ "attrs": ["attr", "attrs"], # attrs package can be imported as "attr" or "attrs"
13
+ "pillow": ["PIL"], # pillow package is imported as PIL
14
+ "scikit-learn": ["sklearn"], # scikit-learn is imported as sklearn
15
+ "beautifulsoup4": ["bs4"], # beautifulsoup4 is imported as bs4
16
+ "pyyaml": ["yaml"], # pyyaml is imported as yaml
17
+ "python-dateutil": ["dateutil"], # python-dateutil is imported as dateutil
18
+ }
19
+
9
20
 
10
21
  @dataclass
11
22
  class ImportInfo:
@@ -53,15 +64,24 @@ class ImportVisitor(cst.CSTVisitor):
53
64
 
54
65
  def __init__(self, target_library: str):
55
66
  self.target_library = target_library
67
+ # Get all possible import names for this library
68
+ self.import_names = PACKAGE_IMPORT_ALIASES.get(target_library.lower(), [target_library])
56
69
  self.imports: list[ImportInfo] = []
57
70
  self._imported_names: set[str] = set()
58
71
 
72
+ def _matches_target_library(self, module_name: str) -> bool:
73
+ """Check if a module name matches the target library or its aliases."""
74
+ for import_name in self.import_names:
75
+ if module_name == import_name or module_name.startswith(f"{import_name}."):
76
+ return True
77
+ return False
78
+
59
79
  def visit_Import(self, node: cst.Import) -> None:
60
80
  """Visit import statements like 'import pydantic'."""
61
81
  for name in node.names if isinstance(node.names, tuple) else []:
62
82
  if isinstance(name, cst.ImportAlias):
63
83
  module_name = self._get_name_value(name.name)
64
- if module_name and module_name.startswith(self.target_library):
84
+ if module_name and self._matches_target_library(module_name):
65
85
  alias = None
66
86
  if name.asname and isinstance(name.asname, cst.AsName):
67
87
  alias = self._get_name_value(name.asname.name)
@@ -84,7 +104,7 @@ class ImportVisitor(cst.CSTVisitor):
84
104
  return
85
105
 
86
106
  module_name = self._get_name_value(node.module)
87
- if not module_name or not module_name.startswith(self.target_library):
107
+ if not module_name or not self._matches_target_library(module_name):
88
108
  return
89
109
 
90
110
  names = []
@@ -4,12 +4,109 @@ This client calls the Codeshift API instead of Anthropic directly,
4
4
  ensuring that LLM features are gated behind the subscription model.
5
5
  """
6
6
 
7
+ import logging
7
8
  from dataclasses import dataclass
9
+ from urllib.parse import urlparse
8
10
 
9
11
  import httpx
10
12
 
11
13
  from codeshift.cli.commands.auth import get_api_key, get_api_url
12
14
 
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class InsecureURLError(Exception):
19
+ """Raised when an insecure (non-HTTPS) URL is used for API communication.
20
+
21
+ HTTPS is required to protect API keys and sensitive data in transit.
22
+ Man-in-the-middle attacks could intercept API keys if HTTP is used.
23
+ """
24
+
25
+ def __init__(self, url: str, message: str | None = None):
26
+ self.url = url
27
+ default_msg = (
28
+ f"Insecure URL detected: {url}. "
29
+ "HTTPS is required for API communication to protect your API key. "
30
+ "Use HTTPS or set CODESHIFT_ALLOW_INSECURE=true for local development only."
31
+ )
32
+ super().__init__(message or default_msg)
33
+
34
+
35
+ def validate_api_url(url: str) -> str:
36
+ """Validate and normalize the API URL.
37
+
38
+ Enforces HTTPS for all non-localhost hosts to prevent API key interception.
39
+
40
+ Args:
41
+ url: The API URL to validate
42
+
43
+ Returns:
44
+ The validated and normalized URL
45
+
46
+ Raises:
47
+ InsecureURLError: If the URL uses HTTP for a non-localhost host
48
+ ValueError: If the URL is malformed
49
+ """
50
+ if not url:
51
+ raise ValueError("API URL cannot be empty")
52
+
53
+ # Parse the URL
54
+ try:
55
+ parsed = urlparse(url)
56
+ except Exception as e:
57
+ raise ValueError(f"Malformed URL: {url}") from e
58
+
59
+ if not parsed.scheme:
60
+ raise ValueError(f"URL must include a scheme (http/https): {url}")
61
+
62
+ if not parsed.netloc:
63
+ raise ValueError(f"URL must include a host: {url}")
64
+
65
+ # Define localhost patterns
66
+ localhost_patterns = (
67
+ "localhost",
68
+ "127.0.0.1",
69
+ "::1",
70
+ "0.0.0.0",
71
+ )
72
+
73
+ host = parsed.hostname or ""
74
+ is_localhost = any(
75
+ host == pattern or host.startswith(f"{pattern}:") for pattern in localhost_patterns
76
+ )
77
+
78
+ # Allow HTTP only for localhost (development)
79
+ if parsed.scheme == "http":
80
+ # Check for explicit override (development only)
81
+ import os
82
+
83
+ allow_insecure = os.environ.get("CODESHIFT_ALLOW_INSECURE", "").lower() == "true"
84
+
85
+ if is_localhost:
86
+ logger.warning(
87
+ "Using HTTP for localhost development. " "This should not be used in production."
88
+ )
89
+ elif allow_insecure:
90
+ logger.warning(
91
+ "SECURITY WARNING: CODESHIFT_ALLOW_INSECURE is set. "
92
+ "HTTP is being used for API communication. "
93
+ "Your API key may be exposed to network interception. "
94
+ "This should ONLY be used for local testing."
95
+ )
96
+ else:
97
+ raise InsecureURLError(
98
+ url,
99
+ f"HTTP is not allowed for non-localhost hosts: {host}. "
100
+ "Use HTTPS to protect your API key from interception.",
101
+ )
102
+
103
+ # Validate HTTPS URLs
104
+ if parsed.scheme not in ("http", "https"):
105
+ raise ValueError(f"URL scheme must be http or https, got: {parsed.scheme}")
106
+
107
+ # Remove trailing slash for consistency
108
+ return url.rstrip("/")
109
+
13
110
 
14
111
  @dataclass
15
112
  class APIResponse:
@@ -30,24 +127,46 @@ class CodeshiftAPIClient:
30
127
  - Authentication and authorization
31
128
  - Quota checking and billing
32
129
  - Server-side Anthropic API calls
130
+
131
+ Security features:
132
+ - HTTPS enforcement for all non-localhost URLs
133
+ - API key protection via secure headers
134
+ - SSL verification enabled by default
33
135
  """
34
136
 
35
137
  def __init__(
36
138
  self,
37
139
  api_key: str | None = None,
38
140
  api_url: str | None = None,
39
- timeout: int = 60,
141
+ timeout: int = 180,
142
+ verify_ssl: bool = True,
40
143
  ):
41
144
  """Initialize the API client.
42
145
 
43
146
  Args:
44
147
  api_key: Codeshift API key. Defaults to stored credentials.
45
148
  api_url: API base URL. Defaults to stored URL.
46
- timeout: Request timeout in seconds.
149
+ timeout: Request timeout in seconds (default 180 for LLM calls).
150
+ verify_ssl: Whether to verify SSL certificates (default True).
151
+
152
+ Raises:
153
+ InsecureURLError: If the URL uses HTTP for a non-localhost host.
47
154
  """
48
155
  self.api_key = api_key or get_api_key()
49
- self.api_url = api_url or get_api_url()
156
+
157
+ # Validate and normalize the API URL
158
+ raw_url = api_url or get_api_url()
159
+ self.api_url = validate_api_url(raw_url)
160
+
50
161
  self.timeout = timeout
162
+ self.verify_ssl = verify_ssl
163
+
164
+ # Log SSL verification status
165
+ if not verify_ssl:
166
+ logger.warning(
167
+ "SSL verification is disabled. "
168
+ "This exposes the connection to man-in-the-middle attacks."
169
+ )
51
170
 
52
171
  @property
53
172
  def is_available(self) -> bool:
@@ -76,9 +195,13 @@ class CodeshiftAPIClient:
76
195
 
77
196
  return httpx.post(
78
197
  f"{self.api_url}{endpoint}",
79
- headers={"X-API-Key": self.api_key},
198
+ headers={
199
+ "X-API-Key": self.api_key,
200
+ "Content-Type": "application/json",
201
+ },
80
202
  json=payload,
81
203
  timeout=self.timeout,
204
+ verify=self.verify_ssl,
82
205
  )
83
206
 
84
207
  def migrate_code(
@@ -157,6 +280,15 @@ class CodeshiftAPIClient:
157
280
  error="LLM migrations require Pro tier or higher. Run 'codeshift upgrade-plan' to upgrade.",
158
281
  )
159
282
 
283
+ elif response.status_code == 429:
284
+ # Rate limited
285
+ retry_after = response.headers.get("Retry-After", "60")
286
+ return APIResponse(
287
+ success=False,
288
+ content=code,
289
+ error=f"Rate limited. Please wait {retry_after} seconds before retrying.",
290
+ )
291
+
160
292
  elif response.status_code == 503:
161
293
  return APIResponse(
162
294
  success=False,
@@ -233,6 +365,14 @@ class CodeshiftAPIClient:
233
365
  error="This feature requires Pro tier or higher.",
234
366
  )
235
367
 
368
+ elif response.status_code == 429:
369
+ retry_after = response.headers.get("Retry-After", "60")
370
+ return APIResponse(
371
+ success=False,
372
+ content="",
373
+ error=f"Rate limited. Please wait {retry_after} seconds before retrying.",
374
+ )
375
+
236
376
  else:
237
377
  return APIResponse(
238
378
  success=False,