继续说下一个SQL语句:
insert into t values(‘c1′,’ret1’);
入口和之前语句一样,也是个Exec类型的语句,这里就不在描述了。
直接进入Exec():
// Exec implements the stmt.Statement Exec interface. func (s *InsertIntoStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) { t, err := getTable(ctx, s.TableIdent)//根据表名获得表对应信息 if err != nil { return nil, errors.Trace(err) } tableCols := t.Cols()//获取表中所有字段 cols, err := s.getColumns(tableCols)//检查SQL内的字段名是否合法,如果SQL没有指定字段名称则默认为全部字段,并检查字段是否只出现了一次 if err != nil { return nil, errors.Trace(err) }
然后运行:
// Process `insert ... (select ..) ` if s.Sel != nil { return s.execSelect(t, cols, ctx) }
检查是否包含子select语句,如果有需要先执行select语句,这里先不说select语句,后面会讲。
然后检查是否包含set col = value的插入语法内容:
// Process `insert ... set x=y...` if len(s.Setlist) > 0 { if len(s.Lists) > 0 { return nil, errors.Errorf("INSERT INTO %s: set type should not use values", s.TableIdent) } var l []expression.Expression for _, v := range s.Setlist { l = append(l, v.Expr) } s.Lists = append(s.Lists, l) }
如果存在则也把这些值加入Lists的插入值列表。
然后把没有赋值的字段(/列)赋上默认值:
m := map[interface{}]interface{}{} for _, v := range tableCols { var ( value interface{} ok bool ) value, ok, err = getDefaultValue(ctx, v) if ok { if err != nil { return nil, errors.Trace(err) } m[v.Name.L] = value } }
然后开始对插入的每行数据做进一步处理:
insertValueCount := len(s.Lists[0]) for i, list := range s.Lists {//遍历Lists,一个list的大小为2,包含一个列名称和一个值数据/表达式 r := make([]interface{}, len(tableCols)) valueCount := len(list)
然后检查list大小是否合法:
if insertValueCount != valueCount { // "insert into t values (), ()" is valid. // "insert into t values (), (1)" is not valid. // "insert into t values (1), ()" is not valid. // "insert into t values (1,2), (1)" is not valid. // So the value count must be same for all insert list. return nil, errors.Errorf("Column count doesn't match value count at row %d", i+1) } if valueCount == 0 && len(s.ColNames) > 0 { // "insert into t (c1) values ()" is not valid. return nil, errors.Errorf("INSERT INTO %s: expected %d value(s), have %d", s.TableIdent, len(s.ColNames), 0) } else if valueCount > 0 && valueCount != len(cols) { return nil, errors.Errorf("INSERT INTO %s: expected %d value(s), have %d", s.TableIdent, len(cols), valueCount) }
接下获取各个字段值,或者计算出对应表达式值(如各种内置函数等),并把结果放在r变量:
// Clear last insert id. variable.GetSessionVars(ctx).SetLastInsertID(0) marked := make(map[int]struct{}, len(list)) for i, expr := range list { // For "insert into t values (default)" Default Eval. m[expressions.ExprEvalDefaultName] = cols[i].Name.O val, err := expr.Eval(ctx, m)//计算表达式的值(也可能只是一个value) if err != nil { return nil, errors.Trace(err) } r[cols[i].Offset] = val//根据偏移量赋值给r marked[cols[i].Offset] = struct{}{} }
接着对SQL没有提到的字段在r中赋值为默认值:
if err := s.initDefaultValues(ctx, t, tableCols, r, marked); err != nil { return nil, errors.Trace(err) }
然后进入s.initDefaultValues:
func (s *InsertIntoStmt) initDefaultValues(ctx context.Context, t table.Table, cols []*column.Col, row []interface{}, marked map[int]struct{}) error { var err error var defaultValueCols []*column.Col for i, c := range cols { if row[i] != nil { // Column value is not nil, continue. continue } // If the nil value is evaluated in insert list, we will use nil except auto increment column. if _, ok := marked[i]; ok && !mysql.HasAutoIncrementFlag(c.Flag) && !mysql.HasTimestampFlag(c.Flag) { continue } //对自动递增的值或者时间递增 if mysql.HasAutoIncrementFlag(c.Flag) { var id int64 if id, err = t.AllocAutoID(); err != nil { return errors.Trace(err) } row[i] = id variable.GetSessionVars(ctx).SetLastInsertID(uint64(id)) } else { var value interface{} value, _, err = getDefaultValue(ctx, c) if err != nil { return errors.Trace(err) } row[i] = value } defaultValueCols = append(defaultValueCols, c) } if err = column.CastValues(ctx, row, defaultValueCols); err != nil {//把默认值映射到row return errors.Trace(err) } return nil }
然后回到:
if err = column.CastValues(ctx, r, cols); err != nil { return nil, errors.Trace(err) } if err = column.CheckNotNull(tableCols, r); err != nil { return nil, errors.Trace(err) }
CastValues(ctx, r, cols):映射cols到r:
// CastValues casts values based on columns type. func CastValues(ctx context.Context, rec []interface{}, cols []*Col) (err error) { for _, c := range cols { rec[c.Offset], err = c.CastValue(ctx, rec[c.Offset]) if err != nil { return } } return }
并且在c.CastValue(ctx, rec[c.Offset]),根据具体类型做转换:
// CastValue casts a value based on column's type. func (c *Col) CastValue(ctx context.Context, val interface{}) (casted interface{}, err error) { if val == nil { return } switch c.Tp { case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong, mysql.TypeYear: intVal, errCode := c.normalizeIntegerValue(val) if errCode == errCodeType { casted = intVal err = c.TypeError(val) return } return c.castIntegerValue(intVal, errCode) case mysql.TypeFloat, mysql.TypeDouble: return c.castFloatValue(val) case mysql.TypeDate, mysql.TypeDatetime, mysql.TypeTimestamp: switch v := val.(type) { case int64: casted, err = mysql.ParseTimeFromNum(v, c.Tp, c.Decimal) if err != nil { err = newParseColError(err, c) } case string: casted, err = mysql.ParseTime(v, c.Tp, c.Decimal) if err != nil { err = newParseColError(err, c) } case mysql.Time: var t mysql.Time t, err = v.Convert(c.Tp) if err != nil { err = newParseColError(err, c) return } casted, err = t.RoundFrac(c.Decimal) if err != nil { err = newParseColError(err, c) } default: err = c.TypeError(val) } case mysql.TypeDuration: switch v := val.(type) { case string: casted, err = mysql.ParseDuration(v, c.Decimal) if err != nil { err = newParseColError(err, c) } case mysql.Time: var t mysql.Duration t, err = v.ConvertToDuration() if err != nil { err = newParseColError(err, c) return } casted, err = t.RoundFrac(c.Decimal) if err != nil { err = newParseColError(err, c) } case mysql.Duration: casted, err = v.RoundFrac(c.Decimal) default: err = c.TypeError(val) } case mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob, mysql.TypeString, mysql.TypeVarchar, mysql.TypeVarString: strV := "" switch v := val.(type) { case mysql.Time: strV = v.String() case mysql.Duration: strV = v.String() case []byte: if c.Charset == charset.CharsetBin { casted = v return } strV = string(v) default: strV = fmt.Sprintf("%v", val) } if (c.Flen != types.UnspecifiedLength) && (len(strV) > c.Flen) { strV = strV[:c.Flen] } casted = strV case mysql.TypeDecimal, mysql.TypeNewDecimal: switch v := val.(type) { case string: casted, err = mysql.ParseDecimal(v) if err != nil { err = newParseColError(err, c) } case int8: casted = mysql.NewDecimalFromInt(int64(v), 0) case int16: casted = mysql.NewDecimalFromInt(int64(v), 0) case int32: casted = mysql.NewDecimalFromInt(int64(v), 0) case int64: casted = mysql.NewDecimalFromInt(int64(v), 0) case int: casted = mysql.NewDecimalFromInt(int64(v), 0) case uint8: casted = mysql.NewDecimalFromUint(uint64(v), 0) case uint16: casted = mysql.NewDecimalFromUint(uint64(v), 0) case uint32: casted = mysql.NewDecimalFromUint(uint64(v), 0) case uint64: casted = mysql.NewDecimalFromUint(uint64(v), 0) case uint: casted = mysql.NewDecimalFromUint(uint64(v), 0) case float32: casted = mysql.NewDecimalFromFloat(float64(v)) case float64: casted = mysql.NewDecimalFromFloat(float64(v)) case mysql.Decimal: casted = v } default: err = c.TypeError(val) } return }
然后回到:
if err = column.CheckNotNull(tableCols, r); err != nil { return nil, errors.Trace(err) }
然后如果有nil值在这里检查这个字段是否有标记不能为null,否则就拒绝执行:
// CheckNotNull checks if row has nil value set to a column with NotNull flag set. func CheckNotNull(cols []*Col, row []interface{}) error { for _, c := range cols { if err := c.CheckNotNull(row[c.Offset]); err != nil { return err } } return nil }
接着回到:
// Notes: incompatible with mysql // MySQL will set last insert id to the first row, as follows: // `t(id int AUTO_INCREMENT, c1 int, PRIMARY KEY (id))` // `insert t (c1) values(1),(2),(3);` // Last insert id will be 1, not 3. h, err := t.AddRecord(ctx, r) if err == nil { continue } if len(s.OnDuplicate) == 0 || !errors2.ErrorEqual(err, kv.ErrConditionNotMatch) { return nil, errors.Trace(err) }
然后进入t.AddRecord(ctx, r)插入这行数据:
// AddRecord implements table.Table AddRecord interface. func (t *Table) AddRecord(ctx context.Context, r []interface{}) (recordID int64, err error) { id := variable.GetSessionVars(ctx).LastInsertID // Already have auto increment ID if id != 0 { recordID = int64(id) } else { recordID, err = t.alloc.Alloc(t.ID) if err != nil { return 0, err } }
如果不存在autoID,那么进入t.alloc.Alloc(t.ID)得到autoID:
// Alloc allocs the next autoID for table with tableID. // It gets a batch of autoIDs at a time. So it does not need to access storage for each call. func (alloc *allocator) Alloc(tableID int64) (int64, error) { if tableID == 0 { return 0, errors.New("Invalid tableID") } metaKey := meta.AutoIDKey(tableID)//根据当前表的ID获得当前表的autoID的KEY(例如:"mTable::4_autoID") alloc.mu.Lock()//对当前表的autoID分配加锁 defer alloc.mu.Unlock() if alloc.base == alloc.end { // step,给这个session的插入预先取出step=1000的id,以提高效率(不需要每次加锁访问KV了) err := kv.RunInNewTxn(alloc.store, true, func(txn kv.Transaction) error { end, err := meta.GenID(txn, []byte(metaKey), step) if err != nil { return errors.Trace(err) } alloc.end = end alloc.base = alloc.end - step return nil }) if err != nil { return 0, errors.Trace(err) } } alloc.base++ log.Infof("Alloc id %d, table ID:%d, from %p, store ID:%s", alloc.base, tableID, alloc, alloc.store.UUID()) return alloc.base, nil }
获取完autoID后回到:
recordID, err = t.alloc.Alloc(t.ID) if err != nil { return 0, err } } txn, err := ctx.GetTxn(false) if err != nil { return 0, err }
取出当前session的事务txn后来到:
for _, v := range t.indices { if v == nil { continue } colVals, _ := v.FetchValues(r)//取出当前要插入的行的该索引字段值 if err = v.X.Create(txn, colVals, recordID); err != nil { if errors2.ErrorEqual(err, kv.ErrConditionNotMatch) { // Get the duplicate row handle iter, _, terr := v.X.Seek(txn, colVals) if terr != nil { return 0, errors.Trace(terr) } _, h, terr := iter.Next() if terr != nil { return 0, errors.Trace(terr) } return h, errors.Trace(err) } return 0, errors.Trace(err) } }
这里遍历所有索引,然后对插入的这行数据都建立索引v.X.Create(txn, colVals, recordID):
// Create creates a new entry in the kvIndex data. // If the index is unique and there already exists an entry with the same key, Create will return ErrConditionNotMatch func (c *kvIndex) Create(txn Transaction, indexedValues []interface{}, h int64) error { keyBuf, err := c.genIndexKey(indexedValues, h) if err != nil { return err } if !c.unique { // TODO: reconsider value err = txn.Set(keyBuf, []byte("timestamp?")) return errors.Trace(err) } // unique index _, err = txn.Get(keyBuf) if IsErrNotFound(err) { err = txn.Set(keyBuf, encodeHandle(h)) return errors.Trace(err) } return errors.Trace(ErrConditionNotMatch) }
进入c.genIndexKey(indexedValues, h)获得KEY编码:
func (c *kvIndex) genIndexKey(indexedValues []interface{}, h int64) ([]byte, error) { var ( encVal []byte err error ) // only support single value index if !c.unique {//不是唯一索引需要加上行号以区分 encVal, err = EncodeValue(append(indexedValues, h)...) } else { encVal, err = EncodeValue(indexedValues...) } if err != nil { return nil, err } buf := append([]byte(nil), []byte(c.prefix)...) buf = append(buf, encVal...) return buf, nil }
回到:
if err != nil { return err } if !c.unique { // TODO: reconsider value err = txn.Set(keyBuf, []byte("timestamp?"))//对索引设置一个值(现在固定是"timestamp?",似乎没什么用?因为行号已经包含在key里面了,todo说了以后应该会重新考虑) return errors.Trace(err) }
回到:
if err = v.X.Create(txn, colVals, recordID); err != nil { if errors2.ErrorEqual(err, kv.ErrConditionNotMatch) { // Get the duplicate row handle iter, _, terr := v.X.Seek(txn, colVals) if terr != nil { return 0, errors.Trace(terr) } _, h, terr := iter.Next() if terr != nil { return 0, errors.Trace(terr) } return h, errors.Trace(err) } return 0, errors.Trace(err) } } // split a record into multiple kv pair // first key -> LOCK k := t.RecordKey(recordID, nil) // A new row with current txn-id as lockKey err = txn.Set([]byte(k), []byte(txn.String())) if err != nil { return 0, err }
然后把一行数据拆分成多个KV对存储:
// split a record into multiple kv pair // first key -> LOCK k := t.RecordKey(recordID, nil)//获取数据行标识对应的key // A new row with current txn-id as lockKey err = txn.Set([]byte(k), []byte(txn.String())) if err != nil { return 0, err } // column key -> column value for _, c := range t.Cols() {//把每一列数据拆分成一个kv对存储 colKey := t.RecordKey(recordID, c) data, err := t.EncodeValue(r[c.Offset]) if err != nil { return 0, err } err = txn.Set([]byte(colKey), data) if err != nil { return 0, err } } variable.GetSessionVars(ctx).AddAffectedRows(1)//标记影响行+1 return recordID, nil }
其中err = txn.Set([]byte(k), []byte(txn.String())),设置行Key到Kv,值是txn.String(),其中txn.String()获取了txn.tID,txn.tID是在https://github.com/pingcap/tidb/blob/0d6f270068e8ff2aedc1c314e907771b6a508ebd/store/localstore/kv.go#L110-L110初始化的,而globalID是初始为0的全局变量,放在内存的,每次启动归0,除了日志输出和这里的写入,暂时没发现还有什么用。
接着,回到:
h, err := t.AddRecord(ctx, r) if err == nil { continue }
由于值插入一行数据,所以退出这个Exec函数:
for i, list := range s.Lists { r := make([]interface{}, len(tableCols)) valueCount := len(list) if insertValueCount != valueCount { // "insert into t values (), ()" is valid. // "insert into t values (), (1)" is not valid. // "insert into t values (1), ()" is not valid. // "insert into t values (1,2), (1)" is not valid. // So the value count must be same for all insert list. return nil, errors.Errorf("Column count doesn't match value count at row %d", i+1) } if valueCount == 0 && len(s.ColNames) > 0 { // "insert into t (c1) values ()" is not valid. return nil, errors.Errorf("INSERT INTO %s: expected %d value(s), have %d", s.TableIdent, len(s.ColNames), 0) } else if valueCount > 0 && valueCount != len(cols) { return nil, errors.Errorf("INSERT INTO %s: expected %d value(s), have %d", s.TableIdent, len(cols), valueCount) } // Clear last insert id. variable.GetSessionVars(ctx).SetLastInsertID(0) marked := make(map[int]struct{}, len(list)) for i, expr := range list { // For "insert into t values (default)" Default Eval. m[expressions.ExprEvalDefaultName] = cols[i].Name.O val, err := expr.Eval(ctx, m) if err != nil { return nil, errors.Trace(err) } r[cols[i].Offset] = val marked[cols[i].Offset] = struct{}{} } if err := s.initDefaultValues(ctx, t, tableCols, r, marked); err != nil { return nil, errors.Trace(err) } if err = column.CastValues(ctx, r, cols); err != nil { return nil, errors.Trace(err) } if err = column.CheckNotNull(tableCols, r); err != nil { return nil, errors.Trace(err) } // Notes: incompatible with mysql // MySQL will set last insert id to the first row, as follows: // `t(id int AUTO_INCREMENT, c1 int, PRIMARY KEY (id))` // `insert t (c1) values(1),(2),(3);` // Last insert id will be 1, not 3. h, err := t.AddRecord(ctx, r) if err == nil { continue } if len(s.OnDuplicate) == 0 || !errors2.ErrorEqual(err, kv.ErrConditionNotMatch) { return nil, errors.Trace(err) } // On duplicate key Update the duplicate row. // Evaluate the updated value. // TODO: report rows affected and last insert id. toUpdateColumns, err := getUpdateColumns(t, s.OnDuplicate) if err != nil { return nil, errors.Trace(err) } data, err := t.Row(ctx, h) if err != nil { return nil, errors.Trace(err) } err = updateRecord(ctx, h, data, t, toUpdateColumns, s.OnDuplicate, r) if err != nil { return nil, errors.Trace(err) } } return nil, nil }
之后就和前面一样的退出了,到此,这个语句执行完成了。
小结
插入数据的流程:检查SQL中插入字段的合法性—>计算出对应表达式值(如各种内置函数等)并提取出插入的数据行->对SQL没有提到的字段设置为默认值,并检查空值->生成唯一行号->创建涉及到的所有索引->把一行数据拆分成多个KV对存储。