Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import java.util.ArrayList;
import java.util.List;
import java.util.Map;

import io.modelcontextprotocol.util.Assert;

Expand Down Expand Up @@ -47,27 +46,18 @@ private DefaultServerTransportSecurityValidator(List<String> allowedOrigins, Lis
}

@Override
public void validateHeaders(Map<String, List<String>> headers) throws ServerTransportSecurityException {
boolean missingHost = true;
for (Map.Entry<String, List<String>> entry : headers.entrySet()) {
if (ORIGIN_HEADER.equalsIgnoreCase(entry.getKey())) {
List<String> values = entry.getValue();
if (values == null || values.isEmpty()) {
throw new ServerTransportSecurityException(403, "Invalid Origin header");
}
validateOrigin(values.get(0));
}
else if (HOST_HEADER.equalsIgnoreCase(entry.getKey())) {
missingHost = false;
List<String> values = entry.getValue();
if (values == null || values.isEmpty()) {
throw new ServerTransportSecurityException(421, "Invalid Host header");
}
validateHost(values.get(0));
}
public void validateHeaders(HeaderAccessor accessor) throws ServerTransportSecurityException {
List<String> originValues = accessor.getHeader(ORIGIN_HEADER);
if (originValues != null && !originValues.isEmpty()) {
validateOrigin(originValues.get(0));
}
if (!allowedHosts.isEmpty() && missingHost) {
throw new ServerTransportSecurityException(421, "Invalid Host header");

if (!allowedHosts.isEmpty()) {
List<String> hostValues = accessor.getHeader(HOST_HEADER);
if (hostValues == null || hostValues.isEmpty()) {
throw new ServerTransportSecurityException(421, "Invalid Host header");
}
validateHost(hostValues.get(0));
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Copyright 2026-2026 the original author or authors.
*/

package io.modelcontextprotocol.server.transport;

import java.util.List;

/**
* Abstraction for accessing HTTP headers from an incoming request. Implementations should
* provide case-insensitive header name lookups (e.g., when backed by
* {@code HttpServletRequest}).
*
* @author Neeraj Bhatt
* @since 0.16.0
* @see ServerTransportSecurityValidator
*/
public interface HeaderAccessor {

/**
* Returns the values of the specified header, or an empty list if the header is not
* present.
* @param name the header name (case-insensitive)
* @return the list of header values, never {@code null}
*/
List<String> getHeader(String name);

/**
* Returns all header names present in the request.
* @return the list of header names, never {@code null}
*/
List<String> getHeaderNames();

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Copyright 2026-2026 the original author or authors.
*/

package io.modelcontextprotocol.server.transport;

import java.util.Collections;
import java.util.List;

import jakarta.servlet.http.HttpServletRequest;

/**
* {@link HeaderAccessor} implementation backed by an {@link HttpServletRequest}. Header
* name lookups are case-insensitive as per the Servlet specification.
*
* <p>
* For internal use only.
*
* @author Neeraj Bhatt
* @since 0.16.0
* @see HeaderAccessor
*/
final class HttpServletHeaderAccessor implements HeaderAccessor {

private final HttpServletRequest request;

HttpServletHeaderAccessor(HttpServletRequest request) {
this.request = request;
}

@Override
public List<String> getHeader(String name) {
return Collections.list(this.request.getHeaders(name));
}

@Override
public List<String> getHeaderNames() {
return Collections.list(this.request.getHeaderNames());
}

}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -280,8 +280,7 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response)
}

try {
Map<String, List<String>> headers = HttpServletRequestUtils.extractHeaders(request);
this.securityValidator.validateHeaders(headers);
this.securityValidator.validateHeaders(new HttpServletHeaderAccessor(request));
}
catch (ServerTransportSecurityException e) {
response.sendError(e.getStatusCode(), e.getMessage());
Expand Down Expand Up @@ -353,8 +352,7 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response)
}

try {
Map<String, List<String>> headers = HttpServletRequestUtils.extractHeaders(request);
this.securityValidator.validateHeaders(headers);
this.securityValidator.validateHeaders(new HttpServletHeaderAccessor(request));
}
catch (ServerTransportSecurityException e) {
response.sendError(e.getStatusCode(), e.getMessage());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import java.io.IOException;
import java.io.PrintWriter;
import java.util.List;
import java.util.Map;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -134,8 +133,7 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response)
}

try {
Map<String, List<String>> headers = HttpServletRequestUtils.extractHeaders(request);
this.securityValidator.validateHeaders(headers);
this.securityValidator.validateHeaders(new HttpServletHeaderAccessor(request));
}
catch (ServerTransportSecurityException e) {
response.sendError(e.getStatusCode(), e.getMessage());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.locks.ReentrantLock;

Expand Down Expand Up @@ -271,8 +270,7 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response)
}

try {
Map<String, List<String>> headers = HttpServletRequestUtils.extractHeaders(request);
this.securityValidator.validateHeaders(headers);
this.securityValidator.validateHeaders(new HttpServletHeaderAccessor(request));
}
catch (ServerTransportSecurityException e) {
response.sendError(e.getStatusCode(), e.getMessage());
Expand Down Expand Up @@ -407,8 +405,7 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response)
}

