TIDB源码分析-从github第一次提交说起(4)

继续说下一个SQL语句:

insert into t values(‘c1′,’ret1’);

入口和之前语句一样,也是个Exec类型的语句,这里就不在描述了。

直接进入Exec():

// Exec implements the stmt.Statement Exec interface.
func (s *InsertIntoStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) {
	t, err := getTable(ctx, s.TableIdent)//根据表名获得表对应信息
	if err != nil {
		return nil, errors.Trace(err)
	}

	tableCols := t.Cols()//获取表中所有字段
	cols, err := s.getColumns(tableCols)//检查SQL内的字段名是否合法,如果SQL没有指定字段名称则默认为全部字段,并检查字段是否只出现了一次
	if err != nil {
		return nil, errors.Trace(err)
	}

https://github.com/pingcap/tidb/blob/0d6f270068e8ff2aedc1c314e907771b6a508ebd/stmt/stmts/insert.go#L179-L190

然后运行:

	// Process `insert ... (select ..) `
	if s.Sel != nil {
		return s.execSelect(t, cols, ctx)
	}

https://github.com/pingcap/tidb/blob/0d6f270068e8ff2aedc1c314e907771b6a508ebd/stmt/stmts/insert.go#L192-L195

检查是否包含子select语句,如果有需要先执行select语句,这里先不说select语句,后面会讲。

然后检查是否包含set col = value的插入语法内容:

	// Process `insert ... set x=y...`
	if len(s.Setlist) > 0 {
		if len(s.Lists) > 0 {
			return nil, errors.Errorf("INSERT INTO %s: set type should not use values", s.TableIdent)
		}

		var l []expression.Expression
		for _, v := range s.Setlist {
			l = append(l, v.Expr)
		}
		s.Lists = append(s.Lists, l)
	}

https://github.com/pingcap/tidb/blob/0d6f270068e8ff2aedc1c314e907771b6a508ebd/stmt/stmts/insert.go#L197-L208

如果存在则也把这些值加入Lists的插入值列表。

然后把没有赋值的字段(/列)赋上默认值:

	m := map[interface{}]interface{}{}
	for _, v := range tableCols {
		var (
			value interface{}
			ok    bool
		)
		value, ok, err = getDefaultValue(ctx, v)
		if ok {
			if err != nil {
				return nil, errors.Trace(err)
			}

			m[v.Name.L] = value
		}
	}

https://github.com/pingcap/tidb/blob/0d6f270068e8ff2aedc1c314e907771b6a508ebd/stmt/stmts/insert.go#L210-L224

然后开始对插入的每行数据做进一步处理:

	insertValueCount := len(s.Lists[0])
	for i, list := range s.Lists {//遍历Lists,一个list的大小为2,包含一个列名称和一个值数据/表达式
		r := make([]interface{}, len(tableCols))
		valueCount := len(list)

https://github.com/pingcap/tidb/blob/0d6f270068e8ff2aedc1c314e907771b6a508ebd/stmt/stmts/insert.go#L226-L229

然后检查list大小是否合法:

		if insertValueCount != valueCount {
			// "insert into t values (), ()" is valid.
			// "insert into t values (), (1)" is not valid.
			// "insert into t values (1), ()" is not valid.
			// "insert into t values (1,2), (1)" is not valid.
			// So the value count must be same for all insert list.
			return nil, errors.Errorf("Column count doesn't match value count at row %d", i+1)
		}

		if valueCount == 0 && len(s.ColNames) > 0 {
			// "insert into t (c1) values ()" is not valid.
			return nil, errors.Errorf("INSERT INTO %s: expected %d value(s), have %d", s.TableIdent, len(s.ColNames), 0)
		} else if valueCount > 0 && valueCount != len(cols) {
			return nil, errors.Errorf("INSERT INTO %s: expected %d value(s), have %d", s.TableIdent, len(cols), valueCount)
		}

https://github.com/pingcap/tidb/blob/0d6f270068e8ff2aedc1c314e907771b6a508ebd/stmt/stmts/insert.go#L231-L245

接下获取各个字段值,或者计算出对应表达式值(如各种内置函数等),并把结果放在r变量:

		// Clear last insert id.
		variable.GetSessionVars(ctx).SetLastInsertID(0)

		marked := make(map[int]struct{}, len(list))
		for i, expr := range list {
			// For "insert into t values (default)" Default Eval.
			m[expressions.ExprEvalDefaultName] = cols[i].Name.O

			val, err := expr.Eval(ctx, m)//计算表达式的值(也可能只是一个value)
			if err != nil {
				return nil, errors.Trace(err)
			}
			r[cols[i].Offset] = val//根据偏移量赋值给r
			marked[cols[i].Offset] = struct{}{}
		}

https://github.com/pingcap/tidb/blob/0d6f270068e8ff2aedc1c314e907771b6a508ebd/stmt/stmts/insert.go#L247-L261

接着对SQL没有提到的字段在r中赋值为默认值:

		if err := s.initDefaultValues(ctx, t, tableCols, r, marked); err != nil {
			return nil, errors.Trace(err)
		}

https://github.com/pingcap/tidb/blob/0d6f270068e8ff2aedc1c314e907771b6a508ebd/stmt/stmts/insert.go#L262-L265

然后进入s.initDefaultValues:

func (s *InsertIntoStmt) initDefaultValues(ctx context.Context, t table.Table, cols []*column.Col, row []interface{}, marked map[int]struct{}) error {
	var err error
	var defaultValueCols []*column.Col
	for i, c := range cols {
		if row[i] != nil {
			// Column value is not nil, continue.
			continue
		}

		// If the nil value is evaluated in insert list, we will use nil except auto increment column.
		if _, ok := marked[i]; ok && !mysql.HasAutoIncrementFlag(c.Flag) && !mysql.HasTimestampFlag(c.Flag) {
			continue
		}
                //对自动递增的值或者时间递增
		if mysql.HasAutoIncrementFlag(c.Flag) {
			var id int64
			if id, err = t.AllocAutoID(); err != nil {
				return errors.Trace(err)
			}
			row[i] = id
			variable.GetSessionVars(ctx).SetLastInsertID(uint64(id))
		} else {
			var value interface{}
			value, _, err = getDefaultValue(ctx, c)
			if err != nil {
				return errors.Trace(err)
			}

			row[i] = value
		}

		defaultValueCols = append(defaultValueCols, c)
	}

	if err = column.CastValues(ctx, row, defaultValueCols); err != nil {//把默认值映射到row
		return errors.Trace(err)
	}

	return nil
}

https://github.com/pingcap/tidb/blob/0d6f270068e8ff2aedc1c314e907771b6a508ebd/stmt/stmts/insert.go#L306-L345

然后回到:

		if err = column.CastValues(ctx, r, cols); err != nil {
			return nil, errors.Trace(err)
		}
		if err = column.CheckNotNull(tableCols, r); err != nil {
			return nil, errors.Trace(err)
		}

https://github.com/pingcap/tidb/blob/0d6f270068e8ff2aedc1c314e907771b6a508ebd/stmt/stmts/insert.go#L267-L272

CastValues(ctx, r, cols):映射cols到r:

// CastValues casts values based on columns type.
func CastValues(ctx context.Context, rec []interface{}, cols []*Col) (err error) {
	for _, c := range cols {
		rec[c.Offset], err = c.CastValue(ctx, rec[c.Offset])
		if err != nil {
			return
		}
	}
	return
}

https://github.com/pingcap/tidb/blob/0d6f270068e8ff2aedc1c314e907771b6a508ebd/column/column.go#L546-L555

并且在c.CastValue(ctx, rec[c.Offset]),根据具体类型做转换:

// CastValue casts a value based on column's type.
func (c *Col) CastValue(ctx context.Context, val interface{}) (casted interface{}, err error) {
	if val == nil {
		return
	}
	switch c.Tp {
	case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong, mysql.TypeYear:
		intVal, errCode := c.normalizeIntegerValue(val)
		if errCode == errCodeType {
			casted = intVal
			err = c.TypeError(val)
			return
		}
		return c.castIntegerValue(intVal, errCode)
	case mysql.TypeFloat, mysql.TypeDouble:
		return c.castFloatValue(val)
	case mysql.TypeDate, mysql.TypeDatetime, mysql.TypeTimestamp:
		switch v := val.(type) {
		case int64:
			casted, err = mysql.ParseTimeFromNum(v, c.Tp, c.Decimal)
			if err != nil {
				err = newParseColError(err, c)
			}
		case string:
			casted, err = mysql.ParseTime(v, c.Tp, c.Decimal)
			if err != nil {
				err = newParseColError(err, c)
			}
		case mysql.Time:
			var t mysql.Time
			t, err = v.Convert(c.Tp)
			if err != nil {
				err = newParseColError(err, c)
				return
			}
			casted, err = t.RoundFrac(c.Decimal)
			if err != nil {
				err = newParseColError(err, c)
			}
		default:
			err = c.TypeError(val)
		}
	case mysql.TypeDuration:
		switch v := val.(type) {
		case string:
			casted, err = mysql.ParseDuration(v, c.Decimal)
			if err != nil {
				err = newParseColError(err, c)
			}
		case mysql.Time:
			var t mysql.Duration
			t, err = v.ConvertToDuration()
			if err != nil {
				err = newParseColError(err, c)
				return
			}
			casted, err = t.RoundFrac(c.Decimal)
			if err != nil {
				err = newParseColError(err, c)
			}
		case mysql.Duration:
			casted, err = v.RoundFrac(c.Decimal)
		default:
			err = c.TypeError(val)
		}
	case mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob, mysql.TypeString, mysql.TypeVarchar, mysql.TypeVarString:
		strV := ""
		switch v := val.(type) {
		case mysql.Time:
			strV = v.String()
		case mysql.Duration:
			strV = v.String()
		case []byte:
			if c.Charset == charset.CharsetBin {
				casted = v
				return
			}
			strV = string(v)
		default:
			strV = fmt.Sprintf("%v", val)
		}
		if (c.Flen != types.UnspecifiedLength) && (len(strV) > c.Flen) {
			strV = strV[:c.Flen]
		}
		casted = strV
	case mysql.TypeDecimal, mysql.TypeNewDecimal:
		switch v := val.(type) {
		case string:
			casted, err = mysql.ParseDecimal(v)
			if err != nil {
				err = newParseColError(err, c)
			}
		case int8:
			casted = mysql.NewDecimalFromInt(int64(v), 0)
		case int16:
			casted = mysql.NewDecimalFromInt(int64(v), 0)
		case int32:
			casted = mysql.NewDecimalFromInt(int64(v), 0)
		case int64:
			casted = mysql.NewDecimalFromInt(int64(v), 0)
		case int:
			casted = mysql.NewDecimalFromInt(int64(v), 0)
		case uint8:
			casted = mysql.NewDecimalFromUint(uint64(v), 0)
		case uint16:
			casted = mysql.NewDecimalFromUint(uint64(v), 0)
		case uint32:
			casted = mysql.NewDecimalFromUint(uint64(v), 0)
		case uint64:
			casted = mysql.NewDecimalFromUint(uint64(v), 0)
		case uint:
			casted = mysql.NewDecimalFromUint(uint64(v), 0)
		case float32:
			casted = mysql.NewDecimalFromFloat(float64(v))
		case float64:
			casted = mysql.NewDecimalFromFloat(float64(v))
		case mysql.Decimal:
			casted = v
		}
	default:
		err = c.TypeError(val)
	}
	return
}

https://github.com/pingcap/tidb/blob/0d6f270068e8ff2aedc1c314e907771b6a508ebd/column/column.go#L130-L253

然后回到:

		if err = column.CheckNotNull(tableCols, r); err != nil {
			return nil, errors.Trace(err)
		}

https://github.com/pingcap/tidb/blob/0d6f270068e8ff2aedc1c314e907771b6a508ebd/stmt/stmts/insert.go#L270-L272

然后如果有nil值在这里检查这个字段是否有标记不能为null,否则就拒绝执行:

// CheckNotNull checks if row has nil value set to a column with NotNull flag set.
func CheckNotNull(cols []*Col, row []interface{}) error {
	for _, c := range cols {
		if err := c.CheckNotNull(row[c.Offset]); err != nil {
			return err
		}
	}
	return nil
}

https://github.com/pingcap/tidb/blob/0d6f270068e8ff2aedc1c314e907771b6a508ebd/column/column.go#L663-L671

接着回到:

		// Notes: incompatible with mysql
		// MySQL will set last insert id to the first row, as follows:
		// `t(id int AUTO_INCREMENT, c1 int, PRIMARY KEY (id))`
		// `insert t (c1) values(1),(2),(3);`
		// Last insert id will be 1, not 3.
		h, err := t.AddRecord(ctx, r)
		if err == nil {
			continue
		}
		if len(s.OnDuplicate) == 0 || !errors2.ErrorEqual(err, kv.ErrConditionNotMatch) {
			return nil, errors.Trace(err)
		}

https://github.com/pingcap/tidb/blob/0d6f270068e8ff2aedc1c314e907771b6a508ebd/stmt/stmts/insert.go#L274-L285

然后进入t.AddRecord(ctx, r)插入这行数据:

// AddRecord implements table.Table AddRecord interface.
func (t *Table) AddRecord(ctx context.Context, r []interface{}) (recordID int64, err error) {
	id := variable.GetSessionVars(ctx).LastInsertID
	// Already have auto increment ID
	if id != 0 {
		recordID = int64(id)
	} else {
		recordID, err = t.alloc.Alloc(t.ID)
		if err != nil {
			return 0, err
		}
	}

https://github.com/pingcap/tidb/blob/0d6f270068e8ff2aedc1c314e907771b6a508ebd/table/tables/tables.go#L321-L332

如果不存在autoID,那么进入t.alloc.Alloc(t.ID)得到autoID:

// Alloc allocs the next autoID for table with tableID.
// It gets a batch of autoIDs at a time. So it does not need to access storage for each call.
func (alloc *allocator) Alloc(tableID int64) (int64, error) {
	if tableID == 0 {
		return 0, errors.New("Invalid tableID")
	}
	metaKey := meta.AutoIDKey(tableID)//根据当前表的ID获得当前表的autoID的KEY(例如:"mTable::4_autoID")
	alloc.mu.Lock()//对当前表的autoID分配加锁
	defer alloc.mu.Unlock()
	if alloc.base == alloc.end { // step,给这个session的插入预先取出step=1000的id,以提高效率(不需要每次加锁访问KV了)
		err := kv.RunInNewTxn(alloc.store, true, func(txn kv.Transaction) error {
			end, err := meta.GenID(txn, []byte(metaKey), step)
			if err != nil {
				return errors.Trace(err)
			}

			alloc.end = end
			alloc.base = alloc.end - step
			return nil
		})

		if err != nil {
			return 0, errors.Trace(err)
		}
	}

	alloc.base++
	log.Infof("Alloc id %d, table ID:%d, from %p, store ID:%s", alloc.base, tableID, alloc, alloc.store.UUID())
	return alloc.base, nil
}

https://github.com/pingcap/tidb/blob/0d6f270068e8ff2aedc1c314e907771b6a508ebd/meta/autoid/autoid.go#L42-L71

获取完autoID后回到:

		recordID, err = t.alloc.Alloc(t.ID)
		if err != nil {
			return 0, err
		}
	}
	txn, err := ctx.GetTxn(false)
	if err != nil {
		return 0, err
	}

https://github.com/pingcap/tidb/blob/0d6f270068e8ff2aedc1c314e907771b6a508ebd/table/tables/tables.go#L328-L336

取出当前session的事务txn后来到:

	for _, v := range t.indices {
		if v == nil {
			continue
		}
		colVals, _ := v.FetchValues(r)//取出当前要插入的行的该索引字段值
		if err = v.X.Create(txn, colVals, recordID); err != nil {
			if errors2.ErrorEqual(err, kv.ErrConditionNotMatch) {
				// Get the duplicate row handle
				iter, _, terr := v.X.Seek(txn, colVals)
				if terr != nil {
					return 0, errors.Trace(terr)
				}
				_, h, terr := iter.Next()
				if terr != nil {
					return 0, errors.Trace(terr)
				}
				return h, errors.Trace(err)
			}
			return 0, errors.Trace(err)
		}
	}

https://github.com/pingcap/tidb/blob/0d6f270068e8ff2aedc1c314e907771b6a508ebd/table/tables/tables.go#L337-L357

这里遍历所有索引,然后对插入的这行数据都建立索引v.X.Create(txn, colVals, recordID):

// Create creates a new entry in the kvIndex data.
// If the index is unique and there already exists an entry with the same key, Create will return ErrConditionNotMatch
func (c *kvIndex) Create(txn Transaction, indexedValues []interface{}, h int64) error {
	keyBuf, err := c.genIndexKey(indexedValues, h)
	if err != nil {
		return err
	}
	if !c.unique {
		// TODO: reconsider value
		err = txn.Set(keyBuf, []byte("timestamp?"))
		return errors.Trace(err)
	}

	// unique index
	_, err = txn.Get(keyBuf)
	if IsErrNotFound(err) {
		err = txn.Set(keyBuf, encodeHandle(h))
		return errors.Trace(err)
	}

	return errors.Trace(ErrConditionNotMatch)
}

https://github.com/pingcap/tidb/blob/0d6f270068e8ff2aedc1c314e907771b6a508ebd/kv/index_iter.go#L140-L161

进入c.genIndexKey(indexedValues, h)获得KEY编码:

func (c *kvIndex) genIndexKey(indexedValues []interface{}, h int64) ([]byte, error) {
	var (
		encVal []byte
		err    error
	)
	// only support single value index
	if !c.unique {//不是唯一索引需要加上行号以区分
		encVal, err = EncodeValue(append(indexedValues, h)...)
	} else {
		encVal, err = EncodeValue(indexedValues...)
	}
	if err != nil {
		return nil, err
	}
	buf := append([]byte(nil), []byte(c.prefix)...)
	buf = append(buf, encVal...)
	return buf, nil
}

https://github.com/pingcap/tidb/blob/0d6f270068e8ff2aedc1c314e907771b6a508ebd/kv/index_iter.go#L121-L138

回到:

if err != nil {
		return err
	}
	if !c.unique {
		// TODO: reconsider value
		err = txn.Set(keyBuf, []byte("timestamp?"))//对索引设置一个值(现在固定是"timestamp?",似乎没什么用?因为行号已经包含在key里面了,todo说了以后应该会重新考虑)
		return errors.Trace(err)
	}

https://github.com/pingcap/tidb/blob/0d6f270068e8ff2aedc1c314e907771b6a508ebd/kv/index_iter.go#L144-L151

回到:

if err = v.X.Create(txn, colVals, recordID); err != nil {
			if errors2.ErrorEqual(err, kv.ErrConditionNotMatch) {
				// Get the duplicate row handle
				iter, _, terr := v.X.Seek(txn, colVals)
				if terr != nil {
					return 0, errors.Trace(terr)
				}
				_, h, terr := iter.Next()
				if terr != nil {
					return 0, errors.Trace(terr)
				}
				return h, errors.Trace(err)
			}
			return 0, errors.Trace(err)
		}
	}

	// split a record into multiple kv pair
	// first key -> LOCK
	k := t.RecordKey(recordID, nil)
	// A new row with current txn-id as lockKey
	err = txn.Set([]byte(k), []byte(txn.String()))
	if err != nil {
		return 0, err
	}

https://github.com/pingcap/tidb/blob/0d6f270068e8ff2aedc1c314e907771b6a508ebd/table/tables/tables.go#L342-L366

然后把一行数据拆分成多个KV对存储:

// split a record into multiple kv pair
	// first key -> LOCK
	k := t.RecordKey(recordID, nil)//获取数据行标识对应的key
	// A new row with current txn-id as lockKey
	err = txn.Set([]byte(k), []byte(txn.String()))
	if err != nil {
		return 0, err
	}
	// column key -> column value
	for _, c := range t.Cols() {//把每一列数据拆分成一个kv对存储
		colKey := t.RecordKey(recordID, c)
		data, err := t.EncodeValue(r[c.Offset])
		if err != nil {
			return 0, err
		}
		err = txn.Set([]byte(colKey), data)
		if err != nil {
			return 0, err
		}
	}
	variable.GetSessionVars(ctx).AddAffectedRows(1)//标记影响行+1
	return recordID, nil
}

