继续说下一个SQL语句:
insert into t values(‘c1′,’ret1’);
入口和之前语句一样,也是个Exec类型的语句,这里就不在描述了。
直接进入Exec():
// Exec implements the stmt.Statement Exec interface.
func (s *InsertIntoStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) {
t, err := getTable(ctx, s.TableIdent)//根据表名获得表对应信息
if err != nil {
return nil, errors.Trace(err)
}
tableCols := t.Cols()//获取表中所有字段
cols, err := s.getColumns(tableCols)//检查SQL内的字段名是否合法,如果SQL没有指定字段名称则默认为全部字段,并检查字段是否只出现了一次
if err != nil {
return nil, errors.Trace(err)
}
然后运行:
// Process `insert ... (select ..) `
if s.Sel != nil {
return s.execSelect(t, cols, ctx)
}
检查是否包含子select语句,如果有需要先执行select语句,这里先不说select语句,后面会讲。
然后检查是否包含set col = value的插入语法内容:
// Process `insert ... set x=y...`
if len(s.Setlist) > 0 {
if len(s.Lists) > 0 {
return nil, errors.Errorf("INSERT INTO %s: set type should not use values", s.TableIdent)
}
var l []expression.Expression
for _, v := range s.Setlist {
l = append(l, v.Expr)
}
s.Lists = append(s.Lists, l)
}
如果存在则也把这些值加入Lists的插入值列表。
然后把没有赋值的字段(/列)赋上默认值:
m := map[interface{}]interface{}{}
for _, v := range tableCols {
var (
value interface{}
ok bool
)
value, ok, err = getDefaultValue(ctx, v)
if ok {
if err != nil {
return nil, errors.Trace(err)
}
m[v.Name.L] = value
}
}
然后开始对插入的每行数据做进一步处理:
insertValueCount := len(s.Lists[0])
for i, list := range s.Lists {//遍历Lists,一个list的大小为2,包含一个列名称和一个值数据/表达式
r := make([]interface{}, len(tableCols))
valueCount := len(list)
然后检查list大小是否合法:
if insertValueCount != valueCount {
// "insert into t values (), ()" is valid.
// "insert into t values (), (1)" is not valid.
// "insert into t values (1), ()" is not valid.
// "insert into t values (1,2), (1)" is not valid.
// So the value count must be same for all insert list.
return nil, errors.Errorf("Column count doesn't match value count at row %d", i+1)
}
if valueCount == 0 && len(s.ColNames) > 0 {
// "insert into t (c1) values ()" is not valid.
return nil, errors.Errorf("INSERT INTO %s: expected %d value(s), have %d", s.TableIdent, len(s.ColNames), 0)
} else if valueCount > 0 && valueCount != len(cols) {
return nil, errors.Errorf("INSERT INTO %s: expected %d value(s), have %d", s.TableIdent, len(cols), valueCount)
}
接下获取各个字段值,或者计算出对应表达式值(如各种内置函数等),并把结果放在r变量:
// Clear last insert id.
variable.GetSessionVars(ctx).SetLastInsertID(0)
marked := make(map[int]struct{}, len(list))
for i, expr := range list {
// For "insert into t values (default)" Default Eval.
m[expressions.ExprEvalDefaultName] = cols[i].Name.O
val, err := expr.Eval(ctx, m)//计算表达式的值(也可能只是一个value)
if err != nil {
return nil, errors.Trace(err)
}
r[cols[i].Offset] = val//根据偏移量赋值给r
marked[cols[i].Offset] = struct{}{}
}
接着对SQL没有提到的字段在r中赋值为默认值:
if err := s.initDefaultValues(ctx, t, tableCols, r, marked); err != nil {
return nil, errors.Trace(err)
}
然后进入s.initDefaultValues:
func (s *InsertIntoStmt) initDefaultValues(ctx context.Context, t table.Table, cols []*column.Col, row []interface{}, marked map[int]struct{}) error {
var err error
var defaultValueCols []*column.Col
for i, c := range cols {
if row[i] != nil {
// Column value is not nil, continue.
continue
}
// If the nil value is evaluated in insert list, we will use nil except auto increment column.
if _, ok := marked[i]; ok && !mysql.HasAutoIncrementFlag(c.Flag) && !mysql.HasTimestampFlag(c.Flag) {
continue
}
//对自动递增的值或者时间递增
if mysql.HasAutoIncrementFlag(c.Flag) {
var id int64
if id, err = t.AllocAutoID(); err != nil {
return errors.Trace(err)
}
row[i] = id
variable.GetSessionVars(ctx).SetLastInsertID(uint64(id))
} else {
var value interface{}
value, _, err = getDefaultValue(ctx, c)
if err != nil {
return errors.Trace(err)
}
row[i] = value
}
defaultValueCols = append(defaultValueCols, c)
}
if err = column.CastValues(ctx, row, defaultValueCols); err != nil {//把默认值映射到row
return errors.Trace(err)
}
return nil
}
然后回到:
if err = column.CastValues(ctx, r, cols); err != nil {
return nil, errors.Trace(err)
}
if err = column.CheckNotNull(tableCols, r); err != nil {
return nil, errors.Trace(err)
}
CastValues(ctx, r, cols):映射cols到r:
// CastValues casts values based on columns type.
func CastValues(ctx context.Context, rec []interface{}, cols []*Col) (err error) {
for _, c := range cols {
rec[c.Offset], err = c.CastValue(ctx, rec[c.Offset])
if err != nil {
return
}
}
return
}
并且在c.CastValue(ctx, rec[c.Offset]),根据具体类型做转换:
// CastValue casts a value based on column's type.
func (c *Col) CastValue(ctx context.Context, val interface{}) (casted interface{}, err error) {
if val == nil {
return
}
switch c.Tp {
case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong, mysql.TypeYear:
intVal, errCode := c.normalizeIntegerValue(val)
if errCode == errCodeType {
casted = intVal
err = c.TypeError(val)
return
}
return c.castIntegerValue(intVal, errCode)
case mysql.TypeFloat, mysql.TypeDouble:
return c.castFloatValue(val)
case mysql.TypeDate, mysql.TypeDatetime, mysql.TypeTimestamp:
switch v := val.(type) {
case int64:
casted, err = mysql.ParseTimeFromNum(v, c.Tp, c.Decimal)
if err != nil {
err = newParseColError(err, c)
}
case string:
casted, err = mysql.ParseTime(v, c.Tp, c.Decimal)
if err != nil {
err = newParseColError(err, c)
}
case mysql.Time:
var t mysql.Time
t, err = v.Convert(c.Tp)
if err != nil {
err = newParseColError(err, c)
return
}
casted, err = t.RoundFrac(c.Decimal)
if err != nil {
err = newParseColError(err, c)
}
default:
err = c.TypeError(val)
}
case mysql.TypeDuration:
switch v := val.(type) {
case string:
casted, err = mysql.ParseDuration(v, c.Decimal)
if err != nil {
err = newParseColError(err, c)
}
case mysql.Time:
var t mysql.Duration
t, err = v.ConvertToDuration()
if err != nil {
err = newParseColError(err, c)
return
}
casted, err = t.RoundFrac(c.Decimal)
if err != nil {
err = newParseColError(err, c)
}
case mysql.Duration:
casted, err = v.RoundFrac(c.Decimal)
default:
err = c.TypeError(val)
}
case mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob, mysql.TypeString, mysql.TypeVarchar, mysql.TypeVarString:
strV := ""
switch v := val.(type) {
case mysql.Time:
strV = v.String()
case mysql.Duration:
strV = v.String()
case []byte:
if c.Charset == charset.CharsetBin {
casted = v
return
}
strV = string(v)
default:
strV = fmt.Sprintf("%v", val)
}
if (c.Flen != types.UnspecifiedLength) && (len(strV) > c.Flen) {
strV = strV[:c.Flen]
}
casted = strV
case mysql.TypeDecimal, mysql.TypeNewDecimal:
switch v := val.(type) {
case string:
casted, err = mysql.ParseDecimal(v)
if err != nil {
err = newParseColError(err, c)
}
case int8:
casted = mysql.NewDecimalFromInt(int64(v), 0)
case int16:
casted = mysql.NewDecimalFromInt(int64(v), 0)
case int32:
casted = mysql.NewDecimalFromInt(int64(v), 0)
case int64:
casted = mysql.NewDecimalFromInt(int64(v), 0)
case int:
casted = mysql.NewDecimalFromInt(int64(v), 0)
case uint8:
casted = mysql.NewDecimalFromUint(uint64(v), 0)
case uint16:
casted = mysql.NewDecimalFromUint(uint64(v), 0)
case uint32:
casted = mysql.NewDecimalFromUint(uint64(v), 0)
case uint64:
casted = mysql.NewDecimalFromUint(uint64(v), 0)
case uint:
casted = mysql.NewDecimalFromUint(uint64(v), 0)
case float32:
casted = mysql.NewDecimalFromFloat(float64(v))
case float64:
casted = mysql.NewDecimalFromFloat(float64(v))
case mysql.Decimal:
casted = v
}
default:
err = c.TypeError(val)
}
return
}
然后回到:
if err = column.CheckNotNull(tableCols, r); err != nil {
return nil, errors.Trace(err)
}
然后如果有nil值在这里检查这个字段是否有标记不能为null,否则就拒绝执行:
// CheckNotNull checks if row has nil value set to a column with NotNull flag set.
func CheckNotNull(cols []*Col, row []interface{}) error {
for _, c := range cols {
if err := c.CheckNotNull(row[c.Offset]); err != nil {
return err
}
}
return nil
}
接着回到:
// Notes: incompatible with mysql
// MySQL will set last insert id to the first row, as follows:
// `t(id int AUTO_INCREMENT, c1 int, PRIMARY KEY (id))`
// `insert t (c1) values(1),(2),(3);`
// Last insert id will be 1, not 3.
h, err := t.AddRecord(ctx, r)
if err == nil {
continue
}
if len(s.OnDuplicate) == 0 || !errors2.ErrorEqual(err, kv.ErrConditionNotMatch) {
return nil, errors.Trace(err)
}
然后进入t.AddRecord(ctx, r)插入这行数据:
// AddRecord implements table.Table AddRecord interface.
func (t *Table) AddRecord(ctx context.Context, r []interface{}) (recordID int64, err error) {
id := variable.GetSessionVars(ctx).LastInsertID
// Already have auto increment ID
if id != 0 {
recordID = int64(id)
} else {
recordID, err = t.alloc.Alloc(t.ID)
if err != nil {
return 0, err
}
}
如果不存在autoID,那么进入t.alloc.Alloc(t.ID)得到autoID:
// Alloc allocs the next autoID for table with tableID.
// It gets a batch of autoIDs at a time. So it does not need to access storage for each call.
func (alloc *allocator) Alloc(tableID int64) (int64, error) {
if tableID == 0 {
return 0, errors.New("Invalid tableID")
}
metaKey := meta.AutoIDKey(tableID)//根据当前表的ID获得当前表的autoID的KEY(例如:"mTable::4_autoID")
alloc.mu.Lock()//对当前表的autoID分配加锁
defer alloc.mu.Unlock()
if alloc.base == alloc.end { // step,给这个session的插入预先取出step=1000的id,以提高效率(不需要每次加锁访问KV了)
err := kv.RunInNewTxn(alloc.store, true, func(txn kv.Transaction) error {
end, err := meta.GenID(txn, []byte(metaKey), step)
if err != nil {
return errors.Trace(err)
}
alloc.end = end
alloc.base = alloc.end - step
return nil
})
if err != nil {
return 0, errors.Trace(err)
}
}
alloc.base++
log.Infof("Alloc id %d, table ID:%d, from %p, store ID:%s", alloc.base, tableID, alloc, alloc.store.UUID())
return alloc.base, nil
}
获取完autoID后回到:
recordID, err = t.alloc.Alloc(t.ID)
if err != nil {
return 0, err
}
}
txn, err := ctx.GetTxn(false)
if err != nil {
return 0, err
}
取出当前session的事务txn后来到:
for _, v := range t.indices {
if v == nil {
continue
}
colVals, _ := v.FetchValues(r)//取出当前要插入的行的该索引字段值
if err = v.X.Create(txn, colVals, recordID); err != nil {
if errors2.ErrorEqual(err, kv.ErrConditionNotMatch) {
// Get the duplicate row handle
iter, _, terr := v.X.Seek(txn, colVals)
if terr != nil {
return 0, errors.Trace(terr)
}
_, h, terr := iter.Next()
if terr != nil {
return 0, errors.Trace(terr)
}
return h, errors.Trace(err)
}
return 0, errors.Trace(err)
}
}
这里遍历所有索引,然后对插入的这行数据都建立索引v.X.Create(txn, colVals, recordID):
// Create creates a new entry in the kvIndex data.
// If the index is unique and there already exists an entry with the same key, Create will return ErrConditionNotMatch
func (c *kvIndex) Create(txn Transaction, indexedValues []interface{}, h int64) error {
keyBuf, err := c.genIndexKey(indexedValues, h)
if err != nil {
return err
}
if !c.unique {
// TODO: reconsider value
err = txn.Set(keyBuf, []byte("timestamp?"))
return errors.Trace(err)
}
// unique index
_, err = txn.Get(keyBuf)
if IsErrNotFound(err) {
err = txn.Set(keyBuf, encodeHandle(h))
return errors.Trace(err)
}
return errors.Trace(ErrConditionNotMatch)
}
进入c.genIndexKey(indexedValues, h)获得KEY编码:
func (c *kvIndex) genIndexKey(indexedValues []interface{}, h int64) ([]byte, error) {
var (
encVal []byte
err error
)
// only support single value index
if !c.unique {//不是唯一索引需要加上行号以区分
encVal, err = EncodeValue(append(indexedValues, h)...)
} else {
encVal, err = EncodeValue(indexedValues...)
}
if err != nil {
return nil, err
}
buf := append([]byte(nil), []byte(c.prefix)...)
buf = append(buf, encVal...)
return buf, nil
}
回到:
if err != nil {
return err
}
if !c.unique {
// TODO: reconsider value
err = txn.Set(keyBuf, []byte("timestamp?"))//对索引设置一个值(现在固定是"timestamp?",似乎没什么用?因为行号已经包含在key里面了,todo说了以后应该会重新考虑)
return errors.Trace(err)
}
回到:
if err = v.X.Create(txn, colVals, recordID); err != nil {
if errors2.ErrorEqual(err, kv.ErrConditionNotMatch) {
// Get the duplicate row handle
iter, _, terr := v.X.Seek(txn, colVals)
if terr != nil {
return 0, errors.Trace(terr)
}
_, h, terr := iter.Next()
if terr != nil {
return 0, errors.Trace(terr)
}
return h, errors.Trace(err)
}
return 0, errors.Trace(err)
}
}
// split a record into multiple kv pair
// first key -> LOCK
k := t.RecordKey(recordID, nil)
// A new row with current txn-id as lockKey
err = txn.Set([]byte(k), []byte(txn.String()))
if err != nil {
return 0, err
}
然后把一行数据拆分成多个KV对存储:
// split a record into multiple kv pair
// first key -> LOCK
k := t.RecordKey(recordID, nil)//获取数据行标识对应的key
// A new row with current txn-id as lockKey
err = txn.Set([]byte(k), []byte(txn.String()))
if err != nil {
return 0, err
}
// column key -> column value
for _, c := range t.Cols() {//把每一列数据拆分成一个kv对存储
colKey := t.RecordKey(recordID, c)
data, err := t.EncodeValue(r[c.Offset])
if err != nil {
return 0, err
}
err = txn.Set([]byte(colKey), data)
if err != nil {
return 0, err
}
}
variable.GetSessionVars(ctx).AddAffectedRows(1)//标记影响行+1
return recordID, nil
}
其中err = txn.Set([]byte(k), []byte(txn.String())),设置行Key到Kv,值是txn.String(),其中txn.String()获取了txn.tID,txn.tID是在https://github.com/pingcap/tidb/blob/0d6f270068e8ff2aedc1c314e907771b6a508ebd/store/localstore/kv.go#L110-L110初始化的,而globalID是初始为0的全局变量,放在内存的,每次启动归0,除了日志输出和这里的写入,暂时没发现还有什么用。
接着,回到:
h, err := t.AddRecord(ctx, r)
if err == nil {
continue
}
由于值插入一行数据,所以退出这个Exec函数:
for i, list := range s.Lists {
r := make([]interface{}, len(tableCols))
valueCount := len(list)
if insertValueCount != valueCount {
// "insert into t values (), ()" is valid.
// "insert into t values (), (1)" is not valid.
// "insert into t values (1), ()" is not valid.
// "insert into t values (1,2), (1)" is not valid.
// So the value count must be same for all insert list.
return nil, errors.Errorf("Column count doesn't match value count at row %d", i+1)
}
if valueCount == 0 && len(s.ColNames) > 0 {
// "insert into t (c1) values ()" is not valid.
return nil, errors.Errorf("INSERT INTO %s: expected %d value(s), have %d", s.TableIdent, len(s.ColNames), 0)
} else if valueCount > 0 && valueCount != len(cols) {
return nil, errors.Errorf("INSERT INTO %s: expected %d value(s), have %d", s.TableIdent, len(cols), valueCount)
}
// Clear last insert id.
variable.GetSessionVars(ctx).SetLastInsertID(0)
marked := make(map[int]struct{}, len(list))
for i, expr := range list {
// For "insert into t values (default)" Default Eval.
m[expressions.ExprEvalDefaultName] = cols[i].Name.O
val, err := expr.Eval(ctx, m)
if err != nil {
return nil, errors.Trace(err)
}
r[cols[i].Offset] = val
marked[cols[i].Offset] = struct{}{}
}
if err := s.initDefaultValues(ctx, t, tableCols, r, marked); err != nil {
return nil, errors.Trace(err)
}
if err = column.CastValues(ctx, r, cols); err != nil {
return nil, errors.Trace(err)
}
if err = column.CheckNotNull(tableCols, r); err != nil {
return nil, errors.Trace(err)
}
// Notes: incompatible with mysql
// MySQL will set last insert id to the first row, as follows:
// `t(id int AUTO_INCREMENT, c1 int, PRIMARY KEY (id))`
// `insert t (c1) values(1),(2),(3);`
// Last insert id will be 1, not 3.
h, err := t.AddRecord(ctx, r)
if err == nil {
continue
}
if len(s.OnDuplicate) == 0 || !errors2.ErrorEqual(err, kv.ErrConditionNotMatch) {
return nil, errors.Trace(err)
}
// On duplicate key Update the duplicate row.
// Evaluate the updated value.
// TODO: report rows affected and last insert id.
toUpdateColumns, err := getUpdateColumns(t, s.OnDuplicate)
if err != nil {
return nil, errors.Trace(err)
}
data, err := t.Row(ctx, h)
if err != nil {
return nil, errors.Trace(err)
}
err = updateRecord(ctx, h, data, t, toUpdateColumns, s.OnDuplicate, r)
if err != nil {
return nil, errors.Trace(err)
}
}
return nil, nil
}
之后就和前面一样的退出了,到此,这个语句执行完成了。
小结
插入数据的流程:检查SQL中插入字段的合法性—>计算出对应表达式值(如各种内置函数等)并提取出插入的数据行->对SQL没有提到的字段设置为默认值,并检查空值->生成唯一行号->创建涉及到的所有索引->把一行数据拆分成多个KV对存储。