diff --git a/pkg/client/util.go b/pkg/client/util.go index b7a7fb4e8..d32313fd0 100644 --- a/pkg/client/util.go +++ b/pkg/client/util.go @@ -8,7 +8,8 @@ import ( var ( // List of keywords that are not allowed in read-only mode - reRestrictedKeywords = regexp.MustCompile(`(?mi)\s?(CREATE|INSERT|UPDATE|DROP|DELETE|TRUNCATE|GRANT|OPEN|IMPORT|COPY)\s`) + reRestrictedKeywords = regexp.MustCompile(`(?mi)\s?(CREATE|INSERT|UPDATE|DROP|DELETE|TRUNCATE|GRANT|OPEN|IMPORT|COPY)\s`) + reRestrictedFunctions = regexp.MustCompile(`(?mi)(pg_cancel_backend|pg_terminate_backend)\s*\(`) // Comment regular expressions reSlashComment = regexp.MustCompile(`(?m)/\*.+\*/`) @@ -85,7 +86,7 @@ func containsRestrictedKeywords(str string) bool { str = reSlashComment.ReplaceAllString(str, "") str = reDashComment.ReplaceAllString(str, "") - return reRestrictedKeywords.MatchString(str) + return reRestrictedKeywords.MatchString(str) || reRestrictedFunctions.MatchString(str) } func hasBinary(data string, checkLen int) bool { diff --git a/pkg/client/util_test.go b/pkg/client/util_test.go index 2551d87a1..8a46b1ac8 100644 --- a/pkg/client/util_test.go +++ b/pkg/client/util_test.go @@ -95,6 +95,32 @@ func TestGetMajorMinorVersion(t *testing.T) { } } +func TestContainsRestrictedKeywords(t *testing.T) { + examples := []struct { + input string + result bool + }{ + {"SELECT 1", false}, + {"SELECT * FROM users", false}, + {"CREATE TABLE foo (id int)", true}, + {"INSERT INTO foo VALUES (1)", true}, + {"DROP TABLE foo", true}, + {"DELETE FROM foo", true}, + {"SELECT pg_cancel_backend(1234)", true}, + {"SELECT pg_terminate_backend(1234)", true}, + {"select pg_cancel_backend( 1234 )", true}, + {"select pg_terminate_backend( 1234 )", true}, + {"SELECT PG_CANCEL_BACKEND(1234)", true}, + {"SELECT PG_TERMINATE_BACKEND(1234)", true}, + } + + for _, ex := range examples { + t.Run(ex.input, func(t *testing.T) { + assert.Equal(t, ex.result, containsRestrictedKeywords(ex.input)) + }) + } +} + func TestCheckVersionRequirement(t *testing.T) { examples := []struct { client string