Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ go_test(
"//pkg/parser/mysql",
"//pkg/parser/opcode",
"//pkg/parser/terror",
"//pkg/parser/test_driver",
"@com_github_pingcap_errors//:errors",
"@com_github_stretchr_testify//require",
"@org_uber_go_goleak//:goleak",
Expand Down
5 changes: 4 additions & 1 deletion ast/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,21 @@ go_library(
srcs = [
"ast.go",
"base.go",
"datum.go",
"datum_helper.go",
"ddl.go",
"dml.go",
"expressions.go",
"flag.go",
"functions.go",
"misc.go",
"model.go",
"mydecimal.go",
"procedure.go",
"sem.go",
"stats.go",
"util.go",
"value_expr.go",
],
importpath = "github.com/sqlc-dev/marino/ast",
visibility = ["//visibility:public"],
Expand Down Expand Up @@ -60,7 +64,6 @@ go_test(
"//pkg/parser/charset",
"//pkg/parser/format",
"//pkg/parser/mysql",
"//pkg/parser/test_driver",
"@com_github_stretchr_testify//require",
],
)
3 changes: 0 additions & 3 deletions ast/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -410,9 +410,6 @@ type exprNode struct {
flag uint64
}

// TexprNode is exported for parser driver.
type TexprNode = exprNode

// SetType implements ExprNode interface.
func (en *exprNode) SetType(tp *types.FieldType) {
en.Type = *tp
Expand Down
50 changes: 25 additions & 25 deletions test_driver/test_driver_datum.go → ast/datum.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

//go:build !codes

package test_driver
package ast

import (
"bytes"
Expand Down Expand Up @@ -155,12 +155,12 @@ func (d *Datum) SetNull() {
}

// GetBinaryLiteral gets Bit value
func (d *Datum) GetBinaryLiteral() BinaryLiteral {
func (d *Datum) GetBinaryLiteral() BinaryLit {
return d.b
}

// SetBinaryLiteral sets Bit value
func (d *Datum) SetBinaryLiteral(b BinaryLiteral) {
func (d *Datum) SetBinaryLiteral(b BinaryLit) {
d.k = KindBinaryLiteral
d.b = b
}
Expand Down Expand Up @@ -227,12 +227,12 @@ func (d *Datum) SetValue(val any) {
d.SetBytes(x)
case *MyDecimal:
d.SetMysqlDecimal(x)
case BinaryLiteral:
case BinaryLit:
d.SetBinaryLiteral(x)
case BitLiteral: // Store as BinaryLiteral for Bit and Hex literals
d.SetBinaryLiteral(BinaryLiteral(x))
case BitLiteral: // Store as BinaryLit for Bit and Hex literals
d.SetBinaryLiteral(BinaryLit(x))
case HexLiteral:
d.SetBinaryLiteral(BinaryLiteral(x))
d.SetBinaryLiteral(BinaryLit(x))
default:
d.SetInterface(x)
}
Expand Down Expand Up @@ -270,33 +270,33 @@ func MakeDatums(args ...any) []Datum {
return datums
}

// BinaryLiteral is the internal type for storing bit / hex literal type.
type BinaryLiteral []byte
// BinaryLit is the internal type for storing bit / hex literal type.
type BinaryLit []byte

// BitLiteral is the bit literal type.
type BitLiteral BinaryLiteral
type BitLiteral BinaryLit

// HexLiteral is the hex literal type.
type HexLiteral BinaryLiteral
type HexLiteral BinaryLit

// ZeroBinaryLiteral is a BinaryLiteral literal with zero value.
var ZeroBinaryLiteral = BinaryLiteral{}
// ZeroBinaryLit is a BinaryLit literal with zero value.
var ZeroBinaryLit = BinaryLit{}

// String implements fmt.Stringer interface.
func (b BinaryLiteral) String() string {
func (b BinaryLit) String() string {
if len(b) == 0 {
return ""
}
return "0x" + hex.EncodeToString(b)
}

// ToString returns the string representation for the literal.
func (b BinaryLiteral) ToString() string {
func (b BinaryLit) ToString() string {
return string(b)
}

// ToBitLiteralString returns the bit literal representation for the literal.
func (b BinaryLiteral) ToBitLiteralString(trimLeadingZero bool) string {
func (b BinaryLit) ToBitLiteralString(trimLeadingZero bool) string {
if len(b) == 0 {
return "b''"
}
Expand All @@ -317,7 +317,7 @@ func (b BinaryLiteral) ToBitLiteralString(trimLeadingZero bool) string {
// ParseBitStr parses bit string.
// The string format can be b'val', B'val' or 0bval, val must be 0 or 1.
// See https://dev.mysql.com/doc/refman/5.7/en/bit-value-literals.html
func ParseBitStr(s string) (BinaryLiteral, error) {
func ParseBitStr(s string) (BinaryLit, error) {
if len(s) == 0 {
return nil, fmt.Errorf("invalid empty string for parsing bit type")
}
Expand All @@ -333,7 +333,7 @@ func ParseBitStr(s string) (BinaryLiteral, error) {
}

if len(s) == 0 {
return ZeroBinaryLiteral, nil
return ZeroBinaryLit, nil
}

alignedLength := (len(s) + 7) &^ 7
Expand Down Expand Up @@ -362,14 +362,14 @@ func NewBitLiteral(s string) (BitLiteral, error) {
return BitLiteral(b), nil
}

// ToString implement ast.BinaryLiteral interface
// ToString implement BinaryLiteral interface
func (b BitLiteral) ToString() string {
return BinaryLiteral(b).ToString()
return BinaryLit(b).ToString()
}

// ParseHexStr parses hexadecimal string literal.
// See https://dev.mysql.com/doc/refman/5.7/en/hexadecimal-literals.html
func ParseHexStr(s string) (BinaryLiteral, error) {
func ParseHexStr(s string) (BinaryLit, error) {
if len(s) == 0 {
return nil, fmt.Errorf("invalid empty string for parsing hexadecimal literal")
}
Expand All @@ -388,7 +388,7 @@ func ParseHexStr(s string) (BinaryLiteral, error) {
}

if len(s) == 0 {
return ZeroBinaryLiteral, nil
return ZeroBinaryLit, nil
}

if len(s)%2 != 0 {
Expand All @@ -410,9 +410,9 @@ func NewHexLiteral(s string) (HexLiteral, error) {
return HexLiteral(h), nil
}

// ToString implement ast.BinaryLiteral interface
// ToString implement BinaryLiteral interface
func (b HexLiteral) ToString() string {
return BinaryLiteral(b).ToString()
return BinaryLit(b).ToString()
}

// SetBinChsClnFlag sets charset, collation as 'binary' and adds binaryFlag to FieldType.
Expand Down Expand Up @@ -491,7 +491,7 @@ func DefaultTypeForValue(value any, tp *types.FieldType, charset string, collate
tp.SetDecimal(0)
tp.AddFlag(mysql.UnsignedFlag)
SetBinChsClnFlag(tp)
case BinaryLiteral:
case BinaryLit:
tp.SetType(mysql.TypeBit)
tp.SetFlen(len(x) * 8)
tp.SetDecimal(0)
Expand Down
6 changes: 3 additions & 3 deletions test_driver/test_driver_helper.go → ast/datum_helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

//go:build !codes

package test_driver
package ast

import (
"math"
Expand All @@ -38,7 +38,7 @@ func pow10(x int) int32 {
return int32(math.Pow10(x))
}

func Abs(n int64) int64 {
func absInt64(n int64) int64 {
y := n >> 63
return (n ^ y) - y
}
Expand Down Expand Up @@ -68,5 +68,5 @@ func StrLenOfInt64Fast(x int64) int {
if x < 0 {
size = 1 // add "-" sign on the length count
}
return size + StrLenOfUint64Fast(uint64(Abs(x)))
return size + StrLenOfUint64Fast(uint64(absInt64(x)))
}
6 changes: 0 additions & 6 deletions ast/expressions.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,6 @@ type ValueExpr interface {
SetProjectionOffset(offset int)
}

// NewValueExpr creates a ValueExpr with value, and sets default field type.
var NewValueExpr func(value any, charset string, collate string) ValueExpr

// NewParamMarkerExpr creates a ParamMarkerExpr.
var NewParamMarkerExpr func(offset int) ParamMarkerExpr

// BetweenExpr is for "between and" or "not between and" expression.
type BetweenExpr struct {
exprNode
Expand Down
5 changes: 2 additions & 3 deletions ast/functions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import (
"github.com/sqlc-dev/marino/format"
"github.com/sqlc-dev/marino/mysql"
"github.com/sqlc-dev/marino/parser"
"github.com/sqlc-dev/marino/test_driver"

"reflect"
)
Expand Down Expand Up @@ -182,7 +181,7 @@ func TestConvert(t *testing.T) {

st := stmt.(*SelectStmt)
expr := st.Fields.Fields[0].Expr.(*FuncCallExpr)
charsetArg := expr.Args[1].(*test_driver.ValueExpr)
charsetArg := expr.Args[1].(*ValueExprBase)
if !reflect.DeepEqual(testCase.CharsetName, charsetArg.GetString()) {
t.Fatalf("got %v, want %v", charsetArg.GetString(), testCase.CharsetName)
}
Expand Down Expand Up @@ -217,7 +216,7 @@ func TestChar(t *testing.T) {

st := stmt.(*SelectStmt)
expr := st.Fields.Fields[0].Expr.(*FuncCallExpr)
charsetArg := expr.Args[1].(*test_driver.ValueExpr)
charsetArg := expr.Args[1].(*ValueExprBase)
if !reflect.DeepEqual(testCase.CharsetName, charsetArg.GetString()) {
t.Fatalf("got %v, want %v", charsetArg.GetString(), testCase.CharsetName)
}
Expand Down
11 changes: 2 additions & 9 deletions ast/misc.go
Original file line number Diff line number Diff line change
Expand Up @@ -4226,19 +4226,12 @@ type TextString struct {
IsBinaryLiteral bool
}

// BinaryLiteral abstracts over the concrete bit/hex literal types so the
// parser can stringify any of them without a type switch.
type BinaryLiteral interface {
ToString() string
}

// NewDecimal creates a types.Decimal value, it's provided by parser driver.
var NewDecimal func(string) (any, error)

// NewHexLiteral creates a types.HexLiteral value, it's provided by parser driver.
var NewHexLiteral func(string) (any, error)

// NewBitLiteral creates a types.BitLiteral value, it's provided by parser driver.
var NewBitLiteral func(string) (any, error)

// SetResourceGroupStmt is a statement to set the resource group name for current session.
type SetResourceGroupStmt struct {
stmtNode
Expand Down
2 changes: 1 addition & 1 deletion test_driver/test_driver_mydecimal.go → ast/mydecimal.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

//go:build !codes

package test_driver
package ast

const panicInfo = "This branch is not implemented. " +
"This is because you are trying to test something specific to TiDB's MyDecimal implementation. " +
Expand Down
3 changes: 1 addition & 2 deletions ast/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import (
. "github.com/sqlc-dev/marino/format"
"github.com/sqlc-dev/marino/mysql"
"github.com/sqlc-dev/marino/parser"
"github.com/sqlc-dev/marino/test_driver"

"reflect"
)
Expand Down Expand Up @@ -194,7 +193,7 @@ func (checker *nodeTextCleaner) Enter(in Node) (out Node, skipChildren bool) {
node.FnName.O = strings.ToLower(node.FnName.O)
switch node.FnName.L {
case "convert":
node.Args[1].(*test_driver.ValueExpr).Datum.SetBytes(nil)
node.Args[1].(*ValueExprBase).Datum.SetBytes(nil)
}
case *AggregateFuncExpr:
node.F = strings.ToLower(node.F)
Expand Down
Loading
Loading