Skip to content

Commit c624c9e

Browse files
committed
1 parent a7de4b4 commit c624c9e

File tree

4 files changed

+101
-16
lines changed
  • langchain4j-http-client-spring-restclient/src/main/java/dev/langchain4j/http/client/spring/restclient
  • langchain4j-ollama-spring-boot-starter/src/main/java/dev/langchain4j/ollama/spring
  • langchain4j-open-ai-spring-boot-starter/src

4 files changed

+101
-16
lines changed

langchain4j-http-client-spring-restclient/src/main/java/dev/langchain4j/http/client/spring/restclient/SpringRestClient.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ public SpringRestClient(SpringRestClientBuilder builder) {
5555

5656
private static AsyncTaskExecutor createDefaultStreamingRequestExecutor() {
5757
ThreadPoolTaskExecutor taskExecutor = new ThreadPoolTaskExecutor();
58+
taskExecutor.setQueueCapacity(0);
5859
taskExecutor.initialize();
5960
return taskExecutor;
6061
}

langchain4j-ollama-spring-boot-starter/src/main/java/dev/langchain4j/ollama/spring/AutoConfig.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ HttpClientBuilder ollamaStreamingChatModelHttpClientBuilder(
139139
@ConditionalOnClass(name = "io.micrometer.context.ContextSnapshotFactory")
140140
AsyncTaskExecutor ollamaStreamingChatModelTaskExecutorWithContextPropagation() {
141141
ThreadPoolTaskExecutor taskExecutor = new ThreadPoolTaskExecutor();
142+
taskExecutor.setQueueCapacity(0);
142143
taskExecutor.setThreadNamePrefix(TASK_EXECUTOR_THREAD_NAME_PREFIX);
143144
taskExecutor.setTaskDecorator(new ContextPropagatingTaskDecorator());
144145
return taskExecutor;
@@ -150,6 +151,7 @@ AsyncTaskExecutor ollamaStreamingChatModelTaskExecutorWithContextPropagation() {
150151
@ConditionalOnMissingClass("io.micrometer.context.ContextSnapshotFactory")
151152
AsyncTaskExecutor ollamaStreamingChatModelTaskExecutor() {
152153
ThreadPoolTaskExecutor taskExecutor = new ThreadPoolTaskExecutor();
154+
taskExecutor.setQueueCapacity(0);
153155
taskExecutor.setThreadNamePrefix(TASK_EXECUTOR_THREAD_NAME_PREFIX);
154156
return taskExecutor;
155157
}
@@ -233,6 +235,7 @@ HttpClientBuilder ollamaStreamingLanguageModelHttpClientBuilder(
233235
@ConditionalOnClass(name = "io.micrometer.context.ContextSnapshotFactory")
234236
AsyncTaskExecutor ollamaStreamingLanguageModelTaskExecutorWithContextPropagation() {
235237
ThreadPoolTaskExecutor taskExecutor = new ThreadPoolTaskExecutor();
238+
taskExecutor.setQueueCapacity(0);
236239
taskExecutor.setThreadNamePrefix(TASK_EXECUTOR_THREAD_NAME_PREFIX);
237240
taskExecutor.setTaskDecorator(new ContextPropagatingTaskDecorator());
238241
return taskExecutor;
@@ -244,6 +247,7 @@ AsyncTaskExecutor ollamaStreamingLanguageModelTaskExecutorWithContextPropagation
244247
@ConditionalOnMissingClass("io.micrometer.context.ContextSnapshotFactory")
245248
AsyncTaskExecutor ollamaStreamingLanguageModelTaskExecutor() {
246249
ThreadPoolTaskExecutor taskExecutor = new ThreadPoolTaskExecutor();
250+
taskExecutor.setQueueCapacity(0);
247251
taskExecutor.setThreadNamePrefix(TASK_EXECUTOR_THREAD_NAME_PREFIX);
248252
return taskExecutor;
249253
}

langchain4j-open-ai-spring-boot-starter/src/main/java/dev/langchain4j/openai/spring/AutoConfig.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ HttpClientBuilder openAiStreamingChatModelHttpClientBuilder(
159159
@ConditionalOnClass(name = "io.micrometer.context.ContextSnapshotFactory")
160160
AsyncTaskExecutor openAiStreamingChatModelTaskExecutorWithContextPropagation() {
161161
ThreadPoolTaskExecutor taskExecutor = new ThreadPoolTaskExecutor();
162+
taskExecutor.setQueueCapacity(0);
162163
taskExecutor.setThreadNamePrefix(TASK_EXECUTOR_THREAD_NAME_PREFIX);
163164
taskExecutor.setTaskDecorator(new ContextPropagatingTaskDecorator());
164165
return taskExecutor;
@@ -170,6 +171,7 @@ AsyncTaskExecutor openAiStreamingChatModelTaskExecutorWithContextPropagation() {
170171
@ConditionalOnMissingClass("io.micrometer.context.ContextSnapshotFactory")
171172
AsyncTaskExecutor openAiStreamingChatModelTaskExecutor() {
172173
ThreadPoolTaskExecutor taskExecutor = new ThreadPoolTaskExecutor();
174+
taskExecutor.setQueueCapacity(0);
173175
taskExecutor.setThreadNamePrefix(TASK_EXECUTOR_THREAD_NAME_PREFIX);
174176
return taskExecutor;
175177
}
@@ -247,6 +249,7 @@ HttpClientBuilder openAiStreamingLanguageModelHttpClientBuilder(
247249
@ConditionalOnClass(name = "io.micrometer.context.ContextSnapshotFactory")
248250
AsyncTaskExecutor openAiStreamingLanguageModelTaskExecutorWithContextPropagation() {
249251
ThreadPoolTaskExecutor taskExecutor = new ThreadPoolTaskExecutor();
252+
taskExecutor.setQueueCapacity(0);
250253
taskExecutor.setThreadNamePrefix(TASK_EXECUTOR_THREAD_NAME_PREFIX);
251254
taskExecutor.setTaskDecorator(new ContextPropagatingTaskDecorator());
252255
return taskExecutor;
@@ -258,6 +261,7 @@ AsyncTaskExecutor openAiStreamingLanguageModelTaskExecutorWithContextPropagation
258261
@ConditionalOnMissingClass("io.micrometer.context.ContextSnapshotFactory")
259262
AsyncTaskExecutor openAiStreamingLanguageModelTaskExecutor() {
260263
ThreadPoolTaskExecutor taskExecutor = new ThreadPoolTaskExecutor();
264+
taskExecutor.setQueueCapacity(0);
261265
taskExecutor.setThreadNamePrefix(TASK_EXECUTOR_THREAD_NAME_PREFIX);
262266
return taskExecutor;
263267
}

langchain4j-open-ai-spring-boot-starter/src/test/java/dev/langchain4j/openai/spring/AutoConfigIT.java

Lines changed: 92 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@
2424
import org.springframework.core.annotation.Order;
2525
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
2626

27+
import java.time.LocalDateTime;
2728
import java.util.concurrent.CompletableFuture;
29+
import java.util.concurrent.atomic.AtomicReference;
2830

2931
import static java.util.concurrent.TimeUnit.SECONDS;
3032
import static org.assertj.core.api.Assertions.assertThat;
@@ -106,36 +108,73 @@ void should_create_chat_model_with_default_http_client() {
106108
void should_provide_streaming_chat_model() {
107109
contextRunner
108110
.withPropertyValues(
109-
"langchain4j.open-ai.streaming-chat-model.base-url=" + BASE_URL,
111+
// not setting base URL to use OpenAI API without caching proxy (proxy responds way faster)
110112
"langchain4j.open-ai.streaming-chat-model.api-key=" + API_KEY,
111113
"langchain4j.open-ai.streaming-chat-model.model-name=gpt-4o-mini",
112-
"langchain4j.open-ai.streaming-chat-model.max-tokens=20"
114+
"langchain4j.open-ai.streaming-chat-model.max-tokens=50"
113115
)
114116
.run(context -> {
115117

116118
StreamingChatModel model = context.getBean(StreamingChatModel.class);
117119
assertThat(model).isInstanceOf(OpenAiStreamingChatModel.class);
118120
assertThat(context.getBean(OpenAiStreamingChatModel.class)).isSameAs(model);
119121

120-
CompletableFuture<ChatResponse> future = new CompletableFuture<>();
121-
model.chat("What is the capital of Germany?", new StreamingChatResponseHandler() {
122+
CompletableFuture<ChatResponse> future1 = new CompletableFuture<>();
123+
AtomicReference<LocalDateTime> streamingStarted1 = new AtomicReference<>();
124+
AtomicReference<LocalDateTime> streamingFinished1 = new AtomicReference<>();
125+
model.chat("Tell me a story exactly 50 words long", new StreamingChatResponseHandler() {
122126

123127
@Override
124128
public void onPartialResponse(String partialResponse) {
129+
if (streamingStarted1.get() == null) {
130+
streamingStarted1.set(LocalDateTime.now());
131+
}
125132
}
126133

127134
@Override
128135
public void onCompleteResponse(ChatResponse completeResponse) {
129-
future.complete(completeResponse);
136+
streamingFinished1.set(LocalDateTime.now());
137+
future1.complete(completeResponse);
130138
}
131139

132140
@Override
133141
public void onError(Throwable error) {
134-
future.completeExceptionally(error);
142+
future1.completeExceptionally(error);
135143
}
136144
});
137-
ChatResponse chatResponse = future.get(15, SECONDS);
138-
assertThat(chatResponse.aiMessage().text()).contains("Berlin");
145+
146+
CompletableFuture<ChatResponse> future2 = new CompletableFuture<>();
147+
AtomicReference<LocalDateTime> streamingStarted2 = new AtomicReference<>();
148+
AtomicReference<LocalDateTime> streamingFinished2 = new AtomicReference<>();
149+
model.chat("Tell me a story exactly 50 words long", new StreamingChatResponseHandler() {
150+
151+
@Override
152+
public void onPartialResponse(String partialResponse) {
153+
if (streamingStarted2.get() == null) {
154+
streamingStarted2.set(LocalDateTime.now());
155+
}
156+
}
157+
158+
@Override
159+
public void onCompleteResponse(ChatResponse completeResponse) {
160+
streamingFinished2.set(LocalDateTime.now());
161+
future2.complete(completeResponse);
162+
}
163+
164+
@Override
165+
public void onError(Throwable error) {
166+
future2.completeExceptionally(error);
167+
}
168+
});
169+
170+
ChatResponse chatResponse1 = future1.get(15, SECONDS);
171+
assertThat(chatResponse1.aiMessage().text()).isNotBlank();
172+
173+
ChatResponse chatResponse2 = future2.get(15, SECONDS);
174+
assertThat(chatResponse2.aiMessage().text()).isNotBlank();
175+
176+
assertThat(streamingStarted1.get()).isBefore(streamingFinished2.get());
177+
assertThat(streamingStarted2.get()).isBefore(streamingFinished1.get());
139178
});
140179
}
141180

@@ -232,32 +271,69 @@ public void onError(Throwable error) {
232271
void should_create_streaming_chat_model_with_default_http_client() throws Exception {
233272

234273
OpenAiStreamingChatModel model = OpenAiStreamingChatModel.builder()
235-
.baseUrl(BASE_URL)
274+
// not setting base URL to use OpenAI API without caching proxy (proxy responds way faster)
236275
.apiKey(API_KEY)
237276
.modelName("gpt-4o-mini")
238277
.temperature(0.0)
239-
.maxTokens(20)
278+
.maxTokens(50)
240279
.build();
241280

242-
CompletableFuture<ChatResponse> future = new CompletableFuture<>();
243-
model.chat("What is the capital of Germany?", new StreamingChatResponseHandler() {
281+
CompletableFuture<ChatResponse> future1 = new CompletableFuture<>();
282+
AtomicReference<LocalDateTime> streamingStarted1 = new AtomicReference<>();
283+
AtomicReference<LocalDateTime> streamingFinished1 = new AtomicReference<>();
284+
model.chat("Tell me a story exactly 50 words long", new StreamingChatResponseHandler() {
244285

245286
@Override
246287
public void onPartialResponse(String partialResponse) {
288+
if (streamingStarted1.get() == null) {
289+
streamingStarted1.set(LocalDateTime.now());
290+
}
247291
}
248292

249293
@Override
250294
public void onCompleteResponse(ChatResponse completeResponse) {
251-
future.complete(completeResponse);
295+
streamingFinished1.set(LocalDateTime.now());
296+
future1.complete(completeResponse);
252297
}
253298

254299
@Override
255300
public void onError(Throwable error) {
256-
future.completeExceptionally(error);
301+
future1.completeExceptionally(error);
302+
}
303+
});
304+
305+
CompletableFuture<ChatResponse> future2 = new CompletableFuture<>();
306+
AtomicReference<LocalDateTime> streamingStarted2 = new AtomicReference<>();
307+
AtomicReference<LocalDateTime> streamingFinished2 = new AtomicReference<>();
308+
model.chat("Tell me a story exactly 50 words long", new StreamingChatResponseHandler() {
309+
310+
@Override
311+
public void onPartialResponse(String partialResponse) {
312+
if (streamingStarted2.get() == null) {
313+
streamingStarted2.set(LocalDateTime.now());
314+
}
315+
}
316+
317+
@Override
318+
public void onCompleteResponse(ChatResponse completeResponse) {
319+
streamingFinished2.set(LocalDateTime.now());
320+
future2.complete(completeResponse);
321+
}
322+
323+
@Override
324+
public void onError(Throwable error) {
325+
future2.completeExceptionally(error);
257326
}
258327
});
259-
ChatResponse chatResponse = future.get(15, SECONDS);
260-
assertThat(chatResponse.aiMessage().text()).contains("Berlin");
328+
329+
ChatResponse chatResponse1 = future1.get(15, SECONDS);
330+
assertThat(chatResponse1.aiMessage().text()).isNotBlank();
331+
332+
ChatResponse chatResponse2 = future2.get(15, SECONDS);
333+
assertThat(chatResponse2.aiMessage().text()).isNotBlank();
334+
335+
assertThat(streamingStarted1.get()).isBefore(streamingFinished2.get());
336+
assertThat(streamingStarted2.get()).isBefore(streamingFinished1.get());
261337
}
262338

263339
@Test

0 commit comments

Comments
 (0)