diff --git a/CHANGELOG.md b/CHANGELOG.md index b91df0bc5..90e483fcd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,5 @@ +* Moved `internal/decimal` package to `pkg/decimal` for public usage + ## v3.122.0 * Added `trace.NodeHintInfo` field for OnPoolGet trace callback which stores info for node hint misses * Added `ydb_go_sdk_ydb_table_pool_node_hint_miss` and `ydb_go_sdk_ydb_query_pool_node_hint_miss` metrics for node hint misses @@ -11,7 +13,7 @@ * Changed internal pprof label to pyroscope supported format * Added `query.ImplicitTxControl()` transaction control (the same as `query.NoTx()` and `query.EmptyTxControl()`). See more about implicit transactions on [ydb.tech](https://ydb.tech/docs/en/concepts/transactions?version=v25.2#implicit) * Added `SnapshotReadWrite` isolation mode support to `database/sql` driver using `sql.TxOptions{Isolation: sql.LevelSnapshot, ReadOnly: false}` -* Move `internal/ratelimiter/options` to `ratelimiter/options` for public usage +* Moved `internal/ratelimiter/options` to `ratelimiter/options` for public usage ## v3.120.0 * Added support of `SnapshotReadWrite` isolation mode into query and table clients diff --git a/internal/decimal/decimal_test.go b/internal/decimal/decimal_test.go deleted file mode 100644 index d6945135f..000000000 --- a/internal/decimal/decimal_test.go +++ /dev/null @@ -1,86 +0,0 @@ -package decimal - -import ( - "encoding/binary" - "testing" -) - -func TestFromBytes(t *testing.T) { - for _, test := range []struct { - name string - bts []byte - precision uint32 - scale uint32 - format string - }{ - { - bts: uint128(0xffffffffffffffff, 0xffffffffffffffff), - precision: 22, - scale: 9, - format: "-0.000000001", - }, - { - bts: uint128(0xffffffffffffffff, 0), - precision: 22, - scale: 9, - format: "-18446744073.709551616", - }, - { - bts: uint128(0x4000000000000000, 0), - precision: 22, - scale: 9, - format: "inf", - }, - { - bts: uint128(0x8000000000000000, 0), - precision: 22, - scale: 9, - format: "-inf", - }, - { - bts: uint128s(1000000000), - precision: 22, - scale: 9, - format: "1.000000000", - }, - { - bts: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 250, 240, 128}, - precision: 22, - scale: 9, - format: "0.050000000", - }, - } { - t.Run(test.name, func(t *testing.T) { - x := FromBytes(test.bts, test.precision, test.scale) - p := Append(nil, x) - y := FromBytes(p, test.precision, test.scale) - if x.Cmp(y) != 0 { - t.Errorf( - "parsed bytes serialized to different value: %v; want %v", - x, y, - ) - } - formatted := Format(x, test.precision, test.scale) - if test.format != formatted { - t.Errorf("unexpected decimal format. Expected: %s, actual %s", test.format, formatted) - } - t.Logf( - "%s %s", - Format(x, test.precision, test.scale), - Format(y, test.precision, test.scale), - ) - }) - } -} - -func uint128(hi, lo uint64) []byte { - p := make([]byte, 16) - binary.BigEndian.PutUint64(p[:8], hi) - binary.BigEndian.PutUint64(p[8:], lo) - - return p -} - -func uint128s(lo uint64) []byte { - return uint128(0, lo) -} diff --git a/internal/decimal/type.go b/internal/decimal/type.go deleted file mode 100644 index 89956a761..000000000 --- a/internal/decimal/type.go +++ /dev/null @@ -1,19 +0,0 @@ -package decimal - -import "math/big" - -type Decimal struct { - Bytes [16]byte - Precision uint32 - Scale uint32 -} - -func (d *Decimal) String() string { - v := FromInt128(d.Bytes, d.Precision, d.Scale) - - return Format(v, d.Precision, d.Scale) -} - -func (d *Decimal) BigInt() *big.Int { - return FromInt128(d.Bytes, d.Precision, d.Scale) -} diff --git a/internal/query/scanner/struct_test.go b/internal/query/scanner/struct_test.go index 8fa23b03b..5fdb441f4 100644 --- a/internal/query/scanner/struct_test.go +++ b/internal/query/scanner/struct_test.go @@ -9,10 +9,10 @@ import ( "github.com/stretchr/testify/require" "github.com/ydb-platform/ydb-go-genproto/protos/Ydb" - "github.com/ydb-platform/ydb-go-sdk/v3/internal/decimal" "github.com/ydb-platform/ydb-go-sdk/v3/internal/value" + "github.com/ydb-platform/ydb-go-sdk/v3/pkg/decimal" "github.com/ydb-platform/ydb-go-sdk/v3/pkg/xtest" - ttypes "github.com/ydb-platform/ydb-go-sdk/v3/table/types" + "github.com/ydb-platform/ydb-go-sdk/v3/table/types" ) func TestFieldName(t *testing.T) { @@ -934,9 +934,9 @@ func TestScannerDecimal(t *testing.T) { }, )) var row struct { - A ttypes.Decimal + A types.Decimal } - expected := ttypes.Decimal{Bytes: decimal.BigIntToByte(big.NewInt(10200000000), 22, 9), Precision: 22, Scale: 9} + expected := types.Decimal{Bytes: decimal.BigIntToByte(big.NewInt(10200000000), 22), Precision: 22, Scale: 9} err := scanner.ScanStruct(&row) require.NoError(t, err) require.Equal(t, expected, row.A) @@ -964,9 +964,9 @@ func TestScannerDecimalNegative(t *testing.T) { }, )) var row struct { - A ttypes.Decimal + A types.Decimal } - expected := ttypes.Decimal{Bytes: decimal.BigIntToByte(big.NewInt(-2005000000), 22, 9), Precision: 22, Scale: 9} + expected := types.Decimal{Bytes: decimal.BigIntToByte(big.NewInt(-2005000000), 22), Precision: 22, Scale: 9} err := scanner.ScanStruct(&row) require.NoError(t, err) require.Equal(t, expected, row.A) @@ -995,7 +995,7 @@ func TestScannerDecimalBigDecimal(t *testing.T) { }, )) var row struct { - A ttypes.Decimal + A types.Decimal } expectedVal := decimal.Decimal{ Bytes: [16]byte{0, 19, 66, 97, 114, 199, 77, 130, 43, 135, 143, 232, 0, 0, 0, 0}, diff --git a/internal/scanner/scanner.go b/internal/scanner/scanner.go index cdf53fc08..cdb4e8493 100644 --- a/internal/scanner/scanner.go +++ b/internal/scanner/scanner.go @@ -6,9 +6,9 @@ import ( "github.com/google/uuid" - "github.com/ydb-platform/ydb-go-sdk/v3/internal/decimal" "github.com/ydb-platform/ydb-go-sdk/v3/internal/types" "github.com/ydb-platform/ydb-go-sdk/v3/internal/value" + "github.com/ydb-platform/ydb-go-sdk/v3/pkg/decimal" ) // RawValue scanning non-primitive yql types or for own implementation scanner native API diff --git a/internal/table/scanner/scan_raw.go b/internal/table/scanner/scan_raw.go index 4ef4e565c..db385bf52 100644 --- a/internal/table/scanner/scan_raw.go +++ b/internal/table/scanner/scan_raw.go @@ -12,10 +12,10 @@ import ( "github.com/google/uuid" "github.com/ydb-platform/ydb-go-genproto/protos/Ydb" - "github.com/ydb-platform/ydb-go-sdk/v3/internal/decimal" "github.com/ydb-platform/ydb-go-sdk/v3/internal/types" "github.com/ydb-platform/ydb-go-sdk/v3/internal/value" "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors" + "github.com/ydb-platform/ydb-go-sdk/v3/pkg/decimal" "github.com/ydb-platform/ydb-go-sdk/v3/pkg/xstring" ) diff --git a/internal/table/scanner/scanner.go b/internal/table/scanner/scanner.go index a138544df..31473353d 100644 --- a/internal/table/scanner/scanner.go +++ b/internal/table/scanner/scanner.go @@ -12,12 +12,12 @@ import ( "github.com/google/uuid" "github.com/ydb-platform/ydb-go-genproto/protos/Ydb" - "github.com/ydb-platform/ydb-go-sdk/v3/internal/decimal" "github.com/ydb-platform/ydb-go-sdk/v3/internal/scanner" internalTypes "github.com/ydb-platform/ydb-go-sdk/v3/internal/types" "github.com/ydb-platform/ydb-go-sdk/v3/internal/value" "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors" "github.com/ydb-platform/ydb-go-sdk/v3/internal/xsync" + "github.com/ydb-platform/ydb-go-sdk/v3/pkg/decimal" "github.com/ydb-platform/ydb-go-sdk/v3/pkg/xstring" "github.com/ydb-platform/ydb-go-sdk/v3/table/options" "github.com/ydb-platform/ydb-go-sdk/v3/table/result" diff --git a/internal/value/any.go b/internal/value/any.go index 340398e30..bfe170f7e 100644 --- a/internal/value/any.go +++ b/internal/value/any.go @@ -2,6 +2,7 @@ package value import ( "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors" + "github.com/ydb-platform/ydb-go-sdk/v3/pkg/decimal" "github.com/ydb-platform/ydb-go-sdk/v3/pkg/xstring" ) @@ -88,6 +89,8 @@ func Any(v Value) (any, error) { //nolint:funlen,gocyclo return xstring.ToBytes(string(vv)), nil case jsonDocumentValue: return xstring.ToBytes(string(vv)), nil + case *decimalValue: + return decimal.ToDecimal(vv), nil default: return v, nil } diff --git a/internal/value/value.go b/internal/value/value.go index 6832f34cb..2b01b3e88 100644 --- a/internal/value/value.go +++ b/internal/value/value.go @@ -15,9 +15,9 @@ import ( "github.com/google/uuid" "github.com/ydb-platform/ydb-go-genproto/protos/Ydb" - "github.com/ydb-platform/ydb-go-sdk/v3/internal/decimal" "github.com/ydb-platform/ydb-go-sdk/v3/internal/types" "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors" + "github.com/ydb-platform/ydb-go-sdk/v3/pkg/decimal" "github.com/ydb-platform/ydb-go-sdk/v3/pkg/xstring" ) @@ -581,29 +581,15 @@ func Datetime64ValueFromTime(t time.Time) datetime64Value { return datetime64Value(t.Unix()) } -var _ DecimalValuer = (*decimalValue)(nil) +var _ decimal.Interface = (*decimalValue)(nil) type decimalValue struct { value [16]byte innerType *types.Decimal } -func (v *decimalValue) Value() [16]byte { - return v.value -} - -func (v *decimalValue) Precision() uint32 { - return v.innerType.Precision() -} - -func (v *decimalValue) Scale() uint32 { - return v.innerType.Scale() -} - -type DecimalValuer interface { - Value() [16]byte - Precision() uint32 - Scale() uint32 +func (v *decimalValue) Decimal() (bytes [16]byte, precision uint32, scale uint32) { + return v.value, v.innerType.Precision(), v.innerType.Scale() } func (v *decimalValue) castTo(dst any) error { @@ -613,8 +599,7 @@ func (v *decimalValue) castTo(dst any) error { return nil case *decimal.Decimal: - decVal := decimal.Decimal{Bytes: v.value, Precision: v.Precision(), Scale: v.Scale()} - *dstValue = decVal + *dstValue = *decimal.ToDecimal(v) return nil default: @@ -631,7 +616,7 @@ func (v *decimalValue) Yql() string { buffer.WriteString(v.innerType.Name()) buffer.WriteByte('(') buffer.WriteByte('"') - s := decimal.FromBytes(v.value[:], v.innerType.Precision(), v.innerType.Scale()).String() + s := decimal.FromBytes(v.value[:], v.innerType.Precision()).String() if len(s) < int(v.innerType.Scale()) { s = strings.Repeat("0", int(v.innerType.Scale())-len(s)) + s } @@ -665,7 +650,7 @@ func (v *decimalValue) toYDB() *Ydb.Value { } func DecimalValueFromBigInt(v *big.Int, precision, scale uint32) *decimalValue { - b := decimal.BigIntToByte(v, precision, scale) + b := decimal.BigIntToByte(v, precision) return DecimalValue(b, precision, scale) } diff --git a/internal/value/value_test.go b/internal/value/value_test.go index 4dcc46fdf..55ac2f736 100644 --- a/internal/value/value_test.go +++ b/internal/value/value_test.go @@ -1926,9 +1926,10 @@ func TestDecimalValue(t *testing.T) { decBytes := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} v := DecimalValue(decBytes, 22, 9) require.NotNil(t, v) - require.Equal(t, decBytes, v.Value()) - require.Equal(t, uint32(22), v.Precision()) - require.Equal(t, uint32(9), v.Scale()) + bytes, precision, scale := v.Decimal() + require.Equal(t, decBytes, bytes) + require.Equal(t, uint32(22), precision) + require.Equal(t, uint32(9), scale) }) t.Run("FromString", func(t *testing.T) { diff --git a/internal/decimal/README.md b/pkg/decimal/README.md similarity index 100% rename from internal/decimal/README.md rename to pkg/decimal/README.md diff --git a/internal/decimal/decimal.go b/pkg/decimal/decimal.go similarity index 76% rename from internal/decimal/decimal.go rename to pkg/decimal/decimal.go index ffd07c661..85e78ae0e 100644 --- a/internal/decimal/decimal.go +++ b/pkg/decimal/decimal.go @@ -1,9 +1,13 @@ package decimal import ( + "database/sql/driver" + "fmt" "math/big" "math/bits" + "strings" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors" "github.com/ydb-platform/ydb-go-sdk/v3/pkg/xstring" ) @@ -31,6 +35,70 @@ const ( errorTag = "" ) +// ParseDecimal parses a decimal string into a big.Int and exponent. +// Returns (n, e) such that n * 10^(-e) equals the original number. +func ParseDecimal(s string) (_ *big.Int, exp uint32, _ error) { + dotIndex := strings.Index(s, ".") + if dotIndex == -1 { + n := &big.Int{} + if _, ok := n.SetString(s, 10); !ok { + return nil, 0, xerrors.WithStackTrace(fmt.Errorf("invalid integer: %s", s)) + } + + return n, 0, nil + } + + integerPart := s[:dotIndex] + fractionalPart := s[dotIndex+1:] + + combined := integerPart + fractionalPart + n := &big.Int{} + if _, ok := n.SetString(combined, 10); !ok { + return nil, 0, xerrors.WithStackTrace(fmt.Errorf("invalid number: %s", s)) + } + + return n, uint32(len(fractionalPart)), nil +} + +func (d *Decimal) apply(value any) error { + switch v := value.(type) { + case *Decimal: + d.Bytes = v.Bytes + d.Precision = v.Precision + d.Scale = v.Scale + + return nil + case Interface: + d.Bytes, d.Precision, d.Scale = v.Decimal() + + return nil + case string: + vv, exp, err := ParseDecimal(v) + if err != nil { + return xerrors.WithStackTrace(err) + } + + d.Scale = exp + d.Precision = bufferSize - exp - 3 + d.Bytes = BigIntToByte(vv, d.Precision) + + return nil + case driver.Valuer: + vv, err := v.Value() + if err != nil { + return xerrors.WithStackTrace(err) + } + + if err := d.apply(vv); err != nil { + return xerrors.WithStackTrace(err) + } + + return nil + default: + return xerrors.WithStackTrace(fmt.Errorf("cannot apply '%T' to '%T'", v, d)) + } +} + // IsInf reports whether x is an infinity. func IsInf(x *big.Int) bool { return x.CmpAbs(inf) == 0 } @@ -54,7 +122,7 @@ func Err() *big.Int { return big.NewInt(0).Set(err) } // // If given bytes contains value that is greater than given precision it // returns infinity or negative infinity value accordingly the bytes sign. -func FromBytes(bts []byte, precision, scale uint32) *big.Int { +func FromBytes(bts []byte, precision uint32) *big.Int { v := big.NewInt(0) if len(bts) == 0 { return v @@ -82,8 +150,8 @@ func FromBytes(bts []byte, precision, scale uint32) *big.Int { // FromInt128 returns big integer from given array. That is, it interprets // 16-byte array as 128-bit integer. -func FromInt128(p [16]byte, precision, scale uint32) *big.Int { - return FromBytes(p[:], precision, scale) +func FromInt128(p [16]byte, precision uint32) *big.Int { + return FromBytes(p[:], precision) } // Parse interprets a string s with the given precision and scale and returns @@ -194,24 +262,22 @@ func Parse(s string, precision, scale uint32) (*big.Int, error) { // scale. // //nolint:funlen -func Format(x *big.Int, precision, scale uint32) string { +func Format(x *big.Int, precision, scale uint32, trimTrailingZeros bool) string { switch { + case x == nil: + return "0" case x.CmpAbs(inf) == 0: if x.Sign() < 0 { return "-inf" } return "inf" - - case x.CmpAbs(nan) == 0: + case x.CmpAbs(nan) == 0, precision == 0: if x.Sign() < 0 { return "-nan" } return "nan" - - case x == nil: - return "0" } v := big.NewInt(0).Set(x) @@ -237,12 +303,17 @@ func Format(x *big.Int, precision, scale uint32) string { d := int(digit.Int64()) if d != 0 || scale == 0 || pos > 0 { const numbers = "0123456789" - pos-- - bts[pos] = numbers[d] + if d != 0 { + trimTrailingZeros = false + } + if !trimTrailingZeros { + pos-- + bts[pos] = numbers[d] + } } if scale > 0 { scale-- - if scale == 0 && pos > 0 { + if scale == 0 && pos > 0 && pos < bufferSize { pos-- bts[pos] = '.' } @@ -277,7 +348,7 @@ func Format(x *big.Int, precision, scale uint32) string { // // If x value does not fit in 16 bytes with given precision, it returns 16-byte // representation of infinity or negative infinity value accordingly to x's sign. -func BigIntToByte(x *big.Int, precision, scale uint32) (p [16]byte) { +func BigIntToByte(x *big.Int, precision uint32) (p [16]byte) { if !IsInf(x) && !IsNaN(x) && !IsErr(x) && x.CmpAbs(pow(ten, precision)) >= 0 { if x.Sign() < 0 { x = neginf diff --git a/pkg/decimal/decimal_test.go b/pkg/decimal/decimal_test.go new file mode 100644 index 000000000..40ac85297 --- /dev/null +++ b/pkg/decimal/decimal_test.go @@ -0,0 +1,830 @@ +package decimal + +import ( + "database/sql/driver" + "encoding/binary" + "errors" + "math/big" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestFromBytes(t *testing.T) { + for _, tt := range []struct { + name string + bts []byte + precision uint32 + scale uint32 + format map[bool]string + }{ + { + bts: uint128(0xffffffffffffffff, 0xffffffffffffffff), + precision: 22, + scale: 9, + format: map[bool]string{ + false: "-0.000000001", + true: "-0.000000001", + }, + }, + { + bts: uint128(0xffffffffffffffff, 0), + precision: 22, + scale: 9, + format: map[bool]string{ + false: "-18446744073.709551616", + true: "-18446744073.709551616", + }, + }, + { + bts: uint128(0x4000000000000000, 0), + precision: 22, + scale: 9, + format: map[bool]string{ + false: "inf", + true: "inf", + }, + }, + { + bts: uint128(0x8000000000000000, 0), + precision: 22, + scale: 9, + format: map[bool]string{ + false: "-inf", + true: "-inf", + }, + }, + { + bts: uint128s(1000000000), + precision: 22, + scale: 9, + format: map[bool]string{ + false: "1.000000000", + true: "1", + }, + }, + { + bts: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 250, 240, 128}, + precision: 22, + scale: 9, + format: map[bool]string{ + false: "0.050000000", + true: "0.05", + }, + }, + } { + t.Run(tt.name, func(t *testing.T) { + x := FromBytes(tt.bts, tt.precision) + p := Append(nil, x) + y := FromBytes(p, tt.precision) + if x.Cmp(y) != 0 { + t.Errorf( + "parsed bytes serialized to different value: %v; want %v", + x, y, + ) + } + require.Equal(t, tt.format[false], Format(x, tt.precision, tt.scale, false)) + require.Equal(t, tt.format[true], Format(x, tt.precision, tt.scale, true)) + }) + } +} + +func uint128(hi, lo uint64) []byte { + p := make([]byte, 16) + binary.BigEndian.PutUint64(p[:8], hi) + binary.BigEndian.PutUint64(p[8:], lo) + + return p +} + +func uint128s(lo uint64) []byte { + return uint128(0, lo) +} + +func TestParseDecimal(t *testing.T) { + for _, tt := range []struct { + s string + n *big.Int + exp uint32 + err bool + }{ + { + s: "123456789", + n: big.NewInt(123456789), + exp: 0, + }, + { + s: "123.456", + n: big.NewInt(123456), + exp: 3, + }, + { + s: "0.123456789", + n: big.NewInt(123456789), + exp: 9, + }, + { + s: ".123456789", + n: big.NewInt(123456789), + exp: 9, + }, + { + s: "-123456789", + n: big.NewInt(-123456789), + exp: 0, + }, + { + s: "-123.456", + n: big.NewInt(-123456), + exp: 3, + }, + { + s: "-0.123456789", + n: big.NewInt(-123456789), + exp: 9, + }, + { + s: "invalid", + err: true, + }, + { + s: "123.invalid", + err: true, + }, + } { + t.Run(tt.s, func(t *testing.T) { + n, exp, err := ParseDecimal(tt.s) + if tt.err { + require.Error(t, err) + } else { + require.NoError(t, err) + require.Equal(t, tt.n, n) + require.Equal(t, tt.exp, exp) + } + }) + } +} + +func TestParse(t *testing.T) { + for _, tt := range []struct { + name string + s string + precision uint32 + scale uint32 + expected *big.Int + err bool + }{ + { + name: "empty string", + s: "", + precision: 22, + scale: 9, + expected: big.NewInt(0), + }, + { + name: "positive integer", + s: "123", + precision: 22, + scale: 9, + expected: big.NewInt(123000000000), + }, + { + name: "negative integer", + s: "-123", + precision: 22, + scale: 9, + expected: big.NewInt(-123000000000), + }, + { + name: "positive with plus sign", + s: "+123", + precision: 22, + scale: 9, + expected: big.NewInt(123000000000), + }, + { + name: "decimal number", + s: "123.456", + precision: 22, + scale: 9, + expected: big.NewInt(123456000000), + }, + { + name: "decimal with trailing zeros truncated", + s: "123.4567890123", + precision: 22, + scale: 9, + expected: big.NewInt(123456789012), + }, + { + name: "inf lowercase", + s: "inf", + precision: 22, + scale: 9, + expected: Inf(), + }, + { + name: "inf uppercase", + s: "INF", + precision: 22, + scale: 9, + expected: Inf(), + }, + { + name: "inf mixed case", + s: "InF", + precision: 22, + scale: 9, + expected: Inf(), + }, + { + name: "negative inf", + s: "-inf", + precision: 22, + scale: 9, + expected: big.NewInt(0).Neg(Inf()), + }, + { + name: "positive inf with plus", + s: "+inf", + precision: 22, + scale: 9, + expected: Inf(), + }, + { + name: "nan lowercase", + s: "nan", + precision: 22, + scale: 9, + expected: NaN(), + }, + { + name: "nan uppercase", + s: "NAN", + precision: 22, + scale: 9, + expected: NaN(), + }, + { + name: "nan mixed case", + s: "NaN", + precision: 22, + scale: 9, + expected: NaN(), + }, + { + name: "negative nan", + s: "-nan", + precision: 22, + scale: 9, + expected: big.NewInt(0).Neg(NaN()), + }, + { + name: "scale greater than precision", + s: "123", + precision: 5, + scale: 10, + err: true, + }, + { + name: "double dot syntax error", + s: "123..456", + precision: 22, + scale: 9, + err: true, + }, + { + name: "invalid character", + s: "12a34", + precision: 22, + scale: 9, + err: true, + }, + { + name: "invalid character after dot", + s: "12.3a4", + precision: 22, + scale: 9, + err: true, + }, + { + name: "overflow to infinity", + s: "9999999999999999999999999", + precision: 10, + scale: 0, + expected: Inf(), + }, + { + name: "negative overflow to negative infinity", + s: "-9999999999999999999999999", + precision: 10, + scale: 0, + expected: big.NewInt(0).Neg(Inf()), + }, + { + name: "rounding up when digit > 5", + s: "1.236", + precision: 22, + scale: 2, + expected: big.NewInt(124), + }, + { + name: "rounding with digit = 5 and odd last", + s: "1.235", + precision: 22, + scale: 2, + expected: big.NewInt(124), + }, + { + name: "rounding with digit = 5 and trailing non-zero", + s: "1.2451", + precision: 22, + scale: 2, + expected: big.NewInt(125), + }, + { + name: "rounding with digit = 5 and even last no trailing", + s: "1.245", + precision: 22, + scale: 2, + expected: big.NewInt(124), // banker's rounding - even stays + }, + { + name: "invalid digit in rounding sequence", + s: "1.23a", + precision: 22, + scale: 2, + err: true, + }, + { + name: "rounding causes overflow to infinity", + s: "9999999999.99999999999", + precision: 10, + scale: 0, + expected: Inf(), + }, + { + name: "invalid char in trailing digits after 5 with even last digit", + s: "1.245x", + precision: 22, + scale: 2, + err: true, + }, + } { + t.Run(tt.name, func(t *testing.T) { + result, err := Parse(tt.s, tt.precision, tt.scale) + if tt.err { + require.Error(t, err) + } else { + require.NoError(t, err) + require.Equal(t, 0, tt.expected.Cmp(result), "expected %v, got %v", tt.expected, result) + } + }) + } +} + +func TestIsInfNaNErr(t *testing.T) { + t.Run("IsInf", func(t *testing.T) { + require.True(t, IsInf(Inf())) + require.True(t, IsInf(big.NewInt(0).Neg(Inf()))) + require.False(t, IsInf(big.NewInt(123))) + require.False(t, IsInf(NaN())) + }) + + t.Run("IsNaN", func(t *testing.T) { + require.True(t, IsNaN(NaN())) + require.True(t, IsNaN(big.NewInt(0).Neg(NaN()))) + require.False(t, IsNaN(big.NewInt(123))) + require.False(t, IsNaN(Inf())) + }) + + t.Run("IsErr", func(t *testing.T) { + require.True(t, IsErr(Err())) + require.False(t, IsErr(big.NewInt(123))) + require.False(t, IsErr(Inf())) + require.False(t, IsErr(NaN())) + }) +} + +func TestInfNaNErr(t *testing.T) { + t.Run("Inf returns copy", func(t *testing.T) { + i1 := Inf() + i2 := Inf() + require.Equal(t, 0, i1.Cmp(i2)) + i1.SetInt64(0) + require.NotEqual(t, 0, Inf().Cmp(i1)) + }) + + t.Run("NaN returns copy", func(t *testing.T) { + n1 := NaN() + n2 := NaN() + require.Equal(t, 0, n1.Cmp(n2)) + n1.SetInt64(0) + require.NotEqual(t, 0, NaN().Cmp(n1)) + }) + + t.Run("Err returns copy", func(t *testing.T) { + e1 := Err() + e2 := Err() + require.Equal(t, 0, e1.Cmp(e2)) + e1.SetInt64(0) + require.NotEqual(t, 0, Err().Cmp(e1)) + }) +} + +func TestFromInt128(t *testing.T) { + t.Run("simple positive", func(t *testing.T) { + var p [16]byte + binary.BigEndian.PutUint64(p[8:], 1000000000) + result := FromInt128(p, 22) + require.Equal(t, 0, big.NewInt(1000000000).Cmp(result)) + }) + + t.Run("zero bytes", func(t *testing.T) { + var p [16]byte + result := FromInt128(p, 22) + require.Equal(t, 0, big.NewInt(0).Cmp(result)) + }) +} + +func TestBigIntToByte(t *testing.T) { + t.Run("normal value", func(t *testing.T) { + x := big.NewInt(123456789) + p := BigIntToByte(x, 22) + result := FromInt128(p, 22) + require.Equal(t, 0, x.Cmp(result)) + }) + + t.Run("negative value", func(t *testing.T) { + x := big.NewInt(-123456789) + p := BigIntToByte(x, 22) + result := FromInt128(p, 22) + require.Equal(t, 0, x.Cmp(result)) + }) + + t.Run("overflow positive becomes inf", func(t *testing.T) { + x := big.NewInt(0).Exp(big.NewInt(10), big.NewInt(25), nil) + p := BigIntToByte(x, 22) + result := FromInt128(p, 22) + require.True(t, IsInf(result)) + }) + + t.Run("overflow negative becomes neginf", func(t *testing.T) { + x := big.NewInt(0).Exp(big.NewInt(10), big.NewInt(25), nil) + x.Neg(x) + p := BigIntToByte(x, 22) + result := FromInt128(p, 22) + require.True(t, IsInf(result)) + require.True(t, result.Sign() < 0) + }) + + t.Run("inf stays inf", func(t *testing.T) { + x := Inf() + p := BigIntToByte(x, 22) + result := FromInt128(p, 22) + require.True(t, IsInf(result)) + }) + + t.Run("nan converted to bytes", func(t *testing.T) { + x := NaN() + p := BigIntToByte(x, 22) + // NaN is larger than any precision, so FromInt128 will interpret as inf + result := FromInt128(p, 22) + require.True(t, IsInf(result)) + }) + + t.Run("err converted to bytes", func(t *testing.T) { + x := Err() + p := BigIntToByte(x, 22) + // Err is larger than any precision, so FromInt128 will interpret as inf + result := FromInt128(p, 22) + require.True(t, IsInf(result)) + }) +} + +func TestFormat(t *testing.T) { + for _, tt := range []struct { + name string + x *big.Int + precision uint32 + scale uint32 + trimTrailingZeros bool + expected string + }{ + { + name: "nil value", + x: nil, + precision: 22, + scale: 9, + expected: "0", + }, + { + name: "zero precision returns nan", + x: big.NewInt(123), + precision: 0, + scale: 0, + expected: "nan", + }, + { + name: "negative zero precision returns -nan", + x: big.NewInt(-123), + precision: 0, + scale: 0, + expected: "-nan", + }, + { + name: "positive inf", + x: Inf(), + precision: 22, + scale: 9, + expected: "inf", + }, + { + name: "negative inf", + x: big.NewInt(0).Neg(Inf()), + precision: 22, + scale: 9, + expected: "-inf", + }, + { + name: "positive nan", + x: NaN(), + precision: 22, + scale: 9, + expected: "nan", + }, + { + name: "negative nan", + x: big.NewInt(0).Neg(NaN()), + precision: 22, + scale: 9, + expected: "-nan", + }, + { + name: "simple integer", + x: big.NewInt(123000000000), + precision: 22, + scale: 9, + expected: "123.000000000", + }, + { + name: "simple integer with trim", + x: big.NewInt(123000000000), + precision: 22, + scale: 9, + trimTrailingZeros: true, + expected: "123", + }, + { + name: "negative number", + x: big.NewInt(-123456000000), + precision: 22, + scale: 9, + expected: "-123.456000000", + }, + { + name: "negative number with trim", + x: big.NewInt(-123456000000), + precision: 22, + scale: 9, + trimTrailingZeros: true, + expected: "-123.456", + }, + { + name: "zero", + x: big.NewInt(0), + precision: 22, + scale: 9, + expected: "0.000000000", + }, + { + name: "zero with trim still shows zeros due to scale handling", + x: big.NewInt(0), + precision: 22, + scale: 9, + trimTrailingZeros: true, + expected: "0.000000000", + }, + { + name: "small decimal", + x: big.NewInt(1), + precision: 22, + scale: 9, + expected: "0.000000001", + }, + { + name: "precision exhausted returns error tag", + x: big.NewInt(0).Exp(big.NewInt(10), big.NewInt(50), nil), + precision: 10, + scale: 0, + expected: "", + }, + { + name: "precision exhausted in scale fill returns error tag", + x: big.NewInt(0), + precision: 2, + scale: 10, + expected: "", + }, + } { + t.Run(tt.name, func(t *testing.T) { + result := Format(tt.x, tt.precision, tt.scale, tt.trimTrailingZeros) + require.Equal(t, tt.expected, result) + }) + } +} + +func TestFromBytesEdgeCases(t *testing.T) { + t.Run("empty bytes", func(t *testing.T) { + result := FromBytes([]byte{}, 22) + require.Equal(t, big.NewInt(0), result) + }) + + t.Run("positive overflow to inf", func(t *testing.T) { + bts := uint128(0x7fffffffffffffff, 0xffffffffffffffff) + result := FromBytes(bts, 10) + require.True(t, IsInf(result)) + require.True(t, result.Sign() > 0) + }) +} + +func TestDecimalType(t *testing.T) { + t.Run("ToDecimal", func(t *testing.T) { + original := &Decimal{ + Bytes: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}, + Precision: 22, + Scale: 9, + } + result := ToDecimal(original) + require.Equal(t, original.Bytes, result.Bytes) + require.Equal(t, original.Precision, result.Precision) + require.Equal(t, original.Scale, result.Scale) + }) + + t.Run("Decimal method", func(t *testing.T) { + d := &Decimal{ + Bytes: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}, + Precision: 22, + Scale: 9, + } + bytes, precision, scale := d.Decimal() + require.Equal(t, d.Bytes, bytes) + require.Equal(t, d.Precision, precision) + require.Equal(t, d.Scale, scale) + }) + + t.Run("String", func(t *testing.T) { + d := &Decimal{ + Bytes: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 59, 154, 202, 0}, + Precision: 22, + Scale: 9, + } + result := d.String() + require.Equal(t, "1.000000000", result) + }) + + t.Run("Format method", func(t *testing.T) { + d := &Decimal{ + Bytes: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 59, 154, 202, 0}, + Precision: 22, + Scale: 9, + } + require.Equal(t, "1.000000000", d.Format(false)) + require.Equal(t, "1", d.Format(true)) + }) + + t.Run("BigInt", func(t *testing.T) { + d := &Decimal{ + Bytes: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 59, 154, 202, 0}, + Precision: 22, + Scale: 9, + } + result := d.BigInt() + require.Equal(t, big.NewInt(1000000000), result) + }) +} + +type testDecimalInterface struct { + bytes [16]byte + precision uint32 + scale uint32 +} + +func (t *testDecimalInterface) Decimal() ([16]byte, uint32, uint32) { + return t.bytes, t.precision, t.scale +} + +type testValuer struct { + value any + err error +} + +func (t *testValuer) Value() (driver.Value, error) { + return t.value, t.err +} + +func TestDecimalScan(t *testing.T) { + t.Run("scan from Interface", func(t *testing.T) { + iface := &testDecimalInterface{ + bytes: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}, + precision: 22, + scale: 9, + } + var d Decimal + err := d.Scan(iface) + require.NoError(t, err) + require.Equal(t, iface.bytes, d.Bytes) + require.Equal(t, iface.precision, d.Precision) + require.Equal(t, iface.scale, d.Scale) + }) + + t.Run("scan from *Decimal", func(t *testing.T) { + original := &Decimal{ + Bytes: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}, + Precision: 22, + Scale: 9, + } + var d Decimal + err := d.Scan(original) + require.NoError(t, err) + require.Equal(t, original.Bytes, d.Bytes) + require.Equal(t, original.Precision, d.Precision) + require.Equal(t, original.Scale, d.Scale) + }) + + t.Run("scan from string", func(t *testing.T) { + var d Decimal + err := d.Scan("123.456") + require.NoError(t, err) + require.NotEmpty(t, d.Bytes) + }) + + t.Run("scan from invalid string", func(t *testing.T) { + var d Decimal + err := d.Scan("invalid") + require.Error(t, err) + }) + + t.Run("scan from driver.Valuer", func(t *testing.T) { + valuer := &testValuer{value: "123.456", err: nil} + var d Decimal + err := d.Scan(valuer) + require.NoError(t, err) + }) + + t.Run("scan from driver.Valuer with error", func(t *testing.T) { + valuer := &testValuer{value: nil, err: errors.New("valuer error")} + var d Decimal + err := d.Scan(valuer) + require.Error(t, err) + }) + + t.Run("scan from driver.Valuer with invalid value", func(t *testing.T) { + valuer := &testValuer{value: "invalid", err: nil} + var d Decimal + err := d.Scan(valuer) + require.Error(t, err) + }) + + t.Run("scan from unsupported type", func(t *testing.T) { + var d Decimal + err := d.Scan(12345) + require.Error(t, err) + }) +} + +func TestParseError(t *testing.T) { + t.Run("Error method", func(t *testing.T) { + pe := &ParseError{ + Err: errors.New("test error"), + Input: "test input", + } + errStr := pe.Error() + require.Contains(t, errStr, "test input") + require.Contains(t, errStr, "test error") + }) + + t.Run("Unwrap method", func(t *testing.T) { + innerErr := errors.New("inner error") + pe := &ParseError{ + Err: innerErr, + Input: "test input", + } + require.Equal(t, innerErr, pe.Unwrap()) + }) + + t.Run("syntax error through Parse", func(t *testing.T) { + _, err := Parse("12a34", 22, 9) + require.Error(t, err) + var pe *ParseError + require.True(t, errors.As(err, &pe)) + }) + + t.Run("precision error through Parse", func(t *testing.T) { + _, err := Parse("123", 5, 10) + require.Error(t, err) + var pe *ParseError + require.True(t, errors.As(err, &pe)) + require.Contains(t, pe.Error(), "precision") + }) +} diff --git a/internal/decimal/errors.go b/pkg/decimal/errors.go similarity index 100% rename from internal/decimal/errors.go rename to pkg/decimal/errors.go diff --git a/pkg/decimal/type.go b/pkg/decimal/type.go new file mode 100644 index 000000000..fdd51eefa --- /dev/null +++ b/pkg/decimal/type.go @@ -0,0 +1,60 @@ +package decimal + +import ( + "database/sql" + "math/big" + + "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors" +) + +var ( + _ sql.Scanner = (*Decimal)(nil) + _ Interface = (*Decimal)(nil) +) + +type ( + Interface interface { + Decimal() (bytes [16]byte, precision uint32, scale uint32) + } + Decimal struct { + Bytes [16]byte + Precision uint32 + Scale uint32 + } +) + +func ToDecimal(v Interface) *Decimal { + var d Decimal + + d.Bytes, d.Precision, d.Scale = v.Decimal() + + return &d +} + +func (d *Decimal) Decimal() (bytes [16]byte, precision uint32, scale uint32) { + return d.Bytes, d.Precision, d.Scale +} + +func (d *Decimal) Scan(value any) error { + if err := d.apply(value); err != nil { + return xerrors.WithStackTrace(err) + } + + return nil +} + +func (d Decimal) String() string { + v := FromInt128(d.Bytes, d.Precision) + + return Format(v, d.Precision, d.Scale, false) +} + +func (d Decimal) Format(trimTrailingZeros bool) string { + v := FromInt128(d.Bytes, d.Precision) + + return Format(v, d.Precision, d.Scale, trimTrailingZeros) +} + +func (d Decimal) BigInt() *big.Int { + return FromInt128(d.Bytes, d.Precision) +} diff --git a/table/types/cast.go b/table/types/cast.go index 69ac0a572..ea488f9e2 100644 --- a/table/types/cast.go +++ b/table/types/cast.go @@ -6,6 +6,7 @@ import ( "github.com/ydb-platform/ydb-go-sdk/v3/internal/value" "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors" + "github.com/ydb-platform/ydb-go-sdk/v3/pkg/decimal" ) var errNilValue = errors.New("nil value") @@ -45,12 +46,8 @@ func Unwrap(v Value) Value { // ToDecimal returns Decimal struct from abstract Value func ToDecimal(v Value) (*Decimal, error) { - if valuer, isDecimalValuer := v.(value.DecimalValuer); isDecimalValuer { - return &Decimal{ - Bytes: valuer.Value(), - Precision: valuer.Precision(), - Scale: valuer.Scale(), - }, nil + if d, has := v.(decimal.Interface); has { + return decimal.ToDecimal(d), nil } return nil, xerrors.WithStackTrace(fmt.Errorf("value type '%s' is not decimal type", v.Type().Yql())) diff --git a/table/types/value.go b/table/types/value.go index 591affdbd..67d8d497d 100644 --- a/table/types/value.go +++ b/table/types/value.go @@ -6,8 +6,8 @@ import ( "github.com/google/uuid" - "github.com/ydb-platform/ydb-go-sdk/v3/internal/decimal" "github.com/ydb-platform/ydb-go-sdk/v3/internal/value" + "github.com/ydb-platform/ydb-go-sdk/v3/pkg/decimal" "github.com/ydb-platform/ydb-go-sdk/v3/pkg/xstring" ) diff --git a/tests/integration/basic_example_database_sql_test.go b/tests/integration/basic_example_database_sql_test.go index 3e4a0d1b6..595877186 100644 --- a/tests/integration/basic_example_database_sql_test.go +++ b/tests/integration/basic_example_database_sql_test.go @@ -139,17 +139,17 @@ func TestBasicExampleDatabaseSql(t *testing.T) { } _, err = db.ExecContext(ctx, ` - CREATE TABLE `+"`"+tablePath+"`"+` ( - series_id Uint64, - title UTF8, - series_info UTF8, - release_date Date, - comment UTF8, - PRIMARY KEY ( - series_id - ) - ); - `) + CREATE TABLE `+"`"+tablePath+"`"+` ( + series_id Uint64, + title UTF8, + series_info UTF8, + release_date Date, + comment UTF8, + PRIMARY KEY ( + series_id + ) + ); + `) require.NoError(t, err) }) t.Run("seasons", func(t *testing.T) { @@ -168,18 +168,18 @@ func TestBasicExampleDatabaseSql(t *testing.T) { } _, err = db.ExecContext(ctx, ` - CREATE TABLE `+"`"+tablePath+"`"+` ( - series_id Uint64, - season_id Uint64, - title UTF8, - first_aired Date, - last_aired Date, - PRIMARY KEY ( - series_id, - season_id - ) - ); - `) + CREATE TABLE `+"`"+tablePath+"`"+` ( + series_id Uint64, + season_id Uint64, + title UTF8, + first_aired Date, + last_aired Date, + PRIMARY KEY ( + series_id, + season_id + ) + ); + `) require.NoError(t, err) }) t.Run("episodes", func(t *testing.T) { @@ -198,20 +198,20 @@ func TestBasicExampleDatabaseSql(t *testing.T) { } _, err = db.ExecContext(ctx, ` - CREATE TABLE `+"`"+tablePath+"`"+` ( - series_id Uint64, - season_id Uint64, - episode_id Uint64, - title UTF8, - air_date Date, - views Uint64, - PRIMARY KEY ( - series_id, - season_id, - episode_id - ) - ); - `) + CREATE TABLE `+"`"+tablePath+"`"+` ( + series_id Uint64, + season_id Uint64, + episode_id Uint64, + title UTF8, + air_date Date, + views Uint64, + PRIMARY KEY ( + series_id, + season_id, + episode_id + ) + ); + `) require.NoError(t, err) }) }) @@ -221,35 +221,35 @@ func TestBasicExampleDatabaseSql(t *testing.T) { t.Run("upsert", func(t *testing.T) { err = retry.Do(ctx, db, func(ctx context.Context, cc *sql.Conn) error { stmt, err := cc.PrepareContext(ctx, ` - PRAGMA TablePathPrefix("`+path.Join(nativeDriver.Name(), folder)+`"); - - DECLARE $seriesData AS List>>; + PRAGMA TablePathPrefix("`+path.Join(nativeDriver.Name(), folder)+`"); + + DECLARE $seriesData AS List>>; + + DECLARE $seasonsData AS List>; + + DECLARE $episodesData AS List>; + + REPLACE INTO series SELECT * FROM AS_TABLE($seriesData); - DECLARE $seasonsData AS List>; + REPLACE INTO seasons SELECT * FROM AS_TABLE($seasonsData); - DECLARE $episodesData AS List>; - - REPLACE INTO series SELECT * FROM AS_TABLE($seriesData); - - REPLACE INTO seasons SELECT * FROM AS_TABLE($seasonsData); - - REPLACE INTO episodes SELECT * FROM AS_TABLE($episodesData); - `) + REPLACE INTO episodes SELECT * FROM AS_TABLE($episodesData); + `) if err != nil { return fmt.Errorf("failed to prepare query: %w", err) } @@ -269,18 +269,18 @@ func TestBasicExampleDatabaseSql(t *testing.T) { t.Run("query", func(t *testing.T) { query := ` - PRAGMA TablePathPrefix("` + path.Join(nativeDriver.Name(), folder) + `"); - - DECLARE $seriesID AS Uint64; - DECLARE $seasonID AS Uint64; - DECLARE $episodeID AS Uint64; - - SELECT views - FROM episodes - WHERE - series_id = $seriesID AND - season_id = $seasonID AND - episode_id = $episodeID;` + PRAGMA TablePathPrefix("` + path.Join(nativeDriver.Name(), folder) + `"); + + DECLARE $seriesID AS Uint64; + DECLARE $seasonID AS Uint64; + DECLARE $episodeID AS Uint64; + + SELECT views + FROM episodes + WHERE + series_id = $seriesID AND + season_id = $seasonID AND + episode_id = $episodeID;` t.Run("explain", func(t *testing.T) { row := db.QueryRowContext( ydb.WithQueryMode(ctx, ydb.ExplainQueryMode), query, @@ -322,16 +322,16 @@ func TestBasicExampleDatabaseSql(t *testing.T) { } // increment `views` _, err = tx.ExecContext(ctx, ` - PRAGMA TablePathPrefix("`+path.Join(nativeDriver.Name(), folder)+`"); - - DECLARE $seriesID AS Uint64; - DECLARE $seasonID AS Uint64; - DECLARE $episodeID AS Uint64; - DECLARE $views AS Uint64; - - UPSERT INTO episodes ( series_id, season_id, episode_id, views ) - VALUES ( $seriesID, $seasonID, $episodeID, $views ); - `, + PRAGMA TablePathPrefix("`+path.Join(nativeDriver.Name(), folder)+`"); + + DECLARE $seriesID AS Uint64; + DECLARE $seasonID AS Uint64; + DECLARE $episodeID AS Uint64; + DECLARE $views AS Uint64; + + UPSERT INTO episodes ( series_id, season_id, episode_id, views ) + VALUES ( $seriesID, $seasonID, $episodeID, $views ); + `, sql.Named("seriesID", uint64(1)), sql.Named("seasonID", uint64(1)), sql.Named("episodeID", uint64(1)), @@ -349,18 +349,18 @@ func TestBasicExampleDatabaseSql(t *testing.T) { t.Run("isolation", func(t *testing.T) { t.Run("snapshot", func(t *testing.T) { query := ` - PRAGMA TablePathPrefix("` + path.Join(nativeDriver.Name(), folder) + `"); - - DECLARE $seriesID AS Uint64; - DECLARE $seasonID AS Uint64; - DECLARE $episodeID AS Uint64; - - SELECT views FROM episodes - WHERE - series_id = $seriesID AND - season_id = $seasonID AND - episode_id = $episodeID; - ` + PRAGMA TablePathPrefix("` + path.Join(nativeDriver.Name(), folder) + `"); + + DECLARE $seriesID AS Uint64; + DECLARE $seasonID AS Uint64; + DECLARE $episodeID AS Uint64; + + SELECT views FROM episodes + WHERE + series_id = $seriesID AND + season_id = $seasonID AND + episode_id = $episodeID; + ` err = retry.DoTx(ctx, db, func(ctx context.Context, tx *sql.Tx) error { row := tx.QueryRowContext(ctx, query, @@ -401,27 +401,27 @@ func TestBasicExampleDatabaseSql(t *testing.T) { airDate *time.Time views sql.NullFloat64 query = ` - PRAGMA TablePathPrefix("` + path.Join(nativeDriver.Name(), folder) + `"); - - DECLARE $seriesID AS Optional; - DECLARE $seasonID AS Optional; - DECLARE $episodeID AS Optional; - - SELECT - series_id, - season_id, - episode_id, - title, - air_date, - views - FROM episodes - WHERE - (series_id >= $seriesID OR $seriesID IS NULL) AND - (season_id >= $seasonID OR $seasonID IS NULL) AND - (episode_id >= $episodeID OR $episodeID IS NULL) - ORDER BY - series_id, season_id, episode_id; - ` + PRAGMA TablePathPrefix("` + path.Join(nativeDriver.Name(), folder) + `"); + + DECLARE $seriesID AS Optional; + DECLARE $seasonID AS Optional; + DECLARE $episodeID AS Optional; + + SELECT + series_id, + season_id, + episode_id, + title, + air_date, + views + FROM episodes + WHERE + (series_id >= $seriesID OR $seriesID IS NULL) AND + (season_id >= $seasonID OR $seasonID IS NULL) AND + (episode_id >= $episodeID OR $episodeID IS NULL) + ORDER BY + series_id, season_id, episode_id; + ` ) err := retry.DoTx(ctx, db, func(ctx context.Context, tx *sql.Tx) error { diff --git a/tests/integration/decimal_test.go b/tests/integration/decimal_test.go new file mode 100644 index 000000000..755d63131 --- /dev/null +++ b/tests/integration/decimal_test.go @@ -0,0 +1,443 @@ +//go:build integration +// +build integration + +package integration + +import ( + "context" + "database/sql" + "math/big" + "os" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/ydb-platform/ydb-go-sdk/v3" + "github.com/ydb-platform/ydb-go-sdk/v3/pkg/decimal" + "github.com/ydb-platform/ydb-go-sdk/v3/pkg/xtest" + "github.com/ydb-platform/ydb-go-sdk/v3/query" + "github.com/ydb-platform/ydb-go-sdk/v3/table" + "github.com/ydb-platform/ydb-go-sdk/v3/table/types" +) + +func TestIssue1234UnexpectedDecimalRepresentation(t *testing.T) { + scope := newScope(t) + driver := scope.Driver() + + tests := []struct { + name string + bts [16]byte + precision uint32 + scale uint32 + expectedFormat string + }{ + { + bts: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 250, 240, 128}, + precision: 22, + scale: 9, + expectedFormat: "0.050000000", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + expected := decimal.Decimal{ + Bytes: tt.bts, + Precision: tt.precision, + Scale: tt.scale, + } + var actual decimal.Decimal + + err := driver.Table().Do(scope.Ctx, func(ctx context.Context, s table.Session) error { + _, result, err := s.Execute(ctx, table.DefaultTxControl(), ` + DECLARE $value AS Decimal(22,9); + SELECT $value;`, + table.NewQueryParameters( + table.ValueParam("$value", types.DecimalValue(&expected)), + ), + ) + if err != nil { + return err + } + for result.NextResultSet(ctx) { + for result.NextRow() { + err = result.Scan(&actual) + if err != nil { + return err + } + } + } + return nil + }) + require.NoError(t, err) + require.Equal(t, expected, actual) + require.Equal(t, tt.expectedFormat, actual.String()) + }) + } +} + +func TestQueryDecimalScan(t *testing.T) { + ctx, cancel := context.WithCancel(xtest.Context(t)) + defer cancel() + + db, err := ydb.Open(ctx, + os.Getenv("YDB_CONNECTION_STRING"), + ydb.WithAccessTokenCredentials(os.Getenv("YDB_ACCESS_TOKEN_CREDENTIALS")), + ) + require.NoError(t, err) + defer func() { + _ = db.Close(ctx) + }() + + t.Run("DirectScan", func(t *testing.T) { + row, err := db.Query().QueryRow(ctx, + `SELECT Decimal('100.500', 33, 12)`, + query.WithIdempotent(), + ) + require.NoError(t, err) + + var dst decimal.Decimal + err = row.Scan(&dst) + require.NoError(t, err) + require.Equal(t, uint32(12), dst.Scale) + require.Equal(t, uint32(33), dst.Precision) + require.Equal(t, big.NewInt(100500000000000), dst.BigInt()) + require.Equal(t, "100.500000000000", dst.String()) + require.Equal(t, "100.5", dst.Format(true)) + require.Equal(t, "100.500000000000", dst.Format(false)) + }) + + t.Run("DirectScanNegative", func(t *testing.T) { + row, err := db.Query().QueryRow(ctx, + `SELECT Decimal('-5.33', 22, 9)`, + query.WithIdempotent(), + ) + require.NoError(t, err) + + var dst decimal.Decimal + err = row.Scan(&dst) + require.NoError(t, err) + require.Equal(t, uint32(22), dst.Precision) + require.Equal(t, uint32(9), dst.Scale) + require.Equal(t, big.NewInt(-5330000000), dst.BigInt()) + require.Equal(t, "-5.330000000", dst.String()) + require.Equal(t, "-5.33", dst.Format(true)) + require.Equal(t, "-5.330000000", dst.Format(false)) + }) + + t.Run("DirectScanWithOtherTypes", func(t *testing.T) { + row, err := db.Query().QueryRow(ctx, + `SELECT 42u AS id, Decimal('10.01', 22, 9) AS amount`, + query.WithIdempotent(), + ) + require.NoError(t, err) + + var id uint64 + var amount decimal.Decimal + err = row.Scan(&id, &amount) + require.NoError(t, err) + require.Equal(t, uint64(42), id) + require.Equal(t, uint32(22), amount.Precision) + require.Equal(t, uint32(9), amount.Scale) + require.Equal(t, big.NewInt(10010000000), amount.BigInt()) + require.Equal(t, "10.010000000", amount.String()) + require.Equal(t, "10.01", amount.Format(true)) + require.Equal(t, "10.010000000", amount.Format(false)) + }) + + t.Run("DirectScanOptional", func(t *testing.T) { + row, err := db.Query().QueryRow(ctx, + `SELECT CAST(NULL AS Decimal(22, 9))`, + query.WithIdempotent(), + ) + require.NoError(t, err) + + var dst *decimal.Decimal + err = row.Scan(&dst) + require.NoError(t, err) + require.Nil(t, dst) + }) + + t.Run("DirectScanOptionalNonNull", func(t *testing.T) { + row, err := db.Query().QueryRow(ctx, + `SELECT JUST(Decimal('99.99', 22, 9))`, + query.WithIdempotent(), + ) + require.NoError(t, err) + + var dst *decimal.Decimal + err = row.Scan(&dst) + require.NoError(t, err) + require.NotNil(t, dst) + require.Equal(t, uint32(22), dst.Precision) + require.Equal(t, uint32(9), dst.Scale) + require.Equal(t, big.NewInt(99990000000), dst.BigInt()) + require.Equal(t, "99.990000000", dst.String()) + require.Equal(t, "99.99", dst.Format(true)) + require.Equal(t, "99.990000000", dst.Format(false)) + }) +} + +func TestDatabaseSqlDecimalScan(t *testing.T) { + ctx, cancel := context.WithCancel(xtest.Context(t)) + defer cancel() + + nativeDriver, err := ydb.Open(ctx, + os.Getenv("YDB_CONNECTION_STRING"), + ydb.WithAccessTokenCredentials(os.Getenv("YDB_ACCESS_TOKEN_CREDENTIALS")), + ) + require.NoError(t, err) + defer func() { + _ = nativeDriver.Close(ctx) + }() + + connector, err := ydb.Connector(nativeDriver, + ydb.WithQueryService(true), + ) + require.NoError(t, err) + defer func() { + _ = connector.Close() + }() + + db := sql.OpenDB(connector) + + t.Run("DirectScan", func(t *testing.T) { + row := db.QueryRowContext(ctx, `SELECT Decimal('100.500', 33, 12)`) + + var dst decimal.Decimal + err = row.Scan(&dst) + require.NoError(t, err) + require.Equal(t, uint32(33), dst.Precision) + require.Equal(t, uint32(12), dst.Scale) + require.Equal(t, big.NewInt(100500000000000), dst.BigInt()) + require.Equal(t, "100.500000000000", dst.String()) + require.Equal(t, "100.5", dst.Format(true)) + require.Equal(t, "100.500000000000", dst.Format(false)) + }) + + t.Run("DirectScanNegative", func(t *testing.T) { + row := db.QueryRowContext(ctx, `SELECT Decimal('-5.33', 22, 9)`) + + var dst decimal.Decimal + err = row.Scan(&dst) + require.NoError(t, err) + require.Equal(t, uint32(22), dst.Precision) + require.Equal(t, uint32(9), dst.Scale) + require.Equal(t, big.NewInt(-5330000000), dst.BigInt()) + require.Equal(t, "-5.330000000", dst.String()) + require.Equal(t, "-5.33", dst.Format(true)) + require.Equal(t, "-5.330000000", dst.Format(false)) + }) + + t.Run("DirectScanWithOtherTypes", func(t *testing.T) { + row := db.QueryRowContext(ctx, `SELECT 42u AS id, Decimal('10.01', 22, 9) AS amount`) + + var id uint64 + var amount decimal.Decimal + err = row.Scan(&id, &amount) + require.NoError(t, err) + + require.Equal(t, uint64(42), id) + require.Equal(t, uint32(22), amount.Precision) + require.Equal(t, uint32(9), amount.Scale) + require.Equal(t, big.NewInt(10010000000), amount.BigInt()) + require.Equal(t, "10.010000000", amount.String()) + require.Equal(t, "10.01", amount.Format(true)) + require.Equal(t, "10.010000000", amount.Format(false)) + }) + + t.Run("DirectScanOptional", func(t *testing.T) { + row := db.QueryRowContext(ctx, `SELECT CAST(NULL AS Decimal(22, 9))`) + + var dst *decimal.Decimal + err = row.Scan(&dst) + require.NoError(t, err) + require.Nil(t, dst) + }) + + t.Run("DirectScanOptionalNonNull", func(t *testing.T) { + row := db.QueryRowContext(ctx, `SELECT JUST(Decimal('99.99', 22, 9))`) + + var dst *decimal.Decimal + err = row.Scan(&dst) + require.NoError(t, err) + require.NotNil(t, dst) + require.Equal(t, uint32(22), dst.Precision) + require.Equal(t, uint32(9), dst.Scale) + require.Equal(t, big.NewInt(99990000000), dst.BigInt()) + require.Equal(t, "99.990000000", dst.String()) + require.Equal(t, "99.99", dst.Format(true)) + require.Equal(t, "99.990000000", dst.Format(false)) + }) +} + +func TestQueryDecimalParam(t *testing.T) { + ctx, cancel := context.WithCancel(xtest.Context(t)) + defer cancel() + + db, err := ydb.Open(ctx, + os.Getenv("YDB_CONNECTION_STRING"), + ydb.WithAccessTokenCredentials(os.Getenv("YDB_ACCESS_TOKEN_CREDENTIALS")), + ) + require.NoError(t, err) + defer func() { + _ = db.Close(ctx) + }() + + t.Run("DirectScan", func(t *testing.T) { + d, err := types.DecimalValueFromString("100.5", 33, 12) + require.NoError(t, err) + row, err := db.Query().QueryRow(ctx, ` + DECLARE $p AS Decimal(33,12); + SELECT $p; + `, query.WithParameters(ydb.ParamsBuilder(). + Param("$p").Any(d). + Build(), + ), query.WithIdempotent()) + require.NoError(t, err) + + var dst decimal.Decimal + err = row.Scan(&dst) + require.NoError(t, err) + require.Equal(t, uint32(12), dst.Scale) + require.Equal(t, uint32(33), dst.Precision) + require.Equal(t, big.NewInt(100500000000000), dst.BigInt()) + require.Equal(t, "100.500000000000", dst.String()) + require.Equal(t, "100.5", dst.Format(true)) + require.Equal(t, "100.500000000000", dst.Format(false)) + }) + + t.Run("DirectScanNegative", func(t *testing.T) { + d, err := types.DecimalValueFromString("-5.33", 22, 9) + require.NoError(t, err) + row, err := db.Query().QueryRow(ctx, ` + DECLARE $p AS Decimal(22,9); + SELECT $p; + `, query.WithParameters(ydb.ParamsBuilder(). + Param("$p").Any(d). + Build(), + ), query.WithIdempotent()) + require.NoError(t, err) + + var dst decimal.Decimal + err = row.Scan(&dst) + require.NoError(t, err) + require.Equal(t, uint32(22), dst.Precision) + require.Equal(t, uint32(9), dst.Scale) + require.Equal(t, big.NewInt(-5330000000), dst.BigInt()) + require.Equal(t, "-5.330000000", dst.String()) + require.Equal(t, "-5.33", dst.Format(true)) + require.Equal(t, "-5.330000000", dst.Format(false)) + }) + + t.Run("DirectScanWithOtherTypes", func(t *testing.T) { + d, err := types.DecimalValueFromString("10.01", 22, 9) + require.NoError(t, err) + row, err := db.Query().QueryRow(ctx, ` + DECLARE $p1 AS Uint64; + DECLARE $p2 AS Decimal(22,9); + SELECT $p1 AS id, $p2 AS amount; + `, query.WithParameters(ydb.ParamsBuilder(). + Param("$p1").Uint64(42). + Param("$p2").Any(d). + Build(), + ), query.WithIdempotent()) + require.NoError(t, err) + + var id uint64 + var amount decimal.Decimal + err = row.Scan(&id, &amount) + require.NoError(t, err) + require.Equal(t, uint64(42), id) + require.Equal(t, uint32(22), amount.Precision) + require.Equal(t, uint32(9), amount.Scale) + require.Equal(t, big.NewInt(10010000000), amount.BigInt()) + require.Equal(t, "10.010000000", amount.String()) + require.Equal(t, "10.01", amount.Format(true)) + require.Equal(t, "10.010000000", amount.Format(false)) + }) +} + +func TestDatabaseSqlDecimalParam(t *testing.T) { + ctx, cancel := context.WithCancel(xtest.Context(t)) + defer cancel() + + nativeDriver, err := ydb.Open(ctx, + os.Getenv("YDB_CONNECTION_STRING"), + ydb.WithAccessTokenCredentials(os.Getenv("YDB_ACCESS_TOKEN_CREDENTIALS")), + ) + require.NoError(t, err) + defer func() { + _ = nativeDriver.Close(ctx) + }() + + connector, err := ydb.Connector(nativeDriver, + ydb.WithQueryService(true), + ) + require.NoError(t, err) + defer func() { + _ = connector.Close() + }() + + db := sql.OpenDB(connector) + + t.Run("DirectScan", func(t *testing.T) { + d, err := types.DecimalValueFromString("100.5", 33, 12) + require.NoError(t, err) + row := db.QueryRowContext(ctx, ` + DECLARE $p AS Decimal(33,12); + SELECT $p; + `, sql.Named("p", d)) + + var dst decimal.Decimal + err = row.Scan(&dst) + require.NoError(t, err) + require.Equal(t, uint32(33), dst.Precision) + require.Equal(t, uint32(12), dst.Scale) + require.Equal(t, big.NewInt(100500000000000), dst.BigInt()) + require.Equal(t, "100.500000000000", dst.String()) + require.Equal(t, "100.5", dst.Format(true)) + require.Equal(t, "100.500000000000", dst.Format(false)) + }) + + t.Run("DirectScanNegative", func(t *testing.T) { + d, err := types.DecimalValueFromString("-5.33", 22, 9) + require.NoError(t, err) + row := db.QueryRowContext(ctx, ` + DECLARE $p AS Decimal(22,9); + SELECT $p; + `, sql.Named("p", d)) + + var dst decimal.Decimal + err = row.Scan(&dst) + require.NoError(t, err) + require.Equal(t, uint32(22), dst.Precision) + require.Equal(t, uint32(9), dst.Scale) + require.Equal(t, big.NewInt(-5330000000), dst.BigInt()) + require.Equal(t, "-5.330000000", dst.String()) + require.Equal(t, "-5.33", dst.Format(true)) + require.Equal(t, "-5.330000000", dst.Format(false)) + }) + + t.Run("DirectScanWithOtherTypes", func(t *testing.T) { + d, err := types.DecimalValueFromString("10.01", 22, 9) + require.NoError(t, err) + row := db.QueryRowContext(ctx, ` + DECLARE $p1 AS Uint64; + DECLARE $p2 AS Decimal(22,9); + SELECT $p1 AS id, $p2 AS amount; + `, sql.Named("p1", uint64(42)), sql.Named("p2", d)) + + var id uint64 + var amount decimal.Decimal + err = row.Scan(&id, &amount) + require.NoError(t, err) + + require.Equal(t, uint64(42), id) + require.Equal(t, uint32(22), amount.Precision) + require.Equal(t, uint32(9), amount.Scale) + require.Equal(t, big.NewInt(10010000000), amount.BigInt()) + require.Equal(t, "10.010000000", amount.String()) + require.Equal(t, "10.01", amount.Format(true)) + require.Equal(t, "10.010000000", amount.Format(false)) + }) +} diff --git a/tests/integration/query_execute_test.go b/tests/integration/query_execute_test.go index c54ae91f9..a834e5e99 100644 --- a/tests/integration/query_execute_test.go +++ b/tests/integration/query_execute_test.go @@ -22,10 +22,10 @@ import ( "github.com/ydb-platform/ydb-go-genproto/protos/Ydb_Issue" "github.com/ydb-platform/ydb-go-sdk/v3" - "github.com/ydb-platform/ydb-go-sdk/v3/internal/decimal" "github.com/ydb-platform/ydb-go-sdk/v3/internal/value" "github.com/ydb-platform/ydb-go-sdk/v3/internal/version" "github.com/ydb-platform/ydb-go-sdk/v3/log" + "github.com/ydb-platform/ydb-go-sdk/v3/pkg/decimal" "github.com/ydb-platform/ydb-go-sdk/v3/pkg/xtest" "github.com/ydb-platform/ydb-go-sdk/v3/query" "github.com/ydb-platform/ydb-go-sdk/v3/table/types" @@ -732,8 +732,8 @@ func TestIssue1785FillDecimalFields(t *testing.T) { require.NoError(t, err) t.Run("Query", func(t *testing.T) { type RowData struct { - Id uint64 `sql:"id"` - DecimalVal types.Decimal `sql:"dc"` + Id uint64 `sql:"id"` + DecimalVal decimal.Decimal `sql:"dc"` } result, err := db.Query().Query(ctx, ` SELECT id, dc @@ -757,18 +757,18 @@ func TestIssue1785FillDecimalFields(t *testing.T) { err = row.ScanStruct(&rd) require.NoError(t, err) require.EqualValues(t, uint64(1), rd.Id) - require.EqualValues(t, types.Decimal{Bytes: decimal.BigIntToByte(big.NewInt(10010000000), 22, 9), Precision: 22, Scale: 9}, rd.DecimalVal) + require.EqualValues(t, decimal.Decimal{Bytes: decimal.BigIntToByte(big.NewInt(10010000000), 22), Precision: 22, Scale: 9}, rd.DecimalVal) row, err = resultSet.NextRow(ctx) require.NoError(t, err) err = row.ScanStruct(&rd) require.NoError(t, err) require.EqualValues(t, uint64(2), rd.Id) - require.EqualValues(t, types.Decimal{Bytes: decimal.BigIntToByte(big.NewInt(-5330000000), 22, 9), Precision: 22, Scale: 9}, rd.DecimalVal) + require.EqualValues(t, decimal.Decimal{Bytes: decimal.BigIntToByte(big.NewInt(-5330000000), 22), Precision: 22, Scale: 9}, rd.DecimalVal) row, err = resultSet.NextRow(ctx) require.NoError(t, err) err = row.ScanStruct(&rd) require.NoError(t, err) - expectedVal := types.Decimal{Bytes: [16]byte{0, 19, 66, 97, 114, 199, 77, 130, 43, 135, 143, 232, 0, 0, 0, 0}, Precision: 22, Scale: 9} + expectedVal := decimal.Decimal{Bytes: [16]byte{0, 19, 66, 97, 114, 199, 77, 130, 43, 135, 143, 232, 0, 0, 0, 0}, Precision: 22, Scale: 9} require.EqualValues(t, expectedVal, rd.DecimalVal) }) } diff --git a/tests/integration/unexpected_decimal_parse_test.go b/tests/integration/unexpected_decimal_parse_test.go deleted file mode 100644 index 2a682d2e7..000000000 --- a/tests/integration/unexpected_decimal_parse_test.go +++ /dev/null @@ -1,70 +0,0 @@ -//go:build integration -// +build integration - -package integration - -import ( - "context" - "testing" - - "github.com/stretchr/testify/require" - - "github.com/ydb-platform/ydb-go-sdk/v3/internal/decimal" - "github.com/ydb-platform/ydb-go-sdk/v3/table" - "github.com/ydb-platform/ydb-go-sdk/v3/table/types" -) - -func TestIssue1234UnexpectedDecimalRepresentation(t *testing.T) { - scope := newScope(t) - driver := scope.Driver() - - tests := []struct { - name string - bts [16]byte - precision uint32 - scale uint32 - expectedFormat string - }{ - { - bts: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 250, 240, 128}, - precision: 22, - scale: 9, - expectedFormat: "0.050000000", - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - expected := decimal.Decimal{ - Bytes: tt.bts, - Precision: tt.precision, - Scale: tt.scale, - } - var actual decimal.Decimal - - err := driver.Table().Do(scope.Ctx, func(ctx context.Context, s table.Session) error { - _, result, err := s.Execute(ctx, table.DefaultTxControl(), ` - DECLARE $value AS Decimal(22,9); - SELECT $value;`, - table.NewQueryParameters( - table.ValueParam("$value", types.DecimalValue(&expected)), - ), - ) - if err != nil { - return err - } - for result.NextResultSet(ctx) { - for result.NextRow() { - err = result.Scan(&actual) - if err != nil { - return err - } - } - } - return nil - }) - require.NoError(t, err) - require.Equal(t, expected, actual) - require.Equal(t, tt.expectedFormat, actual.String()) - }) - } -}