Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 41 additions & 14 deletions ast/base_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ import (
"testing"

"github.com/sqlc-dev/marino/charset"
"github.com/stretchr/testify/require"

"reflect"
)

func TestNodeSetText(t *testing.T) {
Expand All @@ -37,8 +38,12 @@ func TestNodeSetText(t *testing.T) {
}
for _, tt := range tests {
n.SetText(tt.enc, tt.text)
require.Equal(t, tt.expectUTF8Text, n.Text())
require.Equal(t, tt.expectText, n.OriginalText())
if !reflect.DeepEqual(tt.expectUTF8Text, n.Text()) {
t.Fatalf("got %v, want %v", n.Text(), tt.expectUTF8Text)
}
if !reflect.DeepEqual(tt.expectText, n.OriginalText()) {
t.Fatalf("got %v, want %v", n.OriginalText(), tt.expectText)
}
}
}

Expand Down Expand Up @@ -66,7 +71,9 @@ func TestBinaryStringLiteralConversion(t *testing.T) {
}
for _, tt := range printableTests {
n.SetText(charset.EncodingUTF8Impl, tt.text)
require.Equal(t, tt.want, n.Text(), tt.name)
if !reflect.DeepEqual(tt.want, n.Text()) {
t.Fatalf("%v: got %v, want %v", tt.name, n.Text(), tt.want)
}
}

// Binary (non-printable) strings — should convert to 0x hex literals
Expand Down Expand Up @@ -98,7 +105,9 @@ func TestBinaryStringLiteralConversion(t *testing.T) {
}
for _, tt := range binaryTests {
n.SetText(charset.EncodingUTF8Impl, tt.text)
require.Equal(t, tt.want, n.Text(), tt.name)
if !reflect.DeepEqual(tt.want, n.Text()) {
t.Fatalf("%v: got %v, want %v", tt.name, n.Text(), tt.want)
}
}
}

Expand Down Expand Up @@ -206,7 +215,9 @@ func TestBinaryStringLiteralSkipsComments(t *testing.T) {
}
for _, tt := range tests {
n.SetText(charset.EncodingUTF8Impl, tt.text)
require.Equal(t, tt.want, n.Text(), tt.name)
if !reflect.DeepEqual(tt.want, n.Text()) {
t.Fatalf("%v: got %v, want %v", tt.name, n.Text(), tt.want)
}
}
}

Expand All @@ -215,15 +226,21 @@ func TestBinaryStringLiteralNoBackslashEscapes(t *testing.T) {

n.SetText(charset.EncodingUTF8Impl, "SELECT '\\n'")
n.SetNoBackslashEscapes(true)
require.Equal(t, "SELECT '\\n'", n.Text(), "NO_BACKSLASH_ESCAPES literal \\n")
if !reflect.DeepEqual("SELECT '\\n'", n.Text()) {
t.Fatalf("%v: got %v, want %v", "NO_BACKSLASH_ESCAPES literal \\n", n.Text(), "SELECT '\\n'")
}

n.SetText(charset.EncodingUTF8Impl, "SELECT '\\' , 'after'")
n.SetNoBackslashEscapes(true)
require.Equal(t, "SELECT '\\' , 'after'", n.Text(), "NO_BACKSLASH_ESCAPES quote boundary")
if !reflect.DeepEqual("SELECT '\\' , 'after'", n.Text()) {
t.Fatalf("%v: got %v, want %v", "NO_BACKSLASH_ESCAPES quote boundary", n.Text(), "SELECT '\\' , 'after'")
}

n.SetText(charset.EncodingUTF8Impl, "SELECT '\xd2\xe4'")
n.SetNoBackslashEscapes(true)
require.Equal(t, "SELECT 0xd2e4", n.Text(), "NO_BACKSLASH_ESCAPES binary")
if !reflect.DeepEqual("SELECT 0xd2e4", n.Text()) {
t.Fatalf("%v: got %v, want %v", "NO_BACKSLASH_ESCAPES binary", n.Text(), "SELECT 0xd2e4")
}
}

