diff --git a/docs/bundles/ai-bundle.rst b/docs/bundles/ai-bundle.rst index ecbe540bb4..3cf4fd02ef 100644 --- a/docs/bundles/ai-bundle.rst +++ b/docs/bundles/ai-bundle.rst @@ -51,6 +51,11 @@ Advanced Example with Multiple Agents deployment: '%env(AZURE_OPENAI_GPT)%' api_key: '%env(AZURE_OPENAI_KEY)%' api_version: '%env(AZURE_GPT_VERSION)%' + bedrock: + # multiple instances possible - for example region depending + default: ~ + eu: + bedrock_runtime_client: 'async_aws.client.bedrock_runtime_eu' eleven_labs: host: '%env(ELEVEN_LABS_HOST)%' api_key: '%env(ELEVEN_LABS_API_KEY)%' @@ -100,6 +105,10 @@ Advanced Example with Multiple Agents platform: 'ai.platform.eleven_labs' model: 'text-to-speech' tools: false + nova: + platform: 'ai.platform.bedrock.default' + model: 'nova-pro' + tools: false store: chromadb: # multiple collections possible per type diff --git a/examples/.env b/examples/.env index fce85f4b88..d5d1045080 100644 --- a/examples/.env +++ b/examples/.env @@ -80,8 +80,8 @@ MAPBOX_ACCESS_TOKEN= MONGODB_URI=mongodb://symfony:symfony@127.0.0.1:27017 # For using Pinecone (store) -PINECONE_API_KEY= -PINECONE_HOST= +PINECONE_API_KEY=pclocal +PINECONE_HOST=http://127.0.0.1:5080 # For using Postgres (store) POSTGRES_URI=pdo-pgsql://postgres:postgres@127.0.1:5432/my_database diff --git a/examples/commands/stores.php b/examples/commands/stores.php index 026292b864..26f140a402 100644 --- a/examples/commands/stores.php +++ b/examples/commands/stores.php @@ -14,6 +14,7 @@ use Doctrine\DBAL\DriverManager; use Doctrine\DBAL\Tools\DsnParser; use MongoDB\Client as MongoDbClient; +use Probots\Pinecone\Client as PineconeClient; use Symfony\AI\Store\Bridge\Cache\Store as CacheStore; use Symfony\AI\Store\Bridge\ClickHouse\Store as ClickHouseStore; use Symfony\AI\Store\Bridge\Elasticsearch\Store as ElasticsearchStore; @@ -24,6 +25,7 @@ use Symfony\AI\Store\Bridge\MongoDb\Store as MongoDbStore; use Symfony\AI\Store\Bridge\Neo4j\Store as Neo4jStore; use Symfony\AI\Store\Bridge\OpenSearch\Store as OpenSearchStore; +use Symfony\AI\Store\Bridge\Pinecone\Store as PineconeStore; use Symfony\AI\Store\Bridge\Postgres\Store as PostgresStore; use Symfony\AI\Store\Bridge\Qdrant\Store as QdrantStore; use Symfony\AI\Store\Bridge\Redis\Store as RedisStore; @@ -99,6 +101,10 @@ // env('OPENSEARCH_ENDPOINT'), // 'symfony', // ), + // 'pinecone' => static fn (): PineconeStore => new PineconeStore( + // new PineconeClient(env('PINECONE_API_KEY'), env('PINECONE_HOST')), + // 'symfony', + // ), 'postgres' => static fn (): PostgresStore => PostgresStore::fromDbal( DriverManager::getConnection((new DsnParser())->parse(env('POSTGRES_URI'))), 'my_table', diff --git a/examples/compose.yaml b/examples/compose.yaml index 1423519d0e..6c60303066 100644 --- a/examples/compose.yaml +++ b/examples/compose.yaml @@ -146,6 +146,15 @@ services: ports: - '9201:9200' + pinecone: + image: ghcr.io/pinecone-io/pinecone-local:latest + platform: linux/amd64 + environment: + PORT: 5080 + PINECONE_HOST: localhost + ports: + - '5080-5090:5080-5090' + opensearch: image: opensearchproject/opensearch environment: diff --git a/examples/rag/pinecone.php b/examples/rag/pinecone.php index 072dd6b960..36049d9866 100644 --- a/examples/rag/pinecone.php +++ b/examples/rag/pinecone.php @@ -29,7 +29,7 @@ require_once dirname(__DIR__).'/bootstrap.php'; // initialize the store -$store = new Store(Pinecone::client(env('PINECONE_API_KEY'), env('PINECONE_HOST'))); +$store = new Store(Pinecone::client(env('PINECONE_API_KEY'), env('PINECONE_HOST')), 'symfony'); // create embeddings and documents $documents = []; diff --git a/src/ai-bundle/config/options.php b/src/ai-bundle/config/options.php index 1a3396e447..7da770a8c9 100644 --- a/src/ai-bundle/config/options.php +++ b/src/ai-bundle/config/options.php @@ -62,6 +62,18 @@ ->end() ->end() ->end() + ->arrayNode('bedrock') + ->useAttributeAsKey('name') + ->arrayPrototype() + ->children() + ->stringNode('bedrock_runtime_client') + ->defaultNull() + ->info('Service ID of the Bedrock runtime client to use') + ->end() + ->stringNode('model_catalog')->defaultNull()->end() + ->end() + ->end() + ->end() ->arrayNode('cache') ->useAttributeAsKey('name') ->arrayPrototype() @@ -800,6 +812,7 @@ ->cannotBeEmpty() ->defaultValue(PineconeClient::class) ->end() + ->stringNode('index_name')->isRequired()->end() ->stringNode('namespace')->end() ->arrayNode('filter') ->scalarPrototype() diff --git a/src/ai-bundle/config/services.php b/src/ai-bundle/config/services.php index 6980570d1b..4ea8963fde 100644 --- a/src/ai-bundle/config/services.php +++ b/src/ai-bundle/config/services.php @@ -29,6 +29,7 @@ use Symfony\AI\Platform\Bridge\Anthropic\Contract\AnthropicContract; use Symfony\AI\Platform\Bridge\Anthropic\ModelCatalog as AnthropicModelCatalog; use Symfony\AI\Platform\Bridge\Azure\OpenAi\ModelCatalog as AzureOpenAiModelCatalog; +use Symfony\AI\Platform\Bridge\Bedrock\ModelCatalog as BedrockModelCatalog; use Symfony\AI\Platform\Bridge\Cartesia\ModelCatalog as CartesiaModelCatalog; use Symfony\AI\Platform\Bridge\Cerebras\ModelCatalog as CerebrasModelCatalog; use Symfony\AI\Platform\Bridge\Decart\ModelCatalog as DecartModelCatalog; @@ -96,6 +97,7 @@ ->set('ai.platform.model_catalog.albert', AlbertModelCatalog::class) ->set('ai.platform.model_catalog.anthropic', AnthropicModelCatalog::class) ->set('ai.platform.model_catalog.azure.openai', AzureOpenAiModelCatalog::class) + ->set('ai.platform.model_catalog.bedrock', BedrockModelCatalog::class) ->set('ai.platform.model_catalog.cartesia', CartesiaModelCatalog::class) ->set('ai.platform.model_catalog.cerebras', CerebrasModelCatalog::class) ->set('ai.platform.model_catalog.decart', DecartModelCatalog::class) diff --git a/src/ai-bundle/src/AiBundle.php b/src/ai-bundle/src/AiBundle.php index af38f73cd6..fe372e1581 100644 --- a/src/ai-bundle/src/AiBundle.php +++ b/src/ai-bundle/src/AiBundle.php @@ -53,6 +53,7 @@ use Symfony\AI\Platform\Bridge\Albert\PlatformFactory as AlbertPlatformFactory; use Symfony\AI\Platform\Bridge\Anthropic\PlatformFactory as AnthropicPlatformFactory; use Symfony\AI\Platform\Bridge\Azure\OpenAi\PlatformFactory as AzureOpenAiPlatformFactory; +use Symfony\AI\Platform\Bridge\Bedrock\PlatformFactory as BedrockFactory; use Symfony\AI\Platform\Bridge\Cartesia\PlatformFactory as CartesiaPlatformFactory; use Symfony\AI\Platform\Bridge\Cerebras\PlatformFactory as CerebrasPlatformFactory; use Symfony\AI\Platform\Bridge\Decart\PlatformFactory as DecartPlatformFactory; @@ -408,6 +409,31 @@ private function processPlatformConfig(string $type, array $platform, ContainerB return; } + if ('bedrock' === $type) { + foreach ($platform as $name => $config) { + if (!ContainerBuilder::willBeAvailable('symfony/ai-bedrock-platform', BedrockFactory::class, ['symfony/ai-bundle'])) { + throw new RuntimeException('Bedrock platform configuration requires "symfony/ai-bedrock-platform" package. Try running "composer require symfony/ai-bedrock-platform".'); + } + + $platformId = 'ai.platform.bedrock.'.$name; + $definition = (new Definition(Platform::class)) + ->setFactory(BedrockFactory::class.'::create') + ->setLazy(true) + ->addTag('proxy', ['interface' => PlatformInterface::class]) + ->setArguments([ + $config['bedrock_runtime_client'] ? new Reference($config['bedrock_runtime_client'], ContainerInterface::NULL_ON_INVALID_REFERENCE) : null, + $config['model_catalog'] ? new Reference($config['model_catalog']) : new Reference('ai.platform.model_catalog.bedrock'), + null, + new Reference('event_dispatcher'), + ]) + ->addTag('ai.platform', ['name' => 'bedrock.'.$name]); + + $container->setDefinition($platformId, $definition); + } + + return; + } + if ('cache' === $type) { foreach ($platform as $name => $cachedPlatformConfig) { $definition = (new Definition(CachedPlatform::class)) @@ -1519,12 +1545,13 @@ private function processStoreConfig(string $type, array $stores, ContainerBuilde foreach ($stores as $name => $store) { $arguments = [ new Reference($store['client']), + $store['index_name'], $store['namespace'] ?? $name, $store['filter'], ]; if (\array_key_exists('top_k', $store)) { - $arguments[3] = $store['top_k']; + $arguments[4] = $store['top_k']; } $definition = new Definition(PineconeStore::class); @@ -1532,6 +1559,7 @@ private function processStoreConfig(string $type, array $stores, ContainerBuilde ->setLazy(true) ->setArguments($arguments) ->addTag('proxy', ['interface' => StoreInterface::class]) + ->addTag('proxy', ['interface' => ManagedStoreInterface::class]) ->addTag('ai.store'); $container->setDefinition('ai.store.'.$type.'.'.$name, $definition); diff --git a/src/ai-bundle/tests/DependencyInjection/AiBundleTest.php b/src/ai-bundle/tests/DependencyInjection/AiBundleTest.php index ed2157c22b..8a6737b760 100644 --- a/src/ai-bundle/tests/DependencyInjection/AiBundleTest.php +++ b/src/ai-bundle/tests/DependencyInjection/AiBundleTest.php @@ -11,6 +11,7 @@ namespace Symfony\AI\AiBundle\Tests\DependencyInjection; +use AsyncAws\BedrockRuntime\BedrockRuntimeClient; use Codewithkyrian\ChromaDB\Client; use MongoDB\Client as MongoDbClient; use PHPUnit\Framework\Attributes\DoesNotPerformAssertions; @@ -2370,48 +2371,13 @@ public function testOpenSearchStoreWithCustomHttpClientCanBeConfigured() } public function testPineconeStoreCanBeConfigured() - { - $container = $this->buildContainer([ - 'ai' => [ - 'store' => [ - 'pinecone' => [ - 'my_pinecone_store' => [], - ], - ], - ], - ]); - - $this->assertTrue($container->hasDefinition('ai.store.pinecone.my_pinecone_store')); - - $definition = $container->getDefinition('ai.store.pinecone.my_pinecone_store'); - $this->assertSame(PineconeStore::class, $definition->getClass()); - - $this->assertTrue($definition->isLazy()); - $this->assertCount(3, $definition->getArguments()); - $this->assertInstanceOf(Reference::class, $definition->getArgument(0)); - $this->assertSame(PineconeClient::class, (string) $definition->getArgument(0)); - $this->assertSame('my_pinecone_store', $definition->getArgument(1)); - $this->assertSame([], $definition->getArgument(2)); - - $this->assertTrue($definition->hasTag('proxy')); - $this->assertSame([['interface' => StoreInterface::class]], $definition->getTag('proxy')); - $this->assertTrue($definition->hasTag('ai.store')); - - $this->assertTrue($container->hasAlias('.Symfony\AI\Store\StoreInterface $my_pinecone_store')); - $this->assertTrue($container->hasAlias('Symfony\AI\Store\StoreInterface $myPineconeStore')); - $this->assertTrue($container->hasAlias('.Symfony\AI\Store\StoreInterface $pinecone_my_pinecone_store')); - $this->assertTrue($container->hasAlias('Symfony\AI\Store\StoreInterface $pineconeMyPineconeStore')); - $this->assertTrue($container->hasAlias('Symfony\AI\Store\StoreInterface')); - } - - public function testPineconeStoreWithCustomNamespaceCanBeConfigured() { $container = $this->buildContainer([ 'ai' => [ 'store' => [ 'pinecone' => [ 'my_pinecone_store' => [ - 'namespace' => 'my_namespace', + 'index_name' => 'my_index', ], ], ], @@ -2424,14 +2390,15 @@ public function testPineconeStoreWithCustomNamespaceCanBeConfigured() $this->assertSame(PineconeStore::class, $definition->getClass()); $this->assertTrue($definition->isLazy()); - $this->assertCount(3, $definition->getArguments()); + $this->assertCount(4, $definition->getArguments()); $this->assertInstanceOf(Reference::class, $definition->getArgument(0)); $this->assertSame(PineconeClient::class, (string) $definition->getArgument(0)); - $this->assertSame('my_namespace', $definition->getArgument(1)); - $this->assertSame([], $definition->getArgument(2)); + $this->assertSame('my_index', $definition->getArgument(1)); + $this->assertSame('my_pinecone_store', $definition->getArgument(2)); + $this->assertSame([], $definition->getArgument(3)); $this->assertTrue($definition->hasTag('proxy')); - $this->assertSame([['interface' => StoreInterface::class]], $definition->getTag('proxy')); + $this->assertSame([['interface' => StoreInterface::class], ['interface' => ManagedStoreInterface::class]], $definition->getTag('proxy')); $this->assertTrue($definition->hasTag('ai.store')); $this->assertTrue($container->hasAlias('.Symfony\AI\Store\StoreInterface $my_pinecone_store')); @@ -2441,14 +2408,14 @@ public function testPineconeStoreWithCustomNamespaceCanBeConfigured() $this->assertTrue($container->hasAlias('Symfony\AI\Store\StoreInterface')); } - public function testPineconeStoreWithCustomClientCanBeConfigured() + public function testPineconeStoreWithCustomIndexNameCanBeConfigured() { $container = $this->buildContainer([ 'ai' => [ 'store' => [ 'pinecone' => [ 'my_pinecone_store' => [ - 'client' => 'foo', + 'index_name' => 'custom_index', 'namespace' => 'my_namespace', ], ], @@ -2461,54 +2428,16 @@ public function testPineconeStoreWithCustomClientCanBeConfigured() $definition = $container->getDefinition('ai.store.pinecone.my_pinecone_store'); $this->assertSame(PineconeStore::class, $definition->getClass()); - $this->assertTrue($definition->isLazy()); - $this->assertCount(3, $definition->getArguments()); - $this->assertInstanceOf(Reference::class, $definition->getArgument(0)); - $this->assertSame('foo', (string) $definition->getArgument(0)); - $this->assertSame('my_namespace', $definition->getArgument(1)); - $this->assertSame([], $definition->getArgument(2)); - - $this->assertTrue($definition->hasTag('proxy')); - $this->assertSame([['interface' => StoreInterface::class]], $definition->getTag('proxy')); - $this->assertTrue($definition->hasTag('ai.store')); - - $this->assertTrue($container->hasAlias('.Symfony\AI\Store\StoreInterface $my_pinecone_store')); - $this->assertTrue($container->hasAlias('Symfony\AI\Store\StoreInterface $myPineconeStore')); - $this->assertTrue($container->hasAlias('.Symfony\AI\Store\StoreInterface $pinecone_my_pinecone_store')); - $this->assertTrue($container->hasAlias('Symfony\AI\Store\StoreInterface $pineconeMyPineconeStore')); - $this->assertTrue($container->hasAlias('Symfony\AI\Store\StoreInterface')); - } - - public function testPineconeStoreWithTopKCanBeConfigured() - { - $container = $this->buildContainer([ - 'ai' => [ - 'store' => [ - 'pinecone' => [ - 'my_pinecone_store' => [ - 'namespace' => 'my_namespace', - 'top_k' => 100, - ], - ], - ], - ], - ]); - - $this->assertTrue($container->hasDefinition('ai.store.pinecone.my_pinecone_store')); - - $definition = $container->getDefinition('ai.store.pinecone.my_pinecone_store'); - $this->assertSame(PineconeStore::class, $definition->getClass()); - $this->assertTrue($definition->isLazy()); $this->assertCount(4, $definition->getArguments()); $this->assertInstanceOf(Reference::class, $definition->getArgument(0)); $this->assertSame(PineconeClient::class, (string) $definition->getArgument(0)); - $this->assertSame('my_namespace', $definition->getArgument(1)); - $this->assertSame([], $definition->getArgument(2)); - $this->assertSame(100, $definition->getArgument(3)); + $this->assertSame('custom_index', $definition->getArgument(1)); + $this->assertSame('my_namespace', $definition->getArgument(2)); + $this->assertSame([], $definition->getArgument(3)); $this->assertTrue($definition->hasTag('proxy')); - $this->assertSame([['interface' => StoreInterface::class]], $definition->getTag('proxy')); + $this->assertSame([['interface' => StoreInterface::class], ['interface' => ManagedStoreInterface::class]], $definition->getTag('proxy')); $this->assertTrue($definition->hasTag('ai.store')); $this->assertTrue($container->hasAlias('.Symfony\AI\Store\StoreInterface $my_pinecone_store')); @@ -7013,6 +6942,7 @@ private function buildContainer(array $configuration): ContainerBuilder $container->setParameter('kernel.environment', 'dev'); $container->setParameter('kernel.build_dir', 'public'); $container->setDefinition(ClockInterface::class, new Definition(MonotonicClock::class)); + $container->setDefinition('async_aws.client.bedrock_us', new Definition(BedrockRuntimeClient::class)); $extension = (new AiBundle())->getContainerExtension(); $extension->load($configuration, $container); @@ -7049,6 +6979,12 @@ private function getFullConfig(): array 'api_version' => '2024-02-15-preview', ], ], + 'bedrock' => [ + 'default' => [], + 'us' => [ + 'bedrock_runtime_client' => 'async_aws.client.bedrock_us', + ], + ], 'cache' => [ 'azure' => [ 'platform' => 'ai.platform.azure.my_azure_instance', @@ -7317,15 +7253,18 @@ private function getFullConfig(): array ], 'pinecone' => [ 'my_pinecone_store' => [ + 'index_name' => 'my_index', 'namespace' => 'my_namespace', 'filter' => ['category' => 'books'], 'top_k' => 10, ], 'my_pinecone_store_with_filter' => [ + 'index_name' => 'my_index', 'namespace' => 'my_namespace', 'filter' => ['category' => 'books'], ], 'my_pinecone_store_with_top_k' => [ + 'index_name' => 'my_index', 'namespace' => 'my_namespace', 'filter' => ['category' => 'books'], 'top_k' => 10, diff --git a/src/platform/src/Bridge/Bedrock/Nova/NovaModelClient.php b/src/platform/src/Bridge/Bedrock/Nova/NovaModelClient.php index a1990da623..31329a0c42 100644 --- a/src/platform/src/Bridge/Bedrock/Nova/NovaModelClient.php +++ b/src/platform/src/Bridge/Bedrock/Nova/NovaModelClient.php @@ -34,6 +34,8 @@ public function supports(Model $model): bool public function request(Model $model, array|string $payload, array $options = []): RawBedrockResult { + unset($payload['model']); + $modelOptions = []; if (isset($options['tools'])) { $modelOptions['toolConfig']['tools'] = $options['tools']; diff --git a/src/platform/src/Bridge/Bedrock/PlatformFactory.php b/src/platform/src/Bridge/Bedrock/PlatformFactory.php index 2d7632fe09..be2af84852 100644 --- a/src/platform/src/Bridge/Bedrock/PlatformFactory.php +++ b/src/platform/src/Bridge/Bedrock/PlatformFactory.php @@ -33,7 +33,7 @@ final class PlatformFactory { public static function create( - BedrockRuntimeClient $bedrockRuntimeClient = new BedrockRuntimeClient(), + ?BedrockRuntimeClient $bedrockRuntimeClient = null, ModelCatalogInterface $modelCatalog = new ModelCatalog(), ?Contract $contract = null, ?EventDispatcherInterface $eventDispatcher = null, @@ -42,6 +42,10 @@ public static function create( throw new RuntimeException('For using the Bedrock platform, the async-aws/bedrock-runtime package is required. Try running "composer require async-aws/bedrock-runtime".'); } + if (null === $bedrockRuntimeClient) { + $bedrockRuntimeClient = new BedrockRuntimeClient(); + } + return new Platform( [ new ClaudeModelClient($bedrockRuntimeClient), diff --git a/src/platform/src/Bridge/Bedrock/Tests/Anthropic/ClaudeModelClientTest.php b/src/platform/src/Bridge/Bedrock/Tests/Anthropic/ClaudeModelClientTest.php new file mode 100644 index 0000000000..fdf5100627 --- /dev/null +++ b/src/platform/src/Bridge/Bedrock/Tests/Anthropic/ClaudeModelClientTest.php @@ -0,0 +1,132 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\Bedrock\Tests\Anthropic; + +use AsyncAws\BedrockRuntime\BedrockRuntimeClient; +use AsyncAws\BedrockRuntime\Input\InvokeModelRequest; +use AsyncAws\BedrockRuntime\Result\InvokeModelResponse; +use AsyncAws\Core\Configuration; +use PHPUnit\Framework\MockObject\MockObject; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\Bridge\Anthropic\Claude; +use Symfony\AI\Platform\Bridge\Bedrock\Anthropic\ClaudeModelClient; +use Symfony\AI\Platform\Bridge\Bedrock\RawBedrockResult; + +final class ClaudeModelClientTest extends TestCase +{ + private const VERSION = '2023-05-31'; + + private MockObject&BedrockRuntimeClient $bedrockClient; + private ClaudeModelClient $modelClient; + private Claude $model; + + protected function setUp(): void + { + $this->model = new Claude('claude-sonnet-4-5-20250929'); + $this->bedrockClient = $this->getMockBuilder(BedrockRuntimeClient::class) + ->setConstructorArgs([ + Configuration::create([Configuration::OPTION_REGION => Configuration::DEFAULT_REGION]), + ]) + ->onlyMethods(['invokeModel']) + ->getMock(); + } + + public function testPassesModelId() + { + $this->bedrockClient->expects($this->once()) + ->method('invokeModel') + ->with($this->callback(function ($arg) { + $this->assertInstanceOf(InvokeModelRequest::class, $arg); + $this->assertSame('us.anthropic.claude-sonnet-4-5-20250929-v1:0', $arg->getModelId()); + $this->assertSame('application/json', $arg->getContentType()); + $this->assertTrue(json_validate($arg->getBody())); + + return true; + })) + ->willReturn($this->createMock(InvokeModelResponse::class)); + + $this->modelClient = new ClaudeModelClient($this->bedrockClient, self::VERSION); + + $response = $this->modelClient->request($this->model, ['message' => 'test']); + $this->assertInstanceOf(RawBedrockResult::class, $response); + } + + public function testUnsetsModelName() + { + $this->bedrockClient->expects($this->once()) + ->method('invokeModel') + ->with($this->callback(function ($arg) { + $this->assertInstanceOf(InvokeModelRequest::class, $arg); + $this->assertSame('application/json', $arg->getContentType()); + $this->assertTrue(json_validate($arg->getBody())); + + $body = json_decode($arg->getBody(), true); + $this->assertArrayNotHasKey('model', $body); + + return true; + })) + ->willReturn($this->createMock(InvokeModelResponse::class)); + + $this->modelClient = new ClaudeModelClient($this->bedrockClient, self::VERSION); + + $response = $this->modelClient->request($this->model, ['message' => 'test', 'model' => 'claude']); + $this->assertInstanceOf(RawBedrockResult::class, $response); + } + + public function testSetsAnthropicVersion() + { + $this->bedrockClient->expects($this->once()) + ->method('invokeModel') + ->with($this->callback(function ($arg) { + $this->assertInstanceOf(InvokeModelRequest::class, $arg); + $this->assertSame('application/json', $arg->getContentType()); + $this->assertTrue(json_validate($arg->getBody())); + + $body = json_decode($arg->getBody(), true); + $this->assertSame('bedrock-'.self::VERSION, $body['anthropic_version']); + + return true; + })) + ->willReturn($this->createMock(InvokeModelResponse::class)); + + $this->modelClient = new ClaudeModelClient($this->bedrockClient, self::VERSION); + + $response = $this->modelClient->request($this->model, ['message' => 'test']); + $this->assertInstanceOf(RawBedrockResult::class, $response); + } + + public function testSetsToolOptionsIfToolsEnabled() + { + $this->bedrockClient->expects($this->once()) + ->method('invokeModel') + ->with($this->callback(function ($arg) { + $this->assertInstanceOf(InvokeModelRequest::class, $arg); + $this->assertSame('application/json', $arg->getContentType()); + $this->assertTrue(json_validate($arg->getBody())); + + $body = json_decode($arg->getBody(), true); + $this->assertSame(['type' => 'auto'], $body['tool_choice']); + + return true; + })) + ->willReturn($this->createMock(InvokeModelResponse::class)); + + $this->modelClient = new ClaudeModelClient($this->bedrockClient, self::VERSION); + + $options = [ + 'tools' => ['Tool'], + ]; + + $response = $this->modelClient->request($this->model, ['message' => 'test'], $options); + $this->assertInstanceOf(RawBedrockResult::class, $response); + } +} diff --git a/src/platform/src/Bridge/Bedrock/Tests/Nova/NovaModelClientTest.php b/src/platform/src/Bridge/Bedrock/Tests/Nova/NovaModelClientTest.php new file mode 100644 index 0000000000..0f1195e768 --- /dev/null +++ b/src/platform/src/Bridge/Bedrock/Tests/Nova/NovaModelClientTest.php @@ -0,0 +1,190 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\Bedrock\Tests\Nova; + +use AsyncAws\BedrockRuntime\BedrockRuntimeClient; +use AsyncAws\BedrockRuntime\Input\InvokeModelRequest; +use AsyncAws\BedrockRuntime\Result\InvokeModelResponse; +use AsyncAws\Core\Configuration; +use PHPUnit\Framework\MockObject\MockObject; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\Bridge\Bedrock\Nova\Nova; +use Symfony\AI\Platform\Bridge\Bedrock\Nova\NovaModelClient; +use Symfony\AI\Platform\Bridge\Bedrock\RawBedrockResult; + +final class NovaModelClientTest extends TestCase +{ + private MockObject&BedrockRuntimeClient $bedrockClient; + private NovaModelClient $modelClient; + private Nova $model; + + protected function setUp(): void + { + $this->model = new Nova('nova-pro'); + $this->bedrockClient = $this->getMockBuilder(BedrockRuntimeClient::class) + ->setConstructorArgs([ + Configuration::create([Configuration::OPTION_REGION => Configuration::DEFAULT_REGION]), + ]) + ->onlyMethods(['invokeModel']) + ->getMock(); + } + + public function testPassesModelId() + { + $this->bedrockClient->expects($this->once()) + ->method('invokeModel') + ->with($this->callback(function ($arg) { + $this->assertInstanceOf(InvokeModelRequest::class, $arg); + $this->assertSame('us.amazon.nova-pro-v1:0', $arg->getModelId()); + $this->assertSame('application/json', $arg->getContentType()); + $this->assertTrue(json_validate($arg->getBody())); + + return true; + })) + ->willReturn($this->createMock(InvokeModelResponse::class)); + + $this->modelClient = new NovaModelClient($this->bedrockClient); + + $response = $this->modelClient->request($this->model, ['message' => 'test']); + $this->assertInstanceOf(RawBedrockResult::class, $response); + } + + public function testUnsetsModelName() + { + $this->bedrockClient->expects($this->once()) + ->method('invokeModel') + ->with($this->callback(function ($arg) { + $this->assertInstanceOf(InvokeModelRequest::class, $arg); + $this->assertSame('application/json', $arg->getContentType()); + $this->assertTrue(json_validate($arg->getBody())); + + $body = json_decode($arg->getBody(), true); + $this->assertArrayNotHasKey('model', $body); + + return true; + })) + ->willReturn($this->createMock(InvokeModelResponse::class)); + + $this->modelClient = new NovaModelClient($this->bedrockClient); + + $response = $this->modelClient->request($this->model, ['message' => 'test', 'model' => 'nova-pro']); + $this->assertInstanceOf(RawBedrockResult::class, $response); + } + + public function testSetsToolOptionsIfToolsEnabled() + { + $this->bedrockClient->expects($this->once()) + ->method('invokeModel') + ->with($this->callback(function ($arg) { + $this->assertInstanceOf(InvokeModelRequest::class, $arg); + $this->assertSame('application/json', $arg->getContentType()); + $this->assertTrue(json_validate($arg->getBody())); + + $body = json_decode($arg->getBody(), true); + $this->assertSame(['tools' => ['Tool']], $body['toolConfig']); + + return true; + })) + ->willReturn($this->createMock(InvokeModelResponse::class)); + + $this->modelClient = new NovaModelClient($this->bedrockClient); + + $options = [ + 'tools' => ['Tool'], + ]; + + $response = $this->modelClient->request($this->model, ['message' => 'test'], $options); + $this->assertInstanceOf(RawBedrockResult::class, $response); + } + + public function testPassesTemperature() + { + $this->bedrockClient->expects($this->once()) + ->method('invokeModel') + ->with($this->callback(function ($arg) { + $this->assertInstanceOf(InvokeModelRequest::class, $arg); + $this->assertSame('application/json', $arg->getContentType()); + $this->assertTrue(json_validate($arg->getBody())); + + $body = json_decode($arg->getBody(), true); + $this->assertArrayHasKey('inferenceConfig', $body); + $this->assertSame(['temperature' => 0.35], $body['inferenceConfig']); + + return true; + })) + ->willReturn($this->createMock(InvokeModelResponse::class)); + + $this->modelClient = new NovaModelClient($this->bedrockClient); + + $options = [ + 'temperature' => 0.35, + ]; + + $response = $this->modelClient->request($this->model, ['message' => 'test'], $options); + $this->assertInstanceOf(RawBedrockResult::class, $response); + } + + public function testPassesMaxTokens() + { + $this->bedrockClient->expects($this->once()) + ->method('invokeModel') + ->with($this->callback(function ($arg) { + $this->assertInstanceOf(InvokeModelRequest::class, $arg); + $this->assertSame('application/json', $arg->getContentType()); + $this->assertTrue(json_validate($arg->getBody())); + + $body = json_decode($arg->getBody(), true); + $this->assertArrayHasKey('inferenceConfig', $body); + $this->assertSame(['maxTokens' => 1000], $body['inferenceConfig']); + + return true; + })) + ->willReturn($this->createMock(InvokeModelResponse::class)); + + $this->modelClient = new NovaModelClient($this->bedrockClient); + + $options = [ + 'max_tokens' => 1000, + ]; + + $response = $this->modelClient->request($this->model, ['message' => 'test'], $options); + $this->assertInstanceOf(RawBedrockResult::class, $response); + } + + public function testPassesBothTemperatureAndMaxTokens() + { + $this->bedrockClient->expects($this->once()) + ->method('invokeModel') + ->with($this->callback(function ($arg) { + $this->assertInstanceOf(InvokeModelRequest::class, $arg); + $this->assertSame('application/json', $arg->getContentType()); + $this->assertTrue(json_validate($arg->getBody())); + + $body = json_decode($arg->getBody(), true); + $this->assertArrayHasKey('inferenceConfig', $body); + $this->assertSame(['temperature' => 0.35, 'maxTokens' => 1000], $body['inferenceConfig']); + + return true; + })) + ->willReturn($this->createMock(InvokeModelResponse::class)); + + $this->modelClient = new NovaModelClient($this->bedrockClient); + + $options = [ + 'max_tokens' => 1000, + 'temperature' => 0.35, + ]; + + $response = $this->modelClient->request($this->model, ['message' => 'test'], $options); + $this->assertInstanceOf(RawBedrockResult::class, $response); + } +} diff --git a/src/store/src/Bridge/Pinecone/Store.php b/src/store/src/Bridge/Pinecone/Store.php index deabd8bbbf..7bc0ab7e7b 100644 --- a/src/store/src/Bridge/Pinecone/Store.php +++ b/src/store/src/Bridge/Pinecone/Store.php @@ -16,25 +16,53 @@ use Symfony\AI\Platform\Vector\Vector; use Symfony\AI\Store\Document\Metadata; use Symfony\AI\Store\Document\VectorDocument; +use Symfony\AI\Store\Exception\InvalidArgumentException; +use Symfony\AI\Store\ManagedStoreInterface; use Symfony\AI\Store\StoreInterface; use Symfony\Component\Uid\Uuid; /** * @author Christopher Hertel */ -final class Store implements StoreInterface +final class Store implements ManagedStoreInterface, StoreInterface { /** * @param array $filter */ public function __construct( private readonly Client $pinecone, + private readonly string $indexName, private readonly ?string $namespace = null, private readonly array $filter = [], private readonly int $topK = 3, ) { } + /** + * @param array{ + * dimension?: int, + * metric?: string, + * cloud?: string, + * region?: string, + * } $options + */ + public function setup(array $options = []): void + { + if (false === isset($options['dimension'])) { + throw new InvalidArgumentException('The "dimension" option is required.'); + } + + $this->pinecone + ->control() + ->index($this->indexName) + ->createServerless( + $options['dimension'], + $options['metric'] ?? null, + $options['cloud'] ?? null, + $options['region'] ?? null, + ); + } + public function add(VectorDocument ...$documents): void { $vectors = []; @@ -73,6 +101,14 @@ public function query(Vector $vector, array $options = []): iterable } } + public function drop(array $options = []): void + { + $this->pinecone + ->control() + ->index($this->indexName) + ->delete(); + } + private function getVectors(): VectorResource { return $this->pinecone->data()->vectors(); diff --git a/src/store/src/Bridge/Pinecone/Tests/StoreTest.php b/src/store/src/Bridge/Pinecone/Tests/StoreTest.php index a3fb23bde2..d7377f3aa5 100644 --- a/src/store/src/Bridge/Pinecone/Tests/StoreTest.php +++ b/src/store/src/Bridge/Pinecone/Tests/StoreTest.php @@ -13,6 +13,8 @@ use PHPUnit\Framework\TestCase; use Probots\Pinecone\Client; +use Probots\Pinecone\Resources\Control\IndexResource; +use Probots\Pinecone\Resources\ControlResource; use Probots\Pinecone\Resources\Data\VectorResource; use Probots\Pinecone\Resources\DataResource; use Saloon\Http\Response; @@ -20,6 +22,7 @@ use Symfony\AI\Store\Bridge\Pinecone\Store; use Symfony\AI\Store\Document\Metadata; use Symfony\AI\Store\Document\VectorDocument; +use Symfony\AI\Store\Exception\InvalidArgumentException; use Symfony\Component\Uid\Uuid; final class StoreTest extends TestCase @@ -53,10 +56,8 @@ public function testAddSingleDocument() null, ); - $store = new Store($client); - $document = new VectorDocument($uuid, new Vector([0.1, 0.2, 0.3]), new Metadata(['title' => 'Test Document'])); - $store->add($document); + self::createStore($client)->add($document); } public function testAddMultipleDocuments() @@ -94,12 +95,10 @@ public function testAddMultipleDocuments() null, ); - $store = new Store($client); - $document1 = new VectorDocument($uuid1, new Vector([0.1, 0.2, 0.3])); $document2 = new VectorDocument($uuid2, new Vector([0.4, 0.5, 0.6]), new Metadata(['title' => 'Second Document'])); - $store->add($document1, $document2); + self::createStore($client)->add($document1, $document2); } public function testAddWithNamespace() @@ -131,10 +130,8 @@ public function testAddWithNamespace() 'test-namespace', ); - $store = new Store($client, 'test-namespace'); - $document = new VectorDocument($uuid, new Vector([0.1, 0.2, 0.3])); - $store->add($document); + self::createStore($client, namespace: 'test-namespace')->add($document); } public function testAddWithEmptyDocuments() @@ -144,8 +141,7 @@ public function testAddWithEmptyDocuments() $client->expects($this->never()) ->method('data'); - $store = new Store($client); - $store->add(); + self::createStore($client)->add(); } public function testQueryReturnsDocuments() @@ -194,9 +190,7 @@ public function testQueryReturnsDocuments() ) ->willReturn($response); - $store = new Store($client); - - $results = iterator_to_array($store->query(new Vector([0.1, 0.2, 0.3]))); + $results = iterator_to_array(self::createStore($client)->query(new Vector([0.1, 0.2, 0.3]))); $this->assertCount(2, $results); $this->assertInstanceOf(VectorDocument::class, $results[0]); @@ -239,9 +233,7 @@ public function testQueryWithNamespaceAndFilter() ) ->willReturn($response); - $store = new Store($client, 'test-namespace', ['category' => 'test'], 5); - - $results = iterator_to_array($store->query(new Vector([0.1, 0.2, 0.3]))); + $results = iterator_to_array(self::createStore($client, namespace: 'test-namespace', filter: ['category' => 'test'], topK: 5)->query(new Vector([0.1, 0.2, 0.3]))); $this->assertCount(0, $results); } @@ -276,9 +268,7 @@ public function testQueryWithCustomOptions() ) ->willReturn($response); - $store = new Store($client); - - $results = iterator_to_array($store->query(new Vector([0.1, 0.2, 0.3]), [ + $results = iterator_to_array(self::createStore($client)->query(new Vector([0.1, 0.2, 0.3]), [ 'namespace' => 'custom-namespace', 'filter' => ['type' => 'document'], 'topK' => 10, @@ -310,10 +300,74 @@ public function testQueryWithEmptyResults() ->method('query') ->willReturn($response); - $store = new Store($client); - - $results = iterator_to_array($store->query(new Vector([0.1, 0.2, 0.3]))); + $results = iterator_to_array(self::createStore($client)->query(new Vector([0.1, 0.2, 0.3]))); $this->assertCount(0, $results); } + + public function testSetup() + { + $indexResource = $this->createMock(IndexResource::class); + $controlResource = $this->createMock(ControlResource::class); + $client = $this->createMock(Client::class); + + $client->expects($this->once()) + ->method('control') + ->willReturn($controlResource); + + $controlResource->expects($this->once()) + ->method('index') + ->with('my-index') + ->willReturn($indexResource); + + $indexResource->expects($this->once()) + ->method('createServerless') + ->with(1536, 'cosine', 'aws', 'us-east-1'); + + self::createStore($client, indexName: 'my-index')->setup([ + 'dimension' => 1536, + 'metric' => 'cosine', + 'cloud' => 'aws', + 'region' => 'us-east-1', + ]); + } + + public function testSetupThrowsExceptionWithoutDimension() + { + $client = $this->createMock(Client::class); + + $this->expectException(InvalidArgumentException::class); + $this->expectExceptionMessage('The "dimension" option is required.'); + + self::createStore($client, indexName: 'my-index')->setup([]); + } + + public function testDrop() + { + $indexResource = $this->createMock(IndexResource::class); + $controlResource = $this->createMock(ControlResource::class); + $client = $this->createMock(Client::class); + + $client->expects($this->once()) + ->method('control') + ->willReturn($controlResource); + + $controlResource->expects($this->once()) + ->method('index') + ->with('my-index') + ->willReturn($indexResource); + + $indexResource->expects($this->once()) + ->method('delete'); + + self::createStore($client, indexName: 'my-index')->drop(); + } + + /** + * @param array $filter + */ + private static function createStore(Client $client, string $indexName = 'test-index', ?string $namespace = null, array $filter = [], int $topK = 3): Store + { + return new Store($client, $indexName, $namespace, $filter, $topK); + } }