392 lines
9.4 KiB
Go
392 lines
9.4 KiB
Go
|
|
package persistence
|
|||
|
|
|
|||
|
|
import (
|
|||
|
|
"context"
|
|||
|
|
"database/sql"
|
|||
|
|
"testing"
|
|||
|
|
"time"
|
|||
|
|
|
|||
|
|
_ "github.com/mattn/go-sqlite3"
|
|||
|
|
|
|||
|
|
"go.yandata.net/iod/iod/go-trustlog/api/logger"
|
|||
|
|
"go.yandata.net/iod/iod/go-trustlog/api/model"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
// setupTestDB 创建测试用的 SQLite 内存数据库
|
|||
|
|
func setupTestDB(t *testing.T) *sql.DB {
|
|||
|
|
db, err := sql.Open("sqlite3", ":memory:")
|
|||
|
|
if err != nil {
|
|||
|
|
t.Fatalf("failed to open test database: %v", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 创建表
|
|||
|
|
opDDL, cursorDDL, retryDDL, err := GetDialectDDL("sqlite3")
|
|||
|
|
if err != nil {
|
|||
|
|
t.Fatalf("failed to get DDL: %v", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if _, err := db.Exec(opDDL); err != nil {
|
|||
|
|
t.Fatalf("failed to create operation table: %v", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if _, err := db.Exec(cursorDDL); err != nil {
|
|||
|
|
t.Fatalf("failed to create cursor table: %v", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if _, err := db.Exec(retryDDL); err != nil {
|
|||
|
|
t.Fatalf("failed to create retry table: %v", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return db
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// createTestOperation 创建测试用的 Operation
|
|||
|
|
func createTestOperation(t *testing.T, opID string) *model.Operation {
|
|||
|
|
op, err := model.NewFullOperation(
|
|||
|
|
model.OpSourceDOIP,
|
|||
|
|
model.OpTypeCreate,
|
|||
|
|
"10.1000",
|
|||
|
|
"test-repo",
|
|||
|
|
"10.1000/test-repo/"+opID,
|
|||
|
|
"producer-001",
|
|||
|
|
"test-actor",
|
|||
|
|
[]byte(`{"test":"request"}`),
|
|||
|
|
[]byte(`{"test":"response"}`),
|
|||
|
|
time.Now(),
|
|||
|
|
)
|
|||
|
|
if err != nil {
|
|||
|
|
t.Fatalf("failed to create test operation: %v", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
op.OpID = opID // 覆盖自动生成的 ID
|
|||
|
|
return op
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func TestOperationRepository_Save(t *testing.T) {
|
|||
|
|
db := setupTestDB(t)
|
|||
|
|
defer db.Close()
|
|||
|
|
|
|||
|
|
ctx := context.Background()
|
|||
|
|
log := logger.GetGlobalLogger()
|
|||
|
|
repo := NewOperationRepository(db, log)
|
|||
|
|
|
|||
|
|
op := createTestOperation(t, "test-op-001")
|
|||
|
|
|
|||
|
|
// 设置 IP 字段
|
|||
|
|
clientIP := "192.168.1.100"
|
|||
|
|
serverIP := "10.0.0.50"
|
|||
|
|
op.ClientIP = &clientIP
|
|||
|
|
op.ServerIP = &serverIP
|
|||
|
|
|
|||
|
|
// 测试保存
|
|||
|
|
err := repo.Save(ctx, op, StatusNotTrustlogged)
|
|||
|
|
if err != nil {
|
|||
|
|
t.Fatalf("failed to save operation: %v", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 验证保存结果
|
|||
|
|
savedOp, status, err := repo.FindByID(ctx, "test-op-001")
|
|||
|
|
if err != nil {
|
|||
|
|
t.Fatalf("failed to find operation: %v", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if savedOp.OpID != "test-op-001" {
|
|||
|
|
t.Errorf("expected OpID to be 'test-op-001', got %s", savedOp.OpID)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if status != StatusNotTrustlogged {
|
|||
|
|
t.Errorf("expected status to be StatusNotTrustlogged, got %v", status)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if savedOp.ClientIP == nil || *savedOp.ClientIP != "192.168.1.100" {
|
|||
|
|
t.Error("ClientIP not saved correctly")
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if savedOp.ServerIP == nil || *savedOp.ServerIP != "10.0.0.50" {
|
|||
|
|
t.Error("ServerIP not saved correctly")
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func TestOperationRepository_SaveWithNullIP(t *testing.T) {
|
|||
|
|
db := setupTestDB(t)
|
|||
|
|
defer db.Close()
|
|||
|
|
|
|||
|
|
ctx := context.Background()
|
|||
|
|
log := logger.GetGlobalLogger()
|
|||
|
|
repo := NewOperationRepository(db, log)
|
|||
|
|
|
|||
|
|
op := createTestOperation(t, "test-op-002")
|
|||
|
|
// IP 字段保持为 nil
|
|||
|
|
|
|||
|
|
err := repo.Save(ctx, op, StatusNotTrustlogged)
|
|||
|
|
if err != nil {
|
|||
|
|
t.Fatalf("failed to save operation: %v", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
savedOp, _, err := repo.FindByID(ctx, "test-op-002")
|
|||
|
|
if err != nil {
|
|||
|
|
t.Fatalf("failed to find operation: %v", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if savedOp.ClientIP != nil {
|
|||
|
|
t.Error("ClientIP should be nil")
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if savedOp.ServerIP != nil {
|
|||
|
|
t.Error("ServerIP should be nil")
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func TestOperationRepository_UpdateStatus(t *testing.T) {
|
|||
|
|
db := setupTestDB(t)
|
|||
|
|
defer db.Close()
|
|||
|
|
|
|||
|
|
ctx := context.Background()
|
|||
|
|
log := logger.GetGlobalLogger()
|
|||
|
|
repo := NewOperationRepository(db, log)
|
|||
|
|
|
|||
|
|
op := createTestOperation(t, "test-op-003")
|
|||
|
|
|
|||
|
|
// 先保存
|
|||
|
|
err := repo.Save(ctx, op, StatusNotTrustlogged)
|
|||
|
|
if err != nil {
|
|||
|
|
t.Fatalf("failed to save operation: %v", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 更新状态
|
|||
|
|
err = repo.UpdateStatus(ctx, "test-op-003", StatusTrustlogged)
|
|||
|
|
if err != nil {
|
|||
|
|
t.Fatalf("failed to update status: %v", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 验证更新结果
|
|||
|
|
_, status, err := repo.FindByID(ctx, "test-op-003")
|
|||
|
|
if err != nil {
|
|||
|
|
t.Fatalf("failed to find operation: %v", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if status != StatusTrustlogged {
|
|||
|
|
t.Errorf("expected status to be StatusTrustlogged, got %v", status)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func TestOperationRepository_FindUntrustlogged(t *testing.T) {
|
|||
|
|
db := setupTestDB(t)
|
|||
|
|
defer db.Close()
|
|||
|
|
|
|||
|
|
ctx := context.Background()
|
|||
|
|
log := logger.GetGlobalLogger()
|
|||
|
|
repo := NewOperationRepository(db, log)
|
|||
|
|
|
|||
|
|
// 保存多个操作
|
|||
|
|
for i := 1; i <= 5; i++ {
|
|||
|
|
op := createTestOperation(t, "test-op-00"+string(rune('0'+i)))
|
|||
|
|
status := StatusNotTrustlogged
|
|||
|
|
if i%2 == 0 {
|
|||
|
|
status = StatusTrustlogged
|
|||
|
|
}
|
|||
|
|
err := repo.Save(ctx, op, status)
|
|||
|
|
if err != nil {
|
|||
|
|
t.Fatalf("failed to save operation %d: %v", i, err)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 查询未存证的操作
|
|||
|
|
ops, err := repo.FindUntrustlogged(ctx, 10)
|
|||
|
|
if err != nil {
|
|||
|
|
t.Fatalf("failed to find untrustlogged operations: %v", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 应该有 3 个未存证的操作(1, 3, 5)
|
|||
|
|
if len(ops) != 3 {
|
|||
|
|
t.Errorf("expected 3 untrustlogged operations, got %d", len(ops))
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func TestCursorRepository_GetAndUpdate(t *testing.T) {
|
|||
|
|
db := setupTestDB(t)
|
|||
|
|
defer db.Close()
|
|||
|
|
|
|||
|
|
ctx := context.Background()
|
|||
|
|
log := logger.GetGlobalLogger()
|
|||
|
|
repo := NewCursorRepository(db, log)
|
|||
|
|
|
|||
|
|
cursorKey := "test-cursor"
|
|||
|
|
|
|||
|
|
// 初始化游标
|
|||
|
|
now := time.Now().Format(time.RFC3339Nano)
|
|||
|
|
err := repo.InitCursor(ctx, cursorKey, now)
|
|||
|
|
if err != nil {
|
|||
|
|
t.Fatalf("failed to init cursor: %v", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 获取游标值
|
|||
|
|
cursorValue, err := repo.GetCursor(ctx, cursorKey)
|
|||
|
|
if err != nil {
|
|||
|
|
t.Fatalf("failed to get cursor: %v", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if cursorValue != now {
|
|||
|
|
t.Errorf("expected cursor value to be %s, got %s", now, cursorValue)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 更新游标
|
|||
|
|
newTime := time.Now().Add(1 * time.Hour).Format(time.RFC3339Nano)
|
|||
|
|
err = repo.UpdateCursor(ctx, cursorKey, newTime)
|
|||
|
|
if err != nil {
|
|||
|
|
t.Fatalf("failed to update cursor: %v", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 验证更新结果
|
|||
|
|
cursorValue, err = repo.GetCursor(ctx, cursorKey)
|
|||
|
|
if err != nil {
|
|||
|
|
t.Fatalf("failed to get cursor: %v", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if cursorValue != newTime {
|
|||
|
|
t.Errorf("expected cursor value to be %s, got %s", newTime, cursorValue)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func TestRetryRepository_AddAndFind(t *testing.T) {
|
|||
|
|
db := setupTestDB(t)
|
|||
|
|
defer db.Close()
|
|||
|
|
|
|||
|
|
ctx := context.Background()
|
|||
|
|
log := logger.GetGlobalLogger()
|
|||
|
|
repo := NewRetryRepository(db, log)
|
|||
|
|
|
|||
|
|
// 添加重试记录(立即可以重试)
|
|||
|
|
nextRetry := time.Now().Add(-1 * time.Second) // 过去的时间,立即可以查询到
|
|||
|
|
err := repo.AddRetry(ctx, "test-op-001", "test error", nextRetry)
|
|||
|
|
if err != nil {
|
|||
|
|
t.Fatalf("failed to add retry: %v", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 查找待重试的记录
|
|||
|
|
records, err := repo.FindPendingRetries(ctx, 10)
|
|||
|
|
if err != nil {
|
|||
|
|
t.Fatalf("failed to find pending retries: %v", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if len(records) != 1 {
|
|||
|
|
t.Errorf("expected 1 retry record, got %d", len(records))
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if len(records) > 0 {
|
|||
|
|
if records[0].OpID != "test-op-001" {
|
|||
|
|
t.Errorf("expected OpID to be 'test-op-001', got %s", records[0].OpID)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if records[0].RetryStatus != RetryStatusPending {
|
|||
|
|
t.Errorf("expected status to be PENDING, got %v", records[0].RetryStatus)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func TestRetryRepository_IncrementRetry(t *testing.T) {
|
|||
|
|
db := setupTestDB(t)
|
|||
|
|
defer db.Close()
|
|||
|
|
|
|||
|
|
ctx := context.Background()
|
|||
|
|
log := logger.GetGlobalLogger()
|
|||
|
|
repo := NewRetryRepository(db, log)
|
|||
|
|
|
|||
|
|
// 添加重试记录
|
|||
|
|
nextRetry := time.Now().Add(-1 * time.Second)
|
|||
|
|
err := repo.AddRetry(ctx, "test-op-001", "test error", nextRetry)
|
|||
|
|
if err != nil {
|
|||
|
|
t.Fatalf("failed to add retry: %v", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 增加重试次数(立即可以重试)
|
|||
|
|
nextRetry2 := time.Now().Add(-1 * time.Second)
|
|||
|
|
err = repo.IncrementRetry(ctx, "test-op-001", "test error 2", nextRetry2)
|
|||
|
|
if err != nil {
|
|||
|
|
t.Fatalf("failed to increment retry: %v", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 验证重试次数
|
|||
|
|
records, err := repo.FindPendingRetries(ctx, 10)
|
|||
|
|
if err != nil {
|
|||
|
|
t.Fatalf("failed to find pending retries: %v", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if len(records) != 1 {
|
|||
|
|
t.Fatalf("expected 1 retry record, got %d", len(records))
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if records[0].RetryCount != 1 {
|
|||
|
|
t.Errorf("expected RetryCount to be 1, got %d", records[0].RetryCount)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if records[0].RetryStatus != RetryStatusRetrying {
|
|||
|
|
t.Errorf("expected status to be RETRYING, got %v", records[0].RetryStatus)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func TestRetryRepository_MarkAsDeadLetter(t *testing.T) {
|
|||
|
|
db := setupTestDB(t)
|
|||
|
|
defer db.Close()
|
|||
|
|
|
|||
|
|
ctx := context.Background()
|
|||
|
|
log := logger.GetGlobalLogger()
|
|||
|
|
repo := NewRetryRepository(db, log)
|
|||
|
|
|
|||
|
|
// 添加重试记录
|
|||
|
|
nextRetry := time.Now().Add(-1 * time.Second)
|
|||
|
|
err := repo.AddRetry(ctx, "test-op-001", "test error", nextRetry)
|
|||
|
|
if err != nil {
|
|||
|
|
t.Fatalf("failed to add retry: %v", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 标记为死信
|
|||
|
|
err = repo.MarkAsDeadLetter(ctx, "test-op-001", "max retries exceeded")
|
|||
|
|
if err != nil {
|
|||
|
|
t.Fatalf("failed to mark as dead letter: %v", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 验证状态(死信不应该在待重试列表中)
|
|||
|
|
records, err := repo.FindPendingRetries(ctx, 10)
|
|||
|
|
if err != nil {
|
|||
|
|
t.Fatalf("failed to find pending retries: %v", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if len(records) != 0 {
|
|||
|
|
t.Errorf("expected 0 pending retry records, got %d", len(records))
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func TestRetryRepository_DeleteRetry(t *testing.T) {
|
|||
|
|
db := setupTestDB(t)
|
|||
|
|
defer db.Close()
|
|||
|
|
|
|||
|
|
ctx := context.Background()
|
|||
|
|
log := logger.GetGlobalLogger()
|
|||
|
|
repo := NewRetryRepository(db, log)
|
|||
|
|
|
|||
|
|
// 添加重试记录
|
|||
|
|
nextRetry := time.Now().Add(-1 * time.Second)
|
|||
|
|
err := repo.AddRetry(ctx, "test-op-001", "test error", nextRetry)
|
|||
|
|
if err != nil {
|
|||
|
|
t.Fatalf("failed to add retry: %v", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 删除重试记录
|
|||
|
|
err = repo.DeleteRetry(ctx, "test-op-001")
|
|||
|
|
if err != nil {
|
|||
|
|
t.Fatalf("failed to delete retry: %v", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 验证已删除
|
|||
|
|
records, err := repo.FindPendingRetries(ctx, 10)
|
|||
|
|
if err != nil {
|
|||
|
|
t.Fatalf("failed to find pending retries: %v", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if len(records) != 0 {
|
|||
|
|
t.Errorf("expected 0 retry records, got %d", len(records))
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|