* Dropped unused codekit config * Integrated dynamic and static bindata for public * Ignore public bindata * Add a general generate make task * Integrated flexible public assets into web command * Updated vendoring, added all missiong govendor deps * Made the linter happy with the bindata and dynamic code * Moved public bindata definition to modules directory * Ignoring the new bindata path now * Updated to the new public modules import path * Updated public bindata command and drop the new prefix
		
			
				
	
	
		
			718 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			718 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| // Copyright 2015 PingCAP, Inc.
 | |
| //
 | |
| // Licensed under the Apache License, Version 2.0 (the "License");
 | |
| // you may not use this file except in compliance with the License.
 | |
| // You may obtain a copy of the License at
 | |
| //
 | |
| //     http://www.apache.org/licenses/LICENSE-2.0
 | |
| //
 | |
| // Unless required by applicable law or agreed to in writing, software
 | |
| // distributed under the License is distributed on an "AS IS" BASIS,
 | |
| // See the License for the specific language governing permissions and
 | |
| // limitations under the License.
 | |
| 
 | |
| package evaluator
 | |
| 
 | |
| import (
 | |
| 	"strings"
 | |
| 
 | |
| 	"github.com/juju/errors"
 | |
| 	"github.com/pingcap/tidb/ast"
 | |
| 	"github.com/pingcap/tidb/context"
 | |
| 	"github.com/pingcap/tidb/mysql"
 | |
| 	"github.com/pingcap/tidb/parser/opcode"
 | |
| 	"github.com/pingcap/tidb/sessionctx/variable"
 | |
| 	"github.com/pingcap/tidb/terror"
 | |
| 	"github.com/pingcap/tidb/util/types"
 | |
| )
 | |
| 
 | |
| // Error instances.
 | |
| var (
 | |
| 	ErrInvalidOperation = terror.ClassEvaluator.New(CodeInvalidOperation, "invalid operation")
 | |
| )
 | |
| 
 | |
| // Error codes.
 | |
| const (
 | |
| 	CodeInvalidOperation terror.ErrCode = 1
 | |
| )
 | |
| 
 | |
| // Eval evaluates an expression to a value.
 | |
| func Eval(ctx context.Context, expr ast.ExprNode) (interface{}, error) {
 | |
| 	e := &Evaluator{ctx: ctx}
 | |
| 	expr.Accept(e)
 | |
| 	if e.err != nil {
 | |
| 		return nil, errors.Trace(e.err)
 | |
| 	}
 | |
| 	return expr.GetValue(), nil
 | |
| }
 | |
| 
 | |
| // EvalBool evalueates an expression to a boolean value.
 | |
| func EvalBool(ctx context.Context, expr ast.ExprNode) (bool, error) {
 | |
| 	val, err := Eval(ctx, expr)
 | |
| 	if err != nil {
 | |
| 		return false, errors.Trace(err)
 | |
| 	}
 | |
| 	if val == nil {
 | |
| 		return false, nil
 | |
| 	}
 | |
| 
 | |
| 	i, err := types.ToBool(val)
 | |
| 	if err != nil {
 | |
| 		return false, errors.Trace(err)
 | |
| 	}
 | |
| 	return i != 0, nil
 | |
| }
 | |
| 
 | |
| func boolToInt64(v bool) int64 {
 | |
| 	if v {
 | |
| 		return int64(1)
 | |
| 	}
 | |
| 	return int64(0)
 | |
| }
 | |
| 
 | |
| // Evaluator is an ast Visitor that evaluates an expression.
 | |
| type Evaluator struct {
 | |
| 	ctx          context.Context
 | |
| 	err          error
 | |
| 	multipleRows bool
 | |
| 	existRow     bool
 | |
| }
 | |
| 
 | |
| // Enter implements ast.Visitor interface.
 | |
