* Dropped unused codekit config * Integrated dynamic and static bindata for public * Ignore public bindata * Add a general generate make task * Integrated flexible public assets into web command * Updated vendoring, added all missiong govendor deps * Made the linter happy with the bindata and dynamic code * Moved public bindata definition to modules directory * Ignoring the new bindata path now * Updated to the new public modules import path * Updated public bindata command and drop the new prefix
		
			
				
	
	
		
			925 lines
		
	
	
		
			28 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			925 lines
		
	
	
		
			28 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| // Copyright 2015 PingCAP, Inc.
 | |
| //
 | |
| // Licensed under the Apache License, Version 2.0 (the "License");
 | |
| // you may not use this file except in compliance with the License.
 | |
| // You may obtain a copy of the License at
 | |
| //
 | |
| //     http://www.apache.org/licenses/LICENSE-2.0
 | |
| //
 | |
| // Unless required by applicable law or agreed to in writing, software
 | |
| // distributed under the License is distributed on an "AS IS" BASIS,
 | |
| // See the License for the specific language governing permissions and
 | |
| // limitations under the License.
 | |
| 
 | |
| package optimizer
 | |
| 
 | |
| import (
 | |
| 	"fmt"
 | |
| 
 | |
| 	"github.com/juju/errors"
 | |
| 	"github.com/pingcap/tidb/ast"
 | |
| 	"github.com/pingcap/tidb/column"
 | |
| 	"github.com/pingcap/tidb/context"
 | |
| 	"github.com/pingcap/tidb/infoschema"
 | |
| 	"github.com/pingcap/tidb/model"
 | |
| 	"github.com/pingcap/tidb/mysql"
 | |
| 	"github.com/pingcap/tidb/sessionctx/db"
 | |
| 	"github.com/pingcap/tidb/util/types"
 | |
| )
 | |
| 
 | |
| // ResolveName resolves table name and column name.
 | |
| // It generates ResultFields for ResultSetNode and resolves ColumnNameExpr to a ResultField.
 | |
| func ResolveName(node ast.Node, info infoschema.InfoSchema, ctx context.Context) error {
 | |
| 	defaultSchema := db.GetCurrentSchema(ctx)
 | |
| 	resolver := nameResolver{Info: info, Ctx: ctx, DefaultSchema: model.NewCIStr(defaultSchema)}
 | |
| 	node.Accept(&resolver)
 | |
| 	return errors.Trace(resolver.Err)
 | |
| }
 | |
| 
 | |
| // nameResolver is the visitor to resolve table name and column name.
 | |
| // In general, a reference can only refer to information that are available for it.
 | |
| // So children elements are visited in the order that previous elements make information
 | |
| // available for following elements.
 | |
| //
 | |
| // During visiting, information are collected and stored in resolverContext.
 | |
| // When we enter a subquery, a new resolverContext is pushed to the contextStack, so subquery
 | |
| // information can overwrite outer query information. When we look up for a column reference,
 | |
| // we look up from top to bottom in the contextStack.
 | |
| type nameResolver struct {
 | |
| 	Info            infoschema.InfoSchema
 | |
| 	Ctx             context.Context
 | |
| 	DefaultSchema   model.CIStr
 | |
| 	Err             error
 | |
| 	useOuterContext bool
 | |
| 
 | |
| 	contextStack []*resolverContext
 | |
| }
 | |
| 
 | |
| // resolverContext stores information in a single level of select statement
 | |
| // that table name and column name can be resolved.
 | |
| type resolverContext struct {
 | |
| 	/* For Select Statement. */
 | |
| 	// table map to lookup and check table name conflict.
 | |
| 	tableMap map[string]int
 | |
| 	// table map to lookup and check derived-table(subselect) name conflict.
 | |
| 	derivedTableMap map[string]int
 | |
| 	// tableSources collected in from clause.
 | |
| 	tables []*ast.TableSource
 | |
| 	// result fields collected in select field list.
 | |
| 	fieldList []*ast.ResultField
 | |
| 	// result fields collected in group by clause.
 | |
| 	groupBy []*ast.ResultField
 | |
| 
 | |
| 	// The join node stack is used by on condition to find out
 | |
| 	// available tables to reference. On condition can only
 | |
| 	// refer to tables involved in current join.
 | |
| 	joinNodeStack []*ast.Join
 | |
| 
 | |
| 	// When visiting TableRefs, tables in this context are not available
 | |
| 	// because it is being collected.
 | |
| 	inTableRefs bool
 | |
| 	// When visiting on conditon only tables in current join node are available.
 | |
| 	inOnCondition bool
 | |
| 	// When visiting field list, fieldList in this context are not available.
 | |
| 	inFieldList bool
 | |
| 	// When visiting group by, groupBy fields are not available.
 | |
| 	inGroupBy bool
 | |
| 	// When visiting having, only fieldList and groupBy fields are available.
 | |
| 	inHaving bool
 | |
| 	// When visiting having, checks if the expr is an aggregate function expr.
 | |
| 	inHavingAgg bool
 | |
| 	// OrderBy clause has different resolving rule than group by.
 | |
| 	inOrderBy bool
 | |
| 	// When visiting column name in ByItem, we should know if the column name is in an expression.
 | |
| 	inByItemExpression bool
 | |
| 	// If subquery use outer context.
 | |
| 	useOuterContext bool
 | |
| 	// When visiting multi-table delete stmt table list.
 | |
| 	inDeleteTableList bool
 | |
| 	// When visiting create/drop table statement.
 | |
| 	inCreateOrDropTable bool
 | |
| 	// When visiting show statement.
 | |
| 	inShow bool
 | |
| }
 | |
| 
 | |
| // currentContext gets the current resolverContext.
 | |
| func (nr *nameResolver) currentContext() *resolverContext {
 | |
| 	stackLen := len(nr.contextStack)
 | |
| 	if stackLen == 0 {
 | |
| 		return nil
 | |
| 	}
 | |
| 	return nr.contextStack[stackLen-1]
 | |
| }
 | |
| 
 | |
