Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mellea/stdlib/requirements/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
MatplotlibHeadlessBackend,
PlotDependenciesAvailable,
PlotFileSaved,
python_plotting_requirements,
)
from .python_reqs import PythonExecutionReq
from .python_tools import (
Expand Down Expand Up @@ -49,6 +50,7 @@
"is_markdown_list",
"is_markdown_table",
"python_code_generation_requirements",
"python_plotting_requirements",
"req",
"reqify",
"requirement_check_to_bool",
Expand Down
8 changes: 7 additions & 1 deletion mellea/stdlib/requirements/plotting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@
MatplotlibHeadlessBackend,
PlotDependenciesAvailable,
PlotFileSaved,
python_plotting_requirements,
)

__all__ = ["MatplotlibHeadlessBackend", "PlotDependenciesAvailable", "PlotFileSaved"]
__all__ = [
"MatplotlibHeadlessBackend",
"PlotDependenciesAvailable",
"PlotFileSaved",
"python_plotting_requirements",
]
74 changes: 74 additions & 0 deletions mellea/stdlib/requirements/plotting/matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from mellea.core import Context, Requirement, ValidationResult
from mellea.stdlib.requirements.python_reqs import _has_python_code_listing
from mellea.stdlib.requirements.python_tools import python_code_generation_requirements

# Matplotlib backends suitable for headless (non-interactive) execution.
# Includes standard raster (Agg, Cairo), vector (pdf, svg, pgf),
Expand Down Expand Up @@ -334,3 +335,76 @@ def _validate_dependencies(self, ctx: Context) -> ValidationResult:
return ValidationResult(
result=True, reason="All dependencies available (matplotlib, numpy)."
)


def python_plotting_requirements(
output_path: str,
allowed_imports: list[str] | None = None,
output_limit_chars: int = 10_000,
timeout_seconds: int = 5,
use_sandbox: bool = False,
) -> list[Requirement]:
"""Bundle matplotlib-specific requirements for plotting code validation.

Factory function that creates a complete set of requirements for validating
matplotlib plotting code, composing general Python code generation requirements
with plotting-specific constraints for headless backend configuration, file
output, and dependency availability.

Args:
output_path: File path where the plot should be saved (e.g., '/tmp/plot.png').
This path must match the savefig() call in the generated code.
allowed_imports: Whitelist of importable top-level modules. None allows all.
Default None.
output_limit_chars: Maximum allowed characters of captured stdout.
Default 10,000.
timeout_seconds: Maximum execution time in seconds. Default 5.
use_sandbox: Use llm-sandbox for Docker-isolated execution. Default False.

Returns:
list[Requirement]: Seven requirements in validation order:
1-4. PythonCodeExtraction, PythonSyntaxValid, PythonExecutionReq,
ImportRestrictions/NoImportRestrictions (from python_code_generation_requirements)
5. MatplotlibHeadlessBackend — validates headless backend configuration
6. PlotFileSaved — validates plot is saved to the specified output_path
7. PlotDependenciesAvailable — validates matplotlib and numpy are available

Raises:
ValueError: If output_path is empty or whitespace-only, or if timeout_seconds
or output_limit_chars is not positive.

Examples:
>>> output_path = "/tmp/plot.png"
>>> reqs = python_plotting_requirements(output_path=output_path)
>>> len(reqs)
7
>>> isinstance(reqs[4], MatplotlibHeadlessBackend)
True
>>> isinstance(reqs[5], PlotFileSaved)
True
>>> isinstance(reqs[6], PlotDependenciesAvailable)
True
>>> reqs_restricted = python_plotting_requirements(
... output_path=output_path,
... allowed_imports=["matplotlib", "numpy"]
... )
>>> len(reqs_restricted)
7
"""
if not output_path.strip():
raise ValueError("output_path cannot be empty or whitespace-only")

requirements = python_code_generation_requirements(
allowed_imports=allowed_imports,
output_limit_chars=output_limit_chars,
timeout_seconds=timeout_seconds,
use_sandbox=use_sandbox,
)
requirements.extend(
[
MatplotlibHeadlessBackend(),
PlotFileSaved(output_path=output_path),
PlotDependenciesAvailable(),
]
)
return requirements
135 changes: 134 additions & 1 deletion test/stdlib/requirements/plotting/test_matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,20 @@

import pytest

from mellea.core import Context, ModelOutputThunk
from mellea.core import Context, ModelOutputThunk, Requirement
from mellea.stdlib.context import ChatContext
from mellea.stdlib.requirements.plotting import (
MatplotlibHeadlessBackend,
PlotDependenciesAvailable,
PlotFileSaved,
python_plotting_requirements,
)
from mellea.stdlib.requirements.python_tools import (
ImportRestrictions,
NoImportRestrictions,
PythonCodeExtraction,
PythonExecutionReq,
PythonSyntaxValid,
)


