561 lines
19 KiB
Python
561 lines
19 KiB
Python
|
|
#!/usr/bin/env python3
|
|||
|
|
"""
|
|||
|
|
独立测试脚本:验证 OpenWebUI Skills Manager 的所有安全修复
|
|||
|
|
不需要 OpenWebUI 环境,可以直接运行
|
|||
|
|
|
|||
|
|
测试内容:
|
|||
|
|
1. SSRF 防护 (_is_safe_url)
|
|||
|
|
2. 不安全 tar/zip 提取防护 (_safe_extract_zip, _safe_extract_tar)
|
|||
|
|
3. 名称冲突检查 (update_skill)
|
|||
|
|
4. URL 验证
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import asyncio
|
|||
|
|
import json
|
|||
|
|
import logging
|
|||
|
|
import sys
|
|||
|
|
import tempfile
|
|||
|
|
import tarfile
|
|||
|
|
import zipfile
|
|||
|
|
from pathlib import Path
|
|||
|
|
from typing import Optional, Dict, Any, List, Tuple
|
|||
|
|
|
|||
|
|
# 配置日志
|
|||
|
|
logging.basicConfig(
|
|||
|
|
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
|||
|
|
)
|
|||
|
|
logger = logging.getLogger(__name__)
|
|||
|
|
|
|||
|
|
# ==================== 模拟 OpenWebUI Skills 类 ====================
|
|||
|
|
|
|||
|
|
|
|||
|
|
class MockSkill:
|
|||
|
|
def __init__(self, id: str, name: str, description: str = "", content: str = ""):
|
|||
|
|
self.id = id
|
|||
|
|
self.name = name
|
|||
|
|
self.description = description
|
|||
|
|
self.content = content
|
|||
|
|
self.is_active = True
|
|||
|
|
self.updated_at = "2024-03-08T00:00:00Z"
|
|||
|
|
|
|||
|
|
|
|||
|
|
class MockSkills:
|
|||
|
|
"""Mock Skills 模型,用于测试"""
|
|||
|
|
|
|||
|
|
_skills: Dict[str, List[MockSkill]] = {}
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
def reset(cls):
|
|||
|
|
cls._skills = {}
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
def get_skills_by_user_id(cls, user_id: str):
|
|||
|
|
return cls._skills.get(user_id, [])
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
def insert_new_skill(cls, user_id: str, form_data):
|
|||
|
|
if user_id not in cls._skills:
|
|||
|
|
cls._skills[user_id] = []
|
|||
|
|
skill = MockSkill(
|
|||
|
|
form_data.id, form_data.name, form_data.description, form_data.content
|
|||
|
|
)
|
|||
|
|
cls._skills[user_id].append(skill)
|
|||
|
|
return skill
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
def update_skill_by_id(cls, skill_id: str, updates: Dict[str, Any]):
|
|||
|
|
for user_skills in cls._skills.values():
|
|||
|
|
for skill in user_skills:
|
|||
|
|
if skill.id == skill_id:
|
|||
|
|
for key, value in updates.items():
|
|||
|
|
setattr(skill, key, value)
|
|||
|
|
return skill
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
def delete_skill_by_id(cls, skill_id: str):
|
|||
|
|
for user_id, user_skills in cls._skills.items():
|
|||
|
|
for idx, skill in enumerate(user_skills):
|
|||
|
|
if skill.id == skill_id:
|
|||
|
|
user_skills.pop(idx)
|
|||
|
|
return True
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ==================== 提取安全测试的核心方法 ====================
|
|||
|
|
|
|||
|
|
import ipaddress
|
|||
|
|
import urllib.parse
|
|||
|
|
|
|||
|
|
|
|||
|
|
class SecurityTester:
|
|||
|
|
"""提取出的安全测试核心类"""
|
|||
|
|
|
|||
|
|
def __init__(self):
|
|||
|
|
# 模拟 Valves 配置
|
|||
|
|
self.valves = type(
|
|||
|
|
"Valves",
|
|||
|
|
(),
|
|||
|
|
{
|
|||
|
|
"ENABLE_DOMAIN_WHITELIST": True,
|
|||
|
|
"TRUSTED_DOMAINS": "github.com,raw.githubusercontent.com,huggingface.co",
|
|||
|
|
},
|
|||
|
|
)()
|
|||
|
|
|
|||
|
|
def _is_safe_url(self, url: str) -> tuple:
|
|||
|
|
"""
|
|||
|
|
验证 URL 是否指向内部/敏感目标。
|
|||
|
|
防止服务端请求伪造 (SSRF) 攻击。
|
|||
|
|
|
|||
|
|
返回 (True, None) 如果 URL 是安全的,否则返回 (False, error_message)。
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
parsed = urllib.parse.urlparse(url)
|
|||
|
|
hostname = parsed.hostname or ""
|
|||
|
|
|
|||
|
|
if not hostname:
|
|||
|
|
return False, "URL is malformed: missing hostname"
|
|||
|
|
|
|||
|
|
# 拒绝 localhost 变体
|
|||
|
|
if hostname.lower() in (
|
|||
|
|
"localhost",
|
|||
|
|
"127.0.0.1",
|
|||
|
|
"::1",
|
|||
|
|
"[::1]",
|
|||
|
|
"0.0.0.0",
|
|||
|
|
"[::ffff:127.0.0.1]",
|
|||
|
|
"localhost.localdomain",
|
|||
|
|
):
|
|||
|
|
return False, "URL points to local host"
|
|||
|
|
|
|||
|
|
# 拒绝内部 IP 范围 (RFC 1918, link-local 等)
|
|||
|
|
try:
|
|||
|
|
ip = ipaddress.ip_address(hostname.lstrip("[").rstrip("]"))
|
|||
|
|
# 拒绝私有、回环、链接本地和保留 IP
|
|||
|
|
if (
|
|||
|
|
ip.is_private
|
|||
|
|
or ip.is_loopback
|
|||
|
|
or ip.is_link_local
|
|||
|
|
or ip.is_reserved
|
|||
|
|
):
|
|||
|
|
return False, f"URL points to internal IP: {ip}"
|
|||
|
|
except ValueError:
|
|||
|
|
# 不是 IP 地址,检查 hostname 模式
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
# 拒绝 file:// 和其他非 http(s) 方案
|
|||
|
|
if parsed.scheme not in ("http", "https"):
|
|||
|
|
return False, f"URL scheme not allowed: {parsed.scheme}"
|
|||
|
|
|
|||
|
|
# 域名白名单检查 (安全层 2)
|
|||
|
|
if self.valves.ENABLE_DOMAIN_WHITELIST:
|
|||
|
|
trusted_domains = [
|
|||
|
|
d.strip().lower()
|
|||
|
|
for d in (self.valves.TRUSTED_DOMAINS or "").split(",")
|
|||
|
|
if d.strip()
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
if not trusted_domains:
|
|||
|
|
# 没有配置授信域名,仅进行安全检查
|
|||
|
|
return True, None
|
|||
|
|
|
|||
|
|
hostname_lower = hostname.lower()
|
|||
|
|
|
|||
|
|
# 检查 hostname 是否匹配任何授信域名(精确或子域名)
|
|||
|
|
is_trusted = False
|
|||
|
|
for trusted_domain in trusted_domains:
|
|||
|
|
# 精确匹配
|
|||
|
|
if hostname_lower == trusted_domain:
|
|||
|
|
is_trusted = True
|
|||
|
|
break
|
|||
|
|
# 子域名匹配 (*.example.com 匹配 api.example.com)
|
|||
|
|
if hostname_lower.endswith("." + trusted_domain):
|
|||
|
|
is_trusted = True
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
if not is_trusted:
|
|||
|
|
error_msg = f"URL domain '{hostname}' is not in whitelist. Trusted domains: {', '.join(trusted_domains)}"
|
|||
|
|
return False, error_msg
|
|||
|
|
|
|||
|
|
return True, None
|
|||
|
|
except Exception as e:
|
|||
|
|
return False, f"Error validating URL: {e}"
|
|||
|
|
|
|||
|
|
def _safe_extract_zip(self, zip_path: Path, extract_dir: Path) -> None:
|
|||
|
|
"""
|
|||
|
|
安全地提取 ZIP 文件,验证成员路径以防止路径遍历。
|
|||
|
|
"""
|
|||
|
|
with zipfile.ZipFile(zip_path, "r") as zf:
|
|||
|
|
for member in zf.namelist():
|
|||
|
|
# 检查路径遍历尝试
|
|||
|
|
member_path = Path(extract_dir) / member
|
|||
|
|
try:
|
|||
|
|
# 确保解析的路径在 extract_dir 内
|
|||
|
|
member_path.resolve().relative_to(extract_dir.resolve())
|
|||
|
|
except ValueError:
|
|||
|
|
# 路径在 extract_dir 外(遍历尝试)
|
|||
|
|
logger.warning(f"Skipping unsafe ZIP member: {member}")
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
# 提取成员
|
|||
|
|
zf.extract(member, extract_dir)
|
|||
|
|
|
|||
|
|
def _safe_extract_tar(self, tar_path: Path, extract_dir: Path) -> None:
|
|||
|
|
"""
|
|||
|
|
安全地提取 TAR 文件,验证成员路径以防止路径遍历。
|
|||
|
|
"""
|
|||
|
|
with tarfile.open(tar_path, "r:*") as tf:
|
|||
|
|
for member in tf.getmembers():
|
|||
|
|
# 检查路径遍历尝试
|
|||
|
|
member_path = Path(extract_dir) / member.name
|
|||
|
|
try:
|
|||
|
|
# 确保解析的路径在 extract_dir 内
|
|||
|
|
member_path.resolve().relative_to(extract_dir.resolve())
|
|||
|
|
except ValueError:
|
|||
|
|
# 路径在 extract_dir 外(遍历尝试)
|
|||
|
|
logger.warning(f"Skipping unsafe TAR member: {member.name}")
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
# 提取成员
|
|||
|
|
tf.extract(member, extract_dir)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ==================== 测试用例 ====================
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_ssrf_protection():
|
|||
|
|
"""测试 SSRF 防护"""
|
|||
|
|
print("\n" + "=" * 60)
|
|||
|
|
print("测试 1: SSRF 防护 (_is_safe_url)")
|
|||
|
|
print("=" * 60)
|
|||
|
|
|
|||
|
|
tester = SecurityTester()
|
|||
|
|
|
|||
|
|
# 不安全的 URLs (应该被拒绝)
|
|||
|
|
unsafe_urls = [
|
|||
|
|
"http://localhost/skill",
|
|||
|
|
"http://127.0.0.1:8000/skill",
|
|||
|
|
"http://[::1]/skill",
|
|||
|
|
"http://0.0.0.0/skill",
|
|||
|
|
"http://192.168.1.1/skill", # 私有 IP (RFC 1918)
|
|||
|
|
"http://10.0.0.1/skill",
|
|||
|
|
"http://172.16.0.1/skill",
|
|||
|
|
"http://169.254.1.1/skill", # link-local
|
|||
|
|
"file:///etc/passwd", # file:// scheme
|
|||
|
|
"gopher://example.com/skill", # 非 http(s)
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
print("\n❌ 不安全的 URLs (应该被拒绝):")
|
|||
|
|
for url in unsafe_urls:
|
|||
|
|
is_safe, error_msg = tester._is_safe_url(url)
|
|||
|
|
status = "✗ 被拒绝 (正确)" if not is_safe else "✗ 被接受 (错误)"
|
|||
|
|
error_info = f" - {error_msg}" if error_msg else ""
|
|||
|
|
print(f" {url:<50} {status}{error_info}")
|
|||
|
|
assert not is_safe, f"URL 不应该被接受: {url}"
|
|||
|
|
|
|||
|
|
# 安全的 URLs (应该被接受)
|
|||
|
|
safe_urls = [
|
|||
|
|
"https://github.com/Fu-Jie/openwebui-extensions/raw/main/SKILL.md",
|
|||
|
|
"https://raw.githubusercontent.com/user/repo/main/skill.md",
|
|||
|
|
"https://huggingface.co/spaces/user/skill",
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
print("\n✅ 安全且在白名单中的 URLs (应该被接受):")
|
|||
|
|
for url in safe_urls:
|
|||
|
|
is_safe, error_msg = tester._is_safe_url(url)
|
|||
|
|
status = "✓ 被接受 (正确)" if is_safe else "✓ 被拒绝 (错误)"
|
|||
|
|
error_info = f" - {error_msg}" if error_msg else ""
|
|||
|
|
print(f" {url:<60} {status}{error_info}")
|
|||
|
|
assert is_safe, f"URL 不应该被拒绝: {url} - {error_msg}"
|
|||
|
|
|
|||
|
|
print("\n✓ SSRF 防护测试通过!")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_tar_extraction_safety():
|
|||
|
|
"""测试 TAR 提取路径遍历防护"""
|
|||
|
|
print("\n" + "=" * 60)
|
|||
|
|
print("测试 2: TAR 提取安全性 (_safe_extract_tar)")
|
|||
|
|
print("=" * 60)
|
|||
|
|
|
|||
|
|
tester = SecurityTester()
|
|||
|
|
|
|||
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|||
|
|
tmpdir_path = Path(tmpdir)
|
|||
|
|
|
|||
|
|
# 创建一个包含路径遍历尝试的 tar 文件
|
|||
|
|
tar_path = tmpdir_path / "malicious.tar"
|
|||
|
|
extract_dir = tmpdir_path / "extracted"
|
|||
|
|
extract_dir.mkdir(parents=True, exist_ok=True)
|
|||
|
|
|
|||
|
|
print("\n创建测试 TAR 文件...")
|
|||
|
|
with tarfile.open(tar_path, "w") as tf:
|
|||
|
|
# 合法的成员
|
|||
|
|
import io
|
|||
|
|
|
|||
|
|
info = tarfile.TarInfo(name="safe_file.txt")
|
|||
|
|
info.size = 11
|
|||
|
|
tf.addfile(tarinfo=info, fileobj=io.BytesIO(b"safe content"))
|
|||
|
|
|
|||
|
|
# 路径遍历尝试
|
|||
|
|
info = tarfile.TarInfo(name="../../etc/passwd")
|
|||
|
|
info.size = 10
|
|||
|
|
tf.addfile(tarinfo=info, fileobj=io.BytesIO(b"evil data!"))
|
|||
|
|
|
|||
|
|
print(f" TAR 文件已创建: {tar_path}")
|
|||
|
|
|
|||
|
|
# 提取文件
|
|||
|
|
print("\n提取 TAR 文件...")
|
|||
|
|
try:
|
|||
|
|
tester._safe_extract_tar(tar_path, extract_dir)
|
|||
|
|
|
|||
|
|
# 检查结果
|
|||
|
|
safe_file = extract_dir / "safe_file.txt"
|
|||
|
|
evil_file = extract_dir / "etc" / "passwd"
|
|||
|
|
evil_file_alt = Path("/etc/passwd")
|
|||
|
|
|
|||
|
|
print(f" 检查合法文件: {safe_file.exists()} (应该为 True)")
|
|||
|
|
assert safe_file.exists(), "合法文件应该被提取"
|
|||
|
|
|
|||
|
|
print(f" 检查恶意文件不存在: {not evil_file.exists()} (应该为 True)")
|
|||
|
|
assert not evil_file.exists(), "恶意文件不应该被提取"
|
|||
|
|
|
|||
|
|
print("\n✓ TAR 提取安全性测试通过!")
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"✗ 提取失败: {e}")
|
|||
|
|
raise
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_zip_extraction_safety():
|
|||
|
|
"""测试 ZIP 提取路径遍历防护"""
|
|||
|
|
print("\n" + "=" * 60)
|
|||
|
|
print("测试 3: ZIP 提取安全性 (_safe_extract_zip)")
|
|||
|
|
print("=" * 60)
|
|||
|
|
|
|||
|
|
tester = SecurityTester()
|
|||
|
|
|
|||
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|||
|
|
tmpdir_path = Path(tmpdir)
|
|||
|
|
|
|||
|
|
# 创建一个包含路径遍历尝试的 zip 文件
|
|||
|
|
zip_path = tmpdir_path / "malicious.zip"
|
|||
|
|
extract_dir = tmpdir_path / "extracted"
|
|||
|
|
extract_dir.mkdir(parents=True, exist_ok=True)
|
|||
|
|
|
|||
|
|
print("\n创建测试 ZIP 文件...")
|
|||
|
|
with zipfile.ZipFile(zip_path, "w") as zf:
|
|||
|
|
# 合法的成员
|
|||
|
|
zf.writestr("safe_file.txt", "safe content")
|
|||
|
|
|
|||
|
|
# 路径遍历尝试
|
|||
|
|
zf.writestr("../../etc/passwd", "evil data!")
|
|||
|
|
|
|||
|
|
print(f" ZIP 文件已创建: {zip_path}")
|
|||
|
|
|
|||
|
|
# 提取文件
|
|||
|
|
print("\n提取 ZIP 文件...")
|
|||
|
|
try:
|
|||
|
|
tester._safe_extract_zip(zip_path, extract_dir)
|
|||
|
|
|
|||
|
|
# 检查结果
|
|||
|
|
safe_file = extract_dir / "safe_file.txt"
|
|||
|
|
evil_file = extract_dir / "etc" / "passwd"
|
|||
|
|
|
|||
|
|
print(f" 检查合法文件: {safe_file.exists()} (应该为 True)")
|
|||
|
|
assert safe_file.exists(), "合法文件应该被提取"
|
|||
|
|
|
|||
|
|
print(f" 检查恶意文件不存在: {not evil_file.exists()} (应该为 True)")
|
|||
|
|
assert not evil_file.exists(), "恶意文件不应该被提取"
|
|||
|
|
|
|||
|
|
print("\n✓ ZIP 提取安全性测试通过!")
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"✗ 提取失败: {e}")
|
|||
|
|
raise
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_skill_name_collision():
|
|||
|
|
"""测试技能名称冲突检查"""
|
|||
|
|
print("\n" + "=" * 60)
|
|||
|
|
print("测试 4: 技能名称冲突检查")
|
|||
|
|
print("=" * 60)
|
|||
|
|
|
|||
|
|
# 模拟技能管理
|
|||
|
|
user_id = "test_user_1"
|
|||
|
|
MockSkills.reset()
|
|||
|
|
|
|||
|
|
# 创建第一个技能
|
|||
|
|
print("\n创建技能 1: 'MySkill'...")
|
|||
|
|
skill1 = MockSkill("skill_1", "MySkill", "First skill", "content1")
|
|||
|
|
MockSkills._skills[user_id] = [skill1]
|
|||
|
|
print(f" ✓ 技能已创建: {skill1.name}")
|
|||
|
|
|
|||
|
|
# 创建第二个技能
|
|||
|
|
print("\n创建技能 2: 'AnotherSkill'...")
|
|||
|
|
skill2 = MockSkill("skill_2", "AnotherSkill", "Second skill", "content2")
|
|||
|
|
MockSkills._skills[user_id].append(skill2)
|
|||
|
|
print(f" ✓ 技能已创建: {skill2.name}")
|
|||
|
|
|
|||
|
|
# 测试名称冲突检查逻辑
|
|||
|
|
print("\n测试名称冲突检查...")
|
|||
|
|
|
|||
|
|
# 模拟尝试将 skill2 改名为 skill1 的名称
|
|||
|
|
new_name = "MySkill" # 已被 skill1 占用
|
|||
|
|
print(f"\n尝试将技能 2 改名为 '{new_name}'...")
|
|||
|
|
print(f" 检查是否与其他技能冲突...")
|
|||
|
|
|
|||
|
|
# 这是 update_skill 中的冲突检查逻辑
|
|||
|
|
collision_found = False
|
|||
|
|
for other_skill in MockSkills._skills[user_id]:
|
|||
|
|
# 跳过要更新的技能本身
|
|||
|
|
if other_skill.id == "skill_2":
|
|||
|
|
continue
|
|||
|
|
# 检查是否存在同名技能
|
|||
|
|
if other_skill.name.lower() == new_name.lower():
|
|||
|
|
collision_found = True
|
|||
|
|
print(f" ✓ 冲突检测成功!发现重复名称: {other_skill.name}")
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
assert collision_found, "应该检测到名称冲突"
|
|||
|
|
|
|||
|
|
# 测试允许的改名(改为不同的名称)
|
|||
|
|
print(f"\n尝试将技能 2 改名为 'UniqueSkill'...")
|
|||
|
|
new_name = "UniqueSkill"
|
|||
|
|
collision_found = False
|
|||
|
|
for other_skill in MockSkills._skills[user_id]:
|
|||
|
|
if other_skill.id == "skill_2":
|
|||
|
|
continue
|
|||
|
|
if other_skill.name.lower() == new_name.lower():
|
|||
|
|
collision_found = True
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
assert not collision_found, "不应该存在冲突"
|
|||
|
|
print(f" ✓ 允许改名,没有冲突")
|
|||
|
|
|
|||
|
|
print("\n✓ 技能名称冲突检查测试通过!")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_url_normalization():
|
|||
|
|
"""测试 URL 标准化"""
|
|||
|
|
print("\n" + "=" * 60)
|
|||
|
|
print("测试 5: URL 标准化")
|
|||
|
|
print("=" * 60)
|
|||
|
|
|
|||
|
|
tester = SecurityTester()
|
|||
|
|
|
|||
|
|
# 测试无效的 URL
|
|||
|
|
print("\n测试无效的 URL:")
|
|||
|
|
invalid_urls = [
|
|||
|
|
"not-a-url",
|
|||
|
|
"ftp://example.com/file",
|
|||
|
|
"",
|
|||
|
|
" ",
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
for url in invalid_urls:
|
|||
|
|
is_safe, error_msg = tester._is_safe_url(url)
|
|||
|
|
print(f" '{url}' -> 被拒绝: {not is_safe} ✓")
|
|||
|
|
assert not is_safe, f"无效 URL 应该被拒绝: {url}"
|
|||
|
|
|
|||
|
|
print("\n✓ URL 标准化测试通过!")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_domain_whitelist():
|
|||
|
|
"""测试域名白名单功能"""
|
|||
|
|
print("\n" + "=" * 60)
|
|||
|
|
print("测试 6: 域名白名单 (ENABLE_DOMAIN_WHITELIST)")
|
|||
|
|
print("=" * 60)
|
|||
|
|
|
|||
|
|
# 创建启用白名单的测试器
|
|||
|
|
tester = SecurityTester()
|
|||
|
|
tester.valves.ENABLE_DOMAIN_WHITELIST = True
|
|||
|
|
tester.valves.TRUSTED_DOMAINS = (
|
|||
|
|
"github.com,raw.githubusercontent.com,huggingface.co"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
print("\n配置信息:")
|
|||
|
|
print(f" 白名单启用: {tester.valves.ENABLE_DOMAIN_WHITELIST}")
|
|||
|
|
print(f" 授信域名: {tester.valves.TRUSTED_DOMAINS}")
|
|||
|
|
|
|||
|
|
# 白名单中的 URLs (应该被接受)
|
|||
|
|
whitelisted_urls = [
|
|||
|
|
"https://github.com/user/repo/raw/main/skill.md",
|
|||
|
|
"https://raw.githubusercontent.com/user/repo/main/skill.md",
|
|||
|
|
"https://api.github.com/repos/user/repo/contents",
|
|||
|
|
"https://huggingface.co/spaces/user/skill",
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
print("\n✅ 白名单中的 URLs (应该被接受):")
|
|||
|
|
for url in whitelisted_urls:
|
|||
|
|
is_safe, error_msg = tester._is_safe_url(url)
|
|||
|
|
status = "✓ 被接受 (正确)" if is_safe else "✗ 被拒绝 (错误)"
|
|||
|
|
print(f" {url:<65} {status}")
|
|||
|
|
assert is_safe, f"白名单中的 URL 应该被接受: {url} - {error_msg}"
|
|||
|
|
|
|||
|
|
# 不在白名单中的 URLs (应该被拒绝)
|
|||
|
|
non_whitelisted_urls = [
|
|||
|
|
"https://example.com/skill.md",
|
|||
|
|
"https://evil.com/skill.zip",
|
|||
|
|
"https://api.example.com/skill",
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
print("\n❌ 非白名单 URLs (应该被拒绝):")
|
|||
|
|
for url in non_whitelisted_urls:
|
|||
|
|
is_safe, error_msg = tester._is_safe_url(url)
|
|||
|
|
status = "✗ 被拒绝 (正确)" if not is_safe else "✓ 被接受 (错误)"
|
|||
|
|
print(f" {url:<65} {status}")
|
|||
|
|
assert not is_safe, f"非白名单 URL 应该被拒绝: {url}"
|
|||
|
|
|
|||
|
|
# 测试禁用白名单
|
|||
|
|
print("\n禁用白名单进行测试...")
|
|||
|
|
tester.valves.ENABLE_DOMAIN_WHITELIST = False
|
|||
|
|
is_safe, error_msg = tester._is_safe_url("https://example.com/skill.md")
|
|||
|
|
print(f" example.com without whitelist: {is_safe} ✓")
|
|||
|
|
assert is_safe, "禁用白名单时,example.com 应该被接受"
|
|||
|
|
|
|||
|
|
print("\n✓ 域名白名单测试通过!")
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ==================== 主函数 ====================
|
|||
|
|
|
|||
|
|
|
|||
|
|
def main():
|
|||
|
|
print("\n" + "🔒 OpenWebUI Skills Manager 安全修复测试".center(60, "="))
|
|||
|
|
print("版本: 0.2.2")
|
|||
|
|
print("=" * 60)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# 运行所有测试
|
|||
|
|
test_ssrf_protection()
|
|||
|
|
test_tar_extraction_safety()
|
|||
|
|
test_zip_extraction_safety()
|
|||
|
|
test_skill_name_collision()
|
|||
|
|
test_url_normalization()
|
|||
|
|
test_domain_whitelist()
|
|||
|
|
|
|||
|
|
# 测试总结
|
|||
|
|
print("\n" + "=" * 60)
|
|||
|
|
print("🎉 所有测试通过!".center(60))
|
|||
|
|
print("=" * 60)
|
|||
|
|
print("\n修复验证:")
|
|||
|
|
print(" ✓ SSRF 防护:阻止指向内部 IP 的请求")
|
|||
|
|
print(" ✓ TAR/ZIP 安全提取:防止路径遍历攻击")
|
|||
|
|
print(" ✓ 名称冲突检查:防止技能名称重复")
|
|||
|
|
print(" ✓ URL 验证:仅接受安全的 HTTP(S) URL")
|
|||
|
|
print(" ✓ 域名白名单:只允许授信域名下载技能")
|
|||
|
|
print("\n所有安全功能都已成功实现!")
|
|||
|
|
print("=" * 60 + "\n")
|
|||
|
|
|
|||
|
|
return 0
|
|||
|
|
except AssertionError as e:
|
|||
|
|
print(f"\n❌ 测试失败: {e}\n")
|
|||
|
|
return 1
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"\n❌ 测试错误: {e}\n")
|
|||
|
|
import traceback
|
|||
|
|
|
|||
|
|
traceback.print_exc()
|
|||
|
|
return 1
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
sys.exit(main())
|