Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 46 additions & 5 deletions src/bentoml/_internal/utils/uri.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,30 +57,71 @@ def is_http_url(url: str) -> bool:


original_create_connection = None
_original_getaddrinfo = None


@contextlib.contextmanager
def make_safe_connect():
"""Patch loop.create_connection() method to reject unsafe URLs."""
"""Patch network connection to reject requests to private/internal IP addresses.

On Linux/macOS with uvloop: patches uvloop.Loop.create_connection.
On Windows or without uvloop: patches socket.getaddrinfo as fallback.
"""

from urllib.request import getproxies

import httpx
from uvloop import Loop

from bentoml.exceptions import BadInput

global original_create_connection
try:
from uvloop import Loop
except ImportError:
Loop = None

if original_create_connection is None:
original_create_connection = Loop.create_connection
global original_create_connection
global _original_getaddrinfo

# Do not check connections with proxy servers
proxies = [
(parsed.hostname, parsed.port)
for parsed in map(urlparse, getproxies().values())
]

if Loop is None:
# Fallback for platforms without uvloop (e.g. Windows):
# Patch socket.getaddrinfo to check resolved IPs before connection.
if _original_getaddrinfo is None:
_original_getaddrinfo = socket.getaddrinfo

def safe_getaddrinfo(host, port, *args, **kwargs):
results = _original_getaddrinfo(host, port, *args, **kwargs)
if host is not None and (host, port) not in proxies:
for family, type_, proto, canonname, sockaddr in results:
try:
ip = ipaddress.ip_address(sockaddr[0])
except ValueError:
continue
if ip.is_private or ip.is_loopback or ip.is_link_local:
raise socket.gaierror(
f"Blocked private IP address {sockaddr[0]}"
)
return results

socket.getaddrinfo = safe_getaddrinfo
try:
yield
except httpx.ConnectError as e:
if "All connection attempts failed" in str(e):
raise BadInput("Connection blocked due to insecure input URL") from e
finally:
socket.getaddrinfo = _original_getaddrinfo
return

# uvloop available: use original Loop.create_connection patching
if original_create_connection is None:
original_create_connection = Loop.create_connection

@no_type_check
async def safe_create_connection(
self, protocol_factory, host=None, port=None, **kwargs
Expand Down