diff options
author | Ben Johnson <benbjohnson@yahoo.com> | 2022-03-30 18:41:31 -0600 |
---|---|---|
committer | Ben Johnson <benbjohnson@yahoo.com> | 2022-03-30 18:41:31 -0600 |
commit | f3f3928a96bb5705300164c7575aa5599f56a669 (patch) | |
tree | 1b3b2791c80755d7fba64378dbf549b6b7e1a7ed | |
parent | e36da875ef8d722b1d6c50e83c7dee4ba321c077 (diff) |
Add prepared statement handling
-rw-r--r-- | server.go | 224 |
1 files changed, 198 insertions, 26 deletions
@@ -9,15 +9,50 @@ import ( "net" "os" "path/filepath" + "regexp" "strings" "sync" "github.com/jackc/pgproto3/v2" "github.com/jackc/pgtype" - _ "github.com/mattn/go-sqlite3" + "github.com/mattn/go-sqlite3" "golang.org/x/sync/errgroup" ) +// Postgres settings. +const ( + ServerVersion = "13.0.0" +) + +func init() { + sql.Register("postlite-sqlite3", &sqlite3.SQLiteDriver{ + ConnectHook: func(conn *sqlite3.SQLiteConn) error { + if err := conn.RegisterFunc("current_catalog", currentCatalog, true); err != nil { + return fmt.Errorf("cannot register current_catalog() function") + } + if err := conn.RegisterFunc("current_schema", currentSchema, true); err != nil { + return fmt.Errorf("cannot register current_schema() function") + } + if err := conn.RegisterFunc("current_user", currentUser, true); err != nil { + return fmt.Errorf("cannot register current_schema() function") + } + if err := conn.RegisterFunc("session_user", sessionUser, true); err != nil { + return fmt.Errorf("cannot register session_user() function") + } + if err := conn.RegisterFunc("user", user, true); err != nil { + return fmt.Errorf("cannot register user() function") + } + return nil + }, + }) +} + +func currentCatalog() string { return "public" } +func currentSchema() string { return "public" } +func currentUser() string { return "sqlite3" } +func sessionUser() string { return "sqlite3" } +func user() string { return "sqlite3" } + type Server struct { mu sync.Mutex ln net.Listener @@ -125,7 +160,7 @@ func (s *Server) serve() error { defer s.CloseClientConnection(conn) if err := s.serveConn(s.ctx, conn); err != nil && s.ctx.Err() == nil { - log.Println("connection error, closing: %s", err) + log.Printf("connection error, closing: %s", err) return nil } @@ -146,13 +181,25 @@ func (s *Server) serveConn(ctx context.Context, c *Conn) error { return fmt.Errorf("receive message: %w", err) } + log.Printf("[recv] %#v", msg) + switch msg := msg.(type) { case *pgproto3.Query: if err := s.handleQueryMessage(ctx, c, msg); err != nil { return fmt.Errorf("query message: %w", err) } + + case *pgproto3.Parse: + if err := s.handleParseMessage(ctx, c, msg); err != nil { + return fmt.Errorf("parse message: %w", err) + } + + case *pgproto3.Sync: // ignore + continue + case *pgproto3.Terminate: return nil // exit + default: return fmt.Errorf("unexpected message type: %#v", msg) } @@ -193,17 +240,19 @@ func (s *Server) handleStartupMessage(ctx context.Context, c *Conn, msg *pgproto } // Open SQL database & attach to the connection. - if c.db, err = sql.Open("sqlite3", filepath.Join(s.DataDir, name)); err != nil { + if c.db, err = sql.Open("postlite-sqlite3", filepath.Join(s.DataDir, name)); err != nil { return err } return writeMessages(c, &pgproto3.AuthenticationOk{}, + &pgproto3.ParameterStatus{Name: "server_version", Value: ServerVersion}, &pgproto3.ReadyForQuery{TxStatus: 'I'}, ) } func (s *Server) handleSSLRequestMessage(ctx context.Context, c *Conn, msg *pgproto3.SSLRequest) error { + log.Printf("received ssl request message: %#v", msg) if _, err := c.Write([]byte("N")); err != nil { return err } @@ -221,13 +270,36 @@ func (s *Server) handleQueryMessage(ctx context.Context, c *Conn, msg *pgproto3. &pgproto3.ReadyForQuery{TxStatus: 'I'}, ) } + defer rows.Close() - // Encode header. + // Encode column header. cols, err := rows.ColumnTypes() if err != nil { - return fmt.Errorf("columns: %w", err) + return fmt.Errorf("column types: %w", err) } + buf := toRowDescription(cols).Encode(nil) + // Iterate over each row and encode it to the wire protocol. + for rows.Next() { + row, err := scanRow(rows, cols) + if err != nil { + return fmt.Errorf("scan: %w", err) + } + buf = row.Encode(buf) + } + if err := rows.Err(); err != nil { + return fmt.Errorf("rows: %w", err) + } + + // Mark command complete and ready for next query. + buf = (&pgproto3.CommandComplete{CommandTag: []byte("SELECT 1")}).Encode(buf) + buf = (&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(buf) + + _, err = c.Write(buf) + return err +} + +func toRowDescription(cols []*sql.ColumnType) *pgproto3.RowDescription { var desc pgproto3.RowDescription for _, col := range cols { desc.Fields = append(desc.Fields, pgproto3.FieldDescription{ @@ -240,36 +312,113 @@ func (s *Server) handleQueryMessage(ctx context.Context, c *Conn, msg *pgproto3. Format: 0, }) } - buf := desc.Encode(nil) + return &desc +} - // Iterate over each row and encode it to the wire protocol. - for rows.Next() { - refs := make([]interface{}, len(cols)) - values := make([]interface{}, len(cols)) - for i := range refs { - refs[i] = &values[i] - } +func scanRow(rows *sql.Rows, cols []*sql.ColumnType) (*pgproto3.DataRow, error) { + refs := make([]interface{}, len(cols)) + values := make([]interface{}, len(cols)) + for i := range refs { + refs[i] = &values[i] + } - // Scan from SQLite database. - if err := rows.Scan(refs...); err != nil { - return fmt.Errorf("scan: %w", err) + // Scan from SQLite database. + if err := rows.Scan(refs...); err != nil { + return nil, fmt.Errorf("scan: %w", err) + } + + // Convert to TEXT values to return over Postgres wire protocol. + row := pgproto3.DataRow{Values: make([][]byte, len(values))} + for i := range values { + row.Values[i] = []byte(fmt.Sprint(values[i])) + } + return &row, nil +} + +func (s *Server) handleParseMessage(ctx context.Context, c *Conn, pmsg *pgproto3.Parse) error { + // Rewrite system-information queries so they're tolerable by SQLite. + query := rewriteQuery(pmsg.Query) + + if pmsg.Query != query { + log.Printf("query rewrite: %s", query) + } + + // Prepare the query. + stmt, err := c.db.PrepareContext(ctx, query) + if err != nil { + return err + } + + var rows *sql.Rows + var cols []*sql.ColumnType + exec := func() (err error) { + if rows != nil { + return nil + } + if rows, err = stmt.QueryContext(ctx); err != nil { + return err } + if cols, err = rows.ColumnTypes(); err != nil { + return err + } + return nil + } - // Convert to TEXT values to return over Postgres wire protocol. - row := pgproto3.DataRow{Values: make([][]byte, len(values))} - for i := range values { - row.Values[i] = []byte(fmt.Sprint(values[i])) + // LOOP: + for { + msg, err := c.backend.Receive() + if err != nil { + return fmt.Errorf("receive message during parse: %w", err) } - // Encode row. - buf = row.Encode(buf) + log.Printf("[recv(p)] %#v", msg) + + switch msg := msg.(type) { + case *pgproto3.Bind: + // ignore + + case *pgproto3.Describe: + if err := exec(); err != nil { + return fmt.Errorf("exec: %w", err) + } + if _, err := c.Write(toRowDescription(cols).Encode(nil)); err != nil { + return err + } + + case *pgproto3.Execute: + // TODO: Send pgproto3.ParseComplete? + if err := exec(); err != nil { + return fmt.Errorf("exec: %w", err) + } + + var buf []byte + for rows.Next() { + row, err := scanRow(rows, cols) + if err != nil { + return fmt.Errorf("scan: %w", err) + } + buf = row.Encode(buf) + } + if err := rows.Err(); err != nil { + return fmt.Errorf("rows: %w", err) + } + + // Mark command complete and ready for next query. + buf = (&pgproto3.CommandComplete{CommandTag: []byte("SELECT 1")}).Encode(buf) + buf = (&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(buf) + _, err := c.Write(buf) + return err + + default: + return fmt.Errorf("unexpected message type during parse: %#v", msg) + } } +} - // Mark command complete and ready for next query. - buf = (&pgproto3.CommandComplete{CommandTag: []byte("SELECT 1")}).Encode(buf) +func (s *Server) execSetQuery(ctx context.Context, c *Conn, query string) error { + buf := (&pgproto3.CommandComplete{CommandTag: []byte("SELECT 1")}).Encode(nil) buf = (&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(buf) - - _, err = c.Write(buf) + _, err := c.Write(buf) return err } @@ -302,3 +451,26 @@ func writeMessages(w io.Writer, msgs ...pgproto3.Message) error { _, err := w.Write(buf) return err } + +func rewriteQuery(q string) string { + // Ignore SET queries by rewriting them to empty resultsets. + if strings.HasPrefix(q, "SET ") { + return `SELECT 'SET'` + } + + // Rewrite system information variables so they are functions so we can inject them. + // https://www.postgresql.org/docs/9.1/functions-info.html + q = systemFunctionRegex.ReplaceAllString(q, "$1()$2") + + // Rewrite double-colon casting by simply removing it. + // https://www.postgresql.org/docs/7.3/sql-expressions.html#SQL-SYNTAX-TYPE-CASTS + q = castRegex.ReplaceAllString(q, "") + + return q +} + +var ( + systemFunctionRegex = regexp.MustCompile(`\b(current_catalog|current_schema|current_user|session_user|user)\b([^\(]|$)`) + + castRegex = regexp.MustCompile(`::(regclass)`) +) |