|
24 | 24 | import org.springframework.core.annotation.Order; |
25 | 25 | import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor; |
26 | 26 |
|
| 27 | +import java.time.LocalDateTime; |
27 | 28 | import java.util.concurrent.CompletableFuture; |
| 29 | +import java.util.concurrent.atomic.AtomicReference; |
28 | 30 |
|
29 | 31 | import static java.util.concurrent.TimeUnit.SECONDS; |
30 | 32 | import static org.assertj.core.api.Assertions.assertThat; |
@@ -106,36 +108,73 @@ void should_create_chat_model_with_default_http_client() { |
106 | 108 | void should_provide_streaming_chat_model() { |
107 | 109 | contextRunner |
108 | 110 | .withPropertyValues( |
109 | | - "langchain4j.open-ai.streaming-chat-model.base-url=" + BASE_URL, |
| 111 | + // not setting base URL to use OpenAI API without caching proxy (proxy responds way faster) |
110 | 112 | "langchain4j.open-ai.streaming-chat-model.api-key=" + API_KEY, |
111 | 113 | "langchain4j.open-ai.streaming-chat-model.model-name=gpt-4o-mini", |
112 | | - "langchain4j.open-ai.streaming-chat-model.max-tokens=20" |
| 114 | + "langchain4j.open-ai.streaming-chat-model.max-tokens=50" |
113 | 115 | ) |
114 | 116 | .run(context -> { |
115 | 117 |
|
116 | 118 | StreamingChatModel model = context.getBean(StreamingChatModel.class); |
117 | 119 | assertThat(model).isInstanceOf(OpenAiStreamingChatModel.class); |
118 | 120 | assertThat(context.getBean(OpenAiStreamingChatModel.class)).isSameAs(model); |
119 | 121 |
|
120 | | - CompletableFuture<ChatResponse> future = new CompletableFuture<>(); |
121 | | - model.chat("What is the capital of Germany?", new StreamingChatResponseHandler() { |
| 122 | + CompletableFuture<ChatResponse> future1 = new CompletableFuture<>(); |
| 123 | + AtomicReference<LocalDateTime> streamingStarted1 = new AtomicReference<>(); |
| 124 | + AtomicReference<LocalDateTime> streamingFinished1 = new AtomicReference<>(); |
| 125 | + model.chat("Tell me a story exactly 50 words long", new StreamingChatResponseHandler() { |
122 | 126 |
|
123 | 127 | @Override |
124 | 128 | public void onPartialResponse(String partialResponse) { |
| 129 | + if (streamingStarted1.get() == null) { |
| 130 | + streamingStarted1.set(LocalDateTime.now()); |
| 131 | + } |
125 | 132 | } |
126 | 133 |
|
127 | 134 | @Override |
128 | 135 | public void onCompleteResponse(ChatResponse completeResponse) { |
129 | | - future.complete(completeResponse); |
| 136 | + streamingFinished1.set(LocalDateTime.now()); |
| 137 | + future1.complete(completeResponse); |
130 | 138 | } |
131 | 139 |
|
132 | 140 | @Override |
133 | 141 | public void onError(Throwable error) { |
134 | | - future.completeExceptionally(error); |
| 142 | + future1.completeExceptionally(error); |
135 | 143 | } |
136 | 144 | }); |
137 | | - ChatResponse chatResponse = future.get(15, SECONDS); |
138 | | - assertThat(chatResponse.aiMessage().text()).contains("Berlin"); |
| 145 | + |
| 146 | + CompletableFuture<ChatResponse> future2 = new CompletableFuture<>(); |
| 147 | + AtomicReference<LocalDateTime> streamingStarted2 = new AtomicReference<>(); |
| 148 | + AtomicReference<LocalDateTime> streamingFinished2 = new AtomicReference<>(); |
| 149 | + model.chat("Tell me a story exactly 50 words long", new StreamingChatResponseHandler() { |
| 150 | + |
| 151 | + @Override |
| 152 | + public void onPartialResponse(String partialResponse) { |
| 153 | + if (streamingStarted2.get() == null) { |
| 154 | + streamingStarted2.set(LocalDateTime.now()); |
| 155 | + } |
| 156 | + } |
| 157 | + |
| 158 | + @Override |
| 159 | + public void onCompleteResponse(ChatResponse completeResponse) { |
| 160 | + streamingFinished2.set(LocalDateTime.now()); |
| 161 | + future2.complete(completeResponse); |
| 162 | + } |
| 163 | + |
| 164 | + @Override |
| 165 | + public void onError(Throwable error) { |
| 166 | + future2.completeExceptionally(error); |
| 167 | + } |
| 168 | + }); |
| 169 | + |
| 170 | + ChatResponse chatResponse1 = future1.get(15, SECONDS); |
| 171 | + assertThat(chatResponse1.aiMessage().text()).isNotBlank(); |
| 172 | + |
| 173 | + ChatResponse chatResponse2 = future2.get(15, SECONDS); |
| 174 | + assertThat(chatResponse2.aiMessage().text()).isNotBlank(); |
| 175 | + |
| 176 | + assertThat(streamingStarted1.get()).isBefore(streamingFinished2.get()); |
| 177 | + assertThat(streamingStarted2.get()).isBefore(streamingFinished1.get()); |
139 | 178 | }); |
140 | 179 | } |
141 | 180 |
|
@@ -232,32 +271,69 @@ public void onError(Throwable error) { |
232 | 271 | void should_create_streaming_chat_model_with_default_http_client() throws Exception { |
233 | 272 |
|
234 | 273 | OpenAiStreamingChatModel model = OpenAiStreamingChatModel.builder() |
235 | | - .baseUrl(BASE_URL) |
| 274 | + // not setting base URL to use OpenAI API without caching proxy (proxy responds way faster) |
236 | 275 | .apiKey(API_KEY) |
237 | 276 | .modelName("gpt-4o-mini") |
238 | 277 | .temperature(0.0) |
239 | | - .maxTokens(20) |
| 278 | + .maxTokens(50) |
240 | 279 | .build(); |
241 | 280 |
|
242 | | - CompletableFuture<ChatResponse> future = new CompletableFuture<>(); |
243 | | - model.chat("What is the capital of Germany?", new StreamingChatResponseHandler() { |
| 281 | + CompletableFuture<ChatResponse> future1 = new CompletableFuture<>(); |
| 282 | + AtomicReference<LocalDateTime> streamingStarted1 = new AtomicReference<>(); |
| 283 | + AtomicReference<LocalDateTime> streamingFinished1 = new AtomicReference<>(); |
| 284 | + model.chat("Tell me a story exactly 50 words long", new StreamingChatResponseHandler() { |
244 | 285 |
|
245 | 286 | @Override |
246 | 287 | public void onPartialResponse(String partialResponse) { |
| 288 | + if (streamingStarted1.get() == null) { |
| 289 | + streamingStarted1.set(LocalDateTime.now()); |
| 290 | + } |
247 | 291 | } |
248 | 292 |
|
249 | 293 | @Override |
250 | 294 | public void onCompleteResponse(ChatResponse completeResponse) { |
251 | | - future.complete(completeResponse); |
| 295 | + streamingFinished1.set(LocalDateTime.now()); |
| 296 | + future1.complete(completeResponse); |
252 | 297 | } |
253 | 298 |
|
254 | 299 | @Override |
255 | 300 | public void onError(Throwable error) { |
256 | | - future.completeExceptionally(error); |
| 301 | + future1.completeExceptionally(error); |
| 302 | + } |
| 303 | + }); |
| 304 | + |
| 305 | + CompletableFuture<ChatResponse> future2 = new CompletableFuture<>(); |
| 306 | + AtomicReference<LocalDateTime> streamingStarted2 = new AtomicReference<>(); |
| 307 | + AtomicReference<LocalDateTime> streamingFinished2 = new AtomicReference<>(); |
| 308 | + model.chat("Tell me a story exactly 50 words long", new StreamingChatResponseHandler() { |
| 309 | + |
| 310 | + @Override |
| 311 | + public void onPartialResponse(String partialResponse) { |
| 312 | + if (streamingStarted2.get() == null) { |
| 313 | + streamingStarted2.set(LocalDateTime.now()); |
| 314 | + } |
| 315 | + } |
| 316 | + |
| 317 | + @Override |
| 318 | + public void onCompleteResponse(ChatResponse completeResponse) { |
| 319 | + streamingFinished2.set(LocalDateTime.now()); |
| 320 | + future2.complete(completeResponse); |
| 321 | + } |
| 322 | + |
| 323 | + @Override |
| 324 | + public void onError(Throwable error) { |
| 325 | + future2.completeExceptionally(error); |
257 | 326 | } |
258 | 327 | }); |
259 | | - ChatResponse chatResponse = future.get(15, SECONDS); |
260 | | - assertThat(chatResponse.aiMessage().text()).contains("Berlin"); |
| 328 | + |
| 329 | + ChatResponse chatResponse1 = future1.get(15, SECONDS); |
| 330 | + assertThat(chatResponse1.aiMessage().text()).isNotBlank(); |
| 331 | + |
| 332 | + ChatResponse chatResponse2 = future2.get(15, SECONDS); |
| 333 | + assertThat(chatResponse2.aiMessage().text()).isNotBlank(); |
| 334 | + |
| 335 | + assertThat(streamingStarted1.get()).isBefore(streamingFinished2.get()); |
| 336 | + assertThat(streamingStarted2.get()).isBefore(streamingFinished1.get()); |
261 | 337 | } |
262 | 338 |
|
263 | 339 | @Test |
|
0 commit comments