Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions cursor.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ var (
ErrCursorQueryOrdered = errors.New("cursor query already has order by")
// ErrCursorPageOrdered signals page-level ordering that does not match the cursor order.
ErrCursorPageOrdered = errors.New("cursor page order does not match cursor order")
// ErrCursorPaged signals a page carrying both a cursor and a page number.
ErrCursorPaged = errors.New("cursor and page number are mutually exclusive")
)

// EncodeCursor produces an opaque cursor: base64-JSON, not signed, never use it for authorization.
Expand Down Expand Up @@ -135,6 +137,7 @@ func (p CursorPaginator[T, C, PC]) PrepareResult(result []T, page *Page) ([]T, e
limit := int(page.Limit())
page.Size = uint32(limit)
page.More = len(result) > limit
page.NextCursor = ""
if !page.More {
return result, nil
}
Expand Down
44 changes: 22 additions & 22 deletions page.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package pgkit

import (
"cmp"
"context"
"fmt"
"regexp"
Expand All @@ -25,6 +26,20 @@ const (
Asc Order = "ASC"
)

// IsValid reports whether o is one of the defined sort directions.
func (o Order) IsValid() bool {
return o == Asc || o == Desc
}

// Sanitize normalizes case and surrounding whitespace, defaulting any unrecognized value to Asc.
func (o Order) Sanitize() Order {
o = Order(strings.ToUpper(strings.TrimSpace(string(o))))
if !o.IsValid() {
return Asc
}
return o
}

type Sort struct {
Column string
Order Order
Expand All @@ -39,14 +54,7 @@ func (s Sort) sanitize(columnFunc func(string) string) Sort {
s.Column = pgx.Identifier(strings.Split(s.Column, ".")).Sanitize()
}

switch strings.ToUpper(strings.TrimSpace(string(s.Order))) {
case string(Desc):
s.Order = Desc
case string(Asc):
s.Order = Asc
default:
s.Order = Asc
}
s.Order = s.Order.Sanitize()
return s
}

Expand Down Expand Up @@ -100,23 +108,13 @@ func (p *Page) SetDefaults(o *PaginatorSettings) {
if o == nil {
o = &PaginatorSettings{}
}
defaultSize := o.DefaultSize
if defaultSize == 0 {
defaultSize = DefaultPageSize
}
maxSize := o.MaxSize
if maxSize == 0 {
maxSize = MaxPageSize
}
if p.Size == 0 {
p.Size = defaultSize
}
defaultSize := cmp.Or(o.DefaultSize, DefaultPageSize)
maxSize := cmp.Or(o.MaxSize, MaxPageSize)
p.Size = cmp.Or(p.Size, defaultSize)
if p.Size > maxSize {
p.Size = maxSize
}
if p.Page == 0 {
p.Page = 1
}
p.Page = cmp.Or(p.Page, 1)
}

func (p *Page) GetOrder(columnFunc func(string) string, defaultSort ...string) []Sort {
Expand Down Expand Up @@ -303,6 +301,8 @@ func (p Paginator[T]) PrepareRaw(q string, args []any, page *Page) ([]T, string,
func (p Paginator[T]) PrepareResult(result []T, page *Page) []T {
limit := int(page.Limit())
page.More = len(result) > limit
// Offset pagination yields no cursor - clear any stale value from a reused page.
page.NextCursor = ""
if page.More {
result = result[:limit]
}
Expand Down
91 changes: 91 additions & 0 deletions table.go
Original file line number Diff line number Diff line change
Expand Up @@ -397,11 +397,33 @@ func (t *Table[T, P, I]) List(ctx context.Context, where sq.Sqlizer, orderBy []s
return records, nil
}

// idCursor is the keyset cursor ListPaged encodes when ordering by IDColumn.
// It carries its own direction so a bare {Cursor: next} page continues the walk.
type idCursor[I ID] struct {
ID I `json:"id"`
Order Order `json:"order"`
}

// ListPaged returns paginated records matching the condition.
//
// When the effective order is exactly IDColumn (the default), pages with more
// rows get NextCursor populated. A Page carrying that cursor continues as a
// keyset walk over IDColumn instead of offset pagination: forward-only, no
// random page access, but pages never skip or duplicate rows under concurrent
// writes. The cursor encodes its direction; a conflicting page order returns
// ErrCursorPageOrdered and a cursor combined with a page number > 1 returns
// ErrCursorPaged. IDColumn must be unique for the keyset ordering to be stable.
func (t *Table[T, P, I]) ListPaged(ctx context.Context, where sq.Sqlizer, page *Page) ([]P, *Page, error) {
if page == nil {
page = &Page{}
}
if page.Cursor != "" {
if page.Page > 1 {
return nil, nil, ErrCursorPaged
}
return t.listKeyset(ctx, where, page)
}

// Ensure deterministic ordering for stable pagination.
if len(page.Sort) == 0 && page.Column == "" && len(t.Paginator.settings.Sort) == 0 {
page.Sort = []Sort{{Column: t.IDColumn, Order: Asc}}
Expand All @@ -413,6 +435,75 @@ func (t *Table[T, P, I]) ListPaged(ctx context.Context, where sq.Sqlizer, page *
return nil, nil, err
}
result = t.Paginator.PrepareResult(result, page)
if order, ok := t.idOrder(page); ok && page.More {
next, err := EncodeCursor(idCursor[I]{ID: result[len(result)-1].GetID(), Order: order})
if err != nil {
return nil, nil, err
}
page.NextCursor = next
}
return result, page, nil
}

// idOrder reports whether the page's effective order is exactly IDColumn, and in which direction.
func (t *Table[T, P, I]) idOrder(page *Page) (Order, bool) {
sorts := page.GetOrder(t.Paginator.settings.ColumnFunc, t.Paginator.settings.Sort...)
if len(sorts) != 1 || sorts[0].Column != (Sort{Column: t.IDColumn}).sanitize(nil).Column {
return "", false
}
return sorts[0].Order, true
}

// listKeyset continues a cursor walk over IDColumn; see ListPaged.
func (t *Table[T, P, I]) listKeyset(ctx context.Context, where sq.Sqlizer, page *Page) ([]P, *Page, error) {
cursor, err := DecodeCursor[idCursor[I]](page.Cursor)
if err != nil {
return nil, nil, err
}
// The cursor is minted by ListPaged with an exact direction - anything else is a
// forged or corrupted token, so reject it rather than coerce like user sort input.
order := cursor.Order
if !order.IsValid() {
return nil, nil, ErrInvalidCursor
}
if sorts := page.GetOrder(t.Paginator.settings.ColumnFunc, t.Paginator.settings.Sort...); len(sorts) != 0 {
if len(sorts) != 1 || sorts[0] != (Sort{Column: t.IDColumn, Order: order}).sanitize(nil) {
return nil, nil, ErrCursorPageOrdered
}
}

page.SetDefaults(&t.Paginator.settings)
page.More = false
page.NextCursor = ""

q := t.SQL.Select("*").From(t.Name).Where(where).
OrderBy(Sort{Column: t.IDColumn, Order: order}.String())
if order == Desc {
q = q.Where(sq.Lt{t.IDColumn: cursor.ID})
} else {
q = q.Where(sq.Gt{t.IDColumn: cursor.ID})
}

limit := int(page.Limit())
q = q.Limit(uint64(limit) + 1)

result := make([]P, 0, limit+1)
if err := t.Query.GetAll(ctx, q, &result); err != nil {
return nil, nil, err
}

page.Size = uint32(limit)
if len(result) <= limit {
return result, page, nil
}
page.More = true
result = result[:limit]
next, err := EncodeCursor(idCursor[I]{ID: result[len(result)-1].GetID(), Order: order})
if err != nil {
return nil, nil, err
}
page.NextCursor = next

return result, page, nil
}

Expand Down
124 changes: 124 additions & 0 deletions tests/cursor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,130 @@ func TestCursorPaginatorPaginateReturnsPage(t *testing.T) {
require.NotEqual(t, a.ID, b.ID, "cursor pages should not overlap")
}
}

page.Cursor = page.NextCursor
third, thirdPage, err := paginator.Paginate(ctx, db.Query, q, page)
require.NoError(t, err)
require.Len(t, third, 1)
require.False(t, thirdPage.More)
require.Empty(t, thirdPage.NextCursor, "final page must not leak a stale cursor")
}

func TestTableListPagedCursor(t *testing.T) {
ctx := t.Context()
db := initDB(DB)

account := &Account{Name: "ListPagedCursor Account"}
require.NoError(t, db.Accounts.Save(ctx, account))

for range 5 {
require.NoError(t, db.Articles.Save(ctx, &Article{
AccountID: account.ID,
Author: "Cursor Author",
}))
}
where := sq.Eq{"account_id": account.ID}

// pageAll starts from the given page and walks every page by NextCursor, collecting ids in server order.
pageAll := func(t *testing.T, page *pgkit.Page) []uint64 {
t.Helper()
var ids []uint64
for {
rows, p, err := db.Articles.ListPaged(ctx, where, page)
require.NoError(t, err)
require.LessOrEqual(t, len(rows), 2)
for _, r := range rows {
ids = append(ids, r.ID)
}
if !p.More {
require.Empty(t, p.NextCursor)
break
}
require.NotEmpty(t, p.NextCursor)
page = &pgkit.Page{Size: 2, Cursor: p.NextCursor}
}
return ids
}

descPage := func() *pgkit.Page {
return &pgkit.Page{Size: 2, Sort: []pgkit.Sort{{Column: "id", Order: pgkit.Desc}}}
}

t.Run("Desc walks newest first without gaps or overlap", func(t *testing.T) {
ids := pageAll(t, descPage())
require.Len(t, ids, 5)
for i := 1; i < len(ids); i++ {
require.Greater(t, ids[i-1], ids[i], "ids must strictly descend across pages")
}
})

t.Run("default order walks oldest first without gaps or overlap", func(t *testing.T) {
ids := pageAll(t, &pgkit.Page{Size: 2})
require.Len(t, ids, 5)
for i := 1; i < len(ids); i++ {
require.Less(t, ids[i-1], ids[i], "ids must strictly ascend across pages")
}
})

t.Run("round-tripping the returned page continues the walk", func(t *testing.T) {
page := descPage()
var ids []uint64
for {
rows, p, err := db.Articles.ListPaged(ctx, where, page)
require.NoError(t, err)
for _, r := range rows {
ids = append(ids, r.ID)
}
if !p.More {
require.Empty(t, p.NextCursor, "final page must not leak a stale cursor")
break
}
p.Cursor = p.NextCursor
page = p
}
require.Len(t, ids, 5)
})

t.Run("non-id order emits no cursor", func(t *testing.T) {
_, p, err := db.Articles.ListPaged(ctx, where, &pgkit.Page{Size: 2, Sort: []pgkit.Sort{{Column: "author"}}})
require.NoError(t, err)
require.True(t, p.More)
require.Empty(t, p.NextCursor)
})

t.Run("cursor with a conflicting page order errors", func(t *testing.T) {
_, first, err := db.Articles.ListPaged(ctx, where, descPage())
require.NoError(t, err)
_, _, err = db.Articles.ListPaged(ctx, where, &pgkit.Page{Cursor: first.NextCursor, Sort: []pgkit.Sort{{Column: "id", Order: pgkit.Asc}}})
require.ErrorIs(t, err, pgkit.ErrCursorPageOrdered)
_, _, err = db.Articles.ListPaged(ctx, where, &pgkit.Page{Cursor: first.NextCursor, Sort: []pgkit.Sort{{Column: "author"}}})
require.ErrorIs(t, err, pgkit.ErrCursorPageOrdered)
})

t.Run("cursor with a page number errors", func(t *testing.T) {
_, first, err := db.Articles.ListPaged(ctx, where, &pgkit.Page{Size: 2})
require.NoError(t, err)
_, _, err = db.Articles.ListPaged(ctx, where, &pgkit.Page{Page: 2, Cursor: first.NextCursor})
require.ErrorIs(t, err, pgkit.ErrCursorPaged)
})

t.Run("rejects an undecodable cursor", func(t *testing.T) {
_, _, err := db.Articles.ListPaged(ctx, where, &pgkit.Page{Cursor: "not-a-cursor"})
require.ErrorIs(t, err, pgkit.ErrInvalidCursor)
})

t.Run("rejects a forged cursor order", func(t *testing.T) {
type forgedCursor struct {
ID uint64 `json:"id"`
Order pgkit.Order `json:"order"`
}
for _, order := range []pgkit.Order{"sideways", "asc", ""} {
forged, err := pgkit.EncodeCursor(forgedCursor{ID: 1, Order: order})
require.NoError(t, err)
_, _, err = db.Articles.ListPaged(ctx, where, &pgkit.Page{Cursor: forged})
require.ErrorIs(t, err, pgkit.ErrInvalidCursor, "order %q must be rejected", order)
}
})
}

func TestPaginatorPaginateReturnsPage(t *testing.T) {
Expand Down
Loading