| func (e *Evaluator) Enter(in ast.Node) (out ast.Node, skipChildren bool) {
 | |
| 	switch v := in.(type) {
 | |
| 	case *ast.SubqueryExpr:
 | |
| 		if v.Evaluated && !v.UseOuterContext {
 | |
| 			return in, true
 | |
| 		}
 | |
| 	case *ast.PatternInExpr, *ast.CompareSubqueryExpr:
 | |
| 		e.multipleRows = true
 | |
| 	case *ast.ExistsSubqueryExpr:
 | |
| 		e.existRow = true
 | |
| 	}
 | |
| 	return in, false
 | |
| }
 | |
| 
 | |
| // Leave implements ast.Visitor interface.
 | |
| func (e *Evaluator) Leave(in ast.Node) (out ast.Node, ok bool) {
 | |
| 	switch v := in.(type) {
 | |
| 	case *ast.AggregateFuncExpr:
 | |
| 		ok = e.aggregateFunc(v)
 | |
| 	case *ast.BetweenExpr:
 | |
| 		ok = e.between(v)
 | |
| 	case *ast.BinaryOperationExpr:
 | |
| 		ok = e.binaryOperation(v)
 | |
| 	case *ast.CaseExpr:
 | |
| 		ok = e.caseExpr(v)
 | |
| 	case *ast.ColumnName:
 | |
| 		ok = true
 | |
| 	case *ast.ColumnNameExpr:
 | |
| 		ok = e.columnName(v)
 | |
| 	case *ast.CompareSubqueryExpr:
 | |
| 		e.multipleRows = false
 | |
| 		ok = e.compareSubquery(v)
 | |
| 	case *ast.DefaultExpr:
 | |
| 		ok = e.defaultExpr(v)
 | |
| 	case *ast.ExistsSubqueryExpr:
 | |
| 		e.existRow = false
 | |
| 		ok = e.existsSubquery(v)
 | |
| 	case *ast.FuncCallExpr:
 | |
| 		ok = e.funcCall(v)
 | |
| 	case *ast.FuncCastExpr:
 | |
| 		ok = e.funcCast(v)
 | |
| 	case *ast.IsNullExpr:
 | |
| 		ok = e.isNull(v)
 | |
| 	case *ast.IsTruthExpr:
 | |
| 		ok = e.isTruth(v)
 | |
| 	case *ast.ParamMarkerExpr:
 | |
| 		ok = e.paramMarker(v)
 | |
| 	case *ast.ParenthesesExpr:
 | |
| 		ok = e.parentheses(v)
 | |
| 	case *ast.PatternInExpr:
 | |
| 		e.multipleRows = false
 | |
| 		ok = e.patternIn(v)
 | |
| 	case *ast.PatternLikeExpr:
 | |
| 		ok = e.patternLike(v)
 | |
| 	case *ast.PatternRegexpExpr:
 | |
| 		ok = e.patternRegexp(v)
 | |
| 	case *ast.PositionExpr:
 | |
| 		ok = e.position(v)
 | |
| 	case *ast.RowExpr:
 | |
| 		ok = e.row(v)
 | |
| 	case *ast.SubqueryExpr:
 | |
| 		ok = e.subqueryExpr(v)
 | |
| 	case ast.SubqueryExec:
 | |
| 		ok = e.subqueryExec(v)
 | |
| 	case *ast.UnaryOperationExpr:
 | |
| 		ok = e.unaryOperation(v)
 | |
| 	case *ast.ValueExpr:
 | |
| 		ok = true
 | |
| 	case *ast.ValuesExpr:
 | |
| 		ok = e.values(v)
 | |
| 	case *ast.VariableExpr:
 | |
| 		ok = e.variable(v)
 | |
| 	case *ast.WhenClause:
 | |
| 		ok = true
 | |
| 	}
 | |
| 	out = in
 | |
| 	return
 | |
| }
 | |
| 
 | |
