Commit c7891c1b authored by gwenn's avatar gwenn

Check statement column count is consistent.

parent bf209594
......@@ -350,7 +350,14 @@ func (c *Conn) Exists(query string, args ...interface{}) (bool, error) {
return false, err
}
defer s.Finalize()
return s.Next()
ok, err := s.Next()
if err != nil {
return false, err
}
if s.ColumnCount() == 0 {
return false, s.specificError("don't use Exists with query that returns no data such as %q", query)
}
return ok, nil
}
// OneValue is used with SELECT that returns only one row with only one column.
......@@ -366,6 +373,9 @@ func (c *Conn) OneValue(query string, value interface{}, args ...interface{}) er
if err != nil {
return err
} else if !b {
if s.ColumnCount() == 0 {
return s.specificError("don't use OneValue with query that returns no data such as %q", query)
}
return io.EOF
}
return s.Scan(value)
......
......@@ -163,6 +163,9 @@ func (s *Stmt) exec() error {
}
return s.error(rv, "Stmt.exec")
}
if s.ColumnCount() > 0 {
return s.specificError("don't use exec with anything that returns data such as %q", s.SQL())
}
return nil
}
......@@ -211,6 +214,9 @@ func (s *Stmt) Select(rowCallbackHandler func(s *Stmt) error, args ...interface{
return err
}
}
if s.ColumnCount() == 0 {
return s.specificError("don't use Select with query that returns no data such as %q", s.SQL())
}
for {
if ok, err := s.Next(); err != nil {
return err
......@@ -233,6 +239,9 @@ func (s *Stmt) SelectOneRow(args ...interface{}) (bool, error) {
if ok, err := s.Next(); err != nil {
return false, err
} else if !ok {
if s.ColumnCount() == 0 {
return false, s.specificError("don't use SelectOneRow with query that returns no data such as %q", s.SQL())
}
return false, nil
}
return true, s.Scan(args...)
......
......@@ -291,8 +291,14 @@ func TestStmtSelectWithInsert(t *testing.T) {
defer s.Finalize()
exists, err := s.SelectOneRow()
checkNoError(t, err, "SELECT error: %s")
assert.T(t, !exists, "no row expected")
assert.T(t, err != nil, "error expected")
//println(err.Error())
if serr, ok := err.(*StmtError); ok {
assert.Equal(t, ErrSpecific, serr.Code())
} else {
t.Errorf("Expected StmtError but got %s", reflect.TypeOf(err))
}
assert.T(t, !exists, "false expected")
}
func TestNamedBind(t *testing.T) {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment