114 lines
2.3 KiB
Go
114 lines
2.3 KiB
Go
|
|
package grpcclient
|
|||
|
|
|
|||
|
|
import (
|
|||
|
|
"errors"
|
|||
|
|
"fmt"
|
|||
|
|
"sync"
|
|||
|
|
"sync/atomic"
|
|||
|
|
|
|||
|
|
"google.golang.org/grpc"
|
|||
|
|
"google.golang.org/grpc/credentials/insecure"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
// ClientFactory 客户端工厂函数类型.
|
|||
|
|
type ClientFactory[T any] func(grpc.ClientConnInterface) T
|
|||
|
|
|
|||
|
|
// ServerClient 封装单个服务器的连接.
|
|||
|
|
type ServerClient[T any] struct {
|
|||
|
|
addr string
|
|||
|
|
conn *grpc.ClientConn
|
|||
|
|
client T
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// LoadBalancer 轮询负载均衡器(泛型版本).
|
|||
|
|
type LoadBalancer[T any] struct {
|
|||
|
|
servers []*ServerClient[T]
|
|||
|
|
counter atomic.Uint64
|
|||
|
|
mu sync.RWMutex
|
|||
|
|
closed bool
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// NewLoadBalancer 创建新的负载均衡器.
|
|||
|
|
func NewLoadBalancer[T any](
|
|||
|
|
addrs []string,
|
|||
|
|
dialOpts []grpc.DialOption,
|
|||
|
|
factory ClientFactory[T],
|
|||
|
|
) (*LoadBalancer[T], error) {
|
|||
|
|
if len(addrs) == 0 {
|
|||
|
|
return nil, errors.New("at least one server address is required")
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
lb := &LoadBalancer[T]{
|
|||
|
|
servers: make([]*ServerClient[T], 0, len(addrs)),
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 默认使用不安全的连接(生产环境应使用TLS)
|
|||
|
|
opts := []grpc.DialOption{
|
|||
|
|
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
|||
|
|
}
|
|||
|
|
opts = append(opts, dialOpts...)
|
|||
|
|
|
|||
|
|
// 连接所有服务器
|
|||
|
|
for _, addr := range addrs {
|
|||
|
|
conn, err := grpc.NewClient(addr, opts...)
|
|||
|
|
if err != nil {
|
|||
|
|
// 关闭已创建的连接
|
|||
|
|
_ = lb.Close()
|
|||
|
|
return nil, fmt.Errorf("failed to connect to server %s: %w", addr, err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
client := factory(conn)
|
|||
|
|
lb.servers = append(lb.servers, &ServerClient[T]{
|
|||
|
|
addr: addr,
|
|||
|
|
conn: conn,
|
|||
|
|
client: client,
|
|||
|
|
})
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return lb, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Next 使用轮询算法获取下一个客户端.
|
|||
|
|
func (lb *LoadBalancer[T]) Next() T {
|
|||
|
|
lb.mu.RLock()
|
|||
|
|
defer lb.mu.RUnlock()
|
|||
|
|
|
|||
|
|
if len(lb.servers) == 0 || lb.closed {
|
|||
|
|
var zero T
|
|||
|
|
return zero
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 原子递增计数器并取模
|
|||
|
|
idx := lb.counter.Add(1) % uint64(len(lb.servers))
|
|||
|
|
return lb.servers[idx].client
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Close 关闭所有连接.
|
|||
|
|
func (lb *LoadBalancer[T]) Close() error {
|
|||
|
|
lb.mu.Lock()
|
|||
|
|
defer lb.mu.Unlock()
|
|||
|
|
|
|||
|
|
// 如果已经关闭,直接返回
|
|||
|
|
if lb.closed {
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
var lastErr error
|
|||
|
|
for _, server := range lb.servers {
|
|||
|
|
if err := server.conn.Close(); err != nil {
|
|||
|
|
lastErr = err
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 标记为已关闭
|
|||
|
|
lb.closed = true
|
|||
|
|
return lastErr
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// ServerCount 返回服务器数量.
|
|||
|
|
func (lb *LoadBalancer[T]) ServerCount() int {
|
|||
|
|
lb.mu.RLock()
|
|||
|
|
defer lb.mu.RUnlock()
|
|||
|
|
return len(lb.servers)
|
|||
|
|
}
|