try {
Map<String, List<String>> headers = HttpServletRequestUtils.extractHeaders(request);
this.securityValidator.validateHeaders(headers);
this.securityValidator.validateHeaders(new HttpServletHeaderAccessor(request));
}
catch (ServerTransportSecurityException e) {
response.sendError(e.getStatusCode(), e.getMessage());
Expand Down Expand Up @@ -588,8 +585,7 @@ protected void doDelete(HttpServletRequest request, HttpServletResponse response
}

try {
Map<String, List<String>> headers = HttpServletRequestUtils.extractHeaders(request);
this.securityValidator.validateHeaders(headers);
this.securityValidator.validateHeaders(new HttpServletHeaderAccessor(request));
}
catch (ServerTransportSecurityException e) {
response.sendError(e.getStatusCode(), e.getMessage());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,33 +4,101 @@

package io.modelcontextprotocol.server.transport;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

/**
* Interface for validating HTTP requests in server transports. Implementations can
* validate Origin headers, Host headers, or any other security-related headers according
* to the MCP specification.
*
* <p>
* New implementations should override {@link #validateHeaders(HeaderAccessor)
* validateHeaders(HeaderAccessor)} for more efficient, case-insensitive header access.
* The older {@link #validateHeaders(Map) validateHeaders(Map)} is deprecated and will be
* removed in a future major version.
*
* @author Daniel Garnier-Moiroux
* @see DefaultServerTransportSecurityValidator
* @see ServerTransportSecurityException
*/
@FunctionalInterface
public interface ServerTransportSecurityValidator {

/**
* A no-op validator that accepts all requests without validation.
*/
ServerTransportSecurityValidator NOOP = headers -> {
ServerTransportSecurityValidator NOOP = new ServerTransportSecurityValidator() {
@Override
public void validateHeaders(Map<String, List<String>> headers) throws ServerTransportSecurityException {
}

@Override
public void validateHeaders(HeaderAccessor accessor) throws ServerTransportSecurityException {
}
};

/**
* Validates the HTTP headers from an incoming request.
*
* <p>
* The default implementation converts the map into a {@link HeaderAccessor} and
* delegates to {@link #validateHeaders(HeaderAccessor)}.
* @param headers A map of header names to their values (multi-valued headers
* supported)
* @throws ServerTransportSecurityException if validation fails
* @deprecated Use {@link #validateHeaders(HeaderAccessor)} instead for more
* efficient, case-insensitive header access. This method will be removed in a future
* major version.
*/
@Deprecated
default void validateHeaders(Map<String, List<String>> headers) throws ServerTransportSecurityException {
validateHeaders(new HeaderAccessor() {
@Override
public List<String> getHeader(String name) {
return headers.entrySet()
.stream()
.filter(e -> e.getKey().equalsIgnoreCase(name))
.map(Map.Entry::getValue)
.findFirst()
.orElse(List.of());
}

@Override
public List<String> getHeaderNames() {
return List.copyOf(headers.keySet());
}
});
}

/**
* Validates the HTTP headers from an incoming request using a {@link HeaderAccessor}.
*
* <p>
* New implementations should override this method. Header name lookup through the
* accessor should be case-insensitive (e.g., when backed by
* {@code HttpServletRequest}).
*
* <p>
* The default implementation collects all headers from the accessor into a
* {@link Map} and delegates to the deprecated {@link #validateHeaders(Map)} method,
* so that existing implementations that only override {@link #validateHeaders(Map)}
* continue to work.
* @param accessor provides access to request headers
* @throws ServerTransportSecurityException if validation fails
*/
void validateHeaders(Map<String, List<String>> headers) throws ServerTransportSecurityException;
default void validateHeaders(HeaderAccessor accessor) throws ServerTransportSecurityException {
var collectedHeaders = accessor.getHeaderNames()
.stream()
.collect(Collectors.<String, String, List<String>>toUnmodifiableMap(String::toLowerCase,
accessor::getHeader, (l1, l2) -> {
var merged = new ArrayList<>(l1);
merged.addAll(l2);
return Collections.unmodifiableList(merged);
}));
validateHeaders(collectedHeaders);
}

}
Loading