| // pushContext is called when we enter a statement.
 | |
| func (nr *nameResolver) pushContext() {
 | |
| 	nr.contextStack = append(nr.contextStack, &resolverContext{
 | |
| 		tableMap:        map[string]int{},
 | |
| 		derivedTableMap: map[string]int{},
 | |
| 	})
 | |
| }
 | |
| 
 | |
| // popContext is called when we leave a statement.
 | |
| func (nr *nameResolver) popContext() {
 | |
| 	nr.contextStack = nr.contextStack[:len(nr.contextStack)-1]
 | |
| }
 | |
| 
 | |
| // pushJoin is called when we enter a join node.
 | |
| func (nr *nameResolver) pushJoin(j *ast.Join) {
 | |
| 	ctx := nr.currentContext()
 | |
| 	ctx.joinNodeStack = append(ctx.joinNodeStack, j)
 | |
| }
 | |
| 
 | |
| // popJoin is called when we leave a join node.
 | |
| func (nr *nameResolver) popJoin() {
 | |
| 	ctx := nr.currentContext()
 | |
| 	ctx.joinNodeStack = ctx.joinNodeStack[:len(ctx.joinNodeStack)-1]
 | |
| }
 | |
| 
 | |
| // Enter implements ast.Visitor interface.
 | |
| func (nr *nameResolver) Enter(inNode ast.Node) (outNode ast.Node, skipChildren bool) {
 | |
| 	switch v := inNode.(type) {
 | |
| 	case *ast.AdminStmt:
 | |
| 		nr.pushContext()
 | |
| 	case *ast.AggregateFuncExpr:
 | |
| 		ctx := nr.currentContext()
 | |
| 		if ctx.inHaving {
 | |
| 			ctx.inHavingAgg = true
 | |
| 		}
 | |
| 	case *ast.AlterTableStmt:
 | |
| 		nr.pushContext()
 | |
| 	case *ast.ByItem:
 | |
| 		if _, ok := v.Expr.(*ast.ColumnNameExpr); !ok {
 | |
| 			// If ByItem is not a single column name expression,
 | |
| 			// the resolving rule is different from order by clause.
 | |
| 			nr.currentContext().inByItemExpression = true
 | |
| 		}
 | |
| 		if nr.currentContext().inGroupBy {
 | |
| 			// make sure item is not aggregate function
 | |
| 			if ast.HasAggFlag(v.Expr) {
 | |
| 				nr.Err = ErrInvalidGroupFuncUse
 | |
| 				return inNode, true
 | |
| 			}
 | |
| 		}
 | |
| 	case *ast.CreateIndexStmt:
 | |
| 		nr.pushContext()
 | |
| 	case *ast.CreateTableStmt:
 | |
| 		nr.pushContext()
 | |
| 		nr.currentContext().inCreateOrDropTable = true
 | |
| 	case *ast.DeleteStmt:
 | |
| 		nr.pushContext()
 | |
| 	case *ast.DeleteTableList:
 | |
| 		nr.currentContext().inDeleteTableList = true
 | |
| 	case *ast.DoStmt:
 | |
| 		nr.pushContext()
 | |
| 	case *ast.DropTableStmt:
 | |
| 		nr.pushContext()
 | |
| 		nr.currentContext().inCreateOrDropTable = true
 | |
| 	case *ast.DropIndexStmt:
 | |
| 		nr.pushContext()
 | |
| 	case *ast.FieldList:
 | |
| 		nr.currentContext().inFieldList = true
 | |
| 	case *ast.GroupByClause:
 | |
| 		nr.currentContext().inGroupBy = true
 | |
| 	case *ast.HavingClause:
 | |
| 		nr.currentContext().inHaving = true
 | |
| 	case *ast.InsertStmt:
 | |
| 		nr.pushContext()
 | |
| 	case *ast.Join:
 | |
| 		nr.pushJoin(v)
 | |
| 	case *ast.OnCondition:
 | |
| 		nr.currentContext().inOnCondition = true
 | |
| 	case *ast.OrderByClause:
 | |
| 		nr.currentContext().inOrderBy = true
 | |
| 	case *ast.SelectStmt:
 | |
| 		nr.pushContext()
 | |
| 	case *ast.SetStmt:
 | |
| 		for _, assign := range v.Variables {
 | |
| 			if cn, ok := assign.Value.(*ast.ColumnNameExpr); ok && cn.Name.Table.L == "" {
 | |
| 				// Convert column name expression to string value expression.
 | |
| 				assign.Value = ast.NewValueExpr(cn.Name.Name.O)
 | |
| 			}
 | |
| 		}
 | |
| 		nr.pushContext()
 | |
| 	case *ast.ShowStmt:
 | |
| 		nr.pushContext()
 | |
| 		nr.currentContext().inShow = true
 | |
| 		nr.fillShowFields(v)
 | |
| 	case *ast.TableRefsClause:
 | |
| 		nr.currentContext().inTableRefs = true
 | |
| 	case *ast.TruncateTableStmt:
 | |
| 		nr.pushContext()
 | |
| 	case *ast.UnionStmt:
 | |
| 		nr.pushContext()
 | |
| 	case *ast.UpdateStmt:
 | |
| 		nr.pushContext()
 | |
| 	}
 | |
| 	return inNode, false
 | |
| }
 | |
| 
 | |
| // Leave implements ast.Visitor interface.
 | |
