init - add project files
This commit is contained in:
303
tools/jet-2.12.0/internal/jet/sql_builder.go
Normal file
303
tools/jet-2.12.0/internal/jet/sql_builder.go
Normal file
@@ -0,0 +1,303 @@
|
||||
package jet
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"github.com/go-jet/jet/v2/internal/3rdparty/pq"
|
||||
"github.com/go-jet/jet/v2/internal/utils/is"
|
||||
"github.com/google/uuid"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
// SQLBuilder generates output SQL
|
||||
type SQLBuilder struct {
|
||||
Dialect Dialect
|
||||
Buff bytes.Buffer
|
||||
Args []interface{}
|
||||
|
||||
lastChar byte
|
||||
ident int
|
||||
|
||||
Debug bool
|
||||
}
|
||||
|
||||
const tabSize = 4
|
||||
const defaultIdent = 5
|
||||
|
||||
// IncreaseIdent adds ident or defaultIdent number of spaces to each new line
|
||||
func (s *SQLBuilder) IncreaseIdent(ident ...int) {
|
||||
if len(ident) > 0 {
|
||||
s.ident += ident[0]
|
||||
} else {
|
||||
s.ident += defaultIdent
|
||||
}
|
||||
}
|
||||
|
||||
// DecreaseIdent removes ident or defaultIdent number of spaces for each new line
|
||||
func (s *SQLBuilder) DecreaseIdent(ident ...int) {
|
||||
toDecrease := defaultIdent
|
||||
|
||||
if len(ident) > 0 {
|
||||
toDecrease = ident[0]
|
||||
}
|
||||
|
||||
if s.ident < toDecrease {
|
||||
s.ident = 0
|
||||
}
|
||||
|
||||
s.ident -= toDecrease
|
||||
}
|
||||
|
||||
// WriteProjections func
|
||||
func (s *SQLBuilder) WriteProjections(statement StatementType, projections []Projection) {
|
||||
s.IncreaseIdent()
|
||||
SerializeProjectionList(statement, projections, s)
|
||||
s.DecreaseIdent()
|
||||
}
|
||||
|
||||
// NewLine adds new line to output SQL
|
||||
func (s *SQLBuilder) NewLine() {
|
||||
s.write([]byte{'\n'})
|
||||
s.write(bytes.Repeat([]byte{' '}, s.ident))
|
||||
}
|
||||
|
||||
func (s *SQLBuilder) write(data []byte) {
|
||||
if len(data) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
if !isPreSeparator(s.lastChar) && !isPostSeparator(data[0]) && s.Buff.Len() > 0 {
|
||||
s.Buff.WriteByte(' ')
|
||||
}
|
||||
|
||||
s.Buff.Write(data)
|
||||
s.lastChar = data[len(data)-1]
|
||||
}
|
||||
|
||||
func isPreSeparator(b byte) bool {
|
||||
return b == ' ' || b == '.' || b == ',' || b == '(' || b == '\n' || b == ':'
|
||||
}
|
||||
|
||||
func isPostSeparator(b byte) bool {
|
||||
return b == ' ' || b == '.' || b == ',' || b == ')' || b == '\n' || b == ':'
|
||||
}
|
||||
|
||||
// WriteAlias is used to add alias to output SQL
|
||||
func (s *SQLBuilder) WriteAlias(str string) {
|
||||
aliasQuoteChar := string(s.Dialect.AliasQuoteChar())
|
||||
s.WriteString(aliasQuoteChar + str + aliasQuoteChar)
|
||||
}
|
||||
|
||||
// WriteString writes sting to output SQL
|
||||
func (s *SQLBuilder) WriteString(str string) {
|
||||
s.write([]byte(str))
|
||||
}
|
||||
|
||||
// WriteIdentifier adds identifier to output SQL
|
||||
func (s *SQLBuilder) WriteIdentifier(name string, alwaysQuote ...bool) {
|
||||
if s.shouldQuote(name, alwaysQuote...) {
|
||||
identQuoteChar := string(s.Dialect.IdentifierQuoteChar())
|
||||
s.WriteString(identQuoteChar + name + identQuoteChar)
|
||||
} else {
|
||||
s.WriteString(name)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SQLBuilder) shouldQuote(name string, alwaysQuote ...bool) bool {
|
||||
return s.Dialect.IsReservedWord(name) || shouldQuoteIdentifier(name) || len(alwaysQuote) > 0
|
||||
}
|
||||
|
||||
// WriteByte writes byte to output SQL
|
||||
func (s *SQLBuilder) WriteByte(b byte) {
|
||||
s.write([]byte{b})
|
||||
}
|
||||
|
||||
func (s *SQLBuilder) finalize() (string, []interface{}) {
|
||||
return s.Buff.String() + ";\n", s.Args
|
||||
}
|
||||
|
||||
func (s *SQLBuilder) insertConstantArgument(arg interface{}) {
|
||||
s.WriteString(argToString(arg))
|
||||
}
|
||||
|
||||
func (s *SQLBuilder) insertParametrizedArgument(arg interface{}) {
|
||||
if s.Debug {
|
||||
s.insertConstantArgument(arg)
|
||||
return
|
||||
}
|
||||
|
||||
s.Args = append(s.Args, arg)
|
||||
argPlaceholder := s.Dialect.ArgumentPlaceholder()(len(s.Args))
|
||||
|
||||
s.WriteString(argPlaceholder)
|
||||
}
|
||||
|
||||
func (s *SQLBuilder) insertRawQuery(raw string, namedArg map[string]interface{}) {
|
||||
type namedArgumentPosition struct {
|
||||
Name string
|
||||
Value interface{}
|
||||
Position int
|
||||
}
|
||||
|
||||
var namedArgumentPositions []namedArgumentPosition
|
||||
|
||||
for namedArg, value := range namedArg {
|
||||
rawCopy := raw
|
||||
rawIndex := 0
|
||||
exists := false
|
||||
|
||||
// one named argument can occur multiple times inside raw string
|
||||
for {
|
||||
index := strings.Index(rawCopy, namedArg)
|
||||
if index == -1 {
|
||||
break
|
||||
}
|
||||
|
||||
exists = true
|
||||
namedArgumentPositions = append(namedArgumentPositions, namedArgumentPosition{
|
||||
Name: namedArg,
|
||||
Value: value,
|
||||
Position: rawIndex + index,
|
||||
})
|
||||
|
||||
rawCopy = rawCopy[index+len(namedArg):]
|
||||
rawIndex += index + len(namedArg)
|
||||
}
|
||||
|
||||
if !exists {
|
||||
panic("jet: named argument '" + namedArg + "' does not appear in raw query")
|
||||
}
|
||||
}
|
||||
|
||||
sort.Slice(namedArgumentPositions, func(i, j int) bool {
|
||||
return namedArgumentPositions[i].Position < namedArgumentPositions[j].Position
|
||||
})
|
||||
|
||||
for _, namedArgumentPos := range namedArgumentPositions {
|
||||
// if named argument does not exists in raw string do not add argument to the list of arguments
|
||||
// It can happen if the same argument occurs multiple times in postgres query.
|
||||
if !strings.Contains(raw, namedArgumentPos.Name) {
|
||||
continue
|
||||
}
|
||||
s.Args = append(s.Args, namedArgumentPos.Value)
|
||||
currentArgNum := len(s.Args)
|
||||
|
||||
placeholder := s.Dialect.ArgumentPlaceholder()(currentArgNum)
|
||||
// if placeholder is not unique identifier ($1, $2, etc..), we will replace just one occurrence of the argument
|
||||
toReplace := -1 // all occurrences
|
||||
if placeholder == "?" {
|
||||
toReplace = 1 // just one occurrence
|
||||
}
|
||||
|
||||
if s.Debug {
|
||||
placeholder = argToString(namedArgumentPos.Value)
|
||||
}
|
||||
|
||||
raw = strings.Replace(raw, namedArgumentPos.Name, placeholder, toReplace)
|
||||
}
|
||||
|
||||
s.WriteString(raw)
|
||||
}
|
||||
|
||||
func argToString(value interface{}) string {
|
||||
if is.Nil(value) {
|
||||
return "NULL"
|
||||
}
|
||||
|
||||
switch bindVal := value.(type) {
|
||||
case bool:
|
||||
if bindVal {
|
||||
return "TRUE"
|
||||
}
|
||||
return "FALSE"
|
||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
|
||||
return integerTypesToString(bindVal)
|
||||
|
||||
case float32:
|
||||
return strconv.FormatFloat(float64(bindVal), 'f', -1, 64)
|
||||
case float64:
|
||||
return strconv.FormatFloat(float64(bindVal), 'f', -1, 64)
|
||||
|
||||
case string:
|
||||
return stringQuote(bindVal)
|
||||
case []byte:
|
||||
return stringQuote(string(bindVal))
|
||||
case uuid.UUID:
|
||||
return stringQuote(bindVal.String())
|
||||
case time.Time:
|
||||
return stringQuote(string(pq.FormatTimestamp(bindVal)))
|
||||
default:
|
||||
if strBindValue, ok := bindVal.(fmt.Stringer); ok {
|
||||
return stringQuote(strBindValue.String())
|
||||
}
|
||||
|
||||
if valuer, ok := bindVal.(driver.Valuer); ok {
|
||||
val, err := valuer.Value()
|
||||
|
||||
if err != nil {
|
||||
// If valuer for some reason returns an error, we return error string representation.
|
||||
// This is fine because argToString is called only from DebugSQL, and DebugSQL shouldn't be used in production.
|
||||
return err.Error()
|
||||
}
|
||||
|
||||
return argToString(val)
|
||||
}
|
||||
|
||||
panic(fmt.Sprintf("jet: %s type can not be used as SQL query parameter", reflect.TypeOf(value).String()))
|
||||
}
|
||||
}
|
||||
|
||||
func integerTypesToString(value interface{}) string {
|
||||
switch bindVal := value.(type) {
|
||||
case int:
|
||||
return strconv.FormatInt(int64(bindVal), 10)
|
||||
case uint:
|
||||
return strconv.FormatUint(uint64(bindVal), 10)
|
||||
case int8:
|
||||
return strconv.FormatInt(int64(bindVal), 10)
|
||||
case uint8:
|
||||
return strconv.FormatUint(uint64(bindVal), 10)
|
||||
case int16:
|
||||
return strconv.FormatInt(int64(bindVal), 10)
|
||||
case uint16:
|
||||
return strconv.FormatUint(uint64(bindVal), 10)
|
||||
case int32:
|
||||
return strconv.FormatInt(int64(bindVal), 10)
|
||||
case uint32:
|
||||
return strconv.FormatUint(uint64(bindVal), 10)
|
||||
case int64:
|
||||
return strconv.FormatInt(bindVal, 10)
|
||||
case uint64:
|
||||
return strconv.FormatUint(bindVal, 10)
|
||||
}
|
||||
panic("jet: Unsupported integer type: " + reflect.TypeOf(value).String())
|
||||
}
|
||||
|
||||
func shouldQuoteIdentifier(identifier string) bool {
|
||||
_, err := strconv.ParseInt(identifier, 10, 64)
|
||||
|
||||
if err == nil { // if it is a number we should quote it
|
||||
return true
|
||||
}
|
||||
|
||||
// check if contains non ascii characters
|
||||
for _, c := range identifier {
|
||||
if unicode.IsNumber(c) || c == '_' {
|
||||
continue
|
||||
}
|
||||
if c > unicode.MaxASCII || !unicode.IsLetter(c) || unicode.IsUpper(c) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func stringQuote(value string) string {
|
||||
return `'` + strings.Replace(value, "'", "''", -1) + `'`
|
||||
}
|
||||
Reference in New Issue
Block a user