| func (e *Evaluator) between(v *ast.BetweenExpr) bool {
 | |
| 	var l, r ast.ExprNode
 | |
| 	op := opcode.AndAnd
 | |
| 
 | |
| 	if v.Not {
 | |
| 		// v < lv || v > rv
 | |
| 		op = opcode.OrOr
 | |
| 		l = &ast.BinaryOperationExpr{Op: opcode.LT, L: v.Expr, R: v.Left}
 | |
| 		r = &ast.BinaryOperationExpr{Op: opcode.GT, L: v.Expr, R: v.Right}
 | |
| 	} else {
 | |
| 		// v >= lv && v <= rv
 | |
| 		l = &ast.BinaryOperationExpr{Op: opcode.GE, L: v.Expr, R: v.Left}
 | |
| 		r = &ast.BinaryOperationExpr{Op: opcode.LE, L: v.Expr, R: v.Right}
 | |
| 	}
 | |
| 
 | |
| 	ret := &ast.BinaryOperationExpr{Op: op, L: l, R: r}
 | |
| 	ret.Accept(e)
 | |
| 	if e.err != nil {
 | |
| 		return false
 | |
| 	}
 | |
| 	v.SetDatum(*ret.GetDatum())
 | |
| 	return true
 | |
| }
 | |
| 
 | |
