147 lines
3.2 KiB
Go
147 lines
3.2 KiB
Go
|
|
package helpers
|
|||
|
|
|
|||
|
|
import (
|
|||
|
|
"errors"
|
|||
|
|
"fmt"
|
|||
|
|
"io"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
// TLVReader 提供 TLV(Type-Length-Value)格式的顺序读取能力。
|
|||
|
|
// 支持无需反序列化全部报文即可读取特定字段。
|
|||
|
|
type TLVReader struct {
|
|||
|
|
r io.Reader
|
|||
|
|
br io.ByteReader
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// NewTLVReader 创建新的 TLVReader。
|
|||
|
|
func NewTLVReader(r io.Reader) *TLVReader {
|
|||
|
|
return &TLVReader{
|
|||
|
|
r: r,
|
|||
|
|
br: newByteReader(r),
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// ReadField 读取下一个 TLV 字段。
|
|||
|
|
// 返回字段的长度和值。
|
|||
|
|
func (tr *TLVReader) ReadField() ([]byte, error) {
|
|||
|
|
length, err := readVarint(tr.br)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, fmt.Errorf("failed to read field length: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if length == 0 {
|
|||
|
|
return nil, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
value := make([]byte, length)
|
|||
|
|
if _, errRead := io.ReadFull(tr.r, value); errRead != nil {
|
|||
|
|
return nil, fmt.Errorf("failed to read field value: %w", errRead)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return value, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// ReadStringField 读取下一个 TLV 字段并转换为字符串。
|
|||
|
|
func (tr *TLVReader) ReadStringField() (string, error) {
|
|||
|
|
data, err := tr.ReadField()
|
|||
|
|
if err != nil {
|
|||
|
|
return "", err
|
|||
|
|
}
|
|||
|
|
return string(data), nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// TLVWriter 提供 TLV 格式的顺序写入能力。
|
|||
|
|
type TLVWriter struct {
|
|||
|
|
w io.Writer
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// NewTLVWriter 创建新的 TLVWriter。
|
|||
|
|
func NewTLVWriter(w io.Writer) *TLVWriter {
|
|||
|
|
return &TLVWriter{w: w}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// WriteField 写入一个 TLV 字段。
|
|||
|
|
func (tw *TLVWriter) WriteField(value []byte) error {
|
|||
|
|
if err := writeVarint(tw.w, uint64(len(value))); err != nil {
|
|||
|
|
return fmt.Errorf("failed to write field length: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if len(value) > 0 {
|
|||
|
|
if _, err := tw.w.Write(value); err != nil {
|
|||
|
|
return fmt.Errorf("failed to write field value: %w", err)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// WriteStringField 写入一个字符串 TLV 字段。
|
|||
|
|
func (tw *TLVWriter) WriteStringField(value string) error {
|
|||
|
|
return tw.WriteField([]byte(value))
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Varint 编码/解码函数
|
|||
|
|
|
|||
|
|
const (
|
|||
|
|
// varintContinueBit 表示 varint 还有后续字节的标志位。
|
|||
|
|
varintContinueBit = 0x80
|
|||
|
|
// varintDataMask 用于提取 varint 数据位的掩码。
|
|||
|
|
varintDataMask = 0x7f
|
|||
|
|
// varintMaxShift 表示 varint 最大的位移量,防止溢出。
|
|||
|
|
varintMaxShift = 64
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
// writeVarint 写入变长整数(类似 Protobuf 的 varint 编码)。
|
|||
|
|
// 将 uint64 编码为变长格式,节省存储空间。
|
|||
|
|
//
|
|||
|
|
|
|||
|
|
func writeVarint(w io.Writer, x uint64) error {
|
|||
|
|
var buf [10]byte
|
|||
|
|
n := 0
|
|||
|
|
for x >= varintContinueBit {
|
|||
|
|
buf[n] = byte(x) | varintContinueBit
|
|||
|
|
x >>= 7
|
|||
|
|
n++
|
|||
|
|
}
|
|||
|
|
buf[n] = byte(x)
|
|||
|
|
_, err := w.Write(buf[:n+1])
|
|||
|
|
return err
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// readVarint 读取变长整数。
|
|||
|
|
// 从字节流中解码 varint 格式的整数。
|
|||
|
|
func readVarint(r io.ByteReader) (uint64, error) {
|
|||
|
|
var x uint64
|
|||
|
|
var shift uint
|
|||
|
|
for {
|
|||
|
|
b, err := r.ReadByte()
|
|||
|
|
if err != nil {
|
|||
|
|
return 0, err
|
|||
|
|
}
|
|||
|
|
x |= uint64(b&varintDataMask) << shift
|
|||
|
|
if b&varintContinueBit == 0 {
|
|||
|
|
return x, nil
|
|||
|
|
}
|
|||
|
|
shift += 7
|
|||
|
|
if shift >= varintMaxShift {
|
|||
|
|
return 0, errors.New("varint overflow")
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// byteReader 为 io.Reader 实现 io.ByteReader 接口。
|
|||
|
|
// 提供逐字节读取能力,用于 varint 解码。
|
|||
|
|
type byteReader struct {
|
|||
|
|
r io.Reader
|
|||
|
|
b [1]byte
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func newByteReader(r io.Reader) io.ByteReader {
|
|||
|
|
return &byteReader{r: r}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (br *byteReader) ReadByte() (byte, error) {
|
|||
|
|
_, err := br.r.Read(br.b[:])
|
|||
|
|
return br.b[0], err
|
|||
|
|
}
|