Expand Down Expand Up @@ -449,3 +457,128 @@ def test_path_as_keyword_with_different_names(self):
result = req.validation_fn(ctx)
# fname= is the correct keyword; savefig has no 'filename' parameter
assert result.as_bool() is True


class TestPythonPlottingRequirementsFactory:
"""Tests for python_plotting_requirements factory function."""

def test_factory_returns_list_of_requirements(self):
"""Test that factory returns a list of seven Requirement instances."""
output_path = "/tmp/plot.png"
reqs = python_plotting_requirements(output_path=output_path)

assert isinstance(reqs, list)
assert len(reqs) == 7
assert all(isinstance(r, Requirement) for r in reqs)

def test_factory_returns_correct_requirement_order(self):
"""Test requirements are in expected order (python tools + plotting)."""
output_path = "/tmp/plot.png"
reqs = python_plotting_requirements(output_path=output_path)

# First 4 from python_code_generation_requirements
assert isinstance(reqs[0], PythonCodeExtraction)
assert isinstance(reqs[1], PythonSyntaxValid)
assert isinstance(reqs[2], PythonExecutionReq)
assert isinstance(
reqs[3], NoImportRestrictions
) # No allowed_imports, so NoImportRestrictions

# Last 3 plotting-specific
assert isinstance(reqs[4], MatplotlibHeadlessBackend)
assert isinstance(reqs[5], PlotFileSaved)
assert isinstance(reqs[6], PlotDependenciesAvailable)

def test_factory_with_allowed_imports(self):
"""Test that allowed_imports parameter creates ImportRestrictions."""
output_path = "/tmp/plot.png"
allowed_imports = ["matplotlib", "numpy"]
reqs = python_plotting_requirements(
output_path=output_path, allowed_imports=allowed_imports
)

assert len(reqs) == 7
assert isinstance(reqs[3], ImportRestrictions)
assert reqs[3].allowed_imports == allowed_imports

def test_factory_propagates_output_path(self):
"""Test that output_path is correctly passed to PlotFileSaved."""
output_path = "/output/my_plot.png"
reqs = python_plotting_requirements(output_path=output_path)

plot_saved_req = reqs[5]
assert isinstance(plot_saved_req, PlotFileSaved)
assert plot_saved_req.output_path == output_path
assert output_path in plot_saved_req.description

def test_factory_with_different_paths(self):
"""Test factory works with various output path formats."""
test_paths = [
"/tmp/plot.png",
"/output/figures/chart.svg",
"plot.pdf",
"/var/tmp/matplotlib_output.jpg",
]

for path in test_paths:
reqs = python_plotting_requirements(output_path=path)
assert len(reqs) == 7
assert reqs[5].output_path == path

def test_factory_raises_on_empty_path(self):
"""Test that factory raises ValueError for empty output_path."""
with pytest.raises(ValueError):
python_plotting_requirements(output_path="")

with pytest.raises(ValueError):
python_plotting_requirements(output_path=" ")

def test_factory_propagates_timeout_seconds_validation(self):
"""Test that invalid timeout_seconds from delegated factory propagates."""
with pytest.raises(ValueError, match="timeout_seconds must be positive"):
python_plotting_requirements(output_path="/tmp/plot.png", timeout_seconds=0)

def test_factory_propagates_output_limit_chars_validation(self):
"""Test that invalid output_limit_chars from delegated factory propagates."""
with pytest.raises(ValueError, match="output_limit_chars must be positive"):
python_plotting_requirements(
output_path="/tmp/plot.png", output_limit_chars=-1
)

def test_factory_can_be_unpacked(self):
"""Test that factory result can be unpacked into all 7 requirements."""
output_path = "/tmp/plot.png"
r1, r2, r3, r4, r5, r6, r7 = python_plotting_requirements(
output_path=output_path
)
assert isinstance(r1, PythonCodeExtraction)
assert isinstance(r2, PythonSyntaxValid)
assert isinstance(r3, PythonExecutionReq)
assert isinstance(r4, NoImportRestrictions)
assert isinstance(r5, MatplotlibHeadlessBackend)
assert isinstance(r6, PlotFileSaved)
assert isinstance(r7, PlotDependenciesAvailable)

def test_factory_requirements_are_independent_instances(self):
"""Test that multiple factory calls create independent requirement instances."""
reqs1 = python_plotting_requirements(output_path="/tmp/plot1.png")
reqs2 = python_plotting_requirements(output_path="/tmp/plot2.png")

# Different instances
assert reqs1[5] is not reqs2[5]
# But different output paths
assert reqs1[5].output_path != reqs2[5].output_path

def test_factory_with_custom_execution_parameters(self):
"""Test that custom execution parameters can be passed through."""
output_path = "/tmp/plot.png"
reqs = python_plotting_requirements(
output_path=output_path,
output_limit_chars=5000,
timeout_seconds=10,
use_sandbox=True,
)

assert len(reqs) == 7
execution_req = reqs[2]
assert isinstance(execution_req, PythonExecutionReq)
Loading