| func (nr *nameResolver) Leave(inNode ast.Node) (node ast.Node, ok bool) {
 | |
| 	switch v := inNode.(type) {
 | |
| 	case *ast.AdminStmt:
 | |
| 		nr.popContext()
 | |
| 	case *ast.AggregateFuncExpr:
 | |
| 		ctx := nr.currentContext()
 | |
| 		if ctx.inHaving {
 | |
| 			ctx.inHavingAgg = false
 | |
| 		}
 | |
| 	case *ast.AlterTableStmt:
 | |
| 		nr.popContext()
 | |
| 	case *ast.TableName:
 | |
| 		nr.handleTableName(v)
 | |
| 	case *ast.ColumnNameExpr:
 | |
| 		nr.handleColumnName(v)
 | |
| 	case *ast.CreateIndexStmt:
 | |
| 		nr.popContext()
 | |
| 	case *ast.CreateTableStmt:
 | |
| 		nr.popContext()
 | |
| 	case *ast.DeleteTableList:
 | |
| 		nr.currentContext().inDeleteTableList = false
 | |
| 	case *ast.DoStmt:
 | |
| 		nr.popContext()
 | |
| 	case *ast.DropIndexStmt:
 | |
| 		nr.popContext()
 | |
| 	case *ast.DropTableStmt:
 | |
| 		nr.popContext()
 | |
| 	case *ast.TableSource:
 | |
| 		nr.handleTableSource(v)
 | |
| 	case *ast.OnCondition:
 | |
| 		nr.currentContext().inOnCondition = false
 | |
| 	case *ast.Join:
 | |
| 		nr.handleJoin(v)
 | |
| 		nr.popJoin()
 | |
| 	case *ast.TableRefsClause:
 | |
| 		nr.currentContext().inTableRefs = false
 | |
| 	case *ast.FieldList:
 | |
| 		nr.handleFieldList(v)
 | |
| 		nr.currentContext().inFieldList = false
 | |
| 	case *ast.GroupByClause:
 | |
| 		ctx := nr.currentContext()
 | |
| 		ctx.inGroupBy = false
 | |
| 		for _, item := range v.Items {
 | |
| 			switch x := item.Expr.(type) {
 | |
| 			case *ast.ColumnNameExpr:
 | |
| 				ctx.groupBy = append(ctx.groupBy, x.Refer)
 | |
| 			}
 | |
| 		}
 | |
| 	case *ast.HavingClause:
 | |
| 		nr.currentContext().inHaving = false
 | |
| 	case *ast.OrderByClause:
 | |
| 		nr.currentContext().inOrderBy = false
 | |
| 	case *ast.ByItem:
 | |
| 		nr.currentContext().inByItemExpression = false
 | |
| 	case *ast.PositionExpr:
 | |
| 		nr.handlePosition(v)
 | |
| 	case *ast.SelectStmt:
 | |
| 		ctx := nr.currentContext()
 | |
| 		v.SetResultFields(ctx.fieldList)
 | |
| 		if ctx.useOuterContext {
 | |
| 			nr.useOuterContext = true
 | |
| 		}
 | |
| 		nr.popContext()
 | |
| 	case *ast.SetStmt:
 | |
| 		nr.popContext()
 | |
| 	case *ast.ShowStmt:
 | |
| 		nr.popContext()
 | |
| 	case *ast.SubqueryExpr:
 | |
| 		if nr.useOuterContext {
 | |
| 			// TODO: check this
 | |
| 			// If there is a deep nest of subquery, there may be something wrong.
 | |
| 			v.UseOuterContext = true
 | |
| 			nr.useOuterContext = false
 | |
| 		}
 | |
| 	case *ast.TruncateTableStmt:
 | |
| 		nr.popContext()
 | |
| 	case *ast.UnionStmt:
 | |
| 		ctx := nr.currentContext()
 | |
| 		v.SetResultFields(ctx.fieldList)
 | |
| 		if ctx.useOuterContext {
 | |
| 			nr.useOuterContext = true
 | |
| 		}
 | |
| 		nr.popContext()
 | |
| 	case *ast.UnionSelectList:
 | |
| 		nr.handleUnionSelectList(v)
 | |
| 	case *ast.InsertStmt:
 | |
| 		nr.popContext()
 | |
| 	case *ast.DeleteStmt:
 | |
| 		nr.popContext()
 | |
| 	case *ast.UpdateStmt:
 | |
| 		nr.popContext()
 | |
| 	}
 | |
| 	return inNode, nr.Err == nil
 | |
| }
 | |
| 
 | |
| // handleTableName looks up and sets the schema information and result fields for table name.
 | |
| func (nr *nameResolver) handleTableName(tn *ast.TableName) {
 | |
| 	if tn.Schema.L == "" {
 | |
| 		tn.Schema = nr.DefaultSchema
 | |
| 	}
 | |
| 	ctx := nr.currentContext()
 | |
| 	if ctx.inCreateOrDropTable {
 | |
| 		// The table may not exist in create table or drop table statement.
 | |
| 		// Skip resolving the table to avoid error.
 | |
| 		return
 | |
| 	}
 | |
| 	if ctx.inDeleteTableList {
 | |
| 		idx, ok := ctx.tableMap[nr.tableUniqueName(tn.Schema, tn.Name)]
 | |
| 		if !ok {
 | |
| 			nr.Err = errors.Errorf("Unknown table %s", tn.Name.O)
 | |
| 			return
 | |
| 		}
 | |
| 		ts := ctx.tables[idx]
 | |
| 		tableName := ts.Source.(*ast.TableName)
 | |
| 		tn.DBInfo = tableName.DBInfo
 | |
| 		tn.TableInfo = tableName.TableInfo
 | |
| 		tn.SetResultFields(tableName.GetResultFields())
 | |
| 		return
 | |
| 	}
 | |
| 	table, err := nr.Info.TableByName(tn.Schema, tn.Name)
 | |
| 	if err != nil {
 | |
| 		nr.Err = errors.Trace(err)
 | |
| 		return
 | |
| 	}
 | |
| 	tn.TableInfo = table.Meta()
 | |
| 	dbInfo, _ := nr.Info.SchemaByName(tn.Schema)
 | |
| 	tn.DBInfo = dbInfo
 | |
| 
 | |
| 	rfs := make([]*ast.ResultField, 0, len(tn.TableInfo.Columns))
 | |
| 	for _, v := range tn.TableInfo.Columns {
 | |
| 		if v.State != model.StatePublic {
 | |
| 			continue
 | |
| 		}
 | |
| 		expr := &ast.ValueExpr{}
 | |
| 		expr.SetType(&v.FieldType)
 | |
| 		rf := &ast.ResultField{
 | |
| 			Column:    v,
 | |
| 			Table:     tn.TableInfo,
 | |
| 			DBName:    tn.Schema,
 | |
| 			Expr:      expr,
 | |
| 			TableName: tn,
 | |
| 		}
 | |
| 		rfs = append(rfs, rf)
 | |
| 	}
 | |
| 	tn.SetResultFields(rfs)
 | |
| 	return
 | |
| }
 | |
