postgres协议学习

 

参考

基于postgres官方文档v15进行学习,协议为v3

基本操作

基本操作分为两个阶段:

  • Startup: 启动连接。
  • Normal: 连接建立之后前后端的正常交互。这里的前端(Frontend)通常指的是连接pg的客户端,后端(Backend)指的是pg服务端。

SQL的协议类型

大致有两种协议:

  • 简单查询协议(Simple Query Protocol): 直接传递纯文本进行查询,pg服务端可以直接解析并进行执行
  • 扩展查询协议(Extended Query Protocol): 会分步骤做查询,更灵活、性能也更好
    • 解析预定义语句(parsing)
    • 绑定参数(binding of parameter values)
    • 执行语句(execution)

消息

基本定义

一个message stream的基本格式为:

  • 第1字节: 定义消息类型 (StartupMessage没有这个字节)
  • 第2-5字节: 定义剩余的消息长度 (包括这4个字节)
  • 剩余: 具体数据

数据类型

pg支持数据类型text、binary和其他,根据传递的format code来区分:

  • 0代表text
  • 1代表binary。binary模式的integer使用网络字节序(大端字节序 Big Endian)

状态流转

pg的消息传递交互并不是都有固定顺序的。例如:

  • 客户端收到了RowDescription的消息,下一个消息将会是DataRow,但是由于行数并不是固定的,所以再往后的消息可能还是DataRow,或者是CommandComplete等等。
  • 客户端发送StartupMessage消息建立连接,如果服务端要求验证密码,则可能中间会插入AuthenticationMD5Password之类的消息。又如果需要tls连接,则需要将当前的连接升级。整个过程都是不确定的,由客户端的连接参数和服务端的配置共同决定。

因此处理pg的消息最好是用类似于状态机流转的方式,收到什么消息就处理什么消息,并做出回应。
这里可以参考pgx/pgproto3的实现,因为它是按照protocol原生实现的,所以对于理解protocol有比较大的帮助。
这边提供一个自己写的精简版proxy的状态机demo。可以看到里面有很多的流转分支,当然有的部分并没有完全实现。但是逻辑的转变主要有3大关键节点: Startup => AuthenticationOK => ReadyForQuery

type PgProxyConnection struct {
	clientConn    net.Conn           // tcp connection between client and proxy
	pgConn        net.Conn           // tcp connection between proxy and postgres
	clientControl *pgproto3.Backend  // clientConn controller
	pgControl     *pgproto3.Frontend // pgConn controller
	TlsOn         bool               // is tls on
	ctx           context.Context
}

func (pc *PgProxyConnection) startUp() error {
	if pc.TlsOn {
		crt, err := tls.LoadX509KeyPair("server.crt", "server.key")
		if err != nil {
			return err
		}
		cfg := &tls.Config{
			Certificates:       []tls.Certificate{crt},
			InsecureSkipVerify: true,
			MinVersion:         tls.VersionTLS12,
			MaxVersion:         tls.VersionTLS12,
		}
		// upgrade to tls connection
		pc.clientConn = tls.Server(pc.clientConn, cfg)
		pc.clientControl = pgproto3.NewBackend(pgproto3.NewChunkReader(pc.clientConn), pc.clientConn)
	}
	startUpMsg, err := pc.clientControl.ReceiveStartupMessage()
	if err != nil {
		return err
	}

	switch msg := startUpMsg.(type) {
	case *pgproto3.StartupMessage:
		return pc.handShake(msg)
	case *pgproto3.SSLRequest:
		// set TLS option on
		pc.TlsOn = true
		// ACK for ssl request
		_, err = pc.clientConn.Write([]byte{'S'})
		if err != nil {
			return err
		}
		return pc.startUp()
	default:
		return fmt.Errorf("unknown startup message: %#v", startUpMsg)
	}
}

func (pc *PgProxyConnection) handShake(msg *pgproto3.StartupMessage) error {
	host := msg.Parameters["host"]
	port := msg.Parameters["port"]
	var err error
	dialer := net.Dialer{KeepAlive: 5 * time.Minute}
	if pc.pgConn, err = dialer.Dial("tcp", fmt.Sprintf("%s:%s", host, port)); err != nil {
		return err
	}
	if pc.TlsOn {
		sslMode, _ := config.Get("TCP.SslMode").(string)
		rootCrt, _ := config.Get("TCP.RootCrt").(string)
		serverCrt, _ := config.Get("TCP.ServerCrt").(string)
		serverKey, _ := config.Get("TCP.ServerKey").(string)
		var cfg *tls.Config
		if cfg, err = configTLS(sslMode, rootCrt, serverCrt, serverKey, host); err != nil {
			return err
		}
		if pc.pgConn, err = startTLS(pc.pgConn, cfg); err != nil {
			return err
		}
	}
	cr, err := chunkreader.NewConfig(pc.pgConn, chunkreader.Config{MinBufLen: 8192})
	if err != nil {
		return err
	}
	pc.pgControl = pgproto3.NewFrontend(cr, pc.pgConn)

	if _, err := pc.pgConn.Write(msg.Encode(nil)); err != nil {
		pc.pgConn.Close()
		return err
	}

	for {
		pdbStartUpMsg, err := pc.pgControl.Receive()
		if err != nil {
			return err
		}
		switch parsedMsg := pdbStartUpMsg.(type) {
		case *pgproto3.ReadyForQuery:
			_, _ = pc.clientConn.Write(parsedMsg.Encode(nil))
			return nil
		case *pgproto3.ParameterStatus:
			_, _ = pc.clientConn.Write(parsedMsg.Encode(nil))
		case *pgproto3.ErrorResponse:
			_, _ = pc.clientConn.Write(parsedMsg.Encode(nil))
			return errors.New(parsedMsg.Message)
		case *pgproto3.NoticeResponse:
			_, _ = pc.clientConn.Write(parsedMsg.Encode(nil))
		case *pgproto3.NotificationResponse:
			_, _ = pc.clientConn.Write(parsedMsg.Encode(nil))
		case *pgproto3.BackendKeyData:
			_, _ = pc.clientConn.Write(parsedMsg.Encode(nil))
		case *pgproto3.AuthenticationOk:
			_, _ = pc.clientConn.Write(parsedMsg.Encode(nil))
		case *pgproto3.AuthenticationCleartextPassword:
			_, _ = pc.clientConn.Write(parsedMsg.Encode(nil))
		case *pgproto3.AuthenticationMD5Password:
            // 对密码的md5生成规则为:
            // concat('md5', md5(concat(md5(concat(password, username)), random-salt)))
			_, err = pc.clientConn.Write(parsedMsg.Encode(nil))
			if err != nil {
				return err
			}
			md5pw, err := pc.clientControl.Receive()
			if err != nil {
				return err
			}
			_, err = pc.pgConn.Write(md5pw.Encode(nil))
			if err != nil {
				return err
			}
		case *pgproto3.AuthenticationSASL:
			_, _ = pc.clientConn.Write(parsedMsg.Encode(nil))
		case *pgproto3.AuthenticationGSS:
			_, _ = pc.clientConn.Write(parsedMsg.Encode(nil))
		default:
			return errors.New("unknown connect message")
		}
	}
}

其他流程

取消请求

pg取消query是通过启动另一个新的连接,发送CancelRequest消息来处理的(而非StartupMessage)。
需要携带对应query的pid和secret key。因为是启动新的连接,所以可以由其他实例发送取消的请求,只要pid和secret key对的上就行。

相关文档