diff --git a/logfire/_internal/cli/__init__.py b/logfire/_internal/cli/__init__.py index cb49be100..7c719f1da 100644 --- a/logfire/_internal/cli/__init__.py +++ b/logfire/_internal/cli/__init__.py @@ -25,6 +25,7 @@ from ..client import LogfireClient from ..config import REGIONS, LogfireCredentials, get_base_url_from_token from ..config_params import ParamManager +from ..server_response import install_logfire_response_hook from ..tracer import SDKTracerProvider from .auth import parse_auth, parse_logout from .prompt import parse_prompt @@ -434,8 +435,9 @@ def log_trace_id(response: requests.Response, context: ContextCarrier, *args: An else: with tracer.start_as_current_span('logfire._internal.cli'), requests.Session() as session: context = get_context() - session.hooks = {'response': functools.partial(log_trace_id, context=context)} + session.hooks = {'response': [functools.partial(log_trace_id, context=context)]} session.headers.update(context) + install_logfire_response_hook(session) namespace._session = session namespace.func(namespace) diff --git a/logfire/_internal/client.py b/logfire/_internal/client.py index 9bd9c77de..a66a35b57 100644 --- a/logfire/_internal/client.py +++ b/logfire/_internal/client.py @@ -10,6 +10,7 @@ from logfire.version import VERSION from .auth import UserToken, UserTokenCollection +from .server_response import ServerResponseCallback, install_logfire_response_hook from .utils import UnexpectedResponse UA_HEADER = f'logfire/{VERSION}' @@ -29,18 +30,29 @@ class LogfireClient: Args: user_token: The user token to use when authenticating against the API. + server_response_hook: Optional override for the API response hook (see + `AdvancedOptions.server_response_hook`). """ - def __init__(self, user_token: UserToken) -> None: + def __init__( + self, + user_token: UserToken, + server_response_hook: ServerResponseCallback | None = None, + ) -> None: if user_token.is_expired: raise RuntimeError('The provided user token is expired') self.base_url = user_token.base_url self._token = user_token.token self._session = Session() self._session.headers.update({'Authorization': self._token, 'User-Agent': UA_HEADER}) + install_logfire_response_hook(self._session, server_response_hook) @classmethod - def from_url(cls, base_url: str | None) -> Self: + def from_url( + cls, + base_url: str | None, + server_response_hook: ServerResponseCallback | None = None, + ) -> Self: """Create a client from the provided base URL. Args: @@ -48,8 +60,13 @@ def from_url(cls, base_url: str | None) -> Self: the user into selecting a token from the token collection (or, if only one available, use it directly). The token collection will be created from the `~/.logfire/default.toml` file (or an empty one if no such file exists). + server_response_hook: Optional override for the API response hook (see + `AdvancedOptions.server_response_hook`). """ - return cls(user_token=UserTokenCollection().get_token(base_url)) + return cls( + user_token=UserTokenCollection().get_token(base_url), + server_response_hook=server_response_hook, + ) def _get_raw(self, endpoint: str, params: dict[str, Any] | None = None) -> Response: response = self._session.get(urljoin(self.base_url, endpoint), params=params) diff --git a/logfire/_internal/config.py b/logfire/_internal/config.py index 98e73d46b..05c55488c 100644 --- a/logfire/_internal/config.py +++ b/logfire/_internal/config.py @@ -110,6 +110,7 @@ from .logs import ProxyLoggerProvider from .metrics import ProxyMeterProvider from .scrubbing import NOOP_SCRUBBER, BaseScrubber, Scrubber, ScrubbingOptions +from .server_response import ServerResponseCallback, install_logfire_response_hook from .stack_info import warn_at_user_stacklevel from .tracer import OPEN_SPANS, PendingSpanProcessor, ProxyTracerProvider from .utils import ( @@ -216,6 +217,32 @@ class AdvancedOptions: This log and configuration is experimental and may be modified or removed. """ + server_response_hook: ServerResponseCallback | None = None + """Optional callback invoked for every HTTP response received from the Logfire API. + + This is experimental and may be modified or removed. + + This applies to OTLP exports, credential / project initialisation, and the remote + variables provider. The default surfaces the `X-Logfire-Warning` header as a + `LogfireServerWarning`. + + Setting this replaces the default; pass `lambda response: None` to opt out entirely. + + Example usage: + + ```python skip-run="true" skip-reason="needs metric/logfire setup" + from logfire.types import ServerResponseCallbackHelper + + def hook(helper: ServerResponseCallbackHelper): + my_metric.inc(helper.response.status_code) + helper.default_hook() # call this to keep the default warning behavior + + logfire.configure(advanced=AdvancedOptions(server_response_hook=hook)) + ``` + + Raise from the hook to abort the calling code path. + """ + def generate_base_url(self, token: str) -> str: if self.base_url is not None: return self.base_url @@ -1097,7 +1124,7 @@ def add_span_processor(span_processor: SpanProcessor) -> None: # If we don't have tokens or credentials from a file, # try initializing a new project and writing a new creds file. # note, we only do this if `send_to_logfire` is explicitly `True`, not 'if-token-present' - client = LogfireClient.from_url(self.advanced.base_url) + client = LogfireClient.from_url(self.advanced.base_url, self.advanced.server_response_hook) credentials = LogfireCredentials.initialize_project(client=client) credentials.write_creds_file(self.data_dir) @@ -1148,6 +1175,7 @@ def check_tokens(): base_url = self.advanced.generate_base_url(token) headers = {'User-Agent': f'logfire/{VERSION}', 'Authorization': token} session = OTLPExporterHttpSession() + install_logfire_response_hook(session, self.advanced.server_response_hook) span_exporter = BodySizeCheckingOTLPSpanExporter( endpoint=urljoin(base_url, '/v1/traces'), session=session, @@ -1324,6 +1352,7 @@ def fix_pid(): # pragma: no cover base_url=base_url, token=self.api_key, options=self.variables, + server_response_hook=self.advanced.server_response_hook, ) multi_log_processor = SynchronousMultiLogRecordProcessor() for processor in log_record_processors: @@ -1456,6 +1485,7 @@ def _lazy_init_variable_provider(self) -> VariableProvider: base_url=base_url, token=api_key, options=options, + server_response_hook=self.advanced.server_response_hook, ) self._variable_provider = provider provider.start(Logfire(config=self)) @@ -1472,7 +1502,9 @@ def warn_if_not_initialized(self, message: str): ) def _initialize_credentials_from_token(self, token: str) -> LogfireCredentials | None: - return LogfireCredentials.from_token(token, requests.Session(), self.advanced.generate_base_url(token)) + session = requests.Session() + install_logfire_response_hook(session, self.advanced.server_response_hook) + return LogfireCredentials.from_token(token, session, self.advanced.generate_base_url(token)) def _ensure_flush_after_aws_lambda(self): """Ensure that `force_flush` is called after an AWS Lambda invocation. diff --git a/logfire/_internal/server_response.py b/logfire/_internal/server_response.py new file mode 100644 index 000000000..3001bdadc --- /dev/null +++ b/logfire/_internal/server_response.py @@ -0,0 +1,44 @@ +"""Surface out-of-band signals the Logfire backend wants every SDK request to know about. + +The server attaches the `X-Logfire-Warning` header to API responses to signal an +out-of-band warning the server wants the user to see. It is surfaced via +`warnings.warn(..., LogfireServerWarning)`. Python's standard "default" filter +dedupes identical messages, so a chatty server only warns once. + +`install_logfire_response_hook(session)` wires this into a `requests.Session` as +a response hook so every Logfire-bound HTTP response is inspected. Callers can +pass a custom `hook` to replace the default behavior (see +`AdvancedOptions.server_response_hook`). +""" + +from __future__ import annotations + +from typing import Any + +import requests + +from logfire.types import ServerResponseCallback, ServerResponseCallbackHelper + + +def install_logfire_response_hook( + session: requests.Session, + hook: ServerResponseCallback | None = None, +) -> None: + """Install a `requests` response hook on `session` for every Logfire API response. + + By default, calls `ServerResponseCallbackHelper.default_hook()`, which emits a warning + if the `X-Logfire-Warning` response header is present. + + Pass a custom callable to replace the default behavior (e.g. opt out by passing `lambda _: None`). + """ + + def _hook(response: requests.Response, *args: Any, **kwargs: Any) -> requests.Response: + helper = ServerResponseCallbackHelper(response, args, kwargs) + if hook is not None: + hook(helper) + else: + helper.default_hook() + return response + + response_hooks: list[Any] = session.hooks.setdefault('response', []) + response_hooks.append(_hook) diff --git a/logfire/exceptions.py b/logfire/exceptions.py index 617fba04f..4a3ac54fc 100644 --- a/logfire/exceptions.py +++ b/logfire/exceptions.py @@ -3,3 +3,7 @@ class LogfireConfigError(ValueError): """Error raised when there is a problem with the Logfire configuration.""" + + +class LogfireServerWarning(UserWarning): + """Warning emitted when the Logfire server returns an `X-Logfire-Warning` header on a response.""" diff --git a/logfire/types.py b/logfire/types.py index 1f1561700..75b61ff75 100644 --- a/logfire/types.py +++ b/logfire/types.py @@ -2,7 +2,9 @@ from collections.abc import Callable from dataclasses import dataclass -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any + +import requests from logfire._internal.constants import ( ATTRIBUTES_LOG_LEVEL_NUM_KEY, @@ -11,8 +13,10 @@ LevelName, log_level_attributes, ) +from logfire._internal.stack_info import warn_at_user_stacklevel from logfire._internal.tracer import get_parent_span from logfire._internal.utils import canonicalize_exception_traceback +from logfire.exceptions import LogfireServerWarning if TYPE_CHECKING: from opentelemetry.sdk.trace import ReadableSpan, Span @@ -260,3 +264,40 @@ def my_callback(helper: logfire.types.ExceptionCallbackHelper): helper.no_record_exception() """ + + +@dataclass +class ServerResponseCallbackHelper: + """Helper object passed to the server response callback. + + This is experimental and may change significantly in future releases. + """ + + response: requests.Response + """The raw HTTP response from the Logfire API.""" + + args: tuple[Any, ...] + """Positional arguments passed to the response hook by `requests`.""" + + kwargs: dict[str, Any] + """Keyword arguments passed to the response hook by `requests`.""" + + WARNING_HEADER_NAME = 'X-Logfire-Warning' + + @property + def warning_header(self) -> str | None: + """Value of the Logfire warning header, or `None` if not present.""" + return self.response.headers.get(self.WARNING_HEADER_NAME) + + def default_hook(self) -> None: + """The default hook behavior: emit a `LogfireServerWarning` if the warning header is present.""" + warning_message = self.warning_header + if warning_message: + warn_at_user_stacklevel(warning_message, LogfireServerWarning) + + +ServerResponseCallback = Callable[[ServerResponseCallbackHelper], None] +"""Callable invoked for every Logfire API response received by the SDK. + +This is experimental and may change significantly in future releases. +""" diff --git a/logfire/variables/remote.py b/logfire/variables/remote.py index b780a658f..0de792837 100644 --- a/logfire/variables/remote.py +++ b/logfire/variables/remote.py @@ -17,6 +17,7 @@ from logfire._internal.client import UA_HEADER from logfire._internal.config import VariablesOptions +from logfire._internal.server_response import ServerResponseCallback, install_logfire_response_hook from logfire._internal.utils import UnexpectedResponse from logfire.variables.abstract import ( ResolvedVariable, @@ -54,21 +55,31 @@ class LogfireRemoteVariableProvider(VariableProvider): The threading implementation draws heavily from opentelemetry.sdk._shared_internal.BatchProcessor. """ - def __init__(self, base_url: str, token: str, options: VariablesOptions): + def __init__( + self, + base_url: str, + token: str, + options: VariablesOptions, + server_response_hook: ServerResponseCallback | None = None, + ): """Create a new remote variable provider. Args: base_url: The base URL of the Logfire API. token: Authentication token for the Logfire API. options: Options for retrieving remote variables. + server_response_hook: Optional override for the API response hook + (see `AdvancedOptions.server_response_hook`). """ block_before_first_resolve = options.block_before_first_resolve polling_interval = options.polling_interval self._base_url = base_url self._token = token + self._server_response_hook = server_response_hook self._session = Session() self._session.headers.update({'Authorization': f'bearer {token}', 'User-Agent': UA_HEADER}) + install_logfire_response_hook(self._session, server_response_hook) self._timeout = options.timeout self._block_before_first_fetch = block_before_first_resolve self._polling_interval: timedelta = ( @@ -197,6 +208,7 @@ def _sse_listener(self): # pragma: no cover 'Cache-Control': 'no-cache', } ) + install_logfire_response_hook(sse_session, self._server_response_hook) # Open streaming connection response = sse_session.get(sse_url, stream=True, timeout=(10, None)) diff --git a/tests/test_server_response.py b/tests/test_server_response.py new file mode 100644 index 000000000..503a0f41c --- /dev/null +++ b/tests/test_server_response.py @@ -0,0 +1,115 @@ +from __future__ import annotations + +import warnings + +import requests +import requests_mock +from inline_snapshot import snapshot + +from logfire.exceptions import LogfireServerWarning +from logfire.types import ServerResponseCallbackHelper + + +def test_process_response_warning_header_emits_warning(): + response = requests.Response() + response.headers[ServerResponseCallbackHelper.WARNING_HEADER_NAME] = ( + 'The /foo/bar endpoint is deprecated, please use /bar/baz' + ) + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter('always') + ServerResponseCallbackHelper(response, (), {}).default_hook() + assert [(w.category, str(w.message)) for w in caught] == snapshot( + [(LogfireServerWarning, 'The /foo/bar endpoint is deprecated, please use /bar/baz')] + ) + + +def test_process_response_warning_header_dedupes(): + """Python's default `warnings` filter should fold repeats of the same message into one entry.""" + response = requests.Response() + response.headers[ServerResponseCallbackHelper.WARNING_HEADER_NAME] = 'a duplicated warning' + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter('default') + for _ in range(5): + ServerResponseCallbackHelper(response, (), {}).default_hook() + messages = [str(w.message) for w in caught] + assert messages == ['a duplicated warning'] + + +def test_response_hook_installed_on_logfire_client(): + from logfire._internal.auth import UserToken + from logfire._internal.client import LogfireClient + + token = UserToken( + token='pylf_v1_us_xxx', + base_url='https://logfire-us.pydantic.dev', + expiration='2099-12-31T23:59:59', + ) + client = LogfireClient(user_token=token) + + with requests_mock.Mocker() as m: + m.get( + 'https://logfire-us.pydantic.dev/v1/account/me', + json={'name': 'me'}, + headers={ServerResponseCallbackHelper.WARNING_HEADER_NAME: 'deprecated endpoint'}, + ) + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter('always') + client.get_user_information() + + assert any(isinstance(w.message, LogfireServerWarning) for w in caught) + + +def test_custom_server_response_hook_replaces_default(): + """A custom hook replaces the built-in header processor entirely.""" + from logfire._internal.auth import UserToken + from logfire._internal.client import LogfireClient + + seen: list[requests.Response] = [] + + def my_hook(helper: ServerResponseCallbackHelper) -> None: + seen.append(helper.response) + + token = UserToken( + token='pylf_v1_us_xxx', + base_url='https://logfire-us.pydantic.dev', + expiration='2099-12-31T23:59:59', + ) + client = LogfireClient(user_token=token, server_response_hook=my_hook) + + with requests_mock.Mocker() as m: + m.get( + 'https://logfire-us.pydantic.dev/v1/account/me', + json={'name': 'me'}, + headers={ServerResponseCallbackHelper.WARNING_HEADER_NAME: 'deprecated'}, + ) + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter('always') + client.get_user_information() + + assert len(seen) == 1 + assert not any(isinstance(w.message, LogfireServerWarning) for w in caught) + + +def test_server_response_hook_can_opt_out(): + """`lambda response: None` disables the default warning behavior.""" + from logfire._internal.auth import UserToken + from logfire._internal.client import LogfireClient + + token = UserToken( + token='pylf_v1_us_xxx', + base_url='https://logfire-us.pydantic.dev', + expiration='2099-12-31T23:59:59', + ) + client = LogfireClient(user_token=token, server_response_hook=lambda response: None) + + with requests_mock.Mocker() as m: + m.get( + 'https://logfire-us.pydantic.dev/v1/account/me', + json={'name': 'me'}, + headers={ServerResponseCallbackHelper.WARNING_HEADER_NAME: 'deprecated'}, + ) + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter('always') + assert client.get_user_information() == {'name': 'me'} + + assert not any(isinstance(w.message, LogfireServerWarning) for w in caught)