diff --git a/src/bentoml/_internal/utils/uri.py b/src/bentoml/_internal/utils/uri.py index c3fbb97616c..ae579d38fe9 100644 --- a/src/bentoml/_internal/utils/uri.py +++ b/src/bentoml/_internal/utils/uri.py @@ -57,23 +57,30 @@ def is_http_url(url: str) -> bool: original_create_connection = None +_original_getaddrinfo = None @contextlib.contextmanager def make_safe_connect(): - """Patch loop.create_connection() method to reject unsafe URLs.""" + """Patch network connection to reject requests to private/internal IP addresses. + + On Linux/macOS with uvloop: patches uvloop.Loop.create_connection. + On Windows or without uvloop: patches socket.getaddrinfo as fallback. + """ from urllib.request import getproxies import httpx - from uvloop import Loop from bentoml.exceptions import BadInput - global original_create_connection + try: + from uvloop import Loop + except ImportError: + Loop = None - if original_create_connection is None: - original_create_connection = Loop.create_connection + global original_create_connection + global _original_getaddrinfo # Do not check connections with proxy servers proxies = [ @@ -81,6 +88,40 @@ def make_safe_connect(): for parsed in map(urlparse, getproxies().values()) ] + if Loop is None: + # Fallback for platforms without uvloop (e.g. Windows): + # Patch socket.getaddrinfo to check resolved IPs before connection. + if _original_getaddrinfo is None: + _original_getaddrinfo = socket.getaddrinfo + + def safe_getaddrinfo(host, port, *args, **kwargs): + results = _original_getaddrinfo(host, port, *args, **kwargs) + if host is not None and (host, port) not in proxies: + for family, type_, proto, canonname, sockaddr in results: + try: + ip = ipaddress.ip_address(sockaddr[0]) + except ValueError: + continue + if ip.is_private or ip.is_loopback or ip.is_link_local: + raise socket.gaierror( + f"Blocked private IP address {sockaddr[0]}" + ) + return results + + socket.getaddrinfo = safe_getaddrinfo + try: + yield + except httpx.ConnectError as e: + if "All connection attempts failed" in str(e): + raise BadInput("Connection blocked due to insecure input URL") from e + finally: + socket.getaddrinfo = _original_getaddrinfo + return + + # uvloop available: use original Loop.create_connection patching + if original_create_connection is None: + original_create_connection = Loop.create_connection + @no_type_check async def safe_create_connection( self, protocol_factory, host=None, port=None, **kwargs