Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mellea/stdlib/requirements/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
MatplotlibHeadlessBackend,
PlotDependenciesAvailable,
PlotFileSaved,
python_plotting_requirements,
)
from .python_reqs import PythonExecutionReq
from .python_tools import (
Expand Down Expand Up @@ -49,6 +50,7 @@
"is_markdown_list",
"is_markdown_table",
"python_code_generation_requirements",
"python_plotting_requirements",
"req",
"reqify",
"requirement_check_to_bool",
Expand Down
8 changes: 7 additions & 1 deletion mellea/stdlib/requirements/plotting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@
MatplotlibHeadlessBackend,
PlotDependenciesAvailable,
PlotFileSaved,
python_plotting_requirements,
)

__all__ = ["MatplotlibHeadlessBackend", "PlotDependenciesAvailable", "PlotFileSaved"]
__all__ = [
"MatplotlibHeadlessBackend",
"PlotDependenciesAvailable",
"PlotFileSaved",
"python_plotting_requirements",
]
79 changes: 79 additions & 0 deletions mellea/stdlib/requirements/plotting/matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from mellea.core import Context, Requirement, ValidationResult
from mellea.stdlib.requirements.python_reqs import _has_python_code_listing
from mellea.stdlib.requirements.python_tools import python_code_generation_requirements

# Matplotlib backends suitable for headless (non-interactive) execution.
# Includes standard raster (Agg, Cairo), vector (pdf, svg, pgf),
Expand Down Expand Up @@ -334,3 +335,81 @@ def _validate_dependencies(self, ctx: Context) -> ValidationResult:
return ValidationResult(
result=True, reason="All dependencies available (matplotlib, numpy)."
)


def python_plotting_requirements(
output_path: str,
allowed_imports: list[str] | None = None,
output_limit_chars: int = 10_000,
timeout_seconds: int = 5,
use_sandbox: bool = False,
) -> list[Requirement]:
"""Bundle matplotlib-specific requirements for plotting code validation.

Factory function that creates a complete set of requirements for validating
matplotlib plotting code, composing general Python code generation requirements
with plotting-specific constraints for headless backend configuration, file
output, and dependency availability.

Args:
output_path: File path where the plot should be saved (e.g., '/tmp/plot.png').
This path must match the savefig() call in the generated code.
allowed_imports: Whitelist of importable top-level modules. None allows all.
Default None.
output_limit_chars: Maximum allowed characters of captured stdout.
Default 10,000.
timeout_seconds: Maximum execution time in seconds. Default 5.
use_sandbox: Use llm-sandbox for Docker-isolated execution. Default False.

Returns:
list[Requirement]: Seven requirements in validation order:
1-4. PythonCodeExtraction, PythonSyntaxValid, PythonExecutionReq,
ImportRestrictions/NoImportRestrictions (from python_code_generation_requirements)
5. MatplotlibHeadlessBackend — validates headless backend configuration
6. PlotFileSaved — validates plot is saved to the specified output_path
7. PlotDependenciesAvailable — validates matplotlib and numpy are available

Raises:
TypeError: If output_path is not a string.
ValueError: If output_path is empty.

Examples:
>>> output_path = "/tmp/plot.png"
>>> reqs = python_plotting_requirements(output_path=output_path)
>>> len(reqs)
7
>>> isinstance(reqs[4], MatplotlibHeadlessBackend)
True
>>> isinstance(reqs[5], PlotFileSaved)
True
>>> isinstance(reqs[6], PlotDependenciesAvailable)
True
>>> reqs_restricted = python_plotting_requirements(
... output_path=output_path,
... allowed_imports=["matplotlib", "numpy"]
... )
>>> len(reqs_restricted)
7
"""
if not isinstance(output_path, str):
raise TypeError(
f"output_path must be a string, got {type(output_path).__name__}"
)

if not output_path.strip():
raise ValueError("output_path cannot be empty")

requirements = python_code_generation_requirements(
allowed_imports=allowed_imports,
output_limit_chars=output_limit_chars,
timeout_seconds=timeout_seconds,
use_sandbox=use_sandbox,
)
requirements.extend(
[
MatplotlibHeadlessBackend(),
PlotFileSaved(output_path=output_path),
PlotDependenciesAvailable(),
]
)
return requirements
130 changes: 129 additions & 1 deletion test/stdlib/requirements/plotting/test_matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,20 @@

import pytest

from mellea.core import Context, ModelOutputThunk
from mellea.core import Context, ModelOutputThunk, Requirement
from mellea.stdlib.context import ChatContext
from mellea.stdlib.requirements.plotting import (
MatplotlibHeadlessBackend,
PlotDependenciesAvailable,
PlotFileSaved,
python_plotting_requirements,
)
from mellea.stdlib.requirements.python_tools import (
ImportRestrictions,
NoImportRestrictions,
PythonCodeExtraction,
PythonExecutionReq,
PythonSyntaxValid,
)


