Files
Fu-Jie_openwebui-extensions/plugins/debug/openwebui-skills-manager/test_security_fixes.py
fujie d29c24ba4a feat(openwebui-skills-manager): enhance auto-discovery and structural refactoring
- Enable default overwrite installation policy for overlapping skills
- Support deep recursive GitHub trees discovery mechanism to resolve #58
- Refactor internal architecture to fully decouple stateless helper logic
- READMEs and docs synced (v0.3.0)
2026-03-08 18:21:21 +08:00

561 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
独立测试脚本:验证 OpenWebUI Skills Manager 的所有安全修复
不需要 OpenWebUI 环境,可以直接运行
测试内容:
1. SSRF 防护 (_is_safe_url)
2. 不安全 tar/zip 提取防护 (_safe_extract_zip, _safe_extract_tar)
3. 名称冲突检查 (update_skill)
4. URL 验证
"""
import asyncio
import json
import logging
import sys
import tempfile
import tarfile
import zipfile
from pathlib import Path
from typing import Optional, Dict, Any, List, Tuple
# 配置日志
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
# ==================== 模拟 OpenWebUI Skills 类 ====================
class MockSkill:
def __init__(self, id: str, name: str, description: str = "", content: str = ""):
self.id = id
self.name = name
self.description = description
self.content = content
self.is_active = True
self.updated_at = "2024-03-08T00:00:00Z"
class MockSkills:
"""Mock Skills 模型,用于测试"""
_skills: Dict[str, List[MockSkill]] = {}
@classmethod
def reset(cls):
cls._skills = {}
@classmethod
def get_skills_by_user_id(cls, user_id: str):
return cls._skills.get(user_id, [])
@classmethod
def insert_new_skill(cls, user_id: str, form_data):
if user_id not in cls._skills:
cls._skills[user_id] = []
skill = MockSkill(
form_data.id, form_data.name, form_data.description, form_data.content
)
cls._skills[user_id].append(skill)
return skill
@classmethod
def update_skill_by_id(cls, skill_id: str, updates: Dict[str, Any]):
for user_skills in cls._skills.values():
for skill in user_skills:
if skill.id == skill_id:
for key, value in updates.items():
setattr(skill, key, value)
return skill
return None
@classmethod
def delete_skill_by_id(cls, skill_id: str):
for user_id, user_skills in cls._skills.items():
for idx, skill in enumerate(user_skills):
if skill.id == skill_id:
user_skills.pop(idx)
return True
return False
# ==================== 提取安全测试的核心方法 ====================
import ipaddress
import urllib.parse
class SecurityTester:
"""提取出的安全测试核心类"""
def __init__(self):
# 模拟 Valves 配置
self.valves = type(
"Valves",
(),
{
"ENABLE_DOMAIN_WHITELIST": True,
"TRUSTED_DOMAINS": "github.com,raw.githubusercontent.com,huggingface.co",
},
)()
def _is_safe_url(self, url: str) -> tuple:
"""
验证 URL 是否指向内部/敏感目标。
防止服务端请求伪造 (SSRF) 攻击。
返回 (True, None) 如果 URL 是安全的,否则返回 (False, error_message)。
"""
try:
parsed = urllib.parse.urlparse(url)
hostname = parsed.hostname or ""
if not hostname:
return False, "URL is malformed: missing hostname"
# 拒绝 localhost 变体
if hostname.lower() in (
"localhost",
"127.0.0.1",
"::1",
"[::1]",
"0.0.0.0",
"[::ffff:127.0.0.1]",
"localhost.localdomain",
):
return False, "URL points to local host"
# 拒绝内部 IP 范围 (RFC 1918, link-local 等)
try:
ip = ipaddress.ip_address(hostname.lstrip("[").rstrip("]"))
# 拒绝私有、回环、链接本地和保留 IP
if (
ip.is_private
or ip.is_loopback
or ip.is_link_local
or ip.is_reserved
):
return False, f"URL points to internal IP: {ip}"
except ValueError:
# 不是 IP 地址,检查 hostname 模式
pass
# 拒绝 file:// 和其他非 http(s) 方案
if parsed.scheme not in ("http", "https"):
return False, f"URL scheme not allowed: {parsed.scheme}"
# 域名白名单检查 (安全层 2)
if self.valves.ENABLE_DOMAIN_WHITELIST:
trusted_domains = [
d.strip().lower()
for d in (self.valves.TRUSTED_DOMAINS or "").split(",")
if d.strip()
]
if not trusted_domains:
# 没有配置授信域名,仅进行安全检查
return True, None
hostname_lower = hostname.lower()
# 检查 hostname 是否匹配任何授信域名(精确或子域名)
is_trusted = False
for trusted_domain in trusted_domains:
# 精确匹配
if hostname_lower == trusted_domain:
is_trusted = True
break
# 子域名匹配 (*.example.com 匹配 api.example.com)
if hostname_lower.endswith("." + trusted_domain):
is_trusted = True
break
if not is_trusted:
error_msg = f"URL domain '{hostname}' is not in whitelist. Trusted domains: {', '.join(trusted_domains)}"
return False, error_msg
return True, None
except Exception as e:
return False, f"Error validating URL: {e}"
def _safe_extract_zip(self, zip_path: Path, extract_dir: Path) -> None:
"""
安全地提取 ZIP 文件,验证成员路径以防止路径遍历。
"""
with zipfile.ZipFile(zip_path, "r") as zf:
for member in zf.namelist():
# 检查路径遍历尝试
member_path = Path(extract_dir) / member
try:
# 确保解析的路径在 extract_dir 内
member_path.resolve().relative_to(extract_dir.resolve())
except ValueError:
# 路径在 extract_dir 外(遍历尝试)
logger.warning(f"Skipping unsafe ZIP member: {member}")
continue
# 提取成员
zf.extract(member, extract_dir)
def _safe_extract_tar(self, tar_path: Path, extract_dir: Path) -> None:
"""
安全地提取 TAR 文件,验证成员路径以防止路径遍历。
"""
with tarfile.open(tar_path, "r:*") as tf:
for member in tf.getmembers():
# 检查路径遍历尝试
member_path = Path(extract_dir) / member.name
try:
# 确保解析的路径在 extract_dir 内
member_path.resolve().relative_to(extract_dir.resolve())
except ValueError:
# 路径在 extract_dir 外(遍历尝试)
logger.warning(f"Skipping unsafe TAR member: {member.name}")
continue
# 提取成员
tf.extract(member, extract_dir)
# ==================== 测试用例 ====================
def test_ssrf_protection():
"""测试 SSRF 防护"""
print("\n" + "=" * 60)
print("测试 1: SSRF 防护 (_is_safe_url)")
print("=" * 60)
tester = SecurityTester()
# 不安全的 URLs (应该被拒绝)
unsafe_urls = [
"http://localhost/skill",
"http://127.0.0.1:8000/skill",
"http://[::1]/skill",
"http://0.0.0.0/skill",
"http://192.168.1.1/skill", # 私有 IP (RFC 1918)
"http://10.0.0.1/skill",
"http://172.16.0.1/skill",
"http://169.254.1.1/skill", # link-local
"file:///etc/passwd", # file:// scheme
"gopher://example.com/skill", # 非 http(s)
]
print("\n❌ 不安全的 URLs (应该被拒绝):")
for url in unsafe_urls:
is_safe, error_msg = tester._is_safe_url(url)
status = "✗ 被拒绝 (正确)" if not is_safe else "✗ 被接受 (错误)"
error_info = f" - {error_msg}" if error_msg else ""
print(f" {url:<50} {status}{error_info}")
assert not is_safe, f"URL 不应该被接受: {url}"
# 安全的 URLs (应该被接受)
safe_urls = [
"https://github.com/Fu-Jie/openwebui-extensions/raw/main/SKILL.md",
"https://raw.githubusercontent.com/user/repo/main/skill.md",
"https://huggingface.co/spaces/user/skill",
]
print("\n✅ 安全且在白名单中的 URLs (应该被接受):")
for url in safe_urls:
is_safe, error_msg = tester._is_safe_url(url)
status = "✓ 被接受 (正确)" if is_safe else "✓ 被拒绝 (错误)"
error_info = f" - {error_msg}" if error_msg else ""
print(f" {url:<60} {status}{error_info}")
assert is_safe, f"URL 不应该被拒绝: {url} - {error_msg}"
print("\n✓ SSRF 防护测试通过!")
def test_tar_extraction_safety():
"""测试 TAR 提取路径遍历防护"""
print("\n" + "=" * 60)
print("测试 2: TAR 提取安全性 (_safe_extract_tar)")
print("=" * 60)
tester = SecurityTester()
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir_path = Path(tmpdir)
# 创建一个包含路径遍历尝试的 tar 文件
tar_path = tmpdir_path / "malicious.tar"
extract_dir = tmpdir_path / "extracted"
extract_dir.mkdir(parents=True, exist_ok=True)
print("\n创建测试 TAR 文件...")
with tarfile.open(tar_path, "w") as tf:
# 合法的成员
import io
info = tarfile.TarInfo(name="safe_file.txt")
info.size = 11
tf.addfile(tarinfo=info, fileobj=io.BytesIO(b"safe content"))
# 路径遍历尝试
info = tarfile.TarInfo(name="../../etc/passwd")
info.size = 10
tf.addfile(tarinfo=info, fileobj=io.BytesIO(b"evil data!"))
print(f" TAR 文件已创建: {tar_path}")
# 提取文件
print("\n提取 TAR 文件...")
try:
tester._safe_extract_tar(tar_path, extract_dir)
# 检查结果
safe_file = extract_dir / "safe_file.txt"
evil_file = extract_dir / "etc" / "passwd"
evil_file_alt = Path("/etc/passwd")
print(f" 检查合法文件: {safe_file.exists()} (应该为 True)")
assert safe_file.exists(), "合法文件应该被提取"
print(f" 检查恶意文件不存在: {not evil_file.exists()} (应该为 True)")
assert not evil_file.exists(), "恶意文件不应该被提取"
print("\n✓ TAR 提取安全性测试通过!")
except Exception as e:
print(f"✗ 提取失败: {e}")
raise
def test_zip_extraction_safety():
"""测试 ZIP 提取路径遍历防护"""
print("\n" + "=" * 60)
print("测试 3: ZIP 提取安全性 (_safe_extract_zip)")
print("=" * 60)
tester = SecurityTester()
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir_path = Path(tmpdir)
# 创建一个包含路径遍历尝试的 zip 文件
zip_path = tmpdir_path / "malicious.zip"
extract_dir = tmpdir_path / "extracted"
extract_dir.mkdir(parents=True, exist_ok=True)
print("\n创建测试 ZIP 文件...")
with zipfile.ZipFile(zip_path, "w") as zf:
# 合法的成员
zf.writestr("safe_file.txt", "safe content")
# 路径遍历尝试
zf.writestr("../../etc/passwd", "evil data!")
print(f" ZIP 文件已创建: {zip_path}")
# 提取文件
print("\n提取 ZIP 文件...")
try:
tester._safe_extract_zip(zip_path, extract_dir)
# 检查结果
safe_file = extract_dir / "safe_file.txt"
evil_file = extract_dir / "etc" / "passwd"
print(f" 检查合法文件: {safe_file.exists()} (应该为 True)")
assert safe_file.exists(), "合法文件应该被提取"
print(f" 检查恶意文件不存在: {not evil_file.exists()} (应该为 True)")
assert not evil_file.exists(), "恶意文件不应该被提取"
print("\n✓ ZIP 提取安全性测试通过!")
except Exception as e:
print(f"✗ 提取失败: {e}")
raise
def test_skill_name_collision():
"""测试技能名称冲突检查"""
print("\n" + "=" * 60)
print("测试 4: 技能名称冲突检查")
print("=" * 60)
# 模拟技能管理
user_id = "test_user_1"
MockSkills.reset()
# 创建第一个技能
print("\n创建技能 1: 'MySkill'...")
skill1 = MockSkill("skill_1", "MySkill", "First skill", "content1")
MockSkills._skills[user_id] = [skill1]
print(f" ✓ 技能已创建: {skill1.name}")
# 创建第二个技能
print("\n创建技能 2: 'AnotherSkill'...")
skill2 = MockSkill("skill_2", "AnotherSkill", "Second skill", "content2")
MockSkills._skills[user_id].append(skill2)
print(f" ✓ 技能已创建: {skill2.name}")
# 测试名称冲突检查逻辑
print("\n测试名称冲突检查...")
# 模拟尝试将 skill2 改名为 skill1 的名称
new_name = "MySkill" # 已被 skill1 占用
print(f"\n尝试将技能 2 改名为 '{new_name}'...")
print(f" 检查是否与其他技能冲突...")
# 这是 update_skill 中的冲突检查逻辑
collision_found = False
for other_skill in MockSkills._skills[user_id]:
# 跳过要更新的技能本身
if other_skill.id == "skill_2":
continue
# 检查是否存在同名技能
if other_skill.name.lower() == new_name.lower():
collision_found = True
print(f" ✓ 冲突检测成功!发现重复名称: {other_skill.name}")
break
assert collision_found, "应该检测到名称冲突"
# 测试允许的改名(改为不同的名称)
print(f"\n尝试将技能 2 改名为 'UniqueSkill'...")
new_name = "UniqueSkill"
collision_found = False
for other_skill in MockSkills._skills[user_id]:
if other_skill.id == "skill_2":
continue
if other_skill.name.lower() == new_name.lower():
collision_found = True
break
assert not collision_found, "不应该存在冲突"
print(f" ✓ 允许改名,没有冲突")
print("\n✓ 技能名称冲突检查测试通过!")
def test_url_normalization():
"""测试 URL 标准化"""
print("\n" + "=" * 60)
print("测试 5: URL 标准化")
print("=" * 60)
tester = SecurityTester()
# 测试无效的 URL
print("\n测试无效的 URL:")
invalid_urls = [
"not-a-url",
"ftp://example.com/file",
"",
" ",
]
for url in invalid_urls:
is_safe, error_msg = tester._is_safe_url(url)
print(f" '{url}' -> 被拒绝: {not is_safe}")
assert not is_safe, f"无效 URL 应该被拒绝: {url}"
print("\n✓ URL 标准化测试通过!")
def test_domain_whitelist():
"""测试域名白名单功能"""
print("\n" + "=" * 60)
print("测试 6: 域名白名单 (ENABLE_DOMAIN_WHITELIST)")
print("=" * 60)
# 创建启用白名单的测试器
tester = SecurityTester()
tester.valves.ENABLE_DOMAIN_WHITELIST = True
tester.valves.TRUSTED_DOMAINS = (
"github.com,raw.githubusercontent.com,huggingface.co"
)
print("\n配置信息:")
print(f" 白名单启用: {tester.valves.ENABLE_DOMAIN_WHITELIST}")
print(f" 授信域名: {tester.valves.TRUSTED_DOMAINS}")
# 白名单中的 URLs (应该被接受)
whitelisted_urls = [
"https://github.com/user/repo/raw/main/skill.md",
"https://raw.githubusercontent.com/user/repo/main/skill.md",
"https://api.github.com/repos/user/repo/contents",
"https://huggingface.co/spaces/user/skill",
]
print("\n✅ 白名单中的 URLs (应该被接受):")
for url in whitelisted_urls:
is_safe, error_msg = tester._is_safe_url(url)
status = "✓ 被接受 (正确)" if is_safe else "✗ 被拒绝 (错误)"
print(f" {url:<65} {status}")
assert is_safe, f"白名单中的 URL 应该被接受: {url} - {error_msg}"
# 不在白名单中的 URLs (应该被拒绝)
non_whitelisted_urls = [
"https://example.com/skill.md",
"https://evil.com/skill.zip",
"https://api.example.com/skill",
]
print("\n❌ 非白名单 URLs (应该被拒绝):")
for url in non_whitelisted_urls:
is_safe, error_msg = tester._is_safe_url(url)
status = "✗ 被拒绝 (正确)" if not is_safe else "✓ 被接受 (错误)"
print(f" {url:<65} {status}")
assert not is_safe, f"非白名单 URL 应该被拒绝: {url}"
# 测试禁用白名单
print("\n禁用白名单进行测试...")
tester.valves.ENABLE_DOMAIN_WHITELIST = False
is_safe, error_msg = tester._is_safe_url("https://example.com/skill.md")
print(f" example.com without whitelist: {is_safe}")
assert is_safe, "禁用白名单时example.com 应该被接受"
print("\n✓ 域名白名单测试通过!")
# ==================== 主函数 ====================
def main():
print("\n" + "🔒 OpenWebUI Skills Manager 安全修复测试".center(60, "="))
print("版本: 0.2.2")
print("=" * 60)
try:
# 运行所有测试
test_ssrf_protection()
test_tar_extraction_safety()
test_zip_extraction_safety()
test_skill_name_collision()
test_url_normalization()
test_domain_whitelist()
# 测试总结
print("\n" + "=" * 60)
print("🎉 所有测试通过!".center(60))
print("=" * 60)
print("\n修复验证:")
print(" ✓ SSRF 防护:阻止指向内部 IP 的请求")
print(" ✓ TAR/ZIP 安全提取:防止路径遍历攻击")
print(" ✓ 名称冲突检查:防止技能名称重复")
print(" ✓ URL 验证:仅接受安全的 HTTP(S) URL")
print(" ✓ 域名白名单:只允许授信域名下载技能")
print("\n所有安全功能都已成功实现!")
print("=" * 60 + "\n")
return 0
except AssertionError as e:
print(f"\n❌ 测试失败: {e}\n")
return 1
except Exception as e:
print(f"\n❌ 测试错误: {e}\n")
import traceback
traceback.print_exc()
return 1
if __name__ == "__main__":
sys.exit(main())