Files
Fu-Jie_openwebui-extensions/plugins/debug/openwebui-skills-manager/test_security_fixes.py

561 lines
19 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
"""
独立测试脚本验证 OpenWebUI Skills Manager 的所有安全修复
不需要 OpenWebUI 环境可以直接运行
测试内容
1. SSRF 防护 (_is_safe_url)
2. 不安全 tar/zip 提取防护 (_safe_extract_zip, _safe_extract_tar)
3. 名称冲突检查 (update_skill)
4. URL 验证
"""
import asyncio
import json
import logging
import sys
import tempfile
import tarfile
import zipfile
from pathlib import Path
from typing import Optional, Dict, Any, List, Tuple
# 配置日志
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
# ==================== 模拟 OpenWebUI Skills 类 ====================
class MockSkill:
def __init__(self, id: str, name: str, description: str = "", content: str = ""):
self.id = id
self.name = name
self.description = description
self.content = content
self.is_active = True
self.updated_at = "2024-03-08T00:00:00Z"
class MockSkills:
"""Mock Skills 模型,用于测试"""
_skills: Dict[str, List[MockSkill]] = {}
@classmethod
def reset(cls):
cls._skills = {}
@classmethod
def get_skills_by_user_id(cls, user_id: str):
return cls._skills.get(user_id, [])
@classmethod
def insert_new_skill(cls, user_id: str, form_data):
if user_id not in cls._skills:
cls._skills[user_id] = []
skill = MockSkill(
form_data.id, form_data.name, form_data.description, form_data.content
)
cls._skills[user_id].append(skill)
return skill
@classmethod
def update_skill_by_id(cls, skill_id: str, updates: Dict[str, Any]):
for user_skills in cls._skills.values():
for skill in user_skills:
if skill.id == skill_id:
for key, value in updates.items():
setattr(skill, key, value)
return skill
return None
@classmethod
def delete_skill_by_id(cls, skill_id: str):
for user_id, user_skills in cls._skills.items():
for idx, skill in enumerate(user_skills):
if skill.id == skill_id:
user_skills.pop(idx)
return True
return False
# ==================== 提取安全测试的核心方法 ====================
import ipaddress
import urllib.parse
class SecurityTester:
"""提取出的安全测试核心类"""
def __init__(self):
# 模拟 Valves 配置
self.valves = type(
"Valves",
(),
{
"ENABLE_DOMAIN_WHITELIST": True,
"TRUSTED_DOMAINS": "github.com,raw.githubusercontent.com,huggingface.co",
},
)()
def _is_safe_url(self, url: str) -> tuple:
"""
验证 URL 是否指向内部/敏感目标
防止服务端请求伪造 (SSRF) 攻击
返回 (True, None) 如果 URL 是安全的否则返回 (False, error_message)
"""
try:
parsed = urllib.parse.urlparse(url)
hostname = parsed.hostname or ""
if not hostname:
return False, "URL is malformed: missing hostname"
# 拒绝 localhost 变体
if hostname.lower() in (
"localhost",
"127.0.0.1",
"::1",
"[::1]",
"0.0.0.0",
"[::ffff:127.0.0.1]",
"localhost.localdomain",
):
return False, "URL points to local host"
# 拒绝内部 IP 范围 (RFC 1918, link-local 等)
try:
ip = ipaddress.ip_address(hostname.lstrip("[").rstrip("]"))
# 拒绝私有、回环、链接本地和保留 IP
if (
ip.is_private
or ip.is_loopback
or ip.is_link_local
or ip.is_reserved
):
return False, f"URL points to internal IP: {ip}"
except ValueError:
# 不是 IP 地址,检查 hostname 模式
pass
# 拒绝 file:// 和其他非 http(s) 方案
if parsed.scheme not in ("http", "https"):
return False, f"URL scheme not allowed: {parsed.scheme}"
# 域名白名单检查 (安全层 2)
if self.valves.ENABLE_DOMAIN_WHITELIST:
trusted_domains = [
d.strip().lower()
for d in (self.valves.TRUSTED_DOMAINS or "").split(",")
if d.strip()
]
if not trusted_domains:
# 没有配置授信域名,仅进行安全检查
return True, None
hostname_lower = hostname.lower()
# 检查 hostname 是否匹配任何授信域名(精确或子域名)
is_trusted = False
for trusted_domain in trusted_domains:
# 精确匹配
if hostname_lower == trusted_domain:
is_trusted = True
break
# 子域名匹配 (*.example.com 匹配 api.example.com)
if hostname_lower.endswith("." + trusted_domain):
is_trusted = True
break
if not is_trusted:
error_msg = f"URL domain '{hostname}' is not in whitelist. Trusted domains: {', '.join(trusted_domains)}"
return False, error_msg
return True, None
except Exception as e:
return False, f"Error validating URL: {e}"
def _safe_extract_zip(self, zip_path: Path, extract_dir: Path) -> None:
"""
安全地提取 ZIP 文件验证成员路径以防止路径遍历
"""
with zipfile.ZipFile(zip_path, "r") as zf:
for member in zf.namelist():
# 检查路径遍历尝试
member_path = Path(extract_dir) / member
try:
# 确保解析的路径在 extract_dir 内
member_path.resolve().relative_to(extract_dir.resolve())
except ValueError:
# 路径在 extract_dir 外(遍历尝试)
logger.warning(f"Skipping unsafe ZIP member: {member}")
continue
# 提取成员
zf.extract(member, extract_dir)
def _safe_extract_tar(self, tar_path: Path, extract_dir: Path) -> None:
"""
安全地提取 TAR 文件验证成员路径以防止路径遍历
"""
with tarfile.open(tar_path, "r:*") as tf:
for member in tf.getmembers():
# 检查路径遍历尝试
member_path = Path(extract_dir) / member.name
try:
# 确保解析的路径在 extract_dir 内
member_path.resolve().relative_to(extract_dir.resolve())
except ValueError:
# 路径在 extract_dir 外(遍历尝试)
logger.warning(f"Skipping unsafe TAR member: {member.name}")
continue
# 提取成员
tf.extract(member, extract_dir)
# ==================== 测试用例 ====================
def test_ssrf_protection():
"""测试 SSRF 防护"""
print("\n" + "=" * 60)
print("测试 1: SSRF 防护 (_is_safe_url)")
print("=" * 60)
tester = SecurityTester()
# 不安全的 URLs (应该被拒绝)
unsafe_urls = [
"http://localhost/skill",
"http://127.0.0.1:8000/skill",
"http://[::1]/skill",
"http://0.0.0.0/skill",
"http://192.168.1.1/skill", # 私有 IP (RFC 1918)
"http://10.0.0.1/skill",
"http://172.16.0.1/skill",
"http://169.254.1.1/skill", # link-local
"file:///etc/passwd", # file:// scheme
"gopher://example.com/skill", # 非 http(s)
]
print("\n❌ 不安全的 URLs (应该被拒绝):")
for url in unsafe_urls:
is_safe, error_msg = tester._is_safe_url(url)
status = "✗ 被拒绝 (正确)" if not is_safe else "✗ 被接受 (错误)"
error_info = f" - {error_msg}" if error_msg else ""
print(f" {url:<50} {status}{error_info}")
assert not is_safe, f"URL 不应该被接受: {url}"
# 安全的 URLs (应该被接受)
safe_urls = [
"https://github.com/Fu-Jie/openwebui-extensions/raw/main/SKILL.md",
"https://raw.githubusercontent.com/user/repo/main/skill.md",
"https://huggingface.co/spaces/user/skill",
]
print("\n✅ 安全且在白名单中的 URLs (应该被接受):")
for url in safe_urls:
is_safe, error_msg = tester._is_safe_url(url)
status = "✓ 被接受 (正确)" if is_safe else "✓ 被拒绝 (错误)"
error_info = f" - {error_msg}" if error_msg else ""
print(f" {url:<60} {status}{error_info}")
assert is_safe, f"URL 不应该被拒绝: {url} - {error_msg}"
print("\n✓ SSRF 防护测试通过!")
def test_tar_extraction_safety():
"""测试 TAR 提取路径遍历防护"""
print("\n" + "=" * 60)
print("测试 2: TAR 提取安全性 (_safe_extract_tar)")
print("=" * 60)
tester = SecurityTester()
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir_path = Path(tmpdir)
# 创建一个包含路径遍历尝试的 tar 文件
tar_path = tmpdir_path / "malicious.tar"
extract_dir = tmpdir_path / "extracted"
extract_dir.mkdir(parents=True, exist_ok=True)
print("\n创建测试 TAR 文件...")
with tarfile.open(tar_path, "w") as tf:
# 合法的成员
import io
info = tarfile.TarInfo(name="safe_file.txt")
info.size = 11
tf.addfile(tarinfo=info, fileobj=io.BytesIO(b"safe content"))
# 路径遍历尝试
info = tarfile.TarInfo(name="../../etc/passwd")
info.size = 10
tf.addfile(tarinfo=info, fileobj=io.BytesIO(b"evil data!"))
print(f" TAR 文件已创建: {tar_path}")
# 提取文件
print("\n提取 TAR 文件...")
try:
tester._safe_extract_tar(tar_path, extract_dir)
# 检查结果
safe_file = extract_dir / "safe_file.txt"
evil_file = extract_dir / "etc" / "passwd"
evil_file_alt = Path("/etc/passwd")
print(f" 检查合法文件: {safe_file.exists()} (应该为 True)")
assert safe_file.exists(), "合法文件应该被提取"
print(f" 检查恶意文件不存在: {not evil_file.exists()} (应该为 True)")
assert not evil_file.exists(), "恶意文件不应该被提取"
print("\n✓ TAR 提取安全性测试通过!")
except Exception as e:
print(f"✗ 提取失败: {e}")
raise
def test_zip_extraction_safety():
"""测试 ZIP 提取路径遍历防护"""
print("\n" + "=" * 60)
print("测试 3: ZIP 提取安全性 (_safe_extract_zip)")
print("=" * 60)
tester = SecurityTester()
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir_path = Path(tmpdir)
# 创建一个包含路径遍历尝试的 zip 文件
zip_path = tmpdir_path / "malicious.zip"
extract_dir = tmpdir_path / "extracted"
extract_dir.mkdir(parents=True, exist_ok=True)
print("\n创建测试 ZIP 文件...")
with zipfile.ZipFile(zip_path, "w") as zf:
# 合法的成员
zf.writestr("safe_file.txt", "safe content")
# 路径遍历尝试
zf.writestr("../../etc/passwd", "evil data!")
print(f" ZIP 文件已创建: {zip_path}")
# 提取文件
print("\n提取 ZIP 文件...")
try:
tester._safe_extract_zip(zip_path, extract_dir)
# 检查结果
safe_file = extract_dir / "safe_file.txt"
evil_file = extract_dir / "etc" / "passwd"
print(f" 检查合法文件: {safe_file.exists()} (应该为 True)")
assert safe_file.exists(), "合法文件应该被提取"
print(f" 检查恶意文件不存在: {not evil_file.exists()} (应该为 True)")
assert not evil_file.exists(), "恶意文件不应该被提取"
print("\n✓ ZIP 提取安全性测试通过!")
except Exception as e:
print(f"✗ 提取失败: {e}")
raise
def test_skill_name_collision():
"""测试技能名称冲突检查"""
print("\n" + "=" * 60)
print("测试 4: 技能名称冲突检查")
print("=" * 60)
# 模拟技能管理
user_id = "test_user_1"
MockSkills.reset()
# 创建第一个技能
print("\n创建技能 1: 'MySkill'...")
skill1 = MockSkill("skill_1", "MySkill", "First skill", "content1")
MockSkills._skills[user_id] = [skill1]
print(f" ✓ 技能已创建: {skill1.name}")
# 创建第二个技能
print("\n创建技能 2: 'AnotherSkill'...")
skill2 = MockSkill("skill_2", "AnotherSkill", "Second skill", "content2")
MockSkills._skills[user_id].append(skill2)
print(f" ✓ 技能已创建: {skill2.name}")
# 测试名称冲突检查逻辑
print("\n测试名称冲突检查...")
# 模拟尝试将 skill2 改名为 skill1 的名称
new_name = "MySkill" # 已被 skill1 占用
print(f"\n尝试将技能 2 改名为 '{new_name}'...")
print(f" 检查是否与其他技能冲突...")
# 这是 update_skill 中的冲突检查逻辑
collision_found = False
for other_skill in MockSkills._skills[user_id]:
# 跳过要更新的技能本身
if other_skill.id == "skill_2":
continue
# 检查是否存在同名技能
if other_skill.name.lower() == new_name.lower():
collision_found = True
print(f" ✓ 冲突检测成功!发现重复名称: {other_skill.name}")
break
assert collision_found, "应该检测到名称冲突"
# 测试允许的改名(改为不同的名称)
print(f"\n尝试将技能 2 改名为 'UniqueSkill'...")
new_name = "UniqueSkill"
collision_found = False
for other_skill in MockSkills._skills[user_id]:
if other_skill.id == "skill_2":
continue
if other_skill.name.lower() == new_name.lower():
collision_found = True
break
assert not collision_found, "不应该存在冲突"
print(f" ✓ 允许改名,没有冲突")
print("\n✓ 技能名称冲突检查测试通过!")
def test_url_normalization():
"""测试 URL 标准化"""
print("\n" + "=" * 60)
print("测试 5: URL 标准化")
print("=" * 60)
tester = SecurityTester()
# 测试无效的 URL
print("\n测试无效的 URL:")
invalid_urls = [
"not-a-url",
"ftp://example.com/file",
"",
" ",
]
for url in invalid_urls:
is_safe, error_msg = tester._is_safe_url(url)
print(f" '{url}' -> 被拒绝: {not is_safe}")
assert not is_safe, f"无效 URL 应该被拒绝: {url}"
print("\n✓ URL 标准化测试通过!")
def test_domain_whitelist():
"""测试域名白名单功能"""
print("\n" + "=" * 60)
print("测试 6: 域名白名单 (ENABLE_DOMAIN_WHITELIST)")
print("=" * 60)
# 创建启用白名单的测试器
tester = SecurityTester()
tester.valves.ENABLE_DOMAIN_WHITELIST = True
tester.valves.TRUSTED_DOMAINS = (
"github.com,raw.githubusercontent.com,huggingface.co"
)
print("\n配置信息:")
print(f" 白名单启用: {tester.valves.ENABLE_DOMAIN_WHITELIST}")
print(f" 授信域名: {tester.valves.TRUSTED_DOMAINS}")
# 白名单中的 URLs (应该被接受)
whitelisted_urls = [
"https://github.com/user/repo/raw/main/skill.md",
"https://raw.githubusercontent.com/user/repo/main/skill.md",
"https://api.github.com/repos/user/repo/contents",
"https://huggingface.co/spaces/user/skill",
]
print("\n✅ 白名单中的 URLs (应该被接受):")
for url in whitelisted_urls:
is_safe, error_msg = tester._is_safe_url(url)
status = "✓ 被接受 (正确)" if is_safe else "✗ 被拒绝 (错误)"
print(f" {url:<65} {status}")
assert is_safe, f"白名单中的 URL 应该被接受: {url} - {error_msg}"
# 不在白名单中的 URLs (应该被拒绝)
non_whitelisted_urls = [
"https://example.com/skill.md",
"https://evil.com/skill.zip",
"https://api.example.com/skill",
]
print("\n❌ 非白名单 URLs (应该被拒绝):")
for url in non_whitelisted_urls:
is_safe, error_msg = tester._is_safe_url(url)
status = "✗ 被拒绝 (正确)" if not is_safe else "✓ 被接受 (错误)"
print(f" {url:<65} {status}")
assert not is_safe, f"非白名单 URL 应该被拒绝: {url}"
# 测试禁用白名单
print("\n禁用白名单进行测试...")
tester.valves.ENABLE_DOMAIN_WHITELIST = False
is_safe, error_msg = tester._is_safe_url("https://example.com/skill.md")
print(f" example.com without whitelist: {is_safe}")
assert is_safe, "禁用白名单时example.com 应该被接受"
print("\n✓ 域名白名单测试通过!")
# ==================== 主函数 ====================
def main():
print("\n" + "🔒 OpenWebUI Skills Manager 安全修复测试".center(60, "="))
print("版本: 0.2.2")
print("=" * 60)
try:
# 运行所有测试
test_ssrf_protection()
test_tar_extraction_safety()
test_zip_extraction_safety()
test_skill_name_collision()
test_url_normalization()
test_domain_whitelist()
# 测试总结
print("\n" + "=" * 60)
print("🎉 所有测试通过!".center(60))
print("=" * 60)
print("\n修复验证:")
print(" ✓ SSRF 防护:阻止指向内部 IP 的请求")
print(" ✓ TAR/ZIP 安全提取:防止路径遍历攻击")
print(" ✓ 名称冲突检查:防止技能名称重复")
print(" ✓ URL 验证:仅接受安全的 HTTP(S) URL")
print(" ✓ 域名白名单:只允许授信域名下载技能")
print("\n所有安全功能都已成功实现!")
print("=" * 60 + "\n")
return 0
except AssertionError as e:
print(f"\n❌ 测试失败: {e}\n")
return 1
except Exception as e:
print(f"\n❌ 测试错误: {e}\n")
import traceback
traceback.print_exc()
return 1
if __name__ == "__main__":
sys.exit(main())