196 lines
4.2 KiB
Go
196 lines
4.2 KiB
Go
|
|
package adapter
|
|||
|
|
|
|||
|
|
import (
|
|||
|
|
"context"
|
|||
|
|
"errors"
|
|||
|
|
"fmt"
|
|||
|
|
"net"
|
|||
|
|
"sync"
|
|||
|
|
"time"
|
|||
|
|
|
|||
|
|
"github.com/ThreeDotsLabs/watermill/message"
|
|||
|
|
|
|||
|
|
"go.yandata.net/iod/iod/trustlog-sdk/api/logger"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
// 默认配置常量.
|
|||
|
|
const (
|
|||
|
|
defaultConnectTimeout = 10 * time.Second
|
|||
|
|
defaultMaxRetries = 3
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
// 预定义错误.
|
|||
|
|
var (
|
|||
|
|
ErrServerAddrRequired = errors.New("server address is required")
|
|||
|
|
ErrPublisherClosed = errors.New("publisher is closed")
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
// TCPPublisherConfig TCP 发布者配置
|
|||
|
|
type TCPPublisherConfig struct {
|
|||
|
|
// ServerAddr TCP 服务器地址,格式: "host:port"
|
|||
|
|
ServerAddr string
|
|||
|
|
// ConnectTimeout 连接超时时间
|
|||
|
|
ConnectTimeout time.Duration
|
|||
|
|
// MaxRetries 最大重试次数
|
|||
|
|
MaxRetries int
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// TCPPublisher 实现基于 TCP 的 watermill Publisher
|
|||
|
|
type TCPPublisher struct {
|
|||
|
|
config TCPPublisherConfig
|
|||
|
|
conn net.Conn
|
|||
|
|
logger logger.Logger
|
|||
|
|
|
|||
|
|
closed bool
|
|||
|
|
closedMu sync.RWMutex
|
|||
|
|
closeChan chan struct{}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// NewTCPPublisher 创建一个新的 TCP Publisher.
|
|||
|
|
func NewTCPPublisher(config TCPPublisherConfig, logger logger.Logger) (*TCPPublisher, error) {
|
|||
|
|
if config.ServerAddr == "" {
|
|||
|
|
return nil, ErrServerAddrRequired
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if config.ConnectTimeout == 0 {
|
|||
|
|
config.ConnectTimeout = defaultConnectTimeout
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if config.MaxRetries == 0 {
|
|||
|
|
config.MaxRetries = defaultMaxRetries
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
p := &TCPPublisher{
|
|||
|
|
config: config,
|
|||
|
|
logger: logger,
|
|||
|
|
closeChan: make(chan struct{}),
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 连接到服务器
|
|||
|
|
if err := p.connect(); err != nil {
|
|||
|
|
return nil, err
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 不再接收 ACK/NACK,发送即成功模式
|
|||
|
|
// go p.receiveAcks() // 已移除
|
|||
|
|
|
|||
|
|
return p, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// connect 连接到 TCP 服务器
|
|||
|
|
func (p *TCPPublisher) connect() error {
|
|||
|
|
ctx, cancel := context.WithTimeout(context.Background(), p.config.ConnectTimeout)
|
|||
|
|
defer cancel()
|
|||
|
|
|
|||
|
|
var d net.Dialer
|
|||
|
|
conn, err := d.DialContext(ctx, "tcp", p.config.ServerAddr)
|
|||
|
|
if err != nil {
|
|||
|
|
return fmt.Errorf("failed to connect to %s: %w", p.config.ServerAddr, err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
p.conn = conn
|
|||
|
|
p.logger.InfoContext(context.Background(), "Connected to TCP server", "addr", p.config.ServerAddr)
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Publish 发布消息.
|
|||
|
|
func (p *TCPPublisher) Publish(topic string, messages ...*message.Message) error {
|
|||
|
|
p.closedMu.RLock()
|
|||
|
|
if p.closed {
|
|||
|
|
p.closedMu.RUnlock()
|
|||
|
|
return ErrPublisherClosed
|
|||
|
|
}
|
|||
|
|
p.closedMu.RUnlock()
|
|||
|
|
|
|||
|
|
ctx := context.Background()
|
|||
|
|
|
|||
|
|
// 使用 WaitGroup 和 errChan 来并发发送消息并收集错误
|
|||
|
|
var wg sync.WaitGroup
|
|||
|
|
errs := make([]error, 0, len(messages))
|
|||
|
|
var errMu sync.Mutex
|
|||
|
|
errChan := make(chan error, len(messages))
|
|||
|
|
|
|||
|
|
for _, msg := range messages {
|
|||
|
|
if msg == nil {
|
|||
|
|
continue
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
wg.Add(1)
|
|||
|
|
go func(m *message.Message) {
|
|||
|
|
defer wg.Done()
|
|||
|
|
|
|||
|
|
if err := p.publishSingle(ctx, topic, m); err != nil {
|
|||
|
|
errChan <- err
|
|||
|
|
}
|
|||
|
|
}(msg)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 等待所有消息发送完成
|
|||
|
|
wg.Wait()
|
|||
|
|
close(errChan)
|
|||
|
|
|
|||
|
|
// 检查是否有错误
|
|||
|
|
for err := range errChan {
|
|||
|
|
errMu.Lock()
|
|||
|
|
errs = append(errs, err)
|
|||
|
|
errMu.Unlock()
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if len(errs) > 0 {
|
|||
|
|
return fmt.Errorf("failed to publish %d messages: %w", len(errs), errors.Join(errs...))
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// publishSingle 发送单条消息,不等待 ACK
|
|||
|
|
func (p *TCPPublisher) publishSingle(ctx context.Context, topic string, msg *message.Message) error {
|
|||
|
|
tcpMsg := &TCPMessage{
|
|||
|
|
Type: MessageTypeData,
|
|||
|
|
Topic: topic,
|
|||
|
|
UUID: msg.UUID,
|
|||
|
|
Payload: msg.Payload,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 编码消息
|
|||
|
|
data, err := EncodeTCPMessage(tcpMsg)
|
|||
|
|
if err != nil {
|
|||
|
|
return fmt.Errorf("failed to encode message: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
p.logger.DebugContext(ctx, "Sending message", "uuid", msg.UUID, "topic", topic)
|
|||
|
|
|
|||
|
|
// 发送消息
|
|||
|
|
if _, err := p.conn.Write(data); err != nil {
|
|||
|
|
return fmt.Errorf("failed to write message: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
p.logger.DebugContext(ctx, "Message sent successfully", "uuid", msg.UUID)
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// receiveAcks, shouldStopReceiving, handleDecodeError 方法已移除
|
|||
|
|
// 不再接收 ACK/NACK,采用发送即成功模式以提高性能
|
|||
|
|
|
|||
|
|
// Close 关闭发布者
|
|||
|
|
func (p *TCPPublisher) Close() error {
|
|||
|
|
p.closedMu.Lock()
|
|||
|
|
if p.closed {
|
|||
|
|
p.closedMu.Unlock()
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
p.closed = true
|
|||
|
|
p.closedMu.Unlock()
|
|||
|
|
|
|||
|
|
close(p.closeChan)
|
|||
|
|
|
|||
|
|
if p.conn != nil {
|
|||
|
|
if err := p.conn.Close(); err != nil {
|
|||
|
|
return fmt.Errorf("failed to close connection: %w", err)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
p.logger.InfoContext(context.Background(), "TCP Publisher closed")
|
|||
|
|
return nil
|
|||
|
|
}
|