https://github.com/pingcap/tidb/blob/0d6f270068e8ff2aedc1c314e907771b6a508ebd/table/tables/tables.go#L359-L381

其中err = txn.Set([]byte(k), []byte(txn.String())),设置行Key到Kv,值是txn.String(),其中txn.String()获取了txn.tID,txn.tID是在https://github.com/pingcap/tidb/blob/0d6f270068e8ff2aedc1c314e907771b6a508ebd/store/localstore/kv.go#L110-L110初始化的,而globalID是初始为0的全局变量,放在内存的,每次启动归0,除了日志输出和这里的写入,暂时没发现还有什么用。

 

接着,回到:

		h, err := t.AddRecord(ctx, r)
		if err == nil {
			continue
		}

https://github.com/pingcap/tidb/blob/0d6f270068e8ff2aedc1c314e907771b6a508ebd/stmt/stmts/insert.go#L279-L282

由于值插入一行数据,所以退出这个Exec函数:

	for i, list := range s.Lists {
		r := make([]interface{}, len(tableCols))
		valueCount := len(list)

		if insertValueCount != valueCount {
			// "insert into t values (), ()" is valid.
			// "insert into t values (), (1)" is not valid.
			// "insert into t values (1), ()" is not valid.
			// "insert into t values (1,2), (1)" is not valid.
			// So the value count must be same for all insert list.
			return nil, errors.Errorf("Column count doesn't match value count at row %d", i+1)
		}

		if valueCount == 0 && len(s.ColNames) > 0 {
			// "insert into t (c1) values ()" is not valid.
			return nil, errors.Errorf("INSERT INTO %s: expected %d value(s), have %d", s.TableIdent, len(s.ColNames), 0)
		} else if valueCount > 0 && valueCount != len(cols) {
			return nil, errors.Errorf("INSERT INTO %s: expected %d value(s), have %d", s.TableIdent, len(cols), valueCount)
		}

		// Clear last insert id.
		variable.GetSessionVars(ctx).SetLastInsertID(0)

		marked := make(map[int]struct{}, len(list))
		for i, expr := range list {
			// For "insert into t values (default)" Default Eval.
			m[expressions.ExprEvalDefaultName] = cols[i].Name.O

			val, err := expr.Eval(ctx, m)
			if err != nil {
				return nil, errors.Trace(err)
			}
			r[cols[i].Offset] = val
			marked[cols[i].Offset] = struct{}{}
		}

		if err := s.initDefaultValues(ctx, t, tableCols, r, marked); err != nil {
			return nil, errors.Trace(err)
		}

		if err = column.CastValues(ctx, r, cols); err != nil {
			return nil, errors.Trace(err)
		}
		if err = column.CheckNotNull(tableCols, r); err != nil {
			return nil, errors.Trace(err)
		}

		// Notes: incompatible with mysql
		// MySQL will set last insert id to the first row, as follows:
		// `t(id int AUTO_INCREMENT, c1 int, PRIMARY KEY (id))`
		// `insert t (c1) values(1),(2),(3);`
		// Last insert id will be 1, not 3.
		h, err := t.AddRecord(ctx, r)
		if err == nil {
			continue
		}
		if len(s.OnDuplicate) == 0 || !errors2.ErrorEqual(err, kv.ErrConditionNotMatch) {
			return nil, errors.Trace(err)
		}
		// On duplicate key Update the duplicate row.
		// Evaluate the updated value.
		// TODO: report rows affected and last insert id.
		toUpdateColumns, err := getUpdateColumns(t, s.OnDuplicate)
		if err != nil {
			return nil, errors.Trace(err)
		}
		data, err := t.Row(ctx, h)
		if err != nil {
			return nil, errors.Trace(err)
		}
		err = updateRecord(ctx, h, data, t, toUpdateColumns, s.OnDuplicate, r)
		if err != nil {
			return nil, errors.Trace(err)
		}
	}

	return nil, nil
}

https://github.com/pingcap/tidb/blob/0d6f270068e8ff2aedc1c314e907771b6a508ebd/stmt/stmts/insert.go#L226-L304

之后就和前面一样的退出了,到此,这个语句执行完成了。

小结

插入数据的流程:检查SQL中插入字段的合法性—>计算出对应表达式值(如各种内置函数等)并提取出插入的数据行->对SQL没有提到的字段设置为默认值,并检查空值->生成唯一行号->创建涉及到的所有索引->把一行数据拆分成多个KV对存储。

发表评论

邮箱地址不会被公开。 必填项已用*标注

请输入正确的验证码