| 
 | |
| // handleTableSources checks name duplication
 | |
| // and puts the table source in current resolverContext.
 | |
| // Note:
 | |
| // "select * from t as a join (select 1) as a;" is not duplicate.
 | |
| // "select * from t as a join t as a;" is duplicate.
 | |
| // "select * from (select 1) as a join (select 1) as a;" is duplicate.
 | |
| func (nr *nameResolver) handleTableSource(ts *ast.TableSource) {
 | |
| 	for _, v := range ts.GetResultFields() {
 | |
| 		v.TableAsName = ts.AsName
 | |
| 	}
 | |
| 	ctx := nr.currentContext()
 | |
| 	switch ts.Source.(type) {
 | |
| 	case *ast.TableName:
 | |
| 		var name string
 | |
| 		if ts.AsName.L != "" {
 | |
| 			name = ts.AsName.L
 | |
| 		} else {
 | |
| 			tableName := ts.Source.(*ast.TableName)
 | |
| 			name = nr.tableUniqueName(tableName.Schema, tableName.Name)
 | |
| 		}
 | |
| 		if _, ok := ctx.tableMap[name]; ok {
 | |
| 			nr.Err = errors.Errorf("duplicated table/alias name %s", name)
 | |
| 			return
 | |
| 		}
 | |
| 		ctx.tableMap[name] = len(ctx.tables)
 | |
| 	case *ast.SelectStmt:
 | |
| 		name := ts.AsName.L
 | |
| 		if _, ok := ctx.derivedTableMap[name]; ok {
 | |
| 			nr.Err = errors.Errorf("duplicated table/alias name %s", name)
 | |
| 			return
 | |
| 		}
 | |
| 		ctx.derivedTableMap[name] = len(ctx.tables)
 | |
| 	}
 | |
| 	dupNames := make(map[string]struct{}, len(ts.GetResultFields()))
 | |
| 	for _, f := range ts.GetResultFields() {
 | |
| 		// duplicate column name in one table is not allowed.
 | |
| 		// "select * from (select 1, 1) as a;" is duplicate.
 | |
| 		name := f.ColumnAsName.L
 | |
| 		if name == "" {
 | |
| 			name = f.Column.Name.L
 | |
| 		}
 | |
| 		if _, ok := dupNames[name]; ok {
 | |
| 			nr.Err = errors.Errorf("Duplicate column name '%s'", name)
 | |
| 			return
 | |
| 		}
 | |
| 		dupNames[name] = struct{}{}
 | |
| 	}
 | |
| 	ctx.tables = append(ctx.tables, ts)
 | |
| 	return
 | |
| }
 | |
| 
 | |
| // handleJoin sets result fields for join.
 | |
| func (nr *nameResolver) handleJoin(j *ast.Join) {
 | |
| 	if j.Right == nil {
 | |
| 		j.SetResultFields(j.Left.GetResultFields())
 | |
| 		return
 | |
| 	}
 | |
| 	leftLen := len(j.Left.GetResultFields())
 | |
| 	rightLen := len(j.Right.GetResultFields())
 | |
| 	rfs := make([]*ast.ResultField, leftLen+rightLen)
 | |
| 	copy(rfs, j.Left.GetResultFields())
 | |
| 	copy(rfs[leftLen:], j.Right.GetResultFields())
 | |
| 	j.SetResultFields(rfs)
 | |
| }
 | |
| 
 | |
| // handleColumnName looks up and sets ResultField for
 | |
| // the column name.
 | |
