Files
Fu-Jie_openwebui-extensions/plugins/debug/copilot_sdk_research/test_permission_comprehensive.py

203 lines
6.1 KiB
Python

"""
Comprehensive Permission Control Test Suite
Tests all permission control scenarios for GitHub Copilot SDK
"""
import argparse
import asyncio
import logging
import re
from typing import Any, Dict, List, Tuple
from copilot import CopilotClient
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
def build_permission_handler(allow_all: bool, allow_shell: bool, pattern: str):
async def on_permission_request(request: Dict[str, Any], context: Dict[str, str]):
kind = request.get("kind")
# Shell requests use 'fullCommandText' not 'command'
command = request.get("fullCommandText", "") or request.get("command", "")
if allow_all:
logger.info("✅ Approved (allow-all): kind=%s command=%r", kind, command)
return {"kind": "approved"}
if kind in ("read", "url"):
logger.info("✅ Approved (safe): kind=%s", kind)
return {"kind": "approved"}
if kind == "shell":
if allow_shell:
logger.info("✅ Approved (allow-shell): command=%r", command)
return {"kind": "approved"}
if pattern and command:
try:
if re.match(pattern, command):
logger.info(
"✅ Approved (regex match): pattern=%r command=%r",
pattern,
command,
)
return {"kind": "approved"}
except re.error as exc:
logger.error("Invalid regex pattern: %s (%s)", pattern, exc)
logger.warning("❌ Denied: kind=%s command=%r", kind, command)
return {"kind": "denied-by-rules", "rules": [{"kind": "test-suite"}]}
return on_permission_request
async def run_test(
model: str, allow_all: bool, allow_shell: bool, pattern: str, prompt: str
) -> Tuple[bool, str]:
"""Run a single test and return (success, response)"""
try:
client = CopilotClient()
await client.start()
session = await client.create_session(
{
"model": model,
"on_permission_request": build_permission_handler(
allow_all=allow_all,
allow_shell=allow_shell,
pattern=pattern,
),
}
)
response = await session.send_and_wait({"prompt": prompt})
await client.stop()
content = response.data.content
# Check if response indicates success or denial
denied_keywords = [
"不允许",
"无法",
"对不起",
"Sorry",
"can't",
"cannot",
"not have permission",
]
is_denied = any(kw in content for kw in denied_keywords)
return (not is_denied, content)
except Exception as e:
logger.error("Test failed with exception: %s", e)
return (False, str(e))
async def main():
parser = argparse.ArgumentParser(
description="Comprehensive permission control test suite."
)
parser.add_argument("--model", default="gpt-4.1", help="Model ID for testing.")
args = parser.parse_args()
# Test cases: (name, allow_all, allow_shell, pattern, prompt, expected_approved)
test_cases = [
("Default Deny Shell", False, False, "", "请执行: ls -la", False),
("Allow All", True, False, "", "请执行: ls -la", True),
("Allow Shell", False, True, "", "请执行: pwd", True),
("Regex Match: ^ls", False, False, "^ls", "请执行: ls -la", True),
("Regex No Match: ^ls vs pwd", False, False, "^ls", "请执行: pwd", False),
(
"Regex Complex: ^(ls|pwd|echo)",
False,
False,
"^(ls|pwd|echo)",
"请执行: pwd",
True,
),
(
"Regex Complex No Match: git",
False,
False,
"^(ls|pwd|echo)",
"请执行: git status",
False,
),
(
"Read Permission (Always Allow)",
False,
False,
"",
"Read the file: README.md",
True,
),
]
results = []
logger.info("=" * 80)
logger.info("Starting Comprehensive Permission Control Test Suite")
logger.info("Model: %s", args.model)
logger.info("=" * 80)
for i, (name, allow_all, allow_shell, pattern, prompt, expected) in enumerate(
test_cases, 1
):
logger.info("\n[Test %d/%d] %s", i, len(test_cases), name)
logger.info(
" Config: allow_all=%s, allow_shell=%s, pattern=%r",
allow_all,
allow_shell,
pattern,
)
logger.info(" Prompt: %s", prompt)
approved, response = await run_test(
args.model, allow_all, allow_shell, pattern, prompt
)
passed = approved == expected
status = "✅ PASS" if passed else "❌ FAIL"
results.append((name, passed))
logger.info(
" Expected: %s, Got: %s - %s",
"Approved" if expected else "Denied",
"Approved" if approved else "Denied",
status,
)
logger.info(
" Response: %s",
response[:100] + "..." if len(response) > 100 else response,
)
# Summary
logger.info("\n" + "=" * 80)
logger.info("Test Summary")
logger.info("=" * 80)
passed_count = sum(1 for _, passed in results if passed)
total_count = len(results)
for name, passed in results:
logger.info("%s %s", "" if passed else "", name)
logger.info("-" * 80)
logger.info(
"Total: %d/%d tests passed (%.1f%%)",
passed_count,
total_count,
100 * passed_count / total_count,
)
if passed_count == total_count:
logger.info("🎉 All tests passed!")
else:
logger.warning("⚠️ Some tests failed. Please review the logs.")
if __name__ == "__main__":
asyncio.run(main())