func TestBinaryStringLiteralGBK(t *testing.T) {
Expand All @@ -233,23 +250,33 @@ func TestBinaryStringLiteralGBK(t *testing.T) {
// This should be decoded as valid GBK and left as a printable string,
// not converted to a hex literal.
n.SetText(charset.EncodingGBKImpl, "select '\xb1\xed\x31'")
require.Equal(t, "select '表1'", n.Text(), "GBK printable")
if !reflect.DeepEqual("select '表1'", n.Text()) {
t.Fatalf("%v: got %v, want %v", "GBK printable", n.Text(), "select '表1'")
}

// GBK with actual invalid bytes should still convert to hex
n.SetText(charset.EncodingGBKImpl, "select '\x80\xff'")
require.Equal(t, "select 0x80ff", n.Text(), "GBK binary")
if !reflect.DeepEqual("select 0x80ff", n.Text()) {
t.Fatalf("%v: got %v, want %v", "GBK binary", n.Text(), "select 0x80ff")
}

// 筡 = \xb9\x5c in GBK; trail byte 0x5c must not be mistaken for backslash
n.SetText(charset.EncodingGBKImpl, "select '\xb9\x5c'")
require.Equal(t, "select '筡'", n.Text(), "GBK 0x5c trail byte")
if !reflect.DeepEqual("select '筡'", n.Text()) {
t.Fatalf("%v: got %v, want %v", "GBK 0x5c trail byte", n.Text(), "select '筡'")
}

// Multiple GBK chars with 0x5c trail bytes: 筡 = \xb9\x5c, 臷 = \xc5\x5c
n.SetText(charset.EncodingGBKImpl, "select '\xb9\x5c\xc5\x5c'")
require.Equal(t, "select '筡臷'", n.Text(), "GBK multiple 0x5c trail bytes")
if !reflect.DeepEqual("select '筡臷'", n.Text()) {
t.Fatalf("%v: got %v, want %v", "GBK multiple 0x5c trail bytes", n.Text(), "select '筡臷'")
}

// 0x5c trail byte right before closing quote must not escape the quote
n.SetText(charset.EncodingGBKImpl, "select '\xb9\x5c', 'after'")
require.Equal(t, "select '筡', 'after'", n.Text(), "GBK 0x5c before quote")
if !reflect.DeepEqual("select '筡', 'after'", n.Text()) {
t.Fatalf("%v: got %v, want %v", "GBK 0x5c before quote", n.Text(), "select '筡', 'after'")
}
}

func buildBinaryClause() string {
Expand Down
11 changes: 8 additions & 3 deletions ast/ddl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ import (

. "github.com/sqlc-dev/marino/ast"
"github.com/sqlc-dev/marino/format"
"github.com/stretchr/testify/require"

"reflect"
)

func TestDDLVisitorCover(t *testing.T) {
Expand Down Expand Up @@ -59,8 +60,12 @@ func TestDDLVisitorCover(t *testing.T) {
for _, v := range stmts {
ce.reset()
v.node.Accept(checkVisitor{})
require.Equal(t, v.expectedEnterCnt, ce.enterCnt)
require.Equal(t, v.expectedLeaveCnt, ce.leaveCnt)
if !reflect.DeepEqual(v.expectedEnterCnt, ce.enterCnt) {
t.Fatalf("got %v, want %v", ce.enterCnt, v.expectedEnterCnt)
}
if !reflect.DeepEqual(v.expectedLeaveCnt, ce.leaveCnt) {
t.Fatalf("got %v, want %v", ce.leaveCnt, v.expectedLeaveCnt)
}
v.node.Accept(visitor1{})
}
}
Expand Down
51 changes: 38 additions & 13 deletions ast/dml_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@ import (
"fmt"
"testing"

"github.com/sqlc-dev/marino/parser"
. "github.com/sqlc-dev/marino/ast"
"github.com/sqlc-dev/marino/format"
"github.com/stretchr/testify/require"
"github.com/sqlc-dev/marino/parser"

"reflect"
"regexp"
"strings"
)

func TestDMLVisitorCover(t *testing.T) {
Expand Down Expand Up @@ -68,8 +71,12 @@ func TestDMLVisitorCover(t *testing.T) {
for _, v := range stmts {
ce.reset()
v.node.Accept(checkVisitor{})
require.Equal(t, v.expectedEnterCnt, ce.enterCnt)
require.Equal(t, v.expectedLeaveCnt, ce.leaveCnt)
if !reflect.DeepEqual(v.expectedEnterCnt, ce.enterCnt) {
t.Fatalf("got %v, want %v", ce.enterCnt, v.expectedEnterCnt)
}
if !reflect.DeepEqual(v.expectedLeaveCnt, ce.leaveCnt) {
t.Fatalf("got %v, want %v", ce.leaveCnt, v.expectedLeaveCnt)
}
v.node.Accept(visitor1{})
}
}
Expand Down Expand Up @@ -630,9 +637,15 @@ func TestImportIntoRestore(t *testing.T) {
}

func TestFulltextSearchModifier(t *testing.T) {
require.False(t, FulltextSearchModifier(FulltextSearchModifierNaturalLanguageMode).IsBooleanMode())
require.True(t, FulltextSearchModifier(FulltextSearchModifierNaturalLanguageMode).IsNaturalLanguageMode())
require.False(t, FulltextSearchModifier(FulltextSearchModifierNaturalLanguageMode).WithQueryExpansion())
if FulltextSearchModifier(FulltextSearchModifierNaturalLanguageMode).IsBooleanMode() {
t.Fatal("expected false")
}
if !(FulltextSearchModifier(FulltextSearchModifierNaturalLanguageMode).IsNaturalLanguageMode()) {
t.Fatal("expected true")
}
if FulltextSearchModifier(FulltextSearchModifierNaturalLanguageMode).WithQueryExpansion() {
t.Fatal("expected false")
}
}

func TestImportIntoSecureText(t *testing.T) {
Expand All @@ -658,19 +671,31 @@ func TestImportIntoSecureText(t *testing.T) {
for _, tc := range testCases {
comment := fmt.Sprintf("input = %s", tc.input)
node, err := p.ParseOneStmt(tc.input, "", "")
require.NoError(t, err, comment)
if err != nil {
t.Fatalf("%v: %v", comment, err)
}
n, ok := node.(SensitiveStmtNode)
require.True(t, ok, comment)
require.Regexp(t, tc.secured, n.SecureText(), comment)
if !(ok) {
t.Fatal(comment)
}
if !regexp.MustCompile(tc.secured).MatchString(n.SecureText()) {
t.Fatalf("%v: expected %q to match %q", comment, n.SecureText(), tc.secured)
}
}
}

func TestImportIntoFromSelectInvalidStmt(t *testing.T) {
p := parser.New()
_, err := p.ParseOneStmt("IMPORT INTO t1(a, @1) FROM select * from t2;", "", "")
require.ErrorContains(t, err, "Cannot use user variable(1) in IMPORT INTO FROM SELECT statement")
if err == nil || !strings.Contains(err.Error(), "Cannot use user variable(1) in IMPORT INTO FROM SELECT statement") {
t.Fatalf("expected error containing %q, got %v", "Cannot use user variable(1) in IMPORT INTO FROM SELECT statement", err)
}
_, err = p.ParseOneStmt("IMPORT INTO t1(a, @b) FROM select * from t2;", "", "")
require.ErrorContains(t, err, "Cannot use user variable(b) in IMPORT INTO FROM SELECT statement")
if err == nil || !strings.Contains(err.Error(), "Cannot use user variable(b) in IMPORT INTO FROM SELECT statement") {
t.Fatalf("expected error containing %q, got %v", "Cannot use user variable(b) in IMPORT INTO FROM SELECT statement", err)
}
_, err = p.ParseOneStmt("IMPORT INTO t1(a) set a=1 FROM select a from t2;", "", "")
require.ErrorContains(t, err, "Cannot use SET clause in IMPORT INTO FROM SELECT statement.")
if err == nil || !strings.Contains(err.Error(), "Cannot use SET clause in IMPORT INTO FROM SELECT statement.") {
t.Fatalf("expected error containing %q, got %v", "Cannot use SET clause in IMPORT INTO FROM SELECT statement.", err)
}
}
11 changes: 8 additions & 3 deletions ast/expressions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ import (
. "github.com/sqlc-dev/marino/ast"
"github.com/sqlc-dev/marino/format"
"github.com/sqlc-dev/marino/mysql"
"github.com/stretchr/testify/require"

"reflect"
)

type checkVisitor struct{}
Expand Down Expand Up @@ -94,8 +95,12 @@ func TestExpresionsVisitorCover(t *testing.T) {
for _, v := range stmts {
ce.reset()
v.node.Accept(checkVisitor{})
require.Equal(t, v.expectedEnterCnt, ce.enterCnt)
require.Equal(t, v.expectedLeaveCnt, ce.leaveCnt)
if !reflect.DeepEqual(v.expectedEnterCnt, ce.enterCnt) {
t.Fatalf("got %v, want %v", ce.enterCnt, v.expectedEnterCnt)
}
if !reflect.DeepEqual(v.expectedLeaveCnt, ce.leaveCnt) {
t.Fatalf("got %v, want %v", ce.leaveCnt, v.expectedLeaveCnt)
}
v.node.Accept(visitor1{})
}
}
Expand Down
18 changes: 13 additions & 5 deletions ast/flag_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
package ast_test

import (
"fmt"
"testing"

"github.com/sqlc-dev/marino/parser"
"github.com/sqlc-dev/marino/ast"
"github.com/stretchr/testify/require"
"github.com/sqlc-dev/marino/parser"

"reflect"
)

func TestHasAggFlag(t *testing.T) {
Expand All @@ -33,7 +35,9 @@ func TestHasAggFlag(t *testing.T) {
}
for _, tt := range flagTests {
expr.SetFlag(tt.flag)
require.Equal(t, tt.hasAgg, ast.HasAggFlag(expr))
if !reflect.DeepEqual(tt.hasAgg, ast.HasAggFlag(expr)) {
t.Fatalf("got %v, want %v", ast.HasAggFlag(expr), tt.hasAgg)
}
}
}

Expand Down Expand Up @@ -130,10 +134,14 @@ func TestFlag(t *testing.T) {
p := parser.New()
for _, tt := range flagTests {
stmt, err := p.ParseOneStmt("select "+tt.expr, "", "")
require.NoError(t, err)
if err != nil {
t.Fatal(err)
}
selectStmt := stmt.(*ast.SelectStmt)
ast.SetFlag(selectStmt)
expr := selectStmt.Fields.Fields[0].Expr
require.Equalf(t, tt.flag, expr.GetFlag(), "For %s", tt.expr)
if !reflect.DeepEqual(tt.flag, expr.GetFlag()) {
t.Fatalf("%s: got %v, want %v", fmt.Sprintf("For %s", tt.expr), expr.GetFlag(), tt.flag)
}
}
}
13 changes: 9 additions & 4 deletions ast/format_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@ import (
"fmt"
"testing"

"github.com/sqlc-dev/marino/parser"
"github.com/sqlc-dev/marino/ast"
"github.com/stretchr/testify/require"
"github.com/sqlc-dev/marino/parser"

"reflect"
)

func getDefaultCharsetAndCollate() (string, string) {
Expand Down Expand Up @@ -89,10 +90,14 @@ func TestAstFormat(t *testing.T) {
charset, collation := getDefaultCharsetAndCollate()
stmts, _, err := parser.New().Parse(expr, charset, collation)
node := stmts[0].(*ast.SelectStmt).Fields.Fields[0].Expr
require.NoError(t, err)
if err != nil {
t.Fatal(err)
}

writer := bytes.NewBufferString("")
node.Format(writer)
require.Equal(t, tt.output, writer.String())
if !reflect.DeepEqual(tt.output, writer.String()) {
t.Fatalf("got %v, want %v", writer.String(), tt.output)
}
}
}
Loading
Loading