Expand Down Expand Up @@ -449,3 +457,123 @@ def test_path_as_keyword_with_different_names(self):
result = req.validation_fn(ctx)
# fname= is the correct keyword; savefig has no 'filename' parameter
assert result.as_bool() is True


class TestPythonPlottingRequirementsFactory:
"""Tests for python_plotting_requirements factory function."""

def test_factory_returns_list_of_requirements(self):
"""Test that factory returns a list of seven Requirement instances."""
output_path = "/tmp/plot.png"
reqs = python_plotting_requirements(output_path=output_path)

assert isinstance(reqs, list)
assert len(reqs) == 7
assert all(isinstance(r, Requirement) for r in reqs)

def test_factory_returns_correct_requirement_order(self):
"""Test requirements are in expected order (python tools + plotting)."""
output_path = "/tmp/plot.png"
reqs = python_plotting_requirements(output_path=output_path)

# First 4 from python_code_generation_requirements
assert isinstance(reqs[0], PythonCodeExtraction)
assert isinstance(reqs[1], PythonSyntaxValid)
assert isinstance(reqs[2], PythonExecutionReq)
assert isinstance(
reqs[3], NoImportRestrictions
) # No allowed_imports, so NoImportRestrictions

# Last 3 plotting-specific
assert isinstance(reqs[4], MatplotlibHeadlessBackend)
assert isinstance(reqs[5], PlotFileSaved)
assert isinstance(reqs[6], PlotDependenciesAvailable)

def test_factory_with_allowed_imports(self):
"""Test that allowed_imports parameter creates ImportRestrictions."""
output_path = "/tmp/plot.png"
allowed_imports = ["matplotlib", "numpy"]
reqs = python_plotting_requirements(
output_path=output_path, allowed_imports=allowed_imports
)

assert len(reqs) == 7
assert isinstance(reqs[3], ImportRestrictions)
assert reqs[3].allowed_imports == allowed_imports

def test_factory_propagates_output_path(self):
"""Test that output_path is correctly passed to PlotFileSaved."""
output_path = "/output/my_plot.png"
reqs = python_plotting_requirements(output_path=output_path)

plot_saved_req = reqs[5]
assert isinstance(plot_saved_req, PlotFileSaved)
assert plot_saved_req.output_path == output_path
assert output_path in plot_saved_req.description

def test_factory_with_different_paths(self):
"""Test factory works with various output path formats."""
test_paths = [
"/tmp/plot.png",
"/output/figures/chart.svg",
"plot.pdf",
"/var/tmp/matplotlib_output.jpg",
]

for path in test_paths:
reqs = python_plotting_requirements(output_path=path)
assert len(reqs) == 7
assert reqs[5].output_path == path

def test_factory_raises_on_non_string_path(self):
"""Test that factory raises TypeError for non-string output_path."""
with pytest.raises(TypeError):
python_plotting_requirements(output_path=123)

with pytest.raises(TypeError):
python_plotting_requirements(output_path=None)

with pytest.raises(TypeError):
python_plotting_requirements(output_path=["/tmp/plot.png"])

def test_factory_raises_on_empty_path(self):
"""Test that factory raises ValueError for empty output_path."""
with pytest.raises(ValueError):
python_plotting_requirements(output_path="")

with pytest.raises(ValueError):
python_plotting_requirements(output_path=" ")

def test_factory_can_be_unpacked(self):
"""Test that factory result can be unpacked like in usage examples."""
output_path = "/tmp/plot.png"
reqs = python_plotting_requirements(output_path=output_path)

# Verify unpacking pattern works
unpacked = [*reqs]
assert len(unpacked) == 7
assert unpacked == reqs

def test_factory_requirements_are_independent_instances(self):
"""Test that multiple factory calls create independent requirement instances."""
reqs1 = python_plotting_requirements(output_path="/tmp/plot1.png")
reqs2 = python_plotting_requirements(output_path="/tmp/plot2.png")

# Different instances
assert reqs1[5] is not reqs2[5]
# But different output paths
assert reqs1[5].output_path != reqs2[5].output_path

def test_factory_with_custom_execution_parameters(self):
"""Test that custom execution parameters can be passed through."""
output_path = "/tmp/plot.png"
reqs = python_plotting_requirements(
output_path=output_path,
output_limit_chars=5000,
timeout_seconds=10,
use_sandbox=True,
)

assert len(reqs) == 7
execution_req = reqs[2]
assert isinstance(execution_req, PythonExecutionReq)
Loading