summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBen Johnson <benbjohnson@yahoo.com>2022-03-30 18:41:31 -0600
committerBen Johnson <benbjohnson@yahoo.com>2022-03-30 18:41:31 -0600
commitf3f3928a96bb5705300164c7575aa5599f56a669 (patch)
tree1b3b2791c80755d7fba64378dbf549b6b7e1a7ed
parente36da875ef8d722b1d6c50e83c7dee4ba321c077 (diff)
Add prepared statement handling
-rw-r--r--server.go224
1 files changed, 198 insertions, 26 deletions
diff --git a/server.go b/server.go
index 8bf6894..931b745 100644
--- a/server.go
+++ b/server.go
@@ -9,15 +9,50 @@ import (
"net"
"os"
"path/filepath"
+ "regexp"
"strings"
"sync"
"github.com/jackc/pgproto3/v2"
"github.com/jackc/pgtype"
- _ "github.com/mattn/go-sqlite3"
+ "github.com/mattn/go-sqlite3"
"golang.org/x/sync/errgroup"
)
+// Postgres settings.
+const (
+ ServerVersion = "13.0.0"
+)
+
+func init() {
+ sql.Register("postlite-sqlite3", &sqlite3.SQLiteDriver{
+ ConnectHook: func(conn *sqlite3.SQLiteConn) error {
+ if err := conn.RegisterFunc("current_catalog", currentCatalog, true); err != nil {
+ return fmt.Errorf("cannot register current_catalog() function")
+ }
+ if err := conn.RegisterFunc("current_schema", currentSchema, true); err != nil {
+ return fmt.Errorf("cannot register current_schema() function")
+ }
+ if err := conn.RegisterFunc("current_user", currentUser, true); err != nil {
+ return fmt.Errorf("cannot register current_schema() function")
+ }
+ if err := conn.RegisterFunc("session_user", sessionUser, true); err != nil {
+ return fmt.Errorf("cannot register session_user() function")
+ }
+ if err := conn.RegisterFunc("user", user, true); err != nil {
+ return fmt.Errorf("cannot register user() function")
+ }
+ return nil
+ },
+ })
+}
+
+func currentCatalog() string { return "public" }
+func currentSchema() string { return "public" }
+func currentUser() string { return "sqlite3" }
+func sessionUser() string { return "sqlite3" }
+func user() string { return "sqlite3" }
+
type Server struct {
mu sync.Mutex
ln net.Listener
@@ -125,7 +160,7 @@ func (s *Server) serve() error {
defer s.CloseClientConnection(conn)
if err := s.serveConn(s.ctx, conn); err != nil && s.ctx.Err() == nil {
- log.Println("connection error, closing: %s", err)
+ log.Printf("connection error, closing: %s", err)
return nil
}
@@ -146,13 +181,25 @@ func (s *Server) serveConn(ctx context.Context, c *Conn) error {
return fmt.Errorf("receive message: %w", err)
}
+ log.Printf("[recv] %#v", msg)
+
switch msg := msg.(type) {
case *pgproto3.Query:
if err := s.handleQueryMessage(ctx, c, msg); err != nil {
return fmt.Errorf("query message: %w", err)
}
+
+ case *pgproto3.Parse:
+ if err := s.handleParseMessage(ctx, c, msg); err != nil {
+ return fmt.Errorf("parse message: %w", err)
+ }
+
+ case *pgproto3.Sync: // ignore
+ continue
+
case *pgproto3.Terminate:
return nil // exit
+
default:
return fmt.Errorf("unexpected message type: %#v", msg)
}
@@ -193,17 +240,19 @@ func (s *Server) handleStartupMessage(ctx context.Context, c *Conn, msg *pgproto
}
// Open SQL database & attach to the connection.
- if c.db, err = sql.Open("sqlite3", filepath.Join(s.DataDir, name)); err != nil {
+ if c.db, err = sql.Open("postlite-sqlite3", filepath.Join(s.DataDir, name)); err != nil {
return err
}
return writeMessages(c,
&pgproto3.AuthenticationOk{},
+ &pgproto3.ParameterStatus{Name: "server_version", Value: ServerVersion},
&pgproto3.ReadyForQuery{TxStatus: 'I'},
)
}
func (s *Server) handleSSLRequestMessage(ctx context.Context, c *Conn, msg *pgproto3.SSLRequest) error {
+ log.Printf("received ssl request message: %#v", msg)
if _, err := c.Write([]byte("N")); err != nil {
return err
}
@@ -221,13 +270,36 @@ func (s *Server) handleQueryMessage(ctx context.Context, c *Conn, msg *pgproto3.
&pgproto3.ReadyForQuery{TxStatus: 'I'},
)
}
+ defer rows.Close()
- // Encode header.
+ // Encode column header.
cols, err := rows.ColumnTypes()
if err != nil {
- return fmt.Errorf("columns: %w", err)
+ return fmt.Errorf("column types: %w", err)
}
+ buf := toRowDescription(cols).Encode(nil)
+ // Iterate over each row and encode it to the wire protocol.
+ for rows.Next() {
+ row, err := scanRow(rows, cols)
+ if err != nil {
+ return fmt.Errorf("scan: %w", err)
+ }
+ buf = row.Encode(buf)
+ }
+ if err := rows.Err(); err != nil {
+ return fmt.Errorf("rows: %w", err)
+ }
+
+ // Mark command complete and ready for next query.
+ buf = (&pgproto3.CommandComplete{CommandTag: []byte("SELECT 1")}).Encode(buf)
+ buf = (&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(buf)
+
+ _, err = c.Write(buf)
+ return err
+}
+
+func toRowDescription(cols []*sql.ColumnType) *pgproto3.RowDescription {
var desc pgproto3.RowDescription
for _, col := range cols {
desc.Fields = append(desc.Fields, pgproto3.FieldDescription{
@@ -240,36 +312,113 @@ func (s *Server) handleQueryMessage(ctx context.Context, c *Conn, msg *pgproto3.
Format: 0,
})
}
- buf := desc.Encode(nil)
+ return &desc
+}
- // Iterate over each row and encode it to the wire protocol.
- for rows.Next() {
- refs := make([]interface{}, len(cols))
- values := make([]interface{}, len(cols))
- for i := range refs {
- refs[i] = &values[i]
- }
+func scanRow(rows *sql.Rows, cols []*sql.ColumnType) (*pgproto3.DataRow, error) {
+ refs := make([]interface{}, len(cols))
+ values := make([]interface{}, len(cols))
+ for i := range refs {
+ refs[i] = &values[i]
+ }
- // Scan from SQLite database.
- if err := rows.Scan(refs...); err != nil {
- return fmt.Errorf("scan: %w", err)
+ // Scan from SQLite database.
+ if err := rows.Scan(refs...); err != nil {
+ return nil, fmt.Errorf("scan: %w", err)
+ }
+
+ // Convert to TEXT values to return over Postgres wire protocol.
+ row := pgproto3.DataRow{Values: make([][]byte, len(values))}
+ for i := range values {
+ row.Values[i] = []byte(fmt.Sprint(values[i]))
+ }
+ return &row, nil
+}
+
+func (s *Server) handleParseMessage(ctx context.Context, c *Conn, pmsg *pgproto3.Parse) error {
+ // Rewrite system-information queries so they're tolerable by SQLite.
+ query := rewriteQuery(pmsg.Query)
+
+ if pmsg.Query != query {
+ log.Printf("query rewrite: %s", query)
+ }
+
+ // Prepare the query.
+ stmt, err := c.db.PrepareContext(ctx, query)
+ if err != nil {
+ return err
+ }
+
+ var rows *sql.Rows
+ var cols []*sql.ColumnType
+ exec := func() (err error) {
+ if rows != nil {
+ return nil
+ }
+ if rows, err = stmt.QueryContext(ctx); err != nil {
+ return err
}
+ if cols, err = rows.ColumnTypes(); err != nil {
+ return err
+ }
+ return nil
+ }
- // Convert to TEXT values to return over Postgres wire protocol.
- row := pgproto3.DataRow{Values: make([][]byte, len(values))}
- for i := range values {
- row.Values[i] = []byte(fmt.Sprint(values[i]))
+ // LOOP:
+ for {
+ msg, err := c.backend.Receive()
+ if err != nil {
+ return fmt.Errorf("receive message during parse: %w", err)
}
- // Encode row.
- buf = row.Encode(buf)
+ log.Printf("[recv(p)] %#v", msg)
+
+ switch msg := msg.(type) {
+ case *pgproto3.Bind:
+ // ignore
+
+ case *pgproto3.Describe:
+ if err := exec(); err != nil {
+ return fmt.Errorf("exec: %w", err)
+ }
+ if _, err := c.Write(toRowDescription(cols).Encode(nil)); err != nil {
+ return err
+ }
+
+ case *pgproto3.Execute:
+ // TODO: Send pgproto3.ParseComplete?
+ if err := exec(); err != nil {
+ return fmt.Errorf("exec: %w", err)
+ }
+
+ var buf []byte
+ for rows.Next() {
+ row, err := scanRow(rows, cols)
+ if err != nil {
+ return fmt.Errorf("scan: %w", err)
+ }
+ buf = row.Encode(buf)
+ }
+ if err := rows.Err(); err != nil {
+ return fmt.Errorf("rows: %w", err)
+ }
+
+ // Mark command complete and ready for next query.
+ buf = (&pgproto3.CommandComplete{CommandTag: []byte("SELECT 1")}).Encode(buf)
+ buf = (&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(buf)
+ _, err := c.Write(buf)
+ return err
+
+ default:
+ return fmt.Errorf("unexpected message type during parse: %#v", msg)
+ }
}
+}
- // Mark command complete and ready for next query.
- buf = (&pgproto3.CommandComplete{CommandTag: []byte("SELECT 1")}).Encode(buf)
+func (s *Server) execSetQuery(ctx context.Context, c *Conn, query string) error {
+ buf := (&pgproto3.CommandComplete{CommandTag: []byte("SELECT 1")}).Encode(nil)
buf = (&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(buf)
-
- _, err = c.Write(buf)
+ _, err := c.Write(buf)
return err
}
@@ -302,3 +451,26 @@ func writeMessages(w io.Writer, msgs ...pgproto3.Message) error {
_, err := w.Write(buf)
return err
}
+
+func rewriteQuery(q string) string {
+ // Ignore SET queries by rewriting them to empty resultsets.
+ if strings.HasPrefix(q, "SET ") {
+ return `SELECT 'SET'`
+ }
+
+ // Rewrite system information variables so they are functions so we can inject them.
+ // https://www.postgresql.org/docs/9.1/functions-info.html
+ q = systemFunctionRegex.ReplaceAllString(q, "$1()$2")
+
+ // Rewrite double-colon casting by simply removing it.
+ // https://www.postgresql.org/docs/7.3/sql-expressions.html#SQL-SYNTAX-TYPE-CASTS
+ q = castRegex.ReplaceAllString(q, "")
+
+ return q
+}
+
+var (
+ systemFunctionRegex = regexp.MustCompile(`\b(current_catalog|current_schema|current_user|session_user|user)\b([^\(]|$)`)
+
+ castRegex = regexp.MustCompile(`::(regclass)`)
+)