| func (nr *nameResolver) handleColumnName(cn *ast.ColumnNameExpr) {
 | |
| 	ctx := nr.currentContext()
 | |
| 	if ctx.inOnCondition {
 | |
| 		// In on condition, only tables within current join is available.
 | |
| 		nr.resolveColumnNameInOnCondition(cn)
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	// Try to resolve the column name form top to bottom in the context stack.
 | |
| 	for i := len(nr.contextStack) - 1; i >= 0; i-- {
 | |
| 		if nr.resolveColumnNameInContext(nr.contextStack[i], cn) {
 | |
| 			// Column is already resolved or encountered an error.
 | |
| 			if i < len(nr.contextStack)-1 {
 | |
| 				// If in subselect, the query use outer query.
 | |
| 				nr.currentContext().useOuterContext = true
 | |
| 			}
 | |
| 			return
 | |
| 		}
 | |
| 	}
 | |
| 	nr.Err = errors.Errorf("unknown column %s", cn.Name.Name.L)
 | |
| }
 | |
| 
 | |
| // resolveColumnNameInContext looks up and sets ResultField for a column with the ctx.
 | |
| func (nr *nameResolver) resolveColumnNameInContext(ctx *resolverContext, cn *ast.ColumnNameExpr) bool {
 | |
| 	if ctx.inTableRefs {
 | |
| 		// In TableRefsClause, column reference only in join on condition which is handled before.
 | |
| 		return false
 | |
| 	}
 | |
| 	if ctx.inFieldList {
 | |
| 		// only resolve column using tables.
 | |
| 		return nr.resolveColumnInTableSources(cn, ctx.tables)
 | |
| 	}
 | |
| 	if ctx.inGroupBy {
 | |
| 		// From tables first, then field list.
 | |
| 		// If ctx.InByItemExpression is true, the item is not an identifier.
 | |
| 		// Otherwise it is an identifier.
 | |
| 		if ctx.inByItemExpression {
 | |
| 			// From table first, then field list.
 | |
| 			if nr.resolveColumnInTableSources(cn, ctx.tables) {
 | |
| 				return true
 | |
| 			}
 | |
| 			found := nr.resolveColumnInResultFields(ctx, cn, ctx.fieldList)
 | |
| 			if nr.Err == nil && found {
 | |
| 				// Check if resolved refer is an aggregate function expr.
 | |
| 				if _, ok := cn.Refer.Expr.(*ast.AggregateFuncExpr); ok {
 | |
| 					nr.Err = ErrIllegalReference.Gen("Reference '%s' not supported (reference to group function)", cn.Name.Name.O)
 | |
| 				}
 | |
| 			}
 | |
| 			return found
 | |
| 		}
 | |
| 		// Resolve from table first, then from select list.
 | |
| 		found := nr.resolveColumnInTableSources(cn, ctx.tables)
 | |
| 		if nr.Err != nil {
 | |
| 			return found
 | |
| 		}
 | |
| 		// We should copy the refer here.
 | |
| 		// Because if the ByItem is an identifier, we should check if it
 | |
| 		// is ambiguous even it is already resolved from table source.
 | |
| 		// If the ByItem is not an identifier, we do not need the second check.
 | |
| 		r := cn.Refer
 | |
| 		if nr.resolveColumnInResultFields(ctx, cn, ctx.fieldList) {
 | |
| 			if nr.Err != nil {
 | |
| 				return true
 | |
| 			}
 | |
| 			if r != nil {
 | |
| 				// It is not ambiguous and already resolved from table source.
 | |
| 				// We should restore its Refer.
 | |
| 				cn.Refer = r
 | |
| 			}
 | |
| 			if _, ok := cn.Refer.Expr.(*ast.AggregateFuncExpr); ok {
 | |
| 				nr.Err = ErrIllegalReference.Gen("Reference '%s' not supported (reference to group function)", cn.Name.Name.O)
 | |
| 			}
 | |
| 			return true
 | |
| 		}
 | |
| 		return found
 | |
| 	}
 | |
| 	if ctx.inHaving {
 | |
| 		// First group by, then field list.
 | |
| 		if nr.resolveColumnInResultFields(ctx, cn, ctx.groupBy) {
 | |
| 			return true
 | |
| 		}
 | |
| 		if ctx.inHavingAgg {
 | |
| 			// If cn is in an aggregate function in having clause, check tablesource first.
 | |
| 			if nr.resolveColumnInTableSources(cn, ctx.tables) {
 | |
| 				return true
 | |
| 			}
 | |
| 		}
 | |
| 		return nr.resolveColumnInResultFields(ctx, cn, ctx.fieldList)
 | |
| 	}
 | |
| 	if ctx.inOrderBy {
 | |
| 		if nr.resolveColumnInResultFields(ctx, cn, ctx.groupBy) {
 | |
| 			return true
 | |
| 		}
 | |
| 		if ctx.inByItemExpression {
 | |
| 			// From table first, then field list.
 | |
| 			if nr.resolveColumnInTableSources(cn, ctx.tables) {
 | |
| 				return true
 | |
| 			}
 | |
| 			return nr.resolveColumnInResultFields(ctx, cn, ctx.fieldList)
 | |
| 		}
 | |
| 		// Field list first, then from table.
 | |
| 		if nr.resolveColumnInResultFields(ctx, cn, ctx.fieldList) {
 | |
| 			return true
 | |
| 		}
 | |
| 		return nr.resolveColumnInTableSources(cn, ctx.tables)
 | |
| 	}
 | |
| 	if ctx.inShow {
 | |
| 		return nr.resolveColumnInResultFields(ctx, cn, ctx.fieldList)
 | |
| 	}
 | |
| 	// In where clause.
 | |
| 	return nr.resolveColumnInTableSources(cn, ctx.tables)
 | |
| }
 | |
| 
 | |
| // resolveColumnNameInOnCondition resolves the column name in current join.
 | |
| func (nr *nameResolver) resolveColumnNameInOnCondition(cn *ast.ColumnNameExpr) {
 | |
| 	ctx := nr.currentContext()
 | |
| 	join := ctx.joinNodeStack[len(ctx.joinNodeStack)-1]
 | |
| 	tableSources := appendTableSources(nil, join)
 | |
| 	if !nr.resolveColumnInTableSources(cn, tableSources) {
 | |
| 		nr.Err = errors.Errorf("unkown column name %s", cn.Name.Name.O)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (nr *nameResolver) resolveColumnInTableSources(cn *ast.ColumnNameExpr, tableSources []*ast.TableSource) (done bool) {
 | |
| 	var matchedResultField *ast.ResultField
 | |
| 	tableNameL := cn.Name.Table.L
 | |
| 	columnNameL := cn.Name.Name.L
 | |
| 	if tableNameL != "" {
 | |
| 		var matchedTable ast.ResultSetNode
 | |
| 		for _, ts := range tableSources {
 | |
| 			if tableNameL == ts.AsName.L {
 | |
| 				// different table name.
 | |
| 				matchedTable = ts
 | |
| 				break
 | |
| 			} else if ts.AsName.L != "" {
 | |
| 				// Table as name shadows table real name.
 | |
| 				continue
 | |
| 			}
 | |
| 			if tn, ok := ts.Source.(*ast.TableName); ok {
 | |
| 				if cn.Name.Schema.L != "" && cn.Name.Schema.L != tn.Schema.L {
 | |
| 					continue
 | |
| 				}
 | |
| 				if tableNameL == tn.Name.L {
 | |
| 					matchedTable = ts
 | |
| 				}
 | |
| 			}
 | |
| 		}
 | |
| 		if matchedTable != nil {
 | |
| 			resultFields := matchedTable.GetResultFields()
 | |
| 			for _, rf := range resultFields {
 | |
| 				if rf.ColumnAsName.L == columnNameL || rf.Column.Name.L == columnNameL {
 | |
| 					// resolve column.
 | |
| 					matchedResultField = rf
 | |
| 					break
 | |
| 				}
 | |
| 			}
 | |
| 		}
 | |
| 	} else {
 | |
| 		for _, ts := range tableSources {
 | |
| 			rfs := ts.GetResultFields()
 | |
| 			for _, rf := range rfs {
 | |
| 				matchAsName := rf.ColumnAsName.L != "" && rf.ColumnAsName.L == columnNameL
 | |
| 				matchColumnName := rf.ColumnAsName.L == "" && rf.Column.Name.L == columnNameL
 | |
| 				if matchAsName || matchColumnName {
 | |
| 					if matchedResultField != nil {
 | |
| 						nr.Err = errors.Errorf("column %s is ambiguous.", cn.Name.Name.O)
 | |
| 						return true
 | |
| 					}
 | |
| 					matchedResultField = rf
 | |
| 				}
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| 	if matchedResultField != nil {
 | |
| 		// Bind column.
 | |
| 		cn.Refer = matchedResultField
 | |
| 		return true
 | |
| 	}
 | |
| 	return false
 | |
| }
 | |
| 
 | |
| func (nr *nameResolver) resolveColumnInResultFields(ctx *resolverContext, cn *ast.ColumnNameExpr, rfs []*ast.ResultField) bool {
 | |
| 	var matched *ast.ResultField
 | |
| 	for _, rf := range rfs {
 | |
| 		if cn.Name.Table.L != "" {
 | |
| 			// Check table name
 | |
| 			if rf.TableAsName.L != "" {
 | |
| 				if cn.Name.Table.L != rf.TableAsName.L {
 | |
| 					continue
 | |
| 				}
 | |
| 			} else if cn.Name.Table.L != rf.Table.Name.L {
 | |
| 				continue
 | |
| 			}
 | |
| 		}
 | |
| 		matchAsName := cn.Name.Name.L == rf.ColumnAsName.L
 | |
| 		var matchColumnName bool
 | |
| 		if ctx.inHaving {
 | |
| 			matchColumnName = cn.Name.Name.L == rf.Column.Name.L
 | |
| 		} else {
 | |
| 			matchColumnName = rf.ColumnAsName.L == "" && cn.Name.Name.L == rf.Column.Name.L
 | |
| 		}
 | |
| 		if matchAsName || matchColumnName {
 | |
| 			if rf.Column.Name.L == "" {
 | |
| 				// This is not a real table column, resolve it directly.
 | |
| 				cn.Refer = rf
 | |
| 				return true
 | |
| 			}
 | |
| 			if matched == nil {
 | |
| 				matched = rf
 | |
| 			} else {
 | |
| 				sameColumn := matched.TableName == rf.TableName && matched.Column.Name.L == rf.Column.Name.L
 | |
| 				if !sameColumn {
 | |
| 					nr.Err = errors.Errorf("column %s is ambiguous.", cn.Name.Name.O)
 | |
| 					return true
 | |
| 				}
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| 	if matched != nil {
 | |
| 		// If in GroupBy, we clone the ResultField
 | |
| 		if ctx.inGroupBy || ctx.inHaving || ctx.inOrderBy {
 | |
| 			nf := *matched
 | |
| 			expr := matched.Expr
 | |
| 			if cexpr, ok := expr.(*ast.ColumnNameExpr); ok {
 | |
| 				expr = cexpr.Refer.Expr
 | |
| 			}
 | |
| 			nf.Expr = expr
 | |
| 			matched = &nf
 | |
| 		}
 | |
| 		// Bind column.
 | |
| 		cn.Refer = matched
 | |
| 		return true
 | |
| 	}
 | |
| 	return false
 | |
| }
 | |
| 
 | |
| // handleFieldList expands wild card field and sets fieldList in current context.
 | |
| func (nr *nameResolver) handleFieldList(fieldList *ast.FieldList) {
 | |
| 	var resultFields []*ast.ResultField
 | |
| 	for _, v := range fieldList.Fields {
 | |
| 		resultFields = append(resultFields, nr.createResultFields(v)...)
 | |
| 	}
 | |
| 	nr.currentContext().fieldList = resultFields
 | |
| }
 | |
| 
 | |
| func getInnerFromParentheses(expr ast.ExprNode) ast.ExprNode {
 | |
| 	if pexpr, ok := expr.(*ast.ParenthesesExpr); ok {
 | |
| 		return getInnerFromParentheses(pexpr.Expr)
 | |
| 	}
 | |
| 	return expr
 | |
| }
 | |
| 
 | |
| // createResultFields creates result field list for a single select field.
 | |
| func (nr *nameResolver) createResultFields(field *ast.SelectField) (rfs []*ast.ResultField) {
 | |
| 	ctx := nr.currentContext()
 | |
| 	if field.WildCard != nil {
 | |
| 		if len(ctx.tables) == 0 {
 | |
| 			nr.Err = errors.New("No table used.")
 | |
| 			return
 | |
| 		}
 | |
| 		tableRfs := []*ast.ResultField{}
 | |
| 		if field.WildCard.Table.L == "" {
 | |
| 			for _, v := range ctx.tables {
 | |
| 				tableRfs = append(tableRfs, v.GetResultFields()...)
 | |
| 			}
 | |
| 		} else {
 | |
| 			name := nr.tableUniqueName(field.WildCard.Schema, field.WildCard.Table)
 | |
| 			tableIdx, ok1 := ctx.tableMap[name]
 | |
| 			derivedTableIdx, ok2 := ctx.derivedTableMap[name]
 | |
| 			if !ok1 && !ok2 {
 | |
| 				nr.Err = errors.Errorf("unknown table %s.", field.WildCard.Table.O)
 | |
| 			}
 | |
| 			if ok1 {
 | |
| 				tableRfs = ctx.tables[tableIdx].GetResultFields()
 | |
| 			}
 | |
| 			if ok2 {
 | |
| 				tableRfs = append(tableRfs, ctx.tables[derivedTableIdx].GetResultFields()...)
 | |
| 			}
 | |
| 
 | |
| 		}
 | |
| 		for _, trf := range tableRfs {
 | |
| 			// Convert it to ColumnNameExpr
 | |
| 			cn := &ast.ColumnName{
 | |
| 				Schema: trf.DBName,
 | |
| 				Table:  trf.Table.Name,
 | |
| 				Name:   trf.ColumnAsName,
 | |
| 			}
 | |
| 			cnExpr := &ast.ColumnNameExpr{
 | |
| 				Name:  cn,
 | |
| 				Refer: trf,
 | |
| 			}
 | |
| 			ast.SetFlag(cnExpr)
 | |
| 			cnExpr.SetType(trf.Expr.GetType())
 | |
| 			rf := *trf
 | |
| 			rf.Expr = cnExpr
 | |
| 			rfs = append(rfs, &rf)
 | |
| 		}
 | |
| 		return
 | |
| 	}
 | |
| 	// The column is visited before so it must has been resolved already.
 | |
| 	rf := &ast.ResultField{ColumnAsName: field.AsName}
 | |
| 	innerExpr := getInnerFromParentheses(field.Expr)
 | |
| 	switch v := innerExpr.(type) {
 | |
| 	case *ast.ColumnNameExpr:
 | |
| 		rf.Column = v.Refer.Column
 | |
| 		rf.Table = v.Refer.Table
 | |
| 		rf.DBName = v.Refer.DBName
 | |
| 		rf.TableName = v.Refer.TableName
 | |
| 		rf.Expr = v
 | |
| 	default:
 | |
| 		rf.Column = &model.ColumnInfo{} // Empty column info.
 | |
| 		rf.Table = &model.TableInfo{}   // Empty table info.
 | |
| 		rf.Expr = v
 | |
| 	}
 | |
| 	if field.AsName.L == "" {
 | |
| 		switch x := innerExpr.(type) {
 | |
| 		case *ast.ColumnNameExpr:
 | |
| 			rf.ColumnAsName = model.NewCIStr(x.Name.Name.O)
 | |
| 		case *ast.ValueExpr:
 | |
| 			if innerExpr.Text() != "" {
 | |
| 				rf.ColumnAsName = model.NewCIStr(innerExpr.Text())
 | |
| 			} else {
 | |
| 				rf.ColumnAsName = model.NewCIStr(field.Text())
 | |
| 			}
 | |
| 		default:
 | |
| 			rf.ColumnAsName = model.NewCIStr(field.Text())
 | |
| 		}
 | |
| 	}
 | |
| 	rfs = append(rfs, rf)
 | |
| 	return
 | |
| }
 | |
| 
 | |
| func appendTableSources(in []*ast.TableSource, resultSetNode ast.ResultSetNode) (out []*ast.TableSource) {
 | |
| 	switch v := resultSetNode.(type) {
 | |
| 	case *ast.TableSource:
 | |
| 		out = append(in, v)
 | |
| 	case *ast.Join:
 | |
| 		out = appendTableSources(in, v.Left)
 | |
| 		if v.Right != nil {
 | |
| 			out = appendTableSources(out, v.Right)
 | |
| 		}
 | |
| 	}
 | |
| 	return
 | |
| }
 | |
| 
 | |
| func (nr *nameResolver) tableUniqueName(schema, table model.CIStr) string {
 | |
| 	if schema.L != "" && schema.L != nr.DefaultSchema.L {
 | |
| 		return schema.L + "." + table.L
 | |
| 	}
 | |
| 	return table.L
 | |
| }
 | |
| 
 | |
| func (nr *nameResolver) handlePosition(pos *ast.PositionExpr) {
 | |
| 	ctx := nr.currentContext()
 | |
| 	if pos.N < 1 || pos.N > len(ctx.fieldList) {
 | |
| 		nr.Err = errors.Errorf("Unknown column '%d'", pos.N)
 | |
| 		return
 | |
| 	}
 | |
| 	matched := ctx.fieldList[pos.N-1]
 | |
| 	nf := *matched
 | |
| 	expr := matched.Expr
 | |
| 	if cexpr, ok := expr.(*ast.ColumnNameExpr); ok {
 | |
| 		expr = cexpr.Refer.Expr
 | |
| 	}
 | |
| 	nf.Expr = expr
 | |
| 	pos.Refer = &nf
 | |
| 	if nr.currentContext().inGroupBy {
 | |
| 		// make sure item is not aggregate function
 | |
| 		if ast.HasAggFlag(pos.Refer.Expr) {
 | |
| 			nr.Err = errors.New("group by cannot contain aggregate function")
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (nr *nameResolver) handleUnionSelectList(u *ast.UnionSelectList) {
 | |
| 	firstSelFields := u.Selects[0].GetResultFields()
 | |
| 	unionFields := make([]*ast.ResultField, len(firstSelFields))
 | |
| 	// Copy first result fields, because we may change the result field type.
 | |
| 	for i, v := range firstSelFields {
 | |
| 		rf := *v
 | |
| 		col := *v.Column
 | |
| 		rf.Column = &col
 | |
| 		if rf.Column.Flen == 0 {
 | |
| 			rf.Column.Flen = types.UnspecifiedLength
 | |
| 		}
 | |
| 		rf.Expr = &ast.ValueExpr{}
 | |
| 		unionFields[i] = &rf
 | |
| 	}
 | |
| 	nr.currentContext().fieldList = unionFields
 | |
| }
 | |
| 
 | |
| func (nr *nameResolver) fillShowFields(s *ast.ShowStmt) {
 | |
| 	if s.DBName == "" {
 | |
| 		if s.Table != nil && s.Table.Schema.L != "" {
 | |
| 			s.DBName = s.Table.Schema.O
 | |
| 		} else {
 | |
| 			s.DBName = nr.DefaultSchema.O
 | |
| 		}
 | |
| 	} else if s.Table != nil && s.Table.Schema.L == "" {
 | |
| 		s.Table.Schema = model.NewCIStr(s.DBName)
 | |
| 	}
 | |
| 	var fields []*ast.ResultField
 | |
| 	var (
 | |
| 		names  []string
 | |
| 		ftypes []byte
 | |
| 	)
 | |
| 	switch s.Tp {
 | |
| 	case ast.ShowEngines:
 | |
| 		names = []string{"Engine", "Support", "Comment", "Transactions", "XA", "Savepoints"}
 | |
| 	case ast.ShowDatabases:
 | |
| 		names = []string{"Database"}
 | |
| 	case ast.ShowTables:
 | |
| 		names = []string{fmt.Sprintf("Tables_in_%s", s.DBName)}
 | |
| 		if s.Full {
 | |
| 			names = append(names, "Table_type")
 | |
| 		}
 | |
| 	case ast.ShowTableStatus:
 | |
| 		names = []string{"Name", "Engine", "Version", "Row_format", "Rows", "Avg_row_length",
 | |
| 			"Data_length", "Max_data_length", "Index_length", "Data_free", "Auto_increment",
 | |
| 			"Create_time", "Update_time", "Check_time", "Collation", "Checksum",
 | |
| 			"Create_options", "Comment"}
 | |
| 		ftypes = []byte{mysql.TypeVarchar, mysql.TypeVarchar, mysql.TypeLonglong, mysql.TypeVarchar, mysql.TypeLonglong, mysql.TypeLonglong,
 | |
| 			mysql.TypeLonglong, mysql.TypeLonglong, mysql.TypeLonglong, mysql.TypeLonglong, mysql.TypeLonglong,
 | |
| 			mysql.TypeDatetime, mysql.TypeDatetime, mysql.TypeDatetime, mysql.TypeVarchar, mysql.TypeVarchar,
 | |
| 			mysql.TypeVarchar, mysql.TypeVarchar}
 | |
| 	case ast.ShowColumns:
 | |
| 		names = column.ColDescFieldNames(s.Full)
 | |
| 	case ast.ShowWarnings:
 | |
| 		names = []string{"Level", "Code", "Message"}
 | |
| 		ftypes = []byte{mysql.TypeVarchar, mysql.TypeLong, mysql.TypeVarchar}
 | |
| 	case ast.ShowCharset:
 | |
| 		names = []string{"Charset", "Description", "Default collation", "Maxlen"}
 | |
| 		ftypes = []byte{mysql.TypeVarchar, mysql.TypeVarchar, mysql.TypeVarchar, mysql.TypeLonglong}
 | |
| 	case ast.ShowVariables:
 | |
| 		names = []string{"Variable_name", "Value"}
 | |
| 	case ast.ShowStatus:
 | |
| 		names = []string{"Variable_name", "Value"}
 | |
| 	case ast.ShowCollation:
 | |
| 		names = []string{"Collation", "Charset", "Id", "Default", "Compiled", "Sortlen"}
 | |
| 		ftypes = []byte{mysql.TypeVarchar, mysql.TypeVarchar, mysql.TypeLonglong,
 | |
| 			mysql.TypeVarchar, mysql.TypeVarchar, mysql.TypeLonglong}
 | |
| 	case ast.ShowCreateTable:
 | |
| 		names = []string{"Table", "Create Table"}
 | |
| 	case ast.ShowGrants:
 | |
| 		names = []string{fmt.Sprintf("Grants for %s", s.User)}
 | |
| 	case ast.ShowTriggers:
 | |
| 		names = []string{"Trigger", "Event", "Table", "Statement", "Timing", "Created",
 | |
| 			"sql_mode", "Definer", "character_set_client", "collation_connection", "Database Collation"}
 | |
| 		ftypes = []byte{mysql.TypeVarchar, mysql.TypeVarchar, mysql.TypeVarchar, mysql.TypeVarchar, mysql.TypeVarchar, mysql.TypeVarchar,
 | |
| 			mysql.TypeVarchar, mysql.TypeVarchar, mysql.TypeVarchar, mysql.TypeVarchar, mysql.TypeVarchar}
 | |
| 	case ast.ShowProcedureStatus:
 | |
| 		names = []string{}
 | |
| 		ftypes = []byte{}
 | |
| 	case ast.ShowIndex:
 | |
| 		names = []string{"Table", "Non_unique", "Key_name", "Seq_in_index",
 | |
| 			"Column_name", "Collation", "Cardinality", "Sub_part", "Packed",
 | |
| 			"Null", "Index_type", "Comment", "Index_comment"}
 | |
| 		ftypes = []byte{mysql.TypeVarchar, mysql.TypeLonglong, mysql.TypeVarchar, mysql.TypeLonglong,
 | |
| 			mysql.TypeVarchar, mysql.TypeVarchar, mysql.TypeLonglong, mysql.TypeLonglong,
 | |
| 			mysql.TypeVarchar, mysql.TypeVarchar, mysql.TypeVarchar, mysql.TypeVarchar, mysql.TypeVarchar}
 | |
| 	}
 | |
| 	for i, name := range names {
 | |
| 		f := &ast.ResultField{
 | |
| 			ColumnAsName: model.NewCIStr(name),
 | |
| 			Column:       &model.ColumnInfo{}, // Empty column info.
 | |
| 			Table:        &model.TableInfo{},  // Empty table info.
 | |
| 		}
 | |
| 		if ftypes == nil || ftypes[i] == 0 {
 | |
| 			// use varchar as the default return column type
 | |
| 			f.Column.Tp = mysql.TypeVarchar
 | |
| 		} else {
 | |
| 			f.Column.Tp = ftypes[i]
 | |
| 		}
 | |
| 		f.Column.Charset, f.Column.Collate = types.DefaultCharsetForType(f.Column.Tp)
 | |
| 		f.Expr = &ast.ValueExpr{}
 | |
| 		f.Expr.SetType(&f.Column.FieldType)
 | |
| 		fields = append(fields, f)
 | |
| 	}
 | |
| 
 | |
| 	if s.Pattern != nil && s.Pattern.Expr == nil {
 | |
| 		rf := fields[0]
 | |
| 		s.Pattern.Expr = &ast.ColumnNameExpr{
 | |
| 			Name: &ast.ColumnName{Name: rf.ColumnAsName},
 | |
| 		}
 | |
| 		ast.SetFlag(s.Pattern)
 | |
| 	}
 | |
| 	s.SetResultFields(fields)
 | |
| 	nr.currentContext().fieldList = fields
 | |
| }
 |