diff --git a/CHANGELOG.md b/CHANGELOG.md index a2b75c64e..afe4b60eb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,5 @@ * Masked the sensitive credential data in the connection string (DSN, data source name) from error messages for security reasons +* Added `WithConcurrentResultSets` option for `db.Query().Query()` ## v3.121.0 * Changed internal pprof label to pyroscope supported format diff --git a/internal/query/client.go b/internal/query/client.go index 62a93c1e5..f89ffde41 100644 --- a/internal/query/client.go +++ b/internal/query/client.go @@ -404,6 +404,7 @@ func clientQuery(ctx context.Context, pool sessionPool, q string, opts ...option if err != nil { return xerrors.WithStackTrace(err) } + defer func() { _ = streamResult.Close(ctx) }() diff --git a/internal/query/client_test.go b/internal/query/client_test.go index 2421f6259..2021e8177 100644 --- a/internal/query/client_test.go +++ b/internal/query/client_test.go @@ -852,241 +852,350 @@ func TestClient(t *testing.T) { }) }) t.Run("Query", func(t *testing.T) { + mkU64 := func(v uint64) *Ydb.Value { + return &Ydb.Value{Value: &Ydb.Value_Uint64Value{Uint64Value: v}} + } + mkStr := func(v string) *Ydb.Value { + return &Ydb.Value{Value: &Ydb.Value_TextValue{TextValue: v}} + } + mkBool := func(v bool) *Ydb.Value { + return &Ydb.Value{Value: &Ydb.Value_BoolValue{BoolValue: v}} + } t.Run("HappyWay", func(t *testing.T) { ctrl := gomock.NewController(t) - r, err := clientQuery(ctx, testPool(ctx, func(ctx context.Context) (*Session, error) { - stream := NewMockQueryService_ExecuteQueryClient(ctrl) - stream.EXPECT().Recv().Return(&Ydb_Query.ExecuteQueryResponsePart{ - Status: Ydb.StatusIds_SUCCESS, - TxMeta: &Ydb_Query.TransactionMeta{ - Id: "456", + + colsAB := []*Ydb.Column{ + {Name: "a", Type: &Ydb.Type{Type: &Ydb.Type_TypeId{TypeId: Ydb.Type_UINT64}}}, + {Name: "b", Type: &Ydb.Type{Type: &Ydb.Type_TypeId{TypeId: Ydb.Type_UTF8}}}, + } + + colsCDE := []*Ydb.Column{ + {Name: "c", Type: &Ydb.Type{Type: &Ydb.Type_TypeId{TypeId: Ydb.Type_UINT64}}}, + {Name: "d", Type: &Ydb.Type{Type: &Ydb.Type_TypeId{TypeId: Ydb.Type_UTF8}}}, + {Name: "e", Type: &Ydb.Type{Type: &Ydb.Type_TypeId{TypeId: Ydb.Type_BOOL}}}, + } + + respParts := []struct { + idx int + columns []*Ydb.Column + rows [][]*Ydb.Value + }{ + { + idx: 0, + columns: colsAB, + rows: [][]*Ydb.Value{ + {mkU64(1), mkStr("1")}, + {mkU64(2), mkStr("2")}, + {mkU64(3), mkStr("3")}, }, - ResultSetIndex: 0, - ResultSet: &Ydb.ResultSet{ - Columns: []*Ydb.Column{ - { - Name: "a", - Type: &Ydb.Type{ - Type: &Ydb.Type_TypeId{ - TypeId: Ydb.Type_UINT64, - }, - }, - }, - { - Name: "b", - Type: &Ydb.Type{ - Type: &Ydb.Type_TypeId{ - TypeId: Ydb.Type_UTF8, - }, - }, - }, - }, - Rows: []*Ydb.Value{ - { - Items: []*Ydb.Value{{ - Value: &Ydb.Value_Uint64Value{ - Uint64Value: 1, - }, - }, { - Value: &Ydb.Value_TextValue{ - TextValue: "1", - }, - }}, - }, - { - Items: []*Ydb.Value{{ - Value: &Ydb.Value_Uint64Value{ - Uint64Value: 2, - }, - }, { - Value: &Ydb.Value_TextValue{ - TextValue: "2", - }, - }}, - }, - { - Items: []*Ydb.Value{{ - Value: &Ydb.Value_Uint64Value{ - Uint64Value: 3, - }, - }, { - Value: &Ydb.Value_TextValue{ - TextValue: "3", - }, - }}, - }, - }, + }, + { + idx: 0, + rows: [][]*Ydb.Value{ + {mkU64(4), mkStr("4")}, + {mkU64(5), mkStr("5")}, }, - }, nil) - stream.EXPECT().Recv().Return(&Ydb_Query.ExecuteQueryResponsePart{ - Status: Ydb.StatusIds_SUCCESS, - ResultSetIndex: 0, - ResultSet: &Ydb.ResultSet{ - Rows: []*Ydb.Value{ - { - Items: []*Ydb.Value{{ - Value: &Ydb.Value_Uint64Value{ - Uint64Value: 4, - }, - }, { - Value: &Ydb.Value_TextValue{ - TextValue: "4", - }, - }}, - }, - { - Items: []*Ydb.Value{{ - Value: &Ydb.Value_Uint64Value{ - Uint64Value: 5, - }, - }, { - Value: &Ydb.Value_TextValue{ - TextValue: "5", - }, - }}, - }, - }, + }, + { + idx: 1, + columns: colsCDE, + rows: [][]*Ydb.Value{ + {mkU64(1), mkStr("1"), mkBool(true)}, + {mkU64(2), mkStr("2"), mkBool(false)}, }, - }, nil) - stream.EXPECT().Recv().Return(&Ydb_Query.ExecuteQueryResponsePart{ - Status: Ydb.StatusIds_SUCCESS, - ResultSetIndex: 1, - ResultSet: &Ydb.ResultSet{ - Columns: []*Ydb.Column{ - { - Name: "c", - Type: &Ydb.Type{ - Type: &Ydb.Type_TypeId{ - TypeId: Ydb.Type_UINT64, - }, - }, - }, - { - Name: "d", - Type: &Ydb.Type{ - Type: &Ydb.Type_TypeId{ - TypeId: Ydb.Type_UTF8, - }, - }, - }, - { - Name: "e", - Type: &Ydb.Type{ - Type: &Ydb.Type_TypeId{ - TypeId: Ydb.Type_BOOL, - }, - }, - }, - }, - Rows: []*Ydb.Value{ - { - Items: []*Ydb.Value{{ - Value: &Ydb.Value_Uint64Value{ - Uint64Value: 1, - }, - }, { - Value: &Ydb.Value_TextValue{ - TextValue: "1", - }, - }, { - Value: &Ydb.Value_BoolValue{ - BoolValue: true, - }, - }}, - }, - { - Items: []*Ydb.Value{{ - Value: &Ydb.Value_Uint64Value{ - Uint64Value: 2, - }, - }, { - Value: &Ydb.Value_TextValue{ - TextValue: "2", - }, - }, { - Value: &Ydb.Value_BoolValue{ - BoolValue: false, - }, - }}, - }, + }, + } + + stream := NewMockQueryService_ExecuteQueryClient(ctrl) + + for _, p := range respParts { + stream.EXPECT().Recv().Return( + &Ydb_Query.ExecuteQueryResponsePart{ + Status: Ydb.StatusIds_SUCCESS, + TxMeta: &Ydb_Query.TransactionMeta{Id: "456"}, + ResultSetIndex: int64(p.idx), + ResultSet: &Ydb.ResultSet{ + Columns: p.columns, + Rows: func() []*Ydb.Value { + out := make([]*Ydb.Value, len(p.rows)) + for i, items := range p.rows { + out[i] = &Ydb.Value{Items: items} + } + + return out + }(), }, }, - }, nil) - stream.EXPECT().Recv().Return(nil, io.EOF) - client := NewMockQueryServiceClient(ctrl) - client.EXPECT().ExecuteQuery(gomock.Any(), gomock.Any()).Return(stream, nil) + nil, + ) + } + + stream.EXPECT().Recv().Return(nil, io.EOF) + client := NewMockQueryServiceClient(ctrl) + client.EXPECT().ExecuteQuery(gomock.Any(), gomock.Any()).Return(stream, nil) + + r, err := clientQuery(ctx, testPool(ctx, func(context.Context) (*Session, error) { return newTestSessionWithClient("123", client, true), nil }), "") require.NoError(t, err) + { rs, err := r.NextResultSet(ctx) require.NoError(t, err) - r1, err := rs.NextRow(ctx) - require.NoError(t, err) - var ( + for _, want := range []struct { a uint64 b string - ) - err = r1.Scan(&a, &b) - require.NoError(t, err) - require.EqualValues(t, 1, a) - require.EqualValues(t, "1", b) - r2, err := rs.NextRow(ctx) - require.NoError(t, err) - err = r2.Scan(&a, &b) - require.NoError(t, err) - require.EqualValues(t, 2, a) - require.EqualValues(t, "2", b) - r3, err := rs.NextRow(ctx) - require.NoError(t, err) - err = r3.Scan(&a, &b) - require.NoError(t, err) - require.EqualValues(t, 3, a) - require.EqualValues(t, "3", b) - r4, err := rs.NextRow(ctx) - require.NoError(t, err) - err = r4.Scan(&a, &b) - require.NoError(t, err) - require.EqualValues(t, 4, a) - require.EqualValues(t, "4", b) - r5, err := rs.NextRow(ctx) - require.NoError(t, err) - err = r5.Scan(&a, &b) - require.NoError(t, err) - require.EqualValues(t, 5, a) - require.EqualValues(t, "5", b) - r6, err := rs.NextRow(ctx) + }{ + {1, "1"}, + {2, "2"}, + {3, "3"}, + {4, "4"}, + {5, "5"}, + } { + row, err := rs.NextRow(ctx) + if want.a == 5 { + require.NoError(t, err) + } + if errors.Is(err, io.EOF) { + require.Fail(t, "unexpected EOF") + } + var a uint64 + var b string + require.NoError(t, row.Scan(&a, &b)) + require.EqualValues(t, want.a, a) + require.EqualValues(t, want.b, b) + } + row, err := rs.NextRow(ctx) require.ErrorIs(t, err, io.EOF) - require.Nil(t, r6) + require.Nil(t, row) } + { rs, err := r.NextResultSet(ctx) require.NoError(t, err) - r1, err := rs.NextRow(ctx) - require.NoError(t, err) - var ( + + for _, want := range []struct { a uint64 b string c bool + }{ + {1, "1", true}, + {2, "2", false}, + } { + row, err := rs.NextRow(ctx) + require.NoError(t, err) + var a uint64 + var b string + var c bool + require.NoError(t, row.Scan(&a, &b, &c)) + require.EqualValues(t, want.a, a) + require.EqualValues(t, want.b, b) + require.EqualValues(t, want.c, c) + } + + row, err := rs.NextRow(ctx) + require.ErrorIs(t, err, io.EOF) + require.Nil(t, row) + } + }) + t.Run("ConcurrentResultSets", func(t *testing.T) { + ctrl := gomock.NewController(t) + + colsAB := []*Ydb.Column{ + {Name: "a", Type: &Ydb.Type{Type: &Ydb.Type_TypeId{TypeId: Ydb.Type_UINT64}}}, + {Name: "b", Type: &Ydb.Type{Type: &Ydb.Type_TypeId{TypeId: Ydb.Type_UTF8}}}, + } + + colsCDE := []*Ydb.Column{ + {Name: "c", Type: &Ydb.Type{Type: &Ydb.Type_TypeId{TypeId: Ydb.Type_UINT64}}}, + {Name: "d", Type: &Ydb.Type{Type: &Ydb.Type_TypeId{TypeId: Ydb.Type_UTF8}}}, + {Name: "e", Type: &Ydb.Type{Type: &Ydb.Type_TypeId{TypeId: Ydb.Type_BOOL}}}, + } + + respParts := []struct { + idx int + columns []*Ydb.Column + rows [][]*Ydb.Value + }{ + { + idx: 0, + columns: colsAB, + rows: [][]*Ydb.Value{ + {mkU64(1), mkStr("1")}, + {mkU64(2), mkStr("2")}, + {mkU64(3), mkStr("3")}, + }, + }, + { + idx: 1, + columns: colsCDE, + rows: [][]*Ydb.Value{ + {mkU64(1), mkStr("1"), mkBool(true)}, + {mkU64(2), mkStr("2"), mkBool(false)}, + }, + }, + { + idx: 0, + rows: [][]*Ydb.Value{ + {mkU64(4), mkStr("4")}, + {mkU64(5), mkStr("5")}, + }, + }, + } + + stream := NewMockQueryService_ExecuteQueryClient(ctrl) + + for _, p := range respParts { + stream.EXPECT().Recv().Return( + &Ydb_Query.ExecuteQueryResponsePart{ + Status: Ydb.StatusIds_SUCCESS, + TxMeta: &Ydb_Query.TransactionMeta{Id: "456"}, + ResultSetIndex: int64(p.idx), + ResultSet: &Ydb.ResultSet{ + Columns: p.columns, + Rows: func() []*Ydb.Value { + out := make([]*Ydb.Value, len(p.rows)) + for i, items := range p.rows { + out[i] = &Ydb.Value{Items: items} + } + + return out + }(), + }, + }, + nil, ) - err = r1.Scan(&a, &b, &c) - require.NoError(t, err) - require.EqualValues(t, 1, a) - require.EqualValues(t, "1", b) - require.EqualValues(t, true, c) - r2, err := rs.NextRow(ctx) + } + + stream.EXPECT().Recv().Return(nil, io.EOF) + + client := NewMockQueryServiceClient(ctrl) + client.EXPECT().ExecuteQuery(gomock.Any(), gomock.Any()).Return(stream, nil) + + r, err := clientQuery(ctx, testPool(ctx, func(context.Context) (*Session, error) { + return newTestSessionWithClient("123", client, true), nil + }), "", query.WithConcurrentResultSets(true)) + require.NoError(t, err) + + { + rs, err := r.NextResultSet(ctx) require.NoError(t, err) - err = r2.Scan(&a, &b, &c) + + for _, want := range []struct { + a uint64 + b string + }{ + {1, "1"}, + {2, "2"}, + {3, "3"}, + {4, "4"}, + {5, "5"}, + } { + row, err := rs.NextRow(ctx) + if errors.Is(err, io.EOF) { + require.Fail(t, "unexpected EOF in RS0") + } + require.NoError(t, err) + var a uint64 + var b string + require.NoError(t, row.Scan(&a, &b)) + require.EqualValues(t, want.a, a) + require.EqualValues(t, want.b, b) + } + + row, err := rs.NextRow(ctx) + require.ErrorIs(t, err, io.EOF) + require.Nil(t, row) + } + + { + rs, err := r.NextResultSet(ctx) require.NoError(t, err) - require.EqualValues(t, 2, a) - require.EqualValues(t, "2", b) - require.EqualValues(t, false, c) - r3, err := rs.NextRow(ctx) + + for _, want := range []struct { + a uint64 + b string + c bool + }{ + {1, "1", true}, + {2, "2", false}, + } { + row, err := rs.NextRow(ctx) + require.NoError(t, err) + var a uint64 + var b string + var c bool + require.NoError(t, row.Scan(&a, &b, &c)) + require.EqualValues(t, want.a, a) + require.EqualValues(t, want.b, b) + require.EqualValues(t, want.c, c) + } + + row, err := rs.NextRow(ctx) require.ErrorIs(t, err, io.EOF) - require.Nil(t, r3) + require.Nil(t, row) } }) + t.Run("CancelWhileReadResult", func(t *testing.T) { + ctrl := gomock.NewController(t) + executeCtx, cancel := context.WithCancel(xtest.Context(t)) + + stream := NewMockQueryService_ExecuteQueryClient(ctrl) + stream.EXPECT().Recv().DoAndReturn(func() (*Ydb_Query.ExecuteQueryResponsePart, error) { + cancel() + + <-executeCtx.Done() + + return nil, executeCtx.Err() + }) + + client := NewMockQueryServiceClient(ctrl) + client.EXPECT().ExecuteQuery(gomock.Any(), gomock.Any()).Return(stream, nil) + + _, err := clientQuery(executeCtx, testPool(ctx, func(context.Context) (*Session, error) { + return newTestSessionWithClient("123", client, true), nil + }), "", query.WithConcurrentResultSets(true)) + + require.ErrorIs(t, err, context.Canceled) + }) + t.Run("EmptyResultSet", func(t *testing.T) { + ctrl := gomock.NewController(t) + + stream := NewMockQueryService_ExecuteQueryClient(ctrl) + + stream.EXPECT().Recv().Return( + &Ydb_Query.ExecuteQueryResponsePart{ + Status: Ydb.StatusIds_SUCCESS, + TxMeta: &Ydb_Query.TransactionMeta{Id: "456"}, + ResultSetIndex: int64(0), + ResultSet: &Ydb.ResultSet{ + Columns: []*Ydb.Column{ + {Name: "a", Type: &Ydb.Type{Type: &Ydb.Type_TypeId{TypeId: Ydb.Type_UINT64}}}, + }, + Rows: []*Ydb.Value{}, + }, + }, + nil, + ) + + stream.EXPECT().Recv().Return(nil, io.EOF) + + client := NewMockQueryServiceClient(ctrl) + client.EXPECT().ExecuteQuery(gomock.Any(), gomock.Any()).Return(stream, nil) + + r, err := clientQuery(ctx, testPool(ctx, func(context.Context) (*Session, error) { + return newTestSessionWithClient("123", client, true), nil + }), "", query.WithConcurrentResultSets(true)) + require.NoError(t, err) + + rs, err := r.NextResultSet(ctx) + require.NoError(t, err) + + row, err := rs.NextRow(ctx) + require.ErrorIs(t, err, io.EOF) + require.Nil(t, row) + }) t.Run("AllowImplicitSessions", func(t *testing.T) { _, err := mockClientForImplicitSessionTest(ctx, t). Query(ctx, "SELECT 1") diff --git a/internal/query/execute_query.go b/internal/query/execute_query.go index 6cdcb622c..243d8ad1a 100644 --- a/internal/query/execute_query.go +++ b/internal/query/execute_query.go @@ -32,6 +32,7 @@ type executeSettings interface { ResourcePool() string ResponsePartLimitSizeBytes() int64 Label() string + ConcurrentResultSets() bool } type executeScriptConfig interface { @@ -92,7 +93,7 @@ func executeQueryRequest(sessionID, q string, cfg executeSettings) ( }, Parameters: params, StatsMode: Ydb_Query.StatsMode(cfg.StatsMode()), - ConcurrentResultSets: false, + ConcurrentResultSets: cfg.ConcurrentResultSets(), PoolId: cfg.ResourcePool(), ResponsePartLimitBytes: cfg.ResponsePartLimitSizeBytes(), } diff --git a/internal/query/options/execute.go b/internal/query/options/execute.go index 6853c02d9..4e272c601 100644 --- a/internal/query/options/execute.go +++ b/internal/query/options/execute.go @@ -46,6 +46,7 @@ type ( issueCallback func(issues []*Ydb_Issue.IssueMessage) responsePartLimitBytes int64 label string + concurrentResultSets bool } // Execute is an interface for execute method options @@ -72,9 +73,12 @@ type ( } execModeOption = ExecMode responsePartLimitBytes int64 - issuesOption struct { + + issuesOption struct { callback func([]*Ydb_Issue.IssueMessage) } + + concurrentResultSets bool ) func (poolID resourcePool) applyExecuteOption(s *executeSettings) { @@ -132,6 +136,10 @@ func (opts issuesOption) applyExecuteOption(s *executeSettings) { s.issueCallback = opts.callback } +func (opt concurrentResultSets) applyExecuteOption(s *executeSettings) { + s.concurrentResultSets = bool(opt) +} + const ( ExecModeParse = ExecMode(Ydb_Query.ExecMode_EXEC_MODE_PARSE) ExecModeValidate = ExecMode(Ydb_Query.ExecMode_EXEC_MODE_VALIDATE) @@ -205,6 +213,10 @@ func (s *executeSettings) Label() string { return s.label } +func (s *executeSettings) ConcurrentResultSets() bool { + return s.concurrentResultSets +} + func WithParameters(params params.Parameters) parametersOption { return parametersOption{ params: params, @@ -237,6 +249,10 @@ func WithResponsePartLimitSizeBytes(size int64) responsePartLimitBytes { return responsePartLimitBytes(size) } +func WithConcurrentResultSets(isEnabled bool) concurrentResultSets { + return concurrentResultSets(isEnabled) +} + func (size responsePartLimitBytes) applyExecuteOption(s *executeSettings) { s.responsePartLimitBytes = int64(size) } diff --git a/internal/query/result.go b/internal/query/result.go index b26a97a29..797be8bba 100644 --- a/internal/query/result.go +++ b/internal/query/result.go @@ -8,12 +8,14 @@ import ( "time" "github.com/ydb-platform/ydb-go-genproto/Ydb_Query_V1" + "github.com/ydb-platform/ydb-go-genproto/protos/Ydb" "github.com/ydb-platform/ydb-go-genproto/protos/Ydb_Issue" "github.com/ydb-platform/ydb-go-genproto/protos/Ydb_Query" "github.com/ydb-platform/ydb-go-sdk/v3/internal/query/result" "github.com/ydb-platform/ydb-go-sdk/v3/internal/stack" "github.com/ydb-platform/ydb-go-sdk/v3/internal/stats" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/types" "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors" "github.com/ydb-platform/ydb-go-sdk/v3/internal/xiter" "github.com/ydb-platform/ydb-go-sdk/v3/query" @@ -433,11 +435,38 @@ func exactlyOneResultSetFromResult(ctx context.Context, r result.Result) (rs res return MaterializedResultSet(rs.Index(), rs.Columns(), rs.ColumnTypes(), rows), nil } -func resultToMaterializedResult(ctx context.Context, r result.Result) (result.Result, error) { - var resultSets []result.Set +func resultToMaterializedResult(ctx context.Context, r *streamResult) (result.Result, error) { + type resultSet struct { + rows []query.Row + columns []*Ydb.Column + } + resultSetByIndex := make(map[int64]resultSet) + maxIndex := int64(-1) for { - rs, err := r.NextResultSet(ctx) + if ctx.Err() != nil { + return nil, xerrors.WithStackTrace(ctx.Err()) + } + if r.closer.Err() != nil { + return nil, xerrors.WithStackTrace(r.closer.Err()) + } + + curIndex := r.lastPart.GetResultSetIndex() + maxIndex = max(maxIndex, curIndex) + + rs := resultSetByIndex[curIndex] + if len(rs.columns) == 0 { + rs.columns = r.lastPart.GetResultSet().GetColumns() + } + rows := make([]query.Row, len(r.lastPart.GetResultSet().GetRows())) + for i := range r.lastPart.GetResultSet().GetRows() { + rows[i] = NewRow(rs.columns, r.lastPart.GetResultSet().GetRows()[i]) + } + rs.rows = append(rs.rows, rows...) + resultSetByIndex[curIndex] = rs + + var err error + r.lastPart, err = r.nextPart(ctx) if err != nil { if xerrors.Is(err, io.EOF) { break @@ -445,22 +474,22 @@ func resultToMaterializedResult(ctx context.Context, r result.Result) (result.Re return nil, xerrors.WithStackTrace(err) } + if r.lastPart.GetExecStats() != nil && r.statsCallback != nil { + r.statsCallback(stats.FromQueryStats(r.lastPart.GetExecStats())) + } + } - var rows []query.Row - for { - row, err := rs.NextRow(ctx) - if err != nil { - if xerrors.Is(err, io.EOF) { - break - } - - return nil, xerrors.WithStackTrace(err) - } + resultSets := make([]result.Set, maxIndex+1) + for rsIndex, rs := range resultSetByIndex { + columnNames := make([]string, len(rs.columns)) + columnTypes := make([]types.Type, len(rs.columns)) - rows = append(rows, row) + for i := range rs.columns { + columnNames[i] = rs.columns[i].GetName() + columnTypes[i] = types.TypeFromYDB(rs.columns[i].GetType()) } - resultSets = append(resultSets, MaterializedResultSet(rs.Index(), rs.Columns(), rs.ColumnTypes(), rows)) + resultSets[rsIndex] = MaterializedResultSet(int(rsIndex), columnNames, columnTypes, rs.rows) } return &materializedResult{ diff --git a/query/execute_options.go b/query/execute_options.go index c0f6cfb0a..3e188dcd7 100644 --- a/query/execute_options.go +++ b/query/execute_options.go @@ -70,6 +70,10 @@ func WithResponsePartLimitSizeBytes(size int64) ExecuteOption { return options.WithResponsePartLimitSizeBytes(size) } +func WithConcurrentResultSets(isEnabled bool) ExecuteOption { + return options.WithConcurrentResultSets(isEnabled) +} + func WithCallOptions(opts ...grpc.CallOption) ExecuteOption { return options.WithCallOptions(opts...) } diff --git a/tests/integration/query_execute_test.go b/tests/integration/query_execute_test.go index c54ae91f9..bb21ef10d 100644 --- a/tests/integration/query_execute_test.go +++ b/tests/integration/query_execute_test.go @@ -999,3 +999,50 @@ func TestIssue1872QueryWarning(t *testing.T) { issueList[0].Issues[0].Issues[0].Message) }) } + +// https://github.com/ydb-platform/ydb-go-sdk/issues/1878 +func TestIssue1878ConcurrentResultSet(t *testing.T) { + ctx, cancel := context.WithCancel(xtest.Context(t)) + defer cancel() + db, err := ydb.Open(ctx, + os.Getenv("YDB_CONNECTION_STRING"), + ydb.WithAccessTokenCredentials(os.Getenv("YDB_ACCESS_TOKEN_CREDENTIALS")), + ydb.WithTraceQuery( + log.Query( + log.Default(os.Stdout, + log.WithLogQuery(), + log.WithColoring(), + log.WithMinLevel(log.INFO), + ), + trace.QueryEvents, + ), + ), + ) + require.NoError(t, err) + t.Run("Select with enabled option", func(t *testing.T) { + q := db.Query() + res, err := q.Query(ctx, ` + SELECT 1; + SELECT 2; + SELECT 3; + SELECT 4; + SELECT 5; + `, + query.WithSyntax(query.SyntaxYQL), + query.WithIdempotent(), + query.WithConcurrentResultSets(true), + ) + require.NoError(t, err) + rsCount := 0 + for rs, err := range res.ResultSets(ctx) { + rsCount++ + require.NoError(t, err) + row, err := rs.NextRow(ctx) + require.NoError(t, err) + require.Equal(t, 1, len(row.Values())) + require.EqualValues(t, rsCount, row.Values()[0]) + } + require.NoError(t, res.Close(ctx)) + require.Equal(t, 5, rsCount) + }) +}