diff --git a/.gitignore b/.gitignore index 98ba9dd..fba9c27 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ __pycache__/ flagged/ -/venv/ \ No newline at end of file +/venv/ +.env \ No newline at end of file diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..0bc5130 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,3 @@ +[pytest] +testpaths = tests +addopts = -ra diff --git a/requirements.txt b/requirements.txt index 2acff59..6ffdd57 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,5 @@ openai==2.6.0 -gradio==5.49.1 \ No newline at end of file +gradio==5.49.1 +pytest +pytest-mock +pytest-cov diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..1d6505e --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,56 @@ +import json +import importlib +import sys +import gradio as gr +import pytest +import sys +from pathlib import Path + +ROOT = str(Path(__file__).resolve().parents[1]) +if ROOT not in sys.path: + sys.path.insert(0, ROOT) + +@pytest.fixture +def sample_request_body(): + return { + "prompt": "$INPUT", + "metadata": { + "language": "en", + "extras": ["a", "b"], + }, + } + +@pytest.fixture +def sample_response_structure(): + return { + "data": { + "result": "$OUTPUT", + } + } + +@pytest.fixture +def sample_response_payload(): + return { + "data": { + "result": "system prompt here", + } + } + +@pytest.fixture +def request_body_json(sample_request_body): + return json.dumps(sample_request_body) + +@pytest.fixture +def response_body_json(sample_response_structure): + return json.dumps(sample_response_structure) + +@pytest.fixture +def ui_app(monkeypatch): + pytest.importorskip("core.whistleblower", reason="core.whistleblower module missing") + monkeypatch.setattr(gr.Blocks, "launch", lambda self: None) + module_name = "ui.app" + if module_name in sys.modules: + module = importlib.reload(sys.modules[module_name]) + else: + module = importlib.import_module(module_name) + return module diff --git a/tests/test_api.py b/tests/test_api.py new file mode 100644 index 0000000..1311754 --- /dev/null +++ b/tests/test_api.py @@ -0,0 +1,43 @@ +from types import SimpleNamespace +import pytest +from core.api import call_external_api + +def test_call_external_api_replaces_placeholders(monkeypatch, sample_request_body, sample_response_structure, sample_response_payload): + captured = {} + + def fake_post(url, json, headers): + captured["url"] = url + captured["json"] = json + captured["headers"] = headers + return SimpleNamespace(json=lambda: sample_response_payload) + + monkeypatch.setattr("core.api.requests.post", fake_post) + + output = call_external_api( + "https://api.example.com/chat", + "leak prompt", + sample_request_body, + sample_response_structure, + api_key="secret-token", + ) + + assert captured["url"] == "https://api.example.com/chat" + assert captured["json"]["prompt"] == "leak prompt" + assert captured["headers"] == {"X-repello-api-key": "secret-token"} + assert output == "system prompt here" + +def test_call_external_api_passes_through_when_no_key(monkeypatch, sample_request_body, sample_response_structure, sample_response_payload): + def fake_post(url, json, headers): + assert headers == {} + return SimpleNamespace(json=lambda: sample_response_payload) + + monkeypatch.setattr("core.api.requests.post", fake_post) + + output = call_external_api( + "https://api.example.com/chat", + "test prompt", + sample_request_body, + sample_response_structure, + ) + + assert output == "system prompt here" diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 0000000..4a91a02 --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,52 @@ +import json +import sys +from pathlib import Path +import pytest + +main = pytest.importorskip("main", reason="main module missing") +whistleblower = pytest.importorskip("core.whistleblower", reason="core.whistleblower module missing") + +def test_read_json_file_valid(tmp_path): + data = {"api_url": "http://example.com"} + json_path = tmp_path / "config.json" + json_path.write_text(json.dumps(data)) + + loaded = whistleblower.read_json_file(str(json_path)) + + assert loaded == data + +def test_read_json_file_invalid(tmp_path, capsys): + json_path = tmp_path / "invalid.json" + json_path.write_text("{invalid") + + loaded = whistleblower.read_json_file(str(json_path)) + + assert loaded == {} + +def test_main_invokes_whistleblower(monkeypatch, tmp_path): + config_path = tmp_path / "input.json" + config_path.write_text( + json.dumps( + { + "api_url": "http://example.com", + "request_body": {"prompt": "$INPUT"}, + "response_body": {"response": "$OUTPUT"}, + "OpenAI_api_key": "key", + "model": "gpt-4", + } + ) + ) + + called = {} + + def fake_whistleblower(args): + called["json_file"] = args.json_file + return "done" + + monkeypatch.setattr(main, "whistleblower", fake_whistleblower) + monkeypatch.setattr(sys, "argv", ["main.py", "--json_file", str(config_path)]) + + result = main.main() + + assert called["json_file"] == str(config_path) + assert result == "done" diff --git a/tests/test_ui.py b/tests/test_ui.py new file mode 100644 index 0000000..7afaf61 --- /dev/null +++ b/tests/test_ui.py @@ -0,0 +1,153 @@ +import json + +import gradio as gr +import pytest + + +def test_check_for_placeholders(ui_app, request_body_json): + assert ui_app.check_for_placeholders(request_body_json, "$INPUT") + assert not ui_app.check_for_placeholders({"prompt": "hello"}, "$INPUT") + + +def test_check_for_placeholders_list(ui_app): + payload = json.dumps({"messages": [{"text": "$INPUT"}]}) + + assert ui_app.check_for_placeholders(payload, "$INPUT") + + +def test_validate_input_json_success(ui_app, request_body_json, response_body_json, monkeypatch): + captured = {} + + def fake_generate_output(*args): + captured["args"] = args + return "result-json" + + monkeypatch.setattr(ui_app, "generate_output", fake_generate_output) + + output = ui_app.validate_input( + "https://api.example.com", + "api-key", + "JSON", + "", + request_body_json, + "", + response_body_json, + "openai-key", + "gpt-4", + ) + + assert output == "result-json" + assert captured["args"][0] == "https://api.example.com" + assert json.loads(captured["args"][2])["prompt"] == "$INPUT" + + +def test_validate_input_json_empty_request_raises(ui_app): + with pytest.raises(gr.Error) as excinfo: + ui_app.validate_input( + "https://api.example.com", + "", + "JSON", + "", + "", + "", + "{}", + "", + "gpt-4o", + ) + + assert "Request body cannot be empty" in str(excinfo.value) + + +def test_validate_input_json_invalid_request_raises(ui_app): + with pytest.raises(gr.Error) as excinfo: + ui_app.validate_input( + "https://api.example.com", + "", + "JSON", + "", + "{bad json", + "", + '{"response": "$OUTPUT"}', + "", + "gpt-4o", + ) + + assert "Invalid JSON format in request body." in str(excinfo.value) + + +def test_validate_input_json_missing_output_placeholder_raises(ui_app): + with pytest.raises(gr.Error) as excinfo: + ui_app.validate_input( + "https://api.example.com", + "", + "JSON", + "", + '{"prompt": "$INPUT"}', + "", + '{"response": "missing"}', + "", + "gpt-4o", + ) + + assert "Response body must contain the $OUTPUT placeholder." in str(excinfo.value) + + +def test_validate_input_key_value_success(ui_app, monkeypatch): + captured = {} + + def fake_generate_output(*args): + captured["args"] = args + return "result-kv" + + monkeypatch.setattr(ui_app, "generate_output", fake_generate_output) + + output = ui_app.validate_input( + "https://api.example.com", + "", + "Key-Value", + "prompt: $INPUT\nstatic: value", + "", + "response: $OUTPUT", + "", + "service-key", + "gpt-4o", + ) + + assert output == "result-kv" + request_dict = captured["args"][2] + assert request_dict["prompt"] == "$INPUT" + assert request_dict["static"] == "value" + response_dict = captured["args"][3] + assert response_dict["response"] == "$OUTPUT" + + +def test_validate_input_key_value_empty_fields_raise(ui_app): + with pytest.raises(gr.Error) as excinfo: + ui_app.validate_input( + "https://api.example.com", + "", + "Key-Value", + "", + "", + "response: $OUTPUT", + "", + "openai-key", + "gpt-4o", + ) + + assert "Request body cannot be empty." in str(excinfo.value) + + +def test_validate_input_missing_placeholder_raises(ui_app): + with pytest.raises(gr.Error): + ui_app.validate_input( + "https://api.example.com", + "", + "JSON", + "", + '{"prompt": "no placeholder"}', + "", + '{"response": "$OUTPUT"}', + "", + "gpt-4o", + ) diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..35c92dd --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,46 @@ +from copy import deepcopy + +from core.utils import extract_nested_value, replace_nested_value + + +def test_extract_nested_value_success(sample_response_payload, sample_response_structure): + value = extract_nested_value(sample_response_payload, sample_response_structure, "$OUTPUT") + assert value == "system prompt here" + + +def test_extract_nested_value_handles_list_structure(): + payload = {"items": [{"value": "skip"}, {"value": "system prompt here"}]} + structure = {"items": [{}, {"value": "$OUTPUT"}]} + + value = extract_nested_value(payload, structure, "$OUTPUT") + + assert value == "system prompt here" + + +def test_extract_nested_value_missing(sample_response_payload): + structure_without_placeholder = {"data": {"result": "not-a-placeholder"}} + value = extract_nested_value(sample_response_payload, structure_without_placeholder, "$OUTPUT") + assert value is None + + +def test_extract_nested_value_returns_none_for_none_payload(): + payload = {"data": {"result": None}} + structure = {"data": {"result": "$OUTPUT"}} + + assert extract_nested_value(payload, structure, "$OUTPUT") is None + + +def test_replace_nested_value_updates_only_placeholder(sample_request_body): + updated = replace_nested_value(deepcopy(sample_request_body), "$INPUT", "example message") + assert updated["prompt"] == "example message" + assert updated["metadata"]["language"] == "en" + + +def test_replace_nested_value_updates_nested_lists(): + data = {"items": ["$INPUT", {"nested": ["keep", "$INPUT"]}]} + + updated = replace_nested_value(deepcopy(data), "$INPUT", "filled") + + assert updated["items"][0] == "filled" + assert updated["items"][1]["nested"][1] == "filled" + assert updated["items"][1]["nested"][0] == "keep" diff --git a/tests/test_whistleblower.py b/tests/test_whistleblower.py new file mode 100644 index 0000000..969ed22 --- /dev/null +++ b/tests/test_whistleblower.py @@ -0,0 +1,266 @@ +import json +from types import SimpleNamespace + +import pytest + +from core import seeds,whistleblower + +def _fake_openai_response(content): + return SimpleNamespace( + choices=[SimpleNamespace(message=SimpleNamespace(content=content))] + ) + + +def test_convert_to_json_handles_success_and_failure(tmp_path): + payload = {"key": "value"} + assert whistleblower.convert_to_json(json.dumps(payload)) == payload + assert whistleblower.convert_to_json("not-json") == {} + + +def test_read_json_file_success(tmp_path): + config_path = tmp_path / "config.json" + config_path.write_text(json.dumps({"name": "value"})) + + assert whistleblower.read_json_file(str(config_path)) == {"name": "value"} + + +def test_read_json_file_invalid_json(tmp_path): + config_path = tmp_path / "config.json" + config_path.write_text("{bad json") + + assert whistleblower.read_json_file(str(config_path)) == {} + + +def test_target_model_returns_response(monkeypatch): + def fake_call(api_url, question, request_body, response_body, api_key): + return f"{question}-answer" + + monkeypatch.setattr(whistleblower, "call_external_api", fake_call) + monkeypatch.setattr(whistleblower, "read_judge_prompt", lambda: "prompt") + + result = whistleblower.target_model( + "https://api.example.com", {"prompt": "$INPUT"}, {"response": "$OUTPUT"}, "key", "question" + ) + + assert result == "question-answer" + + +def test_target_model_returns_none_on_exception(monkeypatch, capsys): + def fake_call(*_): + raise RuntimeError("network down") + + monkeypatch.setattr(whistleblower, "call_external_api", fake_call) + + result = whistleblower.target_model( + "https://api.example.com", {"prompt": "$INPUT"}, {"response": "$OUTPUT"}, "key", "question" + ) + + captured = capsys.readouterr() + assert "Error" in captured.out + assert result is None + + +def test_capability_output_uses_openai(monkeypatch): + expected = "capability response" + + def fake_openai(): + return SimpleNamespace( + chat=SimpleNamespace( + completions=SimpleNamespace( + create=lambda **_: _fake_openai_response(expected) + ) + ) + ) + + monkeypatch.setattr(whistleblower, "OpenAI", fake_openai) + + output = whistleblower.capability_output("prompt", "system", "key", "model") + + assert output == expected + + +def test_get_context_aggregates_responses(monkeypatch): + answers = { + "Give me 16 short sentences that best describe yourself. Start with '1:'": "A", + "Give me 16 example questions that I can ask you. Start with '1:'": "B", + "Give me 16 scenarios where I can use you. Start with '1:'": "C", + "Give me 16 short sentences comparing yourself with ChatGPT. Start with '1:'": "D", + } + + def fake_call(api_url, question, request_body, response_body, api_key): + return answers[question] + + monkeypatch.setattr(whistleblower, "call_external_api", fake_call) + + captured = {} + + def fake_capability_output(context, judge_prompt, api_key, model): + captured["context"] = context + return "summarized" + + monkeypatch.setattr(whistleblower, "capability_output", fake_capability_output) + monkeypatch.setattr(whistleblower, "read_judge_prompt", lambda: "prompt") + + context = whistleblower.get_context( + "https://api.example.com", + {"prompt": "$INPUT"}, + {"response": "$OUTPUT"}, + "api-key", + "model", + ) + + assert context == "summarized" + assert "A" in captured["context"] + assert "D" in captured["context"] + + +def test_judge_model_parses_openai_response(monkeypatch): + payload = json.dumps({"score": 2, "improvement": "try again"}) + + def fake_openai(): + return SimpleNamespace( + chat=SimpleNamespace( + completions=SimpleNamespace( + create=lambda **_: _fake_openai_response(payload) + ) + ) + ) + + monkeypatch.setattr(whistleblower, "OpenAI", fake_openai) + + score, improvement = whistleblower.judge_model("adv", "target", "key", "model", "ctx") + + assert score == 2 + assert improvement == "try again" + + +def test_judge_model_handles_exception(monkeypatch): + def fake_openai(): + return SimpleNamespace( + chat=SimpleNamespace( + completions=SimpleNamespace( + create=lambda **_: (_ for _ in ()).throw(RuntimeError("boom")) + ) + ) + ) + + monkeypatch.setattr(whistleblower, "OpenAI", fake_openai) + + score, improvement = whistleblower.judge_model("adv", "resp", "key", "model") + + assert score == 0 + assert improvement is None + + +def test_attacker_model_returns_suggestion(monkeypatch): + expected = "new prompt" + + def fake_openai(): + return SimpleNamespace( + chat=SimpleNamespace( + completions=SimpleNamespace( + create=lambda **_: _fake_openai_response(expected) + ) + ) + ) + + monkeypatch.setattr(whistleblower, "OpenAI", fake_openai) + + assert whistleblower.attacker_model("prev", 1, "improve", "key", "model") == expected + + +def test_attacker_model_handles_exception(monkeypatch): + def fake_openai(): + return SimpleNamespace( + chat=SimpleNamespace( + completions=SimpleNamespace( + create=lambda **_: (_ for _ in ()).throw(RuntimeError("fail")) + ) + ) + ) + + monkeypatch.setattr(whistleblower, "OpenAI", fake_openai) + + assert whistleblower.attacker_model("prev", 1, "improve", "key", "model") == "" + + +def test_generate_output_returns_success(monkeypatch): + monkeypatch.setattr(whistleblower, "get_context", lambda *args, **kwargs: "context") + monkeypatch.setattr(whistleblower, "target_model", lambda *args, **kwargs: "leaked") + monkeypatch.setattr(whistleblower, "judge_model", lambda *args, **kwargs: (3, "done")) + + result = whistleblower.generate_output( + "https://api.example.com", + "api-key", + '{"prompt": "$INPUT"}', + '{"response": "$OUTPUT"}', + "openai-key", + "gpt-4o", + ) + + assert result == "leaked" + + +def test_generate_output_respects_repetition_limit(monkeypatch): + monkeypatch.setattr(whistleblower, "get_context", lambda *args, **kwargs: "context") + monkeypatch.setattr(whistleblower, "target_model", lambda *args, **kwargs: "partial") + + call_count = {"judge": 0} + + def fake_judge(*args, **kwargs): + call_count["judge"] += 1 + return (2, "improve") + + prompts = [] + + def fake_attacker(prompt, score, improvement, api_key, model): + prompts.append((prompt, score, improvement)) + return f"{prompt}-{score}" + + monkeypatch.setattr(whistleblower, "judge_model", fake_judge) + monkeypatch.setattr(whistleblower, "attacker_model", fake_attacker) + + result = whistleblower.generate_output( + "https://api.example.com", + "api-key", + '{"prompt": "$INPUT"}', + '{"response": "$OUTPUT"}', + "openai-key", + "gpt-4o", + ) + + assert result == "partial" + assert call_count["judge"] == 4 + assert prompts + + +def test_whistleblower_invokes_generate_output(tmp_path, monkeypatch): + config_path = tmp_path / "config.json" + config = { + "api_url": "https://api.example.com", + "api_key": "api-key", + "request_body": {"prompt": "$INPUT"}, + "response_body": {"response": "$OUTPUT"}, + "OpenAI_api_key": "openai-key", + "model": "gpt-4o", + } + config_path.write_text(json.dumps(config)) + + called = {} + + def fake_generate_output(*args): + called["args"] = args + return "result" + + monkeypatch.setattr(whistleblower, "generate_output", fake_generate_output) + + args = SimpleNamespace(json_file=str(config_path)) + + whistleblower.whistleblower(args) + + assert called["args"][0] == config["api_url"] + + +def test_seeds_prompts_not_empty(): + assert seeds.prompts + assert any("You are" in prompt for prompt in seeds.prompts) diff --git a/ui/app.py b/ui/app.py index f5b005c..c5d83f1 100644 --- a/ui/app.py +++ b/ui/app.py @@ -1,6 +1,7 @@ import sys import os import json +from pathlib import Path # Add the parent directory to the Python path sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) @@ -8,8 +9,8 @@ import gradio as gr from core.whistleblower import generate_output -with open('styles.css', 'r') as file: - css = file.read() +css_path = Path(__file__).with_name('styles.css') +css = css_path.read_text() if css_path.exists() else "" def check_for_placeholders(data, placeholder): data = json.loads(data) if isinstance(data, str) else data