| func (e *Evaluator) caseExpr(v *ast.CaseExpr) bool {
 | |
| 	tmp := types.NewDatum(boolToInt64(true))
 | |
| 	target := &tmp
 | |
| 	if v.Value != nil {
 | |
| 		target = v.Value.GetDatum()
 | |
| 	}
 | |
| 	if target.Kind() != types.KindNull {
 | |
| 		for _, val := range v.WhenClauses {
 | |
| 			cmp, err := target.CompareDatum(*val.Expr.GetDatum())
 | |
| 			if err != nil {
 | |
| 				e.err = errors.Trace(err)
 | |
| 				return false
 | |
| 			}
 | |
| 			if cmp == 0 {
 | |
| 				v.SetDatum(*val.Result.GetDatum())
 | |
| 				return true
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| 	if v.ElseClause != nil {
 | |
| 		v.SetDatum(*v.ElseClause.GetDatum())
 | |
| 	} else {
 | |
| 		v.SetNull()
 | |
| 	}
 | |
| 	return true
 | |
| }
 | |
| 
 | |
| func (e *Evaluator) columnName(v *ast.ColumnNameExpr) bool {
 | |
| 	v.SetDatum(*v.Refer.Expr.GetDatum())
 | |
| 	return true
 | |
| }
 | |
| 
 | |
| func (e *Evaluator) defaultExpr(v *ast.DefaultExpr) bool {
 | |
| 	return true
 | |
| }
 | |
| 
 | |
| func (e *Evaluator) compareSubquery(cs *ast.CompareSubqueryExpr) bool {
 | |
| 	lvDatum := cs.L.GetDatum()
 | |
| 	if lvDatum.Kind() == types.KindNull {
 | |
| 		cs.SetNull()
 | |
| 		return true
 | |
| 	}
 | |
| 	lv := lvDatum.GetValue()
 | |
| 	x, err := e.checkResult(cs, lv, cs.R.GetValue().([]interface{}))
 | |
| 	if err != nil {
 | |
| 		e.err = errors.Trace(err)
 | |
| 		return false
 | |
| 	}
 | |
| 	cs.SetValue(x)
 | |
| 	return true
 | |
| }
 | |
| 
 | |
| func (e *Evaluator) checkResult(cs *ast.CompareSubqueryExpr, lv interface{}, result []interface{}) (interface{}, error) {
 | |
| 	if cs.All {
 | |
| 		return e.checkAllResult(cs, lv, result)
 | |
| 	}
 | |
| 	return e.checkAnyResult(cs, lv, result)
 | |
| }
 | |
| 
 | |
| func (e *Evaluator) checkAllResult(cs *ast.CompareSubqueryExpr, lv interface{}, result []interface{}) (interface{}, error) {
 | |
| 	hasNull := false
 | |
| 	for _, v := range result {
 | |
| 		if v == nil {
 | |
| 			hasNull = true
 | |
| 			continue
 | |
| 		}
 | |
| 
 | |
| 		comRes, err := types.Compare(lv, v)
 | |
| 		if err != nil {
 | |
| 			return nil, errors.Trace(err)
 | |
| 		}
 | |
| 
 | |
| 		res, err := getCompResult(cs.Op, comRes)
 | |
| 		if err != nil {
 | |
| 			return nil, errors.Trace(err)
 | |
| 		}
 | |
| 		if !res {
 | |
| 			return false, nil
 | |
| 		}
 | |
| 	}
 | |
| 	if hasNull {
 | |
| 		// If no matched but we get null, return null.
 | |
| 		// Like `insert t (c) values (1),(2),(null)`, then
 | |
| 		// `select 3 > all (select c from t)`, returns null.
 | |
| 		return nil, nil
 | |
| 	}
 | |
| 	return true, nil
 | |
| }
 | |
| 
 | |
| func (e *Evaluator) checkAnyResult(cs *ast.CompareSubqueryExpr, lv interface{}, result []interface{}) (interface{}, error) {
 | |
| 	hasNull := false
 | |
| 	for _, v := range result {
 | |
| 		if v == nil {
 | |
| 			hasNull = true
 | |
| 			continue
 | |
| 		}
 | |
| 
 | |
| 		comRes, err := types.Compare(lv, v)
 | |
| 		if err != nil {
 | |
| 			return nil, errors.Trace(err)
 | |
| 		}
 | |
| 
 | |
| 		res, err := getCompResult(cs.Op, comRes)
 | |
| 		if err != nil {
 | |
| 			return nil, errors.Trace(err)
 | |
| 		}
 | |
| 		if res {
 | |
| 			return true, nil
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	if hasNull {
 | |
| 		// If no matched but we get null, return null.
 | |
| 		// Like `insert t (c) values (1),(2),(null)`, then
 | |
| 		// `select 0 > any (select c from t)`, returns null.
 | |
| 		return nil, nil
 | |
| 	}
 | |
| 
 | |
| 	return false, nil
 | |
| }
 | |
| 
 | |
| func (e *Evaluator) existsSubquery(v *ast.ExistsSubqueryExpr) bool {
 | |
| 	datum := v.Sel.GetDatum()
 | |
| 	if datum.Kind() == types.KindNull {
 | |
| 		v.SetInt64(0)
 | |
| 		return true
 | |
| 	}
 | |
| 	r := datum.GetValue()
 | |
| 	rows, _ := r.([]interface{})
 | |
| 	if len(rows) > 0 {
 | |
| 		v.SetInt64(1)
 | |
| 	} else {
 | |
| 		v.SetInt64(0)
 | |
| 	}
 | |
| 	return true
 | |
| }
 | |
| 
 | |
| // Evaluate SubqueryExpr.
 | |
| // Get the value from v.SubQuery and set it to v.
 | |
| func (e *Evaluator) subqueryExpr(v *ast.SubqueryExpr) bool {
 | |
| 	if v.SubqueryExec != nil {
 | |
| 		v.SetDatum(*v.SubqueryExec.GetDatum())
 | |
| 	}
 | |
| 	v.Evaluated = true
 | |
| 	return true
 | |
| }
 | |
| 
 | |
| // Do the real work to evaluate subquery.
 | |
| func (e *Evaluator) subqueryExec(v ast.SubqueryExec) bool {
 | |
| 	rowCount := 2
 | |
| 	if e.multipleRows {
 | |
| 		rowCount = -1
 | |
| 	} else if e.existRow {
 | |
| 		rowCount = 1
 | |
| 	}
 | |
| 	rows, err := v.EvalRows(e.ctx, rowCount)
 | |
| 	if err != nil {
 | |
| 		e.err = errors.Trace(err)
 | |
| 		return false
 | |
| 	}
 | |
| 	if e.multipleRows || e.existRow {
 | |
| 		v.SetValue(rows)
 | |
| 		return true
 | |
| 	}
 | |
| 	switch len(rows) {
 | |
| 	case 0:
 | |
| 		v.GetDatum().SetNull()
 | |
| 	case 1:
 | |
| 		v.SetValue(rows[0])
 | |
| 	default:
 | |
| 		e.err = errors.New("Subquery returns more than 1 row")
 | |
| 		return false
 | |
| 	}
 | |
| 	return true
 | |
| }
 | |
| 
 | |
| func (e *Evaluator) checkInList(not bool, in interface{}, list []interface{}) interface{} {
 | |
| 	hasNull := false
 | |
| 	for _, v := range list {
 | |
| 		if v == nil {
 | |
| 			hasNull = true
 | |
| 			continue
 | |
| 		}
 | |
| 
 | |
| 		r, err := types.Compare(types.Coerce(in, v))
 | |
| 		if err != nil {
 | |
| 			e.err = errors.Trace(err)
 | |
| 			return nil
 | |
| 		}
 | |
| 		if r == 0 {
 | |
| 			if !not {
 | |
| 				return 1
 | |
| 			}
 | |
| 			return 0
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	if hasNull {
 | |
| 		// if no matched but we got null in In, return null
 | |
| 		// e.g 1 in (null, 2, 3) returns null
 | |
| 		return nil
 | |
| 	}
 | |
| 	if not {
 | |
| 		return 1
 | |
| 	}
 | |
| 	return 0
 | |
| }
 | |
| 
 | |
| func (e *Evaluator) patternIn(n *ast.PatternInExpr) bool {
 | |
| 	lhs := n.Expr.GetDatum()
 | |
| 	if lhs.Kind() == types.KindNull {
 | |
| 		n.SetNull()
 | |
| 		return true
 | |
| 	}
 | |
| 	if n.Sel == nil {
 | |
| 		values := make([]interface{}, 0, len(n.List))
 | |
| 		for _, ei := range n.List {
 | |
| 			values = append(values, ei.GetValue())
 | |
| 		}
 | |
| 		x := e.checkInList(n.Not, lhs.GetValue(), values)
 | |
| 		if e.err != nil {
 | |
| 			return false
 | |
| 		}
 | |
| 		n.SetValue(x)
 | |
| 		return true
 | |
| 	}
 | |
| 	se := n.Sel.(*ast.SubqueryExpr)
 | |
| 	sel := se.SubqueryExec
 | |
| 
 | |
| 	res := sel.GetValue().([]interface{})
 | |
| 	x := e.checkInList(n.Not, lhs.GetValue(), res)
 | |
| 	if e.err != nil {
 | |
| 		return false
 | |
| 	}
 | |
| 	n.SetValue(x)
 | |
| 	return true
 | |
| }
 | |
| 
 | |
| func (e *Evaluator) isNull(v *ast.IsNullExpr) bool {
 | |
| 	var boolVal bool
 | |
| 	if v.Expr.GetDatum().Kind() == types.KindNull {
 | |
| 		boolVal = true
 | |
| 	}
 | |
| 	if v.Not {
 | |
| 		boolVal = !boolVal
 | |
| 	}
 | |
| 	v.SetInt64(boolToInt64(boolVal))
 | |
| 	return true
 | |
| }
 | |
| 
 | |
| func (e *Evaluator) isTruth(v *ast.IsTruthExpr) bool {
 | |
| 	var boolVal bool
 | |
| 	datum := v.Expr.GetDatum()
 | |
| 	if datum.Kind() != types.KindNull {
 | |
| 		ival, err := datum.ToBool()
 | |
| 		if err != nil {
 | |
| 			e.err = errors.Trace(err)
 | |
| 			return false
 | |
| 		}
 | |
| 		if ival == v.True {
 | |
| 			boolVal = true
 | |
| 		}
 | |
| 	}
 | |
| 	if v.Not {
 | |
| 		boolVal = !boolVal
 | |
| 	}
 | |
| 	v.GetDatum().SetInt64(boolToInt64(boolVal))
 | |
| 	return true
 | |
| }
 | |
| 
 | |
| func (e *Evaluator) paramMarker(v *ast.ParamMarkerExpr) bool {
 | |
| 	return true
 | |
| }
 | |
| 
 | |
| func (e *Evaluator) parentheses(v *ast.ParenthesesExpr) bool {
 | |
| 	v.SetDatum(*v.Expr.GetDatum())
 | |
| 	return true
 | |
| }
 | |
| 
 | |
| func (e *Evaluator) position(v *ast.PositionExpr) bool {
 | |
| 	v.SetDatum(*v.Refer.Expr.GetDatum())
 | |
| 	return true
 | |
| }
 | |
| 
 | |
| func (e *Evaluator) row(v *ast.RowExpr) bool {
 | |
| 	row := make([]interface{}, 0, len(v.Values))
 | |
| 	for _, val := range v.Values {
 | |
| 		row = append(row, val.GetValue())
 | |
| 	}
 | |
| 	v.SetValue(row)
 | |
| 	return true
 | |
| }
 | |
| 
 | |
| func (e *Evaluator) unaryOperation(u *ast.UnaryOperationExpr) bool {
 | |
| 	defer func() {
 | |
| 		if er := recover(); er != nil {
 | |
| 			e.err = errors.Errorf("%v", er)
 | |
| 		}
 | |
| 	}()
 | |
| 	aDatum := u.V.GetDatum()
 | |
| 	if aDatum.Kind() == types.KindNull {
 | |
| 		u.SetNull()
 | |
| 		return true
 | |
| 	}
 | |
| 	switch op := u.Op; op {
 | |
| 	case opcode.Not:
 | |
| 		n, err := aDatum.ToBool()
 | |
| 		if err != nil {
 | |
| 			e.err = errors.Trace(err)
 | |
| 		} else if n == 0 {
 | |
| 			u.SetInt64(1)
 | |
| 		} else {
 | |
| 			u.SetInt64(0)
 | |
| 		}
 | |
| 	case opcode.BitNeg:
 | |
| 		// for bit operation, we will use int64 first, then return uint64
 | |
| 		n, err := aDatum.ToInt64()
 | |
| 		if err != nil {
 | |
| 			e.err = errors.Trace(err)
 | |
| 			return false
 | |
| 		}
 | |
| 		u.SetUint64(uint64(^n))
 | |
| 	case opcode.Plus:
 | |
| 		switch aDatum.Kind() {
 | |
| 		case types.KindInt64,
 | |
| 			types.KindUint64,
 | |
| 			types.KindFloat64,
 | |
| 			types.KindFloat32,
 | |
| 			types.KindMysqlDuration,
 | |
| 			types.KindMysqlTime,
 | |
| 			types.KindString,
 | |
| 			types.KindMysqlDecimal,
 | |
| 			types.KindBytes,
 | |
| 			types.KindMysqlHex,
 | |
| 			types.KindMysqlBit,
 | |
| 			types.KindMysqlEnum,
 | |
| 			types.KindMysqlSet:
 | |
| 			u.SetDatum(*aDatum)
 | |
| 		default:
 | |
| 			e.err = ErrInvalidOperation
 | |
| 			return false
 | |
| 		}
 | |
| 	case opcode.Minus:
 | |
| 		switch aDatum.Kind() {
 | |
| 		case types.KindInt64:
 | |
| 			u.SetInt64(-aDatum.GetInt64())
 | |
| 		case types.KindUint64:
 | |
| 			u.SetInt64(-int64(aDatum.GetUint64()))
 | |
| 		case types.KindFloat64:
 | |
| 			u.SetFloat64(-aDatum.GetFloat64())
 | |
| 		case types.KindFloat32:
 | |
| 			u.SetFloat32(-aDatum.GetFloat32())
 | |
| 		case types.KindMysqlDuration:
 | |
| 			u.SetValue(mysql.ZeroDecimal.Sub(aDatum.GetMysqlDuration().ToNumber()))
 | |
| 		case types.KindMysqlTime:
 | |
| 			u.SetValue(mysql.ZeroDecimal.Sub(aDatum.GetMysqlTime().ToNumber()))
 | |
| 		case types.KindString:
 | |
| 			f, err := types.StrToFloat(aDatum.GetString())
 | |
| 			e.err = errors.Trace(err)
 | |
| 			u.SetFloat64(-f)
 | |
| 		case types.KindMysqlDecimal:
 | |
| 			f, _ := aDatum.GetMysqlDecimal().Float64()
 | |
| 			u.SetValue(mysql.NewDecimalFromFloat(-f))
 | |
| 		case types.KindBytes:
 | |
| 			f, err := types.StrToFloat(string(aDatum.GetBytes()))
 | |
| 			e.err = errors.Trace(err)
 | |
| 			u.SetFloat64(-f)
 | |
| 		case types.KindMysqlHex:
 | |
| 			u.SetFloat64(-aDatum.GetMysqlHex().ToNumber())
 | |
| 		case types.KindMysqlBit:
 | |
| 			u.SetFloat64(-aDatum.GetMysqlBit().ToNumber())
 | |
| 		case types.KindMysqlEnum:
 | |
| 			u.SetFloat64(-aDatum.GetMysqlEnum().ToNumber())
 | |
| 		case types.KindMysqlSet:
 | |
| 			u.SetFloat64(-aDatum.GetMysqlSet().ToNumber())
 | |
| 		default:
 | |
| 			e.err = ErrInvalidOperation
 | |
| 			return false
 | |
| 		}
 | |
| 	default:
 | |
| 		e.err = ErrInvalidOperation
 | |
| 		return false
 | |
| 	}
 | |
| 
 | |
| 	return true
 | |
| }
 | |
| 
 | |
| func (e *Evaluator) values(v *ast.ValuesExpr) bool {
 | |
| 	v.SetDatum(*v.Column.GetDatum())
 | |
| 	return true
 | |
| }
 | |
| 
 | |
| func (e *Evaluator) variable(v *ast.VariableExpr) bool {
 | |
| 	name := strings.ToLower(v.Name)
 | |
| 	sessionVars := variable.GetSessionVars(e.ctx)
 | |
| 	globalVars := variable.GetGlobalVarAccessor(e.ctx)
 | |
| 	if !v.IsSystem {
 | |
| 		// user vars
 | |
| 		if value, ok := sessionVars.Users[name]; ok {
 | |
| 			v.SetString(value)
 | |
| 			return true
 | |
| 		}
 | |
| 		// select null user vars is permitted.
 | |
| 		v.SetNull()
 | |
| 		return true
 | |
| 	}
 | |
| 
 | |
| 	_, ok := variable.SysVars[name]
 | |
| 	if !ok {
 | |
| 		// select null sys vars is not permitted
 | |
| 		e.err = variable.UnknownSystemVar.Gen("Unknown system variable '%s'", name)
 | |
| 		return false
 | |
| 	}
 | |
| 
 | |
| 	if !v.IsGlobal {
 | |
| 		if value, ok := sessionVars.Systems[name]; ok {
 | |
| 			v.SetString(value)
 | |
| 			return true
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	value, err := globalVars.GetGlobalSysVar(e.ctx, name)
 | |
| 	if err != nil {
 | |
| 		e.err = errors.Trace(err)
 | |
| 		return false
 | |
| 	}
 | |
| 
 | |
| 	v.SetString(value)
 | |
| 	return true
 | |
| }
 | |
| 
 | |
| func (e *Evaluator) funcCall(v *ast.FuncCallExpr) bool {
 | |
| 	of, ok := OldFuncs[v.FnName.L]
 | |
| 	if ok {
 | |
| 		if len(v.Args) < of.MinArgs || (of.MaxArgs != -1 && len(v.Args) > of.MaxArgs) {
 | |
| 			e.err = ErrInvalidOperation.Gen("number of function arguments must in [%d, %d].", of.MinArgs, of.MaxArgs)
 | |
| 			return false
 | |
| 		}
 | |
| 		a := make([]interface{}, len(v.Args))
 | |
| 		for i, arg := range v.Args {
 | |
| 			a[i] = arg.GetValue()
 | |
| 		}
 | |
| 		val, err := of.F(a, e.ctx)
 | |
| 		if err != nil {
 | |
| 			e.err = errors.Trace(err)
 | |
| 			return false
 | |
| 		}
 | |
| 		v.SetValue(val)
 | |
| 		return true
 | |
| 	}
 | |
| 	f, ok := Funcs[v.FnName.L]
 | |
| 	if !ok {
 | |
| 		e.err = ErrInvalidOperation.Gen("unknown function %s", v.FnName.O)
 | |
| 		return false
 | |
| 	}
 | |
| 	if len(v.Args) < f.MinArgs || (f.MaxArgs != -1 && len(v.Args) > f.MaxArgs) {
 | |
| 		e.err = ErrInvalidOperation.Gen("number of function arguments must in [%d, %d].", f.MinArgs, f.MaxArgs)
 | |
| 		return false
 | |
| 	}
 | |
| 	a := make([]types.Datum, len(v.Args))
 | |
| 	for i, arg := range v.Args {
 | |
| 		a[i] = *arg.GetDatum()
 | |
| 	}
 | |
| 	val, err := f.F(a, e.ctx)
 | |
| 	if err != nil {
 | |
| 		e.err = errors.Trace(err)
 | |
| 		return false
 | |
| 	}
 | |
| 	v.SetDatum(val)
 | |
| 	return true
 | |
| }
 | |
| 
 | |
| func (e *Evaluator) funcCast(v *ast.FuncCastExpr) bool {
 | |
| 	value := v.Expr.GetValue()
 | |
| 	// Casting nil to any type returns null
 | |
| 	if value == nil {
 | |
| 		v.SetNull()
 | |
| 		return true
 | |
| 	}
 | |
| 	var err error
 | |
| 	value, err = types.Cast(value, v.Tp)
 | |
| 	if err != nil {
 | |
| 		e.err = errors.Trace(err)
 | |
| 		return false
 | |
| 	}
 | |
| 	v.SetValue(value)
 | |
| 	return true
 | |
| }
 | |
| 
 | |
| func (e *Evaluator) aggregateFunc(v *ast.AggregateFuncExpr) bool {
 | |
| 	name := strings.ToLower(v.F)
 | |
| 	switch name {
 | |
| 	case ast.AggFuncAvg:
 | |
| 		e.evalAggAvg(v)
 | |
| 	case ast.AggFuncCount:
 | |
| 		e.evalAggCount(v)
 | |
| 	case ast.AggFuncFirstRow, ast.AggFuncMax, ast.AggFuncMin, ast.AggFuncSum:
 | |
| 		e.evalAggSetValue(v)
 | |
| 	case ast.AggFuncGroupConcat:
 | |
| 		e.evalAggGroupConcat(v)
 | |
| 	}
 | |
| 	return e.err == nil
 | |
| }
 | |
| 
 | |
| func (e *Evaluator) evalAggCount(v *ast.AggregateFuncExpr) {
 | |
| 	ctx := v.GetContext()
 | |
| 	v.SetInt64(ctx.Count)
 | |
| }
 | |
| 
 | |
| func (e *Evaluator) evalAggSetValue(v *ast.AggregateFuncExpr) {
 | |
| 	ctx := v.GetContext()
 | |
| 	v.SetValue(ctx.Value)
 | |
| }
 | |
| 
 | |
| func (e *Evaluator) evalAggAvg(v *ast.AggregateFuncExpr) {
 | |
| 	ctx := v.GetContext()
 | |
| 	switch x := ctx.Value.(type) {
 | |
| 	case float64:
 | |
| 		ctx.Value = x / float64(ctx.Count)
 | |
| 	case mysql.Decimal:
 | |
| 		ctx.Value = x.Div(mysql.NewDecimalFromUint(uint64(ctx.Count), 0))
 | |
| 	}
 | |
| 	v.SetValue(ctx.Value)
 | |
| }
 | |
| 
 | |
| func (e *Evaluator) evalAggGroupConcat(v *ast.AggregateFuncExpr) {
 | |
| 	ctx := v.GetContext()
 | |
| 	if ctx.Buffer != nil {
 | |
| 		v.SetValue(ctx.Buffer.String())
 | |
| 	} else {
 | |
| 		v.SetValue(nil)
 | |
| 	}
 | |
| }
 |