diff --git a/mellea/stdlib/requirements/__init__.py b/mellea/stdlib/requirements/__init__.py index ff2882b35..97ec2f0a1 100644 --- a/mellea/stdlib/requirements/__init__.py +++ b/mellea/stdlib/requirements/__init__.py @@ -6,6 +6,7 @@ MatplotlibHeadlessBackend, PlotDependenciesAvailable, PlotFileSaved, + python_plotting_requirements, ) from .python_reqs import PythonExecutionReq from .python_tools import ( @@ -49,6 +50,7 @@ "is_markdown_list", "is_markdown_table", "python_code_generation_requirements", + "python_plotting_requirements", "req", "reqify", "requirement_check_to_bool", diff --git a/mellea/stdlib/requirements/plotting/__init__.py b/mellea/stdlib/requirements/plotting/__init__.py index 6886f73bd..3314ea88d 100644 --- a/mellea/stdlib/requirements/plotting/__init__.py +++ b/mellea/stdlib/requirements/plotting/__init__.py @@ -11,6 +11,12 @@ MatplotlibHeadlessBackend, PlotDependenciesAvailable, PlotFileSaved, + python_plotting_requirements, ) -__all__ = ["MatplotlibHeadlessBackend", "PlotDependenciesAvailable", "PlotFileSaved"] +__all__ = [ + "MatplotlibHeadlessBackend", + "PlotDependenciesAvailable", + "PlotFileSaved", + "python_plotting_requirements", +] diff --git a/mellea/stdlib/requirements/plotting/matplotlib.py b/mellea/stdlib/requirements/plotting/matplotlib.py index 754f6de7a..8feb89986 100644 --- a/mellea/stdlib/requirements/plotting/matplotlib.py +++ b/mellea/stdlib/requirements/plotting/matplotlib.py @@ -8,6 +8,7 @@ from mellea.core import Context, Requirement, ValidationResult from mellea.stdlib.requirements.python_reqs import _has_python_code_listing +from mellea.stdlib.requirements.python_tools import python_code_generation_requirements # Matplotlib backends suitable for headless (non-interactive) execution. # Includes standard raster (Agg, Cairo), vector (pdf, svg, pgf), @@ -334,3 +335,76 @@ def _validate_dependencies(self, ctx: Context) -> ValidationResult: return ValidationResult( result=True, reason="All dependencies available (matplotlib, numpy)." ) + + +def python_plotting_requirements( + output_path: str, + allowed_imports: list[str] | None = None, + output_limit_chars: int = 10_000, + timeout_seconds: int = 5, + use_sandbox: bool = False, +) -> list[Requirement]: + """Bundle matplotlib-specific requirements for plotting code validation. + + Factory function that creates a complete set of requirements for validating + matplotlib plotting code, composing general Python code generation requirements + with plotting-specific constraints for headless backend configuration, file + output, and dependency availability. + + Args: + output_path: File path where the plot should be saved (e.g., '/tmp/plot.png'). + This path must match the savefig() call in the generated code. + allowed_imports: Whitelist of importable top-level modules. None allows all. + Default None. + output_limit_chars: Maximum allowed characters of captured stdout. + Default 10,000. + timeout_seconds: Maximum execution time in seconds. Default 5. + use_sandbox: Use llm-sandbox for Docker-isolated execution. Default False. + + Returns: + list[Requirement]: Seven requirements in validation order: + 1-4. PythonCodeExtraction, PythonSyntaxValid, PythonExecutionReq, + ImportRestrictions/NoImportRestrictions (from python_code_generation_requirements) + 5. MatplotlibHeadlessBackend — validates headless backend configuration + 6. PlotFileSaved — validates plot is saved to the specified output_path + 7. PlotDependenciesAvailable — validates matplotlib and numpy are available + + Raises: + ValueError: If output_path is empty or whitespace-only, or if timeout_seconds + or output_limit_chars is not positive. + + Examples: + >>> output_path = "/tmp/plot.png" + >>> reqs = python_plotting_requirements(output_path=output_path) + >>> len(reqs) + 7 + >>> isinstance(reqs[4], MatplotlibHeadlessBackend) + True + >>> isinstance(reqs[5], PlotFileSaved) + True + >>> isinstance(reqs[6], PlotDependenciesAvailable) + True + >>> reqs_restricted = python_plotting_requirements( + ... output_path=output_path, + ... allowed_imports=["matplotlib", "numpy"] + ... ) + >>> len(reqs_restricted) + 7 + """ + if not output_path.strip(): + raise ValueError("output_path cannot be empty or whitespace-only") + + requirements = python_code_generation_requirements( + allowed_imports=allowed_imports, + output_limit_chars=output_limit_chars, + timeout_seconds=timeout_seconds, + use_sandbox=use_sandbox, + ) + requirements.extend( + [ + MatplotlibHeadlessBackend(), + PlotFileSaved(output_path=output_path), + PlotDependenciesAvailable(), + ] + ) + return requirements diff --git a/test/stdlib/requirements/plotting/test_matplotlib.py b/test/stdlib/requirements/plotting/test_matplotlib.py index 95e1b7e11..fa201795b 100644 --- a/test/stdlib/requirements/plotting/test_matplotlib.py +++ b/test/stdlib/requirements/plotting/test_matplotlib.py @@ -2,12 +2,20 @@ import pytest -from mellea.core import Context, ModelOutputThunk +from mellea.core import Context, ModelOutputThunk, Requirement from mellea.stdlib.context import ChatContext from mellea.stdlib.requirements.plotting import ( MatplotlibHeadlessBackend, PlotDependenciesAvailable, PlotFileSaved, + python_plotting_requirements, +) +from mellea.stdlib.requirements.python_tools import ( + ImportRestrictions, + NoImportRestrictions, + PythonCodeExtraction, + PythonExecutionReq, + PythonSyntaxValid, ) @@ -449,3 +457,128 @@ def test_path_as_keyword_with_different_names(self): result = req.validation_fn(ctx) # fname= is the correct keyword; savefig has no 'filename' parameter assert result.as_bool() is True + + +class TestPythonPlottingRequirementsFactory: + """Tests for python_plotting_requirements factory function.""" + + def test_factory_returns_list_of_requirements(self): + """Test that factory returns a list of seven Requirement instances.""" + output_path = "/tmp/plot.png" + reqs = python_plotting_requirements(output_path=output_path) + + assert isinstance(reqs, list) + assert len(reqs) == 7 + assert all(isinstance(r, Requirement) for r in reqs) + + def test_factory_returns_correct_requirement_order(self): + """Test requirements are in expected order (python tools + plotting).""" + output_path = "/tmp/plot.png" + reqs = python_plotting_requirements(output_path=output_path) + + # First 4 from python_code_generation_requirements + assert isinstance(reqs[0], PythonCodeExtraction) + assert isinstance(reqs[1], PythonSyntaxValid) + assert isinstance(reqs[2], PythonExecutionReq) + assert isinstance( + reqs[3], NoImportRestrictions + ) # No allowed_imports, so NoImportRestrictions + + # Last 3 plotting-specific + assert isinstance(reqs[4], MatplotlibHeadlessBackend) + assert isinstance(reqs[5], PlotFileSaved) + assert isinstance(reqs[6], PlotDependenciesAvailable) + + def test_factory_with_allowed_imports(self): + """Test that allowed_imports parameter creates ImportRestrictions.""" + output_path = "/tmp/plot.png" + allowed_imports = ["matplotlib", "numpy"] + reqs = python_plotting_requirements( + output_path=output_path, allowed_imports=allowed_imports + ) + + assert len(reqs) == 7 + assert isinstance(reqs[3], ImportRestrictions) + assert reqs[3].allowed_imports == allowed_imports + + def test_factory_propagates_output_path(self): + """Test that output_path is correctly passed to PlotFileSaved.""" + output_path = "/output/my_plot.png" + reqs = python_plotting_requirements(output_path=output_path) + + plot_saved_req = reqs[5] + assert isinstance(plot_saved_req, PlotFileSaved) + assert plot_saved_req.output_path == output_path + assert output_path in plot_saved_req.description + + def test_factory_with_different_paths(self): + """Test factory works with various output path formats.""" + test_paths = [ + "/tmp/plot.png", + "/output/figures/chart.svg", + "plot.pdf", + "/var/tmp/matplotlib_output.jpg", + ] + + for path in test_paths: + reqs = python_plotting_requirements(output_path=path) + assert len(reqs) == 7 + assert reqs[5].output_path == path + + def test_factory_raises_on_empty_path(self): + """Test that factory raises ValueError for empty output_path.""" + with pytest.raises(ValueError): + python_plotting_requirements(output_path="") + + with pytest.raises(ValueError): + python_plotting_requirements(output_path=" ") + + def test_factory_propagates_timeout_seconds_validation(self): + """Test that invalid timeout_seconds from delegated factory propagates.""" + with pytest.raises(ValueError, match="timeout_seconds must be positive"): + python_plotting_requirements(output_path="/tmp/plot.png", timeout_seconds=0) + + def test_factory_propagates_output_limit_chars_validation(self): + """Test that invalid output_limit_chars from delegated factory propagates.""" + with pytest.raises(ValueError, match="output_limit_chars must be positive"): + python_plotting_requirements( + output_path="/tmp/plot.png", output_limit_chars=-1 + ) + + def test_factory_can_be_unpacked(self): + """Test that factory result can be unpacked into all 7 requirements.""" + output_path = "/tmp/plot.png" + r1, r2, r3, r4, r5, r6, r7 = python_plotting_requirements( + output_path=output_path + ) + assert isinstance(r1, PythonCodeExtraction) + assert isinstance(r2, PythonSyntaxValid) + assert isinstance(r3, PythonExecutionReq) + assert isinstance(r4, NoImportRestrictions) + assert isinstance(r5, MatplotlibHeadlessBackend) + assert isinstance(r6, PlotFileSaved) + assert isinstance(r7, PlotDependenciesAvailable) + + def test_factory_requirements_are_independent_instances(self): + """Test that multiple factory calls create independent requirement instances.""" + reqs1 = python_plotting_requirements(output_path="/tmp/plot1.png") + reqs2 = python_plotting_requirements(output_path="/tmp/plot2.png") + + # Different instances + assert reqs1[5] is not reqs2[5] + # But different output paths + assert reqs1[5].output_path != reqs2[5].output_path + + def test_factory_with_custom_execution_parameters(self): + """Test that custom execution parameters can be passed through.""" + output_path = "/tmp/plot.png" + reqs = python_plotting_requirements( + output_path=output_path, + output_limit_chars=5000, + timeout_seconds=10, + use_sandbox=True, + ) + + assert len(reqs) == 7 + execution_req = reqs[2] + assert isinstance(execution_req, PythonExecutionReq)