1
2
3
4
5
6
7 package sql
8
9 import (
10 "bytes"
11 "database/sql/driver"
12 "errors"
13 "fmt"
14 "reflect"
15 "strconv"
16 "time"
17 "unicode"
18 "unicode/utf8"
19 _ "unsafe"
20 )
21
22 var errNilPtr = errors.New("destination pointer is nil")
23
24 func describeNamedValue(nv *driver.NamedValue) string {
25 if len(nv.Name) == 0 {
26 return fmt.Sprintf("$%d", nv.Ordinal)
27 }
28 return fmt.Sprintf("with name %q", nv.Name)
29 }
30
31 func validateNamedValueName(name string) error {
32 if len(name) == 0 {
33 return nil
34 }
35 r, _ := utf8.DecodeRuneInString(name)
36 if unicode.IsLetter(r) {
37 return nil
38 }
39 return fmt.Errorf("name %q does not begin with a letter", name)
40 }
41
42
43
44
45 type ccChecker struct {
46 cci driver.ColumnConverter
47 want int
48 }
49
50 func (c ccChecker) CheckNamedValue(nv *driver.NamedValue) error {
51 if c.cci == nil {
52 return driver.ErrSkip
53 }
54
55
56
57 index := nv.Ordinal - 1
58 if c.want <= index {
59 return nil
60 }
61
62
63
64
65 if vr, ok := nv.Value.(driver.Valuer); ok {
66 sv, err := callValuerValue(vr)
67 if err != nil {
68 return err
69 }
70 if !driver.IsValue(sv) {
71 return fmt.Errorf("non-subset type %T returned from Value", sv)
72 }
73 nv.Value = sv
74 }
75
76
77
78
79
80
81
82
83 var err error
84 arg := nv.Value
85 nv.Value, err = c.cci.ColumnConverter(index).ConvertValue(arg)
86 if err != nil {
87 return err
88 }
89 if !driver.IsValue(nv.Value) {
90 return fmt.Errorf("driver ColumnConverter error converted %T to unsupported type %T", arg, nv.Value)
91 }
92 return nil
93 }
94
95
96
97
98 func defaultCheckNamedValue(nv *driver.NamedValue) (err error) {
99 nv.Value, err = driver.DefaultParameterConverter.ConvertValue(nv.Value)
100 return err
101 }
102
103
104
105
106
107
108
109 func driverArgsConnLocked(ci driver.Conn, ds *driverStmt, args []any) ([]driver.NamedValue, error) {
110 nvargs := make([]driver.NamedValue, len(args))
111
112
113
114
115 want := -1
116
117 var si driver.Stmt
118 var cc ccChecker
119 if ds != nil {
120 si = ds.si
121 want = ds.si.NumInput()
122 cc.want = want
123 }
124
125
126
127
128
129 nvc, ok := si.(driver.NamedValueChecker)
130 if !ok {
131 nvc, _ = ci.(driver.NamedValueChecker)
132 }
133 cci, ok := si.(driver.ColumnConverter)
134 if ok {
135 cc.cci = cci
136 }
137
138
139
140
141
142
143 var err error
144 var n int
145 for _, arg := range args {
146 nv := &nvargs[n]
147 if np, ok := arg.(NamedArg); ok {
148 if err = validateNamedValueName(np.Name); err != nil {
149 return nil, err
150 }
151 arg = np.Value
152 nv.Name = np.Name
153 }
154 nv.Ordinal = n + 1
155 nv.Value = arg
156
157
158
159
160
161
162
163
164
165
166
167
168 checker := defaultCheckNamedValue
169 nextCC := false
170 switch {
171 case nvc != nil:
172 nextCC = cci != nil
173 checker = nvc.CheckNamedValue
174 case cci != nil:
175 checker = cc.CheckNamedValue
176 }
177
178 nextCheck:
179 err = checker(nv)
180 switch err {
181 case nil:
182 n++
183 continue
184 case driver.ErrRemoveArgument:
185 nvargs = nvargs[:len(nvargs)-1]
186 continue
187 case driver.ErrSkip:
188 if nextCC {
189 nextCC = false
190 checker = cc.CheckNamedValue
191 } else {
192 checker = defaultCheckNamedValue
193 }
194 goto nextCheck
195 default:
196 return nil, fmt.Errorf("sql: converting argument %s type: %w", describeNamedValue(nv), err)
197 }
198 }
199
200
201
202 if want != -1 && len(nvargs) != want {
203 return nil, fmt.Errorf("sql: expected %d arguments, got %d", want, len(nvargs))
204 }
205
206 return nvargs, nil
207 }
208
209
210
211
212
213
214
215
216
217
218
219
220
221 func convertAssign(dest, src any) error {
222 return convertAssignRows(dest, src, nil)
223 }
224
225
226
227
228
229
230 func convertAssignRows(dest, src any, rows *Rows) error {
231
232 switch s := src.(type) {
233 case string:
234 switch d := dest.(type) {
235 case *string:
236 if d == nil {
237 return errNilPtr
238 }
239 *d = s
240 return nil
241 case *[]byte:
242 if d == nil {
243 return errNilPtr
244 }
245 *d = []byte(s)
246 return nil
247 case *RawBytes:
248 if d == nil {
249 return errNilPtr
250 }
251 *d = rows.setrawbuf(append(rows.rawbuf(), s...))
252 return nil
253 }
254 case []byte:
255 switch d := dest.(type) {
256 case *string:
257 if d == nil {
258 return errNilPtr
259 }
260 *d = string(s)
261 return nil
262 case *any:
263 if d == nil {
264 return errNilPtr
265 }
266 *d = bytes.Clone(s)
267 return nil
268 case *[]byte:
269 if d == nil {
270 return errNilPtr
271 }
272 *d = bytes.Clone(s)
273 return nil
274 case *RawBytes:
275 if d == nil {
276 return errNilPtr
277 }
278 *d = s
279 return nil
280 }
281 case time.Time:
282 switch d := dest.(type) {
283 case *time.Time:
284 *d = s
285 return nil
286 case *string:
287 *d = s.Format(time.RFC3339Nano)
288 return nil
289 case *[]byte:
290 if d == nil {
291 return errNilPtr
292 }
293 *d = []byte(s.Format(time.RFC3339Nano))
294 return nil
295 case *RawBytes:
296 if d == nil {
297 return errNilPtr
298 }
299 *d = rows.setrawbuf(s.AppendFormat(rows.rawbuf(), time.RFC3339Nano))
300 return nil
301 }
302 case decimalDecompose:
303 switch d := dest.(type) {
304 case decimalCompose:
305 return d.Compose(s.Decompose(nil))
306 }
307 case nil:
308 switch d := dest.(type) {
309 case *any:
310 if d == nil {
311 return errNilPtr
312 }
313 *d = nil
314 return nil
315 case *[]byte:
316 if d == nil {
317 return errNilPtr
318 }
319 *d = nil
320 return nil
321 case *RawBytes:
322 if d == nil {
323 return errNilPtr
324 }
325 *d = nil
326 return nil
327 }
328
329 case driver.Rows:
330 switch d := dest.(type) {
331 case *Rows:
332 if d == nil {
333 return errNilPtr
334 }
335 if rows == nil {
336 return errors.New("invalid context to convert cursor rows, missing parent *Rows")
337 }
338 rows.closemu.Lock()
339 *d = Rows{
340 dc: rows.dc,
341 releaseConn: func(error) {},
342 rowsi: s,
343 }
344
345 parentCancel := rows.cancel
346 rows.cancel = func() {
347
348
349 d.close(rows.lasterr)
350 if parentCancel != nil {
351 parentCancel()
352 }
353 }
354 rows.closemu.Unlock()
355 return nil
356 }
357 }
358
359 var sv reflect.Value
360
361 switch d := dest.(type) {
362 case *string:
363 sv = reflect.ValueOf(src)
364 switch sv.Kind() {
365 case reflect.Bool,
366 reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
367 reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
368 reflect.Float32, reflect.Float64:
369 *d = asString(src)
370 return nil
371 }
372 case *[]byte:
373 sv = reflect.ValueOf(src)
374 if b, ok := asBytes(nil, sv); ok {
375 *d = b
376 return nil
377 }
378 case *RawBytes:
379 sv = reflect.ValueOf(src)
380 if b, ok := asBytes(rows.rawbuf(), sv); ok {
381 *d = rows.setrawbuf(b)
382 return nil
383 }
384 case *bool:
385 bv, err := driver.Bool.ConvertValue(src)
386 if err == nil {
387 *d = bv.(bool)
388 }
389 return err
390 case *any:
391 *d = src
392 return nil
393 }
394
395 if scanner, ok := dest.(Scanner); ok {
396 return scanner.Scan(src)
397 }
398
399 dpv := reflect.ValueOf(dest)
400 if dpv.Kind() != reflect.Pointer {
401 return errors.New("destination not a pointer")
402 }
403 if dpv.IsNil() {
404 return errNilPtr
405 }
406
407 if !sv.IsValid() {
408 sv = reflect.ValueOf(src)
409 }
410
411 dv := reflect.Indirect(dpv)
412 if sv.IsValid() && sv.Type().AssignableTo(dv.Type()) {
413 switch b := src.(type) {
414 case []byte:
415 dv.Set(reflect.ValueOf(bytes.Clone(b)))
416 default:
417 dv.Set(sv)
418 }
419 return nil
420 }
421
422 if dv.Kind() == sv.Kind() && sv.Type().ConvertibleTo(dv.Type()) {
423 dv.Set(sv.Convert(dv.Type()))
424 return nil
425 }
426
427
428
429
430
431
432 switch dv.Kind() {
433 case reflect.Pointer:
434 if src == nil {
435 dv.SetZero()
436 return nil
437 }
438 dv.Set(reflect.New(dv.Type().Elem()))
439 return convertAssignRows(dv.Interface(), src, rows)
440 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
441 if src == nil {
442 return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind())
443 }
444 s := asString(src)
445 i64, err := strconv.ParseInt(s, 10, dv.Type().Bits())
446 if err != nil {
447 err = strconvErr(err)
448 return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
449 }
450 dv.SetInt(i64)
451 return nil
452 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
453 if src == nil {
454 return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind())
455 }
456 s := asString(src)
457 u64, err := strconv.ParseUint(s, 10, dv.Type().Bits())
458 if err != nil {
459 err = strconvErr(err)
460 return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
461 }
462 dv.SetUint(u64)
463 return nil
464 case reflect.Float32, reflect.Float64:
465 if src == nil {
466 return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind())
467 }
468 s := asString(src)
469 f64, err := strconv.ParseFloat(s, dv.Type().Bits())
470 if err != nil {
471 err = strconvErr(err)
472 return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
473 }
474 dv.SetFloat(f64)
475 return nil
476 case reflect.String:
477 if src == nil {
478 return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind())
479 }
480 switch v := src.(type) {
481 case string:
482 dv.SetString(v)
483 return nil
484 case []byte:
485 dv.SetString(string(v))
486 return nil
487 }
488 }
489
490 return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, dest)
491 }
492
493 func strconvErr(err error) error {
494 if ne, ok := err.(*strconv.NumError); ok {
495 return ne.Err
496 }
497 return err
498 }
499
500 func asString(src any) string {
501 switch v := src.(type) {
502 case string:
503 return v
504 case []byte:
505 return string(v)
506 }
507 rv := reflect.ValueOf(src)
508 switch rv.Kind() {
509 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
510 return strconv.FormatInt(rv.Int(), 10)
511 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
512 return strconv.FormatUint(rv.Uint(), 10)
513 case reflect.Float64:
514 return strconv.FormatFloat(rv.Float(), 'g', -1, 64)
515 case reflect.Float32:
516 return strconv.FormatFloat(rv.Float(), 'g', -1, 32)
517 case reflect.Bool:
518 return strconv.FormatBool(rv.Bool())
519 }
520 return fmt.Sprintf("%v", src)
521 }
522
523 func asBytes(buf []byte, rv reflect.Value) (b []byte, ok bool) {
524 switch rv.Kind() {
525 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
526 return strconv.AppendInt(buf, rv.Int(), 10), true
527 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
528 return strconv.AppendUint(buf, rv.Uint(), 10), true
529 case reflect.Float32:
530 return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 32), true
531 case reflect.Float64:
532 return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 64), true
533 case reflect.Bool:
534 return strconv.AppendBool(buf, rv.Bool()), true
535 case reflect.String:
536 s := rv.String()
537 return append(buf, s...), true
538 }
539 return
540 }
541
542 var valuerReflectType = reflect.TypeFor[driver.Valuer]()
543
544
545
546
547
548
549
550
551
552
553
554
555 func callValuerValue(vr driver.Valuer) (v driver.Value, err error) {
556 if rv := reflect.ValueOf(vr); rv.Kind() == reflect.Pointer &&
557 rv.IsNil() &&
558 rv.Type().Elem().Implements(valuerReflectType) {
559 return nil, nil
560 }
561 return vr.Value()
562 }
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585 type decimal interface {
586 decimalDecompose
587 decimalCompose
588 }
589
590 type decimalDecompose interface {
591
592
593
594 Decompose(buf []byte) (form byte, negative bool, coefficient []byte, exponent int32)
595 }
596
597 type decimalCompose interface {
598
599
600 Compose(form byte, negative bool, coefficient []byte, exponent int32) error
601 }
602
View as plain text